├── README.md └── Music_generation.ipynb /README.md: -------------------------------------------------------------------------------- 1 | # Generate-Music-Using-Neural-Networks 2 | In contrast to the popular belief that deep learning is a black box and hard to train, the effectiveness of deep learning has amazed everyone. 3 | Let us explore one of the wonders of deep learning which is producing art, which was always assumed beyond the scope for machines, generate music using deep learning. 4 | 5 | Here are a few generated pieces of music, 6 | 7 | https://soundcloud.com/ramya-vidiyala-850882745/gen-music-1 8 | 9 | https://soundcloud.com/ramya-vidiyala-850882745/gen-music-2 10 | 11 | https://soundcloud.com/ramya-vidiyala-850882745/gen-music-3 12 | 13 | https://soundcloud.com/ramya-vidiyala-850882745/gen-music-4 14 | 15 | https://soundcloud.com/ramya-vidiyala-850882745/gen-music-5 16 | 17 | We generated these pleasant melodies using LSTMs. 18 | 19 | -------------------------------------------------------------------------------- /Music_generation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 3, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "import json\n", 11 | "import argparse\n", 12 | "\n", 13 | "import numpy as np\n", 14 | "\n", 15 | "from model import build_model, save_weights\n", 16 | "import numpy as np\n", 17 | "\n", 18 | "from model import build_model, load_weights\n", 19 | "\n", 20 | "from keras.models import Sequential, load_model\n", 21 | "from keras.layers import LSTM, Dropout, TimeDistributed, Dense, Activation, Embedding\n", 22 | "\n", 23 | "\n", 24 | "DATA_DIR = './data'\n", 25 | "LOG_DIR = './logs'\n", 26 | "\n", 27 | "BATCH_SIZE = 16\n", 28 | "SEQ_LENGTH = 64" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 4, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "class TrainLogger(object):\n", 38 | " def __init__(self, file):\n", 39 | " self.file = os.path.join(LOG_DIR, file)\n", 40 | " self.epochs = 0\n", 41 | " with open(self.file, 'w') as f:\n", 42 | " f.write('epoch,loss,acc\\n')\n", 43 | "\n", 44 | " def add_entry(self, loss, acc):\n", 45 | " self.epochs += 1\n", 46 | " s = '{},{},{}\\n'.format(self.epochs, loss, acc)\n", 47 | " with open(self.file, 'a') as f:\n", 48 | " f.write(s)\n" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 5, 54 | "metadata": {}, 55 | "outputs": [], 56 | "source": [ 57 | "\n", 58 | "def read_batches(T, vocab_size):\n", 59 | " length = T.shape[0]; #129,665\n", 60 | " batch_chars = int(length / BATCH_SIZE); # 8,104\n", 61 | "\n", 62 | " for start in range(0, batch_chars - SEQ_LENGTH, SEQ_LENGTH): # (0, 8040, 64)\n", 63 | " X = np.zeros((BATCH_SIZE, SEQ_LENGTH)) # 16X64\n", 64 | " Y = np.zeros((BATCH_SIZE, SEQ_LENGTH, vocab_size)) # 16X64X86\n", 65 | " for batch_idx in range(0, BATCH_SIZE): # (0,16)\n", 66 | " for i in range(0, SEQ_LENGTH): #(0,64)\n", 67 | " X[batch_idx, i] = T[batch_chars * batch_idx + start + i] # \n", 68 | " Y[batch_idx, i, T[batch_chars * batch_idx + start + i + 1]] = 1\n", 69 | " yield X, Y" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": 6, 75 | "metadata": {}, 76 | "outputs": [ 77 | { 78 | "name": "stdout", 79 | "output_type": "stream", 80 | "text": [ 81 | "processing\n", 82 | "Number of unique characters: 86\n", 83 | "processing done\n", 84 | "creating model\n", 85 | "model created\n", 86 | "Model: \"sequential_1\"\n", 87 | "_________________________________________________________________\n", 88 | "Layer (type) Output Shape Param # \n", 89 | "=================================================================\n", 90 | "embedding (Embedding) (16, 64, 512) 44032 \n", 91 | "_________________________________________________________________\n", 92 | "lstm (LSTM) (16, 64, 256) 787456 \n", 93 | "_________________________________________________________________\n", 94 | "dropout (Dropout) (16, 64, 256) 0 \n", 95 | "_________________________________________________________________\n", 96 | "lstm_1 (LSTM) (16, 64, 256) 525312 \n", 97 | "_________________________________________________________________\n", 98 | "dropout_1 (Dropout) (16, 64, 256) 0 \n", 99 | "_________________________________________________________________\n", 100 | "lstm_2 (LSTM) (16, 64, 256) 525312 \n", 101 | "_________________________________________________________________\n", 102 | "dropout_2 (Dropout) (16, 64, 256) 0 \n", 103 | "_________________________________________________________________\n", 104 | "time_distributed (TimeDistri (16, 64, 86) 22102 \n", 105 | "_________________________________________________________________\n", 106 | "activation (Activation) (16, 64, 86) 0 \n", 107 | "=================================================================\n", 108 | "Total params: 1,904,214\n", 109 | "Trainable params: 1,904,214\n", 110 | "Non-trainable params: 0\n", 111 | "_________________________________________________________________\n", 112 | "training data\n", 113 | "Length of text:129665\n", 114 | "\n", 115 | "Epoch 1/1\n", 116 | "[[51. 25. 1. ... 28. 0. 40.]\n", 117 | " [84. 29. 61. ... 29. 1. 34.]\n", 118 | " [47. 75. 58. ... 32. 33. 1.]\n", 119 | " ...\n", 120 | " [ 3. 29. 17. ... 22. 0. 47.]\n", 121 | " [22. 3. 28. ... 28. 84. 3.]\n", 122 | " [ 3. 33. 17. ... 3. 31. 3.]]\n", 123 | "Batch 1: loss = 4.454742431640625, acc = 0.0166015625\n", 124 | "[[25. 21. 14. ... 3. 28. 3.]\n", 125 | " [29. 61. 84. ... 1. 31. 33.]\n", 126 | " [28. 17. 34. ... 64. 84. 3.]\n", 127 | " ...\n", 128 | " [25. 41. 72. ... 78. 76. 66.]\n", 129 | " [31. 3. 63. ... 3. 28. 3.]\n", 130 | " [61. 17. 63. ... 0. 47. 25.]]\n", 131 | "Batch 2: loss = 4.432841777801514, acc = 0.1484375\n", 132 | "[[62. 60. 60. ... 60. 62. 1.]\n", 133 | " [28. 84. 0. ... 29. 84. 3.]\n", 134 | " [30. 3. 64. ... 1. 29. 17.]\n", 135 | " ...\n", 136 | " [60. 1. 31. ... 30. 0. 3.]\n", 137 | " [60. 17. 60. ... 17. 61. 84.]\n", 138 | " [50. 58. 68. ... 66. 58. 1.]]\n", 139 | "Batch 3: loss = 4.380638122558594, acc = 0.115234375\n", 140 | "[[58. 64. 26. ... 3. 28. 22.]\n", 141 | " [30. 3. 62. ... 1. 17. 18.]\n", 142 | " [28. 84. 3. ... 3. 31. 70.]\n", 143 | " ...\n", 144 | " [30. 3. 60. ... 34. 32. 1.]\n", 145 | " [ 3. 31. 3. ... 64. 63. 62.]\n", 146 | " [43. 65. 66. ... 3. 60. 17.]]\n", 147 | "Batch 4: loss = 4.123827934265137, acc = 0.1455078125\n", 148 | "[[ 3. 60. 29. ... 3. 31. 3.]\n", 149 | " [ 0. 47. 25. ... 75. 66. 64.]\n", 150 | " [ 3. 58. 18. ... 62. 62. 63.]\n", 151 | " ...\n", 152 | " [33. 32. 31. ... 56. 34. 17.]\n", 153 | " [84. 3. 28. ... 61. 1. 3.]\n", 154 | " [29. 1. 3. ... 3. 61. 29.]]\n", 155 | "Batch 5: loss = 3.7452220916748047, acc = 0.1611328125\n", 156 | "[[61. 18. 1. ... 1. 29. 58.]\n", 157 | " [64. 76. 11. ... 28. 22. 3.]\n", 158 | " [84. 3. 30. ... 3. 30. 3.]\n", 159 | " ...\n", 160 | " [32. 1. 3. ... 28. 29. 60.]\n", 161 | " [29. 70. 3. ... 17. 62. 84.]\n", 162 | " [34. 1. 3. ... 70. 3. 62.]]\n", 163 | "Batch 6: loss = 3.66882586479187, acc = 0.142578125\n", 164 | "[[75. 80. 62. ... 3. 28. 17.]\n", 165 | " [62. 17. 60. ... 29. 60. 84.]\n", 166 | " [62. 17. 61. ... 62. 1. 35.]\n", 167 | " ...\n", 168 | " [84. 3. 32. ... 28. 70. 3.]\n", 169 | " [ 3. 31. 3. ... 58. 77. 58.]\n", 170 | " [60. 28. 84. ... 31. 17. 29.]]\n", 171 | "Batch 7: loss = 3.616271495819092, acc = 0.0859375\n", 172 | "[[28. 84. 3. ... 1. 3. 34.]\n", 173 | " [ 3. 31. 3. ... 29. 60. 84.]\n", 174 | " [78. 71. 77. ... 25. 21. 14.]\n", 175 | " ...\n", 176 | " [28. 18. 1. ... 77. 77. 66.]\n", 177 | " [59. 58. 76. ... 29. 1. 28.]\n", 178 | " [ 1. 60. 17. ... 3. 28. 60.]]\n", 179 | "Batch 8: loss = 3.560462474822998, acc = 0.08203125\n", 180 | "[[ 3. 34. 18. ... 3. 34. 3.]\n", 181 | " [ 3. 31. 3. ... 1. 61. 17.]\n", 182 | " [23. 0. 38. ... 32. 70. 3.]\n", 183 | " ...\n", 184 | " [71. 64. 65. ... 30. 3. 58.]\n", 185 | " [34. 33. 84. ... 17. 25. 84.]\n", 186 | " [62. 1. 61. ... 1. 40. 78.]]\n", 187 | "Batch 9: loss = 3.628173351287842, acc = 0.0771484375\n", 188 | "[[64. 17. 64. ... 1. 34. 17.]\n", 189 | " [25. 84. 0. ... 38. 62. 79.]\n", 190 | " [62. 60. 28. ... 62. 61. 60.]\n", 191 | " ...\n", 192 | " [18. 1. 64. ... 33. 3. 58.]\n", 193 | " [ 0. 43. 25. ... 3. 29. 29.]\n", 194 | " [76. 66. 60. ... 3. 61. 17.]]\n", 195 | "Batch 10: loss = 3.577927589416504, acc = 0.125\n", 196 | "[[33. 84. 84. ... 3. 28. 29.]\n", 197 | " [66. 71. 1. ... 30. 3. 28.]\n", 198 | " [ 1. 3. 31. ... 28. 1. 28.]\n", 199 | " ...\n", 200 | " [18. 1. 58. ... 61. 84. 3.]\n", 201 | " [29. 1. 29. ... 34. 84. 3.]\n", 202 | " [28. 1. 12. ... 17. 28. 1.]]\n", 203 | "Batch 11: loss = 3.413236141204834, acc = 0.1552734375\n", 204 | "[[60. 84. 0. ... 29. 1. 62.]\n", 205 | " [34. 32. 84. ... 84. 3. 34.]\n", 206 | " [58. 64. 84. ... 0. 0. 0.]\n", 207 | " ...\n", 208 | " [31. 22. 3. ... 60. 1. 29.]\n", 209 | " [31. 22. 3. ... 70. 1. 40.]\n", 210 | " [12. 3. 28. ... 84. 3. 31.]]\n", 211 | "Batch 12: loss = 3.4859752655029297, acc = 0.1640625\n", 212 | "[[17. 62. 84. ... 62. 17. 62.]\n", 213 | " [ 3. 31. 34. ... 30. 3. 64.]\n", 214 | " [51. 25. 1. ... 1. 38. 66.]\n", 215 | " ...\n", 216 | " [17. 60. 84. ... 18. 1. 58.]\n", 217 | " [78. 76. 66. ... 33. 1. 34.]\n", 218 | " [ 3. 33. 34. ... 84. 54. 0.]]\n", 219 | "Batch 13: loss = 3.5393099784851074, acc = 0.1435546875\n", 220 | "[[84. 0. 3. ... 18. 0. 47.]\n", 221 | " [63. 62. 84. ... 28. 84. 3.]\n", 222 | " [75. 68. 73. ... 32. 84. 3.]\n", 223 | " ...\n", 224 | " [17. 56. 58. ... 58. 18. 84.]\n", 225 | " [17. 61. 84. ... 1. 12. 34.]\n", 226 | " [ 3. 32. 70. ... 70. 1. 40.]]\n", 227 | "Batch 14: loss = 3.498975992202759, acc = 0.1455078125\n", 228 | "[[25. 47. 65. ... 0. 40. 25.]\n", 229 | " [34. 3. 29. ... 58. 76. 62.]\n", 230 | " [34. 3. 31. ... 29. 11. 34.]\n", 231 | " ...\n", 232 | " [ 3. 34. 22. ... 1. 40. 78.]\n", 233 | " [17. 25. 25. ... 64. 58. 59.]\n", 234 | " [78. 76. 66. ... 29. 1. 3.]]\n", 235 | "Batch 15: loss = 3.531930446624756, acc = 0.1396484375\n", 236 | "[[21. 14. 23. ... 22. 3. 1.]\n", 237 | " [ 0. 52. 25. ... 31. 1. 31.]\n", 238 | " [11. 1. 31. ... 30. 1. 3.]\n", 239 | " ...\n", 240 | " [76. 66. 60. ... 84. 3. 34.]\n", 241 | " [84. 3. 30. ... 72. 75. 1.]\n", 242 | " [32. 70. 3. ... 29. 17. 56.]]\n", 243 | "Batch 16: loss = 3.4190330505371094, acc = 0.138671875\n", 244 | "[[62. 17. 1. ... 1. 61. 17.]\n", 245 | " [32. 34. 84. ... 34. 1. 28.]\n", 246 | " [33. 3. 33. ... 1. 28. 17.]\n", 247 | " ...\n", 248 | " [ 3. 61. 56. ... 3. 32. 70.]\n", 249 | " [58. 1. 45. ... 28. 0. 84.]\n", 250 | " [60. 84. 54. ... 3. 32. 70.]]\n", 251 | "Batch 17: loss = 3.3994626998901367, acc = 0.154296875\n", 252 | "[[25. 84. 0. ... 3. 62. 17.]\n", 253 | " [29. 60. 84. ... 29. 34. 32.]\n", 254 | " [34. 84. 33. ... 33. 17. 28.]\n", 255 | " ...\n", 256 | " [ 3. 29. 34. ... 3. 30. 3.]\n", 257 | " [64. 84. 1. ... 84. 3. 34.]\n", 258 | " [ 3. 34. 17. ... 29. 18. 1.]]\n", 259 | "Batch 18: loss = 3.283130168914795, acc = 0.1591796875\n", 260 | "[[ 1. 64. 84. ... 0. 51. 25.]\n", 261 | " [84. 3. 34. ... 34. 1. 28.]\n", 262 | " [84. 1. 3. ... 66. 71. 64.]\n", 263 | " ...\n", 264 | " [64. 63. 62. ... 63. 62. 1.]\n", 265 | " [ 3. 34. 17. ... 63. 17. 63.]\n", 266 | " [12. 29. 17. ... 12. 64. 17.]]\n", 267 | "Batch 19: loss = 3.5331082344055176, acc = 0.1318359375\n", 268 | "[[ 1. 19. 0. ... 1. 73. 1.]\n", 269 | " [34. 32. 84. ... 34. 32. 84.]\n", 270 | " [65. 58. 70. ... 25. 21. 14.]\n", 271 | " ...\n", 272 | " [ 3. 34. 3. ... 62. 79. 62.]\n", 273 | " [ 1. 63. 62. ... 84. 3. 28.]\n", 274 | " [63. 84. 64. ... 3. 29. 17.]]\n", 275 | "Batch 20: loss = 3.477001428604126, acc = 0.1298828125\n", 276 | "[[19. 21. 11. ... 34. 3. 61.]\n", 277 | " [ 3. 28. 70. ... 34. 32. 31.]\n", 278 | " [23. 0. 38. ... 17. 29. 1.]\n", 279 | " ...\n", 280 | " [71. 1. 46. ... 1. 79. 66.]\n", 281 | " [22. 3. 58. ... 58. 77. 58.]\n", 282 | " [61. 84. 54. ... 40. 78. 76.]]\n", 283 | "Batch 21: loss = 3.5111210346221924, acc = 0.1318359375\n", 284 | "[[62. 61. 1. ... 60. 29. 28.]\n", 285 | " [ 1. 31. 17. ... 62. 79. 66.]\n", 286 | " [29. 28. 29. ... 64. 17. 64.]\n", 287 | " ...\n", 288 | " [58. 1. 52. ... 3. 29. 60.]\n", 289 | " [59. 58. 76. ... 3. 28. 3.]\n", 290 | " [66. 60. 1. ... 84. 3. 34.]]\n", 291 | "Batch 22: loss = 3.4660110473632812, acc = 0.1376953125\n", 292 | "[[84. 3. 34. ... 61. 64. 29.]\n", 293 | " [71. 1. 29. ... 62. 84. 3.]\n", 294 | " [ 1. 64. 18. ... 22. 3. 62.]\n", 295 | " ...\n", 296 | " [61. 1. 3. ... 64. 84. 1.]\n", 297 | " [28. 60. 62. ... 3. 58. 64.]\n", 298 | " [ 3. 34. 17. ... 31. 22. 3.]]\n", 299 | "Batch 23: loss = 3.294177532196045, acc = 0.1708984375\n", 300 | "[[84. 3. 28. ... 28. 22. 3.]\n", 301 | " [31. 3. 63. ... 17. 61. 1.]\n", 302 | " [61. 60. 84. ... 66. 60. 1.]\n", 303 | " ...\n", 304 | " [53. 16. 3. ... 3. 28. 3.]\n", 305 | " [63. 1. 3. ... 64. 58. 84.]\n", 306 | " [28. 29. 60. ... 28. 70. 3.]]\n", 307 | "Batch 24: loss = 3.434953212738037, acc = 0.1513671875\n", 308 | "[[63. 64. 62. ... 1. 31. 58.]\n", 309 | " [61. 17. 61. ... 58. 17. 58.]\n", 310 | " [31. 58. 77. ... 28. 1. 28.]\n", 311 | " ...\n", 312 | " [62. 63. 64. ... 62. 63. 64.]\n", 313 | " [ 3. 31. 3. ... 28. 3. 28.]\n", 314 | " [60. 61. 62. ... 1. 3. 31.]]\n", 315 | "Batch 25: loss = 3.413236141204834, acc = 0.15625\n", 316 | "[[77. 58. 59. ... 84. 3. 32.]\n", 317 | " [84. 3. 28. ... 14. 64. 10.]\n", 318 | " [33. 28. 84. ... 28. 22. 3.]\n", 319 | " ...\n", 320 | " [84. 3. 28. ... 59. 58. 76.]\n", 321 | " [17. 25. 84. ... 45. 46. 30.]\n", 322 | " [22. 3. 61. ... 64. 58. 59.]]\n", 323 | "Batch 26: loss = 3.403815746307373, acc = 0.150390625\n", 324 | "[[22. 3. 29. ... 3. 29. 28.]\n", 325 | " [ 3. 62. 17. ... 3. 63. 17.]\n", 326 | " [62. 63. 64. ... 32. 32. 84.]\n", 327 | " ...\n", 328 | " [62. 0. 46. ... 3. 31. 3.]\n", 329 | " [31. 46. 1. ... 22. 3. 62.]\n", 330 | " [ 1. 64. 17. ... 58. 17. 64.]]\n" 331 | ] 332 | }, 333 | { 334 | "name": "stdout", 335 | "output_type": "stream", 336 | "text": [ 337 | "Batch 27: loss = 3.14103102684021, acc = 0.193359375\n", 338 | "[[29. 1. 3. ... 28. 84. 3.]\n", 339 | " [63. 84. 3. ... 18. 1. 3.]\n", 340 | " [ 3. 34. 3. ... 31. 3. 63.]\n", 341 | " ...\n", 342 | " [28. 33. 31. ... 3. 34. 18.]\n", 343 | " [61. 60. 84. ... 60. 84. 3.]\n", 344 | " [84. 3. 30. ... 34. 70. 3.]]\n", 345 | "Batch 28: loss = 3.301115036010742, acc = 0.1728515625\n", 346 | "[[28. 3. 62. ... 3. 28. 3.]\n", 347 | " [34. 3. 61. ... 3. 31. 3.]\n", 348 | " [61. 61. 1. ... 25. 40. 66.]\n", 349 | " ...\n", 350 | " [ 1. 3. 29. ... 3. 29. 60.]\n", 351 | " [31. 3. 61. ... 1. 3. 28.]\n", 352 | " [64. 17. 34. ... 61. 60. 29.]]\n", 353 | "Batch 29: loss = 3.397311210632324, acc = 0.1494140625\n", 354 | "[[58. 62. 60. ... 25. 33. 72.]\n", 355 | " [63. 17. 64. ... 1. 46. 72.]\n", 356 | " [68. 62. 1. ... 28. 84. 3.]\n", 357 | " ...\n", 358 | " [28. 1. 3. ... 0. 3. 32.]\n", 359 | " [22. 3. 29. ... 17. 33. 84.]\n", 360 | " [ 1. 28. 34. ... 65. 58. 70.]]\n", 361 | "Batch 30: loss = 3.347227096557617, acc = 0.1435546875\n", 362 | "[[75. 70. 58. ... 17. 29. 14.]\n", 363 | " [ 1. 29. 69. ... 33. 0. 40.]\n", 364 | " [34. 3. 29. ... 3. 34. 3.]\n", 365 | " ...\n", 366 | " [70. 3. 29. ... 34. 84. 3.]\n", 367 | " [ 3. 34. 3. ... 28. 3. 32.]\n", 368 | " [ 1. 40. 78. ... 84. 3. 28.]]\n", 369 | "Batch 31: loss = 3.249284505844116, acc = 0.1552734375\n", 370 | "[[17. 84. 28. ... 40. 25. 21.]\n", 371 | " [25. 19. 14. ... 61. 17. 62.]\n", 372 | " [29. 34. 32. ... 1. 63. 61.]\n", 373 | " ...\n", 374 | " [31. 3. 33. ... 61. 66. 64.]\n", 375 | " [17. 28. 1. ... 62. 84. 0.]\n", 376 | " [ 3. 60. 18. ... 62. 1. 61.]]\n", 377 | "Batch 32: loss = 3.3391714096069336, acc = 0.1337890625\n", 378 | "[[14. 23. 0. ... 1. 28. 34.]\n", 379 | " [ 1. 64. 63. ... 25. 29. 0.]\n", 380 | " [61. 84. 3. ... 3. 61. 18.]\n", 381 | " ...\n", 382 | " [ 0. 5. 1. ... 45. 72. 80.]\n", 383 | " [ 3. 31. 22. ... 18. 1. 61.]\n", 384 | " [60. 29. 84. ... 32. 28. 60.]]\n", 385 | "Batch 33: loss = 3.273256778717041, acc = 0.15234375\n", 386 | "[[33. 84. 32. ... 31. 1. 32.]\n", 387 | " [84. 25. 61. ... 63. 84. 3.]\n", 388 | " [ 1. 61. 60. ... 3. 34. 29.]\n", 389 | " ...\n", 390 | " [62. 0. 40. ... 3. 62. 18.]\n", 391 | " [17. 84. 84. ... 43. 58. 78.]\n", 392 | " [ 1. 62. 61. ... 60. 29. 84.]]\n", 393 | "Batch 34: loss = 3.251103401184082, acc = 0.154296875\n", 394 | "[[31. 30. 84. ... 17. 63. 1.]\n", 395 | " [30. 14. 62. ... 3. 28. 70.]\n", 396 | " [29. 1. 3. ... 66. 75. 60.]\n", 397 | " ...\n", 398 | " [ 1. 3. 32. ... 84. 3. 28.]\n", 399 | " [69. 66. 71. ... 34. 3. 29.]\n", 400 | " [53. 16. 3. ... 50. 66. 63.]]\n", 401 | "Batch 35: loss = 3.2006759643554688, acc = 0.1494140625\n", 402 | "[[63. 62. 63. ... 17. 62. 84.]\n", 403 | " [ 3. 1. 28. ... 3. 34. 22.]\n", 404 | " [69. 62. 1. ... 79. 66. 58.]\n", 405 | " ...\n", 406 | " [ 3. 60. 62. ... 1. 31. 32.]\n", 407 | " [18. 1. 61. ... 3. 62. 17.]\n", 408 | " [62. 0. 5. ... 25. 21. 14.]]\n", 409 | "Batch 36: loss = 3.3504037857055664, acc = 0.1435546875\n", 410 | "[[84. 0. 38. ... 61. 60. 84.]\n", 411 | " [ 3. 1. 62. ... 0. 51. 25.]\n", 412 | " [ 1. 43. 65. ... 17. 33. 1.]\n", 413 | " ...\n", 414 | " [33. 84. 3. ... 18. 1. 12.]\n", 415 | " [28. 1. 29. ... 64. 17. 61.]\n", 416 | " [23. 0. 38. ... 32. 22. 3.]]\n", 417 | "Batch 37: loss = 3.3050389289855957, acc = 0.130859375\n", 418 | "[[ 0. 43. 25. ... 60. 1. 31.]\n", 419 | " [ 1. 17. 23. ... 1. 16. 1.]\n", 420 | " [34. 33. 32. ... 3. 31. 22.]\n", 421 | " ...\n", 422 | " [29. 17. 32. ... 3. 28. 3.]\n", 423 | " [ 1. 62. 17. ... 28. 21. 12.]\n", 424 | " [32. 17. 32. ... 17. 84. 84.]]\n", 425 | "Batch 38: loss = 3.2785816192626953, acc = 0.1376953125\n", 426 | "[[58. 77. 58. ... 70. 3. 33.]\n", 427 | " [73. 1. 19. ... 28. 22. 3.]\n", 428 | " [ 3. 61. 17. ... 33. 32. 84.]\n", 429 | " ...\n", 430 | " [28. 18. 1. ... 62. 1. 3.]\n", 431 | " [84. 28. 18. ... 1. 60. 17.]\n", 432 | " [ 0. 43. 25. ... 61. 29. 84.]]\n", 433 | "Batch 39: loss = 3.2367336750030518, acc = 0.1435546875\n", 434 | "[[34. 33. 1. ... 84. 0. 3.]\n", 435 | " [61. 63. 59. ... 31. 3. 61.]\n", 436 | " [ 3. 34. 3. ... 30. 3. 32.]\n", 437 | " ...\n", 438 | " [32. 3. 61. ... 76. 66. 60.]\n", 439 | " [61. 84. 3. ... 77. 66. 71.]\n", 440 | " [ 0. 3. 32. ... 3. 28. 3.]]\n", 441 | "Batch 40: loss = 3.216315269470215, acc = 0.158203125\n", 442 | "[[34. 3. 62. ... 33. 17. 28.]\n", 443 | " [18. 1. 61. ... 31. 3. 58.]\n", 444 | " [34. 32. 84. ... 3. 28. 3.]\n", 445 | " ...\n", 446 | " [ 1. 31. 58. ... 9. 3. 62.]\n", 447 | " [64. 65. 58. ... 0. 38. 25.]\n", 448 | " [32. 28. 60. ... 29. 84. 3.]]\n", 449 | "Batch 41: loss = 3.2099642753601074, acc = 0.1484375\n", 450 | "[[ 1. 61. 17. ... 33. 17. 32.]\n", 451 | " [18. 1. 3. ... 28. 22. 3.]\n", 452 | " [32. 28. 60. ... 3. 62. 18.]\n", 453 | " ...\n", 454 | " [17. 61. 1. ... 17. 60. 84.]\n", 455 | " [31. 0. 3. ... 34. 84. 3.]\n", 456 | " [28. 3. 28. ... 1. 31. 58.]]\n", 457 | "Batch 42: loss = 3.120986223220825, acc = 0.177734375\n", 458 | "[[84. 3. 31. ... 34. 3. 29.]\n", 459 | " [28. 18. 84. ... 62. 63. 84.]\n", 460 | " [ 1. 62. 17. ... 25. 47. 65.]\n", 461 | " ...\n", 462 | " [ 0. 3. 32. ... 34. 17. 29.]\n", 463 | " [31. 3. 28. ... 3. 34. 60.]\n", 464 | " [77. 58. 59. ... 61. 84. 3.]]\n", 465 | "Batch 43: loss = 3.1336774826049805, acc = 0.1767578125\n", 466 | "[[18. 1. 61. ... 28. 17. 28.]\n", 467 | " [ 3. 29. 70. ... 70. 62. 0.]\n", 468 | " [62. 1. 30. ... 25. 21. 14.]\n", 469 | " ...\n", 470 | " [84. 3. 28. ... 28. 18. 1.]\n", 471 | " [34. 1. 60. ... 16. 3. 28.]\n", 472 | " [30. 3. 62. ... 3. 28. 70.]]\n", 473 | "Batch 44: loss = 3.263458251953125, acc = 0.158203125\n", 474 | "[[84. 0. 3. ... 84. 3. 31.]\n", 475 | " [ 5. 1. 41. ... 0. 38. 25.]\n", 476 | " [23. 0. 38. ... 3. 28. 33.]\n", 477 | " ...\n", 478 | " [12. 28. 17. ... 32. 3. 29.]\n", 479 | " [22. 3. 34. ... 71. 64. 65.]\n", 480 | " [ 3. 60. 28. ... 33. 34. 84.]]\n", 481 | "Batch 45: loss = 3.2378668785095215, acc = 0.15625\n", 482 | "[[ 3. 31. 18. ... 59. 58. 76.]\n", 483 | " [28. 70. 0. ... 33. 3. 60.]\n", 484 | " [31. 84. 0. ... 3. 31. 3.]\n", 485 | " ...\n", 486 | " [18. 1. 12. ... 1. 64. 63.]\n", 487 | " [58. 70. 1. ... 84. 3. 32.]\n", 488 | " [ 0. 3. 28. ... 25. 50. 65.]]\n", 489 | "Batch 46: loss = 3.1844067573547363, acc = 0.166015625\n", 490 | "[[62. 0. 52. ... 3. 62. 18.]\n", 491 | " [29. 28. 1. ... 3. 30. 3.]\n", 492 | " [63. 58. 61. ... 22. 3. 62.]\n", 493 | " ...\n", 494 | " [62. 84. 0. ... 84. 0. 0.]\n", 495 | " [70. 3. 34. ... 3. 61. 56.]\n", 496 | " [62. 71. 1. ... 66. 71. 1.]]\n", 497 | "Batch 47: loss = 3.2154860496520996, acc = 0.16796875\n", 498 | "[[ 1. 62. 60. ... 28. 1. 3.]\n", 499 | " [60. 61. 62. ... 61. 60. 84.]\n", 500 | " [63. 64. 84. ... 41. 72. 77.]\n", 501 | " ...\n", 502 | " [ 0. 51. 25. ... 75. 66. 64.]\n", 503 | " [60. 61. 1. ... 60. 61. 1.]\n", 504 | " [29. 75. 66. ... 29. 60. 1.]]\n", 505 | "Batch 48: loss = 3.153773069381714, acc = 0.171875\n", 506 | "[[32. 22. 3. ... 3. 31. 3.]\n", 507 | " [ 3. 32. 22. ... 66. 60. 1.]\n", 508 | " [77. 66. 71. ... 43. 25. 28.]\n", 509 | " ...\n", 510 | " [64. 76. 11. ... 29. 34. 32.]\n", 511 | " [ 3. 32. 70. ... 34. 33. 84.]\n", 512 | " [ 3. 31. 22. ... 34. 18. 1.]]\n", 513 | "Batch 49: loss = 3.249722719192505, acc = 0.146484375\n", 514 | "[[28. 61. 63. ... 62. 14. 17.]\n", 515 | " [31. 58. 77. ... 3. 32. 70.]\n", 516 | " [ 0. 62. 14. ... 3. 28. 3.]\n", 517 | " ...\n", 518 | " [ 1. 32. 34. ... 28. 84. 3.]\n", 519 | " [ 3. 32. 70. ... 59. 58. 76.]\n", 520 | " [12. 34. 17. ... 34. 3. 34.]]\n", 521 | "Batch 50: loss = 3.1717400550842285, acc = 0.154296875\n", 522 | "[[58. 18. 14. ... 3. 58. 62.]\n", 523 | " [ 3. 64. 63. ... 70. 3. 62.]\n", 524 | " [60. 61. 60. ... 14. 17. 84.]\n", 525 | " ...\n", 526 | " [31. 3. 61. ... 63. 64. 84.]\n", 527 | " [62. 0. 52. ... 17. 34. 1.]\n", 528 | " [17. 28. 1. ... 84. 3. 31.]]\n", 529 | "Batch 51: loss = 3.1504948139190674, acc = 0.158203125\n", 530 | "[[62. 1. 62. ... 29. 1. 3.]\n", 531 | " [60. 28. 1. ... 3. 58. 64.]\n", 532 | " [ 3. 28. 3. ... 62. 1. 3.]\n", 533 | " ...\n", 534 | " [ 0. 3. 31. ... 65. 58. 71.]\n", 535 | " [ 3. 31. 22. ... 28. 33. 31.]\n", 536 | " [22. 3. 33. ... 70. 3. 61.]]\n", 537 | "Batch 52: loss = 3.1174869537353516, acc = 0.1728515625\n", 538 | "[[28. 3. 28. ... 29. 34. 29.]\n", 539 | " [63. 84. 3. ... 53. 16. 3.]\n", 540 | " [31. 3. 63. ... 58. 66. 71.]\n", 541 | " ...\n", 542 | " [62. 7. 76. ... 14. 23. 0.]\n", 543 | " [84. 3. 34. ... 34. 84. 3.]\n", 544 | " [17. 29. 1. ... 40. 78. 76.]]\n", 545 | "Batch 53: loss = 3.21025013923645, acc = 0.1513671875\n", 546 | "[[84. 0. 3. ... 0. 51. 25.]\n", 547 | " [32. 70. 3. ... 1. 41. 72.]\n", 548 | " [62. 0. 5. ... 70. 0. 32.]\n", 549 | " ...\n", 550 | " [38. 25. 34. ... 34. 3. 61.]\n", 551 | " [32. 70. 3. ... 17. 33. 1.]\n", 552 | " [66. 60. 1. ... 3. 61. 29.]]\n", 553 | "Batch 54: loss = 3.2601447105407715, acc = 0.1552734375\n", 554 | "[[ 1. 24. 0. ... 76. 1. 39.]\n", 555 | " [77. 77. 66. ... 3. 32. 28.]\n", 556 | " [84. 3. 28. ... 60. 84. 3.]\n", 557 | " ...\n", 558 | " [29. 34. 1. ... 84. 3. 34.]\n", 559 | " [34. 17. 3. ... 60. 61. 84.]\n", 560 | " [34. 1. 3. ... 3. 30. 3.]]\n", 561 | "Batch 55: loss = 3.2791213989257812, acc = 0.154296875\n", 562 | "[[ 1. 31. 72. ... 84. 3. 32.]\n", 563 | " [28. 1. 28. ... 84. 3. 28.]\n", 564 | " [32. 22. 3. ... 64. 17. 56.]\n", 565 | " ...\n", 566 | " [ 3. 61. 29. ... 61. 84. 3.]\n", 567 | " [ 3. 31. 3. ... 30. 3. 29.]\n", 568 | " [60. 61. 62. ... 64. 17. 58.]]\n", 569 | "Batch 56: loss = 3.1544079780578613, acc = 0.1552734375\n", 570 | "[[70. 3. 32. ... 63. 84. 62.]\n", 571 | " [70. 3. 62. ... 34. 34. 1.]\n", 572 | " [64. 84. 3. ... 84. 3. 28.]\n", 573 | " ...\n", 574 | " [30. 3. 62. ... 51. 25. 1.]\n", 575 | " [17. 60. 84. ... 5. 1. 41.]\n", 576 | " [84. 3. 32. ... 70. 3. 62.]]\n", 577 | "Batch 57: loss = 3.1329426765441895, acc = 0.1845703125\n", 578 | "[[17. 61. 1. ... 60. 84. 3.]\n", 579 | " [61. 34. 34. ... 28. 1. 28.]\n", 580 | " [70. 3. 60. ... 59. 69. 62.]\n", 581 | " ...\n", 582 | " [17. 23. 21. ... 71. 71. 62.]\n", 583 | " [72. 77. 77. ... 14. 17. 29.]\n", 584 | " [18. 1. 12. ... 65. 58. 70.]]\n", 585 | "Batch 58: loss = 3.2053322792053223, acc = 0.146484375\n", 586 | "[[31. 3. 61. ... 78. 69. 66.]\n", 587 | " [17. 63. 84. ... 29. 61. 84.]\n", 588 | " [76. 0. 5. ... 0. 28. 84.]\n", 589 | " ...\n", 590 | " [61. 82. 11. ... 34. 17. 34.]\n", 591 | " [14. 17. 56. ... 18. 1. 28.]\n", 592 | " [ 1. 40. 78. ... 28. 70. 0.]]\n", 593 | "Batch 59: loss = 3.227497100830078, acc = 0.158203125\n", 594 | "[[71. 62. 1. ... 33. 32. 31.]\n", 595 | " [ 3. 28. 70. ... 78. 71. 72.]\n", 596 | " [ 3. 31. 3. ... 28. 28. 1.]\n", 597 | " ...\n", 598 | " [ 1. 34. 17. ... 70. 3. 61.]\n", 599 | " [33. 28. 84. ... 84. 25. 3.]\n", 600 | " [ 3. 28. 70. ... 56. 63. 64.]]\n", 601 | "Batch 60: loss = 3.1385927200317383, acc = 0.158203125\n", 602 | "[[ 1. 28. 34. ... 0. 31. 84.]\n", 603 | " [ 0. 5. 1. ... 40. 25. 21.]\n", 604 | " [61. 28. 28. ... 59. 59. 84.]\n", 605 | " ...\n", 606 | " [18. 1. 61. ... 1. 32. 18.]\n", 607 | " [33. 3. 28. ... 84. 58. 60.]\n", 608 | " [ 1. 62. 17. ... 34. 3. 61.]]\n" 609 | ] 610 | }, 611 | { 612 | "name": "stdout", 613 | "output_type": "stream", 614 | "text": [ 615 | "Batch 61: loss = 3.1540369987487793, acc = 0.1650390625\n", 616 | "[[ 3. 34. 3. ... 31. 22. 3.]\n", 617 | " [14. 23. 0. ... 26. 29. 28.]\n", 618 | " [ 3. 31. 3. ... 1. 61. 62.]\n", 619 | " ...\n", 620 | " [84. 0. 3. ... 23. 22. 0.]\n", 621 | " [ 7. 58. 1. ... 51. 25. 1.]\n", 622 | " [29. 34. 1. ... 28. 1. 62.]]\n", 623 | "Batch 62: loss = 3.0938234329223633, acc = 0.181640625\n", 624 | "[[28. 29. 28. ... 64. 0. 5.]\n", 625 | " [84. 0. 3. ... 25. 0. 58.]\n", 626 | " [63. 84. 3. ... 61. 0. 5.]\n", 627 | " ...\n", 628 | " [47. 25. 47. ... 65. 58. 70.]\n", 629 | " [18. 15. 24. ... 45. 72. 80.]\n", 630 | " [28. 62. 84. ... 84. 62. 17.]]\n", 631 | "Batch 63: loss = 3.2291955947875977, acc = 0.158203125\n", 632 | "[[ 1. 41. 72. ... 3. 61. 17.]\n", 633 | " [84. 3. 31. ... 1. 59. 58.]\n", 634 | " [ 1. 41. 72. ... 31. 33. 28.]\n", 635 | " ...\n", 636 | " [ 1. 40. 78. ... 34. 3. 64.]\n", 637 | " [62. 0. 40. ... 61. 60. 84.]\n", 638 | " [61. 1. 64. ... 3. 28. 70.]]\n", 639 | "Batch 64: loss = 3.213397979736328, acc = 0.146484375\n", 640 | "[[59. 1. 59. ... 64. 1. 64.]\n", 641 | " [64. 84. 3. ... 66. 64. 0.]\n", 642 | " [84. 3. 31. ... 3. 31. 3.]\n", 643 | " ...\n", 644 | " [17. 34. 1. ... 30. 3. 32.]\n", 645 | " [ 0. 3. 34. ... 3. 30. 3.]\n", 646 | " [ 3. 62. 17. ... 0. 0. 51.]]\n", 647 | "Batch 65: loss = 3.0897982120513916, acc = 0.1591796875\n", 648 | "[[17. 64. 84. ... 64. 1. 59.]\n", 649 | " [ 5. 1. 41. ... 84. 63. 14.]\n", 650 | " [61. 62. 61. ... 3. 62. 60.]\n", 651 | " ...\n", 652 | " [34. 34. 1. ... 61. 1. 61.]\n", 653 | " [60. 61. 62. ... 17. 60. 1.]\n", 654 | " [25. 1. 18. ... 78. 58. 69.]]\n", 655 | "Batch 66: loss = 3.110912561416626, acc = 0.1689453125\n", 656 | "[[64. 62. 84. ... 84. 3. 34.]\n", 657 | " [17. 58. 18. ... 1. 29. 14.]\n", 658 | " [28. 1. 28. ... 1. 31. 33.]\n", 659 | " ...\n", 660 | " [29. 34. 84. ... 61. 62. 63.]\n", 661 | " [ 3. 31. 22. ... 77. 58. 59.]\n", 662 | " [11. 1. 79. ... 31. 17. 31.]]\n", 663 | "Batch 67: loss = 3.162571430206299, acc = 0.166015625\n", 664 | "[[ 3. 34. 29. ... 77. 58. 59.]\n", 665 | " [17. 61. 18. ... 17. 61. 84.]\n", 666 | " [28. 84. 3. ... 41. 72. 77.]\n", 667 | " ...\n", 668 | " [84. 3. 32. ... 72. 77. 77.]\n", 669 | " [58. 76. 62. ... 64. 1. 63.]\n", 670 | " [84. 3. 30. ... 3. 29. 17.]]\n", 671 | "Batch 68: loss = 3.186009407043457, acc = 0.158203125\n", 672 | "[[58. 76. 62. ... 32. 18. 1.]\n", 673 | " [ 3. 34. 3. ... 60. 29. 28.]\n", 674 | " [77. 66. 71. ... 38. 25. 31.]\n", 675 | " ...\n", 676 | " [66. 71. 64. ... 17. 84. 3.]\n", 677 | " [62. 61. 84. ... 62. 61. 1.]\n", 678 | " [28. 84. 3. ... 3. 29. 18.]]\n", 679 | "Batch 69: loss = 3.123450517654419, acc = 0.16796875\n", 680 | "[[29. 17. 28. ... 30. 3. 34.]\n", 681 | " [84. 3. 31. ... 63. 64. 25.]\n", 682 | " [ 0. 29. 84. ... 84. 63. 62.]\n", 683 | " ...\n", 684 | " [31. 3. 58. ... 3. 31. 3.]\n", 685 | " [61. 17. 25. ... 31. 3. 31.]\n", 686 | " [ 1. 29. 60. ... 25. 84. 0.]]\n", 687 | "Batch 70: loss = 3.036101818084717, acc = 0.1787109375\n", 688 | "[[17. 33. 84. ... 28. 18. 1.]\n", 689 | " [84. 0. 1. ... 76. 66. 60.]\n", 690 | " [61. 1. 28. ... 3. 28. 60.]\n", 691 | " ...\n", 692 | " [58. 59. 58. ... 58. 59. 58.]\n", 693 | " [33. 28. 1. ... 47. 25. 47.]\n", 694 | " [ 0. 0. 51. ... 58. 75. 75.]]\n", 695 | "Batch 71: loss = 3.1351051330566406, acc = 0.1669921875\n", 696 | "[[60. 17. 29. ... 72. 77. 77.]\n", 697 | " [ 1. 31. 58. ... 61. 84. 3.]\n", 698 | " [60. 1. 60. ... 61. 18. 1.]\n", 699 | " ...\n", 700 | " [ 1. 3. 28. ... 60. 29. 84.]\n", 701 | " [58. 70. 7. ... 28. 29. 0.]\n", 702 | " [ 1. 43. 65. ... 29. 84. 3.]]\n", 703 | "Batch 72: loss = 3.170665740966797, acc = 0.177734375\n", 704 | "[[66. 71. 64. ... 0. 29. 14.]\n", 705 | " [28. 70. 3. ... 3. 28. 70.]\n", 706 | " [12. 61. 17. ... 25. 47. 75.]\n", 707 | " ...\n", 708 | " [ 3. 31. 3. ... 61. 17. 60.]\n", 709 | " [40. 25. 21. ... 17. 62. 84.]\n", 710 | " [28. 3. 60. ... 3. 28. 3.]]\n", 711 | "Batch 73: loss = 3.100280523300171, acc = 0.17578125\n", 712 | "[[17. 28. 14. ... 17. 29. 84.]\n", 713 | " [ 3. 62. 17. ... 62. 60. 84.]\n", 714 | " [58. 61. 11. ... 1. 3. 28.]\n", 715 | " ...\n", 716 | " [84. 3. 34. ... 22. 3. 32.]\n", 717 | " [ 3. 31. 70. ... 62. 17. 25.]\n", 718 | " [28. 11. 29. ... 61. 84. 3.]]\n", 719 | "Batch 74: loss = 3.035527229309082, acc = 0.1796875\n", 720 | "[[ 0. 3. 34. ... 25. 29. 0.]\n", 721 | " [ 3. 34. 3. ... 3. 32. 22.]\n", 722 | " [ 3. 32. 33. ... 84. 0. 62.]\n", 723 | " ...\n", 724 | " [18. 1. 33. ... 29. 84. 3.]\n", 725 | " [84. 53. 17. ... 0. 3. 34.]\n", 726 | " [28. 3. 60. ... 32. 22. 3.]]\n", 727 | "Batch 75: loss = 3.0142033100128174, acc = 0.1748046875\n", 728 | "[[84. 25. 29. ... 62. 29. 62.]\n", 729 | " [ 3. 61. 62. ... 77. 77. 66.]\n", 730 | " [84. 3. 31. ... 1. 61. 17.]\n", 731 | " ...\n", 732 | " [31. 3. 28. ... 61. 1. 46.]\n", 733 | " [22. 3. 64. ... 16. 17. 0.]\n", 734 | " [61. 17. 29. ... 1. 34. 17.]]\n", 735 | "Batch 76: loss = 3.1168668270111084, acc = 0.171875\n", 736 | "[[ 1. 3. 32. ... 17. 25. 84.]\n", 737 | " [71. 64. 65. ... 1. 63. 17.]\n", 738 | " [60. 84. 0. ... 3. 60. 29.]\n", 739 | " ...\n", 740 | " [65. 66. 73. ... 25. 31. 0.]\n", 741 | " [47. 25. 31. ... 1. 32. 33.]\n", 742 | " [29. 84. 3. ... 32. 18. 84.]]\n", 743 | "Batch 77: loss = 3.021369218826294, acc = 0.197265625\n", 744 | "[[53. 17. 3. ... 58. 77. 58.]\n", 745 | " [58. 84. 3. ... 3. 64. 63.]\n", 746 | " [28. 84. 0. ... 0. 51. 25.]\n", 747 | " ...\n", 748 | " [28. 84. 3. ... 3. 62. 60.]\n", 749 | " [ 0. 52. 25. ... 28. 1. 3.]\n", 750 | " [ 3. 31. 3. ... 25. 1. 18.]]\n", 751 | "Batch 78: loss = 2.984117031097412, acc = 0.1806640625\n", 752 | "[[59. 58. 76. ... 34. 3. 61.]\n", 753 | " [64. 1. 3. ... 29. 1. 29.]\n", 754 | " [ 1. 20. 23. ... 76. 77. 11.]\n", 755 | " ...\n", 756 | " [28. 84. 3. ... 70. 3. 59.]\n", 757 | " [28. 22. 3. ... 31. 3. 63.]\n", 758 | " [18. 19. 0. ... 43. 65. 66.]]\n", 759 | "Batch 79: loss = 3.0566611289978027, acc = 0.1796875\n", 760 | "[[62. 61. 1. ... 84. 3. 30.]\n", 761 | " [17. 29. 84. ... 3. 58. 63.]\n", 762 | " [ 1. 79. 66. ... 28. 62. 63.]\n", 763 | " ...\n", 764 | " [17. 64. 1. ... 61. 18. 1.]\n", 765 | " [17. 61. 1. ... 84. 3. 32.]\n", 766 | " [69. 1. 45. ... 28. 34. 1.]]\n", 767 | "Batch 80: loss = 2.877868175506592, acc = 0.2080078125\n", 768 | "[[ 3. 28. 29. ... 62. 61. 29.]\n", 769 | " [61. 1. 3. ... 66. 60. 1.]\n", 770 | " [84. 3. 32. ... 84. 3. 31.]\n", 771 | " ...\n", 772 | " [12. 61. 17. ... 58. 59. 58.]\n", 773 | " [70. 3. 64. ... 1. 47. 58.]\n", 774 | " [ 3. 31. 22. ... 25. 84. 0.]]\n", 775 | "Batch 81: loss = 3.063302516937256, acc = 0.1748046875\n", 776 | "[[84. 3. 29. ... 34. 84. 3.]\n", 777 | " [31. 58. 77. ... 33. 1. 31.]\n", 778 | " [ 3. 58. 63. ... 1. 3. 29.]\n", 779 | " ...\n", 780 | " [76. 62. 0. ... 3. 63. 62.]\n", 781 | " [77. 77. 66. ... 21. 14. 23.]\n", 782 | " [61. 84. 3. ... 61. 84. 3.]]\n", 783 | "Batch 82: loss = 3.143832206726074, acc = 0.1748046875\n", 784 | "[[30. 3. 28. ... 66. 71. 64.]\n", 785 | " [33. 32. 12. ... 22. 3. 63.]\n", 786 | " [70. 3. 28. ... 28. 33. 28.]\n", 787 | " ...\n", 788 | " [61. 1. 28. ... 34. 3. 29.]\n", 789 | " [ 0. 38. 25. ... 3. 58. 28.]\n", 790 | " [30. 3. 62. ... 34. 3. 34.]]\n", 791 | "Batch 83: loss = 3.0165915489196777, acc = 0.1923828125\n", 792 | "[[65. 58. 70. ... 28. 17. 28.]\n", 793 | " [26. 62. 63. ... 33. 1. 3.]\n", 794 | " [ 1. 61. 62. ... 1. 3. 28.]\n", 795 | " ...\n", 796 | " [17. 61. 1. ... 28. 17. 63.]\n", 797 | " [28. 1. 3. ... 64. 1. 58.]\n", 798 | " [28. 34. 1. ... 17. 62. 84.]]\n", 799 | "Batch 84: loss = 2.987359046936035, acc = 0.1884765625\n", 800 | "[[84. 3. 34. ... 29. 34. 84.]\n", 801 | " [30. 70. 3. ... 31. 12. 33.]\n", 802 | " [22. 3. 33. ... 84. 3. 31.]\n", 803 | " ...\n", 804 | " [84. 3. 32. ... 84. 3. 31.]\n", 805 | " [17. 62. 84. ... 1. 63. 62.]\n", 806 | " [ 3. 34. 3. ... 1. 28. 1.]]\n", 807 | "Batch 85: loss = 2.8812522888183594, acc = 0.2060546875\n", 808 | "[[54. 0. 3. ... 17. 62. 84.]\n", 809 | " [28. 1. 29. ... 64. 84. 63.]\n", 810 | " [ 3. 28. 61. ... 40. 78. 76.]\n", 811 | " ...\n", 812 | " [ 3. 61. 18. ... 76. 62. 0.]\n", 813 | " [62. 84. 3. ... 66. 71. 64.]\n", 814 | " [50. 58. 71. ... 66. 77. 66.]]\n", 815 | "Batch 86: loss = 3.0345849990844727, acc = 0.1650390625\n", 816 | "[[ 3. 34. 3. ... 18. 84. 84.]\n", 817 | " [26. 62. 63. ... 1. 12. 63.]\n", 818 | " [66. 60. 1. ... 18. 28. 14.]\n", 819 | " ...\n", 820 | " [46. 25. 37. ... 62. 63. 64.]\n", 821 | " [65. 58. 70. ... 28. 84. 3.]\n", 822 | " [72. 71. 11. ... 60. 34. 32.]]\n", 823 | "Batch 87: loss = 3.17026948928833, acc = 0.1484375\n", 824 | "[[ 0. 3. 34. ... 31. 3. 28.]\n", 825 | " [17. 59. 84. ... 62. 75. 62.]\n", 826 | " [17. 29. 14. ... 29. 60. 84.]\n", 827 | " ...\n", 828 | " [ 1. 3. 31. ... 3. 34. 3.]\n", 829 | " [34. 3. 29. ... 3. 28. 70.]\n", 830 | " [84. 26. 60. ... 25. 0. 33.]]\n", 831 | "Batch 88: loss = 3.017911434173584, acc = 0.193359375\n", 832 | "[[17. 28. 84. ... 1. 8. 41.]\n", 833 | " [ 1. 30. 58. ... 69. 1. 45.]\n", 834 | " [ 0. 3. 31. ... 3. 61. 17.]\n", 835 | " ...\n", 836 | " [31. 34. 28. ... 18. 84. 3.]\n", 837 | " [ 3. 61. 62. ... 34. 84. 3.]\n", 838 | " [14. 17. 34. ... 3. 31. 3.]]\n", 839 | "Batch 89: loss = 3.0516326427459717, acc = 0.1728515625\n", 840 | "[[72. 75. 77. ... 14. 17. 84.]\n", 841 | " [72. 80. 62. ... 1. 61. 0.]\n", 842 | " [62. 1. 63. ... 14. 60. 10.]\n", 843 | " ...\n", 844 | " [31. 3. 28. ... 62. 18. 1.]\n", 845 | " [28. 70. 3. ... 59. 3. 3.]\n", 846 | " [28. 18. 1. ... 84. 3. 30.]]\n", 847 | "Batch 90: loss = 3.1139893531799316, acc = 0.1904296875\n", 848 | "[[ 3. 34. 3. ... 31. 84. 0.]\n", 849 | " [29. 34. 84. ... 29. 84. 3.]\n", 850 | " [ 3. 64. 17. ... 58. 76. 76.]\n", 851 | " ...\n", 852 | " [28. 29. 14. ... 28. 28. 84.]\n", 853 | " [33. 3. 64. ... 40. 78. 76.]\n", 854 | " [ 3. 64. 58. ... 31. 3. 61.]]\n", 855 | "Batch 91: loss = 2.9526102542877197, acc = 0.2080078125\n", 856 | "[[ 3. 30. 3. ... 34. 1. 3.]\n", 857 | " [30. 3. 60. ... 3. 61. 62.]\n", 858 | " [ 1. 42. 7. ... 80. 62. 0.]\n", 859 | " ...\n", 860 | " [54. 0. 3. ... 0. 46. 25.]\n", 861 | " [66. 60. 1. ... 32. 84. 3.]\n", 862 | " [18. 1. 12. ... 0. 46. 25.]]\n", 863 | "Batch 92: loss = 2.9916226863861084, acc = 0.185546875\n", 864 | "[[29. 70. 3. ... 34. 3. 29.]\n", 865 | " [56. 64. 84. ... 77. 77. 66.]\n", 866 | " [40. 25. 21. ... 3. 60. 17.]\n", 867 | " ...\n", 868 | " [33. 47. 29. ... 31. 3. 61.]\n", 869 | " [31. 3. 33. ... 84. 0. 3.]\n", 870 | " [50. 72. 69. ... 1. 63. 62.]]\n", 871 | "Batch 93: loss = 3.004767894744873, acc = 0.181640625\n", 872 | "[[28. 34. 84. ... 5. 1. 41.]\n", 873 | " [71. 64. 65. ... 25. 28. 0.]\n", 874 | " [25. 84. 0. ... 58. 17. 63.]\n", 875 | " ...\n", 876 | " [62. 61. 1. ... 61. 29. 28.]\n", 877 | " [32. 70. 3. ... 63. 17. 58.]\n", 878 | " [61. 84. 0. ... 84. 63. 62.]]\n", 879 | "Batch 94: loss = 2.9218974113464355, acc = 0.203125\n", 880 | "[[72. 77. 77. ... 3. 34. 17.]\n", 881 | " [28. 84. 3. ... 1. 61. 17.]\n", 882 | " [ 1. 3. 30. ... 71. 1. 43.]\n", 883 | " ...\n", 884 | " [84. 3. 31. ... 28. 22. 3.]\n", 885 | " [ 1. 58. 63. ... 63. 84. 0.]\n", 886 | " [61. 1. 56. ... 61. 63. 1.]]\n" 887 | ] 888 | }, 889 | { 890 | "name": "stdout", 891 | "output_type": "stream", 892 | "text": [ 893 | "Batch 95: loss = 2.8708009719848633, acc = 0.2158203125\n", 894 | "[[34. 1. 3. ... 17. 61. 1.]\n", 895 | " [63. 84. 3. ... 28. 17. 63.]\n", 896 | " [72. 80. 75. ... 1. 45. 72.]\n", 897 | " ...\n", 898 | " [58. 17. 28. ... 18. 0. 47.]\n", 899 | " [ 3. 31. 3. ... 25. 1. 18.]\n", 900 | " [28. 61. 63. ... 71. 68. 76.]]\n", 901 | "Batch 96: loss = 2.889787197113037, acc = 0.205078125\n", 902 | "[[ 3. 28. 70. ... 62. 63. 84.]\n", 903 | " [84. 3. 28. ... 56. 61. 62.]\n", 904 | " [80. 62. 0. ... 84. 3. 34.]\n", 905 | " ...\n", 906 | " [25. 46. 70. ... 14. 23. 0.]\n", 907 | " [16. 21. 0. ... 1. 73. 1.]\n", 908 | " [ 1. 72. 63. ... 14. 35. 37.]]\n", 909 | "Batch 97: loss = 3.033266067504883, acc = 0.1708984375\n", 910 | "[[64. 17. 64. ... 3. 31. 22.]\n", 911 | " [ 1. 58. 17. ... 17. 63. 1.]\n", 912 | " [70. 3. 61. ... 61. 17. 60.]\n", 913 | " ...\n", 914 | " [38. 25. 31. ... 1. 3. 28.]\n", 915 | " [16. 18. 19. ... 1. 32. 17.]\n", 916 | " [33. 0. 40. ... 3. 29. 61.]]\n", 917 | "Batch 98: loss = 3.0093603134155273, acc = 0.1845703125\n", 918 | "[[ 3. 28. 33. ... 32. 84. 3.]\n", 919 | " [63. 62. 63. ... 61. 18. 1.]\n", 920 | " [84. 3. 34. ... 1. 63. 62.]\n", 921 | " ...\n", 922 | " [22. 3. 34. ... 29. 0. 84.]\n", 923 | " [60. 84. 3. ... 3. 29. 18.]\n", 924 | " [61. 1. 3. ... 3. 33. 32.]]\n", 925 | "Batch 99: loss = 2.8540709018707275, acc = 0.22265625\n", 926 | "[[28. 22. 3. ... 18. 1. 12.]\n", 927 | " [12. 61. 17. ... 59. 58. 76.]\n", 928 | " [63. 84. 3. ... 30. 22. 3.]\n", 929 | " ...\n", 930 | " [63. 14. 17. ... 28. 22. 14.]\n", 931 | " [ 1. 3. 28. ... 0. 3. 28.]\n", 932 | " [32. 1. 3. ... 58. 64. 84.]]\n", 933 | "Batch 100: loss = 2.81121826171875, acc = 0.234375\n", 934 | "[[34. 17. 25. ... 25. 31. 72.]\n", 935 | " [62. 0. 46. ... 25. 3. 31.]\n", 936 | " [60. 61. 60. ... 84. 34. 17.]\n", 937 | " ...\n", 938 | " [62. 3. 64. ... 64. 63. 1.]\n", 939 | " [ 3. 62. 60. ... 62. 84. 3.]\n", 940 | " [ 0. 3. 31. ... 84. 0. 0.]]\n", 941 | "Batch 101: loss = 2.934945583343506, acc = 0.2080078125\n", 942 | "[[71. 1. 40. ... 38. 25. 31.]\n", 943 | " [ 3. 31. 32. ... 70. 3. 28.]\n", 944 | " [34. 1. 34. ... 65. 58. 70.]\n", 945 | " ...\n", 946 | " [63. 62. 61. ... 10. 3. 58.]\n", 947 | " [31. 3. 63. ... 84. 3. 31.]\n", 948 | " [ 0. 51. 25. ... 58. 76. 62.]]\n", 949 | "Batch 102: loss = 2.9974610805511475, acc = 0.1962890625\n", 950 | "[[ 0. 28. 84. ... 33. 1. 3.]\n", 951 | " [34. 33. 84. ... 1. 61. 62.]\n", 952 | " [ 1. 40. 78. ... 60. 62. 1.]\n", 953 | " ...\n", 954 | " [64. 63. 1. ... 62. 0. 5.]\n", 955 | " [ 3. 61. 63. ... 1. 40. 78.]\n", 956 | " [ 0. 46. 25. ... 28. 3. 62.]]\n", 957 | "Batch 103: loss = 2.956305980682373, acc = 0.1875\n", 958 | "[[34. 3. 34. ... 62. 17. 61.]\n", 959 | " [63. 84. 3. ... 84. 3. 32.]\n", 960 | " [ 3. 31. 70. ... 28. 60. 62.]\n", 961 | " ...\n", 962 | " [ 1. 41. 72. ... 38. 25. 28.]\n", 963 | " [76. 66. 60. ... 28. 28. 1.]\n", 964 | " [63. 58. 1. ... 84. 3. 31.]]\n", 965 | "Batch 104: loss = 2.8138980865478516, acc = 0.22265625\n", 966 | "[[84. 3. 28. ... 29. 60. 84.]\n", 967 | " [70. 3. 32. ... 17. 28. 84.]\n", 968 | " [ 1. 3. 31. ... 18. 1. 12.]\n", 969 | " ...\n", 970 | " [ 0. 43. 25. ... 28. 17. 28.]\n", 971 | " [60. 28. 28. ... 29. 29. 1.]\n", 972 | " [ 3. 62. 63. ... 3. 28. 28.]]\n", 973 | "Batch 105: loss = 2.771090030670166, acc = 0.2275390625\n", 974 | "[[ 3. 31. 3. ... 58. 59. 58.]\n", 975 | " [ 3. 34. 3. ... 47. 65. 62.]\n", 976 | " [28. 17. 84. ... 1. 28. 64.]\n", 977 | " ...\n", 978 | " [ 1. 28. 34. ... 1. 60. 61.]\n", 979 | " [ 3. 28. 3. ... 3. 31. 3.]\n", 980 | " [28. 84. 3. ... 3. 28. 18.]]\n", 981 | "Batch 106: loss = 2.824322462081909, acc = 0.2216796875\n", 982 | "[[76. 62. 0. ... 17. 33. 84.]\n", 983 | " [ 1. 30. 58. ... 14. 23. 0.]\n", 984 | " [28. 84. 57. ... 70. 58. 71.]\n", 985 | " ...\n", 986 | " [62. 84. 3. ... 3. 28. 14.]\n", 987 | " [63. 64. 58. ... 3. 62. 63.]\n", 988 | " [84. 3. 28. ... 60. 84. 3.]]\n", 989 | "Batch 107: loss = 2.8435311317443848, acc = 0.2216796875\n", 990 | "[[ 3. 30. 14. ... 3. 28. 17.]\n", 991 | " [38. 25. 34. ... 34. 18. 1.]\n", 992 | " [ 0. 5. 1. ... 14. 23. 0.]\n", 993 | " ...\n", 994 | " [62. 3. 62. ... 1. 40. 78.]\n", 995 | " [62. 1. 3. ... 40. 78. 76.]\n", 996 | " [28. 3. 62. ... 1. 63. 62.]]\n", 997 | "Batch 108: loss = 2.7772090435028076, acc = 0.2373046875\n", 998 | "[[60. 84. 3. ... 62. 61. 3.]\n", 999 | " [34. 28. 29. ... 3. 62. 14.]\n", 1000 | " [38. 25. 34. ... 61. 17. 60.]\n", 1001 | " ...\n", 1002 | " [76. 66. 60. ... 34. 17. 34.]\n", 1003 | " [66. 60. 1. ... 31. 3. 63.]\n", 1004 | " [62. 84. 3. ... 62. 63. 63.]]\n", 1005 | "Batch 109: loss = 2.6875505447387695, acc = 0.2431640625\n", 1006 | "[[30. 14. 62. ... 14. 60. 3.]\n", 1007 | " [17. 63. 14. ... 3. 31. 3.]\n", 1008 | " [ 1. 29. 17. ... 32. 18. 1.]\n", 1009 | " ...\n", 1010 | " [ 1. 34. 17. ... 84. 0. 53.]\n", 1011 | " [64. 58. 1. ... 3. 61. 60.]\n", 1012 | " [ 1. 3. 28. ... 28. 84. 3.]]\n", 1013 | "Batch 110: loss = 2.8363776206970215, acc = 0.228515625\n", 1014 | "[[60. 1. 64. ... 1. 60. 7.]\n", 1015 | " [61. 29. 28. ... 1. 41. 72.]\n", 1016 | " [12. 32. 17. ... 18. 84. 3.]\n", 1017 | " ...\n", 1018 | " [17. 3. 34. ... 21. 84. 3.]\n", 1019 | " [29. 84. 3. ... 1. 64. 17.]\n", 1020 | " [31. 3. 63. ... 18. 84. 0.]]\n", 1021 | "Batch 111: loss = 2.8449063301086426, acc = 0.22265625\n", 1022 | "[[64. 62. 1. ... 47. 75. 58.]\n", 1023 | " [77. 77. 66. ... 80. 62. 0.]\n", 1024 | " [28. 70. 3. ... 3. 61. 17.]\n", 1025 | " ...\n", 1026 | " [31. 22. 3. ... 34. 84. 0.]\n", 1027 | " [62. 84. 3. ... 3. 28. 3.]\n", 1028 | " [ 3. 28. 3. ... 29. 60. 84.]]\n", 1029 | "Batch 112: loss = 2.8083200454711914, acc = 0.2373046875\n", 1030 | "[[61. 11. 1. ... 84. 54. 0.]\n", 1031 | " [40. 25. 21. ... 84. 0. 3.]\n", 1032 | " [34. 1. 34. ... 1. 60. 29.]\n", 1033 | " ...\n", 1034 | " [ 3. 33. 3. ... 25. 1. 17.]\n", 1035 | " [28. 18. 1. ... 58. 76. 62.]\n", 1036 | " [ 3. 28. 3. ... 7. 76. 1.]]\n", 1037 | "Batch 113: loss = 2.768315553665161, acc = 0.244140625\n", 1038 | "[[ 3. 34. 22. ... 31. 3. 26.]\n", 1039 | " [29. 70. 3. ... 3. 61. 29.]\n", 1040 | " [28. 84. 3. ... 28. 1. 29.]\n", 1041 | " ...\n", 1042 | " [24. 21. 0. ... 69. 1. 45.]\n", 1043 | " [ 0. 46. 25. ... 62. 18. 14.]\n", 1044 | " [44. 78. 58. ... 25. 21. 14.]]\n", 1045 | "Batch 114: loss = 2.912618637084961, acc = 0.2216796875\n", 1046 | "[[28. 33. 31. ... 28. 17. 61.]\n", 1047 | " [29. 1. 29. ... 28. 1. 3.]\n", 1048 | " [17. 60. 84. ... 34. 1. 60.]\n", 1049 | " ...\n", 1050 | " [72. 80. 62. ... 31. 84. 3.]\n", 1051 | " [17. 60. 14. ... 33. 14. 17.]\n", 1052 | " [23. 0. 38. ... 1. 12. 3.]]\n", 1053 | "Batch 115: loss = 2.8427305221557617, acc = 0.228515625\n", 1054 | "[[84. 3. 34. ... 0. 26. 33.]\n", 1055 | " [34. 3. 29. ... 29. 1. 29.]\n", 1056 | " [17. 62. 84. ... 0. 0. 51.]\n", 1057 | " ...\n", 1058 | " [30. 3. 32. ... 31. 34. 29.]\n", 1059 | " [32. 84. 0. ... 3. 64. 18.]\n", 1060 | " [30. 70. 3. ... 29. 11. 31.]]\n", 1061 | "Batch 116: loss = 2.8150267601013184, acc = 0.2265625\n", 1062 | "[[31. 33. 1. ... 71. 64. 65.]\n", 1063 | " [25. 84. 0. ... 75. 75. 1.]\n", 1064 | " [25. 1. 21. ... 28. 11. 1.]\n", 1065 | " ...\n", 1066 | " [84. 61. 29. ... 3. 31. 17.]\n", 1067 | " [14. 17. 62. ... 17. 64. 62.]\n", 1068 | " [33. 1. 29. ... 60. 28. 60.]]\n", 1069 | "Batch 117: loss = 2.926863193511963, acc = 0.2119140625\n", 1070 | "[[58. 70. 1. ... 32. 28. 60.]\n", 1071 | " [43. 65. 66. ... 62. 17. 61.]\n", 1072 | " [79. 66. 58. ... 84. 3. 34.]\n", 1073 | " ...\n", 1074 | " [33. 84. 3. ... 34. 1. 32.]\n", 1075 | " [ 1. 3. 28. ... 14. 17. 32.]\n", 1076 | " [84. 62. 17. ... 84. 3. 29.]]\n", 1077 | "Batch 118: loss = 2.782355308532715, acc = 0.2451171875\n", 1078 | "[[ 1. 62. 17. ... 58. 17. 63.]\n", 1079 | " [84. 1. 3. ... 70. 3. 64.]\n", 1080 | " [ 3. 29. 60. ... 3. 34. 3.]\n", 1081 | " ...\n", 1082 | " [34. 28. 84. ... 56. 34. 1.]\n", 1083 | " [84. 0. 3. ... 64. 18. 14.]\n", 1084 | " [59. 3. 63. ... 17. 34. 84.]]\n", 1085 | "Batch 119: loss = 2.7212114334106445, acc = 0.2646484375\n", 1086 | "[[84. 3. 28. ... 1. 32. 28.]\n", 1087 | " [63. 62. 84. ... 28. 29. 60.]\n", 1088 | " [29. 60. 61. ... 31. 3. 61.]\n", 1089 | " ...\n", 1090 | " [28. 17. 34. ... 0. 0. 0.]\n", 1091 | " [17. 62. 14. ... 77. 77. 66.]\n", 1092 | " [ 3. 33. 22. ... 29. 1. 3.]]\n", 1093 | "Batch 120: loss = 2.8013296127319336, acc = 0.2578125\n", 1094 | "[[60. 84. 3. ... 84. 0. 0.]\n", 1095 | " [84. 0. 3. ... 0. 0. 0.]\n", 1096 | " [62. 61. 1. ... 63. 84. 3.]\n", 1097 | " ...\n", 1098 | " [51. 25. 1. ... 1. 31. 62.]\n", 1099 | " [71. 64. 65. ... 28. 1. 3.]\n", 1100 | " [33. 22. 3. ... 60. 1. 62.]]\n", 1101 | "Batch 121: loss = 2.717653274536133, acc = 0.25390625\n", 1102 | "[[ 0. 51. 25. ... 29. 66. 69.]\n", 1103 | " [51. 25. 1. ... 21. 14. 23.]\n", 1104 | " [28. 22. 3. ... 70. 1. 40.]\n", 1105 | " ...\n", 1106 | " [80. 65. 78. ... 3. 32. 3.]\n", 1107 | " [34. 3. 34. ... 33. 28. 61.]\n", 1108 | " [60. 28. 84. ... 66. 71. 64.]]\n", 1109 | "Batch 122: loss = 2.817592144012451, acc = 0.2275390625\n", 1110 | "[[69. 82. 1. ... 34. 1. 61.]\n", 1111 | " [ 0. 38. 25. ... 64. 63. 84.]\n", 1112 | " [78. 76. 66. ... 29. 28. 33.]\n", 1113 | " ...\n", 1114 | " [29. 17. 32. ... 3. 28. 18.]\n", 1115 | " [ 1. 63. 62. ... 29. 60. 1.]\n", 1116 | " [65. 58. 70. ... 3. 28. 3.]]\n", 1117 | "Batch 123: loss = 2.7893738746643066, acc = 0.2294921875\n", 1118 | "[[18. 84. 3. ... 3. 29. 61.]\n", 1119 | " [ 0. 3. 32. ... 1. 62. 18.]\n", 1120 | " [84. 3. 31. ... 29. 84. 3.]\n", 1121 | " ...\n", 1122 | " [ 1. 12. 28. ... 33. 1. 28.]\n", 1123 | " [62. 61. 60. ... 78. 75. 71.]\n", 1124 | " [62. 17. 61. ... 18. 1. 61.]]\n", 1125 | "Batch 124: loss = 2.662649631500244, acc = 0.2802734375\n", 1126 | "[[64. 1. 29. ... 61. 29. 1.]\n", 1127 | " [84. 84. 0. ... 3. 61. 18.]\n", 1128 | " [33. 3. 28. ... 61. 29. 34.]\n", 1129 | " ...\n", 1130 | " [60. 62. 84. ... 1. 36. 76.]\n", 1131 | " [73. 66. 68. ... 25. 21. 14.]\n", 1132 | " [17. 62. 84. ... 84. 3. 28.]]\n", 1133 | "Batch 125: loss = 2.648763418197632, acc = 0.2744140625\n", 1134 | "[[34. 29. 61. ... 62. 34. 29.]\n", 1135 | " [ 1. 3. 31. ... 72. 80. 71.]\n", 1136 | " [84. 3. 34. ... 26. 33. 18.]\n", 1137 | " ...\n", 1138 | " [69. 58. 71. ... 66. 58. 1.]\n", 1139 | " [23. 0. 38. ... 60. 29. 60.]\n", 1140 | " [ 3. 28. 29. ... 32. 70. 3.]]\n", 1141 | "Batch 126: loss = 2.759761095046997, acc = 0.2705078125\n", 1142 | "Saved checkpoint to weights.1.h5\n", 1143 | "training done...........\n" 1144 | ] 1145 | } 1146 | ], 1147 | "source": [ 1148 | "if __name__ == '__main__':\n", 1149 | " parser = argparse.ArgumentParser(description='Train the model on some text.')\n", 1150 | " parser.add_argument('--input', default='input.txt', help='name of the text file to train from')\n", 1151 | " parser.add_argument('--epochs', type=int, default=100, help='number of epochs to train for')\n", 1152 | " parser.add_argument('--freq', type=int, default=10, help='checkpoint save frequency')\n", 1153 | " args, unknown = parser.parse_known_args()\n", 1154 | "\n", 1155 | " if not os.path.exists(LOG_DIR):\n", 1156 | " os.makedirs(LOG_DIR)\n", 1157 | "\n", 1158 | "epochs = args.epochs\n", 1159 | "save_freq = args.freq\n", 1160 | "text = open(os.path.join(DATA_DIR, args.input)).read()\n", 1161 | "\n", 1162 | "print(\"processing\")\n", 1163 | "# character to index and vice-versa mappings\n", 1164 | "char_to_idx = { ch: i for (i, ch) in enumerate(sorted(list(set(text)))) }\n", 1165 | "print(\"Number of unique characters: \" + str(len(char_to_idx))) #86\n", 1166 | "\n", 1167 | "idx_to_char = { i: ch for (ch, i) in char_to_idx.items() }\n", 1168 | "vocab_size = len(char_to_idx)\n", 1169 | "print(\"processing done\")\n", 1170 | "\n", 1171 | "print(\"creating model\")\n", 1172 | "#model_architecture \n", 1173 | "model = Sequential()\n", 1174 | "model.add(Embedding(vocab_size, 512, batch_input_shape=(BATCH_SIZE, SEQ_LENGTH)))\n", 1175 | "for i in range(3):\n", 1176 | " model.add(LSTM(256, return_sequences=True, stateful=True))\n", 1177 | " model.add(Dropout(0.2))\n", 1178 | "\n", 1179 | "model.add(TimeDistributed(Dense(vocab_size))) \n", 1180 | "model.add(Activation('softmax'))\n", 1181 | "print(\"model created\")\n", 1182 | " \n", 1183 | "model.summary()\n", 1184 | "model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])\n", 1185 | "\n", 1186 | "\n", 1187 | "#Train data generation\n", 1188 | "print(\"training data\")\n", 1189 | "T = np.asarray([char_to_idx[c] for c in text], dtype=np.int32) #convert complete text into numerical indices\n", 1190 | "\n", 1191 | "print(\"Length of text:\" + str(T.size)) #129,665\n", 1192 | "\n", 1193 | "steps_per_epoch = (len(text) / BATCH_SIZE - 1) / SEQ_LENGTH \n", 1194 | "\n", 1195 | "log = TrainLogger('training_log.csv')\n", 1196 | "\n", 1197 | "for epoch in range(epochs):\n", 1198 | " print('\\nEpoch {}/{}'.format(epoch + 1, epochs))\n", 1199 | " \n", 1200 | " losses, accs = [], []\n", 1201 | "\n", 1202 | " for i, (X, Y) in enumerate(read_batches(T, vocab_size)):\n", 1203 | " \n", 1204 | " print(X);\n", 1205 | "\n", 1206 | " loss, acc = model.train_on_batch(X, Y)\n", 1207 | " print('Batch {}: loss = {}, acc = {}'.format(i + 1, loss, acc))\n", 1208 | " losses.append(loss)\n", 1209 | " accs.append(acc)\n", 1210 | "\n", 1211 | " log.add_entry(np.average(losses), np.average(accs))\n", 1212 | " \n", 1213 | " if (epoch + 1) % save_freq == 0:\n", 1214 | " save_weights(epoch + 1, model)\n", 1215 | " print('Saved checkpoint to', 'weights.{}.h5'.format(epoch + 1))\n", 1216 | "\n", 1217 | "print(\"training done...........\")" 1218 | ] 1219 | }, 1220 | { 1221 | "cell_type": "code", 1222 | "execution_count": 10, 1223 | "metadata": { 1224 | "scrolled": false 1225 | }, 1226 | "outputs": [ 1227 | { 1228 | "name": "stdout", 1229 | "output_type": "stream", 1230 | "text": [ 1231 | "Model: \"sequential_5\"\n", 1232 | "_________________________________________________________________\n", 1233 | "Layer (type) Output Shape Param # \n", 1234 | "=================================================================\n", 1235 | "embedding_4 (Embedding) (1, 1, 512) 44032 \n", 1236 | "_________________________________________________________________\n", 1237 | "lstm_12 (LSTM) (1, 1, 256) 787456 \n", 1238 | "_________________________________________________________________\n", 1239 | "dropout_12 (Dropout) (1, 1, 256) 0 \n", 1240 | "_________________________________________________________________\n", 1241 | "lstm_13 (LSTM) (1, 1, 256) 525312 \n", 1242 | "_________________________________________________________________\n", 1243 | "dropout_13 (Dropout) (1, 1, 256) 0 \n", 1244 | "_________________________________________________________________\n", 1245 | "lstm_14 (LSTM) (1, 1, 256) 525312 \n", 1246 | "_________________________________________________________________\n", 1247 | "dropout_14 (Dropout) (1, 1, 256) 0 \n", 1248 | "_________________________________________________________________\n", 1249 | "time_distributed_4 (TimeDist (1, 1, 86) 22102 \n", 1250 | "_________________________________________________________________\n", 1251 | "activation_4 (Activation) (1, 1, 86) 0 \n", 1252 | "=================================================================\n", 1253 | "Total params: 1,904,214\n", 1254 | "Trainable params: 1,904,214\n", 1255 | "Non-trainable params: 0\n", 1256 | "_________________________________________________________________\n", 1257 | "sampled\n", 1258 | "[17, 64, 14, 17, 63, 14, 17, 62, 14, 17, 1, 61, 18, 14, 17, 62, 14, 17, 60, 14, 17, 84, 0, 62, 14, 17, 61, 14, 17, 60, 14, 17, 29, 14, 17, 28, 14, 17, 34, 14, 17, 1, 62, 14, 17, 61, 14, 17, 60, 14, 17, 29, 14, 17, 28, 14, 17, 34, 14, 17, 84, 29, 14, 17, 34, 14, 17, 29, 14, 17, 61, 14, 17, 64, 14, 17, 61, 14, 17, 1, 59, 14, 17, 58, 14, 17, 64, 14, 17, 63, 14, 17, 62, 14, 17, 61, 14, 17, 84, 54, 0, 32, 28, 60, 1, 29, 60, 28, 84, 34, 19, 84, 84, 0, 43, 25, 33, 0, 29, 60, 84, 61, 18, 14, 17, 29, 14, 17, 61, 14, 17, 29, 14, 17, 1, 62, 18, 14, 17, 29, 14, 17, 62, 14, 17, 29, 14, 17, 84, 61, 18, 14, 17, 29, 14, 17, 61, 14, 17, 60, 14, 17, 1, 29, 28, 34, 84, 54, 0, 32, 14, 17, 31, 14, 17, 33, 14, 17, 28, 14, 17, 61, 14, 17, 28, 14, 17, 1, 64, 14, 17, 29, 14, 17, 60, 14, 17, 28, 14, 17, 84, 0, 60, 28, 14, 17, 29, 14, 17, 60, 14, 17, 28, 14, 17, 1, 60, 14, 17, 29, 14, 17, 28, 14, 17, 34, 14, 17, 84, 32, 18, 1, 28, 60, 29, 84, 84, 0, 60, 60, 60, 1, 60, 34, 29, 84, 61, 14, 17, 29, 14, 17, 29, 14, 17, 61, 14, 17, 60, 14, 17, 1, 29, 28, 34, 84, 54, 0, 32, 28, 60, 1, 29, 60, 28, 84, 34, 19, 84, 84, 0, 43, 25, 35, 0, 29, 60, 84, 61, 18, 14, 17, 29, 14, 17, 61, 14, 17, 29, 14, 17, 1, 62, 18, 14, 17, 29, 14, 17, 62, 14, 17, 29, 14, 17, 84, 61, 14, 17, 34, 14, 17, 29, 14, 17, 61, 14, 17, 64, 14, 17, 61, 14, 17, 1, 59, 14, 17, 58, 14, 17, 63, 14, 17, 84, 0, 63, 18, 14, 17, 64, 14, 17, 63, 14, 17, 62, 14, 17, 1, 61, 18, 14, 17, 62, 14, 17, 61, 14, 17, 29, 14, 17, 84, 61, 18, 14, 17, 29, 14, 17, 61, 14, 17, 60, 14, 17, 1, 29, 28, 34, 84, 54, 0, 32, 28, 60, 1, 29, 60, 28, 84, 34, 19, 84, 84, 0, 0, 0, 51, 25, 1, 17, 19, 23, 0, 47, 25, 46, 75, 58, 71, 61, 72, 71, 1, 45, 58, 60, 62, 0, 5, 1, 41, 72, 77, 77, 66, 71, 64, 65, 58, 70, 1, 40, 78, 76, 66, 60, 1, 31, 58, 77, 58, 59, 58, 76, 62, 0, 46, 25, 47, 75, 58, 61, 11, 1, 58, 75, 75, 1, 43, 65, 66, 69, 1, 45, 72, 80, 62, 0, 40, 25, 21, 14, 23, 0, 38, 25, 31, 0, 3, 31, 3, 61, 17, 61, 1, 3, 28, 3, 60, 29, 28, 84, 3, 34, 3, 61, 17, 61, 1, 3, 31, 3, 28, 17, 61, 84, 3, 34, 3, 29, 17, 61, 1, 3, 31, 3, 28, 17, 61, 84, 3, 32, 70, 3, 34, 18, 1, 12, 34, 25, 25, 0, 64, 84, 3, 31, 3, 63, 62, 61, 1, 62, 61, 63, 84, 3, 28, 22, 3, 28, 60, 62, 1, 62, 60, 28, 84, 3, 34, 3, 29, 60, 61, 1, 3, 28, 22, 3, 62, 60, 28, 84, 3, 31, 3, 61, 63, 62, 1, 61, 17, 25, 25, 0, 29, 84, 3, 31, 3, 28, 17, 61, 1, 61, 17, 63, 84, 3, 34, 3, 62, 61, 29, 1, 3, 28, 22, 3, 28, 63, 64, 84, 3, 31, 3, 58, 63, 61, 1, 3, 28, 22, 3, 62, 58, 60, 84, 3, 31, 3, 61, 18, 1, 61, 17, 84, 84, 0, 0, 0, 51, 25, 1, 17, 22, 21, 0, 47, 25, 46, 60, 72, 77, 76, 59, 75, 72, 72, 70, 62, 1, 37, 66, 64, 0, 5, 1, 41, 72, 77, 77, 66, 71, 64, 65, 58, 70, 1, 40, 78, 76, 66, 60, 1, 31, 58, 77, 58, 59, 58, 76, 62, 0, 46, 25, 47, 75, 58, 61, 11, 1, 79, 66, 58, 1, 32, 33, 0, 52, 25, 28, 29, 0, 40, 25, 21, 14, 23, 0, 38, 25, 31, 0, 43, 25, 28, 0, 28, 84, 3, 31, 3, 61, 17, 61, 1, 61, 60, 61, 84, 3, 34, 3, 29, 17, 61, 1, 3, 31, 14, 63, 10, 3, 28, 17, 61, 84, 3, 31, 3, 61, 18, 1, 3, 28, 22, 3, 28, 61, 60, 84, 3, 31, 3, 61, 18, 1, 61, 17, 25, 84, 0, 0, 0, 51, 25, 1, 17, 16, 0, 47, 25, 29, 66, 60, 68, 1, 46, 82, 1, 40, 72, 80, 66, 62, 0, 5, 1, 41, 72, 77, 77, 66, 71, 64, 65, 58, 70, 1, 40, 78, 76, 66, 60, 1, 31, 58, 77, 58, 59, 58, 76, 62, 0, 46, 25, 40, 66, 68, 62, 1, 45, 66, 60, 65, 58, 75, 61, 76, 72, 71, 1, 16, 21, 13, 16, 17, 13, 23, 24, 11, 1, 79, 66, 58, 1, 43, 65, 66, 69, 1, 45, 72, 80, 62, 0, 40, 25, 21, 14, 23, 0, 38, 25, 34, 0, 3, 34, 3, 34, 17, 34, 1, 3, 31, 3, 28, 17, 28, 84, 3, 34, 3, 29, 17, 60, 1, 61, 29, 34, 84, 3, 30, 3, 32, 17, 32, 1, 60, 17, 29, 84, 3, 28, 70, 3, 28, 29, 34, 1, 3, 31, 22, 3, 33, 32, 31, 84, 3, 34, 3, 34, 17, 34, 1, 3, 31, 3, 28, 17, 28, 84, 3, 34, 3, 29, 17, 60, 1, 61, 29, 34, 84, 54, 0, 3, 30, 3, 32, 33, 34, 1, 3, 31, 22, 3, 28, 17, 33, 84, 3, 34, 3, 34, 18, 1, 12, 34, 17, 25, 25, 0, 28, 84, 3, 34, 3, 29, 60, 29, 1, 64, 29, 60, 84, 3, 34, 3, 61, 17, 29]\n", 1259 | "2g/2f/2e/2 d3/2e/2c/2|\n", 1260 | "e/2d/2c/2B/2A/2G/2 e/2d/2c/2B/2A/2G/2|B/2G/2B/2d/2g/2d/2 b/2a/2g/2f/2e/2d/2|\\\n", 1261 | "EAc BcA|G4||\n", 1262 | "P:F\n", 1263 | "Bc|d3/2B/2d/2B/2 e3/2B/2e/2B/2|d3/2B/2d/2c/2 BAG|\\\n", 1264 | "E/2D/2F/2A/2d/2A/2 g/2B/2c/2A/2|\n", 1265 | "cA/2B/2c/2A/2 c/2B/2A/2G/2|E3 AcB||\n", 1266 | "ccc cGB|d/2B/2B/2d/2c/2 BAG|\\\n", 1267 | "EAc BcA|G4||\n", 1268 | "P:H\n", 1269 | "Bc|d3/2B/2d/2B/2 e3/2B/2e/2B/2|d/2G/2B/2d/2g/2d/2 b/2a/2f/2|\n", 1270 | "f3/2g/2f/2e/2 d3/2e/2d/2B/2|d3/2B/2d/2c/2 BAG|\\\n", 1271 | "EAc BcA|G4||\n", 1272 | "\n", 1273 | "\n", 1274 | "X: 248\n", 1275 | "T:Srandon Race\n", 1276 | "% Nottingham Music Database\n", 1277 | "S:Trad, arr Phil Rowe\n", 1278 | "M:6/8\n", 1279 | "K:D\n", 1280 | "\"D\"d2d \"A\"cBA|\"G\"d2d \"D\"A2d|\"G\"B2d \"D\"A2d|\"Em\"G3 -G::\n", 1281 | "g|\"D\"fed edf|\"A7\"Ace ecA|\"G\"Bcd \"A7\"ecA|\"D\"dfe d2::\n", 1282 | "B|\"D\"A2d d2f|\"G\"edB \"A7\"Afg|\"D\"afd \"A7\"eac|\"D\"d3 d2||\n", 1283 | "\n", 1284 | "\n", 1285 | "X: 276\n", 1286 | "T:Scotsbroome Jig\n", 1287 | "% Nottingham Music Database\n", 1288 | "S:Trad, via EF\n", 1289 | "Y:AB\n", 1290 | "M:6/8\n", 1291 | "K:D\n", 1292 | "P:A\n", 1293 | "A|\"D\"d2d dcd|\"G\"B2d \"D/f+\"A2d|\"D\"d3 \"A7\"Adc|\"D\"d3 d2:|\n", 1294 | "\n", 1295 | "\n", 1296 | "X: 21\n", 1297 | "T:Bick Sy Mowie\n", 1298 | "% Nottingham Music Database\n", 1299 | "S:Mike Richardson 16.12.89, via Phil Rowe\n", 1300 | "M:6/8\n", 1301 | "K:G\n", 1302 | "\"G\"G2G \"D\"A2A|\"G\"B2c dBG|\"C\"E2E c2B|\"Am\"ABG \"D7\"FED|\"G\"G2G \"D\"A2A|\"G\"B2c dBG|\\\n", 1303 | "\"C\"EFG \"D7\"A2F|\"G\"G3 -G2::\n", 1304 | "A|\"G\"BcB gBc|\"G\"d2B\n" 1305 | ] 1306 | } 1307 | ], 1308 | "source": [ 1309 | "MODEL_DIR = './model'\n", 1310 | "model2 = Sequential()\n", 1311 | "model2.add(Embedding(vocab_size, 512, batch_input_shape=(1,1)))\n", 1312 | "for i in range(3):\n", 1313 | " model2.add(LSTM(256, return_sequences=True, stateful=True))\n", 1314 | " model2.add(Dropout(0.2))\n", 1315 | "\n", 1316 | "model2.add(TimeDistributed(Dense(vocab_size))) \n", 1317 | "model2.add(Activation('softmax'))\n", 1318 | "\n", 1319 | "model2.load_weights(os.path.join(MODEL_DIR, 'weights.100.h5'.format(epoch)))\n", 1320 | "model2.summary()\n", 1321 | "\n", 1322 | "\n", 1323 | "\n", 1324 | "sampled = []\n", 1325 | "for i in range(1024):\n", 1326 | " batch = np.zeros((1, 1))\n", 1327 | " if sampled:\n", 1328 | " batch[0, 0] = sampled[-1]\n", 1329 | " else:\n", 1330 | " batch[0, 0] = np.random.randint(vocab_size)\n", 1331 | " result = model2.predict_on_batch(batch).ravel()\n", 1332 | " sample = np.random.choice(range(vocab_size), p=result)\n", 1333 | " sampled.append(sample)\n", 1334 | "\n", 1335 | "print(\"sampled\")\n", 1336 | "print(sampled)\n", 1337 | "print(''.join(idx_to_char[c] for c in sampled))\n" 1338 | ] 1339 | } 1340 | ], 1341 | "metadata": { 1342 | "kernelspec": { 1343 | "display_name": "Python 3", 1344 | "language": "python", 1345 | "name": "python3" 1346 | }, 1347 | "language_info": { 1348 | "codemirror_mode": { 1349 | "name": "ipython", 1350 | "version": 3 1351 | }, 1352 | "file_extension": ".py", 1353 | "mimetype": "text/x-python", 1354 | "name": "python", 1355 | "nbconvert_exporter": "python", 1356 | "pygments_lexer": "ipython3", 1357 | "version": "3.7.6" 1358 | } 1359 | }, 1360 | "nbformat": 4, 1361 | "nbformat_minor": 4 1362 | } 1363 | --------------------------------------------------------------------------------