├── .gitignore
├── CNN2RNN.ipynb
├── G729VAD.ipynb
├── GRUAutoencoder.ipynb
├── LICENSE
├── README.md
├── Youtube-GraphTheory-EdgeWeightedEx.sagews
├── allstate_pca.ipynb
├── audio_rnn_basic.ipynb
├── c++
└── pytorch_dcgan_tutorial
│ ├── CMakeLists.txt
│ └── dcgan.cpp
├── char-rnn.ipynb
├── convolution1d_tests.ipynb
├── coursera-algos-knapsack-ps.ipynb
├── coursera-algos-twosum.ipynb
├── data
├── coursera_algos
│ ├── algos-2sums.txt.zip
│ ├── kargerMinCut.txt
│ ├── knapsack1.txt
│ └── knapsack_big.txt
└── tmp.zip
├── databricks_python_graphframes_GraphExample.ipynb
├── databricks_python_numpy_spark_kMeansExample.ipynb
├── denoising_autoencoder.ipynb
├── imdb_tokenize.ipynb
├── logistic-regression-pure-julia.ipynb
├── macro-hw3-2.ipynb
├── pymc-test.ipynb
├── pytorch_attention_audio.py
├── pytorch_basics.ipynb
├── pytorch_embedding_test.ipynb
├── pytorch_tutorial_classify_names.ipynb
├── rnn_autoencoder.ipynb
├── tedlium_labels.ipynb
├── test_VarLenDataset.ipynb
├── test_pdb_debugger.ipynb
├── test_tqdm.ipynb
└── timeseries.ipynb
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 |
27 | # PyInstaller
28 | # Usually these files are written by a python script from a template
29 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
30 | *.manifest
31 | *.spec
32 |
33 | # Installer logs
34 | pip-log.txt
35 | pip-delete-this-directory.txt
36 |
37 | # Unit test / coverage reports
38 | htmlcov/
39 | .tox/
40 | .coverage
41 | .coverage.*
42 | .cache
43 | nosetests.xml
44 | coverage.xml
45 | *,cover
46 | .hypothesis/
47 |
48 | # Translations
49 | *.mo
50 | *.pot
51 |
52 | # Django stuff:
53 | *.log
54 | local_settings.py
55 |
56 | # Flask stuff:
57 | instance/
58 | .webassets-cache
59 |
60 | # Scrapy stuff:
61 | .scrapy
62 |
63 | # Sphinx documentation
64 | docs/_build/
65 |
66 | # PyBuilder
67 | target/
68 |
69 | # IPython Notebook
70 | .ipynb_checkpoints
71 |
72 | # pyenv
73 | .python-version
74 |
75 | # celery beat schedule file
76 | celerybeat-schedule
77 |
78 | # dotenv
79 | .env
80 |
81 | # virtualenv
82 | venv/
83 | ENV/
84 |
85 | # Spyder project settings
86 | .spyderproject
87 |
88 | # Rope project settings
89 | .ropeproject
90 |
--------------------------------------------------------------------------------
/CNN2RNN.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {
7 | "collapsed": true
8 | },
9 | "outputs": [],
10 | "source": [
11 | "import numpy as np\n",
12 | "import librosa \n",
13 | "import torch\n",
14 | "from torch import nn\n",
15 | "from torch.autograd import Variable\n",
16 | "import torch.nn.functional as F\n",
17 | "import matplotlib.pyplot as plt\n",
18 | "%matplotlib inline"
19 | ]
20 | },
21 | {
22 | "cell_type": "code",
23 | "execution_count": 2,
24 | "metadata": {},
25 | "outputs": [
26 | {
27 | "name": "stdout",
28 | "output_type": "stream",
29 | "text": [
30 | "(78000,) 16000 (74000,) 16000\n"
31 | ]
32 | }
33 | ],
34 | "source": [
35 | "files = librosa.util.find_files(\"pcsnpny-20150204-mkj/wav\")\n",
36 | "file = files[0]\n",
37 | "sig1, sr1 = librosa.core.load(file, sr=None)\n",
38 | "sig2, sr2 = librosa.core.load(files[1], sr=None)\n",
39 | "print(sig1.shape, sr1, sig2.shape, sr2)"
40 | ]
41 | },
42 | {
43 | "cell_type": "code",
44 | "execution_count": 3,
45 | "metadata": {},
46 | "outputs": [
47 | {
48 | "name": "stdout",
49 | "output_type": "stream",
50 | "text": [
51 | "torch.Size([1, 1, 78000])\n",
52 | "input size: torch.Size([1, 1, 78000])\n",
53 | "conv1d out: torch.Size([1, 66, 242])\n",
54 | "conv2d out: torch.Size([1, 64, 64, 240])\n",
55 | "gru in: torch.Size([1, 240, 4096])\n",
56 | "gru out: torch.Size([1, 240, 1024])\n",
57 | "dense in: torch.Size([240, 1024])\n",
58 | "torch.Size([1, 240])\n",
59 | "input size: torch.Size([1, 1, 74000])\n",
60 | "conv1d out: torch.Size([1, 66, 230])\n",
61 | "conv2d out: torch.Size([1, 64, 64, 228])\n",
62 | "gru in: torch.Size([1, 228, 4096])\n",
63 | "gru out: torch.Size([1, 228, 1024])\n",
64 | "dense in: torch.Size([228, 1024])\n",
65 | "torch.Size([1, 228])\n"
66 | ]
67 | }
68 | ],
69 | "source": [
70 | "class CNN2RNN(nn.Module):\n",
71 | " def __init__(self, conv_in_channels, conv_out_features, ws, hs, rnn_hidden_size, rnn_output_size):\n",
72 | " super(CNN2RNN, self).__init__()\n",
73 | " self.cin_channels = conv_in_channels\n",
74 | " self.cout_features = conv_out_features\n",
75 | " self.rnn_hid = rnn_hidden_size\n",
76 | " self.rnn_out = rnn_output_size\n",
77 | " \n",
78 | " \n",
79 | " # hard coding vars so I know what they are.\n",
80 | " n_layers = 2 # number of layers of RNN\n",
81 | " batch_size = 1\n",
82 | " kernel2d = 3\n",
83 | " # hidden initialization\n",
84 | " self.hidden = self.init_hidden(n_layers, batch_size, rnn_hidden_size)\n",
85 | " \n",
86 | " # net layer types\n",
87 | " self.c1 = nn.Conv1d(conv_in_channels, conv_out_features+kernel2d-1, ws, stride=hs)\n",
88 | " self.c2 = nn.Conv2d(1, conv_out_features, kernel2d)\n",
89 | " self.gru = nn.GRU(conv_out_features*conv_out_features, rnn_hidden_size, n_layers, \n",
90 | " batch_first=True, bidirectional=False)\n",
91 | " self.dense = nn.Linear(rnn_hidden_size, 1)\n",
92 | " \n",
93 | " def forward(self, input):\n",
94 | " print(\"input size: {}\".format(input.size()))\n",
95 | " conv_out = self.c1(input)\n",
96 | " print(\"conv1d out: {}\".format(conv_out.size()))\n",
97 | " conv2_out = self.c2(conv_out.unsqueeze(0))\n",
98 | " print(\"conv2d out: {}\".format(conv2_out.size()))\n",
99 | " gru_in = conv2_out.view(input.size(0), -1, self.cout_features * self.cout_features)\n",
100 | " print(\"gru in: {}\".format(gru_in.size()))\n",
101 | " gru_out, self.hidden = self.gru(gru_in, self.hidden)\n",
102 | " print(\"gru out: {}\".format(gru_out.size()))\n",
103 | " dense_in = gru_out.view(gru_in.size(1)*input.size(0) ,-1)\n",
104 | " print(\"dense in: {}\".format(dense_in.size()))\n",
105 | " out_space = self.dense(dense_in)\n",
106 | " out = F.sigmoid(out_space)\n",
107 | " out = out.view(input.size(0), -1)\n",
108 | " return(out)\n",
109 | " \n",
110 | " def init_hidden(self, nl, bat_dim, hid_dim):\n",
111 | " # The axes: (num_layers, minibatch_size, hidden_dim)\n",
112 | " # see docs\n",
113 | " return (Variable(torch.zeros(nl, bat_dim, hid_dim)))\n",
114 | "\n",
115 | "ws=640\n",
116 | "hs=ws//2\n",
117 | "nb=64\n",
118 | "\n",
119 | "net = CNN2RNN(1, nb, ws, hs, 1024, 2)\n",
120 | "inputs1 = torch.Tensor(sig1)\n",
121 | "inputs1.unsqueeze_(0)\n",
122 | "inputs1.unsqueeze_(0)\n",
123 | "print(net(Variable(inputs1)).size())\n",
124 | "inputs2 = torch.Tensor(sig2)\n",
125 | "inputs2.unsqueeze_(0)\n",
126 | "inputs2.unsqueeze_(0)\n",
127 | "print(net(Variable(inputs2)).size())\n"
128 | ]
129 | },
130 | {
131 | "cell_type": "code",
132 | "execution_count": null,
133 | "metadata": {
134 | "collapsed": true
135 | },
136 | "outputs": [],
137 | "source": [
138 | "torch"
139 | ]
140 | }
141 | ],
142 | "metadata": {
143 | "kernelspec": {
144 | "display_name": "Python 3",
145 | "language": "python",
146 | "name": "python3"
147 | },
148 | "language_info": {
149 | "codemirror_mode": {
150 | "name": "ipython",
151 | "version": 3
152 | },
153 | "file_extension": ".py",
154 | "mimetype": "text/x-python",
155 | "name": "python",
156 | "nbconvert_exporter": "python",
157 | "pygments_lexer": "ipython3",
158 | "version": "3.6.1"
159 | }
160 | },
161 | "nbformat": 4,
162 | "nbformat_minor": 2
163 | }
164 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2016 David Pollack
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Programming Notebooks
2 |
3 | ## Python
4 |
5 | In the summer of 2016, I spent some time with various free online resources to learn
6 | graph theory and basic algorithm programming. Most of these notebooks were written
7 | for various assignments in these courses. Not all of my notebooks are here.
8 |
9 | ### List of Courses
10 | 1. [Sarada Herke's Graph Theory](https://www.youtube.com/user/DrSaradaHerke)
11 | 2. [Coursera - Social and Economic Networks: Models and Analysis](https://www.coursera.org/learn/social-economic-networks/home/welcome)
12 | 3. [Coursera - Algorithms: Design and Analysis, Part 1](https://www.coursera.org/learn/algorithm-design-analysis/home/welcome)
13 | 4. [Coursera - Algorithms: Design and Analysis, Part 2](https://www.coursera.org/learn/algorithm-design-analysis-2/home/welcome)
14 | 5. [Coursera - Probabilistic Graphical Models 1: Representation](https://www.coursera.org/learn/probabilistic-graphical-models/home/welcome)
15 |
--------------------------------------------------------------------------------
/Youtube-GraphTheory-EdgeWeightedEx.sagews:
--------------------------------------------------------------------------------
1 | ︠2e450aa8-978f-4dd7-bef5-643f502926e8s︠
2 | import numpy as np
3 | a = [0, 3, 1, 2, 0, 0]
4 | b = [3, 0, 0, 0, 0, 0]
5 | c = [1, 0, 0, 0, 4, 2]
6 | d = [2, 0, 0, 0, 1, 0]
7 | e = [0, 0, 4, 1, 0, 0]
8 | f = [0, 0, 2, 0, 0, 0]
9 | A = Matrix([a, b, c, d, e, f])
10 | G = Graph(A, weighted=True)
11 |
12 | def dijkstras_algo(G, u, S = [], t = []):
13 | byw = G.weighted()
14 | if not S:
15 | # initial empty list
16 | S.append(u)
17 | t = np.zeros(G.order())
18 | t_max = 1 + np.sum(G.weighted_adjacency_matrix()) / 2
19 | for i in range(len(t)):
20 | if i != u:
21 | t[i] = t_max
22 | neighborhood = list(set(G.neighbors(u)) - set(S))
23 | for v in neighborhood:
24 | if v not in S:
25 | t[v] = min(t[v], t[u] + G.distance(u,v,by_weight=byw))
26 | left = list(set(G.vertices()) - set(S))
27 | w = left[np.argmin(t[left])]
28 | S.append(w)
29 | if set(G.vertices()) != set(S):
30 | dijkstras_algo(G, w, S, t)
31 | return S, t
32 |
33 | S, t = dijkstras_algo(G, 0)
34 | print S, t
35 |
36 | ︡9ffef417-fe72-4731-adbf-f9f2fd917223︡{"stdout":"[0, 2, 3, 1, 4, 5] [ 0. 3. 1. 2. 3. 3.]\n"}︡{"done":true}︡
37 | ︠2a091a46-11a3-4c7c-ae57-e523ef4d7dd4s︠
38 | a = [0,2,3,0,0,0]
39 | b = [2,0,2,1,3,3]
40 | c = [3,2,0,0,1,0]
41 | d = [0,1,0,0,2,1]
42 | e = [0,3,1,2,0,2]
43 | f = [0,3,0,1,2,0]
44 | A = Matrix([a,b,c,d,e,f])
45 | G = Graph(A, weighted=True)
46 | S, t = dijkstras_algo(G, 0, S = [], t = [])
47 | print S, t
48 |
49 | ︡8eaafbcd-1aeb-45de-a972-fc33646cade1︡{"stdout":"[0, 1, 2, 3, 4, 5] [ 0. 2. 3. 3. 4. 4.]\n"}︡{"done":true}︡
50 | ︠1e4eee2a-d583-4206-a288-aa6d2f21c0d2s︠
51 | x = 4
52 |
53 | a = [0, 3, 1, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0]
54 | b = [3, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]
55 | c = [1, 0, 0, 4, 0, 0, 7, 0, 0, 0, 0, 0, 0]
56 | d = [0, 1, 4, 0, 0, 2, 0, 6, 0, 0, 0, 0, 0]
57 | e = [5, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0]
58 | f = [0, 0, 0, 2, 2, 0, 1, x, 0, 0, 0, 1, 0]
59 | g = [0, 0, 7, 0, 0, 1, 0, 2, 0, 6, 0, 0, 0]
60 | h = [0, 0, 0, 6, 0, x, 2, 0, 1, 0, 0, 0, 0]
61 | i = [0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 7, 0, 1]
62 | j = [0, 0, 0, 0, 0, 0, 6, 0, 2, 0, 4, 0, 0]
63 | k = [0, 0, 0, 0, 0, 0, 0, 0, 7, 4, 0, 3, 0]
64 | l = [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 3, 0, 4]
65 | m = [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 4, 0]
66 | A=Matrix([a, b, c, d, e, f, g, h, i, j, k, l, m])
67 | G = Graph(A, weighted=True)
68 |
69 | print G.shortest_path(0,9, by_weight=True)
70 | print G.shortest_path(0,9, by_weight=False)
71 | print G.distance(0,9, by_weight=True)
72 | print G.distance(0,9, by_weight=False)
73 |
74 | S, t = dijkstras_algo(G, 0, S = [], t = [])
75 | print S, t
76 | ︡8fd014ff-0088-45a8-8475-cc59e2d270b0︡{"stdout":"[0, 1, 3, 5, 6, 7, 8, 9]\n"}︡{"stdout":"[0, 2, 6, 9]\n"}︡{"stdout":"12\n"}︡{"stdout":"3\n"}︡{"stdout":"[0, 2, 1, 3, 4, 5, 6, 11, 7, 8, 10, 12, 9] [ 0. 3. 1. 4. 5. 6. 7. 9. 10. 12. 10. 7. 11.]\n"}︡{"done":true}︡
77 | ︠9ba6667e-18b2-40ba-9dc5-5123e35dd7d4s︠
78 | a = [0,7,4,5,0,0,0,0,0]
79 | b = [7,0,2,0,25,0,0,0,0]
80 | c = [4,2,0,0,0,0,0,9,0]
81 | d = [5,0,0,0,0,9,0,0,0]
82 | e = [0,25,0,0,0,0,10,0,0]
83 | f = [0,0,0,9,0,0,0,20,0]
84 | g = [0,0,0,0,10,0,0,0,2]
85 | h = [0,0,9,0,0,20,0,0,3]
86 | i = [0,0,0,0,0,0,2,3,0]
87 | A = Matrix([a,b,c,d,e,f,g,h,i])
88 | G = Graph(A, weighted=True)
89 | G.plot()
90 |
91 | S, t = dijkstras_algo(G, 0, S = [], t = [])
92 | print S, t
93 | ︡89a52e87-0c4f-4ba7-8c4f-a943ed0d913a︡{"file":{"filename":"/projects/e4d6c2ac-d4ac-4792-9987-c8fef33777ee/.sage/temp/compute1-us/12429/tmp_7Y2ieD.svg","show":true,"text":null,"uuid":"c5d4a60b-592f-4e49-9116-e3eadb1a9287"},"once":false}︡{"html":"
"}︡{"stdout":"[0, 2, 3, 1, 7, 5, 8, 6, 4] [ 0. 6. 4. 5. 28. 14. 18. 13. 16.]\n"}︡{"done":true}︡
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
--------------------------------------------------------------------------------
/audio_rnn_basic.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {
7 | "collapsed": true
8 | },
9 | "outputs": [],
10 | "source": [
11 | "import numpy as np\n",
12 | "import librosa"
13 | ]
14 | },
15 | {
16 | "cell_type": "code",
17 | "execution_count": 3,
18 | "metadata": {},
19 | "outputs": [
20 | {
21 | "name": "stdout",
22 | "output_type": "stream",
23 | "text": [
24 | "(78000,) 16000 (74000,) 16000\n"
25 | ]
26 | }
27 | ],
28 | "source": [
29 | "# load audio\n",
30 | "audio_manifest = librosa.util.find_files(\"pcsnpny-20150204-mkj\")\n",
31 | "sig1, sr1 = librosa.core.load(audio_manifest[0], sr=None)\n",
32 | "sig2, sr2 = librosa.core.load(audio_manifest[1], sr=None)\n",
33 | "print(sig1.shape, sr1, sig2.shape, sr2)"
34 | ]
35 | },
36 | {
37 | "cell_type": "code",
38 | "execution_count": 8,
39 | "metadata": {},
40 | "outputs": [
41 | {
42 | "name": "stdout",
43 | "output_type": "stream",
44 | "text": [
45 | "(128, 153) (128, 145) 0.0 0.0\n"
46 | ]
47 | }
48 | ],
49 | "source": [
50 | "win_len = 1024\n",
51 | "hop_len = win_len // 2\n",
52 | "gram1 = librosa.feature.melspectrogram(sig1, sr=sr1, n_fft=win_len, hop_length=hop_len)\n",
53 | "gram2 = librosa.feature.melspectrogram(sig2, sr=sr2, n_fft=win_len, hop_length=hop_len)\n",
54 | "\n",
55 | "gram1 = librosa.power_to_db(gram1, ref=np.max)\n",
56 | "gram2 = librosa.power_to_db(gram2, ref=np.max)\n",
57 | "\n",
58 | "gram1 -= gram1.min()\n",
59 | "gram2 -= gram2.min()\n",
60 | "\n",
61 | "print(gram1.shape, gram2.shape, gram1.min(), gram2.min())"
62 | ]
63 | },
64 | {
65 | "cell_type": "code",
66 | "execution_count": 56,
67 | "metadata": {},
68 | "outputs": [
69 | {
70 | "name": "stdout",
71 | "output_type": "stream",
72 | "text": [
73 | "[153, 145] torch.Size([2, 153, 128]) torch.Size([153, 128])\n",
74 | "pack size: torch.Size([298, 128])\n",
75 | "h0: torch.Size([3, 2, 20])\n",
76 | "out size: torch.Size([298, 20]) torch.Size([3, 2, 20])\n",
77 | "torch.Size([2, 153, 20]) [153, 145]\n"
78 | ]
79 | }
80 | ],
81 | "source": [
82 | "import torch\n",
83 | "import torch.nn as nn\n",
84 | "from torch.autograd import Variable\n",
85 | "\n",
86 | "BATCH_SIZE = 2\n",
87 | "MAX_LENGTH = max([g.shape[1] for g in [gram1, gram2]])\n",
88 | "HIDDEN_SIZE = 20\n",
89 | "N_LAYERS = 3\n",
90 | "\n",
91 | "def pad2d(t, length):\n",
92 | " if t.size(1) == length:\n",
93 | " return(t)\n",
94 | " else:\n",
95 | " return torch.cat((t, t.new(t.size(0), length - t.size(1)).zero_()),1)\n",
96 | "\n",
97 | " \n",
98 | "seq_lens = [g.shape[1] for g in [gram1, gram2]]\n",
99 | "batch_in = [pad2d(torch.Tensor(g), MAX_LENGTH) for g in [gram1, gram2]]\n",
100 | "batch_in = torch.stack(batch_in).transpose(1,2)\n",
101 | "print(seq_lens, batch_in.size(), in_size)\n",
102 | "batch_in = Variable(batch_in)\n",
103 | "\n",
104 | "pack = torch.nn.utils.rnn.pack_padded_sequence(batch_in, seq_lens, batch_first=True)\n",
105 | "print(\"pack size:\", pack.data.size())\n",
106 | "\n",
107 | "rnn = nn.GRU(128, HIDDEN_SIZE, N_LAYERS, batch_first=True)\n",
108 | "h0 = Variable(torch.randn(N_LAYERS, BATCH_SIZE, HIDDEN_SIZE))\n",
109 | "print(\"h0:\", h0.size())\n",
110 | "# forward\n",
111 | "out, hidden_new = rnn(pack, h0)\n",
112 | "\n",
113 | "print(\"out size:\", out.data.size(), hidden_new.size())\n",
114 | "\n",
115 | "unpacked, unpacked_len = torch.nn.utils.rnn.pad_packed_sequence(out, batch_first=True)\n",
116 | "\n",
117 | "print(unpacked.size(), unpacked_len)\n"
118 | ]
119 | },
120 | {
121 | "cell_type": "code",
122 | "execution_count": 55,
123 | "metadata": {},
124 | "outputs": [
125 | {
126 | "data": {
127 | "text/plain": [
128 | "torch.utils.data.dataset.TensorDataset"
129 | ]
130 | },
131 | "execution_count": 55,
132 | "metadata": {},
133 | "output_type": "execute_result"
134 | }
135 | ],
136 | "source": [
137 | "from torch.utils.data import TensorDataset, DataLoader\n",
138 | "DataLoader()"
139 | ]
140 | }
141 | ],
142 | "metadata": {
143 | "kernelspec": {
144 | "display_name": "Python 3",
145 | "language": "python",
146 | "name": "python3"
147 | },
148 | "language_info": {
149 | "codemirror_mode": {
150 | "name": "ipython",
151 | "version": 3
152 | },
153 | "file_extension": ".py",
154 | "mimetype": "text/x-python",
155 | "name": "python",
156 | "nbconvert_exporter": "python",
157 | "pygments_lexer": "ipython3",
158 | "version": "3.6.1"
159 | }
160 | },
161 | "nbformat": 4,
162 | "nbformat_minor": 2
163 | }
164 |
--------------------------------------------------------------------------------
/c++/pytorch_dcgan_tutorial/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
2 | project(dcgan)
3 |
4 | find_package(Torch REQUIRED)
5 |
6 | add_executable(dcgan dcgan.cpp)
7 | target_link_libraries(dcgan "${TORCH_LIBRARIES}")
8 | set_property(TARGET dcgan PROPERTY CXX_STANDARD 14)
9 |
--------------------------------------------------------------------------------
/c++/pytorch_dcgan_tutorial/dcgan.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | /*
5 | struct Net : torch::nn::Module {
6 | Net(int64_t N, int64_t M)
7 | : linear(register_module("linear", torch::nn::Linear(N, M))) {
8 | another_input = register_parameter("b", torch::randn(M));
9 | }
10 | torch::Tensor forward(torch::Tensor input) {
11 | return linear(input) + another_input;
12 | }
13 | torch::nn::Linear linear;
14 | torch::Tensor another_input;
15 | };
16 | */
17 |
18 | using namespace torch;
19 |
20 | int main() {
21 | string MNIST_path("/home/david/Programming/data/MNIST");
22 |
23 | nn::Sequential generator(
24 | // Layer 1
25 | nn::Conv2d(
26 | nn::Conv2dOptions(100, 256, 4).with_bias(false).transposed(true)),
27 | nn::BatchNorm(256), nn::Functional(torch::relu),
28 | // Layer 2
29 | nn::Conv2d(nn::Conv2dOptions(256, 128, 3)
30 | .stride(2)
31 | .padding(1)
32 | .with_bias(false)
33 | .transposed(true)),
34 | nn::BatchNorm(128), nn::Functional(torch::relu),
35 | // Layer 3
36 | nn::Conv2d(nn::Conv2dOptions(128, 64, 4)
37 | .stride(2)
38 | .padding(1)
39 | .with_bias(false)
40 | .transposed(true)),
41 | nn::BatchNorm(64), nn::Functional(torch::relu),
42 | // Layer 4
43 | nn::Conv2d(nn::Conv2dOptions(64, 1, 4)
44 | .stride(2)
45 | .padding(1)
46 | .with_bias(false)
47 | .transposed(true)),
48 | nn::Functional(torch::tanh));
49 |
50 | nn::Sequential discriminator(
51 | // Layer 1
52 | nn::Conv2d(
53 | nn::Conv2dOptions(1, 64, 4).stride(2).padding(1).with_bias(false)),
54 | nn::Functional(torch::leaky_relu, 0.2),
55 | // Layer 2
56 | nn::Conv2d(nn::Conv2dOptions(64, 128, 4)
57 | .stride(2)
58 | .padding(1)
59 | .with_bias(false)),
60 | nn::BatchNorm(128), nn::Functional(torch::leaky_relu, 0.2),
61 | // Layer 3
62 | nn::Conv2d(nn::Conv2dOptions(128, 256, 4)
63 | .stride(2)
64 | .padding(1)
65 | .with_bias(false)),
66 | nn::BatchNorm(256), nn::Functional(torch::leaky_relu, 0.2),
67 | // Layer 4
68 | nn::Conv2d(
69 | nn::Conv2dOptions(256, 1, 3).stride(1).padding(0).with_bias(false)),
70 | nn::Functional(torch::sigmoid));
71 |
72 | auto dataset = torch::data::datasets::MNIST(MNIST_path.c_str())
73 | .map(torch::data::transforms::Normalize<>(0.5, 0.5))
74 | .map(torch::data::transforms::Stack<>());
75 |
76 | auto data_loader = torch::data::make_data_loader(
77 | std::move(dataset),
78 | torch::data::DataLoaderOptions().batch_size(64).workers(2));
79 |
80 | torch::optim::Adam gen_opt(generator->parameters(),
81 | torch::optim::AdamOptions(2e-4).beta1(0.5));
82 | torch::optim::Adam dis_opt(discriminator->parameters(),
83 | torch::optim::AdamOptions(5e-4).beta1(0.5));
84 | int64_t kNumEpochs = 100;
85 | for (int64_t epoch = 1; epoch <= kNumEpochs; ++epoch) {
86 | int64_t batch_index = 0;
87 | for (torch::data::Example<> &batch : *data_loader) {
88 | discriminator->zero_grad();
89 | torch::Tensor real_images = batch.data;
90 | torch::Tensor real_labels =
91 | torch::empty(batch.data.size(0)).uniform_(0.8, 1.0);
92 | torch::Tensor real_output = discriminator->forward(real_images);
93 | torch::Tensor d_loss_real =
94 | torch::binary_cross_entropy(real_output, real_labels);
95 | d_loss_real.backward();
96 |
97 | torch::Tensor noise = torch::randn({batch.data.size(0), 100, 1, 1});
98 | torch::Tensor fake_images = generator->forward(noise);
99 | torch::Tensor fake_labels = torch::zeros(batch.data.size(0));
100 | torch::Tensor fake_output =
101 | discriminator->forward(fake_images.detach());
102 | torch::Tensor d_loss_fake =
103 | torch::binary_cross_entropy(fake_output, fake_labels);
104 | d_loss_fake.backward();
105 |
106 | torch::Tensor d_loss = d_loss_real + d_loss_fake;
107 | dis_opt.step();
108 |
109 | generator->zero_grad();
110 | fake_labels.fill_(1);
111 | fake_output = discriminator->forward(fake_images);
112 | torch::Tensor g_loss =
113 | torch::binary_cross_entropy(fake_output, fake_labels);
114 | g_loss.backward();
115 | gen_opt.step();
116 |
117 | std::printf("\r[%5d/%5d][%5d/%5d] D_loss: %.4f | G_loss: %.4f",
118 | epoch, kNumEpochs, ++batch_index, 938,
119 | d_loss.item(), g_loss.item());
120 | }
121 | }
122 | }
123 |
--------------------------------------------------------------------------------
/char-rnn.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 12,
6 | "metadata": {
7 | "collapsed": true
8 | },
9 | "outputs": [],
10 | "source": [
11 | "import torch\n",
12 | "import torch.nn as nn\n",
13 | "import torch.nn.functional as F\n",
14 | "import torch.optim as optim\n",
15 | "import torch.utils.data as data\n",
16 | "from torch.autograd import Variable\n",
17 | "\n",
18 | "from tqdm import tnrange, tqdm_notebook, tqdm"
19 | ]
20 | },
21 | {
22 | "cell_type": "code",
23 | "execution_count": 28,
24 | "metadata": {},
25 | "outputs": [],
26 | "source": [
27 | "class SimpleGRU(nn.Module):\n",
28 | " def __init__(self, vocab_size, emb_size, hid_size, batch_size, seq_len, n_layers=1):\n",
29 | " super(SimpleGRU, self).__init__()\n",
30 | " self.vocab_size = vocab_size\n",
31 | " self.emb_size = emb_size\n",
32 | " self.hid_size = hid_size\n",
33 | " self.n_layers = n_layers\n",
34 | " self.batch_size = batch_size\n",
35 | " self.seq_len = seq_len\n",
36 | " self.emb = nn.Embedding(vocab_size, emb_size)\n",
37 | " self.gru = nn.GRU(emb_size, hid_size, batch_first=True)\n",
38 | " self.fc1 = nn.Linear(seq_len * hid_size, vocab_size)\n",
39 | " self.selu = nn.SELU()\n",
40 | " self.logsoftmax = nn.LogSoftmax()\n",
41 | " def forward(self, input, hidden):\n",
42 | " x = self.emb(input)\n",
43 | " x, hidden = self.gru(x, hidden)\n",
44 | " x = x.contiguous().view(self.batch_size, -1)\n",
45 | " x = self.selu(self.fc1(x))\n",
46 | " x = self.logsoftmax(x)\n",
47 | " return x, hidden\n",
48 | "\n",
49 | "class CharDataset(data.Dataset):\n",
50 | " def __init__(self, data, seq_len):\n",
51 | " self.data = data\n",
52 | " self.seq_len = seq_len\n",
53 | " def __getitem__(self, index):\n",
54 | " inp_seq = self.data[index:(index+self.seq_len-1)]\n",
55 | " tgt_seq = torch.Tensor([self.data[index+self.seq_len]]).type(self.data.type())\n",
56 | " return inp_seq, tgt_seq\n",
57 | " def __len__(self):\n",
58 | " return len(self.data) - self.seq_len\n",
59 | " \n"
60 | ]
61 | },
62 | {
63 | "cell_type": "code",
64 | "execution_count": 22,
65 | "metadata": {
66 | "collapsed": true
67 | },
68 | "outputs": [],
69 | "source": [
70 | "seq_length = 25\n",
71 | "batch_size = 50\n",
72 | "\n",
73 | "emb_size = 25\n",
74 | "hid_size = 100\n",
75 | "n_layers = 3"
76 | ]
77 | },
78 | {
79 | "cell_type": "code",
80 | "execution_count": 23,
81 | "metadata": {},
82 | "outputs": [
83 | {
84 | "name": "stdout",
85 | "output_type": "stream",
86 | "text": [
87 | "1115394 1115394\n",
88 | "torch.Size([1115394])\n",
89 | "22307\n"
90 | ]
91 | }
92 | ],
93 | "source": [
94 | "with open(\"/home/david/Programming/data/project_gutenberg/tiny-shakespeare.txt\", \"r\") as f:\n",
95 | " text_raw = [c for l in f.readlines() for c in l]\n",
96 | " #dangling_length = len(text_raw) % seq_length\n",
97 | " #text_raw = text_raw[:-dangling_length]\n",
98 | "\n",
99 | "charset = sorted(list(set(text_raw)))\n",
100 | "c2i = {c: i for i, c in enumerate(charset)}\n",
101 | "i2c = {i: c for c, i in c2i.items()}\n",
102 | "text_idx = [c2i[c] for c in text_raw]\n",
103 | "print(len(text_idx), len(text_raw))\n",
104 | "\n",
105 | "#inputs = torch.Tensor([x for x in zip(*[text_idx[i::seq_length] for i in range(seq_length-1)])]).long()\n",
106 | "#targets = torch.Tensor(text_idx[(seq_length-1)::seq_length]).long()\n",
107 | "inputs = torch.Tensor(text_idx).long()\n",
108 | "print(inputs.size())\n",
109 | "\n",
110 | "#ds = data.TensorDataset(inputs, targets)\n",
111 | "ds = CharDataset(inputs, seq_length)\n",
112 | "dl = data.DataLoader(ds, batch_size=batch_size, drop_last=True)\n",
113 | "print(len(dl))"
114 | ]
115 | },
116 | {
117 | "cell_type": "code",
118 | "execution_count": 24,
119 | "metadata": {
120 | "collapsed": true
121 | },
122 | "outputs": [],
123 | "source": [
124 | "vocab_size = len(charset)\n",
125 | "num_batches = len(dl)\n",
126 | "epochs = 10"
127 | ]
128 | },
129 | {
130 | "cell_type": "code",
131 | "execution_count": 31,
132 | "metadata": {},
133 | "outputs": [
134 | {
135 | "name": "stdout",
136 | "output_type": "stream",
137 | "text": [
138 | "SimpleGRU (\n",
139 | " (emb): Embedding(65, 25)\n",
140 | " (gru): GRU(25, 100, batch_first=True)\n",
141 | " (fc1): Linear (2400 -> 65)\n",
142 | " (selu): SELU\n",
143 | " (logsoftmax): LogSoftmax ()\n",
144 | ")\n"
145 | ]
146 | }
147 | ],
148 | "source": [
149 | "model = SimpleGRU(vocab_size, emb_size, hid_size, batch_size, seq_length-1, n_layers)\n",
150 | "criterion = nn.NLLLoss()\n",
151 | "optimizer = optim.Adam(model.parameters(), lr=0.005)\n",
152 | "print(model)"
153 | ]
154 | },
155 | {
156 | "cell_type": "code",
157 | "execution_count": 32,
158 | "metadata": {
159 | "scrolled": false
160 | },
161 | "outputs": [
162 | {
163 | "data": {
164 | "application/vnd.jupyter.widget-view+json": {
165 | "model_id": "a34737486b714652b0ca1bebbbd82500"
166 | }
167 | },
168 | "metadata": {},
169 | "output_type": "display_data"
170 | },
171 | {
172 | "ename": "KeyboardInterrupt",
173 | "evalue": "",
174 | "output_type": "error",
175 | "traceback": [
176 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
177 | "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
178 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mh\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mh\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcriterion\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtgts\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 13\u001b[0;31m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 14\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[0mh\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetach_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
179 | "\u001b[0;32m~/miniconda3/lib/python3.6/site-packages/torch/autograd/variable.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph, retain_variables)\u001b[0m\n\u001b[1;32m 154\u001b[0m \u001b[0mVariable\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 155\u001b[0m \"\"\"\n\u001b[0;32m--> 156\u001b[0;31m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_variables\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 157\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 158\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mregister_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
180 | "\u001b[0;32m~/miniconda3/lib/python3.6/site-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(variables, grad_variables, retain_graph, create_graph, retain_variables)\u001b[0m\n\u001b[1;32m 96\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 97\u001b[0m Variable._execution_engine.run_backward(\n\u001b[0;32m---> 98\u001b[0;31m variables, grad_variables, retain_graph)\n\u001b[0m\u001b[1;32m 99\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 100\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
181 | "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
182 | ]
183 | }
184 | ],
185 | "source": [
186 | "batch_bar = tqdm_notebook(dl, desc=\"batches\", mininterval=0.9)\n",
187 | "\n",
188 | "for epoch in range(epochs):\n",
189 | " running_loss = 0\n",
190 | " for i, (mb, tgts) in enumerate(batch_bar):\n",
191 | " h = Variable(torch.zeros(n_layers,batch_size, hid_size))\n",
192 | " tgts.squeeze_()\n",
193 | " model.train()\n",
194 | " model.zero_grad()\n",
195 | " mb, tgts = Variable(mb), Variable(tgts)\n",
196 | " out, h = model(mb, h)\n",
197 | " loss = criterion(out, tgts)\n",
198 | " loss.backward()\n",
199 | " optimizer.step()\n",
200 | " h.detach_()\n",
201 | " running_loss += loss.data[0]\n",
202 | " if i % 25 == 0 or i == num_batches - 1:\n",
203 | " batch_bar.set_postfix(loss=(running_loss / (i+1)))\n",
204 | " torch.save(model.state_dict(), \"model_charrnn_{}.pt\".format(epoch+1))"
205 | ]
206 | },
207 | {
208 | "cell_type": "code",
209 | "execution_count": null,
210 | "metadata": {
211 | "collapsed": true
212 | },
213 | "outputs": [],
214 | "source": []
215 | }
216 | ],
217 | "metadata": {
218 | "kernelspec": {
219 | "display_name": "Python [default]",
220 | "language": "python",
221 | "name": "python3"
222 | },
223 | "language_info": {
224 | "codemirror_mode": {
225 | "name": "ipython",
226 | "version": 3
227 | },
228 | "file_extension": ".py",
229 | "mimetype": "text/x-python",
230 | "name": "python",
231 | "nbconvert_exporter": "python",
232 | "pygments_lexer": "ipython3",
233 | "version": "3.6.1"
234 | }
235 | },
236 | "nbformat": 4,
237 | "nbformat_minor": 2
238 | }
239 |
--------------------------------------------------------------------------------
/convolution1d_tests.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {
7 | "collapsed": true
8 | },
9 | "outputs": [],
10 | "source": [
11 | "import torch\n",
12 | "import torch.nn as nn\n",
13 | "import torch.nn.functional as F\n",
14 | "from torch.autograd import Variable\n",
15 | "import torchaudio"
16 | ]
17 | },
18 | {
19 | "cell_type": "code",
20 | "execution_count": 13,
21 | "metadata": {},
22 | "outputs": [
23 | {
24 | "data": {
25 | "text/plain": [
26 | "[Variable containing:\n",
27 | " (0 ,.,.) = \n",
28 | " \n",
29 | " Columns 0 to 18 \n",
30 | " 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18\n",
31 | " \n",
32 | " Columns 19 to 37 \n",
33 | " 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37\n",
34 | " [torch.FloatTensor of size 1x1x38], Variable containing:\n",
35 | " (0 ,.,.) = \n",
36 | " \n",
37 | " Columns 0 to 18 \n",
38 | " 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18\n",
39 | " \n",
40 | " Columns 19 to 37 \n",
41 | " 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 24\n",
42 | " [torch.FloatTensor of size 1x1x38], Variable containing:\n",
43 | " (0 ,.,.) = \n",
44 | " \n",
45 | " Columns 0 to 18 \n",
46 | " 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18\n",
47 | " \n",
48 | " Columns 19 to 37 \n",
49 | " 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 30 23 19\n",
50 | " [torch.FloatTensor of size 1x1x38], Variable containing:\n",
51 | " (0 ,.,.) = \n",
52 | " \n",
53 | " Columns 0 to 18 \n",
54 | " 1 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18\n",
55 | " \n",
56 | " Columns 19 to 37 \n",
57 | " 19 20 21 22 23 24 25 26 27 28 29 30 31 30 28 28 20 18 17\n",
58 | " [torch.FloatTensor of size 1x1x38], Variable containing:\n",
59 | " (0 ,.,.) = \n",
60 | " \n",
61 | " Columns 0 to 18 \n",
62 | " 1 2 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18\n",
63 | " \n",
64 | " Columns 19 to 37 \n",
65 | " 19 20 21 22 23 24 25 26 27 27 27 28 26 25 24 19 17 16 15\n",
66 | " [torch.FloatTensor of size 1x1x38]]"
67 | ]
68 | },
69 | "execution_count": 13,
70 | "metadata": {},
71 | "output_type": "execute_result"
72 | }
73 | ],
74 | "source": [
75 | "X = torch.arange(0, 38).view(1, 1, -1)\n",
76 | "X = Variable(X)\n",
77 | "ksize = 3\n",
78 | "conv1d1_3dN = []\n",
79 | "pad1dN = []\n",
80 | "for i in range(4):\n",
81 | " conv1d1_3dN += [nn.Conv1d(in_channels=1, out_channels=1, stride=1, padding=0, kernel_size=ksize, dilation=(i+1), bias=False)]\n",
82 | " pad1dN += [nn.ConstantPad1d(((i+1), (i+1)), 0)]\n",
83 | "outs = [X]\n",
84 | "for i, cv in enumerate(zip(conv1d1_3dN, pad1dN)):\n",
85 | " cv[0].weight.data.fill_(1.0 / ksize)\n",
86 | " model = nn.Sequential(cv[1], cv[0])\n",
87 | " #model = cv[0]\n",
88 | " out = model(outs[i])\n",
89 | " outs.append(out.long().float())\n",
90 | "outs"
91 | ]
92 | },
93 | {
94 | "cell_type": "code",
95 | "execution_count": 3,
96 | "metadata": {},
97 | "outputs": [],
98 | "source": []
99 | },
100 | {
101 | "cell_type": "code",
102 | "execution_count": null,
103 | "metadata": {
104 | "collapsed": true
105 | },
106 | "outputs": [],
107 | "source": []
108 | }
109 | ],
110 | "metadata": {
111 | "kernelspec": {
112 | "display_name": "Python [conda env:pytorch-dev]",
113 | "language": "python",
114 | "name": "conda-env-pytorch-dev-py"
115 | },
116 | "language_info": {
117 | "codemirror_mode": {
118 | "name": "ipython",
119 | "version": 3
120 | },
121 | "file_extension": ".py",
122 | "mimetype": "text/x-python",
123 | "name": "python",
124 | "nbconvert_exporter": "python",
125 | "pygments_lexer": "ipython3",
126 | "version": "3.6.2"
127 | }
128 | },
129 | "nbformat": 4,
130 | "nbformat_minor": 2
131 | }
132 |
--------------------------------------------------------------------------------
/coursera-algos-twosum.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {
7 | "collapsed": false
8 | },
9 | "outputs": [],
10 | "source": [
11 | "import numpy as np\n",
12 | "import scipy.stats\n",
13 | "import scipy\n",
14 | "import bisect\n",
15 | "\n"
16 | ]
17 | },
18 | {
19 | "cell_type": "code",
20 | "execution_count": 2,
21 | "metadata": {
22 | "collapsed": false
23 | },
24 | "outputs": [],
25 | "source": [
26 | "#raw1 = np.loadtxt(\"2sum.txt\", dtype=int)\n",
27 | "#np.savez_compressed(\"2sum.txt.npz\", data=raw1)\n",
28 | "with np.load('2sum.txt.npz') as data:\n",
29 | " raw1 = data['data']\n",
30 | "raw1.shape\n",
31 | "U = raw1"
32 | ]
33 | },
34 | {
35 | "cell_type": "code",
36 | "execution_count": 3,
37 | "metadata": {
38 | "collapsed": false
39 | },
40 | "outputs": [
41 | {
42 | "data": {
43 | "text/plain": [
44 | "-1.1990247437581012"
45 | ]
46 | },
47 | "execution_count": 3,
48 | "metadata": {},
49 | "output_type": "execute_result"
50 | }
51 | ],
52 | "source": [
53 | "scipy.stats.kurtosis(U)"
54 | ]
55 | },
56 | {
57 | "cell_type": "code",
58 | "execution_count": 4,
59 | "metadata": {
60 | "collapsed": false
61 | },
62 | "outputs": [
63 | {
64 | "name": "stdout",
65 | "output_type": "stream",
66 | "text": [
67 | "Mean: 1322010.15214\n",
68 | "Standard Deviation: 57723915587.6\n"
69 | ]
70 | }
71 | ],
72 | "source": [
73 | "U_mean = scipy.mean(U)\n",
74 | "U_std = scipy.std(U)\n",
75 | "\n",
76 | "print \"Mean:\", U_mean\n",
77 | "print \"Standard Deviation:\", U_std"
78 | ]
79 | },
80 | {
81 | "cell_type": "code",
82 | "execution_count": 5,
83 | "metadata": {
84 | "collapsed": false
85 | },
86 | "outputs": [],
87 | "source": [
88 | "U.sort()"
89 | ]
90 | },
91 | {
92 | "cell_type": "code",
93 | "execution_count": 80,
94 | "metadata": {
95 | "collapsed": false
96 | },
97 | "outputs": [
98 | {
99 | "name": "stdout",
100 | "output_type": "stream",
101 | "text": [
102 | "set([8195, 5308, 3445])\n"
103 | ]
104 | }
105 | ],
106 | "source": [
107 | "# this function takes advantage of the fact the numbers are uniformly distributed\n",
108 | "# thus the answer should simple be on the \"opposite\" side of the distribution\n",
109 | "# the function searches the other side of the distribution for a range of possible \n",
110 | "# y indexes for a given x index\n",
111 | "def find_yrange(U, i, i_mag, mag_fac = 100, mag_zoom = 20):\n",
112 | " # U: universe of numbers\n",
113 | " # i: x index\n",
114 | " # i_mag: y index search median\n",
115 | " # mag_fac: number of indices between range checks\n",
116 | " global t_bounds\n",
117 | " x_i = U[i]\n",
118 | " y_indices = [mag_fac*k + i_mag for k in range(-10, 11) if -1*(mag_fac*k + i_mag) > 0 and -1*(mag_fac*k + i_mag) < len(U)]\n",
119 | " y = U[y_indices]\n",
120 | " t = x_i + y\n",
121 | " t_min = t[0]\n",
122 | " t_max = t[-1]\n",
123 | " # check if the max or min is outside of the bounds\n",
124 | " if t_max < t_bounds[0] or t_min > t_bounds[1]:\n",
125 | " return None\n",
126 | " l_bound = y_indices[bisect.bisect_left(t, -t_bounds[0])-1]\n",
127 | " r_bound = y_indices[bisect.bisect_right(t, t_bounds[1])]\n",
128 | " m_bound = np.mean([l_bound, r_bound], dtype=int)\n",
129 | " mag_fac /= mag_zoom\n",
130 | " if mag_fac >= 5:\n",
131 | " l_bound, r_bound = find_yrange(U, i, m_bound, mag_fac)\n",
132 | " return l_bound, r_bound\n",
133 | "\n",
134 | "t_bounds = (-10000,10000)\n",
135 | "t_set = set()\n",
136 | "for i in range(1000,1010):\n",
137 | " bounds = find_yrange(U,i, -1*(i+1))\n",
138 | " if bounds:\n",
139 | " t_i = [U[i] + U[i_inv] for i_inv in range(bounds[0], bounds[1]+1) if U[i] + U[i_inv] > t_bounds[0] and U[i] + U[i_inv] < t_bounds[1]]\n",
140 | " if t_i: t_set.update(t_i)\n",
141 | "print t_set\n",
142 | "# set([8195, 5308, 3445])"
143 | ]
144 | },
145 | {
146 | "cell_type": "code",
147 | "execution_count": 70,
148 | "metadata": {
149 | "collapsed": false
150 | },
151 | "outputs": [
152 | {
153 | "name": "stdout",
154 | "output_type": "stream",
155 | "text": [
156 | "set([2048, 4097, 6146, 8195, 6145, 8194, 7171, -7637, -2795, 1024, -6613, -2794, 6890, -8104, -8103, -4006, -4005, 92, 2141, 6238, 8287, 8288, -747, 9125, 6891, -5588, -746, 3073, 4096, -9873, -8848, -4751, -2702, -653, -8011, -5962, -5961, -1864, 185, 2234, 2235, 6332, 8381, 2421, 4470, 9591, -3354, 5122, -1492, 2327, -9967, -7918, -7917, -5868, -3819, -1770, 279, 4376, 6425, 8474, 8475, 2328, 3351, -4286, -9127, -9500, -4285, 557, -6427, -4378, -1305, -9874, -7825, -5776, -3727, -3726, 371, 2420, 4469, 6518, 6519, 8568, 558, 3818, 7915, 7916, -2237, -6055, -9780, -7731, -5682, -1585, 464, 2513, 2514, 4563, 8660, 8661, 2607, 7449, -5030, -1211, -188, 3631, 7450, -7080, -5031, -6054, -1957, 1116, -7638, -5589, -3540, -1491, 2606, 4655, 4656, 8753, 8754, -8382, 3166, 6239, 9312, 838, 1395, 5681, 6704, -9594, -7545, -5496, -5495, -1398, 651, 2700, 4749, 8846, -5775, -1956, -6147, -933, 7728, 2887, -8569, -6707, -6706, -3633, -1584, -9501, -7452, -5403, -5402, 1489, 2792, 2793, 4842, 8939, 8940, 4562, -7824, 6611, 9777, 6612, -7544, 1117, 4936, -3541, -1678, 5960, 6983, -9408, -9407, -7358, -5309, -1212, 837, 2886, 4935, 9032, 9033, -1677, 2142, 4189, -7359, 1023, 4190, -5310, -3261, -3260, -9315, -9314, 1861, -5216, -3167, -1118, 931, 2980, 1862, 9126, -3447, 372, 5959, 6984, 8009, -7265, -3446, 1396, 5215, -9222, -7173, -5124, -3075, -3074, -1025, 3072, 5121, 7170, 9219, 9220, 3444, -5217, -1397, -6240, -6986, 650, -3913, -3912, 1209, -9128, -7079, -2982, -2981, -932, 3165, 5214, 7263, 7264, 9313, 5307, 5493, 5308, 7357, -2144, -6985, -2143, 1676, 2699, -9035, -9034, -4937, -2888, -839, 3258, 3259, 7356, 9405, 9406, -4938, 7542, -9779, -9687, -8662, -94, 8567, -4564, -2515, -8942, -8941, -4844, -466, 1302, 1303, 5400, 5401, 9498, 9499, -2889, 930, 3632, -8010, 6705, -7730, -5683, 1955, 5774, 6797, -8849, -6800, -6799, -4750, -2701, -652, 3445, 5494, 7543, 9592, -1863, -840, -7266, -4657, -3634, -6241, 7823, -3168, -1119, -8756, -8755, -4658, -2609, -560, -559, 3538, 2979, 9684, 9685, 186, 6052, 5028, 8101, -4191, 7078, -7451, -6428, 1210, 5029, -2608, -8663, -6614, -4565, -2516, -467, 1582, 1583, 5680, 7729, 9778, 7077, -9966, -6893, 4283, -3820, -1771, 278, -8570, -6521, -6520, -2423, -374, -373, 3724, 5773, 7822, 9871, 9872, 3352, 465, 4284, 4377, 6426, 93, -7172, -3353, -2329, 1490, -8477, -8476, -4379, -2330, -281, 1768, 1769, 5866, 5867, 9964, 9965, -1304, 8380, -9593, -280, 3539, -4472, -4471, -2422, -8383, -6334, 1675, -2236, -187, 3910, 3911, 8008, 4748, 744, 3725, 5586, 6798, 8847, 745, 5587, -6892, -5869, -2050, -8290, -8289, -4192, -95, 1954, 4003, 4004, 6053, 8102, -2049, -1026, 7635, -9686, 3817, 7636, -9221, -4843, -8196, 4841, -5123, -4098, -8197, -6148, -4099, -1, 2047])\n"
157 | ]
158 | }
159 | ],
160 | "source": [
161 | "t_bounds = (-10000,10000)\n",
162 | "t_set = set()\n",
163 | "for i in range(len(U)):\n",
164 | " bounds = find_yrange(U,i, -1*(i+1))\n",
165 | " if bounds:\n",
166 | " t_i = [U[i] + U[i_inv] for i_inv in range(bounds[0], bounds[1]+1) if U[i] + U[i_inv] > t_bounds[0] and U[i] + U[i_inv] < t_bounds[1]]\n",
167 | " if t_i: t_set.update(t_i)\n",
168 | "print t_set\n"
169 | ]
170 | },
171 | {
172 | "cell_type": "code",
173 | "execution_count": 66,
174 | "metadata": {
175 | "collapsed": false
176 | },
177 | "outputs": [
178 | {
179 | "name": "stdout",
180 | "output_type": "stream",
181 | "text": [
182 | "427\n"
183 | ]
184 | }
185 | ],
186 | "source": [
187 | "print len(t_set)"
188 | ]
189 | },
190 | {
191 | "cell_type": "markdown",
192 | "metadata": {},
193 | "source": [
194 | "\n",
195 | "The goal of this problem is to implement the \"Median Maintenance\" algorithm (covered in the Week 5 lecture on heap applications). The text file contains a list of the integers from 1 to 10000 in unsorted order; you should treat this as a stream of numbers, arriving one by one. Letting xi denote the ith number of the file, the kth median mk is defined as the median of the numbers x1,…,xk. (So, if k is odd, then mk is ((k+1)/2)th smallest number among x1,…,xk; if k is even, then mk is the (k/2)th smallest number among x1,…,xk.)\n",
196 | "\n",
197 | "In the box below you should type the sum of these 10000 medians, modulo 10000 (i.e., only the last 4 digits). That is, you should compute (m1+m2+m3+⋯+m10000)mod10000.\n",
198 | "\n",
199 | "OPTIONAL EXERCISE: Compare the performance achieved by heap-based and search-tree-based implementations of the algorithm.\n"
200 | ]
201 | },
202 | {
203 | "cell_type": "code",
204 | "execution_count": null,
205 | "metadata": {
206 | "collapsed": true
207 | },
208 | "outputs": [],
209 | "source": [
210 | "\n",
211 | "\n"
212 | ]
213 | }
214 | ],
215 | "metadata": {
216 | "kernelspec": {
217 | "display_name": "Python 2 (SageMath)",
218 | "language": "python",
219 | "name": "python2"
220 | },
221 | "language_info": {
222 | "codemirror_mode": {
223 | "name": "ipython",
224 | "version": 2
225 | },
226 | "file_extension": ".py",
227 | "mimetype": "text/x-python",
228 | "name": "python",
229 | "nbconvert_exporter": "python",
230 | "pygments_lexer": "ipython2",
231 | "version": "2.7.10"
232 | }
233 | },
234 | "nbformat": 4,
235 | "nbformat_minor": 0
236 | }
237 |
--------------------------------------------------------------------------------
/data/coursera_algos/algos-2sums.txt.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dhpollack/programming_notebooks/168936ac09071c00a61d56b35467d7d27ec61e60/data/coursera_algos/algos-2sums.txt.zip
--------------------------------------------------------------------------------
/data/coursera_algos/kargerMinCut.txt:
--------------------------------------------------------------------------------
1 | 193 26 112 191 62 195 25 91 39 71 158 131 46 5 38 37 185 18 118 124 56 156 58 86 53 108 68 175 87 120 15 42 73 99 4 75 151 64 35 48
2 | 1 37 79 164 155 32 87 39 113 15 18 78 175 140 200 4 160 97 191 100 91 20 69 198 196
3 | 2 123 134 10 141 13 12 43 47 3 177 101 179 77 182 117 116 36 103 51 154 162 128 30
4 | 3 48 123 134 109 41 17 159 49 136 16 130 141 29 176 2 190 66 153 157 70 114 65 173 104 194 54
5 | 4 91 171 118 125 158 76 107 18 73 140 42 193 127 100 84 121 60 81 99 80 150 55 1 35 23 93
6 | 5 193 156 102 118 175 39 124 119 19 99 160 75 20 112 37 23 145 135 146 73 35
7 | 6 155 56 52 120 131 160 124 119 14 196 144 25 75 76 166 35 87 26 20 32 23
8 | 7 156 185 178 79 27 52 144 107 78 22 71 26 31 15 56 76 112 39 8 113 93
9 | 8 185 155 171 178 108 64 164 53 140 25 100 133 9 52 191 46 20 150 144 39 62 131 42 119 127 31 7
10 | 9 91 155 8 160 107 132 195 26 20 133 39 76 100 78 122 127 38 156 191 196 115
11 | 10 190 184 154 49 2 182 173 170 161 47 189 101 153 50 30 109 177 148 179 16 163 116 13 90 185
12 | 11 123 134 163 41 12 28 130 13 101 83 77 109 114 21 82 88 74 24 94 48 33
13 | 12 161 109 169 21 24 36 65 50 2 101 159 148 54 192 88 47 11 142 43 70 182 177 179 189 194 33
14 | 13 161 141 157 44 83 90 181 41 2 176 10 29 116 134 182 170 165 173 190 159 47 82 111 142 72 154 110 21 103 130 11 33 138 152
15 | 14 91 156 58 122 62 113 107 73 137 25 19 40 6 139 150 46 37 76 39 127
16 | 15 149 58 68 52 39 67 121 191 1 45 100 18 118 174 40 85 196 122 42 193 119 139 26 127 145 135 57 38 7
17 | 16 48 10 36 187 43 3 114 173 111 142 129 88 189 117 128 147 141 194 180 106 167 179 66 74 136 51 59
18 | 17 48 123 134 36 163 3 44 117 167 161 152 95 170 83 180 77 65 72 109 47 43 88 159 197 28 194 181 49
19 | 18 193 149 56 62 15 160 67 191 140 52 178 96 107 132 1 145 89 198 4 26 73 151 126 34 115
20 | 19 156 80 178 164 108 84 71 174 40 62 113 22 89 45 91 126 195 144 5 14 172
21 | 20 185 122 171 56 8 52 73 191 67 126 9 119 1 89 79 107 96 31 75 55 5 6 34 23
22 | 21 188 187 12 173 180 197 138 167 63 111 95 13 192 116 94 114 105 49 177 51 130 90 11 50 66 157 176
23 | 22 156 27 32 131 7 56 53 81 149 23 100 146 115 26 175 121 96 75 57 39 119 71 132 19 150 140 93
24 | 23 91 122 124 22 200 195 145 5 69 125 55 68 156 20 58 191 4 57 149 6
25 | 24 123 134 161 163 169 72 116 167 30 33 77 162 143 159 187 63 184 130 28 50 153 12 148 11 53
26 | 25 193 185 79 108 8 158 87 73 81 115 39 64 178 132 27 68 127 84 14 52 200 97 6 93
27 | 26 193 58 27 108 52 144 160 18 84 81 22 75 139 166 15 107 198 131 7 9 133 6
28 | 27 156 139 144 166 112 100 26 174 31 42 75 158 122 81 22 7 58 73 89 115 39 25 200 69 169
29 | 28 134 188 24 184 159 29 72 114 152 116 169 173 141 17 111 61 192 90 11 177 179 77 33 66 83 136
30 | 29 48 134 188 13 47 88 3 82 92 28 194 50 192 189 123 199 177 147 43 106 148 197 77 103 129 181
31 | 30 165 123 10 24 41 187 47 168 92 148 197 101 50 2 179 111 130 77 153 199 70
32 | 31 27 171 56 131 146 139 191 89 20 108 38 71 75 69 196 149 97 8 86 98 7
33 | 32 156 149 171 62 22 185 35 124 56 38 158 97 53 121 160 1 191 58 89 127 87 120 39 99 84 60 151 174 6
34 | 33 48 161 109 141 24 187 47 88 168 183 110 103 95 116 28 12 11 13 83 134 63
35 | 34 37 122 171 118 76 131 166 137 40 46 97 87 80 164 127 18 62 52 20 139
36 | 35 79 164 125 32 107 137 75 121 85 55 69 45 193 132 4 5 200 135 76 139 198 6
37 | 36 165 188 17 106 88 16 177 110 147 154 159 179 136 41 50 141 66 162 152 168 184 12 43 72 180 190 77 2 170 61 122
38 | 37 193 149 39 121 191 115 146 52 127 79 198 58 125 38 34 1 76 89 164 97 86 178 108 87 84 124 98 174 195 14 5 57 196 186
39 | 38 193 37 86 32 76 107 73 85 127 100 46 89 31 57 96 158 99 160 45 15 9
40 | 39 193 37 122 102 8 158 32 87 85 81 200 60 5 27 155 1 58 150 15 113 76 84 22 25 151 139 100 14 145 9 7
41 | 40 91 156 122 79 118 125 52 175 87 15 81 166 132 121 19 14 160 34 78 71
42 | 41 36 169 184 116 163 106 189 11 104 61 30 123 129 111 3 47 49 154 161 152 13 153 65 92 183 177 162 95 54 70 108
43 | 42 178 79 27 53 171 164 102 52 87 113 15 191 131 91 62 193 8 122 89 56 4 127 145 112
44 | 43 165 161 12 70 199 54 17 190 16 153 141 36 47 44 194 110 82 189 2 148 183 29 130 94 170 51 61 59
45 | 44 188 163 169 17 13 43 114 173 142 154 103 129 181 105 157 148 182 101 110 66 176 49
46 | 45 156 80 149 58 178 53 108 68 56 125 15 93 75 135 174 198 81 166 113 100 19 89 35 97 38
47 | 46 193 58 86 122 155 8 175 160 99 127 67 14 150 144 126 146 34 131 55 38 196
48 | 47 123 10 109 41 17 12 43 116 59 33 13 2 187 165 88 117 29 30 176 147 180 101 130 194 50 94 152 70
49 | 48 180 128 188 197 105 51 94 190 116 29 183 114 153 33 16 49 3 63 184 17 141 168 179 162 11 66 83 193
50 | 49 48 123 10 186 141 41 168 3 148 142 179 21 136 109 44 117 17 103 187 74
51 | 50 165 123 10 188 36 169 24 187 12 65 29 167 47 21 134 130 111 168 77 116 138 106 66 30
52 | 51 165 48 109 141 163 159 92 190 143 2 21 138 43 59 192 117 16 184 104 169
53 | 52 37 178 8 6 15 200 133 80 102 96 40 119 164 166 127 151 20 42 7 26 76 18 73 99 78 25 132 139 191 150 34
54 | 53 156 122 102 193 137 133 42 62 45 64 60 78 160 132 155 56 144 131 196 178 125 8 32 113 22 98 121 198 24
55 | 54 123 187 12 43 168 65 129 130 147 95 41 61 59 141 3 138 114 199 66 110
56 | 55 68 56 125 113 73 78 200 46 131 85 107 89 185 60 84 4 35 99 171 71 20 23
57 | 56 193 122 53 108 45 55 18 125 86 171 79 6 85 133 80 140 96 31 107 20 32 87 120 42 73 84 22 112 196 7
58 | 57 79 155 158 76 84 22 75 121 85 191 100 97 15 37 102 69 38 64 195 145 23
59 | 58 15 146 107 193 140 84 144 86 14 150 26 60 46 166 80 87 139 195 126 45 37 27 102 62 32 175 39 124 23 93 188
60 | 59 186 163 187 47 168 92 143 147 157 54 51 61 199 43 103 63 184 16 188 197
61 | 60 58 86 178 53 171 125 39 144 146 73 124 4 93 32 99 158 132 80 91 96 75 174 55 196
62 | 61 188 116 41 114 183 28 77 54 36 163 187 147 179 65 63 190 43 186 59 189
63 | 62 193 185 79 53 171 102 146 84 198 14 58 137 8 166 175 32 113 18 115 196 42 200 132 121 19 112 133 172 34
64 | 63 48 24 116 183 111 167 21 162 90 181 147 105 106 101 33 123 153 184 128 168 61 59
65 | 64 91 149 122 53 8 120 113 124 119 25 132 131 193 121 144 68 185 108 175 107 57
66 | 65 186 17 12 114 82 3 66 159 167 197 50 101 176 94 54 134 41 165 157 194 180 77 103 74 61
67 | 66 188 36 116 3 65 183 72 180 77 16 136 177 82 159 28 95 48 50 187 21 44 54 176
68 | 67 91 156 102 68 175 131 144 15 18 146 119 20 200 46 112 139 75 80 86 137
69 | 68 122 171 55 78 175 96 75 71 93 118 195 115 67 79 45 158 15 120 155 193 87 146 81 25 64 23
70 | 69 91 108 125 131 81 75 85 174 145 89 27 200 35 156 1 98 86 23 191 127 31 57
71 | 70 165 123 163 153 12 43 168 3 114 82 148 190 129 74 176 47 110 181 41 30
72 | 71 193 79 171 68 124 98 120 73 75 93 151 108 89 155 19 96 22 119 7 125 85 127 40 55 140 31
73 | 72 109 169 24 116 17 114 182 190 141 186 192 28 104 13 66 165 162 36 159 77 110 129 130 181
74 | 73 91 80 27 160 4 20 100 56 193 60 38 18 25 172 76 14 55 52 149 96 78 71 195 127 5 112 97 93
75 | 74 88 104 197 111 130 95 11 138 123 77 159 65 179 94 165 70 141 16 153 136 83 49
76 | 75 91 79 27 171 68 158 87 81 22 119 71 166 57 60 35 160 69 175 26 193 45 127 67 126 89 20 5 31 198 6
77 | 76 37 178 175 120 122 57 52 34 4 156 98 115 133 131 171 149 85 38 174 39 73 99 14 140 35 135 172 9 7 6
78 | 77 24 116 17 72 190 110 36 173 143 94 65 29 61 154 2 161 90 66 159 28 128 117 11 50 138 74 30
79 | 78 156 80 53 164 68 131 144 107 73 195 55 175 1 126 7 149 171 81 91 52 124 40 9
80 | 79 80 37 158 40 75 108 198 25 62 42 7 86 57 102 122 145 175 71 1 35 164 68 118 56 144 200 139 20 112 135 172 163
81 | 80 91 144 100 118 45 78 139 112 149 98 113 87 195 124 85 73 119 19 79 99 58 56 52 146 81 4 60 137 67 126 145 97 34 134
82 | 81 86 27 120 39 131 40 132 25 69 96 26 127 108 75 68 135 98 80 115 149 78 22 4 45 172
83 | 82 161 186 109 153 43 159 114 101 104 70 29 197 192 116 148 65 152 184 13 154 92 110 130 11 147 66 176
84 | 83 17 13 153 88 111 90 117 95 11 33 147 28 109 199 48 74 167 182 154 101
85 | 84 149 58 86 178 171 62 175 113 26 135 96 198 39 19 57 144 37 118 32 56 119 4 98 25 200 150 55
86 | 85 156 80 185 155 102 56 158 39 76 15 69 38 150 175 118 137 71 35 120 57 55 135
87 | 86 58 84 120 81 112 37 60 46 185 193 38 125 118 175 149 99 108 126 133 102 79 56 144 67 69 31 109
88 | 87 80 149 58 118 25 75 126 115 68 178 37 200 32 40 164 56 42 193 107 1 39 113 145 89 172 6 34 93
89 | 88 36 116 12 47 165 168 29 83 159 154 74 148 92 176 180 33 110 194 167 17 114 173 16 11
90 | 89 37 27 102 32 42 18 99 71 151 19 55 75 98 131 69 45 31 38 195 87 20 135 115
91 | 90 163 116 13 173 28 110 190 77 129 117 130 10 179 92 134 194 21 63 83 157 94 152 199
92 | 91 126 14 75 135 107 80 102 67 160 158 144 23 64 155 4 198 40 9 73 69 193 185 149 164 42 78 60 98 19 1 165
93 | 92 161 109 163 187 88 186 141 51 117 197 59 173 192 82 29 143 116 41 30 111 148 129 90 194 152 170
94 | 93 68 99 71 60 45 150 145 98 25 87 122 178 185 7 144 4 22 73 58 112
95 | 94 48 163 65 173 182 154 77 21 117 11 147 74 189 43 114 47 177 192 104 90
96 | 95 123 141 169 17 168 114 179 21 33 165 167 177 41 103 83 199 129 182 188 74 66 157 152 54 176
97 | 96 171 164 68 56 52 18 73 84 81 124 22 71 60 126 150 20 97 38 149 158 196 115
98 | 97 185 149 37 178 122 108 118 32 131 1 31 140 73 34 45 25 80 133 57 96
99 | 98 80 171 164 158 120 76 99 81 71 37 91 151 144 172 140 150 160 121 84 53 139 89 69 31 133 93
100 | 99 80 86 178 122 118 125 52 193 32 174 46 113 98 93 89 150 102 76 100 124 4 60 55 5 38 196
101 | 100 80 185 149 27 8 113 131 15 73 99 22 4 200 45 102 172 57 164 38 39 1 112 9
102 | 101 165 10 188 109 141 163 187 12 168 82 65 148 180 47 167 2 162 169 192 30 179 11 44 83 170 63
103 | 102 91 86 79 53 171 155 67 113 39 137 158 196 166 120 42 89 58 5 119 85 62 52 146 99 121 100 57
104 | 103 161 188 186 184 153 159 167 189 2 168 13 29 104 142 33 65 197 117 148 44 95 181 59 49
105 | 104 134 188 186 141 184 41 187 168 114 82 148 192 194 154 74 3 190 161 105 163 72 103 94 170 51
106 | 105 165 48 161 116 159 104 21 129 63 44 169 106 142 190 181 143 179 157 173 188 199 176
107 | 106 36 163 169 184 41 159 29 197 111 16 179 162 105 138 148 50 165 63 173 109
108 | 107 91 58 56 125 158 87 191 133 9 38 122 164 175 135 4 35 14 7 185 78 18 146 151 26 64 126 55 20 112 198 172
109 | 108 86 178 79 164 37 25 195 144 120 198 26 56 19 132 69 193 115 45 172 97 155 8 81 71 166 139 174 64 31 41
110 | 109 10 47 116 12 128 51 167 180 179 33 123 101 142 92 143 199 154 3 72 82 141 17 168 114 173 183 148 189 11 181 106 83 152 49 86
111 | 110 165 186 36 163 169 43 88 168 33 138 147 13 167 189 184 134 72 90 188 82 77 44 54 70
112 | 111 41 92 114 162 188 129 187 170 152 83 142 74 157 16 13 138 186 106 63 184 28 21 50 30 192
113 | 112 193 80 86 27 137 174 67 5 79 198 132 127 42 131 100 73 107 56 135 62 7 93
114 | 113 156 80 149 171 102 62 39 133 84 87 198 1 53 64 55 172 122 100 42 14 160 99 124 137 45 19 145 7
115 | 114 48 104 136 16 61 72 181 128 82 88 138 109 129 70 192 3 153 65 95 44 111 28 21 11 94 54
116 | 115 185 37 27 108 68 62 87 120 76 81 22 25 151 139 191 140 9 89 96 18
117 | 116 48 188 109 24 157 61 41 181 128 66 88 63 152 189 168 90 105 72 77 10 13 47 82 173 92 142 28 180 2 21 117 50 33 136 164
118 | 117 17 47 92 16 179 2 103 90 194 189 192 94 143 170 162 83 116 129 184 77 138 152 51 49
119 | 118 193 156 80 86 155 68 196 145 40 4 97 79 5 99 87 149 174 34 137 166 131 15 160 84 121 85
120 | 119 80 178 102 52 160 124 84 185 6 22 67 174 8 121 133 5 75 64 15 150 71 151 145 20
121 | 120 185 86 171 108 102 68 200 193 126 64 160 32 150 6 115 98 81 56 191 76 124 71 139 85 195 135
122 | 121 156 185 37 32 15 146 22 119 4 98 151 57 62 126 53 139 40 102 35 118 64 135 133
123 | 122 156 146 23 99 196 46 14 56 53 175 164 20 68 39 195 126 97 40 34 64 79 27 155 113 76 131 15 42 107 166 151 9 93 36
124 | 123 165 194 50 70 95 11 153 30 24 17 49 47 54 2 189 147 199 3 188 197 186 109 41 29 129 130 128 74 136 63 156
125 | 124 193 80 125 32 146 99 23 5 60 113 120 191 6 58 64 119 71 37 78 96 137
126 | 125 37 86 56 4 133 69 178 132 53 172 45 60 151 40 198 55 35 107 99 124 71 137 127 23
127 | 126 91 58 86 122 155 87 120 146 78 200 121 19 80 20 137 75 96 107 18 156 46 135
128 | 127 37 178 32 52 131 160 81 4 25 191 125 38 71 15 8 42 75 46 174 73 14 69 112 172 9 34
129 | 128 48 134 161 109 116 187 114 183 16 179 189 130 181 142 123 138 2 136 77 152 63
130 | 129 165 184 41 153 114 148 111 177 16 180 44 29 123 54 173 72 92 70 199 105 90 117 95 157
131 | 130 186 169 24 3 154 189 21 136 72 82 74 190 11 47 13 173 54 181 43 123 90 128 50 138 30 176
132 | 131 193 53 76 78 22 69 97 118 178 122 140 81 34 164 8 100 127 67 31 6 42 146 200 195 26 64 46 55 89 112
133 | 132 185 149 53 171 108 125 18 81 60 25 139 191 178 64 52 9 62 22 40 164 35 112
134 | 133 86 53 56 125 8 52 113 76 107 119 195 97 156 98 121 62 140 26 175 9
135 | 134 153 11 17 182 165 197 138 173 24 28 104 128 29 2 167 152 183 3 162 154 186 13 65 142 110 90 50 33 157 80
136 | 135 91 171 164 107 84 81 45 5 35 112 120 15 151 85 76 79 178 121 126 89
137 | 136 165 36 187 3 114 173 190 130 128 66 74 152 123 154 116 28 49 153 16 169
138 | 137 156 185 53 155 102 118 62 166 175 200 151 112 113 14 124 80 146 35 125 34 85 67 126 145 172
139 | 138 134 186 153 114 111 142 110 179 21 128 130 117 197 169 77 13 74 51 180 50 106 54
140 | 139 80 58 27 164 132 121 108 115 140 31 79 98 15 160 52 26 171 39 195 120 67 14 35 196 34
141 | 140 185 58 56 8 175 131 144 18 4 166 98 139 1 195 22 156 71 76 178 115 97 133
142 | 141 109 104 184 2 143 188 51 49 33 197 48 176 95 159 148 101 154 13 170 190 36 43 3 92 72 16 28 167 180 147 74 54 178
143 | 142 161 109 169 12 183 111 163 187 162 180 16 165 116 184 138 44 13 49 134 143 177 103 128 181 105 192
144 | 143 165 161 109 141 24 184 168 92 142 154 180 51 147 179 176 183 189 59 159 167 77 117 105 152
145 | 144 91 80 58 178 27 53 164 108 8 158 60 79 67 155 78 140 7 86 26 174 84 98 64 19 46 145 6 93
146 | 145 178 79 155 118 18 151 69 39 113 42 119 93 80 144 87 15 137 23 158 166 5 57 172
147 | 146 149 58 37 122 62 60 68 107 31 67 172 195 121 131 102 185 174 124 126 80 22 137 46 5
148 | 147 123 186 36 47 183 29 182 197 16 167 110 143 168 83 82 141 63 94 54 59 157 170 61
149 | 148 10 188 186 141 163 12 43 88 82 109 192 70 161 30 129 24 92 49 177 29 104 197 154 101 103 162 181 106 44 176
150 | 149 156 80 200 146 15 185 87 91 100 45 37 178 18 32 64 132 97 84 150 113 86 118 76 73 81 78 22 31 96 23 161
151 | 150 149 58 8 175 120 39 99 119 166 98 200 85 14 4 46 96 84 22 93 52
152 | 151 125 52 160 71 166 98 137 145 122 171 119 193 175 32 121 107 115 18 39 89 135 196
153 | 152 134 161 36 116 41 17 82 111 28 128 157 13 92 47 143 117 109 90 95 136 170
154 | 153 165 48 123 134 10 186 138 187 103 82 70 163 129 188 170 173 24 83 167 41 43 3 114 154 180 74 136 63 30
155 | 154 134 10 109 141 36 41 88 159 82 104 143 148 197 44 13 94 130 162 177 153 167 77 2 83 136 170
156 | 155 91 185 53 164 108 145 171 200 6 122 8 9 85 57 1 137 126 102 118 46 68 158 39 144 160 71 172
157 | 156 7 53 113 45 14 193 122 40 67 19 121 22 118 78 5 137 166 149 32 85 178 27 158 76 126 140 69 133 9 23 123
158 | 157 184 116 13 3 65 111 90 105 187 147 152 134 59 95 186 197 44 170 129 21
159 | 158 91 193 79 27 102 68 32 75 156 107 185 144 4 57 39 155 98 171 85 25 175 60 145 38 96
160 | 159 188 141 36 24 13 12 88 103 105 179 177 17 3 182 51 82 28 161 154 106 65 183 72 143 77 66 74
161 | 160 91 185 53 32 120 155 6 9 18 127 171 118 113 119 151 26 46 195 200 73 98 75 139 40 1 5 198 38
162 | 161 165 10 24 173 142 182 162 92 13 152 33 82 189 43 143 128 103 186 105 12 188 184 41 17 159 148 104 180 77 194 170 149
163 | 162 134 161 188 186 36 24 173 182 111 142 72 154 101 48 148 2 179 106 63 41 117
164 | 163 177 92 10 148 101 44 169 59 110 188 17 51 90 186 94 11 24 70 106 182 41 153 187 104 142 190 61 79
165 | 164 185 37 122 175 139 144 91 35 19 96 98 78 155 1 108 198 42 135 79 195 8 52 87 131 107 166 132 100 34 116
166 | 165 30 105 36 101 167 183 43 50 70 189 181 190 161 123 136 51 143 129 153 110 134 169 13 47 88 65 142 72 95 106 74 91
167 | 166 156 58 27 102 118 62 52 191 140 164 40 122 151 108 75 174 137 196 150 34 200 45 26 145 6
168 | 167 165 134 109 24 17 153 187 88 168 65 182 177 154 16 194 141 103 50 147 63 110 101 179 143 21 95 83
169 | 168 36 169 116 88 48 70 30 33 101 104 109 110 143 181 167 54 199 49 59 95 177 180 103 50 147 63
170 | 169 163 50 184 12 176 106 41 194 173 168 110 142 44 95 130 24 165 182 72 177 197 28 101 138 105 136 51 27
171 | 170 10 186 141 17 13 153 111 117 157 92 161 152 147 104 179 154 43 101 36 199
172 | 171 34 60 68 31 20 42 62 4 120 71 32 135 113 175 84 102 132 96 98 75 155 56 8 158 76 160 78 151 139 55 184
173 | 172 185 108 125 113 146 73 98 174 100 155 81 79 76 127 87 137 107 19 62 145
174 | 173 134 10 161 169 13 153 109 162 44 21 88 182 3 190 179 136 90 16 116 94 92 28 77 129 130 105 106
175 | 174 27 118 175 76 144 15 146 99 119 166 45 32 19 172 191 60 108 37 69 112 127
176 | 175 86 122 79 171 164 68 62 84 140 5 40 196 46 174 150 67 58 76 193 158 107 78 22 137 75 151 85 195 64 1 133
177 | 176 141 169 13 47 88 3 65 183 177 143 148 82 105 130 66 181 44 21 182 95 70
178 | 177 10 36 163 169 159 29 148 190 12 199 167 176 180 2 192 168 142 41 179 129 154 28 21 95 66 94
179 | 178 149 76 19 52 127 108 45 156 60 198 99 37 144 7 97 145 84 119 42 191 53 125 8 87 131 18 25 132 195 140 135 93 141
180 | 179 10 109 36 159 173 177 49 117 12 48 101 189 2 138 128 95 106 28 167 16 143 162 90 105 74 170 199 30 61
181 | 180 48 109 17 88 142 177 16 65 194 36 153 116 161 66 143 141 129 47 168 21 101 181 138
182 | 181 165 188 184 116 13 168 114 182 130 128 44 103 148 63 29 142 180 17 109 72 105 176 70
183 | 182 134 10 161 163 169 184 13 159 173 183 147 189 162 181 12 72 167 94 199 192 2 95 44 83 176
184 | 183 165 48 134 188 186 63 33 147 43 61 109 142 159 66 41 128 182 176 190 184 143 189
185 | 184 48 10 141 169 24 143 41 161 181 36 129 197 104 186 182 28 106 157 103 199 82 183 111 142 110 117 63 51 59 171
186 | 185 193 140 8 97 7 121 160 155 85 172 91 100 164 132 25 120 20 62 137 115 149 86 32 158 107 146 119 64 55 93 10
187 | 186 161 104 199 162 148 110 138 183 153 103 192 82 130 123 59 49 147 134 65 170 163 184 187 92 197 111 72 157 61 37
188 | 187 24 153 186 92 189 54 33 59 21 167 163 101 136 50 197 16 104 128 192 30 47 111 142 66 157 61 49
189 | 188 48 123 50 104 36 101 148 66 161 44 28 21 162 116 183 103 159 181 61 29 141 163 153 111 110 190 95 105 59 58
190 | 189 165 123 10 161 116 41 187 43 29 182 197 16 110 179 143 12 183 103 199 109 130 128 194 117 94 61
191 | 190 165 48 10 141 13 43 3 173 183 104 177 72 36 77 188 90 136 51 163 70 130 105 61
192 | 191 193 37 178 8 32 120 15 42 107 18 124 166 132 174 20 52 57 31 115 127 1 69 9 23
193 | 192 186 187 12 114 82 92 29 148 104 182 177 72 28 101 21 117 94 51 111 142
194 | 194 123 169 43 88 65 29 104 16 167 180 90 189 161 199 47 92 3 12 17 117
195 | 195 193 80 58 122 164 108 68 160 146 78 139 23 120 178 133 73 175 37 131 9 19 89 140 57
196 | 196 122 53 102 118 62 175 15 166 31 9 37 6 99 151 56 139 46 1 96 60
197 | 197 48 123 134 141 184 187 82 65 92 106 148 147 199 29 17 30 189 186 74 169 154 21 103 138 157 59
198 | 198 91 37 178 79 164 108 125 62 113 18 84 45 26 112 35 107 160 53 1 75
199 | 199 123 186 109 184 43 168 29 182 197 177 189 129 194 95 83 170 54 105 90 179 30 59
200 | 200 149 155 52 87 120 39 160 137 27 79 131 100 25 55 23 126 84 166 150 62 67 1 69 35
201 |
--------------------------------------------------------------------------------
/data/coursera_algos/knapsack1.txt:
--------------------------------------------------------------------------------
1 | 10000 100
2 | 16808 250
3 | 50074 659
4 | 8931 273
5 | 27545 879
6 | 77924 710
7 | 64441 166
8 | 84493 43
9 | 7988 504
10 | 82328 730
11 | 78841 613
12 | 44304 170
13 | 17710 158
14 | 29561 934
15 | 93100 279
16 | 51817 336
17 | 99098 827
18 | 13513 268
19 | 23811 634
20 | 80980 150
21 | 36580 822
22 | 11968 673
23 | 1394 337
24 | 25486 746
25 | 25229 92
26 | 40195 358
27 | 35002 154
28 | 16709 945
29 | 15669 491
30 | 88125 197
31 | 9531 904
32 | 27723 667
33 | 28550 25
34 | 97802 854
35 | 40978 409
36 | 8229 934
37 | 60299 982
38 | 28636 14
39 | 23866 815
40 | 39064 537
41 | 39426 670
42 | 24116 95
43 | 75630 502
44 | 46518 196
45 | 30106 405
46 | 19452 299
47 | 82189 124
48 | 99506 883
49 | 6753 567
50 | 36717 338
51 | 54439 145
52 | 51502 898
53 | 83872 829
54 | 11138 359
55 | 53178 398
56 | 22295 905
57 | 21610 232
58 | 59746 176
59 | 53636 299
60 | 98143 400
61 | 27969 413
62 | 261 558
63 | 41595 9
64 | 16396 969
65 | 19114 531
66 | 71007 963
67 | 97943 366
68 | 42083 853
69 | 30768 822
70 | 85696 713
71 | 73672 902
72 | 48591 832
73 | 14739 58
74 | 31617 791
75 | 55641 680
76 | 37336 7
77 | 97973 99
78 | 49096 320
79 | 83455 224
80 | 12290 761
81 | 48906 127
82 | 36124 507
83 | 45814 771
84 | 35239 95
85 | 96221 845
86 | 12367 535
87 | 25227 395
88 | 41364 739
89 | 7845 591
90 | 36551 160
91 | 8624 948
92 | 97386 218
93 | 95273 540
94 | 99248 386
95 | 13497 886
96 | 40624 421
97 | 28145 969
98 | 35736 916
99 | 61626 535
100 | 46043 12
101 | 54680 153
102 |
--------------------------------------------------------------------------------
/data/coursera_algos/knapsack_big.txt:
--------------------------------------------------------------------------------
1 | 2000000 2000
2 | 16808 241486
3 | 50074 834558
4 | 8931 738037
5 | 27545 212860
6 | 77924 494349
7 | 64441 815107
8 | 84493 723724
9 | 7988 421316
10 | 82328 652893
11 | 78841 402599
12 | 44304 631607
13 | 17710 318556
14 | 29561 608119
15 | 93100 429390
16 | 51817 431959
17 | 99098 365899
18 | 13513 90282
19 | 23811 467558
20 | 80980 743542
21 | 36580 896948
22 | 11968 883369
23 | 1394 604449
24 | 25486 562244
25 | 25229 333236
26 | 40195 157443
27 | 35002 696933
28 | 16709 539123
29 | 15669 202584
30 | 88125 759690
31 | 9531 69730
32 | 27723 110467
33 | 28550 651058
34 | 97802 231944
35 | 40978 186803
36 | 8229 572887
37 | 60299 797491
38 | 28636 692529
39 | 23866 503157
40 | 39064 243688
41 | 39426 182032
42 | 24116 262772
43 | 75630 246257
44 | 46518 605917
45 | 30106 685556
46 | 19452 522241
47 | 82189 419114
48 | 99506 622799
49 | 6753 512332
50 | 36717 489578
51 | 54439 869423
52 | 51502 771067
53 | 83872 884103
54 | 11138 450309
55 | 53178 444804
56 | 22295 405168
57 | 21610 527012
58 | 59746 255432
59 | 53636 702071
60 | 98143 264607
61 | 27969 227173
62 | 261 471962
63 | 41595 807033
64 | 16396 471595
65 | 19114 520773
66 | 71007 21653
67 | 97943 882635
68 | 42083 801161
69 | 30768 806299
70 | 85696 402599
71 | 73672 858413
72 | 48591 122945
73 | 14739 256166
74 | 31617 350118
75 | 55641 783545
76 | 37336 518571
77 | 97973 17616
78 | 49096 640048
79 | 83455 74134
80 | 12290 610321
81 | 48906 141662
82 | 36124 589402
83 | 45814 638947
84 | 35239 751616
85 | 96221 533618
86 | 12367 850339
87 | 25227 670142
88 | 41364 699869
89 | 7845 269378
90 | 36551 198914
91 | 8624 98356
92 | 97386 122211
93 | 95273 646654
94 | 99248 231210
95 | 13497 806666
96 | 40624 153773
97 | 28145 427188
98 | 35736 90282
99 | 61626 23855
100 | 46043 585365
101 | 54680 654728
102 | 75245 593439
103 | 78819 117073
104 | 87693 903554
105 | 12992 357458
106 | 22670 469393
107 | 89554 642617
108 | 75826 568850
109 | 29663 532884
110 | 33809 531783
111 | 22762 171022
112 | 37336 833090
113 | 77996 467191
114 | 62768 784279
115 | 29875 716017
116 | 40557 519305
117 | 68873 549766
118 | 20095 435262
119 | 13756 171022
120 | 51408 709411
121 | 30194 111935
122 | 15681 116706
123 | 16856 669408
124 | 70964 512332
125 | 86677 892544
126 | 63250 836026
127 | 77845 418747
128 | 60809 531049
129 | 28652 532884
130 | 45204 478201
131 | 96532 514167
132 | 95420 866487
133 | 42010 780976
134 | 87930 258368
135 | 44055 202217
136 | 76738 877864
137 | 47318 901352
138 | 4201 634910
139 | 21565 69730
140 | 6649 16515
141 | 25551 203685
142 | 41977 835292
143 | 38555 736202
144 | 2505 215429
145 | 34749 905022
146 | 59379 171389
147 | 78210 744276
148 | 78554 485174
149 | 50448 394158
150 | 80774 697300
151 | 64306 324061
152 | 70054 724825
153 | 84631 377643
154 | 5401 191941
155 | 95371 695098
156 | 83017 107164
157 | 8156 807033
158 | 15558 218365
159 | 53179 670142
160 | 87358 219466
161 | 27879 757121
162 | 74820 194143
163 | 83134 660600
164 | 23721 531783
165 | 6634 32663
166 | 76032 528480
167 | 60590 884470
168 | 98878 569584
169 | 79359 553436
170 | 77255 80373
171 | 14916 269378
172 | 45481 776205
173 | 69956 608853
174 | 37108 143497
175 | 40001 213961
176 | 45036 344613
177 | 61114 472696
178 | 29594 539123
179 | 98355 675647
180 | 25358 809235
181 | 13730 649223
182 | 76564 218365
183 | 60918 373606
184 | 19420 63491
185 | 47748 105696
186 | 55396 600045
187 | 4474 913096
188 | 11749 417279
189 | 5659 576190
190 | 45407 777306
191 | 39825 916399
192 | 974 70097
193 | 46898 715650
194 | 3951 50646
195 | 88481 562978
196 | 13901 764828
197 | 62534 772535
198 | 17004 569217
199 | 20951 344246
200 | 58781 576557
201 | 41833 484073
202 | 33550 262405
203 | 6250 548665
204 | 66311 846302
205 | 984 237082
206 | 92041 86245
207 | 73651 451043
208 | 53230 873827
209 | 17747 340209
210 | 87231 470861
211 | 70777 354889
212 | 68245 560409
213 | 47697 456181
214 | 29065 364064
215 | 56599 853642
216 | 69941 86245
217 | 89552 891443
218 | 27206 238183
219 | 80419 708677
220 | 59650 713448
221 | 83321 593439
222 | 22013 800794
223 | 54789 811070
224 | 24009 64592
225 | 65325 643718
226 | 51258 547197
227 | 35691 770333
228 | 62411 360027
229 | 90554 384983
230 | 90961 242220
231 | 51404 456181
232 | 7308 263873
233 | 44911 44774
234 | 92800 210658
235 | 20172 702438
236 | 34415 522608
237 | 16642 883002
238 | 55349 229375
239 | 14316 229375
240 | 81446 587567
241 | 72221 702071
242 | 9240 681152
243 | 41306 439666
244 | 82254 569217
245 | 15167 44040
246 | 59305 492514
247 | 45899 274516
248 | 1721 647388
249 | 51581 173591
250 | 49353 713448
251 | 78699 803363
252 | 60812 626469
253 | 56892 885204
254 | 31729 70097
255 | 51662 300940
256 | 3121 193409
257 | 17159 448107
258 | 69628 266075
259 | 53712 737303
260 | 29595 767030
261 | 1728 778774
262 | 98796 623166
263 | 37443 173958
264 | 85007 28993
265 | 66925 438198
266 | 27098 28993
267 | 67104 736202
268 | 80109 49178
269 | 2282 51013
270 | 43711 456548
271 | 1906 245890
272 | 67782 447740
273 | 70792 346815
274 | 97866 424252
275 | 29537 643351
276 | 57126 605917
277 | 80168 124413
278 | 477 792353
279 | 21577 216530
280 | 77468 571419
281 | 27385 373973
282 | 25211 910160
283 | 28212 37434
284 | 19519 788316
285 | 21935 671977
286 | 80768 843366
287 | 49819 30461
288 | 11826 765562
289 | 58029 693263
290 | 78035 197813
291 | 51241 914564
292 | 86541 80006
293 | 90936 640415
294 | 48049 103127
295 | 87012 148268
296 | 53065 755286
297 | 17238 601146
298 | 70246 643351
299 | 70416 446272
300 | 24711 83676
301 | 63038 601146
302 | 68448 270846
303 | 41694 471962
304 | 67130 793454
305 | 33980 25323
306 | 95526 631974
307 | 74959 279287
308 | 11801 837127
309 | 87671 310482
310 | 80847 29727
311 | 9544 481504
312 | 85937 343145
313 | 97720 712347
314 | 91248 681519
315 | 28360 177628
316 | 45036 308647
317 | 83913 383148
318 | 45990 226072
319 | 63735 870891
320 | 520 608486
321 | 67046 170288
322 | 26662 853642
323 | 82820 815474
324 | 22904 485174
325 | 36799 171022
326 | 14005 75969
327 | 680 373239
328 | 7392 11377
329 | 64925 211025
330 | 57122 750148
331 | 77107 175426
332 | 51704 577658
333 | 71033 300206
334 | 61655 729596
335 | 42218 330300
336 | 57290 271580
337 | 84748 344246
338 | 89880 890342
339 | 1517 266809
340 | 82763 677482
341 | 49040 687391
342 | 5904 592338
343 | 10762 475632
344 | 61467 211025
345 | 44558 659866
346 | 26956 747579
347 | 21495 751983
348 | 92222 148635
349 | 2294 546096
350 | 42233 803730
351 | 34792 782811
352 | 49153 468292
353 | 73571 862083
354 | 27440 422784
355 | 35800 547931
356 | 91836 119642
357 | 46044 903187
358 | 74087 11744
359 | 39151 263873
360 | 18624 324428
361 | 9606 760057
362 | 37312 718953
363 | 77993 321125
364 | 38700 550133
365 | 80640 181665
366 | 25837 614358
367 | 82858 9542
368 | 83329 456548
369 | 74411 253964
370 | 85177 320391
371 | 72092 745377
372 | 84712 129551
373 | 91027 364798
374 | 10791 740973
375 | 59365 256900
376 | 24885 536554
377 | 41442 323694
378 | 16180 252863
379 | 94538 438198
380 | 87562 642250
381 | 74166 892911
382 | 57733 838228
383 | 47855 547197
384 | 9848 299105
385 | 40463 519672
386 | 43820 80006
387 | 13378 225338
388 | 43837 82942
389 | 32300 419114
390 | 34760 867588
391 | 74391 500955
392 | 30049 203685
393 | 59045 26424
394 | 40151 554904
395 | 41296 197079
396 | 95732 58720
397 | 28246 573254
398 | 46058 332135
399 | 93929 177995
400 | 62174 629038
401 | 9670 147534
402 | 4504 873827
403 | 18676 450676
404 | 24301 914197
405 | 40263 659132
406 | 68264 545729
407 | 24890 524443
408 | 76736 863551
409 | 31365 262772
410 | 50964 410306
411 | 85653 263506
412 | 57317 293233
413 | 61421 58353
414 | 7451 223136
415 | 78742 240385
416 | 99197 178729
417 | 2502 619863
418 | 93129 860248
419 | 77449 41104
420 | 17254 534352
421 | 69330 184601
422 | 33720 367367
423 | 38821 862817
424 | 81080 44407
425 | 94411 366266
426 | 66114 185702
427 | 98139 389020
428 | 46045 216163
429 | 98394 321859
430 | 981 538022
431 | 45597 208089
432 | 67371 397461
433 | 6544 836393
434 | 62853 536921
435 | 76825 856211
436 | 82381 835292
437 | 71858 673445
438 | 59499 253597
439 | 52846 637479
440 | 68510 375808
441 | 96850 896214
442 | 19291 907591
443 | 85744 97989
444 | 36386 749047
445 | 86276 493248
446 | 11152 843733
447 | 75252 74134
448 | 32784 804831
449 | 66438 447740
450 | 7699 873827
451 | 83138 347549
452 | 45879 875295
453 | 70343 314519
454 | 45961 662802
455 | 35141 228641
456 | 43567 609220
457 | 563 114504
458 | 41970 721889
459 | 12624 847770
460 | 17314 450309
461 | 27011 782811
462 | 76979 415811
463 | 7395 42939
464 | 87657 633075
465 | 61787 423518
466 | 54545 680051
467 | 8777 22754
468 | 74253 642984
469 | 75515 14680
470 | 95020 29360
471 | 66672 310482
472 | 56664 101292
473 | 84025 546463
474 | 16855 720788
475 | 14982 751249
476 | 86289 257267
477 | 8360 856578
478 | 55380 560776
479 | 19462 674913
480 | 8354 849238
481 | 85789 906857
482 | 69881 895847
483 | 38284 303142
484 | 237 855110
485 | 98568 353054
486 | 47940 451410
487 | 69652 48811
488 | 9267 640415
489 | 111 11010
490 | 86328 820245
491 | 25582 745377
492 | 7937 250661
493 | 45436 715283
494 | 66973 184234
495 | 19799 778040
496 | 74588 571786
497 | 69090 86245
498 | 29091 784646
499 | 66613 883369
500 | 72113 21653
501 | 51952 540591
502 | 3322 410306
503 | 55129 850339
504 | 28601 104228
505 | 49755 254331
506 | 55195 428656
507 | 48633 722623
508 | 58778 149736
509 | 18633 499120
510 | 86384 131386
511 | 37154 588668
512 | 98522 914197
513 | 9819 361495
514 | 66901 55417
515 | 74121 362229
516 | 33600 582429
517 | 27145 390488
518 | 40871 298738
519 | 4600 309748
520 | 74619 368101
521 | 79453 143864
522 | 52801 734000
523 | 44841 895480
524 | 86386 107898
525 | 24516 877130
526 | 14879 11010
527 | 18562 376909
528 | 28049 591237
529 | 49441 336539
530 | 18771 747212
531 | 40252 427555
532 | 5870 45141
533 | 72470 833824
534 | 8173 477834
535 | 52126 31195
536 | 98851 889608
537 | 11244 81107
538 | 25870 266442
539 | 41395 105696
540 | 26393 240018
541 | 9525 389754
542 | 39520 545729
543 | 81563 285526
544 | 3726 176527
545 | 5950 333236
546 | 37323 66060
547 | 60094 485908
548 | 85123 695098
549 | 26367 398929
550 | 71024 556372
551 | 66640 416545
552 | 8794 681152
553 | 59179 372872
554 | 56772 86245
555 | 51356 797858
556 | 68490 587200
557 | 67121 642617
558 | 10313 176160
559 | 33348 127349
560 | 32852 402599
561 | 25812 271947
562 | 91737 773636
563 | 29608 779141
564 | 87867 301307
565 | 41413 698034
566 | 14041 177995
567 | 34258 571052
568 | 88296 292132
569 | 42252 345347
570 | 92695 437464
571 | 51139 339108
572 | 19876 332135
573 | 94293 868689
574 | 14563 16882
575 | 93237 747579
576 | 83874 734000
577 | 27485 734367
578 | 54120 609220
579 | 99329 619129
580 | 263 449942
581 | 5717 293967
582 | 92173 438932
583 | 44001 91383
584 | 61399 552702
585 | 89546 254698
586 | 12301 898783
587 | 52193 452144
588 | 14784 411407
589 | 69802 322226
590 | 46222 89181
591 | 24514 555271
592 | 32532 75969
593 | 45820 808868
594 | 46857 233045
595 | 64255 249193
596 | 49815 649223
597 | 66538 40370
598 | 47298 505726
599 | 25761 418013
600 | 80192 799326
601 | 54175 458016
602 | 14878 412875
603 | 94785 455447
604 | 6751 870524
605 | 45000 848504
606 | 99610 675647
607 | 55365 606651
608 | 32813 613624
609 | 24135 177995
610 | 93584 433794
611 | 14260 402966
612 | 754 651792
613 | 79322 812538
614 | 65254 684455
615 | 36097 720421
616 | 30662 296536
617 | 27492 84043
618 | 42921 905022
619 | 81688 573621
620 | 55473 265708
621 | 55642 12111
622 | 82952 493982
623 | 58446 712714
624 | 49475 822814
625 | 96177 417279
626 | 25827 788683
627 | 23345 460952
628 | 85980 621698
629 | 68130 490312
630 | 59849 858780
631 | 5618 865019
632 | 47839 910894
633 | 88930 409205
634 | 29969 669041
635 | 25200 811437
636 | 59830 39269
637 | 3489 489945
638 | 2820 386818
639 | 91138 760057
640 | 80319 370303
641 | 80684 500955
642 | 91121 310115
643 | 89962 86979
644 | 11019 480403
645 | 79947 830521
646 | 31479 715650
647 | 25064 716384
648 | 71507 487376
649 | 78969 133588
650 | 4120 517103
651 | 49678 676381
652 | 78954 905022
653 | 51854 149736
654 | 8272 794188
655 | 40205 286627
656 | 28033 104228
657 | 83532 623900
658 | 9469 226072
659 | 52611 148268
660 | 67271 613624
661 | 6266 660967
662 | 43280 84777
663 | 75328 402966
664 | 25364 620964
665 | 75767 92851
666 | 3084 805565
667 | 26185 337640
668 | 44157 648122
669 | 63665 15414
670 | 57887 586833
671 | 44925 557106
672 | 195 422417
673 | 36768 109733
674 | 6911 186436
675 | 56806 793087
676 | 79653 110834
677 | 73586 457649
678 | 97807 78905
679 | 17961 505726
680 | 10404 55784
681 | 1086 325896
682 | 48586 907224
683 | 30485 416178
684 | 55201 100191
685 | 80101 117073
686 | 28200 564446
687 | 4429 216530
688 | 93918 235614
689 | 87803 403333
690 | 56579 692896
691 | 20790 496918
692 | 60596 59821
693 | 42320 533985
694 | 20722 190840
695 | 98515 348650
696 | 75707 677849
697 | 34627 164783
698 | 23802 230843
699 | 97645 616193
700 | 23259 658765
701 | 3132 276718
702 | 60721 690327
703 | 61025 446639
704 | 80509 309381
705 | 19398 707943
706 | 78315 171389
707 | 34000 513800
708 | 92534 558574
709 | 80540 457649
710 | 58822 61289
711 | 76537 726293
712 | 14663 573254
713 | 14494 426454
714 | 74989 349384
715 | 6839 822814
716 | 64892 717485
717 | 50561 900251
718 | 44093 210291
719 | 89776 625001
720 | 30414 895480
721 | 16637 850339
722 | 23348 471962
723 | 54833 766296
724 | 28929 407003
725 | 40880 877130
726 | 39055 17983
727 | 32538 316354
728 | 4422 67528
729 | 51035 164049
730 | 19417 294334
731 | 81939 192675
732 | 35184 735101
733 | 81097 109733
734 | 70525 902453
735 | 60364 721889
736 | 73013 765195
737 | 72203 877864
738 | 61537 513066
739 | 2209 152672
740 | 8032 47710
741 | 32096 685556
742 | 13113 151938
743 | 29938 909793
744 | 15686 815841
745 | 84382 746845
746 | 99909 865019
747 | 99440 445538
748 | 96120 408838
749 | 75000 358192
750 | 74516 291031
751 | 59923 474531
752 | 3981 581328
753 | 26886 306078
754 | 56326 149369
755 | 46995 605917
756 | 52603 461686
757 | 28481 481504
758 | 92835 316721
759 | 98882 427555
760 | 33459 588668
761 | 74993 713081
762 | 48652 441501
763 | 8543 506460
764 | 49636 580961
765 | 21873 269745
766 | 47323 338374
767 | 78442 572520
768 | 85780 168820
769 | 47161 8808
770 | 64956 177628
771 | 29365 356724
772 | 19314 212126
773 | 34329 491046
774 | 19856 145332
775 | 3882 800060
776 | 99905 65693
777 | 77617 574722
778 | 48865 345347
779 | 32123 216530
780 | 4944 81841
781 | 85119 542793
782 | 52189 822080
783 | 17640 481504
784 | 70344 635277
785 | 36220 725559
786 | 3318 725559
787 | 27783 747212
788 | 47860 142763
789 | 35591 789050
790 | 44011 767030
791 | 53016 718219
792 | 28543 727761
793 | 48271 81474
794 | 3009 348283
795 | 31100 637112
796 | 38214 226439
797 | 42952 508295
798 | 3503 699135
799 | 14954 819511
800 | 17917 915298
801 | 11870 25323
802 | 23504 536187
803 | 74330 494349
804 | 13383 238183
805 | 9353 311216
806 | 32527 267910
807 | 78756 725192
808 | 2953 48811
809 | 67423 128450
810 | 43787 478201
811 | 41880 701704
812 | 78891 838962
813 | 14687 406636
814 | 68828 222035
815 | 199 206988
816 | 70818 798592
817 | 73477 147167
818 | 31564 288829
819 | 6384 353054
820 | 29723 252129
821 | 32792 840797
822 | 6143 113770
823 | 31554 317822
824 | 22381 906123
825 | 64916 150103
826 | 37111 531783
827 | 11559 909426
828 | 11434 327731
829 | 62660 429757
830 | 17563 212493
831 | 49755 387552
832 | 33527 890342
833 | 27596 286627
834 | 81178 441501
835 | 69159 580961
836 | 6021 860615
837 | 64728 721889
838 | 59463 789784
839 | 22938 668674
840 | 46058 828686
841 | 908 470494
842 | 94259 797124
843 | 17344 31562
844 | 22282 707943
845 | 81458 113403
846 | 90118 842632
847 | 85625 905756
848 | 61460 119275
849 | 42217 223870
850 | 96852 256900
851 | 92923 728862
852 | 38616 510864
853 | 51362 423518
854 | 17232 440033
855 | 26124 162948
856 | 87445 83309
857 | 96239 523342
858 | 59360 177628
859 | 21266 661701
860 | 18388 378744
861 | 63949 523709
862 | 15425 656196
863 | 42776 834191
864 | 34990 905756
865 | 90846 222402
866 | 57258 365532
867 | 13262 844100
868 | 57105 348650
869 | 57032 81107
870 | 84687 61656
871 | 38644 59087
872 | 5649 452878
873 | 97318 75235
874 | 70126 70097
875 | 7864 328832
876 | 46681 638947
877 | 60980 275984
878 | 86815 889608
879 | 72190 295802
880 | 94906 131753
881 | 13615 17616
882 | 79309 103127
883 | 69450 85144
884 | 56030 569951
885 | 87867 790151
886 | 58986 825383
887 | 77519 436363
888 | 70580 205887
889 | 13384 570685
890 | 30560 792720
891 | 21175 104228
892 | 22676 682620
893 | 89437 393057
894 | 42916 345347
895 | 13014 906490
896 | 72493 729229
897 | 84084 654728
898 | 25848 916766
899 | 32707 140194
900 | 42307 80740
901 | 4115 726660
902 | 84471 597476
903 | 28618 678216
904 | 22305 70097
905 | 49664 248826
906 | 49899 60188
907 | 48657 601880
908 | 53013 423885
909 | 96281 572520
910 | 79105 398195
911 | 52074 516736
912 | 92980 539123
913 | 48270 881167
914 | 5632 771801
915 | 25634 520406
916 | 52521 543527
917 | 77598 282223
918 | 36872 73767
919 | 93321 805565
920 | 28019 248459
921 | 20518 732532
922 | 12614 192308
923 | 91357 62023
924 | 88595 650691
925 | 35671 469026
926 | 21650 123679
927 | 76426 378744
928 | 91892 796390
929 | 81218 203318
930 | 42655 64959
931 | 63902 493982
932 | 96393 822447
933 | 24971 502056
934 | 56188 651425
935 | 31566 132854
936 | 12972 545362
937 | 71161 385717
938 | 40776 141662
939 | 69695 230109
940 | 31535 429757
941 | 34415 371037
942 | 40424 70097
943 | 73078 91016
944 | 25628 497652
945 | 85575 255065
946 | 98840 183867
947 | 68184 32296
948 | 50531 842999
949 | 82607 293233
950 | 2013 830888
951 | 31939 325529
952 | 5322 913830
953 | 39228 562978
954 | 24461 636745
955 | 43488 310115
956 | 11582 756020
957 | 39195 23855
958 | 51296 409205
959 | 33420 166251
960 | 63444 719687
961 | 68707 422050
962 | 26594 680785
963 | 31227 100191
964 | 42734 881534
965 | 95297 52481
966 | 18062 623166
967 | 21378 502056
968 | 50523 615826
969 | 1053 516736
970 | 55181 212860
971 | 57152 542793
972 | 94974 280021
973 | 30418 241486
974 | 12466 864285
975 | 22740 140194
976 | 25008 822080
977 | 52431 157443
978 | 37826 322226
979 | 8328 829787
980 | 13498 638213
981 | 58226 315620
982 | 98823 230476
983 | 32554 688859
984 | 69469 455447
985 | 37425 171389
986 | 74278 826851
987 | 80501 757121
988 | 31602 455814
989 | 30699 632341
990 | 14550 809969
991 | 77128 565547
992 | 60 37067
993 | 54671 68262
994 | 7225 502790
995 | 5021 712347
996 | 5220 498386
997 | 38696 582429
998 | 25665 507928
999 | 95119 740606
1000 | 13724 819511
1001 | 43755 280388
1002 | 38588 433427
1003 | 4914 560042
1004 | 4158 171022
1005 | 40583 405902
1006 | 40963 49545
1007 | 98441 611055
1008 | 21171 583163
1009 | 45202 25690
1010 | 51621 23488
1011 | 30318 189739
1012 | 31753 40737
1013 | 63236 149369
1014 | 10832 73033
1015 | 51852 344613
1016 | 55452 150103
1017 | 45550 815107
1018 | 26235 445171
1019 | 49224 379845
1020 | 86772 447006
1021 | 8960 669408
1022 | 75346 643351
1023 | 59724 694364
1024 | 51820 184968
1025 | 34887 501322
1026 | 17393 262405
1027 | 84947 714182
1028 | 29366 304610
1029 | 2757 307179
1030 | 97918 745377
1031 | 83696 386451
1032 | 6865 409205
1033 | 90792 583897
1034 | 13757 809602
1035 | 19898 216530
1036 | 75310 702805
1037 | 46023 295435
1038 | 76900 645186
1039 | 71721 548298
1040 | 50149 66794
1041 | 95777 123312
1042 | 76834 349384
1043 | 95194 418747
1044 | 8009 782077
1045 | 86767 683721
1046 | 36631 165517
1047 | 69187 913463
1048 | 67670 643718
1049 | 75159 601513
1050 | 52400 579860
1051 | 83115 333970
1052 | 95057 708310
1053 | 10054 336172
1054 | 75776 136157
1055 | 89556 359293
1056 | 45049 331034
1057 | 99453 778407
1058 | 1632 762626
1059 | 52264 11010
1060 | 98920 209924
1061 | 28202 460952
1062 | 98583 884470
1063 | 38997 74501
1064 | 63502 635644
1065 | 50931 481137
1066 | 66421 831989
1067 | 82152 743175
1068 | 93171 565547
1069 | 85158 80006
1070 | 48016 213961
1071 | 53862 820612
1072 | 91174 844100
1073 | 93087 494349
1074 | 13053 640048
1075 | 74380 84777
1076 | 23433 126615
1077 | 65773 708310
1078 | 30188 532517
1079 | 95609 888874
1080 | 97037 769232
1081 | 18479 711246
1082 | 44089 49912
1083 | 84510 72666
1084 | 15124 688125
1085 | 78124 73033
1086 | 23880 702805
1087 | 61675 256533
1088 | 51432 36333
1089 | 4972 414710
1090 | 66195 765929
1091 | 4657 439666
1092 | 58217 127349
1093 | 6814 731064
1094 | 50258 728862
1095 | 51558 331034
1096 | 1237 729596
1097 | 27534 413242
1098 | 48911 644452
1099 | 53267 902820
1100 | 45109 709044
1101 | 62174 834191
1102 | 67897 388286
1103 | 23225 160379
1104 | 45409 193409
1105 | 79562 193409
1106 | 43572 445905
1107 | 76115 270846
1108 | 41290 660233
1109 | 91860 101659
1110 | 8695 139460
1111 | 52928 904655
1112 | 50843 89548
1113 | 52122 622065
1114 | 43581 638580
1115 | 8810 873827
1116 | 5484 237816
1117 | 48439 718953
1118 | 80939 95420
1119 | 86482 85511
1120 | 4551 440400
1121 | 41176 298371
1122 | 78275 310482
1123 | 98843 913463
1124 | 29821 736936
1125 | 49806 507928
1126 | 49221 188638
1127 | 37904 649223
1128 | 38851 607752
1129 | 35363 835292
1130 | 47071 455447
1131 | 69143 697300
1132 | 64736 866487
1133 | 38735 165517
1134 | 39875 478201
1135 | 17391 851807
1136 | 63007 887406
1137 | 16044 813639
1138 | 21356 297637
1139 | 41976 653260
1140 | 47376 712347
1141 | 18198 864652
1142 | 40538 653260
1143 | 14784 539857
1144 | 20177 387552
1145 | 25083 76336
1146 | 99152 230476
1147 | 15306 706108
1148 | 76378 288829
1149 | 23791 365532
1150 | 18397 652159
1151 | 28232 462787
1152 | 80607 323694
1153 | 90471 499120
1154 | 21358 474898
1155 | 4054 899517
1156 | 75281 502790
1157 | 35084 727027
1158 | 82583 744276
1159 | 79564 704273
1160 | 4638 910894
1161 | 74688 568483
1162 | 19952 678216
1163 | 11623 220567
1164 | 69631 351219
1165 | 31507 10643
1166 | 74718 826851
1167 | 49140 418380
1168 | 71693 422784
1169 | 96145 121110
1170 | 18084 436363
1171 | 11866 149736
1172 | 33388 313051
1173 | 92288 471595
1174 | 46680 460952
1175 | 10364 132854
1176 | 18000 296169
1177 | 21502 380946
1178 | 97206 613991
1179 | 26245 827585
1180 | 53077 649590
1181 | 69320 338374
1182 | 52989 780609
1183 | 36153 581695
1184 | 24559 693263
1185 | 51682 175059
1186 | 33217 206621
1187 | 69246 51380
1188 | 13686 498019
1189 | 814 28993
1190 | 48106 524076
1191 | 49655 597843
1192 | 98471 40003
1193 | 13980 377276
1194 | 80040 146066
1195 | 79946 353054
1196 | 53749 411407
1197 | 95178 900985
1198 | 27492 524810
1199 | 499 96521
1200 | 12005 265708
1201 | 25557 909059
1202 | 83134 101292
1203 | 21293 610688
1204 | 24271 24956
1205 | 5354 215429
1206 | 70129 582429
1207 | 91083 533618
1208 | 49301 106797
1209 | 73456 463521
1210 | 25385 62390
1211 | 95848 513800
1212 | 12489 711246
1213 | 84309 754919
1214 | 32775 630873
1215 | 8262 854743
1216 | 23925 675280
1217 | 96390 342778
1218 | 13345 284425
1219 | 37748 707209
1220 | 43968 117807
1221 | 74577 680051
1222 | 60837 423518
1223 | 92147 136524
1224 | 99758 568116
1225 | 89597 729963
1226 | 28382 453245
1227 | 79295 890709
1228 | 23805 92851
1229 | 90688 822080
1230 | 42218 858046
1231 | 35294 817309
1232 | 29939 243688
1233 | 19488 290297
1234 | 12681 501689
1235 | 53527 767030
1236 | 8708 481137
1237 | 57618 125881
1238 | 82639 229742
1239 | 19546 741707
1240 | 56126 888874
1241 | 93880 508662
1242 | 61642 303509
1243 | 43982 834558
1244 | 76326 230109
1245 | 93894 98723
1246 | 14564 553436
1247 | 25502 625368
1248 | 76941 325162
1249 | 14330 149002
1250 | 42281 198914
1251 | 85475 197079
1252 | 78777 914564
1253 | 87705 491413
1254 | 34969 820979
1255 | 91267 828686
1256 | 66375 617661
1257 | 39822 820979
1258 | 4416 261304
1259 | 29842 322593
1260 | 50159 342411
1261 | 3766 617294
1262 | 37762 349017
1263 | 30273 890342
1264 | 6195 358926
1265 | 65255 633809
1266 | 517 642617
1267 | 27684 821713
1268 | 15114 328098
1269 | 55545 685189
1270 | 1760 705374
1271 | 74673 467558
1272 | 21173 814740
1273 | 75230 148635
1274 | 77144 349017
1275 | 56510 662435
1276 | 49185 439666
1277 | 59512 36333
1278 | 18384 600412
1279 | 15786 22754
1280 | 70761 909059
1281 | 32990 205153
1282 | 39734 488844
1283 | 38586 697667
1284 | 7586 751616
1285 | 21175 466090
1286 | 38664 26424
1287 | 23304 501322
1288 | 70364 765929
1289 | 63992 313051
1290 | 71090 681519
1291 | 89830 357458
1292 | 570 147901
1293 | 66532 59454
1294 | 6712 434161
1295 | 18097 297637
1296 | 70805 499120
1297 | 80889 416545
1298 | 43821 895480
1299 | 68366 829787
1300 | 63373 732899
1301 | 50381 63858
1302 | 7901 655462
1303 | 62619 341677
1304 | 91920 624634
1305 | 70764 784279
1306 | 2059 278186
1307 | 85468 673078
1308 | 12552 418013
1309 | 5896 325162
1310 | 62430 605917
1311 | 76992 535086
1312 | 45461 397461
1313 | 54993 913830
1314 | 43240 675647
1315 | 80385 522241
1316 | 83514 753084
1317 | 98309 754552
1318 | 86121 416545
1319 | 67246 125514
1320 | 30880 520773
1321 | 59703 750882
1322 | 36168 131753
1323 | 95970 443703
1324 | 89830 215062
1325 | 54619 139460
1326 | 68806 166251
1327 | 63838 774003
1328 | 57130 752350
1329 | 59916 269378
1330 | 47473 151571
1331 | 50161 189372
1332 | 74280 101659
1333 | 85325 139460
1334 | 60781 219099
1335 | 76559 200749
1336 | 52116 908692
1337 | 49622 296536
1338 | 21946 346448
1339 | 36096 283324
1340 | 79011 371404
1341 | 43283 71932
1342 | 98283 748313
1343 | 92302 372872
1344 | 59636 18717
1345 | 12685 213227
1346 | 46606 396727
1347 | 65778 766296
1348 | 40050 586833
1349 | 76545 800060
1350 | 91625 262772
1351 | 77062 657297
1352 | 55639 653994
1353 | 47508 401865
1354 | 47014 254331
1355 | 84852 901719
1356 | 85699 826484
1357 | 47801 697667
1358 | 21090 358559
1359 | 74500 729596
1360 | 26675 439299
1361 | 11529 563712
1362 | 64648 750515
1363 | 13294 650691
1364 | 56169 785380
1365 | 77477 122945
1366 | 13054 497652
1367 | 82488 407370
1368 | 88520 572520
1369 | 54807 824282
1370 | 92232 485541
1371 | 41350 183500
1372 | 4210 93585
1373 | 14633 571052
1374 | 42985 895113
1375 | 14554 219099
1376 | 59389 127716
1377 | 59939 515268
1378 | 80608 867588
1379 | 21159 506093
1380 | 89176 708677
1381 | 69493 248826
1382 | 22030 656563
1383 | 69861 135790
1384 | 84028 437097
1385 | 91440 539123
1386 | 93365 355990
1387 | 49137 758222
1388 | 3012 426087
1389 | 40656 380946
1390 | 93160 307179
1391 | 42650 716751
1392 | 42173 901719
1393 | 91991 254698
1394 | 16743 113403
1395 | 93022 689593
1396 | 69835 482605
1397 | 81294 240752
1398 | 4102 214695
1399 | 58249 689960
1400 | 82333 102026
1401 | 98766 82942
1402 | 59003 854009
1403 | 20720 794555
1404 | 82438 627937
1405 | 50737 824649
1406 | 34009 387185
1407 | 50934 688492
1408 | 38 193042
1409 | 49695 124046
1410 | 36011 148635
1411 | 75906 477467
1412 | 23208 106430
1413 | 75575 460585
1414 | 91908 876763
1415 | 90815 457649
1416 | 55440 286994
1417 | 81088 556739
1418 | 24449 670509
1419 | 27314 354155
1420 | 96561 674913
1421 | 7671 53215
1422 | 74957 275617
1423 | 40227 763360
1424 | 45323 259469
1425 | 65200 850339
1426 | 99897 100558
1427 | 29342 915665
1428 | 70436 281122
1429 | 72472 625001
1430 | 70108 707943
1431 | 56293 486642
1432 | 44553 548665
1433 | 53014 174692
1434 | 77604 282223
1435 | 20418 131386
1436 | 61989 263139
1437 | 1044 96154
1438 | 99689 481137
1439 | 1767 361862
1440 | 47639 354155
1441 | 21168 297637
1442 | 88123 107898
1443 | 49332 150103
1444 | 74121 757855
1445 | 81465 618395
1446 | 38206 32663
1447 | 97843 291031
1448 | 12189 321859
1449 | 86686 786481
1450 | 60674 277452
1451 | 93398 149369
1452 | 51253 655095
1453 | 10387 248459
1454 | 89718 342411
1455 | 80724 340576
1456 | 93197 916399
1457 | 37080 114871
1458 | 50843 198547
1459 | 89598 871625
1460 | 12312 414710
1461 | 55545 671977
1462 | 33926 24222
1463 | 28073 355256
1464 | 5458 117073
1465 | 74328 159645
1466 | 40783 694731
1467 | 43145 478935
1468 | 31814 485174
1469 | 13902 580227
1470 | 88724 500955
1471 | 38965 371037
1472 | 11459 367367
1473 | 15344 881901
1474 | 35528 239651
1475 | 55134 288462
1476 | 25542 742808
1477 | 69900 563712
1478 | 26731 109366
1479 | 63577 849972
1480 | 49508 693630
1481 | 84399 611789
1482 | 52416 840063
1483 | 93648 251028
1484 | 95476 872359
1485 | 3187 366633
1486 | 12248 460585
1487 | 34039 273048
1488 | 20703 817676
1489 | 24080 246991
1490 | 21147 766663
1491 | 42441 706475
1492 | 34442 227173
1493 | 98503 498019
1494 | 25370 48077
1495 | 63909 355256
1496 | 41844 649957
1497 | 87081 205887
1498 | 93003 291765
1499 | 65971 222769
1500 | 5588 187904
1501 | 29109 751616
1502 | 30206 773269
1503 | 74415 282223
1504 | 49666 142029
1505 | 12160 334337
1506 | 61391 612156
1507 | 16883 320758
1508 | 22904 701704
1509 | 17185 495083
1510 | 68605 911995
1511 | 45028 850339
1512 | 46525 702438
1513 | 7415 901719
1514 | 4949 13579
1515 | 93905 306078
1516 | 66836 177995
1517 | 61137 503157
1518 | 78526 78905
1519 | 83955 585365
1520 | 42234 303876
1521 | 2231 12845
1522 | 7371 435996
1523 | 11305 35232
1524 | 84680 894746
1525 | 17656 548298
1526 | 92680 561877
1527 | 76721 71198
1528 | 73502 351219
1529 | 11350 292132
1530 | 48803 216897
1531 | 23505 95787
1532 | 27161 205887
1533 | 61929 23488
1534 | 10102 657297
1535 | 68383 166985
1536 | 74986 355990
1537 | 14824 99824
1538 | 46769 305711
1539 | 38928 851440
1540 | 8275 475999
1541 | 29238 389754
1542 | 17321 507194
1543 | 24813 314886
1544 | 26684 131386
1545 | 87542 716384
1546 | 25068 602981
1547 | 16289 528480
1548 | 25034 177995
1549 | 87282 374707
1550 | 93233 12111
1551 | 75751 95787
1552 | 74971 578392
1553 | 15024 167719
1554 | 8575 90282
1555 | 43819 484440
1556 | 98201 480403
1557 | 26792 539490
1558 | 10120 576924
1559 | 91794 650324
1560 | 83080 458750
1561 | 7519 268277
1562 | 52499 556372
1563 | 53658 59821
1564 | 76769 801528
1565 | 9427 841531
1566 | 76653 680051
1567 | 64994 155975
1568 | 67442 85878
1569 | 40708 339475
1570 | 60764 561877
1571 | 11923 55784
1572 | 9259 430858
1573 | 55047 909793
1574 | 78164 55784
1575 | 71911 104962
1576 | 38774 473430
1577 | 35233 487009
1578 | 21308 248826
1579 | 94485 94319
1580 | 4136 433060
1581 | 31699 892177
1582 | 45627 826851
1583 | 76135 732532
1584 | 45069 837494
1585 | 32600 296169
1586 | 41570 397828
1587 | 27397 583163
1588 | 12874 564446
1589 | 13012 424252
1590 | 70901 97622
1591 | 2380 917133
1592 | 67070 42939
1593 | 8415 653994
1594 | 46885 607018
1595 | 24645 623166
1596 | 9128 663536
1597 | 61400 91383
1598 | 72496 772902
1599 | 21292 647021
1600 | 49539 42205
1601 | 59454 703539
1602 | 95124 598944
1603 | 25981 172490
1604 | 76874 88080
1605 | 34181 275250
1606 | 44294 721522
1607 | 90144 656196
1608 | 71947 761892
1609 | 65776 32296
1610 | 18777 473063
1611 | 48336 398195
1612 | 7449 537655
1613 | 98163 206988
1614 | 57248 187170
1615 | 60815 420215
1616 | 66631 656930
1617 | 87857 487009
1618 | 98344 660233
1619 | 10016 434161
1620 | 51143 609220
1621 | 524 129551
1622 | 70013 310849
1623 | 53381 319657
1624 | 14158 463521
1625 | 59224 30461
1626 | 20007 786481
1627 | 51377 493982
1628 | 73939 269745
1629 | 77267 801895
1630 | 54310 530315
1631 | 27397 485908
1632 | 9297 860615
1633 | 56736 684822
1634 | 3659 167719
1635 | 83904 334337
1636 | 1847 466457
1637 | 65951 688492
1638 | 53628 435996
1639 | 80153 635644
1640 | 19785 529214
1641 | 64117 841531
1642 | 85252 714916
1643 | 50684 611055
1644 | 12000 847770
1645 | 20243 135056
1646 | 6837 602614
1647 | 67661 221668
1648 | 34715 665738
1649 | 41067 565547
1650 | 9900 434895
1651 | 12020 854376
1652 | 6644 217264
1653 | 8177 873827
1654 | 76047 571052
1655 | 17721 380212
1656 | 29250 898049
1657 | 68348 759323
1658 | 18477 739138
1659 | 6079 654728
1660 | 68237 136157
1661 | 75640 211025
1662 | 86480 374707
1663 | 78267 347549
1664 | 1607 430858
1665 | 357 297270
1666 | 73347 378744
1667 | 61634 703172
1668 | 3591 637846
1669 | 48059 54683
1670 | 34595 843366
1671 | 58098 432326
1672 | 41236 420215
1673 | 42969 617661
1674 | 61465 699502
1675 | 89789 867955
1676 | 41884 679684
1677 | 7832 644085
1678 | 69157 703539
1679 | 5813 55784
1680 | 25753 909793
1681 | 84882 232311
1682 | 28236 792353
1683 | 66425 307179
1684 | 2867 126248
1685 | 31129 510497
1686 | 12334 481137
1687 | 61304 343145
1688 | 76689 696566
1689 | 92490 815107
1690 | 62682 829787
1691 | 31924 73400
1692 | 6453 217264
1693 | 82962 768131
1694 | 6502 908692
1695 | 88994 531783
1696 | 19614 77437
1697 | 27 778774
1698 | 32957 197079
1699 | 42307 901352
1700 | 98965 849238
1701 | 67907 93218
1702 | 24720 845568
1703 | 8135 531783
1704 | 85993 579860
1705 | 49595 798959
1706 | 19044 590136
1707 | 66428 464989
1708 | 96965 610688
1709 | 39167 422784
1710 | 14485 517470
1711 | 33322 292866
1712 | 56871 622432
1713 | 35593 187170
1714 | 71130 298738
1715 | 94548 31562
1716 | 89781 341310
1717 | 11698 841898
1718 | 33541 476733
1719 | 84401 208456
1720 | 23782 871992
1721 | 71776 150103
1722 | 89303 524810
1723 | 97012 59821
1724 | 34276 40370
1725 | 67799 633075
1726 | 48738 80373
1727 | 73587 75969
1728 | 80505 658398
1729 | 99460 243688
1730 | 16396 538756
1731 | 29282 255799
1732 | 76269 712347
1733 | 9079 838962
1734 | 90505 875662
1735 | 8527 843733
1736 | 97741 561143
1737 | 62909 58353
1738 | 93060 685556
1739 | 25098 82942
1740 | 33019 434161
1741 | 99158 521507
1742 | 20722 517103
1743 | 21072 32296
1744 | 76023 423518
1745 | 38953 587567
1746 | 22750 746845
1747 | 65705 419481
1748 | 6688 163315
1749 | 92015 565914
1750 | 1060 461686
1751 | 32277 192675
1752 | 40017 244422
1753 | 86073 401498
1754 | 45443 785747
1755 | 67915 350118
1756 | 20130 790151
1757 | 13588 441868
1758 | 45732 183500
1759 | 37586 690327
1760 | 95720 336539
1761 | 79270 795656
1762 | 15654 522241
1763 | 40442 166985
1764 | 28149 143864
1765 | 74227 125147
1766 | 37929 741340
1767 | 32564 819511
1768 | 88748 697667
1769 | 72932 667940
1770 | 34891 725192
1771 | 85724 344613
1772 | 5130 536187
1773 | 27184 384983
1774 | 3375 469760
1775 | 9147 913830
1776 | 27132 805565
1777 | 60156 914197
1778 | 19480 436363
1779 | 63085 118541
1780 | 95875 678216
1781 | 96386 473797
1782 | 30546 634910
1783 | 2901 522608
1784 | 84309 830888
1785 | 10913 880433
1786 | 52951 568116
1787 | 32186 407370
1788 | 71350 798592
1789 | 29969 495817
1790 | 54597 772902
1791 | 70674 617661
1792 | 21686 226072
1793 | 35620 373239
1794 | 28105 240385
1795 | 72430 546830
1796 | 6095 768131
1797 | 28032 113403
1798 | 56270 489578
1799 | 23400 894746
1800 | 4746 169554
1801 | 76656 698401
1802 | 35221 440400
1803 | 94976 549032
1804 | 3769 224971
1805 | 23568 309014
1806 | 94145 905389
1807 | 59191 684088
1808 | 89326 69363
1809 | 22744 817309
1810 | 10313 615826
1811 | 52897 803730
1812 | 5064 569217
1813 | 30828 354889
1814 | 20704 302041
1815 | 49244 622065
1816 | 35601 670876
1817 | 31530 126248
1818 | 35798 652159
1819 | 50917 133221
1820 | 45780 747946
1821 | 69649 495083
1822 | 77149 17249
1823 | 74561 210658
1824 | 27993 888140
1825 | 91518 127716
1826 | 92589 871258
1827 | 16146 13212
1828 | 31780 486275
1829 | 42676 142396
1830 | 45754 424986
1831 | 20541 23855
1832 | 28 658031
1833 | 63157 28259
1834 | 77586 383148
1835 | 96707 695832
1836 | 76473 162948
1837 | 58300 786481
1838 | 20447 760791
1839 | 13038 55784
1840 | 1823 732899
1841 | 56455 102760
1842 | 43847 598944
1843 | 28197 224604
1844 | 95753 538389
1845 | 77658 736936
1846 | 22300 244789
1847 | 30904 169187
1848 | 64264 774003
1849 | 24516 273782
1850 | 55622 415444
1851 | 90153 399663
1852 | 22516 728862
1853 | 22624 400764
1854 | 55010 444437
1855 | 82924 18717
1856 | 70743 564446
1857 | 72758 278553
1858 | 78699 313785
1859 | 24163 117073
1860 | 39335 492147
1861 | 57124 309748
1862 | 99084 509763
1863 | 97387 295802
1864 | 32194 816942
1865 | 43610 774737
1866 | 31068 205887
1867 | 39599 567382
1868 | 51053 818043
1869 | 97433 571419
1870 | 99402 525911
1871 | 20142 150470
1872 | 32195 513433
1873 | 65175 703539
1874 | 54938 186436
1875 | 66297 121477
1876 | 28273 891810
1877 | 47434 297270
1878 | 65237 477467
1879 | 49416 797124
1880 | 31312 105696
1881 | 10547 394892
1882 | 89521 915665
1883 | 60827 557840
1884 | 15 677849
1885 | 55502 314519
1886 | 91984 166985
1887 | 29946 561877
1888 | 43871 206988
1889 | 70548 47710
1890 | 24782 626102
1891 | 13642 523342
1892 | 57502 387185
1893 | 6224 866854
1894 | 98862 847770
1895 | 92233 60922
1896 | 30741 790518
1897 | 53965 303509
1898 | 69578 647755
1899 | 72056 431592
1900 | 78927 299472
1901 | 92924 277085
1902 | 28940 482972
1903 | 43668 827585
1904 | 52084 134322
1905 | 21674 881534
1906 | 2587 526645
1907 | 48418 16148
1908 | 71794 770333
1909 | 18615 532150
1910 | 82917 279287
1911 | 85208 114137
1912 | 73381 139093
1913 | 49805 231577
1914 | 98845 284058
1915 | 59964 797124
1916 | 74593 721155
1917 | 75721 789417
1918 | 82430 308280
1919 | 79182 468659
1920 | 37556 424986
1921 | 47980 588301
1922 | 26496 773636
1923 | 93251 824282
1924 | 86790 666472
1925 | 98458 429390
1926 | 46383 241119
1927 | 18715 499120
1928 | 1643 826117
1929 | 31313 361862
1930 | 53967 801161
1931 | 37791 758222
1932 | 89490 280755
1933 | 27636 574355
1934 | 78337 83676
1935 | 43324 188638
1936 | 38603 558941
1937 | 23628 742441
1938 | 56106 233779
1939 | 8093 723724
1940 | 41550 419114
1941 | 87109 524810
1942 | 33232 473797
1943 | 97029 819511
1944 | 18545 770333
1945 | 39937 201116
1946 | 77310 897315
1947 | 31291 520039
1948 | 4005 410306
1949 | 38057 384616
1950 | 62018 271947
1951 | 40666 719687
1952 | 59611 549032
1953 | 21189 492881
1954 | 82634 525911
1955 | 92735 873460
1956 | 90651 85511
1957 | 4708 340943
1958 | 12421 57619
1959 | 97722 533985
1960 | 27200 329199
1961 | 50462 662068
1962 | 85191 290664
1963 | 70601 631974
1964 | 49579 456915
1965 | 38894 208823
1966 | 6916 685556
1967 | 47648 62390
1968 | 22752 157076
1969 | 87996 420582
1970 | 91615 167352
1971 | 66420 189372
1972 | 73099 561877
1973 | 16210 423885
1974 | 27646 357825
1975 | 41077 8074
1976 | 82442 527746
1977 | 16816 714549
1978 | 25747 790151
1979 | 35284 623900
1980 | 20134 253964
1981 | 15887 560042
1982 | 83813 12478
1983 | 50283 241486
1984 | 67569 757121
1985 | 95705 683354
1986 | 64813 227540
1987 | 98810 628304
1988 | 40017 360761
1989 | 92262 667940
1990 | 42362 94686
1991 | 50513 26424
1992 | 88057 465723
1993 | 55000 567382
1994 | 98343 472696
1995 | 53524 780976
1996 | 83908 740973
1997 | 82788 711613
1998 | 20058 484073
1999 | 40941 895480
2000 | 76263 416912
2001 | 48122 661701
2002 |
--------------------------------------------------------------------------------
/data/tmp.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dhpollack/programming_notebooks/168936ac09071c00a61d56b35467d7d27ec61e60/data/tmp.zip
--------------------------------------------------------------------------------
/databricks_python_graphframes_GraphExample.ipynb:
--------------------------------------------------------------------------------
1 | {"cells":[{"cell_type":"code","source":["import graphframes as gf\n# note on databricks: sc and sqlContext are preloaded"],"metadata":{},"outputs":[],"execution_count":1},{"cell_type":"code","source":["rawdat = sqlContext.sql(\"SELECT * FROM coursera_algos_mincut2\")\nprint rawdat.first()"],"metadata":{},"outputs":[],"execution_count":2},{"cell_type":"code","source":["def verticesList(row):\n return [(row.vertex, str(row.vertex))]\n\ndef adjacencyList(row):\n e = []\n v1 = row.vertex\n for i in range(1, 40):\n v2 = row[\"C\"+str(i)]\n if v2 is not None:\n edge = (v1, v2) if v1 < v2 else (v2, v1)\n e.append(edge)\n return e\n\n# http://graphframes.github.io/\n# http://cdn2.hubspot.net/hubfs/438089/notebooks/help/Setup_graphframes_package.html\nvertices = rawdat.flatMap(verticesList)\nedges = rawdat.flatMap(adjacencyList)\n# using distinct on edges because this is an undirected graph\nG = gf.GraphFrame(vertices.toDF([\"id\", \"name\"]), edges.distinct().toDF([\"src\", \"dst\"]))\nprint G.edges.collect()\n"],"metadata":{},"outputs":[],"execution_count":3},{"cell_type":"code","source":["print(G.edges.filter(\"src = 1\").collect())"],"metadata":{},"outputs":[],"execution_count":4},{"cell_type":"code","source":["print G.edges.count(), G.vertices.count()\n"],"metadata":{},"outputs":[],"execution_count":5},{"cell_type":"code","source":[""],"metadata":{},"outputs":[],"execution_count":6}],"metadata":{"name":"GraphExample","notebookId":4197684023626608},"nbformat":4,"nbformat_minor":0}
2 |
--------------------------------------------------------------------------------
/databricks_python_numpy_spark_kMeansExample.ipynb:
--------------------------------------------------------------------------------
1 | {"cells":[{"cell_type":"code","source":["import numpy as np\n#import pandas as pd\nimport matplotlib.pyplot as plt\n#%matplotlib inline\n\nc_num = 100\ncenters = [tuple([np.random.rand(), np.random.rand()]) for p in range(c_num)]\nprint centers\nsd = np.eye(2) / 10000\nX=[]\nn = 1000\nfor c in centers:\n X.extend(np.random.multivariate_normal(c, sd, size=n))\nX = np.array(X)\nfig = plt.figure()\nplt.scatter(x=X[:,0], y=X[:,1])\ndisplay(fig)\n"],"metadata":{},"outputs":[],"execution_count":1},{"cell_type":"code","source":["from pyspark.mllib.clustering import KMeans, KMeansModel\n\nX_parr = sc.parallelize(X)\n\nmodel = KMeans.train(X_parr, c_num, maxIterations=15, initializationMode=\"random\")\nprint model.centers\npredictions = model.predict(X_parr).collect()\nfig = plt.figure()\nplt.scatter(x=X[:,0], y=X[:,1], c=predictions)\ndisplay(fig)\n"],"metadata":{},"outputs":[],"execution_count":2},{"cell_type":"code","source":[""],"metadata":{},"outputs":[],"execution_count":3}],"metadata":{"name":"kMeansExample","notebookId":3297546998532744},"nbformat":4,"nbformat_minor":0}
2 |
--------------------------------------------------------------------------------
/imdb_tokenize.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 8,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import torch\n",
10 | "import torch.utils.data as data\n",
11 | "import torch.nn as nn\n",
12 | "import torch.nn.functional as F\n",
13 | "import numpy as np\n",
14 | "import nltk\n",
15 | "import os\n",
16 | "import bs4\n",
17 | "import random\n",
18 | "\n",
19 | "torch.manual_seed(12345)\n",
20 | "random.seed(12345)\n",
21 | "\n",
22 | "nltk.data.path.append(\"/home/david/Programming/data/nltk_data\")"
23 | ]
24 | },
25 | {
26 | "cell_type": "code",
27 | "execution_count": 25,
28 | "metadata": {
29 | "collapsed": true
30 | },
31 | "outputs": [],
32 | "source": [
33 | "def concat_contractions(tokens):\n",
34 | " contractions = set([\"'ve\", \"'d\", \"'m\", \"'ll\", \"'re\", \"n't\"])\n",
35 | " return [\"\".join(tokens[i]) if (i+1 == len(tokens) or tokens[i+1] not in contractions) else \"\".join(tokens[(i):(i+2)]) for i in range(len(tokens)) if tokens[i] not in contractions]\n",
36 | "\n",
37 | "def data_processing(ds_paths, max_len=500, split_ratio=1.0):\n",
38 | " ds = []\n",
39 | " for i, tfp in enumerate(ds_paths):\n",
40 | " idx, rating = os.path.basename(tfp).split(\".\")[0].split(\"_\")\n",
41 | " with open(tfp, \"r\") as f:\n",
42 | " raw = f.readlines()\n",
43 | " raw = bs4.BeautifulSoup(raw[0], \"html5lib\")\n",
44 | " txt = raw.get_text(separator=' ')\n",
45 | " tokens = nltk.word_tokenize(txt)\n",
46 | " tokens = concat_contractions(tokens)\n",
47 | " #tokens = [vocab[w] if w in vocab else len(vocab) for w in tokens] # keep out of vocab\n",
48 | " tokens = [vocab[w] for w in tokens if w in vocab]\n",
49 | " if len(tokens) > max_len:\n",
50 | " tokens = tokens[:max_len]\n",
51 | " elif len(tokens) < max_len:\n",
52 | " tokens = tokens + [0]*(max_len-len(tokens))\n",
53 | " ds.append((tokens, int(rating)))\n",
54 | " dat, labels = zip(*ds)\n",
55 | " assert split_ratio >= 0. and split_ratio <= 1.0\n",
56 | " if split_ratio == 1.:\n",
57 | " return (dat, labels), (None, None)\n",
58 | " else:\n",
59 | " split_idx = int(len(dat) * split_ratio)\n",
60 | " tidx = list(range(len(dat)))\n",
61 | " random.shuffle(tidx)\n",
62 | " tidx, vidx = tidx[:split_idx], tidx[split_idx:]\n",
63 | " ts, ts_labels = [dat[tid] for tid in tidx], [labels[tid] for tid in tidx]\n",
64 | " vs, vs_labels = [dat[vid] for vid in vidx], [labels[vid] for vid in vidx]\n",
65 | " return (ts, ts_labels), (vs, vs_labels)\n"
66 | ]
67 | },
68 | {
69 | "cell_type": "code",
70 | "execution_count": 5,
71 | "metadata": {},
72 | "outputs": [
73 | {
74 | "name": "stdout",
75 | "output_type": "stream",
76 | "text": [
77 | "imdbEr.txt imdb.vocab README \u001b[0m\u001b[01;34mtest\u001b[0m/ \u001b[01;34mtrain\u001b[0m/\r\n"
78 | ]
79 | },
80 | {
81 | "data": {
82 | "text/plain": [
83 | "['/home/david/Programming/data/aclImdb/train/neg/0_3.txt',\n",
84 | " '/home/david/Programming/data/aclImdb/train/neg/10000_4.txt',\n",
85 | " '/home/david/Programming/data/aclImdb/train/neg/10001_4.txt',\n",
86 | " '/home/david/Programming/data/aclImdb/train/neg/10002_1.txt',\n",
87 | " '/home/david/Programming/data/aclImdb/train/neg/10003_1.txt']"
88 | ]
89 | },
90 | "execution_count": 5,
91 | "metadata": {},
92 | "output_type": "execute_result"
93 | }
94 | ],
95 | "source": [
96 | "IMDB_BASEDIR = \"/home/david/Programming/data/aclImdb\"\n",
97 | "%ls $IMDB_BASEDIR\n",
98 | "train_paths = sorted([f.path for d in [\"pos\", \"neg\"] for f in os.scandir(os.path.join(IMDB_BASEDIR, \"train\", d))])\n",
99 | "test_paths = sorted([f.path for d in [\"pos\", \"neg\"] for f in os.scandir(os.path.join(IMDB_BASEDIR, \"test\", d))])\n",
100 | "\n",
101 | "train_paths[:5]"
102 | ]
103 | },
104 | {
105 | "cell_type": "code",
106 | "execution_count": 6,
107 | "metadata": {
108 | "collapsed": true
109 | },
110 | "outputs": [],
111 | "source": [
112 | "vocab_limit = 5000\n",
113 | "with open(os.path.join(IMDB_BASEDIR, \"imdb.vocab\"), \"r\") as f:\n",
114 | " vocab = {w:(i+1) for i, w in enumerate([l.strip() for l in f.readlines()][:vocab_limit])}\n"
115 | ]
116 | },
117 | {
118 | "cell_type": "code",
119 | "execution_count": 26,
120 | "metadata": {},
121 | "outputs": [
122 | {
123 | "name": "stdout",
124 | "output_type": "stream",
125 | "text": [
126 | "500 500\n"
127 | ]
128 | }
129 | ],
130 | "source": [
131 | "\n",
132 | "trainset, validset = data_processing(train_paths, split_ratio = 0.9)\n",
133 | "print(len(trainset[0][0]), len(validset[0][0]))\n",
134 | "ts, ts_labels = torch.Tensor(trainset[0]).long(), torch.Tensor(trainset[1])\n",
135 | "ts_labels = (ts_labels > 5).float()\n",
136 | "dts = data.TensorDataset(ts, ts_labels)\n",
137 | "dlts = data.DataLoader(dts, batch_size=100)"
138 | ]
139 | },
140 | {
141 | "cell_type": "code",
142 | "execution_count": 33,
143 | "metadata": {},
144 | "outputs": [
145 | {
146 | "name": "stdout",
147 | "output_type": "stream",
148 | "text": [
149 | "2500\n"
150 | ]
151 | }
152 | ],
153 | "source": [
154 | "vs, vs_labels = torch.Tensor(validset[0]).long(), torch.Tensor(validset[1])\n",
155 | "vs_labels = (vs_labels > 5).float()\n",
156 | "dvs = data.TensorDataset(vs, vs_labels)\n",
157 | "dlvs = data.DataLoader(dvs, batch_size=100)\n",
158 | "print(len(vs))"
159 | ]
160 | },
161 | {
162 | "cell_type": "code",
163 | "execution_count": 12,
164 | "metadata": {},
165 | "outputs": [
166 | {
167 | "name": "stdout",
168 | "output_type": "stream",
169 | "text": [
170 | "25000 2\n"
171 | ]
172 | }
173 | ],
174 | "source": [
175 | "#split_ratio = 0.9\n",
176 | "#split_idx = int(len(trainset) * split_ratio)\n",
177 | "#tidx = list(range(len(trainset)))\n",
178 | "#random.shuffle(tidx)\n",
179 | "#tidx, vidx = tidx[:split_idx], tidx[split_idx:]\n",
180 | "#ts, ts_labels = [trainset[tid] for tid in tidx], [train_labels[tid] for tid in tidx]\n",
181 | "#vs, vs_labels = [trainset[vid] for vid in vidx], [train_labels[vid] for vid in vidx]\n",
182 | "\n",
183 | "#ts, ts_labels = torch.Tensor(ts).long(), torch.Tensor(ts_labels)\n",
184 | "#ts_labels = (ts_labels > 5).float()\n",
185 | "#dts = data.TensorDataset(ts, ts_labels)\n",
186 | "#dlts = data.DataLoader(dts, batch_size=100)"
187 | ]
188 | },
189 | {
190 | "cell_type": "code",
191 | "execution_count": 60,
192 | "metadata": {
193 | "scrolled": false
194 | },
195 | "outputs": [
196 | {
197 | "name": "stdout",
198 | "output_type": "stream",
199 | "text": [
200 | "0 5000\n",
201 | "torch.Size([22500, 500])\n",
202 | "SingleHiddenNN (\n",
203 | " (emb): Embedding(5001, 32)\n",
204 | " (fc): Linear (16000 -> 100)\n",
205 | " (relu): SELU\n",
206 | " (dropout): Dropout (p = 0.7)\n",
207 | " (out): Linear (100 -> 1)\n",
208 | " (sigmoid): Sigmoid ()\n",
209 | ")\n",
210 | "epoch 1 had a loss of 180.21:\n",
211 | "epoch 2 had a loss of 149.41:\n",
212 | "epoch 3 had a loss of 157.26:\n",
213 | "epoch 4 had a loss of 136.89:\n",
214 | "epoch 5 had a loss of 146.05:\n",
215 | "epoch 6 had a loss of 128.16:\n",
216 | "correct: 1506, total: 2500\n",
217 | "validation accuracy: 60.24\n",
218 | "epoch 7 had a loss of 138.26:\n",
219 | "epoch 8 had a loss of 122.53:\n",
220 | "epoch 9 had a loss of 131.25:\n",
221 | "epoch 10 had a loss of 116.95:\n",
222 | "epoch 11 had a loss of 125.69:\n",
223 | "correct: 1515, total: 2500\n",
224 | "validation accuracy: 60.60\n",
225 | "epoch 12 had a loss of 113.37:\n",
226 | "epoch 13 had a loss of 120.07:\n",
227 | "epoch 14 had a loss of 106.46:\n",
228 | "epoch 15 had a loss of 116.02:\n",
229 | "epoch 16 had a loss of 102.44:\n",
230 | "correct: 1567, total: 2500\n",
231 | "validation accuracy: 62.68\n",
232 | "epoch 17 had a loss of 111.7:\n",
233 | "epoch 18 had a loss of 99.776:\n",
234 | "epoch 19 had a loss of 109.25:\n",
235 | "epoch 20 had a loss of 96.52:\n",
236 | "epoch 21 had a loss of 104.85:\n",
237 | "correct: 1563, total: 2500\n",
238 | "validation accuracy: 62.52\n",
239 | "epoch 22 had a loss of 91.543:\n",
240 | "epoch 23 had a loss of 100.93:\n",
241 | "epoch 24 had a loss of 89.977:\n",
242 | "epoch 25 had a loss of 97.779:\n",
243 | "epoch 26 had a loss of 85.79:\n",
244 | "correct: 1630, total: 2500\n",
245 | "validation accuracy: 65.20\n",
246 | "epoch 27 had a loss of 94.657:\n",
247 | "epoch 28 had a loss of 82.805:\n",
248 | "epoch 29 had a loss of 91.25:\n",
249 | "epoch 30 had a loss of 81.997:\n",
250 | "epoch 31 had a loss of 90.426:\n",
251 | "correct: 1589, total: 2500\n",
252 | "validation accuracy: 63.56\n",
253 | "epoch 32 had a loss of 78.93:\n",
254 | "epoch 33 had a loss of 86.228:\n",
255 | "epoch 34 had a loss of 76.29:\n",
256 | "epoch 35 had a loss of 83.203:\n",
257 | "epoch 36 had a loss of 74.013:\n",
258 | "correct: 1689, total: 2500\n",
259 | "validation accuracy: 67.56\n",
260 | "epoch 37 had a loss of 82.516:\n",
261 | "epoch 38 had a loss of 72.502:\n",
262 | "epoch 39 had a loss of 79.541:\n",
263 | "epoch 40 had a loss of 70.698:\n",
264 | "epoch 41 had a loss of 76.423:\n",
265 | "correct: 1676, total: 2500\n",
266 | "validation accuracy: 67.04\n",
267 | "epoch 42 had a loss of 66.121:\n",
268 | "epoch 43 had a loss of 74.163:\n",
269 | "epoch 44 had a loss of 65.6:\n",
270 | "epoch 45 had a loss of 72.702:\n",
271 | "epoch 46 had a loss of 63.273:\n",
272 | "correct: 1744, total: 2500\n",
273 | "validation accuracy: 69.76\n",
274 | "epoch 47 had a loss of 70.441:\n",
275 | "epoch 48 had a loss of 61.714:\n",
276 | "epoch 49 had a loss of 69.674:\n",
277 | "epoch 50 had a loss of 60.041:\n",
278 | "epoch 51 had a loss of 67.995:\n",
279 | "correct: 1672, total: 2500\n",
280 | "validation accuracy: 66.88\n",
281 | "epoch 52 had a loss of 58.486:\n",
282 | "epoch 53 had a loss of 67.985:\n",
283 | "epoch 54 had a loss of 56.784:\n",
284 | "epoch 55 had a loss of 65.629:\n",
285 | "epoch 56 had a loss of 55.778:\n",
286 | "correct: 1778, total: 2500\n",
287 | "validation accuracy: 71.12\n",
288 | "epoch 57 had a loss of 63.425:\n",
289 | "epoch 58 had a loss of 54.342:\n",
290 | "epoch 59 had a loss of 60.929:\n",
291 | "epoch 60 had a loss of 52.447:\n",
292 | "epoch 61 had a loss of 59.177:\n",
293 | "correct: 1784, total: 2500\n",
294 | "validation accuracy: 71.36\n",
295 | "epoch 62 had a loss of 50.707:\n",
296 | "epoch 63 had a loss of 59.645:\n",
297 | "epoch 64 had a loss of 48.444:\n",
298 | "epoch 65 had a loss of 58.692:\n",
299 | "epoch 66 had a loss of 47.309:\n",
300 | "correct: 1796, total: 2500\n",
301 | "validation accuracy: 71.84\n",
302 | "epoch 67 had a loss of 56.025:\n",
303 | "epoch 68 had a loss of 46.815:\n",
304 | "epoch 69 had a loss of 57.314:\n",
305 | "epoch 70 had a loss of 45.563:\n",
306 | "epoch 71 had a loss of 54.548:\n",
307 | "correct: 1807, total: 2500\n",
308 | "validation accuracy: 72.28\n",
309 | "epoch 72 had a loss of 43.947:\n",
310 | "epoch 73 had a loss of 51.221:\n",
311 | "epoch 74 had a loss of 43.44:\n",
312 | "epoch 75 had a loss of 50.027:\n",
313 | "epoch 76 had a loss of 41.055:\n",
314 | "correct: 1806, total: 2500\n",
315 | "validation accuracy: 72.24\n",
316 | "epoch 77 had a loss of 49.864:\n",
317 | "epoch 78 had a loss of 40.638:\n",
318 | "epoch 79 had a loss of 47.843:\n",
319 | "epoch 80 had a loss of 40.034:\n",
320 | "epoch 81 had a loss of 46.633:\n",
321 | "correct: 1817, total: 2500\n",
322 | "validation accuracy: 72.68\n",
323 | "epoch 82 had a loss of 38.564:\n",
324 | "epoch 83 had a loss of 45.391:\n",
325 | "epoch 84 had a loss of 38.921:\n",
326 | "epoch 85 had a loss of 44.739:\n",
327 | "epoch 86 had a loss of 36.371:\n",
328 | "correct: 1825, total: 2500\n",
329 | "validation accuracy: 73.00\n",
330 | "epoch 87 had a loss of 42.843:\n",
331 | "epoch 88 had a loss of 36.642:\n",
332 | "epoch 89 had a loss of 44.325:\n",
333 | "epoch 90 had a loss of 35.034:\n",
334 | "epoch 91 had a loss of 43.729:\n",
335 | "correct: 1835, total: 2500\n",
336 | "validation accuracy: 73.40\n",
337 | "epoch 92 had a loss of 34.83:\n",
338 | "epoch 93 had a loss of 43.054:\n",
339 | "epoch 94 had a loss of 34.262:\n",
340 | "epoch 95 had a loss of 42.408:\n",
341 | "epoch 96 had a loss of 33.244:\n",
342 | "correct: 1844, total: 2500\n",
343 | "validation accuracy: 73.76\n",
344 | "epoch 97 had a loss of 40.137:\n",
345 | "epoch 98 had a loss of 32.036:\n",
346 | "epoch 99 had a loss of 40.617:\n",
347 | "epoch 100 had a loss of 31.283:\n",
348 | "epoch 101 had a loss of 38.893:\n",
349 | "correct: 1818, total: 2500\n",
350 | "validation accuracy: 72.72\n",
351 | "epoch 102 had a loss of 30.514:\n",
352 | "epoch 103 had a loss of 36.414:\n",
353 | "epoch 104 had a loss of 29.846:\n",
354 | "epoch 105 had a loss of 35.484:\n",
355 | "epoch 106 had a loss of 28.688:\n",
356 | "correct: 1844, total: 2500\n",
357 | "validation accuracy: 73.76\n",
358 | "epoch 107 had a loss of 32.477:\n",
359 | "epoch 108 had a loss of 27.244:\n",
360 | "epoch 109 had a loss of 33.256:\n",
361 | "epoch 110 had a loss of 26.222:\n",
362 | "epoch 111 had a loss of 31.101:\n",
363 | "correct: 1846, total: 2500\n",
364 | "validation accuracy: 73.84\n",
365 | "epoch 112 had a loss of 26.721:\n",
366 | "epoch 113 had a loss of 30.268:\n",
367 | "epoch 114 had a loss of 25.618:\n",
368 | "epoch 115 had a loss of 28.931:\n",
369 | "epoch 116 had a loss of 25.284:\n",
370 | "correct: 1859, total: 2500\n",
371 | "validation accuracy: 74.36\n",
372 | "epoch 117 had a loss of 28.243:\n",
373 | "epoch 118 had a loss of 24.153:\n",
374 | "epoch 119 had a loss of 28.069:\n",
375 | "epoch 120 had a loss of 23.269:\n",
376 | "epoch 121 had a loss of 27.448:\n",
377 | "correct: 1866, total: 2500\n",
378 | "validation accuracy: 74.64\n",
379 | "epoch 122 had a loss of 22.092:\n",
380 | "epoch 123 had a loss of 25.823:\n",
381 | "epoch 124 had a loss of 22.316:\n",
382 | "epoch 125 had a loss of 26.363:\n",
383 | "epoch 126 had a loss of 22.562:\n",
384 | "correct: 1869, total: 2500\n",
385 | "validation accuracy: 74.76\n",
386 | "epoch 127 had a loss of 25.361:\n",
387 | "epoch 128 had a loss of 21.893:\n",
388 | "epoch 129 had a loss of 23.67:\n",
389 | "epoch 130 had a loss of 20.472:\n",
390 | "epoch 131 had a loss of 24.869:\n",
391 | "correct: 1872, total: 2500\n",
392 | "validation accuracy: 74.88\n",
393 | "epoch 132 had a loss of 20.084:\n",
394 | "epoch 133 had a loss of 23.902:\n",
395 | "epoch 134 had a loss of 20.17:\n",
396 | "epoch 135 had a loss of 24.072:\n",
397 | "epoch 136 had a loss of 19.705:\n",
398 | "correct: 1873, total: 2500\n",
399 | "validation accuracy: 74.92\n",
400 | "epoch 137 had a loss of 23.764:\n",
401 | "epoch 138 had a loss of 20.042:\n",
402 | "epoch 139 had a loss of 24.471:\n",
403 | "epoch 140 had a loss of 19.268:\n",
404 | "epoch 141 had a loss of 26.174:\n",
405 | "correct: 1865, total: 2500\n",
406 | "validation accuracy: 74.60\n",
407 | "epoch 142 had a loss of 19.303:\n",
408 | "epoch 143 had a loss of 26.392:\n",
409 | "epoch 144 had a loss of 20.878:\n",
410 | "epoch 145 had a loss of 29.359:\n",
411 | "epoch 146 had a loss of 19.566:\n",
412 | "correct: 1882, total: 2500\n",
413 | "validation accuracy: 75.28\n",
414 | "epoch 147 had a loss of 29.096:\n",
415 | "epoch 148 had a loss of 17.609:\n",
416 | "epoch 149 had a loss of 28.129:\n",
417 | "epoch 150 had a loss of 16.877:\n",
418 | "epoch 151 had a loss of 25.728:\n",
419 | "correct: 1887, total: 2500\n",
420 | "validation accuracy: 75.48\n",
421 | "epoch 152 had a loss of 16.436:\n",
422 | "epoch 153 had a loss of 24.018:\n",
423 | "epoch 154 had a loss of 15.716:\n",
424 | "epoch 155 had a loss of 20.792:\n",
425 | "epoch 156 had a loss of 14.76:\n",
426 | "correct: 1888, total: 2500\n",
427 | "validation accuracy: 75.52\n",
428 | "epoch 157 had a loss of 19.901:\n",
429 | "epoch 158 had a loss of 14.685:\n",
430 | "epoch 159 had a loss of 20.57:\n",
431 | "epoch 160 had a loss of 14.596:\n",
432 | "epoch 161 had a loss of 20.118:\n",
433 | "correct: 1909, total: 2500\n",
434 | "validation accuracy: 76.36\n",
435 | "epoch 162 had a loss of 14.506:\n",
436 | "epoch 163 had a loss of 18.156:\n",
437 | "epoch 164 had a loss of 14.234:\n",
438 | "epoch 165 had a loss of 18.269:\n",
439 | "epoch 166 had a loss of 14.057:\n",
440 | "correct: 1897, total: 2500\n",
441 | "validation accuracy: 75.88\n",
442 | "epoch 167 had a loss of 16.399:\n",
443 | "epoch 168 had a loss of 13.191:\n",
444 | "epoch 169 had a loss of 16.284:\n",
445 | "epoch 170 had a loss of 12.713:\n",
446 | "epoch 171 had a loss of 14.768:\n",
447 | "correct: 1897, total: 2500\n",
448 | "validation accuracy: 75.88\n",
449 | "epoch 172 had a loss of 12.85:\n",
450 | "epoch 173 had a loss of 14.618:\n",
451 | "epoch 174 had a loss of 12.6:\n",
452 | "epoch 175 had a loss of 13.754:\n",
453 | "epoch 176 had a loss of 12.1:\n",
454 | "correct: 1896, total: 2500\n",
455 | "validation accuracy: 75.84\n",
456 | "epoch 177 had a loss of 14.189:\n",
457 | "epoch 178 had a loss of 12.096:\n",
458 | "epoch 179 had a loss of 13.668:\n",
459 | "epoch 180 had a loss of 13.047:\n",
460 | "epoch 181 had a loss of 14.236:\n",
461 | "correct: 1884, total: 2500\n",
462 | "validation accuracy: 75.36\n",
463 | "epoch 182 had a loss of 12.84:\n",
464 | "epoch 183 had a loss of 12.927:\n",
465 | "epoch 184 had a loss of 12.742:\n",
466 | "epoch 185 had a loss of 12.622:\n",
467 | "epoch 186 had a loss of 13.217:\n",
468 | "correct: 1904, total: 2500\n",
469 | "validation accuracy: 76.16\n",
470 | "epoch 187 had a loss of 13.594:\n",
471 | "epoch 188 had a loss of 12.24:\n",
472 | "epoch 189 had a loss of 12.772:\n",
473 | "epoch 190 had a loss of 12.673:\n",
474 | "epoch 191 had a loss of 12.067:\n",
475 | "correct: 1835, total: 2500\n",
476 | "validation accuracy: 73.40\n"
477 | ]
478 | },
479 | {
480 | "name": "stdout",
481 | "output_type": "stream",
482 | "text": [
483 | "epoch 192 had a loss of 13.577:\n",
484 | "epoch 193 had a loss of 13.533:\n",
485 | "epoch 194 had a loss of 13.713:\n",
486 | "epoch 195 had a loss of 12.667:\n",
487 | "epoch 196 had a loss of 11.576:\n",
488 | "correct: 1915, total: 2500\n",
489 | "validation accuracy: 76.60\n",
490 | "epoch 197 had a loss of 12.638:\n",
491 | "epoch 198 had a loss of 10.731:\n",
492 | "epoch 199 had a loss of 12.506:\n",
493 | "epoch 200 had a loss of 10.383:\n",
494 | "epoch 201 had a loss of 12.581:\n",
495 | "correct: 1844, total: 2500\n",
496 | "validation accuracy: 73.76\n",
497 | "epoch 202 had a loss of 12.099:\n",
498 | "epoch 203 had a loss of 13.259:\n",
499 | "epoch 204 had a loss of 10.19:\n",
500 | "epoch 205 had a loss of 14.359:\n",
501 | "epoch 206 had a loss of 10.284:\n",
502 | "correct: 1917, total: 2500\n",
503 | "validation accuracy: 76.68\n",
504 | "epoch 207 had a loss of 14.6:\n",
505 | "epoch 208 had a loss of 10.931:\n",
506 | "epoch 209 had a loss of 12.596:\n",
507 | "epoch 210 had a loss of 10.348:\n",
508 | "epoch 211 had a loss of 11.99:\n",
509 | "correct: 1856, total: 2500\n",
510 | "validation accuracy: 74.24\n",
511 | "epoch 212 had a loss of 11.241:\n",
512 | "epoch 213 had a loss of 11.915:\n",
513 | "epoch 214 had a loss of 11.312:\n",
514 | "epoch 215 had a loss of 10.298:\n",
515 | "epoch 216 had a loss of 10.769:\n",
516 | "correct: 1921, total: 2500\n",
517 | "validation accuracy: 76.84\n",
518 | "epoch 217 had a loss of 9.7846:\n",
519 | "epoch 218 had a loss of 10.23:\n",
520 | "epoch 219 had a loss of 8.9255:\n",
521 | "epoch 220 had a loss of 9.081:\n",
522 | "epoch 221 had a loss of 8.568:\n",
523 | "correct: 1853, total: 2500\n",
524 | "validation accuracy: 74.12\n",
525 | "epoch 222 had a loss of 8.9211:\n",
526 | "epoch 223 had a loss of 8.0079:\n",
527 | "epoch 224 had a loss of 9.1798:\n",
528 | "epoch 225 had a loss of 7.9771:\n",
529 | "epoch 226 had a loss of 8.3975:\n",
530 | "correct: 1921, total: 2500\n",
531 | "validation accuracy: 76.84\n",
532 | "epoch 227 had a loss of 9.0924:\n",
533 | "epoch 228 had a loss of 8.2583:\n",
534 | "epoch 229 had a loss of 8.2822:\n",
535 | "epoch 230 had a loss of 7.976:\n",
536 | "epoch 231 had a loss of 9.0271:\n",
537 | "correct: 1854, total: 2500\n",
538 | "validation accuracy: 74.16\n",
539 | "epoch 232 had a loss of 8.3784:\n",
540 | "epoch 233 had a loss of 8.3691:\n",
541 | "epoch 234 had a loss of 8.1224:\n",
542 | "epoch 235 had a loss of 8.0464:\n",
543 | "epoch 236 had a loss of 7.1277:\n",
544 | "correct: 1926, total: 2500\n",
545 | "validation accuracy: 77.04\n",
546 | "epoch 237 had a loss of 8.9603:\n",
547 | "epoch 238 had a loss of 8.8157:\n",
548 | "epoch 239 had a loss of 9.0398:\n",
549 | "epoch 240 had a loss of 6.0725:\n",
550 | "epoch 241 had a loss of 9.0863:\n",
551 | "correct: 1908, total: 2500\n",
552 | "validation accuracy: 76.32\n",
553 | "epoch 242 had a loss of 6.1797:\n",
554 | "epoch 243 had a loss of 8.3911:\n",
555 | "epoch 244 had a loss of 6.0233:\n",
556 | "epoch 245 had a loss of 7.594:\n",
557 | "epoch 246 had a loss of 5.2482:\n",
558 | "correct: 1922, total: 2500\n",
559 | "validation accuracy: 76.88\n",
560 | "epoch 247 had a loss of 8.3338:\n",
561 | "epoch 248 had a loss of 6.1232:\n",
562 | "epoch 249 had a loss of 8.2647:\n",
563 | "epoch 250 had a loss of 6.3322:\n"
564 | ]
565 | }
566 | ],
567 | "source": [
568 | "class SingleHiddenNN(nn.Module):\n",
569 | " def __init__(self, vocab_size, max_len, embed_elems, batch_size):\n",
570 | " super(SingleHiddenNN, self).__init__()\n",
571 | " self.vocab_size = vocab_size\n",
572 | " self.embed_elems = embed_elems\n",
573 | " self.max_len = max_len\n",
574 | " self.emb = nn.Embedding(self.vocab_size+1, self.embed_elems)\n",
575 | " self.fc = nn.Linear(int(self.max_len * self.embed_elems), 100)\n",
576 | " self.relu = nn.SELU()\n",
577 | " self.dropout = nn.Dropout(0.7)\n",
578 | " self.out = nn.Linear(100, 1)\n",
579 | " self.sigmoid = nn.Sigmoid()\n",
580 | " def forward(self, input):\n",
581 | " x = self.emb(input)\n",
582 | " x = x.view(input.size(0), -1)\n",
583 | " x = self.fc(x)\n",
584 | " x = self.relu(x)\n",
585 | " x = self.dropout(x)\n",
586 | " x = self.out(x)\n",
587 | " x = self.sigmoid(x)\n",
588 | " return x.view(-1)\n",
589 | "\n",
590 | "print(ts.min(), ts.max())\n",
591 | "print(ts.size())\n",
592 | "model = SingleHiddenNN(len(vocab), 500, 32, 100)\n",
593 | "print(model)\n",
594 | "criterion = nn.BCELoss()\n",
595 | "optimizer = []\n",
596 | "optimizer += [torch.optim.Adam(model.parameters(), lr=0.0001)]\n",
597 | "optimizer += [torch.optim.SGD(model.parameters(), lr=0.0001, momentum=0.9)]\n",
598 | "epochs = 250\n",
599 | "for epoch in range(epochs):\n",
600 | " model.train()\n",
601 | " running_loss = 0\n",
602 | " for i, (mb, tgts) in enumerate(dlts):\n",
603 | " model.zero_grad()\n",
604 | " mb, tgts = torch.autograd.Variable(mb), torch.autograd.Variable(tgts.float())\n",
605 | " out = model(mb)\n",
606 | " loss = criterion(out, tgts)\n",
607 | " loss.backward()\n",
608 | " opt_idx = epoch % 2\n",
609 | " optimizer[opt_idx].step()\n",
610 | " running_loss += loss.data[0]\n",
611 | " print(\"epoch {} had a loss of {:.5}:\".format(epoch+1, running_loss))\n",
612 | " if epoch > 0 and epoch % 5 == 0:\n",
613 | " model.eval()\n",
614 | " correct = 0\n",
615 | " for vmb, vtgts in dlvs:\n",
616 | " vmb, vtgts = torch.autograd.Variable(vmb), torch.autograd.Variable(vtgts.float())\n",
617 | " vout = model(vmb)\n",
618 | " vpred = vout.round()\n",
619 | " correct += (vpred == vtgts).data.sum()\n",
620 | " print(\"correct: {}, total: {}\".format(correct, len(vs)))\n",
621 | " print(\"validation accuracy: {:.2f}\".format(100.*correct/len(vs)))\n"
622 | ]
623 | },
624 | {
625 | "cell_type": "code",
626 | "execution_count": 56,
627 | "metadata": {
628 | "scrolled": false
629 | },
630 | "outputs": [
631 | {
632 | "name": "stdout",
633 | "output_type": "stream",
634 | "text": [
635 | "2019 0.8076 2500\n",
636 | "\n",
637 | " 0 0\n",
638 | " 1 1\n",
639 | " 0 0\n",
640 | " 0 1\n",
641 | " 1 1\n",
642 | " 0 0\n",
643 | " 0 0\n",
644 | " 1 0\n",
645 | " 0 0\n",
646 | " 1 1\n",
647 | " 0 0\n",
648 | " 1 0\n",
649 | " 1 1\n",
650 | " 0 0\n",
651 | " 1 0\n",
652 | " 0 0\n",
653 | " 0 0\n",
654 | " 1 0\n",
655 | " 0 0\n",
656 | " 1 0\n",
657 | " 0 0\n",
658 | " 1 1\n",
659 | " 0 1\n",
660 | " 1 1\n",
661 | " 0 0\n",
662 | " 0 0\n",
663 | " 1 1\n",
664 | " 0 1\n",
665 | " 1 1\n",
666 | " 1 1\n",
667 | " 1 0\n",
668 | " 1 1\n",
669 | " 0 0\n",
670 | " 0 0\n",
671 | " 0 0\n",
672 | " 0 0\n",
673 | " 0 1\n",
674 | " 0 1\n",
675 | " 1 1\n",
676 | " 1 0\n",
677 | " 0 0\n",
678 | " 1 1\n",
679 | " 1 0\n",
680 | " 1 1\n",
681 | " 0 1\n",
682 | " 0 0\n",
683 | " 1 1\n",
684 | " 0 1\n",
685 | " 0 0\n",
686 | " 0 0\n",
687 | " 0 0\n",
688 | " 1 1\n",
689 | " 0 0\n",
690 | " 1 1\n",
691 | " 0 0\n",
692 | " 0 0\n",
693 | " 0 1\n",
694 | " 1 1\n",
695 | " 1 1\n",
696 | " 1 1\n",
697 | " 1 0\n",
698 | " 0 0\n",
699 | " 1 1\n",
700 | " 0 0\n",
701 | " 0 0\n",
702 | " 0 0\n",
703 | " 0 0\n",
704 | " 1 0\n",
705 | " 1 1\n",
706 | " 0 0\n",
707 | " 0 0\n",
708 | " 0 0\n",
709 | " 1 1\n",
710 | " 1 1\n",
711 | " 1 1\n",
712 | " 1 1\n",
713 | " 0 1\n",
714 | " 0 0\n",
715 | " 1 1\n",
716 | " 0 1\n",
717 | " 0 0\n",
718 | " 0 0\n",
719 | " 1 1\n",
720 | " 1 1\n",
721 | " 0 0\n",
722 | " 1 1\n",
723 | " 1 1\n",
724 | " 0 0\n",
725 | " 1 1\n",
726 | " 0 0\n",
727 | " 1 0\n",
728 | " 0 0\n",
729 | " 1 1\n",
730 | " 1 1\n",
731 | " 0 0\n",
732 | " 0 0\n",
733 | " 0 0\n",
734 | " 0 0\n",
735 | " 0 0\n",
736 | " 1 1\n",
737 | "[torch.FloatTensor of size 100x2]\n",
738 | "\n"
739 | ]
740 | }
741 | ],
742 | "source": [
743 | "correct = 0\n",
744 | "for i, (mb, tgts) in enumerate(dlvs):\n",
745 | " mb, tgts = torch.autograd.Variable(mb), torch.autograd.Variable(tgts.float())\n",
746 | " out = model(mb)\n",
747 | " pred = out.round()\n",
748 | " correct += (pred == tgts).data.sum()\n",
749 | "print(correct, correct / len(vs), len(vs))\n",
750 | "print(torch.stack((pred.data, tgts.data), 1))"
751 | ]
752 | },
753 | {
754 | "cell_type": "code",
755 | "execution_count": 53,
756 | "metadata": {
757 | "collapsed": true
758 | },
759 | "outputs": [],
760 | "source": [
761 | "torch.save(model.state_dict(), \"model_imdb_20170912.pt\")"
762 | ]
763 | },
764 | {
765 | "cell_type": "markdown",
766 | "metadata": {},
767 | "source": [
768 | "## torchtext"
769 | ]
770 | },
771 | {
772 | "cell_type": "code",
773 | "execution_count": null,
774 | "metadata": {
775 | "collapsed": true
776 | },
777 | "outputs": [],
778 | "source": [
779 | "import torchtext\n",
780 | "import torchtext.data as ttdata\n",
781 | "TEXT = ttdata.Field()\n",
782 | "LABEL = ttdata.Field(sequential=False)\n",
783 | "imdb_ds = torchtext.datasets.IMDB(\"/home/david/imdb_sentiment/data\", TEXT, LABEL)\n",
784 | "train_iter, test_iter = imdb_ds.iters(batch_size=4, device=-1)"
785 | ]
786 | },
787 | {
788 | "cell_type": "code",
789 | "execution_count": null,
790 | "metadata": {
791 | "collapsed": true
792 | },
793 | "outputs": [],
794 | "source": [
795 | "train_iter, test_iter = imdb_ds.iters(batch_size=25, device=-1)\n",
796 | "for x in train_iter:\n",
797 | " print(x.text, x.label.size())\n",
798 | " break\n",
799 | " "
800 | ]
801 | },
802 | {
803 | "cell_type": "code",
804 | "execution_count": 52,
805 | "metadata": {},
806 | "outputs": [
807 | {
808 | "name": "stdout",
809 | "output_type": "stream",
810 | "text": [
811 | "aae_supervised.py notes.md\r\n",
812 | "\u001b[0m\u001b[01;34malgore\u001b[0m/ numpy_reshape_test.ipynb\r\n",
813 | "AlGore_2009.sph pad_test.py\r\n",
814 | "AlGore_2009.stm \u001b[01;34mpcsnpny-20150204-mkj\u001b[0m/\r\n",
815 | "audio_rnn_basic.ipynb \u001b[01;31mpcsnpny-20150204-mkj.tgz\u001b[0m\r\n",
816 | "\u001b[01;35mclipmin.png\u001b[0m \u001b[00;36mpiano2.mp3\u001b[0m\r\n",
817 | "CNN2RNN.ipynb \u001b[00;36mpiano.mp3\u001b[0m\r\n",
818 | "collate_variable.py \u001b[00;36mpiano_new.wav\u001b[0m\r\n",
819 | "\u001b[01;34mdata\u001b[0m/ playground.ipynb\r\n",
820 | "\u001b[01;31mdata.zip\u001b[0m predict_audio.ipynb\r\n",
821 | "deepspeech1d.ipynb Presentation.ipynb\r\n",
822 | "denoising_autoencoder.ipynb prime_factors.py\r\n",
823 | "extract_mnist.py pyaudio-test.py\r\n",
824 | "\u001b[00;36mfile2.wav\u001b[0m \u001b[01;34m__pycache__\u001b[0m/\r\n",
825 | "\u001b[00;36mfile.flac\u001b[0m pytorch_basics.ipynb\r\n",
826 | "\u001b[00;36mfile.mp3\u001b[0m PyTorch Embeddings Test.ipynb\r\n",
827 | "\u001b[00;36mfile.wav\u001b[0m pytorch_tutorial_classify_names.ipynb\r\n",
828 | "\u001b[01;34mfrancemusique\u001b[0m/ rnn_autoencoder.ipynb\r\n",
829 | "G729VAD.ipynb \u001b[01;35mrnn_beispiel1.png\u001b[0m\r\n",
830 | "GRUAutoencoder.ipynb \u001b[01;35mrnn_beispiel2.png\u001b[0m\r\n",
831 | "\u001b[00;36mhallöchen.wav\u001b[0m \u001b[01;35mrnn_predictions_final.png\u001b[0m\r\n",
832 | "\u001b[01;32mhelloworld.py\u001b[0m* \u001b[01;35mrnn_predictions.png\u001b[0m\r\n",
833 | "imdb_tokenize.ipynb \u001b[01;35msmoother1.png\u001b[0m\r\n",
834 | "\u001b[01;35mkmeans1.png\u001b[0m \u001b[01;34msnipsdata\u001b[0m/\r\n",
835 | "\u001b[01;35mkmeans2.png\u001b[0m Snips_Report.ipynb\r\n",
836 | "label_smoothing.ipynb \u001b[01;34mstarters\u001b[0m/\r\n",
837 | "levinson.py startup.sh\r\n",
838 | "librispeech_load_txts.ipynb stm_sph_processing.ipynb\r\n",
839 | "librispeech.py \u001b[01;35mtedlium1.png\u001b[0m\r\n",
840 | "loader.py \u001b[01;35mtedlium2.png\u001b[0m\r\n",
841 | "\u001b[01;34mmini_librispeech_dataset\u001b[0m/ tedlium_labels.ipynb\r\n",
842 | "\u001b[01;35mmlp_beispiel_noise1.png\u001b[0m test_argparse.py\r\n",
843 | "\u001b[01;35mmlp_beispiel_noise2.png\u001b[0m test_nltk.py\r\n",
844 | "\u001b[01;35mmlp_beispiel_nonoise1.png\u001b[0m test_spacy.py\r\n",
845 | "\u001b[01;35mmlp_beispiel_nonoise2.png\u001b[0m timeseries.ipynb\r\n",
846 | "\u001b[01;35mmlp_predictions_noise.png\u001b[0m torchaudio\r\n",
847 | "\u001b[01;35mmlp_predictions.png\u001b[0m torchaudio_vs_librosa.ipynb\r\n",
848 | "\u001b[01;34mMNIST\u001b[0m/ VAD_labeling.ipynb\r\n",
849 | "model_20170912.pt wavelet.ipynb\r\n",
850 | "\u001b[01;34mmodels\u001b[0m/ yesno_torchaudio_playground.ipynb\r\n",
851 | "mu_law_companding.ipynb zero2d_test.py\r\n"
852 | ]
853 | }
854 | ],
855 | "source": [
856 | "%ls"
857 | ]
858 | }
859 | ],
860 | "metadata": {
861 | "kernelspec": {
862 | "display_name": "Python 3",
863 | "language": "python",
864 | "name": "python3"
865 | },
866 | "language_info": {
867 | "codemirror_mode": {
868 | "name": "ipython",
869 | "version": 3
870 | },
871 | "file_extension": ".py",
872 | "mimetype": "text/x-python",
873 | "name": "python",
874 | "nbconvert_exporter": "python",
875 | "pygments_lexer": "ipython3",
876 | "version": "3.6.1"
877 | }
878 | },
879 | "nbformat": 4,
880 | "nbformat_minor": 2
881 | }
882 |
--------------------------------------------------------------------------------
/pytorch_attention_audio.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torch.autograd import Variable
6 | from torch.optim import lr_scheduler
7 | import torch.utils.data as data
8 | from torch.nn.utils.rnn import pack_padded_sequence as pack, pad_packed_sequence as unpack
9 | import torchaudio
10 | import torchaudio.transforms as tat
11 | import numpy as np
12 | import os
13 | import glob
14 | import matplotlib
15 | matplotlib.use('Agg')
16 | import matplotlib.pyplot as plt
17 |
18 | from pytorch_audio_utils import *
19 |
20 | parser = argparse.ArgumentParser(description='PyTorch Language ID Classifier Trainer')
21 | parser.add_argument('--epochs', type=int, default=5,
22 | help='upper epoch limit')
23 | parser.add_argument('--batch-size', type=int, default=6,
24 | help='batch size')
25 | parser.add_argument('--window-size', type=int, default=200,
26 | help='size of fft window')
27 | parser.add_argument('--validate', action='store_true',
28 | help='do out-of-bag validation')
29 | parser.add_argument('--log-interval', type=int, default=5,
30 | help='reports per epoch')
31 | parser.add_argument('--load-model', type=str, default=None,
32 | help='path of model to load')
33 | parser.add_argument('--save-model', action='store_true',
34 | help='path to save the final model')
35 | parser.add_argument('--train-full-model', action='store_true',
36 | help='train full model vs. final layer')
37 | args = parser.parse_args()
38 |
39 | class Preemphasis(object):
40 | """Perform preemphasis on signal
41 |
42 | y = x[n] - α*x[n-1]
43 |
44 | Args:
45 | alpha (float): preemphasis coefficient
46 |
47 | """
48 |
49 | def __init__(self, alpha=0.97):
50 | self.alpha = alpha
51 |
52 | def __call__(self, sig):
53 | """
54 |
55 | Args:
56 | sig (Tensor): Tensor of audio of size (Samples x Channels)
57 |
58 | Returns:
59 | sig (Tensor): Preemphasized. See equation above.
60 |
61 | """
62 | if self.alpha == 0:
63 | return sig
64 | else:
65 | sig[1:, :] -= self.alpha * sig[:-1, :]
66 | return sig
67 |
68 | class RfftPow(object):
69 | """This function emulates power of the discrete fourier transform.
70 |
71 | Note: this implementation may not be numerically stable
72 |
73 | Args:
74 | K (int): number of fft freq bands
75 |
76 | """
77 |
78 | def __init__(self, K=None):
79 | self.K = K
80 |
81 | def __call__(self, sig):
82 | """
83 |
84 | Args:
85 | sig (Tensor): Tensor of audio of size (Samples x Channels)
86 |
87 | Returns:
88 | S (Tensor): spectrogram
89 |
90 | """
91 | N = sig.size(1)
92 | if self.K is None:
93 | K = N
94 | else:
95 | K = self.K
96 |
97 | k_vec = torch.arange(0, K).unsqueeze(0)
98 | n_vec = torch.arange(0, N).unsqueeze(1)
99 | angular_pt = 2 * np.pi * k_vec * n_vec / K
100 | S = torch.sqrt(torch.matmul(sig, angular_pt.cos())**2 + \
101 | torch.matmul(sig, angular_pt.sin())**2)
102 | S = S.squeeze()[:(K//2+1)]
103 | S = (1 / K) * S**2
104 | return S
105 |
106 | class FilterBanks(object):
107 | """Bins a periodogram from K fft frequency bands into N bins (banks)
108 |
109 | fft bands (K//2+1) -> filterbanks (n_filterbanks) -> bins (bins)
110 |
111 | Args:
112 | n_filterbanks (int): number of filterbanks
113 | bins (list): number of bins
114 |
115 | """
116 |
117 | def __init__(self, n_filterbanks, bins):
118 | self.n_filterbanks = n_filterbanks
119 | self.bins = bins
120 |
121 | def __call__(self, S):
122 | """
123 |
124 | Args:
125 | S (Tensor): Tensor of Spectro- / Periodogram
126 |
127 | Returns:
128 | fb (Tensor): binned filterbanked spectrogram
129 |
130 | """
131 | conversion_factor = np.log(10) # torch.log10 doesn't exist
132 | K = S.size(0)
133 | fb_mat = torch.zeros((self.n_filterbanks, K))
134 | for m in range(1, self.n_filterbanks+1):
135 | f_m_minus = int(self.bins[m - 1])
136 | f_m = int(self.bins[m])
137 | f_m_plus = int(self.bins[m + 1])
138 |
139 | fb_mat[m - 1, f_m_minus:f_m] = (torch.arange(f_m_minus, f_m) - f_m_minus) / (f_m - f_m_minus)
140 | fb_mat[m - 1, f_m:f_m_plus] = (f_m_plus - torch.arange(f_m, f_m_plus)) / (f_m_plus - f_m)
141 | fb = torch.matmul(S, fb_mat.t())
142 | fb = 20 * torch.log(fb) / conversion_factor
143 | return fb
144 |
145 | class MFCC(object):
146 | """Discrete Cosine Transform
147 |
148 | There are three types of the DCT. This is 'Type 2' as described in the scipy docs.
149 |
150 | filterbank bins (bins) -> mfcc (mfcc)
151 |
152 | Args:
153 | n_filterbanks (int): number of filterbanks
154 | n_coeffs (int): number of mfc coefficients to keep
155 | mode (str): orthogonal transformation
156 |
157 | """
158 |
159 | def __init__(self, n_filterbanks, n_coeffs, mode="ortho"):
160 | self.n_filterbanks = n_filterbanks
161 | self.n_coeffs = n_coeffs
162 | self.mode = "ortho"
163 |
164 | def __call__(self, fb):
165 | """
166 |
167 | Args:
168 | fb (Tensor): Tensor of binned filterbanked spectrogram
169 |
170 | Returns:
171 | mfcc (Tensor): Tensor of mfcc coefficients
172 |
173 | """
174 | K = self.n_filterbanks
175 | k_vec = torch.arange(0, K).unsqueeze(0)
176 | n_vec = torch.arange(0, self.n_filterbanks).unsqueeze(1)
177 | angular_pt = np.pi * k_vec * ((2*n_vec+1) / (2*K))
178 | mfcc = 2 * torch.matmul(fb, angular_pt.cos())
179 | if self.mode == "ortho":
180 | mfcc[0] *= np.sqrt(1/(4*self.n_filterbanks))
181 | mfcc[1:] *= np.sqrt(1/(2*self.n_filterbanks))
182 | return mfcc[1:(self.n_coeffs+1)]
183 |
184 | class Sig2Features(object):
185 | """Get the log power, MFCCs and 1st derivatives of the signal across n hops
186 | and concatenate all that together
187 |
188 | Args:
189 | n_hops (int): number of filterbanks
190 | transformDict (dict): dict of transformations for each hop
191 |
192 | """
193 |
194 | def __init__(self, ws, hs, transformDict):
195 | self.ws = ws
196 | self.hs = hs
197 | self.td = transformDict
198 |
199 | def __call__(self, sig):
200 | """
201 |
202 | Args:
203 | sig (Tensor): Tensor of signal
204 |
205 | Returns:
206 | Feats (Tensor): Tensor of log-power, 12 mfcc coefficients and 1st devs
207 |
208 | """
209 | n_hops = (sig.size(0) - ws) // hs
210 |
211 | P = []
212 | Mfcc = []
213 |
214 | for i in range(n_hops):
215 | # create frame
216 | st = int(i * hs)
217 | end = st + ws
218 | sig_n = sig[st:end]
219 |
220 | # get power/energy
221 | P += [self.td["RfftPow"](sig_n.transpose(0, 1))]
222 |
223 | # get mfccs and filter banks
224 | fb = self.td["FilterBanks"](P[-1])
225 | Mfcc += [self.td["MFCC"](fb)]
226 |
227 | # concat and calculate derivatives
228 | P = torch.stack(P, 1)
229 | P_sum = torch.log(P.sum(0))
230 | P_dev = torch.zeros(P_sum.size())
231 | P_dev[1:] = P_sum[1:] - P_sum[:-1]
232 | Mfcc = torch.stack(Mfcc, 1)
233 | Mfcc_dev = torch.cat((torch.zeros(n_coefficients, 1), Mfcc[:,:-1] - Mfcc[:,1:]), 1)
234 | Feats = torch.cat((P_sum.unsqueeze(0), P_dev.unsqueeze(0), Mfcc, Mfcc_dev), 0)
235 | return Feats
236 |
237 | class Labeler(object):
238 | """Labels from text to int + 1
239 |
240 | """
241 |
242 | def __call__(self, labels):
243 | return torch.LongTensor([int(l)+1 for l in labels])
244 |
245 | def pad_packed_collate(batch):
246 | """Puts data, and lengths into a packed_padded_sequence then returns
247 | the packed_padded_sequence and the labels. Set use_lengths to True
248 | to use this collate function.
249 |
250 | Args:
251 | batch: (list of tuples) [(audio, target)].
252 | audio is a FloatTensor
253 | target is a LongTensor with a length of 8
254 | Output:
255 | packed_batch: (PackedSequence), see torch.nn.utils.rnn.pack_padded_sequence
256 | labels: (Tensor), labels from the file names of the wav.
257 |
258 | """
259 |
260 | if len(batch) == 1:
261 | sigs, labels = batch[0][0], batch[0][1]
262 | sigs = sigs.t()
263 | lengths = [sigs.size(0)]
264 | sigs.unsqueeze_(0)
265 | labels.unsqueeze_(0)
266 | if len(batch) > 1:
267 | sigs, labels, lengths = zip(*[(a.t(), b, a.size(1)) for (a,b) in sorted(batch, key=lambda x: x[0].size(1), reverse=True)])
268 | max_len, n_feats = sigs[0].size()
269 | sigs = [torch.cat((s, torch.zeros(max_len - s.size(0), n_feats)), 0) if s.size(0) != max_len else s for s in sigs]
270 | sigs = torch.stack(sigs, 0)
271 | labels = torch.stack(labels, 0)
272 | packed_batch = pack(Variable(sigs), lengths, batch_first=True)
273 | return packed_batch, labels
274 |
275 | def unpack_lengths(batch_sizes):
276 | """taken directly from pad_packed_sequence()
277 | """
278 | lengths = []
279 | data_offset = 0
280 | prev_batch_size = batch_sizes[0]
281 | for i, batch_size in enumerate(batch_sizes):
282 | dec = prev_batch_size - batch_size
283 | if dec > 0:
284 | lengths.extend((i,) * dec)
285 | prev_batch_size = batch_size
286 | lengths.extend((i + 1,) * batch_size)
287 | lengths.reverse()
288 | return lengths
289 |
290 | class EncoderRNN2(nn.Module):
291 | def __init__(self, input_size, hidden_size, n_layers=1, batch_size=1):
292 | super(EncoderRNN2, self).__init__()
293 | self.n_layers = n_layers
294 | self.hidden_size = hidden_size
295 | self.batch_size = batch_size
296 |
297 | self.gru = nn.GRU(input_size, hidden_size, n_layers, batch_first=True)
298 |
299 | def forward(self, input, hidden):
300 | output = input
301 | output, hidden = self.gru(output, hidden)
302 | #print("encoder:", output.size(), hidden.size())
303 | return output, hidden
304 |
305 | def initHidden(self, ttype=None):
306 | if ttype == None:
307 | ttype = torch.FloatTensor
308 | result = Variable(ttype(self.n_layers * 1, self.batch_size, self.hidden_size).fill_(0))
309 | if use_cuda:
310 | return result.cuda()
311 | else:
312 | return result
313 |
314 | class Attn(nn.Module):
315 | def __init__(self, hidden_size, batch_size=1, method="dot"):
316 | super(Attn, self).__init__()
317 |
318 | self.method = method
319 | self.hidden_size = hidden_size
320 | self.batch_size = batch_size
321 |
322 | if self.method == 'general':
323 | self.attn = nn.Linear(self.hidden_size, hidden_size, bias=False)
324 |
325 | elif self.method == 'concat':
326 | self.attn = nn.Linear(self.hidden_size * 2, hidden_size, bias=False)
327 | self.v = nn.Parameter(torch.FloatTensor(batch_size, 1, hidden_size))
328 |
329 | def forward(self, hidden, encoder_outputs):
330 | max_len = encoder_outputs.size(1)
331 |
332 | # get attn energies in one batch
333 | attn_energies = self.score(hidden, encoder_outputs)
334 |
335 | # Normalize energies to weights in range 0 to 1
336 | return F.softmax(attn_energies)
337 |
338 | def score(self, hidden, encoder_output):
339 | #print("attn.score:", hidden.size(), encoder_output.size())
340 | if self.method == 'general':
341 | energy = self.attn(encoder_output)
342 | energy = energy.transpose(2, 1)
343 | energy = hidden.bmm(energy)
344 | return energy
345 |
346 | elif self.method == 'concat':
347 | hidden = hidden * Variable(encoder_output.data.new(encoder_output.size()).fill_(1)) # broadcast hidden to encoder_outputs size
348 | energy = self.attn(torch.cat((hidden, encoder_output), -1))
349 | energy = energy.transpose(2, 1)
350 | energy = self.v.bmm(energy)
351 | return energy
352 | else:
353 | #self.method == 'dot':
354 | encoder_output = encoder_output.transpose(2, 1)
355 | energy = hidden.bmm(encoder_output)
356 | return energy
357 |
358 | class LuongAttnDecoderRNN(nn.Module):
359 | def __init__(self, hidden_size, output_size, attn_model="dot", n_layers=1, dropout=0.1, batch_size=1):
360 | super(LuongAttnDecoderRNN, self).__init__()
361 |
362 | # Keep for reference
363 | self.attn_model = attn_model
364 | self.hidden_size = hidden_size
365 | self.output_size = output_size
366 | self.n_layers = n_layers
367 | self.dropout = dropout
368 | self.batch_size = batch_size
369 |
370 | # Define layers
371 | self.gru = nn.GRU(hidden_size, hidden_size, n_layers, dropout=dropout, batch_first=True)
372 | self.concat = nn.Linear(hidden_size * 2, hidden_size)
373 | self.out = nn.Linear(hidden_size, output_size)
374 |
375 | # Choose attention model
376 | if attn_model != 'none':
377 | self.attn = Attn(hidden_size, method=attn_model, batch_size=batch_size)
378 |
379 | def forward(self, input_seq, last_hidden, encoder_outputs):
380 | # Note: This now runs in batch but was originally run one
381 | # step at a time
382 | # B = batch size
383 | # S = output length
384 | # N = # of hidden features
385 |
386 | # Get the embedding of the current input word (last output word)
387 | batch_size = input_seq.size(0)
388 |
389 | # Get current hidden state from input word and last hidden state
390 | rnn_output, hidden = self.gru(input_seq, last_hidden)
391 |
392 | # Calculate attention from current RNN state and all encoder outputs;
393 | # apply to encoder outputs to get weighted average
394 | #print("decoder:", rnn_output.size(), encoder_outputs.size())
395 | attn_weights = self.attn(rnn_output, encoder_outputs)
396 | context = attn_weights.bmm(encoder_outputs) # [B, S, L] dot [B, L, N] -> [B, S, N]
397 | print(attn_weights.size(), encoder_outputs.size(), context.size())
398 | #print("decoder context:", context.size())
399 |
400 | # Attentional vector using the RNN hidden state and context vector
401 | # concatenated together (Luong eq. 5)
402 | concat_input = torch.cat((rnn_output, context), -1) # B x S x 2*N
403 | concat_output = F.tanh(self.concat(concat_input))
404 |
405 | # Finally predict next token (Luong eq. 6, without softmax)
406 | output = self.out(concat_output)
407 |
408 | # Return final output, hidden state, and attention weights (for visualization)
409 | return output, hidden, attn_weights
410 |
411 | # train parameters
412 | epochs = args.epochs
413 |
414 | # set dataset parameters
415 | DATADIR = "/home/david/Programming/data"
416 | sr = 8000
417 | ws = args.window_size
418 | hs = ws // 2
419 | n_fft = 512 # 256
420 | n_filterbanks = 26
421 | n_coefficients = 12
422 | low_mel_freq = 0
423 | high_freq_mel = (2595 * np.log10(1 + (sr/2) / 700))
424 | mel_pts = np.linspace(low_mel_freq, high_freq_mel, n_filterbanks + 2)
425 | hz_pts = np.floor(700 * (10**(mel_pts / 2595) - 1))
426 | bins = np.floor((n_fft + 1) * hz_pts / sr)
427 |
428 | # data transformations
429 | td = {
430 | "RfftPow": RfftPow(n_fft),
431 | "FilterBanks": FilterBanks(n_filterbanks, bins),
432 | "MFCC": MFCC(n_filterbanks, n_coefficients),
433 | }
434 |
435 | transforms = tat.Compose([
436 | tat.Scale(),
437 | tat.PadTrim(58000, fill_value=1e-8),
438 | Preemphasis(),
439 | Sig2Features(ws, hs, td),
440 | ])
441 |
442 | # set network parameters
443 | use_cuda = torch.cuda.is_available()
444 | batch_size = args.batch_size
445 | input_features = 26
446 | hidden_size = 100
447 | output_size = 3
448 | #output_length = (8 + 7 + 2) # with "blanks"
449 | output_length = 8 # without blanks
450 | n_layers = 1
451 | attn_modus = "dot"
452 |
453 | # build networks, criterion, optimizers, dataset and dataloader
454 | encoder2 = EncoderRNN2(input_features, hidden_size, n_layers=n_layers, batch_size=batch_size)
455 | decoder2 = LuongAttnDecoderRNN(hidden_size, output_size, n_layers=n_layers, attn_model=attn_modus, batch_size=batch_size)
456 | print(encoder2)
457 | print(decoder2)
458 | criterion = nn.CrossEntropyLoss()
459 | optimizer = torch.optim.RMSprop([
460 | {"params": encoder2.parameters()},
461 | {"params": decoder2.parameters(), "lr": 0.0001}
462 | ], lr=0.001, momentum=0.9)
463 | scheduler = lr_scheduler.StepLR(optimizer, step_size=80, gamma=0.6)
464 | ds = torchaudio.datasets.YESNO(DATADIR, transform=transforms, target_transform=Labeler())
465 | dl = data.DataLoader(ds, batch_size=batch_size)
466 |
467 | if use_cuda:
468 | print("using CUDA: {}".format(use_cuda))
469 | encoder2 = encoder2.cuda()
470 | decoder2 = decoder2.cuda()
471 |
472 | loss_total = []
473 | # begin training
474 | for epoch in range(epochs):
475 | scheduler.step()
476 | print("epoch {}".format(epoch+1))
477 | running_loss = 0
478 | loss_epoch = []
479 | for i, (mb, tgts) in enumerate(dl):
480 | # set model into train mode and clear gradients
481 | encoder2.train()
482 | decoder2.train()
483 | encoder2.zero_grad()
484 | decoder2.zero_grad()
485 |
486 | # set inputs and targets
487 | mb = mb.transpose(2, 1) # [B x N x L] -> [B, L, N]
488 | if use_cuda:
489 | mb, tgts = mb.cuda(), tgts.cuda()
490 | mb, tgts = Variable(mb), Variable(tgts)
491 |
492 | encoder2_hidden = encoder2.initHidden(type(mb.data))
493 | encoder2_output, encoder2_hidden = encoder2(mb, encoder2_hidden)
494 | #print(encoder2_output)
495 |
496 | # Prepare input and output variables for decoder
497 | dec_i = Variable(encoder2_output.data.new([[[0] * hidden_size] * output_length] * batch_size))
498 | dec_h = encoder2_hidden # Use last (forward) hidden state from encoder
499 | #print(dec_h.size())
500 |
501 | """
502 | # Run through decoder one time step at a time
503 | # collect attentions
504 | attentions = []
505 | outputs = []
506 | dec_i = Variable(torch.FloatTensor([[[0] * hidden_size] * 1]))
507 | target_seq = Variable(torch.FloatTensor([[[-1] * hidden_size]*8]))
508 | for t in range(output_length):
509 | #print("t:", t, dec_i.size())
510 | dec_o, dec_h, dec_attn = decoder2(
511 | dec_i, dec_h, encoder2_output
512 | )
513 | #print("decoder output", dec_o.size())
514 | dec_i = target_seq[:,t].unsqueeze(1) # Next input is current target
515 | outputs += [dec_o]
516 | attentions += [dec_attn]
517 | dec_o = torch.cat(outputs, 1)
518 | dec_attn = torch.cat(attentions, 1)
519 | """
520 | # run through decoder in one shot
521 | dec_o, dec_h, dec_attn = decoder2(dec_i, dec_h, encoder2_output)
522 |
523 | # calculate loss and backprop
524 | loss = criterion(dec_o.view(-1, output_size), tgts.view(-1))
525 | running_loss += loss.data[0]
526 | loss_epoch += [loss.data[0]]
527 | loss.backward()
528 | #nn.utils.clip_grad_norm(encoder2.parameters(), 0.05)
529 | #nn.utils.clip_grad_norm(decoder2.parameters(), 0.05)
530 | optimizer.step()
531 |
532 | # logging stuff
533 | if (i % args.log_interval == 0 and i != 0) or epoch == 0:
534 | print(loss.data[0])
535 | loss_total += [loss_epoch]
536 | print((dec_o.max(2)[1].data == tgts.data).float().sum(1) / tgts.size(1))
537 | print("ave loss of {} at epoch {}".format(running_loss / (i+1), epoch+1))
538 |
539 | loss_total = np.array(loss_total)
540 | plt.figure()
541 | plt.plot(loss_total.mean(1))
542 | plt.savefig("pytorch_attention_audio-loss.png")
543 |
544 | # Set up figure with colorbar
545 | attn_plot = dec_attn[0, :, :].data
546 | attn_plot = attn_plot.numpy() if not use_cuda else attn_plot.cpu().numpy()
547 | fig = plt.figure(figsize=(20, 6))
548 | ax = fig.add_subplot(111)
549 | cax = ax.matshow(attn_plot, cmap='bone', aspect="auto")
550 | fig.colorbar(cax)
551 | fig.savefig("pytorch_attention_audio-attention.png")
552 |
--------------------------------------------------------------------------------
/pytorch_basics.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 6,
6 | "metadata": {
7 | "collapsed": true
8 | },
9 | "outputs": [],
10 | "source": [
11 | "import numpy as np\n",
12 | "import torch\n",
13 | "import torchaudio\n",
14 | "import torch.nn as nn\n",
15 | "import torch.nn.functional as F\n",
16 | "from torch.autograd import Variable\n",
17 | "import os\n",
18 | "\n",
19 | "import matplotlib.pyplot as plt\n",
20 | "from IPython.display import display, Audio\n",
21 | "%matplotlib notebook"
22 | ]
23 | },
24 | {
25 | "cell_type": "code",
26 | "execution_count": 4,
27 | "metadata": {},
28 | "outputs": [],
29 | "source": [
30 | "def find_files(dir):\n",
31 | " return [os.path.join(dir,x) for x in os.listdir(dir) if os.path.isfile(os.path.join(dir, x))]\n",
32 | "\n",
33 | "files = find_files(\"data/waves_yesno\")"
34 | ]
35 | },
36 | {
37 | "cell_type": "code",
38 | "execution_count": 9,
39 | "metadata": {},
40 | "outputs": [
41 | {
42 | "data": {
43 | "text/plain": [
44 | "torch.Size([1, 1, 1000])"
45 | ]
46 | },
47 | "metadata": {},
48 | "output_type": "display_data"
49 | },
50 | {
51 | "data": {
52 | "text/plain": [
53 | "[(39, -0.648193359375),\n",
54 | " (53, 0.966796875),\n",
55 | " (53, 0.966796875),\n",
56 | " (104, -0.9051513671875),\n",
57 | " (104, -0.9051513671875),\n",
58 | " (104, -0.9051513671875),\n",
59 | " (118, 1.31011962890625),\n",
60 | " (118, 1.31011962890625),\n",
61 | " (168, -1.84478759765625),\n",
62 | " (168, -1.84478759765625),\n",
63 | " (168, -1.84478759765625),\n",
64 | " (168, -1.84478759765625),\n",
65 | " (178, 1.91741943359375),\n",
66 | " (178, 1.91741943359375),\n",
67 | " (178, 1.91741943359375),\n",
68 | " (178, 1.91741943359375),\n",
69 | " (226, -1.46392822265625),\n",
70 | " (226, -1.46392822265625),\n",
71 | " (239, 1.43524169921875),\n",
72 | " (239, 1.43524169921875),\n",
73 | " (239, 1.43524169921875),\n",
74 | " (282, -0.94512939453125),\n",
75 | " (343, -1.25518798828125),\n",
76 | " (343, -1.25518798828125),\n",
77 | " (343, -1.25518798828125),\n",
78 | " (361, 1.47003173828125),\n",
79 | " (378, 1.640625),\n",
80 | " (378, 1.640625),\n",
81 | " (414, -1.5948486328125),\n",
82 | " (414, -1.5948486328125),\n",
83 | " (423, 2.0025634765625),\n",
84 | " (423, 2.0025634765625),\n",
85 | " (423, 2.0025634765625),\n",
86 | " (423, 2.0025634765625),\n",
87 | " (476, -1.59820556640625),\n",
88 | " (485, 1.990966796875),\n",
89 | " (508, -1.619873046875),\n",
90 | " (508, -1.619873046875),\n",
91 | " (538, -2.08831787109375),\n",
92 | " (538, -2.08831787109375),\n",
93 | " (546, 2.00653076171875),\n",
94 | " (546, 2.00653076171875),\n",
95 | " (599, -2.55279541015625),\n",
96 | " (599, -2.55279541015625),\n",
97 | " (599, -2.55279541015625),\n",
98 | " (599, -2.55279541015625),\n",
99 | " (608, 2.11395263671875),\n",
100 | " (622, 2.373046875),\n",
101 | " (622, 2.373046875),\n",
102 | " (661, -2.41485595703125),\n",
103 | " (684, 2.3974609375),\n",
104 | " (684, 2.3974609375),\n",
105 | " (722, -2.6776123046875),\n",
106 | " (722, -2.6776123046875),\n",
107 | " (730, 2.626953125),\n",
108 | " (745, 2.7105712890625),\n",
109 | " (745, 2.7105712890625),\n",
110 | " (784, -2.9150390625),\n",
111 | " (784, -2.9150390625),\n",
112 | " (784, -2.9150390625),\n",
113 | " (784, -2.9150390625),\n",
114 | " (807, 2.7520751953125),\n",
115 | " (807, 2.7520751953125),\n",
116 | " (807, 2.7520751953125),\n",
117 | " (807, 2.7520751953125),\n",
118 | " (845, -2.86590576171875),\n",
119 | " (845, -2.86590576171875),\n",
120 | " (845, -2.86590576171875),\n",
121 | " (853, 2.7276611328125),\n",
122 | " (906, -2.56103515625),\n",
123 | " (914, 2.86346435546875),\n",
124 | " (914, 2.86346435546875),\n",
125 | " (914, 2.86346435546875),\n",
126 | " (914, 2.86346435546875),\n",
127 | " (967, -2.66265869140625),\n",
128 | " (967, -2.66265869140625)]"
129 | ]
130 | },
131 | "metadata": {},
132 | "output_type": "display_data"
133 | }
134 | ],
135 | "source": [
136 | "sig, sr = torchaudio.load(files[2])\n",
137 | "sig = torchaudio.transforms.Scale()(sig)\n",
138 | "\n",
139 | "X = sig[7900:8900].transpose(0, 1) * 10\n",
140 | "X.unsqueeze_(0)\n",
141 | "#X.unsqueeze_(0)\n",
142 | "display(X.size())\n",
143 | "\n",
144 | "x_max, idx_max = torch.nn.functional.max_pool1d(X, 100, 25, 13, return_indices=True)\n",
145 | "x_min, idx_min = torch.nn.functional.max_pool1d(-X, 100, 25, 13, return_indices=True)\n",
146 | "x_max.squeeze(), idx_max.data.numpy(), -x_min.squeeze(), idx_min.data.numpy()\n",
147 | "max_list = list(zip(idx_max.data.squeeze().numpy(), x_max.data.squeeze()))\n",
148 | "min_list = list(zip(idx_min.data.squeeze().numpy(), -x_min.data.squeeze()))\n",
149 | "combined_list = sorted(max_list + min_list)\n",
150 | "display(combined_list)"
151 | ]
152 | },
153 | {
154 | "cell_type": "code",
155 | "execution_count": 83,
156 | "metadata": {
157 | "scrolled": false
158 | },
159 | "outputs": [
160 | {
161 | "data": {
162 | "text/plain": [
163 | "\n",
164 | "( 0 ,.,.) = \n",
165 | "\n",
166 | "Columns 0 to 8 \n",
167 | " -0.0624 -0.0716 -0.0797 -0.0856 -0.0905 -0.0900 -0.0818 -0.0692 -0.0501\n",
168 | "\n",
169 | "Columns 9 to 17 \n",
170 | " -0.0287 -0.0072 0.0182 0.0451 0.0731 0.0965 0.1147 0.1244 0.1286\n",
171 | "\n",
172 | "Columns 18 to 26 \n",
173 | " 0.1310 0.1280 0.1182 0.1014 0.0768 0.0516 0.0299 0.0112 -0.0106\n",
174 | "\n",
175 | "Columns 27 to 35 \n",
176 | " -0.0326 -0.0532 -0.0644 -0.0667 -0.0625 -0.0581 -0.0517 -0.0411 -0.0269\n",
177 | "\n",
178 | "Columns 36 to 44 \n",
179 | " -0.0071 0.0119 0.0245 0.0345 0.0388 0.0443 0.0449 0.0386 0.0284\n",
180 | "\n",
181 | "Columns 45 to 53 \n",
182 | " 0.0139 0.0010 -0.0069 -0.0127 -0.0211 -0.0263 -0.0274 -0.0225 -0.0141\n",
183 | "\n",
184 | "Columns 54 to 62 \n",
185 | " -0.0041 0.0028 0.0041 0.0074 0.0062 -0.0039 -0.0194 -0.0416 -0.0692\n",
186 | "\n",
187 | "Columns 63 to 71 \n",
188 | " -0.0992 -0.1273 -0.1520 -0.1707 -0.1809 -0.1845 -0.1731 -0.1436 -0.0975\n",
189 | "\n",
190 | "Columns 72 to 80 \n",
191 | " -0.0462 0.0043 0.0543 0.1037 0.1487 0.1815 0.1917 0.1761 0.1512\n",
192 | "\n",
193 | "Columns 81 to 89 \n",
194 | " 0.1140 0.0734 0.0276 -0.0191 -0.0643 -0.0964 -0.1084 -0.1012 -0.0851\n",
195 | "\n",
196 | "Columns 90 to 98 \n",
197 | " -0.0621 -0.0371 -0.0034 0.0344 0.0718 0.0935 0.0977 0.0902 0.0768\n",
198 | "\n",
199 | "Columns 99 to 99 \n",
200 | " 0.0605\n",
201 | "[torch.FloatTensor of size 1x1x100]"
202 | ]
203 | },
204 | "metadata": {},
205 | "output_type": "display_data"
206 | },
207 | {
208 | "name": "stdout",
209 | "output_type": "stream",
210 | "text": [
211 | "Conv1d(1, 3, kernel_size=(10,), stride=(9,), dilation=(2,), bias=False)\n"
212 | ]
213 | },
214 | {
215 | "data": {
216 | "text/plain": [
217 | "Variable containing:\n",
218 | "(0 ,.,.) = \n",
219 | "\n",
220 | "Columns 0 to 8 \n",
221 | " -0.0624 -0.0287 0.1310 -0.0326 -0.0071 0.0139 -0.0041 -0.0992 -0.0462\n",
222 | " -0.0210 -0.0097 0.0441 -0.0110 -0.0024 0.0047 -0.0014 -0.0334 -0.0156\n",
223 | " -0.0180 -0.0083 0.0378 -0.0094 -0.0020 0.0040 -0.0012 -0.0286 -0.0133\n",
224 | "\n",
225 | "Columns 9 to 11 \n",
226 | " 0.1140 -0.0621 0.0605\n",
227 | " 0.0384 -0.0209 0.0204\n",
228 | " 0.0329 -0.0179 0.0175\n",
229 | "[torch.FloatTensor of size 1x3x12]"
230 | ]
231 | },
232 | "execution_count": 83,
233 | "metadata": {},
234 | "output_type": "execute_result"
235 | }
236 | ],
237 | "source": [
238 | "X2 = sig[8000:8100].transpose(0,1).unsqueeze(0)\n",
239 | "display(X2)\n",
240 | "conv_op = nn.Conv1d(1, 3, kernel_size=10, dilation=2, stride=9, bias=False)\n",
241 | "conv_op.weight = nn.Parameter(torch.cat((torch.ones(1, 1, 1), torch.rand(2, 1, 1))))\n",
242 | "print(conv_op)\n",
243 | "X2_conv = conv_op(Variable(X2))\n",
244 | "X2_conv"
245 | ]
246 | },
247 | {
248 | "cell_type": "code",
249 | "execution_count": 144,
250 | "metadata": {},
251 | "outputs": [
252 | {
253 | "data": {
254 | "text/plain": [
255 | "\n",
256 | "(0 ,.,.) = \n",
257 | " 0 1 2 3 4 5 6 7 8 9\n",
258 | "[torch.FloatTensor of size 1x1x10]"
259 | ]
260 | },
261 | "metadata": {},
262 | "output_type": "display_data"
263 | },
264 | {
265 | "data": {
266 | "text/plain": [
267 | "Variable containing:\n",
268 | "(0 ,.,.) = \n",
269 | " 2 3 5 7 9 11 13 15 7\n",
270 | "[torch.FloatTensor of size 1x1x9]"
271 | ]
272 | },
273 | "metadata": {},
274 | "output_type": "display_data"
275 | },
276 | {
277 | "data": {
278 | "text/plain": [
279 | "Variable containing:\n",
280 | "(0 ,.,.) = \n",
281 | " 0 1 2 4 6 8 10 12 14 16 8 9\n",
282 | "[torch.FloatTensor of size 1x1x12]"
283 | ]
284 | },
285 | "metadata": {},
286 | "output_type": "display_data"
287 | }
288 | ],
289 | "source": [
290 | "# Convolutions\n",
291 | "X = torch.arange(0, 10).view(1, 1, -1)\n",
292 | "display(X)\n",
293 | "conv_op = nn.Conv1d(1, 1, 2, dilation=3, padding=1, bias=False)\n",
294 | "conv_op.weight = nn.Parameter(torch.ones(conv_op.weight.size()))\n",
295 | "display(conv_op(Variable(X)))\n",
296 | "tconv_op = nn.ConvTranspose1d(1, 1, 2, dilation=2, bias=False)\n",
297 | "tconv_op.weight = nn.Parameter(torch.ones(tconv_op.weight.size()))\n",
298 | "display(tconv_op(Variable(X)))"
299 | ]
300 | },
301 | {
302 | "cell_type": "code",
303 | "execution_count": 155,
304 | "metadata": {},
305 | "outputs": [
306 | {
307 | "data": {
308 | "text/plain": [
309 | "\n",
310 | "(0 ,.,.) = \n",
311 | "\n",
312 | "Columns 0 to 8 \n",
313 | " 0.0000 1.0000 2.0000 3.0000 4.0000 5.0000 6.0000 7.0000 8.0000\n",
314 | "\n",
315 | "Columns 9 to 9 \n",
316 | " 9.0000\n",
317 | "\n",
318 | "(1 ,.,.) = \n",
319 | "\n",
320 | "Columns 0 to 8 \n",
321 | " 0.1731 0.8068 0.1422 0.7546 0.9261 0.5195 0.4176 0.6777 0.1002\n",
322 | "\n",
323 | "Columns 9 to 9 \n",
324 | " 0.1254\n",
325 | "[torch.FloatTensor of size 2x1x10]"
326 | ]
327 | },
328 | "metadata": {},
329 | "output_type": "display_data"
330 | },
331 | {
332 | "data": {
333 | "text/plain": [
334 | "Variable containing:\n",
335 | "(0 ,.,.) = \n",
336 | " 1.0000 3.0000 5.0000 7.0000 9.0000\n",
337 | "\n",
338 | "(1 ,.,.) = \n",
339 | " 0.8068 0.7546 0.9261 0.6777 0.1254\n",
340 | "[torch.FloatTensor of size 2x1x5]"
341 | ]
342 | },
343 | "execution_count": 155,
344 | "metadata": {},
345 | "output_type": "execute_result"
346 | }
347 | ],
348 | "source": [
349 | "# Pooling\n",
350 | "X = torch.cat((torch.arange(0, 10).view(1, 1, -1), torch.rand(1, 1, 10)))\n",
351 | "display(X)\n",
352 | "pool_op = nn.AdaptiveMaxPool1d(5)\n",
353 | "pool_op(X)"
354 | ]
355 | },
356 | {
357 | "cell_type": "code",
358 | "execution_count": 163,
359 | "metadata": {},
360 | "outputs": [
361 | {
362 | "data": {
363 | "text/plain": [
364 | "Variable containing:\n",
365 | "(0 ,0 ,.,.) = \n",
366 | " 1 0 1 2 3 4 5 6 7 8 9 8\n",
367 | "[torch.FloatTensor of size 1x1x1x12]"
368 | ]
369 | },
370 | "execution_count": 163,
371 | "metadata": {},
372 | "output_type": "execute_result"
373 | }
374 | ],
375 | "source": [
376 | "# Padding\n",
377 | "X = torch.arange(0, 10).view(1, 1, 1, -1)\n",
378 | "nn.ReflectionPad2d((1, 1, 0, 0))(X)\n"
379 | ]
380 | },
381 | {
382 | "cell_type": "code",
383 | "execution_count": null,
384 | "metadata": {},
385 | "outputs": [],
386 | "source": []
387 | },
388 | {
389 | "cell_type": "code",
390 | "execution_count": null,
391 | "metadata": {
392 | "collapsed": true
393 | },
394 | "outputs": [],
395 | "source": []
396 | }
397 | ],
398 | "metadata": {
399 | "kernelspec": {
400 | "display_name": "Python 3",
401 | "language": "python",
402 | "name": "python3"
403 | },
404 | "language_info": {
405 | "codemirror_mode": {
406 | "name": "ipython",
407 | "version": 3
408 | },
409 | "file_extension": ".py",
410 | "mimetype": "text/x-python",
411 | "name": "python",
412 | "nbconvert_exporter": "python",
413 | "pygments_lexer": "ipython3",
414 | "version": "3.6.1"
415 | }
416 | },
417 | "nbformat": 4,
418 | "nbformat_minor": 2
419 | }
420 |
--------------------------------------------------------------------------------
/pytorch_embedding_test.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {
7 | "collapsed": true
8 | },
9 | "outputs": [],
10 | "source": [
11 | "import torch\n",
12 | "import torch.nn as nn\n",
13 | "from torch.autograd import Variable"
14 | ]
15 | },
16 | {
17 | "cell_type": "code",
18 | "execution_count": 10,
19 | "metadata": {},
20 | "outputs": [
21 | {
22 | "data": {
23 | "text/plain": [
24 | "(torch.Size([2, 5, 3]), \n",
25 | " (0 ,.,.) = \n",
26 | " 0.3966 1.1329 0.9703\n",
27 | " 1.2760 -0.3156 -0.3003\n",
28 | " 0.4951 -2.4289 -1.8240\n",
29 | " -1.2025 -0.5043 -0.1327\n",
30 | " 0.3966 1.1329 0.9703\n",
31 | " \n",
32 | " (1 ,.,.) = \n",
33 | " 0.4951 -2.4289 -1.8240\n",
34 | " -0.3016 -1.0829 1.2743\n",
35 | " -0.0533 -0.9346 -0.3806\n",
36 | " -1.8540 -1.1871 -0.9740\n",
37 | " 0.5305 -0.0377 -0.4296\n",
38 | " [torch.FloatTensor of size 2x5x3], 1.275978684425354, -0.304186017687122, -2.4288904666900635)"
39 | ]
40 | },
41 | "execution_count": 10,
42 | "metadata": {},
43 | "output_type": "execute_result"
44 | }
45 | ],
46 | "source": [
47 | "# One Hot Vector (max val = N) to M-dim vector\n",
48 | "N = 1000\n",
49 | "M = 3\n",
50 | "embedding = nn.Embedding(N, M)\n",
51 | "input = Variable(torch.LongTensor([[1, 2, 4, 5, 1],[4, 3, 0, 9, 999]]))\n",
52 | "output = embedding(input)\n",
53 | "output.size(), output.data, output.data.max(), output.data.mean(), output.data.min()"
54 | ]
55 | },
56 | {
57 | "cell_type": "code",
58 | "execution_count": 13,
59 | "metadata": {},
60 | "outputs": [
61 | {
62 | "data": {
63 | "text/plain": [
64 | "(torch.Size([2, 5, 9]), \n",
65 | " (0 ,.,.) = \n",
66 | " 1.2783 -2.0387 0.3489 2.5870 0.4003 -0.8323 0.4795 -0.1188 1.5058\n",
67 | " 1.9345 0.6939 -0.6005 0.0113 -1.3579 -1.5732 0.6476 0.3265 2.0663\n",
68 | " -1.7766 1.7161 1.5391 -0.2763 -1.0611 -1.7473 1.4869 0.1519 0.5832\n",
69 | " 0.2865 1.8251 -0.3772 0.1366 -0.1392 -1.4507 -0.0412 0.3509 -1.0199\n",
70 | " 1.2783 -2.0387 0.3489 2.5870 0.4003 -0.8323 0.4795 -0.1188 1.5058\n",
71 | " \n",
72 | " (1 ,.,.) = \n",
73 | " -1.7766 1.7161 1.5391 -0.2763 -1.0611 -1.7473 1.4869 0.1519 0.5832\n",
74 | " -0.1941 -0.7290 -1.2425 2.0444 -0.2431 -0.7992 -1.1772 0.0754 0.0805\n",
75 | " -1.7917 0.3172 -0.4379 -0.0060 0.6569 2.5088 -0.9982 -1.5249 0.3790\n",
76 | " -0.9209 -0.2454 1.3805 -2.9321 -0.3620 1.3712 1.5372 -0.0832 -0.0649\n",
77 | " -0.7073 -0.0540 0.0487 1.0955 -0.9316 0.5288 0.2074 0.9193 0.6603\n",
78 | " [torch.FloatTensor of size 2x5x9], 2.587045907974243, 0.09486205115293463, -2.9320576190948486)"
79 | ]
80 | },
81 | "execution_count": 13,
82 | "metadata": {},
83 | "output_type": "execute_result"
84 | }
85 | ],
86 | "source": [
87 | "M_new = 9\n",
88 | "embedding_new = nn.Embedding(N, M_new)\n",
89 | "output = embedding_new(input)\n",
90 | "output.size(), output.data, output.data.max(), output.data.mean(), output.data.min()"
91 | ]
92 | },
93 | {
94 | "cell_type": "code",
95 | "execution_count": 11,
96 | "metadata": {},
97 | "outputs": [
98 | {
99 | "ename": "RuntimeError",
100 | "evalue": "index out of range at /Users/soumith/minicondabuild3/conda-bld/pytorch_1512381214802/work/torch/lib/TH/generic/THTensorMath.c:277",
101 | "output_type": "error",
102 | "traceback": [
103 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
104 | "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
105 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mN_invalid\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m999\u001b[0m \u001b[0;31m# the last val in input is illegal\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0membedding_invalid\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mEmbedding\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mN_invalid\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mM\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0membedding_invalid\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
106 | "\u001b[0;32m~/miniconda3/envs/ml/lib/python3.6/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 323\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_forward_pre_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 324\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 325\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 326\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 327\u001b[0m \u001b[0mhook_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
107 | "\u001b[0;32m~/miniconda3/envs/ml/lib/python3.6/site-packages/torch/nn/modules/sparse.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 101\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 102\u001b[0m \u001b[0mpadding_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax_norm\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnorm_type\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 103\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mscale_grad_by_freq\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msparse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 104\u001b[0m )\n\u001b[1;32m 105\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
108 | "\u001b[0;32m~/miniconda3/envs/ml/lib/python3.6/site-packages/torch/nn/_functions/thnn/sparse.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(cls, ctx, indices, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)\u001b[0m\n\u001b[1;32m 57\u001b[0m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mindex_select\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mindices\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 58\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 59\u001b[0;31m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mindex_select\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mindices\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mview\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 60\u001b[0m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0moutput\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mview\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mindices\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mindices\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mweight\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 61\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
109 | "\u001b[0;31mRuntimeError\u001b[0m: index out of range at /Users/soumith/minicondabuild3/conda-bld/pytorch_1512381214802/work/torch/lib/TH/generic/THTensorMath.c:277"
110 | ]
111 | }
112 | ],
113 | "source": [
114 | "N_invalid = 999 # the last val in input is illegal\n",
115 | "embedding_invalid = nn.Embedding(N_invalid, M)\n",
116 | "embedding_invalid(input)"
117 | ]
118 | },
119 | {
120 | "cell_type": "code",
121 | "execution_count": null,
122 | "metadata": {
123 | "collapsed": true
124 | },
125 | "outputs": [],
126 | "source": []
127 | }
128 | ],
129 | "metadata": {
130 | "kernelspec": {
131 | "display_name": "Python 3",
132 | "language": "python",
133 | "name": "python3"
134 | },
135 | "language_info": {
136 | "codemirror_mode": {
137 | "name": "ipython",
138 | "version": 3
139 | },
140 | "file_extension": ".py",
141 | "mimetype": "text/x-python",
142 | "name": "python",
143 | "nbconvert_exporter": "python",
144 | "pygments_lexer": "ipython3",
145 | "version": "3.6.3"
146 | }
147 | },
148 | "nbformat": 4,
149 | "nbformat_minor": 2
150 | }
151 |
--------------------------------------------------------------------------------
/test_VarLenDataset.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {
7 | "collapsed": true
8 | },
9 | "outputs": [],
10 | "source": [
11 | "import torch\n",
12 | "import torch.utils.data as data\n",
13 | "import torchaudio\n",
14 | "import librosa\n",
15 | "import numpy as np\n",
16 | "import random\n",
17 | "import os\n",
18 | "import glob"
19 | ]
20 | },
21 | {
22 | "cell_type": "code",
23 | "execution_count": 11,
24 | "metadata": {
25 | "collapsed": true,
26 | "scrolled": false
27 | },
28 | "outputs": [],
29 | "source": [
30 | "class VariableLengthDataset(data.Dataset):\n",
31 | " def __init__(self, manifest, snippet_length=24000, get_sequentially=False, ret_np=False, use_librosa=False):\n",
32 | " self.manifest = manifest\n",
33 | " self.snippet_length = snippet_length\n",
34 | " self.get_sequentially = get_sequentially\n",
35 | " self.use_librosa = use_librosa\n",
36 | " self.ret_np = ret_np\n",
37 | " self.acc = 0\n",
38 | " self.snippet_counter = 0\n",
39 | " self.audio_idx = 0\n",
40 | " self.st = 0\n",
41 | " self.data = {}\n",
42 | " def __getitem__(self, index):\n",
43 | " # load audio data from file or cache\n",
44 | " if self.snippet_counter == 0:\n",
45 | " self.audio_idx = index - self.acc\n",
46 | " apath = self.manifest[self.audio_idx]\n",
47 | " if apath not in self.data:\n",
48 | " if self.use_librosa:\n",
49 | " sig, sr = librosa.core.load(apath, sr=None)\n",
50 | " sig = torch.from_numpy(sig).unsqueeze(1).float()\n",
51 | " else:\n",
52 | " sig, sr = torchaudio.load(apath, normalization=True)\n",
53 | " self.data[apath] = (sig, sr)\n",
54 | " else:\n",
55 | " sig, sr = self.data[apath]\n",
56 | "\n",
57 | " # increase iterations based on length of audio\n",
58 | " num_snippets = int(sig.size(0) // self.snippet_length)\n",
59 | " self.acc += max(num_snippets-1,0)\n",
60 | " else:\n",
61 | " apath = self.manifest[self.audio_idx]\n",
62 | " sig, sr = self.data[apath]\n",
63 | " num_snippets = int(sig.size(0) // self.snippet_length)\n",
64 | "\n",
65 | " # create snippet\n",
66 | " if self.get_sequentially:\n",
67 | " self.st += self.snippet_length\n",
68 | " else:\n",
69 | " self.st = random.randrange(int(sig.size(0)-self.snippet_length))\n",
70 | " ret_sig = sig[self.st:(self.st+self.snippet_length)]\n",
71 | " if self.ret_np:\n",
72 | " ret_sig = ret_sig.numpy()\n",
73 | "\n",
74 | " # update counter for current audio file\n",
75 | " self.snippet_counter += 1\n",
76 | "\n",
77 | " # label creation\n",
78 | " spkr = os.path.dirname(apath).rsplit(\"/\", 1)[-1]\n",
79 | " spkr = 0\n",
80 | "\n",
81 | " # check for reset\n",
82 | " if self.snippet_counter >= num_snippets:\n",
83 | " self.snippet_counter = 0\n",
84 | " self.st = 0\n",
85 | "\n",
86 | " return ret_sig, spkr\n",
87 | "\n",
88 | " def __len__(self):\n",
89 | " return len(self.manifest) + self.acc\n",
90 | "\n",
91 | " def reset_acc(self):\n",
92 | " self.acc = 0\n",
93 | "\n",
94 | "class FixedLengthDataset(data.Dataset):\n",
95 | " def __init__(self, manifest, transforms = snippet_length=24000, ret_np=False, use_librosa=False):\n",
96 | " self.manifest = manifest\n",
97 | " self.snippet_length = snippet_length\n",
98 | " self.use_librosa = use_librosa\n",
99 | " self.ret_np = ret_np\n",
100 | " self.num_snippets = 1\n",
101 | " self.acc = 0\n",
102 | " self.audio_idx = 0\n",
103 | " self.st = 0\n",
104 | " self.data = {}\n",
105 | " def __getitem__(self, index):\n",
106 | " # load audio data from file or cache\n",
107 | " self.audio_idx = index if self.num_snippets == 1 else index // self.num_snippets\n",
108 | " apath = self.manifest[self.audio_idx]\n",
109 | " if self.use_librosa:\n",
110 | " sig, sr = librosa.core.load(apath, sr=None)\n",
111 | " sig = torch.from_numpy(sig).unsqueeze(1).float()\n",
112 | " else:\n",
113 | " sig, sr = torchaudio.load(apath, normalization=True)\n",
114 | "\n",
115 | " # create snippet\n",
116 | " if sig.size(0) < self.snippet_length:\n",
117 | " ret_sig = sig\n",
118 | " else:\n",
119 | " self.st = random.randrange(int(sig.size(0)-self.snippet_length))\n",
120 | " ret_sig = sig[self.st:(self.st+self.snippet_length)]\n",
121 | " if self.ret_np:\n",
122 | " ret_sig = ret_sig.numpy()\n",
123 | "\n",
124 | " # label creation\n",
125 | " #spkr = os.path.dirname(apath).rsplit(\"/\", 1)[-1]\n",
126 | " spkr = 0 # just using a dummy label now.\n",
127 | "\n",
128 | " return ret_sig, spkr\n",
129 | "\n",
130 | " def __len__(self):\n",
131 | " return len(self.manifest) * self.num_snippets\n",
132 | "\n",
133 | "def run_dataset():\n",
134 | " for epoch in range(1):\n",
135 | " all_data = [(x, label) for x, label in ds]\n",
136 | " print(epoch, len(all_data))\n",
137 | " try:\n",
138 | " ds.reset_acc()\n",
139 | " except:\n",
140 | " pass\n"
141 | ]
142 | },
143 | {
144 | "cell_type": "code",
145 | "execution_count": 3,
146 | "metadata": {},
147 | "outputs": [
148 | {
149 | "name": "stdout",
150 | "output_type": "stream",
151 | "text": [
152 | "0\n",
153 | "CPU times: user 91.1 ms, sys: 8.62 ms, total: 99.7 ms\n",
154 | "Wall time: 291 ms\n",
155 | "0\n",
156 | "CPU times: user 79 ms, sys: 6.72 ms, total: 85.7 ms\n",
157 | "Wall time: 69.8 ms\n"
158 | ]
159 | }
160 | ],
161 | "source": [
162 | "datadir = \"/home/david/Programming/tests/pcsnpny-20150204-mkj\"\n",
163 | "audio_manifest = [a for a in glob.glob(datadir+\"/**/*.wav\", recursive=True)]\n",
164 | "\n",
165 | "ds = VariableLengthDataset(audio_manifest, 12000, get_sequentially=True, ret_np=False, use_librosa=False)\n",
166 | "%time run_dataset()\n",
167 | "ds = VariableLengthDataset(audio_manifest, 12000, get_sequentially=True, ret_np=False, use_librosa=True)\n",
168 | "%time run_dataset()\n"
169 | ]
170 | },
171 | {
172 | "cell_type": "code",
173 | "execution_count": 4,
174 | "metadata": {
175 | "scrolled": false
176 | },
177 | "outputs": [
178 | {
179 | "name": "stdout",
180 | "output_type": "stream",
181 | "text": [
182 | "10 1\n",
183 | "torch.Size([10, 12000, 1]) torch.Size([10])\n"
184 | ]
185 | }
186 | ],
187 | "source": [
188 | "ds = VariableLengthDataset(audio_manifest, 12000, get_sequentially=False, ret_np=False, use_librosa=False)\n",
189 | "dl = data.DataLoader(ds, batch_size=10)\n",
190 | "print(len(ds), len(dl))\n",
191 | "for mb, tgts in dl:\n",
192 | " print(mb.size(), tgts.size())\n",
193 | " break"
194 | ]
195 | },
196 | {
197 | "cell_type": "code",
198 | "execution_count": 13,
199 | "metadata": {},
200 | "outputs": [
201 | {
202 | "name": "stdout",
203 | "output_type": "stream",
204 | "text": [
205 | "0 10\n",
206 | "CPU times: user 18.1 ms, sys: 0 ns, total: 18.1 ms\n",
207 | "Wall time: 18.4 ms\n"
208 | ]
209 | }
210 | ],
211 | "source": [
212 | "datadir = \"/home/david/Programming/tests/pcsnpny-20150204-mkj\"\n",
213 | "audio_manifest = [a for a in glob.glob(datadir+\"/**/*.wav\", recursive=True)]\n",
214 | "ds = FixedLengthDataset(audio_manifest, 12000, ret_np=True, use_librosa=True)\n",
215 | "%time run_dataset()\n"
216 | ]
217 | }
218 | ],
219 | "metadata": {
220 | "kernelspec": {
221 | "display_name": "Python [conda root]",
222 | "language": "python",
223 | "name": "conda-root-py"
224 | },
225 | "language_info": {
226 | "codemirror_mode": {
227 | "name": "ipython",
228 | "version": 3
229 | },
230 | "file_extension": ".py",
231 | "mimetype": "text/x-python",
232 | "name": "python",
233 | "nbconvert_exporter": "python",
234 | "pygments_lexer": "ipython3",
235 | "version": "3.6.1"
236 | }
237 | },
238 | "nbformat": 4,
239 | "nbformat_minor": 2
240 | }
241 |
--------------------------------------------------------------------------------
/test_pdb_debugger.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {
7 | "collapsed": true
8 | },
9 | "outputs": [],
10 | "source": [
11 | "import pdb"
12 | ]
13 | },
14 | {
15 | "cell_type": "code",
16 | "execution_count": 3,
17 | "metadata": {},
18 | "outputs": [
19 | {
20 | "name": "stdout",
21 | "output_type": "stream",
22 | "text": [
23 | "--Return--\n",
24 | "> (2)()->None\n",
25 | "-> pdb.set_trace()\n",
26 | "(Pdb) cont\n",
27 | "1\n",
28 | "3\n",
29 | "6\n",
30 | "10\n",
31 | "15\n",
32 | "21\n",
33 | "28\n",
34 | "36\n",
35 | "45\n",
36 | "55\n"
37 | ]
38 | }
39 | ],
40 | "source": [
41 | "x = 0\n",
42 | "pdb.set_trace()\n",
43 | "for i in range(10):\n",
44 | " x += i+1\n",
45 | " print(x)"
46 | ]
47 | },
48 | {
49 | "cell_type": "code",
50 | "execution_count": null,
51 | "metadata": {},
52 | "outputs": [],
53 | "source": []
54 | }
55 | ],
56 | "metadata": {
57 | "kernelspec": {
58 | "display_name": "Python 3",
59 | "language": "python",
60 | "name": "python3"
61 | },
62 | "language_info": {
63 | "codemirror_mode": {
64 | "name": "ipython",
65 | "version": 3
66 | },
67 | "file_extension": ".py",
68 | "mimetype": "text/x-python",
69 | "name": "python",
70 | "nbconvert_exporter": "python",
71 | "pygments_lexer": "ipython3",
72 | "version": "3.6.1"
73 | }
74 | },
75 | "nbformat": 4,
76 | "nbformat_minor": 2
77 | }
78 |
--------------------------------------------------------------------------------
/test_tqdm.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 10,
6 | "metadata": {},
7 | "outputs": [
8 | {
9 | "data": {
10 | "application/vnd.jupyter.widget-view+json": {
11 | "model_id": "983855d2e2764926ada81bfb689a388e"
12 | }
13 | },
14 | "metadata": {},
15 | "output_type": "display_data"
16 | },
17 | {
18 | "name": "stdout",
19 | "output_type": "stream",
20 | "text": [
21 | "\n"
22 | ]
23 | }
24 | ],
25 | "source": [
26 | "from tqdm import tqdm_notebook\n",
27 | "import time\n",
28 | "\n",
29 | "bar = tqdm_notebook(range(100), desc=\"test range\")\n",
30 | "x = 0\n",
31 | "for j, i in enumerate(bar):\n",
32 | " x += i\n",
33 | " time.sleep(0.1)\n",
34 | " bar.set_description(\"Processing {}\".format(x))\n"
35 | ]
36 | }
37 | ],
38 | "metadata": {
39 | "kernelspec": {
40 | "display_name": "Python [default]",
41 | "language": "python",
42 | "name": "python3"
43 | },
44 | "language_info": {
45 | "codemirror_mode": {
46 | "name": "ipython",
47 | "version": 3
48 | },
49 | "file_extension": ".py",
50 | "mimetype": "text/x-python",
51 | "name": "python",
52 | "nbconvert_exporter": "python",
53 | "pygments_lexer": "ipython3",
54 | "version": "3.6.1"
55 | }
56 | },
57 | "nbformat": 4,
58 | "nbformat_minor": 2
59 | }
60 |
--------------------------------------------------------------------------------