├── 1.ae.ipynb
├── 10.diffusion.ipynb
├── 11.pix2pix.ipynb
├── 2.vae.ipynb
├── 3.dcgan.ipynb
├── 4.wgan.ipynb
├── 5.wgangp.ipynb
├── 6.cyclegan.ipynb
├── 7.musegan.ipynb
├── 8.style transfer.ipynb
├── 9.simple_style transfer.ipynb
├── README.md
├── datas
├── content.jpeg
├── style.jpg
└── temp.midi
└── keras
├── 1.ae画mnist.ipynb
├── 10.diffusion.ipynb
├── 2.vae画mnist.ipynb
├── 3.vae画celeba.ipynb
├── 4.gan画quick_draw.ipynb
├── 5.wgan画cifar10.ipynb
├── 6.wgangp画celeba.ipynb
├── 7.cyclegan画apple2orange.ipynb
├── 8.lstm创作cello.ipynb
├── 9.musegan创作chorales.ipynb
└── README.md
/7.musegan.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "id": "e93fca6e",
7 | "metadata": {},
8 | "outputs": [
9 | {
10 | "name": "stderr",
11 | "output_type": "stream",
12 | "text": [
13 | "Using custom data configuration lansinuote--gen.2.chorales-2bf7c47eabbdde89\n",
14 | "Found cached dataset parquet (/root/.cache/huggingface/datasets/lansinuote___parquet/lansinuote--gen.2.chorales-2bf7c47eabbdde89/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)\n"
15 | ]
16 | },
17 | {
18 | "data": {
19 | "text/plain": [
20 | "((229, 4, 2, 16, 84), -1.0, 1.0)"
21 | ]
22 | },
23 | "execution_count": 1,
24 | "metadata": {},
25 | "output_type": "execute_result"
26 | }
27 | ],
28 | "source": [
29 | "#加载全部数据到内存中\n",
30 | "def get_data():\n",
31 | " from datasets import load_dataset\n",
32 | " import numpy as np\n",
33 | "\n",
34 | " #加载\n",
35 | " dataset = load_dataset('lansinuote/gen.2.chorales', split='train')\n",
36 | "\n",
37 | " #加载为numpy数据\n",
38 | " data = np.empty((229, 4, 2, 16, 84), dtype=np.float32)\n",
39 | " for i in range(len(dataset)):\n",
40 | " data[i] = dataset[i]['data']\n",
41 | "\n",
42 | " return data\n",
43 | "\n",
44 | "\n",
45 | "data = get_data()\n",
46 | "\n",
47 | "data.shape, data.min(), data.max()"
48 | ]
49 | },
50 | {
51 | "cell_type": "code",
52 | "execution_count": 2,
53 | "id": "9e6c9a3e",
54 | "metadata": {},
55 | "outputs": [
56 | {
57 | "data": {
58 | "text/plain": [
59 | "(3, torch.Size([64, 4, 2, 16, 84]))"
60 | ]
61 | },
62 | "execution_count": 2,
63 | "metadata": {},
64 | "output_type": "execute_result"
65 | }
66 | ],
67 | "source": [
68 | "import torch\n",
69 | "\n",
70 | "loader = torch.utils.data.DataLoader(\n",
71 | " dataset=data,\n",
72 | " batch_size=64,\n",
73 | " shuffle=True,\n",
74 | " drop_last=True,\n",
75 | ")\n",
76 | "\n",
77 | "len(loader), next(iter(loader)).shape"
78 | ]
79 | },
80 | {
81 | "cell_type": "code",
82 | "execution_count": 3,
83 | "id": "f6100c70",
84 | "metadata": {},
85 | "outputs": [
86 | {
87 | "data": {
88 | "text/html": [
89 | "\n",
90 | "
\n",
91 | " \n",
93 | " \n",
94 | " "
104 | ],
105 | "text/plain": [
106 | ""
107 | ]
108 | },
109 | "metadata": {},
110 | "output_type": "display_data"
111 | },
112 | {
113 | "data": {
114 | "text/html": [
115 | "\n",
116 | " \n",
117 | " \n",
119 | " \n",
120 | " "
130 | ],
131 | "text/plain": [
132 | ""
133 | ]
134 | },
135 | "metadata": {},
136 | "output_type": "display_data"
137 | },
138 | {
139 | "data": {
140 | "text/html": [
141 | "\n",
142 | " \n",
143 | " \n",
145 | " \n",
146 | " "
156 | ],
157 | "text/plain": [
158 | ""
159 | ]
160 | },
161 | "metadata": {},
162 | "output_type": "display_data"
163 | }
164 | ],
165 | "source": [
166 | "import music21\n",
167 | "\n",
168 | "\n",
169 | "#工具类,不重要\n",
170 | "class Show():\n",
171 | " #工具函数,不重要\n",
172 | " def __merge_note(self, note, duration=None):\n",
173 | " import numpy as np\n",
174 | "\n",
175 | " if duration is None:\n",
176 | " duration = np.full(note.shape, fill_value=0.25, dtype=np.float32)\n",
177 | "\n",
178 | " #从前往后遍历\n",
179 | " for i in range(len(note) - 1):\n",
180 | " j = i + 1\n",
181 | "\n",
182 | " #判断相连的两个note是否相同,并且duration相加不大于1.0\n",
183 | " if note[i] == note[j] and duration[i] + duration[j] <= 1.0:\n",
184 | "\n",
185 | " #duration合并\n",
186 | " duration[i] += duration[j]\n",
187 | "\n",
188 | " #删除重复的note\n",
189 | " note = np.delete(note, j, axis=0)\n",
190 | " duration = np.delete(duration, j, axis=0)\n",
191 | "\n",
192 | " #递归调用\n",
193 | " return self.__merge_note(note, duration)\n",
194 | "\n",
195 | " return note, duration\n",
196 | "\n",
197 | " #工具函数,不重要\n",
198 | " def __save_to_mid(self, data):\n",
199 | " #data -> [32, 4]\n",
200 | " stream = music21.stream.Score()\n",
201 | " stream.append(music21.tempo.MetronomeMark(number=66))\n",
202 | "\n",
203 | " for i in range(4):\n",
204 | " channel = music21.stream.Part()\n",
205 | "\n",
206 | " notes, durations = self.__merge_note(data[:, i])\n",
207 | " notes, durations = notes.tolist(), durations.tolist()\n",
208 | " for n, d in zip(notes, durations):\n",
209 | " note = music21.note.Note(n)\n",
210 | " note.duration = music21.duration.Duration(d)\n",
211 | " channel.append(note)\n",
212 | "\n",
213 | " stream.append(channel)\n",
214 | "\n",
215 | " stream.write('midi', fp='./datas/temp.midi')\n",
216 | "\n",
217 | " def __call__(self, data):\n",
218 | " #[4, 2, 16, 84] -> [4, 2, 16] -> [32, 4]\n",
219 | " data = data.argmax(dim=-1).reshape(32, 4)\n",
220 | " data = data.to('cpu').detach().numpy()\n",
221 | " self.__save_to_mid(data)\n",
222 | "\n",
223 | " f = music21.midi.MidiFile()\n",
224 | " f.open('./datas/temp.midi')\n",
225 | " f.read()\n",
226 | " f.close()\n",
227 | " music21.midi.translate.midiFileToStream(f).show('midi')\n",
228 | "\n",
229 | "\n",
230 | "show = Show()\n",
231 | "\n",
232 | "for _ in range(3):\n",
233 | " show(next(iter(loader))[0])"
234 | ]
235 | },
236 | {
237 | "cell_type": "code",
238 | "execution_count": 4,
239 | "id": "6fa83528",
240 | "metadata": {},
241 | "outputs": [
242 | {
243 | "data": {
244 | "text/plain": [
245 | "torch.Size([2, 1, 1, 16, 84])"
246 | ]
247 | },
248 | "execution_count": 4,
249 | "metadata": {},
250 | "output_type": "execute_result"
251 | }
252 | ],
253 | "source": [
254 | "def get_gen_track():\n",
255 | " return torch.nn.Sequential(\n",
256 | " torch.nn.Linear(4 * 32, 1024),\n",
257 | " torch.nn.BatchNorm1d(1024),\n",
258 | " torch.nn.ReLU(inplace=True),\n",
259 | " torch.nn.Unflatten(unflattened_size=(512, 2, 1), dim=1),\n",
260 | " torch.nn.ConvTranspose2d(512,\n",
261 | " 512,\n",
262 | " kernel_size=(2, 1),\n",
263 | " stride=(2, 1),\n",
264 | " padding=0),\n",
265 | " torch.nn.BatchNorm2d(512),\n",
266 | " torch.nn.ReLU(inplace=True),\n",
267 | " torch.nn.ConvTranspose2d(512,\n",
268 | " 256,\n",
269 | " kernel_size=(2, 1),\n",
270 | " stride=(2, 1),\n",
271 | " padding=0),\n",
272 | " torch.nn.BatchNorm2d(256),\n",
273 | " torch.nn.ReLU(inplace=True),\n",
274 | " torch.nn.ConvTranspose2d(256,\n",
275 | " 256,\n",
276 | " kernel_size=(2, 1),\n",
277 | " stride=(2, 1),\n",
278 | " padding=0),\n",
279 | " torch.nn.BatchNorm2d(256),\n",
280 | " torch.nn.ReLU(inplace=True),\n",
281 | " torch.nn.ConvTranspose2d(256,\n",
282 | " 256,\n",
283 | " kernel_size=(1, 7),\n",
284 | " stride=(1, 7),\n",
285 | " padding=0),\n",
286 | " torch.nn.BatchNorm2d(256),\n",
287 | " torch.nn.ReLU(inplace=True),\n",
288 | " torch.nn.ConvTranspose2d(256,\n",
289 | " 1,\n",
290 | " kernel_size=(1, 12),\n",
291 | " stride=(1, 12),\n",
292 | " padding=0),\n",
293 | " torch.nn.Unflatten(unflattened_size=(1, 1), dim=1),\n",
294 | " )\n",
295 | "\n",
296 | "\n",
297 | "get_gen_track()(torch.randn(2, 128)).shape"
298 | ]
299 | },
300 | {
301 | "cell_type": "code",
302 | "execution_count": 5,
303 | "id": "3529afe3",
304 | "metadata": {},
305 | "outputs": [
306 | {
307 | "data": {
308 | "text/plain": [
309 | "torch.Size([2, 32, 2])"
310 | ]
311 | },
312 | "execution_count": 5,
313 | "metadata": {},
314 | "output_type": "execute_result"
315 | }
316 | ],
317 | "source": [
318 | "def get_gen_block():\n",
319 | " return torch.nn.Sequential(\n",
320 | " torch.nn.Unflatten(unflattened_size=(32, 1, 1), dim=1),\n",
321 | " torch.nn.ConvTranspose2d(32,\n",
322 | " 1024,\n",
323 | " kernel_size=(2, 1),\n",
324 | " stride=(1, 1),\n",
325 | " padding=0), torch.nn.BatchNorm2d(1024),\n",
326 | " torch.nn.ReLU(inplace=True),\n",
327 | " torch.nn.ConvTranspose2d(1024,\n",
328 | " 32,\n",
329 | " kernel_size=(2 - 1, 1),\n",
330 | " stride=(1, 1),\n",
331 | " padding=0), torch.nn.BatchNorm2d(32),\n",
332 | " torch.nn.ReLU(inplace=True), torch.nn.Flatten(start_dim=2))\n",
333 | "\n",
334 | "\n",
335 | "get_gen_block()(torch.randn(2, 32)).shape"
336 | ]
337 | },
338 | {
339 | "cell_type": "code",
340 | "execution_count": 6,
341 | "id": "7f6371b9",
342 | "metadata": {},
343 | "outputs": [
344 | {
345 | "data": {
346 | "text/plain": [
347 | "torch.Size([2, 4, 2, 16, 84])"
348 | ]
349 | },
350 | "execution_count": 6,
351 | "metadata": {},
352 | "output_type": "execute_result"
353 | }
354 | ],
355 | "source": [
356 | "class GEN(torch.nn.Module):\n",
357 | "\n",
358 | " def __init__(self):\n",
359 | " super().__init__()\n",
360 | "\n",
361 | " self.gen_chord = get_gen_block()\n",
362 | "\n",
363 | " self.gen_melody = torch.nn.ModuleList(\n",
364 | " [get_gen_block() for _ in range(4)])\n",
365 | "\n",
366 | " self.gen_track = torch.nn.ModuleList(\n",
367 | " [get_gen_track() for _ in range(4)])\n",
368 | "\n",
369 | " def forward(self, chord, style, melody, groove):\n",
370 | " #chord -> [b, 32]\n",
371 | " #style -> [b, 32]\n",
372 | " #melody -> [b, 4, 32]\n",
373 | " #groove -> [b, 4, 32]\n",
374 | "\n",
375 | " #[b, 32] -> [b, 32, 2]\n",
376 | " out_chord = self.gen_chord(chord)\n",
377 | "\n",
378 | " out_i = []\n",
379 | " for i in range(2):\n",
380 | "\n",
381 | " out_j = []\n",
382 | " for j in range(4):\n",
383 | "\n",
384 | " #[b, 32] -> [b, 32, 2] -> [b, 32]\n",
385 | " out_melody = self.gen_melody[j](melody[:, j])[:, :, i]\n",
386 | "\n",
387 | " #[b, 32+32+32+32] -> [b, 128]\n",
388 | " out = torch.cat(\n",
389 | " [out_chord[:, :, i], style, out_melody, groove[:, j]],\n",
390 | " dim=1)\n",
391 | "\n",
392 | " #[b, 128] -> [b, 1, 1, 16, 84]\n",
393 | " out = self.gen_track[j](out)\n",
394 | "\n",
395 | " out_j.append(out)\n",
396 | "\n",
397 | " #[b, 1*4, 1, 16, 84] -> [b, 4, 1, 16, 84]\n",
398 | " out_i.append(torch.cat(out_j, dim=1))\n",
399 | "\n",
400 | " #[b, 4, 1*2, 16, 84] -> [b, 4, 2, 16, 84]\n",
401 | " out = torch.cat(out_i, dim=2)\n",
402 | "\n",
403 | " return out\n",
404 | "\n",
405 | "\n",
406 | "gen = GEN()\n",
407 | "\n",
408 | "gen(torch.randn(2, 32), torch.randn(2, 32), torch.randn(2, 4, 32),\n",
409 | " torch.randn(2, 4, 32)).shape"
410 | ]
411 | },
412 | {
413 | "cell_type": "code",
414 | "execution_count": 7,
415 | "id": "c9570ee1",
416 | "metadata": {},
417 | "outputs": [
418 | {
419 | "data": {
420 | "text/plain": [
421 | "tensor([[0.0234],\n",
422 | " [0.0236]], grad_fn=)"
423 | ]
424 | },
425 | "execution_count": 7,
426 | "metadata": {},
427 | "output_type": "execute_result"
428 | }
429 | ],
430 | "source": [
431 | "def get_cls():\n",
432 | " return torch.nn.Sequential(\n",
433 | " torch.nn.Conv3d(4, 128, (2, 1, 1), (1, 1, 1), padding=0),\n",
434 | " torch.nn.LeakyReLU(0.3, inplace=True),\n",
435 | " torch.nn.Conv3d(128, 128, (2 - 1, 1, 1), (1, 1, 1), padding=0),\n",
436 | " torch.nn.LeakyReLU(0.3, inplace=True),\n",
437 | " torch.nn.Conv3d(128, 128, (1, 1, 12), (1, 1, 12), padding=0),\n",
438 | " torch.nn.LeakyReLU(0.3, inplace=True),\n",
439 | " torch.nn.Conv3d(128, 128, (1, 1, 7), (1, 1, 7), padding=0),\n",
440 | " torch.nn.LeakyReLU(0.3, inplace=True),\n",
441 | " torch.nn.Conv3d(128, 128, (1, 2, 1), (1, 2, 1), padding=0),\n",
442 | " torch.nn.LeakyReLU(0.3, inplace=True),\n",
443 | " torch.nn.Conv3d(128, 128, (1, 2, 1), (1, 2, 1), padding=0),\n",
444 | " torch.nn.LeakyReLU(0.3, inplace=True),\n",
445 | " torch.nn.Conv3d(128, 2 * 128, (1, 4, 1), (1, 2, 1), padding=(0, 1, 0)),\n",
446 | " torch.nn.LeakyReLU(0.3, inplace=True),\n",
447 | " torch.nn.Conv3d(2 * 128,\n",
448 | " 4 * 128, (1, 3, 1), (1, 2, 1),\n",
449 | " padding=(0, 1, 0)),\n",
450 | " torch.nn.LeakyReLU(0.3, inplace=True),\n",
451 | " torch.nn.Flatten(),\n",
452 | " torch.nn.Linear(4 * 128, 1024),\n",
453 | " torch.nn.LeakyReLU(0.3, inplace=True),\n",
454 | " torch.nn.Linear(1024, 1),\n",
455 | " )\n",
456 | "\n",
457 | "\n",
458 | "cls = get_cls()\n",
459 | "\n",
460 | "cls(torch.randn(2, 4, 2, 16, 84))"
461 | ]
462 | },
463 | {
464 | "cell_type": "code",
465 | "execution_count": 8,
466 | "id": "minus-uniform",
467 | "metadata": {},
468 | "outputs": [
469 | {
470 | "data": {
471 | "text/plain": [
472 | "'cuda'"
473 | ]
474 | },
475 | "execution_count": 8,
476 | "metadata": {},
477 | "output_type": "execute_result"
478 | }
479 | ],
480 | "source": [
481 | "def set_requires_grad(model, requires_grad):\n",
482 | " for param in model.parameters():\n",
483 | " param.requires_grad_(requires_grad)\n",
484 | "\n",
485 | "def wasserstein(pred, label):\n",
486 | " return -(pred * label).mean()\n",
487 | "\n",
488 | "\n",
489 | "optimizer_cls = torch.optim.Adam(cls.parameters(), lr=1e-3, betas=(0.5, 0.9))\n",
490 | "optimizer_gen = torch.optim.Adam(gen.parameters(), lr=1e-3, betas=(0.5, 0.9))\n",
491 | "\n",
492 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
493 | "\n",
494 | "gen.to(device)\n",
495 | "cls.to(device)\n",
496 | "\n",
497 | "device"
498 | ]
499 | },
500 | {
501 | "cell_type": "code",
502 | "execution_count": 9,
503 | "id": "1170a748",
504 | "metadata": {},
505 | "outputs": [
506 | {
507 | "data": {
508 | "text/plain": [
509 | "tensor(0.9996, device='cuda:0', grad_fn=)"
510 | ]
511 | },
512 | "execution_count": 9,
513 | "metadata": {},
514 | "output_type": "execute_result"
515 | }
516 | ],
517 | "source": [
518 | "def get_gradient_penalty(real, fake):\n",
519 | " #real -> [64, 4, 2, 16, 84]\n",
520 | " #fake -> [64, 4, 2, 16, 84]\n",
521 | "\n",
522 | " r = torch.rand((64, 1, 1, 1, 1), device=device)\n",
523 | " r.requires_grad = True\n",
524 | "\n",
525 | " #[64, 4, 2, 16, 84]\n",
526 | " merge = r * real + (1 - r) * fake\n",
527 | "\n",
528 | " #[64, 4, 2, 16, 84] -> [64, 1]\n",
529 | " pred_merge = cls(merge)\n",
530 | "\n",
531 | " grad = torch.autograd.grad(inputs=merge,\n",
532 | " outputs=pred_merge,\n",
533 | " grad_outputs=torch.ones(64, 1, device=device),\n",
534 | " create_graph=True,\n",
535 | " retain_graph=True)\n",
536 | "\n",
537 | " #[64, 4, 2, 16, 84] -> [64, 10752]\n",
538 | " grad = grad[0].reshape(64, -1)\n",
539 | "\n",
540 | " #[64, 10752] -> [64]\n",
541 | " grad = grad.norm(p=2, dim=1)\n",
542 | "\n",
543 | " #[64] -> scala\n",
544 | " return (1 - grad).pow(2).mean()\n",
545 | "\n",
546 | "\n",
547 | "get_gradient_penalty(torch.randn(64, 4, 2, 16, 84, device=device),\n",
548 | " torch.randn(64, 4, 2, 16, 84, device=device))"
549 | ]
550 | },
551 | {
552 | "cell_type": "code",
553 | "execution_count": 10,
554 | "id": "31ad7ed8",
555 | "metadata": {},
556 | "outputs": [
557 | {
558 | "data": {
559 | "text/plain": [
560 | "9.99548053741455"
561 | ]
562 | },
563 | "execution_count": 10,
564 | "metadata": {},
565 | "output_type": "execute_result"
566 | }
567 | ],
568 | "source": [
569 | "def train_cls():\n",
570 | " set_requires_grad(cls, True)\n",
571 | " set_requires_grad(gen, False)\n",
572 | " \n",
573 | " #得到三份数据\n",
574 | " real = next(iter(loader)).to(device)\n",
575 | "\n",
576 | " with torch.no_grad():\n",
577 | " cord = torch.randn(64, 32, device=device)\n",
578 | " style = torch.randn(64, 32, device=device)\n",
579 | " melody = torch.randn(64, 4, 32, device=device)\n",
580 | " groove = torch.randn(64, 4, 32, device=device)\n",
581 | " fake = gen(cord, style, melody, groove)\n",
582 | "\n",
583 | " #分别计算\n",
584 | " pred_fake = cls(fake)\n",
585 | " pred_real = cls(real)\n",
586 | "\n",
587 | " #求loss,加权求和\n",
588 | " loss_fake = wasserstein(pred_fake, -torch.ones(64, 1, device=device))\n",
589 | " loss_real = wasserstein(pred_real, torch.ones(64, 1, device=device))\n",
590 | " loss_grad = get_gradient_penalty(real, fake)\n",
591 | "\n",
592 | " loss = loss_fake + loss_real + loss_grad * 10\n",
593 | "\n",
594 | " loss.backward()\n",
595 | " optimizer_cls.step()\n",
596 | " optimizer_cls.zero_grad()\n",
597 | "\n",
598 | " return loss.item()\n",
599 | "\n",
600 | "\n",
601 | "train_cls()"
602 | ]
603 | },
604 | {
605 | "cell_type": "code",
606 | "execution_count": 11,
607 | "id": "414d78b2",
608 | "metadata": {},
609 | "outputs": [
610 | {
611 | "data": {
612 | "text/plain": [
613 | "0.015252873301506042"
614 | ]
615 | },
616 | "execution_count": 11,
617 | "metadata": {},
618 | "output_type": "execute_result"
619 | }
620 | ],
621 | "source": [
622 | "def train_gen():\n",
623 | " set_requires_grad(cls, False)\n",
624 | " set_requires_grad(gen, True)\n",
625 | " \n",
626 | " cord = torch.randn(64, 32, device=device)\n",
627 | " style = torch.randn(64, 32, device=device)\n",
628 | " melody = torch.randn(64, 4, 32, device=device)\n",
629 | " groove = torch.randn(64, 4, 32, device=device)\n",
630 | "\n",
631 | " fake = gen(cord, style, melody, groove)\n",
632 | " fake_pred = cls(fake)\n",
633 | "\n",
634 | " loss = wasserstein(fake_pred, torch.ones(64, 1, device=device))\n",
635 | " loss.backward()\n",
636 | " optimizer_gen.step()\n",
637 | " optimizer_gen.zero_grad()\n",
638 | "\n",
639 | " return loss.item()\n",
640 | "\n",
641 | "\n",
642 | "train_gen()"
643 | ]
644 | },
645 | {
646 | "cell_type": "code",
647 | "execution_count": 12,
648 | "id": "taken-cover",
649 | "metadata": {
650 | "scrolled": false
651 | },
652 | "outputs": [
653 | {
654 | "name": "stdout",
655 | "output_type": "stream",
656 | "text": [
657 | "0 -88.02761840820312 365.5087890625\n"
658 | ]
659 | },
660 | {
661 | "data": {
662 | "text/html": [
663 | "\n",
664 | " \n",
665 | " \n",
667 | " \n",
668 | " "
678 | ],
679 | "text/plain": [
680 | ""
681 | ]
682 | },
683 | "metadata": {},
684 | "output_type": "display_data"
685 | },
686 | {
687 | "name": "stdout",
688 | "output_type": "stream",
689 | "text": [
690 | "2000 -18.66831398010254 4.117001533508301\n"
691 | ]
692 | },
693 | {
694 | "data": {
695 | "text/html": [
696 | "\n",
697 | " \n",
698 | " \n",
700 | " \n",
701 | " "
711 | ],
712 | "text/plain": [
713 | ""
714 | ]
715 | },
716 | "metadata": {},
717 | "output_type": "display_data"
718 | },
719 | {
720 | "name": "stdout",
721 | "output_type": "stream",
722 | "text": [
723 | "4000 -18.55518341064453 -1.6856918334960938\n"
724 | ]
725 | },
726 | {
727 | "data": {
728 | "text/html": [
729 | "\n",
730 | " \n",
731 | " \n",
733 | " \n",
734 | " "
744 | ],
745 | "text/plain": [
746 | ""
747 | ]
748 | },
749 | "metadata": {},
750 | "output_type": "display_data"
751 | },
752 | {
753 | "name": "stdout",
754 | "output_type": "stream",
755 | "text": [
756 | "6000 -16.17269515991211 -4.4504075050354\n"
757 | ]
758 | },
759 | {
760 | "data": {
761 | "text/html": [
762 | "\n",
763 | " \n",
764 | " \n",
766 | " \n",
767 | " "
777 | ],
778 | "text/plain": [
779 | ""
780 | ]
781 | },
782 | "metadata": {},
783 | "output_type": "display_data"
784 | },
785 | {
786 | "name": "stdout",
787 | "output_type": "stream",
788 | "text": [
789 | "8000 -12.522842407226562 -3.564983367919922\n"
790 | ]
791 | },
792 | {
793 | "data": {
794 | "text/html": [
795 | "\n",
796 | " \n",
797 | " \n",
799 | " \n",
800 | " "
810 | ],
811 | "text/plain": [
812 | ""
813 | ]
814 | },
815 | "metadata": {},
816 | "output_type": "display_data"
817 | },
818 | {
819 | "name": "stdout",
820 | "output_type": "stream",
821 | "text": [
822 | "10000 -13.41166877746582 -3.798163890838623\n"
823 | ]
824 | },
825 | {
826 | "data": {
827 | "text/html": [
828 | "\n",
829 | " \n",
830 | " \n",
832 | " \n",
833 | " "
843 | ],
844 | "text/plain": [
845 | ""
846 | ]
847 | },
848 | "metadata": {},
849 | "output_type": "display_data"
850 | },
851 | {
852 | "name": "stdout",
853 | "output_type": "stream",
854 | "text": [
855 | "12000 -9.745172500610352 -6.956740379333496\n"
856 | ]
857 | },
858 | {
859 | "data": {
860 | "text/html": [
861 | "\n",
862 | " \n",
863 | " \n",
865 | " \n",
866 | " "
876 | ],
877 | "text/plain": [
878 | ""
879 | ]
880 | },
881 | "metadata": {},
882 | "output_type": "display_data"
883 | },
884 | {
885 | "name": "stdout",
886 | "output_type": "stream",
887 | "text": [
888 | "14000 -8.481831550598145 -3.7079086303710938\n"
889 | ]
890 | },
891 | {
892 | "data": {
893 | "text/html": [
894 | "\n",
895 | " \n",
896 | " \n",
898 | " \n",
899 | " "
909 | ],
910 | "text/plain": [
911 | ""
912 | ]
913 | },
914 | "metadata": {},
915 | "output_type": "display_data"
916 | },
917 | {
918 | "name": "stdout",
919 | "output_type": "stream",
920 | "text": [
921 | "16000 -8.238203048706055 -1.697916030883789\n"
922 | ]
923 | },
924 | {
925 | "data": {
926 | "text/html": [
927 | "\n",
928 | " \n",
929 | " \n",
931 | " \n",
932 | " "
942 | ],
943 | "text/plain": [
944 | ""
945 | ]
946 | },
947 | "metadata": {},
948 | "output_type": "display_data"
949 | },
950 | {
951 | "name": "stdout",
952 | "output_type": "stream",
953 | "text": [
954 | "18000 -7.125642776489258 -2.4775843620300293\n"
955 | ]
956 | },
957 | {
958 | "data": {
959 | "text/html": [
960 | "\n",
961 | " \n",
962 | " \n",
964 | " \n",
965 | " "
975 | ],
976 | "text/plain": [
977 | ""
978 | ]
979 | },
980 | "metadata": {},
981 | "output_type": "display_data"
982 | }
983 | ],
984 | "source": [
985 | "def train():\n",
986 | " for epoch in range(2_0000):\n",
987 | " for _ in range(5):\n",
988 | " loss_cls = train_cls()\n",
989 | "\n",
990 | " loss_gen = train_gen()\n",
991 | "\n",
992 | " if epoch % 2000 == 0:\n",
993 | " print(epoch, loss_cls, loss_gen)\n",
994 | "\n",
995 | " #这里的b必须要大于1,否则BatchNorm层的计算会出错\n",
996 | " chord = torch.rand(2, 32, device=device)\n",
997 | " style = torch.rand(2, 32, device=device)\n",
998 | " melody = torch.rand(2, 4, 32, device=device)\n",
999 | " groove = torch.rand(2, 4, 32, device=device)\n",
1000 | "\n",
1001 | " #[2, 4, 2, 16, 84]\n",
1002 | " pred = gen(chord, style, melody, groove)\n",
1003 | " show(pred[0])\n",
1004 | "\n",
1005 | "\n",
1006 | "local_training = True\n",
1007 | "\n",
1008 | "if local_training:\n",
1009 | " train()"
1010 | ]
1011 | },
1012 | {
1013 | "cell_type": "code",
1014 | "execution_count": 13,
1015 | "id": "973e9776",
1016 | "metadata": {},
1017 | "outputs": [
1018 | {
1019 | "data": {
1020 | "application/vnd.jupyter.widget-view+json": {
1021 | "model_id": "3ae555b3113b4bd991eab096b47251ba",
1022 | "version_major": 2,
1023 | "version_minor": 0
1024 | },
1025 | "text/plain": [
1026 | "pytorch_model.bin: 0%| | 0.00/32.3M [00:00, ?B/s]"
1027 | ]
1028 | },
1029 | "metadata": {},
1030 | "output_type": "display_data"
1031 | },
1032 | {
1033 | "data": {
1034 | "application/vnd.jupyter.widget-view+json": {
1035 | "model_id": "4c8255497cfe4e4d90168b4cbd79638b",
1036 | "version_major": 2,
1037 | "version_minor": 0
1038 | },
1039 | "text/plain": [
1040 | "Upload 1 LFS files: 0%| | 0/1 [00:00, ?it/s]"
1041 | ]
1042 | },
1043 | "metadata": {},
1044 | "output_type": "display_data"
1045 | }
1046 | ],
1047 | "source": [
1048 | "from transformers import PreTrainedModel, PretrainedConfig\n",
1049 | "\n",
1050 | "\n",
1051 | "class Model(PreTrainedModel):\n",
1052 | " config_class = PretrainedConfig\n",
1053 | "\n",
1054 | " def __init__(self, config):\n",
1055 | " super().__init__(config)\n",
1056 | " self.cls = cls.to('cpu')\n",
1057 | " self.gen = gen.to('cpu')\n",
1058 | "\n",
1059 | "\n",
1060 | "if local_training:\n",
1061 | " #保存训练好的模型到hub\n",
1062 | " Model(PretrainedConfig()).push_to_hub(\n",
1063 | " repo_id='lansinuote/gen.7.musegan',\n",
1064 | " use_auth_token=open('/root/hub_token.txt').read().strip())"
1065 | ]
1066 | },
1067 | {
1068 | "cell_type": "code",
1069 | "execution_count": 14,
1070 | "id": "e933c6de",
1071 | "metadata": {},
1072 | "outputs": [
1073 | {
1074 | "data": {
1075 | "application/vnd.jupyter.widget-view+json": {
1076 | "model_id": "c34bf31ec95d4307ad8a37f98d0c5272",
1077 | "version_major": 2,
1078 | "version_minor": 0
1079 | },
1080 | "text/plain": [
1081 | "Downloading (…)\"pytorch_model.bin\";: 0%| | 0.00/32.3M [00:00, ?B/s]"
1082 | ]
1083 | },
1084 | "metadata": {},
1085 | "output_type": "display_data"
1086 | },
1087 | {
1088 | "data": {
1089 | "text/html": [
1090 | "\n",
1091 | " \n",
1092 | " \n",
1094 | " \n",
1095 | " "
1105 | ],
1106 | "text/plain": [
1107 | ""
1108 | ]
1109 | },
1110 | "metadata": {},
1111 | "output_type": "display_data"
1112 | }
1113 | ],
1114 | "source": [
1115 | "#加载训练好的模型\n",
1116 | "gen = Model.from_pretrained('lansinuote/gen.7.musegan').gen\n",
1117 | "with torch.no_grad():\n",
1118 | " #这里的b必须要大于1,否则BatchNorm层的计算会出错\n",
1119 | " chord = torch.rand(2, 32)\n",
1120 | " style = torch.rand(2, 32)\n",
1121 | " melody = torch.rand(2, 4, 32)\n",
1122 | " groove = torch.rand(2, 4, 32)\n",
1123 | "\n",
1124 | " #[2, 4, 2, 16, 84]\n",
1125 | " pred = gen(chord, style, melody, groove)\n",
1126 | " show(pred[0])"
1127 | ]
1128 | }
1129 | ],
1130 | "metadata": {
1131 | "kernelspec": {
1132 | "display_name": "Python [conda env:pt39]",
1133 | "language": "python",
1134 | "name": "conda-env-pt39-py"
1135 | },
1136 | "language_info": {
1137 | "codemirror_mode": {
1138 | "name": "ipython",
1139 | "version": 3
1140 | },
1141 | "file_extension": ".py",
1142 | "mimetype": "text/x-python",
1143 | "name": "python",
1144 | "nbconvert_exporter": "python",
1145 | "pygments_lexer": "ipython3",
1146 | "version": "3.9.13"
1147 | }
1148 | },
1149 | "nbformat": 4,
1150 | "nbformat_minor": 5
1151 | }
1152 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | 视频课程:https://www.bilibili.com/video/BV1hs4y1P7T3
2 |
3 | 环境信息:
4 |
5 | python==3.9
6 |
7 | torch==1.12.1+cu113
8 |
9 | transformers==4.26.1
10 |
11 | datasets==2.9.0
12 |
13 | music21==8.1.0
14 |
15 | 2023年4月27日更新:
16 |
17 | 1.ae,2.vae,3.dcgan,4.wgan,5wgangp,这5个任务的生成模型的code从768维,降低到128维.
18 |
--------------------------------------------------------------------------------
/datas/content.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lansinuote/Simple_Generative_in_PyTorch/d349f0efc7062ac258258613fe98d31c13e3495a/datas/content.jpeg
--------------------------------------------------------------------------------
/datas/style.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lansinuote/Simple_Generative_in_PyTorch/d349f0efc7062ac258258613fe98d31c13e3495a/datas/style.jpg
--------------------------------------------------------------------------------
/datas/temp.midi:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lansinuote/Simple_Generative_in_PyTorch/d349f0efc7062ac258258613fe98d31c13e3495a/datas/temp.midi
--------------------------------------------------------------------------------
/keras/8.lstm创作cello.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [
8 | {
9 | "data": {
10 | "text/html": [
11 | "\n",
12 | " \n",
13 | " \n",
15 | " "
25 | ],
26 | "text/plain": [
27 | ""
28 | ]
29 | },
30 | "metadata": {},
31 | "output_type": "display_data"
32 | }
33 | ],
34 | "source": [
35 | "import music21\n",
36 | "\n",
37 | "\n",
38 | "def show(file):\n",
39 | " f = music21.midi.MidiFile()\n",
40 | " f.open(file)\n",
41 | " f.read()\n",
42 | " f.close()\n",
43 | " music21.midi.translate.midiFileToStream(f).show('midi')\n",
44 | "\n",
45 | "\n",
46 | "show('../datas/cello/cs2-2all.mid')"
47 | ]
48 | },
49 | {
50 | "cell_type": "code",
51 | "execution_count": 2,
52 | "metadata": {},
53 | "outputs": [
54 | {
55 | "name": "stdout",
56 | "output_type": "stream",
57 | "text": [
58 | "([['START', 'A3', 'D2.A2.F3.A3', 'B-3', 'A3', 'G3', 'F3', 'E3', 'D3', 'D3', 'C#3', 'D3', 'E3', 'A2', 'B-2', 'G2', 'F2', 'A2', 'D3', 'F2', 'E2', 'C#3', 'D2.A2.D3', 'E3', 'F3', 'G3', 'A3', 'B-3', 'D2.A2.F#3.C4', 'D4', 'E-4', 'D4', 'C4', 'B-3', 'A3', 'C4', 'B-3', 'A3', 'G3', 'D4', 'F3', 'E3', 'G3', 'B-3', 'D4', 'C4', 'B-3', 'A3', 'G3', 'B-3', 'A3', 'G3', 'F3', 'F3.A3', 'F3.A3', 'B3', 'F3', 'E3', 'D3', 'E3', 'C#4', 'D4', 'C#4', 'D3.D4', 'E4', 'F4', 'E4', 'D4', 'E4', 'D4', 'C4', 'B3', 'C4', 'B3', 'A3', 'G#3', 'A3', 'G#3', 'F#3', 'E3', 'E4', 'C4', 'A3', 'G3', 'F3.E4', 'A3', 'F3', 'D3', 'B2.D3', 'F3', 'D3', 'B2', 'G#2', 'B2', 'E3', 'G#3', 'B3', 'D4', 'C4', 'B3', 'C4', 'A3', 'F3', 'E3', 'D3', 'F3', 'E3', 'D3', 'G#3', 'A3', 'B3', 'D4', 'E3', 'D3', 'C3', 'E3', 'A3', 'D4', 'E3.B3', 'A3', 'E-3.A3', 'E-3', 'E3', 'F#3.G#3', 'A3.B3', 'C4', 'D4', 'C4', 'B3', 'A3.C4', 'D3', 'G#3', 'A3', 'B3', 'A3', 'G#3', 'F#3', 'E3', 'C3.E3.A3', 'F3', 'E3', 'D3', 'C3', 'B2', 'A2', 'G#2.D3.B3', 'E3', 'F3', 'E3', 'D3', 'C3', 'B2', 'D4', 'B3', 'C4', 'A3', 'E3', 'G#3', 'A2', 'C#3', 'E3', 'G3', 'F3', 'E3', 'F3', 'A3', 'D4', 'G#3', 'A3', 'A3', 'D2.A2.F3.A3', 'B-3', 'A3', 'G3', 'F3', 'E3', 'D3', 'D3', 'C#3', 'D3', 'E3', 'A2', 'B-2', 'G2', 'F2', 'A2', 'D3', 'F2', 'E2', 'C#3', 'D2.A2.D3', 'E3', 'F3', 'G3', 'A3', 'B-3', 'D2.A2.F#3.C4', 'D4', 'E-4', 'D4', 'C4', 'B-3', 'A3', 'C4', 'B-3', 'A3', 'G3', 'D4', 'F3', 'E3', 'G3', 'B-3', 'D4', 'C4', 'B-3', 'A3', 'G3', 'B-3', 'A3', 'G3', 'F3', 'F3.A3', 'F3.A3', 'B3', 'F3', 'E3', 'D3', 'E3', 'C#4', 'D4', 'C#4', 'D3.D4', 'E4', 'F4', 'E4', 'D4', 'E4', 'D4', 'C4', 'B3', 'C4', 'B3', 'A3', 'G#3', 'A3', 'G#3', 'F#3', 'E3', 'E4', 'C4', 'A3', 'G3', 'F3.E4', 'A3', 'F3', 'D3', 'B2.D3', 'F3', 'D3', 'B2', 'G#2', 'B2', 'E3', 'G#3', 'B3', 'D4', 'C4', 'B3', 'C4', 'A3', 'F3', 'E3', 'D3', 'F3', 'E3', 'D3', 'G#3', 'A3', 'B3', 'D4', 'E3', 'D3', 'C3', 'E3', 'A3', 'D4', 'E3.B3', 'A3', 'E-3.A3', 'E-3', 'E3', 'F#3.G#3', 'A3.B3', 'C4', 'D4', 'C4', 'B3', 'A3.C4', 'D3', 'G#3', 'A3', 'B3', 'A3', 'G#3', 'F#3', 'E3', 'C3.E3.A3', 'F3', 'E3', 'D3', 'C3', 'B2', 'A2', 'G#2.D3.B3', 'E3', 'F3', 'E3', 'D3', 'C3', 'B2', 'D4', 'B3', 'C4', 'A3', 'E3', 'G#3', 'A2', 'C#3', 'E3', 'G3', 'F3', 'E3', 'F3', 'A3', 'D4', 'G#3', 'A3', 'E3', 'A2.E3.C#4', 'F3', 'G3', 'E3', 'F3', 'A3', 'C#3', 'D3', 'E3', 'B-2', 'A2', 'G2', 'F2', 'A3', 'F3', 'D3', 'G3', 'B2', 'C#3', 'A3', 'G3', 'F3', 'E3', 'D3', 'F#3', 'D3', 'E-3', 'C3', 'B-2', 'G3', 'A2', 'G2', 'F#2', 'A2', 'D3', 'C4', 'B-3', 'F#3', 'G3', 'B-3', 'D4', 'A3', 'B-3', 'G3', 'E-3', 'D3', 'E-3', 'G3', 'C4', 'A3', 'B-3', 'G3', 'D3', 'C3', 'D3', 'G3', 'B-3', 'F#3', 'G3', 'E-3', 'C3', 'B-2', 'C3', 'B-3', 'A3', 'C4', 'E-4', 'G3', 'C3.F#3', 'G3', 'A3', 'D3', 'E-3', 'C3', 'B-2', 'D3', 'G3', 'B-2', 'D2', 'F#3', 'G2.G3', 'A3', 'B-3', 'D4', 'G3', 'F3', 'B-2.E3', 'F3', 'G3', 'E3', 'C3', 'B-2', 'A2', 'F3', 'G2', 'F2', 'E2', 'G3', 'A3', 'B-3', 'B-3', 'A3', 'G3', 'F3', 'A3', 'E3', 'F3', 'D3', 'B-2', 'D3', 'F3', 'A3', 'D4', 'A3', 'B-3', 'G3', 'A2', 'G3', 'C#4', 'D4', 'E4', 'G3', 'A3', 'E3', 'F3', 'D3', 'B-2', 'D3', 'G#2', 'F3', 'E3', 'D3', 'D3', 'C#3', 'B2', 'A2', 'C3', 'A2', 'F#2', 'D3', 'C3', 'A2', 'B2', 'D3', 'F3', 'D3', 'G#2', 'D3', 'C#3', 'E3', 'G3', 'B-3', 'E4', 'A3', 'B-3', 'G3', 'F3', 'C#3', 'D3', 'G#2', 'A2', 'C#3', 'D2', 'D4', 'C4', 'A3', 'B-3', 'G3', 'E3', 'C#4', 'D4', 'A3', 'F3', 'D3', 'D2', 'E3', 'A2.E3.C#4', 'F3', 'G3', 'E3', 'F3', 'A3', 'C#3', 'D3', 'E3', 'B-2', 'A2', 'G2', 'F2', 'A3', 'F3', 'D3', 'G3', 'B2', 'C#3', 'A3', 'G3', 'F3', 'E3', 'D3', 'F#3', 'D3', 'E-3', 'C3', 'B-2', 'G3', 'A2', 'G2', 'F#2', 'A2', 'D3', 'C4', 'B-3', 'F#3', 'G3', 'B-3', 'D4', 'A3', 'B-3', 'G3', 'E-3', 'D3', 'E-3', 'G3', 'C4', 'A3', 'B-3', 'G3', 'D3', 'C3', 'D3', 'G3', 'B-3', 'F#3', 'G3', 'E-3', 'C3', 'B-2', 'C3', 'B-3', 'A3', 'C4', 'E-4', 'G3', 'C3.F#3', 'G3', 'A3', 'D3', 'E-3', 'C3', 'B-2', 'D3', 'G3', 'B-2', 'D2', 'F#3', 'G2.G3', 'A3', 'B-3', 'D4', 'G3', 'F3', 'B-2.E3', 'F3', 'G3', 'E3', 'C3', 'B-2', 'A2', 'F3', 'G2', 'F2', 'E2', 'G3', 'A3', 'B-3', 'B-3', 'A3', 'G3', 'F3', 'A3', 'E3', 'F3', 'D3', 'B-2', 'D3', 'F3', 'A3', 'D4', 'A3', 'B-3', 'G3', 'A2', 'G3', 'C#4', 'D4', 'E4', 'G3', 'A3', 'E3', 'F3', 'D3', 'B-2', 'D3', 'G#2', 'F3', 'E3', 'D3', 'D3', 'C#3', 'B2', 'A2', 'C3', 'A2', 'F#2', 'D3', 'C3', 'A2', 'B2', 'D3', 'F3', 'D3', 'G#2', 'D3', 'C#3', 'E3', 'G3', 'B-3', 'E4', 'A3', 'B-3', 'G3', 'F3', 'C#3', 'D3', 'G#2', 'A2', 'C#3', 'D2', 'D4', 'C4', 'A3', 'B-3', 'G3', 'E3', 'C#4', 'D4', 'A3', 'F3', 'D3', 'D2', 'START'], ['START', 'B-3', 'E-2.B-2.F#3.B-3', 'B3', 'B-3', 'G#3', 'F#3', 'F3', 'E-3', 'E-3', 'D3', 'E-3', 'F3', 'B-2', 'B2', 'G#2', 'F#2', 'B-2', 'E-3', 'F#2', 'F2', 'D3', 'E-2.B-2.E-3', 'F3', 'F#3', 'G#3', 'B-3', 'B3', 'E-2.B-2.G3.C#4', 'E-4', 'E4', 'E-4', 'C#4', 'B3', 'B-3', 'C#4', 'B3', 'B-3', 'G#3', 'E-4', 'F#3', 'F3', 'G#3', 'B3', 'E-4', 'C#4', 'B3', 'B-3', 'G#3', 'B3', 'B-3', 'G#3', 'F#3', 'F#3.B-3', 'F#3.B-3', 'C4', 'F#3', 'F3', 'E-3', 'F3', 'D4', 'E-4', 'D4', 'E-3.E-4', 'F4', 'F#4', 'F4', 'E-4', 'F4', 'E-4', 'C#4', 'C4', 'C#4', 'C4', 'B-3', 'A3', 'B-3', 'A3', 'G3', 'F3', 'F4', 'C#4', 'B-3', 'G#3', 'F#3.F4', 'B-3', 'F#3', 'E-3', 'C3.E-3', 'F#3', 'E-3', 'C3', 'A2', 'C3', 'F3', 'A3', 'C4', 'E-4', 'C#4', 'C4', 'C#4', 'B-3', 'F#3', 'F3', 'E-3', 'F#3', 'F3', 'E-3', 'A3', 'B-3', 'C4', 'E-4', 'F3', 'E-3', 'C#3', 'F3', 'B-3', 'E-4', 'F3.C4', 'B-3', 'E3.B-3', 'E3', 'F3', 'G3.A3', 'B-3.C4', 'C#4', 'E-4', 'C#4', 'C4', 'B-3.C#4', 'E-3', 'A3', 'B-3', 'C4', 'B-3', 'A3', 'G3', 'F3', 'C#3.F3.B-3', 'F#3', 'F3', 'E-3', 'C#3', 'C3', 'B-2', 'A2.E-3.C4', 'F3', 'F#3', 'F3', 'E-3', 'C#3', 'C3', 'E-4', 'C4', 'C#4', 'B-3', 'F3', 'A3', 'B-2', 'D3', 'F3', 'G#3', 'F#3', 'F3', 'F#3', 'B-3', 'E-4', 'A3', 'B-3', 'B-3', 'E-2.B-2.F#3.B-3', 'B3', 'B-3', 'G#3', 'F#3', 'F3', 'E-3', 'E-3', 'D3', 'E-3', 'F3', 'B-2', 'B2', 'G#2', 'F#2', 'B-2', 'E-3', 'F#2', 'F2', 'D3', 'E-2.B-2.E-3', 'F3', 'F#3', 'G#3', 'B-3', 'B3', 'E-2.B-2.G3.C#4', 'E-4', 'E4', 'E-4', 'C#4', 'B3', 'B-3', 'C#4', 'B3', 'B-3', 'G#3', 'E-4', 'F#3', 'F3', 'G#3', 'B3', 'E-4', 'C#4', 'B3', 'B-3', 'G#3', 'B3', 'B-3', 'G#3', 'F#3', 'F#3.B-3', 'F#3.B-3', 'C4', 'F#3', 'F3', 'E-3', 'F3', 'D4', 'E-4', 'D4', 'E-3.E-4', 'F4', 'F#4', 'F4', 'E-4', 'F4', 'E-4', 'C#4', 'C4', 'C#4', 'C4', 'B-3', 'A3', 'B-3', 'A3', 'G3', 'F3', 'F4', 'C#4', 'B-3', 'G#3', 'F#3.F4', 'B-3', 'F#3', 'E-3', 'C3.E-3', 'F#3', 'E-3', 'C3', 'A2', 'C3', 'F3', 'A3', 'C4', 'E-4', 'C#4', 'C4', 'C#4', 'B-3', 'F#3', 'F3', 'E-3', 'F#3', 'F3', 'E-3', 'A3', 'B-3', 'C4', 'E-4', 'F3', 'E-3', 'C#3', 'F3', 'B-3', 'E-4', 'F3.C4', 'B-3', 'E3.B-3', 'E3', 'F3', 'G3.A3', 'B-3.C4', 'C#4', 'E-4', 'C#4', 'C4', 'B-3.C#4', 'E-3', 'A3', 'B-3', 'C4', 'B-3', 'A3', 'G3', 'F3', 'C#3.F3.B-3', 'F#3', 'F3', 'E-3', 'C#3', 'C3', 'B-2', 'A2.E-3.C4', 'F3', 'F#3', 'F3', 'E-3', 'C#3', 'C3', 'E-4', 'C4', 'C#4', 'B-3', 'F3', 'A3', 'B-2', 'D3', 'F3', 'G#3', 'F#3', 'F3', 'F#3', 'B-3', 'E-4', 'A3', 'B-3', 'F3', 'B-2.F3.D4', 'F#3', 'G#3', 'F3', 'F#3', 'B-3', 'D3', 'E-3', 'F3', 'B2', 'B-2', 'G#2', 'F#2', 'B-3', 'F#3', 'E-3', 'G#3', 'C3', 'D3', 'B-3', 'G#3', 'F#3', 'F3', 'E-3', 'G3', 'E-3', 'E3', 'C#3', 'B2', 'G#3', 'B-2', 'G#2', 'G2', 'B-2', 'E-3', 'C#4', 'B3', 'G3', 'G#3', 'B3', 'E-4', 'B-3', 'B3', 'G#3', 'E3', 'E-3', 'E3', 'G#3', 'C#4', 'B-3', 'B3', 'G#3', 'E-3', 'C#3', 'E-3', 'G#3', 'B3', 'G3', 'G#3', 'E3', 'C#3', 'B2', 'C#3', 'B3', 'B-3', 'C#4', 'E4', 'G#3', 'C#3.G3', 'G#3', 'B-3', 'E-3', 'E3', 'C#3', 'B2', 'E-3', 'G#3', 'B2', 'E-2', 'G3', 'G#2.G#3', 'B-3', 'B3', 'E-4', 'G#3', 'F#3', 'B2.F3', 'F#3', 'G#3', 'F3', 'C#3', 'B2', 'B-2', 'F#3', 'G#2', 'F#2', 'F2', 'G#3', 'B-3', 'B3', 'B3', 'B-3', 'G#3', 'F#3', 'B-3', 'F3', 'F#3', 'E-3', 'B2', 'E-3', 'F#3', 'B-3', 'E-4', 'B-3', 'B3', 'G#3', 'B-2', 'G#3', 'D4', 'E-4', 'F4', 'G#3', 'B-3', 'F3', 'F#3', 'E-3', 'B2', 'E-3', 'A2', 'F#3', 'F3', 'E-3', 'E-3', 'D3', 'C3', 'B-2', 'C#3', 'B-2', 'G2', 'E-3', 'C#3', 'B-2', 'C3', 'E-3', 'F#3', 'E-3', 'A2', 'E-3', 'D3', 'F3', 'G#3', 'B3', 'F4', 'B-3', 'B3', 'G#3', 'F#3', 'D3', 'E-3', 'A2', 'B-2', 'D3', 'E-2', 'E-4', 'C#4', 'B-3', 'B3', 'G#3', 'F3', 'D4', 'E-4', 'B-3', 'F#3', 'E-3', 'E-2', 'F3', 'B-2.F3.D4', 'F#3', 'G#3', 'F3', 'F#3', 'B-3', 'D3', 'E-3', 'F3', 'B2', 'B-2', 'G#2', 'F#2', 'B-3', 'F#3', 'E-3', 'G#3', 'C3', 'D3', 'B-3', 'G#3', 'F#3', 'F3', 'E-3', 'G3', 'E-3', 'E3', 'C#3', 'B2', 'G#3', 'B-2', 'G#2', 'G2', 'B-2', 'E-3', 'C#4', 'B3', 'G3', 'G#3', 'B3', 'E-4', 'B-3', 'B3', 'G#3', 'E3', 'E-3', 'E3', 'G#3', 'C#4', 'B-3', 'B3', 'G#3', 'E-3', 'C#3', 'E-3', 'G#3', 'B3', 'G3', 'G#3', 'E3', 'C#3', 'B2', 'C#3', 'B3', 'B-3', 'C#4', 'E4', 'G#3', 'C#3.G3', 'G#3', 'B-3', 'E-3', 'E3', 'C#3', 'B2', 'E-3', 'G#3', 'B2', 'E-2', 'G3', 'G#2.G#3', 'B-3', 'B3', 'E-4', 'G#3', 'F#3', 'B2.F3', 'F#3', 'G#3', 'F3', 'C#3', 'B2', 'B-2', 'F#3', 'G#2', 'F#2', 'F2', 'G#3', 'B-3', 'B3', 'B3', 'B-3', 'G#3', 'F#3', 'B-3', 'F3', 'F#3', 'E-3', 'B2', 'E-3', 'F#3', 'B-3', 'E-4', 'B-3', 'B3', 'G#3', 'B-2', 'G#3', 'D4', 'E-4', 'F4', 'G#3', 'B-3', 'F3', 'F#3', 'E-3', 'B2', 'E-3', 'A2', 'F#3', 'F3', 'E-3', 'E-3', 'D3', 'C3', 'B-2', 'C#3', 'B-2', 'G2', 'E-3', 'C#3', 'B-2', 'C3', 'E-3', 'F#3', 'E-3', 'A2', 'E-3', 'D3', 'F3', 'G#3', 'B3', 'F4', 'B-3', 'B3', 'G#3', 'F#3', 'D3', 'E-3', 'A2', 'B-2', 'D3', 'E-2', 'E-4', 'C#4', 'B-3', 'B3', 'G#3', 'F3', 'D4', 'E-4', 'B-3', 'F#3', 'E-3', 'E-2', 'START']], [[0, 0.25, 0.5, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.5, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.75, 0.25, 0.5, Fraction(1, 6), Fraction(1, 12), 0.25, 0.25, Fraction(1, 12), Fraction(1, 6), Fraction(1, 6), Fraction(1, 12), 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.5, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.75, 0.25, 0.5, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.5, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.75, 0.25, 0.5, Fraction(1, 6), Fraction(1, 12), 0.25, 0.25, Fraction(1, 12), Fraction(1, 6), Fraction(1, 6), Fraction(1, 12), 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.5, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.75, 0.25, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.25, 0.25, 0.5, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.5, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.5, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.75, 0.25, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.25, 0.25, 0.5, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.5, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.5, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.75, 0], [0, 0.25, 0.5, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.5, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.75, 0.25, 0.5, Fraction(1, 6), Fraction(1, 12), 0.25, 0.25, Fraction(1, 12), Fraction(1, 6), Fraction(1, 6), Fraction(1, 12), 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.5, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.75, 0.25, 0.5, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.5, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.75, 0.25, 0.5, Fraction(1, 6), Fraction(1, 12), 0.25, 0.25, Fraction(1, 12), Fraction(1, 6), Fraction(1, 6), Fraction(1, 12), 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.5, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.75, 0.25, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.25, 0.25, 0.5, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.5, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.5, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.75, 0.25, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.25, 0.25, 0.5, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.5, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.5, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.75, 0]])\n"
59 | ]
60 | }
61 | ],
62 | "source": [
63 | "def get_note_duration(file):\n",
64 | " #读取数据\n",
65 | " file = music21.converter.parse(file).chordify()\n",
66 | "\n",
67 | " note = []\n",
68 | " duration = []\n",
69 | "\n",
70 | " #不知道为什么是0和1,总之1个mid文件里能解析出两条音轨\n",
71 | " for i in [0, 1]:\n",
72 | "\n",
73 | " #开始符号\n",
74 | " n = ['START']\n",
75 | " d = [0]\n",
76 | "\n",
77 | " #读取音符和持续时间\n",
78 | " for j in file.transpose(i).flat:\n",
79 | " if not isinstance(j, music21.chord.Chord):\n",
80 | " continue\n",
81 | "\n",
82 | " #在同一个时间点可能有多个音符,把他们都前后拼合在一起,以\".\"间隔\n",
83 | " n_join = [k.nameWithOctave for k in j.pitches]\n",
84 | " n_join = '.'.join(n_join)\n",
85 | " n.append(n_join)\n",
86 | "\n",
87 | " #取持续时间\n",
88 | " d.append(j.duration.quarterLength)\n",
89 | "\n",
90 | " #结束符号\n",
91 | " n.append('START')\n",
92 | " d.append(0)\n",
93 | "\n",
94 | " note.append(n)\n",
95 | " duration.append(d)\n",
96 | "\n",
97 | " return note, duration\n",
98 | "\n",
99 | "\n",
100 | "print(get_note_duration('../datas/cello/cs2-2all.mid'))"
101 | ]
102 | },
103 | {
104 | "cell_type": "code",
105 | "execution_count": 3,
106 | "metadata": {},
107 | "outputs": [
108 | {
109 | "name": "stdout",
110 | "output_type": "stream",
111 | "text": [
112 | "['START', 'C4', 'B3', 'A3', 'G3', 'F3', 'E3', 'D3', 'C3', 'G2', 'E2', 'G2', 'C2', 'D2', 'E2', 'F2', 'G2', 'A2', 'B2', 'C3']\n",
113 | "[0, 0.5, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 1.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25]\n"
114 | ]
115 | },
116 | {
117 | "data": {
118 | "text/plain": [
119 | "(72, 72)"
120 | ]
121 | },
122 | "execution_count": 3,
123 | "metadata": {},
124 | "output_type": "execute_result"
125 | }
126 | ],
127 | "source": [
128 | "import os\n",
129 | "\n",
130 | "\n",
131 | "def load_datas():\n",
132 | " note = []\n",
133 | " duration = []\n",
134 | "\n",
135 | " #读取文件列表\n",
136 | " files = ['../datas/cello/' + i for i in os.listdir('../datas/cello')]\n",
137 | "\n",
138 | " for i in files:\n",
139 | " n, d = get_note_duration(i)\n",
140 | " note.extend(n)\n",
141 | " duration.extend(d)\n",
142 | "\n",
143 | " return note, duration\n",
144 | "\n",
145 | "\n",
146 | "note, duration = load_datas()\n",
147 | "\n",
148 | "print(note[0][:20])\n",
149 | "print(duration[0][:20])\n",
150 | "\n",
151 | "len(note), len(duration)"
152 | ]
153 | },
154 | {
155 | "cell_type": "code",
156 | "execution_count": 4,
157 | "metadata": {},
158 | "outputs": [
159 | {
160 | "name": "stdout",
161 | "output_type": "stream",
162 | "text": [
163 | "[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 9, 11, 12, 10, 13, 9, 14, 15, 8]\n",
164 | "[0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 2, 2, 2, 2, 2, 2, 2]\n",
165 | "72 72\n",
166 | "{'START': 0, 'C4': 1, 'B3': 2, 'A3': 3, 'G3': 4, 'F3': 5, 'E3': 6, 'D3': 7, 'C3': 8, 'G2': 9, 'E2': 10, 'C2': 11, 'D2': 12, 'F2': 13, 'A2': 14, 'B2': 15, 'D4': 16, 'E4': 17, 'F4': 18, 'F#3': 19, 'C#3': 20, 'G#3': 21, 'E-3': 22, 'B-3': 23, 'G4': 24, 'E-4': 25, 'F#2': 26, 'B-2': 27, 'F2.G2.D3.B3': 28, 'E-2.G2.G3.A3': 29, 'D2.G2.F3.B3': 30, 'C2.G2.E3.C4': 31, 'G2.D3.C4': 32, 'G2.D3.B3': 33, 'C2.G2.E3.B-3': 34, 'C2.A2.F3.A3': 35, 'C2.G#2.D3.B3': 36, 'D3.B3': 37, 'E3.C4': 38, 'C#4': 39, 'G#2': 40, 'C#2': 41, 'E-2': 42, 'F#4': 43, 'G#4': 44, 'F#2.G#2.E-3.C4': 45, 'E2.G#2.G#3.B-3': 46, 'E-2.G#2.F#3.C4': 47, 'C#2.G#2.F3.C#4': 48, 'G#2.E-3.C#4': 49, 'G#2.E-3.C4': 50, 'C#2.G#2.F3.B3': 51, 'C#2.B-2.F#3.B-3': 52, 'C#2.A2.E-3.C4': 53, 'E-3.C4': 54, 'F3.C#4': 55, 'D2.A2.F3.A3': 56, 'D2.A2.D3': 57, 'D2.A2.F#3.C4': 58, 'F3.A3': 59, 'D3.D4': 60, 'F3.E4': 61, 'B2.D3': 62, 'E3.B3': 63, 'E-3.A3': 64, 'F#3.G#3': 65, 'A3.B3': 66, 'A3.C4': 67, 'C3.E3.A3': 68, 'G#2.D3.B3': 69, 'A2.E3.C#4': 70, 'C3.F#3': 71, 'G2.G3': 72, 'B-2.E3': 73, 'E-2.B-2.F#3.B-3': 74, 'E-2.B-2.E-3': 75, 'E-2.B-2.G3.C#4': 76, 'F#3.B-3': 77, 'E-3.E-4': 78, 'F#3.F4': 79, 'C3.E-3': 80, 'F3.C4': 81, 'E3.B-3': 82, 'G3.A3': 83, 'B-3.C4': 84, 'B-3.C#4': 85, 'C#3.F3.B-3': 86, 'A2.E-3.C4': 87, 'B-2.F3.D4': 88, 'C#3.G3': 89, 'G#2.G#3': 90, 'B2.F3': 91, 'C2.C3': 92, 'C2.B2.F3.G#3': 93, 'C2.E-3': 94, 'G#2.F3': 95, 'G#2.F3.C4': 96, 'C2.G2.F3.C4': 97, 'E3.G3': 98, 'F3.G#3': 99, 'D3.E-3': 100, 'C2.G2.E-3': 101, 'G2.D3': 102, 'G2.F#3.E-4': 103, 'G2.F3.B3': 104, 'G2.E3.C#4': 105, 'D2.B-2.F3.G#3': 106, 'E-2.B-2.E-3.G3': 107, 'F#3.D4': 108, 'G2.D3.B-3': 109, 'C2.G2.E-3.B-3': 110, 'D3.G3': 111, 'G2.D3.G3': 112, 'G3.B-3': 113, 'G3.C4': 114, 'D3.A3': 115, 'D3.B-3': 116, 'D3.C4': 117, 'C3.F3': 118, 'G2.E-3': 119, 'G2.F3': 120, 'B-2.F3': 121, 'B-2.E-3': 122, 'C3.D3': 123, 'B-2.D3': 124, 'A2.D3': 125, 'B2.E-3': 126, 'D3.G#3': 127, 'E-3.G3': 128, 'C3.G3': 129, 'F3.D4': 130, 'F3.B3': 131, 'G3.D4': 132, 'D4.E-4': 133, 'C2.B-2': 134, 'C2.G#2': 135, 'C2.G2': 136, 'D2.B2': 137, 'D2.E-3': 138, 'E-2.G2': 139, 'E-2.G3': 140, 'E2.C3.G3.B-3': 141, 'F#2.C3.E-3.C4': 142, 'C#2.C#3': 143, 'C#2.C3.F#3.A3': 144, 'C#3.E3': 145, 'C#2.E3': 146, 'A2.F#3': 147, 'A2.F#3.C#4': 148, 'C#2.G#2.F#3.C#4': 149, 'F#3.A3': 150, 'E-3.E3': 151, 'C#2.G#2.E3': 152, 'G#2.E-3': 153, 'G#2.G3.E4': 154, 'B3.C#4': 155, 'G#2.F#3.C4': 156, 'G#2.F3.D4': 157, 'E-2.B2.F#3.A3': 158, 'E2.B2.E3.G#3': 159, 'G3.E-4': 160, 'G#2.E-3.B3': 161, 'C#2.G#2.E3.B3': 162, 'E-3.G#3': 163, 'G#2.E-3.G#3': 164, 'G#3.B3': 165, 'G#3.C#4': 166, 'E-3.B-3': 167, 'E-3.B3': 168, 'E-3.C#4': 169, 'C#3.F#3': 170, 'G#2.E3': 171, 'G#2.F#3': 172, 'B2.F#3': 173, 'B2.E3': 174, 'E2.B2.E3': 175, 'C#3.E-3': 176, 'C3.E3': 177, 'E3.G#3': 178, 'C#3.G#3': 179, 'F#3.E-4': 180, 'F#3.C4': 181, 'G#3.E-4': 182, 'E-4.E4': 183, 'C#2.B2': 184, 'C#2.A2': 185, 'C#2.G#2': 186, 'E-2.C3': 187, 'E-2.E3': 188, 'E2.G#2': 189, 'E2.G#3': 190, 'F2.C#3.G#3.B3': 191, 'G2.C#3.E3.C#4': 192, 'F3.B-3': 193, 'F3.G3': 194, 'C2.G2.F3': 195, 'E2.E3.G3': 196, 'A2.G3': 197, 'A2.E3': 198, 'G2.D3.A3': 199, 'C#3.G3.A3': 200, 'G2.E3.B3': 201, 'G#2.F3.B3': 202, 'G#2.E3.B3': 203, 'B3.C4': 204, 'D2.A2.F3.D4': 205, 'C2.A2.F#3': 206, 'C2.A2.G3.A3': 207, 'A2.F#3.C4': 208, 'A2.F#3.D4.E4': 209, 'A2.F3': 210, 'B2.A3': 211, 'B2.G3': 212, 'F#3.B3': 213, 'C#2.G#2.F#3': 214, 'F2.F3.G#3': 215, 'B-2.G#3': 216, 'G#2.E-3.B-3': 217, 'D3.G#3.B-3': 218, 'A2.F3.C4': 219, 'C4.C#4': 220, 'E-2.B-2.F#3.E-4': 221, 'C#2.B-2.G3': 222, 'C#2.B-2.G#3.B-3': 223, 'B-2.G3.C#4': 224, 'B-2.G3.E-4.F4': 225, 'B-2.F#3': 226, 'C3.B-3': 227, 'C3.G#3': 228, 'D2.C3.F#3.E-4': 229, 'D2.B-2.G3.D4': 230, 'D2.B-2.G3.C4': 231, 'C#3.G3.B-3': 232, 'D3.G3.A3': 233, 'E-2.B-2.G3.E-4': 234, 'E-2.C#3.G3.E4': 235, 'E-2.B2.G#3.E-4': 236, 'E-2.B2.G#3.C#4': 237, 'D3.G#3.B3': 238, 'E-3.G#3.B-3': 239, 'E2.B2.G#3.E4': 240, 'D3.E-4': 241, 'D3.F#4': 242, 'D3.G4': 243, 'D3.C#4': 244, 'G3.B3': 245, 'G2.F#3': 246, 'G2.A3': 247, 'G2.C4': 248, 'G2.G#3': 249, 'G2.B3': 250, 'E-3.E4': 251, 'E-3.G4': 252, 'E-3.G#4': 253, 'E-3.D4': 254, 'G3.C#4': 255, 'G#3.C4': 256, 'G#2.G3': 257, 'G#2.B-3': 258, 'G#2.C#4': 259, 'G#2.A3': 260, 'G#2.C4': 261, 'C2.E-3.G3': 262, 'C3.G3.A3': 263, 'C3.G3.B-3': 264, 'B-2.G3': 265, 'D3.F#3': 266, 'C2.G2.E3': 267, 'C2.E3.G3': 268, 'G2.E-3.B-3': 269, 'F2.G#2.E-3': 270, 'F2.G#2.D3': 271, 'E-2.B-2.G3': 272, 'F2.B2.G3': 273, 'F2.C3.D3': 274, 'G#2.D3': 275, 'F2.C3.E-3': 276, 'C#2.E3.G#3': 277, 'C#3.G#3.B-3': 278, 'C#3.G#3.B3': 279, 'B2.G#3': 280, 'C#2.G#2.F3': 281, 'C#2.F3.G#3': 282, 'C#3.A3': 283, 'F#2.A2.E3': 284, 'F#2.A2.E-3': 285, 'E3.A3': 286, 'E2.B2.G#3': 287, 'F#2.C3.G#3': 288, 'E2.C#3': 289, 'F#2.C#3.E-3': 290, 'A2.E-3': 291, 'F#2.C#3.E3': 292, 'G2.E3.C4': 293, 'B2.C3': 294, 'D2.A2.F#3': 295, 'A2.B2': 296, 'E3.F3': 297, 'A2.E3.C4': 298, 'G#2.F3.C#4': 299, 'C3.C#3': 300, 'B-2.C3': 301, 'F3.F#3': 302, 'B-2.F3.C#4': 303, 'B-2.F#3.C#4': 304, 'D3.A3.F#4': 305, 'G2.D3.B3.F#4': 306, 'G3.B3.E4': 307, 'G3.B3.F#4': 308, 'E3.D4': 309, 'E3.C#4': 310, 'G2.E3.C#4.A4': 311, 'A4': 312, 'F#3.D4.A4': 313, 'B4': 314, 'A3.G4': 315, 'A3.F#4': 316, 'D3.A3.G4': 317, 'F#3.C#4.A4': 318, 'B2.F#3.E-4.A4': 319, 'B3.G4': 320, 'B3.F#4': 321, 'E3.B3.A4': 322, 'E3.B3.G4': 323, 'D3.B3.G4': 324, 'B2.F#3.B3': 325, 'B2.F#3.D4': 326, 'A2.F#3.D4': 327, 'G#2.E3.D4': 328, 'A2.E3.D4': 329, 'A3.E4': 330, 'B2.D3.B3.F#4': 331, 'G2.D3.B3.A4': 332, 'A3.D4': 333, 'D2.A2.F#3.D4': 334, 'F#3.G3': 335, 'D3.C#4.E4': 336, 'D3.E4': 337, 'D3.A4': 338, 'E-3.B-3.G4': 339, 'G#2.E-3.C4.G4': 340, 'G#3.C4.F4': 341, 'G#3.C4.G4': 342, 'F3.E-4': 343, 'G#2.F3.D4.B-4': 344, 'B-4': 345, 'G3.E-4.B-4': 346, 'C5': 347, 'B-3.G#4': 348, 'B-3.G4': 349, 'E-3.B-3.G#4': 350, 'G3.D4.B-4': 351, 'C3.G3.E4.B-4': 352, 'C4.G#4': 353, 'C4.G4': 354, 'F3.C4.B-4': 355, 'F3.C4.G#4': 356, 'E-3.C4.G#4': 357, 'C3.G3.C4': 358, 'C3.G3.E-4': 359, 'B-2.G3.E-4': 360, 'A2.F3.E-4': 361, 'B-2.F3.E-4': 362, 'B-3.F4': 363, 'C3.E-3.C4.G4': 364, 'G#2.E-3.C4.B-4': 365, 'B-3.E-4': 366, 'G3.G#3': 367, 'E-3.D4.F4': 368, 'E-3.F4': 369, 'E-3.B-4': 370, 'D3.E3': 371, 'F2.A2': 372, 'F#2.A2.D3.A3': 373, 'B3.D4': 374, 'C4.E4': 375, 'C4.F#4': 376, 'G2.B2': 377, 'E2.G2': 378, 'D4.E4': 379, 'E3.F#3': 380, 'G#3.A3': 381, 'A2.E3.A3': 382, 'C4.D4': 383, 'C#3.D3': 384, 'A2.C3': 385, 'E-3.F3': 386, 'F#2.B-2': 387, 'G2.B-2.E-3.B-3': 388, 'C#3.F3': 389, 'C4.E-4': 390, 'C#4.F4': 391, 'C#4.G4': 392, 'G#2.C3': 393, 'F2.G#2': 394, 'G#3.B-3': 395, 'E-4.F4': 396, 'A3.B-3': 397, 'B-2.F3.B-3': 398, 'C#4.E-4': 399, 'B-2.C#3': 400, 'D3.C#4.F#4': 401, 'C#3.E3.B3.E4': 402, 'E2.B2.G#3.D4': 403, 'C#3.E3.A3': 404, 'C#4.D4': 405, 'G#3.D4.E4': 406, 'E3.D4.G#4': 407, 'D4.G#4': 408, 'D4.B4': 409, 'B3.E4': 410, 'C#4.A4': 411, 'C#4.E4': 412, 'A3.E4.F#4': 413, 'A3.E4.G4': 414, 'E4.G4': 415, 'F#4.G4': 416, 'C#5': 417, 'D5': 418, 'F#2.G2': 419, 'E2.B2.G3': 420, 'D2.B2.G3': 421, 'F#4.D5': 422, 'F#4.A4': 423, 'D4.A4': 424, 'E-3.D4.G4': 425, 'D3.F3.C4.F4': 426, 'F2.C3.A3.E-4': 427, 'D3.F3.B-3': 428, 'A3.E-4.F4': 429, 'F3.E-4.A4': 430, 'E-4.A4': 431, 'E-4.C5': 432, 'C4.F4': 433, 'D4.B-4': 434, 'D4.F4': 435, 'B-3.F4.G4': 436, 'B-3.F4.G#4': 437, 'F4.G#4': 438, 'G4.G#4': 439, 'E-5': 440, 'G2.G#2': 441, 'F2.C3.G#3': 442, 'E-2.C3.G#3': 443, 'G4.E-5': 444, 'G4.B-4': 445, 'E-4.B-4': 446, 'C2.G2.E-3.C4': 447, 'E-2.G2.D3': 448, 'E2.C3.G3': 449, 'D2.B-2.G#3': 450, 'G#3.F4': 451, 'E-3.F3.G3': 452, 'C2.A2.F#3.D4': 453, 'G2.B3.D4': 454, 'B-2.D3.G#3': 455, 'B-2.D3.E-3.G#3': 456, 'G#2.B-2': 457, 'C2.B-2.E3': 458, 'D3.G3.G#3': 459, 'C#2.G#2.E3.C#4': 460, 'E2.G#2.E-3': 461, 'F2.C#3.G#3': 462, 'F#2.C#3.A3': 463, 'E-2.B2.A3': 464, 'E3.F#3.G#3': 465, 'C#2.B-2.G3.E-4': 466, 'G#2.C4.E-4': 467, 'E4.F#4': 468, 'B-3.B3': 469, 'B2.E-3.A3': 470, 'B2.E-3.E3.A3': 471, 'C#2.B2.F3': 472, 'E-3.G#3.A3': 473, 'D2.A2.F3': 474, 'F2.A2.D3': 475, 'F2.A2.E3': 476, 'A2.A3': 477, 'B-2.D3.A3': 478, 'B-2.A3': 479, 'B-2.B-3': 480, 'G2.F3.B-3': 481, 'G2.F3.C4': 482, 'G2.F3.D4': 483, 'G2.E3': 484, 'F2.A2.D3.A3': 485, 'G2.E3.F3': 486, 'G2.F3.G3': 487, 'G2.D3.C#4': 488, 'G2.D3.D4': 489, 'G#3.D4': 490, 'E-2.B-2.F#3': 491, 'F#2.B-2.E-3': 492, 'F#2.B-2.F3': 493, 'B2.E-3.B-3': 494, 'B2.B-3': 495, 'B2.B3': 496, 'G#2.F#3.B3': 497, 'G#2.F#3.C#4': 498, 'G#2.F#3.E-4': 499, 'D2.B2.G#3': 500, 'F#2.B-2.E-3.B-3': 501, 'G#2.F3.F#3': 502, 'G#2.F#3.G#3': 503, 'F#3.C#4': 504, 'G#2.E-3.D4': 505, 'G#2.E-3.E-4': 506, 'A3.E-4': 507, 'C3.A3.E-4': 508, 'F2.C3': 509, 'F2.B-2': 510, 'F2.D3.G#3': 511, 'C#3.B-3.E4': 512, 'F#2.C#3': 513, 'F#2.B2': 514, 'F#2.E-3.A3': 515, 'G2.B-3': 516, 'F2.A3': 517, 'D3.F4': 518, 'C3.E3.B3': 519, 'G#2.B3': 520, 'F#2.B-3': 521, 'E-3.F#4': 522, 'C#3.F3.C4': 523, 'E5': 524, 'F#5': 525, 'G5': 526, 'B2.G#3.D4': 527, 'B-2.G3.D4': 528, 'F5': 529, 'G#5': 530, 'B2.G#3.E-4': 531, 'A2.F#3.E-4': 532, 'G2.G3.B3.E4': 533, 'G2.G3.B3.C#4': 534, 'A2.F#3.D4.A4': 535, 'A2.F#3.D4.F#4': 536, 'E3.C#4.G4': 537, 'G#2.E3.D4.B4': 538, 'G#2.E3.C#4': 539, 'A2.E3.C#4.E4': 540, 'F#2.E3.C#4.E4': 541, 'F#3.A3.E4': 542, 'E-3.C4.F#4': 543, 'E3.B3.F#4': 544, 'D4.C5': 545, 'G3.D4.B4': 546, 'B2.E3.D4': 547, 'B2.E3.E4': 548, 'C3.E3.E4': 549, 'C3.E3.F#4': 550, 'G2.D3.C4.G4': 551, 'G2.D3.A3.G4': 552, 'G3.B3.B4': 553, 'F#3.B3.A4': 554, 'A2.E3.C#4.G#4': 555, 'A2.E3.C#4.A4': 556, 'A2.F3.D4.A4': 557, 'A3.F4': 558, 'B-2.F#3.C#4.E4': 559, 'D3.C4.F#4': 560, 'G3.E4': 561, 'G2.F#3.D4': 562, 'G2.E3.D4': 563, 'A2.G3.D4': 564, 'D2.A2.F#3.C#4': 565, 'G#2.G#3.C4.F4': 566, 'G#2.G#3.C4.D4': 567, 'B-2.G3.E-4.B-4': 568, 'B-2.G3.E-4.G4': 569, 'F3.D4.G#4': 570, 'A2.F3.E-4.C5': 571, 'A2.F3.D4': 572, 'B-2.F3.D4.F4': 573, 'G2.F3.D4.F4': 574, 'G3.B-3.F4': 575, 'F3.C4.G4': 576, 'E-4.C#5': 577, 'G#3.E-4.C5': 578, 'C3.F3.E-4': 579, 'C3.F3.F4': 580, 'C#3.F3.F4': 581, 'C#3.F3.G4': 582, 'G#2.E-3.C#4.G#4': 583, 'G#2.E-3.B-3.G#4': 584, 'G#3.C4.C5': 585, 'G3.C4.B-4': 586, 'B-2.F3.D4.A4': 587, 'B-2.F3.D4.B-4': 588, 'B-2.F#3.E-4.B-4': 589, 'B-3.F#4': 590, 'B2.G3.D4.F4': 591, 'E-3.C#4.G4': 592, 'G#2.G3.E-4': 593, 'G#2.F3.E-4': 594, 'B-2.G#3.E-4': 595, 'E-2.B-2.G3.D4': 596, 'E-2.B-2.G#3': 597, 'E-2.B-2.F3': 598, 'B-2.F3.C4': 599, 'C2.A2.E-3': 600, 'A2.E3.B3': 601, 'E2.B2.A3': 602, 'E2.B2.F#3': 603, 'E3.E-4': 604, 'C3.A3': 605, 'C#3.G#3.E4': 606, 'B2.F#3.E-4': 607, 'B2.F#3.E4': 608, 'B2.F#3.C#4': 609, 'C#2.B-2.E3': 610, 'A2.G#3': 611, 'C#3.B3': 612, 'A2.F#3.E4': 613, 'G2.B3.G4': 614, 'G#2.C4.G#4': 615, 'D2.B2.F3': 616, 'C2.G2.D3': 617, 'C3.G#3.E-4': 618, 'E-2.C3.F#3': 619, 'C#2.G#2.E-3': 620, 'C#3.A3.E4': 621, 'G#3.E4': 622, 'C#4.D4.E4': 623, 'E3.C#4.D4.E4': 624, 'E3.C#4.E4': 625, 'E3.D4.E4': 626, 'B3.D4.E4': 627, 'A3.D4.E4': 628, 'B3.C#4.D4.E4': 629, 'F#3.B3.C#4.D4': 630, 'B3.G4.A4': 631, 'C#4.D4.E4.F#4.G4': 632, 'E3.B3.C#4.D4.E4.F#4.G4': 633, 'D4.E4.F#4.G4': 634, 'E4.F#4.G4': 635, 'E4.F#4.G4.B4': 636, 'A3.E4.F#4.G4': 637, 'A3.F#4.G4': 638, 'D4.F#4': 639, 'B-4.B4.C#5.D5': 640, 'B-3.B-4.B4.C#5.D5': 641, 'B-3.E4': 642, 'B3.C#4.E4': 643, 'B3.C#4.F#4.G#4.A4': 644, 'F#3.B3.C#4.G#4.A4': 645, 'G#3.G#4.A4': 646, 'A3.G#4.A4': 647, 'B3.G#4.A4': 648, 'A3.C#4': 649, 'G#4.A4': 650, 'F#3.C#4.E4': 651, 'E3.D4.G#4.A4': 652, 'D4.E4.F#4': 653, 'C#4.D4.F#4': 654, 'B3.D4.F#4': 655, 'A3.D4.F#4': 656, 'G#3.D4.F#4': 657, 'G#4.B4': 658, 'A3.B3.C#4': 659, 'G2.E3.A3.C#4.A4': 660, 'D3.E3.F#4.G4': 661, 'D3.E3.F#3.F#4.G4': 662, 'D3.E3.F#3': 663, 'C2.D3.E3': 664, 'B2.G3.B3.C4.D4': 665, 'A2.C3.D3.E3': 666, 'G#2.C3.D3': 667, 'G#2.F#4': 668, 'F4.F#4': 669, 'B2.F#3.C#4.D4.E4': 670, 'A3.C#4.E4': 671, 'F#3.G3.A3': 672, 'D3.E3.F#3.A3': 673, 'B3.C#4.D4': 674, 'F3.C#4.D4': 675, 'B2.C#3': 676, 'A2.F#4': 677, 'C#4.F#4': 678, 'A4.B4': 679, 'B2.F#3.A3.B3': 680, 'B2.F#3.A3.B3.C4': 681, 'A3.B3.C4': 682, 'A3.B3.A4': 683, 'A3.B3.G4': 684, 'A3.B3.F#4.G4': 685, 'F#4.G4.A4': 686, 'E3.B3.F#4.G4': 687, 'G4.A4': 688, 'E-3.F#3': 689, 'E4.G#4': 690, 'D4.E-4.F4': 691, 'F3.D4.E-4.F4': 692, 'F3.D4.F4': 693, 'F3.E-4.F4': 694, 'C4.E-4.F4': 695, 'B-3.E-4.F4': 696, 'C4.D4.E-4.F4': 697, 'G3.C4.D4.E-4': 698, 'C4.G#4.B-4': 699, 'D4.E-4.F4.G4.G#4': 700, 'F3.C4.D4.E-4.F4.G4.G#4': 701, 'E-4.F4.G4.G#4': 702, 'F4.G4.G#4': 703, 'F4.G4.G#4.C5': 704, 'B-3.F4.G4.G#4': 705, 'B-3.G4.G#4': 706, 'E-4.G4': 707, 'B4.C5.D5.E-5': 708, 'B3.B4.C5.D5.E-5': 709, 'B3.F4': 710, 'C4.D4.F4': 711, 'C4.D4.G4.A4.B-4': 712, 'G3.C4.D4.A4.B-4': 713, 'A3.A4.B-4': 714, 'B-3.A4.B-4': 715, 'C4.A4.B-4': 716, 'B-3.D4': 717, 'A4.B-4': 718, 'G3.D4.F4': 719, 'F3.E-4.A4.B-4': 720, 'E-4.F4.G4': 721, 'D4.E-4.G4': 722, 'C4.E-4.G4': 723, 'B-3.E-4.G4': 724, 'A3.E-4.G4': 725, 'A4.C5': 726, 'F4.G4': 727, 'B-3.C4.D4': 728, 'G#2.F3.B-3.D4.B-4': 729, 'E-3.F3.G4.G#4': 730, 'E-3.F3.G3.G4.G#4': 731, 'C#2.E-3.F3': 732, 'C3.G#3.C4.C#4.E-4': 733, 'B-2.C#3.E-3.F3': 734, 'A2.C#3.E-3': 735, 'A2.G4': 736, 'C3.G3.D4.E-4.F4': 737, 'B-3.D4.F4': 738, 'G3.G#3.B-3': 739, 'E-3.F3.G3.B-3': 740, 'C4.D4.E-4': 741, 'F#3.D4.E-4': 742, 'B-2.G4': 743, 'D4.G4': 744, 'B-4.C5': 745, 'E4.F4': 746, 'C3.G3.E4': 747, 'C3.G3.B-3.C4': 748, 'C3.G3.B-3.C4.C#4': 749, 'B-3.C4.C#4': 750, 'B-3.C4.B-4': 751, 'B-3.C4.G#4': 752, 'B-3.C4.G4.G#4': 753, 'G4.G#4.B-4': 754, 'F3.C4.G4.G#4': 755, 'G#4.B-4': 756, 'C#3.B-3': 757, 'F4.A4': 758, 'D3.F3.A3': 759, 'C3.E3.B-3': 760, 'G2.G3.E4': 761, 'B-2.E3.D4': 762, 'D3.F3': 763, 'E-3.F#3.B-3': 764, 'C#3.F3.B3': 765, 'G#2.G#3.F4': 766, 'B2.F3.E-4': 767, 'C#3.G#3.C#4': 768, 'C#2.B-2.F3': 769, 'A2.G3.C#4': 770, 'D2.B2.F#3': 771, 'B-2.G#3.D4': 772, 'B-2.F#3.E-4': 773}\n",
167 | "{0: 0, 0.5: 1, 0.25: 2, 1.25: 3, 1.0: 4, 3.0: 5, 0.75: 6, Fraction(1, 6): 7, Fraction(1, 12): 8, 1.5: 9, Fraction(2, 3): 10, 2.0: 11, Fraction(1, 3): 12, Fraction(4, 3): 13, 2.25: 14, 1.75: 15, 2.5: 16, 4.0: 17, Fraction(5, 12): 18}\n",
168 | "774 19\n"
169 | ]
170 | }
171 | ],
172 | "source": [
173 | "def encode(data):\n",
174 | " #编字典\n",
175 | " vocab = {}\n",
176 | " for i in data:\n",
177 | " for j in i:\n",
178 | " if j not in vocab:\n",
179 | " vocab[j] = len(vocab)\n",
180 | "\n",
181 | " #用字典编码\n",
182 | " new_date = []\n",
183 | " for line in data:\n",
184 | " new_line = [vocab[node] for node in line]\n",
185 | " new_date.append(new_line)\n",
186 | "\n",
187 | " return new_date, vocab\n",
188 | "\n",
189 | "\n",
190 | "note, note_vocab = encode(note)\n",
191 | "duration, duration_vocab = encode(duration)\n",
192 | "\n",
193 | "print(note[0][:20])\n",
194 | "print(duration[0][:20])\n",
195 | "print(len(note), len(duration))\n",
196 | "\n",
197 | "print(note_vocab)\n",
198 | "print(duration_vocab)\n",
199 | "print(len(note_vocab), len(duration_vocab))"
200 | ]
201 | },
202 | {
203 | "cell_type": "code",
204 | "execution_count": 5,
205 | "metadata": {},
206 | "outputs": [
207 | {
208 | "name": "stderr",
209 | "output_type": "stream",
210 | "text": [
211 | "Using TensorFlow backend.\n"
212 | ]
213 | },
214 | {
215 | "data": {
216 | "text/plain": [
217 | "(array([[ 0, 1, 2, ..., 6, 5, 7],\n",
218 | " [ 1, 2, 3, ..., 5, 7, 6],\n",
219 | " [ 2, 3, 4, ..., 7, 6, 5],\n",
220 | " ...,\n",
221 | " [27, 4, 5, ..., 4, 27, 7],\n",
222 | " [ 4, 5, 21, ..., 27, 7, 40],\n",
223 | " [ 5, 21, 4, ..., 7, 40, 42]]),\n",
224 | " array([[0., 0., 0., ..., 0., 0., 0.],\n",
225 | " [0., 0., 0., ..., 0., 0., 0.],\n",
226 | " [0., 0., 0., ..., 0., 0., 0.],\n",
227 | " ...,\n",
228 | " [0., 0., 0., ..., 0., 0., 0.],\n",
229 | " [0., 0., 0., ..., 0., 0., 0.],\n",
230 | " [1., 0., 0., ..., 0., 0., 0.]], dtype=float32))"
231 | ]
232 | },
233 | "execution_count": 5,
234 | "metadata": {},
235 | "output_type": "execute_result"
236 | }
237 | ],
238 | "source": [
239 | "import numpy as np\n",
240 | "import keras\n",
241 | "\n",
242 | "\n",
243 | "#把一维的数据切成段,以前面的词,预测最后一个词\n",
244 | "def prepare_sequences(data, num_classes):\n",
245 | " input = []\n",
246 | " output = []\n",
247 | "\n",
248 | " for line in data:\n",
249 | " for i in range(len(line) - 32):\n",
250 | " input.append(line[i:i + 32])\n",
251 | " output.append(line[i + 32])\n",
252 | "\n",
253 | " input = np.array(input)\n",
254 | " output = keras.utils.np_utils.to_categorical(output,\n",
255 | " num_classes=num_classes)\n",
256 | "\n",
257 | " return input, output\n",
258 | "\n",
259 | "\n",
260 | "prepare_sequences(note, len(note_vocab))"
261 | ]
262 | },
263 | {
264 | "cell_type": "code",
265 | "execution_count": 6,
266 | "metadata": {},
267 | "outputs": [
268 | {
269 | "data": {
270 | "text/plain": [
271 | "((53162, 32), (53162, 774), (53162, 32), (53162, 19))"
272 | ]
273 | },
274 | "execution_count": 6,
275 | "metadata": {},
276 | "output_type": "execute_result"
277 | }
278 | ],
279 | "source": [
280 | "def get_input_output():\n",
281 | " note_input, note_output = prepare_sequences(note, len(note_vocab))\n",
282 | " duration_input, duration_output = prepare_sequences(\n",
283 | " duration, len(duration_vocab))\n",
284 | "\n",
285 | " input = [note_input, duration_input]\n",
286 | " output = [note_output, duration_output]\n",
287 | "\n",
288 | " return input, output\n",
289 | "\n",
290 | "\n",
291 | "input, output = get_input_output()\n",
292 | "\n",
293 | "input[0].shape, output[0].shape, input[1].shape, output[1].shape"
294 | ]
295 | },
296 | {
297 | "cell_type": "code",
298 | "execution_count": 7,
299 | "metadata": {},
300 | "outputs": [
301 | {
302 | "name": "stdout",
303 | "output_type": "stream",
304 | "text": [
305 | "WARNING:tensorflow:From /root/anaconda3/envs/gdl/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:74: The name tf.get_default_graph is deprecated. Please use tf.compat.v1.get_default_graph instead.\n",
306 | "\n",
307 | "WARNING:tensorflow:From /root/anaconda3/envs/gdl/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:517: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.\n",
308 | "\n",
309 | "WARNING:tensorflow:From /root/anaconda3/envs/gdl/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:4138: The name tf.random_uniform is deprecated. Please use tf.random.uniform instead.\n",
310 | "\n",
311 | "WARNING:tensorflow:From /root/anaconda3/envs/gdl/lib/python3.6/site-packages/keras/optimizers.py:790: The name tf.train.Optimizer is deprecated. Please use tf.compat.v1.train.Optimizer instead.\n",
312 | "\n",
313 | "WARNING:tensorflow:From /root/anaconda3/envs/gdl/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:3295: The name tf.log is deprecated. Please use tf.math.log instead.\n",
314 | "\n"
315 | ]
316 | },
317 | {
318 | "data": {
319 | "text/plain": [
320 | ""
321 | ]
322 | },
323 | "execution_count": 7,
324 | "metadata": {},
325 | "output_type": "execute_result"
326 | }
327 | ],
328 | "source": [
329 | "def build_model():\n",
330 | " note_input = keras.layers.Input(shape=(None, ))\n",
331 | " duration_input = keras.layers.Input(shape=(None, ))\n",
332 | "\n",
333 | " x1 = keras.layers.Embedding(len(note_vocab), 100)(note_input)\n",
334 | " x2 = keras.layers.Embedding(len(duration_vocab), 100)(duration_input)\n",
335 | "\n",
336 | " x = keras.layers.Concatenate()([x1, x2])\n",
337 | "\n",
338 | " x = keras.models.Sequential([\n",
339 | " keras.layers.LSTM(256, return_sequences=True),\n",
340 | " keras.layers.LSTM(256, return_sequences=True),\n",
341 | " ])(x)\n",
342 | "\n",
343 | " e = keras.models.Sequential([\n",
344 | " keras.layers.Dense(1, activation='tanh'),\n",
345 | " keras.layers.Reshape([-1]),\n",
346 | " keras.layers.Activation('softmax'),\n",
347 | " keras.layers.RepeatVector(256),\n",
348 | " keras.layers.Permute([2, 1]),\n",
349 | " ])(x)\n",
350 | "\n",
351 | " x = keras.layers.Multiply()([x, e])\n",
352 | " x = keras.layers.Lambda(lambda i: keras.backend.sum(i, axis=1),\n",
353 | " output_shape=(256, ))(x)\n",
354 | " note_output = keras.layers.Dense(len(note_vocab), activation='softmax')(x)\n",
355 | " duration_output = keras.layers.Dense(len(duration_vocab),\n",
356 | " activation='softmax')(x)\n",
357 | "\n",
358 | " model = keras.models.Model([note_input, duration_input],\n",
359 | " [note_output, duration_output])\n",
360 | "\n",
361 | " model.compile(\n",
362 | " loss=['categorical_crossentropy', 'categorical_crossentropy'],\n",
363 | " optimizer=keras.optimizers.RMSprop(lr=0.001))\n",
364 | "\n",
365 | " return model\n",
366 | "\n",
367 | "\n",
368 | "model = build_model()\n",
369 | "\n",
370 | "model"
371 | ]
372 | },
373 | {
374 | "cell_type": "code",
375 | "execution_count": 8,
376 | "metadata": {},
377 | "outputs": [
378 | {
379 | "name": "stdout",
380 | "output_type": "stream",
381 | "text": [
382 | "WARNING:tensorflow:From /root/anaconda3/envs/gdl/lib/python3.6/site-packages/tensorflow/python/ops/math_grad.py:1250: add_dispatch_support..wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n",
383 | "Instructions for updating:\n",
384 | "Use tf.where in 2.0, which has the same broadcast rule as np.where\n",
385 | "WARNING:tensorflow:From /root/anaconda3/envs/gdl/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:986: The name tf.assign_add is deprecated. Please use tf.compat.v1.assign_add instead.\n",
386 | "\n",
387 | "Train on 42529 samples, validate on 10633 samples\n",
388 | "Epoch 1/20\n",
389 | "42529/42529 [==============================] - 99s 2ms/step - loss: 4.2513 - dense_2_loss: 3.5614 - dense_3_loss: 0.6899 - val_loss: 4.7055 - val_dense_2_loss: 3.8565 - val_dense_3_loss: 0.8491\n",
390 | "0\n",
391 | "Epoch 2/20\n",
392 | "42529/42529 [==============================] - 98s 2ms/step - loss: 3.7205 - dense_2_loss: 3.1797 - dense_3_loss: 0.5408 - val_loss: 4.5344 - val_dense_2_loss: 3.7185 - val_dense_3_loss: 0.8159\n",
393 | "1\n",
394 | "Epoch 3/20\n",
395 | "42529/42529 [==============================] - 98s 2ms/step - loss: 3.4867 - dense_2_loss: 3.0231 - dense_3_loss: 0.4636 - val_loss: 4.8023 - val_dense_2_loss: 3.8016 - val_dense_3_loss: 1.0006\n",
396 | "2\n",
397 | "Epoch 4/20\n",
398 | "42529/42529 [==============================] - 98s 2ms/step - loss: 3.3193 - dense_2_loss: 2.9080 - dense_3_loss: 0.4113 - val_loss: 4.4798 - val_dense_2_loss: 3.5864 - val_dense_3_loss: 0.8934\n",
399 | "3\n",
400 | "Epoch 5/20\n",
401 | "42529/42529 [==============================] - 98s 2ms/step - loss: 3.1565 - dense_2_loss: 2.7972 - dense_3_loss: 0.3593 - val_loss: 4.5762 - val_dense_2_loss: 3.5712 - val_dense_3_loss: 1.0049\n",
402 | "4\n",
403 | "Epoch 6/20\n",
404 | "42529/42529 [==============================] - 98s 2ms/step - loss: 3.0020 - dense_2_loss: 2.6898 - dense_3_loss: 0.3121 - val_loss: 4.7727 - val_dense_2_loss: 3.7600 - val_dense_3_loss: 1.0127\n",
405 | "5\n",
406 | "Epoch 7/20\n",
407 | "42529/42529 [==============================] - 98s 2ms/step - loss: 2.8425 - dense_2_loss: 2.5738 - dense_3_loss: 0.2687 - val_loss: 4.9168 - val_dense_2_loss: 3.6715 - val_dense_3_loss: 1.2454\n",
408 | "6\n",
409 | "Epoch 8/20\n",
410 | "42529/42529 [==============================] - 98s 2ms/step - loss: 2.6804 - dense_2_loss: 2.4557 - dense_3_loss: 0.2247 - val_loss: 5.1645 - val_dense_2_loss: 3.7302 - val_dense_3_loss: 1.4343\n",
411 | "7\n",
412 | "Epoch 9/20\n",
413 | "42529/42529 [==============================] - 98s 2ms/step - loss: 2.5197 - dense_2_loss: 2.3274 - dense_3_loss: 0.1923 - val_loss: 5.2515 - val_dense_2_loss: 3.8443 - val_dense_3_loss: 1.4072\n",
414 | "8\n",
415 | "Epoch 10/20\n",
416 | "42529/42529 [==============================] - 98s 2ms/step - loss: 2.3798 - dense_2_loss: 2.2173 - dense_3_loss: 0.1625 - val_loss: 5.8140 - val_dense_2_loss: 4.1481 - val_dense_3_loss: 1.6659\n",
417 | "9\n",
418 | "Epoch 11/20\n",
419 | "42529/42529 [==============================] - 98s 2ms/step - loss: 2.2637 - dense_2_loss: 2.1261 - dense_3_loss: 0.1376 - val_loss: 6.0465 - val_dense_2_loss: 4.2798 - val_dense_3_loss: 1.7667\n",
420 | "10\n",
421 | "Epoch 12/20\n",
422 | "42529/42529 [==============================] - 98s 2ms/step - loss: 2.1609 - dense_2_loss: 2.0378 - dense_3_loss: 0.1232 - val_loss: 6.1596 - val_dense_2_loss: 4.4052 - val_dense_3_loss: 1.7544\n",
423 | "11\n",
424 | "Epoch 13/20\n",
425 | "42529/42529 [==============================] - 98s 2ms/step - loss: 2.0491 - dense_2_loss: 1.9415 - dense_3_loss: 0.1076 - val_loss: 6.3409 - val_dense_2_loss: 4.4816 - val_dense_3_loss: 1.8593\n",
426 | "12\n",
427 | "Epoch 14/20\n",
428 | "42529/42529 [==============================] - 98s 2ms/step - loss: 1.9393 - dense_2_loss: 1.8408 - dense_3_loss: 0.0985 - val_loss: 6.5004 - val_dense_2_loss: 4.5988 - val_dense_3_loss: 1.9016\n",
429 | "13\n",
430 | "Epoch 15/20\n",
431 | "42529/42529 [==============================] - 98s 2ms/step - loss: 1.8155 - dense_2_loss: 1.7267 - dense_3_loss: 0.0888 - val_loss: 6.6667 - val_dense_2_loss: 4.6960 - val_dense_3_loss: 1.9707\n",
432 | "14\n",
433 | "Epoch 16/20\n",
434 | "42529/42529 [==============================] - 98s 2ms/step - loss: 1.7020 - dense_2_loss: 1.6204 - dense_3_loss: 0.0816 - val_loss: 6.8498 - val_dense_2_loss: 4.7874 - val_dense_3_loss: 2.0624\n",
435 | "15\n",
436 | "Epoch 17/20\n",
437 | "42529/42529 [==============================] - 98s 2ms/step - loss: 1.5955 - dense_2_loss: 1.5201 - dense_3_loss: 0.0754 - val_loss: 6.9344 - val_dense_2_loss: 4.8593 - val_dense_3_loss: 2.0751\n",
438 | "16\n",
439 | "Epoch 18/20\n",
440 | "42529/42529 [==============================] - 98s 2ms/step - loss: 1.4926 - dense_2_loss: 1.4226 - dense_3_loss: 0.0700 - val_loss: 7.2345 - val_dense_2_loss: 5.1542 - val_dense_3_loss: 2.0803\n",
441 | "17\n",
442 | "Epoch 19/20\n",
443 | "42529/42529 [==============================] - 98s 2ms/step - loss: 1.4049 - dense_2_loss: 1.3408 - dense_3_loss: 0.0642 - val_loss: 7.5128 - val_dense_2_loss: 5.3212 - val_dense_3_loss: 2.1916\n",
444 | "18\n",
445 | "Epoch 20/20\n",
446 | "42529/42529 [==============================] - 98s 2ms/step - loss: 1.3287 - dense_2_loss: 1.2664 - dense_3_loss: 0.0624 - val_loss: 7.5138 - val_dense_2_loss: 5.3279 - val_dense_3_loss: 2.1859\n",
447 | "19\n"
448 | ]
449 | },
450 | {
451 | "data": {
452 | "text/plain": [
453 | ""
454 | ]
455 | },
456 | "execution_count": 8,
457 | "metadata": {},
458 | "output_type": "execute_result"
459 | }
460 | ],
461 | "source": [
462 | "#在训练过程中打印预测图片\n",
463 | "class CustomCallback(keras.callbacks.Callback):\n",
464 | "\n",
465 | " def on_epoch_end(self, epoch, logs):\n",
466 | " if epoch % 1 == 0:\n",
467 | " print(epoch)\n",
468 | "\n",
469 | "\n",
470 | "model.fit(input,\n",
471 | " output,\n",
472 | " epochs=20,\n",
473 | " batch_size=32,\n",
474 | " validation_split=0.2,\n",
475 | " callbacks=[\n",
476 | " keras.callbacks.EarlyStopping(monitor='loss',\n",
477 | " restore_best_weights=True,\n",
478 | " patience=10),\n",
479 | " CustomCallback()\n",
480 | " ],\n",
481 | " shuffle=True,\n",
482 | " verbose=1)"
483 | ]
484 | },
485 | {
486 | "cell_type": "code",
487 | "execution_count": 9,
488 | "metadata": {},
489 | "outputs": [
490 | {
491 | "name": "stdout",
492 | "output_type": "stream",
493 | "text": [
494 | "[[22, 1], [22, 1], [7, 1], [22, 1], [7, 1], [22, 1], [7, 1], [22, 1], [22, 1], [7, 1], [22, 1], [22, 1], [7, 1], [7, 1], [22, 1], [22, 1], [7, 1], [5, 1], [5, 1], [22, 1], [7, 1], [22, 1], [5, 1], [7, 1], [22, 1], [7, 1], [5, 1], [22, 1], [4, 1], [5, 1], [21, 1], [4, 1], [5, 1], [22, 1], [5, 1], [4, 1], [22, 1], [5, 1], [4, 1], [5, 1], [22, 1], [4, 1], [23, 1], [23, 1], [4, 1], [21, 1], [4, 1], [23, 1], [1, 1], [23, 1]]\n"
495 | ]
496 | },
497 | {
498 | "data": {
499 | "text/plain": [
500 | "50"
501 | ]
502 | },
503 | "execution_count": 9,
504 | "metadata": {},
505 | "output_type": "execute_result"
506 | }
507 | ],
508 | "source": [
509 | "def get_pred():\n",
510 | " pred = []\n",
511 | "\n",
512 | " def random_sample(data):\n",
513 | " data = np.log(data) * 2\n",
514 | " data = np.exp(data)\n",
515 | " data = data / np.sum(data)\n",
516 | " return np.random.choice(len(data), p=data)\n",
517 | "\n",
518 | " note = [note_vocab['START']] * 32\n",
519 | " duration = [duration_vocab[0]] * 32\n",
520 | "\n",
521 | " for _ in range(50):\n",
522 | " input = [np.array([note]), np.array([duration])]\n",
523 | "\n",
524 | " output_note, output_duration = model.predict(input, verbose=0)\n",
525 | "\n",
526 | " output_note = random_sample(output_note[0])\n",
527 | " output_duration = random_sample(output_duration[0])\n",
528 | "\n",
529 | " pred.append([output_note, output_duration])\n",
530 | "\n",
531 | " note.append(output_note)\n",
532 | " duration.append(output_duration)\n",
533 | "\n",
534 | " if len(note) > 32:\n",
535 | " note = note[-32:]\n",
536 | " duration = duration[-32:]\n",
537 | "\n",
538 | " if note_vocab['START'] == output_note:\n",
539 | " break\n",
540 | "\n",
541 | " return pred\n",
542 | "\n",
543 | "\n",
544 | "pred = get_pred()\n",
545 | "\n",
546 | "print(pred)\n",
547 | "\n",
548 | "len(pred)"
549 | ]
550 | },
551 | {
552 | "cell_type": "code",
553 | "execution_count": 10,
554 | "metadata": {},
555 | "outputs": [
556 | {
557 | "data": {
558 | "text/html": [
559 | "\n",
560 | " \n",
561 | " \n",
563 | " "
573 | ],
574 | "text/plain": [
575 | ""
576 | ]
577 | },
578 | "metadata": {},
579 | "output_type": "display_data"
580 | }
581 | ],
582 | "source": [
583 | "def save_midi():\n",
584 | " stream = music21.stream.Stream()\n",
585 | "\n",
586 | " #反字典\n",
587 | " note_vocab_r = {v: k for k, v in note_vocab.items()}\n",
588 | " duration_vocab_r = {v: k for k, v in duration_vocab.items()}\n",
589 | "\n",
590 | " for (n, d) in pred:\n",
591 | " n = note_vocab_r[n]\n",
592 | " d = duration_vocab_r[d]\n",
593 | "\n",
594 | " #复合音符\n",
595 | " if ('.' in n):\n",
596 | " chord_note = []\n",
597 | " for i in n.split('.'):\n",
598 | " note_i = music21.note.Note(i)\n",
599 | " note_i.duration = music21.duration.Duration(d)\n",
600 | " note_i.storedInstrument = music21.instrument.Violoncello()\n",
601 | " chord_note.append(note_i)\n",
602 | " stream.append(music21.chord.Chord(chord_note))\n",
603 | " #rest音符\n",
604 | " elif n == 'rest':\n",
605 | " new_note = music21.note.Rest()\n",
606 | " new_note.duration = music21.duration.Duration(d)\n",
607 | " new_note.storedInstrument = music21.instrument.Violoncello()\n",
608 | " stream.append(new_note)\n",
609 | " #单音符\n",
610 | " elif n != 'START':\n",
611 | " new_note = music21.note.Note(n)\n",
612 | " new_note.duration = music21.duration.Duration(d)\n",
613 | " new_note.storedInstrument = music21.instrument.Violoncello()\n",
614 | " stream.append(new_note)\n",
615 | "\n",
616 | " stream = stream.chordify()\n",
617 | " stream.write('midi', fp='pred.mid')\n",
618 | "\n",
619 | " show('pred.mid')\n",
620 | "\n",
621 | "\n",
622 | "save_midi()"
623 | ]
624 | }
625 | ],
626 | "metadata": {
627 | "kernelspec": {
628 | "display_name": "Python 3 (ipykernel)",
629 | "language": "python",
630 | "name": "python3"
631 | },
632 | "language_info": {
633 | "codemirror_mode": {
634 | "name": "ipython",
635 | "version": 3
636 | },
637 | "file_extension": ".py",
638 | "mimetype": "text/x-python",
639 | "name": "python",
640 | "nbconvert_exporter": "python",
641 | "pygments_lexer": "ipython3",
642 | "version": "3.9.12"
643 | }
644 | },
645 | "nbformat": 4,
646 | "nbformat_minor": 2
647 | }
648 |
--------------------------------------------------------------------------------
/keras/9.musegan创作chorales.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [
8 | {
9 | "data": {
10 | "text/plain": [
11 | "array([[[ 1, -1, -1, -1, -1, -1, -1, -1, -1],\n",
12 | " [-1, 1, -1, -1, -1, -1, -1, -1, -1],\n",
13 | " [-1, -1, 1, -1, -1, -1, -1, -1, -1]],\n",
14 | "\n",
15 | " [[-1, -1, -1, 1, -1, -1, -1, -1, -1],\n",
16 | " [-1, -1, -1, -1, 1, -1, -1, -1, -1],\n",
17 | " [-1, -1, -1, -1, -1, 1, -1, -1, -1]],\n",
18 | "\n",
19 | " [[-1, -1, -1, -1, -1, -1, 1, -1, -1],\n",
20 | " [-1, -1, -1, -1, -1, -1, -1, 1, -1],\n",
21 | " [-1, -1, -1, -1, -1, -1, -1, -1, 1]]], dtype=int32)"
22 | ]
23 | },
24 | "execution_count": 1,
25 | "metadata": {},
26 | "output_type": "execute_result"
27 | }
28 | ],
29 | "source": [
30 | "import numpy as np\n",
31 | "\n",
32 | "\n",
33 | "#数字矩阵转one hot编码的函数\n",
34 | "def build_one_hot(data, max_value):\n",
35 | " data = np.eye(max_value, dtype=np.int32)[data]\n",
36 | " data[data == 0] = -1\n",
37 | "\n",
38 | " return data\n",
39 | "\n",
40 | "\n",
41 | "build_one_hot(np.arange(9).reshape(3, 3), 9)"
42 | ]
43 | },
44 | {
45 | "cell_type": "code",
46 | "execution_count": 2,
47 | "metadata": {},
48 | "outputs": [
49 | {
50 | "name": "stdout",
51 | "output_type": "stream",
52 | "text": [
53 | "data[0]= (192, 4) float16\n",
54 | "data[1]= (228, 4) float16\n",
55 | "data[2]= (208, 4) float16\n",
56 | "data[3]= (432, 4) float16\n",
57 | "data[4]= (260, 4) float16\n",
58 | "data[5]= (212, 4) float16\n",
59 | "data[6]= (292, 4) float16\n",
60 | "data[7]= (180, 4) float16\n",
61 | "data[8]= (132, 4) float16\n",
62 | "data[9]= (192, 4) float16\n",
63 | "data= (229,) object\n",
64 | "new_data= 229 (192, 4) (228, 4)\n",
65 | "data_cut= (229, 32, 4) int32\n"
66 | ]
67 | },
68 | {
69 | "data": {
70 | "text/plain": [
71 | "((229, 2, 16, 84, 4), dtype('int32'))"
72 | ]
73 | },
74 | "execution_count": 2,
75 | "metadata": {},
76 | "output_type": "execute_result"
77 | }
78 | ],
79 | "source": [
80 | "def get_data():\n",
81 | " #加载数据\n",
82 | " data = np.load('../datas/chorales/Jsb16thSeparated.npz',\n",
83 | " encoding='bytes')['train']\n",
84 | "\n",
85 | " #一共229首曲子,每个曲子长度不定,都是4个声部\n",
86 | " for i in range(10):\n",
87 | " print('data[%d]=' % i, data[i].shape, data[i].dtype)\n",
88 | "\n",
89 | " print('data=', data.shape, data.dtype)\n",
90 | "\n",
91 | " #筛除数据中的nan,这数据集做的简直是一坨屎\n",
92 | " new_data = []\n",
93 | " for song in data:\n",
94 | " new_song = []\n",
95 | " for time in song:\n",
96 | " #time -> [4]\n",
97 | "\n",
98 | " if np.isnan(time).any():\n",
99 | " continue\n",
100 | "\n",
101 | " new_song.append(time)\n",
102 | "\n",
103 | " new_song = np.array(new_song, dtype=np.int32)\n",
104 | " new_data.append(new_song)\n",
105 | "\n",
106 | " print('new_data=', len(new_data), new_data[0].shape, new_data[1].shape)\n",
107 | "\n",
108 | " #截取每首曲子的前32个拍子\n",
109 | " data_cut = []\n",
110 | " for song in new_data:\n",
111 | " data_cut.append(song[:32])\n",
112 | "\n",
113 | " #[229, 32, 4]\n",
114 | " data_cut = np.array(data_cut)\n",
115 | "\n",
116 | " print('data_cut=', data_cut.shape, data_cut.dtype)\n",
117 | "\n",
118 | " #分成两条音轨,每条音轨16个拍子\n",
119 | " #[229, 32, 4] -> [229, 2, 16, 4]\n",
120 | " data_cut = data_cut.reshape([229, 2, 16, 4])\n",
121 | "\n",
122 | " #转one hot编码\n",
123 | " #[229, 2, 16, 4] -> [229, 2, 16, 4, 84]\n",
124 | " data_cut = build_one_hot(data_cut, max_value=84)\n",
125 | "\n",
126 | " #交换最后两个维度\n",
127 | " #[229, 2, 16, 4, 84] -> [229, 2, 16, 84, 4]\n",
128 | " data_cut = data_cut.transpose([0, 1, 2, 4, 3])\n",
129 | "\n",
130 | " return data_cut\n",
131 | "\n",
132 | "\n",
133 | "data = get_data()\n",
134 | "\n",
135 | "data.shape, data.dtype"
136 | ]
137 | },
138 | {
139 | "cell_type": "code",
140 | "execution_count": 3,
141 | "metadata": {},
142 | "outputs": [
143 | {
144 | "name": "stdout",
145 | "output_type": "stream",
146 | "text": [
147 | "(array([0, 1, 2, 3, 4]), array([0.25, 0.25, 0.25, 0.25, 0.25], dtype=float32))\n",
148 | "(array([1., 1.]), array([1. , 0.25], dtype=float32))\n",
149 | "(array([0, 1, 2, 3, 4, 5, 6]), array([0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25], dtype=float32))\n",
150 | "(array([1., 1.]), array([1. , 0.75], dtype=float32))\n"
151 | ]
152 | }
153 | ],
154 | "source": [
155 | "def merge_note(note, duration=None):\n",
156 | " if duration is None:\n",
157 | " duration = np.full(note.shape, fill_value=0.25, dtype=np.float32)\n",
158 | "\n",
159 | " #从前往后遍历\n",
160 | " for i in range(len(note) - 1):\n",
161 | " j = i + 1\n",
162 | "\n",
163 | " #判断相连的两个note是否相同,并且duration相加不大于1.0\n",
164 | " if note[i] == note[j] and duration[i] + duration[j] <= 1.0:\n",
165 | "\n",
166 | " #duration合并\n",
167 | " duration[i] += duration[j]\n",
168 | "\n",
169 | " #删除重复的note\n",
170 | " note = np.delete(note, j, axis=0)\n",
171 | " duration = np.delete(duration, j, axis=0)\n",
172 | "\n",
173 | " #递归调用\n",
174 | " return merge_note(note, duration)\n",
175 | "\n",
176 | " return note, duration\n",
177 | "\n",
178 | "\n",
179 | "print(merge_note(np.arange(5)))\n",
180 | "print(merge_note(np.ones(5)))\n",
181 | "\n",
182 | "print(merge_note(np.arange(7)))\n",
183 | "print(merge_note(np.ones(7)))"
184 | ]
185 | },
186 | {
187 | "cell_type": "code",
188 | "execution_count": 4,
189 | "metadata": {},
190 | "outputs": [],
191 | "source": [
192 | "import music21\n",
193 | "\n",
194 | "\n",
195 | "def save_to_mid(data, filename):\n",
196 | " #data -> [32, 4]\n",
197 | " stream = music21.stream.Score()\n",
198 | " stream.append(music21.tempo.MetronomeMark(number=66))\n",
199 | "\n",
200 | " for i in range(4):\n",
201 | " channel = music21.stream.Part()\n",
202 | "\n",
203 | " notes, durations = merge_note(data[:, i])\n",
204 | " notes, durations = notes.tolist(), durations.tolist()\n",
205 | " for n, d in zip(notes, durations):\n",
206 | " note = music21.note.Note(n)\n",
207 | " note.duration = music21.duration.Duration(d)\n",
208 | " channel.append(note)\n",
209 | "\n",
210 | " stream.append(channel)\n",
211 | "\n",
212 | " stream.write('midi', fp=filename)\n",
213 | "\n",
214 | "\n",
215 | "save_to_mid(data[0].argmax(axis=2).reshape(32, 4), 'sample.mid')"
216 | ]
217 | },
218 | {
219 | "cell_type": "code",
220 | "execution_count": 5,
221 | "metadata": {},
222 | "outputs": [
223 | {
224 | "data": {
225 | "text/html": [
226 | "\n",
227 | " \n",
228 | " \n",
230 | " "
240 | ],
241 | "text/plain": [
242 | ""
243 | ]
244 | },
245 | "metadata": {},
246 | "output_type": "display_data"
247 | }
248 | ],
249 | "source": [
250 | "def show(file):\n",
251 | " f = music21.midi.MidiFile()\n",
252 | " f.open(file)\n",
253 | " f.read()\n",
254 | " f.close()\n",
255 | " music21.midi.translate.midiFileToStream(f).show('midi')\n",
256 | "\n",
257 | "\n",
258 | "show('sample.mid')"
259 | ]
260 | },
261 | {
262 | "cell_type": "code",
263 | "execution_count": 6,
264 | "metadata": {},
265 | "outputs": [
266 | {
267 | "name": "stderr",
268 | "output_type": "stream",
269 | "text": [
270 | "Using TensorFlow backend.\n"
271 | ]
272 | },
273 | {
274 | "name": "stdout",
275 | "output_type": "stream",
276 | "text": [
277 | "WARNING:tensorflow:From /root/anaconda3/envs/gdl/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:74: The name tf.get_default_graph is deprecated. Please use tf.compat.v1.get_default_graph instead.\n",
278 | "\n",
279 | "WARNING:tensorflow:From /root/anaconda3/envs/gdl/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:517: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.\n",
280 | "\n",
281 | "WARNING:tensorflow:From /root/anaconda3/envs/gdl/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:4115: The name tf.random_normal is deprecated. Please use tf.random.normal instead.\n",
282 | "\n"
283 | ]
284 | },
285 | {
286 | "data": {
287 | "text/plain": [
288 | ""
289 | ]
290 | },
291 | "execution_count": 6,
292 | "metadata": {},
293 | "output_type": "execute_result"
294 | }
295 | ],
296 | "source": [
297 | "import keras\n",
298 | "\n",
299 | "weight_init = keras.initializers.RandomNormal(mean=0., stddev=0.02)\n",
300 | "\n",
301 | "cls = keras.models.Sequential([\n",
302 | " keras.layers.Conv3D(filters=128,\n",
303 | " kernel_size=(2, 1, 1),\n",
304 | " padding='valid',\n",
305 | " strides=(1, 1, 1),\n",
306 | " kernel_initializer=weight_init,\n",
307 | " input_shape=(2, 16, 84, 4)),\n",
308 | " keras.layers.LeakyReLU(),\n",
309 | " keras.layers.Conv3D(filters=128,\n",
310 | " kernel_size=(1, 1, 1),\n",
311 | " padding='valid',\n",
312 | " strides=(1, 1, 1),\n",
313 | " kernel_initializer=weight_init),\n",
314 | " keras.layers.LeakyReLU(),\n",
315 | " keras.layers.Conv3D(filters=128,\n",
316 | " kernel_size=(1, 1, 12),\n",
317 | " padding='same',\n",
318 | " strides=(1, 1, 12),\n",
319 | " kernel_initializer=weight_init),\n",
320 | " keras.layers.LeakyReLU(),\n",
321 | " keras.layers.Conv3D(filters=128,\n",
322 | " kernel_size=(1, 1, 7),\n",
323 | " padding='same',\n",
324 | " strides=(1, 1, 7),\n",
325 | " kernel_initializer=weight_init),\n",
326 | " keras.layers.LeakyReLU(),\n",
327 | " keras.layers.Conv3D(filters=128,\n",
328 | " kernel_size=(1, 2, 1),\n",
329 | " padding='same',\n",
330 | " strides=(1, 2, 1),\n",
331 | " kernel_initializer=weight_init),\n",
332 | " keras.layers.LeakyReLU(),\n",
333 | " keras.layers.Conv3D(filters=128,\n",
334 | " kernel_size=(1, 2, 1),\n",
335 | " padding='same',\n",
336 | " strides=(1, 2, 1),\n",
337 | " kernel_initializer=weight_init),\n",
338 | " keras.layers.LeakyReLU(),\n",
339 | " keras.layers.Conv3D(filters=256,\n",
340 | " kernel_size=(1, 4, 1),\n",
341 | " padding='same',\n",
342 | " strides=(1, 2, 1),\n",
343 | " kernel_initializer=weight_init),\n",
344 | " keras.layers.LeakyReLU(),\n",
345 | " keras.layers.Conv3D(filters=512,\n",
346 | " kernel_size=(1, 3, 1),\n",
347 | " padding='same',\n",
348 | " strides=(1, 2, 1),\n",
349 | " kernel_initializer=weight_init),\n",
350 | " keras.layers.LeakyReLU(),\n",
351 | " keras.layers.Flatten(),\n",
352 | " keras.layers.Dense(1024, kernel_initializer=weight_init),\n",
353 | " keras.layers.LeakyReLU(),\n",
354 | " keras.layers.Dense(1, activation=None, kernel_initializer=weight_init),\n",
355 | "])\n",
356 | "\n",
357 | "cls"
358 | ]
359 | },
360 | {
361 | "cell_type": "code",
362 | "execution_count": 7,
363 | "metadata": {},
364 | "outputs": [
365 | {
366 | "name": "stdout",
367 | "output_type": "stream",
368 | "text": [
369 | "WARNING:tensorflow:From /root/anaconda3/envs/gdl/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:174: The name tf.get_default_session is deprecated. Please use tf.compat.v1.get_default_session instead.\n",
370 | "\n",
371 | "WARNING:tensorflow:From /root/anaconda3/envs/gdl/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:181: The name tf.ConfigProto is deprecated. Please use tf.compat.v1.ConfigProto instead.\n",
372 | "\n",
373 | "WARNING:tensorflow:From /root/anaconda3/envs/gdl/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:1834: The name tf.nn.fused_batch_norm is deprecated. Please use tf.compat.v1.nn.fused_batch_norm instead.\n",
374 | "\n"
375 | ]
376 | },
377 | {
378 | "data": {
379 | "text/plain": [
380 | ""
381 | ]
382 | },
383 | "execution_count": 7,
384 | "metadata": {},
385 | "output_type": "execute_result"
386 | }
387 | ],
388 | "source": [
389 | "def get_gen():\n",
390 | "\n",
391 | " def TemporalNetwork():\n",
392 | " return keras.models.Sequential([\n",
393 | " keras.layers.Reshape([1, 1, 32], input_shape=(32, )),\n",
394 | " keras.layers.Conv2DTranspose(filters=1024,\n",
395 | " kernel_size=(2, 1),\n",
396 | " padding='valid',\n",
397 | " strides=(1, 1),\n",
398 | " kernel_initializer=weight_init),\n",
399 | " keras.layers.BatchNormalization(momentum=0.9),\n",
400 | " keras.layers.Activation('relu'),\n",
401 | " keras.layers.Conv2DTranspose(filters=32,\n",
402 | " kernel_size=(1, 1),\n",
403 | " padding='valid',\n",
404 | " strides=(1, 1),\n",
405 | " kernel_initializer=weight_init),\n",
406 | " keras.layers.BatchNormalization(momentum=0.9),\n",
407 | " keras.layers.Activation('relu'),\n",
408 | " keras.layers.Reshape([2, 32]),\n",
409 | " ])\n",
410 | "\n",
411 | " def BarGenerator():\n",
412 | " return keras.models.Sequential([\n",
413 | " keras.layers.Dense(1024, input_shape=(128, )),\n",
414 | " keras.layers.BatchNormalization(momentum=0.9),\n",
415 | " keras.layers.Activation('relu'),\n",
416 | " keras.layers.Reshape([2, 1, 512]),\n",
417 | " keras.layers.Conv2DTranspose(filters=512,\n",
418 | " kernel_size=(2, 1),\n",
419 | " padding='same',\n",
420 | " strides=(2, 1),\n",
421 | " kernel_initializer=weight_init),\n",
422 | " keras.layers.BatchNormalization(momentum=0.9),\n",
423 | " keras.layers.Activation('relu'),\n",
424 | " keras.layers.Conv2DTranspose(filters=256,\n",
425 | " kernel_size=(2, 1),\n",
426 | " padding='same',\n",
427 | " strides=(2, 1),\n",
428 | " kernel_initializer=weight_init),\n",
429 | " keras.layers.BatchNormalization(momentum=0.9),\n",
430 | " keras.layers.Activation('relu'),\n",
431 | " keras.layers.Conv2DTranspose(filters=256,\n",
432 | " kernel_size=(2, 1),\n",
433 | " padding='same',\n",
434 | " strides=(2, 1),\n",
435 | " kernel_initializer=weight_init),\n",
436 | " keras.layers.BatchNormalization(momentum=0.9),\n",
437 | " keras.layers.Activation('relu'),\n",
438 | " keras.layers.Conv2DTranspose(filters=256,\n",
439 | " kernel_size=(1, 7),\n",
440 | " padding='same',\n",
441 | " strides=(1, 7),\n",
442 | " kernel_initializer=weight_init),\n",
443 | " keras.layers.BatchNormalization(momentum=0.9),\n",
444 | " keras.layers.Activation('relu'),\n",
445 | " keras.layers.Conv2DTranspose(filters=1,\n",
446 | " kernel_size=(1, 12),\n",
447 | " padding='same',\n",
448 | " strides=(1, 12),\n",
449 | " kernel_initializer=weight_init),\n",
450 | " keras.layers.Activation('tanh'),\n",
451 | " keras.layers.Reshape([1, 16, 84, 1]),\n",
452 | " ])\n",
453 | "\n",
454 | " input_chord = keras.layers.Input(shape=(32, ))\n",
455 | " input_style = keras.layers.Input(shape=(32, ))\n",
456 | " input_melody = keras.layers.Input(shape=(4, 32))\n",
457 | " input_groove = keras.layers.Input(shape=(4, 32))\n",
458 | "\n",
459 | " output_chord = TemporalNetwork()(input_chord)\n",
460 | "\n",
461 | " output = []\n",
462 | " for i in range(2):\n",
463 | " output_c = []\n",
464 | "\n",
465 | " for j in range(4):\n",
466 | "\n",
467 | " output_melody = keras.models.Sequential([\n",
468 | " keras.layers.Lambda(lambda x: x[:, j, :]),\n",
469 | " TemporalNetwork(),\n",
470 | " keras.layers.Lambda(lambda x: x[:, i, :])\n",
471 | " ])(input_melody)\n",
472 | "\n",
473 | " concat = keras.layers.Concatenate(axis=1)([\n",
474 | " keras.layers.Lambda(lambda x: x[:, i, :])(output_chord),\n",
475 | " input_style, output_melody,\n",
476 | " keras.layers.Lambda(lambda x: x[:, j, :])(input_groove)\n",
477 | " ])\n",
478 | " output_c.append(BarGenerator()(concat))\n",
479 | "\n",
480 | " output.append(keras.layers.Concatenate(axis=-1)(output_c))\n",
481 | "\n",
482 | " output = keras.layers.Concatenate(axis=1)(output)\n",
483 | "\n",
484 | " gen = keras.models.Model(\n",
485 | " [input_chord, input_style, input_melody, input_groove], output)\n",
486 | "\n",
487 | " return gen\n",
488 | "\n",
489 | "\n",
490 | "gen = get_gen()\n",
491 | "\n",
492 | "gen"
493 | ]
494 | },
495 | {
496 | "cell_type": "code",
497 | "execution_count": 8,
498 | "metadata": {},
499 | "outputs": [
500 | {
501 | "name": "stdout",
502 | "output_type": "stream",
503 | "text": [
504 | "WARNING:tensorflow:From /root/anaconda3/envs/gdl/lib/python3.6/site-packages/keras/optimizers.py:790: The name tf.train.Optimizer is deprecated. Please use tf.compat.v1.train.Optimizer instead.\n",
505 | "\n"
506 | ]
507 | },
508 | {
509 | "data": {
510 | "text/plain": [
511 | "(,\n",
512 | " )"
513 | ]
514 | },
515 | "execution_count": 8,
516 | "metadata": {},
517 | "output_type": "execute_result"
518 | }
519 | ],
520 | "source": [
521 | "from functools import partial\n",
522 | "\n",
523 | "\n",
524 | "def get_gan():\n",
525 | "\n",
526 | " class RandomMerge(keras.layers.merge._Merge):\n",
527 | "\n",
528 | " def __init__(self):\n",
529 | " super().__init__()\n",
530 | "\n",
531 | " def _merge_function(self, inputs):\n",
532 | " alpha = keras.backend.random_uniform((64, 1, 1, 1, 1))\n",
533 | " return (alpha * inputs[0]) + ((1 - alpha) * inputs[1])\n",
534 | "\n",
535 | " def set_trainable(model, trainable):\n",
536 | " model.trainable = trainable\n",
537 | " for layer in model.layers:\n",
538 | " layer.trainable = trainable\n",
539 | "\n",
540 | " set_trainable(gen, False)\n",
541 | "\n",
542 | " input_cls = keras.layers.Input(shape=[2, 16, 84, 4])\n",
543 | " input_chord = keras.layers.Input(shape=(32, ))\n",
544 | " input_style = keras.layers.Input(shape=(32, ))\n",
545 | " input_melody = keras.layers.Input(shape=(4, 32))\n",
546 | " input_groove = keras.layers.Input(shape=(4, 32))\n",
547 | "\n",
548 | " output_gen = gen([input_chord, input_style, input_melody, input_groove])\n",
549 | "\n",
550 | " output_cls_fake = cls(output_gen)\n",
551 | " output_cls_real = cls(input_cls)\n",
552 | "\n",
553 | " input_merge = RandomMerge()([input_cls, output_gen])\n",
554 | "\n",
555 | " output_cls_merge = cls(input_merge)\n",
556 | "\n",
557 | " def get_grads_loss(y_true, y_pred, input_merge):\n",
558 | " grads = keras.backend.gradients(y_pred, input_merge)[0]\n",
559 | " grads = keras.backend.square(grads)\n",
560 | " grads = keras.backend.sum(grads, axis=np.arange(1, len(grads.shape)))\n",
561 | " grads = keras.backend.sqrt(grads)\n",
562 | " grads = keras.backend.square(1 - grads)\n",
563 | " return keras.backend.mean(grads)\n",
564 | "\n",
565 | " grads_loss = partial(get_grads_loss, input_merge=input_merge)\n",
566 | "\n",
567 | " def wasserstein(y_true, y_pred):\n",
568 | " return -keras.backend.mean(y_true * y_pred)\n",
569 | "\n",
570 | " cls_model = keras.models.Model(\n",
571 | " inputs=[\n",
572 | " input_cls, input_chord, input_style, input_melody, input_groove\n",
573 | " ],\n",
574 | " outputs=[output_cls_real, output_cls_fake, output_cls_merge])\n",
575 | "\n",
576 | " cls_model.compile(loss=[wasserstein, wasserstein, grads_loss],\n",
577 | " optimizer=keras.optimizers.Adam(lr=0.001,\n",
578 | " beta_1=0.5,\n",
579 | " beta_2=0.9),\n",
580 | " loss_weights=[1, 1, 10])\n",
581 | "\n",
582 | " set_trainable(cls, False)\n",
583 | " set_trainable(gen, True)\n",
584 | "\n",
585 | " gan = keras.models.Model(\n",
586 | " [input_chord, input_style, input_melody, input_groove],\n",
587 | " output_cls_fake)\n",
588 | "\n",
589 | " gan.compile(optimizer=keras.optimizers.Adam(lr=0.001,\n",
590 | " beta_1=0.5,\n",
591 | " beta_2=0.9),\n",
592 | " loss=wasserstein)\n",
593 | "\n",
594 | " set_trainable(cls, True)\n",
595 | "\n",
596 | " return gan, cls_model\n",
597 | "\n",
598 | "\n",
599 | "gan, cls_model = get_gan()\n",
600 | "\n",
601 | "gan, cls_model"
602 | ]
603 | },
604 | {
605 | "cell_type": "code",
606 | "execution_count": 9,
607 | "metadata": {},
608 | "outputs": [
609 | {
610 | "data": {
611 | "text/html": [
612 | "\n",
613 | " \n",
614 | " \n",
616 | " "
626 | ],
627 | "text/plain": [
628 | ""
629 | ]
630 | },
631 | "metadata": {},
632 | "output_type": "display_data"
633 | }
634 | ],
635 | "source": [
636 | "def test():\n",
637 | " chord = np.random.normal(0, 1, (1, 32))\n",
638 | " style = np.random.normal(0, 1, (1, 32))\n",
639 | " melody = np.random.normal(0, 1, (1, 4, 32))\n",
640 | " groove = np.random.normal(0, 1, (1, 4, 32))\n",
641 | "\n",
642 | " #[1, 2, 16, 84, 4]\n",
643 | " pred = gen.predict([chord, style, melody, groove])\n",
644 | "\n",
645 | " #[1, 2, 16, 84, 4] -> [1, 2, 16, 4]\n",
646 | " pred = pred.argmax(axis=3)\n",
647 | "\n",
648 | " #[1, 2, 16, 4] -> [32, 4]\n",
649 | " pred = pred.reshape(32, 4)\n",
650 | "\n",
651 | " save_to_mid(pred, 'pred.mid')\n",
652 | "\n",
653 | " show('pred.mid')\n",
654 | "\n",
655 | "\n",
656 | "test()"
657 | ]
658 | },
659 | {
660 | "cell_type": "code",
661 | "execution_count": 10,
662 | "metadata": {},
663 | "outputs": [
664 | {
665 | "name": "stdout",
666 | "output_type": "stream",
667 | "text": [
668 | "WARNING:tensorflow:From /root/anaconda3/envs/gdl/lib/python3.6/site-packages/tensorflow/python/ops/math_grad.py:1250: add_dispatch_support..wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n",
669 | "Instructions for updating:\n",
670 | "Use tf.where in 2.0, which has the same broadcast rule as np.where\n"
671 | ]
672 | },
673 | {
674 | "name": "stderr",
675 | "output_type": "stream",
676 | "text": [
677 | "/root/anaconda3/envs/gdl/lib/python3.6/site-packages/keras/engine/training.py:490: UserWarning: Discrepancy between trainable weights and collected trainable weights, did you set `model.trainable` without calling `model.compile` after ?\n",
678 | " 'Discrepancy between trainable weights and collected trainable'\n",
679 | "/root/anaconda3/envs/gdl/lib/python3.6/site-packages/keras/engine/training.py:490: UserWarning: Discrepancy between trainable weights and collected trainable weights, did you set `model.trainable` without calling `model.compile` after ?\n",
680 | " 'Discrepancy between trainable weights and collected trainable'\n"
681 | ]
682 | },
683 | {
684 | "name": "stdout",
685 | "output_type": "stream",
686 | "text": [
687 | "0 [8.912887, -0.85698175, -0.034562703, 0.98044306] 0.0040728305\n"
688 | ]
689 | },
690 | {
691 | "name": "stderr",
692 | "output_type": "stream",
693 | "text": [
694 | "/root/anaconda3/envs/gdl/lib/python3.6/site-packages/keras/engine/training.py:490: UserWarning: Discrepancy between trainable weights and collected trainable weights, did you set `model.trainable` without calling `model.compile` after ?\n",
695 | " 'Discrepancy between trainable weights and collected trainable'\n"
696 | ]
697 | },
698 | {
699 | "name": "stdout",
700 | "output_type": "stream",
701 | "text": [
702 | "50 [-27.792698, -267.06592, 231.58945, 0.76837736] -282.16797\n",
703 | "100 [-27.88805, -257.02512, 218.39185, 1.0745221] -225.14125\n",
704 | "150 [-18.23226, -43.738167, 19.566328, 0.59395784] -33.739716\n",
705 | "200 [-16.835743, -75.27756, 53.224827, 0.52169865] -45.045284\n",
706 | "250 [-15.013928, -43.07686, 21.97362, 0.60893106] -14.732184\n",
707 | "300 [-14.297857, -106.17346, 89.668396, 0.22072089] -97.43548\n",
708 | "350 [-14.1342125, -32.94745, 15.056252, 0.37569845] -12.34396\n",
709 | "400 [-12.606506, -67.773705, 48.30908, 0.68581194] -37.58142\n",
710 | "450 [-11.947504, -14.487357, 1.7038689, 0.08359844] -15.795556\n",
711 | "500 [-11.203049, -20.237303, 8.629093, 0.040516045] -6.5508084\n",
712 | "550 [-12.3838825, -23.784214, 9.745614, 0.16547178] -12.909263\n",
713 | "600 [-11.5849285, -35.62585, 22.524584, 0.15163384] -24.991646\n",
714 | "650 [-11.719839, -28.288511, 14.372057, 0.21966158] -16.3292\n",
715 | "700 [-10.081033, -24.602768, 12.748704, 0.17730309] -10.602691\n",
716 | "750 [-10.870434, -20.221205, 8.281688, 0.10690833] -7.7463045\n",
717 | "800 [-9.710244, -26.209225, 14.81569, 0.16832903] -12.453541\n",
718 | "850 [-9.430994, -15.119096, 4.843739, 0.08443623] -7.6122823\n",
719 | "900 [-9.68703, -20.41964, 9.861892, 0.08707178] -2.0840075\n",
720 | "950 [-10.114071, -17.73894, 6.2960553, 0.13288136] -6.283645\n"
721 | ]
722 | }
723 | ],
724 | "source": [
725 | "def train():\n",
726 | "\n",
727 | " def train_cls():\n",
728 | " pos = np.ones((64, 1), dtype=np.int32)\n",
729 | " neg = -np.ones((64, 1), dtype=np.int32)\n",
730 | " dummy = np.zeros((64, 1), dtype=np.int32)\n",
731 | "\n",
732 | " chord = np.random.normal(0, 1, (64, 32))\n",
733 | " style = np.random.normal(0, 1, (64, 32))\n",
734 | " melody = np.random.normal(0, 1, (64, 4, 32))\n",
735 | " groove = np.random.normal(0, 1, (64, 4, 32))\n",
736 | "\n",
737 | " data_sub = data[np.random.randint(0, data.shape[0], 64)]\n",
738 | "\n",
739 | " loss_cls = cls_model.train_on_batch(\n",
740 | " [data_sub, chord, style, melody, groove], [pos, neg, dummy])\n",
741 | "\n",
742 | " return loss_cls\n",
743 | "\n",
744 | " def train_gen():\n",
745 | " pos = np.ones((64, 1), dtype=np.int32)\n",
746 | "\n",
747 | " chord = np.random.normal(0, 1, (64, 32))\n",
748 | " style = np.random.normal(0, 1, (64, 32))\n",
749 | " melody = np.random.normal(0, 1, (64, 4, 32))\n",
750 | " groove = np.random.normal(0, 1, (64, 4, 32))\n",
751 | "\n",
752 | " loss_gen = gan.train_on_batch([chord, style, melody, groove], pos)\n",
753 | "\n",
754 | " return loss_gen\n",
755 | "\n",
756 | " for epoch in range(1000):\n",
757 | " for _ in range(5):\n",
758 | " loss_cls = train_cls()\n",
759 | "\n",
760 | " loss_gen = train_gen()\n",
761 | "\n",
762 | " if epoch % 50 == 0:\n",
763 | " print(epoch, loss_cls, loss_gen)\n",
764 | "\n",
765 | "\n",
766 | "train()"
767 | ]
768 | },
769 | {
770 | "cell_type": "code",
771 | "execution_count": 13,
772 | "metadata": {},
773 | "outputs": [
774 | {
775 | "data": {
776 | "text/html": [
777 | "\n",
778 | " \n",
779 | " \n",
781 | " "
791 | ],
792 | "text/plain": [
793 | ""
794 | ]
795 | },
796 | "metadata": {},
797 | "output_type": "display_data"
798 | }
799 | ],
800 | "source": [
801 | "test()"
802 | ]
803 | }
804 | ],
805 | "metadata": {
806 | "kernelspec": {
807 | "display_name": "Python 3 (ipykernel)",
808 | "language": "python",
809 | "name": "python3"
810 | },
811 | "language_info": {
812 | "codemirror_mode": {
813 | "name": "ipython",
814 | "version": 3
815 | },
816 | "file_extension": ".py",
817 | "mimetype": "text/x-python",
818 | "name": "python",
819 | "nbconvert_exporter": "python",
820 | "pygments_lexer": "ipython3",
821 | "version": "3.9.12"
822 | }
823 | },
824 | "nbformat": 4,
825 | "nbformat_minor": 2
826 | }
827 |
--------------------------------------------------------------------------------
/keras/README.md:
--------------------------------------------------------------------------------
1 | 环境信息:
2 |
3 | keras==2.2.4
4 |
5 | tensorflow==1.14.0
6 |
7 |
8 | 引用自:https://github.com/davidADSP/GDL_code
9 |
--------------------------------------------------------------------------------