├── .gitignore ├── CNN2RNN.ipynb ├── G729VAD.ipynb ├── GRUAutoencoder.ipynb ├── LICENSE ├── README.md ├── Youtube-GraphTheory-EdgeWeightedEx.sagews ├── allstate_pca.ipynb ├── audio_rnn_basic.ipynb ├── c++ └── pytorch_dcgan_tutorial │ ├── CMakeLists.txt │ └── dcgan.cpp ├── char-rnn.ipynb ├── convolution1d_tests.ipynb ├── coursera-algos-knapsack-ps.ipynb ├── coursera-algos-twosum.ipynb ├── data ├── coursera_algos │ ├── algos-2sums.txt.zip │ ├── kargerMinCut.txt │ ├── knapsack1.txt │ └── knapsack_big.txt └── tmp.zip ├── databricks_python_graphframes_GraphExample.ipynb ├── databricks_python_numpy_spark_kMeansExample.ipynb ├── denoising_autoencoder.ipynb ├── imdb_tokenize.ipynb ├── logistic-regression-pure-julia.ipynb ├── macro-hw3-2.ipynb ├── pymc-test.ipynb ├── pytorch_attention_audio.py ├── pytorch_basics.ipynb ├── pytorch_embedding_test.ipynb ├── pytorch_tutorial_classify_names.ipynb ├── rnn_autoencoder.ipynb ├── tedlium_labels.ipynb ├── test_VarLenDataset.ipynb ├── test_pdb_debugger.ipynb ├── test_tqdm.ipynb └── timeseries.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # IPython Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # dotenv 79 | .env 80 | 81 | # virtualenv 82 | venv/ 83 | ENV/ 84 | 85 | # Spyder project settings 86 | .spyderproject 87 | 88 | # Rope project settings 89 | .ropeproject 90 | -------------------------------------------------------------------------------- /CNN2RNN.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import numpy as np\n", 12 | "import librosa \n", 13 | "import torch\n", 14 | "from torch import nn\n", 15 | "from torch.autograd import Variable\n", 16 | "import torch.nn.functional as F\n", 17 | "import matplotlib.pyplot as plt\n", 18 | "%matplotlib inline" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 2, 24 | "metadata": {}, 25 | "outputs": [ 26 | { 27 | "name": "stdout", 28 | "output_type": "stream", 29 | "text": [ 30 | "(78000,) 16000 (74000,) 16000\n" 31 | ] 32 | } 33 | ], 34 | "source": [ 35 | "files = librosa.util.find_files(\"pcsnpny-20150204-mkj/wav\")\n", 36 | "file = files[0]\n", 37 | "sig1, sr1 = librosa.core.load(file, sr=None)\n", 38 | "sig2, sr2 = librosa.core.load(files[1], sr=None)\n", 39 | "print(sig1.shape, sr1, sig2.shape, sr2)" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 3, 45 | "metadata": {}, 46 | "outputs": [ 47 | { 48 | "name": "stdout", 49 | "output_type": "stream", 50 | "text": [ 51 | "torch.Size([1, 1, 78000])\n", 52 | "input size: torch.Size([1, 1, 78000])\n", 53 | "conv1d out: torch.Size([1, 66, 242])\n", 54 | "conv2d out: torch.Size([1, 64, 64, 240])\n", 55 | "gru in: torch.Size([1, 240, 4096])\n", 56 | "gru out: torch.Size([1, 240, 1024])\n", 57 | "dense in: torch.Size([240, 1024])\n", 58 | "torch.Size([1, 240])\n", 59 | "input size: torch.Size([1, 1, 74000])\n", 60 | "conv1d out: torch.Size([1, 66, 230])\n", 61 | "conv2d out: torch.Size([1, 64, 64, 228])\n", 62 | "gru in: torch.Size([1, 228, 4096])\n", 63 | "gru out: torch.Size([1, 228, 1024])\n", 64 | "dense in: torch.Size([228, 1024])\n", 65 | "torch.Size([1, 228])\n" 66 | ] 67 | } 68 | ], 69 | "source": [ 70 | "class CNN2RNN(nn.Module):\n", 71 | " def __init__(self, conv_in_channels, conv_out_features, ws, hs, rnn_hidden_size, rnn_output_size):\n", 72 | " super(CNN2RNN, self).__init__()\n", 73 | " self.cin_channels = conv_in_channels\n", 74 | " self.cout_features = conv_out_features\n", 75 | " self.rnn_hid = rnn_hidden_size\n", 76 | " self.rnn_out = rnn_output_size\n", 77 | " \n", 78 | " \n", 79 | " # hard coding vars so I know what they are.\n", 80 | " n_layers = 2 # number of layers of RNN\n", 81 | " batch_size = 1\n", 82 | " kernel2d = 3\n", 83 | " # hidden initialization\n", 84 | " self.hidden = self.init_hidden(n_layers, batch_size, rnn_hidden_size)\n", 85 | " \n", 86 | " # net layer types\n", 87 | " self.c1 = nn.Conv1d(conv_in_channels, conv_out_features+kernel2d-1, ws, stride=hs)\n", 88 | " self.c2 = nn.Conv2d(1, conv_out_features, kernel2d)\n", 89 | " self.gru = nn.GRU(conv_out_features*conv_out_features, rnn_hidden_size, n_layers, \n", 90 | " batch_first=True, bidirectional=False)\n", 91 | " self.dense = nn.Linear(rnn_hidden_size, 1)\n", 92 | " \n", 93 | " def forward(self, input):\n", 94 | " print(\"input size: {}\".format(input.size()))\n", 95 | " conv_out = self.c1(input)\n", 96 | " print(\"conv1d out: {}\".format(conv_out.size()))\n", 97 | " conv2_out = self.c2(conv_out.unsqueeze(0))\n", 98 | " print(\"conv2d out: {}\".format(conv2_out.size()))\n", 99 | " gru_in = conv2_out.view(input.size(0), -1, self.cout_features * self.cout_features)\n", 100 | " print(\"gru in: {}\".format(gru_in.size()))\n", 101 | " gru_out, self.hidden = self.gru(gru_in, self.hidden)\n", 102 | " print(\"gru out: {}\".format(gru_out.size()))\n", 103 | " dense_in = gru_out.view(gru_in.size(1)*input.size(0) ,-1)\n", 104 | " print(\"dense in: {}\".format(dense_in.size()))\n", 105 | " out_space = self.dense(dense_in)\n", 106 | " out = F.sigmoid(out_space)\n", 107 | " out = out.view(input.size(0), -1)\n", 108 | " return(out)\n", 109 | " \n", 110 | " def init_hidden(self, nl, bat_dim, hid_dim):\n", 111 | " # The axes: (num_layers, minibatch_size, hidden_dim)\n", 112 | " # see docs\n", 113 | " return (Variable(torch.zeros(nl, bat_dim, hid_dim)))\n", 114 | "\n", 115 | "ws=640\n", 116 | "hs=ws//2\n", 117 | "nb=64\n", 118 | "\n", 119 | "net = CNN2RNN(1, nb, ws, hs, 1024, 2)\n", 120 | "inputs1 = torch.Tensor(sig1)\n", 121 | "inputs1.unsqueeze_(0)\n", 122 | "inputs1.unsqueeze_(0)\n", 123 | "print(net(Variable(inputs1)).size())\n", 124 | "inputs2 = torch.Tensor(sig2)\n", 125 | "inputs2.unsqueeze_(0)\n", 126 | "inputs2.unsqueeze_(0)\n", 127 | "print(net(Variable(inputs2)).size())\n" 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": null, 133 | "metadata": { 134 | "collapsed": true 135 | }, 136 | "outputs": [], 137 | "source": [ 138 | "torch" 139 | ] 140 | } 141 | ], 142 | "metadata": { 143 | "kernelspec": { 144 | "display_name": "Python 3", 145 | "language": "python", 146 | "name": "python3" 147 | }, 148 | "language_info": { 149 | "codemirror_mode": { 150 | "name": "ipython", 151 | "version": 3 152 | }, 153 | "file_extension": ".py", 154 | "mimetype": "text/x-python", 155 | "name": "python", 156 | "nbconvert_exporter": "python", 157 | "pygments_lexer": "ipython3", 158 | "version": "3.6.1" 159 | } 160 | }, 161 | "nbformat": 4, 162 | "nbformat_minor": 2 163 | } 164 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2016 David Pollack 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Programming Notebooks 2 | 3 | ## Python 4 | 5 | In the summer of 2016, I spent some time with various free online resources to learn 6 | graph theory and basic algorithm programming. Most of these notebooks were written 7 | for various assignments in these courses. Not all of my notebooks are here. 8 | 9 | ### List of Courses 10 | 1. [Sarada Herke's Graph Theory](https://www.youtube.com/user/DrSaradaHerke) 11 | 2. [Coursera - Social and Economic Networks: Models and Analysis](https://www.coursera.org/learn/social-economic-networks/home/welcome) 12 | 3. [Coursera - Algorithms: Design and Analysis, Part 1](https://www.coursera.org/learn/algorithm-design-analysis/home/welcome) 13 | 4. [Coursera - Algorithms: Design and Analysis, Part 2](https://www.coursera.org/learn/algorithm-design-analysis-2/home/welcome) 14 | 5. [Coursera - Probabilistic Graphical Models 1: Representation](https://www.coursera.org/learn/probabilistic-graphical-models/home/welcome) 15 | -------------------------------------------------------------------------------- /Youtube-GraphTheory-EdgeWeightedEx.sagews: -------------------------------------------------------------------------------- 1 | ︠2e450aa8-978f-4dd7-bef5-643f502926e8s︠ 2 | import numpy as np 3 | a = [0, 3, 1, 2, 0, 0] 4 | b = [3, 0, 0, 0, 0, 0] 5 | c = [1, 0, 0, 0, 4, 2] 6 | d = [2, 0, 0, 0, 1, 0] 7 | e = [0, 0, 4, 1, 0, 0] 8 | f = [0, 0, 2, 0, 0, 0] 9 | A = Matrix([a, b, c, d, e, f]) 10 | G = Graph(A, weighted=True) 11 | 12 | def dijkstras_algo(G, u, S = [], t = []): 13 | byw = G.weighted() 14 | if not S: 15 | # initial empty list 16 | S.append(u) 17 | t = np.zeros(G.order()) 18 | t_max = 1 + np.sum(G.weighted_adjacency_matrix()) / 2 19 | for i in range(len(t)): 20 | if i != u: 21 | t[i] = t_max 22 | neighborhood = list(set(G.neighbors(u)) - set(S)) 23 | for v in neighborhood: 24 | if v not in S: 25 | t[v] = min(t[v], t[u] + G.distance(u,v,by_weight=byw)) 26 | left = list(set(G.vertices()) - set(S)) 27 | w = left[np.argmin(t[left])] 28 | S.append(w) 29 | if set(G.vertices()) != set(S): 30 | dijkstras_algo(G, w, S, t) 31 | return S, t 32 | 33 | S, t = dijkstras_algo(G, 0) 34 | print S, t 35 | 36 | ︡9ffef417-fe72-4731-adbf-f9f2fd917223︡{"stdout":"[0, 2, 3, 1, 4, 5] [ 0. 3. 1. 2. 3. 3.]\n"}︡{"done":true}︡ 37 | ︠2a091a46-11a3-4c7c-ae57-e523ef4d7dd4s︠ 38 | a = [0,2,3,0,0,0] 39 | b = [2,0,2,1,3,3] 40 | c = [3,2,0,0,1,0] 41 | d = [0,1,0,0,2,1] 42 | e = [0,3,1,2,0,2] 43 | f = [0,3,0,1,2,0] 44 | A = Matrix([a,b,c,d,e,f]) 45 | G = Graph(A, weighted=True) 46 | S, t = dijkstras_algo(G, 0, S = [], t = []) 47 | print S, t 48 | 49 | ︡8eaafbcd-1aeb-45de-a972-fc33646cade1︡{"stdout":"[0, 1, 2, 3, 4, 5] [ 0. 2. 3. 3. 4. 4.]\n"}︡{"done":true}︡ 50 | ︠1e4eee2a-d583-4206-a288-aa6d2f21c0d2s︠ 51 | x = 4 52 | 53 | a = [0, 3, 1, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0] 54 | b = [3, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0] 55 | c = [1, 0, 0, 4, 0, 0, 7, 0, 0, 0, 0, 0, 0] 56 | d = [0, 1, 4, 0, 0, 2, 0, 6, 0, 0, 0, 0, 0] 57 | e = [5, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0] 58 | f = [0, 0, 0, 2, 2, 0, 1, x, 0, 0, 0, 1, 0] 59 | g = [0, 0, 7, 0, 0, 1, 0, 2, 0, 6, 0, 0, 0] 60 | h = [0, 0, 0, 6, 0, x, 2, 0, 1, 0, 0, 0, 0] 61 | i = [0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 7, 0, 1] 62 | j = [0, 0, 0, 0, 0, 0, 6, 0, 2, 0, 4, 0, 0] 63 | k = [0, 0, 0, 0, 0, 0, 0, 0, 7, 4, 0, 3, 0] 64 | l = [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 3, 0, 4] 65 | m = [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 4, 0] 66 | A=Matrix([a, b, c, d, e, f, g, h, i, j, k, l, m]) 67 | G = Graph(A, weighted=True) 68 | 69 | print G.shortest_path(0,9, by_weight=True) 70 | print G.shortest_path(0,9, by_weight=False) 71 | print G.distance(0,9, by_weight=True) 72 | print G.distance(0,9, by_weight=False) 73 | 74 | S, t = dijkstras_algo(G, 0, S = [], t = []) 75 | print S, t 76 | ︡8fd014ff-0088-45a8-8475-cc59e2d270b0︡{"stdout":"[0, 1, 3, 5, 6, 7, 8, 9]\n"}︡{"stdout":"[0, 2, 6, 9]\n"}︡{"stdout":"12\n"}︡{"stdout":"3\n"}︡{"stdout":"[0, 2, 1, 3, 4, 5, 6, 11, 7, 8, 10, 12, 9] [ 0. 3. 1. 4. 5. 6. 7. 9. 10. 12. 10. 7. 11.]\n"}︡{"done":true}︡ 77 | ︠9ba6667e-18b2-40ba-9dc5-5123e35dd7d4s︠ 78 | a = [0,7,4,5,0,0,0,0,0] 79 | b = [7,0,2,0,25,0,0,0,0] 80 | c = [4,2,0,0,0,0,0,9,0] 81 | d = [5,0,0,0,0,9,0,0,0] 82 | e = [0,25,0,0,0,0,10,0,0] 83 | f = [0,0,0,9,0,0,0,20,0] 84 | g = [0,0,0,0,10,0,0,0,2] 85 | h = [0,0,9,0,0,20,0,0,3] 86 | i = [0,0,0,0,0,0,2,3,0] 87 | A = Matrix([a,b,c,d,e,f,g,h,i]) 88 | G = Graph(A, weighted=True) 89 | G.plot() 90 | 91 | S, t = dijkstras_algo(G, 0, S = [], t = []) 92 | print S, t 93 | ︡89a52e87-0c4f-4ba7-8c4f-a943ed0d913a︡{"file":{"filename":"/projects/e4d6c2ac-d4ac-4792-9987-c8fef33777ee/.sage/temp/compute1-us/12429/tmp_7Y2ieD.svg","show":true,"text":null,"uuid":"c5d4a60b-592f-4e49-9116-e3eadb1a9287"},"once":false}︡{"html":"
"}︡{"stdout":"[0, 2, 3, 1, 7, 5, 8, 6, 4] [ 0. 6. 4. 5. 28. 14. 18. 13. 16.]\n"}︡{"done":true}︡ 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | -------------------------------------------------------------------------------- /audio_rnn_basic.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import numpy as np\n", 12 | "import librosa" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 3, 18 | "metadata": {}, 19 | "outputs": [ 20 | { 21 | "name": "stdout", 22 | "output_type": "stream", 23 | "text": [ 24 | "(78000,) 16000 (74000,) 16000\n" 25 | ] 26 | } 27 | ], 28 | "source": [ 29 | "# load audio\n", 30 | "audio_manifest = librosa.util.find_files(\"pcsnpny-20150204-mkj\")\n", 31 | "sig1, sr1 = librosa.core.load(audio_manifest[0], sr=None)\n", 32 | "sig2, sr2 = librosa.core.load(audio_manifest[1], sr=None)\n", 33 | "print(sig1.shape, sr1, sig2.shape, sr2)" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 8, 39 | "metadata": {}, 40 | "outputs": [ 41 | { 42 | "name": "stdout", 43 | "output_type": "stream", 44 | "text": [ 45 | "(128, 153) (128, 145) 0.0 0.0\n" 46 | ] 47 | } 48 | ], 49 | "source": [ 50 | "win_len = 1024\n", 51 | "hop_len = win_len // 2\n", 52 | "gram1 = librosa.feature.melspectrogram(sig1, sr=sr1, n_fft=win_len, hop_length=hop_len)\n", 53 | "gram2 = librosa.feature.melspectrogram(sig2, sr=sr2, n_fft=win_len, hop_length=hop_len)\n", 54 | "\n", 55 | "gram1 = librosa.power_to_db(gram1, ref=np.max)\n", 56 | "gram2 = librosa.power_to_db(gram2, ref=np.max)\n", 57 | "\n", 58 | "gram1 -= gram1.min()\n", 59 | "gram2 -= gram2.min()\n", 60 | "\n", 61 | "print(gram1.shape, gram2.shape, gram1.min(), gram2.min())" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": 56, 67 | "metadata": {}, 68 | "outputs": [ 69 | { 70 | "name": "stdout", 71 | "output_type": "stream", 72 | "text": [ 73 | "[153, 145] torch.Size([2, 153, 128]) torch.Size([153, 128])\n", 74 | "pack size: torch.Size([298, 128])\n", 75 | "h0: torch.Size([3, 2, 20])\n", 76 | "out size: torch.Size([298, 20]) torch.Size([3, 2, 20])\n", 77 | "torch.Size([2, 153, 20]) [153, 145]\n" 78 | ] 79 | } 80 | ], 81 | "source": [ 82 | "import torch\n", 83 | "import torch.nn as nn\n", 84 | "from torch.autograd import Variable\n", 85 | "\n", 86 | "BATCH_SIZE = 2\n", 87 | "MAX_LENGTH = max([g.shape[1] for g in [gram1, gram2]])\n", 88 | "HIDDEN_SIZE = 20\n", 89 | "N_LAYERS = 3\n", 90 | "\n", 91 | "def pad2d(t, length):\n", 92 | " if t.size(1) == length:\n", 93 | " return(t)\n", 94 | " else:\n", 95 | " return torch.cat((t, t.new(t.size(0), length - t.size(1)).zero_()),1)\n", 96 | "\n", 97 | " \n", 98 | "seq_lens = [g.shape[1] for g in [gram1, gram2]]\n", 99 | "batch_in = [pad2d(torch.Tensor(g), MAX_LENGTH) for g in [gram1, gram2]]\n", 100 | "batch_in = torch.stack(batch_in).transpose(1,2)\n", 101 | "print(seq_lens, batch_in.size(), in_size)\n", 102 | "batch_in = Variable(batch_in)\n", 103 | "\n", 104 | "pack = torch.nn.utils.rnn.pack_padded_sequence(batch_in, seq_lens, batch_first=True)\n", 105 | "print(\"pack size:\", pack.data.size())\n", 106 | "\n", 107 | "rnn = nn.GRU(128, HIDDEN_SIZE, N_LAYERS, batch_first=True)\n", 108 | "h0 = Variable(torch.randn(N_LAYERS, BATCH_SIZE, HIDDEN_SIZE))\n", 109 | "print(\"h0:\", h0.size())\n", 110 | "# forward\n", 111 | "out, hidden_new = rnn(pack, h0)\n", 112 | "\n", 113 | "print(\"out size:\", out.data.size(), hidden_new.size())\n", 114 | "\n", 115 | "unpacked, unpacked_len = torch.nn.utils.rnn.pad_packed_sequence(out, batch_first=True)\n", 116 | "\n", 117 | "print(unpacked.size(), unpacked_len)\n" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": 55, 123 | "metadata": {}, 124 | "outputs": [ 125 | { 126 | "data": { 127 | "text/plain": [ 128 | "torch.utils.data.dataset.TensorDataset" 129 | ] 130 | }, 131 | "execution_count": 55, 132 | "metadata": {}, 133 | "output_type": "execute_result" 134 | } 135 | ], 136 | "source": [ 137 | "from torch.utils.data import TensorDataset, DataLoader\n", 138 | "DataLoader()" 139 | ] 140 | } 141 | ], 142 | "metadata": { 143 | "kernelspec": { 144 | "display_name": "Python 3", 145 | "language": "python", 146 | "name": "python3" 147 | }, 148 | "language_info": { 149 | "codemirror_mode": { 150 | "name": "ipython", 151 | "version": 3 152 | }, 153 | "file_extension": ".py", 154 | "mimetype": "text/x-python", 155 | "name": "python", 156 | "nbconvert_exporter": "python", 157 | "pygments_lexer": "ipython3", 158 | "version": "3.6.1" 159 | } 160 | }, 161 | "nbformat": 4, 162 | "nbformat_minor": 2 163 | } 164 | -------------------------------------------------------------------------------- /c++/pytorch_dcgan_tutorial/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.0 FATAL_ERROR) 2 | project(dcgan) 3 | 4 | find_package(Torch REQUIRED) 5 | 6 | add_executable(dcgan dcgan.cpp) 7 | target_link_libraries(dcgan "${TORCH_LIBRARIES}") 8 | set_property(TARGET dcgan PROPERTY CXX_STANDARD 14) 9 | -------------------------------------------------------------------------------- /c++/pytorch_dcgan_tutorial/dcgan.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | /* 5 | struct Net : torch::nn::Module { 6 | Net(int64_t N, int64_t M) 7 | : linear(register_module("linear", torch::nn::Linear(N, M))) { 8 | another_input = register_parameter("b", torch::randn(M)); 9 | } 10 | torch::Tensor forward(torch::Tensor input) { 11 | return linear(input) + another_input; 12 | } 13 | torch::nn::Linear linear; 14 | torch::Tensor another_input; 15 | }; 16 | */ 17 | 18 | using namespace torch; 19 | 20 | int main() { 21 | string MNIST_path("/home/david/Programming/data/MNIST"); 22 | 23 | nn::Sequential generator( 24 | // Layer 1 25 | nn::Conv2d( 26 | nn::Conv2dOptions(100, 256, 4).with_bias(false).transposed(true)), 27 | nn::BatchNorm(256), nn::Functional(torch::relu), 28 | // Layer 2 29 | nn::Conv2d(nn::Conv2dOptions(256, 128, 3) 30 | .stride(2) 31 | .padding(1) 32 | .with_bias(false) 33 | .transposed(true)), 34 | nn::BatchNorm(128), nn::Functional(torch::relu), 35 | // Layer 3 36 | nn::Conv2d(nn::Conv2dOptions(128, 64, 4) 37 | .stride(2) 38 | .padding(1) 39 | .with_bias(false) 40 | .transposed(true)), 41 | nn::BatchNorm(64), nn::Functional(torch::relu), 42 | // Layer 4 43 | nn::Conv2d(nn::Conv2dOptions(64, 1, 4) 44 | .stride(2) 45 | .padding(1) 46 | .with_bias(false) 47 | .transposed(true)), 48 | nn::Functional(torch::tanh)); 49 | 50 | nn::Sequential discriminator( 51 | // Layer 1 52 | nn::Conv2d( 53 | nn::Conv2dOptions(1, 64, 4).stride(2).padding(1).with_bias(false)), 54 | nn::Functional(torch::leaky_relu, 0.2), 55 | // Layer 2 56 | nn::Conv2d(nn::Conv2dOptions(64, 128, 4) 57 | .stride(2) 58 | .padding(1) 59 | .with_bias(false)), 60 | nn::BatchNorm(128), nn::Functional(torch::leaky_relu, 0.2), 61 | // Layer 3 62 | nn::Conv2d(nn::Conv2dOptions(128, 256, 4) 63 | .stride(2) 64 | .padding(1) 65 | .with_bias(false)), 66 | nn::BatchNorm(256), nn::Functional(torch::leaky_relu, 0.2), 67 | // Layer 4 68 | nn::Conv2d( 69 | nn::Conv2dOptions(256, 1, 3).stride(1).padding(0).with_bias(false)), 70 | nn::Functional(torch::sigmoid)); 71 | 72 | auto dataset = torch::data::datasets::MNIST(MNIST_path.c_str()) 73 | .map(torch::data::transforms::Normalize<>(0.5, 0.5)) 74 | .map(torch::data::transforms::Stack<>()); 75 | 76 | auto data_loader = torch::data::make_data_loader( 77 | std::move(dataset), 78 | torch::data::DataLoaderOptions().batch_size(64).workers(2)); 79 | 80 | torch::optim::Adam gen_opt(generator->parameters(), 81 | torch::optim::AdamOptions(2e-4).beta1(0.5)); 82 | torch::optim::Adam dis_opt(discriminator->parameters(), 83 | torch::optim::AdamOptions(5e-4).beta1(0.5)); 84 | int64_t kNumEpochs = 100; 85 | for (int64_t epoch = 1; epoch <= kNumEpochs; ++epoch) { 86 | int64_t batch_index = 0; 87 | for (torch::data::Example<> &batch : *data_loader) { 88 | discriminator->zero_grad(); 89 | torch::Tensor real_images = batch.data; 90 | torch::Tensor real_labels = 91 | torch::empty(batch.data.size(0)).uniform_(0.8, 1.0); 92 | torch::Tensor real_output = discriminator->forward(real_images); 93 | torch::Tensor d_loss_real = 94 | torch::binary_cross_entropy(real_output, real_labels); 95 | d_loss_real.backward(); 96 | 97 | torch::Tensor noise = torch::randn({batch.data.size(0), 100, 1, 1}); 98 | torch::Tensor fake_images = generator->forward(noise); 99 | torch::Tensor fake_labels = torch::zeros(batch.data.size(0)); 100 | torch::Tensor fake_output = 101 | discriminator->forward(fake_images.detach()); 102 | torch::Tensor d_loss_fake = 103 | torch::binary_cross_entropy(fake_output, fake_labels); 104 | d_loss_fake.backward(); 105 | 106 | torch::Tensor d_loss = d_loss_real + d_loss_fake; 107 | dis_opt.step(); 108 | 109 | generator->zero_grad(); 110 | fake_labels.fill_(1); 111 | fake_output = discriminator->forward(fake_images); 112 | torch::Tensor g_loss = 113 | torch::binary_cross_entropy(fake_output, fake_labels); 114 | g_loss.backward(); 115 | gen_opt.step(); 116 | 117 | std::printf("\r[%5d/%5d][%5d/%5d] D_loss: %.4f | G_loss: %.4f", 118 | epoch, kNumEpochs, ++batch_index, 938, 119 | d_loss.item(), g_loss.item()); 120 | } 121 | } 122 | } 123 | -------------------------------------------------------------------------------- /char-rnn.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 12, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import torch\n", 12 | "import torch.nn as nn\n", 13 | "import torch.nn.functional as F\n", 14 | "import torch.optim as optim\n", 15 | "import torch.utils.data as data\n", 16 | "from torch.autograd import Variable\n", 17 | "\n", 18 | "from tqdm import tnrange, tqdm_notebook, tqdm" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 28, 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "class SimpleGRU(nn.Module):\n", 28 | " def __init__(self, vocab_size, emb_size, hid_size, batch_size, seq_len, n_layers=1):\n", 29 | " super(SimpleGRU, self).__init__()\n", 30 | " self.vocab_size = vocab_size\n", 31 | " self.emb_size = emb_size\n", 32 | " self.hid_size = hid_size\n", 33 | " self.n_layers = n_layers\n", 34 | " self.batch_size = batch_size\n", 35 | " self.seq_len = seq_len\n", 36 | " self.emb = nn.Embedding(vocab_size, emb_size)\n", 37 | " self.gru = nn.GRU(emb_size, hid_size, batch_first=True)\n", 38 | " self.fc1 = nn.Linear(seq_len * hid_size, vocab_size)\n", 39 | " self.selu = nn.SELU()\n", 40 | " self.logsoftmax = nn.LogSoftmax()\n", 41 | " def forward(self, input, hidden):\n", 42 | " x = self.emb(input)\n", 43 | " x, hidden = self.gru(x, hidden)\n", 44 | " x = x.contiguous().view(self.batch_size, -1)\n", 45 | " x = self.selu(self.fc1(x))\n", 46 | " x = self.logsoftmax(x)\n", 47 | " return x, hidden\n", 48 | "\n", 49 | "class CharDataset(data.Dataset):\n", 50 | " def __init__(self, data, seq_len):\n", 51 | " self.data = data\n", 52 | " self.seq_len = seq_len\n", 53 | " def __getitem__(self, index):\n", 54 | " inp_seq = self.data[index:(index+self.seq_len-1)]\n", 55 | " tgt_seq = torch.Tensor([self.data[index+self.seq_len]]).type(self.data.type())\n", 56 | " return inp_seq, tgt_seq\n", 57 | " def __len__(self):\n", 58 | " return len(self.data) - self.seq_len\n", 59 | " \n" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": 22, 65 | "metadata": { 66 | "collapsed": true 67 | }, 68 | "outputs": [], 69 | "source": [ 70 | "seq_length = 25\n", 71 | "batch_size = 50\n", 72 | "\n", 73 | "emb_size = 25\n", 74 | "hid_size = 100\n", 75 | "n_layers = 3" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": 23, 81 | "metadata": {}, 82 | "outputs": [ 83 | { 84 | "name": "stdout", 85 | "output_type": "stream", 86 | "text": [ 87 | "1115394 1115394\n", 88 | "torch.Size([1115394])\n", 89 | "22307\n" 90 | ] 91 | } 92 | ], 93 | "source": [ 94 | "with open(\"/home/david/Programming/data/project_gutenberg/tiny-shakespeare.txt\", \"r\") as f:\n", 95 | " text_raw = [c for l in f.readlines() for c in l]\n", 96 | " #dangling_length = len(text_raw) % seq_length\n", 97 | " #text_raw = text_raw[:-dangling_length]\n", 98 | "\n", 99 | "charset = sorted(list(set(text_raw)))\n", 100 | "c2i = {c: i for i, c in enumerate(charset)}\n", 101 | "i2c = {i: c for c, i in c2i.items()}\n", 102 | "text_idx = [c2i[c] for c in text_raw]\n", 103 | "print(len(text_idx), len(text_raw))\n", 104 | "\n", 105 | "#inputs = torch.Tensor([x for x in zip(*[text_idx[i::seq_length] for i in range(seq_length-1)])]).long()\n", 106 | "#targets = torch.Tensor(text_idx[(seq_length-1)::seq_length]).long()\n", 107 | "inputs = torch.Tensor(text_idx).long()\n", 108 | "print(inputs.size())\n", 109 | "\n", 110 | "#ds = data.TensorDataset(inputs, targets)\n", 111 | "ds = CharDataset(inputs, seq_length)\n", 112 | "dl = data.DataLoader(ds, batch_size=batch_size, drop_last=True)\n", 113 | "print(len(dl))" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": 24, 119 | "metadata": { 120 | "collapsed": true 121 | }, 122 | "outputs": [], 123 | "source": [ 124 | "vocab_size = len(charset)\n", 125 | "num_batches = len(dl)\n", 126 | "epochs = 10" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": 31, 132 | "metadata": {}, 133 | "outputs": [ 134 | { 135 | "name": "stdout", 136 | "output_type": "stream", 137 | "text": [ 138 | "SimpleGRU (\n", 139 | " (emb): Embedding(65, 25)\n", 140 | " (gru): GRU(25, 100, batch_first=True)\n", 141 | " (fc1): Linear (2400 -> 65)\n", 142 | " (selu): SELU\n", 143 | " (logsoftmax): LogSoftmax ()\n", 144 | ")\n" 145 | ] 146 | } 147 | ], 148 | "source": [ 149 | "model = SimpleGRU(vocab_size, emb_size, hid_size, batch_size, seq_length-1, n_layers)\n", 150 | "criterion = nn.NLLLoss()\n", 151 | "optimizer = optim.Adam(model.parameters(), lr=0.005)\n", 152 | "print(model)" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": 32, 158 | "metadata": { 159 | "scrolled": false 160 | }, 161 | "outputs": [ 162 | { 163 | "data": { 164 | "application/vnd.jupyter.widget-view+json": { 165 | "model_id": "a34737486b714652b0ca1bebbbd82500" 166 | } 167 | }, 168 | "metadata": {}, 169 | "output_type": "display_data" 170 | }, 171 | { 172 | "ename": "KeyboardInterrupt", 173 | "evalue": "", 174 | "output_type": "error", 175 | "traceback": [ 176 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 177 | "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", 178 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mh\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mh\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcriterion\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtgts\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 13\u001b[0;31m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 14\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[0mh\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetach_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 179 | "\u001b[0;32m~/miniconda3/lib/python3.6/site-packages/torch/autograd/variable.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph, retain_variables)\u001b[0m\n\u001b[1;32m 154\u001b[0m \u001b[0mVariable\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 155\u001b[0m \"\"\"\n\u001b[0;32m--> 156\u001b[0;31m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_variables\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 157\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 158\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mregister_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 180 | "\u001b[0;32m~/miniconda3/lib/python3.6/site-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(variables, grad_variables, retain_graph, create_graph, retain_variables)\u001b[0m\n\u001b[1;32m 96\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 97\u001b[0m Variable._execution_engine.run_backward(\n\u001b[0;32m---> 98\u001b[0;31m variables, grad_variables, retain_graph)\n\u001b[0m\u001b[1;32m 99\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 100\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 181 | "\u001b[0;31mKeyboardInterrupt\u001b[0m: " 182 | ] 183 | } 184 | ], 185 | "source": [ 186 | "batch_bar = tqdm_notebook(dl, desc=\"batches\", mininterval=0.9)\n", 187 | "\n", 188 | "for epoch in range(epochs):\n", 189 | " running_loss = 0\n", 190 | " for i, (mb, tgts) in enumerate(batch_bar):\n", 191 | " h = Variable(torch.zeros(n_layers,batch_size, hid_size))\n", 192 | " tgts.squeeze_()\n", 193 | " model.train()\n", 194 | " model.zero_grad()\n", 195 | " mb, tgts = Variable(mb), Variable(tgts)\n", 196 | " out, h = model(mb, h)\n", 197 | " loss = criterion(out, tgts)\n", 198 | " loss.backward()\n", 199 | " optimizer.step()\n", 200 | " h.detach_()\n", 201 | " running_loss += loss.data[0]\n", 202 | " if i % 25 == 0 or i == num_batches - 1:\n", 203 | " batch_bar.set_postfix(loss=(running_loss / (i+1)))\n", 204 | " torch.save(model.state_dict(), \"model_charrnn_{}.pt\".format(epoch+1))" 205 | ] 206 | }, 207 | { 208 | "cell_type": "code", 209 | "execution_count": null, 210 | "metadata": { 211 | "collapsed": true 212 | }, 213 | "outputs": [], 214 | "source": [] 215 | } 216 | ], 217 | "metadata": { 218 | "kernelspec": { 219 | "display_name": "Python [default]", 220 | "language": "python", 221 | "name": "python3" 222 | }, 223 | "language_info": { 224 | "codemirror_mode": { 225 | "name": "ipython", 226 | "version": 3 227 | }, 228 | "file_extension": ".py", 229 | "mimetype": "text/x-python", 230 | "name": "python", 231 | "nbconvert_exporter": "python", 232 | "pygments_lexer": "ipython3", 233 | "version": "3.6.1" 234 | } 235 | }, 236 | "nbformat": 4, 237 | "nbformat_minor": 2 238 | } 239 | -------------------------------------------------------------------------------- /convolution1d_tests.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import torch\n", 12 | "import torch.nn as nn\n", 13 | "import torch.nn.functional as F\n", 14 | "from torch.autograd import Variable\n", 15 | "import torchaudio" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": 13, 21 | "metadata": {}, 22 | "outputs": [ 23 | { 24 | "data": { 25 | "text/plain": [ 26 | "[Variable containing:\n", 27 | " (0 ,.,.) = \n", 28 | " \n", 29 | " Columns 0 to 18 \n", 30 | " 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18\n", 31 | " \n", 32 | " Columns 19 to 37 \n", 33 | " 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37\n", 34 | " [torch.FloatTensor of size 1x1x38], Variable containing:\n", 35 | " (0 ,.,.) = \n", 36 | " \n", 37 | " Columns 0 to 18 \n", 38 | " 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18\n", 39 | " \n", 40 | " Columns 19 to 37 \n", 41 | " 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 24\n", 42 | " [torch.FloatTensor of size 1x1x38], Variable containing:\n", 43 | " (0 ,.,.) = \n", 44 | " \n", 45 | " Columns 0 to 18 \n", 46 | " 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18\n", 47 | " \n", 48 | " Columns 19 to 37 \n", 49 | " 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 30 23 19\n", 50 | " [torch.FloatTensor of size 1x1x38], Variable containing:\n", 51 | " (0 ,.,.) = \n", 52 | " \n", 53 | " Columns 0 to 18 \n", 54 | " 1 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18\n", 55 | " \n", 56 | " Columns 19 to 37 \n", 57 | " 19 20 21 22 23 24 25 26 27 28 29 30 31 30 28 28 20 18 17\n", 58 | " [torch.FloatTensor of size 1x1x38], Variable containing:\n", 59 | " (0 ,.,.) = \n", 60 | " \n", 61 | " Columns 0 to 18 \n", 62 | " 1 2 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18\n", 63 | " \n", 64 | " Columns 19 to 37 \n", 65 | " 19 20 21 22 23 24 25 26 27 27 27 28 26 25 24 19 17 16 15\n", 66 | " [torch.FloatTensor of size 1x1x38]]" 67 | ] 68 | }, 69 | "execution_count": 13, 70 | "metadata": {}, 71 | "output_type": "execute_result" 72 | } 73 | ], 74 | "source": [ 75 | "X = torch.arange(0, 38).view(1, 1, -1)\n", 76 | "X = Variable(X)\n", 77 | "ksize = 3\n", 78 | "conv1d1_3dN = []\n", 79 | "pad1dN = []\n", 80 | "for i in range(4):\n", 81 | " conv1d1_3dN += [nn.Conv1d(in_channels=1, out_channels=1, stride=1, padding=0, kernel_size=ksize, dilation=(i+1), bias=False)]\n", 82 | " pad1dN += [nn.ConstantPad1d(((i+1), (i+1)), 0)]\n", 83 | "outs = [X]\n", 84 | "for i, cv in enumerate(zip(conv1d1_3dN, pad1dN)):\n", 85 | " cv[0].weight.data.fill_(1.0 / ksize)\n", 86 | " model = nn.Sequential(cv[1], cv[0])\n", 87 | " #model = cv[0]\n", 88 | " out = model(outs[i])\n", 89 | " outs.append(out.long().float())\n", 90 | "outs" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": 3, 96 | "metadata": {}, 97 | "outputs": [], 98 | "source": [] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": null, 103 | "metadata": { 104 | "collapsed": true 105 | }, 106 | "outputs": [], 107 | "source": [] 108 | } 109 | ], 110 | "metadata": { 111 | "kernelspec": { 112 | "display_name": "Python [conda env:pytorch-dev]", 113 | "language": "python", 114 | "name": "conda-env-pytorch-dev-py" 115 | }, 116 | "language_info": { 117 | "codemirror_mode": { 118 | "name": "ipython", 119 | "version": 3 120 | }, 121 | "file_extension": ".py", 122 | "mimetype": "text/x-python", 123 | "name": "python", 124 | "nbconvert_exporter": "python", 125 | "pygments_lexer": "ipython3", 126 | "version": "3.6.2" 127 | } 128 | }, 129 | "nbformat": 4, 130 | "nbformat_minor": 2 131 | } 132 | -------------------------------------------------------------------------------- /coursera-algos-twosum.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": false 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import numpy as np\n", 12 | "import scipy.stats\n", 13 | "import scipy\n", 14 | "import bisect\n", 15 | "\n" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": 2, 21 | "metadata": { 22 | "collapsed": false 23 | }, 24 | "outputs": [], 25 | "source": [ 26 | "#raw1 = np.loadtxt(\"2sum.txt\", dtype=int)\n", 27 | "#np.savez_compressed(\"2sum.txt.npz\", data=raw1)\n", 28 | "with np.load('2sum.txt.npz') as data:\n", 29 | " raw1 = data['data']\n", 30 | "raw1.shape\n", 31 | "U = raw1" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": 3, 37 | "metadata": { 38 | "collapsed": false 39 | }, 40 | "outputs": [ 41 | { 42 | "data": { 43 | "text/plain": [ 44 | "-1.1990247437581012" 45 | ] 46 | }, 47 | "execution_count": 3, 48 | "metadata": {}, 49 | "output_type": "execute_result" 50 | } 51 | ], 52 | "source": [ 53 | "scipy.stats.kurtosis(U)" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 4, 59 | "metadata": { 60 | "collapsed": false 61 | }, 62 | "outputs": [ 63 | { 64 | "name": "stdout", 65 | "output_type": "stream", 66 | "text": [ 67 | "Mean: 1322010.15214\n", 68 | "Standard Deviation: 57723915587.6\n" 69 | ] 70 | } 71 | ], 72 | "source": [ 73 | "U_mean = scipy.mean(U)\n", 74 | "U_std = scipy.std(U)\n", 75 | "\n", 76 | "print \"Mean:\", U_mean\n", 77 | "print \"Standard Deviation:\", U_std" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": 5, 83 | "metadata": { 84 | "collapsed": false 85 | }, 86 | "outputs": [], 87 | "source": [ 88 | "U.sort()" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": 80, 94 | "metadata": { 95 | "collapsed": false 96 | }, 97 | "outputs": [ 98 | { 99 | "name": "stdout", 100 | "output_type": "stream", 101 | "text": [ 102 | "set([8195, 5308, 3445])\n" 103 | ] 104 | } 105 | ], 106 | "source": [ 107 | "# this function takes advantage of the fact the numbers are uniformly distributed\n", 108 | "# thus the answer should simple be on the \"opposite\" side of the distribution\n", 109 | "# the function searches the other side of the distribution for a range of possible \n", 110 | "# y indexes for a given x index\n", 111 | "def find_yrange(U, i, i_mag, mag_fac = 100, mag_zoom = 20):\n", 112 | " # U: universe of numbers\n", 113 | " # i: x index\n", 114 | " # i_mag: y index search median\n", 115 | " # mag_fac: number of indices between range checks\n", 116 | " global t_bounds\n", 117 | " x_i = U[i]\n", 118 | " y_indices = [mag_fac*k + i_mag for k in range(-10, 11) if -1*(mag_fac*k + i_mag) > 0 and -1*(mag_fac*k + i_mag) < len(U)]\n", 119 | " y = U[y_indices]\n", 120 | " t = x_i + y\n", 121 | " t_min = t[0]\n", 122 | " t_max = t[-1]\n", 123 | " # check if the max or min is outside of the bounds\n", 124 | " if t_max < t_bounds[0] or t_min > t_bounds[1]:\n", 125 | " return None\n", 126 | " l_bound = y_indices[bisect.bisect_left(t, -t_bounds[0])-1]\n", 127 | " r_bound = y_indices[bisect.bisect_right(t, t_bounds[1])]\n", 128 | " m_bound = np.mean([l_bound, r_bound], dtype=int)\n", 129 | " mag_fac /= mag_zoom\n", 130 | " if mag_fac >= 5:\n", 131 | " l_bound, r_bound = find_yrange(U, i, m_bound, mag_fac)\n", 132 | " return l_bound, r_bound\n", 133 | "\n", 134 | "t_bounds = (-10000,10000)\n", 135 | "t_set = set()\n", 136 | "for i in range(1000,1010):\n", 137 | " bounds = find_yrange(U,i, -1*(i+1))\n", 138 | " if bounds:\n", 139 | " t_i = [U[i] + U[i_inv] for i_inv in range(bounds[0], bounds[1]+1) if U[i] + U[i_inv] > t_bounds[0] and U[i] + U[i_inv] < t_bounds[1]]\n", 140 | " if t_i: t_set.update(t_i)\n", 141 | "print t_set\n", 142 | "# set([8195, 5308, 3445])" 143 | ] 144 | }, 145 | { 146 | "cell_type": "code", 147 | "execution_count": 70, 148 | "metadata": { 149 | "collapsed": false 150 | }, 151 | "outputs": [ 152 | { 153 | "name": "stdout", 154 | "output_type": "stream", 155 | "text": [ 156 | "set([2048, 4097, 6146, 8195, 6145, 8194, 7171, -7637, -2795, 1024, -6613, -2794, 6890, -8104, -8103, -4006, -4005, 92, 2141, 6238, 8287, 8288, -747, 9125, 6891, -5588, -746, 3073, 4096, -9873, -8848, -4751, -2702, -653, -8011, -5962, -5961, -1864, 185, 2234, 2235, 6332, 8381, 2421, 4470, 9591, -3354, 5122, -1492, 2327, -9967, -7918, -7917, -5868, -3819, -1770, 279, 4376, 6425, 8474, 8475, 2328, 3351, -4286, -9127, -9500, -4285, 557, -6427, -4378, -1305, -9874, -7825, -5776, -3727, -3726, 371, 2420, 4469, 6518, 6519, 8568, 558, 3818, 7915, 7916, -2237, -6055, -9780, -7731, -5682, -1585, 464, 2513, 2514, 4563, 8660, 8661, 2607, 7449, -5030, -1211, -188, 3631, 7450, -7080, -5031, -6054, -1957, 1116, -7638, -5589, -3540, -1491, 2606, 4655, 4656, 8753, 8754, -8382, 3166, 6239, 9312, 838, 1395, 5681, 6704, -9594, -7545, -5496, -5495, -1398, 651, 2700, 4749, 8846, -5775, -1956, -6147, -933, 7728, 2887, -8569, -6707, -6706, -3633, -1584, -9501, -7452, -5403, -5402, 1489, 2792, 2793, 4842, 8939, 8940, 4562, -7824, 6611, 9777, 6612, -7544, 1117, 4936, -3541, -1678, 5960, 6983, -9408, -9407, -7358, -5309, -1212, 837, 2886, 4935, 9032, 9033, -1677, 2142, 4189, -7359, 1023, 4190, -5310, -3261, -3260, -9315, -9314, 1861, -5216, -3167, -1118, 931, 2980, 1862, 9126, -3447, 372, 5959, 6984, 8009, -7265, -3446, 1396, 5215, -9222, -7173, -5124, -3075, -3074, -1025, 3072, 5121, 7170, 9219, 9220, 3444, -5217, -1397, -6240, -6986, 650, -3913, -3912, 1209, -9128, -7079, -2982, -2981, -932, 3165, 5214, 7263, 7264, 9313, 5307, 5493, 5308, 7357, -2144, -6985, -2143, 1676, 2699, -9035, -9034, -4937, -2888, -839, 3258, 3259, 7356, 9405, 9406, -4938, 7542, -9779, -9687, -8662, -94, 8567, -4564, -2515, -8942, -8941, -4844, -466, 1302, 1303, 5400, 5401, 9498, 9499, -2889, 930, 3632, -8010, 6705, -7730, -5683, 1955, 5774, 6797, -8849, -6800, -6799, -4750, -2701, -652, 3445, 5494, 7543, 9592, -1863, -840, -7266, -4657, -3634, -6241, 7823, -3168, -1119, -8756, -8755, -4658, -2609, -560, -559, 3538, 2979, 9684, 9685, 186, 6052, 5028, 8101, -4191, 7078, -7451, -6428, 1210, 5029, -2608, -8663, -6614, -4565, -2516, -467, 1582, 1583, 5680, 7729, 9778, 7077, -9966, -6893, 4283, -3820, -1771, 278, -8570, -6521, -6520, -2423, -374, -373, 3724, 5773, 7822, 9871, 9872, 3352, 465, 4284, 4377, 6426, 93, -7172, -3353, -2329, 1490, -8477, -8476, -4379, -2330, -281, 1768, 1769, 5866, 5867, 9964, 9965, -1304, 8380, -9593, -280, 3539, -4472, -4471, -2422, -8383, -6334, 1675, -2236, -187, 3910, 3911, 8008, 4748, 744, 3725, 5586, 6798, 8847, 745, 5587, -6892, -5869, -2050, -8290, -8289, -4192, -95, 1954, 4003, 4004, 6053, 8102, -2049, -1026, 7635, -9686, 3817, 7636, -9221, -4843, -8196, 4841, -5123, -4098, -8197, -6148, -4099, -1, 2047])\n" 157 | ] 158 | } 159 | ], 160 | "source": [ 161 | "t_bounds = (-10000,10000)\n", 162 | "t_set = set()\n", 163 | "for i in range(len(U)):\n", 164 | " bounds = find_yrange(U,i, -1*(i+1))\n", 165 | " if bounds:\n", 166 | " t_i = [U[i] + U[i_inv] for i_inv in range(bounds[0], bounds[1]+1) if U[i] + U[i_inv] > t_bounds[0] and U[i] + U[i_inv] < t_bounds[1]]\n", 167 | " if t_i: t_set.update(t_i)\n", 168 | "print t_set\n" 169 | ] 170 | }, 171 | { 172 | "cell_type": "code", 173 | "execution_count": 66, 174 | "metadata": { 175 | "collapsed": false 176 | }, 177 | "outputs": [ 178 | { 179 | "name": "stdout", 180 | "output_type": "stream", 181 | "text": [ 182 | "427\n" 183 | ] 184 | } 185 | ], 186 | "source": [ 187 | "print len(t_set)" 188 | ] 189 | }, 190 | { 191 | "cell_type": "markdown", 192 | "metadata": {}, 193 | "source": [ 194 | "\n", 195 | "The goal of this problem is to implement the \"Median Maintenance\" algorithm (covered in the Week 5 lecture on heap applications). The text file contains a list of the integers from 1 to 10000 in unsorted order; you should treat this as a stream of numbers, arriving one by one. Letting xi denote the ith number of the file, the kth median mk is defined as the median of the numbers x1,…,xk. (So, if k is odd, then mk is ((k+1)/2)th smallest number among x1,…,xk; if k is even, then mk is the (k/2)th smallest number among x1,…,xk.)\n", 196 | "\n", 197 | "In the box below you should type the sum of these 10000 medians, modulo 10000 (i.e., only the last 4 digits). That is, you should compute (m1+m2+m3+⋯+m10000)mod10000.\n", 198 | "\n", 199 | "OPTIONAL EXERCISE: Compare the performance achieved by heap-based and search-tree-based implementations of the algorithm.\n" 200 | ] 201 | }, 202 | { 203 | "cell_type": "code", 204 | "execution_count": null, 205 | "metadata": { 206 | "collapsed": true 207 | }, 208 | "outputs": [], 209 | "source": [ 210 | "\n", 211 | "\n" 212 | ] 213 | } 214 | ], 215 | "metadata": { 216 | "kernelspec": { 217 | "display_name": "Python 2 (SageMath)", 218 | "language": "python", 219 | "name": "python2" 220 | }, 221 | "language_info": { 222 | "codemirror_mode": { 223 | "name": "ipython", 224 | "version": 2 225 | }, 226 | "file_extension": ".py", 227 | "mimetype": "text/x-python", 228 | "name": "python", 229 | "nbconvert_exporter": "python", 230 | "pygments_lexer": "ipython2", 231 | "version": "2.7.10" 232 | } 233 | }, 234 | "nbformat": 4, 235 | "nbformat_minor": 0 236 | } 237 | -------------------------------------------------------------------------------- /data/coursera_algos/algos-2sums.txt.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dhpollack/programming_notebooks/168936ac09071c00a61d56b35467d7d27ec61e60/data/coursera_algos/algos-2sums.txt.zip -------------------------------------------------------------------------------- /data/coursera_algos/kargerMinCut.txt: -------------------------------------------------------------------------------- 1 | 193 26 112 191 62 195 25 91 39 71 158 131 46 5 38 37 185 18 118 124 56 156 58 86 53 108 68 175 87 120 15 42 73 99 4 75 151 64 35 48 2 | 1 37 79 164 155 32 87 39 113 15 18 78 175 140 200 4 160 97 191 100 91 20 69 198 196 3 | 2 123 134 10 141 13 12 43 47 3 177 101 179 77 182 117 116 36 103 51 154 162 128 30 4 | 3 48 123 134 109 41 17 159 49 136 16 130 141 29 176 2 190 66 153 157 70 114 65 173 104 194 54 5 | 4 91 171 118 125 158 76 107 18 73 140 42 193 127 100 84 121 60 81 99 80 150 55 1 35 23 93 6 | 5 193 156 102 118 175 39 124 119 19 99 160 75 20 112 37 23 145 135 146 73 35 7 | 6 155 56 52 120 131 160 124 119 14 196 144 25 75 76 166 35 87 26 20 32 23 8 | 7 156 185 178 79 27 52 144 107 78 22 71 26 31 15 56 76 112 39 8 113 93 9 | 8 185 155 171 178 108 64 164 53 140 25 100 133 9 52 191 46 20 150 144 39 62 131 42 119 127 31 7 10 | 9 91 155 8 160 107 132 195 26 20 133 39 76 100 78 122 127 38 156 191 196 115 11 | 10 190 184 154 49 2 182 173 170 161 47 189 101 153 50 30 109 177 148 179 16 163 116 13 90 185 12 | 11 123 134 163 41 12 28 130 13 101 83 77 109 114 21 82 88 74 24 94 48 33 13 | 12 161 109 169 21 24 36 65 50 2 101 159 148 54 192 88 47 11 142 43 70 182 177 179 189 194 33 14 | 13 161 141 157 44 83 90 181 41 2 176 10 29 116 134 182 170 165 173 190 159 47 82 111 142 72 154 110 21 103 130 11 33 138 152 15 | 14 91 156 58 122 62 113 107 73 137 25 19 40 6 139 150 46 37 76 39 127 16 | 15 149 58 68 52 39 67 121 191 1 45 100 18 118 174 40 85 196 122 42 193 119 139 26 127 145 135 57 38 7 17 | 16 48 10 36 187 43 3 114 173 111 142 129 88 189 117 128 147 141 194 180 106 167 179 66 74 136 51 59 18 | 17 48 123 134 36 163 3 44 117 167 161 152 95 170 83 180 77 65 72 109 47 43 88 159 197 28 194 181 49 19 | 18 193 149 56 62 15 160 67 191 140 52 178 96 107 132 1 145 89 198 4 26 73 151 126 34 115 20 | 19 156 80 178 164 108 84 71 174 40 62 113 22 89 45 91 126 195 144 5 14 172 21 | 20 185 122 171 56 8 52 73 191 67 126 9 119 1 89 79 107 96 31 75 55 5 6 34 23 22 | 21 188 187 12 173 180 197 138 167 63 111 95 13 192 116 94 114 105 49 177 51 130 90 11 50 66 157 176 23 | 22 156 27 32 131 7 56 53 81 149 23 100 146 115 26 175 121 96 75 57 39 119 71 132 19 150 140 93 24 | 23 91 122 124 22 200 195 145 5 69 125 55 68 156 20 58 191 4 57 149 6 25 | 24 123 134 161 163 169 72 116 167 30 33 77 162 143 159 187 63 184 130 28 50 153 12 148 11 53 26 | 25 193 185 79 108 8 158 87 73 81 115 39 64 178 132 27 68 127 84 14 52 200 97 6 93 27 | 26 193 58 27 108 52 144 160 18 84 81 22 75 139 166 15 107 198 131 7 9 133 6 28 | 27 156 139 144 166 112 100 26 174 31 42 75 158 122 81 22 7 58 73 89 115 39 25 200 69 169 29 | 28 134 188 24 184 159 29 72 114 152 116 169 173 141 17 111 61 192 90 11 177 179 77 33 66 83 136 30 | 29 48 134 188 13 47 88 3 82 92 28 194 50 192 189 123 199 177 147 43 106 148 197 77 103 129 181 31 | 30 165 123 10 24 41 187 47 168 92 148 197 101 50 2 179 111 130 77 153 199 70 32 | 31 27 171 56 131 146 139 191 89 20 108 38 71 75 69 196 149 97 8 86 98 7 33 | 32 156 149 171 62 22 185 35 124 56 38 158 97 53 121 160 1 191 58 89 127 87 120 39 99 84 60 151 174 6 34 | 33 48 161 109 141 24 187 47 88 168 183 110 103 95 116 28 12 11 13 83 134 63 35 | 34 37 122 171 118 76 131 166 137 40 46 97 87 80 164 127 18 62 52 20 139 36 | 35 79 164 125 32 107 137 75 121 85 55 69 45 193 132 4 5 200 135 76 139 198 6 37 | 36 165 188 17 106 88 16 177 110 147 154 159 179 136 41 50 141 66 162 152 168 184 12 43 72 180 190 77 2 170 61 122 38 | 37 193 149 39 121 191 115 146 52 127 79 198 58 125 38 34 1 76 89 164 97 86 178 108 87 84 124 98 174 195 14 5 57 196 186 39 | 38 193 37 86 32 76 107 73 85 127 100 46 89 31 57 96 158 99 160 45 15 9 40 | 39 193 37 122 102 8 158 32 87 85 81 200 60 5 27 155 1 58 150 15 113 76 84 22 25 151 139 100 14 145 9 7 41 | 40 91 156 122 79 118 125 52 175 87 15 81 166 132 121 19 14 160 34 78 71 42 | 41 36 169 184 116 163 106 189 11 104 61 30 123 129 111 3 47 49 154 161 152 13 153 65 92 183 177 162 95 54 70 108 43 | 42 178 79 27 53 171 164 102 52 87 113 15 191 131 91 62 193 8 122 89 56 4 127 145 112 44 | 43 165 161 12 70 199 54 17 190 16 153 141 36 47 44 194 110 82 189 2 148 183 29 130 94 170 51 61 59 45 | 44 188 163 169 17 13 43 114 173 142 154 103 129 181 105 157 148 182 101 110 66 176 49 46 | 45 156 80 149 58 178 53 108 68 56 125 15 93 75 135 174 198 81 166 113 100 19 89 35 97 38 47 | 46 193 58 86 122 155 8 175 160 99 127 67 14 150 144 126 146 34 131 55 38 196 48 | 47 123 10 109 41 17 12 43 116 59 33 13 2 187 165 88 117 29 30 176 147 180 101 130 194 50 94 152 70 49 | 48 180 128 188 197 105 51 94 190 116 29 183 114 153 33 16 49 3 63 184 17 141 168 179 162 11 66 83 193 50 | 49 48 123 10 186 141 41 168 3 148 142 179 21 136 109 44 117 17 103 187 74 51 | 50 165 123 10 188 36 169 24 187 12 65 29 167 47 21 134 130 111 168 77 116 138 106 66 30 52 | 51 165 48 109 141 163 159 92 190 143 2 21 138 43 59 192 117 16 184 104 169 53 | 52 37 178 8 6 15 200 133 80 102 96 40 119 164 166 127 151 20 42 7 26 76 18 73 99 78 25 132 139 191 150 34 54 | 53 156 122 102 193 137 133 42 62 45 64 60 78 160 132 155 56 144 131 196 178 125 8 32 113 22 98 121 198 24 55 | 54 123 187 12 43 168 65 129 130 147 95 41 61 59 141 3 138 114 199 66 110 56 | 55 68 56 125 113 73 78 200 46 131 85 107 89 185 60 84 4 35 99 171 71 20 23 57 | 56 193 122 53 108 45 55 18 125 86 171 79 6 85 133 80 140 96 31 107 20 32 87 120 42 73 84 22 112 196 7 58 | 57 79 155 158 76 84 22 75 121 85 191 100 97 15 37 102 69 38 64 195 145 23 59 | 58 15 146 107 193 140 84 144 86 14 150 26 60 46 166 80 87 139 195 126 45 37 27 102 62 32 175 39 124 23 93 188 60 | 59 186 163 187 47 168 92 143 147 157 54 51 61 199 43 103 63 184 16 188 197 61 | 60 58 86 178 53 171 125 39 144 146 73 124 4 93 32 99 158 132 80 91 96 75 174 55 196 62 | 61 188 116 41 114 183 28 77 54 36 163 187 147 179 65 63 190 43 186 59 189 63 | 62 193 185 79 53 171 102 146 84 198 14 58 137 8 166 175 32 113 18 115 196 42 200 132 121 19 112 133 172 34 64 | 63 48 24 116 183 111 167 21 162 90 181 147 105 106 101 33 123 153 184 128 168 61 59 65 | 64 91 149 122 53 8 120 113 124 119 25 132 131 193 121 144 68 185 108 175 107 57 66 | 65 186 17 12 114 82 3 66 159 167 197 50 101 176 94 54 134 41 165 157 194 180 77 103 74 61 67 | 66 188 36 116 3 65 183 72 180 77 16 136 177 82 159 28 95 48 50 187 21 44 54 176 68 | 67 91 156 102 68 175 131 144 15 18 146 119 20 200 46 112 139 75 80 86 137 69 | 68 122 171 55 78 175 96 75 71 93 118 195 115 67 79 45 158 15 120 155 193 87 146 81 25 64 23 70 | 69 91 108 125 131 81 75 85 174 145 89 27 200 35 156 1 98 86 23 191 127 31 57 71 | 70 165 123 163 153 12 43 168 3 114 82 148 190 129 74 176 47 110 181 41 30 72 | 71 193 79 171 68 124 98 120 73 75 93 151 108 89 155 19 96 22 119 7 125 85 127 40 55 140 31 73 | 72 109 169 24 116 17 114 182 190 141 186 192 28 104 13 66 165 162 36 159 77 110 129 130 181 74 | 73 91 80 27 160 4 20 100 56 193 60 38 18 25 172 76 14 55 52 149 96 78 71 195 127 5 112 97 93 75 | 74 88 104 197 111 130 95 11 138 123 77 159 65 179 94 165 70 141 16 153 136 83 49 76 | 75 91 79 27 171 68 158 87 81 22 119 71 166 57 60 35 160 69 175 26 193 45 127 67 126 89 20 5 31 198 6 77 | 76 37 178 175 120 122 57 52 34 4 156 98 115 133 131 171 149 85 38 174 39 73 99 14 140 35 135 172 9 7 6 78 | 77 24 116 17 72 190 110 36 173 143 94 65 29 61 154 2 161 90 66 159 28 128 117 11 50 138 74 30 79 | 78 156 80 53 164 68 131 144 107 73 195 55 175 1 126 7 149 171 81 91 52 124 40 9 80 | 79 80 37 158 40 75 108 198 25 62 42 7 86 57 102 122 145 175 71 1 35 164 68 118 56 144 200 139 20 112 135 172 163 81 | 80 91 144 100 118 45 78 139 112 149 98 113 87 195 124 85 73 119 19 79 99 58 56 52 146 81 4 60 137 67 126 145 97 34 134 82 | 81 86 27 120 39 131 40 132 25 69 96 26 127 108 75 68 135 98 80 115 149 78 22 4 45 172 83 | 82 161 186 109 153 43 159 114 101 104 70 29 197 192 116 148 65 152 184 13 154 92 110 130 11 147 66 176 84 | 83 17 13 153 88 111 90 117 95 11 33 147 28 109 199 48 74 167 182 154 101 85 | 84 149 58 86 178 171 62 175 113 26 135 96 198 39 19 57 144 37 118 32 56 119 4 98 25 200 150 55 86 | 85 156 80 185 155 102 56 158 39 76 15 69 38 150 175 118 137 71 35 120 57 55 135 87 | 86 58 84 120 81 112 37 60 46 185 193 38 125 118 175 149 99 108 126 133 102 79 56 144 67 69 31 109 88 | 87 80 149 58 118 25 75 126 115 68 178 37 200 32 40 164 56 42 193 107 1 39 113 145 89 172 6 34 93 89 | 88 36 116 12 47 165 168 29 83 159 154 74 148 92 176 180 33 110 194 167 17 114 173 16 11 90 | 89 37 27 102 32 42 18 99 71 151 19 55 75 98 131 69 45 31 38 195 87 20 135 115 91 | 90 163 116 13 173 28 110 190 77 129 117 130 10 179 92 134 194 21 63 83 157 94 152 199 92 | 91 126 14 75 135 107 80 102 67 160 158 144 23 64 155 4 198 40 9 73 69 193 185 149 164 42 78 60 98 19 1 165 93 | 92 161 109 163 187 88 186 141 51 117 197 59 173 192 82 29 143 116 41 30 111 148 129 90 194 152 170 94 | 93 68 99 71 60 45 150 145 98 25 87 122 178 185 7 144 4 22 73 58 112 95 | 94 48 163 65 173 182 154 77 21 117 11 147 74 189 43 114 47 177 192 104 90 96 | 95 123 141 169 17 168 114 179 21 33 165 167 177 41 103 83 199 129 182 188 74 66 157 152 54 176 97 | 96 171 164 68 56 52 18 73 84 81 124 22 71 60 126 150 20 97 38 149 158 196 115 98 | 97 185 149 37 178 122 108 118 32 131 1 31 140 73 34 45 25 80 133 57 96 99 | 98 80 171 164 158 120 76 99 81 71 37 91 151 144 172 140 150 160 121 84 53 139 89 69 31 133 93 100 | 99 80 86 178 122 118 125 52 193 32 174 46 113 98 93 89 150 102 76 100 124 4 60 55 5 38 196 101 | 100 80 185 149 27 8 113 131 15 73 99 22 4 200 45 102 172 57 164 38 39 1 112 9 102 | 101 165 10 188 109 141 163 187 12 168 82 65 148 180 47 167 2 162 169 192 30 179 11 44 83 170 63 103 | 102 91 86 79 53 171 155 67 113 39 137 158 196 166 120 42 89 58 5 119 85 62 52 146 99 121 100 57 104 | 103 161 188 186 184 153 159 167 189 2 168 13 29 104 142 33 65 197 117 148 44 95 181 59 49 105 | 104 134 188 186 141 184 41 187 168 114 82 148 192 194 154 74 3 190 161 105 163 72 103 94 170 51 106 | 105 165 48 161 116 159 104 21 129 63 44 169 106 142 190 181 143 179 157 173 188 199 176 107 | 106 36 163 169 184 41 159 29 197 111 16 179 162 105 138 148 50 165 63 173 109 108 | 107 91 58 56 125 158 87 191 133 9 38 122 164 175 135 4 35 14 7 185 78 18 146 151 26 64 126 55 20 112 198 172 109 | 108 86 178 79 164 37 25 195 144 120 198 26 56 19 132 69 193 115 45 172 97 155 8 81 71 166 139 174 64 31 41 110 | 109 10 47 116 12 128 51 167 180 179 33 123 101 142 92 143 199 154 3 72 82 141 17 168 114 173 183 148 189 11 181 106 83 152 49 86 111 | 110 165 186 36 163 169 43 88 168 33 138 147 13 167 189 184 134 72 90 188 82 77 44 54 70 112 | 111 41 92 114 162 188 129 187 170 152 83 142 74 157 16 13 138 186 106 63 184 28 21 50 30 192 113 | 112 193 80 86 27 137 174 67 5 79 198 132 127 42 131 100 73 107 56 135 62 7 93 114 | 113 156 80 149 171 102 62 39 133 84 87 198 1 53 64 55 172 122 100 42 14 160 99 124 137 45 19 145 7 115 | 114 48 104 136 16 61 72 181 128 82 88 138 109 129 70 192 3 153 65 95 44 111 28 21 11 94 54 116 | 115 185 37 27 108 68 62 87 120 76 81 22 25 151 139 191 140 9 89 96 18 117 | 116 48 188 109 24 157 61 41 181 128 66 88 63 152 189 168 90 105 72 77 10 13 47 82 173 92 142 28 180 2 21 117 50 33 136 164 118 | 117 17 47 92 16 179 2 103 90 194 189 192 94 143 170 162 83 116 129 184 77 138 152 51 49 119 | 118 193 156 80 86 155 68 196 145 40 4 97 79 5 99 87 149 174 34 137 166 131 15 160 84 121 85 120 | 119 80 178 102 52 160 124 84 185 6 22 67 174 8 121 133 5 75 64 15 150 71 151 145 20 121 | 120 185 86 171 108 102 68 200 193 126 64 160 32 150 6 115 98 81 56 191 76 124 71 139 85 195 135 122 | 121 156 185 37 32 15 146 22 119 4 98 151 57 62 126 53 139 40 102 35 118 64 135 133 123 | 122 156 146 23 99 196 46 14 56 53 175 164 20 68 39 195 126 97 40 34 64 79 27 155 113 76 131 15 42 107 166 151 9 93 36 124 | 123 165 194 50 70 95 11 153 30 24 17 49 47 54 2 189 147 199 3 188 197 186 109 41 29 129 130 128 74 136 63 156 125 | 124 193 80 125 32 146 99 23 5 60 113 120 191 6 58 64 119 71 37 78 96 137 126 | 125 37 86 56 4 133 69 178 132 53 172 45 60 151 40 198 55 35 107 99 124 71 137 127 23 127 | 126 91 58 86 122 155 87 120 146 78 200 121 19 80 20 137 75 96 107 18 156 46 135 128 | 127 37 178 32 52 131 160 81 4 25 191 125 38 71 15 8 42 75 46 174 73 14 69 112 172 9 34 129 | 128 48 134 161 109 116 187 114 183 16 179 189 130 181 142 123 138 2 136 77 152 63 130 | 129 165 184 41 153 114 148 111 177 16 180 44 29 123 54 173 72 92 70 199 105 90 117 95 157 131 | 130 186 169 24 3 154 189 21 136 72 82 74 190 11 47 13 173 54 181 43 123 90 128 50 138 30 176 132 | 131 193 53 76 78 22 69 97 118 178 122 140 81 34 164 8 100 127 67 31 6 42 146 200 195 26 64 46 55 89 112 133 | 132 185 149 53 171 108 125 18 81 60 25 139 191 178 64 52 9 62 22 40 164 35 112 134 | 133 86 53 56 125 8 52 113 76 107 119 195 97 156 98 121 62 140 26 175 9 135 | 134 153 11 17 182 165 197 138 173 24 28 104 128 29 2 167 152 183 3 162 154 186 13 65 142 110 90 50 33 157 80 136 | 135 91 171 164 107 84 81 45 5 35 112 120 15 151 85 76 79 178 121 126 89 137 | 136 165 36 187 3 114 173 190 130 128 66 74 152 123 154 116 28 49 153 16 169 138 | 137 156 185 53 155 102 118 62 166 175 200 151 112 113 14 124 80 146 35 125 34 85 67 126 145 172 139 | 138 134 186 153 114 111 142 110 179 21 128 130 117 197 169 77 13 74 51 180 50 106 54 140 | 139 80 58 27 164 132 121 108 115 140 31 79 98 15 160 52 26 171 39 195 120 67 14 35 196 34 141 | 140 185 58 56 8 175 131 144 18 4 166 98 139 1 195 22 156 71 76 178 115 97 133 142 | 141 109 104 184 2 143 188 51 49 33 197 48 176 95 159 148 101 154 13 170 190 36 43 3 92 72 16 28 167 180 147 74 54 178 143 | 142 161 109 169 12 183 111 163 187 162 180 16 165 116 184 138 44 13 49 134 143 177 103 128 181 105 192 144 | 143 165 161 109 141 24 184 168 92 142 154 180 51 147 179 176 183 189 59 159 167 77 117 105 152 145 | 144 91 80 58 178 27 53 164 108 8 158 60 79 67 155 78 140 7 86 26 174 84 98 64 19 46 145 6 93 146 | 145 178 79 155 118 18 151 69 39 113 42 119 93 80 144 87 15 137 23 158 166 5 57 172 147 | 146 149 58 37 122 62 60 68 107 31 67 172 195 121 131 102 185 174 124 126 80 22 137 46 5 148 | 147 123 186 36 47 183 29 182 197 16 167 110 143 168 83 82 141 63 94 54 59 157 170 61 149 | 148 10 188 186 141 163 12 43 88 82 109 192 70 161 30 129 24 92 49 177 29 104 197 154 101 103 162 181 106 44 176 150 | 149 156 80 200 146 15 185 87 91 100 45 37 178 18 32 64 132 97 84 150 113 86 118 76 73 81 78 22 31 96 23 161 151 | 150 149 58 8 175 120 39 99 119 166 98 200 85 14 4 46 96 84 22 93 52 152 | 151 125 52 160 71 166 98 137 145 122 171 119 193 175 32 121 107 115 18 39 89 135 196 153 | 152 134 161 36 116 41 17 82 111 28 128 157 13 92 47 143 117 109 90 95 136 170 154 | 153 165 48 123 134 10 186 138 187 103 82 70 163 129 188 170 173 24 83 167 41 43 3 114 154 180 74 136 63 30 155 | 154 134 10 109 141 36 41 88 159 82 104 143 148 197 44 13 94 130 162 177 153 167 77 2 83 136 170 156 | 155 91 185 53 164 108 145 171 200 6 122 8 9 85 57 1 137 126 102 118 46 68 158 39 144 160 71 172 157 | 156 7 53 113 45 14 193 122 40 67 19 121 22 118 78 5 137 166 149 32 85 178 27 158 76 126 140 69 133 9 23 123 158 | 157 184 116 13 3 65 111 90 105 187 147 152 134 59 95 186 197 44 170 129 21 159 | 158 91 193 79 27 102 68 32 75 156 107 185 144 4 57 39 155 98 171 85 25 175 60 145 38 96 160 | 159 188 141 36 24 13 12 88 103 105 179 177 17 3 182 51 82 28 161 154 106 65 183 72 143 77 66 74 161 | 160 91 185 53 32 120 155 6 9 18 127 171 118 113 119 151 26 46 195 200 73 98 75 139 40 1 5 198 38 162 | 161 165 10 24 173 142 182 162 92 13 152 33 82 189 43 143 128 103 186 105 12 188 184 41 17 159 148 104 180 77 194 170 149 163 | 162 134 161 188 186 36 24 173 182 111 142 72 154 101 48 148 2 179 106 63 41 117 164 | 163 177 92 10 148 101 44 169 59 110 188 17 51 90 186 94 11 24 70 106 182 41 153 187 104 142 190 61 79 165 | 164 185 37 122 175 139 144 91 35 19 96 98 78 155 1 108 198 42 135 79 195 8 52 87 131 107 166 132 100 34 116 166 | 165 30 105 36 101 167 183 43 50 70 189 181 190 161 123 136 51 143 129 153 110 134 169 13 47 88 65 142 72 95 106 74 91 167 | 166 156 58 27 102 118 62 52 191 140 164 40 122 151 108 75 174 137 196 150 34 200 45 26 145 6 168 | 167 165 134 109 24 17 153 187 88 168 65 182 177 154 16 194 141 103 50 147 63 110 101 179 143 21 95 83 169 | 168 36 169 116 88 48 70 30 33 101 104 109 110 143 181 167 54 199 49 59 95 177 180 103 50 147 63 170 | 169 163 50 184 12 176 106 41 194 173 168 110 142 44 95 130 24 165 182 72 177 197 28 101 138 105 136 51 27 171 | 170 10 186 141 17 13 153 111 117 157 92 161 152 147 104 179 154 43 101 36 199 172 | 171 34 60 68 31 20 42 62 4 120 71 32 135 113 175 84 102 132 96 98 75 155 56 8 158 76 160 78 151 139 55 184 173 | 172 185 108 125 113 146 73 98 174 100 155 81 79 76 127 87 137 107 19 62 145 174 | 173 134 10 161 169 13 153 109 162 44 21 88 182 3 190 179 136 90 16 116 94 92 28 77 129 130 105 106 175 | 174 27 118 175 76 144 15 146 99 119 166 45 32 19 172 191 60 108 37 69 112 127 176 | 175 86 122 79 171 164 68 62 84 140 5 40 196 46 174 150 67 58 76 193 158 107 78 22 137 75 151 85 195 64 1 133 177 | 176 141 169 13 47 88 3 65 183 177 143 148 82 105 130 66 181 44 21 182 95 70 178 | 177 10 36 163 169 159 29 148 190 12 199 167 176 180 2 192 168 142 41 179 129 154 28 21 95 66 94 179 | 178 149 76 19 52 127 108 45 156 60 198 99 37 144 7 97 145 84 119 42 191 53 125 8 87 131 18 25 132 195 140 135 93 141 180 | 179 10 109 36 159 173 177 49 117 12 48 101 189 2 138 128 95 106 28 167 16 143 162 90 105 74 170 199 30 61 181 | 180 48 109 17 88 142 177 16 65 194 36 153 116 161 66 143 141 129 47 168 21 101 181 138 182 | 181 165 188 184 116 13 168 114 182 130 128 44 103 148 63 29 142 180 17 109 72 105 176 70 183 | 182 134 10 161 163 169 184 13 159 173 183 147 189 162 181 12 72 167 94 199 192 2 95 44 83 176 184 | 183 165 48 134 188 186 63 33 147 43 61 109 142 159 66 41 128 182 176 190 184 143 189 185 | 184 48 10 141 169 24 143 41 161 181 36 129 197 104 186 182 28 106 157 103 199 82 183 111 142 110 117 63 51 59 171 186 | 185 193 140 8 97 7 121 160 155 85 172 91 100 164 132 25 120 20 62 137 115 149 86 32 158 107 146 119 64 55 93 10 187 | 186 161 104 199 162 148 110 138 183 153 103 192 82 130 123 59 49 147 134 65 170 163 184 187 92 197 111 72 157 61 37 188 | 187 24 153 186 92 189 54 33 59 21 167 163 101 136 50 197 16 104 128 192 30 47 111 142 66 157 61 49 189 | 188 48 123 50 104 36 101 148 66 161 44 28 21 162 116 183 103 159 181 61 29 141 163 153 111 110 190 95 105 59 58 190 | 189 165 123 10 161 116 41 187 43 29 182 197 16 110 179 143 12 183 103 199 109 130 128 194 117 94 61 191 | 190 165 48 10 141 13 43 3 173 183 104 177 72 36 77 188 90 136 51 163 70 130 105 61 192 | 191 193 37 178 8 32 120 15 42 107 18 124 166 132 174 20 52 57 31 115 127 1 69 9 23 193 | 192 186 187 12 114 82 92 29 148 104 182 177 72 28 101 21 117 94 51 111 142 194 | 194 123 169 43 88 65 29 104 16 167 180 90 189 161 199 47 92 3 12 17 117 195 | 195 193 80 58 122 164 108 68 160 146 78 139 23 120 178 133 73 175 37 131 9 19 89 140 57 196 | 196 122 53 102 118 62 175 15 166 31 9 37 6 99 151 56 139 46 1 96 60 197 | 197 48 123 134 141 184 187 82 65 92 106 148 147 199 29 17 30 189 186 74 169 154 21 103 138 157 59 198 | 198 91 37 178 79 164 108 125 62 113 18 84 45 26 112 35 107 160 53 1 75 199 | 199 123 186 109 184 43 168 29 182 197 177 189 129 194 95 83 170 54 105 90 179 30 59 200 | 200 149 155 52 87 120 39 160 137 27 79 131 100 25 55 23 126 84 166 150 62 67 1 69 35 201 | -------------------------------------------------------------------------------- /data/coursera_algos/knapsack1.txt: -------------------------------------------------------------------------------- 1 | 10000 100 2 | 16808 250 3 | 50074 659 4 | 8931 273 5 | 27545 879 6 | 77924 710 7 | 64441 166 8 | 84493 43 9 | 7988 504 10 | 82328 730 11 | 78841 613 12 | 44304 170 13 | 17710 158 14 | 29561 934 15 | 93100 279 16 | 51817 336 17 | 99098 827 18 | 13513 268 19 | 23811 634 20 | 80980 150 21 | 36580 822 22 | 11968 673 23 | 1394 337 24 | 25486 746 25 | 25229 92 26 | 40195 358 27 | 35002 154 28 | 16709 945 29 | 15669 491 30 | 88125 197 31 | 9531 904 32 | 27723 667 33 | 28550 25 34 | 97802 854 35 | 40978 409 36 | 8229 934 37 | 60299 982 38 | 28636 14 39 | 23866 815 40 | 39064 537 41 | 39426 670 42 | 24116 95 43 | 75630 502 44 | 46518 196 45 | 30106 405 46 | 19452 299 47 | 82189 124 48 | 99506 883 49 | 6753 567 50 | 36717 338 51 | 54439 145 52 | 51502 898 53 | 83872 829 54 | 11138 359 55 | 53178 398 56 | 22295 905 57 | 21610 232 58 | 59746 176 59 | 53636 299 60 | 98143 400 61 | 27969 413 62 | 261 558 63 | 41595 9 64 | 16396 969 65 | 19114 531 66 | 71007 963 67 | 97943 366 68 | 42083 853 69 | 30768 822 70 | 85696 713 71 | 73672 902 72 | 48591 832 73 | 14739 58 74 | 31617 791 75 | 55641 680 76 | 37336 7 77 | 97973 99 78 | 49096 320 79 | 83455 224 80 | 12290 761 81 | 48906 127 82 | 36124 507 83 | 45814 771 84 | 35239 95 85 | 96221 845 86 | 12367 535 87 | 25227 395 88 | 41364 739 89 | 7845 591 90 | 36551 160 91 | 8624 948 92 | 97386 218 93 | 95273 540 94 | 99248 386 95 | 13497 886 96 | 40624 421 97 | 28145 969 98 | 35736 916 99 | 61626 535 100 | 46043 12 101 | 54680 153 102 | -------------------------------------------------------------------------------- /data/coursera_algos/knapsack_big.txt: -------------------------------------------------------------------------------- 1 | 2000000 2000 2 | 16808 241486 3 | 50074 834558 4 | 8931 738037 5 | 27545 212860 6 | 77924 494349 7 | 64441 815107 8 | 84493 723724 9 | 7988 421316 10 | 82328 652893 11 | 78841 402599 12 | 44304 631607 13 | 17710 318556 14 | 29561 608119 15 | 93100 429390 16 | 51817 431959 17 | 99098 365899 18 | 13513 90282 19 | 23811 467558 20 | 80980 743542 21 | 36580 896948 22 | 11968 883369 23 | 1394 604449 24 | 25486 562244 25 | 25229 333236 26 | 40195 157443 27 | 35002 696933 28 | 16709 539123 29 | 15669 202584 30 | 88125 759690 31 | 9531 69730 32 | 27723 110467 33 | 28550 651058 34 | 97802 231944 35 | 40978 186803 36 | 8229 572887 37 | 60299 797491 38 | 28636 692529 39 | 23866 503157 40 | 39064 243688 41 | 39426 182032 42 | 24116 262772 43 | 75630 246257 44 | 46518 605917 45 | 30106 685556 46 | 19452 522241 47 | 82189 419114 48 | 99506 622799 49 | 6753 512332 50 | 36717 489578 51 | 54439 869423 52 | 51502 771067 53 | 83872 884103 54 | 11138 450309 55 | 53178 444804 56 | 22295 405168 57 | 21610 527012 58 | 59746 255432 59 | 53636 702071 60 | 98143 264607 61 | 27969 227173 62 | 261 471962 63 | 41595 807033 64 | 16396 471595 65 | 19114 520773 66 | 71007 21653 67 | 97943 882635 68 | 42083 801161 69 | 30768 806299 70 | 85696 402599 71 | 73672 858413 72 | 48591 122945 73 | 14739 256166 74 | 31617 350118 75 | 55641 783545 76 | 37336 518571 77 | 97973 17616 78 | 49096 640048 79 | 83455 74134 80 | 12290 610321 81 | 48906 141662 82 | 36124 589402 83 | 45814 638947 84 | 35239 751616 85 | 96221 533618 86 | 12367 850339 87 | 25227 670142 88 | 41364 699869 89 | 7845 269378 90 | 36551 198914 91 | 8624 98356 92 | 97386 122211 93 | 95273 646654 94 | 99248 231210 95 | 13497 806666 96 | 40624 153773 97 | 28145 427188 98 | 35736 90282 99 | 61626 23855 100 | 46043 585365 101 | 54680 654728 102 | 75245 593439 103 | 78819 117073 104 | 87693 903554 105 | 12992 357458 106 | 22670 469393 107 | 89554 642617 108 | 75826 568850 109 | 29663 532884 110 | 33809 531783 111 | 22762 171022 112 | 37336 833090 113 | 77996 467191 114 | 62768 784279 115 | 29875 716017 116 | 40557 519305 117 | 68873 549766 118 | 20095 435262 119 | 13756 171022 120 | 51408 709411 121 | 30194 111935 122 | 15681 116706 123 | 16856 669408 124 | 70964 512332 125 | 86677 892544 126 | 63250 836026 127 | 77845 418747 128 | 60809 531049 129 | 28652 532884 130 | 45204 478201 131 | 96532 514167 132 | 95420 866487 133 | 42010 780976 134 | 87930 258368 135 | 44055 202217 136 | 76738 877864 137 | 47318 901352 138 | 4201 634910 139 | 21565 69730 140 | 6649 16515 141 | 25551 203685 142 | 41977 835292 143 | 38555 736202 144 | 2505 215429 145 | 34749 905022 146 | 59379 171389 147 | 78210 744276 148 | 78554 485174 149 | 50448 394158 150 | 80774 697300 151 | 64306 324061 152 | 70054 724825 153 | 84631 377643 154 | 5401 191941 155 | 95371 695098 156 | 83017 107164 157 | 8156 807033 158 | 15558 218365 159 | 53179 670142 160 | 87358 219466 161 | 27879 757121 162 | 74820 194143 163 | 83134 660600 164 | 23721 531783 165 | 6634 32663 166 | 76032 528480 167 | 60590 884470 168 | 98878 569584 169 | 79359 553436 170 | 77255 80373 171 | 14916 269378 172 | 45481 776205 173 | 69956 608853 174 | 37108 143497 175 | 40001 213961 176 | 45036 344613 177 | 61114 472696 178 | 29594 539123 179 | 98355 675647 180 | 25358 809235 181 | 13730 649223 182 | 76564 218365 183 | 60918 373606 184 | 19420 63491 185 | 47748 105696 186 | 55396 600045 187 | 4474 913096 188 | 11749 417279 189 | 5659 576190 190 | 45407 777306 191 | 39825 916399 192 | 974 70097 193 | 46898 715650 194 | 3951 50646 195 | 88481 562978 196 | 13901 764828 197 | 62534 772535 198 | 17004 569217 199 | 20951 344246 200 | 58781 576557 201 | 41833 484073 202 | 33550 262405 203 | 6250 548665 204 | 66311 846302 205 | 984 237082 206 | 92041 86245 207 | 73651 451043 208 | 53230 873827 209 | 17747 340209 210 | 87231 470861 211 | 70777 354889 212 | 68245 560409 213 | 47697 456181 214 | 29065 364064 215 | 56599 853642 216 | 69941 86245 217 | 89552 891443 218 | 27206 238183 219 | 80419 708677 220 | 59650 713448 221 | 83321 593439 222 | 22013 800794 223 | 54789 811070 224 | 24009 64592 225 | 65325 643718 226 | 51258 547197 227 | 35691 770333 228 | 62411 360027 229 | 90554 384983 230 | 90961 242220 231 | 51404 456181 232 | 7308 263873 233 | 44911 44774 234 | 92800 210658 235 | 20172 702438 236 | 34415 522608 237 | 16642 883002 238 | 55349 229375 239 | 14316 229375 240 | 81446 587567 241 | 72221 702071 242 | 9240 681152 243 | 41306 439666 244 | 82254 569217 245 | 15167 44040 246 | 59305 492514 247 | 45899 274516 248 | 1721 647388 249 | 51581 173591 250 | 49353 713448 251 | 78699 803363 252 | 60812 626469 253 | 56892 885204 254 | 31729 70097 255 | 51662 300940 256 | 3121 193409 257 | 17159 448107 258 | 69628 266075 259 | 53712 737303 260 | 29595 767030 261 | 1728 778774 262 | 98796 623166 263 | 37443 173958 264 | 85007 28993 265 | 66925 438198 266 | 27098 28993 267 | 67104 736202 268 | 80109 49178 269 | 2282 51013 270 | 43711 456548 271 | 1906 245890 272 | 67782 447740 273 | 70792 346815 274 | 97866 424252 275 | 29537 643351 276 | 57126 605917 277 | 80168 124413 278 | 477 792353 279 | 21577 216530 280 | 77468 571419 281 | 27385 373973 282 | 25211 910160 283 | 28212 37434 284 | 19519 788316 285 | 21935 671977 286 | 80768 843366 287 | 49819 30461 288 | 11826 765562 289 | 58029 693263 290 | 78035 197813 291 | 51241 914564 292 | 86541 80006 293 | 90936 640415 294 | 48049 103127 295 | 87012 148268 296 | 53065 755286 297 | 17238 601146 298 | 70246 643351 299 | 70416 446272 300 | 24711 83676 301 | 63038 601146 302 | 68448 270846 303 | 41694 471962 304 | 67130 793454 305 | 33980 25323 306 | 95526 631974 307 | 74959 279287 308 | 11801 837127 309 | 87671 310482 310 | 80847 29727 311 | 9544 481504 312 | 85937 343145 313 | 97720 712347 314 | 91248 681519 315 | 28360 177628 316 | 45036 308647 317 | 83913 383148 318 | 45990 226072 319 | 63735 870891 320 | 520 608486 321 | 67046 170288 322 | 26662 853642 323 | 82820 815474 324 | 22904 485174 325 | 36799 171022 326 | 14005 75969 327 | 680 373239 328 | 7392 11377 329 | 64925 211025 330 | 57122 750148 331 | 77107 175426 332 | 51704 577658 333 | 71033 300206 334 | 61655 729596 335 | 42218 330300 336 | 57290 271580 337 | 84748 344246 338 | 89880 890342 339 | 1517 266809 340 | 82763 677482 341 | 49040 687391 342 | 5904 592338 343 | 10762 475632 344 | 61467 211025 345 | 44558 659866 346 | 26956 747579 347 | 21495 751983 348 | 92222 148635 349 | 2294 546096 350 | 42233 803730 351 | 34792 782811 352 | 49153 468292 353 | 73571 862083 354 | 27440 422784 355 | 35800 547931 356 | 91836 119642 357 | 46044 903187 358 | 74087 11744 359 | 39151 263873 360 | 18624 324428 361 | 9606 760057 362 | 37312 718953 363 | 77993 321125 364 | 38700 550133 365 | 80640 181665 366 | 25837 614358 367 | 82858 9542 368 | 83329 456548 369 | 74411 253964 370 | 85177 320391 371 | 72092 745377 372 | 84712 129551 373 | 91027 364798 374 | 10791 740973 375 | 59365 256900 376 | 24885 536554 377 | 41442 323694 378 | 16180 252863 379 | 94538 438198 380 | 87562 642250 381 | 74166 892911 382 | 57733 838228 383 | 47855 547197 384 | 9848 299105 385 | 40463 519672 386 | 43820 80006 387 | 13378 225338 388 | 43837 82942 389 | 32300 419114 390 | 34760 867588 391 | 74391 500955 392 | 30049 203685 393 | 59045 26424 394 | 40151 554904 395 | 41296 197079 396 | 95732 58720 397 | 28246 573254 398 | 46058 332135 399 | 93929 177995 400 | 62174 629038 401 | 9670 147534 402 | 4504 873827 403 | 18676 450676 404 | 24301 914197 405 | 40263 659132 406 | 68264 545729 407 | 24890 524443 408 | 76736 863551 409 | 31365 262772 410 | 50964 410306 411 | 85653 263506 412 | 57317 293233 413 | 61421 58353 414 | 7451 223136 415 | 78742 240385 416 | 99197 178729 417 | 2502 619863 418 | 93129 860248 419 | 77449 41104 420 | 17254 534352 421 | 69330 184601 422 | 33720 367367 423 | 38821 862817 424 | 81080 44407 425 | 94411 366266 426 | 66114 185702 427 | 98139 389020 428 | 46045 216163 429 | 98394 321859 430 | 981 538022 431 | 45597 208089 432 | 67371 397461 433 | 6544 836393 434 | 62853 536921 435 | 76825 856211 436 | 82381 835292 437 | 71858 673445 438 | 59499 253597 439 | 52846 637479 440 | 68510 375808 441 | 96850 896214 442 | 19291 907591 443 | 85744 97989 444 | 36386 749047 445 | 86276 493248 446 | 11152 843733 447 | 75252 74134 448 | 32784 804831 449 | 66438 447740 450 | 7699 873827 451 | 83138 347549 452 | 45879 875295 453 | 70343 314519 454 | 45961 662802 455 | 35141 228641 456 | 43567 609220 457 | 563 114504 458 | 41970 721889 459 | 12624 847770 460 | 17314 450309 461 | 27011 782811 462 | 76979 415811 463 | 7395 42939 464 | 87657 633075 465 | 61787 423518 466 | 54545 680051 467 | 8777 22754 468 | 74253 642984 469 | 75515 14680 470 | 95020 29360 471 | 66672 310482 472 | 56664 101292 473 | 84025 546463 474 | 16855 720788 475 | 14982 751249 476 | 86289 257267 477 | 8360 856578 478 | 55380 560776 479 | 19462 674913 480 | 8354 849238 481 | 85789 906857 482 | 69881 895847 483 | 38284 303142 484 | 237 855110 485 | 98568 353054 486 | 47940 451410 487 | 69652 48811 488 | 9267 640415 489 | 111 11010 490 | 86328 820245 491 | 25582 745377 492 | 7937 250661 493 | 45436 715283 494 | 66973 184234 495 | 19799 778040 496 | 74588 571786 497 | 69090 86245 498 | 29091 784646 499 | 66613 883369 500 | 72113 21653 501 | 51952 540591 502 | 3322 410306 503 | 55129 850339 504 | 28601 104228 505 | 49755 254331 506 | 55195 428656 507 | 48633 722623 508 | 58778 149736 509 | 18633 499120 510 | 86384 131386 511 | 37154 588668 512 | 98522 914197 513 | 9819 361495 514 | 66901 55417 515 | 74121 362229 516 | 33600 582429 517 | 27145 390488 518 | 40871 298738 519 | 4600 309748 520 | 74619 368101 521 | 79453 143864 522 | 52801 734000 523 | 44841 895480 524 | 86386 107898 525 | 24516 877130 526 | 14879 11010 527 | 18562 376909 528 | 28049 591237 529 | 49441 336539 530 | 18771 747212 531 | 40252 427555 532 | 5870 45141 533 | 72470 833824 534 | 8173 477834 535 | 52126 31195 536 | 98851 889608 537 | 11244 81107 538 | 25870 266442 539 | 41395 105696 540 | 26393 240018 541 | 9525 389754 542 | 39520 545729 543 | 81563 285526 544 | 3726 176527 545 | 5950 333236 546 | 37323 66060 547 | 60094 485908 548 | 85123 695098 549 | 26367 398929 550 | 71024 556372 551 | 66640 416545 552 | 8794 681152 553 | 59179 372872 554 | 56772 86245 555 | 51356 797858 556 | 68490 587200 557 | 67121 642617 558 | 10313 176160 559 | 33348 127349 560 | 32852 402599 561 | 25812 271947 562 | 91737 773636 563 | 29608 779141 564 | 87867 301307 565 | 41413 698034 566 | 14041 177995 567 | 34258 571052 568 | 88296 292132 569 | 42252 345347 570 | 92695 437464 571 | 51139 339108 572 | 19876 332135 573 | 94293 868689 574 | 14563 16882 575 | 93237 747579 576 | 83874 734000 577 | 27485 734367 578 | 54120 609220 579 | 99329 619129 580 | 263 449942 581 | 5717 293967 582 | 92173 438932 583 | 44001 91383 584 | 61399 552702 585 | 89546 254698 586 | 12301 898783 587 | 52193 452144 588 | 14784 411407 589 | 69802 322226 590 | 46222 89181 591 | 24514 555271 592 | 32532 75969 593 | 45820 808868 594 | 46857 233045 595 | 64255 249193 596 | 49815 649223 597 | 66538 40370 598 | 47298 505726 599 | 25761 418013 600 | 80192 799326 601 | 54175 458016 602 | 14878 412875 603 | 94785 455447 604 | 6751 870524 605 | 45000 848504 606 | 99610 675647 607 | 55365 606651 608 | 32813 613624 609 | 24135 177995 610 | 93584 433794 611 | 14260 402966 612 | 754 651792 613 | 79322 812538 614 | 65254 684455 615 | 36097 720421 616 | 30662 296536 617 | 27492 84043 618 | 42921 905022 619 | 81688 573621 620 | 55473 265708 621 | 55642 12111 622 | 82952 493982 623 | 58446 712714 624 | 49475 822814 625 | 96177 417279 626 | 25827 788683 627 | 23345 460952 628 | 85980 621698 629 | 68130 490312 630 | 59849 858780 631 | 5618 865019 632 | 47839 910894 633 | 88930 409205 634 | 29969 669041 635 | 25200 811437 636 | 59830 39269 637 | 3489 489945 638 | 2820 386818 639 | 91138 760057 640 | 80319 370303 641 | 80684 500955 642 | 91121 310115 643 | 89962 86979 644 | 11019 480403 645 | 79947 830521 646 | 31479 715650 647 | 25064 716384 648 | 71507 487376 649 | 78969 133588 650 | 4120 517103 651 | 49678 676381 652 | 78954 905022 653 | 51854 149736 654 | 8272 794188 655 | 40205 286627 656 | 28033 104228 657 | 83532 623900 658 | 9469 226072 659 | 52611 148268 660 | 67271 613624 661 | 6266 660967 662 | 43280 84777 663 | 75328 402966 664 | 25364 620964 665 | 75767 92851 666 | 3084 805565 667 | 26185 337640 668 | 44157 648122 669 | 63665 15414 670 | 57887 586833 671 | 44925 557106 672 | 195 422417 673 | 36768 109733 674 | 6911 186436 675 | 56806 793087 676 | 79653 110834 677 | 73586 457649 678 | 97807 78905 679 | 17961 505726 680 | 10404 55784 681 | 1086 325896 682 | 48586 907224 683 | 30485 416178 684 | 55201 100191 685 | 80101 117073 686 | 28200 564446 687 | 4429 216530 688 | 93918 235614 689 | 87803 403333 690 | 56579 692896 691 | 20790 496918 692 | 60596 59821 693 | 42320 533985 694 | 20722 190840 695 | 98515 348650 696 | 75707 677849 697 | 34627 164783 698 | 23802 230843 699 | 97645 616193 700 | 23259 658765 701 | 3132 276718 702 | 60721 690327 703 | 61025 446639 704 | 80509 309381 705 | 19398 707943 706 | 78315 171389 707 | 34000 513800 708 | 92534 558574 709 | 80540 457649 710 | 58822 61289 711 | 76537 726293 712 | 14663 573254 713 | 14494 426454 714 | 74989 349384 715 | 6839 822814 716 | 64892 717485 717 | 50561 900251 718 | 44093 210291 719 | 89776 625001 720 | 30414 895480 721 | 16637 850339 722 | 23348 471962 723 | 54833 766296 724 | 28929 407003 725 | 40880 877130 726 | 39055 17983 727 | 32538 316354 728 | 4422 67528 729 | 51035 164049 730 | 19417 294334 731 | 81939 192675 732 | 35184 735101 733 | 81097 109733 734 | 70525 902453 735 | 60364 721889 736 | 73013 765195 737 | 72203 877864 738 | 61537 513066 739 | 2209 152672 740 | 8032 47710 741 | 32096 685556 742 | 13113 151938 743 | 29938 909793 744 | 15686 815841 745 | 84382 746845 746 | 99909 865019 747 | 99440 445538 748 | 96120 408838 749 | 75000 358192 750 | 74516 291031 751 | 59923 474531 752 | 3981 581328 753 | 26886 306078 754 | 56326 149369 755 | 46995 605917 756 | 52603 461686 757 | 28481 481504 758 | 92835 316721 759 | 98882 427555 760 | 33459 588668 761 | 74993 713081 762 | 48652 441501 763 | 8543 506460 764 | 49636 580961 765 | 21873 269745 766 | 47323 338374 767 | 78442 572520 768 | 85780 168820 769 | 47161 8808 770 | 64956 177628 771 | 29365 356724 772 | 19314 212126 773 | 34329 491046 774 | 19856 145332 775 | 3882 800060 776 | 99905 65693 777 | 77617 574722 778 | 48865 345347 779 | 32123 216530 780 | 4944 81841 781 | 85119 542793 782 | 52189 822080 783 | 17640 481504 784 | 70344 635277 785 | 36220 725559 786 | 3318 725559 787 | 27783 747212 788 | 47860 142763 789 | 35591 789050 790 | 44011 767030 791 | 53016 718219 792 | 28543 727761 793 | 48271 81474 794 | 3009 348283 795 | 31100 637112 796 | 38214 226439 797 | 42952 508295 798 | 3503 699135 799 | 14954 819511 800 | 17917 915298 801 | 11870 25323 802 | 23504 536187 803 | 74330 494349 804 | 13383 238183 805 | 9353 311216 806 | 32527 267910 807 | 78756 725192 808 | 2953 48811 809 | 67423 128450 810 | 43787 478201 811 | 41880 701704 812 | 78891 838962 813 | 14687 406636 814 | 68828 222035 815 | 199 206988 816 | 70818 798592 817 | 73477 147167 818 | 31564 288829 819 | 6384 353054 820 | 29723 252129 821 | 32792 840797 822 | 6143 113770 823 | 31554 317822 824 | 22381 906123 825 | 64916 150103 826 | 37111 531783 827 | 11559 909426 828 | 11434 327731 829 | 62660 429757 830 | 17563 212493 831 | 49755 387552 832 | 33527 890342 833 | 27596 286627 834 | 81178 441501 835 | 69159 580961 836 | 6021 860615 837 | 64728 721889 838 | 59463 789784 839 | 22938 668674 840 | 46058 828686 841 | 908 470494 842 | 94259 797124 843 | 17344 31562 844 | 22282 707943 845 | 81458 113403 846 | 90118 842632 847 | 85625 905756 848 | 61460 119275 849 | 42217 223870 850 | 96852 256900 851 | 92923 728862 852 | 38616 510864 853 | 51362 423518 854 | 17232 440033 855 | 26124 162948 856 | 87445 83309 857 | 96239 523342 858 | 59360 177628 859 | 21266 661701 860 | 18388 378744 861 | 63949 523709 862 | 15425 656196 863 | 42776 834191 864 | 34990 905756 865 | 90846 222402 866 | 57258 365532 867 | 13262 844100 868 | 57105 348650 869 | 57032 81107 870 | 84687 61656 871 | 38644 59087 872 | 5649 452878 873 | 97318 75235 874 | 70126 70097 875 | 7864 328832 876 | 46681 638947 877 | 60980 275984 878 | 86815 889608 879 | 72190 295802 880 | 94906 131753 881 | 13615 17616 882 | 79309 103127 883 | 69450 85144 884 | 56030 569951 885 | 87867 790151 886 | 58986 825383 887 | 77519 436363 888 | 70580 205887 889 | 13384 570685 890 | 30560 792720 891 | 21175 104228 892 | 22676 682620 893 | 89437 393057 894 | 42916 345347 895 | 13014 906490 896 | 72493 729229 897 | 84084 654728 898 | 25848 916766 899 | 32707 140194 900 | 42307 80740 901 | 4115 726660 902 | 84471 597476 903 | 28618 678216 904 | 22305 70097 905 | 49664 248826 906 | 49899 60188 907 | 48657 601880 908 | 53013 423885 909 | 96281 572520 910 | 79105 398195 911 | 52074 516736 912 | 92980 539123 913 | 48270 881167 914 | 5632 771801 915 | 25634 520406 916 | 52521 543527 917 | 77598 282223 918 | 36872 73767 919 | 93321 805565 920 | 28019 248459 921 | 20518 732532 922 | 12614 192308 923 | 91357 62023 924 | 88595 650691 925 | 35671 469026 926 | 21650 123679 927 | 76426 378744 928 | 91892 796390 929 | 81218 203318 930 | 42655 64959 931 | 63902 493982 932 | 96393 822447 933 | 24971 502056 934 | 56188 651425 935 | 31566 132854 936 | 12972 545362 937 | 71161 385717 938 | 40776 141662 939 | 69695 230109 940 | 31535 429757 941 | 34415 371037 942 | 40424 70097 943 | 73078 91016 944 | 25628 497652 945 | 85575 255065 946 | 98840 183867 947 | 68184 32296 948 | 50531 842999 949 | 82607 293233 950 | 2013 830888 951 | 31939 325529 952 | 5322 913830 953 | 39228 562978 954 | 24461 636745 955 | 43488 310115 956 | 11582 756020 957 | 39195 23855 958 | 51296 409205 959 | 33420 166251 960 | 63444 719687 961 | 68707 422050 962 | 26594 680785 963 | 31227 100191 964 | 42734 881534 965 | 95297 52481 966 | 18062 623166 967 | 21378 502056 968 | 50523 615826 969 | 1053 516736 970 | 55181 212860 971 | 57152 542793 972 | 94974 280021 973 | 30418 241486 974 | 12466 864285 975 | 22740 140194 976 | 25008 822080 977 | 52431 157443 978 | 37826 322226 979 | 8328 829787 980 | 13498 638213 981 | 58226 315620 982 | 98823 230476 983 | 32554 688859 984 | 69469 455447 985 | 37425 171389 986 | 74278 826851 987 | 80501 757121 988 | 31602 455814 989 | 30699 632341 990 | 14550 809969 991 | 77128 565547 992 | 60 37067 993 | 54671 68262 994 | 7225 502790 995 | 5021 712347 996 | 5220 498386 997 | 38696 582429 998 | 25665 507928 999 | 95119 740606 1000 | 13724 819511 1001 | 43755 280388 1002 | 38588 433427 1003 | 4914 560042 1004 | 4158 171022 1005 | 40583 405902 1006 | 40963 49545 1007 | 98441 611055 1008 | 21171 583163 1009 | 45202 25690 1010 | 51621 23488 1011 | 30318 189739 1012 | 31753 40737 1013 | 63236 149369 1014 | 10832 73033 1015 | 51852 344613 1016 | 55452 150103 1017 | 45550 815107 1018 | 26235 445171 1019 | 49224 379845 1020 | 86772 447006 1021 | 8960 669408 1022 | 75346 643351 1023 | 59724 694364 1024 | 51820 184968 1025 | 34887 501322 1026 | 17393 262405 1027 | 84947 714182 1028 | 29366 304610 1029 | 2757 307179 1030 | 97918 745377 1031 | 83696 386451 1032 | 6865 409205 1033 | 90792 583897 1034 | 13757 809602 1035 | 19898 216530 1036 | 75310 702805 1037 | 46023 295435 1038 | 76900 645186 1039 | 71721 548298 1040 | 50149 66794 1041 | 95777 123312 1042 | 76834 349384 1043 | 95194 418747 1044 | 8009 782077 1045 | 86767 683721 1046 | 36631 165517 1047 | 69187 913463 1048 | 67670 643718 1049 | 75159 601513 1050 | 52400 579860 1051 | 83115 333970 1052 | 95057 708310 1053 | 10054 336172 1054 | 75776 136157 1055 | 89556 359293 1056 | 45049 331034 1057 | 99453 778407 1058 | 1632 762626 1059 | 52264 11010 1060 | 98920 209924 1061 | 28202 460952 1062 | 98583 884470 1063 | 38997 74501 1064 | 63502 635644 1065 | 50931 481137 1066 | 66421 831989 1067 | 82152 743175 1068 | 93171 565547 1069 | 85158 80006 1070 | 48016 213961 1071 | 53862 820612 1072 | 91174 844100 1073 | 93087 494349 1074 | 13053 640048 1075 | 74380 84777 1076 | 23433 126615 1077 | 65773 708310 1078 | 30188 532517 1079 | 95609 888874 1080 | 97037 769232 1081 | 18479 711246 1082 | 44089 49912 1083 | 84510 72666 1084 | 15124 688125 1085 | 78124 73033 1086 | 23880 702805 1087 | 61675 256533 1088 | 51432 36333 1089 | 4972 414710 1090 | 66195 765929 1091 | 4657 439666 1092 | 58217 127349 1093 | 6814 731064 1094 | 50258 728862 1095 | 51558 331034 1096 | 1237 729596 1097 | 27534 413242 1098 | 48911 644452 1099 | 53267 902820 1100 | 45109 709044 1101 | 62174 834191 1102 | 67897 388286 1103 | 23225 160379 1104 | 45409 193409 1105 | 79562 193409 1106 | 43572 445905 1107 | 76115 270846 1108 | 41290 660233 1109 | 91860 101659 1110 | 8695 139460 1111 | 52928 904655 1112 | 50843 89548 1113 | 52122 622065 1114 | 43581 638580 1115 | 8810 873827 1116 | 5484 237816 1117 | 48439 718953 1118 | 80939 95420 1119 | 86482 85511 1120 | 4551 440400 1121 | 41176 298371 1122 | 78275 310482 1123 | 98843 913463 1124 | 29821 736936 1125 | 49806 507928 1126 | 49221 188638 1127 | 37904 649223 1128 | 38851 607752 1129 | 35363 835292 1130 | 47071 455447 1131 | 69143 697300 1132 | 64736 866487 1133 | 38735 165517 1134 | 39875 478201 1135 | 17391 851807 1136 | 63007 887406 1137 | 16044 813639 1138 | 21356 297637 1139 | 41976 653260 1140 | 47376 712347 1141 | 18198 864652 1142 | 40538 653260 1143 | 14784 539857 1144 | 20177 387552 1145 | 25083 76336 1146 | 99152 230476 1147 | 15306 706108 1148 | 76378 288829 1149 | 23791 365532 1150 | 18397 652159 1151 | 28232 462787 1152 | 80607 323694 1153 | 90471 499120 1154 | 21358 474898 1155 | 4054 899517 1156 | 75281 502790 1157 | 35084 727027 1158 | 82583 744276 1159 | 79564 704273 1160 | 4638 910894 1161 | 74688 568483 1162 | 19952 678216 1163 | 11623 220567 1164 | 69631 351219 1165 | 31507 10643 1166 | 74718 826851 1167 | 49140 418380 1168 | 71693 422784 1169 | 96145 121110 1170 | 18084 436363 1171 | 11866 149736 1172 | 33388 313051 1173 | 92288 471595 1174 | 46680 460952 1175 | 10364 132854 1176 | 18000 296169 1177 | 21502 380946 1178 | 97206 613991 1179 | 26245 827585 1180 | 53077 649590 1181 | 69320 338374 1182 | 52989 780609 1183 | 36153 581695 1184 | 24559 693263 1185 | 51682 175059 1186 | 33217 206621 1187 | 69246 51380 1188 | 13686 498019 1189 | 814 28993 1190 | 48106 524076 1191 | 49655 597843 1192 | 98471 40003 1193 | 13980 377276 1194 | 80040 146066 1195 | 79946 353054 1196 | 53749 411407 1197 | 95178 900985 1198 | 27492 524810 1199 | 499 96521 1200 | 12005 265708 1201 | 25557 909059 1202 | 83134 101292 1203 | 21293 610688 1204 | 24271 24956 1205 | 5354 215429 1206 | 70129 582429 1207 | 91083 533618 1208 | 49301 106797 1209 | 73456 463521 1210 | 25385 62390 1211 | 95848 513800 1212 | 12489 711246 1213 | 84309 754919 1214 | 32775 630873 1215 | 8262 854743 1216 | 23925 675280 1217 | 96390 342778 1218 | 13345 284425 1219 | 37748 707209 1220 | 43968 117807 1221 | 74577 680051 1222 | 60837 423518 1223 | 92147 136524 1224 | 99758 568116 1225 | 89597 729963 1226 | 28382 453245 1227 | 79295 890709 1228 | 23805 92851 1229 | 90688 822080 1230 | 42218 858046 1231 | 35294 817309 1232 | 29939 243688 1233 | 19488 290297 1234 | 12681 501689 1235 | 53527 767030 1236 | 8708 481137 1237 | 57618 125881 1238 | 82639 229742 1239 | 19546 741707 1240 | 56126 888874 1241 | 93880 508662 1242 | 61642 303509 1243 | 43982 834558 1244 | 76326 230109 1245 | 93894 98723 1246 | 14564 553436 1247 | 25502 625368 1248 | 76941 325162 1249 | 14330 149002 1250 | 42281 198914 1251 | 85475 197079 1252 | 78777 914564 1253 | 87705 491413 1254 | 34969 820979 1255 | 91267 828686 1256 | 66375 617661 1257 | 39822 820979 1258 | 4416 261304 1259 | 29842 322593 1260 | 50159 342411 1261 | 3766 617294 1262 | 37762 349017 1263 | 30273 890342 1264 | 6195 358926 1265 | 65255 633809 1266 | 517 642617 1267 | 27684 821713 1268 | 15114 328098 1269 | 55545 685189 1270 | 1760 705374 1271 | 74673 467558 1272 | 21173 814740 1273 | 75230 148635 1274 | 77144 349017 1275 | 56510 662435 1276 | 49185 439666 1277 | 59512 36333 1278 | 18384 600412 1279 | 15786 22754 1280 | 70761 909059 1281 | 32990 205153 1282 | 39734 488844 1283 | 38586 697667 1284 | 7586 751616 1285 | 21175 466090 1286 | 38664 26424 1287 | 23304 501322 1288 | 70364 765929 1289 | 63992 313051 1290 | 71090 681519 1291 | 89830 357458 1292 | 570 147901 1293 | 66532 59454 1294 | 6712 434161 1295 | 18097 297637 1296 | 70805 499120 1297 | 80889 416545 1298 | 43821 895480 1299 | 68366 829787 1300 | 63373 732899 1301 | 50381 63858 1302 | 7901 655462 1303 | 62619 341677 1304 | 91920 624634 1305 | 70764 784279 1306 | 2059 278186 1307 | 85468 673078 1308 | 12552 418013 1309 | 5896 325162 1310 | 62430 605917 1311 | 76992 535086 1312 | 45461 397461 1313 | 54993 913830 1314 | 43240 675647 1315 | 80385 522241 1316 | 83514 753084 1317 | 98309 754552 1318 | 86121 416545 1319 | 67246 125514 1320 | 30880 520773 1321 | 59703 750882 1322 | 36168 131753 1323 | 95970 443703 1324 | 89830 215062 1325 | 54619 139460 1326 | 68806 166251 1327 | 63838 774003 1328 | 57130 752350 1329 | 59916 269378 1330 | 47473 151571 1331 | 50161 189372 1332 | 74280 101659 1333 | 85325 139460 1334 | 60781 219099 1335 | 76559 200749 1336 | 52116 908692 1337 | 49622 296536 1338 | 21946 346448 1339 | 36096 283324 1340 | 79011 371404 1341 | 43283 71932 1342 | 98283 748313 1343 | 92302 372872 1344 | 59636 18717 1345 | 12685 213227 1346 | 46606 396727 1347 | 65778 766296 1348 | 40050 586833 1349 | 76545 800060 1350 | 91625 262772 1351 | 77062 657297 1352 | 55639 653994 1353 | 47508 401865 1354 | 47014 254331 1355 | 84852 901719 1356 | 85699 826484 1357 | 47801 697667 1358 | 21090 358559 1359 | 74500 729596 1360 | 26675 439299 1361 | 11529 563712 1362 | 64648 750515 1363 | 13294 650691 1364 | 56169 785380 1365 | 77477 122945 1366 | 13054 497652 1367 | 82488 407370 1368 | 88520 572520 1369 | 54807 824282 1370 | 92232 485541 1371 | 41350 183500 1372 | 4210 93585 1373 | 14633 571052 1374 | 42985 895113 1375 | 14554 219099 1376 | 59389 127716 1377 | 59939 515268 1378 | 80608 867588 1379 | 21159 506093 1380 | 89176 708677 1381 | 69493 248826 1382 | 22030 656563 1383 | 69861 135790 1384 | 84028 437097 1385 | 91440 539123 1386 | 93365 355990 1387 | 49137 758222 1388 | 3012 426087 1389 | 40656 380946 1390 | 93160 307179 1391 | 42650 716751 1392 | 42173 901719 1393 | 91991 254698 1394 | 16743 113403 1395 | 93022 689593 1396 | 69835 482605 1397 | 81294 240752 1398 | 4102 214695 1399 | 58249 689960 1400 | 82333 102026 1401 | 98766 82942 1402 | 59003 854009 1403 | 20720 794555 1404 | 82438 627937 1405 | 50737 824649 1406 | 34009 387185 1407 | 50934 688492 1408 | 38 193042 1409 | 49695 124046 1410 | 36011 148635 1411 | 75906 477467 1412 | 23208 106430 1413 | 75575 460585 1414 | 91908 876763 1415 | 90815 457649 1416 | 55440 286994 1417 | 81088 556739 1418 | 24449 670509 1419 | 27314 354155 1420 | 96561 674913 1421 | 7671 53215 1422 | 74957 275617 1423 | 40227 763360 1424 | 45323 259469 1425 | 65200 850339 1426 | 99897 100558 1427 | 29342 915665 1428 | 70436 281122 1429 | 72472 625001 1430 | 70108 707943 1431 | 56293 486642 1432 | 44553 548665 1433 | 53014 174692 1434 | 77604 282223 1435 | 20418 131386 1436 | 61989 263139 1437 | 1044 96154 1438 | 99689 481137 1439 | 1767 361862 1440 | 47639 354155 1441 | 21168 297637 1442 | 88123 107898 1443 | 49332 150103 1444 | 74121 757855 1445 | 81465 618395 1446 | 38206 32663 1447 | 97843 291031 1448 | 12189 321859 1449 | 86686 786481 1450 | 60674 277452 1451 | 93398 149369 1452 | 51253 655095 1453 | 10387 248459 1454 | 89718 342411 1455 | 80724 340576 1456 | 93197 916399 1457 | 37080 114871 1458 | 50843 198547 1459 | 89598 871625 1460 | 12312 414710 1461 | 55545 671977 1462 | 33926 24222 1463 | 28073 355256 1464 | 5458 117073 1465 | 74328 159645 1466 | 40783 694731 1467 | 43145 478935 1468 | 31814 485174 1469 | 13902 580227 1470 | 88724 500955 1471 | 38965 371037 1472 | 11459 367367 1473 | 15344 881901 1474 | 35528 239651 1475 | 55134 288462 1476 | 25542 742808 1477 | 69900 563712 1478 | 26731 109366 1479 | 63577 849972 1480 | 49508 693630 1481 | 84399 611789 1482 | 52416 840063 1483 | 93648 251028 1484 | 95476 872359 1485 | 3187 366633 1486 | 12248 460585 1487 | 34039 273048 1488 | 20703 817676 1489 | 24080 246991 1490 | 21147 766663 1491 | 42441 706475 1492 | 34442 227173 1493 | 98503 498019 1494 | 25370 48077 1495 | 63909 355256 1496 | 41844 649957 1497 | 87081 205887 1498 | 93003 291765 1499 | 65971 222769 1500 | 5588 187904 1501 | 29109 751616 1502 | 30206 773269 1503 | 74415 282223 1504 | 49666 142029 1505 | 12160 334337 1506 | 61391 612156 1507 | 16883 320758 1508 | 22904 701704 1509 | 17185 495083 1510 | 68605 911995 1511 | 45028 850339 1512 | 46525 702438 1513 | 7415 901719 1514 | 4949 13579 1515 | 93905 306078 1516 | 66836 177995 1517 | 61137 503157 1518 | 78526 78905 1519 | 83955 585365 1520 | 42234 303876 1521 | 2231 12845 1522 | 7371 435996 1523 | 11305 35232 1524 | 84680 894746 1525 | 17656 548298 1526 | 92680 561877 1527 | 76721 71198 1528 | 73502 351219 1529 | 11350 292132 1530 | 48803 216897 1531 | 23505 95787 1532 | 27161 205887 1533 | 61929 23488 1534 | 10102 657297 1535 | 68383 166985 1536 | 74986 355990 1537 | 14824 99824 1538 | 46769 305711 1539 | 38928 851440 1540 | 8275 475999 1541 | 29238 389754 1542 | 17321 507194 1543 | 24813 314886 1544 | 26684 131386 1545 | 87542 716384 1546 | 25068 602981 1547 | 16289 528480 1548 | 25034 177995 1549 | 87282 374707 1550 | 93233 12111 1551 | 75751 95787 1552 | 74971 578392 1553 | 15024 167719 1554 | 8575 90282 1555 | 43819 484440 1556 | 98201 480403 1557 | 26792 539490 1558 | 10120 576924 1559 | 91794 650324 1560 | 83080 458750 1561 | 7519 268277 1562 | 52499 556372 1563 | 53658 59821 1564 | 76769 801528 1565 | 9427 841531 1566 | 76653 680051 1567 | 64994 155975 1568 | 67442 85878 1569 | 40708 339475 1570 | 60764 561877 1571 | 11923 55784 1572 | 9259 430858 1573 | 55047 909793 1574 | 78164 55784 1575 | 71911 104962 1576 | 38774 473430 1577 | 35233 487009 1578 | 21308 248826 1579 | 94485 94319 1580 | 4136 433060 1581 | 31699 892177 1582 | 45627 826851 1583 | 76135 732532 1584 | 45069 837494 1585 | 32600 296169 1586 | 41570 397828 1587 | 27397 583163 1588 | 12874 564446 1589 | 13012 424252 1590 | 70901 97622 1591 | 2380 917133 1592 | 67070 42939 1593 | 8415 653994 1594 | 46885 607018 1595 | 24645 623166 1596 | 9128 663536 1597 | 61400 91383 1598 | 72496 772902 1599 | 21292 647021 1600 | 49539 42205 1601 | 59454 703539 1602 | 95124 598944 1603 | 25981 172490 1604 | 76874 88080 1605 | 34181 275250 1606 | 44294 721522 1607 | 90144 656196 1608 | 71947 761892 1609 | 65776 32296 1610 | 18777 473063 1611 | 48336 398195 1612 | 7449 537655 1613 | 98163 206988 1614 | 57248 187170 1615 | 60815 420215 1616 | 66631 656930 1617 | 87857 487009 1618 | 98344 660233 1619 | 10016 434161 1620 | 51143 609220 1621 | 524 129551 1622 | 70013 310849 1623 | 53381 319657 1624 | 14158 463521 1625 | 59224 30461 1626 | 20007 786481 1627 | 51377 493982 1628 | 73939 269745 1629 | 77267 801895 1630 | 54310 530315 1631 | 27397 485908 1632 | 9297 860615 1633 | 56736 684822 1634 | 3659 167719 1635 | 83904 334337 1636 | 1847 466457 1637 | 65951 688492 1638 | 53628 435996 1639 | 80153 635644 1640 | 19785 529214 1641 | 64117 841531 1642 | 85252 714916 1643 | 50684 611055 1644 | 12000 847770 1645 | 20243 135056 1646 | 6837 602614 1647 | 67661 221668 1648 | 34715 665738 1649 | 41067 565547 1650 | 9900 434895 1651 | 12020 854376 1652 | 6644 217264 1653 | 8177 873827 1654 | 76047 571052 1655 | 17721 380212 1656 | 29250 898049 1657 | 68348 759323 1658 | 18477 739138 1659 | 6079 654728 1660 | 68237 136157 1661 | 75640 211025 1662 | 86480 374707 1663 | 78267 347549 1664 | 1607 430858 1665 | 357 297270 1666 | 73347 378744 1667 | 61634 703172 1668 | 3591 637846 1669 | 48059 54683 1670 | 34595 843366 1671 | 58098 432326 1672 | 41236 420215 1673 | 42969 617661 1674 | 61465 699502 1675 | 89789 867955 1676 | 41884 679684 1677 | 7832 644085 1678 | 69157 703539 1679 | 5813 55784 1680 | 25753 909793 1681 | 84882 232311 1682 | 28236 792353 1683 | 66425 307179 1684 | 2867 126248 1685 | 31129 510497 1686 | 12334 481137 1687 | 61304 343145 1688 | 76689 696566 1689 | 92490 815107 1690 | 62682 829787 1691 | 31924 73400 1692 | 6453 217264 1693 | 82962 768131 1694 | 6502 908692 1695 | 88994 531783 1696 | 19614 77437 1697 | 27 778774 1698 | 32957 197079 1699 | 42307 901352 1700 | 98965 849238 1701 | 67907 93218 1702 | 24720 845568 1703 | 8135 531783 1704 | 85993 579860 1705 | 49595 798959 1706 | 19044 590136 1707 | 66428 464989 1708 | 96965 610688 1709 | 39167 422784 1710 | 14485 517470 1711 | 33322 292866 1712 | 56871 622432 1713 | 35593 187170 1714 | 71130 298738 1715 | 94548 31562 1716 | 89781 341310 1717 | 11698 841898 1718 | 33541 476733 1719 | 84401 208456 1720 | 23782 871992 1721 | 71776 150103 1722 | 89303 524810 1723 | 97012 59821 1724 | 34276 40370 1725 | 67799 633075 1726 | 48738 80373 1727 | 73587 75969 1728 | 80505 658398 1729 | 99460 243688 1730 | 16396 538756 1731 | 29282 255799 1732 | 76269 712347 1733 | 9079 838962 1734 | 90505 875662 1735 | 8527 843733 1736 | 97741 561143 1737 | 62909 58353 1738 | 93060 685556 1739 | 25098 82942 1740 | 33019 434161 1741 | 99158 521507 1742 | 20722 517103 1743 | 21072 32296 1744 | 76023 423518 1745 | 38953 587567 1746 | 22750 746845 1747 | 65705 419481 1748 | 6688 163315 1749 | 92015 565914 1750 | 1060 461686 1751 | 32277 192675 1752 | 40017 244422 1753 | 86073 401498 1754 | 45443 785747 1755 | 67915 350118 1756 | 20130 790151 1757 | 13588 441868 1758 | 45732 183500 1759 | 37586 690327 1760 | 95720 336539 1761 | 79270 795656 1762 | 15654 522241 1763 | 40442 166985 1764 | 28149 143864 1765 | 74227 125147 1766 | 37929 741340 1767 | 32564 819511 1768 | 88748 697667 1769 | 72932 667940 1770 | 34891 725192 1771 | 85724 344613 1772 | 5130 536187 1773 | 27184 384983 1774 | 3375 469760 1775 | 9147 913830 1776 | 27132 805565 1777 | 60156 914197 1778 | 19480 436363 1779 | 63085 118541 1780 | 95875 678216 1781 | 96386 473797 1782 | 30546 634910 1783 | 2901 522608 1784 | 84309 830888 1785 | 10913 880433 1786 | 52951 568116 1787 | 32186 407370 1788 | 71350 798592 1789 | 29969 495817 1790 | 54597 772902 1791 | 70674 617661 1792 | 21686 226072 1793 | 35620 373239 1794 | 28105 240385 1795 | 72430 546830 1796 | 6095 768131 1797 | 28032 113403 1798 | 56270 489578 1799 | 23400 894746 1800 | 4746 169554 1801 | 76656 698401 1802 | 35221 440400 1803 | 94976 549032 1804 | 3769 224971 1805 | 23568 309014 1806 | 94145 905389 1807 | 59191 684088 1808 | 89326 69363 1809 | 22744 817309 1810 | 10313 615826 1811 | 52897 803730 1812 | 5064 569217 1813 | 30828 354889 1814 | 20704 302041 1815 | 49244 622065 1816 | 35601 670876 1817 | 31530 126248 1818 | 35798 652159 1819 | 50917 133221 1820 | 45780 747946 1821 | 69649 495083 1822 | 77149 17249 1823 | 74561 210658 1824 | 27993 888140 1825 | 91518 127716 1826 | 92589 871258 1827 | 16146 13212 1828 | 31780 486275 1829 | 42676 142396 1830 | 45754 424986 1831 | 20541 23855 1832 | 28 658031 1833 | 63157 28259 1834 | 77586 383148 1835 | 96707 695832 1836 | 76473 162948 1837 | 58300 786481 1838 | 20447 760791 1839 | 13038 55784 1840 | 1823 732899 1841 | 56455 102760 1842 | 43847 598944 1843 | 28197 224604 1844 | 95753 538389 1845 | 77658 736936 1846 | 22300 244789 1847 | 30904 169187 1848 | 64264 774003 1849 | 24516 273782 1850 | 55622 415444 1851 | 90153 399663 1852 | 22516 728862 1853 | 22624 400764 1854 | 55010 444437 1855 | 82924 18717 1856 | 70743 564446 1857 | 72758 278553 1858 | 78699 313785 1859 | 24163 117073 1860 | 39335 492147 1861 | 57124 309748 1862 | 99084 509763 1863 | 97387 295802 1864 | 32194 816942 1865 | 43610 774737 1866 | 31068 205887 1867 | 39599 567382 1868 | 51053 818043 1869 | 97433 571419 1870 | 99402 525911 1871 | 20142 150470 1872 | 32195 513433 1873 | 65175 703539 1874 | 54938 186436 1875 | 66297 121477 1876 | 28273 891810 1877 | 47434 297270 1878 | 65237 477467 1879 | 49416 797124 1880 | 31312 105696 1881 | 10547 394892 1882 | 89521 915665 1883 | 60827 557840 1884 | 15 677849 1885 | 55502 314519 1886 | 91984 166985 1887 | 29946 561877 1888 | 43871 206988 1889 | 70548 47710 1890 | 24782 626102 1891 | 13642 523342 1892 | 57502 387185 1893 | 6224 866854 1894 | 98862 847770 1895 | 92233 60922 1896 | 30741 790518 1897 | 53965 303509 1898 | 69578 647755 1899 | 72056 431592 1900 | 78927 299472 1901 | 92924 277085 1902 | 28940 482972 1903 | 43668 827585 1904 | 52084 134322 1905 | 21674 881534 1906 | 2587 526645 1907 | 48418 16148 1908 | 71794 770333 1909 | 18615 532150 1910 | 82917 279287 1911 | 85208 114137 1912 | 73381 139093 1913 | 49805 231577 1914 | 98845 284058 1915 | 59964 797124 1916 | 74593 721155 1917 | 75721 789417 1918 | 82430 308280 1919 | 79182 468659 1920 | 37556 424986 1921 | 47980 588301 1922 | 26496 773636 1923 | 93251 824282 1924 | 86790 666472 1925 | 98458 429390 1926 | 46383 241119 1927 | 18715 499120 1928 | 1643 826117 1929 | 31313 361862 1930 | 53967 801161 1931 | 37791 758222 1932 | 89490 280755 1933 | 27636 574355 1934 | 78337 83676 1935 | 43324 188638 1936 | 38603 558941 1937 | 23628 742441 1938 | 56106 233779 1939 | 8093 723724 1940 | 41550 419114 1941 | 87109 524810 1942 | 33232 473797 1943 | 97029 819511 1944 | 18545 770333 1945 | 39937 201116 1946 | 77310 897315 1947 | 31291 520039 1948 | 4005 410306 1949 | 38057 384616 1950 | 62018 271947 1951 | 40666 719687 1952 | 59611 549032 1953 | 21189 492881 1954 | 82634 525911 1955 | 92735 873460 1956 | 90651 85511 1957 | 4708 340943 1958 | 12421 57619 1959 | 97722 533985 1960 | 27200 329199 1961 | 50462 662068 1962 | 85191 290664 1963 | 70601 631974 1964 | 49579 456915 1965 | 38894 208823 1966 | 6916 685556 1967 | 47648 62390 1968 | 22752 157076 1969 | 87996 420582 1970 | 91615 167352 1971 | 66420 189372 1972 | 73099 561877 1973 | 16210 423885 1974 | 27646 357825 1975 | 41077 8074 1976 | 82442 527746 1977 | 16816 714549 1978 | 25747 790151 1979 | 35284 623900 1980 | 20134 253964 1981 | 15887 560042 1982 | 83813 12478 1983 | 50283 241486 1984 | 67569 757121 1985 | 95705 683354 1986 | 64813 227540 1987 | 98810 628304 1988 | 40017 360761 1989 | 92262 667940 1990 | 42362 94686 1991 | 50513 26424 1992 | 88057 465723 1993 | 55000 567382 1994 | 98343 472696 1995 | 53524 780976 1996 | 83908 740973 1997 | 82788 711613 1998 | 20058 484073 1999 | 40941 895480 2000 | 76263 416912 2001 | 48122 661701 2002 | -------------------------------------------------------------------------------- /data/tmp.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dhpollack/programming_notebooks/168936ac09071c00a61d56b35467d7d27ec61e60/data/tmp.zip -------------------------------------------------------------------------------- /databricks_python_graphframes_GraphExample.ipynb: -------------------------------------------------------------------------------- 1 | {"cells":[{"cell_type":"code","source":["import graphframes as gf\n# note on databricks: sc and sqlContext are preloaded"],"metadata":{},"outputs":[],"execution_count":1},{"cell_type":"code","source":["rawdat = sqlContext.sql(\"SELECT * FROM coursera_algos_mincut2\")\nprint rawdat.first()"],"metadata":{},"outputs":[],"execution_count":2},{"cell_type":"code","source":["def verticesList(row):\n return [(row.vertex, str(row.vertex))]\n\ndef adjacencyList(row):\n e = []\n v1 = row.vertex\n for i in range(1, 40):\n v2 = row[\"C\"+str(i)]\n if v2 is not None:\n edge = (v1, v2) if v1 < v2 else (v2, v1)\n e.append(edge)\n return e\n\n# http://graphframes.github.io/\n# http://cdn2.hubspot.net/hubfs/438089/notebooks/help/Setup_graphframes_package.html\nvertices = rawdat.flatMap(verticesList)\nedges = rawdat.flatMap(adjacencyList)\n# using distinct on edges because this is an undirected graph\nG = gf.GraphFrame(vertices.toDF([\"id\", \"name\"]), edges.distinct().toDF([\"src\", \"dst\"]))\nprint G.edges.collect()\n"],"metadata":{},"outputs":[],"execution_count":3},{"cell_type":"code","source":["print(G.edges.filter(\"src = 1\").collect())"],"metadata":{},"outputs":[],"execution_count":4},{"cell_type":"code","source":["print G.edges.count(), G.vertices.count()\n"],"metadata":{},"outputs":[],"execution_count":5},{"cell_type":"code","source":[""],"metadata":{},"outputs":[],"execution_count":6}],"metadata":{"name":"GraphExample","notebookId":4197684023626608},"nbformat":4,"nbformat_minor":0} 2 | -------------------------------------------------------------------------------- /databricks_python_numpy_spark_kMeansExample.ipynb: -------------------------------------------------------------------------------- 1 | {"cells":[{"cell_type":"code","source":["import numpy as np\n#import pandas as pd\nimport matplotlib.pyplot as plt\n#%matplotlib inline\n\nc_num = 100\ncenters = [tuple([np.random.rand(), np.random.rand()]) for p in range(c_num)]\nprint centers\nsd = np.eye(2) / 10000\nX=[]\nn = 1000\nfor c in centers:\n X.extend(np.random.multivariate_normal(c, sd, size=n))\nX = np.array(X)\nfig = plt.figure()\nplt.scatter(x=X[:,0], y=X[:,1])\ndisplay(fig)\n"],"metadata":{},"outputs":[],"execution_count":1},{"cell_type":"code","source":["from pyspark.mllib.clustering import KMeans, KMeansModel\n\nX_parr = sc.parallelize(X)\n\nmodel = KMeans.train(X_parr, c_num, maxIterations=15, initializationMode=\"random\")\nprint model.centers\npredictions = model.predict(X_parr).collect()\nfig = plt.figure()\nplt.scatter(x=X[:,0], y=X[:,1], c=predictions)\ndisplay(fig)\n"],"metadata":{},"outputs":[],"execution_count":2},{"cell_type":"code","source":[""],"metadata":{},"outputs":[],"execution_count":3}],"metadata":{"name":"kMeansExample","notebookId":3297546998532744},"nbformat":4,"nbformat_minor":0} 2 | -------------------------------------------------------------------------------- /imdb_tokenize.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 8, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "import torch.utils.data as data\n", 11 | "import torch.nn as nn\n", 12 | "import torch.nn.functional as F\n", 13 | "import numpy as np\n", 14 | "import nltk\n", 15 | "import os\n", 16 | "import bs4\n", 17 | "import random\n", 18 | "\n", 19 | "torch.manual_seed(12345)\n", 20 | "random.seed(12345)\n", 21 | "\n", 22 | "nltk.data.path.append(\"/home/david/Programming/data/nltk_data\")" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 25, 28 | "metadata": { 29 | "collapsed": true 30 | }, 31 | "outputs": [], 32 | "source": [ 33 | "def concat_contractions(tokens):\n", 34 | " contractions = set([\"'ve\", \"'d\", \"'m\", \"'ll\", \"'re\", \"n't\"])\n", 35 | " return [\"\".join(tokens[i]) if (i+1 == len(tokens) or tokens[i+1] not in contractions) else \"\".join(tokens[(i):(i+2)]) for i in range(len(tokens)) if tokens[i] not in contractions]\n", 36 | "\n", 37 | "def data_processing(ds_paths, max_len=500, split_ratio=1.0):\n", 38 | " ds = []\n", 39 | " for i, tfp in enumerate(ds_paths):\n", 40 | " idx, rating = os.path.basename(tfp).split(\".\")[0].split(\"_\")\n", 41 | " with open(tfp, \"r\") as f:\n", 42 | " raw = f.readlines()\n", 43 | " raw = bs4.BeautifulSoup(raw[0], \"html5lib\")\n", 44 | " txt = raw.get_text(separator=' ')\n", 45 | " tokens = nltk.word_tokenize(txt)\n", 46 | " tokens = concat_contractions(tokens)\n", 47 | " #tokens = [vocab[w] if w in vocab else len(vocab) for w in tokens] # keep out of vocab\n", 48 | " tokens = [vocab[w] for w in tokens if w in vocab]\n", 49 | " if len(tokens) > max_len:\n", 50 | " tokens = tokens[:max_len]\n", 51 | " elif len(tokens) < max_len:\n", 52 | " tokens = tokens + [0]*(max_len-len(tokens))\n", 53 | " ds.append((tokens, int(rating)))\n", 54 | " dat, labels = zip(*ds)\n", 55 | " assert split_ratio >= 0. and split_ratio <= 1.0\n", 56 | " if split_ratio == 1.:\n", 57 | " return (dat, labels), (None, None)\n", 58 | " else:\n", 59 | " split_idx = int(len(dat) * split_ratio)\n", 60 | " tidx = list(range(len(dat)))\n", 61 | " random.shuffle(tidx)\n", 62 | " tidx, vidx = tidx[:split_idx], tidx[split_idx:]\n", 63 | " ts, ts_labels = [dat[tid] for tid in tidx], [labels[tid] for tid in tidx]\n", 64 | " vs, vs_labels = [dat[vid] for vid in vidx], [labels[vid] for vid in vidx]\n", 65 | " return (ts, ts_labels), (vs, vs_labels)\n" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": 5, 71 | "metadata": {}, 72 | "outputs": [ 73 | { 74 | "name": "stdout", 75 | "output_type": "stream", 76 | "text": [ 77 | "imdbEr.txt imdb.vocab README \u001b[0m\u001b[01;34mtest\u001b[0m/ \u001b[01;34mtrain\u001b[0m/\r\n" 78 | ] 79 | }, 80 | { 81 | "data": { 82 | "text/plain": [ 83 | "['/home/david/Programming/data/aclImdb/train/neg/0_3.txt',\n", 84 | " '/home/david/Programming/data/aclImdb/train/neg/10000_4.txt',\n", 85 | " '/home/david/Programming/data/aclImdb/train/neg/10001_4.txt',\n", 86 | " '/home/david/Programming/data/aclImdb/train/neg/10002_1.txt',\n", 87 | " '/home/david/Programming/data/aclImdb/train/neg/10003_1.txt']" 88 | ] 89 | }, 90 | "execution_count": 5, 91 | "metadata": {}, 92 | "output_type": "execute_result" 93 | } 94 | ], 95 | "source": [ 96 | "IMDB_BASEDIR = \"/home/david/Programming/data/aclImdb\"\n", 97 | "%ls $IMDB_BASEDIR\n", 98 | "train_paths = sorted([f.path for d in [\"pos\", \"neg\"] for f in os.scandir(os.path.join(IMDB_BASEDIR, \"train\", d))])\n", 99 | "test_paths = sorted([f.path for d in [\"pos\", \"neg\"] for f in os.scandir(os.path.join(IMDB_BASEDIR, \"test\", d))])\n", 100 | "\n", 101 | "train_paths[:5]" 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": 6, 107 | "metadata": { 108 | "collapsed": true 109 | }, 110 | "outputs": [], 111 | "source": [ 112 | "vocab_limit = 5000\n", 113 | "with open(os.path.join(IMDB_BASEDIR, \"imdb.vocab\"), \"r\") as f:\n", 114 | " vocab = {w:(i+1) for i, w in enumerate([l.strip() for l in f.readlines()][:vocab_limit])}\n" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": 26, 120 | "metadata": {}, 121 | "outputs": [ 122 | { 123 | "name": "stdout", 124 | "output_type": "stream", 125 | "text": [ 126 | "500 500\n" 127 | ] 128 | } 129 | ], 130 | "source": [ 131 | "\n", 132 | "trainset, validset = data_processing(train_paths, split_ratio = 0.9)\n", 133 | "print(len(trainset[0][0]), len(validset[0][0]))\n", 134 | "ts, ts_labels = torch.Tensor(trainset[0]).long(), torch.Tensor(trainset[1])\n", 135 | "ts_labels = (ts_labels > 5).float()\n", 136 | "dts = data.TensorDataset(ts, ts_labels)\n", 137 | "dlts = data.DataLoader(dts, batch_size=100)" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": 33, 143 | "metadata": {}, 144 | "outputs": [ 145 | { 146 | "name": "stdout", 147 | "output_type": "stream", 148 | "text": [ 149 | "2500\n" 150 | ] 151 | } 152 | ], 153 | "source": [ 154 | "vs, vs_labels = torch.Tensor(validset[0]).long(), torch.Tensor(validset[1])\n", 155 | "vs_labels = (vs_labels > 5).float()\n", 156 | "dvs = data.TensorDataset(vs, vs_labels)\n", 157 | "dlvs = data.DataLoader(dvs, batch_size=100)\n", 158 | "print(len(vs))" 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": 12, 164 | "metadata": {}, 165 | "outputs": [ 166 | { 167 | "name": "stdout", 168 | "output_type": "stream", 169 | "text": [ 170 | "25000 2\n" 171 | ] 172 | } 173 | ], 174 | "source": [ 175 | "#split_ratio = 0.9\n", 176 | "#split_idx = int(len(trainset) * split_ratio)\n", 177 | "#tidx = list(range(len(trainset)))\n", 178 | "#random.shuffle(tidx)\n", 179 | "#tidx, vidx = tidx[:split_idx], tidx[split_idx:]\n", 180 | "#ts, ts_labels = [trainset[tid] for tid in tidx], [train_labels[tid] for tid in tidx]\n", 181 | "#vs, vs_labels = [trainset[vid] for vid in vidx], [train_labels[vid] for vid in vidx]\n", 182 | "\n", 183 | "#ts, ts_labels = torch.Tensor(ts).long(), torch.Tensor(ts_labels)\n", 184 | "#ts_labels = (ts_labels > 5).float()\n", 185 | "#dts = data.TensorDataset(ts, ts_labels)\n", 186 | "#dlts = data.DataLoader(dts, batch_size=100)" 187 | ] 188 | }, 189 | { 190 | "cell_type": "code", 191 | "execution_count": 60, 192 | "metadata": { 193 | "scrolled": false 194 | }, 195 | "outputs": [ 196 | { 197 | "name": "stdout", 198 | "output_type": "stream", 199 | "text": [ 200 | "0 5000\n", 201 | "torch.Size([22500, 500])\n", 202 | "SingleHiddenNN (\n", 203 | " (emb): Embedding(5001, 32)\n", 204 | " (fc): Linear (16000 -> 100)\n", 205 | " (relu): SELU\n", 206 | " (dropout): Dropout (p = 0.7)\n", 207 | " (out): Linear (100 -> 1)\n", 208 | " (sigmoid): Sigmoid ()\n", 209 | ")\n", 210 | "epoch 1 had a loss of 180.21:\n", 211 | "epoch 2 had a loss of 149.41:\n", 212 | "epoch 3 had a loss of 157.26:\n", 213 | "epoch 4 had a loss of 136.89:\n", 214 | "epoch 5 had a loss of 146.05:\n", 215 | "epoch 6 had a loss of 128.16:\n", 216 | "correct: 1506, total: 2500\n", 217 | "validation accuracy: 60.24\n", 218 | "epoch 7 had a loss of 138.26:\n", 219 | "epoch 8 had a loss of 122.53:\n", 220 | "epoch 9 had a loss of 131.25:\n", 221 | "epoch 10 had a loss of 116.95:\n", 222 | "epoch 11 had a loss of 125.69:\n", 223 | "correct: 1515, total: 2500\n", 224 | "validation accuracy: 60.60\n", 225 | "epoch 12 had a loss of 113.37:\n", 226 | "epoch 13 had a loss of 120.07:\n", 227 | "epoch 14 had a loss of 106.46:\n", 228 | "epoch 15 had a loss of 116.02:\n", 229 | "epoch 16 had a loss of 102.44:\n", 230 | "correct: 1567, total: 2500\n", 231 | "validation accuracy: 62.68\n", 232 | "epoch 17 had a loss of 111.7:\n", 233 | "epoch 18 had a loss of 99.776:\n", 234 | "epoch 19 had a loss of 109.25:\n", 235 | "epoch 20 had a loss of 96.52:\n", 236 | "epoch 21 had a loss of 104.85:\n", 237 | "correct: 1563, total: 2500\n", 238 | "validation accuracy: 62.52\n", 239 | "epoch 22 had a loss of 91.543:\n", 240 | "epoch 23 had a loss of 100.93:\n", 241 | "epoch 24 had a loss of 89.977:\n", 242 | "epoch 25 had a loss of 97.779:\n", 243 | "epoch 26 had a loss of 85.79:\n", 244 | "correct: 1630, total: 2500\n", 245 | "validation accuracy: 65.20\n", 246 | "epoch 27 had a loss of 94.657:\n", 247 | "epoch 28 had a loss of 82.805:\n", 248 | "epoch 29 had a loss of 91.25:\n", 249 | "epoch 30 had a loss of 81.997:\n", 250 | "epoch 31 had a loss of 90.426:\n", 251 | "correct: 1589, total: 2500\n", 252 | "validation accuracy: 63.56\n", 253 | "epoch 32 had a loss of 78.93:\n", 254 | "epoch 33 had a loss of 86.228:\n", 255 | "epoch 34 had a loss of 76.29:\n", 256 | "epoch 35 had a loss of 83.203:\n", 257 | "epoch 36 had a loss of 74.013:\n", 258 | "correct: 1689, total: 2500\n", 259 | "validation accuracy: 67.56\n", 260 | "epoch 37 had a loss of 82.516:\n", 261 | "epoch 38 had a loss of 72.502:\n", 262 | "epoch 39 had a loss of 79.541:\n", 263 | "epoch 40 had a loss of 70.698:\n", 264 | "epoch 41 had a loss of 76.423:\n", 265 | "correct: 1676, total: 2500\n", 266 | "validation accuracy: 67.04\n", 267 | "epoch 42 had a loss of 66.121:\n", 268 | "epoch 43 had a loss of 74.163:\n", 269 | "epoch 44 had a loss of 65.6:\n", 270 | "epoch 45 had a loss of 72.702:\n", 271 | "epoch 46 had a loss of 63.273:\n", 272 | "correct: 1744, total: 2500\n", 273 | "validation accuracy: 69.76\n", 274 | "epoch 47 had a loss of 70.441:\n", 275 | "epoch 48 had a loss of 61.714:\n", 276 | "epoch 49 had a loss of 69.674:\n", 277 | "epoch 50 had a loss of 60.041:\n", 278 | "epoch 51 had a loss of 67.995:\n", 279 | "correct: 1672, total: 2500\n", 280 | "validation accuracy: 66.88\n", 281 | "epoch 52 had a loss of 58.486:\n", 282 | "epoch 53 had a loss of 67.985:\n", 283 | "epoch 54 had a loss of 56.784:\n", 284 | "epoch 55 had a loss of 65.629:\n", 285 | "epoch 56 had a loss of 55.778:\n", 286 | "correct: 1778, total: 2500\n", 287 | "validation accuracy: 71.12\n", 288 | "epoch 57 had a loss of 63.425:\n", 289 | "epoch 58 had a loss of 54.342:\n", 290 | "epoch 59 had a loss of 60.929:\n", 291 | "epoch 60 had a loss of 52.447:\n", 292 | "epoch 61 had a loss of 59.177:\n", 293 | "correct: 1784, total: 2500\n", 294 | "validation accuracy: 71.36\n", 295 | "epoch 62 had a loss of 50.707:\n", 296 | "epoch 63 had a loss of 59.645:\n", 297 | "epoch 64 had a loss of 48.444:\n", 298 | "epoch 65 had a loss of 58.692:\n", 299 | "epoch 66 had a loss of 47.309:\n", 300 | "correct: 1796, total: 2500\n", 301 | "validation accuracy: 71.84\n", 302 | "epoch 67 had a loss of 56.025:\n", 303 | "epoch 68 had a loss of 46.815:\n", 304 | "epoch 69 had a loss of 57.314:\n", 305 | "epoch 70 had a loss of 45.563:\n", 306 | "epoch 71 had a loss of 54.548:\n", 307 | "correct: 1807, total: 2500\n", 308 | "validation accuracy: 72.28\n", 309 | "epoch 72 had a loss of 43.947:\n", 310 | "epoch 73 had a loss of 51.221:\n", 311 | "epoch 74 had a loss of 43.44:\n", 312 | "epoch 75 had a loss of 50.027:\n", 313 | "epoch 76 had a loss of 41.055:\n", 314 | "correct: 1806, total: 2500\n", 315 | "validation accuracy: 72.24\n", 316 | "epoch 77 had a loss of 49.864:\n", 317 | "epoch 78 had a loss of 40.638:\n", 318 | "epoch 79 had a loss of 47.843:\n", 319 | "epoch 80 had a loss of 40.034:\n", 320 | "epoch 81 had a loss of 46.633:\n", 321 | "correct: 1817, total: 2500\n", 322 | "validation accuracy: 72.68\n", 323 | "epoch 82 had a loss of 38.564:\n", 324 | "epoch 83 had a loss of 45.391:\n", 325 | "epoch 84 had a loss of 38.921:\n", 326 | "epoch 85 had a loss of 44.739:\n", 327 | "epoch 86 had a loss of 36.371:\n", 328 | "correct: 1825, total: 2500\n", 329 | "validation accuracy: 73.00\n", 330 | "epoch 87 had a loss of 42.843:\n", 331 | "epoch 88 had a loss of 36.642:\n", 332 | "epoch 89 had a loss of 44.325:\n", 333 | "epoch 90 had a loss of 35.034:\n", 334 | "epoch 91 had a loss of 43.729:\n", 335 | "correct: 1835, total: 2500\n", 336 | "validation accuracy: 73.40\n", 337 | "epoch 92 had a loss of 34.83:\n", 338 | "epoch 93 had a loss of 43.054:\n", 339 | "epoch 94 had a loss of 34.262:\n", 340 | "epoch 95 had a loss of 42.408:\n", 341 | "epoch 96 had a loss of 33.244:\n", 342 | "correct: 1844, total: 2500\n", 343 | "validation accuracy: 73.76\n", 344 | "epoch 97 had a loss of 40.137:\n", 345 | "epoch 98 had a loss of 32.036:\n", 346 | "epoch 99 had a loss of 40.617:\n", 347 | "epoch 100 had a loss of 31.283:\n", 348 | "epoch 101 had a loss of 38.893:\n", 349 | "correct: 1818, total: 2500\n", 350 | "validation accuracy: 72.72\n", 351 | "epoch 102 had a loss of 30.514:\n", 352 | "epoch 103 had a loss of 36.414:\n", 353 | "epoch 104 had a loss of 29.846:\n", 354 | "epoch 105 had a loss of 35.484:\n", 355 | "epoch 106 had a loss of 28.688:\n", 356 | "correct: 1844, total: 2500\n", 357 | "validation accuracy: 73.76\n", 358 | "epoch 107 had a loss of 32.477:\n", 359 | "epoch 108 had a loss of 27.244:\n", 360 | "epoch 109 had a loss of 33.256:\n", 361 | "epoch 110 had a loss of 26.222:\n", 362 | "epoch 111 had a loss of 31.101:\n", 363 | "correct: 1846, total: 2500\n", 364 | "validation accuracy: 73.84\n", 365 | "epoch 112 had a loss of 26.721:\n", 366 | "epoch 113 had a loss of 30.268:\n", 367 | "epoch 114 had a loss of 25.618:\n", 368 | "epoch 115 had a loss of 28.931:\n", 369 | "epoch 116 had a loss of 25.284:\n", 370 | "correct: 1859, total: 2500\n", 371 | "validation accuracy: 74.36\n", 372 | "epoch 117 had a loss of 28.243:\n", 373 | "epoch 118 had a loss of 24.153:\n", 374 | "epoch 119 had a loss of 28.069:\n", 375 | "epoch 120 had a loss of 23.269:\n", 376 | "epoch 121 had a loss of 27.448:\n", 377 | "correct: 1866, total: 2500\n", 378 | "validation accuracy: 74.64\n", 379 | "epoch 122 had a loss of 22.092:\n", 380 | "epoch 123 had a loss of 25.823:\n", 381 | "epoch 124 had a loss of 22.316:\n", 382 | "epoch 125 had a loss of 26.363:\n", 383 | "epoch 126 had a loss of 22.562:\n", 384 | "correct: 1869, total: 2500\n", 385 | "validation accuracy: 74.76\n", 386 | "epoch 127 had a loss of 25.361:\n", 387 | "epoch 128 had a loss of 21.893:\n", 388 | "epoch 129 had a loss of 23.67:\n", 389 | "epoch 130 had a loss of 20.472:\n", 390 | "epoch 131 had a loss of 24.869:\n", 391 | "correct: 1872, total: 2500\n", 392 | "validation accuracy: 74.88\n", 393 | "epoch 132 had a loss of 20.084:\n", 394 | "epoch 133 had a loss of 23.902:\n", 395 | "epoch 134 had a loss of 20.17:\n", 396 | "epoch 135 had a loss of 24.072:\n", 397 | "epoch 136 had a loss of 19.705:\n", 398 | "correct: 1873, total: 2500\n", 399 | "validation accuracy: 74.92\n", 400 | "epoch 137 had a loss of 23.764:\n", 401 | "epoch 138 had a loss of 20.042:\n", 402 | "epoch 139 had a loss of 24.471:\n", 403 | "epoch 140 had a loss of 19.268:\n", 404 | "epoch 141 had a loss of 26.174:\n", 405 | "correct: 1865, total: 2500\n", 406 | "validation accuracy: 74.60\n", 407 | "epoch 142 had a loss of 19.303:\n", 408 | "epoch 143 had a loss of 26.392:\n", 409 | "epoch 144 had a loss of 20.878:\n", 410 | "epoch 145 had a loss of 29.359:\n", 411 | "epoch 146 had a loss of 19.566:\n", 412 | "correct: 1882, total: 2500\n", 413 | "validation accuracy: 75.28\n", 414 | "epoch 147 had a loss of 29.096:\n", 415 | "epoch 148 had a loss of 17.609:\n", 416 | "epoch 149 had a loss of 28.129:\n", 417 | "epoch 150 had a loss of 16.877:\n", 418 | "epoch 151 had a loss of 25.728:\n", 419 | "correct: 1887, total: 2500\n", 420 | "validation accuracy: 75.48\n", 421 | "epoch 152 had a loss of 16.436:\n", 422 | "epoch 153 had a loss of 24.018:\n", 423 | "epoch 154 had a loss of 15.716:\n", 424 | "epoch 155 had a loss of 20.792:\n", 425 | "epoch 156 had a loss of 14.76:\n", 426 | "correct: 1888, total: 2500\n", 427 | "validation accuracy: 75.52\n", 428 | "epoch 157 had a loss of 19.901:\n", 429 | "epoch 158 had a loss of 14.685:\n", 430 | "epoch 159 had a loss of 20.57:\n", 431 | "epoch 160 had a loss of 14.596:\n", 432 | "epoch 161 had a loss of 20.118:\n", 433 | "correct: 1909, total: 2500\n", 434 | "validation accuracy: 76.36\n", 435 | "epoch 162 had a loss of 14.506:\n", 436 | "epoch 163 had a loss of 18.156:\n", 437 | "epoch 164 had a loss of 14.234:\n", 438 | "epoch 165 had a loss of 18.269:\n", 439 | "epoch 166 had a loss of 14.057:\n", 440 | "correct: 1897, total: 2500\n", 441 | "validation accuracy: 75.88\n", 442 | "epoch 167 had a loss of 16.399:\n", 443 | "epoch 168 had a loss of 13.191:\n", 444 | "epoch 169 had a loss of 16.284:\n", 445 | "epoch 170 had a loss of 12.713:\n", 446 | "epoch 171 had a loss of 14.768:\n", 447 | "correct: 1897, total: 2500\n", 448 | "validation accuracy: 75.88\n", 449 | "epoch 172 had a loss of 12.85:\n", 450 | "epoch 173 had a loss of 14.618:\n", 451 | "epoch 174 had a loss of 12.6:\n", 452 | "epoch 175 had a loss of 13.754:\n", 453 | "epoch 176 had a loss of 12.1:\n", 454 | "correct: 1896, total: 2500\n", 455 | "validation accuracy: 75.84\n", 456 | "epoch 177 had a loss of 14.189:\n", 457 | "epoch 178 had a loss of 12.096:\n", 458 | "epoch 179 had a loss of 13.668:\n", 459 | "epoch 180 had a loss of 13.047:\n", 460 | "epoch 181 had a loss of 14.236:\n", 461 | "correct: 1884, total: 2500\n", 462 | "validation accuracy: 75.36\n", 463 | "epoch 182 had a loss of 12.84:\n", 464 | "epoch 183 had a loss of 12.927:\n", 465 | "epoch 184 had a loss of 12.742:\n", 466 | "epoch 185 had a loss of 12.622:\n", 467 | "epoch 186 had a loss of 13.217:\n", 468 | "correct: 1904, total: 2500\n", 469 | "validation accuracy: 76.16\n", 470 | "epoch 187 had a loss of 13.594:\n", 471 | "epoch 188 had a loss of 12.24:\n", 472 | "epoch 189 had a loss of 12.772:\n", 473 | "epoch 190 had a loss of 12.673:\n", 474 | "epoch 191 had a loss of 12.067:\n", 475 | "correct: 1835, total: 2500\n", 476 | "validation accuracy: 73.40\n" 477 | ] 478 | }, 479 | { 480 | "name": "stdout", 481 | "output_type": "stream", 482 | "text": [ 483 | "epoch 192 had a loss of 13.577:\n", 484 | "epoch 193 had a loss of 13.533:\n", 485 | "epoch 194 had a loss of 13.713:\n", 486 | "epoch 195 had a loss of 12.667:\n", 487 | "epoch 196 had a loss of 11.576:\n", 488 | "correct: 1915, total: 2500\n", 489 | "validation accuracy: 76.60\n", 490 | "epoch 197 had a loss of 12.638:\n", 491 | "epoch 198 had a loss of 10.731:\n", 492 | "epoch 199 had a loss of 12.506:\n", 493 | "epoch 200 had a loss of 10.383:\n", 494 | "epoch 201 had a loss of 12.581:\n", 495 | "correct: 1844, total: 2500\n", 496 | "validation accuracy: 73.76\n", 497 | "epoch 202 had a loss of 12.099:\n", 498 | "epoch 203 had a loss of 13.259:\n", 499 | "epoch 204 had a loss of 10.19:\n", 500 | "epoch 205 had a loss of 14.359:\n", 501 | "epoch 206 had a loss of 10.284:\n", 502 | "correct: 1917, total: 2500\n", 503 | "validation accuracy: 76.68\n", 504 | "epoch 207 had a loss of 14.6:\n", 505 | "epoch 208 had a loss of 10.931:\n", 506 | "epoch 209 had a loss of 12.596:\n", 507 | "epoch 210 had a loss of 10.348:\n", 508 | "epoch 211 had a loss of 11.99:\n", 509 | "correct: 1856, total: 2500\n", 510 | "validation accuracy: 74.24\n", 511 | "epoch 212 had a loss of 11.241:\n", 512 | "epoch 213 had a loss of 11.915:\n", 513 | "epoch 214 had a loss of 11.312:\n", 514 | "epoch 215 had a loss of 10.298:\n", 515 | "epoch 216 had a loss of 10.769:\n", 516 | "correct: 1921, total: 2500\n", 517 | "validation accuracy: 76.84\n", 518 | "epoch 217 had a loss of 9.7846:\n", 519 | "epoch 218 had a loss of 10.23:\n", 520 | "epoch 219 had a loss of 8.9255:\n", 521 | "epoch 220 had a loss of 9.081:\n", 522 | "epoch 221 had a loss of 8.568:\n", 523 | "correct: 1853, total: 2500\n", 524 | "validation accuracy: 74.12\n", 525 | "epoch 222 had a loss of 8.9211:\n", 526 | "epoch 223 had a loss of 8.0079:\n", 527 | "epoch 224 had a loss of 9.1798:\n", 528 | "epoch 225 had a loss of 7.9771:\n", 529 | "epoch 226 had a loss of 8.3975:\n", 530 | "correct: 1921, total: 2500\n", 531 | "validation accuracy: 76.84\n", 532 | "epoch 227 had a loss of 9.0924:\n", 533 | "epoch 228 had a loss of 8.2583:\n", 534 | "epoch 229 had a loss of 8.2822:\n", 535 | "epoch 230 had a loss of 7.976:\n", 536 | "epoch 231 had a loss of 9.0271:\n", 537 | "correct: 1854, total: 2500\n", 538 | "validation accuracy: 74.16\n", 539 | "epoch 232 had a loss of 8.3784:\n", 540 | "epoch 233 had a loss of 8.3691:\n", 541 | "epoch 234 had a loss of 8.1224:\n", 542 | "epoch 235 had a loss of 8.0464:\n", 543 | "epoch 236 had a loss of 7.1277:\n", 544 | "correct: 1926, total: 2500\n", 545 | "validation accuracy: 77.04\n", 546 | "epoch 237 had a loss of 8.9603:\n", 547 | "epoch 238 had a loss of 8.8157:\n", 548 | "epoch 239 had a loss of 9.0398:\n", 549 | "epoch 240 had a loss of 6.0725:\n", 550 | "epoch 241 had a loss of 9.0863:\n", 551 | "correct: 1908, total: 2500\n", 552 | "validation accuracy: 76.32\n", 553 | "epoch 242 had a loss of 6.1797:\n", 554 | "epoch 243 had a loss of 8.3911:\n", 555 | "epoch 244 had a loss of 6.0233:\n", 556 | "epoch 245 had a loss of 7.594:\n", 557 | "epoch 246 had a loss of 5.2482:\n", 558 | "correct: 1922, total: 2500\n", 559 | "validation accuracy: 76.88\n", 560 | "epoch 247 had a loss of 8.3338:\n", 561 | "epoch 248 had a loss of 6.1232:\n", 562 | "epoch 249 had a loss of 8.2647:\n", 563 | "epoch 250 had a loss of 6.3322:\n" 564 | ] 565 | } 566 | ], 567 | "source": [ 568 | "class SingleHiddenNN(nn.Module):\n", 569 | " def __init__(self, vocab_size, max_len, embed_elems, batch_size):\n", 570 | " super(SingleHiddenNN, self).__init__()\n", 571 | " self.vocab_size = vocab_size\n", 572 | " self.embed_elems = embed_elems\n", 573 | " self.max_len = max_len\n", 574 | " self.emb = nn.Embedding(self.vocab_size+1, self.embed_elems)\n", 575 | " self.fc = nn.Linear(int(self.max_len * self.embed_elems), 100)\n", 576 | " self.relu = nn.SELU()\n", 577 | " self.dropout = nn.Dropout(0.7)\n", 578 | " self.out = nn.Linear(100, 1)\n", 579 | " self.sigmoid = nn.Sigmoid()\n", 580 | " def forward(self, input):\n", 581 | " x = self.emb(input)\n", 582 | " x = x.view(input.size(0), -1)\n", 583 | " x = self.fc(x)\n", 584 | " x = self.relu(x)\n", 585 | " x = self.dropout(x)\n", 586 | " x = self.out(x)\n", 587 | " x = self.sigmoid(x)\n", 588 | " return x.view(-1)\n", 589 | "\n", 590 | "print(ts.min(), ts.max())\n", 591 | "print(ts.size())\n", 592 | "model = SingleHiddenNN(len(vocab), 500, 32, 100)\n", 593 | "print(model)\n", 594 | "criterion = nn.BCELoss()\n", 595 | "optimizer = []\n", 596 | "optimizer += [torch.optim.Adam(model.parameters(), lr=0.0001)]\n", 597 | "optimizer += [torch.optim.SGD(model.parameters(), lr=0.0001, momentum=0.9)]\n", 598 | "epochs = 250\n", 599 | "for epoch in range(epochs):\n", 600 | " model.train()\n", 601 | " running_loss = 0\n", 602 | " for i, (mb, tgts) in enumerate(dlts):\n", 603 | " model.zero_grad()\n", 604 | " mb, tgts = torch.autograd.Variable(mb), torch.autograd.Variable(tgts.float())\n", 605 | " out = model(mb)\n", 606 | " loss = criterion(out, tgts)\n", 607 | " loss.backward()\n", 608 | " opt_idx = epoch % 2\n", 609 | " optimizer[opt_idx].step()\n", 610 | " running_loss += loss.data[0]\n", 611 | " print(\"epoch {} had a loss of {:.5}:\".format(epoch+1, running_loss))\n", 612 | " if epoch > 0 and epoch % 5 == 0:\n", 613 | " model.eval()\n", 614 | " correct = 0\n", 615 | " for vmb, vtgts in dlvs:\n", 616 | " vmb, vtgts = torch.autograd.Variable(vmb), torch.autograd.Variable(vtgts.float())\n", 617 | " vout = model(vmb)\n", 618 | " vpred = vout.round()\n", 619 | " correct += (vpred == vtgts).data.sum()\n", 620 | " print(\"correct: {}, total: {}\".format(correct, len(vs)))\n", 621 | " print(\"validation accuracy: {:.2f}\".format(100.*correct/len(vs)))\n" 622 | ] 623 | }, 624 | { 625 | "cell_type": "code", 626 | "execution_count": 56, 627 | "metadata": { 628 | "scrolled": false 629 | }, 630 | "outputs": [ 631 | { 632 | "name": "stdout", 633 | "output_type": "stream", 634 | "text": [ 635 | "2019 0.8076 2500\n", 636 | "\n", 637 | " 0 0\n", 638 | " 1 1\n", 639 | " 0 0\n", 640 | " 0 1\n", 641 | " 1 1\n", 642 | " 0 0\n", 643 | " 0 0\n", 644 | " 1 0\n", 645 | " 0 0\n", 646 | " 1 1\n", 647 | " 0 0\n", 648 | " 1 0\n", 649 | " 1 1\n", 650 | " 0 0\n", 651 | " 1 0\n", 652 | " 0 0\n", 653 | " 0 0\n", 654 | " 1 0\n", 655 | " 0 0\n", 656 | " 1 0\n", 657 | " 0 0\n", 658 | " 1 1\n", 659 | " 0 1\n", 660 | " 1 1\n", 661 | " 0 0\n", 662 | " 0 0\n", 663 | " 1 1\n", 664 | " 0 1\n", 665 | " 1 1\n", 666 | " 1 1\n", 667 | " 1 0\n", 668 | " 1 1\n", 669 | " 0 0\n", 670 | " 0 0\n", 671 | " 0 0\n", 672 | " 0 0\n", 673 | " 0 1\n", 674 | " 0 1\n", 675 | " 1 1\n", 676 | " 1 0\n", 677 | " 0 0\n", 678 | " 1 1\n", 679 | " 1 0\n", 680 | " 1 1\n", 681 | " 0 1\n", 682 | " 0 0\n", 683 | " 1 1\n", 684 | " 0 1\n", 685 | " 0 0\n", 686 | " 0 0\n", 687 | " 0 0\n", 688 | " 1 1\n", 689 | " 0 0\n", 690 | " 1 1\n", 691 | " 0 0\n", 692 | " 0 0\n", 693 | " 0 1\n", 694 | " 1 1\n", 695 | " 1 1\n", 696 | " 1 1\n", 697 | " 1 0\n", 698 | " 0 0\n", 699 | " 1 1\n", 700 | " 0 0\n", 701 | " 0 0\n", 702 | " 0 0\n", 703 | " 0 0\n", 704 | " 1 0\n", 705 | " 1 1\n", 706 | " 0 0\n", 707 | " 0 0\n", 708 | " 0 0\n", 709 | " 1 1\n", 710 | " 1 1\n", 711 | " 1 1\n", 712 | " 1 1\n", 713 | " 0 1\n", 714 | " 0 0\n", 715 | " 1 1\n", 716 | " 0 1\n", 717 | " 0 0\n", 718 | " 0 0\n", 719 | " 1 1\n", 720 | " 1 1\n", 721 | " 0 0\n", 722 | " 1 1\n", 723 | " 1 1\n", 724 | " 0 0\n", 725 | " 1 1\n", 726 | " 0 0\n", 727 | " 1 0\n", 728 | " 0 0\n", 729 | " 1 1\n", 730 | " 1 1\n", 731 | " 0 0\n", 732 | " 0 0\n", 733 | " 0 0\n", 734 | " 0 0\n", 735 | " 0 0\n", 736 | " 1 1\n", 737 | "[torch.FloatTensor of size 100x2]\n", 738 | "\n" 739 | ] 740 | } 741 | ], 742 | "source": [ 743 | "correct = 0\n", 744 | "for i, (mb, tgts) in enumerate(dlvs):\n", 745 | " mb, tgts = torch.autograd.Variable(mb), torch.autograd.Variable(tgts.float())\n", 746 | " out = model(mb)\n", 747 | " pred = out.round()\n", 748 | " correct += (pred == tgts).data.sum()\n", 749 | "print(correct, correct / len(vs), len(vs))\n", 750 | "print(torch.stack((pred.data, tgts.data), 1))" 751 | ] 752 | }, 753 | { 754 | "cell_type": "code", 755 | "execution_count": 53, 756 | "metadata": { 757 | "collapsed": true 758 | }, 759 | "outputs": [], 760 | "source": [ 761 | "torch.save(model.state_dict(), \"model_imdb_20170912.pt\")" 762 | ] 763 | }, 764 | { 765 | "cell_type": "markdown", 766 | "metadata": {}, 767 | "source": [ 768 | "## torchtext" 769 | ] 770 | }, 771 | { 772 | "cell_type": "code", 773 | "execution_count": null, 774 | "metadata": { 775 | "collapsed": true 776 | }, 777 | "outputs": [], 778 | "source": [ 779 | "import torchtext\n", 780 | "import torchtext.data as ttdata\n", 781 | "TEXT = ttdata.Field()\n", 782 | "LABEL = ttdata.Field(sequential=False)\n", 783 | "imdb_ds = torchtext.datasets.IMDB(\"/home/david/imdb_sentiment/data\", TEXT, LABEL)\n", 784 | "train_iter, test_iter = imdb_ds.iters(batch_size=4, device=-1)" 785 | ] 786 | }, 787 | { 788 | "cell_type": "code", 789 | "execution_count": null, 790 | "metadata": { 791 | "collapsed": true 792 | }, 793 | "outputs": [], 794 | "source": [ 795 | "train_iter, test_iter = imdb_ds.iters(batch_size=25, device=-1)\n", 796 | "for x in train_iter:\n", 797 | " print(x.text, x.label.size())\n", 798 | " break\n", 799 | " " 800 | ] 801 | }, 802 | { 803 | "cell_type": "code", 804 | "execution_count": 52, 805 | "metadata": {}, 806 | "outputs": [ 807 | { 808 | "name": "stdout", 809 | "output_type": "stream", 810 | "text": [ 811 | "aae_supervised.py notes.md\r\n", 812 | "\u001b[0m\u001b[01;34malgore\u001b[0m/ numpy_reshape_test.ipynb\r\n", 813 | "AlGore_2009.sph pad_test.py\r\n", 814 | "AlGore_2009.stm \u001b[01;34mpcsnpny-20150204-mkj\u001b[0m/\r\n", 815 | "audio_rnn_basic.ipynb \u001b[01;31mpcsnpny-20150204-mkj.tgz\u001b[0m\r\n", 816 | "\u001b[01;35mclipmin.png\u001b[0m \u001b[00;36mpiano2.mp3\u001b[0m\r\n", 817 | "CNN2RNN.ipynb \u001b[00;36mpiano.mp3\u001b[0m\r\n", 818 | "collate_variable.py \u001b[00;36mpiano_new.wav\u001b[0m\r\n", 819 | "\u001b[01;34mdata\u001b[0m/ playground.ipynb\r\n", 820 | "\u001b[01;31mdata.zip\u001b[0m predict_audio.ipynb\r\n", 821 | "deepspeech1d.ipynb Presentation.ipynb\r\n", 822 | "denoising_autoencoder.ipynb prime_factors.py\r\n", 823 | "extract_mnist.py pyaudio-test.py\r\n", 824 | "\u001b[00;36mfile2.wav\u001b[0m \u001b[01;34m__pycache__\u001b[0m/\r\n", 825 | "\u001b[00;36mfile.flac\u001b[0m pytorch_basics.ipynb\r\n", 826 | "\u001b[00;36mfile.mp3\u001b[0m PyTorch Embeddings Test.ipynb\r\n", 827 | "\u001b[00;36mfile.wav\u001b[0m pytorch_tutorial_classify_names.ipynb\r\n", 828 | "\u001b[01;34mfrancemusique\u001b[0m/ rnn_autoencoder.ipynb\r\n", 829 | "G729VAD.ipynb \u001b[01;35mrnn_beispiel1.png\u001b[0m\r\n", 830 | "GRUAutoencoder.ipynb \u001b[01;35mrnn_beispiel2.png\u001b[0m\r\n", 831 | "\u001b[00;36mhallöchen.wav\u001b[0m \u001b[01;35mrnn_predictions_final.png\u001b[0m\r\n", 832 | "\u001b[01;32mhelloworld.py\u001b[0m* \u001b[01;35mrnn_predictions.png\u001b[0m\r\n", 833 | "imdb_tokenize.ipynb \u001b[01;35msmoother1.png\u001b[0m\r\n", 834 | "\u001b[01;35mkmeans1.png\u001b[0m \u001b[01;34msnipsdata\u001b[0m/\r\n", 835 | "\u001b[01;35mkmeans2.png\u001b[0m Snips_Report.ipynb\r\n", 836 | "label_smoothing.ipynb \u001b[01;34mstarters\u001b[0m/\r\n", 837 | "levinson.py startup.sh\r\n", 838 | "librispeech_load_txts.ipynb stm_sph_processing.ipynb\r\n", 839 | "librispeech.py \u001b[01;35mtedlium1.png\u001b[0m\r\n", 840 | "loader.py \u001b[01;35mtedlium2.png\u001b[0m\r\n", 841 | "\u001b[01;34mmini_librispeech_dataset\u001b[0m/ tedlium_labels.ipynb\r\n", 842 | "\u001b[01;35mmlp_beispiel_noise1.png\u001b[0m test_argparse.py\r\n", 843 | "\u001b[01;35mmlp_beispiel_noise2.png\u001b[0m test_nltk.py\r\n", 844 | "\u001b[01;35mmlp_beispiel_nonoise1.png\u001b[0m test_spacy.py\r\n", 845 | "\u001b[01;35mmlp_beispiel_nonoise2.png\u001b[0m timeseries.ipynb\r\n", 846 | "\u001b[01;35mmlp_predictions_noise.png\u001b[0m torchaudio\r\n", 847 | "\u001b[01;35mmlp_predictions.png\u001b[0m torchaudio_vs_librosa.ipynb\r\n", 848 | "\u001b[01;34mMNIST\u001b[0m/ VAD_labeling.ipynb\r\n", 849 | "model_20170912.pt wavelet.ipynb\r\n", 850 | "\u001b[01;34mmodels\u001b[0m/ yesno_torchaudio_playground.ipynb\r\n", 851 | "mu_law_companding.ipynb zero2d_test.py\r\n" 852 | ] 853 | } 854 | ], 855 | "source": [ 856 | "%ls" 857 | ] 858 | } 859 | ], 860 | "metadata": { 861 | "kernelspec": { 862 | "display_name": "Python 3", 863 | "language": "python", 864 | "name": "python3" 865 | }, 866 | "language_info": { 867 | "codemirror_mode": { 868 | "name": "ipython", 869 | "version": 3 870 | }, 871 | "file_extension": ".py", 872 | "mimetype": "text/x-python", 873 | "name": "python", 874 | "nbconvert_exporter": "python", 875 | "pygments_lexer": "ipython3", 876 | "version": "3.6.1" 877 | } 878 | }, 879 | "nbformat": 4, 880 | "nbformat_minor": 2 881 | } 882 | -------------------------------------------------------------------------------- /pytorch_attention_audio.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | from torch.optim import lr_scheduler 7 | import torch.utils.data as data 8 | from torch.nn.utils.rnn import pack_padded_sequence as pack, pad_packed_sequence as unpack 9 | import torchaudio 10 | import torchaudio.transforms as tat 11 | import numpy as np 12 | import os 13 | import glob 14 | import matplotlib 15 | matplotlib.use('Agg') 16 | import matplotlib.pyplot as plt 17 | 18 | from pytorch_audio_utils import * 19 | 20 | parser = argparse.ArgumentParser(description='PyTorch Language ID Classifier Trainer') 21 | parser.add_argument('--epochs', type=int, default=5, 22 | help='upper epoch limit') 23 | parser.add_argument('--batch-size', type=int, default=6, 24 | help='batch size') 25 | parser.add_argument('--window-size', type=int, default=200, 26 | help='size of fft window') 27 | parser.add_argument('--validate', action='store_true', 28 | help='do out-of-bag validation') 29 | parser.add_argument('--log-interval', type=int, default=5, 30 | help='reports per epoch') 31 | parser.add_argument('--load-model', type=str, default=None, 32 | help='path of model to load') 33 | parser.add_argument('--save-model', action='store_true', 34 | help='path to save the final model') 35 | parser.add_argument('--train-full-model', action='store_true', 36 | help='train full model vs. final layer') 37 | args = parser.parse_args() 38 | 39 | class Preemphasis(object): 40 | """Perform preemphasis on signal 41 | 42 | y = x[n] - α*x[n-1] 43 | 44 | Args: 45 | alpha (float): preemphasis coefficient 46 | 47 | """ 48 | 49 | def __init__(self, alpha=0.97): 50 | self.alpha = alpha 51 | 52 | def __call__(self, sig): 53 | """ 54 | 55 | Args: 56 | sig (Tensor): Tensor of audio of size (Samples x Channels) 57 | 58 | Returns: 59 | sig (Tensor): Preemphasized. See equation above. 60 | 61 | """ 62 | if self.alpha == 0: 63 | return sig 64 | else: 65 | sig[1:, :] -= self.alpha * sig[:-1, :] 66 | return sig 67 | 68 | class RfftPow(object): 69 | """This function emulates power of the discrete fourier transform. 70 | 71 | Note: this implementation may not be numerically stable 72 | 73 | Args: 74 | K (int): number of fft freq bands 75 | 76 | """ 77 | 78 | def __init__(self, K=None): 79 | self.K = K 80 | 81 | def __call__(self, sig): 82 | """ 83 | 84 | Args: 85 | sig (Tensor): Tensor of audio of size (Samples x Channels) 86 | 87 | Returns: 88 | S (Tensor): spectrogram 89 | 90 | """ 91 | N = sig.size(1) 92 | if self.K is None: 93 | K = N 94 | else: 95 | K = self.K 96 | 97 | k_vec = torch.arange(0, K).unsqueeze(0) 98 | n_vec = torch.arange(0, N).unsqueeze(1) 99 | angular_pt = 2 * np.pi * k_vec * n_vec / K 100 | S = torch.sqrt(torch.matmul(sig, angular_pt.cos())**2 + \ 101 | torch.matmul(sig, angular_pt.sin())**2) 102 | S = S.squeeze()[:(K//2+1)] 103 | S = (1 / K) * S**2 104 | return S 105 | 106 | class FilterBanks(object): 107 | """Bins a periodogram from K fft frequency bands into N bins (banks) 108 | 109 | fft bands (K//2+1) -> filterbanks (n_filterbanks) -> bins (bins) 110 | 111 | Args: 112 | n_filterbanks (int): number of filterbanks 113 | bins (list): number of bins 114 | 115 | """ 116 | 117 | def __init__(self, n_filterbanks, bins): 118 | self.n_filterbanks = n_filterbanks 119 | self.bins = bins 120 | 121 | def __call__(self, S): 122 | """ 123 | 124 | Args: 125 | S (Tensor): Tensor of Spectro- / Periodogram 126 | 127 | Returns: 128 | fb (Tensor): binned filterbanked spectrogram 129 | 130 | """ 131 | conversion_factor = np.log(10) # torch.log10 doesn't exist 132 | K = S.size(0) 133 | fb_mat = torch.zeros((self.n_filterbanks, K)) 134 | for m in range(1, self.n_filterbanks+1): 135 | f_m_minus = int(self.bins[m - 1]) 136 | f_m = int(self.bins[m]) 137 | f_m_plus = int(self.bins[m + 1]) 138 | 139 | fb_mat[m - 1, f_m_minus:f_m] = (torch.arange(f_m_minus, f_m) - f_m_minus) / (f_m - f_m_minus) 140 | fb_mat[m - 1, f_m:f_m_plus] = (f_m_plus - torch.arange(f_m, f_m_plus)) / (f_m_plus - f_m) 141 | fb = torch.matmul(S, fb_mat.t()) 142 | fb = 20 * torch.log(fb) / conversion_factor 143 | return fb 144 | 145 | class MFCC(object): 146 | """Discrete Cosine Transform 147 | 148 | There are three types of the DCT. This is 'Type 2' as described in the scipy docs. 149 | 150 | filterbank bins (bins) -> mfcc (mfcc) 151 | 152 | Args: 153 | n_filterbanks (int): number of filterbanks 154 | n_coeffs (int): number of mfc coefficients to keep 155 | mode (str): orthogonal transformation 156 | 157 | """ 158 | 159 | def __init__(self, n_filterbanks, n_coeffs, mode="ortho"): 160 | self.n_filterbanks = n_filterbanks 161 | self.n_coeffs = n_coeffs 162 | self.mode = "ortho" 163 | 164 | def __call__(self, fb): 165 | """ 166 | 167 | Args: 168 | fb (Tensor): Tensor of binned filterbanked spectrogram 169 | 170 | Returns: 171 | mfcc (Tensor): Tensor of mfcc coefficients 172 | 173 | """ 174 | K = self.n_filterbanks 175 | k_vec = torch.arange(0, K).unsqueeze(0) 176 | n_vec = torch.arange(0, self.n_filterbanks).unsqueeze(1) 177 | angular_pt = np.pi * k_vec * ((2*n_vec+1) / (2*K)) 178 | mfcc = 2 * torch.matmul(fb, angular_pt.cos()) 179 | if self.mode == "ortho": 180 | mfcc[0] *= np.sqrt(1/(4*self.n_filterbanks)) 181 | mfcc[1:] *= np.sqrt(1/(2*self.n_filterbanks)) 182 | return mfcc[1:(self.n_coeffs+1)] 183 | 184 | class Sig2Features(object): 185 | """Get the log power, MFCCs and 1st derivatives of the signal across n hops 186 | and concatenate all that together 187 | 188 | Args: 189 | n_hops (int): number of filterbanks 190 | transformDict (dict): dict of transformations for each hop 191 | 192 | """ 193 | 194 | def __init__(self, ws, hs, transformDict): 195 | self.ws = ws 196 | self.hs = hs 197 | self.td = transformDict 198 | 199 | def __call__(self, sig): 200 | """ 201 | 202 | Args: 203 | sig (Tensor): Tensor of signal 204 | 205 | Returns: 206 | Feats (Tensor): Tensor of log-power, 12 mfcc coefficients and 1st devs 207 | 208 | """ 209 | n_hops = (sig.size(0) - ws) // hs 210 | 211 | P = [] 212 | Mfcc = [] 213 | 214 | for i in range(n_hops): 215 | # create frame 216 | st = int(i * hs) 217 | end = st + ws 218 | sig_n = sig[st:end] 219 | 220 | # get power/energy 221 | P += [self.td["RfftPow"](sig_n.transpose(0, 1))] 222 | 223 | # get mfccs and filter banks 224 | fb = self.td["FilterBanks"](P[-1]) 225 | Mfcc += [self.td["MFCC"](fb)] 226 | 227 | # concat and calculate derivatives 228 | P = torch.stack(P, 1) 229 | P_sum = torch.log(P.sum(0)) 230 | P_dev = torch.zeros(P_sum.size()) 231 | P_dev[1:] = P_sum[1:] - P_sum[:-1] 232 | Mfcc = torch.stack(Mfcc, 1) 233 | Mfcc_dev = torch.cat((torch.zeros(n_coefficients, 1), Mfcc[:,:-1] - Mfcc[:,1:]), 1) 234 | Feats = torch.cat((P_sum.unsqueeze(0), P_dev.unsqueeze(0), Mfcc, Mfcc_dev), 0) 235 | return Feats 236 | 237 | class Labeler(object): 238 | """Labels from text to int + 1 239 | 240 | """ 241 | 242 | def __call__(self, labels): 243 | return torch.LongTensor([int(l)+1 for l in labels]) 244 | 245 | def pad_packed_collate(batch): 246 | """Puts data, and lengths into a packed_padded_sequence then returns 247 | the packed_padded_sequence and the labels. Set use_lengths to True 248 | to use this collate function. 249 | 250 | Args: 251 | batch: (list of tuples) [(audio, target)]. 252 | audio is a FloatTensor 253 | target is a LongTensor with a length of 8 254 | Output: 255 | packed_batch: (PackedSequence), see torch.nn.utils.rnn.pack_padded_sequence 256 | labels: (Tensor), labels from the file names of the wav. 257 | 258 | """ 259 | 260 | if len(batch) == 1: 261 | sigs, labels = batch[0][0], batch[0][1] 262 | sigs = sigs.t() 263 | lengths = [sigs.size(0)] 264 | sigs.unsqueeze_(0) 265 | labels.unsqueeze_(0) 266 | if len(batch) > 1: 267 | sigs, labels, lengths = zip(*[(a.t(), b, a.size(1)) for (a,b) in sorted(batch, key=lambda x: x[0].size(1), reverse=True)]) 268 | max_len, n_feats = sigs[0].size() 269 | sigs = [torch.cat((s, torch.zeros(max_len - s.size(0), n_feats)), 0) if s.size(0) != max_len else s for s in sigs] 270 | sigs = torch.stack(sigs, 0) 271 | labels = torch.stack(labels, 0) 272 | packed_batch = pack(Variable(sigs), lengths, batch_first=True) 273 | return packed_batch, labels 274 | 275 | def unpack_lengths(batch_sizes): 276 | """taken directly from pad_packed_sequence() 277 | """ 278 | lengths = [] 279 | data_offset = 0 280 | prev_batch_size = batch_sizes[0] 281 | for i, batch_size in enumerate(batch_sizes): 282 | dec = prev_batch_size - batch_size 283 | if dec > 0: 284 | lengths.extend((i,) * dec) 285 | prev_batch_size = batch_size 286 | lengths.extend((i + 1,) * batch_size) 287 | lengths.reverse() 288 | return lengths 289 | 290 | class EncoderRNN2(nn.Module): 291 | def __init__(self, input_size, hidden_size, n_layers=1, batch_size=1): 292 | super(EncoderRNN2, self).__init__() 293 | self.n_layers = n_layers 294 | self.hidden_size = hidden_size 295 | self.batch_size = batch_size 296 | 297 | self.gru = nn.GRU(input_size, hidden_size, n_layers, batch_first=True) 298 | 299 | def forward(self, input, hidden): 300 | output = input 301 | output, hidden = self.gru(output, hidden) 302 | #print("encoder:", output.size(), hidden.size()) 303 | return output, hidden 304 | 305 | def initHidden(self, ttype=None): 306 | if ttype == None: 307 | ttype = torch.FloatTensor 308 | result = Variable(ttype(self.n_layers * 1, self.batch_size, self.hidden_size).fill_(0)) 309 | if use_cuda: 310 | return result.cuda() 311 | else: 312 | return result 313 | 314 | class Attn(nn.Module): 315 | def __init__(self, hidden_size, batch_size=1, method="dot"): 316 | super(Attn, self).__init__() 317 | 318 | self.method = method 319 | self.hidden_size = hidden_size 320 | self.batch_size = batch_size 321 | 322 | if self.method == 'general': 323 | self.attn = nn.Linear(self.hidden_size, hidden_size, bias=False) 324 | 325 | elif self.method == 'concat': 326 | self.attn = nn.Linear(self.hidden_size * 2, hidden_size, bias=False) 327 | self.v = nn.Parameter(torch.FloatTensor(batch_size, 1, hidden_size)) 328 | 329 | def forward(self, hidden, encoder_outputs): 330 | max_len = encoder_outputs.size(1) 331 | 332 | # get attn energies in one batch 333 | attn_energies = self.score(hidden, encoder_outputs) 334 | 335 | # Normalize energies to weights in range 0 to 1 336 | return F.softmax(attn_energies) 337 | 338 | def score(self, hidden, encoder_output): 339 | #print("attn.score:", hidden.size(), encoder_output.size()) 340 | if self.method == 'general': 341 | energy = self.attn(encoder_output) 342 | energy = energy.transpose(2, 1) 343 | energy = hidden.bmm(energy) 344 | return energy 345 | 346 | elif self.method == 'concat': 347 | hidden = hidden * Variable(encoder_output.data.new(encoder_output.size()).fill_(1)) # broadcast hidden to encoder_outputs size 348 | energy = self.attn(torch.cat((hidden, encoder_output), -1)) 349 | energy = energy.transpose(2, 1) 350 | energy = self.v.bmm(energy) 351 | return energy 352 | else: 353 | #self.method == 'dot': 354 | encoder_output = encoder_output.transpose(2, 1) 355 | energy = hidden.bmm(encoder_output) 356 | return energy 357 | 358 | class LuongAttnDecoderRNN(nn.Module): 359 | def __init__(self, hidden_size, output_size, attn_model="dot", n_layers=1, dropout=0.1, batch_size=1): 360 | super(LuongAttnDecoderRNN, self).__init__() 361 | 362 | # Keep for reference 363 | self.attn_model = attn_model 364 | self.hidden_size = hidden_size 365 | self.output_size = output_size 366 | self.n_layers = n_layers 367 | self.dropout = dropout 368 | self.batch_size = batch_size 369 | 370 | # Define layers 371 | self.gru = nn.GRU(hidden_size, hidden_size, n_layers, dropout=dropout, batch_first=True) 372 | self.concat = nn.Linear(hidden_size * 2, hidden_size) 373 | self.out = nn.Linear(hidden_size, output_size) 374 | 375 | # Choose attention model 376 | if attn_model != 'none': 377 | self.attn = Attn(hidden_size, method=attn_model, batch_size=batch_size) 378 | 379 | def forward(self, input_seq, last_hidden, encoder_outputs): 380 | # Note: This now runs in batch but was originally run one 381 | # step at a time 382 | # B = batch size 383 | # S = output length 384 | # N = # of hidden features 385 | 386 | # Get the embedding of the current input word (last output word) 387 | batch_size = input_seq.size(0) 388 | 389 | # Get current hidden state from input word and last hidden state 390 | rnn_output, hidden = self.gru(input_seq, last_hidden) 391 | 392 | # Calculate attention from current RNN state and all encoder outputs; 393 | # apply to encoder outputs to get weighted average 394 | #print("decoder:", rnn_output.size(), encoder_outputs.size()) 395 | attn_weights = self.attn(rnn_output, encoder_outputs) 396 | context = attn_weights.bmm(encoder_outputs) # [B, S, L] dot [B, L, N] -> [B, S, N] 397 | print(attn_weights.size(), encoder_outputs.size(), context.size()) 398 | #print("decoder context:", context.size()) 399 | 400 | # Attentional vector using the RNN hidden state and context vector 401 | # concatenated together (Luong eq. 5) 402 | concat_input = torch.cat((rnn_output, context), -1) # B x S x 2*N 403 | concat_output = F.tanh(self.concat(concat_input)) 404 | 405 | # Finally predict next token (Luong eq. 6, without softmax) 406 | output = self.out(concat_output) 407 | 408 | # Return final output, hidden state, and attention weights (for visualization) 409 | return output, hidden, attn_weights 410 | 411 | # train parameters 412 | epochs = args.epochs 413 | 414 | # set dataset parameters 415 | DATADIR = "/home/david/Programming/data" 416 | sr = 8000 417 | ws = args.window_size 418 | hs = ws // 2 419 | n_fft = 512 # 256 420 | n_filterbanks = 26 421 | n_coefficients = 12 422 | low_mel_freq = 0 423 | high_freq_mel = (2595 * np.log10(1 + (sr/2) / 700)) 424 | mel_pts = np.linspace(low_mel_freq, high_freq_mel, n_filterbanks + 2) 425 | hz_pts = np.floor(700 * (10**(mel_pts / 2595) - 1)) 426 | bins = np.floor((n_fft + 1) * hz_pts / sr) 427 | 428 | # data transformations 429 | td = { 430 | "RfftPow": RfftPow(n_fft), 431 | "FilterBanks": FilterBanks(n_filterbanks, bins), 432 | "MFCC": MFCC(n_filterbanks, n_coefficients), 433 | } 434 | 435 | transforms = tat.Compose([ 436 | tat.Scale(), 437 | tat.PadTrim(58000, fill_value=1e-8), 438 | Preemphasis(), 439 | Sig2Features(ws, hs, td), 440 | ]) 441 | 442 | # set network parameters 443 | use_cuda = torch.cuda.is_available() 444 | batch_size = args.batch_size 445 | input_features = 26 446 | hidden_size = 100 447 | output_size = 3 448 | #output_length = (8 + 7 + 2) # with "blanks" 449 | output_length = 8 # without blanks 450 | n_layers = 1 451 | attn_modus = "dot" 452 | 453 | # build networks, criterion, optimizers, dataset and dataloader 454 | encoder2 = EncoderRNN2(input_features, hidden_size, n_layers=n_layers, batch_size=batch_size) 455 | decoder2 = LuongAttnDecoderRNN(hidden_size, output_size, n_layers=n_layers, attn_model=attn_modus, batch_size=batch_size) 456 | print(encoder2) 457 | print(decoder2) 458 | criterion = nn.CrossEntropyLoss() 459 | optimizer = torch.optim.RMSprop([ 460 | {"params": encoder2.parameters()}, 461 | {"params": decoder2.parameters(), "lr": 0.0001} 462 | ], lr=0.001, momentum=0.9) 463 | scheduler = lr_scheduler.StepLR(optimizer, step_size=80, gamma=0.6) 464 | ds = torchaudio.datasets.YESNO(DATADIR, transform=transforms, target_transform=Labeler()) 465 | dl = data.DataLoader(ds, batch_size=batch_size) 466 | 467 | if use_cuda: 468 | print("using CUDA: {}".format(use_cuda)) 469 | encoder2 = encoder2.cuda() 470 | decoder2 = decoder2.cuda() 471 | 472 | loss_total = [] 473 | # begin training 474 | for epoch in range(epochs): 475 | scheduler.step() 476 | print("epoch {}".format(epoch+1)) 477 | running_loss = 0 478 | loss_epoch = [] 479 | for i, (mb, tgts) in enumerate(dl): 480 | # set model into train mode and clear gradients 481 | encoder2.train() 482 | decoder2.train() 483 | encoder2.zero_grad() 484 | decoder2.zero_grad() 485 | 486 | # set inputs and targets 487 | mb = mb.transpose(2, 1) # [B x N x L] -> [B, L, N] 488 | if use_cuda: 489 | mb, tgts = mb.cuda(), tgts.cuda() 490 | mb, tgts = Variable(mb), Variable(tgts) 491 | 492 | encoder2_hidden = encoder2.initHidden(type(mb.data)) 493 | encoder2_output, encoder2_hidden = encoder2(mb, encoder2_hidden) 494 | #print(encoder2_output) 495 | 496 | # Prepare input and output variables for decoder 497 | dec_i = Variable(encoder2_output.data.new([[[0] * hidden_size] * output_length] * batch_size)) 498 | dec_h = encoder2_hidden # Use last (forward) hidden state from encoder 499 | #print(dec_h.size()) 500 | 501 | """ 502 | # Run through decoder one time step at a time 503 | # collect attentions 504 | attentions = [] 505 | outputs = [] 506 | dec_i = Variable(torch.FloatTensor([[[0] * hidden_size] * 1])) 507 | target_seq = Variable(torch.FloatTensor([[[-1] * hidden_size]*8])) 508 | for t in range(output_length): 509 | #print("t:", t, dec_i.size()) 510 | dec_o, dec_h, dec_attn = decoder2( 511 | dec_i, dec_h, encoder2_output 512 | ) 513 | #print("decoder output", dec_o.size()) 514 | dec_i = target_seq[:,t].unsqueeze(1) # Next input is current target 515 | outputs += [dec_o] 516 | attentions += [dec_attn] 517 | dec_o = torch.cat(outputs, 1) 518 | dec_attn = torch.cat(attentions, 1) 519 | """ 520 | # run through decoder in one shot 521 | dec_o, dec_h, dec_attn = decoder2(dec_i, dec_h, encoder2_output) 522 | 523 | # calculate loss and backprop 524 | loss = criterion(dec_o.view(-1, output_size), tgts.view(-1)) 525 | running_loss += loss.data[0] 526 | loss_epoch += [loss.data[0]] 527 | loss.backward() 528 | #nn.utils.clip_grad_norm(encoder2.parameters(), 0.05) 529 | #nn.utils.clip_grad_norm(decoder2.parameters(), 0.05) 530 | optimizer.step() 531 | 532 | # logging stuff 533 | if (i % args.log_interval == 0 and i != 0) or epoch == 0: 534 | print(loss.data[0]) 535 | loss_total += [loss_epoch] 536 | print((dec_o.max(2)[1].data == tgts.data).float().sum(1) / tgts.size(1)) 537 | print("ave loss of {} at epoch {}".format(running_loss / (i+1), epoch+1)) 538 | 539 | loss_total = np.array(loss_total) 540 | plt.figure() 541 | plt.plot(loss_total.mean(1)) 542 | plt.savefig("pytorch_attention_audio-loss.png") 543 | 544 | # Set up figure with colorbar 545 | attn_plot = dec_attn[0, :, :].data 546 | attn_plot = attn_plot.numpy() if not use_cuda else attn_plot.cpu().numpy() 547 | fig = plt.figure(figsize=(20, 6)) 548 | ax = fig.add_subplot(111) 549 | cax = ax.matshow(attn_plot, cmap='bone', aspect="auto") 550 | fig.colorbar(cax) 551 | fig.savefig("pytorch_attention_audio-attention.png") 552 | -------------------------------------------------------------------------------- /pytorch_basics.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 6, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import numpy as np\n", 12 | "import torch\n", 13 | "import torchaudio\n", 14 | "import torch.nn as nn\n", 15 | "import torch.nn.functional as F\n", 16 | "from torch.autograd import Variable\n", 17 | "import os\n", 18 | "\n", 19 | "import matplotlib.pyplot as plt\n", 20 | "from IPython.display import display, Audio\n", 21 | "%matplotlib notebook" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 4, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "def find_files(dir):\n", 31 | " return [os.path.join(dir,x) for x in os.listdir(dir) if os.path.isfile(os.path.join(dir, x))]\n", 32 | "\n", 33 | "files = find_files(\"data/waves_yesno\")" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 9, 39 | "metadata": {}, 40 | "outputs": [ 41 | { 42 | "data": { 43 | "text/plain": [ 44 | "torch.Size([1, 1, 1000])" 45 | ] 46 | }, 47 | "metadata": {}, 48 | "output_type": "display_data" 49 | }, 50 | { 51 | "data": { 52 | "text/plain": [ 53 | "[(39, -0.648193359375),\n", 54 | " (53, 0.966796875),\n", 55 | " (53, 0.966796875),\n", 56 | " (104, -0.9051513671875),\n", 57 | " (104, -0.9051513671875),\n", 58 | " (104, -0.9051513671875),\n", 59 | " (118, 1.31011962890625),\n", 60 | " (118, 1.31011962890625),\n", 61 | " (168, -1.84478759765625),\n", 62 | " (168, -1.84478759765625),\n", 63 | " (168, -1.84478759765625),\n", 64 | " (168, -1.84478759765625),\n", 65 | " (178, 1.91741943359375),\n", 66 | " (178, 1.91741943359375),\n", 67 | " (178, 1.91741943359375),\n", 68 | " (178, 1.91741943359375),\n", 69 | " (226, -1.46392822265625),\n", 70 | " (226, -1.46392822265625),\n", 71 | " (239, 1.43524169921875),\n", 72 | " (239, 1.43524169921875),\n", 73 | " (239, 1.43524169921875),\n", 74 | " (282, -0.94512939453125),\n", 75 | " (343, -1.25518798828125),\n", 76 | " (343, -1.25518798828125),\n", 77 | " (343, -1.25518798828125),\n", 78 | " (361, 1.47003173828125),\n", 79 | " (378, 1.640625),\n", 80 | " (378, 1.640625),\n", 81 | " (414, -1.5948486328125),\n", 82 | " (414, -1.5948486328125),\n", 83 | " (423, 2.0025634765625),\n", 84 | " (423, 2.0025634765625),\n", 85 | " (423, 2.0025634765625),\n", 86 | " (423, 2.0025634765625),\n", 87 | " (476, -1.59820556640625),\n", 88 | " (485, 1.990966796875),\n", 89 | " (508, -1.619873046875),\n", 90 | " (508, -1.619873046875),\n", 91 | " (538, -2.08831787109375),\n", 92 | " (538, -2.08831787109375),\n", 93 | " (546, 2.00653076171875),\n", 94 | " (546, 2.00653076171875),\n", 95 | " (599, -2.55279541015625),\n", 96 | " (599, -2.55279541015625),\n", 97 | " (599, -2.55279541015625),\n", 98 | " (599, -2.55279541015625),\n", 99 | " (608, 2.11395263671875),\n", 100 | " (622, 2.373046875),\n", 101 | " (622, 2.373046875),\n", 102 | " (661, -2.41485595703125),\n", 103 | " (684, 2.3974609375),\n", 104 | " (684, 2.3974609375),\n", 105 | " (722, -2.6776123046875),\n", 106 | " (722, -2.6776123046875),\n", 107 | " (730, 2.626953125),\n", 108 | " (745, 2.7105712890625),\n", 109 | " (745, 2.7105712890625),\n", 110 | " (784, -2.9150390625),\n", 111 | " (784, -2.9150390625),\n", 112 | " (784, -2.9150390625),\n", 113 | " (784, -2.9150390625),\n", 114 | " (807, 2.7520751953125),\n", 115 | " (807, 2.7520751953125),\n", 116 | " (807, 2.7520751953125),\n", 117 | " (807, 2.7520751953125),\n", 118 | " (845, -2.86590576171875),\n", 119 | " (845, -2.86590576171875),\n", 120 | " (845, -2.86590576171875),\n", 121 | " (853, 2.7276611328125),\n", 122 | " (906, -2.56103515625),\n", 123 | " (914, 2.86346435546875),\n", 124 | " (914, 2.86346435546875),\n", 125 | " (914, 2.86346435546875),\n", 126 | " (914, 2.86346435546875),\n", 127 | " (967, -2.66265869140625),\n", 128 | " (967, -2.66265869140625)]" 129 | ] 130 | }, 131 | "metadata": {}, 132 | "output_type": "display_data" 133 | } 134 | ], 135 | "source": [ 136 | "sig, sr = torchaudio.load(files[2])\n", 137 | "sig = torchaudio.transforms.Scale()(sig)\n", 138 | "\n", 139 | "X = sig[7900:8900].transpose(0, 1) * 10\n", 140 | "X.unsqueeze_(0)\n", 141 | "#X.unsqueeze_(0)\n", 142 | "display(X.size())\n", 143 | "\n", 144 | "x_max, idx_max = torch.nn.functional.max_pool1d(X, 100, 25, 13, return_indices=True)\n", 145 | "x_min, idx_min = torch.nn.functional.max_pool1d(-X, 100, 25, 13, return_indices=True)\n", 146 | "x_max.squeeze(), idx_max.data.numpy(), -x_min.squeeze(), idx_min.data.numpy()\n", 147 | "max_list = list(zip(idx_max.data.squeeze().numpy(), x_max.data.squeeze()))\n", 148 | "min_list = list(zip(idx_min.data.squeeze().numpy(), -x_min.data.squeeze()))\n", 149 | "combined_list = sorted(max_list + min_list)\n", 150 | "display(combined_list)" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": 83, 156 | "metadata": { 157 | "scrolled": false 158 | }, 159 | "outputs": [ 160 | { 161 | "data": { 162 | "text/plain": [ 163 | "\n", 164 | "( 0 ,.,.) = \n", 165 | "\n", 166 | "Columns 0 to 8 \n", 167 | " -0.0624 -0.0716 -0.0797 -0.0856 -0.0905 -0.0900 -0.0818 -0.0692 -0.0501\n", 168 | "\n", 169 | "Columns 9 to 17 \n", 170 | " -0.0287 -0.0072 0.0182 0.0451 0.0731 0.0965 0.1147 0.1244 0.1286\n", 171 | "\n", 172 | "Columns 18 to 26 \n", 173 | " 0.1310 0.1280 0.1182 0.1014 0.0768 0.0516 0.0299 0.0112 -0.0106\n", 174 | "\n", 175 | "Columns 27 to 35 \n", 176 | " -0.0326 -0.0532 -0.0644 -0.0667 -0.0625 -0.0581 -0.0517 -0.0411 -0.0269\n", 177 | "\n", 178 | "Columns 36 to 44 \n", 179 | " -0.0071 0.0119 0.0245 0.0345 0.0388 0.0443 0.0449 0.0386 0.0284\n", 180 | "\n", 181 | "Columns 45 to 53 \n", 182 | " 0.0139 0.0010 -0.0069 -0.0127 -0.0211 -0.0263 -0.0274 -0.0225 -0.0141\n", 183 | "\n", 184 | "Columns 54 to 62 \n", 185 | " -0.0041 0.0028 0.0041 0.0074 0.0062 -0.0039 -0.0194 -0.0416 -0.0692\n", 186 | "\n", 187 | "Columns 63 to 71 \n", 188 | " -0.0992 -0.1273 -0.1520 -0.1707 -0.1809 -0.1845 -0.1731 -0.1436 -0.0975\n", 189 | "\n", 190 | "Columns 72 to 80 \n", 191 | " -0.0462 0.0043 0.0543 0.1037 0.1487 0.1815 0.1917 0.1761 0.1512\n", 192 | "\n", 193 | "Columns 81 to 89 \n", 194 | " 0.1140 0.0734 0.0276 -0.0191 -0.0643 -0.0964 -0.1084 -0.1012 -0.0851\n", 195 | "\n", 196 | "Columns 90 to 98 \n", 197 | " -0.0621 -0.0371 -0.0034 0.0344 0.0718 0.0935 0.0977 0.0902 0.0768\n", 198 | "\n", 199 | "Columns 99 to 99 \n", 200 | " 0.0605\n", 201 | "[torch.FloatTensor of size 1x1x100]" 202 | ] 203 | }, 204 | "metadata": {}, 205 | "output_type": "display_data" 206 | }, 207 | { 208 | "name": "stdout", 209 | "output_type": "stream", 210 | "text": [ 211 | "Conv1d(1, 3, kernel_size=(10,), stride=(9,), dilation=(2,), bias=False)\n" 212 | ] 213 | }, 214 | { 215 | "data": { 216 | "text/plain": [ 217 | "Variable containing:\n", 218 | "(0 ,.,.) = \n", 219 | "\n", 220 | "Columns 0 to 8 \n", 221 | " -0.0624 -0.0287 0.1310 -0.0326 -0.0071 0.0139 -0.0041 -0.0992 -0.0462\n", 222 | " -0.0210 -0.0097 0.0441 -0.0110 -0.0024 0.0047 -0.0014 -0.0334 -0.0156\n", 223 | " -0.0180 -0.0083 0.0378 -0.0094 -0.0020 0.0040 -0.0012 -0.0286 -0.0133\n", 224 | "\n", 225 | "Columns 9 to 11 \n", 226 | " 0.1140 -0.0621 0.0605\n", 227 | " 0.0384 -0.0209 0.0204\n", 228 | " 0.0329 -0.0179 0.0175\n", 229 | "[torch.FloatTensor of size 1x3x12]" 230 | ] 231 | }, 232 | "execution_count": 83, 233 | "metadata": {}, 234 | "output_type": "execute_result" 235 | } 236 | ], 237 | "source": [ 238 | "X2 = sig[8000:8100].transpose(0,1).unsqueeze(0)\n", 239 | "display(X2)\n", 240 | "conv_op = nn.Conv1d(1, 3, kernel_size=10, dilation=2, stride=9, bias=False)\n", 241 | "conv_op.weight = nn.Parameter(torch.cat((torch.ones(1, 1, 1), torch.rand(2, 1, 1))))\n", 242 | "print(conv_op)\n", 243 | "X2_conv = conv_op(Variable(X2))\n", 244 | "X2_conv" 245 | ] 246 | }, 247 | { 248 | "cell_type": "code", 249 | "execution_count": 144, 250 | "metadata": {}, 251 | "outputs": [ 252 | { 253 | "data": { 254 | "text/plain": [ 255 | "\n", 256 | "(0 ,.,.) = \n", 257 | " 0 1 2 3 4 5 6 7 8 9\n", 258 | "[torch.FloatTensor of size 1x1x10]" 259 | ] 260 | }, 261 | "metadata": {}, 262 | "output_type": "display_data" 263 | }, 264 | { 265 | "data": { 266 | "text/plain": [ 267 | "Variable containing:\n", 268 | "(0 ,.,.) = \n", 269 | " 2 3 5 7 9 11 13 15 7\n", 270 | "[torch.FloatTensor of size 1x1x9]" 271 | ] 272 | }, 273 | "metadata": {}, 274 | "output_type": "display_data" 275 | }, 276 | { 277 | "data": { 278 | "text/plain": [ 279 | "Variable containing:\n", 280 | "(0 ,.,.) = \n", 281 | " 0 1 2 4 6 8 10 12 14 16 8 9\n", 282 | "[torch.FloatTensor of size 1x1x12]" 283 | ] 284 | }, 285 | "metadata": {}, 286 | "output_type": "display_data" 287 | } 288 | ], 289 | "source": [ 290 | "# Convolutions\n", 291 | "X = torch.arange(0, 10).view(1, 1, -1)\n", 292 | "display(X)\n", 293 | "conv_op = nn.Conv1d(1, 1, 2, dilation=3, padding=1, bias=False)\n", 294 | "conv_op.weight = nn.Parameter(torch.ones(conv_op.weight.size()))\n", 295 | "display(conv_op(Variable(X)))\n", 296 | "tconv_op = nn.ConvTranspose1d(1, 1, 2, dilation=2, bias=False)\n", 297 | "tconv_op.weight = nn.Parameter(torch.ones(tconv_op.weight.size()))\n", 298 | "display(tconv_op(Variable(X)))" 299 | ] 300 | }, 301 | { 302 | "cell_type": "code", 303 | "execution_count": 155, 304 | "metadata": {}, 305 | "outputs": [ 306 | { 307 | "data": { 308 | "text/plain": [ 309 | "\n", 310 | "(0 ,.,.) = \n", 311 | "\n", 312 | "Columns 0 to 8 \n", 313 | " 0.0000 1.0000 2.0000 3.0000 4.0000 5.0000 6.0000 7.0000 8.0000\n", 314 | "\n", 315 | "Columns 9 to 9 \n", 316 | " 9.0000\n", 317 | "\n", 318 | "(1 ,.,.) = \n", 319 | "\n", 320 | "Columns 0 to 8 \n", 321 | " 0.1731 0.8068 0.1422 0.7546 0.9261 0.5195 0.4176 0.6777 0.1002\n", 322 | "\n", 323 | "Columns 9 to 9 \n", 324 | " 0.1254\n", 325 | "[torch.FloatTensor of size 2x1x10]" 326 | ] 327 | }, 328 | "metadata": {}, 329 | "output_type": "display_data" 330 | }, 331 | { 332 | "data": { 333 | "text/plain": [ 334 | "Variable containing:\n", 335 | "(0 ,.,.) = \n", 336 | " 1.0000 3.0000 5.0000 7.0000 9.0000\n", 337 | "\n", 338 | "(1 ,.,.) = \n", 339 | " 0.8068 0.7546 0.9261 0.6777 0.1254\n", 340 | "[torch.FloatTensor of size 2x1x5]" 341 | ] 342 | }, 343 | "execution_count": 155, 344 | "metadata": {}, 345 | "output_type": "execute_result" 346 | } 347 | ], 348 | "source": [ 349 | "# Pooling\n", 350 | "X = torch.cat((torch.arange(0, 10).view(1, 1, -1), torch.rand(1, 1, 10)))\n", 351 | "display(X)\n", 352 | "pool_op = nn.AdaptiveMaxPool1d(5)\n", 353 | "pool_op(X)" 354 | ] 355 | }, 356 | { 357 | "cell_type": "code", 358 | "execution_count": 163, 359 | "metadata": {}, 360 | "outputs": [ 361 | { 362 | "data": { 363 | "text/plain": [ 364 | "Variable containing:\n", 365 | "(0 ,0 ,.,.) = \n", 366 | " 1 0 1 2 3 4 5 6 7 8 9 8\n", 367 | "[torch.FloatTensor of size 1x1x1x12]" 368 | ] 369 | }, 370 | "execution_count": 163, 371 | "metadata": {}, 372 | "output_type": "execute_result" 373 | } 374 | ], 375 | "source": [ 376 | "# Padding\n", 377 | "X = torch.arange(0, 10).view(1, 1, 1, -1)\n", 378 | "nn.ReflectionPad2d((1, 1, 0, 0))(X)\n" 379 | ] 380 | }, 381 | { 382 | "cell_type": "code", 383 | "execution_count": null, 384 | "metadata": {}, 385 | "outputs": [], 386 | "source": [] 387 | }, 388 | { 389 | "cell_type": "code", 390 | "execution_count": null, 391 | "metadata": { 392 | "collapsed": true 393 | }, 394 | "outputs": [], 395 | "source": [] 396 | } 397 | ], 398 | "metadata": { 399 | "kernelspec": { 400 | "display_name": "Python 3", 401 | "language": "python", 402 | "name": "python3" 403 | }, 404 | "language_info": { 405 | "codemirror_mode": { 406 | "name": "ipython", 407 | "version": 3 408 | }, 409 | "file_extension": ".py", 410 | "mimetype": "text/x-python", 411 | "name": "python", 412 | "nbconvert_exporter": "python", 413 | "pygments_lexer": "ipython3", 414 | "version": "3.6.1" 415 | } 416 | }, 417 | "nbformat": 4, 418 | "nbformat_minor": 2 419 | } 420 | -------------------------------------------------------------------------------- /pytorch_embedding_test.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import torch\n", 12 | "import torch.nn as nn\n", 13 | "from torch.autograd import Variable" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 10, 19 | "metadata": {}, 20 | "outputs": [ 21 | { 22 | "data": { 23 | "text/plain": [ 24 | "(torch.Size([2, 5, 3]), \n", 25 | " (0 ,.,.) = \n", 26 | " 0.3966 1.1329 0.9703\n", 27 | " 1.2760 -0.3156 -0.3003\n", 28 | " 0.4951 -2.4289 -1.8240\n", 29 | " -1.2025 -0.5043 -0.1327\n", 30 | " 0.3966 1.1329 0.9703\n", 31 | " \n", 32 | " (1 ,.,.) = \n", 33 | " 0.4951 -2.4289 -1.8240\n", 34 | " -0.3016 -1.0829 1.2743\n", 35 | " -0.0533 -0.9346 -0.3806\n", 36 | " -1.8540 -1.1871 -0.9740\n", 37 | " 0.5305 -0.0377 -0.4296\n", 38 | " [torch.FloatTensor of size 2x5x3], 1.275978684425354, -0.304186017687122, -2.4288904666900635)" 39 | ] 40 | }, 41 | "execution_count": 10, 42 | "metadata": {}, 43 | "output_type": "execute_result" 44 | } 45 | ], 46 | "source": [ 47 | "# One Hot Vector (max val = N) to M-dim vector\n", 48 | "N = 1000\n", 49 | "M = 3\n", 50 | "embedding = nn.Embedding(N, M)\n", 51 | "input = Variable(torch.LongTensor([[1, 2, 4, 5, 1],[4, 3, 0, 9, 999]]))\n", 52 | "output = embedding(input)\n", 53 | "output.size(), output.data, output.data.max(), output.data.mean(), output.data.min()" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 13, 59 | "metadata": {}, 60 | "outputs": [ 61 | { 62 | "data": { 63 | "text/plain": [ 64 | "(torch.Size([2, 5, 9]), \n", 65 | " (0 ,.,.) = \n", 66 | " 1.2783 -2.0387 0.3489 2.5870 0.4003 -0.8323 0.4795 -0.1188 1.5058\n", 67 | " 1.9345 0.6939 -0.6005 0.0113 -1.3579 -1.5732 0.6476 0.3265 2.0663\n", 68 | " -1.7766 1.7161 1.5391 -0.2763 -1.0611 -1.7473 1.4869 0.1519 0.5832\n", 69 | " 0.2865 1.8251 -0.3772 0.1366 -0.1392 -1.4507 -0.0412 0.3509 -1.0199\n", 70 | " 1.2783 -2.0387 0.3489 2.5870 0.4003 -0.8323 0.4795 -0.1188 1.5058\n", 71 | " \n", 72 | " (1 ,.,.) = \n", 73 | " -1.7766 1.7161 1.5391 -0.2763 -1.0611 -1.7473 1.4869 0.1519 0.5832\n", 74 | " -0.1941 -0.7290 -1.2425 2.0444 -0.2431 -0.7992 -1.1772 0.0754 0.0805\n", 75 | " -1.7917 0.3172 -0.4379 -0.0060 0.6569 2.5088 -0.9982 -1.5249 0.3790\n", 76 | " -0.9209 -0.2454 1.3805 -2.9321 -0.3620 1.3712 1.5372 -0.0832 -0.0649\n", 77 | " -0.7073 -0.0540 0.0487 1.0955 -0.9316 0.5288 0.2074 0.9193 0.6603\n", 78 | " [torch.FloatTensor of size 2x5x9], 2.587045907974243, 0.09486205115293463, -2.9320576190948486)" 79 | ] 80 | }, 81 | "execution_count": 13, 82 | "metadata": {}, 83 | "output_type": "execute_result" 84 | } 85 | ], 86 | "source": [ 87 | "M_new = 9\n", 88 | "embedding_new = nn.Embedding(N, M_new)\n", 89 | "output = embedding_new(input)\n", 90 | "output.size(), output.data, output.data.max(), output.data.mean(), output.data.min()" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": 11, 96 | "metadata": {}, 97 | "outputs": [ 98 | { 99 | "ename": "RuntimeError", 100 | "evalue": "index out of range at /Users/soumith/minicondabuild3/conda-bld/pytorch_1512381214802/work/torch/lib/TH/generic/THTensorMath.c:277", 101 | "output_type": "error", 102 | "traceback": [ 103 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 104 | "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", 105 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mN_invalid\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m999\u001b[0m \u001b[0;31m# the last val in input is illegal\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0membedding_invalid\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mEmbedding\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mN_invalid\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mM\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0membedding_invalid\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", 106 | "\u001b[0;32m~/miniconda3/envs/ml/lib/python3.6/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 323\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_forward_pre_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 324\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 325\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 326\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 327\u001b[0m \u001b[0mhook_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 107 | "\u001b[0;32m~/miniconda3/envs/ml/lib/python3.6/site-packages/torch/nn/modules/sparse.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 101\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 102\u001b[0m \u001b[0mpadding_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax_norm\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnorm_type\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 103\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mscale_grad_by_freq\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msparse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 104\u001b[0m )\n\u001b[1;32m 105\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 108 | "\u001b[0;32m~/miniconda3/envs/ml/lib/python3.6/site-packages/torch/nn/_functions/thnn/sparse.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(cls, ctx, indices, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)\u001b[0m\n\u001b[1;32m 57\u001b[0m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mindex_select\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mindices\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 58\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 59\u001b[0;31m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mindex_select\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mindices\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mview\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 60\u001b[0m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0moutput\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mview\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mindices\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mindices\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mweight\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 61\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 109 | "\u001b[0;31mRuntimeError\u001b[0m: index out of range at /Users/soumith/minicondabuild3/conda-bld/pytorch_1512381214802/work/torch/lib/TH/generic/THTensorMath.c:277" 110 | ] 111 | } 112 | ], 113 | "source": [ 114 | "N_invalid = 999 # the last val in input is illegal\n", 115 | "embedding_invalid = nn.Embedding(N_invalid, M)\n", 116 | "embedding_invalid(input)" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": null, 122 | "metadata": { 123 | "collapsed": true 124 | }, 125 | "outputs": [], 126 | "source": [] 127 | } 128 | ], 129 | "metadata": { 130 | "kernelspec": { 131 | "display_name": "Python 3", 132 | "language": "python", 133 | "name": "python3" 134 | }, 135 | "language_info": { 136 | "codemirror_mode": { 137 | "name": "ipython", 138 | "version": 3 139 | }, 140 | "file_extension": ".py", 141 | "mimetype": "text/x-python", 142 | "name": "python", 143 | "nbconvert_exporter": "python", 144 | "pygments_lexer": "ipython3", 145 | "version": "3.6.3" 146 | } 147 | }, 148 | "nbformat": 4, 149 | "nbformat_minor": 2 150 | } 151 | -------------------------------------------------------------------------------- /test_VarLenDataset.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import torch\n", 12 | "import torch.utils.data as data\n", 13 | "import torchaudio\n", 14 | "import librosa\n", 15 | "import numpy as np\n", 16 | "import random\n", 17 | "import os\n", 18 | "import glob" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 11, 24 | "metadata": { 25 | "collapsed": true, 26 | "scrolled": false 27 | }, 28 | "outputs": [], 29 | "source": [ 30 | "class VariableLengthDataset(data.Dataset):\n", 31 | " def __init__(self, manifest, snippet_length=24000, get_sequentially=False, ret_np=False, use_librosa=False):\n", 32 | " self.manifest = manifest\n", 33 | " self.snippet_length = snippet_length\n", 34 | " self.get_sequentially = get_sequentially\n", 35 | " self.use_librosa = use_librosa\n", 36 | " self.ret_np = ret_np\n", 37 | " self.acc = 0\n", 38 | " self.snippet_counter = 0\n", 39 | " self.audio_idx = 0\n", 40 | " self.st = 0\n", 41 | " self.data = {}\n", 42 | " def __getitem__(self, index):\n", 43 | " # load audio data from file or cache\n", 44 | " if self.snippet_counter == 0:\n", 45 | " self.audio_idx = index - self.acc\n", 46 | " apath = self.manifest[self.audio_idx]\n", 47 | " if apath not in self.data:\n", 48 | " if self.use_librosa:\n", 49 | " sig, sr = librosa.core.load(apath, sr=None)\n", 50 | " sig = torch.from_numpy(sig).unsqueeze(1).float()\n", 51 | " else:\n", 52 | " sig, sr = torchaudio.load(apath, normalization=True)\n", 53 | " self.data[apath] = (sig, sr)\n", 54 | " else:\n", 55 | " sig, sr = self.data[apath]\n", 56 | "\n", 57 | " # increase iterations based on length of audio\n", 58 | " num_snippets = int(sig.size(0) // self.snippet_length)\n", 59 | " self.acc += max(num_snippets-1,0)\n", 60 | " else:\n", 61 | " apath = self.manifest[self.audio_idx]\n", 62 | " sig, sr = self.data[apath]\n", 63 | " num_snippets = int(sig.size(0) // self.snippet_length)\n", 64 | "\n", 65 | " # create snippet\n", 66 | " if self.get_sequentially:\n", 67 | " self.st += self.snippet_length\n", 68 | " else:\n", 69 | " self.st = random.randrange(int(sig.size(0)-self.snippet_length))\n", 70 | " ret_sig = sig[self.st:(self.st+self.snippet_length)]\n", 71 | " if self.ret_np:\n", 72 | " ret_sig = ret_sig.numpy()\n", 73 | "\n", 74 | " # update counter for current audio file\n", 75 | " self.snippet_counter += 1\n", 76 | "\n", 77 | " # label creation\n", 78 | " spkr = os.path.dirname(apath).rsplit(\"/\", 1)[-1]\n", 79 | " spkr = 0\n", 80 | "\n", 81 | " # check for reset\n", 82 | " if self.snippet_counter >= num_snippets:\n", 83 | " self.snippet_counter = 0\n", 84 | " self.st = 0\n", 85 | "\n", 86 | " return ret_sig, spkr\n", 87 | "\n", 88 | " def __len__(self):\n", 89 | " return len(self.manifest) + self.acc\n", 90 | "\n", 91 | " def reset_acc(self):\n", 92 | " self.acc = 0\n", 93 | "\n", 94 | "class FixedLengthDataset(data.Dataset):\n", 95 | " def __init__(self, manifest, transforms = snippet_length=24000, ret_np=False, use_librosa=False):\n", 96 | " self.manifest = manifest\n", 97 | " self.snippet_length = snippet_length\n", 98 | " self.use_librosa = use_librosa\n", 99 | " self.ret_np = ret_np\n", 100 | " self.num_snippets = 1\n", 101 | " self.acc = 0\n", 102 | " self.audio_idx = 0\n", 103 | " self.st = 0\n", 104 | " self.data = {}\n", 105 | " def __getitem__(self, index):\n", 106 | " # load audio data from file or cache\n", 107 | " self.audio_idx = index if self.num_snippets == 1 else index // self.num_snippets\n", 108 | " apath = self.manifest[self.audio_idx]\n", 109 | " if self.use_librosa:\n", 110 | " sig, sr = librosa.core.load(apath, sr=None)\n", 111 | " sig = torch.from_numpy(sig).unsqueeze(1).float()\n", 112 | " else:\n", 113 | " sig, sr = torchaudio.load(apath, normalization=True)\n", 114 | "\n", 115 | " # create snippet\n", 116 | " if sig.size(0) < self.snippet_length:\n", 117 | " ret_sig = sig\n", 118 | " else:\n", 119 | " self.st = random.randrange(int(sig.size(0)-self.snippet_length))\n", 120 | " ret_sig = sig[self.st:(self.st+self.snippet_length)]\n", 121 | " if self.ret_np:\n", 122 | " ret_sig = ret_sig.numpy()\n", 123 | "\n", 124 | " # label creation\n", 125 | " #spkr = os.path.dirname(apath).rsplit(\"/\", 1)[-1]\n", 126 | " spkr = 0 # just using a dummy label now.\n", 127 | "\n", 128 | " return ret_sig, spkr\n", 129 | "\n", 130 | " def __len__(self):\n", 131 | " return len(self.manifest) * self.num_snippets\n", 132 | "\n", 133 | "def run_dataset():\n", 134 | " for epoch in range(1):\n", 135 | " all_data = [(x, label) for x, label in ds]\n", 136 | " print(epoch, len(all_data))\n", 137 | " try:\n", 138 | " ds.reset_acc()\n", 139 | " except:\n", 140 | " pass\n" 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": 3, 146 | "metadata": {}, 147 | "outputs": [ 148 | { 149 | "name": "stdout", 150 | "output_type": "stream", 151 | "text": [ 152 | "0\n", 153 | "CPU times: user 91.1 ms, sys: 8.62 ms, total: 99.7 ms\n", 154 | "Wall time: 291 ms\n", 155 | "0\n", 156 | "CPU times: user 79 ms, sys: 6.72 ms, total: 85.7 ms\n", 157 | "Wall time: 69.8 ms\n" 158 | ] 159 | } 160 | ], 161 | "source": [ 162 | "datadir = \"/home/david/Programming/tests/pcsnpny-20150204-mkj\"\n", 163 | "audio_manifest = [a for a in glob.glob(datadir+\"/**/*.wav\", recursive=True)]\n", 164 | "\n", 165 | "ds = VariableLengthDataset(audio_manifest, 12000, get_sequentially=True, ret_np=False, use_librosa=False)\n", 166 | "%time run_dataset()\n", 167 | "ds = VariableLengthDataset(audio_manifest, 12000, get_sequentially=True, ret_np=False, use_librosa=True)\n", 168 | "%time run_dataset()\n" 169 | ] 170 | }, 171 | { 172 | "cell_type": "code", 173 | "execution_count": 4, 174 | "metadata": { 175 | "scrolled": false 176 | }, 177 | "outputs": [ 178 | { 179 | "name": "stdout", 180 | "output_type": "stream", 181 | "text": [ 182 | "10 1\n", 183 | "torch.Size([10, 12000, 1]) torch.Size([10])\n" 184 | ] 185 | } 186 | ], 187 | "source": [ 188 | "ds = VariableLengthDataset(audio_manifest, 12000, get_sequentially=False, ret_np=False, use_librosa=False)\n", 189 | "dl = data.DataLoader(ds, batch_size=10)\n", 190 | "print(len(ds), len(dl))\n", 191 | "for mb, tgts in dl:\n", 192 | " print(mb.size(), tgts.size())\n", 193 | " break" 194 | ] 195 | }, 196 | { 197 | "cell_type": "code", 198 | "execution_count": 13, 199 | "metadata": {}, 200 | "outputs": [ 201 | { 202 | "name": "stdout", 203 | "output_type": "stream", 204 | "text": [ 205 | "0 10\n", 206 | "CPU times: user 18.1 ms, sys: 0 ns, total: 18.1 ms\n", 207 | "Wall time: 18.4 ms\n" 208 | ] 209 | } 210 | ], 211 | "source": [ 212 | "datadir = \"/home/david/Programming/tests/pcsnpny-20150204-mkj\"\n", 213 | "audio_manifest = [a for a in glob.glob(datadir+\"/**/*.wav\", recursive=True)]\n", 214 | "ds = FixedLengthDataset(audio_manifest, 12000, ret_np=True, use_librosa=True)\n", 215 | "%time run_dataset()\n" 216 | ] 217 | } 218 | ], 219 | "metadata": { 220 | "kernelspec": { 221 | "display_name": "Python [conda root]", 222 | "language": "python", 223 | "name": "conda-root-py" 224 | }, 225 | "language_info": { 226 | "codemirror_mode": { 227 | "name": "ipython", 228 | "version": 3 229 | }, 230 | "file_extension": ".py", 231 | "mimetype": "text/x-python", 232 | "name": "python", 233 | "nbconvert_exporter": "python", 234 | "pygments_lexer": "ipython3", 235 | "version": "3.6.1" 236 | } 237 | }, 238 | "nbformat": 4, 239 | "nbformat_minor": 2 240 | } 241 | -------------------------------------------------------------------------------- /test_pdb_debugger.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import pdb" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 3, 17 | "metadata": {}, 18 | "outputs": [ 19 | { 20 | "name": "stdout", 21 | "output_type": "stream", 22 | "text": [ 23 | "--Return--\n", 24 | "> (2)()->None\n", 25 | "-> pdb.set_trace()\n", 26 | "(Pdb) cont\n", 27 | "1\n", 28 | "3\n", 29 | "6\n", 30 | "10\n", 31 | "15\n", 32 | "21\n", 33 | "28\n", 34 | "36\n", 35 | "45\n", 36 | "55\n" 37 | ] 38 | } 39 | ], 40 | "source": [ 41 | "x = 0\n", 42 | "pdb.set_trace()\n", 43 | "for i in range(10):\n", 44 | " x += i+1\n", 45 | " print(x)" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": null, 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [] 54 | } 55 | ], 56 | "metadata": { 57 | "kernelspec": { 58 | "display_name": "Python 3", 59 | "language": "python", 60 | "name": "python3" 61 | }, 62 | "language_info": { 63 | "codemirror_mode": { 64 | "name": "ipython", 65 | "version": 3 66 | }, 67 | "file_extension": ".py", 68 | "mimetype": "text/x-python", 69 | "name": "python", 70 | "nbconvert_exporter": "python", 71 | "pygments_lexer": "ipython3", 72 | "version": "3.6.1" 73 | } 74 | }, 75 | "nbformat": 4, 76 | "nbformat_minor": 2 77 | } 78 | -------------------------------------------------------------------------------- /test_tqdm.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 10, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "data": { 10 | "application/vnd.jupyter.widget-view+json": { 11 | "model_id": "983855d2e2764926ada81bfb689a388e" 12 | } 13 | }, 14 | "metadata": {}, 15 | "output_type": "display_data" 16 | }, 17 | { 18 | "name": "stdout", 19 | "output_type": "stream", 20 | "text": [ 21 | "\n" 22 | ] 23 | } 24 | ], 25 | "source": [ 26 | "from tqdm import tqdm_notebook\n", 27 | "import time\n", 28 | "\n", 29 | "bar = tqdm_notebook(range(100), desc=\"test range\")\n", 30 | "x = 0\n", 31 | "for j, i in enumerate(bar):\n", 32 | " x += i\n", 33 | " time.sleep(0.1)\n", 34 | " bar.set_description(\"Processing {}\".format(x))\n" 35 | ] 36 | } 37 | ], 38 | "metadata": { 39 | "kernelspec": { 40 | "display_name": "Python [default]", 41 | "language": "python", 42 | "name": "python3" 43 | }, 44 | "language_info": { 45 | "codemirror_mode": { 46 | "name": "ipython", 47 | "version": 3 48 | }, 49 | "file_extension": ".py", 50 | "mimetype": "text/x-python", 51 | "name": "python", 52 | "nbconvert_exporter": "python", 53 | "pygments_lexer": "ipython3", 54 | "version": "3.6.1" 55 | } 56 | }, 57 | "nbformat": 4, 58 | "nbformat_minor": 2 59 | } 60 | --------------------------------------------------------------------------------