├── LICENSE
├── README.md
└── With_GoogleNet.ipynb
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 KobeWang-supreme
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Music-genre-classification-with-different-CNNs
2 | Using different CNN models to train on GTZAN Dataset
3 | You can download the whole dataset from: https://www.kaggle.com/datasets/andradaolteanu/gtzan-dataset-music-genre-classification
4 |
5 | Results up to now:
6 | 1. GoogleNet : Test Accuracy: Around 55% - 60%
7 |
8 | 3. Newly Designed CNN:
9 |
10 | With only Mel n_feature = 128: Test Accuracy(10 genres): Around 70% - 80% (baseline: 10%) Test Accuracy(5 genres): Around 85% (baseline: 20%)
11 |
12 |
13 | With Mfcc n_feature = 128: Test Accuracy(10 genres): Around 82%, which is almostly equal to the accuracy with both Mel and Mfcc
14 |
15 |
16 | With Mel and Mfcc(added together) n_feature = 128: Test Accuracy(10 genres): Around 83% (baseline: 10%) Test Accuracy(5 genres): Around 91% (baseline: 20%)
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 | Confusion Matrics:
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 | Let's test on a Reggae Song:
36 |
37 |
38 |
39 | Result:
40 |
41 | 
42 |
43 |
44 |
45 |
46 | Findings:
47 | 1. Actually, some musics belongs to not only one genre. Such as songs from Norah Jones, many songs of her belongs to both jazz and blues. Our model can only return one genre(with highest score). However, it is also clear that by visualizing scores for every genre, we can see jazz and blues have much higher scores than other genres.
48 |
49 |
50 |
51 |
52 | Let's take "Don't know why" as an example:
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
--------------------------------------------------------------------------------
/With_GoogleNet.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "We are using the dataset called \"GTZAN\", which is widely used in Music Genre Classification.\n",
8 | "Our goal is to use audios' mfccs to make the classification. This data_deal.ipynb will show how mfccs are generated."
9 | ]
10 | },
11 | {
12 | "cell_type": "code",
13 | "execution_count": 159,
14 | "metadata": {},
15 | "outputs": [],
16 | "source": [
17 | "import numpy as np \n",
18 | "import pandas as pd \n",
19 | "import librosa\n",
20 | "import os\n",
21 | "import tqdm\n",
22 | "import matplotlib.pyplot as plt\n",
23 | "import random\n",
24 | "import librosa.display\n",
25 | "from mpl_toolkits.axes_grid1 import make_axes_locatable\n",
26 | "import torch\n",
27 | "import torchvision\n",
28 | "from torchvision import transforms\n",
29 | "from torchvision.datasets import ImageFolder\n",
30 | "from torch.utils.data.dataloader import DataLoader\n",
31 | "from torch.utils.data import random_split\n",
32 | "from torchvision.utils import make_grid\n",
33 | "import torch.nn as nn\n",
34 | "import torch.nn.functional as F\n",
35 | "from sklearn.model_selection import train_test_split\n",
36 | "from torch.utils.data import Dataset, TensorDataset\n",
37 | "from sklearn.metrics import classification_report\n",
38 | "from sklearn import metrics"
39 | ]
40 | },
41 | {
42 | "cell_type": "code",
43 | "execution_count": 160,
44 | "metadata": {},
45 | "outputs": [],
46 | "source": [
47 | "def seed_torch(seed):\n",
48 | " random.seed(seed) # python seed\n",
49 | " os.environ['PYTHONHASHSEED'] = str(seed) # 设置python哈希种子,for certain hash-based operations (e.g., the item order in a set or a dict)。seed为0的时候表示不用这个feature,也可以设置为整数。 有时候需要在终端执行,到脚本实行可能就迟了。\n",
50 | " np.random.seed(seed) # If you or any of the libraries you are using rely on NumPy, 比如Sampling,或者一些augmentation。 哪些是例外可以看https://pytorch.org/docs/stable/notes/randomness.html\n",
51 | " torch.manual_seed(seed) # 为当前CPU设置随机种子。 pytorch官网倒是说(both CPU and CUDA)\n",
52 | " torch.cuda.manual_seed(seed) # 为当前GPU设置随机种子\n",
53 | " # torch.cuda.manual_seed_all(seed) # 使用多块GPU时,均设置随机种子\n",
54 | " torch.backends.cudnn.deterministic = True\n",
55 | " # torch.backends.cudnn.benchmark = True # 设置为True时,cuDNN使用非确定性算法寻找最高效算法\n",
56 | " # torch.backends.cudnn.enabled = True # pytorch使用CUDANN加速,即使用GPU加速\n",
57 | " \n",
58 | "seed_torch(seed=32)"
59 | ]
60 | },
61 | {
62 | "cell_type": "markdown",
63 | "metadata": {},
64 | "source": [
65 | "First we should read the data."
66 | ]
67 | },
68 | {
69 | "cell_type": "code",
70 | "execution_count": 161,
71 | "metadata": {},
72 | "outputs": [],
73 | "source": [
74 | "def data_read(directory=r\"F:\\music\\archive\\Data\\genres_original\"): # change the directory if you like \n",
75 | " data = [] # Store all data (read from music)\n",
76 | " labels = [] # Store corresponding labels\n",
77 | " srs = []\n",
78 | "\n",
79 | " x = 0\n",
80 | "\n",
81 | " for dirname, _, filenames in os.walk(r\"F:\\music\\archive\\Data\\genres_original\"):\n",
82 | " for filename in filenames:\n",
83 | " # Get name and skip the broken\n",
84 | " if x == 554: # skip the broken one \n",
85 | " x += 1\n",
86 | " continue\n",
87 | " filename = os.path.join(dirname, filename)\n",
88 | " \n",
89 | " y, sr = librosa.load(filename)\n",
90 | " # print(filename)\n",
91 | " label = filename.split('\\\\')[5]\n",
92 | "\n",
93 | " # Append them to the final data\n",
94 | " data.append(y)\n",
95 | " labels.append(label)\n",
96 | " srs.append(sr)\n",
97 | " x += 1\n",
98 | " #print(x) # counting\n",
99 | " #print(len(data))\n",
100 | " \n",
101 | " return data, labels, srs"
102 | ]
103 | },
104 | {
105 | "cell_type": "code",
106 | "execution_count": 162,
107 | "metadata": {},
108 | "outputs": [],
109 | "source": [
110 | "data, labels, srs = data_read() "
111 | ]
112 | },
113 | {
114 | "cell_type": "markdown",
115 | "metadata": {},
116 | "source": [
117 | "Get one random sample to show whether we read the data successfully."
118 | ]
119 | },
120 | {
121 | "cell_type": "code",
122 | "execution_count": 163,
123 | "metadata": {},
124 | "outputs": [
125 | {
126 | "data": {
127 | "image/png": "",
128 | "text/plain": [
129 | ""
130 | ]
131 | },
132 | "metadata": {},
133 | "output_type": "display_data"
134 | }
135 | ],
136 | "source": [
137 | "def check(): \n",
138 | " fig, axi = plt.subplots(1,1) # initialize\n",
139 | " ind = random.randint(0, len(labels)-1) # randomly choose one index\n",
140 | " times = [sample/srs[ind] for sample in range(len(data[ind]))] # get time axis \n",
141 | " axi.plot(times, data[ind])\n",
142 | " axi.set_ylabel('Normalized Amplitude')\n",
143 | " axi.set_xlabel('Time /s')\n",
144 | " axi.set(title = \"one example from \" + labels[ind].capitalize())\n",
145 | "\n",
146 | " plt.show()\n",
147 | "\n",
148 | "check()"
149 | ]
150 | },
151 | {
152 | "cell_type": "markdown",
153 | "metadata": {},
154 | "source": [
155 | "Then we need to extract mfccs features from \"data\""
156 | ]
157 | },
158 | {
159 | "cell_type": "code",
160 | "execution_count": 164,
161 | "metadata": {},
162 | "outputs": [],
163 | "source": [
164 | "def get_mfcc(data, srs, num_mfcc):\n",
165 | " '''\n",
166 | " input: data : array like data structure\n",
167 | " src : sampling rates\n",
168 | " num_mfcc: how many mfcc features you want\n",
169 | "\n",
170 | " return: mfccs\n",
171 | " '''\n",
172 | " mfccs = []\n",
173 | " for i in range(len(data)):\n",
174 | " mfcc = librosa.feature.mfcc(y=data[i], sr=srs[i], n_mfcc=num_mfcc).T\n",
175 | " mfccs.append(mfcc) # every sample we get one array of mfccs and append it.\n",
176 | " \n",
177 | " return mfccs # contains every music's mfcc"
178 | ]
179 | },
180 | {
181 | "cell_type": "code",
182 | "execution_count": 165,
183 | "metadata": {},
184 | "outputs": [],
185 | "source": [
186 | "mfccs = get_mfcc(data, srs, num_mfcc=40)"
187 | ]
188 | },
189 | {
190 | "cell_type": "markdown",
191 | "metadata": {},
192 | "source": [
193 | "Then we need to pad mfccs to make them have the same length"
194 | ]
195 | },
196 | {
197 | "cell_type": "code",
198 | "execution_count": 166,
199 | "metadata": {},
200 | "outputs": [],
201 | "source": [
202 | "max_length = len(max(mfccs, key=len)) # get the max length\n",
203 | "def pad_mfcc(mfccs, max_length):\n",
204 | " 'pads'\n",
205 | " mfccs_padded = []\n",
206 | " for mfcc in mfccs:\n",
207 | " mfcc_padded = np.pad(mfcc, pad_width=[(0,max_length-mfcc[:,0].shape[0]),(0,0)])\n",
208 | " mfccs_padded.append(mfcc_padded)\n",
209 | " return mfccs_padded"
210 | ]
211 | },
212 | {
213 | "cell_type": "code",
214 | "execution_count": 167,
215 | "metadata": {},
216 | "outputs": [],
217 | "source": [
218 | "mfccs_padded = pad_mfcc(mfccs, max_length)"
219 | ]
220 | },
221 | {
222 | "cell_type": "markdown",
223 | "metadata": {},
224 | "source": [
225 | "Next, we put the mfccs_padded into dataloader"
226 | ]
227 | },
228 | {
229 | "cell_type": "code",
230 | "execution_count": 168,
231 | "metadata": {},
232 | "outputs": [
233 | {
234 | "name": "stdout",
235 | "output_type": "stream",
236 | "text": [
237 | " labels\n",
238 | "0 blues\n",
239 | "1 blues\n",
240 | "2 blues\n",
241 | "3 blues\n",
242 | "4 blues\n",
243 | ".. ...\n",
244 | "994 rock\n",
245 | "995 rock\n",
246 | "996 rock\n",
247 | "997 rock\n",
248 | "998 rock\n",
249 | "\n",
250 | "[999 rows x 1 columns]\n",
251 | " labels\n",
252 | "0 0\n",
253 | "1 0\n",
254 | "2 0\n",
255 | "3 0\n",
256 | "4 0\n",
257 | ".. ...\n",
258 | "994 9\n",
259 | "995 9\n",
260 | "996 9\n",
261 | "997 9\n",
262 | "998 9\n",
263 | "\n",
264 | "[999 rows x 1 columns]\n",
265 | "[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
266 | " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
267 | " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1\n",
268 | " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
269 | " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
270 | " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2\n",
271 | " 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2\n",
272 | " 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2\n",
273 | " 2 2 2 2 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3\n",
274 | " 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3\n",
275 | " 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 4 4 4 4 4 4 4\n",
276 | " 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4\n",
277 | " 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4\n",
278 | " 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5\n",
279 | " 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5\n",
280 | " 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5\n",
281 | " 5 5 5 5 5 5 5 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6\n",
282 | " 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6\n",
283 | " 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 7 7 7 7\n",
284 | " 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7\n",
285 | " 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7\n",
286 | " 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8\n",
287 | " 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8\n",
288 | " 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8\n",
289 | " 8 8 8 8 8 8 8 8 8 8 8 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9\n",
290 | " 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9\n",
291 | " 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9]\n"
292 | ]
293 | }
294 | ],
295 | "source": [
296 | "# convert data to numpy array\n",
297 | "X = np.asarray(mfccs_padded) \n",
298 | "# Mapping labels \n",
299 | "codes = {\n",
300 | " 'blues':0,\n",
301 | " 'classical':1,\n",
302 | " 'country':2,\n",
303 | " 'disco':3,\n",
304 | " 'hiphop':4,\n",
305 | " 'jazz':5,\n",
306 | " 'metal':6,\n",
307 | " 'pop':7,\n",
308 | " 'reggae':8,\n",
309 | " 'rock':9\n",
310 | "}\n",
311 | "\n",
312 | "df_map = pd.DataFrame (labels, columns = ['labels'])\n",
313 | "print(df_map)\n",
314 | "df_map['labels'] = df_map['labels'].map(codes)\n",
315 | "print(df_map)\n",
316 | "y = df_map['labels'].to_numpy()\n",
317 | "print(y)"
318 | ]
319 | },
320 | {
321 | "cell_type": "code",
322 | "execution_count": 169,
323 | "metadata": {},
324 | "outputs": [],
325 | "source": [
326 | "# To split train_data, valid_data and test_data\n",
327 | "def train_val_test_split(x, y, test_size, val_size, random_state=None, stratify=None):\n",
328 | " x_t, x_test, y_t, y_test = train_test_split(x, y, test_size=test_size, random_state=random_state, stratify=stratify)\n",
329 | " if stratify is not None:\n",
330 | " stratify = y_t\n",
331 | " x_train, x_val, y_train, y_val = train_test_split(x_t, y_t, test_size=val_size, random_state=random_state, stratify=stratify)\n",
332 | " \n",
333 | " return x_train, y_train, x_val, y_val, x_test, y_test\n",
334 | "\n",
335 | "def to_Dataloader(x_train, y_train, x_val, y_val, x_test, y_test, batch_size, test_batch=1,shuffle=True):\n",
336 | " \n",
337 | " train_dataset = TensorDataset(torch.Tensor(x_train),torch.Tensor(y_train).type(torch.LongTensor)) # create train_dataset\n",
338 | " val_dataset = TensorDataset(torch.Tensor(x_val),torch.Tensor(y_val).type(torch.LongTensor)) # create val_dataset\n",
339 | " test_dataset = TensorDataset(torch.Tensor(x_test),torch.Tensor(y_test).type(torch.LongTensor)) # create test_dataset\n",
340 | "\n",
341 | " train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle) # create your train_dataloader\n",
342 | " val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle) # create your val_dataloader\n",
343 | " test_dataloader = DataLoader(test_dataset, batch_size=test_batch, shuffle=shuffle) # create your test_dataloader\n",
344 | " \n",
345 | " return train_dataloader, val_dataloader, test_dataloader"
346 | ]
347 | },
348 | {
349 | "cell_type": "code",
350 | "execution_count": 170,
351 | "metadata": {},
352 | "outputs": [
353 | {
354 | "name": "stdout",
355 | "output_type": "stream",
356 | "text": [
357 | "finished\n"
358 | ]
359 | }
360 | ],
361 | "source": [
362 | "x_train, y_train, x_val, y_val, x_test, y_test = train_val_test_split(X, y, test_size=0.15, val_size=0.15, random_state=42, stratify=y)\n",
363 | "\n",
364 | "train_dataloader, val_dataloader, test_dataloader = to_Dataloader(x_train, y_train, x_val, y_val, x_test, y_test, batch_size=64, test_batch=1,shuffle=True)\n",
365 | "\n",
366 | "print('finished')"
367 | ]
368 | },
369 | {
370 | "cell_type": "markdown",
371 | "metadata": {},
372 | "source": [
373 | "We then define a class called \"early stop \"\n",
374 | "It is from : \"https://github.com/Bjarten/early-stopping-pytorch/blob/master/pytorchtools.py\""
375 | ]
376 | },
377 | {
378 | "cell_type": "code",
379 | "execution_count": 171,
380 | "metadata": {},
381 | "outputs": [],
382 | "source": [
383 | "class EarlyStopping:\n",
384 | " \"\"\"Early stops the training if validation loss doesn't improve after a given patience.\"\"\"\n",
385 | " def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print):\n",
386 | " \"\"\"\n",
387 | " Args:\n",
388 | " patience (int): How long to wait after last time validation loss improved.\n",
389 | " Default: 7\n",
390 | " verbose (bool): If True, prints a message for each validation loss improvement. \n",
391 | " Default: False\n",
392 | " delta (float): Minimum change in the monitored quantity to qualify as an improvement.\n",
393 | " Default: 0\n",
394 | " path (str): Path for the checkpoint to be saved to.\n",
395 | " Default: 'checkpoint.pt'\n",
396 | " trace_func (function): trace print function.\n",
397 | " Default: print \n",
398 | " \"\"\"\n",
399 | " self.patience = patience\n",
400 | " self.verbose = verbose\n",
401 | " self.counter = 0\n",
402 | " self.best_score = None\n",
403 | " self.early_stop = False\n",
404 | " self.val_loss_min = np.Inf\n",
405 | " self.delta = delta\n",
406 | " self.path = path\n",
407 | " self.trace_func = trace_func\n",
408 | " def __call__(self, val_loss, model):\n",
409 | "\n",
410 | " score = -val_loss\n",
411 | "\n",
412 | " if self.best_score is None:\n",
413 | " self.best_score = score\n",
414 | " self.save_checkpoint(val_loss, model)\n",
415 | " elif score < self.best_score + self.delta:\n",
416 | " self.counter += 1\n",
417 | " self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')\n",
418 | " if self.counter >= self.patience:\n",
419 | " self.early_stop = True\n",
420 | " else:\n",
421 | " self.best_score = score\n",
422 | " self.save_checkpoint(val_loss, model)\n",
423 | " self.counter = 0\n",
424 | "\n",
425 | " def save_checkpoint(self, val_loss, model):\n",
426 | " '''Saves model when validation loss decrease.'''\n",
427 | " if self.verbose:\n",
428 | " self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')\n",
429 | " torch.save(model.state_dict(), self.path)\n",
430 | " self.val_loss_min = val_loss"
431 | ]
432 | },
433 | {
434 | "cell_type": "markdown",
435 | "metadata": {},
436 | "source": [
437 | "After doing all things above, it is time to build our model to train this dataset!"
438 | ]
439 | },
440 | {
441 | "cell_type": "markdown",
442 | "metadata": {},
443 | "source": [
444 | "For the first model, we try to Reproducing GoogleNet:"
445 | ]
446 | },
447 | {
448 | "cell_type": "code",
449 | "execution_count": 172,
450 | "metadata": {},
451 | "outputs": [],
452 | "source": [
453 | "# A class to define the Inception\n",
454 | "class Inception(nn.Module):\n",
455 | " def __init__(self, in_channels, c1, c2, c3, c4, **kwargs):\n",
456 | " super(Inception, self).__init__(**kwargs)\n",
457 | " self.p1_1 = nn.Conv2d(in_channels, c1, kernel_size=1)\n",
458 | " #1X1 kernel\n",
459 | " self.p2_1 = nn.Conv2d(in_channels, c2[0], kernel_size=1)\n",
460 | " self.p2_2 = nn.Conv2d(c2[0], c2[1], kernel_size=3, padding=1)\n",
461 | " #1X1,3X3 kernel\n",
462 | " self.p3_1 = nn.Conv2d(in_channels, c3[0], kernel_size=1)\n",
463 | " self.p3_2 = nn.Conv2d(c3[0], c3[1], kernel_size=5, padding=2)\n",
464 | " #1X1,5X5 kernel\n",
465 | " self.p4_1 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)\n",
466 | " self.p4_2 = nn.Conv2d(in_channels, c4, kernel_size=1)\n",
467 | " #maxpool, 1x1 kernel\n",
468 | " def forward(self, x):\n",
469 | " p1 = F.relu(self.p1_1(x))\n",
470 | " p2 = F.relu(self.p2_2(F.relu(self.p2_1(x))))\n",
471 | " p3 = F.relu(self.p3_2(F.relu(self.p3_1(x))))\n",
472 | " p4 = F.relu(self.p4_2(self.p4_1(x)))\n",
473 | " # Sequentialize\n",
474 | " return torch.cat((p1, p2, p3, p4), dim=1)\n",
475 | " \n",
476 | "\n",
477 | "class GoogleNet(nn.Module):\n",
478 | " def __init__(self):\n",
479 | " super(GoogleNet, self).__init__()\n",
480 | " self.num_conv_layers = 5\n",
481 | " b1 = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),\n",
482 | " nn.ReLU(),\n",
483 | " nn.MaxPool2d(kernel_size=3, stride=2, padding=1))\n",
484 | "\n",
485 | " b2 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=1),\n",
486 | " nn.ReLU(),\n",
487 | " nn.Conv2d(64, 192, kernel_size=3, padding=1),\n",
488 | " nn.ReLU(),\n",
489 | " nn.MaxPool2d(kernel_size=3, stride=2, padding=1))\n",
490 | "\n",
491 | " b3 = nn.Sequential(Inception(192, 64, (96, 128), (16, 32), 32),\n",
492 | " Inception(256, 128, (128, 192), (32, 96), 64),\n",
493 | " nn.MaxPool2d(kernel_size=3, stride=2, padding=1))\n",
494 | "\n",
495 | " b4 = nn.Sequential(Inception(480, 192, (96, 208), (16, 48), 64),\n",
496 | " Inception(512, 160, (112, 224), (24, 64), 64),\n",
497 | " Inception(512, 128, (128, 256), (24, 64), 64),\n",
498 | " Inception(512, 112, (144, 288), (32, 64), 64),\n",
499 | " Inception(528, 256, (160, 320), (32, 128), 128),\n",
500 | " nn.MaxPool2d(kernel_size=3, stride=2, padding=1))\n",
501 | "\n",
502 | " b5 = nn.Sequential(Inception(832, 256, (160, 320), (32, 128), 128),\n",
503 | " Inception(832, 384, (192, 384), (48, 128), 128),\n",
504 | " nn.AdaptiveAvgPool2d((1,1)),\n",
505 | " nn.Flatten())\n",
506 | " #seq = [b1, b2, b3, b4, b5]\n",
507 | " self.fitter = nn.Sequential(b1, b2, b3, b4, b5, nn.Linear(1024, 10))\n",
508 | "\n",
509 | " def forward(self, x):\n",
510 | "\n",
511 | " x = x.transpose(1,2)\n",
512 | " x.unsqueeze_(1)\n",
513 | " out = self.fitter(x)\n",
514 | " return out"
515 | ]
516 | },
517 | {
518 | "cell_type": "code",
519 | "execution_count": 173,
520 | "metadata": {},
521 | "outputs": [],
522 | "source": [
523 | "# Then we wanna define a resnet\n",
524 | "class MusicGenreModel(nn.Module):\n",
525 | " def __init__(self, classes=10):\n",
526 | " super().__init__()\n",
527 | "\n",
528 | " self.classes = classes\n",
529 | " self.linear1 = nn.Linear(20, 200)\n",
530 | " self.linear2 = nn.Linear(200, 100)\n",
531 | " self.output = nn.Linear(100, self.classes)\n",
532 | "\n",
533 | " def forward(self, x):\n",
534 | " x = x.view(x.size(0), -1)\n",
535 | " x = self.linear1(x)\n",
536 | " x = torch.relu_(x)\n",
537 | " x = self.linear2(x)\n",
538 | " x = torch.relu_(x)\n",
539 | " \n",
540 | " return F.softmax(self.output(x), dim=1)\n"
541 | ]
542 | },
543 | {
544 | "cell_type": "code",
545 | "execution_count": 174,
546 | "metadata": {},
547 | "outputs": [
548 | {
549 | "name": "stdout",
550 | "output_type": "stream",
551 | "text": [
552 | "cuda\n"
553 | ]
554 | }
555 | ],
556 | "source": [
557 | "# check the device\n",
558 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
559 | "print(device)"
560 | ]
561 | },
562 | {
563 | "cell_type": "markdown",
564 | "metadata": {},
565 | "source": [
566 | "Build functions which used for training and predicting"
567 | ]
568 | },
569 | {
570 | "cell_type": "code",
571 | "execution_count": 175,
572 | "metadata": {},
573 | "outputs": [],
574 | "source": [
575 | "def training_loop(model, train_dataloader, optimizer, device=device):\n",
576 | "\n",
577 | " model.train() # Sets model to train mode\n",
578 | " batch_losses = []\n",
579 | "\n",
580 | " for x_batch, y_batch in train_dataloader:\n",
581 | " # Move batches to device\n",
582 | " x_batch, y_batch = x_batch.to(device), y_batch.to(device)\n",
583 | " \n",
584 | " # Clear gradients\n",
585 | " optimizer.zero_grad()\n",
586 | "\n",
587 | " yhat = model(x_batch) #predictions on x\n",
588 | " \n",
589 | " # Forward pass\n",
590 | " loss = loss_function(yhat, y_batch)\n",
591 | " \n",
592 | " # Backward and optimize\n",
593 | " loss.backward()\n",
594 | "\n",
595 | " # Update parameters\n",
596 | " optimizer.step()\n",
597 | "\n",
598 | " batch_losses.append(loss.data.item())\n",
599 | "\n",
600 | " train_loss = np.mean(batch_losses)\n",
601 | "\n",
602 | " return train_loss # Return train_loss and anything else you need\n",
603 | "\n",
604 | "def validation_loop(model, val_dataloader, device=device):\n",
605 | "\n",
606 | " model.eval() # Sets model to val mode\n",
607 | "\n",
608 | " batch_losses = []\n",
609 | "\n",
610 | " for x_batch, y_batch in val_dataloader:\n",
611 | "\n",
612 | " x_batch, y_batch = x_batch.to(device), y_batch.to(device)\n",
613 | " \n",
614 | " yhat = model(x_batch) #predictions on x\n",
615 | " \n",
616 | " loss = loss_function(yhat, y_batch)\n",
617 | "\n",
618 | " batch_losses.append(loss.data.item())\n",
619 | "\n",
620 | " val_loss = np.mean(batch_losses)\n",
621 | "\n",
622 | " return val_loss\n",
623 | "\n",
624 | "def train(model, train_dataloader, val_dataloader, optimizer, epochs, device=\"cuda\", patience = None, temp=100):\n",
625 | "\n",
626 | " train_losses = []\n",
627 | " val_losses = []\n",
628 | "\n",
629 | " print(f\"Initiating training.\")\n",
630 | " \n",
631 | " # Check if early stop is enabled:\n",
632 | " if patience is not None:\n",
633 | " # Initialize EarlyStopping\n",
634 | " early_stopping = EarlyStopping(patience=patience, verbose=False, path='checkpoint.pt')\n",
635 | "\n",
636 | " for epoch in range(epochs):\n",
637 | " # Training loop\n",
638 | "\n",
639 | " train_loss = training_loop(model, train_dataloader, optimizer, device)\n",
640 | " train_losses.append(train_loss)\n",
641 | "\n",
642 | " # Validation loop\n",
643 | " with torch.no_grad():\n",
644 | "\n",
645 | " val_loss = validation_loop(model, val_dataloader, device)\n",
646 | " val_losses.append(val_loss)\n",
647 | "\n",
648 | " if patience != -1:\n",
649 | " early_stopping(val_loss, model)\n",
650 | "\n",
651 | " if early_stopping.early_stop:\n",
652 | " print(\"Early stop. Going back to the last checkpoint.\")\n",
653 | " break\n",
654 | "\n",
655 | " if epoch % temp == 0:\n",
656 | " print(f\"[{epoch}/{epochs}] Training loss: {train_loss:.4f}\\t Validation loss: {val_loss:.4f}.\")\n",
657 | "\n",
658 | " if patience != None and early_stopping.early_stop == True:\n",
659 | " print('Loading model from checkpoint...')\n",
660 | " model.load_state_dict(torch.load('checkpoint.pt'))\n",
661 | " print('Checkpoint loaded.')\n",
662 | "\n",
663 | " print(\"training finished.\")\n",
664 | " \n",
665 | " # visualize the loss as the network trained\n",
666 | " fig = plt.figure(figsize=(10, 8))\n",
667 | " plt.plot(range(1, len(train_losses) + 1), train_losses, label='Training Loss')\n",
668 | " plt.plot(range(1, len(val_losses) + 1), val_losses, label='Validation Loss')\n",
669 | "\n",
670 | " # find position of lowest validation loss\n",
671 | " minposs = val_losses.index(min(val_losses)) + 1\n",
672 | " plt.axvline(minposs, linestyle='--', color='r', label='Early Stopping Checkpoint')\n",
673 | "\n",
674 | " plt.xlabel('epochs')\n",
675 | " plt.ylabel('loss')\n",
676 | " plt.ylim(0, max(val_losses + train_losses)) # consistent scale\n",
677 | " plt.xlim(0, len(train_losses) + 1) # consistent scale\n",
678 | " plt.grid(True)\n",
679 | " plt.legend()\n",
680 | " plt.tight_layout()\n",
681 | " plt.title('Validation and Training Loss of CNN')\n",
682 | " plt.show()\n",
683 | "\n",
684 | " return model\n",
685 | "\n",
686 | "\n",
687 | "def predict(model, test_loader, n_features, loss_function, device=device):\n",
688 | " # Make predictions using model\n",
689 | " preds = []\n",
690 | " true_values = []\n",
691 | " loss = 0\n",
692 | " model.eval() # prep model for evaluation\n",
693 | " \n",
694 | " with torch.no_grad():\n",
695 | " for x_batch, y_batch in test_loader:\n",
696 | " # move to device\n",
697 | " x_batch, y_batch = x_batch.to(device), y_batch.to(device)\n",
698 | "\n",
699 | " # Make predictions\n",
700 | " pred = model(x_batch)\n",
701 | " \n",
702 | " preds.append(np.argmax(pred.cpu().numpy(), axis=1)[0])\n",
703 | " true_values.append(y_batch.cpu().numpy()[0])\n",
704 | " loss += loss_function(pred, y_batch)\n",
705 | "\n",
706 | " #Calculate Accuracy\n",
707 | " accuracy = sum(np.array(preds) == np.array(true_values))/len(true_values)\n",
708 | " \n",
709 | " return preds, true_values, accuracy"
710 | ]
711 | },
712 | {
713 | "cell_type": "markdown",
714 | "metadata": {},
715 | "source": [
716 | "Initialize the model"
717 | ]
718 | },
719 | {
720 | "cell_type": "code",
721 | "execution_count": 176,
722 | "metadata": {},
723 | "outputs": [
724 | {
725 | "data": {
726 | "text/plain": [
727 | "GoogleNet(\n",
728 | " (fitter): Sequential(\n",
729 | " (0): Sequential(\n",
730 | " (0): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))\n",
731 | " (1): ReLU()\n",
732 | " (2): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
733 | " )\n",
734 | " (1): Sequential(\n",
735 | " (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
736 | " (1): ReLU()\n",
737 | " (2): Conv2d(64, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
738 | " (3): ReLU()\n",
739 | " (4): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
740 | " )\n",
741 | " (2): Sequential(\n",
742 | " (0): Inception(\n",
743 | " (p1_1): Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1))\n",
744 | " (p2_1): Conv2d(192, 96, kernel_size=(1, 1), stride=(1, 1))\n",
745 | " (p2_2): Conv2d(96, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
746 | " (p3_1): Conv2d(192, 16, kernel_size=(1, 1), stride=(1, 1))\n",
747 | " (p3_2): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n",
748 | " (p4_1): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)\n",
749 | " (p4_2): Conv2d(192, 32, kernel_size=(1, 1), stride=(1, 1))\n",
750 | " )\n",
751 | " (1): Inception(\n",
752 | " (p1_1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))\n",
753 | " (p2_1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))\n",
754 | " (p2_2): Conv2d(128, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
755 | " (p3_1): Conv2d(256, 32, kernel_size=(1, 1), stride=(1, 1))\n",
756 | " (p3_2): Conv2d(32, 96, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n",
757 | " (p4_1): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)\n",
758 | " (p4_2): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))\n",
759 | " )\n",
760 | " (2): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
761 | " )\n",
762 | " (3): Sequential(\n",
763 | " (0): Inception(\n",
764 | " (p1_1): Conv2d(480, 192, kernel_size=(1, 1), stride=(1, 1))\n",
765 | " (p2_1): Conv2d(480, 96, kernel_size=(1, 1), stride=(1, 1))\n",
766 | " (p2_2): Conv2d(96, 208, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
767 | " (p3_1): Conv2d(480, 16, kernel_size=(1, 1), stride=(1, 1))\n",
768 | " (p3_2): Conv2d(16, 48, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n",
769 | " (p4_1): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)\n",
770 | " (p4_2): Conv2d(480, 64, kernel_size=(1, 1), stride=(1, 1))\n",
771 | " )\n",
772 | " (1): Inception(\n",
773 | " (p1_1): Conv2d(512, 160, kernel_size=(1, 1), stride=(1, 1))\n",
774 | " (p2_1): Conv2d(512, 112, kernel_size=(1, 1), stride=(1, 1))\n",
775 | " (p2_2): Conv2d(112, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
776 | " (p3_1): Conv2d(512, 24, kernel_size=(1, 1), stride=(1, 1))\n",
777 | " (p3_2): Conv2d(24, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n",
778 | " (p4_1): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)\n",
779 | " (p4_2): Conv2d(512, 64, kernel_size=(1, 1), stride=(1, 1))\n",
780 | " )\n",
781 | " (2): Inception(\n",
782 | " (p1_1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1))\n",
783 | " (p2_1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1))\n",
784 | " (p2_2): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
785 | " (p3_1): Conv2d(512, 24, kernel_size=(1, 1), stride=(1, 1))\n",
786 | " (p3_2): Conv2d(24, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n",
787 | " (p4_1): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)\n",
788 | " (p4_2): Conv2d(512, 64, kernel_size=(1, 1), stride=(1, 1))\n",
789 | " )\n",
790 | " (3): Inception(\n",
791 | " (p1_1): Conv2d(512, 112, kernel_size=(1, 1), stride=(1, 1))\n",
792 | " (p2_1): Conv2d(512, 144, kernel_size=(1, 1), stride=(1, 1))\n",
793 | " (p2_2): Conv2d(144, 288, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
794 | " (p3_1): Conv2d(512, 32, kernel_size=(1, 1), stride=(1, 1))\n",
795 | " (p3_2): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n",
796 | " (p4_1): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)\n",
797 | " (p4_2): Conv2d(512, 64, kernel_size=(1, 1), stride=(1, 1))\n",
798 | " )\n",
799 | " (4): Inception(\n",
800 | " (p1_1): Conv2d(528, 256, kernel_size=(1, 1), stride=(1, 1))\n",
801 | " (p2_1): Conv2d(528, 160, kernel_size=(1, 1), stride=(1, 1))\n",
802 | " (p2_2): Conv2d(160, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
803 | " (p3_1): Conv2d(528, 32, kernel_size=(1, 1), stride=(1, 1))\n",
804 | " (p3_2): Conv2d(32, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n",
805 | " (p4_1): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)\n",
806 | " (p4_2): Conv2d(528, 128, kernel_size=(1, 1), stride=(1, 1))\n",
807 | " )\n",
808 | " (5): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
809 | " )\n",
810 | " (4): Sequential(\n",
811 | " (0): Inception(\n",
812 | " (p1_1): Conv2d(832, 256, kernel_size=(1, 1), stride=(1, 1))\n",
813 | " (p2_1): Conv2d(832, 160, kernel_size=(1, 1), stride=(1, 1))\n",
814 | " (p2_2): Conv2d(160, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
815 | " (p3_1): Conv2d(832, 32, kernel_size=(1, 1), stride=(1, 1))\n",
816 | " (p3_2): Conv2d(32, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n",
817 | " (p4_1): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)\n",
818 | " (p4_2): Conv2d(832, 128, kernel_size=(1, 1), stride=(1, 1))\n",
819 | " )\n",
820 | " (1): Inception(\n",
821 | " (p1_1): Conv2d(832, 384, kernel_size=(1, 1), stride=(1, 1))\n",
822 | " (p2_1): Conv2d(832, 192, kernel_size=(1, 1), stride=(1, 1))\n",
823 | " (p2_2): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
824 | " (p3_1): Conv2d(832, 48, kernel_size=(1, 1), stride=(1, 1))\n",
825 | " (p3_2): Conv2d(48, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n",
826 | " (p4_1): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)\n",
827 | " (p4_2): Conv2d(832, 128, kernel_size=(1, 1), stride=(1, 1))\n",
828 | " )\n",
829 | " (2): AdaptiveAvgPool2d(output_size=(1, 1))\n",
830 | " (3): Flatten(start_dim=1, end_dim=-1)\n",
831 | " )\n",
832 | " (5): Linear(in_features=1024, out_features=10, bias=True)\n",
833 | " )\n",
834 | ")"
835 | ]
836 | },
837 | "execution_count": 176,
838 | "metadata": {},
839 | "output_type": "execute_result"
840 | }
841 | ],
842 | "source": [
843 | "model = GoogleNet()\n",
844 | "model.to(device)"
845 | ]
846 | },
847 | {
848 | "cell_type": "code",
849 | "execution_count": 177,
850 | "metadata": {},
851 | "outputs": [],
852 | "source": [
853 | "learning_rate = 0.0001\n",
854 | "weight_decay = 1e-4\n",
855 | "\n",
856 | "\n",
857 | "loss_function = nn.CrossEntropyLoss()\n",
858 | "optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate, weight_decay = weight_decay)"
859 | ]
860 | },
861 | {
862 | "cell_type": "code",
863 | "execution_count": 178,
864 | "metadata": {},
865 | "outputs": [
866 | {
867 | "name": "stdout",
868 | "output_type": "stream",
869 | "text": [
870 | "Initiating training.\n",
871 | "[0/1000] Training loss: 2.3035\t Validation loss: 2.3023.\n",
872 | "[5/1000] Training loss: 2.1136\t Validation loss: 2.0317.\n",
873 | "EarlyStopping counter: 1 out of 15\n",
874 | "EarlyStopping counter: 2 out of 15\n",
875 | "[10/1000] Training loss: 1.8832\t Validation loss: 1.9136.\n",
876 | "EarlyStopping counter: 1 out of 15\n",
877 | "EarlyStopping counter: 1 out of 15\n",
878 | "[15/1000] Training loss: 1.7949\t Validation loss: 1.9552.\n",
879 | "EarlyStopping counter: 2 out of 15\n",
880 | "EarlyStopping counter: 3 out of 15\n",
881 | "EarlyStopping counter: 1 out of 15\n",
882 | "[20/1000] Training loss: 1.6279\t Validation loss: 1.8498.\n",
883 | "EarlyStopping counter: 2 out of 15\n",
884 | "EarlyStopping counter: 3 out of 15\n",
885 | "EarlyStopping counter: 4 out of 15\n",
886 | "EarlyStopping counter: 1 out of 15\n",
887 | "[25/1000] Training loss: 1.5665\t Validation loss: 1.7585.\n",
888 | "EarlyStopping counter: 2 out of 15\n",
889 | "EarlyStopping counter: 1 out of 15\n",
890 | "[30/1000] Training loss: 1.5914\t Validation loss: 1.6470.\n",
891 | "EarlyStopping counter: 1 out of 15\n",
892 | "EarlyStopping counter: 2 out of 15\n",
893 | "EarlyStopping counter: 3 out of 15\n",
894 | "[35/1000] Training loss: 1.4190\t Validation loss: 1.6658.\n",
895 | "EarlyStopping counter: 4 out of 15\n",
896 | "EarlyStopping counter: 5 out of 15\n",
897 | "EarlyStopping counter: 1 out of 15\n",
898 | "[40/1000] Training loss: 1.4399\t Validation loss: 1.6641.\n",
899 | "EarlyStopping counter: 1 out of 15\n",
900 | "EarlyStopping counter: 2 out of 15\n",
901 | "EarlyStopping counter: 3 out of 15\n",
902 | "[45/1000] Training loss: 1.3014\t Validation loss: 1.6467.\n",
903 | "EarlyStopping counter: 4 out of 15\n",
904 | "EarlyStopping counter: 5 out of 15\n",
905 | "EarlyStopping counter: 6 out of 15\n",
906 | "EarlyStopping counter: 7 out of 15\n",
907 | "EarlyStopping counter: 8 out of 15\n",
908 | "[50/1000] Training loss: 1.2805\t Validation loss: 1.5588.\n",
909 | "EarlyStopping counter: 9 out of 15\n",
910 | "EarlyStopping counter: 1 out of 15\n",
911 | "[55/1000] Training loss: 1.2732\t Validation loss: 1.6181.\n",
912 | "EarlyStopping counter: 2 out of 15\n",
913 | "EarlyStopping counter: 3 out of 15\n",
914 | "EarlyStopping counter: 4 out of 15\n",
915 | "[60/1000] Training loss: 1.1238\t Validation loss: 1.3764.\n",
916 | "EarlyStopping counter: 1 out of 15\n",
917 | "EarlyStopping counter: 2 out of 15\n",
918 | "EarlyStopping counter: 3 out of 15\n",
919 | "EarlyStopping counter: 1 out of 15\n",
920 | "[65/1000] Training loss: 1.0934\t Validation loss: 1.4616.\n",
921 | "EarlyStopping counter: 2 out of 15\n",
922 | "EarlyStopping counter: 3 out of 15\n",
923 | "EarlyStopping counter: 4 out of 15\n",
924 | "EarlyStopping counter: 5 out of 15\n",
925 | "EarlyStopping counter: 6 out of 15\n",
926 | "[70/1000] Training loss: 1.2725\t Validation loss: 1.3733.\n",
927 | "EarlyStopping counter: 1 out of 15\n",
928 | "EarlyStopping counter: 2 out of 15\n",
929 | "EarlyStopping counter: 1 out of 15\n",
930 | "[75/1000] Training loss: 1.0698\t Validation loss: 1.2997.\n",
931 | "EarlyStopping counter: 2 out of 15\n",
932 | "EarlyStopping counter: 3 out of 15\n",
933 | "EarlyStopping counter: 4 out of 15\n",
934 | "EarlyStopping counter: 5 out of 15\n",
935 | "EarlyStopping counter: 6 out of 15\n",
936 | "[80/1000] Training loss: 0.9480\t Validation loss: 1.4509.\n",
937 | "EarlyStopping counter: 1 out of 15\n",
938 | "EarlyStopping counter: 1 out of 15\n",
939 | "EarlyStopping counter: 2 out of 15\n",
940 | "[85/1000] Training loss: 0.9552\t Validation loss: 1.8327.\n",
941 | "EarlyStopping counter: 3 out of 15\n",
942 | "EarlyStopping counter: 4 out of 15\n",
943 | "EarlyStopping counter: 5 out of 15\n",
944 | "EarlyStopping counter: 6 out of 15\n",
945 | "EarlyStopping counter: 7 out of 15\n",
946 | "[90/1000] Training loss: 0.8089\t Validation loss: 1.3086.\n",
947 | "EarlyStopping counter: 8 out of 15\n",
948 | "EarlyStopping counter: 9 out of 15\n",
949 | "EarlyStopping counter: 10 out of 15\n",
950 | "EarlyStopping counter: 1 out of 15\n",
951 | "[95/1000] Training loss: 0.9292\t Validation loss: 1.2831.\n",
952 | "EarlyStopping counter: 2 out of 15\n",
953 | "EarlyStopping counter: 3 out of 15\n",
954 | "EarlyStopping counter: 4 out of 15\n",
955 | "EarlyStopping counter: 5 out of 15\n",
956 | "[100/1000] Training loss: 0.7520\t Validation loss: 1.1526.\n",
957 | "EarlyStopping counter: 1 out of 15\n",
958 | "EarlyStopping counter: 2 out of 15\n",
959 | "EarlyStopping counter: 3 out of 15\n",
960 | "EarlyStopping counter: 1 out of 15\n",
961 | "[105/1000] Training loss: 0.6597\t Validation loss: 1.3155.\n",
962 | "EarlyStopping counter: 1 out of 15\n",
963 | "EarlyStopping counter: 2 out of 15\n",
964 | "EarlyStopping counter: 1 out of 15\n",
965 | "[110/1000] Training loss: 0.6098\t Validation loss: 1.2427.\n",
966 | "EarlyStopping counter: 2 out of 15\n",
967 | "EarlyStopping counter: 1 out of 15\n",
968 | "EarlyStopping counter: 2 out of 15\n",
969 | "EarlyStopping counter: 3 out of 15\n",
970 | "[115/1000] Training loss: 0.8273\t Validation loss: 1.1345.\n",
971 | "EarlyStopping counter: 4 out of 15\n",
972 | "EarlyStopping counter: 5 out of 15\n",
973 | "EarlyStopping counter: 6 out of 15\n",
974 | "EarlyStopping counter: 1 out of 15\n",
975 | "[120/1000] Training loss: 0.6367\t Validation loss: 1.3711.\n",
976 | "EarlyStopping counter: 2 out of 15\n",
977 | "EarlyStopping counter: 3 out of 15\n",
978 | "EarlyStopping counter: 4 out of 15\n",
979 | "EarlyStopping counter: 5 out of 15\n",
980 | "EarlyStopping counter: 6 out of 15\n",
981 | "[125/1000] Training loss: 0.4742\t Validation loss: 1.1495.\n",
982 | "EarlyStopping counter: 7 out of 15\n",
983 | "EarlyStopping counter: 8 out of 15\n",
984 | "EarlyStopping counter: 9 out of 15\n",
985 | "EarlyStopping counter: 10 out of 15\n",
986 | "EarlyStopping counter: 11 out of 15\n",
987 | "[130/1000] Training loss: 0.7824\t Validation loss: 1.1804.\n",
988 | "EarlyStopping counter: 12 out of 15\n",
989 | "EarlyStopping counter: 13 out of 15\n",
990 | "EarlyStopping counter: 14 out of 15\n",
991 | "EarlyStopping counter: 15 out of 15\n",
992 | "Early stop. Going back to the last checkpoint.\n",
993 | "Loading model from checkpoint...\n",
994 | "Checkpoint loaded.\n",
995 | "training finished.\n"
996 | ]
997 | },
998 | {
999 | "data": {
1000 | "image/png": "",
1001 | "text/plain": [
1002 | ""
1003 | ]
1004 | },
1005 | "metadata": {},
1006 | "output_type": "display_data"
1007 | }
1008 | ],
1009 | "source": [
1010 | "model = train(model, train_dataloader, val_dataloader, optimizer, epochs=1000, device=\"cuda\", patience = 15, temp=5)"
1011 | ]
1012 | },
1013 | {
1014 | "cell_type": "code",
1015 | "execution_count": 179,
1016 | "metadata": {},
1017 | "outputs": [
1018 | {
1019 | "name": "stdout",
1020 | "output_type": "stream",
1021 | "text": [
1022 | " precision recall f1-score support\n",
1023 | "\n",
1024 | " 0 0.65 0.73 0.69 15\n",
1025 | " 1 0.78 0.93 0.85 15\n",
1026 | " 2 0.56 0.60 0.58 15\n",
1027 | " 3 0.17 0.20 0.18 15\n",
1028 | " 4 0.75 0.60 0.67 15\n",
1029 | " 5 0.60 0.80 0.69 15\n",
1030 | " 6 0.86 0.80 0.83 15\n",
1031 | " 7 0.82 0.60 0.69 15\n",
1032 | " 8 0.62 0.53 0.57 15\n",
1033 | " 9 0.27 0.20 0.23 15\n",
1034 | "\n",
1035 | " accuracy 0.60 150\n",
1036 | " macro avg 0.61 0.60 0.60 150\n",
1037 | "weighted avg 0.61 0.60 0.60 150\n",
1038 | "\n"
1039 | ]
1040 | }
1041 | ],
1042 | "source": [
1043 | "preds, true_values, accuracy = predict(model, test_dataloader, n_features=40, loss_function=loss_function, device=device)\n",
1044 | "print(classification_report(true_values, preds))"
1045 | ]
1046 | }
1047 | ],
1048 | "metadata": {
1049 | "kernelspec": {
1050 | "display_name": "Music",
1051 | "language": "python",
1052 | "name": "python3"
1053 | },
1054 | "language_info": {
1055 | "codemirror_mode": {
1056 | "name": "ipython",
1057 | "version": 3
1058 | },
1059 | "file_extension": ".py",
1060 | "mimetype": "text/x-python",
1061 | "name": "python",
1062 | "nbconvert_exporter": "python",
1063 | "pygments_lexer": "ipython3",
1064 | "version": "3.9.17"
1065 | },
1066 | "orig_nbformat": 4
1067 | },
1068 | "nbformat": 4,
1069 | "nbformat_minor": 2
1070 | }
1071 |
--------------------------------------------------------------------------------