├── 1.선형회귀.ipynb
├── 2.이진 분류(로지스틱_회귀).ipynb
├── 3.다중 분류 (Multi-Class Classification).ipynb
├── 4.심층신경망.ipynb
├── README.md
├── slides
├── 1주차_1강.pdf
├── 1주차_2강.pdf
├── 2주차.pdf
├── 3주차.pdf
├── 4주차.pdf
└── 5주차.pdf
├── 롤 승패 예측하기.ipynb
└── 리그 오브 레전드 승패 예측하기
├── README.md
├── x_test.csv
├── x_train.csv
├── y_test.csv
└── y_train.csv
/1.선형회귀.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "1. 선형회귀.ipynb",
7 | "provenance": [],
8 | "authorship_tag": "ABX9TyN1XCxenMLW6BogIJ2W3kyE",
9 | "include_colab_link": true
10 | },
11 | "kernelspec": {
12 | "name": "python3",
13 | "display_name": "Python 3"
14 | }
15 | },
16 | "cells": [
17 | {
18 | "cell_type": "markdown",
19 | "metadata": {
20 | "id": "view-in-github",
21 | "colab_type": "text"
22 | },
23 | "source": [
24 | "
"
25 | ]
26 | },
27 | {
28 | "cell_type": "markdown",
29 | "metadata": {
30 | "id": "tMZmjJqp482w",
31 | "colab_type": "text"
32 | },
33 | "source": [
34 | "# Linear Regression"
35 | ]
36 | },
37 | {
38 | "cell_type": "markdown",
39 | "metadata": {
40 | "id": "XqXMlTwE5Gkp",
41 | "colab_type": "text"
42 | },
43 | "source": [
44 | "## Theoretical Overview\n",
45 | "$$ H(x) = Wx +b$$\n",
46 | "\n",
47 | "$$ cost(W, b) = \\frac{1}{m} \\sum^m_{i=1} \\left( H(x^{(i)}) - y^{(i)} \\right)^2$$\n",
48 | "\n",
49 | "- H(x): x값에 대한 예측\n",
50 | "- cost(W,b): W,b에 대해 함수 H(x)의 손실값"
51 | ]
52 | },
53 | {
54 | "cell_type": "code",
55 | "metadata": {
56 | "id": "F3TmX7Lg40VR",
57 | "colab_type": "code",
58 | "colab": {}
59 | },
60 | "source": [
61 | "import torch\n",
62 | "import torch.optim as optim"
63 | ],
64 | "execution_count": 0,
65 | "outputs": []
66 | },
67 | {
68 | "cell_type": "code",
69 | "metadata": {
70 | "id": "lw29q_sl7GgK",
71 | "colab_type": "code",
72 | "colab": {
73 | "base_uri": "https://localhost:8080/",
74 | "height": 36
75 | },
76 | "outputId": "80312a96-1b54-4519-95cb-750957433ccd"
77 | },
78 | "source": [
79 | "torch.manual_seed(1) #시드를 고정"
80 | ],
81 | "execution_count": 2,
82 | "outputs": [
83 | {
84 | "output_type": "execute_result",
85 | "data": {
86 | "text/plain": [
87 | ""
88 | ]
89 | },
90 | "metadata": {
91 | "tags": []
92 | },
93 | "execution_count": 2
94 | }
95 | ]
96 | },
97 | {
98 | "cell_type": "markdown",
99 | "metadata": {
100 | "id": "FLL00ihV7pPE",
101 | "colab_type": "text"
102 | },
103 | "source": [
104 | "## Data\n",
105 | "아래의 조건을 충족하는 Fake Data를 만들어 사용\n",
106 | "$$ y(x) = 2x+1$$"
107 | ]
108 | },
109 | {
110 | "cell_type": "code",
111 | "metadata": {
112 | "id": "vSCMadY17ywZ",
113 | "colab_type": "code",
114 | "colab": {}
115 | },
116 | "source": [
117 | "x_train = torch.FloatTensor([[1], [2], [3]])\n",
118 | "y_train = torch.FloatTensor([[3], [5], [7]])"
119 | ],
120 | "execution_count": 0,
121 | "outputs": []
122 | },
123 | {
124 | "cell_type": "code",
125 | "metadata": {
126 | "id": "bBZ4Pz6P8ROy",
127 | "colab_type": "code",
128 | "colab": {
129 | "base_uri": "https://localhost:8080/",
130 | "height": 92
131 | },
132 | "outputId": "4914edd7-4faf-48f0-81cf-d0cea815fd61"
133 | },
134 | "source": [
135 | "print(x_train)\n",
136 | "print(x_train.shape)"
137 | ],
138 | "execution_count": 4,
139 | "outputs": [
140 | {
141 | "output_type": "stream",
142 | "text": [
143 | "tensor([[1.],\n",
144 | " [2.],\n",
145 | " [3.]])\n",
146 | "torch.Size([3, 1])\n"
147 | ],
148 | "name": "stdout"
149 | }
150 | ]
151 | },
152 | {
153 | "cell_type": "code",
154 | "metadata": {
155 | "id": "8St3rJ6o8dJ2",
156 | "colab_type": "code",
157 | "colab": {
158 | "base_uri": "https://localhost:8080/",
159 | "height": 92
160 | },
161 | "outputId": "1b022b57-f721-40ba-eda2-5dd05e953e27"
162 | },
163 | "source": [
164 | "print(y_train)\n",
165 | "print(y_train.shape)"
166 | ],
167 | "execution_count": 5,
168 | "outputs": [
169 | {
170 | "output_type": "stream",
171 | "text": [
172 | "tensor([[3.],\n",
173 | " [5.],\n",
174 | " [7.]])\n",
175 | "torch.Size([3, 1])\n"
176 | ],
177 | "name": "stdout"
178 | }
179 | ]
180 | },
181 | {
182 | "cell_type": "markdown",
183 | "metadata": {
184 | "id": "Y_oMSoFi9BhQ",
185 | "colab_type": "text"
186 | },
187 | "source": [
188 | "## Weight Initialization\n",
189 | "\n",
190 | "requires_grad는 파라메터가 학습되어야 하는지를 지정한다."
191 | ]
192 | },
193 | {
194 | "cell_type": "code",
195 | "metadata": {
196 | "id": "izJbLNFw9IML",
197 | "colab_type": "code",
198 | "colab": {
199 | "base_uri": "https://localhost:8080/",
200 | "height": 36
201 | },
202 | "outputId": "897a4317-3776-4355-956c-7879eaf71c5f"
203 | },
204 | "source": [
205 | "W = torch.zeros(1, requires_grad=True)\n",
206 | "print(W)"
207 | ],
208 | "execution_count": 6,
209 | "outputs": [
210 | {
211 | "output_type": "stream",
212 | "text": [
213 | "tensor([0.], requires_grad=True)\n"
214 | ],
215 | "name": "stdout"
216 | }
217 | ]
218 | },
219 | {
220 | "cell_type": "code",
221 | "metadata": {
222 | "id": "EZNjEfyP9QOQ",
223 | "colab_type": "code",
224 | "colab": {
225 | "base_uri": "https://localhost:8080/",
226 | "height": 36
227 | },
228 | "outputId": "c5134de0-1f6b-433c-fec6-9f0517fe67e3"
229 | },
230 | "source": [
231 | "b = torch.zeros(1, requires_grad=True)\n",
232 | "print(b)"
233 | ],
234 | "execution_count": 7,
235 | "outputs": [
236 | {
237 | "output_type": "stream",
238 | "text": [
239 | "tensor([0.], requires_grad=True)\n"
240 | ],
241 | "name": "stdout"
242 | }
243 | ]
244 | },
245 | {
246 | "cell_type": "markdown",
247 | "metadata": {
248 | "id": "j1Pszwxw9kFf",
249 | "colab_type": "text"
250 | },
251 | "source": [
252 | "## Hypothesis 설정\n",
253 | "$$ H(x) = Wx + b $$"
254 | ]
255 | },
256 | {
257 | "cell_type": "code",
258 | "metadata": {
259 | "id": "EOfMTkR09sgO",
260 | "colab_type": "code",
261 | "colab": {
262 | "base_uri": "https://localhost:8080/",
263 | "height": 73
264 | },
265 | "outputId": "9cffeca1-a327-4982-c50c-b6bc8d0df436"
266 | },
267 | "source": [
268 | "hypothesis = x_train * W + b\n",
269 | "print(hypothesis)"
270 | ],
271 | "execution_count": 8,
272 | "outputs": [
273 | {
274 | "output_type": "stream",
275 | "text": [
276 | "tensor([[0.],\n",
277 | " [0.],\n",
278 | " [0.]], grad_fn=)\n"
279 | ],
280 | "name": "stdout"
281 | }
282 | ]
283 | },
284 | {
285 | "cell_type": "markdown",
286 | "metadata": {
287 | "id": "SEP4eRAo9067",
288 | "colab_type": "text"
289 | },
290 | "source": [
291 | "## Cost\n",
292 | "$$ cost(W, b) = \\frac{1}{m} \\sum^m_{i=1} \\left( H(x^{(i)}) - y^{(i)} \\right)^2$$"
293 | ]
294 | },
295 | {
296 | "cell_type": "code",
297 | "metadata": {
298 | "id": "jKAl-khJ96eU",
299 | "colab_type": "code",
300 | "colab": {
301 | "base_uri": "https://localhost:8080/",
302 | "height": 73
303 | },
304 | "outputId": "e57e48f5-bd95-4978-e817-b628128d51fd"
305 | },
306 | "source": [
307 | "print(hypothesis - y_train)"
308 | ],
309 | "execution_count": 9,
310 | "outputs": [
311 | {
312 | "output_type": "stream",
313 | "text": [
314 | "tensor([[-3.],\n",
315 | " [-5.],\n",
316 | " [-7.]], grad_fn=)\n"
317 | ],
318 | "name": "stdout"
319 | }
320 | ]
321 | },
322 | {
323 | "cell_type": "code",
324 | "metadata": {
325 | "id": "N_LFkAXH-AOQ",
326 | "colab_type": "code",
327 | "colab": {
328 | "base_uri": "https://localhost:8080/",
329 | "height": 36
330 | },
331 | "outputId": "80bbd87a-b381-48a8-83b0-297782214d60"
332 | },
333 | "source": [
334 | "cost = torch.mean((hypothesis - y_train)**2) #MSE\n",
335 | "print(cost)"
336 | ],
337 | "execution_count": 10,
338 | "outputs": [
339 | {
340 | "output_type": "stream",
341 | "text": [
342 | "tensor(27.6667, grad_fn=)\n"
343 | ],
344 | "name": "stdout"
345 | }
346 | ]
347 | },
348 | {
349 | "cell_type": "markdown",
350 | "metadata": {
351 | "id": "OXKeWtnF-OP-",
352 | "colab_type": "text"
353 | },
354 | "source": [
355 | "## Gradient Descent"
356 | ]
357 | },
358 | {
359 | "cell_type": "code",
360 | "metadata": {
361 | "id": "tTe-plz6-Sre",
362 | "colab_type": "code",
363 | "colab": {}
364 | },
365 | "source": [
366 | "optimizer = optim.SGD([W,b], lr=0.01)"
367 | ],
368 | "execution_count": 0,
369 | "outputs": []
370 | },
371 | {
372 | "cell_type": "code",
373 | "metadata": {
374 | "id": "6StgW4-J-YFZ",
375 | "colab_type": "code",
376 | "colab": {}
377 | },
378 | "source": [
379 | "optimizer.zero_grad() # Optimizer 초기화\n",
380 | "cost.backward() # 구한 loss로부터 Back Prop을 통해 각 변수마다 loss에 대한 gradient 산출\n",
381 | "optimizer.step() # 파라메터 업데이트"
382 | ],
383 | "execution_count": 0,
384 | "outputs": []
385 | },
386 | {
387 | "cell_type": "code",
388 | "metadata": {
389 | "id": "uTqzaKBX-iIi",
390 | "colab_type": "code",
391 | "colab": {
392 | "base_uri": "https://localhost:8080/",
393 | "height": 55
394 | },
395 | "outputId": "a44c7a78-672f-4a2f-975a-8bde658c4f7b"
396 | },
397 | "source": [
398 | "print(W)\n",
399 | "print(b)"
400 | ],
401 | "execution_count": 13,
402 | "outputs": [
403 | {
404 | "output_type": "stream",
405 | "text": [
406 | "tensor([0.2267], requires_grad=True)\n",
407 | "tensor([0.1000], requires_grad=True)\n"
408 | ],
409 | "name": "stdout"
410 | }
411 | ]
412 | },
413 | {
414 | "cell_type": "markdown",
415 | "metadata": {
416 | "id": "wdhLdnqV_fxT",
417 | "colab_type": "text"
418 | },
419 | "source": [
420 | "## Validate Training Result"
421 | ]
422 | },
423 | {
424 | "cell_type": "code",
425 | "metadata": {
426 | "id": "0WBSlNen_lq9",
427 | "colab_type": "code",
428 | "colab": {
429 | "base_uri": "https://localhost:8080/",
430 | "height": 73
431 | },
432 | "outputId": "e3edb664-6b01-482d-b5d4-7c33a67487d0"
433 | },
434 | "source": [
435 | "hypothesis = x_train * W + b\n",
436 | "print(hypothesis)"
437 | ],
438 | "execution_count": 14,
439 | "outputs": [
440 | {
441 | "output_type": "stream",
442 | "text": [
443 | "tensor([[0.3267],\n",
444 | " [0.5533],\n",
445 | " [0.7800]], grad_fn=)\n"
446 | ],
447 | "name": "stdout"
448 | }
449 | ]
450 | },
451 | {
452 | "cell_type": "code",
453 | "metadata": {
454 | "id": "s1EsjgGq_n3h",
455 | "colab_type": "code",
456 | "colab": {
457 | "base_uri": "https://localhost:8080/",
458 | "height": 36
459 | },
460 | "outputId": "e9255812-1543-43a1-8f33-70dd6af41d76"
461 | },
462 | "source": [
463 | "cost = torch.mean((hypothesis - y_train)**2) #MSE\n",
464 | "print(cost)"
465 | ],
466 | "execution_count": 15,
467 | "outputs": [
468 | {
469 | "output_type": "stream",
470 | "text": [
471 | "tensor(21.8693, grad_fn=)\n"
472 | ],
473 | "name": "stdout"
474 | }
475 | ]
476 | },
477 | {
478 | "cell_type": "markdown",
479 | "metadata": {
480 | "id": "gAcxhYLU_wef",
481 | "colab_type": "text"
482 | },
483 | "source": [
484 | "## Training\n",
485 | "Loss가 올바르게 감소하는 것을 확인했으니, 충분히 학습시키기"
486 | ]
487 | },
488 | {
489 | "cell_type": "code",
490 | "metadata": {
491 | "id": "WrC1sw-g_0X7",
492 | "colab_type": "code",
493 | "colab": {
494 | "base_uri": "https://localhost:8080/",
495 | "height": 204
496 | },
497 | "outputId": "3304d230-8e82-48df-d6cd-ba5a565f5a7a"
498 | },
499 | "source": [
500 | "epochs = 1000\n",
501 | "for epoch in range(1, epochs+1):\n",
502 | " # Train one step\n",
503 | " optimizer.zero_grad()\n",
504 | " cost.backward()\n",
505 | " optimizer.step()\n",
506 | "\n",
507 | " # Calculate new cost\n",
508 | " hypothesis = x_train * W + b\n",
509 | " cost = torch.mean((hypothesis - y_train)**2)\n",
510 | "\n",
511 | " if epoch % 100 == 0:\n",
512 | " print('Epoch {:4d}/{} W: {:.3f} b: {:.3f} Cost: {:.6f}'.format(epoch, epochs, W.item(), b.item(), cost.item()))"
513 | ],
514 | "execution_count": 16,
515 | "outputs": [
516 | {
517 | "output_type": "stream",
518 | "text": [
519 | "Epoch 100/1000 W: 2.035 b: 0.921 Cost: 0.000895\n",
520 | "Epoch 200/1000 W: 2.027 b: 0.938 Cost: 0.000553\n",
521 | "Epoch 300/1000 W: 2.021 b: 0.951 Cost: 0.000342\n",
522 | "Epoch 400/1000 W: 2.017 b: 0.962 Cost: 0.000211\n",
523 | "Epoch 500/1000 W: 2.013 b: 0.970 Cost: 0.000130\n",
524 | "Epoch 600/1000 W: 2.010 b: 0.976 Cost: 0.000081\n",
525 | "Epoch 700/1000 W: 2.008 b: 0.981 Cost: 0.000050\n",
526 | "Epoch 800/1000 W: 2.006 b: 0.985 Cost: 0.000031\n",
527 | "Epoch 900/1000 W: 2.005 b: 0.988 Cost: 0.000019\n",
528 | "Epoch 1000/1000 W: 2.004 b: 0.991 Cost: 0.000012\n"
529 | ],
530 | "name": "stdout"
531 | }
532 | ]
533 | },
534 | {
535 | "cell_type": "markdown",
536 | "metadata": {
537 | "id": "yB5aDwSTA6lO",
538 | "colab_type": "text"
539 | },
540 | "source": [
541 | "## Test"
542 | ]
543 | },
544 | {
545 | "cell_type": "code",
546 | "metadata": {
547 | "id": "e2zrXzIZA8r6",
548 | "colab_type": "code",
549 | "colab": {}
550 | },
551 | "source": [
552 | "x_test = torch.FloatTensor([[5], [7], [10]])\n",
553 | "y_test = torch.FloatTensor([[11], [15], [21]])"
554 | ],
555 | "execution_count": 0,
556 | "outputs": []
557 | },
558 | {
559 | "cell_type": "code",
560 | "metadata": {
561 | "id": "fvP6TnTPBK2S",
562 | "colab_type": "code",
563 | "colab": {
564 | "base_uri": "https://localhost:8080/",
565 | "height": 92
566 | },
567 | "outputId": "3dce38f4-0c27-4742-940a-1b8bb625236a"
568 | },
569 | "source": [
570 | "hypothesis = W * x_test + b\n",
571 | "cost = torch.mean((hypothesis - y_test)**2)\n",
572 | "\n",
573 | "print(hypothesis)\n",
574 | "print('Cost:',cost)"
575 | ],
576 | "execution_count": 18,
577 | "outputs": [
578 | {
579 | "output_type": "stream",
580 | "text": [
581 | "tensor([[11.0109],\n",
582 | " [15.0188],\n",
583 | " [21.0308]], grad_fn=)\n",
584 | "Cost: tensor(0.0005, grad_fn=)\n"
585 | ],
586 | "name": "stdout"
587 | }
588 | ]
589 | }
590 | ]
591 | }
592 |
--------------------------------------------------------------------------------
/2.이진 분류(로지스틱_회귀).ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "4.로지스틱 회귀",
7 | "provenance": []
8 | },
9 | "kernelspec": {
10 | "name": "python3",
11 | "display_name": "Python 3"
12 | }
13 | },
14 | "cells": [
15 | {
16 | "cell_type": "markdown",
17 | "metadata": {
18 | "id": "yLWxB9XMk28U",
19 | "colab_type": "text"
20 | },
21 | "source": [
22 | "# Logistic Regression"
23 | ]
24 | },
25 | {
26 | "cell_type": "markdown",
27 | "metadata": {
28 | "id": "dq2UAIVClurv",
29 | "colab_type": "text"
30 | },
31 | "source": [
32 | "## Hypothesis\n",
33 | "$$ H(X) = \\frac{1}{1+e^{-WX}} $$"
34 | ]
35 | },
36 | {
37 | "cell_type": "markdown",
38 | "metadata": {
39 | "id": "6ZbwZhcZl8Rc",
40 | "colab_type": "text"
41 | },
42 | "source": [
43 | "## Cost\n",
44 | "$$ cost(W) = -\\frac{1}{m}\\Sigma ylog(H(x)) + (1 - y)(log(1 - H(x)) $$\n",
45 | "- $Log$ loss for binary classification\n",
46 | " - If $ y \\simeq H(x)$, cost converges to 0.\n",
47 | " - If $ y \\neq H(x)$, cost converges to $\\infty$"
48 | ]
49 | },
50 | {
51 | "cell_type": "markdown",
52 | "metadata": {
53 | "id": "ikgnr9jwnR7X",
54 | "colab_type": "text"
55 | },
56 | "source": [
57 | "## Weight Update via Gradient Descent\n",
58 | "$$ W := W - lr\\cdot \\frac{\\partial}{\\partial W}cost(W)$$"
59 | ]
60 | },
61 | {
62 | "cell_type": "markdown",
63 | "metadata": {
64 | "id": "2lG1GqGEn05x",
65 | "colab_type": "text"
66 | },
67 | "source": [
68 | "# Prepare Data"
69 | ]
70 | },
71 | {
72 | "cell_type": "code",
73 | "metadata": {
74 | "id": "ZFcK8GO2lv0q",
75 | "colab_type": "code",
76 | "outputId": "1787ceac-ca9b-4c09-f58c-1727cf654696",
77 | "colab": {
78 | "base_uri": "https://localhost:8080/",
79 | "height": 158
80 | }
81 | },
82 | "source": [
83 | "import torch\n",
84 | "import torch.optim as optim\n",
85 | "\n",
86 | "torch.manual_seed(1)\n",
87 | "\n",
88 | "x_data = [[1, 2], [2, 3], [3, 1], [4, 3], [5, 3], [6, 2]]\n",
89 | "y_data = [[0], [0], [0], [1], [1], [1]]\n",
90 | "\n",
91 | "x_train = torch.FloatTensor(x_data)\n",
92 | "y_train = torch.FloatTensor(y_data)\n",
93 | "\n",
94 | "print(x_train)\n",
95 | "print(x_train.shape)\n",
96 | "print(y_train.shape)"
97 | ],
98 | "execution_count": 1,
99 | "outputs": [
100 | {
101 | "output_type": "stream",
102 | "text": [
103 | "tensor([[1., 2.],\n",
104 | " [2., 3.],\n",
105 | " [3., 1.],\n",
106 | " [4., 3.],\n",
107 | " [5., 3.],\n",
108 | " [6., 2.]])\n",
109 | "torch.Size([6, 2])\n",
110 | "torch.Size([6, 1])\n"
111 | ],
112 | "name": "stdout"
113 | }
114 | ]
115 | },
116 | {
117 | "cell_type": "markdown",
118 | "metadata": {
119 | "id": "36fb8seLoNer",
120 | "colab_type": "text"
121 | },
122 | "source": [
123 | "# Train Model"
124 | ]
125 | },
126 | {
127 | "cell_type": "code",
128 | "metadata": {
129 | "id": "lk6PmgvYoXPP",
130 | "colab_type": "code",
131 | "outputId": "23d0dbee-3fca-433c-c36c-1f7915243b29",
132 | "colab": {
133 | "base_uri": "https://localhost:8080/",
134 | "height": 210
135 | }
136 | },
137 | "source": [
138 | "W = torch.ones((2, 1), requires_grad=True)\n",
139 | "b = torch.zeros(1, requires_grad=True)\n",
140 | "\n",
141 | "optimizer = optim.SGD([W, b], lr=1)\n",
142 | "\n",
143 | "epochs = 1000\n",
144 | "for epoch in range(epochs + 1):\n",
145 | " # Calculate Cost\n",
146 | " hypothesis = torch.sigmoid(x_train.matmul(W) + b)\n",
147 | " cost = -(y_train * torch.log(hypothesis) +\n",
148 | " (1 - y_train) * torch.log(1 - hypothesis)).mean()\n",
149 | " \n",
150 | " # Train\n",
151 | " optimizer.zero_grad()\n",
152 | " cost.backward()\n",
153 | " optimizer.step()\n",
154 | "\n",
155 | " if epoch % 100 == 0:\n",
156 | " print('Epoch {:4d}/{} Cost: {:.6f}'.format(\n",
157 | " epoch, epochs, cost.item()\n",
158 | " ))"
159 | ],
160 | "execution_count": 2,
161 | "outputs": [
162 | {
163 | "output_type": "stream",
164 | "text": [
165 | "Epoch 0/1000 Cost: 2.012506\n",
166 | "Epoch 100/1000 Cost: 0.131138\n",
167 | "Epoch 200/1000 Cost: 0.079377\n",
168 | "Epoch 300/1000 Cost: 0.057256\n",
169 | "Epoch 400/1000 Cost: 0.044908\n",
170 | "Epoch 500/1000 Cost: 0.036998\n",
171 | "Epoch 600/1000 Cost: 0.031483\n",
172 | "Epoch 700/1000 Cost: 0.027413\n",
173 | "Epoch 800/1000 Cost: 0.024282\n",
174 | "Epoch 900/1000 Cost: 0.021798\n",
175 | "Epoch 1000/1000 Cost: 0.019778\n"
176 | ],
177 | "name": "stdout"
178 | }
179 | ]
180 | },
181 | {
182 | "cell_type": "markdown",
183 | "metadata": {
184 | "id": "PCDz7KX1prB1",
185 | "colab_type": "text"
186 | },
187 | "source": [
188 | "# Evaluate Model"
189 | ]
190 | },
191 | {
192 | "cell_type": "code",
193 | "metadata": {
194 | "id": "jQ560dl9pad4",
195 | "colab_type": "code",
196 | "outputId": "bb556b89-31b3-4858-dc7e-af39aee3d187",
197 | "colab": {
198 | "base_uri": "https://localhost:8080/",
199 | "height": 228
200 | }
201 | },
202 | "source": [
203 | "hypothesis = torch.sigmoid(x_train.mm(W) + b)\n",
204 | "\n",
205 | "prediction = hypothesis >= torch.FloatTensor([0.5]) # T가 0.5보다 크거나 같으면 True, 아니면 False\n",
206 | "print(prediction)\n",
207 | "\n",
208 | "correct_prediction = prediction.float() == y_train # T를 1.0, 0.0 으로 바꾼 다음 Y와 비교.\n",
209 | "print(correct_prediction)"
210 | ],
211 | "execution_count": 3,
212 | "outputs": [
213 | {
214 | "output_type": "stream",
215 | "text": [
216 | "tensor([[False],\n",
217 | " [False],\n",
218 | " [False],\n",
219 | " [ True],\n",
220 | " [ True],\n",
221 | " [ True]])\n",
222 | "tensor([[True],\n",
223 | " [True],\n",
224 | " [True],\n",
225 | " [True],\n",
226 | " [True],\n",
227 | " [True]])\n"
228 | ],
229 | "name": "stdout"
230 | }
231 | ]
232 | },
233 | {
234 | "cell_type": "code",
235 | "metadata": {
236 | "id": "iBnVXUDiqDAl",
237 | "colab_type": "code",
238 | "outputId": "2809ba7c-c586-4dc9-9de0-f0cd41621046",
239 | "colab": {
240 | "base_uri": "https://localhost:8080/",
241 | "height": 34
242 | }
243 | },
244 | "source": [
245 | "accuracy = correct_prediction.sum().item() / len(correct_prediction)\n",
246 | "print('The model has an accuracy of {:2.2f}% for the training set.'.format(accuracy * 100))"
247 | ],
248 | "execution_count": 4,
249 | "outputs": [
250 | {
251 | "output_type": "stream",
252 | "text": [
253 | "The model has an accuracy of 100.00% for the training set.\n"
254 | ],
255 | "name": "stdout"
256 | }
257 | ]
258 | },
259 | {
260 | "cell_type": "code",
261 | "metadata": {
262 | "id": "v70dHimtqMx1",
263 | "colab_type": "code",
264 | "outputId": "62f79c79-fb7a-4aec-e423-bbba0113bbda",
265 | "colab": {
266 | "base_uri": "https://localhost:8080/",
267 | "height": 34
268 | }
269 | },
270 | "source": [
271 | "XX = [[100, 5]]\n",
272 | "xx = torch.FloatTensor(XX);\n",
273 | "hypothesis = torch.sigmoid(xx.matmul(W) + b)\n",
274 | "prediction = hypothesis >= torch.FloatTensor([0.5])\n",
275 | "print(prediction)"
276 | ],
277 | "execution_count": 5,
278 | "outputs": [
279 | {
280 | "output_type": "stream",
281 | "text": [
282 | "tensor([[True]])\n"
283 | ],
284 | "name": "stdout"
285 | }
286 | ]
287 | }
288 | ]
289 | }
--------------------------------------------------------------------------------
/3.다중 분류 (Multi-Class Classification).ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "Untitled3.ipynb",
7 | "provenance": []
8 | },
9 | "kernelspec": {
10 | "name": "python3",
11 | "display_name": "Python 3"
12 | }
13 | },
14 | "cells": [
15 | {
16 | "cell_type": "markdown",
17 | "metadata": {
18 | "id": "Rd0RDB6_mF2f",
19 | "colab_type": "text"
20 | },
21 | "source": [
22 | "# Multi-Class Classification\n",
23 | "클래스가 3개 이상인 경우 Softmax 함수를 이용하여 분류"
24 | ]
25 | },
26 | {
27 | "cell_type": "markdown",
28 | "metadata": {
29 | "id": "MfqKrtLBmFSz",
30 | "colab_type": "text"
31 | },
32 | "source": [
33 | "## Load Data"
34 | ]
35 | },
36 | {
37 | "cell_type": "code",
38 | "metadata": {
39 | "id": "crftV5RUyMWg",
40 | "colab_type": "code",
41 | "outputId": "4fd9fc05-ab7b-4e9d-f6d2-34c63b6409de",
42 | "colab": {
43 | "base_uri": "https://localhost:8080/",
44 | "height": 122
45 | }
46 | },
47 | "source": [
48 | "import torch\n",
49 | "import torch.optim as optim\n",
50 | "\n",
51 | "torch.manual_seed(1)\n",
52 | "\n",
53 | "x_train = [[1, 2, 1, 1],\n",
54 | " [2, 1, 3, 2],\n",
55 | " [3, 1, 3, 4],\n",
56 | " [4, 1, 5, 5],\n",
57 | " [1, 7, 5, 5],\n",
58 | " [1, 2, 5, 6],\n",
59 | " [1, 6, 6, 6],\n",
60 | " [1, 7, 7, 7]]\n",
61 | "y_train = [2, 2, 2, 1, 1, 1, 0, 0]\n",
62 | "\n",
63 | "x_train = torch.FloatTensor(x_train)\n",
64 | "y_train = torch.LongTensor(y_train)\n",
65 | "\n",
66 | "print(x_train[:5])\n",
67 | "print(y_train[:5])"
68 | ],
69 | "execution_count": 1,
70 | "outputs": [
71 | {
72 | "output_type": "stream",
73 | "text": [
74 | "tensor([[1., 2., 1., 1.],\n",
75 | " [2., 1., 3., 2.],\n",
76 | " [3., 1., 3., 4.],\n",
77 | " [4., 1., 5., 5.],\n",
78 | " [1., 7., 5., 5.]])\n",
79 | "tensor([2, 2, 2, 1, 1])\n"
80 | ],
81 | "name": "stdout"
82 | }
83 | ]
84 | },
85 | {
86 | "cell_type": "markdown",
87 | "metadata": {
88 | "id": "e23GdH-xYftA",
89 | "colab_type": "text"
90 | },
91 | "source": [
92 | "## one-hot encoding"
93 | ]
94 | },
95 | {
96 | "cell_type": "code",
97 | "metadata": {
98 | "id": "miaw_JqdbsWd",
99 | "colab_type": "code",
100 | "outputId": "b36a3684-b581-4ce4-b0c3-590b9ed37068",
101 | "colab": {
102 | "base_uri": "https://localhost:8080/",
103 | "height": 351
104 | }
105 | },
106 | "source": [
107 | "print(y_train)\n",
108 | "print(y_train.shape)\n",
109 | "print(y_train.unsqueeze(1)) # unsqueeze 함수는 인수로 받은 위치에 새로운 차원을 삽입한다.\n",
110 | "print(y_train.unsqueeze(1).shape)\n",
111 | "\n",
112 | "nb_class = 3\n",
113 | "nb_data = len(y_train)\n",
114 | "y_one_hot = torch.zeros(nb_data, nb_class)\n",
115 | "y_one_hot.scatter_(1, y_train.unsqueeze(1), 1) # scatter(a,b,c): dimension a 에 대해 b에서 지정한 위치에 c 삽입\n",
116 | " # 언더바(_)를 함수뒤에 붙이면 in-place operation, 바꿔치기 연산이 되는데, 연산의 결과를 바로 y_one_hot에 저장하는 것이다.\n",
117 | " # (y_one_hot = y_one_hot.scatter.... 할 필요없이)\n",
118 | "\n",
119 | "print(y_one_hot)"
120 | ],
121 | "execution_count": 2,
122 | "outputs": [
123 | {
124 | "output_type": "stream",
125 | "text": [
126 | "tensor([2, 2, 2, 1, 1, 1, 0, 0])\n",
127 | "torch.Size([8])\n",
128 | "tensor([[2],\n",
129 | " [2],\n",
130 | " [2],\n",
131 | " [1],\n",
132 | " [1],\n",
133 | " [1],\n",
134 | " [0],\n",
135 | " [0]])\n",
136 | "torch.Size([8, 1])\n",
137 | "tensor([[0., 0., 1.],\n",
138 | " [0., 0., 1.],\n",
139 | " [0., 0., 1.],\n",
140 | " [0., 1., 0.],\n",
141 | " [0., 1., 0.],\n",
142 | " [0., 1., 0.],\n",
143 | " [1., 0., 0.],\n",
144 | " [1., 0., 0.]])\n"
145 | ],
146 | "name": "stdout"
147 | }
148 | ]
149 | },
150 | {
151 | "cell_type": "markdown",
152 | "metadata": {
153 | "id": "OTerh5sZo69T",
154 | "colab_type": "text"
155 | },
156 | "source": [
157 | "## Train Model"
158 | ]
159 | },
160 | {
161 | "cell_type": "code",
162 | "metadata": {
163 | "id": "GHSZppOfyp-n",
164 | "colab_type": "code",
165 | "outputId": "84ef90d7-5661-4fe4-c255-f2c44d7b279a",
166 | "colab": {
167 | "base_uri": "https://localhost:8080/",
168 | "height": 210
169 | }
170 | },
171 | "source": [
172 | "import torch.nn.functional as F # softmax가 포함된 클래스\n",
173 | "\n",
174 | "W = torch.zeros((4, nb_class), requires_grad=True) # 3개의 클래스에 대한 4개의 Feature별 가중치\n",
175 | "b = torch.zeros(nb_class, requires_grad=True) # 편향은 하나만 써도 되고 클래스 갯수만큼 써도 됨\n",
176 | "\n",
177 | "# optimizer 설정\n",
178 | "optimizer = optim.SGD([W, b], lr=0.01)\n",
179 | "nb_epochs = 1000\n",
180 | "for epoch in range(nb_epochs + 1):\n",
181 | "\n",
182 | " # Cost 계산 (1)\n",
183 | " hypothesis = F.softmax(x_train.matmul(W) + b, dim=1) # or .mm or @\n",
184 | " \n",
185 | " # cost 표현번 1번 예시\n",
186 | " # one-hot 인코딩이 되어있고, Cross Entropy를 직접 구현\n",
187 | " cost = (y_one_hot * -torch.log(hypothesis)).sum(dim=1).mean()\n",
188 | " \n",
189 | " # cost 표현법 2번 예시\n",
190 | " # torch.nn.functional에 포함된 cross entropy 사용\n",
191 | " # F.cross_entropy를 사용하면 scatter 함수를 이용한 one_hot_encoding을 안해도 됨\n",
192 | " # cost = F.cross_entropy((x_train.matmul(W) + b), y_train)\n",
193 | "\n",
194 | " # cost로 H(x) 개선\n",
195 | " optimizer.zero_grad()\n",
196 | " cost.backward()\n",
197 | " optimizer.step()\n",
198 | "\n",
199 | " # 100번마다 로그 출력\n",
200 | " if epoch % 100 == 0:\n",
201 | " print('Epoch {:4d}/{} Cost: {:.6f}'.format(\n",
202 | " epoch, nb_epochs, cost.item()\n",
203 | " ))"
204 | ],
205 | "execution_count": 3,
206 | "outputs": [
207 | {
208 | "output_type": "stream",
209 | "text": [
210 | "Epoch 0/1000 Cost: 1.098612\n",
211 | "Epoch 100/1000 Cost: 0.825978\n",
212 | "Epoch 200/1000 Cost: 0.745367\n",
213 | "Epoch 300/1000 Cost: 0.695094\n",
214 | "Epoch 400/1000 Cost: 0.658135\n",
215 | "Epoch 500/1000 Cost: 0.629088\n",
216 | "Epoch 600/1000 Cost: 0.605386\n",
217 | "Epoch 700/1000 Cost: 0.585531\n",
218 | "Epoch 800/1000 Cost: 0.568553\n",
219 | "Epoch 900/1000 Cost: 0.553787\n",
220 | "Epoch 1000/1000 Cost: 0.540759\n"
221 | ],
222 | "name": "stdout"
223 | }
224 | ]
225 | },
226 | {
227 | "cell_type": "markdown",
228 | "metadata": {
229 | "id": "QixkNjHgpwHM",
230 | "colab_type": "text"
231 | },
232 | "source": [
233 | "## Evaluate Model"
234 | ]
235 | },
236 | {
237 | "cell_type": "code",
238 | "metadata": {
239 | "id": "QFcRqSOtyxNR",
240 | "colab_type": "code",
241 | "outputId": "03ca8640-01c0-4107-a7bc-9c4d8efefc57",
242 | "colab": {
243 | "base_uri": "https://localhost:8080/",
244 | "height": 228
245 | }
246 | },
247 | "source": [
248 | "# 학습된 W,b를 통한 클래스 예측\n",
249 | "hypothesis = F.softmax(x_train.matmul(W) + b, dim=1) # or .mm or @\n",
250 | "predict = torch.argmax(hypothesis, dim=1) # 가장 큰 값의 인덱스 출력\n",
251 | "\n",
252 | "print(hypothesis)\n",
253 | "print(predict)\n",
254 | "print(y_train)\n",
255 | "\n",
256 | "\n",
257 | "# 정확도 계산 \n",
258 | "correct_prediction = predict.float() == y_train\n",
259 | "print(correct_prediction)\n",
260 | "accuracy = correct_prediction.sum().item() / len(correct_prediction)\n",
261 | "print('The model has an accuracy of {:2.2f}% for the training set.'.format(accuracy * 100))"
262 | ],
263 | "execution_count": 4,
264 | "outputs": [
265 | {
266 | "output_type": "stream",
267 | "text": [
268 | "tensor([[0.0808, 0.1519, 0.7673],\n",
269 | " [0.0561, 0.2938, 0.6502],\n",
270 | " [0.0169, 0.4154, 0.5677],\n",
271 | " [0.0107, 0.5506, 0.4387],\n",
272 | " [0.6195, 0.3130, 0.0675],\n",
273 | " [0.2591, 0.7232, 0.0178],\n",
274 | " [0.5794, 0.3999, 0.0206],\n",
275 | " [0.6471, 0.3453, 0.0075]], grad_fn=)\n",
276 | "tensor([2, 2, 2, 1, 0, 1, 0, 0], grad_fn=)\n",
277 | "tensor([2, 2, 2, 1, 1, 1, 0, 0])\n",
278 | "tensor([ True, True, True, True, False, True, True, True])\n",
279 | "The model has an accuracy of 87.50% for the training set.\n"
280 | ],
281 | "name": "stdout"
282 | }
283 | ]
284 | }
285 | ]
286 | }
--------------------------------------------------------------------------------
/4.심층신경망.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "8.심층신경망.ipynb",
7 | "provenance": [],
8 | "collapsed_sections": []
9 | },
10 | "kernelspec": {
11 | "name": "python3",
12 | "display_name": "Python 3"
13 | }
14 | },
15 | "cells": [
16 | {
17 | "cell_type": "markdown",
18 | "metadata": {
19 | "id": "vaU2pRqXqESs",
20 | "colab_type": "text"
21 | },
22 | "source": [
23 | "# Deep Neural Network (DNNs)\n",
24 | "MNIST Dataset Classficiation with DNN"
25 | ]
26 | },
27 | {
28 | "cell_type": "code",
29 | "metadata": {
30 | "id": "sZs_JS5GqD3v",
31 | "colab_type": "code",
32 | "colab": {}
33 | },
34 | "source": [
35 | "import torch\n",
36 | "import torchvision.datasets as dsets\n",
37 | "import torchvision.transforms as transforms\n",
38 | "from torch.utils.data import DataLoader\n",
39 | "import matplotlib.pyplot as plt\n",
40 | "import random"
41 | ],
42 | "execution_count": 0,
43 | "outputs": []
44 | },
45 | {
46 | "cell_type": "code",
47 | "metadata": {
48 | "id": "vSu0b7UZqSWa",
49 | "colab_type": "code",
50 | "colab": {}
51 | },
52 | "source": [
53 | "GPU = torch.cuda.is_available() \n",
54 | "device = torch.device(\"cuda\" if GPU else \"cpu\") # GPU 사용 가능하면 사용하고 아니면 CPU 사용"
55 | ],
56 | "execution_count": 0,
57 | "outputs": []
58 | },
59 | {
60 | "cell_type": "code",
61 | "metadata": {
62 | "id": "iFlS9aB3qdS_",
63 | "colab_type": "code",
64 | "colab": {}
65 | },
66 | "source": [
67 | "torch.manual_seed(1)\n",
68 | "if device == 'cuda':\n",
69 | " torch.cuda.manual_seed_all(1)"
70 | ],
71 | "execution_count": 0,
72 | "outputs": []
73 | },
74 | {
75 | "cell_type": "markdown",
76 | "metadata": {
77 | "id": "v3TrMkiPqn8B",
78 | "colab_type": "text"
79 | },
80 | "source": [
81 | "## Load MNIST Dataset"
82 | ]
83 | },
84 | {
85 | "cell_type": "code",
86 | "metadata": {
87 | "id": "DY1b__-fqmPI",
88 | "colab_type": "code",
89 | "colab": {}
90 | },
91 | "source": [
92 | "mnist_train = dsets.MNIST(root='data/',\n",
93 | " train=True,\n",
94 | " transform=transforms.ToTensor(),\n",
95 | " download=True)\n",
96 | "mnist_test = dsets.MNIST(root='data/',\n",
97 | " train=False,\n",
98 | " transform=transforms.ToTensor(),\n",
99 | " download=True)"
100 | ],
101 | "execution_count": 0,
102 | "outputs": []
103 | },
104 | {
105 | "cell_type": "code",
106 | "metadata": {
107 | "id": "0SnG76TJqxUj",
108 | "colab_type": "code",
109 | "colab": {}
110 | },
111 | "source": [
112 | "data_loader = DataLoader(dataset=mnist_train,\n",
113 | " batch_size=100,\n",
114 | " shuffle=True,\n",
115 | " drop_last=True) # Drop Last Batch"
116 | ],
117 | "execution_count": 0,
118 | "outputs": []
119 | },
120 | {
121 | "cell_type": "markdown",
122 | "metadata": {
123 | "id": "FRIkMK9nrFCC",
124 | "colab_type": "text"
125 | },
126 | "source": [
127 | "## DNN Model\n",
128 | "| Input | 784 |\n",
129 | "|--|--|\n",
130 | "| Dense | 512 |\n",
131 | "| ReLU | |\n",
132 | "| Dense | 512 |\n",
133 | "| Output | 10 |"
134 | ]
135 | },
136 | {
137 | "cell_type": "code",
138 | "metadata": {
139 | "id": "jat07t9IrHfv",
140 | "colab_type": "code",
141 | "outputId": "ee78e709-4f28-4243-dc97-43afe72888fc",
142 | "colab": {
143 | "base_uri": "https://localhost:8080/",
144 | "height": 111
145 | }
146 | },
147 | "source": [
148 | "l1 = torch.nn.Linear((28*28), 512).to(device)\n",
149 | "l2 = torch.nn.Linear(512, 10).to(device)\n",
150 | "relu = torch.nn.ReLU()\n",
151 | "\n",
152 | "model = torch.nn.Sequential(l1, relu, l2)\n",
153 | "model"
154 | ],
155 | "execution_count": 0,
156 | "outputs": [
157 | {
158 | "output_type": "execute_result",
159 | "data": {
160 | "text/plain": [
161 | "Sequential(\n",
162 | " (0): Linear(in_features=784, out_features=512, bias=True)\n",
163 | " (1): ReLU()\n",
164 | " (2): Linear(in_features=512, out_features=10, bias=True)\n",
165 | ")"
166 | ]
167 | },
168 | "metadata": {
169 | "tags": []
170 | },
171 | "execution_count": 6
172 | }
173 | ]
174 | },
175 | {
176 | "cell_type": "markdown",
177 | "metadata": {
178 | "id": "sCPzWgSLsucr",
179 | "colab_type": "text"
180 | },
181 | "source": [
182 | "## Train Model"
183 | ]
184 | },
185 | {
186 | "cell_type": "code",
187 | "metadata": {
188 | "id": "Nd4Pa7gZsQjY",
189 | "colab_type": "code",
190 | "colab": {}
191 | },
192 | "source": [
193 | "cost = torch.nn.CrossEntropyLoss().to(device) # Built in Softmax Function\n",
194 | "optimizer = torch.optim.SGD(model.parameters(), lr=0.1)"
195 | ],
196 | "execution_count": 0,
197 | "outputs": []
198 | },
199 | {
200 | "cell_type": "code",
201 | "metadata": {
202 | "id": "SU3xQ2R0tGQJ",
203 | "colab_type": "code",
204 | "outputId": "12fd388a-5169-4d6b-da60-97e09ca3619b",
205 | "colab": {
206 | "base_uri": "https://localhost:8080/",
207 | "height": 111
208 | }
209 | },
210 | "source": [
211 | "epochs = 20\n",
212 | "for epoch in range(1, epochs+1):\n",
213 | " avg_cost = 0\n",
214 | " total_batch = len(data_loader)\n",
215 | "\n",
216 | " for x, y in data_loader: # batch loop\n",
217 | " x = x.view(-1, 28*28).to(device)\n",
218 | " y = y.to(device)\n",
219 | "\n",
220 | " optimizer.zero_grad()\n",
221 | " hypothesis = model(x)\n",
222 | " cost_val = cost(hypothesis, y)\n",
223 | " cost_val.backward()\n",
224 | " optimizer.step()\n",
225 | "\n",
226 | " avg_cost += cost_val\n",
227 | " \n",
228 | " avg_cost /= total_batch\n",
229 | "\n",
230 | " if epoch % 5 == 1 or epoch == epochs:\n",
231 | " print('Epoch {:4d}/{} Cost: {:.6f}'.format(epoch, epochs, avg_cost.item()))"
232 | ],
233 | "execution_count": 0,
234 | "outputs": [
235 | {
236 | "output_type": "stream",
237 | "text": [
238 | "Epoch 1/20 Cost: 0.503713\n",
239 | "Epoch 6/20 Cost: 0.121553\n",
240 | "Epoch 11/20 Cost: 0.070358\n",
241 | "Epoch 16/20 Cost: 0.047347\n",
242 | "Epoch 20/20 Cost: 0.035963\n"
243 | ],
244 | "name": "stdout"
245 | }
246 | ]
247 | },
248 | {
249 | "cell_type": "markdown",
250 | "metadata": {
251 | "id": "tbGC7MbHulXu",
252 | "colab_type": "text"
253 | },
254 | "source": [
255 | "## Evaluate Model"
256 | ]
257 | },
258 | {
259 | "cell_type": "code",
260 | "metadata": {
261 | "id": "K4VM9AxauV9K",
262 | "colab_type": "code",
263 | "outputId": "97731ca7-e06c-46fb-b14e-a07e834a1856",
264 | "colab": {
265 | "base_uri": "https://localhost:8080/",
266 | "height": 396
267 | }
268 | },
269 | "source": [
270 | "with torch.no_grad(): # Don't Calculate Gradient\n",
271 | " x_test = mnist_test.test_data.view(-1, 28 * 28).float().to(device)\n",
272 | " y_test = mnist_test.test_labels.to(device)\n",
273 | "\n",
274 | " pred = model(x_test)\n",
275 | " correct_pred = torch.argmax(pred, 1) == y_test\n",
276 | " acc = correct_pred.float().mean()\n",
277 | " print('Accuracy:',acc.item())\n",
278 | "\n",
279 | " r = random.randint(0, len(mnist_test) - 1)\n",
280 | " X_single_data = mnist_test.test_data[r].view(-1, 28 * 28).float().to(device)\n",
281 | " Y_single_data = mnist_test.test_labels[r].to(device)\n",
282 | "\n",
283 | " print('Label: ', Y_single_data.item())\n",
284 | " single_prediction = model(X_single_data)\n",
285 | " print('Prediction: ', torch.argmax(single_prediction, 1).item())\n",
286 | "\n",
287 | " plt.imshow(mnist_test.test_data[r:r + 1].view(28, 28), cmap='Greys', interpolation='nearest')\n",
288 | " plt.show()"
289 | ],
290 | "execution_count": 0,
291 | "outputs": [
292 | {
293 | "output_type": "stream",
294 | "text": [
295 | "/usr/local/lib/python3.6/dist-packages/torchvision/datasets/mnist.py:60: UserWarning: test_data has been renamed data\n",
296 | " warnings.warn(\"test_data has been renamed data\")\n",
297 | "/usr/local/lib/python3.6/dist-packages/torchvision/datasets/mnist.py:50: UserWarning: test_labels has been renamed targets\n",
298 | " warnings.warn(\"test_labels has been renamed targets\")\n"
299 | ],
300 | "name": "stderr"
301 | },
302 | {
303 | "output_type": "stream",
304 | "text": [
305 | "Accuracy: 0.977400004863739\n",
306 | "Label: 5\n",
307 | "Prediction: 5\n"
308 | ],
309 | "name": "stdout"
310 | },
311 | {
312 | "output_type": "display_data",
313 | "data": {
314 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAN0klEQVR4nO3db4xUZZbH8d8BIYY/MdB0SOMYGUeNMRtloEI2AQnrxAngCyRRM5gQSEx6EiRCMi9WZzVjNPHPZgE3Zh3tWQm9G5ZxEoeIia7DkvHP+AJpkFW0MyNrMEBauogvRkIQgbMv+jJpse9TZdWtugXn+0kqVXVP3b4nV3/cqvvUrcfcXQAuf+PKbgBAexB2IAjCDgRB2IEgCDsQxBXt3NiMGTN89uzZ7dwkEMrhw4d14sQJG6vWVNjNbImkf5U0XtK/u/vTqdfPnj1bAwMDzWwSQEKlUsmtNfw23szGS/o3SUsl3SxppZnd3OjfA9BazXxmny/pkLt/5u5nJP1W0vJi2gJQtGbCfrWkI6OeH82WfYuZ9ZrZgJkNVKvVJjYHoBktPxvv7n3uXnH3Snd3d6s3ByBHM2E/JumaUc9/kC0D0IGaCfteSTeY2Q/NbKKkn0naWUxbAIrW8NCbu581s3WS3tTI0NsWd/+4sM4AFKqpcXZ3f13S6wX1AqCF+LosEARhB4Ig7EAQhB0IgrADQRB2IAjCDgRB2IEgCDsQBGEHgiDsQBCEHQiCsANBEHYgCMIOBEHYgSAIOxAEYQeCIOxAEIQdCIKwA0EQdiAIwg4EQdiBIAg7EARhB4Ig7EAQhB0IgrADQRB2IIimpmw2s8OSvpJ0TtJZd68U0RSA4jUV9sw/uPuJAv4OgBbibTwQRLNhd0l/MLN9ZtY71gvMrNfMBsxsoFqtNrk5AI1qNuwL3X2upKWSHjCzRRe/wN373L3i7pXu7u4mNwegUU2F3d2PZffDknZIml9EUwCK13DYzWyymU298FjSTyUdLKoxAMVq5mz8TEk7zOzC3/kvd//vQrrCJePrr79O1o8fP55bGx4ebmrbb7zxRrJ+4kTrBonGjx+frN93333J+k033ZRbmzJlSkM91dJw2N39M0m3FtgLgBZi6A0IgrADQRB2IAjCDgRB2IEgirgQBh3s9OnTyfqTTz6ZrL/11lvJ+t69e5P1b775Jrd2/vz55LqXsmeffTZZ37BhQ25t06ZNRbcjiSM7EAZhB4Ig7EAQhB0IgrADQRB2IAjCDgTBOPtlIDWWvmLFiuS6b775ZtHtfMuSJUtya0eOHEmuW+vy2Xnz5iXrZ86cya198MEHyXW7urqS9UmTJiXrs2bNStYnTJiQrLcCR3YgCMIOBEHYgSAIOxAEYQeCIOxAEIQdCIJx9svA2bNnc2upnyyWpLVr1ybr06dPT9ZrjXVPnDgxt1brenZ3T9ZrjVWn1k/tM0nKfiK94Xqtn5ouA0d2IAjCDgRB2IEgCDsQBGEHgiDsQBCEHQiCcfbLQGqK382bN7exk+9n3LjWHmtSY+FlXE9etpp728y2mNmwmR0ctWy6me0ys0+z+2mtbRNAs+r5p3WrpIt/buQhSbvd/QZJu7PnADpYzbC7+zuSvrxo8XJJ/dnjfkl3FdwXgII1+qFpprsPZY+/kDQz74Vm1mtmA2Y2UK1WG9wcgGY1fYbER642yL3iwN373L3i7pXu7u5mNwegQY2G/biZ9UhSdj9cXEsAWqHRsO+UtDp7vFrSq8W0A6BVao6zm9l2SYslzTCzo5J+JelpSb8zs/slfS7p3lY22emWLVuWrPf09CTrTzzxRLJe6zfIgXrUDLu7r8wp/aTgXgC0EF+XBYIg7EAQhB0IgrADQRB2IAgucS3Arl27kvVz584l6++++26yvm/fvmR96tSpyTogcWQHwiDsQBCEHQiCsANBEHYgCMIOBEHYgSAYZy/A4OBgsn7bbbcl64cOHUrWb7zxxmT9ueeey63dfffdyXURB0d2IAjCDgRB2IEgCDsQBGEHgiDsQBCEHQiCcfYCXH/99cn6a6+9lqw/88wzyforr7ySrK9Zsya3dvDgwdyaJD366KPJ+vjx45N1XDo4sgNBEHYgCMIOBEHYgSAIOxAEYQeCIOxAEIyzt0GlUknWt23blqxfcUX6P9PLL7+cW3v88ceT665YsSJZv/XWW5N1XDpqHtnNbIuZDZvZwVHLHjOzY2Z2ILulJygHULp63sZvlbRkjOWb3X1Odnu92LYAFK1m2N39HUlftqEXAC3UzAm6dWb2YfY2f1rei8ys18wGzGygWq02sTkAzWg07L+W9CNJcyQNSdqY90J373P3irtXuru7G9wcgGY1FHZ3P+7u59z9vKTfSJpfbFsAitZQ2M2sZ9TTFZLS11ECKF3NcXYz2y5psaQZZnZU0q8kLTazOZJc0mFJP29hj5e9iRMnJuv9/f3JeldXV27t+eefT67b29ubrL/99tvJ+pVXXpmso3PUDLu7rxxj8Ust6AVAC/F1WSAIwg4EQdiBIAg7EARhB4LgEtc6nTp1Krc2adKklm671tDcnXfemVt74YUXkuvu3bs3We/r60vWH3zwwWQdnYMjOxAEYQeCIOxAEIQdCIKwA0EQdiAIwg4EwTh7na699trc2o4dO5LrLliwIFk3s4Z6umDp0qW5teXLlyfXrdX7kSNHGuoJnYcjOxAEYQeCIOxAEIQdCIKwA0EQdiAIwg4EwTh7nU6ePJlbW7RoUXLdrVu3JuurVq1K1psZh681xl9rnP3FF19M1h955JFk/aqrrkrW0T4c2YEgCDsQBGEHgiDsQBCEHQiCsANBEHYgCMbZ67Rx48bc2rp165LrrlmzJlkfGhpK1hcvXpys33LLLbm1hQsXJtedPHlysp76foEkvf/++8n6HXfckayjfWoe2c3sGjP7o5l9YmYfm9n6bPl0M9tlZp9m99Na3y6ARtXzNv6spF+4+82S/l7SA2Z2s6SHJO129xsk7c6eA+hQNcPu7kPuvj97/JWkQUlXS1ouqT97Wb+ku1rVJIDmfa8TdGY2W9KPJe2RNNPdL3zY/ELSzJx1es1swMwGqtVqE60CaEbdYTezKZJekbTB3f86uubuLsnHWs/d+9y94u6V7u7uppoF0Li6wm5mEzQS9G3u/vts8XEz68nqPZKGW9MigCLUHHqzkesrX5I06O6bRpV2Slot6ens/tWWdNgh1q5dm1vbs2dPct3t27cn6w8//HBDPbXDuHHp48HcuXPb1AmaVc84+wJJqyR9ZGYHsmW/1EjIf2dm90v6XNK9rWkRQBFqht3d/yQp79cTflJsOwBaha/LAkEQdiAIwg4EQdiBIAg7EASXuBagv78/WX/qqaeS9XvuuSdZ379/f7J++vTp3NqECROS695+++3J+vr165P1rq6uZB2dgyM7EARhB4Ig7EAQhB0IgrADQRB2IAjCDgTBOHsbzJo1K1l/7733kvVTp04l64ODg7m16667LrnutGn8KHAUHNmBIAg7EARhB4Ig7EAQhB0IgrADQRB2IAjG2S8BkyZNStbnzZvXpk5wKePIDgRB2IEgCDsQBGEHgiDsQBCEHQiCsANB1Ay7mV1jZn80s0/M7GMzW58tf8zMjpnZgey2rPXtAmhUPV+qOSvpF+6+38ymStpnZruy2mZ3/5fWtQegKPXMzz4kaSh7/JWZDUq6utWNASjW9/rMbmazJf1Y0p5s0Toz+9DMtpjZmL9vZGa9ZjZgZgPVarWpZgE0ru6wm9kUSa9I2uDuf5X0a0k/kjRHI0f+jWOt5+597l5x90p3d3cBLQNoRF1hN7MJGgn6Nnf/vSS5+3F3P+fu5yX9RtL81rUJoFn1nI03SS9JGnT3TaOW94x62QpJB4tvD0BR6jkbv0DSKkkfmdmBbNkvJa00szmSXNJhST9vSYcAClHP2fg/SbIxSq8X3w6AVuEbdEAQhB0IgrADQRB2IAjCDgRB2IEgCDsQBGEHgiDsQBCEHQiCsANBEHYgCMIOBEHYgSDM3du3MbOqpM9HLZoh6UTbGvh+OrW3Tu1LordGFdnbte4+5u+/tTXs39m42YC7V0prIKFTe+vUviR6a1S7euNtPBAEYQeCKDvsfSVvP6VTe+vUviR6a1Rbeiv1MzuA9in7yA6gTQg7EEQpYTezJWb2ZzM7ZGYPldFDHjM7bGYfZdNQD5TcyxYzGzazg6OWTTezXWb2aXY/5hx7JfXWEdN4J6YZL3XflT39eds/s5vZeEl/kXSHpKOS9kpa6e6ftLWRHGZ2WFLF3Uv/AoaZLZJ0UtJ/uPvfZcv+WdKX7v509g/lNHf/xw7p7TFJJ8uexjubrahn9DTjku6StEYl7rtEX/eqDfutjCP7fEmH3P0zdz8j6beSlpfQR8dz93ckfXnR4uWS+rPH/Rr5n6XtcnrrCO4+5O77s8dfSbowzXip+y7RV1uUEfarJR0Z9fyoOmu+d5f0BzPbZ2a9ZTczhpnuPpQ9/kLSzDKbGUPNabzb6aJpxjtm3zUy/XmzOEH3XQvdfa6kpZIeyN6udiQf+QzWSWOndU3j3S5jTDP+N2Xuu0anP29WGWE/JumaUc9/kC3rCO5+LLsflrRDnTcV9fELM+hm98Ml9/M3nTSN91jTjKsD9l2Z05+XEfa9km4wsx+a2URJP5O0s4Q+vsPMJmcnTmRmkyX9VJ03FfVOSauzx6slvVpiL9/SKdN4500zrpL3XenTn7t722+SlmnkjPz/SfqnMnrI6es6Sf+b3T4uuzdJ2zXytu4bjZzbuF9Sl6Tdkj6V9D+SpndQb/8p6SNJH2okWD0l9bZQI2/RP5R0ILstK3vfJfpqy37j67JAEJygA4Ig7EAQhB0IgrADQRB2IAjCDgRB2IEg/h94ARDRbyNZCAAAAABJRU5ErkJggg==\n",
315 | "text/plain": [
316 | ""
317 | ]
318 | },
319 | "metadata": {
320 | "tags": [],
321 | "needs_background": "light"
322 | }
323 | }
324 | ]
325 | },
326 | {
327 | "cell_type": "markdown",
328 | "metadata": {
329 | "id": "rGeA7JLcvnjL",
330 | "colab_type": "text"
331 | },
332 | "source": [
333 | "### Reference\n",
334 | "- [PyTorch로 시작하는 딥러닝 입문](https://wikidocs.net/60324)"
335 | ]
336 | }
337 | ]
338 | }
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## 스터디 정보
2 | **인터페이스 오픈 AI 스터디 2020**
3 |
8 | Instructor: [백지오](https://github.com/skyil7)
9 | 강의 영상은 유튜브에서 보실 수 있습니다!
10 |
11 | ## 스터디 계획
12 | |Week|Title|Material|Video|
13 | |:---:|:---:|:---:|:---:|
14 | |1|Dive into AI|[슬라이드](https://github.com/sejonginterface/Study_AI/blob/master/slides/1%EC%A3%BC%EC%B0%A8_1%EA%B0%95.pdf)|[유튜브](https://youtu.be/trpAbbBUm0M)|
15 | |1|인공지능을 위한 수학|[슬라이드](https://github.com/sejonginterface/Study_AI/blob/master/slides/1%EC%A3%BC%EC%B0%A8_2%EA%B0%95.pdf)|[유튜브](https://youtu.be/BJ0GfyoFgZM)|
16 | |실습|Teachable Machine으로 인공지능 체험하기|[Teachable Machine](https://teachablemachine.withgoogle.com/)||
17 | |2|선형 회귀와 Convexity|[슬라이드](https://github.com/sejonginterface/Study_AI/blob/master/slides/2%EC%A3%BC%EC%B0%A8.pdf)|[유튜브](https://youtu.be/JZuVEoBB3XA)|
18 | |실습|선형 회귀|[코랩](http://colab.research.google.com/), [실습](https://github.com/sejonginterface/Study_AI/blob/master/1.%EC%84%A0%ED%98%95%ED%9A%8C%EA%B7%80.ipynb)||
19 | |3|선형 분류|[슬라이드](https://github.com/sejonginterface/Study_AI/blob/master/slides/3%EC%A3%BC%EC%B0%A8.pdf)|[유튜브](https://youtu.be/RvIf-POuZ4Y)|
20 | |실습|선형 분류|[이진분류](https://github.com/sejonginterface/Study_AI/blob/master/2.%EC%9D%B4%EC%A7%84%20%EB%B6%84%EB%A5%98(%EB%A1%9C%EC%A7%80%EC%8A%A4%ED%8B%B1_%ED%9A%8C%EA%B7%80).ipynb), [다중분류](https://github.com/sejonginterface/Study_AI/blob/master/3.%EB%8B%A4%EC%A4%91%20%EB%B6%84%EB%A5%98%20(Multi-Class%20Classification).ipynb)||
21 | |4|퍼셉트론과 심층 신경망(DNN)|[슬라이드](https://github.com/sejonginterface/Study_AI/blob/master/slides/4%EC%A3%BC%EC%B0%A8.pdf)|[유튜브](https://youtu.be/tqqU2n8cCpk)|
22 | |실습|손글씨 분류 문제|[실습](https://github.com/sejonginterface/Study_AI/blob/master/4.%EC%8B%AC%EC%B8%B5%EC%8B%A0%EA%B2%BD%EB%A7%9D.ipynb)||
23 | |5|심층 신경망의 적용|[슬라이드](https://github.com/sejonginterface/Study_AI/blob/master/slides/5%EC%A3%BC%EC%B0%A8.pdf)|[유튜브](https://youtu.be/uQoPrtL7Zos)|
24 | |실습|리그오브레전드 승패 예측하기|[실습](https://github.com/sejonginterface/Study_AI/tree/master/%EB%A6%AC%EA%B7%B8%20%EC%98%A4%EB%B8%8C%20%EB%A0%88%EC%A0%84%EB%93%9C%20%EC%8A%B9%ED%8C%A8%20%EC%98%88%EC%B8%A1%ED%95%98%EA%B8%B0)||
25 |
26 | ## 독학하기
27 | 아래 자료들은 스터디 참여를 위해 필요한 배경 지식이나, 스터디에서 다루지 않은 심화적인 부분을 다루는 자료들입니다.
28 | 필요에 따라 이용하시면 되겠습니다!
29 | |분야|이름|분류|소개|
30 | |:---:|:---:|:---:|:---:|
31 | |파이썬|[점프투파이썬](https://wikidocs.net/book/1)|e-book|무료로 볼 수 있는 파이썬 명서, 기초부터 탄탄한 내용!|
32 | |파이썬|[모두를위한파이썬](https://www.edwith.org/pythonforeverybody)|강의|역시 무료인 파이썬 온라인 강의!|
33 | |딥러닝|[What is Neural Network?](https://www.youtube.com/watch?v=aircAruvnKk&list=PLZHQObOWTQDNU6R1_67000Dx_ZCJB-3pi)|유튜브|신경망의 학습과 원리 등을 그래픽으로 설명한, 3Blue1Brown의 딥러닝 강의|
34 | |수학|[프로그래머를 위한 미분/수치 미분/ 경사하강법](https://www.youtube.com/watch?v=LwhK9HBEVAM&list=PLNfg4W25Tapy5hIBmFZgT5coii1HUX6BD&index=9)|강의|홍정모님의 딥러닝을 위한 수학 설명|
35 | |딥러닝|[NYU 딥러닝](https://atcold.github.io/pytorch-Deep-Learning/ko/)|강의 노트|뉴욕대학교 Yann Lecun 교수님의 딥러닝 강의 자료. 한국어도 지원합니다!|
36 |
--------------------------------------------------------------------------------
/slides/1주차_1강.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sejonginterface/Study_AI/1398d769f0e8b52687d57a059aa2f1fbb0ceeb43/slides/1주차_1강.pdf
--------------------------------------------------------------------------------
/slides/1주차_2강.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sejonginterface/Study_AI/1398d769f0e8b52687d57a059aa2f1fbb0ceeb43/slides/1주차_2강.pdf
--------------------------------------------------------------------------------
/slides/2주차.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sejonginterface/Study_AI/1398d769f0e8b52687d57a059aa2f1fbb0ceeb43/slides/2주차.pdf
--------------------------------------------------------------------------------
/slides/3주차.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sejonginterface/Study_AI/1398d769f0e8b52687d57a059aa2f1fbb0ceeb43/slides/3주차.pdf
--------------------------------------------------------------------------------
/slides/4주차.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sejonginterface/Study_AI/1398d769f0e8b52687d57a059aa2f1fbb0ceeb43/slides/4주차.pdf
--------------------------------------------------------------------------------
/slides/5주차.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sejonginterface/Study_AI/1398d769f0e8b52687d57a059aa2f1fbb0ceeb43/slides/5주차.pdf
--------------------------------------------------------------------------------
/롤 승패 예측하기.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "롤.ipynb",
7 | "provenance": [],
8 | "collapsed_sections": [],
9 | "toc_visible": true
10 | },
11 | "kernelspec": {
12 | "name": "python3",
13 | "display_name": "Python 3"
14 | },
15 | "accelerator": "GPU"
16 | },
17 | "cells": [
18 | {
19 | "cell_type": "markdown",
20 | "metadata": {
21 | "id": "8M5BVUv5FVsC",
22 | "colab_type": "text"
23 | },
24 | "source": [
25 | "# League of Legends Win/Lose Prediction\n",
26 | "Task: Binary Classification (Logistic Regression)"
27 | ]
28 | },
29 | {
30 | "cell_type": "code",
31 | "metadata": {
32 | "id": "mK0IzuJC4k5D",
33 | "colab_type": "code",
34 | "colab": {
35 | "resources": {
36 | "http://localhost:8080/nbextensions/google.colab/files.js": {
37 | "data": "Ly8gQ29weXJpZ2h0IDIwMTcgR29vZ2xlIExMQwovLwovLyBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgIkxpY2Vuc2UiKTsKLy8geW91IG1heSBub3QgdXNlIHRoaXMgZmlsZSBleGNlcHQgaW4gY29tcGxpYW5jZSB3aXRoIHRoZSBMaWNlbnNlLgovLyBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXQKLy8KLy8gICAgICBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjAKLy8KLy8gVW5sZXNzIHJlcXVpcmVkIGJ5IGFwcGxpY2FibGUgbGF3IG9yIGFncmVlZCB0byBpbiB3cml0aW5nLCBzb2Z0d2FyZQovLyBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiAiQVMgSVMiIEJBU0lTLAovLyBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC4KLy8gU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZAovLyBsaW1pdGF0aW9ucyB1bmRlciB0aGUgTGljZW5zZS4KCi8qKgogKiBAZmlsZW92ZXJ2aWV3IEhlbHBlcnMgZm9yIGdvb2dsZS5jb2xhYiBQeXRob24gbW9kdWxlLgogKi8KKGZ1bmN0aW9uKHNjb3BlKSB7CmZ1bmN0aW9uIHNwYW4odGV4dCwgc3R5bGVBdHRyaWJ1dGVzID0ge30pIHsKICBjb25zdCBlbGVtZW50ID0gZG9jdW1lbnQuY3JlYXRlRWxlbWVudCgnc3BhbicpOwogIGVsZW1lbnQudGV4dENvbnRlbnQgPSB0ZXh0OwogIGZvciAoY29uc3Qga2V5IG9mIE9iamVjdC5rZXlzKHN0eWxlQXR0cmlidXRlcykpIHsKICAgIGVsZW1lbnQuc3R5bGVba2V5XSA9IHN0eWxlQXR0cmlidXRlc1trZXldOwogIH0KICByZXR1cm4gZWxlbWVudDsKfQoKLy8gTWF4IG51bWJlciBvZiBieXRlcyB3aGljaCB3aWxsIGJlIHVwbG9hZGVkIGF0IGEgdGltZS4KY29uc3QgTUFYX1BBWUxPQURfU0laRSA9IDEwMCAqIDEwMjQ7Ci8vIE1heCBhbW91bnQgb2YgdGltZSB0byBibG9jayB3YWl0aW5nIGZvciB0aGUgdXNlci4KY29uc3QgRklMRV9DSEFOR0VfVElNRU9VVF9NUyA9IDMwICogMTAwMDsKCmZ1bmN0aW9uIF91cGxvYWRGaWxlcyhpbnB1dElkLCBvdXRwdXRJZCkgewogIGNvbnN0IHN0ZXBzID0gdXBsb2FkRmlsZXNTdGVwKGlucHV0SWQsIG91dHB1dElkKTsKICBjb25zdCBvdXRwdXRFbGVtZW50ID0gZG9jdW1lbnQuZ2V0RWxlbWVudEJ5SWQob3V0cHV0SWQpOwogIC8vIENhY2hlIHN0ZXBzIG9uIHRoZSBvdXRwdXRFbGVtZW50IHRvIG1ha2UgaXQgYXZhaWxhYmxlIGZvciB0aGUgbmV4dCBjYWxsCiAgLy8gdG8gdXBsb2FkRmlsZXNDb250aW51ZSBmcm9tIFB5dGhvbi4KICBvdXRwdXRFbGVtZW50LnN0ZXBzID0gc3RlcHM7CgogIHJldHVybiBfdXBsb2FkRmlsZXNDb250aW51ZShvdXRwdXRJZCk7Cn0KCi8vIFRoaXMgaXMgcm91Z2hseSBhbiBhc3luYyBnZW5lcmF0b3IgKG5vdCBzdXBwb3J0ZWQgaW4gdGhlIGJyb3dzZXIgeWV0KSwKLy8gd2hlcmUgdGhlcmUgYXJlIG11bHRpcGxlIGFzeW5jaHJvbm91cyBzdGVwcyBhbmQgdGhlIFB5dGhvbiBzaWRlIGlzIGdvaW5nCi8vIHRvIHBvbGwgZm9yIGNvbXBsZXRpb24gb2YgZWFjaCBzdGVwLgovLyBUaGlzIHVzZXMgYSBQcm9taXNlIHRvIGJsb2NrIHRoZSBweXRob24gc2lkZSBvbiBjb21wbGV0aW9uIG9mIGVhY2ggc3RlcCwKLy8gdGhlbiBwYXNzZXMgdGhlIHJlc3VsdCBvZiB0aGUgcHJldmlvdXMgc3RlcCBhcyB0aGUgaW5wdXQgdG8gdGhlIG5leHQgc3RlcC4KZnVuY3Rpb24gX3VwbG9hZEZpbGVzQ29udGludWUob3V0cHV0SWQpIHsKICBjb25zdCBvdXRwdXRFbGVtZW50ID0gZG9jdW1lbnQuZ2V0RWxlbWVudEJ5SWQob3V0cHV0SWQpOwogIGNvbnN0IHN0ZXBzID0gb3V0cHV0RWxlbWVudC5zdGVwczsKCiAgY29uc3QgbmV4dCA9IHN0ZXBzLm5leHQob3V0cHV0RWxlbWVudC5sYXN0UHJvbWlzZVZhbHVlKTsKICByZXR1cm4gUHJvbWlzZS5yZXNvbHZlKG5leHQudmFsdWUucHJvbWlzZSkudGhlbigodmFsdWUpID0+IHsKICAgIC8vIENhY2hlIHRoZSBsYXN0IHByb21pc2UgdmFsdWUgdG8gbWFrZSBpdCBhdmFpbGFibGUgdG8gdGhlIG5leHQKICAgIC8vIHN0ZXAgb2YgdGhlIGdlbmVyYXRvci4KICAgIG91dHB1dEVsZW1lbnQubGFzdFByb21pc2VWYWx1ZSA9IHZhbHVlOwogICAgcmV0dXJuIG5leHQudmFsdWUucmVzcG9uc2U7CiAgfSk7Cn0KCi8qKgogKiBHZW5lcmF0b3IgZnVuY3Rpb24gd2hpY2ggaXMgY2FsbGVkIGJldHdlZW4gZWFjaCBhc3luYyBzdGVwIG9mIHRoZSB1cGxvYWQKICogcHJvY2Vzcy4KICogQHBhcmFtIHtzdHJpbmd9IGlucHV0SWQgRWxlbWVudCBJRCBvZiB0aGUgaW5wdXQgZmlsZSBwaWNrZXIgZWxlbWVudC4KICogQHBhcmFtIHtzdHJpbmd9IG91dHB1dElkIEVsZW1lbnQgSUQgb2YgdGhlIG91dHB1dCBkaXNwbGF5LgogKiBAcmV0dXJuIHshSXRlcmFibGU8IU9iamVjdD59IEl0ZXJhYmxlIG9mIG5leHQgc3RlcHMuCiAqLwpmdW5jdGlvbiogdXBsb2FkRmlsZXNTdGVwKGlucHV0SWQsIG91dHB1dElkKSB7CiAgY29uc3QgaW5wdXRFbGVtZW50ID0gZG9jdW1lbnQuZ2V0RWxlbWVudEJ5SWQoaW5wdXRJZCk7CiAgaW5wdXRFbGVtZW50LmRpc2FibGVkID0gZmFsc2U7CgogIGNvbnN0IG91dHB1dEVsZW1lbnQgPSBkb2N1bWVudC5nZXRFbGVtZW50QnlJZChvdXRwdXRJZCk7CiAgb3V0cHV0RWxlbWVudC5pbm5lckhUTUwgPSAnJzsKCiAgY29uc3QgcGlja2VkUHJvbWlzZSA9IG5ldyBQcm9taXNlKChyZXNvbHZlKSA9PiB7CiAgICBpbnB1dEVsZW1lbnQuYWRkRXZlbnRMaXN0ZW5lcignY2hhbmdlJywgKGUpID0+IHsKICAgICAgcmVzb2x2ZShlLnRhcmdldC5maWxlcyk7CiAgICB9KTsKICB9KTsKCiAgY29uc3QgY2FuY2VsID0gZG9jdW1lbnQuY3JlYXRlRWxlbWVudCgnYnV0dG9uJyk7CiAgaW5wdXRFbGVtZW50LnBhcmVudEVsZW1lbnQuYXBwZW5kQ2hpbGQoY2FuY2VsKTsKICBjYW5jZWwudGV4dENvbnRlbnQgPSAnQ2FuY2VsIHVwbG9hZCc7CiAgY29uc3QgY2FuY2VsUHJvbWlzZSA9IG5ldyBQcm9taXNlKChyZXNvbHZlKSA9PiB7CiAgICBjYW5jZWwub25jbGljayA9ICgpID0+IHsKICAgICAgcmVzb2x2ZShudWxsKTsKICAgIH07CiAgfSk7CgogIC8vIENhbmNlbCB1cGxvYWQgaWYgdXNlciBoYXNuJ3QgcGlja2VkIGFueXRoaW5nIGluIHRpbWVvdXQuCiAgY29uc3QgdGltZW91dFByb21pc2UgPSBuZXcgUHJvbWlzZSgocmVzb2x2ZSkgPT4gewogICAgc2V0VGltZW91dCgoKSA9PiB7CiAgICAgIHJlc29sdmUobnVsbCk7CiAgICB9LCBGSUxFX0NIQU5HRV9USU1FT1VUX01TKTsKICB9KTsKCiAgLy8gV2FpdCBmb3IgdGhlIHVzZXIgdG8gcGljayB0aGUgZmlsZXMuCiAgY29uc3QgZmlsZXMgPSB5aWVsZCB7CiAgICBwcm9taXNlOiBQcm9taXNlLnJhY2UoW3BpY2tlZFByb21pc2UsIHRpbWVvdXRQcm9taXNlLCBjYW5jZWxQcm9taXNlXSksCiAgICByZXNwb25zZTogewogICAgICBhY3Rpb246ICdzdGFydGluZycsCiAgICB9CiAgfTsKCiAgaWYgKCFmaWxlcykgewogICAgcmV0dXJuIHsKICAgICAgcmVzcG9uc2U6IHsKICAgICAgICBhY3Rpb246ICdjb21wbGV0ZScsCiAgICAgIH0KICAgIH07CiAgfQoKICBjYW5jZWwucmVtb3ZlKCk7CgogIC8vIERpc2FibGUgdGhlIGlucHV0IGVsZW1lbnQgc2luY2UgZnVydGhlciBwaWNrcyBhcmUgbm90IGFsbG93ZWQuCiAgaW5wdXRFbGVtZW50LmRpc2FibGVkID0gdHJ1ZTsKCiAgZm9yIChjb25zdCBmaWxlIG9mIGZpbGVzKSB7CiAgICBjb25zdCBsaSA9IGRvY3VtZW50LmNyZWF0ZUVsZW1lbnQoJ2xpJyk7CiAgICBsaS5hcHBlbmQoc3BhbihmaWxlLm5hbWUsIHtmb250V2VpZ2h0OiAnYm9sZCd9KSk7CiAgICBsaS5hcHBlbmQoc3BhbigKICAgICAgICBgKCR7ZmlsZS50eXBlIHx8ICduL2EnfSkgLSAke2ZpbGUuc2l6ZX0gYnl0ZXMsIGAgKwogICAgICAgIGBsYXN0IG1vZGlmaWVkOiAkewogICAgICAgICAgICBmaWxlLmxhc3RNb2RpZmllZERhdGUgPyBmaWxlLmxhc3RNb2RpZmllZERhdGUudG9Mb2NhbGVEYXRlU3RyaW5nKCkgOgogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAnbi9hJ30gLSBgKSk7CiAgICBjb25zdCBwZXJjZW50ID0gc3BhbignMCUgZG9uZScpOwogICAgbGkuYXBwZW5kQ2hpbGQocGVyY2VudCk7CgogICAgb3V0cHV0RWxlbWVudC5hcHBlbmRDaGlsZChsaSk7CgogICAgY29uc3QgZmlsZURhdGFQcm9taXNlID0gbmV3IFByb21pc2UoKHJlc29sdmUpID0+IHsKICAgICAgY29uc3QgcmVhZGVyID0gbmV3IEZpbGVSZWFkZXIoKTsKICAgICAgcmVhZGVyLm9ubG9hZCA9IChlKSA9PiB7CiAgICAgICAgcmVzb2x2ZShlLnRhcmdldC5yZXN1bHQpOwogICAgICB9OwogICAgICByZWFkZXIucmVhZEFzQXJyYXlCdWZmZXIoZmlsZSk7CiAgICB9KTsKICAgIC8vIFdhaXQgZm9yIHRoZSBkYXRhIHRvIGJlIHJlYWR5LgogICAgbGV0IGZpbGVEYXRhID0geWllbGQgewogICAgICBwcm9taXNlOiBmaWxlRGF0YVByb21pc2UsCiAgICAgIHJlc3BvbnNlOiB7CiAgICAgICAgYWN0aW9uOiAnY29udGludWUnLAogICAgICB9CiAgICB9OwoKICAgIC8vIFVzZSBhIGNodW5rZWQgc2VuZGluZyB0byBhdm9pZCBtZXNzYWdlIHNpemUgbGltaXRzLiBTZWUgYi82MjExNTY2MC4KICAgIGxldCBwb3NpdGlvbiA9IDA7CiAgICB3aGlsZSAocG9zaXRpb24gPCBmaWxlRGF0YS5ieXRlTGVuZ3RoKSB7CiAgICAgIGNvbnN0IGxlbmd0aCA9IE1hdGgubWluKGZpbGVEYXRhLmJ5dGVMZW5ndGggLSBwb3NpdGlvbiwgTUFYX1BBWUxPQURfU0laRSk7CiAgICAgIGNvbnN0IGNodW5rID0gbmV3IFVpbnQ4QXJyYXkoZmlsZURhdGEsIHBvc2l0aW9uLCBsZW5ndGgpOwogICAgICBwb3NpdGlvbiArPSBsZW5ndGg7CgogICAgICBjb25zdCBiYXNlNjQgPSBidG9hKFN0cmluZy5mcm9tQ2hhckNvZGUuYXBwbHkobnVsbCwgY2h1bmspKTsKICAgICAgeWllbGQgewogICAgICAgIHJlc3BvbnNlOiB7CiAgICAgICAgICBhY3Rpb246ICdhcHBlbmQnLAogICAgICAgICAgZmlsZTogZmlsZS5uYW1lLAogICAgICAgICAgZGF0YTogYmFzZTY0LAogICAgICAgIH0sCiAgICAgIH07CiAgICAgIHBlcmNlbnQudGV4dENvbnRlbnQgPQogICAgICAgICAgYCR7TWF0aC5yb3VuZCgocG9zaXRpb24gLyBmaWxlRGF0YS5ieXRlTGVuZ3RoKSAqIDEwMCl9JSBkb25lYDsKICAgIH0KICB9CgogIC8vIEFsbCBkb25lLgogIHlpZWxkIHsKICAgIHJlc3BvbnNlOiB7CiAgICAgIGFjdGlvbjogJ2NvbXBsZXRlJywKICAgIH0KICB9Owp9CgpzY29wZS5nb29nbGUgPSBzY29wZS5nb29nbGUgfHwge307CnNjb3BlLmdvb2dsZS5jb2xhYiA9IHNjb3BlLmdvb2dsZS5jb2xhYiB8fCB7fTsKc2NvcGUuZ29vZ2xlLmNvbGFiLl9maWxlcyA9IHsKICBfdXBsb2FkRmlsZXMsCiAgX3VwbG9hZEZpbGVzQ29udGludWUsCn07Cn0pKHNlbGYpOwo=",
38 | "ok": true,
39 | "headers": [
40 | [
41 | "content-type",
42 | "application/javascript"
43 | ]
44 | ],
45 | "status": 200,
46 | "status_text": ""
47 | }
48 | },
49 | "base_uri": "https://localhost:8080/",
50 | "height": 183
51 | },
52 | "outputId": "17c26229-1376-4caa-918e-12247ff5f3a2"
53 | },
54 | "source": [
55 | "from google.colab import files # 파일 업로드\n",
56 | "files.upload()"
57 | ],
58 | "execution_count": 1,
59 | "outputs": [
60 | {
61 | "output_type": "display_data",
62 | "data": {
63 | "text/html": [
64 | "\n",
65 | " \n",
66 | " \n",
70 | " "
71 | ],
72 | "text/plain": [
73 | ""
74 | ]
75 | },
76 | "metadata": {
77 | "tags": []
78 | }
79 | },
80 | {
81 | "output_type": "stream",
82 | "text": [
83 | "Saving y_test.csv to y_test.csv\n",
84 | "Saving x_test.csv to x_test.csv\n",
85 | "Saving y_train.csv to y_train.csv\n",
86 | "Saving x_train.csv to x_train.csv\n"
87 | ],
88 | "name": "stdout"
89 | }
90 | ]
91 | },
92 | {
93 | "cell_type": "code",
94 | "metadata": {
95 | "id": "C9rkMtwEFyA9",
96 | "colab_type": "code",
97 | "colab": {}
98 | },
99 | "source": [
100 | "import torch\n",
101 | "import torch.optim as optim\n",
102 | "import torch.nn.functional as F\n",
103 | "import pandas as pd\n",
104 | "import numpy as np\n",
105 | "from sklearn.preprocessing import MinMaxScaler # For Normalization"
106 | ],
107 | "execution_count": 0,
108 | "outputs": []
109 | },
110 | {
111 | "cell_type": "code",
112 | "metadata": {
113 | "id": "QwmVe2U3F4W4",
114 | "colab_type": "code",
115 | "colab": {}
116 | },
117 | "source": [
118 | "GPU = torch.cuda.is_available() \n",
119 | "device = torch.device(\"cuda\" if GPU else \"cpu\") # GPU 사용 가능하면 사용하고 아니면 CPU 사용"
120 | ],
121 | "execution_count": 0,
122 | "outputs": []
123 | },
124 | {
125 | "cell_type": "code",
126 | "metadata": {
127 | "id": "sLw8Xyp7GLa-",
128 | "colab_type": "code",
129 | "colab": {}
130 | },
131 | "source": [
132 | "torch.manual_seed(777)\n",
133 | "torch.cuda.manual_seed_all(777)"
134 | ],
135 | "execution_count": 0,
136 | "outputs": []
137 | },
138 | {
139 | "cell_type": "markdown",
140 | "metadata": {
141 | "id": "QhIv-gmZGPNA",
142 | "colab_type": "text"
143 | },
144 | "source": [
145 | "## Load Dataset and pre-process"
146 | ]
147 | },
148 | {
149 | "cell_type": "code",
150 | "metadata": {
151 | "id": "5nfj8WYzH8tG",
152 | "colab_type": "code",
153 | "colab": {}
154 | },
155 | "source": [
156 | "x_train = pd.read_csv('x_train.csv', index_col=0)\n",
157 | "y_train = pd.read_csv('y_train.csv', index_col=0)\n",
158 | "x_test = pd.read_csv('x_test.csv', index_col=0)\n",
159 | "y_test = pd.read_csv('y_test.csv', index_col=0)"
160 | ],
161 | "execution_count": 0,
162 | "outputs": []
163 | },
164 | {
165 | "cell_type": "code",
166 | "metadata": {
167 | "id": "GzRq_BG4IGeI",
168 | "colab_type": "code",
169 | "outputId": "36f2b8a3-548d-4007-875f-65d94ddf40d2",
170 | "colab": {
171 | "base_uri": "https://localhost:8080/",
172 | "height": 224
173 | }
174 | },
175 | "source": [
176 | "x_train.head()"
177 | ],
178 | "execution_count": 6,
179 | "outputs": [
180 | {
181 | "output_type": "execute_result",
182 | "data": {
183 | "text/html": [
184 | "\n",
185 | "\n",
198 | "
\n",
199 | " \n",
200 | " \n",
201 | " | \n",
202 | " gameDuraton | \n",
203 | " blueFirstBlood | \n",
204 | " blueFirstTower | \n",
205 | " blueFirstBaron | \n",
206 | " blueFirstDragon | \n",
207 | " blueFirstInhibitor | \n",
208 | " blueDragonKills | \n",
209 | " blueBaronKills | \n",
210 | " blueTowerKills | \n",
211 | " blueInhibitorKills | \n",
212 | " blueWardPlaced | \n",
213 | " blueWardkills | \n",
214 | " blueKills | \n",
215 | " blueDeath | \n",
216 | " blueAssist | \n",
217 | " blueChampionDamageDealt | \n",
218 | " blueTotalGold | \n",
219 | " blueTotalMinionKills | \n",
220 | " blueTotalLevel | \n",
221 | " blueAvgLevel | \n",
222 | " blueJungleMinionKills | \n",
223 | " blueKillingSpree | \n",
224 | " blueTotalHeal | \n",
225 | " blueObjectDamageDealt | \n",
226 | " redFirstBaron | \n",
227 | " redFirstDragon | \n",
228 | " redFirstInhibitor | \n",
229 | " redDragonKills | \n",
230 | " redBaronKills | \n",
231 | " redTowerKills | \n",
232 | " redInhibitorKills | \n",
233 | " redWardPlaced | \n",
234 | " redWardkills | \n",
235 | " redKills | \n",
236 | " redDeath | \n",
237 | " redAssist | \n",
238 | " redChampionDamageDealt | \n",
239 | " redTotalGold | \n",
240 | " redTotalMinionKills | \n",
241 | " redTotalLevel | \n",
242 | " redAvgLevel | \n",
243 | " redJungleMinionKills | \n",
244 | " redKillingSpree | \n",
245 | " redTotalHeal | \n",
246 | " redObjectDamageDealt | \n",
247 | " Rank_Challenger | \n",
248 | " Rank_GrandMaster | \n",
249 | " Rank_Master | \n",
250 | "
\n",
251 | " \n",
252 | " \n",
253 | " \n",
254 | " 0 | \n",
255 | " 1713 | \n",
256 | " 0 | \n",
257 | " 1 | \n",
258 | " 0 | \n",
259 | " 0 | \n",
260 | " 0 | \n",
261 | " 0 | \n",
262 | " 0 | \n",
263 | " 4 | \n",
264 | " 1 | \n",
265 | " 0 | \n",
266 | " 0 | \n",
267 | " 52 | \n",
268 | " 60 | \n",
269 | " 135 | \n",
270 | " 181255 | \n",
271 | " 86971 | \n",
272 | " 367 | \n",
273 | " 90 | \n",
274 | " 18.0 | \n",
275 | " 0 | \n",
276 | " 11 | \n",
277 | " 56766 | \n",
278 | " 9732 | \n",
279 | " 0 | \n",
280 | " 0 | \n",
281 | " 1 | \n",
282 | " 0 | \n",
283 | " 0 | \n",
284 | " 4 | \n",
285 | " 1 | \n",
286 | " 0 | \n",
287 | " 0 | \n",
288 | " 60 | \n",
289 | " 52 | \n",
290 | " 155 | \n",
291 | " 236665 | \n",
292 | " 86392 | \n",
293 | " 351 | \n",
294 | " 90 | \n",
295 | " 18.0 | \n",
296 | " 0 | \n",
297 | " 16 | \n",
298 | " 75745 | \n",
299 | " 8025 | \n",
300 | " 0 | \n",
301 | " 1 | \n",
302 | " 0 | \n",
303 | "
\n",
304 | " \n",
305 | " 1 | \n",
306 | " 1577 | \n",
307 | " 1 | \n",
308 | " 1 | \n",
309 | " 1 | \n",
310 | " 0 | \n",
311 | " 1 | \n",
312 | " 1 | \n",
313 | " 1 | \n",
314 | " 11 | \n",
315 | " 3 | \n",
316 | " 63 | \n",
317 | " 18 | \n",
318 | " 24 | \n",
319 | " 8 | \n",
320 | " 40 | \n",
321 | " 70900 | \n",
322 | " 54263 | \n",
323 | " 660 | \n",
324 | " 75 | \n",
325 | " 15.0 | \n",
326 | " 167 | \n",
327 | " 7 | \n",
328 | " 18005 | \n",
329 | " 69615 | \n",
330 | " 0 | \n",
331 | " 1 | \n",
332 | " 0 | \n",
333 | " 3 | \n",
334 | " 0 | \n",
335 | " 0 | \n",
336 | " 0 | \n",
337 | " 48 | \n",
338 | " 19 | \n",
339 | " 8 | \n",
340 | " 24 | \n",
341 | " 13 | \n",
342 | " 45032 | \n",
343 | " 41144 | \n",
344 | " 637 | \n",
345 | " 64 | \n",
346 | " 12.8 | \n",
347 | " 142 | \n",
348 | " 3 | \n",
349 | " 37366 | \n",
350 | " 38517 | \n",
351 | " 0 | \n",
352 | " 0 | \n",
353 | " 1 | \n",
354 | "
\n",
355 | " \n",
356 | " 2 | \n",
357 | " 931 | \n",
358 | " 0 | \n",
359 | " 0 | \n",
360 | " 0 | \n",
361 | " 0 | \n",
362 | " 0 | \n",
363 | " 0 | \n",
364 | " 0 | \n",
365 | " 0 | \n",
366 | " 0 | \n",
367 | " 26 | \n",
368 | " 6 | \n",
369 | " 6 | \n",
370 | " 28 | \n",
371 | " 6 | \n",
372 | " 19738 | \n",
373 | " 23858 | \n",
374 | " 324 | \n",
375 | " 45 | \n",
376 | " 9.0 | \n",
377 | " 74 | \n",
378 | " 1 | \n",
379 | " 7488 | \n",
380 | " 3953 | \n",
381 | " 0 | \n",
382 | " 1 | \n",
383 | " 0 | \n",
384 | " 2 | \n",
385 | " 0 | \n",
386 | " 2 | \n",
387 | " 0 | \n",
388 | " 30 | \n",
389 | " 8 | \n",
390 | " 28 | \n",
391 | " 6 | \n",
392 | " 36 | \n",
393 | " 33310 | \n",
394 | " 35044 | \n",
395 | " 378 | \n",
396 | " 50 | \n",
397 | " 10.0 | \n",
398 | " 76 | \n",
399 | " 5 | \n",
400 | " 12242 | \n",
401 | " 25314 | \n",
402 | " 0 | \n",
403 | " 1 | \n",
404 | " 0 | \n",
405 | "
\n",
406 | " \n",
407 | " 3 | \n",
408 | " 1388 | \n",
409 | " 1 | \n",
410 | " 0 | \n",
411 | " 0 | \n",
412 | " 0 | \n",
413 | " 0 | \n",
414 | " 0 | \n",
415 | " 0 | \n",
416 | " 4 | \n",
417 | " 1 | \n",
418 | " 0 | \n",
419 | " 0 | \n",
420 | " 58 | \n",
421 | " 45 | \n",
422 | " 114 | \n",
423 | " 165759 | \n",
424 | " 73791 | \n",
425 | " 304 | \n",
426 | " 90 | \n",
427 | " 18.0 | \n",
428 | " 0 | \n",
429 | " 15 | \n",
430 | " 36128 | \n",
431 | " 12432 | \n",
432 | " 0 | \n",
433 | " 0 | \n",
434 | " 1 | \n",
435 | " 0 | \n",
436 | " 0 | \n",
437 | " 2 | \n",
438 | " 1 | \n",
439 | " 0 | \n",
440 | " 0 | \n",
441 | " 45 | \n",
442 | " 58 | \n",
443 | " 113 | \n",
444 | " 131644 | \n",
445 | " 69276 | \n",
446 | " 265 | \n",
447 | " 90 | \n",
448 | " 18.0 | \n",
449 | " 0 | \n",
450 | " 10 | \n",
451 | " 25100 | \n",
452 | " 4895 | \n",
453 | " 0 | \n",
454 | " 1 | \n",
455 | " 0 | \n",
456 | "
\n",
457 | " \n",
458 | " 4 | \n",
459 | " 1746 | \n",
460 | " 0 | \n",
461 | " 0 | \n",
462 | " 1 | \n",
463 | " 1 | \n",
464 | " 1 | \n",
465 | " 2 | \n",
466 | " 1 | \n",
467 | " 10 | \n",
468 | " 2 | \n",
469 | " 58 | \n",
470 | " 20 | \n",
471 | " 43 | \n",
472 | " 41 | \n",
473 | " 63 | \n",
474 | " 108907 | \n",
475 | " 67493 | \n",
476 | " 623 | \n",
477 | " 80 | \n",
478 | " 16.0 | \n",
479 | " 234 | \n",
480 | " 12 | \n",
481 | " 46089 | \n",
482 | " 85557 | \n",
483 | " 0 | \n",
484 | " 0 | \n",
485 | " 0 | \n",
486 | " 2 | \n",
487 | " 0 | \n",
488 | " 3 | \n",
489 | " 0 | \n",
490 | " 51 | \n",
491 | " 15 | \n",
492 | " 41 | \n",
493 | " 43 | \n",
494 | " 51 | \n",
495 | " 110696 | \n",
496 | " 58327 | \n",
497 | " 596 | \n",
498 | " 73 | \n",
499 | " 14.6 | \n",
500 | " 145 | \n",
501 | " 7 | \n",
502 | " 34174 | \n",
503 | " 24350 | \n",
504 | " 1 | \n",
505 | " 0 | \n",
506 | " 0 | \n",
507 | "
\n",
508 | " \n",
509 | "
\n",
510 | "
"
511 | ],
512 | "text/plain": [
513 | " gameDuraton blueFirstBlood ... Rank_GrandMaster Rank_Master\n",
514 | "0 1713 0 ... 1 0\n",
515 | "1 1577 1 ... 0 1\n",
516 | "2 931 0 ... 1 0\n",
517 | "3 1388 1 ... 1 0\n",
518 | "4 1746 0 ... 0 0\n",
519 | "\n",
520 | "[5 rows x 48 columns]"
521 | ]
522 | },
523 | "metadata": {
524 | "tags": []
525 | },
526 | "execution_count": 6
527 | }
528 | ]
529 | },
530 | {
531 | "cell_type": "code",
532 | "metadata": {
533 | "id": "xoiRVKXhII08",
534 | "colab_type": "code",
535 | "outputId": "74df418b-612c-4119-ee46-177619bade6a",
536 | "colab": {
537 | "base_uri": "https://localhost:8080/",
538 | "height": 204
539 | }
540 | },
541 | "source": [
542 | "y_train.head()"
543 | ],
544 | "execution_count": 7,
545 | "outputs": [
546 | {
547 | "output_type": "execute_result",
548 | "data": {
549 | "text/html": [
550 | "\n",
551 | "\n",
564 | "
\n",
565 | " \n",
566 | " \n",
567 | " | \n",
568 | " blueWins | \n",
569 | "
\n",
570 | " \n",
571 | " \n",
572 | " \n",
573 | " 0 | \n",
574 | " 1 | \n",
575 | "
\n",
576 | " \n",
577 | " 1 | \n",
578 | " 1 | \n",
579 | "
\n",
580 | " \n",
581 | " 2 | \n",
582 | " 0 | \n",
583 | "
\n",
584 | " \n",
585 | " 3 | \n",
586 | " 1 | \n",
587 | "
\n",
588 | " \n",
589 | " 4 | \n",
590 | " 1 | \n",
591 | "
\n",
592 | " \n",
593 | "
\n",
594 | "
"
595 | ],
596 | "text/plain": [
597 | " blueWins\n",
598 | "0 1\n",
599 | "1 1\n",
600 | "2 0\n",
601 | "3 1\n",
602 | "4 1"
603 | ]
604 | },
605 | "metadata": {
606 | "tags": []
607 | },
608 | "execution_count": 7
609 | }
610 | ]
611 | },
612 | {
613 | "cell_type": "code",
614 | "metadata": {
615 | "id": "GcH2HmV4IfTT",
616 | "colab_type": "code",
617 | "colab": {}
618 | },
619 | "source": [
620 | "x_train = np.array(x_train)\n",
621 | "x_test = np.array(x_test)"
622 | ],
623 | "execution_count": 0,
624 | "outputs": []
625 | },
626 | {
627 | "cell_type": "code",
628 | "metadata": {
629 | "id": "IVzbGwbhIo77",
630 | "colab_type": "code",
631 | "outputId": "b2816cc8-4323-49e1-e1b6-c65b4afd0f00",
632 | "colab": {
633 | "base_uri": "https://localhost:8080/",
634 | "height": 317
635 | }
636 | },
637 | "source": [
638 | "scaler = MinMaxScaler() # Normalizer\n",
639 | "x_train = scaler.fit_transform(x_train)\n",
640 | "x_test = scaler.transform(x_test) # Not Fit Transform!\n",
641 | "\n",
642 | "pd.DataFrame(x_train).describe()"
643 | ],
644 | "execution_count": 9,
645 | "outputs": [
646 | {
647 | "output_type": "execute_result",
648 | "data": {
649 | "text/html": [
650 | "\n",
651 | "\n",
664 | "
\n",
665 | " \n",
666 | " \n",
667 | " | \n",
668 | " 0 | \n",
669 | " 1 | \n",
670 | " 2 | \n",
671 | " 3 | \n",
672 | " 4 | \n",
673 | " 5 | \n",
674 | " 6 | \n",
675 | " 7 | \n",
676 | " 8 | \n",
677 | " 9 | \n",
678 | " 10 | \n",
679 | " 11 | \n",
680 | " 12 | \n",
681 | " 13 | \n",
682 | " 14 | \n",
683 | " 15 | \n",
684 | " 16 | \n",
685 | " 17 | \n",
686 | " 18 | \n",
687 | " 19 | \n",
688 | " 20 | \n",
689 | " 21 | \n",
690 | " 22 | \n",
691 | " 23 | \n",
692 | " 24 | \n",
693 | " 25 | \n",
694 | " 26 | \n",
695 | " 27 | \n",
696 | " 28 | \n",
697 | " 29 | \n",
698 | " 30 | \n",
699 | " 31 | \n",
700 | " 32 | \n",
701 | " 33 | \n",
702 | " 34 | \n",
703 | " 35 | \n",
704 | " 36 | \n",
705 | " 37 | \n",
706 | " 38 | \n",
707 | " 39 | \n",
708 | " 40 | \n",
709 | " 41 | \n",
710 | " 42 | \n",
711 | " 43 | \n",
712 | " 44 | \n",
713 | " 45 | \n",
714 | " 46 | \n",
715 | " 47 | \n",
716 | "
\n",
717 | " \n",
718 | " \n",
719 | " \n",
720 | " count | \n",
721 | " 180000.000000 | \n",
722 | " 180000.000000 | \n",
723 | " 180000.000000 | \n",
724 | " 180000.000000 | \n",
725 | " 180000.000000 | \n",
726 | " 180000.000000 | \n",
727 | " 180000.000000 | \n",
728 | " 180000.000000 | \n",
729 | " 180000.000000 | \n",
730 | " 180000.000000 | \n",
731 | " 180000.000000 | \n",
732 | " 180000.000000 | \n",
733 | " 180000.000000 | \n",
734 | " 180000.000000 | \n",
735 | " 180000.000000 | \n",
736 | " 180000.000000 | \n",
737 | " 180000.000000 | \n",
738 | " 180000.000000 | \n",
739 | " 180000.000000 | \n",
740 | " 180000.000000 | \n",
741 | " 180000.000000 | \n",
742 | " 180000.000000 | \n",
743 | " 180000.000000 | \n",
744 | " 180000.000000 | \n",
745 | " 180000.000000 | \n",
746 | " 180000.000000 | \n",
747 | " 180000.000000 | \n",
748 | " 180000.000000 | \n",
749 | " 180000.000000 | \n",
750 | " 180000.000000 | \n",
751 | " 180000.000000 | \n",
752 | " 180000.000000 | \n",
753 | " 180000.000000 | \n",
754 | " 180000.000000 | \n",
755 | " 180000.000000 | \n",
756 | " 180000.000000 | \n",
757 | " 180000.000000 | \n",
758 | " 180000.000000 | \n",
759 | " 180000.000000 | \n",
760 | " 180000.000000 | \n",
761 | " 180000.000000 | \n",
762 | " 180000.000000 | \n",
763 | " 180000.000000 | \n",
764 | " 180000.000000 | \n",
765 | " 180000.000000 | \n",
766 | " 180000.000000 | \n",
767 | " 180000.000000 | \n",
768 | " 180000.000000 | \n",
769 | "
\n",
770 | " \n",
771 | " mean | \n",
772 | " 0.389464 | \n",
773 | " 0.506100 | \n",
774 | " 0.515950 | \n",
775 | " 0.236872 | \n",
776 | " 0.403194 | \n",
777 | " 0.382561 | \n",
778 | " 0.193798 | \n",
779 | " 0.074492 | \n",
780 | " 0.394855 | \n",
781 | " 0.071870 | \n",
782 | " 0.232275 | \n",
783 | " 0.165640 | \n",
784 | " 0.220675 | \n",
785 | " 0.211136 | \n",
786 | " 0.157746 | \n",
787 | " 0.182146 | \n",
788 | " 0.340240 | \n",
789 | " 0.330281 | \n",
790 | " 0.438990 | \n",
791 | " 0.438990 | \n",
792 | " 0.308292 | \n",
793 | " 0.183562 | \n",
794 | " 0.093211 | \n",
795 | " 0.221428 | \n",
796 | " 0.259844 | \n",
797 | " 0.486356 | \n",
798 | " 0.370389 | \n",
799 | " 0.216998 | \n",
800 | " 0.065892 | \n",
801 | " 0.388419 | \n",
802 | " 0.063286 | \n",
803 | " 0.215931 | \n",
804 | " 0.164242 | \n",
805 | " 0.212202 | \n",
806 | " 0.221425 | \n",
807 | " 0.175684 | \n",
808 | " 0.192699 | \n",
809 | " 0.338799 | \n",
810 | " 0.350299 | \n",
811 | " 0.447585 | \n",
812 | " 0.447585 | \n",
813 | " 0.257617 | \n",
814 | " 0.195832 | \n",
815 | " 0.119842 | \n",
816 | " 0.187069 | \n",
817 | " 0.134772 | \n",
818 | " 0.329461 | \n",
819 | " 0.535767 | \n",
820 | "
\n",
821 | " \n",
822 | " std | \n",
823 | " 0.129449 | \n",
824 | " 0.499964 | \n",
825 | " 0.499747 | \n",
826 | " 0.425164 | \n",
827 | " 0.490541 | \n",
828 | " 0.486014 | \n",
829 | " 0.179185 | \n",
830 | " 0.133805 | \n",
831 | " 0.307004 | \n",
832 | " 0.101102 | \n",
833 | " 0.132004 | \n",
834 | " 0.124341 | \n",
835 | " 0.120861 | \n",
836 | " 0.115792 | \n",
837 | " 0.113872 | \n",
838 | " 0.098202 | \n",
839 | " 0.122716 | \n",
840 | " 0.117284 | \n",
841 | " 0.108986 | \n",
842 | " 0.108986 | \n",
843 | " 0.164263 | \n",
844 | " 0.117027 | \n",
845 | " 0.058162 | \n",
846 | " 0.158606 | \n",
847 | " 0.438550 | \n",
848 | " 0.499815 | \n",
849 | " 0.482910 | \n",
850 | " 0.186876 | \n",
851 | " 0.111787 | \n",
852 | " 0.308390 | \n",
853 | " 0.090163 | \n",
854 | " 0.123342 | \n",
855 | " 0.123443 | \n",
856 | " 0.116677 | \n",
857 | " 0.120998 | \n",
858 | " 0.126915 | \n",
859 | " 0.104284 | \n",
860 | " 0.114847 | \n",
861 | " 0.125647 | \n",
862 | " 0.103166 | \n",
863 | " 0.103166 | \n",
864 | " 0.137015 | \n",
865 | " 0.125373 | \n",
866 | " 0.074750 | \n",
867 | " 0.136529 | \n",
868 | " 0.341481 | \n",
869 | " 0.470019 | \n",
870 | " 0.498720 | \n",
871 | "
\n",
872 | " \n",
873 | " min | \n",
874 | " 0.000000 | \n",
875 | " 0.000000 | \n",
876 | " 0.000000 | \n",
877 | " 0.000000 | \n",
878 | " 0.000000 | \n",
879 | " 0.000000 | \n",
880 | " 0.000000 | \n",
881 | " 0.000000 | \n",
882 | " 0.000000 | \n",
883 | " 0.000000 | \n",
884 | " 0.000000 | \n",
885 | " 0.000000 | \n",
886 | " 0.000000 | \n",
887 | " 0.000000 | \n",
888 | " 0.000000 | \n",
889 | " 0.000000 | \n",
890 | " 0.000000 | \n",
891 | " 0.000000 | \n",
892 | " 0.000000 | \n",
893 | " 0.000000 | \n",
894 | " 0.000000 | \n",
895 | " 0.000000 | \n",
896 | " 0.000000 | \n",
897 | " 0.000000 | \n",
898 | " 0.000000 | \n",
899 | " 0.000000 | \n",
900 | " 0.000000 | \n",
901 | " 0.000000 | \n",
902 | " 0.000000 | \n",
903 | " 0.000000 | \n",
904 | " 0.000000 | \n",
905 | " 0.000000 | \n",
906 | " 0.000000 | \n",
907 | " 0.000000 | \n",
908 | " 0.000000 | \n",
909 | " 0.000000 | \n",
910 | " 0.000000 | \n",
911 | " 0.000000 | \n",
912 | " 0.000000 | \n",
913 | " 0.000000 | \n",
914 | " 0.000000 | \n",
915 | " 0.000000 | \n",
916 | " 0.000000 | \n",
917 | " 0.000000 | \n",
918 | " 0.000000 | \n",
919 | " 0.000000 | \n",
920 | " 0.000000 | \n",
921 | " 0.000000 | \n",
922 | "
\n",
923 | " \n",
924 | " 25% | \n",
925 | " 0.297216 | \n",
926 | " 0.000000 | \n",
927 | " 0.000000 | \n",
928 | " 0.000000 | \n",
929 | " 0.000000 | \n",
930 | " 0.000000 | \n",
931 | " 0.000000 | \n",
932 | " 0.000000 | \n",
933 | " 0.090909 | \n",
934 | " 0.000000 | \n",
935 | " 0.143478 | \n",
936 | " 0.067797 | \n",
937 | " 0.133929 | \n",
938 | " 0.128205 | \n",
939 | " 0.078125 | \n",
940 | " 0.108956 | \n",
941 | " 0.253113 | \n",
942 | " 0.252972 | \n",
943 | " 0.367647 | \n",
944 | " 0.367647 | \n",
945 | " 0.206468 | \n",
946 | " 0.096774 | \n",
947 | " 0.050926 | \n",
948 | " 0.081428 | \n",
949 | " 0.000000 | \n",
950 | " 0.000000 | \n",
951 | " 0.000000 | \n",
952 | " 0.000000 | \n",
953 | " 0.000000 | \n",
954 | " 0.090909 | \n",
955 | " 0.000000 | \n",
956 | " 0.133065 | \n",
957 | " 0.068376 | \n",
958 | " 0.129310 | \n",
959 | " 0.133929 | \n",
960 | " 0.087336 | \n",
961 | " 0.114955 | \n",
962 | " 0.257289 | \n",
963 | " 0.267498 | \n",
964 | " 0.379310 | \n",
965 | " 0.379310 | \n",
966 | " 0.172131 | \n",
967 | " 0.103448 | \n",
968 | " 0.065705 | \n",
969 | " 0.064619 | \n",
970 | " 0.000000 | \n",
971 | " 0.000000 | \n",
972 | " 0.000000 | \n",
973 | "
\n",
974 | " \n",
975 | " 50% | \n",
976 | " 0.385513 | \n",
977 | " 1.000000 | \n",
978 | " 1.000000 | \n",
979 | " 0.000000 | \n",
980 | " 0.000000 | \n",
981 | " 0.000000 | \n",
982 | " 0.142857 | \n",
983 | " 0.000000 | \n",
984 | " 0.363636 | \n",
985 | " 0.000000 | \n",
986 | " 0.230435 | \n",
987 | " 0.152542 | \n",
988 | " 0.214286 | \n",
989 | " 0.205128 | \n",
990 | " 0.136719 | \n",
991 | " 0.169018 | \n",
992 | " 0.340068 | \n",
993 | " 0.340159 | \n",
994 | " 0.448529 | \n",
995 | " 0.448529 | \n",
996 | " 0.313433 | \n",
997 | " 0.161290 | \n",
998 | " 0.081488 | \n",
999 | " 0.197553 | \n",
1000 | " 0.000000 | \n",
1001 | " 0.000000 | \n",
1002 | " 0.000000 | \n",
1003 | " 0.142857 | \n",
1004 | " 0.000000 | \n",
1005 | " 0.363636 | \n",
1006 | " 0.000000 | \n",
1007 | " 0.213710 | \n",
1008 | " 0.153846 | \n",
1009 | " 0.206897 | \n",
1010 | " 0.214286 | \n",
1011 | " 0.152838 | \n",
1012 | " 0.178874 | \n",
1013 | " 0.338862 | \n",
1014 | " 0.361746 | \n",
1015 | " 0.455172 | \n",
1016 | " 0.455172 | \n",
1017 | " 0.262295 | \n",
1018 | " 0.172414 | \n",
1019 | " 0.105084 | \n",
1020 | " 0.164466 | \n",
1021 | " 0.000000 | \n",
1022 | " 0.000000 | \n",
1023 | " 1.000000 | \n",
1024 | "
\n",
1025 | " \n",
1026 | " 75% | \n",
1027 | " 0.478899 | \n",
1028 | " 1.000000 | \n",
1029 | " 1.000000 | \n",
1030 | " 0.000000 | \n",
1031 | " 1.000000 | \n",
1032 | " 1.000000 | \n",
1033 | " 0.285714 | \n",
1034 | " 0.250000 | \n",
1035 | " 0.636364 | \n",
1036 | " 0.111111 | \n",
1037 | " 0.317391 | \n",
1038 | " 0.245763 | \n",
1039 | " 0.294643 | \n",
1040 | " 0.282051 | \n",
1041 | " 0.210938 | \n",
1042 | " 0.238266 | \n",
1043 | " 0.423363 | \n",
1044 | " 0.411493 | \n",
1045 | " 0.514706 | \n",
1046 | " 0.514706 | \n",
1047 | " 0.420398 | \n",
1048 | " 0.258065 | \n",
1049 | " 0.122641 | \n",
1050 | " 0.338782 | \n",
1051 | " 1.000000 | \n",
1052 | " 1.000000 | \n",
1053 | " 1.000000 | \n",
1054 | " 0.285714 | \n",
1055 | " 0.200000 | \n",
1056 | " 0.636364 | \n",
1057 | " 0.100000 | \n",
1058 | " 0.294355 | \n",
1059 | " 0.239316 | \n",
1060 | " 0.284483 | \n",
1061 | " 0.294643 | \n",
1062 | " 0.235808 | \n",
1063 | " 0.252328 | \n",
1064 | " 0.417088 | \n",
1065 | " 0.437283 | \n",
1066 | " 0.517241 | \n",
1067 | " 0.517241 | \n",
1068 | " 0.352459 | \n",
1069 | " 0.275862 | \n",
1070 | " 0.157672 | \n",
1071 | " 0.290507 | \n",
1072 | " 0.000000 | \n",
1073 | " 1.000000 | \n",
1074 | " 1.000000 | \n",
1075 | "
\n",
1076 | " \n",
1077 | " max | \n",
1078 | " 1.000000 | \n",
1079 | " 1.000000 | \n",
1080 | " 1.000000 | \n",
1081 | " 1.000000 | \n",
1082 | " 1.000000 | \n",
1083 | " 1.000000 | \n",
1084 | " 1.000000 | \n",
1085 | " 1.000000 | \n",
1086 | " 1.000000 | \n",
1087 | " 1.000000 | \n",
1088 | " 1.000000 | \n",
1089 | " 1.000000 | \n",
1090 | " 1.000000 | \n",
1091 | " 1.000000 | \n",
1092 | " 1.000000 | \n",
1093 | " 1.000000 | \n",
1094 | " 1.000000 | \n",
1095 | " 1.000000 | \n",
1096 | " 1.000000 | \n",
1097 | " 1.000000 | \n",
1098 | " 1.000000 | \n",
1099 | " 1.000000 | \n",
1100 | " 1.000000 | \n",
1101 | " 1.000000 | \n",
1102 | " 1.000000 | \n",
1103 | " 1.000000 | \n",
1104 | " 1.000000 | \n",
1105 | " 1.000000 | \n",
1106 | " 1.000000 | \n",
1107 | " 1.000000 | \n",
1108 | " 1.000000 | \n",
1109 | " 1.000000 | \n",
1110 | " 1.000000 | \n",
1111 | " 1.000000 | \n",
1112 | " 1.000000 | \n",
1113 | " 1.000000 | \n",
1114 | " 1.000000 | \n",
1115 | " 1.000000 | \n",
1116 | " 1.000000 | \n",
1117 | " 1.000000 | \n",
1118 | " 1.000000 | \n",
1119 | " 1.000000 | \n",
1120 | " 1.000000 | \n",
1121 | " 1.000000 | \n",
1122 | " 1.000000 | \n",
1123 | " 1.000000 | \n",
1124 | " 1.000000 | \n",
1125 | " 1.000000 | \n",
1126 | "
\n",
1127 | " \n",
1128 | "
\n",
1129 | "
"
1130 | ],
1131 | "text/plain": [
1132 | " 0 1 ... 46 47\n",
1133 | "count 180000.000000 180000.000000 ... 180000.000000 180000.000000\n",
1134 | "mean 0.389464 0.506100 ... 0.329461 0.535767\n",
1135 | "std 0.129449 0.499964 ... 0.470019 0.498720\n",
1136 | "min 0.000000 0.000000 ... 0.000000 0.000000\n",
1137 | "25% 0.297216 0.000000 ... 0.000000 0.000000\n",
1138 | "50% 0.385513 1.000000 ... 0.000000 1.000000\n",
1139 | "75% 0.478899 1.000000 ... 1.000000 1.000000\n",
1140 | "max 1.000000 1.000000 ... 1.000000 1.000000\n",
1141 | "\n",
1142 | "[8 rows x 48 columns]"
1143 | ]
1144 | },
1145 | "metadata": {
1146 | "tags": []
1147 | },
1148 | "execution_count": 9
1149 | }
1150 | ]
1151 | },
1152 | {
1153 | "cell_type": "code",
1154 | "metadata": {
1155 | "id": "PCUWNw9AJpu4",
1156 | "colab_type": "code",
1157 | "colab": {}
1158 | },
1159 | "source": [
1160 | "y_train = np.array(y_train)\n",
1161 | "y_test = np.array(y_test)"
1162 | ],
1163 | "execution_count": 0,
1164 | "outputs": []
1165 | },
1166 | {
1167 | "cell_type": "code",
1168 | "metadata": {
1169 | "id": "AyhqyNgIJVPt",
1170 | "colab_type": "code",
1171 | "colab": {}
1172 | },
1173 | "source": [
1174 | "x_train = torch.FloatTensor(x_train)\n",
1175 | "y_train = torch.FloatTensor(y_train)\n",
1176 | "x_test = torch.FloatTensor(x_test)\n",
1177 | "y_test = torch.FloatTensor(y_test)"
1178 | ],
1179 | "execution_count": 0,
1180 | "outputs": []
1181 | },
1182 | {
1183 | "cell_type": "markdown",
1184 | "metadata": {
1185 | "id": "KWsozbXxJ02R",
1186 | "colab_type": "text"
1187 | },
1188 | "source": [
1189 | "## Data Loader"
1190 | ]
1191 | },
1192 | {
1193 | "cell_type": "code",
1194 | "metadata": {
1195 | "id": "LYhXTHYAJ2dk",
1196 | "colab_type": "code",
1197 | "colab": {}
1198 | },
1199 | "source": [
1200 | "from torch.utils.data import TensorDataset\n",
1201 | "from torch.utils.data import DataLoader"
1202 | ],
1203 | "execution_count": 0,
1204 | "outputs": []
1205 | },
1206 | {
1207 | "cell_type": "code",
1208 | "metadata": {
1209 | "id": "-NkuyDOuKEYW",
1210 | "colab_type": "code",
1211 | "colab": {}
1212 | },
1213 | "source": [
1214 | "train_set = TensorDataset(x_train, y_train)"
1215 | ],
1216 | "execution_count": 0,
1217 | "outputs": []
1218 | },
1219 | {
1220 | "cell_type": "code",
1221 | "metadata": {
1222 | "id": "KCOC13euKOj-",
1223 | "colab_type": "code",
1224 | "colab": {}
1225 | },
1226 | "source": [
1227 | "data_loader = DataLoader(dataset=train_set,\n",
1228 | " batch_size=20000,\n",
1229 | " shuffle=True)"
1230 | ],
1231 | "execution_count": 0,
1232 | "outputs": []
1233 | },
1234 | {
1235 | "cell_type": "markdown",
1236 | "metadata": {
1237 | "id": "8634QyKgJw2j",
1238 | "colab_type": "text"
1239 | },
1240 | "source": [
1241 | "## Make Model"
1242 | ]
1243 | },
1244 | {
1245 | "cell_type": "code",
1246 | "metadata": {
1247 | "id": "5-t35eIgKlQO",
1248 | "colab_type": "code",
1249 | "outputId": "bb67af03-04a1-40f3-855b-fe710c2fbeae",
1250 | "colab": {
1251 | "base_uri": "https://localhost:8080/",
1252 | "height": 53
1253 | }
1254 | },
1255 | "source": [
1256 | "print(x_train.shape)\n",
1257 | "print(y_train.shape)"
1258 | ],
1259 | "execution_count": 15,
1260 | "outputs": [
1261 | {
1262 | "output_type": "stream",
1263 | "text": [
1264 | "torch.Size([180000, 48])\n",
1265 | "torch.Size([180000, 1])\n"
1266 | ],
1267 | "name": "stdout"
1268 | }
1269 | ]
1270 | },
1271 | {
1272 | "cell_type": "code",
1273 | "metadata": {
1274 | "id": "tcxpDRLMJkvo",
1275 | "colab_type": "code",
1276 | "outputId": "0b413c66-5d2b-43af-ad6c-8362fdabab00",
1277 | "colab": {
1278 | "base_uri": "https://localhost:8080/",
1279 | "height": 125
1280 | }
1281 | },
1282 | "source": [
1283 | "l1 = torch.nn.Linear(48, 16).to(device)\n",
1284 | "l2 = torch.nn.Linear(16, 1).to(device)\n",
1285 | "relu = torch.nn.ReLU()\n",
1286 | "sigmoid = torch.nn.Sigmoid()\n",
1287 | "\n",
1288 | "model = torch.nn.Sequential(l1, relu, l2, sigmoid)\n",
1289 | "model"
1290 | ],
1291 | "execution_count": 16,
1292 | "outputs": [
1293 | {
1294 | "output_type": "execute_result",
1295 | "data": {
1296 | "text/plain": [
1297 | "Sequential(\n",
1298 | " (0): Linear(in_features=48, out_features=16, bias=True)\n",
1299 | " (1): ReLU()\n",
1300 | " (2): Linear(in_features=16, out_features=1, bias=True)\n",
1301 | " (3): Sigmoid()\n",
1302 | ")"
1303 | ]
1304 | },
1305 | "metadata": {
1306 | "tags": []
1307 | },
1308 | "execution_count": 16
1309 | }
1310 | ]
1311 | },
1312 | {
1313 | "cell_type": "markdown",
1314 | "metadata": {
1315 | "id": "etRx0ou_LH6b",
1316 | "colab_type": "text"
1317 | },
1318 | "source": [
1319 | "## Train Model"
1320 | ]
1321 | },
1322 | {
1323 | "cell_type": "code",
1324 | "metadata": {
1325 | "id": "AhBM4EGOLDxy",
1326 | "colab_type": "code",
1327 | "colab": {}
1328 | },
1329 | "source": [
1330 | "cost = torch.nn.BCELoss().to(device)\n",
1331 | "optimizer = torch.optim.SGD(model.parameters(), lr=0.1)"
1332 | ],
1333 | "execution_count": 0,
1334 | "outputs": []
1335 | },
1336 | {
1337 | "cell_type": "code",
1338 | "metadata": {
1339 | "id": "l3Vd9cVCLVKF",
1340 | "colab_type": "code",
1341 | "outputId": "ebe95e6f-22a8-4d75-8991-0629e89d033e",
1342 | "colab": {
1343 | "base_uri": "https://localhost:8080/",
1344 | "height": 143
1345 | }
1346 | },
1347 | "source": [
1348 | "epochs = 60\n",
1349 | "for epoch in range(1, epochs+1):\n",
1350 | " avg_cost = 0\n",
1351 | " total_batch = len(data_loader)\n",
1352 | "\n",
1353 | " for x, y in data_loader: # batch loop\n",
1354 | " x = x.to(device)\n",
1355 | " y = y.to(device)\n",
1356 | "\n",
1357 | " optimizer.zero_grad()\n",
1358 | " hypothesis = model(x)\n",
1359 | " cost_val = cost(hypothesis, y)\n",
1360 | " cost_val.backward()\n",
1361 | " optimizer.step()\n",
1362 | "\n",
1363 | " avg_cost += cost_val\n",
1364 | " \n",
1365 | " avg_cost /= total_batch\n",
1366 | "\n",
1367 | " if epoch % 10 == 1 or epoch == epochs:\n",
1368 | " print('Epoch {:4d}/{} Cost: {:.6f}'.format(epoch, epochs, avg_cost.item()))"
1369 | ],
1370 | "execution_count": 18,
1371 | "outputs": [
1372 | {
1373 | "output_type": "stream",
1374 | "text": [
1375 | "Epoch 1/60 Cost: 0.679092\n",
1376 | "Epoch 11/60 Cost: 0.278091\n",
1377 | "Epoch 21/60 Cost: 0.212358\n",
1378 | "Epoch 31/60 Cost: 0.181309\n",
1379 | "Epoch 41/60 Cost: 0.155241\n",
1380 | "Epoch 51/60 Cost: 0.133555\n",
1381 | "Epoch 60/60 Cost: 0.118290\n"
1382 | ],
1383 | "name": "stdout"
1384 | }
1385 | ]
1386 | },
1387 | {
1388 | "cell_type": "markdown",
1389 | "metadata": {
1390 | "id": "53fyQhMsLtCb",
1391 | "colab_type": "text"
1392 | },
1393 | "source": [
1394 | "## Evaluate Model"
1395 | ]
1396 | },
1397 | {
1398 | "cell_type": "code",
1399 | "metadata": {
1400 | "id": "VkSMq6X1Lokh",
1401 | "colab_type": "code",
1402 | "colab": {}
1403 | },
1404 | "source": [
1405 | "with torch.no_grad(): # Don't Calculate Gradient\n",
1406 | " x_test = x_test.to(device)\n",
1407 | "\n",
1408 | " pred = model(x_test)"
1409 | ],
1410 | "execution_count": 0,
1411 | "outputs": []
1412 | },
1413 | {
1414 | "cell_type": "code",
1415 | "metadata": {
1416 | "id": "AP0IZnTdL4Ad",
1417 | "colab_type": "code",
1418 | "colab": {
1419 | "base_uri": "https://localhost:8080/",
1420 | "height": 35
1421 | },
1422 | "outputId": "f11157ba-7d03-42ed-ba37-74cb8f9b00d7"
1423 | },
1424 | "source": [
1425 | "pred[pred>=0.5] = 1.0\n",
1426 | "pred[pred<=0.5] = 0.0\n",
1427 | "correct_prediction = pred.detach().cpu().float() == y_test\n",
1428 | "accuracy = correct_prediction.sum().item() / len(correct_prediction)\n",
1429 | "print('Accuracy {:2.2f}%'.format(accuracy * 100))"
1430 | ],
1431 | "execution_count": 22,
1432 | "outputs": [
1433 | {
1434 | "output_type": "stream",
1435 | "text": [
1436 | "Accuracy 96.12%\n"
1437 | ],
1438 | "name": "stdout"
1439 | }
1440 | ]
1441 | }
1442 | ]
1443 | }
--------------------------------------------------------------------------------
/리그 오브 레전드 승패 예측하기/README.md:
--------------------------------------------------------------------------------
1 | # 리그 오브 레전드 승패 예측하기
2 | - 데이터 출처: https://www.kaggle.com/gyejr95/league-of-legends-challenger-ranked-games2020
3 | - 데이터 가공 과정: https://www.kaggle.com/skyil7/data-processing-for-regression-tasks
4 |
5 | 리그 오브 레전드 게임 데이터를 활용하여 승리팀을 예측하는 문제입니다.
6 |
7 | 18만 개의 Training 데이터와 19925개의 Test 데이터로 구성되어 있습니다.
8 |
9 | 모델의 크기와 학습율, Epoch 등을 바꿔가며 최고의 정확도를 내는 모델을 설계해봅시다!
--------------------------------------------------------------------------------