├── 1.ae.ipynb ├── 10.diffusion.ipynb ├── 11.pix2pix.ipynb ├── 2.vae.ipynb ├── 3.dcgan.ipynb ├── 4.wgan.ipynb ├── 5.wgangp.ipynb ├── 6.cyclegan.ipynb ├── 7.musegan.ipynb ├── 8.style transfer.ipynb ├── 9.simple_style transfer.ipynb ├── README.md ├── datas ├── content.jpeg ├── style.jpg └── temp.midi └── keras ├── 1.ae画mnist.ipynb ├── 10.diffusion.ipynb ├── 2.vae画mnist.ipynb ├── 3.vae画celeba.ipynb ├── 4.gan画quick_draw.ipynb ├── 5.wgan画cifar10.ipynb ├── 6.wgangp画celeba.ipynb ├── 7.cyclegan画apple2orange.ipynb ├── 8.lstm创作cello.ipynb ├── 9.musegan创作chorales.ipynb └── README.md /7.musegan.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "e93fca6e", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "name": "stderr", 11 | "output_type": "stream", 12 | "text": [ 13 | "Using custom data configuration lansinuote--gen.2.chorales-2bf7c47eabbdde89\n", 14 | "Found cached dataset parquet (/root/.cache/huggingface/datasets/lansinuote___parquet/lansinuote--gen.2.chorales-2bf7c47eabbdde89/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)\n" 15 | ] 16 | }, 17 | { 18 | "data": { 19 | "text/plain": [ 20 | "((229, 4, 2, 16, 84), -1.0, 1.0)" 21 | ] 22 | }, 23 | "execution_count": 1, 24 | "metadata": {}, 25 | "output_type": "execute_result" 26 | } 27 | ], 28 | "source": [ 29 | "#加载全部数据到内存中\n", 30 | "def get_data():\n", 31 | " from datasets import load_dataset\n", 32 | " import numpy as np\n", 33 | "\n", 34 | " #加载\n", 35 | " dataset = load_dataset('lansinuote/gen.2.chorales', split='train')\n", 36 | "\n", 37 | " #加载为numpy数据\n", 38 | " data = np.empty((229, 4, 2, 16, 84), dtype=np.float32)\n", 39 | " for i in range(len(dataset)):\n", 40 | " data[i] = dataset[i]['data']\n", 41 | "\n", 42 | " return data\n", 43 | "\n", 44 | "\n", 45 | "data = get_data()\n", 46 | "\n", 47 | "data.shape, data.min(), data.max()" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": 2, 53 | "id": "9e6c9a3e", 54 | "metadata": {}, 55 | "outputs": [ 56 | { 57 | "data": { 58 | "text/plain": [ 59 | "(3, torch.Size([64, 4, 2, 16, 84]))" 60 | ] 61 | }, 62 | "execution_count": 2, 63 | "metadata": {}, 64 | "output_type": "execute_result" 65 | } 66 | ], 67 | "source": [ 68 | "import torch\n", 69 | "\n", 70 | "loader = torch.utils.data.DataLoader(\n", 71 | " dataset=data,\n", 72 | " batch_size=64,\n", 73 | " shuffle=True,\n", 74 | " drop_last=True,\n", 75 | ")\n", 76 | "\n", 77 | "len(loader), next(iter(loader)).shape" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": 3, 83 | "id": "f6100c70", 84 | "metadata": {}, 85 | "outputs": [ 86 | { 87 | "data": { 88 | "text/html": [ 89 | "\n", 90 | "
\n", 91 | " \n", 93 | " \n", 94 | " " 104 | ], 105 | "text/plain": [ 106 | "" 107 | ] 108 | }, 109 | "metadata": {}, 110 | "output_type": "display_data" 111 | }, 112 | { 113 | "data": { 114 | "text/html": [ 115 | "\n", 116 | "
\n", 117 | " \n", 119 | " \n", 120 | " " 130 | ], 131 | "text/plain": [ 132 | "" 133 | ] 134 | }, 135 | "metadata": {}, 136 | "output_type": "display_data" 137 | }, 138 | { 139 | "data": { 140 | "text/html": [ 141 | "\n", 142 | "
\n", 143 | " \n", 145 | " \n", 146 | " " 156 | ], 157 | "text/plain": [ 158 | "" 159 | ] 160 | }, 161 | "metadata": {}, 162 | "output_type": "display_data" 163 | } 164 | ], 165 | "source": [ 166 | "import music21\n", 167 | "\n", 168 | "\n", 169 | "#工具类,不重要\n", 170 | "class Show():\n", 171 | " #工具函数,不重要\n", 172 | " def __merge_note(self, note, duration=None):\n", 173 | " import numpy as np\n", 174 | "\n", 175 | " if duration is None:\n", 176 | " duration = np.full(note.shape, fill_value=0.25, dtype=np.float32)\n", 177 | "\n", 178 | " #从前往后遍历\n", 179 | " for i in range(len(note) - 1):\n", 180 | " j = i + 1\n", 181 | "\n", 182 | " #判断相连的两个note是否相同,并且duration相加不大于1.0\n", 183 | " if note[i] == note[j] and duration[i] + duration[j] <= 1.0:\n", 184 | "\n", 185 | " #duration合并\n", 186 | " duration[i] += duration[j]\n", 187 | "\n", 188 | " #删除重复的note\n", 189 | " note = np.delete(note, j, axis=0)\n", 190 | " duration = np.delete(duration, j, axis=0)\n", 191 | "\n", 192 | " #递归调用\n", 193 | " return self.__merge_note(note, duration)\n", 194 | "\n", 195 | " return note, duration\n", 196 | "\n", 197 | " #工具函数,不重要\n", 198 | " def __save_to_mid(self, data):\n", 199 | " #data -> [32, 4]\n", 200 | " stream = music21.stream.Score()\n", 201 | " stream.append(music21.tempo.MetronomeMark(number=66))\n", 202 | "\n", 203 | " for i in range(4):\n", 204 | " channel = music21.stream.Part()\n", 205 | "\n", 206 | " notes, durations = self.__merge_note(data[:, i])\n", 207 | " notes, durations = notes.tolist(), durations.tolist()\n", 208 | " for n, d in zip(notes, durations):\n", 209 | " note = music21.note.Note(n)\n", 210 | " note.duration = music21.duration.Duration(d)\n", 211 | " channel.append(note)\n", 212 | "\n", 213 | " stream.append(channel)\n", 214 | "\n", 215 | " stream.write('midi', fp='./datas/temp.midi')\n", 216 | "\n", 217 | " def __call__(self, data):\n", 218 | " #[4, 2, 16, 84] -> [4, 2, 16] -> [32, 4]\n", 219 | " data = data.argmax(dim=-1).reshape(32, 4)\n", 220 | " data = data.to('cpu').detach().numpy()\n", 221 | " self.__save_to_mid(data)\n", 222 | "\n", 223 | " f = music21.midi.MidiFile()\n", 224 | " f.open('./datas/temp.midi')\n", 225 | " f.read()\n", 226 | " f.close()\n", 227 | " music21.midi.translate.midiFileToStream(f).show('midi')\n", 228 | "\n", 229 | "\n", 230 | "show = Show()\n", 231 | "\n", 232 | "for _ in range(3):\n", 233 | " show(next(iter(loader))[0])" 234 | ] 235 | }, 236 | { 237 | "cell_type": "code", 238 | "execution_count": 4, 239 | "id": "6fa83528", 240 | "metadata": {}, 241 | "outputs": [ 242 | { 243 | "data": { 244 | "text/plain": [ 245 | "torch.Size([2, 1, 1, 16, 84])" 246 | ] 247 | }, 248 | "execution_count": 4, 249 | "metadata": {}, 250 | "output_type": "execute_result" 251 | } 252 | ], 253 | "source": [ 254 | "def get_gen_track():\n", 255 | " return torch.nn.Sequential(\n", 256 | " torch.nn.Linear(4 * 32, 1024),\n", 257 | " torch.nn.BatchNorm1d(1024),\n", 258 | " torch.nn.ReLU(inplace=True),\n", 259 | " torch.nn.Unflatten(unflattened_size=(512, 2, 1), dim=1),\n", 260 | " torch.nn.ConvTranspose2d(512,\n", 261 | " 512,\n", 262 | " kernel_size=(2, 1),\n", 263 | " stride=(2, 1),\n", 264 | " padding=0),\n", 265 | " torch.nn.BatchNorm2d(512),\n", 266 | " torch.nn.ReLU(inplace=True),\n", 267 | " torch.nn.ConvTranspose2d(512,\n", 268 | " 256,\n", 269 | " kernel_size=(2, 1),\n", 270 | " stride=(2, 1),\n", 271 | " padding=0),\n", 272 | " torch.nn.BatchNorm2d(256),\n", 273 | " torch.nn.ReLU(inplace=True),\n", 274 | " torch.nn.ConvTranspose2d(256,\n", 275 | " 256,\n", 276 | " kernel_size=(2, 1),\n", 277 | " stride=(2, 1),\n", 278 | " padding=0),\n", 279 | " torch.nn.BatchNorm2d(256),\n", 280 | " torch.nn.ReLU(inplace=True),\n", 281 | " torch.nn.ConvTranspose2d(256,\n", 282 | " 256,\n", 283 | " kernel_size=(1, 7),\n", 284 | " stride=(1, 7),\n", 285 | " padding=0),\n", 286 | " torch.nn.BatchNorm2d(256),\n", 287 | " torch.nn.ReLU(inplace=True),\n", 288 | " torch.nn.ConvTranspose2d(256,\n", 289 | " 1,\n", 290 | " kernel_size=(1, 12),\n", 291 | " stride=(1, 12),\n", 292 | " padding=0),\n", 293 | " torch.nn.Unflatten(unflattened_size=(1, 1), dim=1),\n", 294 | " )\n", 295 | "\n", 296 | "\n", 297 | "get_gen_track()(torch.randn(2, 128)).shape" 298 | ] 299 | }, 300 | { 301 | "cell_type": "code", 302 | "execution_count": 5, 303 | "id": "3529afe3", 304 | "metadata": {}, 305 | "outputs": [ 306 | { 307 | "data": { 308 | "text/plain": [ 309 | "torch.Size([2, 32, 2])" 310 | ] 311 | }, 312 | "execution_count": 5, 313 | "metadata": {}, 314 | "output_type": "execute_result" 315 | } 316 | ], 317 | "source": [ 318 | "def get_gen_block():\n", 319 | " return torch.nn.Sequential(\n", 320 | " torch.nn.Unflatten(unflattened_size=(32, 1, 1), dim=1),\n", 321 | " torch.nn.ConvTranspose2d(32,\n", 322 | " 1024,\n", 323 | " kernel_size=(2, 1),\n", 324 | " stride=(1, 1),\n", 325 | " padding=0), torch.nn.BatchNorm2d(1024),\n", 326 | " torch.nn.ReLU(inplace=True),\n", 327 | " torch.nn.ConvTranspose2d(1024,\n", 328 | " 32,\n", 329 | " kernel_size=(2 - 1, 1),\n", 330 | " stride=(1, 1),\n", 331 | " padding=0), torch.nn.BatchNorm2d(32),\n", 332 | " torch.nn.ReLU(inplace=True), torch.nn.Flatten(start_dim=2))\n", 333 | "\n", 334 | "\n", 335 | "get_gen_block()(torch.randn(2, 32)).shape" 336 | ] 337 | }, 338 | { 339 | "cell_type": "code", 340 | "execution_count": 6, 341 | "id": "7f6371b9", 342 | "metadata": {}, 343 | "outputs": [ 344 | { 345 | "data": { 346 | "text/plain": [ 347 | "torch.Size([2, 4, 2, 16, 84])" 348 | ] 349 | }, 350 | "execution_count": 6, 351 | "metadata": {}, 352 | "output_type": "execute_result" 353 | } 354 | ], 355 | "source": [ 356 | "class GEN(torch.nn.Module):\n", 357 | "\n", 358 | " def __init__(self):\n", 359 | " super().__init__()\n", 360 | "\n", 361 | " self.gen_chord = get_gen_block()\n", 362 | "\n", 363 | " self.gen_melody = torch.nn.ModuleList(\n", 364 | " [get_gen_block() for _ in range(4)])\n", 365 | "\n", 366 | " self.gen_track = torch.nn.ModuleList(\n", 367 | " [get_gen_track() for _ in range(4)])\n", 368 | "\n", 369 | " def forward(self, chord, style, melody, groove):\n", 370 | " #chord -> [b, 32]\n", 371 | " #style -> [b, 32]\n", 372 | " #melody -> [b, 4, 32]\n", 373 | " #groove -> [b, 4, 32]\n", 374 | "\n", 375 | " #[b, 32] -> [b, 32, 2]\n", 376 | " out_chord = self.gen_chord(chord)\n", 377 | "\n", 378 | " out_i = []\n", 379 | " for i in range(2):\n", 380 | "\n", 381 | " out_j = []\n", 382 | " for j in range(4):\n", 383 | "\n", 384 | " #[b, 32] -> [b, 32, 2] -> [b, 32]\n", 385 | " out_melody = self.gen_melody[j](melody[:, j])[:, :, i]\n", 386 | "\n", 387 | " #[b, 32+32+32+32] -> [b, 128]\n", 388 | " out = torch.cat(\n", 389 | " [out_chord[:, :, i], style, out_melody, groove[:, j]],\n", 390 | " dim=1)\n", 391 | "\n", 392 | " #[b, 128] -> [b, 1, 1, 16, 84]\n", 393 | " out = self.gen_track[j](out)\n", 394 | "\n", 395 | " out_j.append(out)\n", 396 | "\n", 397 | " #[b, 1*4, 1, 16, 84] -> [b, 4, 1, 16, 84]\n", 398 | " out_i.append(torch.cat(out_j, dim=1))\n", 399 | "\n", 400 | " #[b, 4, 1*2, 16, 84] -> [b, 4, 2, 16, 84]\n", 401 | " out = torch.cat(out_i, dim=2)\n", 402 | "\n", 403 | " return out\n", 404 | "\n", 405 | "\n", 406 | "gen = GEN()\n", 407 | "\n", 408 | "gen(torch.randn(2, 32), torch.randn(2, 32), torch.randn(2, 4, 32),\n", 409 | " torch.randn(2, 4, 32)).shape" 410 | ] 411 | }, 412 | { 413 | "cell_type": "code", 414 | "execution_count": 7, 415 | "id": "c9570ee1", 416 | "metadata": {}, 417 | "outputs": [ 418 | { 419 | "data": { 420 | "text/plain": [ 421 | "tensor([[0.0234],\n", 422 | " [0.0236]], grad_fn=)" 423 | ] 424 | }, 425 | "execution_count": 7, 426 | "metadata": {}, 427 | "output_type": "execute_result" 428 | } 429 | ], 430 | "source": [ 431 | "def get_cls():\n", 432 | " return torch.nn.Sequential(\n", 433 | " torch.nn.Conv3d(4, 128, (2, 1, 1), (1, 1, 1), padding=0),\n", 434 | " torch.nn.LeakyReLU(0.3, inplace=True),\n", 435 | " torch.nn.Conv3d(128, 128, (2 - 1, 1, 1), (1, 1, 1), padding=0),\n", 436 | " torch.nn.LeakyReLU(0.3, inplace=True),\n", 437 | " torch.nn.Conv3d(128, 128, (1, 1, 12), (1, 1, 12), padding=0),\n", 438 | " torch.nn.LeakyReLU(0.3, inplace=True),\n", 439 | " torch.nn.Conv3d(128, 128, (1, 1, 7), (1, 1, 7), padding=0),\n", 440 | " torch.nn.LeakyReLU(0.3, inplace=True),\n", 441 | " torch.nn.Conv3d(128, 128, (1, 2, 1), (1, 2, 1), padding=0),\n", 442 | " torch.nn.LeakyReLU(0.3, inplace=True),\n", 443 | " torch.nn.Conv3d(128, 128, (1, 2, 1), (1, 2, 1), padding=0),\n", 444 | " torch.nn.LeakyReLU(0.3, inplace=True),\n", 445 | " torch.nn.Conv3d(128, 2 * 128, (1, 4, 1), (1, 2, 1), padding=(0, 1, 0)),\n", 446 | " torch.nn.LeakyReLU(0.3, inplace=True),\n", 447 | " torch.nn.Conv3d(2 * 128,\n", 448 | " 4 * 128, (1, 3, 1), (1, 2, 1),\n", 449 | " padding=(0, 1, 0)),\n", 450 | " torch.nn.LeakyReLU(0.3, inplace=True),\n", 451 | " torch.nn.Flatten(),\n", 452 | " torch.nn.Linear(4 * 128, 1024),\n", 453 | " torch.nn.LeakyReLU(0.3, inplace=True),\n", 454 | " torch.nn.Linear(1024, 1),\n", 455 | " )\n", 456 | "\n", 457 | "\n", 458 | "cls = get_cls()\n", 459 | "\n", 460 | "cls(torch.randn(2, 4, 2, 16, 84))" 461 | ] 462 | }, 463 | { 464 | "cell_type": "code", 465 | "execution_count": 8, 466 | "id": "minus-uniform", 467 | "metadata": {}, 468 | "outputs": [ 469 | { 470 | "data": { 471 | "text/plain": [ 472 | "'cuda'" 473 | ] 474 | }, 475 | "execution_count": 8, 476 | "metadata": {}, 477 | "output_type": "execute_result" 478 | } 479 | ], 480 | "source": [ 481 | "def set_requires_grad(model, requires_grad):\n", 482 | " for param in model.parameters():\n", 483 | " param.requires_grad_(requires_grad)\n", 484 | "\n", 485 | "def wasserstein(pred, label):\n", 486 | " return -(pred * label).mean()\n", 487 | "\n", 488 | "\n", 489 | "optimizer_cls = torch.optim.Adam(cls.parameters(), lr=1e-3, betas=(0.5, 0.9))\n", 490 | "optimizer_gen = torch.optim.Adam(gen.parameters(), lr=1e-3, betas=(0.5, 0.9))\n", 491 | "\n", 492 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", 493 | "\n", 494 | "gen.to(device)\n", 495 | "cls.to(device)\n", 496 | "\n", 497 | "device" 498 | ] 499 | }, 500 | { 501 | "cell_type": "code", 502 | "execution_count": 9, 503 | "id": "1170a748", 504 | "metadata": {}, 505 | "outputs": [ 506 | { 507 | "data": { 508 | "text/plain": [ 509 | "tensor(0.9996, device='cuda:0', grad_fn=)" 510 | ] 511 | }, 512 | "execution_count": 9, 513 | "metadata": {}, 514 | "output_type": "execute_result" 515 | } 516 | ], 517 | "source": [ 518 | "def get_gradient_penalty(real, fake):\n", 519 | " #real -> [64, 4, 2, 16, 84]\n", 520 | " #fake -> [64, 4, 2, 16, 84]\n", 521 | "\n", 522 | " r = torch.rand((64, 1, 1, 1, 1), device=device)\n", 523 | " r.requires_grad = True\n", 524 | "\n", 525 | " #[64, 4, 2, 16, 84]\n", 526 | " merge = r * real + (1 - r) * fake\n", 527 | "\n", 528 | " #[64, 4, 2, 16, 84] -> [64, 1]\n", 529 | " pred_merge = cls(merge)\n", 530 | "\n", 531 | " grad = torch.autograd.grad(inputs=merge,\n", 532 | " outputs=pred_merge,\n", 533 | " grad_outputs=torch.ones(64, 1, device=device),\n", 534 | " create_graph=True,\n", 535 | " retain_graph=True)\n", 536 | "\n", 537 | " #[64, 4, 2, 16, 84] -> [64, 10752]\n", 538 | " grad = grad[0].reshape(64, -1)\n", 539 | "\n", 540 | " #[64, 10752] -> [64]\n", 541 | " grad = grad.norm(p=2, dim=1)\n", 542 | "\n", 543 | " #[64] -> scala\n", 544 | " return (1 - grad).pow(2).mean()\n", 545 | "\n", 546 | "\n", 547 | "get_gradient_penalty(torch.randn(64, 4, 2, 16, 84, device=device),\n", 548 | " torch.randn(64, 4, 2, 16, 84, device=device))" 549 | ] 550 | }, 551 | { 552 | "cell_type": "code", 553 | "execution_count": 10, 554 | "id": "31ad7ed8", 555 | "metadata": {}, 556 | "outputs": [ 557 | { 558 | "data": { 559 | "text/plain": [ 560 | "9.99548053741455" 561 | ] 562 | }, 563 | "execution_count": 10, 564 | "metadata": {}, 565 | "output_type": "execute_result" 566 | } 567 | ], 568 | "source": [ 569 | "def train_cls():\n", 570 | " set_requires_grad(cls, True)\n", 571 | " set_requires_grad(gen, False)\n", 572 | " \n", 573 | " #得到三份数据\n", 574 | " real = next(iter(loader)).to(device)\n", 575 | "\n", 576 | " with torch.no_grad():\n", 577 | " cord = torch.randn(64, 32, device=device)\n", 578 | " style = torch.randn(64, 32, device=device)\n", 579 | " melody = torch.randn(64, 4, 32, device=device)\n", 580 | " groove = torch.randn(64, 4, 32, device=device)\n", 581 | " fake = gen(cord, style, melody, groove)\n", 582 | "\n", 583 | " #分别计算\n", 584 | " pred_fake = cls(fake)\n", 585 | " pred_real = cls(real)\n", 586 | "\n", 587 | " #求loss,加权求和\n", 588 | " loss_fake = wasserstein(pred_fake, -torch.ones(64, 1, device=device))\n", 589 | " loss_real = wasserstein(pred_real, torch.ones(64, 1, device=device))\n", 590 | " loss_grad = get_gradient_penalty(real, fake)\n", 591 | "\n", 592 | " loss = loss_fake + loss_real + loss_grad * 10\n", 593 | "\n", 594 | " loss.backward()\n", 595 | " optimizer_cls.step()\n", 596 | " optimizer_cls.zero_grad()\n", 597 | "\n", 598 | " return loss.item()\n", 599 | "\n", 600 | "\n", 601 | "train_cls()" 602 | ] 603 | }, 604 | { 605 | "cell_type": "code", 606 | "execution_count": 11, 607 | "id": "414d78b2", 608 | "metadata": {}, 609 | "outputs": [ 610 | { 611 | "data": { 612 | "text/plain": [ 613 | "0.015252873301506042" 614 | ] 615 | }, 616 | "execution_count": 11, 617 | "metadata": {}, 618 | "output_type": "execute_result" 619 | } 620 | ], 621 | "source": [ 622 | "def train_gen():\n", 623 | " set_requires_grad(cls, False)\n", 624 | " set_requires_grad(gen, True)\n", 625 | " \n", 626 | " cord = torch.randn(64, 32, device=device)\n", 627 | " style = torch.randn(64, 32, device=device)\n", 628 | " melody = torch.randn(64, 4, 32, device=device)\n", 629 | " groove = torch.randn(64, 4, 32, device=device)\n", 630 | "\n", 631 | " fake = gen(cord, style, melody, groove)\n", 632 | " fake_pred = cls(fake)\n", 633 | "\n", 634 | " loss = wasserstein(fake_pred, torch.ones(64, 1, device=device))\n", 635 | " loss.backward()\n", 636 | " optimizer_gen.step()\n", 637 | " optimizer_gen.zero_grad()\n", 638 | "\n", 639 | " return loss.item()\n", 640 | "\n", 641 | "\n", 642 | "train_gen()" 643 | ] 644 | }, 645 | { 646 | "cell_type": "code", 647 | "execution_count": 12, 648 | "id": "taken-cover", 649 | "metadata": { 650 | "scrolled": false 651 | }, 652 | "outputs": [ 653 | { 654 | "name": "stdout", 655 | "output_type": "stream", 656 | "text": [ 657 | "0 -88.02761840820312 365.5087890625\n" 658 | ] 659 | }, 660 | { 661 | "data": { 662 | "text/html": [ 663 | "\n", 664 | "
\n", 665 | " \n", 667 | " \n", 668 | " " 678 | ], 679 | "text/plain": [ 680 | "" 681 | ] 682 | }, 683 | "metadata": {}, 684 | "output_type": "display_data" 685 | }, 686 | { 687 | "name": "stdout", 688 | "output_type": "stream", 689 | "text": [ 690 | "2000 -18.66831398010254 4.117001533508301\n" 691 | ] 692 | }, 693 | { 694 | "data": { 695 | "text/html": [ 696 | "\n", 697 | "
\n", 698 | " \n", 700 | " \n", 701 | " " 711 | ], 712 | "text/plain": [ 713 | "" 714 | ] 715 | }, 716 | "metadata": {}, 717 | "output_type": "display_data" 718 | }, 719 | { 720 | "name": "stdout", 721 | "output_type": "stream", 722 | "text": [ 723 | "4000 -18.55518341064453 -1.6856918334960938\n" 724 | ] 725 | }, 726 | { 727 | "data": { 728 | "text/html": [ 729 | "\n", 730 | "
\n", 731 | " \n", 733 | " \n", 734 | " " 744 | ], 745 | "text/plain": [ 746 | "" 747 | ] 748 | }, 749 | "metadata": {}, 750 | "output_type": "display_data" 751 | }, 752 | { 753 | "name": "stdout", 754 | "output_type": "stream", 755 | "text": [ 756 | "6000 -16.17269515991211 -4.4504075050354\n" 757 | ] 758 | }, 759 | { 760 | "data": { 761 | "text/html": [ 762 | "\n", 763 | "
\n", 764 | " \n", 766 | " \n", 767 | " " 777 | ], 778 | "text/plain": [ 779 | "" 780 | ] 781 | }, 782 | "metadata": {}, 783 | "output_type": "display_data" 784 | }, 785 | { 786 | "name": "stdout", 787 | "output_type": "stream", 788 | "text": [ 789 | "8000 -12.522842407226562 -3.564983367919922\n" 790 | ] 791 | }, 792 | { 793 | "data": { 794 | "text/html": [ 795 | "\n", 796 | "
\n", 797 | " \n", 799 | " \n", 800 | " " 810 | ], 811 | "text/plain": [ 812 | "" 813 | ] 814 | }, 815 | "metadata": {}, 816 | "output_type": "display_data" 817 | }, 818 | { 819 | "name": "stdout", 820 | "output_type": "stream", 821 | "text": [ 822 | "10000 -13.41166877746582 -3.798163890838623\n" 823 | ] 824 | }, 825 | { 826 | "data": { 827 | "text/html": [ 828 | "\n", 829 | "
\n", 830 | " \n", 832 | " \n", 833 | " " 843 | ], 844 | "text/plain": [ 845 | "" 846 | ] 847 | }, 848 | "metadata": {}, 849 | "output_type": "display_data" 850 | }, 851 | { 852 | "name": "stdout", 853 | "output_type": "stream", 854 | "text": [ 855 | "12000 -9.745172500610352 -6.956740379333496\n" 856 | ] 857 | }, 858 | { 859 | "data": { 860 | "text/html": [ 861 | "\n", 862 | "
\n", 863 | " \n", 865 | " \n", 866 | " " 876 | ], 877 | "text/plain": [ 878 | "" 879 | ] 880 | }, 881 | "metadata": {}, 882 | "output_type": "display_data" 883 | }, 884 | { 885 | "name": "stdout", 886 | "output_type": "stream", 887 | "text": [ 888 | "14000 -8.481831550598145 -3.7079086303710938\n" 889 | ] 890 | }, 891 | { 892 | "data": { 893 | "text/html": [ 894 | "\n", 895 | "
\n", 896 | " \n", 898 | " \n", 899 | " " 909 | ], 910 | "text/plain": [ 911 | "" 912 | ] 913 | }, 914 | "metadata": {}, 915 | "output_type": "display_data" 916 | }, 917 | { 918 | "name": "stdout", 919 | "output_type": "stream", 920 | "text": [ 921 | "16000 -8.238203048706055 -1.697916030883789\n" 922 | ] 923 | }, 924 | { 925 | "data": { 926 | "text/html": [ 927 | "\n", 928 | "
\n", 929 | " \n", 931 | " \n", 932 | " " 942 | ], 943 | "text/plain": [ 944 | "" 945 | ] 946 | }, 947 | "metadata": {}, 948 | "output_type": "display_data" 949 | }, 950 | { 951 | "name": "stdout", 952 | "output_type": "stream", 953 | "text": [ 954 | "18000 -7.125642776489258 -2.4775843620300293\n" 955 | ] 956 | }, 957 | { 958 | "data": { 959 | "text/html": [ 960 | "\n", 961 | "
\n", 962 | " \n", 964 | " \n", 965 | " " 975 | ], 976 | "text/plain": [ 977 | "" 978 | ] 979 | }, 980 | "metadata": {}, 981 | "output_type": "display_data" 982 | } 983 | ], 984 | "source": [ 985 | "def train():\n", 986 | " for epoch in range(2_0000):\n", 987 | " for _ in range(5):\n", 988 | " loss_cls = train_cls()\n", 989 | "\n", 990 | " loss_gen = train_gen()\n", 991 | "\n", 992 | " if epoch % 2000 == 0:\n", 993 | " print(epoch, loss_cls, loss_gen)\n", 994 | "\n", 995 | " #这里的b必须要大于1,否则BatchNorm层的计算会出错\n", 996 | " chord = torch.rand(2, 32, device=device)\n", 997 | " style = torch.rand(2, 32, device=device)\n", 998 | " melody = torch.rand(2, 4, 32, device=device)\n", 999 | " groove = torch.rand(2, 4, 32, device=device)\n", 1000 | "\n", 1001 | " #[2, 4, 2, 16, 84]\n", 1002 | " pred = gen(chord, style, melody, groove)\n", 1003 | " show(pred[0])\n", 1004 | "\n", 1005 | "\n", 1006 | "local_training = True\n", 1007 | "\n", 1008 | "if local_training:\n", 1009 | " train()" 1010 | ] 1011 | }, 1012 | { 1013 | "cell_type": "code", 1014 | "execution_count": 13, 1015 | "id": "973e9776", 1016 | "metadata": {}, 1017 | "outputs": [ 1018 | { 1019 | "data": { 1020 | "application/vnd.jupyter.widget-view+json": { 1021 | "model_id": "3ae555b3113b4bd991eab096b47251ba", 1022 | "version_major": 2, 1023 | "version_minor": 0 1024 | }, 1025 | "text/plain": [ 1026 | "pytorch_model.bin: 0%| | 0.00/32.3M [00:00\n", 1092 | " \n", 1094 | " \n", 1095 | " " 1105 | ], 1106 | "text/plain": [ 1107 | "" 1108 | ] 1109 | }, 1110 | "metadata": {}, 1111 | "output_type": "display_data" 1112 | } 1113 | ], 1114 | "source": [ 1115 | "#加载训练好的模型\n", 1116 | "gen = Model.from_pretrained('lansinuote/gen.7.musegan').gen\n", 1117 | "with torch.no_grad():\n", 1118 | " #这里的b必须要大于1,否则BatchNorm层的计算会出错\n", 1119 | " chord = torch.rand(2, 32)\n", 1120 | " style = torch.rand(2, 32)\n", 1121 | " melody = torch.rand(2, 4, 32)\n", 1122 | " groove = torch.rand(2, 4, 32)\n", 1123 | "\n", 1124 | " #[2, 4, 2, 16, 84]\n", 1125 | " pred = gen(chord, style, melody, groove)\n", 1126 | " show(pred[0])" 1127 | ] 1128 | } 1129 | ], 1130 | "metadata": { 1131 | "kernelspec": { 1132 | "display_name": "Python [conda env:pt39]", 1133 | "language": "python", 1134 | "name": "conda-env-pt39-py" 1135 | }, 1136 | "language_info": { 1137 | "codemirror_mode": { 1138 | "name": "ipython", 1139 | "version": 3 1140 | }, 1141 | "file_extension": ".py", 1142 | "mimetype": "text/x-python", 1143 | "name": "python", 1144 | "nbconvert_exporter": "python", 1145 | "pygments_lexer": "ipython3", 1146 | "version": "3.9.13" 1147 | } 1148 | }, 1149 | "nbformat": 4, 1150 | "nbformat_minor": 5 1151 | } 1152 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 视频课程:https://www.bilibili.com/video/BV1hs4y1P7T3 2 |

3 | 环境信息: 4 |
5 | python==3.9 6 |
7 | torch==1.12.1+cu113 8 |
9 | transformers==4.26.1 10 |
11 | datasets==2.9.0 12 |
13 | music21==8.1.0 14 |

15 | 2023年4月27日更新: 16 |
17 | 1.ae,2.vae,3.dcgan,4.wgan,5wgangp,这5个任务的生成模型的code从768维,降低到128维. 18 | -------------------------------------------------------------------------------- /datas/content.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lansinuote/Simple_Generative_in_PyTorch/d349f0efc7062ac258258613fe98d31c13e3495a/datas/content.jpeg -------------------------------------------------------------------------------- /datas/style.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lansinuote/Simple_Generative_in_PyTorch/d349f0efc7062ac258258613fe98d31c13e3495a/datas/style.jpg -------------------------------------------------------------------------------- /datas/temp.midi: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lansinuote/Simple_Generative_in_PyTorch/d349f0efc7062ac258258613fe98d31c13e3495a/datas/temp.midi -------------------------------------------------------------------------------- /keras/8.lstm创作cello.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "data": { 10 | "text/html": [ 11 | "\n", 12 | "
\n", 13 | " \n", 15 | " " 25 | ], 26 | "text/plain": [ 27 | "" 28 | ] 29 | }, 30 | "metadata": {}, 31 | "output_type": "display_data" 32 | } 33 | ], 34 | "source": [ 35 | "import music21\n", 36 | "\n", 37 | "\n", 38 | "def show(file):\n", 39 | " f = music21.midi.MidiFile()\n", 40 | " f.open(file)\n", 41 | " f.read()\n", 42 | " f.close()\n", 43 | " music21.midi.translate.midiFileToStream(f).show('midi')\n", 44 | "\n", 45 | "\n", 46 | "show('../datas/cello/cs2-2all.mid')" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": 2, 52 | "metadata": {}, 53 | "outputs": [ 54 | { 55 | "name": "stdout", 56 | "output_type": "stream", 57 | "text": [ 58 | "([['START', 'A3', 'D2.A2.F3.A3', 'B-3', 'A3', 'G3', 'F3', 'E3', 'D3', 'D3', 'C#3', 'D3', 'E3', 'A2', 'B-2', 'G2', 'F2', 'A2', 'D3', 'F2', 'E2', 'C#3', 'D2.A2.D3', 'E3', 'F3', 'G3', 'A3', 'B-3', 'D2.A2.F#3.C4', 'D4', 'E-4', 'D4', 'C4', 'B-3', 'A3', 'C4', 'B-3', 'A3', 'G3', 'D4', 'F3', 'E3', 'G3', 'B-3', 'D4', 'C4', 'B-3', 'A3', 'G3', 'B-3', 'A3', 'G3', 'F3', 'F3.A3', 'F3.A3', 'B3', 'F3', 'E3', 'D3', 'E3', 'C#4', 'D4', 'C#4', 'D3.D4', 'E4', 'F4', 'E4', 'D4', 'E4', 'D4', 'C4', 'B3', 'C4', 'B3', 'A3', 'G#3', 'A3', 'G#3', 'F#3', 'E3', 'E4', 'C4', 'A3', 'G3', 'F3.E4', 'A3', 'F3', 'D3', 'B2.D3', 'F3', 'D3', 'B2', 'G#2', 'B2', 'E3', 'G#3', 'B3', 'D4', 'C4', 'B3', 'C4', 'A3', 'F3', 'E3', 'D3', 'F3', 'E3', 'D3', 'G#3', 'A3', 'B3', 'D4', 'E3', 'D3', 'C3', 'E3', 'A3', 'D4', 'E3.B3', 'A3', 'E-3.A3', 'E-3', 'E3', 'F#3.G#3', 'A3.B3', 'C4', 'D4', 'C4', 'B3', 'A3.C4', 'D3', 'G#3', 'A3', 'B3', 'A3', 'G#3', 'F#3', 'E3', 'C3.E3.A3', 'F3', 'E3', 'D3', 'C3', 'B2', 'A2', 'G#2.D3.B3', 'E3', 'F3', 'E3', 'D3', 'C3', 'B2', 'D4', 'B3', 'C4', 'A3', 'E3', 'G#3', 'A2', 'C#3', 'E3', 'G3', 'F3', 'E3', 'F3', 'A3', 'D4', 'G#3', 'A3', 'A3', 'D2.A2.F3.A3', 'B-3', 'A3', 'G3', 'F3', 'E3', 'D3', 'D3', 'C#3', 'D3', 'E3', 'A2', 'B-2', 'G2', 'F2', 'A2', 'D3', 'F2', 'E2', 'C#3', 'D2.A2.D3', 'E3', 'F3', 'G3', 'A3', 'B-3', 'D2.A2.F#3.C4', 'D4', 'E-4', 'D4', 'C4', 'B-3', 'A3', 'C4', 'B-3', 'A3', 'G3', 'D4', 'F3', 'E3', 'G3', 'B-3', 'D4', 'C4', 'B-3', 'A3', 'G3', 'B-3', 'A3', 'G3', 'F3', 'F3.A3', 'F3.A3', 'B3', 'F3', 'E3', 'D3', 'E3', 'C#4', 'D4', 'C#4', 'D3.D4', 'E4', 'F4', 'E4', 'D4', 'E4', 'D4', 'C4', 'B3', 'C4', 'B3', 'A3', 'G#3', 'A3', 'G#3', 'F#3', 'E3', 'E4', 'C4', 'A3', 'G3', 'F3.E4', 'A3', 'F3', 'D3', 'B2.D3', 'F3', 'D3', 'B2', 'G#2', 'B2', 'E3', 'G#3', 'B3', 'D4', 'C4', 'B3', 'C4', 'A3', 'F3', 'E3', 'D3', 'F3', 'E3', 'D3', 'G#3', 'A3', 'B3', 'D4', 'E3', 'D3', 'C3', 'E3', 'A3', 'D4', 'E3.B3', 'A3', 'E-3.A3', 'E-3', 'E3', 'F#3.G#3', 'A3.B3', 'C4', 'D4', 'C4', 'B3', 'A3.C4', 'D3', 'G#3', 'A3', 'B3', 'A3', 'G#3', 'F#3', 'E3', 'C3.E3.A3', 'F3', 'E3', 'D3', 'C3', 'B2', 'A2', 'G#2.D3.B3', 'E3', 'F3', 'E3', 'D3', 'C3', 'B2', 'D4', 'B3', 'C4', 'A3', 'E3', 'G#3', 'A2', 'C#3', 'E3', 'G3', 'F3', 'E3', 'F3', 'A3', 'D4', 'G#3', 'A3', 'E3', 'A2.E3.C#4', 'F3', 'G3', 'E3', 'F3', 'A3', 'C#3', 'D3', 'E3', 'B-2', 'A2', 'G2', 'F2', 'A3', 'F3', 'D3', 'G3', 'B2', 'C#3', 'A3', 'G3', 'F3', 'E3', 'D3', 'F#3', 'D3', 'E-3', 'C3', 'B-2', 'G3', 'A2', 'G2', 'F#2', 'A2', 'D3', 'C4', 'B-3', 'F#3', 'G3', 'B-3', 'D4', 'A3', 'B-3', 'G3', 'E-3', 'D3', 'E-3', 'G3', 'C4', 'A3', 'B-3', 'G3', 'D3', 'C3', 'D3', 'G3', 'B-3', 'F#3', 'G3', 'E-3', 'C3', 'B-2', 'C3', 'B-3', 'A3', 'C4', 'E-4', 'G3', 'C3.F#3', 'G3', 'A3', 'D3', 'E-3', 'C3', 'B-2', 'D3', 'G3', 'B-2', 'D2', 'F#3', 'G2.G3', 'A3', 'B-3', 'D4', 'G3', 'F3', 'B-2.E3', 'F3', 'G3', 'E3', 'C3', 'B-2', 'A2', 'F3', 'G2', 'F2', 'E2', 'G3', 'A3', 'B-3', 'B-3', 'A3', 'G3', 'F3', 'A3', 'E3', 'F3', 'D3', 'B-2', 'D3', 'F3', 'A3', 'D4', 'A3', 'B-3', 'G3', 'A2', 'G3', 'C#4', 'D4', 'E4', 'G3', 'A3', 'E3', 'F3', 'D3', 'B-2', 'D3', 'G#2', 'F3', 'E3', 'D3', 'D3', 'C#3', 'B2', 'A2', 'C3', 'A2', 'F#2', 'D3', 'C3', 'A2', 'B2', 'D3', 'F3', 'D3', 'G#2', 'D3', 'C#3', 'E3', 'G3', 'B-3', 'E4', 'A3', 'B-3', 'G3', 'F3', 'C#3', 'D3', 'G#2', 'A2', 'C#3', 'D2', 'D4', 'C4', 'A3', 'B-3', 'G3', 'E3', 'C#4', 'D4', 'A3', 'F3', 'D3', 'D2', 'E3', 'A2.E3.C#4', 'F3', 'G3', 'E3', 'F3', 'A3', 'C#3', 'D3', 'E3', 'B-2', 'A2', 'G2', 'F2', 'A3', 'F3', 'D3', 'G3', 'B2', 'C#3', 'A3', 'G3', 'F3', 'E3', 'D3', 'F#3', 'D3', 'E-3', 'C3', 'B-2', 'G3', 'A2', 'G2', 'F#2', 'A2', 'D3', 'C4', 'B-3', 'F#3', 'G3', 'B-3', 'D4', 'A3', 'B-3', 'G3', 'E-3', 'D3', 'E-3', 'G3', 'C4', 'A3', 'B-3', 'G3', 'D3', 'C3', 'D3', 'G3', 'B-3', 'F#3', 'G3', 'E-3', 'C3', 'B-2', 'C3', 'B-3', 'A3', 'C4', 'E-4', 'G3', 'C3.F#3', 'G3', 'A3', 'D3', 'E-3', 'C3', 'B-2', 'D3', 'G3', 'B-2', 'D2', 'F#3', 'G2.G3', 'A3', 'B-3', 'D4', 'G3', 'F3', 'B-2.E3', 'F3', 'G3', 'E3', 'C3', 'B-2', 'A2', 'F3', 'G2', 'F2', 'E2', 'G3', 'A3', 'B-3', 'B-3', 'A3', 'G3', 'F3', 'A3', 'E3', 'F3', 'D3', 'B-2', 'D3', 'F3', 'A3', 'D4', 'A3', 'B-3', 'G3', 'A2', 'G3', 'C#4', 'D4', 'E4', 'G3', 'A3', 'E3', 'F3', 'D3', 'B-2', 'D3', 'G#2', 'F3', 'E3', 'D3', 'D3', 'C#3', 'B2', 'A2', 'C3', 'A2', 'F#2', 'D3', 'C3', 'A2', 'B2', 'D3', 'F3', 'D3', 'G#2', 'D3', 'C#3', 'E3', 'G3', 'B-3', 'E4', 'A3', 'B-3', 'G3', 'F3', 'C#3', 'D3', 'G#2', 'A2', 'C#3', 'D2', 'D4', 'C4', 'A3', 'B-3', 'G3', 'E3', 'C#4', 'D4', 'A3', 'F3', 'D3', 'D2', 'START'], ['START', 'B-3', 'E-2.B-2.F#3.B-3', 'B3', 'B-3', 'G#3', 'F#3', 'F3', 'E-3', 'E-3', 'D3', 'E-3', 'F3', 'B-2', 'B2', 'G#2', 'F#2', 'B-2', 'E-3', 'F#2', 'F2', 'D3', 'E-2.B-2.E-3', 'F3', 'F#3', 'G#3', 'B-3', 'B3', 'E-2.B-2.G3.C#4', 'E-4', 'E4', 'E-4', 'C#4', 'B3', 'B-3', 'C#4', 'B3', 'B-3', 'G#3', 'E-4', 'F#3', 'F3', 'G#3', 'B3', 'E-4', 'C#4', 'B3', 'B-3', 'G#3', 'B3', 'B-3', 'G#3', 'F#3', 'F#3.B-3', 'F#3.B-3', 'C4', 'F#3', 'F3', 'E-3', 'F3', 'D4', 'E-4', 'D4', 'E-3.E-4', 'F4', 'F#4', 'F4', 'E-4', 'F4', 'E-4', 'C#4', 'C4', 'C#4', 'C4', 'B-3', 'A3', 'B-3', 'A3', 'G3', 'F3', 'F4', 'C#4', 'B-3', 'G#3', 'F#3.F4', 'B-3', 'F#3', 'E-3', 'C3.E-3', 'F#3', 'E-3', 'C3', 'A2', 'C3', 'F3', 'A3', 'C4', 'E-4', 'C#4', 'C4', 'C#4', 'B-3', 'F#3', 'F3', 'E-3', 'F#3', 'F3', 'E-3', 'A3', 'B-3', 'C4', 'E-4', 'F3', 'E-3', 'C#3', 'F3', 'B-3', 'E-4', 'F3.C4', 'B-3', 'E3.B-3', 'E3', 'F3', 'G3.A3', 'B-3.C4', 'C#4', 'E-4', 'C#4', 'C4', 'B-3.C#4', 'E-3', 'A3', 'B-3', 'C4', 'B-3', 'A3', 'G3', 'F3', 'C#3.F3.B-3', 'F#3', 'F3', 'E-3', 'C#3', 'C3', 'B-2', 'A2.E-3.C4', 'F3', 'F#3', 'F3', 'E-3', 'C#3', 'C3', 'E-4', 'C4', 'C#4', 'B-3', 'F3', 'A3', 'B-2', 'D3', 'F3', 'G#3', 'F#3', 'F3', 'F#3', 'B-3', 'E-4', 'A3', 'B-3', 'B-3', 'E-2.B-2.F#3.B-3', 'B3', 'B-3', 'G#3', 'F#3', 'F3', 'E-3', 'E-3', 'D3', 'E-3', 'F3', 'B-2', 'B2', 'G#2', 'F#2', 'B-2', 'E-3', 'F#2', 'F2', 'D3', 'E-2.B-2.E-3', 'F3', 'F#3', 'G#3', 'B-3', 'B3', 'E-2.B-2.G3.C#4', 'E-4', 'E4', 'E-4', 'C#4', 'B3', 'B-3', 'C#4', 'B3', 'B-3', 'G#3', 'E-4', 'F#3', 'F3', 'G#3', 'B3', 'E-4', 'C#4', 'B3', 'B-3', 'G#3', 'B3', 'B-3', 'G#3', 'F#3', 'F#3.B-3', 'F#3.B-3', 'C4', 'F#3', 'F3', 'E-3', 'F3', 'D4', 'E-4', 'D4', 'E-3.E-4', 'F4', 'F#4', 'F4', 'E-4', 'F4', 'E-4', 'C#4', 'C4', 'C#4', 'C4', 'B-3', 'A3', 'B-3', 'A3', 'G3', 'F3', 'F4', 'C#4', 'B-3', 'G#3', 'F#3.F4', 'B-3', 'F#3', 'E-3', 'C3.E-3', 'F#3', 'E-3', 'C3', 'A2', 'C3', 'F3', 'A3', 'C4', 'E-4', 'C#4', 'C4', 'C#4', 'B-3', 'F#3', 'F3', 'E-3', 'F#3', 'F3', 'E-3', 'A3', 'B-3', 'C4', 'E-4', 'F3', 'E-3', 'C#3', 'F3', 'B-3', 'E-4', 'F3.C4', 'B-3', 'E3.B-3', 'E3', 'F3', 'G3.A3', 'B-3.C4', 'C#4', 'E-4', 'C#4', 'C4', 'B-3.C#4', 'E-3', 'A3', 'B-3', 'C4', 'B-3', 'A3', 'G3', 'F3', 'C#3.F3.B-3', 'F#3', 'F3', 'E-3', 'C#3', 'C3', 'B-2', 'A2.E-3.C4', 'F3', 'F#3', 'F3', 'E-3', 'C#3', 'C3', 'E-4', 'C4', 'C#4', 'B-3', 'F3', 'A3', 'B-2', 'D3', 'F3', 'G#3', 'F#3', 'F3', 'F#3', 'B-3', 'E-4', 'A3', 'B-3', 'F3', 'B-2.F3.D4', 'F#3', 'G#3', 'F3', 'F#3', 'B-3', 'D3', 'E-3', 'F3', 'B2', 'B-2', 'G#2', 'F#2', 'B-3', 'F#3', 'E-3', 'G#3', 'C3', 'D3', 'B-3', 'G#3', 'F#3', 'F3', 'E-3', 'G3', 'E-3', 'E3', 'C#3', 'B2', 'G#3', 'B-2', 'G#2', 'G2', 'B-2', 'E-3', 'C#4', 'B3', 'G3', 'G#3', 'B3', 'E-4', 'B-3', 'B3', 'G#3', 'E3', 'E-3', 'E3', 'G#3', 'C#4', 'B-3', 'B3', 'G#3', 'E-3', 'C#3', 'E-3', 'G#3', 'B3', 'G3', 'G#3', 'E3', 'C#3', 'B2', 'C#3', 'B3', 'B-3', 'C#4', 'E4', 'G#3', 'C#3.G3', 'G#3', 'B-3', 'E-3', 'E3', 'C#3', 'B2', 'E-3', 'G#3', 'B2', 'E-2', 'G3', 'G#2.G#3', 'B-3', 'B3', 'E-4', 'G#3', 'F#3', 'B2.F3', 'F#3', 'G#3', 'F3', 'C#3', 'B2', 'B-2', 'F#3', 'G#2', 'F#2', 'F2', 'G#3', 'B-3', 'B3', 'B3', 'B-3', 'G#3', 'F#3', 'B-3', 'F3', 'F#3', 'E-3', 'B2', 'E-3', 'F#3', 'B-3', 'E-4', 'B-3', 'B3', 'G#3', 'B-2', 'G#3', 'D4', 'E-4', 'F4', 'G#3', 'B-3', 'F3', 'F#3', 'E-3', 'B2', 'E-3', 'A2', 'F#3', 'F3', 'E-3', 'E-3', 'D3', 'C3', 'B-2', 'C#3', 'B-2', 'G2', 'E-3', 'C#3', 'B-2', 'C3', 'E-3', 'F#3', 'E-3', 'A2', 'E-3', 'D3', 'F3', 'G#3', 'B3', 'F4', 'B-3', 'B3', 'G#3', 'F#3', 'D3', 'E-3', 'A2', 'B-2', 'D3', 'E-2', 'E-4', 'C#4', 'B-3', 'B3', 'G#3', 'F3', 'D4', 'E-4', 'B-3', 'F#3', 'E-3', 'E-2', 'F3', 'B-2.F3.D4', 'F#3', 'G#3', 'F3', 'F#3', 'B-3', 'D3', 'E-3', 'F3', 'B2', 'B-2', 'G#2', 'F#2', 'B-3', 'F#3', 'E-3', 'G#3', 'C3', 'D3', 'B-3', 'G#3', 'F#3', 'F3', 'E-3', 'G3', 'E-3', 'E3', 'C#3', 'B2', 'G#3', 'B-2', 'G#2', 'G2', 'B-2', 'E-3', 'C#4', 'B3', 'G3', 'G#3', 'B3', 'E-4', 'B-3', 'B3', 'G#3', 'E3', 'E-3', 'E3', 'G#3', 'C#4', 'B-3', 'B3', 'G#3', 'E-3', 'C#3', 'E-3', 'G#3', 'B3', 'G3', 'G#3', 'E3', 'C#3', 'B2', 'C#3', 'B3', 'B-3', 'C#4', 'E4', 'G#3', 'C#3.G3', 'G#3', 'B-3', 'E-3', 'E3', 'C#3', 'B2', 'E-3', 'G#3', 'B2', 'E-2', 'G3', 'G#2.G#3', 'B-3', 'B3', 'E-4', 'G#3', 'F#3', 'B2.F3', 'F#3', 'G#3', 'F3', 'C#3', 'B2', 'B-2', 'F#3', 'G#2', 'F#2', 'F2', 'G#3', 'B-3', 'B3', 'B3', 'B-3', 'G#3', 'F#3', 'B-3', 'F3', 'F#3', 'E-3', 'B2', 'E-3', 'F#3', 'B-3', 'E-4', 'B-3', 'B3', 'G#3', 'B-2', 'G#3', 'D4', 'E-4', 'F4', 'G#3', 'B-3', 'F3', 'F#3', 'E-3', 'B2', 'E-3', 'A2', 'F#3', 'F3', 'E-3', 'E-3', 'D3', 'C3', 'B-2', 'C#3', 'B-2', 'G2', 'E-3', 'C#3', 'B-2', 'C3', 'E-3', 'F#3', 'E-3', 'A2', 'E-3', 'D3', 'F3', 'G#3', 'B3', 'F4', 'B-3', 'B3', 'G#3', 'F#3', 'D3', 'E-3', 'A2', 'B-2', 'D3', 'E-2', 'E-4', 'C#4', 'B-3', 'B3', 'G#3', 'F3', 'D4', 'E-4', 'B-3', 'F#3', 'E-3', 'E-2', 'START']], [[0, 0.25, 0.5, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.5, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.75, 0.25, 0.5, Fraction(1, 6), Fraction(1, 12), 0.25, 0.25, Fraction(1, 12), Fraction(1, 6), Fraction(1, 6), Fraction(1, 12), 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.5, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.75, 0.25, 0.5, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.5, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.75, 0.25, 0.5, Fraction(1, 6), Fraction(1, 12), 0.25, 0.25, Fraction(1, 12), Fraction(1, 6), Fraction(1, 6), Fraction(1, 12), 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.5, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.75, 0.25, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.25, 0.25, 0.5, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.5, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.5, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.75, 0.25, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.25, 0.25, 0.5, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.5, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.5, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.75, 0], [0, 0.25, 0.5, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.5, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.75, 0.25, 0.5, Fraction(1, 6), Fraction(1, 12), 0.25, 0.25, Fraction(1, 12), Fraction(1, 6), Fraction(1, 6), Fraction(1, 12), 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.5, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.75, 0.25, 0.5, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.5, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.75, 0.25, 0.5, Fraction(1, 6), Fraction(1, 12), 0.25, 0.25, Fraction(1, 12), Fraction(1, 6), Fraction(1, 6), Fraction(1, 12), 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.5, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.75, 0.25, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.25, 0.25, 0.5, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.5, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.5, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.75, 0.25, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.25, 0.25, 0.5, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.5, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.75, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.5, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.75, 0]])\n" 59 | ] 60 | } 61 | ], 62 | "source": [ 63 | "def get_note_duration(file):\n", 64 | " #读取数据\n", 65 | " file = music21.converter.parse(file).chordify()\n", 66 | "\n", 67 | " note = []\n", 68 | " duration = []\n", 69 | "\n", 70 | " #不知道为什么是0和1,总之1个mid文件里能解析出两条音轨\n", 71 | " for i in [0, 1]:\n", 72 | "\n", 73 | " #开始符号\n", 74 | " n = ['START']\n", 75 | " d = [0]\n", 76 | "\n", 77 | " #读取音符和持续时间\n", 78 | " for j in file.transpose(i).flat:\n", 79 | " if not isinstance(j, music21.chord.Chord):\n", 80 | " continue\n", 81 | "\n", 82 | " #在同一个时间点可能有多个音符,把他们都前后拼合在一起,以\".\"间隔\n", 83 | " n_join = [k.nameWithOctave for k in j.pitches]\n", 84 | " n_join = '.'.join(n_join)\n", 85 | " n.append(n_join)\n", 86 | "\n", 87 | " #取持续时间\n", 88 | " d.append(j.duration.quarterLength)\n", 89 | "\n", 90 | " #结束符号\n", 91 | " n.append('START')\n", 92 | " d.append(0)\n", 93 | "\n", 94 | " note.append(n)\n", 95 | " duration.append(d)\n", 96 | "\n", 97 | " return note, duration\n", 98 | "\n", 99 | "\n", 100 | "print(get_note_duration('../datas/cello/cs2-2all.mid'))" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": 3, 106 | "metadata": {}, 107 | "outputs": [ 108 | { 109 | "name": "stdout", 110 | "output_type": "stream", 111 | "text": [ 112 | "['START', 'C4', 'B3', 'A3', 'G3', 'F3', 'E3', 'D3', 'C3', 'G2', 'E2', 'G2', 'C2', 'D2', 'E2', 'F2', 'G2', 'A2', 'B2', 'C3']\n", 113 | "[0, 0.5, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 1.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25]\n" 114 | ] 115 | }, 116 | { 117 | "data": { 118 | "text/plain": [ 119 | "(72, 72)" 120 | ] 121 | }, 122 | "execution_count": 3, 123 | "metadata": {}, 124 | "output_type": "execute_result" 125 | } 126 | ], 127 | "source": [ 128 | "import os\n", 129 | "\n", 130 | "\n", 131 | "def load_datas():\n", 132 | " note = []\n", 133 | " duration = []\n", 134 | "\n", 135 | " #读取文件列表\n", 136 | " files = ['../datas/cello/' + i for i in os.listdir('../datas/cello')]\n", 137 | "\n", 138 | " for i in files:\n", 139 | " n, d = get_note_duration(i)\n", 140 | " note.extend(n)\n", 141 | " duration.extend(d)\n", 142 | "\n", 143 | " return note, duration\n", 144 | "\n", 145 | "\n", 146 | "note, duration = load_datas()\n", 147 | "\n", 148 | "print(note[0][:20])\n", 149 | "print(duration[0][:20])\n", 150 | "\n", 151 | "len(note), len(duration)" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": 4, 157 | "metadata": {}, 158 | "outputs": [ 159 | { 160 | "name": "stdout", 161 | "output_type": "stream", 162 | "text": [ 163 | "[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 9, 11, 12, 10, 13, 9, 14, 15, 8]\n", 164 | "[0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 2, 2, 2, 2, 2, 2, 2]\n", 165 | "72 72\n", 166 | "{'START': 0, 'C4': 1, 'B3': 2, 'A3': 3, 'G3': 4, 'F3': 5, 'E3': 6, 'D3': 7, 'C3': 8, 'G2': 9, 'E2': 10, 'C2': 11, 'D2': 12, 'F2': 13, 'A2': 14, 'B2': 15, 'D4': 16, 'E4': 17, 'F4': 18, 'F#3': 19, 'C#3': 20, 'G#3': 21, 'E-3': 22, 'B-3': 23, 'G4': 24, 'E-4': 25, 'F#2': 26, 'B-2': 27, 'F2.G2.D3.B3': 28, 'E-2.G2.G3.A3': 29, 'D2.G2.F3.B3': 30, 'C2.G2.E3.C4': 31, 'G2.D3.C4': 32, 'G2.D3.B3': 33, 'C2.G2.E3.B-3': 34, 'C2.A2.F3.A3': 35, 'C2.G#2.D3.B3': 36, 'D3.B3': 37, 'E3.C4': 38, 'C#4': 39, 'G#2': 40, 'C#2': 41, 'E-2': 42, 'F#4': 43, 'G#4': 44, 'F#2.G#2.E-3.C4': 45, 'E2.G#2.G#3.B-3': 46, 'E-2.G#2.F#3.C4': 47, 'C#2.G#2.F3.C#4': 48, 'G#2.E-3.C#4': 49, 'G#2.E-3.C4': 50, 'C#2.G#2.F3.B3': 51, 'C#2.B-2.F#3.B-3': 52, 'C#2.A2.E-3.C4': 53, 'E-3.C4': 54, 'F3.C#4': 55, 'D2.A2.F3.A3': 56, 'D2.A2.D3': 57, 'D2.A2.F#3.C4': 58, 'F3.A3': 59, 'D3.D4': 60, 'F3.E4': 61, 'B2.D3': 62, 'E3.B3': 63, 'E-3.A3': 64, 'F#3.G#3': 65, 'A3.B3': 66, 'A3.C4': 67, 'C3.E3.A3': 68, 'G#2.D3.B3': 69, 'A2.E3.C#4': 70, 'C3.F#3': 71, 'G2.G3': 72, 'B-2.E3': 73, 'E-2.B-2.F#3.B-3': 74, 'E-2.B-2.E-3': 75, 'E-2.B-2.G3.C#4': 76, 'F#3.B-3': 77, 'E-3.E-4': 78, 'F#3.F4': 79, 'C3.E-3': 80, 'F3.C4': 81, 'E3.B-3': 82, 'G3.A3': 83, 'B-3.C4': 84, 'B-3.C#4': 85, 'C#3.F3.B-3': 86, 'A2.E-3.C4': 87, 'B-2.F3.D4': 88, 'C#3.G3': 89, 'G#2.G#3': 90, 'B2.F3': 91, 'C2.C3': 92, 'C2.B2.F3.G#3': 93, 'C2.E-3': 94, 'G#2.F3': 95, 'G#2.F3.C4': 96, 'C2.G2.F3.C4': 97, 'E3.G3': 98, 'F3.G#3': 99, 'D3.E-3': 100, 'C2.G2.E-3': 101, 'G2.D3': 102, 'G2.F#3.E-4': 103, 'G2.F3.B3': 104, 'G2.E3.C#4': 105, 'D2.B-2.F3.G#3': 106, 'E-2.B-2.E-3.G3': 107, 'F#3.D4': 108, 'G2.D3.B-3': 109, 'C2.G2.E-3.B-3': 110, 'D3.G3': 111, 'G2.D3.G3': 112, 'G3.B-3': 113, 'G3.C4': 114, 'D3.A3': 115, 'D3.B-3': 116, 'D3.C4': 117, 'C3.F3': 118, 'G2.E-3': 119, 'G2.F3': 120, 'B-2.F3': 121, 'B-2.E-3': 122, 'C3.D3': 123, 'B-2.D3': 124, 'A2.D3': 125, 'B2.E-3': 126, 'D3.G#3': 127, 'E-3.G3': 128, 'C3.G3': 129, 'F3.D4': 130, 'F3.B3': 131, 'G3.D4': 132, 'D4.E-4': 133, 'C2.B-2': 134, 'C2.G#2': 135, 'C2.G2': 136, 'D2.B2': 137, 'D2.E-3': 138, 'E-2.G2': 139, 'E-2.G3': 140, 'E2.C3.G3.B-3': 141, 'F#2.C3.E-3.C4': 142, 'C#2.C#3': 143, 'C#2.C3.F#3.A3': 144, 'C#3.E3': 145, 'C#2.E3': 146, 'A2.F#3': 147, 'A2.F#3.C#4': 148, 'C#2.G#2.F#3.C#4': 149, 'F#3.A3': 150, 'E-3.E3': 151, 'C#2.G#2.E3': 152, 'G#2.E-3': 153, 'G#2.G3.E4': 154, 'B3.C#4': 155, 'G#2.F#3.C4': 156, 'G#2.F3.D4': 157, 'E-2.B2.F#3.A3': 158, 'E2.B2.E3.G#3': 159, 'G3.E-4': 160, 'G#2.E-3.B3': 161, 'C#2.G#2.E3.B3': 162, 'E-3.G#3': 163, 'G#2.E-3.G#3': 164, 'G#3.B3': 165, 'G#3.C#4': 166, 'E-3.B-3': 167, 'E-3.B3': 168, 'E-3.C#4': 169, 'C#3.F#3': 170, 'G#2.E3': 171, 'G#2.F#3': 172, 'B2.F#3': 173, 'B2.E3': 174, 'E2.B2.E3': 175, 'C#3.E-3': 176, 'C3.E3': 177, 'E3.G#3': 178, 'C#3.G#3': 179, 'F#3.E-4': 180, 'F#3.C4': 181, 'G#3.E-4': 182, 'E-4.E4': 183, 'C#2.B2': 184, 'C#2.A2': 185, 'C#2.G#2': 186, 'E-2.C3': 187, 'E-2.E3': 188, 'E2.G#2': 189, 'E2.G#3': 190, 'F2.C#3.G#3.B3': 191, 'G2.C#3.E3.C#4': 192, 'F3.B-3': 193, 'F3.G3': 194, 'C2.G2.F3': 195, 'E2.E3.G3': 196, 'A2.G3': 197, 'A2.E3': 198, 'G2.D3.A3': 199, 'C#3.G3.A3': 200, 'G2.E3.B3': 201, 'G#2.F3.B3': 202, 'G#2.E3.B3': 203, 'B3.C4': 204, 'D2.A2.F3.D4': 205, 'C2.A2.F#3': 206, 'C2.A2.G3.A3': 207, 'A2.F#3.C4': 208, 'A2.F#3.D4.E4': 209, 'A2.F3': 210, 'B2.A3': 211, 'B2.G3': 212, 'F#3.B3': 213, 'C#2.G#2.F#3': 214, 'F2.F3.G#3': 215, 'B-2.G#3': 216, 'G#2.E-3.B-3': 217, 'D3.G#3.B-3': 218, 'A2.F3.C4': 219, 'C4.C#4': 220, 'E-2.B-2.F#3.E-4': 221, 'C#2.B-2.G3': 222, 'C#2.B-2.G#3.B-3': 223, 'B-2.G3.C#4': 224, 'B-2.G3.E-4.F4': 225, 'B-2.F#3': 226, 'C3.B-3': 227, 'C3.G#3': 228, 'D2.C3.F#3.E-4': 229, 'D2.B-2.G3.D4': 230, 'D2.B-2.G3.C4': 231, 'C#3.G3.B-3': 232, 'D3.G3.A3': 233, 'E-2.B-2.G3.E-4': 234, 'E-2.C#3.G3.E4': 235, 'E-2.B2.G#3.E-4': 236, 'E-2.B2.G#3.C#4': 237, 'D3.G#3.B3': 238, 'E-3.G#3.B-3': 239, 'E2.B2.G#3.E4': 240, 'D3.E-4': 241, 'D3.F#4': 242, 'D3.G4': 243, 'D3.C#4': 244, 'G3.B3': 245, 'G2.F#3': 246, 'G2.A3': 247, 'G2.C4': 248, 'G2.G#3': 249, 'G2.B3': 250, 'E-3.E4': 251, 'E-3.G4': 252, 'E-3.G#4': 253, 'E-3.D4': 254, 'G3.C#4': 255, 'G#3.C4': 256, 'G#2.G3': 257, 'G#2.B-3': 258, 'G#2.C#4': 259, 'G#2.A3': 260, 'G#2.C4': 261, 'C2.E-3.G3': 262, 'C3.G3.A3': 263, 'C3.G3.B-3': 264, 'B-2.G3': 265, 'D3.F#3': 266, 'C2.G2.E3': 267, 'C2.E3.G3': 268, 'G2.E-3.B-3': 269, 'F2.G#2.E-3': 270, 'F2.G#2.D3': 271, 'E-2.B-2.G3': 272, 'F2.B2.G3': 273, 'F2.C3.D3': 274, 'G#2.D3': 275, 'F2.C3.E-3': 276, 'C#2.E3.G#3': 277, 'C#3.G#3.B-3': 278, 'C#3.G#3.B3': 279, 'B2.G#3': 280, 'C#2.G#2.F3': 281, 'C#2.F3.G#3': 282, 'C#3.A3': 283, 'F#2.A2.E3': 284, 'F#2.A2.E-3': 285, 'E3.A3': 286, 'E2.B2.G#3': 287, 'F#2.C3.G#3': 288, 'E2.C#3': 289, 'F#2.C#3.E-3': 290, 'A2.E-3': 291, 'F#2.C#3.E3': 292, 'G2.E3.C4': 293, 'B2.C3': 294, 'D2.A2.F#3': 295, 'A2.B2': 296, 'E3.F3': 297, 'A2.E3.C4': 298, 'G#2.F3.C#4': 299, 'C3.C#3': 300, 'B-2.C3': 301, 'F3.F#3': 302, 'B-2.F3.C#4': 303, 'B-2.F#3.C#4': 304, 'D3.A3.F#4': 305, 'G2.D3.B3.F#4': 306, 'G3.B3.E4': 307, 'G3.B3.F#4': 308, 'E3.D4': 309, 'E3.C#4': 310, 'G2.E3.C#4.A4': 311, 'A4': 312, 'F#3.D4.A4': 313, 'B4': 314, 'A3.G4': 315, 'A3.F#4': 316, 'D3.A3.G4': 317, 'F#3.C#4.A4': 318, 'B2.F#3.E-4.A4': 319, 'B3.G4': 320, 'B3.F#4': 321, 'E3.B3.A4': 322, 'E3.B3.G4': 323, 'D3.B3.G4': 324, 'B2.F#3.B3': 325, 'B2.F#3.D4': 326, 'A2.F#3.D4': 327, 'G#2.E3.D4': 328, 'A2.E3.D4': 329, 'A3.E4': 330, 'B2.D3.B3.F#4': 331, 'G2.D3.B3.A4': 332, 'A3.D4': 333, 'D2.A2.F#3.D4': 334, 'F#3.G3': 335, 'D3.C#4.E4': 336, 'D3.E4': 337, 'D3.A4': 338, 'E-3.B-3.G4': 339, 'G#2.E-3.C4.G4': 340, 'G#3.C4.F4': 341, 'G#3.C4.G4': 342, 'F3.E-4': 343, 'G#2.F3.D4.B-4': 344, 'B-4': 345, 'G3.E-4.B-4': 346, 'C5': 347, 'B-3.G#4': 348, 'B-3.G4': 349, 'E-3.B-3.G#4': 350, 'G3.D4.B-4': 351, 'C3.G3.E4.B-4': 352, 'C4.G#4': 353, 'C4.G4': 354, 'F3.C4.B-4': 355, 'F3.C4.G#4': 356, 'E-3.C4.G#4': 357, 'C3.G3.C4': 358, 'C3.G3.E-4': 359, 'B-2.G3.E-4': 360, 'A2.F3.E-4': 361, 'B-2.F3.E-4': 362, 'B-3.F4': 363, 'C3.E-3.C4.G4': 364, 'G#2.E-3.C4.B-4': 365, 'B-3.E-4': 366, 'G3.G#3': 367, 'E-3.D4.F4': 368, 'E-3.F4': 369, 'E-3.B-4': 370, 'D3.E3': 371, 'F2.A2': 372, 'F#2.A2.D3.A3': 373, 'B3.D4': 374, 'C4.E4': 375, 'C4.F#4': 376, 'G2.B2': 377, 'E2.G2': 378, 'D4.E4': 379, 'E3.F#3': 380, 'G#3.A3': 381, 'A2.E3.A3': 382, 'C4.D4': 383, 'C#3.D3': 384, 'A2.C3': 385, 'E-3.F3': 386, 'F#2.B-2': 387, 'G2.B-2.E-3.B-3': 388, 'C#3.F3': 389, 'C4.E-4': 390, 'C#4.F4': 391, 'C#4.G4': 392, 'G#2.C3': 393, 'F2.G#2': 394, 'G#3.B-3': 395, 'E-4.F4': 396, 'A3.B-3': 397, 'B-2.F3.B-3': 398, 'C#4.E-4': 399, 'B-2.C#3': 400, 'D3.C#4.F#4': 401, 'C#3.E3.B3.E4': 402, 'E2.B2.G#3.D4': 403, 'C#3.E3.A3': 404, 'C#4.D4': 405, 'G#3.D4.E4': 406, 'E3.D4.G#4': 407, 'D4.G#4': 408, 'D4.B4': 409, 'B3.E4': 410, 'C#4.A4': 411, 'C#4.E4': 412, 'A3.E4.F#4': 413, 'A3.E4.G4': 414, 'E4.G4': 415, 'F#4.G4': 416, 'C#5': 417, 'D5': 418, 'F#2.G2': 419, 'E2.B2.G3': 420, 'D2.B2.G3': 421, 'F#4.D5': 422, 'F#4.A4': 423, 'D4.A4': 424, 'E-3.D4.G4': 425, 'D3.F3.C4.F4': 426, 'F2.C3.A3.E-4': 427, 'D3.F3.B-3': 428, 'A3.E-4.F4': 429, 'F3.E-4.A4': 430, 'E-4.A4': 431, 'E-4.C5': 432, 'C4.F4': 433, 'D4.B-4': 434, 'D4.F4': 435, 'B-3.F4.G4': 436, 'B-3.F4.G#4': 437, 'F4.G#4': 438, 'G4.G#4': 439, 'E-5': 440, 'G2.G#2': 441, 'F2.C3.G#3': 442, 'E-2.C3.G#3': 443, 'G4.E-5': 444, 'G4.B-4': 445, 'E-4.B-4': 446, 'C2.G2.E-3.C4': 447, 'E-2.G2.D3': 448, 'E2.C3.G3': 449, 'D2.B-2.G#3': 450, 'G#3.F4': 451, 'E-3.F3.G3': 452, 'C2.A2.F#3.D4': 453, 'G2.B3.D4': 454, 'B-2.D3.G#3': 455, 'B-2.D3.E-3.G#3': 456, 'G#2.B-2': 457, 'C2.B-2.E3': 458, 'D3.G3.G#3': 459, 'C#2.G#2.E3.C#4': 460, 'E2.G#2.E-3': 461, 'F2.C#3.G#3': 462, 'F#2.C#3.A3': 463, 'E-2.B2.A3': 464, 'E3.F#3.G#3': 465, 'C#2.B-2.G3.E-4': 466, 'G#2.C4.E-4': 467, 'E4.F#4': 468, 'B-3.B3': 469, 'B2.E-3.A3': 470, 'B2.E-3.E3.A3': 471, 'C#2.B2.F3': 472, 'E-3.G#3.A3': 473, 'D2.A2.F3': 474, 'F2.A2.D3': 475, 'F2.A2.E3': 476, 'A2.A3': 477, 'B-2.D3.A3': 478, 'B-2.A3': 479, 'B-2.B-3': 480, 'G2.F3.B-3': 481, 'G2.F3.C4': 482, 'G2.F3.D4': 483, 'G2.E3': 484, 'F2.A2.D3.A3': 485, 'G2.E3.F3': 486, 'G2.F3.G3': 487, 'G2.D3.C#4': 488, 'G2.D3.D4': 489, 'G#3.D4': 490, 'E-2.B-2.F#3': 491, 'F#2.B-2.E-3': 492, 'F#2.B-2.F3': 493, 'B2.E-3.B-3': 494, 'B2.B-3': 495, 'B2.B3': 496, 'G#2.F#3.B3': 497, 'G#2.F#3.C#4': 498, 'G#2.F#3.E-4': 499, 'D2.B2.G#3': 500, 'F#2.B-2.E-3.B-3': 501, 'G#2.F3.F#3': 502, 'G#2.F#3.G#3': 503, 'F#3.C#4': 504, 'G#2.E-3.D4': 505, 'G#2.E-3.E-4': 506, 'A3.E-4': 507, 'C3.A3.E-4': 508, 'F2.C3': 509, 'F2.B-2': 510, 'F2.D3.G#3': 511, 'C#3.B-3.E4': 512, 'F#2.C#3': 513, 'F#2.B2': 514, 'F#2.E-3.A3': 515, 'G2.B-3': 516, 'F2.A3': 517, 'D3.F4': 518, 'C3.E3.B3': 519, 'G#2.B3': 520, 'F#2.B-3': 521, 'E-3.F#4': 522, 'C#3.F3.C4': 523, 'E5': 524, 'F#5': 525, 'G5': 526, 'B2.G#3.D4': 527, 'B-2.G3.D4': 528, 'F5': 529, 'G#5': 530, 'B2.G#3.E-4': 531, 'A2.F#3.E-4': 532, 'G2.G3.B3.E4': 533, 'G2.G3.B3.C#4': 534, 'A2.F#3.D4.A4': 535, 'A2.F#3.D4.F#4': 536, 'E3.C#4.G4': 537, 'G#2.E3.D4.B4': 538, 'G#2.E3.C#4': 539, 'A2.E3.C#4.E4': 540, 'F#2.E3.C#4.E4': 541, 'F#3.A3.E4': 542, 'E-3.C4.F#4': 543, 'E3.B3.F#4': 544, 'D4.C5': 545, 'G3.D4.B4': 546, 'B2.E3.D4': 547, 'B2.E3.E4': 548, 'C3.E3.E4': 549, 'C3.E3.F#4': 550, 'G2.D3.C4.G4': 551, 'G2.D3.A3.G4': 552, 'G3.B3.B4': 553, 'F#3.B3.A4': 554, 'A2.E3.C#4.G#4': 555, 'A2.E3.C#4.A4': 556, 'A2.F3.D4.A4': 557, 'A3.F4': 558, 'B-2.F#3.C#4.E4': 559, 'D3.C4.F#4': 560, 'G3.E4': 561, 'G2.F#3.D4': 562, 'G2.E3.D4': 563, 'A2.G3.D4': 564, 'D2.A2.F#3.C#4': 565, 'G#2.G#3.C4.F4': 566, 'G#2.G#3.C4.D4': 567, 'B-2.G3.E-4.B-4': 568, 'B-2.G3.E-4.G4': 569, 'F3.D4.G#4': 570, 'A2.F3.E-4.C5': 571, 'A2.F3.D4': 572, 'B-2.F3.D4.F4': 573, 'G2.F3.D4.F4': 574, 'G3.B-3.F4': 575, 'F3.C4.G4': 576, 'E-4.C#5': 577, 'G#3.E-4.C5': 578, 'C3.F3.E-4': 579, 'C3.F3.F4': 580, 'C#3.F3.F4': 581, 'C#3.F3.G4': 582, 'G#2.E-3.C#4.G#4': 583, 'G#2.E-3.B-3.G#4': 584, 'G#3.C4.C5': 585, 'G3.C4.B-4': 586, 'B-2.F3.D4.A4': 587, 'B-2.F3.D4.B-4': 588, 'B-2.F#3.E-4.B-4': 589, 'B-3.F#4': 590, 'B2.G3.D4.F4': 591, 'E-3.C#4.G4': 592, 'G#2.G3.E-4': 593, 'G#2.F3.E-4': 594, 'B-2.G#3.E-4': 595, 'E-2.B-2.G3.D4': 596, 'E-2.B-2.G#3': 597, 'E-2.B-2.F3': 598, 'B-2.F3.C4': 599, 'C2.A2.E-3': 600, 'A2.E3.B3': 601, 'E2.B2.A3': 602, 'E2.B2.F#3': 603, 'E3.E-4': 604, 'C3.A3': 605, 'C#3.G#3.E4': 606, 'B2.F#3.E-4': 607, 'B2.F#3.E4': 608, 'B2.F#3.C#4': 609, 'C#2.B-2.E3': 610, 'A2.G#3': 611, 'C#3.B3': 612, 'A2.F#3.E4': 613, 'G2.B3.G4': 614, 'G#2.C4.G#4': 615, 'D2.B2.F3': 616, 'C2.G2.D3': 617, 'C3.G#3.E-4': 618, 'E-2.C3.F#3': 619, 'C#2.G#2.E-3': 620, 'C#3.A3.E4': 621, 'G#3.E4': 622, 'C#4.D4.E4': 623, 'E3.C#4.D4.E4': 624, 'E3.C#4.E4': 625, 'E3.D4.E4': 626, 'B3.D4.E4': 627, 'A3.D4.E4': 628, 'B3.C#4.D4.E4': 629, 'F#3.B3.C#4.D4': 630, 'B3.G4.A4': 631, 'C#4.D4.E4.F#4.G4': 632, 'E3.B3.C#4.D4.E4.F#4.G4': 633, 'D4.E4.F#4.G4': 634, 'E4.F#4.G4': 635, 'E4.F#4.G4.B4': 636, 'A3.E4.F#4.G4': 637, 'A3.F#4.G4': 638, 'D4.F#4': 639, 'B-4.B4.C#5.D5': 640, 'B-3.B-4.B4.C#5.D5': 641, 'B-3.E4': 642, 'B3.C#4.E4': 643, 'B3.C#4.F#4.G#4.A4': 644, 'F#3.B3.C#4.G#4.A4': 645, 'G#3.G#4.A4': 646, 'A3.G#4.A4': 647, 'B3.G#4.A4': 648, 'A3.C#4': 649, 'G#4.A4': 650, 'F#3.C#4.E4': 651, 'E3.D4.G#4.A4': 652, 'D4.E4.F#4': 653, 'C#4.D4.F#4': 654, 'B3.D4.F#4': 655, 'A3.D4.F#4': 656, 'G#3.D4.F#4': 657, 'G#4.B4': 658, 'A3.B3.C#4': 659, 'G2.E3.A3.C#4.A4': 660, 'D3.E3.F#4.G4': 661, 'D3.E3.F#3.F#4.G4': 662, 'D3.E3.F#3': 663, 'C2.D3.E3': 664, 'B2.G3.B3.C4.D4': 665, 'A2.C3.D3.E3': 666, 'G#2.C3.D3': 667, 'G#2.F#4': 668, 'F4.F#4': 669, 'B2.F#3.C#4.D4.E4': 670, 'A3.C#4.E4': 671, 'F#3.G3.A3': 672, 'D3.E3.F#3.A3': 673, 'B3.C#4.D4': 674, 'F3.C#4.D4': 675, 'B2.C#3': 676, 'A2.F#4': 677, 'C#4.F#4': 678, 'A4.B4': 679, 'B2.F#3.A3.B3': 680, 'B2.F#3.A3.B3.C4': 681, 'A3.B3.C4': 682, 'A3.B3.A4': 683, 'A3.B3.G4': 684, 'A3.B3.F#4.G4': 685, 'F#4.G4.A4': 686, 'E3.B3.F#4.G4': 687, 'G4.A4': 688, 'E-3.F#3': 689, 'E4.G#4': 690, 'D4.E-4.F4': 691, 'F3.D4.E-4.F4': 692, 'F3.D4.F4': 693, 'F3.E-4.F4': 694, 'C4.E-4.F4': 695, 'B-3.E-4.F4': 696, 'C4.D4.E-4.F4': 697, 'G3.C4.D4.E-4': 698, 'C4.G#4.B-4': 699, 'D4.E-4.F4.G4.G#4': 700, 'F3.C4.D4.E-4.F4.G4.G#4': 701, 'E-4.F4.G4.G#4': 702, 'F4.G4.G#4': 703, 'F4.G4.G#4.C5': 704, 'B-3.F4.G4.G#4': 705, 'B-3.G4.G#4': 706, 'E-4.G4': 707, 'B4.C5.D5.E-5': 708, 'B3.B4.C5.D5.E-5': 709, 'B3.F4': 710, 'C4.D4.F4': 711, 'C4.D4.G4.A4.B-4': 712, 'G3.C4.D4.A4.B-4': 713, 'A3.A4.B-4': 714, 'B-3.A4.B-4': 715, 'C4.A4.B-4': 716, 'B-3.D4': 717, 'A4.B-4': 718, 'G3.D4.F4': 719, 'F3.E-4.A4.B-4': 720, 'E-4.F4.G4': 721, 'D4.E-4.G4': 722, 'C4.E-4.G4': 723, 'B-3.E-4.G4': 724, 'A3.E-4.G4': 725, 'A4.C5': 726, 'F4.G4': 727, 'B-3.C4.D4': 728, 'G#2.F3.B-3.D4.B-4': 729, 'E-3.F3.G4.G#4': 730, 'E-3.F3.G3.G4.G#4': 731, 'C#2.E-3.F3': 732, 'C3.G#3.C4.C#4.E-4': 733, 'B-2.C#3.E-3.F3': 734, 'A2.C#3.E-3': 735, 'A2.G4': 736, 'C3.G3.D4.E-4.F4': 737, 'B-3.D4.F4': 738, 'G3.G#3.B-3': 739, 'E-3.F3.G3.B-3': 740, 'C4.D4.E-4': 741, 'F#3.D4.E-4': 742, 'B-2.G4': 743, 'D4.G4': 744, 'B-4.C5': 745, 'E4.F4': 746, 'C3.G3.E4': 747, 'C3.G3.B-3.C4': 748, 'C3.G3.B-3.C4.C#4': 749, 'B-3.C4.C#4': 750, 'B-3.C4.B-4': 751, 'B-3.C4.G#4': 752, 'B-3.C4.G4.G#4': 753, 'G4.G#4.B-4': 754, 'F3.C4.G4.G#4': 755, 'G#4.B-4': 756, 'C#3.B-3': 757, 'F4.A4': 758, 'D3.F3.A3': 759, 'C3.E3.B-3': 760, 'G2.G3.E4': 761, 'B-2.E3.D4': 762, 'D3.F3': 763, 'E-3.F#3.B-3': 764, 'C#3.F3.B3': 765, 'G#2.G#3.F4': 766, 'B2.F3.E-4': 767, 'C#3.G#3.C#4': 768, 'C#2.B-2.F3': 769, 'A2.G3.C#4': 770, 'D2.B2.F#3': 771, 'B-2.G#3.D4': 772, 'B-2.F#3.E-4': 773}\n", 167 | "{0: 0, 0.5: 1, 0.25: 2, 1.25: 3, 1.0: 4, 3.0: 5, 0.75: 6, Fraction(1, 6): 7, Fraction(1, 12): 8, 1.5: 9, Fraction(2, 3): 10, 2.0: 11, Fraction(1, 3): 12, Fraction(4, 3): 13, 2.25: 14, 1.75: 15, 2.5: 16, 4.0: 17, Fraction(5, 12): 18}\n", 168 | "774 19\n" 169 | ] 170 | } 171 | ], 172 | "source": [ 173 | "def encode(data):\n", 174 | " #编字典\n", 175 | " vocab = {}\n", 176 | " for i in data:\n", 177 | " for j in i:\n", 178 | " if j not in vocab:\n", 179 | " vocab[j] = len(vocab)\n", 180 | "\n", 181 | " #用字典编码\n", 182 | " new_date = []\n", 183 | " for line in data:\n", 184 | " new_line = [vocab[node] for node in line]\n", 185 | " new_date.append(new_line)\n", 186 | "\n", 187 | " return new_date, vocab\n", 188 | "\n", 189 | "\n", 190 | "note, note_vocab = encode(note)\n", 191 | "duration, duration_vocab = encode(duration)\n", 192 | "\n", 193 | "print(note[0][:20])\n", 194 | "print(duration[0][:20])\n", 195 | "print(len(note), len(duration))\n", 196 | "\n", 197 | "print(note_vocab)\n", 198 | "print(duration_vocab)\n", 199 | "print(len(note_vocab), len(duration_vocab))" 200 | ] 201 | }, 202 | { 203 | "cell_type": "code", 204 | "execution_count": 5, 205 | "metadata": {}, 206 | "outputs": [ 207 | { 208 | "name": "stderr", 209 | "output_type": "stream", 210 | "text": [ 211 | "Using TensorFlow backend.\n" 212 | ] 213 | }, 214 | { 215 | "data": { 216 | "text/plain": [ 217 | "(array([[ 0, 1, 2, ..., 6, 5, 7],\n", 218 | " [ 1, 2, 3, ..., 5, 7, 6],\n", 219 | " [ 2, 3, 4, ..., 7, 6, 5],\n", 220 | " ...,\n", 221 | " [27, 4, 5, ..., 4, 27, 7],\n", 222 | " [ 4, 5, 21, ..., 27, 7, 40],\n", 223 | " [ 5, 21, 4, ..., 7, 40, 42]]),\n", 224 | " array([[0., 0., 0., ..., 0., 0., 0.],\n", 225 | " [0., 0., 0., ..., 0., 0., 0.],\n", 226 | " [0., 0., 0., ..., 0., 0., 0.],\n", 227 | " ...,\n", 228 | " [0., 0., 0., ..., 0., 0., 0.],\n", 229 | " [0., 0., 0., ..., 0., 0., 0.],\n", 230 | " [1., 0., 0., ..., 0., 0., 0.]], dtype=float32))" 231 | ] 232 | }, 233 | "execution_count": 5, 234 | "metadata": {}, 235 | "output_type": "execute_result" 236 | } 237 | ], 238 | "source": [ 239 | "import numpy as np\n", 240 | "import keras\n", 241 | "\n", 242 | "\n", 243 | "#把一维的数据切成段,以前面的词,预测最后一个词\n", 244 | "def prepare_sequences(data, num_classes):\n", 245 | " input = []\n", 246 | " output = []\n", 247 | "\n", 248 | " for line in data:\n", 249 | " for i in range(len(line) - 32):\n", 250 | " input.append(line[i:i + 32])\n", 251 | " output.append(line[i + 32])\n", 252 | "\n", 253 | " input = np.array(input)\n", 254 | " output = keras.utils.np_utils.to_categorical(output,\n", 255 | " num_classes=num_classes)\n", 256 | "\n", 257 | " return input, output\n", 258 | "\n", 259 | "\n", 260 | "prepare_sequences(note, len(note_vocab))" 261 | ] 262 | }, 263 | { 264 | "cell_type": "code", 265 | "execution_count": 6, 266 | "metadata": {}, 267 | "outputs": [ 268 | { 269 | "data": { 270 | "text/plain": [ 271 | "((53162, 32), (53162, 774), (53162, 32), (53162, 19))" 272 | ] 273 | }, 274 | "execution_count": 6, 275 | "metadata": {}, 276 | "output_type": "execute_result" 277 | } 278 | ], 279 | "source": [ 280 | "def get_input_output():\n", 281 | " note_input, note_output = prepare_sequences(note, len(note_vocab))\n", 282 | " duration_input, duration_output = prepare_sequences(\n", 283 | " duration, len(duration_vocab))\n", 284 | "\n", 285 | " input = [note_input, duration_input]\n", 286 | " output = [note_output, duration_output]\n", 287 | "\n", 288 | " return input, output\n", 289 | "\n", 290 | "\n", 291 | "input, output = get_input_output()\n", 292 | "\n", 293 | "input[0].shape, output[0].shape, input[1].shape, output[1].shape" 294 | ] 295 | }, 296 | { 297 | "cell_type": "code", 298 | "execution_count": 7, 299 | "metadata": {}, 300 | "outputs": [ 301 | { 302 | "name": "stdout", 303 | "output_type": "stream", 304 | "text": [ 305 | "WARNING:tensorflow:From /root/anaconda3/envs/gdl/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:74: The name tf.get_default_graph is deprecated. Please use tf.compat.v1.get_default_graph instead.\n", 306 | "\n", 307 | "WARNING:tensorflow:From /root/anaconda3/envs/gdl/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:517: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.\n", 308 | "\n", 309 | "WARNING:tensorflow:From /root/anaconda3/envs/gdl/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:4138: The name tf.random_uniform is deprecated. Please use tf.random.uniform instead.\n", 310 | "\n", 311 | "WARNING:tensorflow:From /root/anaconda3/envs/gdl/lib/python3.6/site-packages/keras/optimizers.py:790: The name tf.train.Optimizer is deprecated. Please use tf.compat.v1.train.Optimizer instead.\n", 312 | "\n", 313 | "WARNING:tensorflow:From /root/anaconda3/envs/gdl/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:3295: The name tf.log is deprecated. Please use tf.math.log instead.\n", 314 | "\n" 315 | ] 316 | }, 317 | { 318 | "data": { 319 | "text/plain": [ 320 | "" 321 | ] 322 | }, 323 | "execution_count": 7, 324 | "metadata": {}, 325 | "output_type": "execute_result" 326 | } 327 | ], 328 | "source": [ 329 | "def build_model():\n", 330 | " note_input = keras.layers.Input(shape=(None, ))\n", 331 | " duration_input = keras.layers.Input(shape=(None, ))\n", 332 | "\n", 333 | " x1 = keras.layers.Embedding(len(note_vocab), 100)(note_input)\n", 334 | " x2 = keras.layers.Embedding(len(duration_vocab), 100)(duration_input)\n", 335 | "\n", 336 | " x = keras.layers.Concatenate()([x1, x2])\n", 337 | "\n", 338 | " x = keras.models.Sequential([\n", 339 | " keras.layers.LSTM(256, return_sequences=True),\n", 340 | " keras.layers.LSTM(256, return_sequences=True),\n", 341 | " ])(x)\n", 342 | "\n", 343 | " e = keras.models.Sequential([\n", 344 | " keras.layers.Dense(1, activation='tanh'),\n", 345 | " keras.layers.Reshape([-1]),\n", 346 | " keras.layers.Activation('softmax'),\n", 347 | " keras.layers.RepeatVector(256),\n", 348 | " keras.layers.Permute([2, 1]),\n", 349 | " ])(x)\n", 350 | "\n", 351 | " x = keras.layers.Multiply()([x, e])\n", 352 | " x = keras.layers.Lambda(lambda i: keras.backend.sum(i, axis=1),\n", 353 | " output_shape=(256, ))(x)\n", 354 | " note_output = keras.layers.Dense(len(note_vocab), activation='softmax')(x)\n", 355 | " duration_output = keras.layers.Dense(len(duration_vocab),\n", 356 | " activation='softmax')(x)\n", 357 | "\n", 358 | " model = keras.models.Model([note_input, duration_input],\n", 359 | " [note_output, duration_output])\n", 360 | "\n", 361 | " model.compile(\n", 362 | " loss=['categorical_crossentropy', 'categorical_crossentropy'],\n", 363 | " optimizer=keras.optimizers.RMSprop(lr=0.001))\n", 364 | "\n", 365 | " return model\n", 366 | "\n", 367 | "\n", 368 | "model = build_model()\n", 369 | "\n", 370 | "model" 371 | ] 372 | }, 373 | { 374 | "cell_type": "code", 375 | "execution_count": 8, 376 | "metadata": {}, 377 | "outputs": [ 378 | { 379 | "name": "stdout", 380 | "output_type": "stream", 381 | "text": [ 382 | "WARNING:tensorflow:From /root/anaconda3/envs/gdl/lib/python3.6/site-packages/tensorflow/python/ops/math_grad.py:1250: add_dispatch_support..wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n", 383 | "Instructions for updating:\n", 384 | "Use tf.where in 2.0, which has the same broadcast rule as np.where\n", 385 | "WARNING:tensorflow:From /root/anaconda3/envs/gdl/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:986: The name tf.assign_add is deprecated. Please use tf.compat.v1.assign_add instead.\n", 386 | "\n", 387 | "Train on 42529 samples, validate on 10633 samples\n", 388 | "Epoch 1/20\n", 389 | "42529/42529 [==============================] - 99s 2ms/step - loss: 4.2513 - dense_2_loss: 3.5614 - dense_3_loss: 0.6899 - val_loss: 4.7055 - val_dense_2_loss: 3.8565 - val_dense_3_loss: 0.8491\n", 390 | "0\n", 391 | "Epoch 2/20\n", 392 | "42529/42529 [==============================] - 98s 2ms/step - loss: 3.7205 - dense_2_loss: 3.1797 - dense_3_loss: 0.5408 - val_loss: 4.5344 - val_dense_2_loss: 3.7185 - val_dense_3_loss: 0.8159\n", 393 | "1\n", 394 | "Epoch 3/20\n", 395 | "42529/42529 [==============================] - 98s 2ms/step - loss: 3.4867 - dense_2_loss: 3.0231 - dense_3_loss: 0.4636 - val_loss: 4.8023 - val_dense_2_loss: 3.8016 - val_dense_3_loss: 1.0006\n", 396 | "2\n", 397 | "Epoch 4/20\n", 398 | "42529/42529 [==============================] - 98s 2ms/step - loss: 3.3193 - dense_2_loss: 2.9080 - dense_3_loss: 0.4113 - val_loss: 4.4798 - val_dense_2_loss: 3.5864 - val_dense_3_loss: 0.8934\n", 399 | "3\n", 400 | "Epoch 5/20\n", 401 | "42529/42529 [==============================] - 98s 2ms/step - loss: 3.1565 - dense_2_loss: 2.7972 - dense_3_loss: 0.3593 - val_loss: 4.5762 - val_dense_2_loss: 3.5712 - val_dense_3_loss: 1.0049\n", 402 | "4\n", 403 | "Epoch 6/20\n", 404 | "42529/42529 [==============================] - 98s 2ms/step - loss: 3.0020 - dense_2_loss: 2.6898 - dense_3_loss: 0.3121 - val_loss: 4.7727 - val_dense_2_loss: 3.7600 - val_dense_3_loss: 1.0127\n", 405 | "5\n", 406 | "Epoch 7/20\n", 407 | "42529/42529 [==============================] - 98s 2ms/step - loss: 2.8425 - dense_2_loss: 2.5738 - dense_3_loss: 0.2687 - val_loss: 4.9168 - val_dense_2_loss: 3.6715 - val_dense_3_loss: 1.2454\n", 408 | "6\n", 409 | "Epoch 8/20\n", 410 | "42529/42529 [==============================] - 98s 2ms/step - loss: 2.6804 - dense_2_loss: 2.4557 - dense_3_loss: 0.2247 - val_loss: 5.1645 - val_dense_2_loss: 3.7302 - val_dense_3_loss: 1.4343\n", 411 | "7\n", 412 | "Epoch 9/20\n", 413 | "42529/42529 [==============================] - 98s 2ms/step - loss: 2.5197 - dense_2_loss: 2.3274 - dense_3_loss: 0.1923 - val_loss: 5.2515 - val_dense_2_loss: 3.8443 - val_dense_3_loss: 1.4072\n", 414 | "8\n", 415 | "Epoch 10/20\n", 416 | "42529/42529 [==============================] - 98s 2ms/step - loss: 2.3798 - dense_2_loss: 2.2173 - dense_3_loss: 0.1625 - val_loss: 5.8140 - val_dense_2_loss: 4.1481 - val_dense_3_loss: 1.6659\n", 417 | "9\n", 418 | "Epoch 11/20\n", 419 | "42529/42529 [==============================] - 98s 2ms/step - loss: 2.2637 - dense_2_loss: 2.1261 - dense_3_loss: 0.1376 - val_loss: 6.0465 - val_dense_2_loss: 4.2798 - val_dense_3_loss: 1.7667\n", 420 | "10\n", 421 | "Epoch 12/20\n", 422 | "42529/42529 [==============================] - 98s 2ms/step - loss: 2.1609 - dense_2_loss: 2.0378 - dense_3_loss: 0.1232 - val_loss: 6.1596 - val_dense_2_loss: 4.4052 - val_dense_3_loss: 1.7544\n", 423 | "11\n", 424 | "Epoch 13/20\n", 425 | "42529/42529 [==============================] - 98s 2ms/step - loss: 2.0491 - dense_2_loss: 1.9415 - dense_3_loss: 0.1076 - val_loss: 6.3409 - val_dense_2_loss: 4.4816 - val_dense_3_loss: 1.8593\n", 426 | "12\n", 427 | "Epoch 14/20\n", 428 | "42529/42529 [==============================] - 98s 2ms/step - loss: 1.9393 - dense_2_loss: 1.8408 - dense_3_loss: 0.0985 - val_loss: 6.5004 - val_dense_2_loss: 4.5988 - val_dense_3_loss: 1.9016\n", 429 | "13\n", 430 | "Epoch 15/20\n", 431 | "42529/42529 [==============================] - 98s 2ms/step - loss: 1.8155 - dense_2_loss: 1.7267 - dense_3_loss: 0.0888 - val_loss: 6.6667 - val_dense_2_loss: 4.6960 - val_dense_3_loss: 1.9707\n", 432 | "14\n", 433 | "Epoch 16/20\n", 434 | "42529/42529 [==============================] - 98s 2ms/step - loss: 1.7020 - dense_2_loss: 1.6204 - dense_3_loss: 0.0816 - val_loss: 6.8498 - val_dense_2_loss: 4.7874 - val_dense_3_loss: 2.0624\n", 435 | "15\n", 436 | "Epoch 17/20\n", 437 | "42529/42529 [==============================] - 98s 2ms/step - loss: 1.5955 - dense_2_loss: 1.5201 - dense_3_loss: 0.0754 - val_loss: 6.9344 - val_dense_2_loss: 4.8593 - val_dense_3_loss: 2.0751\n", 438 | "16\n", 439 | "Epoch 18/20\n", 440 | "42529/42529 [==============================] - 98s 2ms/step - loss: 1.4926 - dense_2_loss: 1.4226 - dense_3_loss: 0.0700 - val_loss: 7.2345 - val_dense_2_loss: 5.1542 - val_dense_3_loss: 2.0803\n", 441 | "17\n", 442 | "Epoch 19/20\n", 443 | "42529/42529 [==============================] - 98s 2ms/step - loss: 1.4049 - dense_2_loss: 1.3408 - dense_3_loss: 0.0642 - val_loss: 7.5128 - val_dense_2_loss: 5.3212 - val_dense_3_loss: 2.1916\n", 444 | "18\n", 445 | "Epoch 20/20\n", 446 | "42529/42529 [==============================] - 98s 2ms/step - loss: 1.3287 - dense_2_loss: 1.2664 - dense_3_loss: 0.0624 - val_loss: 7.5138 - val_dense_2_loss: 5.3279 - val_dense_3_loss: 2.1859\n", 447 | "19\n" 448 | ] 449 | }, 450 | { 451 | "data": { 452 | "text/plain": [ 453 | "" 454 | ] 455 | }, 456 | "execution_count": 8, 457 | "metadata": {}, 458 | "output_type": "execute_result" 459 | } 460 | ], 461 | "source": [ 462 | "#在训练过程中打印预测图片\n", 463 | "class CustomCallback(keras.callbacks.Callback):\n", 464 | "\n", 465 | " def on_epoch_end(self, epoch, logs):\n", 466 | " if epoch % 1 == 0:\n", 467 | " print(epoch)\n", 468 | "\n", 469 | "\n", 470 | "model.fit(input,\n", 471 | " output,\n", 472 | " epochs=20,\n", 473 | " batch_size=32,\n", 474 | " validation_split=0.2,\n", 475 | " callbacks=[\n", 476 | " keras.callbacks.EarlyStopping(monitor='loss',\n", 477 | " restore_best_weights=True,\n", 478 | " patience=10),\n", 479 | " CustomCallback()\n", 480 | " ],\n", 481 | " shuffle=True,\n", 482 | " verbose=1)" 483 | ] 484 | }, 485 | { 486 | "cell_type": "code", 487 | "execution_count": 9, 488 | "metadata": {}, 489 | "outputs": [ 490 | { 491 | "name": "stdout", 492 | "output_type": "stream", 493 | "text": [ 494 | "[[22, 1], [22, 1], [7, 1], [22, 1], [7, 1], [22, 1], [7, 1], [22, 1], [22, 1], [7, 1], [22, 1], [22, 1], [7, 1], [7, 1], [22, 1], [22, 1], [7, 1], [5, 1], [5, 1], [22, 1], [7, 1], [22, 1], [5, 1], [7, 1], [22, 1], [7, 1], [5, 1], [22, 1], [4, 1], [5, 1], [21, 1], [4, 1], [5, 1], [22, 1], [5, 1], [4, 1], [22, 1], [5, 1], [4, 1], [5, 1], [22, 1], [4, 1], [23, 1], [23, 1], [4, 1], [21, 1], [4, 1], [23, 1], [1, 1], [23, 1]]\n" 495 | ] 496 | }, 497 | { 498 | "data": { 499 | "text/plain": [ 500 | "50" 501 | ] 502 | }, 503 | "execution_count": 9, 504 | "metadata": {}, 505 | "output_type": "execute_result" 506 | } 507 | ], 508 | "source": [ 509 | "def get_pred():\n", 510 | " pred = []\n", 511 | "\n", 512 | " def random_sample(data):\n", 513 | " data = np.log(data) * 2\n", 514 | " data = np.exp(data)\n", 515 | " data = data / np.sum(data)\n", 516 | " return np.random.choice(len(data), p=data)\n", 517 | "\n", 518 | " note = [note_vocab['START']] * 32\n", 519 | " duration = [duration_vocab[0]] * 32\n", 520 | "\n", 521 | " for _ in range(50):\n", 522 | " input = [np.array([note]), np.array([duration])]\n", 523 | "\n", 524 | " output_note, output_duration = model.predict(input, verbose=0)\n", 525 | "\n", 526 | " output_note = random_sample(output_note[0])\n", 527 | " output_duration = random_sample(output_duration[0])\n", 528 | "\n", 529 | " pred.append([output_note, output_duration])\n", 530 | "\n", 531 | " note.append(output_note)\n", 532 | " duration.append(output_duration)\n", 533 | "\n", 534 | " if len(note) > 32:\n", 535 | " note = note[-32:]\n", 536 | " duration = duration[-32:]\n", 537 | "\n", 538 | " if note_vocab['START'] == output_note:\n", 539 | " break\n", 540 | "\n", 541 | " return pred\n", 542 | "\n", 543 | "\n", 544 | "pred = get_pred()\n", 545 | "\n", 546 | "print(pred)\n", 547 | "\n", 548 | "len(pred)" 549 | ] 550 | }, 551 | { 552 | "cell_type": "code", 553 | "execution_count": 10, 554 | "metadata": {}, 555 | "outputs": [ 556 | { 557 | "data": { 558 | "text/html": [ 559 | "\n", 560 | "
\n", 561 | " \n", 563 | " " 573 | ], 574 | "text/plain": [ 575 | "" 576 | ] 577 | }, 578 | "metadata": {}, 579 | "output_type": "display_data" 580 | } 581 | ], 582 | "source": [ 583 | "def save_midi():\n", 584 | " stream = music21.stream.Stream()\n", 585 | "\n", 586 | " #反字典\n", 587 | " note_vocab_r = {v: k for k, v in note_vocab.items()}\n", 588 | " duration_vocab_r = {v: k for k, v in duration_vocab.items()}\n", 589 | "\n", 590 | " for (n, d) in pred:\n", 591 | " n = note_vocab_r[n]\n", 592 | " d = duration_vocab_r[d]\n", 593 | "\n", 594 | " #复合音符\n", 595 | " if ('.' in n):\n", 596 | " chord_note = []\n", 597 | " for i in n.split('.'):\n", 598 | " note_i = music21.note.Note(i)\n", 599 | " note_i.duration = music21.duration.Duration(d)\n", 600 | " note_i.storedInstrument = music21.instrument.Violoncello()\n", 601 | " chord_note.append(note_i)\n", 602 | " stream.append(music21.chord.Chord(chord_note))\n", 603 | " #rest音符\n", 604 | " elif n == 'rest':\n", 605 | " new_note = music21.note.Rest()\n", 606 | " new_note.duration = music21.duration.Duration(d)\n", 607 | " new_note.storedInstrument = music21.instrument.Violoncello()\n", 608 | " stream.append(new_note)\n", 609 | " #单音符\n", 610 | " elif n != 'START':\n", 611 | " new_note = music21.note.Note(n)\n", 612 | " new_note.duration = music21.duration.Duration(d)\n", 613 | " new_note.storedInstrument = music21.instrument.Violoncello()\n", 614 | " stream.append(new_note)\n", 615 | "\n", 616 | " stream = stream.chordify()\n", 617 | " stream.write('midi', fp='pred.mid')\n", 618 | "\n", 619 | " show('pred.mid')\n", 620 | "\n", 621 | "\n", 622 | "save_midi()" 623 | ] 624 | } 625 | ], 626 | "metadata": { 627 | "kernelspec": { 628 | "display_name": "Python 3 (ipykernel)", 629 | "language": "python", 630 | "name": "python3" 631 | }, 632 | "language_info": { 633 | "codemirror_mode": { 634 | "name": "ipython", 635 | "version": 3 636 | }, 637 | "file_extension": ".py", 638 | "mimetype": "text/x-python", 639 | "name": "python", 640 | "nbconvert_exporter": "python", 641 | "pygments_lexer": "ipython3", 642 | "version": "3.9.12" 643 | } 644 | }, 645 | "nbformat": 4, 646 | "nbformat_minor": 2 647 | } 648 | -------------------------------------------------------------------------------- /keras/9.musegan创作chorales.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "data": { 10 | "text/plain": [ 11 | "array([[[ 1, -1, -1, -1, -1, -1, -1, -1, -1],\n", 12 | " [-1, 1, -1, -1, -1, -1, -1, -1, -1],\n", 13 | " [-1, -1, 1, -1, -1, -1, -1, -1, -1]],\n", 14 | "\n", 15 | " [[-1, -1, -1, 1, -1, -1, -1, -1, -1],\n", 16 | " [-1, -1, -1, -1, 1, -1, -1, -1, -1],\n", 17 | " [-1, -1, -1, -1, -1, 1, -1, -1, -1]],\n", 18 | "\n", 19 | " [[-1, -1, -1, -1, -1, -1, 1, -1, -1],\n", 20 | " [-1, -1, -1, -1, -1, -1, -1, 1, -1],\n", 21 | " [-1, -1, -1, -1, -1, -1, -1, -1, 1]]], dtype=int32)" 22 | ] 23 | }, 24 | "execution_count": 1, 25 | "metadata": {}, 26 | "output_type": "execute_result" 27 | } 28 | ], 29 | "source": [ 30 | "import numpy as np\n", 31 | "\n", 32 | "\n", 33 | "#数字矩阵转one hot编码的函数\n", 34 | "def build_one_hot(data, max_value):\n", 35 | " data = np.eye(max_value, dtype=np.int32)[data]\n", 36 | " data[data == 0] = -1\n", 37 | "\n", 38 | " return data\n", 39 | "\n", 40 | "\n", 41 | "build_one_hot(np.arange(9).reshape(3, 3), 9)" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": 2, 47 | "metadata": {}, 48 | "outputs": [ 49 | { 50 | "name": "stdout", 51 | "output_type": "stream", 52 | "text": [ 53 | "data[0]= (192, 4) float16\n", 54 | "data[1]= (228, 4) float16\n", 55 | "data[2]= (208, 4) float16\n", 56 | "data[3]= (432, 4) float16\n", 57 | "data[4]= (260, 4) float16\n", 58 | "data[5]= (212, 4) float16\n", 59 | "data[6]= (292, 4) float16\n", 60 | "data[7]= (180, 4) float16\n", 61 | "data[8]= (132, 4) float16\n", 62 | "data[9]= (192, 4) float16\n", 63 | "data= (229,) object\n", 64 | "new_data= 229 (192, 4) (228, 4)\n", 65 | "data_cut= (229, 32, 4) int32\n" 66 | ] 67 | }, 68 | { 69 | "data": { 70 | "text/plain": [ 71 | "((229, 2, 16, 84, 4), dtype('int32'))" 72 | ] 73 | }, 74 | "execution_count": 2, 75 | "metadata": {}, 76 | "output_type": "execute_result" 77 | } 78 | ], 79 | "source": [ 80 | "def get_data():\n", 81 | " #加载数据\n", 82 | " data = np.load('../datas/chorales/Jsb16thSeparated.npz',\n", 83 | " encoding='bytes')['train']\n", 84 | "\n", 85 | " #一共229首曲子,每个曲子长度不定,都是4个声部\n", 86 | " for i in range(10):\n", 87 | " print('data[%d]=' % i, data[i].shape, data[i].dtype)\n", 88 | "\n", 89 | " print('data=', data.shape, data.dtype)\n", 90 | "\n", 91 | " #筛除数据中的nan,这数据集做的简直是一坨屎\n", 92 | " new_data = []\n", 93 | " for song in data:\n", 94 | " new_song = []\n", 95 | " for time in song:\n", 96 | " #time -> [4]\n", 97 | "\n", 98 | " if np.isnan(time).any():\n", 99 | " continue\n", 100 | "\n", 101 | " new_song.append(time)\n", 102 | "\n", 103 | " new_song = np.array(new_song, dtype=np.int32)\n", 104 | " new_data.append(new_song)\n", 105 | "\n", 106 | " print('new_data=', len(new_data), new_data[0].shape, new_data[1].shape)\n", 107 | "\n", 108 | " #截取每首曲子的前32个拍子\n", 109 | " data_cut = []\n", 110 | " for song in new_data:\n", 111 | " data_cut.append(song[:32])\n", 112 | "\n", 113 | " #[229, 32, 4]\n", 114 | " data_cut = np.array(data_cut)\n", 115 | "\n", 116 | " print('data_cut=', data_cut.shape, data_cut.dtype)\n", 117 | "\n", 118 | " #分成两条音轨,每条音轨16个拍子\n", 119 | " #[229, 32, 4] -> [229, 2, 16, 4]\n", 120 | " data_cut = data_cut.reshape([229, 2, 16, 4])\n", 121 | "\n", 122 | " #转one hot编码\n", 123 | " #[229, 2, 16, 4] -> [229, 2, 16, 4, 84]\n", 124 | " data_cut = build_one_hot(data_cut, max_value=84)\n", 125 | "\n", 126 | " #交换最后两个维度\n", 127 | " #[229, 2, 16, 4, 84] -> [229, 2, 16, 84, 4]\n", 128 | " data_cut = data_cut.transpose([0, 1, 2, 4, 3])\n", 129 | "\n", 130 | " return data_cut\n", 131 | "\n", 132 | "\n", 133 | "data = get_data()\n", 134 | "\n", 135 | "data.shape, data.dtype" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": 3, 141 | "metadata": {}, 142 | "outputs": [ 143 | { 144 | "name": "stdout", 145 | "output_type": "stream", 146 | "text": [ 147 | "(array([0, 1, 2, 3, 4]), array([0.25, 0.25, 0.25, 0.25, 0.25], dtype=float32))\n", 148 | "(array([1., 1.]), array([1. , 0.25], dtype=float32))\n", 149 | "(array([0, 1, 2, 3, 4, 5, 6]), array([0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25], dtype=float32))\n", 150 | "(array([1., 1.]), array([1. , 0.75], dtype=float32))\n" 151 | ] 152 | } 153 | ], 154 | "source": [ 155 | "def merge_note(note, duration=None):\n", 156 | " if duration is None:\n", 157 | " duration = np.full(note.shape, fill_value=0.25, dtype=np.float32)\n", 158 | "\n", 159 | " #从前往后遍历\n", 160 | " for i in range(len(note) - 1):\n", 161 | " j = i + 1\n", 162 | "\n", 163 | " #判断相连的两个note是否相同,并且duration相加不大于1.0\n", 164 | " if note[i] == note[j] and duration[i] + duration[j] <= 1.0:\n", 165 | "\n", 166 | " #duration合并\n", 167 | " duration[i] += duration[j]\n", 168 | "\n", 169 | " #删除重复的note\n", 170 | " note = np.delete(note, j, axis=0)\n", 171 | " duration = np.delete(duration, j, axis=0)\n", 172 | "\n", 173 | " #递归调用\n", 174 | " return merge_note(note, duration)\n", 175 | "\n", 176 | " return note, duration\n", 177 | "\n", 178 | "\n", 179 | "print(merge_note(np.arange(5)))\n", 180 | "print(merge_note(np.ones(5)))\n", 181 | "\n", 182 | "print(merge_note(np.arange(7)))\n", 183 | "print(merge_note(np.ones(7)))" 184 | ] 185 | }, 186 | { 187 | "cell_type": "code", 188 | "execution_count": 4, 189 | "metadata": {}, 190 | "outputs": [], 191 | "source": [ 192 | "import music21\n", 193 | "\n", 194 | "\n", 195 | "def save_to_mid(data, filename):\n", 196 | " #data -> [32, 4]\n", 197 | " stream = music21.stream.Score()\n", 198 | " stream.append(music21.tempo.MetronomeMark(number=66))\n", 199 | "\n", 200 | " for i in range(4):\n", 201 | " channel = music21.stream.Part()\n", 202 | "\n", 203 | " notes, durations = merge_note(data[:, i])\n", 204 | " notes, durations = notes.tolist(), durations.tolist()\n", 205 | " for n, d in zip(notes, durations):\n", 206 | " note = music21.note.Note(n)\n", 207 | " note.duration = music21.duration.Duration(d)\n", 208 | " channel.append(note)\n", 209 | "\n", 210 | " stream.append(channel)\n", 211 | "\n", 212 | " stream.write('midi', fp=filename)\n", 213 | "\n", 214 | "\n", 215 | "save_to_mid(data[0].argmax(axis=2).reshape(32, 4), 'sample.mid')" 216 | ] 217 | }, 218 | { 219 | "cell_type": "code", 220 | "execution_count": 5, 221 | "metadata": {}, 222 | "outputs": [ 223 | { 224 | "data": { 225 | "text/html": [ 226 | "\n", 227 | "
\n", 228 | " \n", 230 | " " 240 | ], 241 | "text/plain": [ 242 | "" 243 | ] 244 | }, 245 | "metadata": {}, 246 | "output_type": "display_data" 247 | } 248 | ], 249 | "source": [ 250 | "def show(file):\n", 251 | " f = music21.midi.MidiFile()\n", 252 | " f.open(file)\n", 253 | " f.read()\n", 254 | " f.close()\n", 255 | " music21.midi.translate.midiFileToStream(f).show('midi')\n", 256 | "\n", 257 | "\n", 258 | "show('sample.mid')" 259 | ] 260 | }, 261 | { 262 | "cell_type": "code", 263 | "execution_count": 6, 264 | "metadata": {}, 265 | "outputs": [ 266 | { 267 | "name": "stderr", 268 | "output_type": "stream", 269 | "text": [ 270 | "Using TensorFlow backend.\n" 271 | ] 272 | }, 273 | { 274 | "name": "stdout", 275 | "output_type": "stream", 276 | "text": [ 277 | "WARNING:tensorflow:From /root/anaconda3/envs/gdl/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:74: The name tf.get_default_graph is deprecated. Please use tf.compat.v1.get_default_graph instead.\n", 278 | "\n", 279 | "WARNING:tensorflow:From /root/anaconda3/envs/gdl/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:517: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.\n", 280 | "\n", 281 | "WARNING:tensorflow:From /root/anaconda3/envs/gdl/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:4115: The name tf.random_normal is deprecated. Please use tf.random.normal instead.\n", 282 | "\n" 283 | ] 284 | }, 285 | { 286 | "data": { 287 | "text/plain": [ 288 | "" 289 | ] 290 | }, 291 | "execution_count": 6, 292 | "metadata": {}, 293 | "output_type": "execute_result" 294 | } 295 | ], 296 | "source": [ 297 | "import keras\n", 298 | "\n", 299 | "weight_init = keras.initializers.RandomNormal(mean=0., stddev=0.02)\n", 300 | "\n", 301 | "cls = keras.models.Sequential([\n", 302 | " keras.layers.Conv3D(filters=128,\n", 303 | " kernel_size=(2, 1, 1),\n", 304 | " padding='valid',\n", 305 | " strides=(1, 1, 1),\n", 306 | " kernel_initializer=weight_init,\n", 307 | " input_shape=(2, 16, 84, 4)),\n", 308 | " keras.layers.LeakyReLU(),\n", 309 | " keras.layers.Conv3D(filters=128,\n", 310 | " kernel_size=(1, 1, 1),\n", 311 | " padding='valid',\n", 312 | " strides=(1, 1, 1),\n", 313 | " kernel_initializer=weight_init),\n", 314 | " keras.layers.LeakyReLU(),\n", 315 | " keras.layers.Conv3D(filters=128,\n", 316 | " kernel_size=(1, 1, 12),\n", 317 | " padding='same',\n", 318 | " strides=(1, 1, 12),\n", 319 | " kernel_initializer=weight_init),\n", 320 | " keras.layers.LeakyReLU(),\n", 321 | " keras.layers.Conv3D(filters=128,\n", 322 | " kernel_size=(1, 1, 7),\n", 323 | " padding='same',\n", 324 | " strides=(1, 1, 7),\n", 325 | " kernel_initializer=weight_init),\n", 326 | " keras.layers.LeakyReLU(),\n", 327 | " keras.layers.Conv3D(filters=128,\n", 328 | " kernel_size=(1, 2, 1),\n", 329 | " padding='same',\n", 330 | " strides=(1, 2, 1),\n", 331 | " kernel_initializer=weight_init),\n", 332 | " keras.layers.LeakyReLU(),\n", 333 | " keras.layers.Conv3D(filters=128,\n", 334 | " kernel_size=(1, 2, 1),\n", 335 | " padding='same',\n", 336 | " strides=(1, 2, 1),\n", 337 | " kernel_initializer=weight_init),\n", 338 | " keras.layers.LeakyReLU(),\n", 339 | " keras.layers.Conv3D(filters=256,\n", 340 | " kernel_size=(1, 4, 1),\n", 341 | " padding='same',\n", 342 | " strides=(1, 2, 1),\n", 343 | " kernel_initializer=weight_init),\n", 344 | " keras.layers.LeakyReLU(),\n", 345 | " keras.layers.Conv3D(filters=512,\n", 346 | " kernel_size=(1, 3, 1),\n", 347 | " padding='same',\n", 348 | " strides=(1, 2, 1),\n", 349 | " kernel_initializer=weight_init),\n", 350 | " keras.layers.LeakyReLU(),\n", 351 | " keras.layers.Flatten(),\n", 352 | " keras.layers.Dense(1024, kernel_initializer=weight_init),\n", 353 | " keras.layers.LeakyReLU(),\n", 354 | " keras.layers.Dense(1, activation=None, kernel_initializer=weight_init),\n", 355 | "])\n", 356 | "\n", 357 | "cls" 358 | ] 359 | }, 360 | { 361 | "cell_type": "code", 362 | "execution_count": 7, 363 | "metadata": {}, 364 | "outputs": [ 365 | { 366 | "name": "stdout", 367 | "output_type": "stream", 368 | "text": [ 369 | "WARNING:tensorflow:From /root/anaconda3/envs/gdl/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:174: The name tf.get_default_session is deprecated. Please use tf.compat.v1.get_default_session instead.\n", 370 | "\n", 371 | "WARNING:tensorflow:From /root/anaconda3/envs/gdl/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:181: The name tf.ConfigProto is deprecated. Please use tf.compat.v1.ConfigProto instead.\n", 372 | "\n", 373 | "WARNING:tensorflow:From /root/anaconda3/envs/gdl/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:1834: The name tf.nn.fused_batch_norm is deprecated. Please use tf.compat.v1.nn.fused_batch_norm instead.\n", 374 | "\n" 375 | ] 376 | }, 377 | { 378 | "data": { 379 | "text/plain": [ 380 | "" 381 | ] 382 | }, 383 | "execution_count": 7, 384 | "metadata": {}, 385 | "output_type": "execute_result" 386 | } 387 | ], 388 | "source": [ 389 | "def get_gen():\n", 390 | "\n", 391 | " def TemporalNetwork():\n", 392 | " return keras.models.Sequential([\n", 393 | " keras.layers.Reshape([1, 1, 32], input_shape=(32, )),\n", 394 | " keras.layers.Conv2DTranspose(filters=1024,\n", 395 | " kernel_size=(2, 1),\n", 396 | " padding='valid',\n", 397 | " strides=(1, 1),\n", 398 | " kernel_initializer=weight_init),\n", 399 | " keras.layers.BatchNormalization(momentum=0.9),\n", 400 | " keras.layers.Activation('relu'),\n", 401 | " keras.layers.Conv2DTranspose(filters=32,\n", 402 | " kernel_size=(1, 1),\n", 403 | " padding='valid',\n", 404 | " strides=(1, 1),\n", 405 | " kernel_initializer=weight_init),\n", 406 | " keras.layers.BatchNormalization(momentum=0.9),\n", 407 | " keras.layers.Activation('relu'),\n", 408 | " keras.layers.Reshape([2, 32]),\n", 409 | " ])\n", 410 | "\n", 411 | " def BarGenerator():\n", 412 | " return keras.models.Sequential([\n", 413 | " keras.layers.Dense(1024, input_shape=(128, )),\n", 414 | " keras.layers.BatchNormalization(momentum=0.9),\n", 415 | " keras.layers.Activation('relu'),\n", 416 | " keras.layers.Reshape([2, 1, 512]),\n", 417 | " keras.layers.Conv2DTranspose(filters=512,\n", 418 | " kernel_size=(2, 1),\n", 419 | " padding='same',\n", 420 | " strides=(2, 1),\n", 421 | " kernel_initializer=weight_init),\n", 422 | " keras.layers.BatchNormalization(momentum=0.9),\n", 423 | " keras.layers.Activation('relu'),\n", 424 | " keras.layers.Conv2DTranspose(filters=256,\n", 425 | " kernel_size=(2, 1),\n", 426 | " padding='same',\n", 427 | " strides=(2, 1),\n", 428 | " kernel_initializer=weight_init),\n", 429 | " keras.layers.BatchNormalization(momentum=0.9),\n", 430 | " keras.layers.Activation('relu'),\n", 431 | " keras.layers.Conv2DTranspose(filters=256,\n", 432 | " kernel_size=(2, 1),\n", 433 | " padding='same',\n", 434 | " strides=(2, 1),\n", 435 | " kernel_initializer=weight_init),\n", 436 | " keras.layers.BatchNormalization(momentum=0.9),\n", 437 | " keras.layers.Activation('relu'),\n", 438 | " keras.layers.Conv2DTranspose(filters=256,\n", 439 | " kernel_size=(1, 7),\n", 440 | " padding='same',\n", 441 | " strides=(1, 7),\n", 442 | " kernel_initializer=weight_init),\n", 443 | " keras.layers.BatchNormalization(momentum=0.9),\n", 444 | " keras.layers.Activation('relu'),\n", 445 | " keras.layers.Conv2DTranspose(filters=1,\n", 446 | " kernel_size=(1, 12),\n", 447 | " padding='same',\n", 448 | " strides=(1, 12),\n", 449 | " kernel_initializer=weight_init),\n", 450 | " keras.layers.Activation('tanh'),\n", 451 | " keras.layers.Reshape([1, 16, 84, 1]),\n", 452 | " ])\n", 453 | "\n", 454 | " input_chord = keras.layers.Input(shape=(32, ))\n", 455 | " input_style = keras.layers.Input(shape=(32, ))\n", 456 | " input_melody = keras.layers.Input(shape=(4, 32))\n", 457 | " input_groove = keras.layers.Input(shape=(4, 32))\n", 458 | "\n", 459 | " output_chord = TemporalNetwork()(input_chord)\n", 460 | "\n", 461 | " output = []\n", 462 | " for i in range(2):\n", 463 | " output_c = []\n", 464 | "\n", 465 | " for j in range(4):\n", 466 | "\n", 467 | " output_melody = keras.models.Sequential([\n", 468 | " keras.layers.Lambda(lambda x: x[:, j, :]),\n", 469 | " TemporalNetwork(),\n", 470 | " keras.layers.Lambda(lambda x: x[:, i, :])\n", 471 | " ])(input_melody)\n", 472 | "\n", 473 | " concat = keras.layers.Concatenate(axis=1)([\n", 474 | " keras.layers.Lambda(lambda x: x[:, i, :])(output_chord),\n", 475 | " input_style, output_melody,\n", 476 | " keras.layers.Lambda(lambda x: x[:, j, :])(input_groove)\n", 477 | " ])\n", 478 | " output_c.append(BarGenerator()(concat))\n", 479 | "\n", 480 | " output.append(keras.layers.Concatenate(axis=-1)(output_c))\n", 481 | "\n", 482 | " output = keras.layers.Concatenate(axis=1)(output)\n", 483 | "\n", 484 | " gen = keras.models.Model(\n", 485 | " [input_chord, input_style, input_melody, input_groove], output)\n", 486 | "\n", 487 | " return gen\n", 488 | "\n", 489 | "\n", 490 | "gen = get_gen()\n", 491 | "\n", 492 | "gen" 493 | ] 494 | }, 495 | { 496 | "cell_type": "code", 497 | "execution_count": 8, 498 | "metadata": {}, 499 | "outputs": [ 500 | { 501 | "name": "stdout", 502 | "output_type": "stream", 503 | "text": [ 504 | "WARNING:tensorflow:From /root/anaconda3/envs/gdl/lib/python3.6/site-packages/keras/optimizers.py:790: The name tf.train.Optimizer is deprecated. Please use tf.compat.v1.train.Optimizer instead.\n", 505 | "\n" 506 | ] 507 | }, 508 | { 509 | "data": { 510 | "text/plain": [ 511 | "(,\n", 512 | " )" 513 | ] 514 | }, 515 | "execution_count": 8, 516 | "metadata": {}, 517 | "output_type": "execute_result" 518 | } 519 | ], 520 | "source": [ 521 | "from functools import partial\n", 522 | "\n", 523 | "\n", 524 | "def get_gan():\n", 525 | "\n", 526 | " class RandomMerge(keras.layers.merge._Merge):\n", 527 | "\n", 528 | " def __init__(self):\n", 529 | " super().__init__()\n", 530 | "\n", 531 | " def _merge_function(self, inputs):\n", 532 | " alpha = keras.backend.random_uniform((64, 1, 1, 1, 1))\n", 533 | " return (alpha * inputs[0]) + ((1 - alpha) * inputs[1])\n", 534 | "\n", 535 | " def set_trainable(model, trainable):\n", 536 | " model.trainable = trainable\n", 537 | " for layer in model.layers:\n", 538 | " layer.trainable = trainable\n", 539 | "\n", 540 | " set_trainable(gen, False)\n", 541 | "\n", 542 | " input_cls = keras.layers.Input(shape=[2, 16, 84, 4])\n", 543 | " input_chord = keras.layers.Input(shape=(32, ))\n", 544 | " input_style = keras.layers.Input(shape=(32, ))\n", 545 | " input_melody = keras.layers.Input(shape=(4, 32))\n", 546 | " input_groove = keras.layers.Input(shape=(4, 32))\n", 547 | "\n", 548 | " output_gen = gen([input_chord, input_style, input_melody, input_groove])\n", 549 | "\n", 550 | " output_cls_fake = cls(output_gen)\n", 551 | " output_cls_real = cls(input_cls)\n", 552 | "\n", 553 | " input_merge = RandomMerge()([input_cls, output_gen])\n", 554 | "\n", 555 | " output_cls_merge = cls(input_merge)\n", 556 | "\n", 557 | " def get_grads_loss(y_true, y_pred, input_merge):\n", 558 | " grads = keras.backend.gradients(y_pred, input_merge)[0]\n", 559 | " grads = keras.backend.square(grads)\n", 560 | " grads = keras.backend.sum(grads, axis=np.arange(1, len(grads.shape)))\n", 561 | " grads = keras.backend.sqrt(grads)\n", 562 | " grads = keras.backend.square(1 - grads)\n", 563 | " return keras.backend.mean(grads)\n", 564 | "\n", 565 | " grads_loss = partial(get_grads_loss, input_merge=input_merge)\n", 566 | "\n", 567 | " def wasserstein(y_true, y_pred):\n", 568 | " return -keras.backend.mean(y_true * y_pred)\n", 569 | "\n", 570 | " cls_model = keras.models.Model(\n", 571 | " inputs=[\n", 572 | " input_cls, input_chord, input_style, input_melody, input_groove\n", 573 | " ],\n", 574 | " outputs=[output_cls_real, output_cls_fake, output_cls_merge])\n", 575 | "\n", 576 | " cls_model.compile(loss=[wasserstein, wasserstein, grads_loss],\n", 577 | " optimizer=keras.optimizers.Adam(lr=0.001,\n", 578 | " beta_1=0.5,\n", 579 | " beta_2=0.9),\n", 580 | " loss_weights=[1, 1, 10])\n", 581 | "\n", 582 | " set_trainable(cls, False)\n", 583 | " set_trainable(gen, True)\n", 584 | "\n", 585 | " gan = keras.models.Model(\n", 586 | " [input_chord, input_style, input_melody, input_groove],\n", 587 | " output_cls_fake)\n", 588 | "\n", 589 | " gan.compile(optimizer=keras.optimizers.Adam(lr=0.001,\n", 590 | " beta_1=0.5,\n", 591 | " beta_2=0.9),\n", 592 | " loss=wasserstein)\n", 593 | "\n", 594 | " set_trainable(cls, True)\n", 595 | "\n", 596 | " return gan, cls_model\n", 597 | "\n", 598 | "\n", 599 | "gan, cls_model = get_gan()\n", 600 | "\n", 601 | "gan, cls_model" 602 | ] 603 | }, 604 | { 605 | "cell_type": "code", 606 | "execution_count": 9, 607 | "metadata": {}, 608 | "outputs": [ 609 | { 610 | "data": { 611 | "text/html": [ 612 | "\n", 613 | "
\n", 614 | " \n", 616 | " " 626 | ], 627 | "text/plain": [ 628 | "" 629 | ] 630 | }, 631 | "metadata": {}, 632 | "output_type": "display_data" 633 | } 634 | ], 635 | "source": [ 636 | "def test():\n", 637 | " chord = np.random.normal(0, 1, (1, 32))\n", 638 | " style = np.random.normal(0, 1, (1, 32))\n", 639 | " melody = np.random.normal(0, 1, (1, 4, 32))\n", 640 | " groove = np.random.normal(0, 1, (1, 4, 32))\n", 641 | "\n", 642 | " #[1, 2, 16, 84, 4]\n", 643 | " pred = gen.predict([chord, style, melody, groove])\n", 644 | "\n", 645 | " #[1, 2, 16, 84, 4] -> [1, 2, 16, 4]\n", 646 | " pred = pred.argmax(axis=3)\n", 647 | "\n", 648 | " #[1, 2, 16, 4] -> [32, 4]\n", 649 | " pred = pred.reshape(32, 4)\n", 650 | "\n", 651 | " save_to_mid(pred, 'pred.mid')\n", 652 | "\n", 653 | " show('pred.mid')\n", 654 | "\n", 655 | "\n", 656 | "test()" 657 | ] 658 | }, 659 | { 660 | "cell_type": "code", 661 | "execution_count": 10, 662 | "metadata": {}, 663 | "outputs": [ 664 | { 665 | "name": "stdout", 666 | "output_type": "stream", 667 | "text": [ 668 | "WARNING:tensorflow:From /root/anaconda3/envs/gdl/lib/python3.6/site-packages/tensorflow/python/ops/math_grad.py:1250: add_dispatch_support..wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n", 669 | "Instructions for updating:\n", 670 | "Use tf.where in 2.0, which has the same broadcast rule as np.where\n" 671 | ] 672 | }, 673 | { 674 | "name": "stderr", 675 | "output_type": "stream", 676 | "text": [ 677 | "/root/anaconda3/envs/gdl/lib/python3.6/site-packages/keras/engine/training.py:490: UserWarning: Discrepancy between trainable weights and collected trainable weights, did you set `model.trainable` without calling `model.compile` after ?\n", 678 | " 'Discrepancy between trainable weights and collected trainable'\n", 679 | "/root/anaconda3/envs/gdl/lib/python3.6/site-packages/keras/engine/training.py:490: UserWarning: Discrepancy between trainable weights and collected trainable weights, did you set `model.trainable` without calling `model.compile` after ?\n", 680 | " 'Discrepancy between trainable weights and collected trainable'\n" 681 | ] 682 | }, 683 | { 684 | "name": "stdout", 685 | "output_type": "stream", 686 | "text": [ 687 | "0 [8.912887, -0.85698175, -0.034562703, 0.98044306] 0.0040728305\n" 688 | ] 689 | }, 690 | { 691 | "name": "stderr", 692 | "output_type": "stream", 693 | "text": [ 694 | "/root/anaconda3/envs/gdl/lib/python3.6/site-packages/keras/engine/training.py:490: UserWarning: Discrepancy between trainable weights and collected trainable weights, did you set `model.trainable` without calling `model.compile` after ?\n", 695 | " 'Discrepancy between trainable weights and collected trainable'\n" 696 | ] 697 | }, 698 | { 699 | "name": "stdout", 700 | "output_type": "stream", 701 | "text": [ 702 | "50 [-27.792698, -267.06592, 231.58945, 0.76837736] -282.16797\n", 703 | "100 [-27.88805, -257.02512, 218.39185, 1.0745221] -225.14125\n", 704 | "150 [-18.23226, -43.738167, 19.566328, 0.59395784] -33.739716\n", 705 | "200 [-16.835743, -75.27756, 53.224827, 0.52169865] -45.045284\n", 706 | "250 [-15.013928, -43.07686, 21.97362, 0.60893106] -14.732184\n", 707 | "300 [-14.297857, -106.17346, 89.668396, 0.22072089] -97.43548\n", 708 | "350 [-14.1342125, -32.94745, 15.056252, 0.37569845] -12.34396\n", 709 | "400 [-12.606506, -67.773705, 48.30908, 0.68581194] -37.58142\n", 710 | "450 [-11.947504, -14.487357, 1.7038689, 0.08359844] -15.795556\n", 711 | "500 [-11.203049, -20.237303, 8.629093, 0.040516045] -6.5508084\n", 712 | "550 [-12.3838825, -23.784214, 9.745614, 0.16547178] -12.909263\n", 713 | "600 [-11.5849285, -35.62585, 22.524584, 0.15163384] -24.991646\n", 714 | "650 [-11.719839, -28.288511, 14.372057, 0.21966158] -16.3292\n", 715 | "700 [-10.081033, -24.602768, 12.748704, 0.17730309] -10.602691\n", 716 | "750 [-10.870434, -20.221205, 8.281688, 0.10690833] -7.7463045\n", 717 | "800 [-9.710244, -26.209225, 14.81569, 0.16832903] -12.453541\n", 718 | "850 [-9.430994, -15.119096, 4.843739, 0.08443623] -7.6122823\n", 719 | "900 [-9.68703, -20.41964, 9.861892, 0.08707178] -2.0840075\n", 720 | "950 [-10.114071, -17.73894, 6.2960553, 0.13288136] -6.283645\n" 721 | ] 722 | } 723 | ], 724 | "source": [ 725 | "def train():\n", 726 | "\n", 727 | " def train_cls():\n", 728 | " pos = np.ones((64, 1), dtype=np.int32)\n", 729 | " neg = -np.ones((64, 1), dtype=np.int32)\n", 730 | " dummy = np.zeros((64, 1), dtype=np.int32)\n", 731 | "\n", 732 | " chord = np.random.normal(0, 1, (64, 32))\n", 733 | " style = np.random.normal(0, 1, (64, 32))\n", 734 | " melody = np.random.normal(0, 1, (64, 4, 32))\n", 735 | " groove = np.random.normal(0, 1, (64, 4, 32))\n", 736 | "\n", 737 | " data_sub = data[np.random.randint(0, data.shape[0], 64)]\n", 738 | "\n", 739 | " loss_cls = cls_model.train_on_batch(\n", 740 | " [data_sub, chord, style, melody, groove], [pos, neg, dummy])\n", 741 | "\n", 742 | " return loss_cls\n", 743 | "\n", 744 | " def train_gen():\n", 745 | " pos = np.ones((64, 1), dtype=np.int32)\n", 746 | "\n", 747 | " chord = np.random.normal(0, 1, (64, 32))\n", 748 | " style = np.random.normal(0, 1, (64, 32))\n", 749 | " melody = np.random.normal(0, 1, (64, 4, 32))\n", 750 | " groove = np.random.normal(0, 1, (64, 4, 32))\n", 751 | "\n", 752 | " loss_gen = gan.train_on_batch([chord, style, melody, groove], pos)\n", 753 | "\n", 754 | " return loss_gen\n", 755 | "\n", 756 | " for epoch in range(1000):\n", 757 | " for _ in range(5):\n", 758 | " loss_cls = train_cls()\n", 759 | "\n", 760 | " loss_gen = train_gen()\n", 761 | "\n", 762 | " if epoch % 50 == 0:\n", 763 | " print(epoch, loss_cls, loss_gen)\n", 764 | "\n", 765 | "\n", 766 | "train()" 767 | ] 768 | }, 769 | { 770 | "cell_type": "code", 771 | "execution_count": 13, 772 | "metadata": {}, 773 | "outputs": [ 774 | { 775 | "data": { 776 | "text/html": [ 777 | "\n", 778 | "
\n", 779 | " \n", 781 | " " 791 | ], 792 | "text/plain": [ 793 | "" 794 | ] 795 | }, 796 | "metadata": {}, 797 | "output_type": "display_data" 798 | } 799 | ], 800 | "source": [ 801 | "test()" 802 | ] 803 | } 804 | ], 805 | "metadata": { 806 | "kernelspec": { 807 | "display_name": "Python 3 (ipykernel)", 808 | "language": "python", 809 | "name": "python3" 810 | }, 811 | "language_info": { 812 | "codemirror_mode": { 813 | "name": "ipython", 814 | "version": 3 815 | }, 816 | "file_extension": ".py", 817 | "mimetype": "text/x-python", 818 | "name": "python", 819 | "nbconvert_exporter": "python", 820 | "pygments_lexer": "ipython3", 821 | "version": "3.9.12" 822 | } 823 | }, 824 | "nbformat": 4, 825 | "nbformat_minor": 2 826 | } 827 | -------------------------------------------------------------------------------- /keras/README.md: -------------------------------------------------------------------------------- 1 | 环境信息: 2 |
3 | keras==2.2.4 4 |
5 | tensorflow==1.14.0 6 |
7 |
8 | 引用自:https://github.com/davidADSP/GDL_code 9 | --------------------------------------------------------------------------------