13 | One-shot - inference setup for voices unseen at the training time, when prompts and speaker embeddings are provided as additional model inputs.↩
14 |
15 |
16 |
17 |
18 |
19 | | Prompt audio | Reference audio | PHEME (100M) | PHEME (300M) no speaker embeddings | PHEME (300M) | Prompt text | Reference text |
20 | | :----------------------------------------------------------- | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ |
21 | | | | | | | let's just say in her own words, once i sat down and watched it i never moved, i w as enthralled by it. | and she told me the next time she went back she would take me with her. and i waited, of course, like i said, thirteen years. |
22 | | | | | | | in early twenty-twenty, blue apron put the word out that it was interested in possibly getting scooped up. maybe by a big grocery chain. or someone else with deep pockets who wanted to own a meal kit delivery business. | at the same time, garcia says, the company acted like it was in turnaround mode. it decid ed to streamline operations, including shutting down its fulfillment center in texas |
23 | | | | | | | aside from influencing basically everyone who matters he was one of the first if not, in fact the first artist to bring an electric guitar player with him on to the grand oleopry stag e. | if you want to call it a honky tonk, and it happened after ernest tubb. it was influenced by ernest tubb. before i get to the story and episode, i'd like to address one other thing. |
24 | | | | | | | so it's ah i think there's a range of risks, but generally speaking ah there's goi ng to be a study increase in the floor of the skill level as these ah a i technologies diffuse. | that is, there will be more and more ah capabilities available to people at the bottom of the scale, that is individuals as well as people with more access to computing power, ah money, and data at the higher end. |
25 | | | | | | | so after they put in their name, phone number, email address onto your landing pag e. where would you like to send them? would you like to send them to your facebook page your website? | book an appointment to a buyer on facebook messenger bot, a seller messenger bot. where w ould you like to send them? so for this example i'm just gonna say book an appointment. |
26 |
27 |
28 |
29 | ### Artificial Voice TTS Examples
30 |
31 | | Prompt audio | Reference audio | PHEME (300M) no training on artificial voice | PHEME (300M) | Prompt text | Reference text |
32 | | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ |
33 | | | | | | Our garden terrace is a lovely spot for afternoon tea. | The city’s ghost walk is a spooky and fascinating evening adventure. |
34 | | | | | | If you need a quiet place to work, our library is just perfect. | Our hotel’s evening bonfires are a great place to socialize. |
35 | | | | | | There’s a delightful chocolate factory tour, great for families. | Our rooftop jazz nights feature some of the best local talent. |
36 | | | | | | The rooftop bar hosts a live DJ on Friday nights. | Our in-house sommelier leads an exquisite wine and cheese pairing event. |
37 | | | | | | The comedy club in town is known for its hilarious acts. | The annual food fair showcases the best of local cuisine. |
38 |
39 | ### Inference speed with Triton-LLM (RTFs, lower is better) for short and long sentences
40 |
41 | | Model | *short* | *long* | GPU |
42 | | ------------------ | --------- | --------- |--------- |
43 | | MQTTS (100M) | 1.930 | 1.842 | A100 |
44 | | PHEME-SMALL (100M) | **0.133** | **0.133** | A100 |
45 | | PHEME-LARGE (300M) | 0.143 | 0.143 | A100 |
46 |
47 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Attribution 4.0 International
2 |
3 | =======================================================================
4 |
5 | Creative Commons Corporation ("Creative Commons") is not a law firm and
6 | does not provide legal services or legal advice. Distribution of
7 | Creative Commons public licenses does not create a lawyer-client or
8 | other relationship. Creative Commons makes its licenses and related
9 | information available on an "as-is" basis. Creative Commons gives no
10 | warranties regarding its licenses, any material licensed under their
11 | terms and conditions, or any related information. Creative Commons
12 | disclaims all liability for damages resulting from their use to the
13 | fullest extent possible.
14 |
15 | Using Creative Commons Public Licenses
16 |
17 | Creative Commons public licenses provide a standard set of terms and
18 | conditions that creators and other rights holders may use to share
19 | original works of authorship and other material subject to copyright
20 | and certain other rights specified in the public license below. The
21 | following considerations are for informational purposes only, are not
22 | exhaustive, and do not form part of our licenses.
23 |
24 | Considerations for licensors: Our public licenses are
25 | intended for use by those authorized to give the public
26 | permission to use material in ways otherwise restricted by
27 | copyright and certain other rights. Our licenses are
28 | irrevocable. Licensors should read and understand the terms
29 | and conditions of the license they choose before applying it.
30 | Licensors should also secure all rights necessary before
31 | applying our licenses so that the public can reuse the
32 | material as expected. Licensors should clearly mark any
33 | material not subject to the license. This includes other CC-
34 | licensed material, or material used under an exception or
35 | limitation to copyright. More considerations for licensors:
36 | wiki.creativecommons.org/Considerations_for_licensors
37 |
38 | Considerations for the public: By using one of our public
39 | licenses, a licensor grants the public permission to use the
40 | licensed material under specified terms and conditions. If
41 | the licensor's permission is not necessary for any reason--for
42 | example, because of any applicable exception or limitation to
43 | copyright--then that use is not regulated by the license. Our
44 | licenses grant only permissions under copyright and certain
45 | other rights that a licensor has authority to grant. Use of
46 | the licensed material may still be restricted for other
47 | reasons, including because others have copyright or other
48 | rights in the material. A licensor may make special requests,
49 | such as asking that all changes be marked or described.
50 | Although not required by our licenses, you are encouraged to
51 | respect those requests where reasonable. More_considerations
52 | for the public:
53 | wiki.creativecommons.org/Considerations_for_licensees
54 |
55 | =======================================================================
56 |
57 | Creative Commons Attribution 4.0 International Public License
58 |
59 | By exercising the Licensed Rights (defined below), You accept and agree
60 | to be bound by the terms and conditions of this Creative Commons
61 | Attribution 4.0 International Public License ("Public License"). To the
62 | extent this Public License may be interpreted as a contract, You are
63 | granted the Licensed Rights in consideration of Your acceptance of
64 | these terms and conditions, and the Licensor grants You such rights in
65 | consideration of benefits the Licensor receives from making the
66 | Licensed Material available under these terms and conditions.
67 |
68 |
69 | Section 1 -- Definitions.
70 |
71 | a. Adapted Material means material subject to Copyright and Similar
72 | Rights that is derived from or based upon the Licensed Material
73 | and in which the Licensed Material is translated, altered,
74 | arranged, transformed, or otherwise modified in a manner requiring
75 | permission under the Copyright and Similar Rights held by the
76 | Licensor. For purposes of this Public License, where the Licensed
77 | Material is a musical work, performance, or sound recording,
78 | Adapted Material is always produced where the Licensed Material is
79 | synched in timed relation with a moving image.
80 |
81 | b. Adapter's License means the license You apply to Your Copyright
82 | and Similar Rights in Your contributions to Adapted Material in
83 | accordance with the terms and conditions of this Public License.
84 |
85 | c. Copyright and Similar Rights means copyright and/or similar rights
86 | closely related to copyright including, without limitation,
87 | performance, broadcast, sound recording, and Sui Generis Database
88 | Rights, without regard to how the rights are labeled or
89 | categorized. For purposes of this Public License, the rights
90 | specified in Section 2(b)(1)-(2) are not Copyright and Similar
91 | Rights.
92 |
93 | d. Effective Technological Measures means those measures that, in the
94 | absence of proper authority, may not be circumvented under laws
95 | fulfilling obligations under Article 11 of the WIPO Copyright
96 | Treaty adopted on December 20, 1996, and/or similar international
97 | agreements.
98 |
99 | e. Exceptions and Limitations means fair use, fair dealing, and/or
100 | any other exception or limitation to Copyright and Similar Rights
101 | that applies to Your use of the Licensed Material.
102 |
103 | f. Licensed Material means the artistic or literary work, database,
104 | or other material to which the Licensor applied this Public
105 | License.
106 |
107 | g. Licensed Rights means the rights granted to You subject to the
108 | terms and conditions of this Public License, which are limited to
109 | all Copyright and Similar Rights that apply to Your use of the
110 | Licensed Material and that the Licensor has authority to license.
111 |
112 | h. Licensor means the individual(s) or entity(ies) granting rights
113 | under this Public License.
114 |
115 | i. Share means to provide material to the public by any means or
116 | process that requires permission under the Licensed Rights, such
117 | as reproduction, public display, public performance, distribution,
118 | dissemination, communication, or importation, and to make material
119 | available to the public including in ways that members of the
120 | public may access the material from a place and at a time
121 | individually chosen by them.
122 |
123 | j. Sui Generis Database Rights means rights other than copyright
124 | resulting from Directive 96/9/EC of the European Parliament and of
125 | the Council of 11 March 1996 on the legal protection of databases,
126 | as amended and/or succeeded, as well as other essentially
127 | equivalent rights anywhere in the world.
128 |
129 | k. You means the individual or entity exercising the Licensed Rights
130 | under this Public License. Your has a corresponding meaning.
131 |
132 |
133 | Section 2 -- Scope.
134 |
135 | a. License grant.
136 |
137 | 1. Subject to the terms and conditions of this Public License,
138 | the Licensor hereby grants You a worldwide, royalty-free,
139 | non-sublicensable, non-exclusive, irrevocable license to
140 | exercise the Licensed Rights in the Licensed Material to:
141 |
142 | a. reproduce and Share the Licensed Material, in whole or
143 | in part; and
144 |
145 | b. produce, reproduce, and Share Adapted Material.
146 |
147 | 2. Exceptions and Limitations. For the avoidance of doubt, where
148 | Exceptions and Limitations apply to Your use, this Public
149 | License does not apply, and You do not need to comply with
150 | its terms and conditions.
151 |
152 | 3. Term. The term of this Public License is specified in Section
153 | 6(a).
154 |
155 | 4. Media and formats; technical modifications allowed. The
156 | Licensor authorizes You to exercise the Licensed Rights in
157 | all media and formats whether now known or hereafter created,
158 | and to make technical modifications necessary to do so. The
159 | Licensor waives and/or agrees not to assert any right or
160 | authority to forbid You from making technical modifications
161 | necessary to exercise the Licensed Rights, including
162 | technical modifications necessary to circumvent Effective
163 | Technological Measures. For purposes of this Public License,
164 | simply making modifications authorized by this Section 2(a)
165 | (4) never produces Adapted Material.
166 |
167 | 5. Downstream recipients.
168 |
169 | a. Offer from the Licensor -- Licensed Material. Every
170 | recipient of the Licensed Material automatically
171 | receives an offer from the Licensor to exercise the
172 | Licensed Rights under the terms and conditions of this
173 | Public License.
174 |
175 | b. No downstream restrictions. You may not offer or impose
176 | any additional or different terms or conditions on, or
177 | apply any Effective Technological Measures to, the
178 | Licensed Material if doing so restricts exercise of the
179 | Licensed Rights by any recipient of the Licensed
180 | Material.
181 |
182 | 6. No endorsement. Nothing in this Public License constitutes or
183 | may be construed as permission to assert or imply that You
184 | are, or that Your use of the Licensed Material is, connected
185 | with, or sponsored, endorsed, or granted official status by,
186 | the Licensor or others designated to receive attribution as
187 | provided in Section 3(a)(1)(A)(i).
188 |
189 | b. Other rights.
190 |
191 | 1. Moral rights, such as the right of integrity, are not
192 | licensed under this Public License, nor are publicity,
193 | privacy, and/or other similar personality rights; however, to
194 | the extent possible, the Licensor waives and/or agrees not to
195 | assert any such rights held by the Licensor to the limited
196 | extent necessary to allow You to exercise the Licensed
197 | Rights, but not otherwise.
198 |
199 | 2. Patent and trademark rights are not licensed under this
200 | Public License.
201 |
202 | 3. To the extent possible, the Licensor waives any right to
203 | collect royalties from You for the exercise of the Licensed
204 | Rights, whether directly or through a collecting society
205 | under any voluntary or waivable statutory or compulsory
206 | licensing scheme. In all other cases the Licensor expressly
207 | reserves any right to collect such royalties.
208 |
209 |
210 | Section 3 -- License Conditions.
211 |
212 | Your exercise of the Licensed Rights is expressly made subject to the
213 | following conditions.
214 |
215 | a. Attribution.
216 |
217 | 1. If You Share the Licensed Material (including in modified
218 | form), You must:
219 |
220 | a. retain the following if it is supplied by the Licensor
221 | with the Licensed Material:
222 |
223 | i. identification of the creator(s) of the Licensed
224 | Material and any others designated to receive
225 | attribution, in any reasonable manner requested by
226 | the Licensor (including by pseudonym if
227 | designated);
228 |
229 | ii. a copyright notice;
230 |
231 | iii. a notice that refers to this Public License;
232 |
233 | iv. a notice that refers to the disclaimer of
234 | warranties;
235 |
236 | v. a URI or hyperlink to the Licensed Material to the
237 | extent reasonably practicable;
238 |
239 | b. indicate if You modified the Licensed Material and
240 | retain an indication of any previous modifications; and
241 |
242 | c. indicate the Licensed Material is licensed under this
243 | Public License, and include the text of, or the URI or
244 | hyperlink to, this Public License.
245 |
246 | 2. You may satisfy the conditions in Section 3(a)(1) in any
247 | reasonable manner based on the medium, means, and context in
248 | which You Share the Licensed Material. For example, it may be
249 | reasonable to satisfy the conditions by providing a URI or
250 | hyperlink to a resource that includes the required
251 | information.
252 |
253 | 3. If requested by the Licensor, You must remove any of the
254 | information required by Section 3(a)(1)(A) to the extent
255 | reasonably practicable.
256 |
257 | 4. If You Share Adapted Material You produce, the Adapter's
258 | License You apply must not prevent recipients of the Adapted
259 | Material from complying with this Public License.
260 |
261 |
262 | Section 4 -- Sui Generis Database Rights.
263 |
264 | Where the Licensed Rights include Sui Generis Database Rights that
265 | apply to Your use of the Licensed Material:
266 |
267 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right
268 | to extract, reuse, reproduce, and Share all or a substantial
269 | portion of the contents of the database;
270 |
271 | b. if You include all or a substantial portion of the database
272 | contents in a database in which You have Sui Generis Database
273 | Rights, then the database in which You have Sui Generis Database
274 | Rights (but not its individual contents) is Adapted Material; and
275 |
276 | c. You must comply with the conditions in Section 3(a) if You Share
277 | all or a substantial portion of the contents of the database.
278 |
279 | For the avoidance of doubt, this Section 4 supplements and does not
280 | replace Your obligations under this Public License where the Licensed
281 | Rights include other Copyright and Similar Rights.
282 |
283 |
284 | Section 5 -- Disclaimer of Warranties and Limitation of Liability.
285 |
286 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
287 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
288 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
289 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
290 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
291 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
292 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
293 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
294 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
295 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
296 |
297 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
298 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
299 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
300 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
301 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
302 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
303 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
304 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
305 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
306 |
307 | c. The disclaimer of warranties and limitation of liability provided
308 | above shall be interpreted in a manner that, to the extent
309 | possible, most closely approximates an absolute disclaimer and
310 | waiver of all liability.
311 |
312 |
313 | Section 6 -- Term and Termination.
314 |
315 | a. This Public License applies for the term of the Copyright and
316 | Similar Rights licensed here. However, if You fail to comply with
317 | this Public License, then Your rights under this Public License
318 | terminate automatically.
319 |
320 | b. Where Your right to use the Licensed Material has terminated under
321 | Section 6(a), it reinstates:
322 |
323 | 1. automatically as of the date the violation is cured, provided
324 | it is cured within 30 days of Your discovery of the
325 | violation; or
326 |
327 | 2. upon express reinstatement by the Licensor.
328 |
329 | For the avoidance of doubt, this Section 6(b) does not affect any
330 | right the Licensor may have to seek remedies for Your violations
331 | of this Public License.
332 |
333 | c. For the avoidance of doubt, the Licensor may also offer the
334 | Licensed Material under separate terms or conditions or stop
335 | distributing the Licensed Material at any time; however, doing so
336 | will not terminate this Public License.
337 |
338 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
339 | License.
340 |
341 |
342 | Section 7 -- Other Terms and Conditions.
343 |
344 | a. The Licensor shall not be bound by any additional or different
345 | terms or conditions communicated by You unless expressly agreed.
346 |
347 | b. Any arrangements, understandings, or agreements regarding the
348 | Licensed Material not stated herein are separate from and
349 | independent of the terms and conditions of this Public License.
350 |
351 |
352 | Section 8 -- Interpretation.
353 |
354 | a. For the avoidance of doubt, this Public License does not, and
355 | shall not be interpreted to, reduce, limit, restrict, or impose
356 | conditions on any use of the Licensed Material that could lawfully
357 | be made without permission under this Public License.
358 |
359 | b. To the extent possible, if any provision of this Public License is
360 | deemed unenforceable, it shall be automatically reformed to the
361 | minimum extent necessary to make it enforceable. If the provision
362 | cannot be reformed, it shall be severed from this Public License
363 | without affecting the enforceability of the remaining terms and
364 | conditions.
365 |
366 | c. No term or condition of this Public License will be waived and no
367 | failure to comply consented to unless expressly agreed to by the
368 | Licensor.
369 |
370 | d. Nothing in this Public License constitutes or may be interpreted
371 | as a limitation upon, or waiver of, any privileges and immunities
372 | that apply to the Licensor or You, including from the legal
373 | processes of any jurisdiction or authority.
374 |
375 |
376 | =======================================================================
377 |
378 | Creative Commons is not a party to its public
379 | licenses. Notwithstanding, Creative Commons may elect to apply one of
380 | its public licenses to material it publishes and in those instances
381 | will be considered the “Licensor.” The text of the Creative Commons
382 | public licenses is dedicated to the public domain under the CC0 Public
383 | Domain Dedication. Except for the limited purpose of indicating that
384 | material is shared under a Creative Commons public license or as
385 | otherwise permitted by the Creative Commons policies published at
386 | creativecommons.org/policies, Creative Commons does not authorize the
387 | use of the trademark "Creative Commons" or any other trademark or logo
388 | of Creative Commons without its prior written consent including,
389 | without limitation, in connection with any unauthorized modifications
390 | to any of its public licenses or any other arrangements,
391 | understandings, or agreements concerning use of licensed material. For
392 | the avoidance of doubt, this paragraph does not form part of the
393 | public licenses.
394 |
395 | Creative Commons may be contacted at creativecommons.org.
--------------------------------------------------------------------------------
/modules/conformer.py:
--------------------------------------------------------------------------------
1 | """Conformer definition adjusted given the Lucidrain's repo.
2 | https://github.com/lucidrains/soundstorm-pytorch/blob/main/soundstorm_pytorch/soundstorm.py # noqa
3 |
4 | Copyright PolyAI Limited.
5 | """
6 | from collections import namedtuple
7 | from functools import wraps
8 | from typing import Dict, Union
9 |
10 | import torch
11 | import torch.nn.functional as F
12 | from einops import rearrange, reduce
13 | from einops.layers.torch import EinMix, Rearrange
14 | from torch import einsum, nn
15 |
16 |
17 | # rotary embedding
18 | class RotaryEmbedding(nn.Module):
19 | def __init__(self, dim, theta = 10000):
20 | super().__init__()
21 | inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
22 | self.register_buffer("inv_freq", inv_freq, persistent = False)
23 |
24 | @property
25 | def device(self):
26 | return next(self.buffers()).device
27 |
28 | def forward(self, seq_len):
29 | t = torch.arange(seq_len, device = self.device).type_as(self.inv_freq)
30 | freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
31 | freqs = torch.cat((freqs, freqs), dim = -1)
32 | return freqs
33 |
34 | def rotate_half(x):
35 | x1, x2 = x.chunk(2, dim=-1)
36 | return torch.cat((-x2, x1), dim=-1)
37 |
38 | def apply_rotary_pos_emb(pos, t):
39 | return (t * pos.cos()) + (rotate_half(t) * pos.sin())
40 |
41 |
42 | # constants
43 | EfficientAttentionConfig = namedtuple(
44 | 'EfficientAttentionConfig',
45 | ['enable_flash', 'enable_math', 'enable_mem_efficient']
46 | )
47 |
48 | # helpers
49 | def exists(val):
50 | return val is not None
51 |
52 | def default(val, d):
53 | return val if exists(val) else d
54 |
55 | def divisible_by(numer, denom):
56 | return (numer % denom) == 0
57 |
58 | def calc_same_padding(kernel_size):
59 | pad = kernel_size // 2
60 | return (pad, pad - (kernel_size + 1) % 2)
61 |
62 | def eval_decorator(fn):
63 | @wraps(fn)
64 | def inner(model, *args, **kwargs):
65 | was_training = model.training
66 | model.eval()
67 | out = fn(model, *args, **kwargs)
68 | model.train(was_training)
69 | return out
70 | return inner
71 |
72 |
73 | def once(fn):
74 | called = False
75 | @wraps(fn)
76 | def inner(x):
77 | nonlocal called
78 | if called:
79 | return
80 | called = True
81 | return fn(x)
82 | return inner
83 |
84 | print_once = once(print)
85 |
86 |
87 | # t5 relative positional bias
88 | class T5RelativePositionBias(nn.Module):
89 | def __init__(
90 | self,
91 | scale = 1.,
92 | num_buckets = 32,
93 | max_distance = 128,
94 | heads = 8
95 | ):
96 | super().__init__()
97 | self.scale = scale
98 | self.num_buckets = num_buckets
99 | self.max_distance = max_distance
100 | self.relative_attention_bias = nn.Embedding(num_buckets, heads)
101 |
102 | @staticmethod
103 | def _relative_position_bucket(
104 | relative_position,
105 | num_buckets = 32,
106 | max_distance = 128
107 | ):
108 | ret = 0
109 | n = -relative_position
110 |
111 | num_buckets //= 2
112 | ret += (n < 0).long() * num_buckets
113 | n = torch.abs(n)
114 |
115 | max_exact = num_buckets // 2
116 | is_small = n < max_exact
117 |
118 | val_if_large = max_exact + (
119 | torch.log(n.float() / max_exact) / math.log(
120 | max_distance / max_exact) * (num_buckets - max_exact)
121 | ).long()
122 |
123 | val_if_large = torch.min(
124 | val_if_large,
125 | torch.full_like(val_if_large, num_buckets - 1)
126 | )
127 |
128 | ret += torch.where(is_small, n, val_if_large)
129 | return ret
130 |
131 | @property
132 | def device(self):
133 | return next(self.parameters()).device
134 |
135 | def forward(self, n):
136 | pos = torch.arange(n, device = self.device).long()
137 | rel_pos = rearrange(pos, 'j -> 1 j') - rearrange(pos, 'i -> i 1')
138 |
139 | rp_bucket = self._relative_position_bucket(
140 | rel_pos, num_buckets = self.num_buckets,
141 | max_distance = self.max_distance)
142 | values = self.relative_attention_bias(rp_bucket)
143 |
144 | bias = rearrange(values, 'i j h -> h i j')
145 | return bias * self.scale
146 |
147 |
148 | # main class
149 | class Attend(nn.Module):
150 | def __init__(
151 | self,
152 | causal = False,
153 | dropout = 0.,
154 | flash = False
155 | ):
156 | super().__init__()
157 | self.dropout = dropout
158 | self.attn_dropout = nn.Dropout(dropout)
159 |
160 | self.causal = causal
161 | self.flash = flash
162 |
163 | # determine efficient attention configs for cuda and cpu
164 | self.cpu_config = EfficientAttentionConfig(True, True, True)
165 | self.cuda_config = None
166 |
167 | if not torch.cuda.is_available() or not flash:
168 | return
169 |
170 | device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
171 |
172 | if device_properties.major == 8 and device_properties.minor == 0:
173 | print_once('A100 GPU detected, using flash attention if input tensor is on cuda') # noqa
174 | self.cuda_config = EfficientAttentionConfig(True, True, True)
175 | else:
176 | print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda') # noqa
177 | self.cuda_config = EfficientAttentionConfig(False, True, True)
178 |
179 | def get_mask(self, i, j, device):
180 | return torch.ones((i, j), device=device, dtype=torch.bool).triu(j - i + 1) # noqa
181 |
182 | def flash_attn(self, q, k, v, mask = None, attn_bias = None):
183 | _, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device # noqa
184 |
185 | # single headed key / values
186 |
187 | if k.ndim == 3:
188 | k = rearrange(k, 'b n d -> b 1 n d')
189 |
190 | if v.ndim == 3:
191 | v = rearrange(v, 'b n d -> b 1 n d')
192 |
193 | # Check if mask exists and expand to compatible shape
194 | # The mask is B L, so it would have to be expanded to B H N L
195 | if exists(mask) and mask.ndim != 4:
196 | mask = rearrange(mask, 'b j -> b 1 1 j')
197 | mask = mask.expand(-1, heads, q_len, -1)
198 |
199 | # Check if there is a compatible device for flash attention
200 | config = self.cuda_config if is_cuda else self.cpu_config
201 | causal = self.causal
202 |
203 | # handle attention bias
204 | if exists(attn_bias):
205 | mask_value = -torch.finfo(q.dtype).max // 2
206 | causal_mask = self.get_mask(q_len, k_len, device)
207 | attn_bias = attn_bias.masked_fill(causal_mask, mask_value)
208 |
209 | if exists(mask):
210 | attn_bias = attn_bias.masked_fill(~mask, mask_value)
211 |
212 | mask = attn_bias
213 | causal = False
214 |
215 | # pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale
216 | with torch.backends.cuda.sdp_kernel(**config._asdict()):
217 | out = F.scaled_dot_product_attention(
218 | q, k, v,
219 | attn_mask = mask,
220 | dropout_p = self.dropout if self.training else 0.,
221 | is_causal = causal
222 | )
223 |
224 | return out
225 |
226 | def forward(self, q, k, v, mask = None, attn_bias = None):
227 | """
228 | einstein notation
229 | b - batch
230 | h - heads
231 | n, i, j - sequence length (base sequence length, source, target)
232 | d - feature dimension
233 | """
234 |
235 | q_len, k_len, device = q.shape[-2], k.shape[-2], q.device
236 |
237 | scale = q.shape[-1] ** -0.5
238 |
239 | kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d'
240 |
241 | if self.flash:
242 | assert not exists(attn_bias)
243 | return self.flash_attn(q, k, v, mask = mask)
244 |
245 | # similarity
246 |
247 | sim = einsum(f"b h i d, {kv_einsum_eq} -> b h i j", q, k) * scale
248 |
249 | # attention bias
250 |
251 | if exists(attn_bias):
252 | sim = sim + attn_bias
253 |
254 | # causal mask
255 | if self.causal:
256 | causal_mask = self.get_mask(q_len, k_len, device)
257 | sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
258 |
259 | # key padding mask
260 | if exists(mask):
261 | if mask.ndim != 4:
262 | mask = rearrange(mask, 'b j -> b 1 1 j')
263 | sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
264 |
265 | # attention
266 | attn = sim.softmax(dim=-1)
267 | attn = self.attn_dropout(attn)
268 |
269 | # aggregate values
270 | out = einsum(f"b h i j, {kv_einsum_eq} -> b h i d", attn, v)
271 |
272 | return out
273 |
274 |
275 | class Swish(nn.Module):
276 | def forward(self, x):
277 | return x * x.sigmoid()
278 |
279 |
280 | class GLU(nn.Module):
281 | def __init__(self, dim):
282 | super().__init__()
283 | self.dim = dim
284 |
285 | def forward(self, x):
286 | out, gate = x.chunk(2, dim=self.dim)
287 | return out * gate.sigmoid()
288 |
289 |
290 | class DepthWiseConv1d(nn.Module):
291 | def __init__(self, chan_in, chan_out, kernel_size, padding):
292 | super().__init__()
293 | self.padding = padding
294 | self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, groups = chan_in)
295 |
296 | def forward(self, x):
297 | x = F.pad(x, self.padding)
298 | return self.conv(x)
299 |
300 |
301 | class Scale(nn.Module):
302 | def __init__(self, scale, fn):
303 | super().__init__()
304 | self.fn = fn
305 | self.scale = scale
306 |
307 | def forward(self, x, **kwargs):
308 | return self.fn(x, **kwargs) * self.scale
309 |
310 |
311 | class ChanLayerNorm(nn.Module):
312 | def __init__(self, dim):
313 | super().__init__()
314 | self.gamma = nn.Parameter(torch.ones(1, dim, 1))
315 |
316 | def forward(self, x):
317 | eps = 1e-6 if x.dtype == torch.float32 else 1e-4
318 | var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
319 | mean = torch.mean(x, dim = 1, keepdim = True)
320 | return (x - mean) * var.clamp(min = eps).rsqrt() * self.gamma
321 |
322 |
323 | class PreNorm(nn.Module):
324 | def __init__(self, dim, fn):
325 | super().__init__()
326 | self.fn = fn
327 | self.norm = nn.LayerNorm(dim)
328 |
329 | def forward(self, x, **kwargs):
330 | x = self.norm(x)
331 | return self.fn(x, **kwargs)
332 |
333 |
334 | class Attention(nn.Module):
335 | def __init__(
336 | self,
337 | dim,
338 | heads = 8,
339 | dim_head = 64,
340 | dropout = 0.,
341 | flash = True
342 | ):
343 | super().__init__()
344 | inner_dim = dim_head * heads
345 | self.heads= heads
346 | self.scale = dim_head ** -0.5
347 |
348 | self.attend = Attend(
349 | flash = flash,
350 | dropout = dropout
351 | )
352 |
353 | self.dropout = nn.Dropout(dropout)
354 |
355 | self.to_q = nn.Linear(dim, inner_dim, bias = False)
356 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
357 | self.to_out = nn.Linear(inner_dim, dim)
358 |
359 | def forward(
360 | self,
361 | x,
362 | context = None,
363 | mask = None,
364 | rotary_emb = None,
365 | attn_bias = None
366 | ):
367 | n, device, h, has_context = x.shape[-2], x.device, self.heads, exists(context)
368 | context = default(context, x)
369 |
370 | q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
371 | q, k, v = map(
372 | lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
373 |
374 | if exists(rotary_emb):
375 | q = apply_rotary_pos_emb(rotary_emb, q)
376 | k = apply_rotary_pos_emb(rotary_emb, k)
377 |
378 | out = self.attend(q, k, v, mask = mask, attn_bias = attn_bias)
379 |
380 | out = rearrange(out, 'b h n d -> b n (h d)')
381 | return self.to_out(out)
382 |
383 |
384 | class FeedForward(nn.Module):
385 | def __init__(
386 | self,
387 | dim,
388 | mult = 4,
389 | dropout = 0.
390 | ):
391 | super().__init__()
392 | self.net = nn.Sequential(
393 | nn.Linear(dim, dim * mult),
394 | Swish(),
395 | nn.Dropout(dropout),
396 | nn.Linear(dim * mult, dim),
397 | nn.Dropout(dropout)
398 | )
399 |
400 | def forward(self, x):
401 | return self.net(x)
402 |
403 |
404 | class ConformerConvModule(nn.Module):
405 | def __init__(
406 | self,
407 | dim,
408 | causal = False,
409 | expansion_factor = 2,
410 | kernel_size = 31,
411 | dropout = 0.
412 | ):
413 | super().__init__()
414 |
415 | inner_dim = dim * expansion_factor
416 | padding = calc_same_padding(kernel_size) if not causal else (kernel_size - 1, 0)
417 |
418 | self.net = nn.Sequential(
419 | nn.LayerNorm(dim),
420 | Rearrange('b n c -> b c n'),
421 | nn.Conv1d(dim, inner_dim * 2, 1),
422 | GLU(dim=1),
423 | DepthWiseConv1d(
424 | inner_dim, inner_dim, kernel_size = kernel_size,
425 | padding = padding
426 | ),
427 | Swish(),
428 | ChanLayerNorm(inner_dim),
429 | nn.Conv1d(inner_dim, dim, 1),
430 | Rearrange('b c n -> b n c'),
431 | nn.Dropout(dropout)
432 | )
433 |
434 | def forward(self, x):
435 | return self.net(x)
436 |
437 |
438 | # Conformer Block
439 | class ConformerBlock(nn.Module):
440 | def __init__(
441 | self,
442 | *,
443 | dim,
444 | dim_head = 64,
445 | heads = 8,
446 | ff_mult = 4,
447 | conv_expansion_factor = 2,
448 | conv_kernel_size = 31,
449 | attn_dropout = 0.,
450 | attn_flash = True,
451 | ff_dropout = 0.,
452 | conv_dropout = 0.,
453 | conv_causal = False
454 | ):
455 | super().__init__()
456 | self.ff1 = FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
457 | self.attn = Attention(
458 | dim = dim, dim_head = dim_head, heads = heads,
459 | dropout = attn_dropout, flash = attn_flash
460 | )
461 | self.conv = ConformerConvModule(
462 | dim = dim, causal = conv_causal,
463 | expansion_factor = conv_expansion_factor,
464 | kernel_size = conv_kernel_size, dropout = conv_dropout
465 | )
466 | self.ff2 = FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
467 |
468 | self.attn = PreNorm(dim, self.attn)
469 | self.ff1 = Scale(0.5, PreNorm(dim, self.ff1))
470 | self.ff2 = Scale(0.5, PreNorm(dim, self.ff2))
471 |
472 | self.post_norm = nn.LayerNorm(dim)
473 |
474 | def forward(
475 | self,
476 | x,
477 | mask = None,
478 | rotary_emb = None,
479 | attn_bias = None
480 | ):
481 | x = self.ff1(x) + x
482 | x = self.attn(x, mask = mask, rotary_emb = rotary_emb, attn_bias = attn_bias) + x # noqa
483 | x = self.conv(x) + x
484 | x = self.ff2(x) + x
485 | x = self.post_norm(x)
486 | return x
487 |
488 |
489 | # Conformer
490 | class Conformer(nn.Module):
491 | def __init__(
492 | self,
493 | dim,
494 | *,
495 | num_layers,
496 | dim_head = 64,
497 | heads = 8,
498 | ff_mult = 4,
499 | conv_expansion_factor = 2,
500 | conv_kernel_size = 31,
501 | attn_dropout = 0.,
502 | ff_dropout = 0.,
503 | conv_dropout = 0.,
504 | conv_causal = False,
505 | attn_flash = True,
506 | t5_rel_pos_bias = False
507 | ):
508 | super().__init__()
509 |
510 | assert not (t5_rel_pos_bias and attn_flash), 'flash attention is not compatible with learned bias' # noqa
511 |
512 | self.dim = dim
513 | self.layers = nn.ModuleList([])
514 |
515 | self.rotary_emb = RotaryEmbedding(
516 | dim_head) if not t5_rel_pos_bias else None
517 | self.rel_pos_bias = T5RelativePositionBias(
518 | dim_head ** 0.5, heads = heads) if t5_rel_pos_bias else None
519 |
520 | for _ in range(num_layers):
521 | self.layers.append(ConformerBlock(
522 | dim = dim,
523 | dim_head = dim_head,
524 | heads = heads,
525 | ff_mult = ff_mult,
526 | conv_expansion_factor = conv_expansion_factor,
527 | conv_kernel_size = conv_kernel_size,
528 | attn_dropout = attn_dropout,
529 | ff_dropout = ff_dropout,
530 | conv_dropout = conv_dropout,
531 | conv_causal = conv_causal,
532 | attn_flash = attn_flash
533 | ))
534 |
535 | def forward(self, x, mask = None):
536 | seq_len = x.shape[-2]
537 |
538 | rotary_emb = self.rotary_emb(seq_len) if exists(self.rotary_emb) else None # noqa
539 | attn_bias = self.rel_pos_bias(seq_len) if exists(self.rel_pos_bias) else None #noqa
540 |
541 | for block in self.layers:
542 | x = block(
543 | x,
544 | mask = mask,
545 | rotary_emb = rotary_emb,
546 | attn_bias = attn_bias
547 | )
548 | return x
549 |
550 |
551 | # conformer with sum reduction across quantized tokens at the beginning,
552 | # along with heads
553 | class ConformerWrapper(nn.Module):
554 | def __init__(
555 | self,
556 | *,
557 | codebook_size,
558 | num_quantizers,
559 | conformer: Union[Conformer, Dict[str, any]],
560 | grouped_quantizers = 1
561 | ):
562 | super().__init__()
563 | self.conformer = conformer
564 |
565 | if isinstance(conformer, dict):
566 | self.conformer = Conformer(**self.conformer)
567 |
568 | dim = self.conformer.dim
569 |
570 | self.embedding_proj = nn.Sequential(
571 | nn.Linear(dim * grouped_quantizers, dim),
572 | nn.LayerNorm(dim)
573 | ) if grouped_quantizers > 1 else nn.Identity()
574 |
575 | num_codes_with_mask = codebook_size + 1
576 | num_effective_quantizers = num_quantizers * grouped_quantizers
577 |
578 | self.code_embeds = nn.Embedding(
579 | num_codes_with_mask * num_effective_quantizers, dim)
580 |
581 | self.register_buffer(
582 | 'quantizer_offsets',
583 | torch.arange(num_effective_quantizers) * num_codes_with_mask,
584 | persistent = False
585 | )
586 | self.register_buffer(
587 | 'mask_tokens', self.quantizer_offsets + num_codes_with_mask,
588 | persistent = False
589 | )
590 |
591 | self.dim = dim
592 | self.codebook_size = codebook_size
593 |
594 | self.num_codes_with_mask = num_codes_with_mask
595 | self.num_quantizers = num_quantizers
596 | self.grouped_quantizers = grouped_quantizers
597 |
598 | self.heads = nn.Sequential(
599 | nn.Linear(dim, dim * num_effective_quantizers),
600 | Rearrange('b n (h d) -> b (n h) d', h = num_effective_quantizers)
601 | )
602 |
603 | # each quantizer codebook would require its own logits weight
604 | # and bias matrices
605 | # the amazing einops makes this easy with 'EinMix'
606 | self.to_logits = nn.Sequential(
607 | nn.LayerNorm(dim),
608 | Rearrange('b (n gq) d -> b n gq d', gq = num_effective_quantizers),
609 | EinMix(
610 | 'b n gq d -> b n gq l',
611 | weight_shape = 'gq d l',
612 | bias_shape = 'gq l',
613 | gq = num_effective_quantizers,
614 | l = codebook_size,
615 | d = dim
616 | ),
617 | Rearrange('b ... d -> b (...) d')
618 | )
619 |
620 | def forward(
621 | self,
622 | x,
623 | *,
624 | mask = None,
625 | cond = None,
626 | sum_embeds = None,
627 | return_embeddings = False,
628 | return_logits_and_embeddings = False
629 | ):
630 | """
631 | einops notation:
632 | b - batch
633 | n - sequence
634 | g - groups
635 | q - quantizers
636 | d - feature dimension
637 | """
638 |
639 | n, q, g = x.shape[-1], self.num_quantizers, self.grouped_quantizers
640 | assert divisible_by(n, g * q), 'sequence must be divisible by number of quantizers' # noqa
641 |
642 | x = rearrange(x, 'b (n gq) -> b n gq', gq = g * q)
643 | x = x + self.quantizer_offsets
644 |
645 | x = self.code_embeds(x)
646 |
647 | x = reduce(x, 'b n (g q) d -> b n (g d)', 'sum', g = g)
648 |
649 | x = self.embedding_proj(x)
650 |
651 | if exists(sum_embeds):
652 | x = x + sum_embeds
653 |
654 | if exists(cond):
655 | if cond.ndim == 2:
656 | cond = rearrange(cond, 'b d -> b 1 d')
657 |
658 | x = x + cond
659 |
660 | x = self.conformer(x, mask = mask)
661 | embeds = self.heads(x)
662 |
663 | if return_embeddings or not exists(self.to_logits):
664 | return embeds
665 |
666 | logits = self.to_logits(embeds)
667 |
668 | if return_logits_and_embeddings:
669 | return logits, embeds
670 |
671 | return logits
672 |
--------------------------------------------------------------------------------
/modules/s2a_model.py:
--------------------------------------------------------------------------------
1 | """A2S model definition.
2 |
3 | Copyright PolyAI Limited.
4 | """
5 | from typing import Union
6 |
7 | import pytorch_lightning as pl
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | import torch.optim as optim
12 | from einops import rearrange
13 |
14 | import constants as c
15 | from modules import masking_logic
16 | from modules.conformer import Conformer
17 | from modules.masking_logic import (State, mask_by_random_topk,
18 | sample_from_logits, state_init)
19 | from utils import load_checkpoint
20 |
21 |
22 | class Pheme(pl.LightningModule):
23 | def __init__(self, hp):
24 | super().__init__()
25 | self.hp = hp
26 | self.model = TTSConformer(hp)
27 | self.cross_entropy = nn.CrossEntropyLoss(
28 | label_smoothing=self.hp.label_smoothing,
29 | ignore_index=self.hp.n_codes
30 | )
31 | if self.hp.pretrained_path:
32 | self.load()
33 | else:
34 | self.apply(self.init_weights)
35 |
36 | if self.hp.only_inference:
37 | self.model.eval()
38 |
39 | self.save_hyperparameters()
40 |
41 | def load(self):
42 | state_dict = load_checkpoint(self.hp.pretrained_path)
43 | print(f"Parameters loaded from {self.hp.pretrained_path}")
44 | self.load_state_dict(state_dict, strict=True)
45 |
46 | def init_weights(self, module):
47 | if isinstance(module, nn.Linear):
48 | module.weight.data.normal_(mean=0.0, std=0.02)
49 | if module.bias is not None:
50 | module.bias.data.zero_()
51 | if isinstance(module, nn.Embedding):
52 | module.weight.data.normal_(mean=0.0, std=0.02)
53 | module._fill_padding_idx_with_zero()
54 | elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
55 | module.bias.data.zero_()
56 | module.weight.data.fill_(1.0)
57 | elif isinstance(module, nn.Conv1d):
58 | module.weight.data.normal_(mean=0.0, std=0.02)
59 | if module.bias is not None:
60 | module.bias.data.zero_()
61 |
62 | def configure_optimizers(self):
63 | optimizer_adam = optim.AdamW(
64 | self.parameters(), lr=self.hp.lr,
65 | betas=(self.hp.adam_beta1, self.hp.adam_beta2))
66 |
67 | # Learning rate scheduler
68 | num_training_steps = self.hp.training_step
69 | num_warmup_steps = self.hp.warmup_step
70 | num_flat_steps = int(self.hp.optim_flat_percent * num_training_steps)
71 |
72 | def lambda_lr(current_step: int):
73 | if current_step < num_warmup_steps:
74 | return float(current_step) / float(max(1, num_warmup_steps))
75 | elif current_step < (num_warmup_steps + num_flat_steps):
76 | return 1.0
77 | return max(
78 | 0.0,
79 | float(num_training_steps - current_step)
80 | / float(
81 | max(1, num_training_steps - (num_warmup_steps + num_flat_steps)) # noqa
82 | ),
83 | )
84 |
85 | scheduler_adam = {
86 | "scheduler": optim.lr_scheduler.LambdaLR(
87 | optimizer_adam, lambda_lr),
88 | "interval": "step",
89 | }
90 | return [optimizer_adam], [scheduler_adam]
91 |
92 | def top_k_accuracy(self, y_true, y_pred_probabilities, k):
93 | _, sorted_indices = torch.sort(y_pred_probabilities, descending=True)
94 |
95 | # Get the top-k predictions
96 | top_k_indices = sorted_indices[:, :k]
97 | expanded_y_true = y_true.unsqueeze(1).expand_as(top_k_indices)
98 |
99 | # Check if true labels exist in top-k predictions
100 | hits = torch.sum(torch.eq(top_k_indices, expanded_y_true))
101 | accuracy = hits.item() / (len(y_true) + 1e-7)
102 |
103 | return accuracy
104 |
105 | def training_step(self, batch, batch_idx):
106 | # Sample training level
107 | rvq_level = torch.randint(
108 | 0, min(self.hp.first_n_lvls, self.hp.n_cluster_groups),(1,)).item()
109 |
110 | target, chosen_tokens, _, _ = self.model(
111 | batch["tts_quantize_input"], rvq_level, batch["semantic_tokens"],
112 | batch["quantization_lengths"],
113 | speaker_emb=batch["speaker"],
114 | min_seq_length=batch["quantization_lengths"].min().item())
115 |
116 | # Mask targets and labels
117 | mask = chosen_tokens
118 | target = target[mask]
119 |
120 | labels = batch["tts_quantize_input"][:, :, rvq_level]
121 | labels = labels[mask]
122 |
123 | loss = self.cross_entropy(target, labels)
124 | acc = (target.argmax(-1) == labels).float().mean()
125 | self.log("train/loss", loss, on_step=True, prog_bar=True)
126 | self.log("train/acc", acc, on_step=True, prog_bar=True)
127 | self.log(
128 | f"train/acc_lvl_{rvq_level}", acc, on_step=True, prog_bar=False)
129 |
130 | return loss
131 |
132 | def validation_step(self, batch, batch_idx, dataloader_idx=0):
133 | speaker_emb = batch["speaker"]
134 | acoustic_tokens = batch["tts_quantize_input"]
135 | semantic_tokens = batch["semantic_tokens"]
136 |
137 | if self.hp.only_inference:
138 | self.inference(
139 | acoustic_tokens, semantic_tokens, self.hp.first_n_lvls)
140 | else:
141 | rvq_level = torch.randint(
142 | 0, min(self.hp.first_n_lvls, self.hp.n_cluster_groups),(1,)
143 | ).item()
144 |
145 | # FIXME: edge case
146 | if len(semantic_tokens.shape) == 3:
147 | semantic_tokens = rearrange(semantic_tokens, "B 1 T -> B T")
148 |
149 | target, chosen_tokens, _, _ = self.model(
150 | acoustic_tokens, rvq_level, semantic_tokens,
151 | torch.tensor([acoustic_tokens.shape[1]]).to(self.device),
152 | speaker_emb=speaker_emb,
153 | min_seq_length=acoustic_tokens.shape[1]
154 | )
155 |
156 | target = target[chosen_tokens]
157 | labels = acoustic_tokens[:, :, rvq_level][chosen_tokens]
158 | loss = self.cross_entropy(target, labels)
159 |
160 | acc = (target.argmax(-1) == labels).float().mean()
161 | acc_5 = self.top_k_accuracy(labels, target, 5)
162 |
163 | self.log(
164 | f"val/dataset_{dataloader_idx}/loss",
165 | loss,
166 | on_epoch=True,
167 | logger=True,
168 | add_dataloader_idx=False,
169 | )
170 | self.log(
171 | f"val/dataset_{dataloader_idx}/acc_lvl",
172 | acc,
173 | on_epoch=True,
174 | logger=True,
175 | add_dataloader_idx=False,
176 | )
177 | self.log(
178 | f"val/dataset_{dataloader_idx}/acc_lvl_{rvq_level}",
179 | acc,
180 | on_epoch=True,
181 | logger=True,
182 | add_dataloader_idx=False,
183 | )
184 | self.log(
185 | f"val/dataset_{dataloader_idx}/acc_top_5",
186 | acc_5,
187 | on_epoch=True,
188 | logger=True,
189 | add_dataloader_idx=False,
190 | )
191 | self.log(
192 | f"val/dataset_{dataloader_idx}/acc_top_5_lvl_{rvq_level}",
193 | acc_5,
194 | on_epoch=True,
195 | logger=True,
196 | add_dataloader_idx=False,
197 | )
198 |
199 | def compute_stats(self, logits, labels, mask_ratio=0, rvq_level=0):
200 | acc = (logits.argmax(-1) == labels).float().mean()
201 | acc_5 = self.top_k_accuracy(labels, logits, 5)
202 | acc_10 = self.top_k_accuracy(labels, logits, 10)
203 |
204 | idx = torch.randperm(logits.shape[0])
205 | logits_shuffled = logits[idx]
206 | random = self.top_k_accuracy(labels, logits_shuffled, 10)
207 | print(f"Mask ratio: {mask_ratio}, Level {rvq_level}: acc {acc},"
208 | f"acc 5 {acc_5}, acc 10 {acc_10}, quasi random {random}")
209 |
210 |
211 | class TTSConformer(pl.LightningModule):
212 | def __init__(self, hp):
213 | super().__init__()
214 | self.hp = hp
215 | self.padding_id = self.hp.n_codes
216 |
217 | additional_codes = [c.PAD, c.SPKR_1, c.SPKR_2]
218 |
219 | self.embedding = nn.ModuleList(
220 | [
221 | nn.Embedding(
222 | self.hp.n_codes + len(additional_codes),
223 | self.hp.hidden_size,
224 | padding_idx=self.padding_id)
225 | for _ in range(self.hp.n_cluster_groups)
226 | ]
227 | )
228 |
229 | # Additional modules
230 | self.semantic_embedding = nn.Embedding(
231 | self.hp.n_semantic_codes + len(additional_codes),
232 | self.hp.hidden_size,
233 | padding_idx=self.padding_id)
234 |
235 | if self.hp.use_spkr_emb:
236 | self.spkr_linear = nn.Linear(c.SPKR_EMB_SIZE, self.hp.hidden_size)
237 |
238 | self.conformer = Conformer(
239 | dim=self.hp.hidden_size,
240 | num_layers=self.hp.enc_nlayers,
241 | heads=self.hp.nheads,
242 | dim_head=64,
243 | ff_mult=4, # 512*4=2048
244 | conv_expansion_factor=2,
245 | conv_kernel_size=self.hp.depthwise_conv_kernel_size,
246 | attn_dropout=self.hp.dropout,
247 | ff_dropout=self.hp.dropout,
248 | conv_dropout=self.hp.dropout,
249 | attn_flash=True,
250 | t5_rel_pos_bias=False
251 | )
252 |
253 | self.heads = nn.ModuleList(
254 | [
255 | nn.Linear(
256 | self.hp.hidden_size,
257 | self.hp.n_codes + len(additional_codes)
258 | )
259 | for _ in range(self.hp.n_cluster_groups)
260 | ]
261 | )
262 |
263 | def build_mask_from_lengths(self, length, max_len=None):
264 | max_len = max_len or length.max().item()
265 | mask = torch.arange(
266 | max_len, device=length.device)[None, :] >= length[:, None]
267 | return mask.bool()
268 |
269 | @torch.no_grad()
270 | def create_mask(
271 | self, B, T, lengths, mask_ratio=None, start_t=None,
272 | min_seq_length=None
273 | ):
274 | # 1. Define the random length of condition tokens given the shortest
275 | # audio in the batch
276 | if start_t is None:
277 | start_t = torch.randint(1, min_seq_length - 1, (1,)).item()
278 |
279 | # 2. Mask other tokens - sample different masking levels per
280 | if mask_ratio is None:
281 | ratio = torch.rand(1).item()
282 | mask_ratio = masking_logic.schedule(ratio)
283 |
284 | # Create a random tensor with values between 0 and 1
285 | random_tensor = torch.rand(
286 | (B, T - start_t), dtype=torch.float).to(self.device)
287 | # Create a mask where values less than p are set to True
288 | initial_mask = random_tensor < mask_ratio
289 | length_mask = self.build_mask_from_lengths(
290 | lengths - start_t, T - start_t)
291 | # we can't pick up tokens past token lengths
292 | initial_mask = torch.logical_and(initial_mask, ~length_mask)
293 |
294 | # Constrain ratio to always include some samples
295 | # If all are False let's pick up at least one:
296 | if torch.sum(initial_mask) == 0:
297 | choose_steps = torch.randint(low=0, high=(T - start_t), size=(B,))
298 | initial_mask[torch.arange(B), choose_steps] = torch.tensor(
299 | True, device=self.device)
300 |
301 | # 3. Add condition tokens containing information
302 | acoustic_token_mask = torch.cat(
303 | (torch.full((B, start_t), False, device=self.device), initial_mask), # noqa
304 | 1
305 | )
306 |
307 | return acoustic_token_mask, start_t, mask_ratio
308 |
309 | def process_input(
310 | self, data, lengths, rvq_level, min_seq_length=None,
311 | mask_ratio=None, start_t=None, acoustic_token_mask=None
312 | ):
313 | """
314 | data: (B, T, code_level, D)
315 | rvq_level: int
316 | """
317 | B = data.size(0)
318 | T = data.size(1)
319 | level_data = data[:, :, rvq_level, :] # [B, T, C, D] -> [B, T, D]
320 |
321 | # Choose acoustic tokens to mask
322 | if acoustic_token_mask is None:
323 | acoustic_token_mask, start_t, mask_ratio = self.create_mask(
324 | B, T, lengths, mask_ratio=mask_ratio, start_t=start_t,
325 | min_seq_length=min_seq_length)
326 | # Remove code information from chosen tokens
327 | level_data[acoustic_token_mask, :] = 0
328 |
329 | # Embed only lower rvq_level
330 | lower_code_data = data[:, :, :rvq_level, :].sum(dim=2)
331 |
332 | # Combine with chosen tokens at rvq_level.
333 | # Note: all tokens at rvq_level+1: will be discarded.
334 | summed_data = torch.add(lower_code_data, level_data)
335 |
336 | return summed_data, acoustic_token_mask, mask_ratio, start_t
337 |
338 | def forward(
339 | self, x, code_level, semantic_tokens, lengths,
340 | speaker_emb=None, min_seq_length=10, mask_ratio=None, start_t=None,
341 | acoustic_token_mask=None
342 | ):
343 | # FIXME: parallelize this
344 | batch = []
345 | for lvl, embed in enumerate(self.embedding[:(code_level + 1)]):
346 | batch.append(embed(x[:, :, lvl])) # [B T D]
347 |
348 | x = torch.stack(batch, dim=2) # [B T C D]
349 | x, acoustic_token_mask, mask_ratio, start_t = self.process_input(
350 | x, lengths, code_level, min_seq_length=min_seq_length,
351 | mask_ratio=mask_ratio, start_t=start_t,
352 | acoustic_token_mask=acoustic_token_mask
353 | )
354 |
355 | # Add phoneme embeddings
356 | # Cross attention for all tokens?
357 |
358 | # Add semantic tokens
359 | # HACK ME
360 | semantic_emb = self.semantic_embedding(semantic_tokens)
361 | x = torch.add(x, semantic_emb)
362 | # FIXME pfb30
363 |
364 | # Merge different modalities
365 | if self.hp.use_spkr_emb:
366 | spkr_emb = F.normalize(speaker_emb, dim=-1)
367 | spkr_emb = self.spkr_linear(
368 | F.dropout(spkr_emb, self.hp.speaker_embed_dropout)
369 | )
370 | x = torch.add(x, spkr_emb)
371 |
372 | output_frames = self.conformer(x, None)
373 |
374 | x = self.heads[code_level](output_frames)
375 |
376 | return x, acoustic_token_mask, mask_ratio, start_t
377 |
378 | @torch.no_grad()
379 | def inference(
380 | self, codes, semantic_tokens,
381 | length: torch.LongTensor, rvq_levels=7,
382 | mask_ratio=0.99, maskgit_inference=True,
383 | start_t: Union[torch.LongTensor, None] = None,
384 | speaker_emb=None, steps=16
385 | ):
386 | # Use half of the recording for the conditioning
387 | if start_t is None:
388 | start_t = torch.tensor(int((codes.shape[1]) / 2)).long()
389 |
390 | start_t = start_t.item()
391 |
392 | for rvq_level in range(rvq_levels):
393 | original_codes = torch.clone(codes)
394 | if rvq_level == 0 and maskgit_inference:
395 | codes = self.multi_step_inference(
396 | original_codes, semantic_tokens, length,
397 | start_t=start_t, vamp_filtering=False,
398 | speaker_emb=speaker_emb, steps=16
399 | )
400 | else:
401 | codes = self.one_step_inference(
402 | original_codes, semantic_tokens, length,
403 | code_level=rvq_level,
404 | mask_ratio=mask_ratio, start_t=start_t,
405 | speaker_emb=speaker_emb
406 | )
407 |
408 | codes = rearrange(codes, 'T C -> 1 T C')
409 |
410 | # Remove any padding left
411 | codes = rearrange(codes, '1 T C -> 1 C T')
412 | codes = torch.where(codes >= self.hp.n_codes, 0, codes)
413 | acoustic_tokens = codes
414 | semantic_tokens = rearrange(semantic_tokens, 'b c -> b 1 c')
415 | semantic_tokens = torch.where(
416 | semantic_tokens >= self.hp.n_codes, 0, semantic_tokens)
417 | codes = torch.cat([semantic_tokens, acoustic_tokens], dim=1)
418 |
419 | return codes
420 |
421 | @torch.no_grad()
422 | def one_step_inference(
423 | self, original_codes, semantic_tokens, lengths, code_level=0,
424 | mask_ratio=0.99, start_t=0, inference_setup="argmax", speaker_emb=None
425 | ):
426 | codes = torch.clone(original_codes)
427 | logits, _, _, _ = self.forward(
428 | codes, code_level, semantic_tokens, lengths,
429 | mask_ratio=mask_ratio, start_t=start_t,
430 | speaker_emb=speaker_emb, acoustic_token_mask=False)
431 |
432 | if inference_setup == "argmax":
433 | probs = torch.nn.functional.softmax(logits, dim=-1)
434 | top_indeces = torch.argmax(probs, dim=-1)
435 |
436 | if inference_setup == "sampling":
437 | top_indeces = torch.distributions.Categorical(
438 | logits=logits).sample()
439 |
440 | codes = rearrange(codes, '1 T C -> T C')
441 | codes[start_t:, code_level] = top_indeces[0, start_t:]
442 |
443 | return codes
444 |
445 | @torch.no_grad()
446 | def multi_step_inference(
447 | self, original_codes, semantic_tokens, lengths,
448 | start_t: torch.LongTensor=None,
449 | choice_temperature=1.0, start_iter=0,
450 | steps=16, vamp_filtering=False, speaker_emb=None
451 | ):
452 | codes = torch.clone(original_codes)
453 | code_level = 0
454 | _, seq_len, _ = original_codes.shape
455 | mask_token_id = self.padding_id
456 |
457 | # Get true codes for the prompt
458 | prompt_mask = codes[:, :start_t, code_level]
459 |
460 | # Fill up rest with masks
461 | mask = torch.full(
462 | (1, seq_len - start_t), mask_token_id, device=self.device)
463 | inputs = torch.cat((prompt_mask, mask), 1)
464 |
465 | num_mask_tokens_at_start = torch.sum(inputs == mask_token_id, axis=-1)
466 |
467 | # Initializes state
468 | state = state_init(inputs, steps, start_iter=start_iter)
469 |
470 | def loop_cond_fn(state):
471 | """Beam search loop termination condition."""
472 | not_at_end = (state.cur_index < steps)
473 | return not_at_end
474 |
475 | while loop_cond_fn(state):
476 | """Beam search loop state update function."""
477 | step = state.cur_index
478 | # Current input ids: [batch_size, seq_length].
479 | cur_ids = state.cur_seqs
480 |
481 | # Calls model on current seqs to get next-iteration seqs.
482 | with torch.no_grad():
483 | logits, _, _, _ = self.forward(
484 | rearrange(inputs, 'B T -> B T 1'),
485 | code_level,
486 | semantic_tokens, lengths,
487 | acoustic_token_mask=False,
488 | speaker_emb=speaker_emb)
489 |
490 | # Samples the ids using categorical sampling:
491 | if vamp_filtering:
492 | typical_mass = 0.2
493 | typical_min_tokens = 1
494 | top_p = None
495 | sample_cutoff = 0.5
496 | typical_filtering = False
497 | sampled_ids, selected_probs = sample_from_logits(
498 | logits, sample=((step / steps) <= sample_cutoff),
499 | temperature=choice_temperature,
500 | typical_filtering=typical_filtering,
501 | typical_mass=typical_mass,
502 | typical_min_tokens=typical_min_tokens,
503 | top_k=None, top_p=top_p, return_probs=True,
504 | )
505 | else:
506 | sampled_ids = torch.distributions.Categorical(
507 | logits=logits).sample()
508 |
509 | # Just updates the masked tokens.
510 | unknown_map = (cur_ids == mask_token_id)
511 | sampled_ids = torch.where(unknown_map, sampled_ids, cur_ids)
512 | # Defines the mask ratio for the next round. The number to mask out
513 | # is determined by mask_ratio * unknown_number_in_the_beginning.
514 | ratio = 1. * (step + 1) / steps
515 | mask_ratio = masking_logic.schedule(ratio)
516 |
517 | # Updates final seqs with the current sampled_ids.
518 | final_seqs = torch.clone(state.final_seqs)
519 | final_seqs[:, step, :] = sampled_ids
520 | # Computes the probabilities of each selected tokens.
521 | probs = torch.nn.functional.softmax(logits, dim=-1)
522 | # Extract the probabilities of sampled ids
523 | selected_probs = torch.squeeze(
524 | torch.take_along_dim(
525 | probs, torch.unsqueeze(sampled_ids, -1) , -1),
526 | -1
527 | )
528 |
529 | # Ignores the tokens given in the input
530 | # by overwriting their confidence.
531 | selected_probs = torch.where(
532 | unknown_map, selected_probs, torch.inf)
533 | # Gets mask lens for each sample in the
534 | # batch according to the mask ratio.
535 | num_to_mask = torch.unsqueeze(
536 | torch.floor(num_mask_tokens_at_start * mask_ratio), 1)
537 |
538 | # Keeps at least one of prediction in this
539 | # round and also masks out at least
540 | # one and for the next iteration
541 | num_to_mask = torch.maximum(
542 | torch.tensor(1),
543 | torch.minimum(
544 | torch.sum(unknown_map, dim=-1, keepdim=True) - 1,
545 | num_to_mask)
546 | )
547 | # Adds noise for randomness
548 | masking = mask_by_random_topk(
549 | num_to_mask, selected_probs, choice_temperature * (1. - ratio))
550 | # Masks tokens with lower confidence.
551 | sampled_ids = torch.where(masking, mask_token_id, sampled_ids)
552 |
553 | state = State(
554 | cur_index=state.cur_index + 1,
555 | cur_seqs=sampled_ids,
556 | final_seqs=final_seqs
557 | )
558 |
559 | codes = torch.clone(original_codes)
560 | codes = rearrange(codes, '1 T C -> T C')
561 | codes[:, 0] = state.final_seqs[0][-1]
562 |
563 | return codes
564 |
--------------------------------------------------------------------------------