├── README.md ├── code ├── 1-1.Intro_to_NDArray.ipynb ├── 1-2.Basic_Deeplearning_Model.ipynb ├── 2.Word_embedding.ipynb ├── 3_1_intent_classification_pycon2019.ipynb ├── 3_2_entity_tagging_pycon2019.ipynb └── 3_3_naver_review_classifications_gluon_bert.ipynb └── slide ├── 1.MXNet_Basic.pdf ├── 2.Word_Embedding.pdf └── 3_bert.pdf /README.md: -------------------------------------------------------------------------------- 1 | # Pycon 2019 Tutorial 2 | ## 딥러닝 NLP 손쉽게 따라해보기 3 | --- 4 | 5 | 1. MXNet Basic 6 | - [Slide](slide/1.MXNet_Basic.pdf) 7 | - Code 8 | - Intro to NDArray [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/seujung/gluonnlp_tutorial/blob/master/code/1-1.Intro_to_NDArray.ipynb) 9 | - Basic Deep Learning Model [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/seujung/gluonnlp_tutorial/blob/master/code/1-2.Basic_Deeplearning_Model.ipynb) 10 | 11 | 12 | 2. Embedding 13 | - [Slide](slide/2.Word_Embedding.pdf) 14 | - Code 15 | - Word Embedding [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/seujung/gluonnlp_tutorial/blob/master/code/2.Word_embedding.ipynb) 16 | 17 | 3. RNN, Attention, BERT with GluonNLP 18 | - [Slide](slide/3_bert.pdf) 19 | - code 20 | - Intent Classification [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/seujung/gluonnlp_tutorial/blob/master/code/3_1_intent_classification_pycon2019.ipynb) 21 | - Entity Tagging [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/seujung/gluonnlp_tutorial/blob/master/code/3_2_entity_tagging_pycon2019.ipynb) 22 | - NSMC with BERT [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/seujung/gluonnlp_tutorial/blob/master/code/3_3_naver_review_classifications_gluon_bert.ipynb) 23 | 24 | --- 25 | 26 | ## [Reference] 27 | - GluonNLP Tutorial (KDD 2018) : https://kdd18.mxnet.io/ 28 | - JSALT 2019 NLP Tutorial: https://jsalt19.mxnet.io/ 29 | - Dive into Deep Learning (NLP) : https://www.d2l.ai/chapter_natural-language-processing/index.html 30 | - GluonNLP Official Tutorial : https://gluon-nlp.mxnet.io/examples/index.html 31 | -------------------------------------------------------------------------------- /code/1-1.Intro_to_NDArray.ipynb: -------------------------------------------------------------------------------- 1 | {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"1.Intro_to_NDArray.ipynb","version":"0.3.2","provenance":[],"collapsed_sections":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"bWMP9WRr2k1w","colab_type":"text"},"source":["패키지 설치 (mxnet, gluonnlp)"]},{"cell_type":"code","metadata":{"id":"VfQFZKLe2tIQ","colab_type":"code","outputId":"cd00746f-1a3b-419f-c29d-546d6be9b87b","executionInfo":{"status":"ok","timestamp":1565312276921,"user_tz":-540,"elapsed":50264,"user":{"displayName":"seung hwan Jung","photoUrl":"","userId":"16375924554386741873"}},"colab":{"base_uri":"https://localhost:8080/","height":496}},"source":["!pip install mxnet-cu100\n","!pip install gluonnlp"],"execution_count":1,"outputs":[{"output_type":"stream","text":["Collecting mxnet-cu100\n","\u001b[?25l Downloading https://files.pythonhosted.org/packages/56/d3/e939814957c2f09ecdd22daa166898889d54e5981e356832425d514edfb6/mxnet_cu100-1.5.0-py2.py3-none-manylinux1_x86_64.whl (540.1MB)\n","\u001b[K |████████████████████████████████| 540.1MB 31kB/s \n","\u001b[?25hRequirement already satisfied: requests<3,>=2.20.0 in /usr/local/lib/python3.6/dist-packages (from mxnet-cu100) (2.21.0)\n","Collecting graphviz<0.9.0,>=0.8.1 (from mxnet-cu100)\n"," Downloading https://files.pythonhosted.org/packages/53/39/4ab213673844e0c004bed8a0781a0721a3f6bb23eb8854ee75c236428892/graphviz-0.8.4-py2.py3-none-any.whl\n","Requirement already satisfied: numpy<2.0.0,>1.16.0 in /usr/local/lib/python3.6/dist-packages (from mxnet-cu100) (1.16.4)\n","Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.20.0->mxnet-cu100) (3.0.4)\n","Requirement already satisfied: urllib3<1.25,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.20.0->mxnet-cu100) (1.24.3)\n","Requirement already satisfied: idna<2.9,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.20.0->mxnet-cu100) (2.8)\n","Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.20.0->mxnet-cu100) (2019.6.16)\n","Installing collected packages: graphviz, mxnet-cu100\n"," Found existing installation: graphviz 0.10.1\n"," Uninstalling graphviz-0.10.1:\n"," Successfully uninstalled graphviz-0.10.1\n","Successfully installed graphviz-0.8.4 mxnet-cu100-1.5.0\n","Collecting gluonnlp\n","\u001b[?25l Downloading https://files.pythonhosted.org/packages/c1/c8/e180cd98ab190e7ac3c6a767a909918e719be33f967bca13d0d4cd7c5468/gluonnlp-0.8.0.tar.gz (235kB)\n","\u001b[K |████████████████████████████████| 245kB 18.1MB/s \n","\u001b[?25hRequirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from gluonnlp) (1.16.4)\n","Building wheels for collected packages: gluonnlp\n"," Building wheel for gluonnlp (setup.py) ... \u001b[?25l\u001b[?25hdone\n"," Created wheel for gluonnlp: filename=gluonnlp-0.8.0-cp36-none-any.whl size=292704 sha256=6d4ef2faeca258d8b30c2330c639a82e7635bb2d8e07c83b888105b3d58fe854\n"," Stored in directory: /root/.cache/pip/wheels/28/ff/33/d73801f242fb93c02f2076f81232fcb9a29305480cc42c5454\n","Successfully built gluonnlp\n","Installing collected packages: gluonnlp\n","Successfully installed gluonnlp-0.8.0\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"JNXFGV6p3BZW","colab_type":"code","colab":{}},"source":["import mxnet as mx\n","from mxnet import nd, autograd\n","import gluonnlp as nlp"],"execution_count":0,"outputs":[]},{"cell_type":"code","metadata":{"id":"liL3eFl3aEoH","colab_type":"code","outputId":"8079745f-3f50-45db-a28c-9a8e0908f875","executionInfo":{"status":"ok","timestamp":1565312295362,"user_tz":-540,"elapsed":1628,"user":{"displayName":"seung hwan Jung","photoUrl":"","userId":"16375924554386741873"}},"colab":{"base_uri":"https://localhost:8080/","height":34}},"source":["mx.__version__"],"execution_count":3,"outputs":[{"output_type":"execute_result","data":{"text/plain":["'1.5.0'"]},"metadata":{"tags":[]},"execution_count":3}]},{"cell_type":"code","metadata":{"id":"yvFmSwij4iQk","colab_type":"code","colab":{"base_uri":"https://localhost:8080/","height":34},"outputId":"8e2d3f93-1679-4792-c4c3-c5159d1235d2","executionInfo":{"status":"ok","timestamp":1565312305438,"user_tz":-540,"elapsed":1627,"user":{"displayName":"seung hwan Jung","photoUrl":"","userId":"16375924554386741873"}}},"source":["nlp.__version__"],"execution_count":4,"outputs":[{"output_type":"execute_result","data":{"text/plain":["'0.8.0'"]},"metadata":{"tags":[]},"execution_count":4}]},{"cell_type":"markdown","metadata":{"id":"5tOt7PfzZyxf","colab_type":"text"},"source":["데이터 생성"]},{"cell_type":"code","metadata":{"id":"hDL8ufHx3HcU","colab_type":"code","outputId":"aba55888-bbeb-48ba-eb81-7a1c79404e07","executionInfo":{"status":"ok","timestamp":1565312312914,"user_tz":-540,"elapsed":1647,"user":{"displayName":"seung hwan Jung","photoUrl":"","userId":"16375924554386741873"}},"colab":{"base_uri":"https://localhost:8080/","height":85}},"source":["x1 = nd.random.normal(shape=(1,10))\n","print(x1)"],"execution_count":5,"outputs":[{"output_type":"stream","text":["\n","[[ 2.2122064 0.7740038 1.0434403 1.1839255 1.8917114 -1.2347414\n"," -1.771029 -0.45138445 0.57938355 -1.856082 ]]\n","\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"6Lo8bTPx3fLP","colab_type":"code","outputId":"b3751d31-d1fc-4ba3-a945-8a9bb7ca6c56","executionInfo":{"status":"ok","timestamp":1565312317014,"user_tz":-540,"elapsed":1632,"user":{"displayName":"seung hwan Jung","photoUrl":"","userId":"16375924554386741873"}},"colab":{"base_uri":"https://localhost:8080/","height":85}},"source":["## reshape\n","x2 = nd.arange(10).reshape(2,5)\n","print(x2)"],"execution_count":6,"outputs":[{"output_type":"stream","text":["\n","[[0. 1. 2. 3. 4.]\n"," [5. 6. 7. 8. 9.]]\n","\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"PAmV0-0a4owc","colab_type":"code","colab":{"base_uri":"https://localhost:8080/","height":136},"outputId":"bb01eefa-4f52-40f8-e7af-4f65efa8373b","executionInfo":{"status":"ok","timestamp":1565312417332,"user_tz":-540,"elapsed":1611,"user":{"displayName":"seung hwan Jung","photoUrl":"","userId":"16375924554386741873"}}},"source":["##transpose\n","x2_1 = x2.swapaxes(0,1)\n","print(x2_1)"],"execution_count":8,"outputs":[{"output_type":"stream","text":["\n","[[0. 5.]\n"," [1. 6.]\n"," [2. 7.]\n"," [3. 8.]\n"," [4. 9.]]\n","\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"xPAjQ6CzaJxS","colab_type":"text"},"source":[" GPU / CPU 할당"]},{"cell_type":"code","metadata":{"id":"6xol7q2J3iPt","colab_type":"code","outputId":"5507f664-923e-4aec-928f-ac72507bc172","executionInfo":{"status":"ok","timestamp":1565312425780,"user_tz":-540,"elapsed":6273,"user":{"displayName":"seung hwan Jung","photoUrl":"","userId":"16375924554386741873"}},"colab":{"base_uri":"https://localhost:8080/","height":119}},"source":["##GPU에 할당\n","x3= nd.random.normal(shape=(2,10), ctx=mx.gpu())\n","print(x3)"],"execution_count":9,"outputs":[{"output_type":"stream","text":["\n","[[-1.3204551 0.68232244 -0.9858383 0.0199282 0.7842404 0.50066984\n"," -1.0283493 0.98445714 0.23791966 0.5675242 ]\n"," [ 0.416008 1.2724396 0.90007704 0.43224135 -0.04885176 0.00976894\n"," -0.6189104 -0.59411055 -0.7401756 -0.29974517]]\n","\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"F3SKfOKm3rRE","colab_type":"code","outputId":"ac096fc7-0511-4396-f8ac-78f39c96c5bd","executionInfo":{"status":"ok","timestamp":1565313199497,"user_tz":-540,"elapsed":1650,"user":{"displayName":"seung hwan Jung","photoUrl":"","userId":"16375924554386741873"}},"colab":{"base_uri":"https://localhost:8080/","height":119}},"source":["##GPU -> CPU 할당\n","x4 = x3.as_in_context(mx.cpu())\n","print(x4)"],"execution_count":10,"outputs":[{"output_type":"stream","text":["\n","[[-1.3204551 0.68232244 -0.9858383 0.0199282 0.7842404 0.50066984\n"," -1.0283493 0.98445714 0.23791966 0.5675242 ]\n"," [ 0.416008 1.2724396 0.90007704 0.43224135 -0.04885176 0.00976894\n"," -0.6189104 -0.59411055 -0.7401756 -0.29974517]]\n","\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"QUnQ5NQnkP1s","colab_type":"text"},"source":["Operation"]},{"cell_type":"code","metadata":{"id":"TXXbkxMA364t","colab_type":"code","outputId":"90ddf6b3-5266-4841-a6f5-5949ebaa45ec","executionInfo":{"status":"ok","timestamp":1565313204962,"user_tz":-540,"elapsed":913,"user":{"displayName":"seung hwan Jung","photoUrl":"","userId":"16375924554386741873"}},"colab":{"base_uri":"https://localhost:8080/","height":85}},"source":["##slice 연산\n","x5 = x4[:,:3]\n","print(x5)"],"execution_count":11,"outputs":[{"output_type":"stream","text":["\n","[[-1.3204551 0.68232244 -0.9858383 ]\n"," [ 0.416008 1.2724396 0.90007704]]\n","\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"BR9lFedW4FJ2","colab_type":"code","colab":{}},"source":["##Matrix 연산\n","### dot product\n","a = nd.arange(10).reshape(2,5)\n","b = nd.arange(10).reshape(5,2)\n","c = nd.dot(a, b)"],"execution_count":0,"outputs":[]},{"cell_type":"code","metadata":{"id":"qwvSAAV64S2P","colab_type":"code","outputId":"456b8036-7848-42b2-dbdf-3fc01da46d25","executionInfo":{"status":"ok","timestamp":1565313221066,"user_tz":-540,"elapsed":1664,"user":{"displayName":"seung hwan Jung","photoUrl":"","userId":"16375924554386741873"}},"colab":{"base_uri":"https://localhost:8080/","height":85}},"source":["print(a)"],"execution_count":13,"outputs":[{"output_type":"stream","text":["\n","[[0. 1. 2. 3. 4.]\n"," [5. 6. 7. 8. 9.]]\n","\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"Jeu28-tH4TuF","colab_type":"code","outputId":"d7f0fd76-04fa-49be-a104-1c0f837332cb","executionInfo":{"status":"ok","timestamp":1565313223122,"user_tz":-540,"elapsed":1359,"user":{"displayName":"seung hwan Jung","photoUrl":"","userId":"16375924554386741873"}},"colab":{"base_uri":"https://localhost:8080/","height":136}},"source":["print(b)"],"execution_count":14,"outputs":[{"output_type":"stream","text":["\n","[[0. 1.]\n"," [2. 3.]\n"," [4. 5.]\n"," [6. 7.]\n"," [8. 9.]]\n","\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"yS_pRdMK4VIN","colab_type":"code","outputId":"c3816c29-3aee-4d1b-ec11-7b790c1f9e77","executionInfo":{"status":"ok","timestamp":1565313225020,"user_tz":-540,"elapsed":855,"user":{"displayName":"seung hwan Jung","photoUrl":"","userId":"16375924554386741873"}},"colab":{"base_uri":"https://localhost:8080/","height":85}},"source":["print(c)"],"execution_count":15,"outputs":[{"output_type":"stream","text":["\n","[[ 60. 70.]\n"," [160. 195.]]\n","\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"OwaWtZyP4W4V","colab_type":"code","colab":{}},"source":["##Batch matrix multiplication\n","a = nd.arange(30).reshape(5,2,3)\n","b = nd.arange(60).reshape(5,3,4)\n","c = nd.batch_dot(a, b)"],"execution_count":0,"outputs":[]},{"cell_type":"code","metadata":{"id":"RfoQs56M4tVm","colab_type":"code","outputId":"0aba3586-ebd0-4586-af55-23f728e76a55","executionInfo":{"status":"ok","timestamp":1565313235306,"user_tz":-540,"elapsed":927,"user":{"displayName":"seung hwan Jung","photoUrl":"","userId":"16375924554386741873"}},"colab":{"base_uri":"https://localhost:8080/","height":289}},"source":["print(a)"],"execution_count":18,"outputs":[{"output_type":"stream","text":["\n","[[[ 0. 1. 2.]\n"," [ 3. 4. 5.]]\n","\n"," [[ 6. 7. 8.]\n"," [ 9. 10. 11.]]\n","\n"," [[12. 13. 14.]\n"," [15. 16. 17.]]\n","\n"," [[18. 19. 20.]\n"," [21. 22. 23.]]\n","\n"," [[24. 25. 26.]\n"," [27. 28. 29.]]]\n","\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"AK8iq_Ut4wjF","colab_type":"code","outputId":"50a94ff4-ed2d-4609-fe78-0ef467917e9c","executionInfo":{"status":"ok","timestamp":1565313237757,"user_tz":-540,"elapsed":1441,"user":{"displayName":"seung hwan Jung","photoUrl":"","userId":"16375924554386741873"}},"colab":{"base_uri":"https://localhost:8080/","height":374}},"source":["print(b)"],"execution_count":19,"outputs":[{"output_type":"stream","text":["\n","[[[ 0. 1. 2. 3.]\n"," [ 4. 5. 6. 7.]\n"," [ 8. 9. 10. 11.]]\n","\n"," [[12. 13. 14. 15.]\n"," [16. 17. 18. 19.]\n"," [20. 21. 22. 23.]]\n","\n"," [[24. 25. 26. 27.]\n"," [28. 29. 30. 31.]\n"," [32. 33. 34. 35.]]\n","\n"," [[36. 37. 38. 39.]\n"," [40. 41. 42. 43.]\n"," [44. 45. 46. 47.]]\n","\n"," [[48. 49. 50. 51.]\n"," [52. 53. 54. 55.]\n"," [56. 57. 58. 59.]]]\n","\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"UXTPmZar4xTu","colab_type":"code","outputId":"c0255bf6-d50d-4a74-83fe-f8570a5ea85d","executionInfo":{"status":"ok","timestamp":1565313239057,"user_tz":-540,"elapsed":752,"user":{"displayName":"seung hwan Jung","photoUrl":"","userId":"16375924554386741873"}},"colab":{"base_uri":"https://localhost:8080/","height":289}},"source":["print(c)"],"execution_count":20,"outputs":[{"output_type":"stream","text":["\n","[[[ 20. 23. 26. 29.]\n"," [ 56. 68. 80. 92.]]\n","\n"," [[ 344. 365. 386. 407.]\n"," [ 488. 518. 548. 578.]]\n","\n"," [[1100. 1139. 1178. 1217.]\n"," [1352. 1400. 1448. 1496.]]\n","\n"," [[2288. 2345. 2402. 2459.]\n"," [2648. 2714. 2780. 2846.]]\n","\n"," [[3908. 3983. 4058. 4133.]\n"," [4376. 4460. 4544. 4628.]]]\n","\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"d-JtXmqnlFr8","colab_type":"text"},"source":["Numpy로 변환"]},{"cell_type":"code","metadata":{"id":"77U86HiB4yOG","colab_type":"code","colab":{}},"source":["d = c.asnumpy()"],"execution_count":0,"outputs":[]},{"cell_type":"code","metadata":{"id":"OYvJ04uAlNMP","colab_type":"code","outputId":"5cf47e00-7d77-4d19-8ec3-b59208ed06a8","executionInfo":{"status":"ok","timestamp":1565313251127,"user_tz":-540,"elapsed":1264,"user":{"displayName":"seung hwan Jung","photoUrl":"","userId":"16375924554386741873"}},"colab":{"base_uri":"https://localhost:8080/","height":255}},"source":["d"],"execution_count":22,"outputs":[{"output_type":"execute_result","data":{"text/plain":["array([[[ 20., 23., 26., 29.],\n"," [ 56., 68., 80., 92.]],\n","\n"," [[ 344., 365., 386., 407.],\n"," [ 488., 518., 548., 578.]],\n","\n"," [[1100., 1139., 1178., 1217.],\n"," [1352., 1400., 1448., 1496.]],\n","\n"," [[2288., 2345., 2402., 2459.],\n"," [2648., 2714., 2780., 2846.]],\n","\n"," [[3908., 3983., 4058., 4133.],\n"," [4376., 4460., 4544., 4628.]]], dtype=float32)"]},"metadata":{"tags":[]},"execution_count":22}]},{"cell_type":"code","metadata":{"id":"7tLMa2XGlNiO","colab_type":"code","colab":{}},"source":[""],"execution_count":0,"outputs":[]}]} -------------------------------------------------------------------------------- /code/1-2.Basic_Deeplearning_Model.ipynb: -------------------------------------------------------------------------------- 1 | {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"2.Basic_Deeplearning_Model.ipynb","version":"0.3.2","provenance":[],"collapsed_sections":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"IbUdIa1v8ieb","colab_type":"text"},"source":["패키지 설치"]},{"cell_type":"code","metadata":{"id":"I17WmOBmrK51","colab_type":"code","outputId":"458a565a-35fb-43c5-9438-6d14abbd80c2","executionInfo":{"status":"ok","timestamp":1565316278087,"user_tz":-540,"elapsed":9877,"user":{"displayName":"seung hwan Jung","photoUrl":"","userId":"16375924554386741873"}},"colab":{"base_uri":"https://localhost:8080/","height":493}},"source":["!pip install mxnet-cu100\n","!pip install gluonnlp\n","!pip install gluoncv"],"execution_count":0,"outputs":[{"output_type":"stream","text":["Requirement already satisfied: mxnet-cu100 in /usr/local/lib/python3.6/dist-packages (1.5.0)\n","Requirement already satisfied: graphviz<0.9.0,>=0.8.1 in /usr/local/lib/python3.6/dist-packages (from mxnet-cu100) (0.8.4)\n","Requirement already satisfied: requests<3,>=2.20.0 in /usr/local/lib/python3.6/dist-packages (from mxnet-cu100) (2.21.0)\n","Requirement already satisfied: numpy<2.0.0,>1.16.0 in /usr/local/lib/python3.6/dist-packages (from mxnet-cu100) (1.16.4)\n","Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.20.0->mxnet-cu100) (2019.6.16)\n","Requirement already satisfied: idna<2.9,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.20.0->mxnet-cu100) (2.8)\n","Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.20.0->mxnet-cu100) (3.0.4)\n","Requirement already satisfied: urllib3<1.25,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.20.0->mxnet-cu100) (1.24.3)\n","Requirement already satisfied: gluonnlp in /usr/local/lib/python3.6/dist-packages (0.8.0)\n","Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from gluonnlp) (1.16.4)\n","Requirement already satisfied: gluoncv in /usr/local/lib/python3.6/dist-packages (0.4.0.post0)\n","Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from gluoncv) (1.16.4)\n","Requirement already satisfied: matplotlib in /usr/local/lib/python3.6/dist-packages (from gluoncv) (3.0.3)\n","Requirement already satisfied: scipy in /usr/local/lib/python3.6/dist-packages (from gluoncv) (1.3.0)\n","Requirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from gluoncv) (2.21.0)\n","Requirement already satisfied: Pillow in /usr/local/lib/python3.6/dist-packages (from gluoncv) (4.3.0)\n","Requirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (from gluoncv) (4.28.1)\n","Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->gluoncv) (1.1.0)\n","Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.6/dist-packages (from matplotlib->gluoncv) (0.10.0)\n","Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->gluoncv) (2.4.2)\n","Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->gluoncv) (2.5.3)\n","Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->gluoncv) (3.0.4)\n","Requirement already satisfied: idna<2.9,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->gluoncv) (2.8)\n","Requirement already satisfied: urllib3<1.25,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->gluoncv) (1.24.3)\n","Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->gluoncv) (2019.6.16)\n","Requirement already satisfied: olefile in /usr/local/lib/python3.6/dist-packages (from Pillow->gluoncv) (0.46)\n","Requirement already satisfied: setuptools in /usr/local/lib/python3.6/dist-packages (from kiwisolver>=1.0.1->matplotlib->gluoncv) (41.0.1)\n","Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from cycler>=0.10->matplotlib->gluoncv) (1.12.0)\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"iklvnNNGHvS7","colab_type":"code","outputId":"1f2d16c6-f8f1-4436-b593-c4865e5e25d6","executionInfo":{"status":"ok","timestamp":1565316287723,"user_tz":-540,"elapsed":2635,"user":{"displayName":"seung hwan Jung","photoUrl":"","userId":"16375924554386741873"}},"colab":{"base_uri":"https://localhost:8080/","height":306}},"source":["!nvidia-smi"],"execution_count":0,"outputs":[{"output_type":"stream","text":["Fri Aug 9 02:04:46 2019 \n","+-----------------------------------------------------------------------------+\n","| NVIDIA-SMI 418.67 Driver Version: 410.79 CUDA Version: 10.0 |\n","|-------------------------------+----------------------+----------------------+\n","| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n","| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n","|===============================+======================+======================|\n","| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 |\n","| N/A 67C P8 17W / 70W | 0MiB / 15079MiB | 0% Default |\n","+-------------------------------+----------------------+----------------------+\n"," \n","+-----------------------------------------------------------------------------+\n","| Processes: GPU Memory |\n","| GPU PID Type Process name Usage |\n","|=============================================================================|\n","| No running processes found |\n","+-----------------------------------------------------------------------------+\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"xiU4uRVKraSG","colab_type":"code","colab":{}},"source":["import mxnet as mx\n","from mxnet import nd, gluon, autograd\n","from mxnet.gluon import nn\n","import gluoncv\n","import gluonnlp as nlp"],"execution_count":0,"outputs":[]},{"cell_type":"code","metadata":{"id":"1j9HvqMF9LC9","colab_type":"code","outputId":"e74b95dc-8ccb-4308-aab1-8387e97476c8","executionInfo":{"status":"ok","timestamp":1565316292596,"user_tz":-540,"elapsed":759,"user":{"displayName":"seung hwan Jung","photoUrl":"","userId":"16375924554386741873"}},"colab":{"base_uri":"https://localhost:8080/","height":34}},"source":["mx.__version__"],"execution_count":0,"outputs":[{"output_type":"execute_result","data":{"text/plain":["'1.5.0'"]},"metadata":{"tags":[]},"execution_count":4}]},{"cell_type":"code","metadata":{"id":"OKWLMVQE9Nru","colab_type":"code","outputId":"e6610cbb-355d-4ee6-d41f-5fcf86906e72","executionInfo":{"status":"ok","timestamp":1565316293700,"user_tz":-540,"elapsed":543,"user":{"displayName":"seung hwan Jung","photoUrl":"","userId":"16375924554386741873"}},"colab":{"base_uri":"https://localhost:8080/","height":34}},"source":["gluoncv.__version__"],"execution_count":0,"outputs":[{"output_type":"execute_result","data":{"text/plain":["'0.4.0'"]},"metadata":{"tags":[]},"execution_count":5}]},{"cell_type":"code","metadata":{"id":"B-FQxb3c9Qnf","colab_type":"code","outputId":"8001a35b-4c6c-4977-c1bd-aaef1bc2899e","executionInfo":{"status":"ok","timestamp":1565316294815,"user_tz":-540,"elapsed":577,"user":{"displayName":"seung hwan Jung","photoUrl":"","userId":"16375924554386741873"}},"colab":{"base_uri":"https://localhost:8080/","height":34}},"source":["nlp.__version__"],"execution_count":0,"outputs":[{"output_type":"execute_result","data":{"text/plain":["'0.8.0'"]},"metadata":{"tags":[]},"execution_count":6}]},{"cell_type":"code","metadata":{"id":"c73K3HyIrPco","colab_type":"code","colab":{}},"source":["## Setting transform\n","def transform(data, label):\n"," data = data.astype('float32')/255\n"," return data, label\n","\n","train_dataset = mx.gluon.data.vision.datasets.FashionMNIST(train=True, transform=transform)\n","valid_dataset = mx.gluon.data.vision.datasets.FashionMNIST(train=False, transform=transform)"],"execution_count":0,"outputs":[]},{"cell_type":"code","metadata":{"id":"I8yJX2cJs-Mu","colab_type":"code","colab":{}},"source":["batch_size = 32\n","train_data_loader = mx.gluon.data.DataLoader(train_dataset, batch_size, shuffle=True, num_workers=1)\n","valid_data_loader = mx.gluon.data.DataLoader(valid_dataset, batch_size, num_workers=1)"],"execution_count":0,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"o_B70cNG5BhY","colab_type":"text"},"source":["DNN Model"]},{"cell_type":"code","metadata":{"id":"LMaNqn8LtDC7","colab_type":"code","colab":{}},"source":["model = nn.HybridSequential()\n","model.add(nn.Dense(128, activation='relu'))\n","model.add(nn.Dense(64, activation='relu'))\n","model.add(nn.Dense(10, activation='relu'))"],"execution_count":0,"outputs":[]},{"cell_type":"code","metadata":{"id":"yFPV9VsG-X3N","colab_type":"code","outputId":"3aa216a1-44c9-4412-cdd2-6c0e72123844","executionInfo":{"status":"ok","timestamp":1565316312138,"user_tz":-540,"elapsed":765,"user":{"displayName":"seung hwan Jung","photoUrl":"","userId":"16375924554386741873"}},"colab":{"base_uri":"https://localhost:8080/","height":102}},"source":["model"],"execution_count":0,"outputs":[{"output_type":"execute_result","data":{"text/plain":["HybridSequential(\n"," (0): Dense(None -> 128, Activation(relu))\n"," (1): Dense(None -> 64, Activation(relu))\n"," (2): Dense(None -> 10, Activation(relu))\n",")"]},"metadata":{"tags":[]},"execution_count":10}]},{"cell_type":"code","metadata":{"id":"u_GDm28RtQm_","colab_type":"code","colab":{}},"source":["ctx = mx.gpu()"],"execution_count":0,"outputs":[]},{"cell_type":"code","metadata":{"id":"LYviHcEXtpcI","colab_type":"code","colab":{}},"source":["## 모델에 대한 초기화 작업을 반드시 해주어야 함\n","model.initialize(mx.init.Xavier(), ctx=ctx)\n","model.hybridize()"],"execution_count":0,"outputs":[]},{"cell_type":"code","metadata":{"id":"PwlsSTLltvNP","colab_type":"code","colab":{}},"source":["# define loss and trainer.\n","criterion = gluon.loss.SoftmaxCrossEntropyLoss()\n","trainer = gluon.Trainer(model.collect_params(), 'sgd', {'learning_rate': 0.1})"],"execution_count":0,"outputs":[]},{"cell_type":"code","metadata":{"id":"M14NMweMtzEh","colab_type":"code","outputId":"218e8a33-9bd8-4094-cdcb-c97a5a68244a","executionInfo":{"status":"ok","timestamp":1565316509588,"user_tz":-540,"elapsed":138366,"user":{"displayName":"seung hwan Jung","photoUrl":"","userId":"16375924554386741873"}},"colab":{"base_uri":"https://localhost:8080/","height":187}},"source":["epochs = 10\n","for epoch in range(epochs):\n"," # training loop (with autograd and trainer steps, etc.)\n"," cumulative_train_loss = mx.nd.zeros(1, ctx=ctx)\n"," training_samples = 0\n"," \n"," for batch_idx, (data, label) in enumerate(train_data_loader):\n"," data = data.as_in_context(ctx).reshape((-1, 784)) # 28*28=784\n"," label = label.as_in_context(ctx)\n"," with autograd.record():\n"," output = model(data)\n"," loss = criterion(output, label)\n"," loss.backward()\n"," trainer.step(data.shape[0])\n"," cumulative_train_loss += loss.sum()\n"," training_samples += data.shape[0]\n"," train_loss = cumulative_train_loss.asscalar()/training_samples\n","\n"," # validation loop\n"," cumulative_valid_loss = mx.nd.zeros(1, ctx)\n"," valid_samples = 0\n"," for batch_idx, (data, label) in enumerate(valid_data_loader):\n"," data = data.as_in_context(ctx).reshape((-1, 784)) # 28*28=784\n"," label = label.as_in_context(ctx)\n"," output = model(data)\n"," loss = criterion(output, label)\n"," cumulative_valid_loss += loss.sum()\n"," valid_samples += data.shape[0]\n"," valid_loss = cumulative_valid_loss.asscalar()/valid_samples\n","\n"," print(\"Epoch {}, training loss: {:.2f}, validation loss: {:.2f}\".format(epoch+1, train_loss, valid_loss))\n"],"execution_count":0,"outputs":[{"output_type":"stream","text":["Epoch 1, training loss: 0.75, validation loss: 0.44\n","Epoch 2, training loss: 0.41, validation loss: 0.38\n","Epoch 3, training loss: 0.36, validation loss: 0.41\n","Epoch 4, training loss: 0.34, validation loss: 0.35\n","Epoch 5, training loss: 0.32, validation loss: 0.35\n","Epoch 6, training loss: 0.30, validation loss: 0.32\n","Epoch 7, training loss: 0.29, validation loss: 0.32\n","Epoch 8, training loss: 0.28, validation loss: 0.31\n","Epoch 9, training loss: 0.27, validation loss: 0.33\n","Epoch 10, training loss: 0.26, validation loss: 0.31\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"S5KWLOXP-651","colab_type":"text"},"source":["CNN Model"]},{"cell_type":"code","metadata":{"id":"daXm8Imkt4Zb","colab_type":"code","colab":{}},"source":["model_cnn = nn.HybridSequential()\n","model_cnn.add(nn.Conv2D(3, 3, 1, activation='relu'))\n","model_cnn.add(nn.Conv2D(3, 3, 1, activation='relu'))\n","model_cnn.add(nn.Flatten())\n","model_cnn.add(nn.Dense(100, activation='relu'))\n","model_cnn.add(nn.Dense(10, activation='relu'))"],"execution_count":0,"outputs":[]},{"cell_type":"code","metadata":{"id":"IJS8rUYQF6Gc","colab_type":"code","outputId":"03ce98ef-7ab4-4949-9abc-aba3c8722499","executionInfo":{"status":"ok","timestamp":1565316517027,"user_tz":-540,"elapsed":512,"user":{"displayName":"seung hwan Jung","photoUrl":"","userId":"16375924554386741873"}},"colab":{"base_uri":"https://localhost:8080/","height":136}},"source":["model_cnn"],"execution_count":0,"outputs":[{"output_type":"execute_result","data":{"text/plain":["HybridSequential(\n"," (0): Conv2D(None -> 3, kernel_size=(3, 3), stride=(1, 1), Activation(relu))\n"," (1): Conv2D(None -> 3, kernel_size=(3, 3), stride=(1, 1), Activation(relu))\n"," (2): Flatten\n"," (3): Dense(None -> 100, Activation(relu))\n"," (4): Dense(None -> 10, Activation(relu))\n",")"]},"metadata":{"tags":[]},"execution_count":16}]},{"cell_type":"code","metadata":{"id":"FAoqTQPD_QCh","colab_type":"code","colab":{}},"source":["model_cnn.initialize(mx.init.Xavier(), ctx=ctx)\n","model_cnn.hybridize()"],"execution_count":0,"outputs":[]},{"cell_type":"code","metadata":{"id":"YXrv1gGv_RU0","colab_type":"code","colab":{}},"source":["# define loss and trainer.\n","criterion = gluon.loss.SoftmaxCrossEntropyLoss()\n","trainer = gluon.Trainer(model_cnn.collect_params(), 'sgd', {'learning_rate': 0.01})"],"execution_count":0,"outputs":[]},{"cell_type":"code","metadata":{"id":"3_8f8z_d_oB0","colab_type":"code","outputId":"073c85a9-e9a4-4b18-8f96-8eb30c4d5b40","executionInfo":{"status":"ok","timestamp":1565316711905,"user_tz":-540,"elapsed":141777,"user":{"displayName":"seung hwan Jung","photoUrl":"","userId":"16375924554386741873"}},"colab":{"base_uri":"https://localhost:8080/","height":187}},"source":["epochs = 10\n","for epoch in range(epochs):\n"," # training loop (with autograd and trainer steps, etc.)\n"," cumulative_train_loss = mx.nd.zeros(1, ctx=ctx)\n"," training_samples = 0\n"," \n"," for batch_idx, (data, label) in enumerate(train_data_loader):\n"," data = data.as_in_context(ctx).transpose((0,3,1,2))\n"," label = label.as_in_context(ctx)\n"," with autograd.record():\n"," output = model_cnn(data)\n"," loss = criterion(output, label)\n"," loss.backward()\n"," trainer.step(data.shape[0])\n"," cumulative_train_loss += loss.sum()\n"," training_samples += data.shape[0]\n"," train_loss = cumulative_train_loss.asscalar()/training_samples\n","\n"," # validation loop\n"," cumulative_valid_loss = mx.nd.zeros(1, ctx)\n"," valid_samples = 0\n"," for batch_idx, (data, label) in enumerate(valid_data_loader):\n"," data = data.as_in_context(ctx).transpose((0,3,1,2))\n"," label = label.as_in_context(ctx)\n"," output = model_cnn(data)\n"," loss = criterion(output, label)\n"," cumulative_valid_loss += loss.sum()\n"," valid_samples += data.shape[0]\n"," valid_loss = cumulative_valid_loss.asscalar()/valid_samples\n","\n"," print(\"Epoch {}, training loss: {:.2f}, validation loss: {:.2f}\".format(epoch+1, train_loss, valid_loss))\n"],"execution_count":0,"outputs":[{"output_type":"stream","text":["Epoch 1, training loss: 0.95, validation loss: 0.51\n","Epoch 2, training loss: 0.47, validation loss: 0.46\n","Epoch 3, training loss: 0.41, validation loss: 0.40\n","Epoch 4, training loss: 0.38, validation loss: 0.38\n","Epoch 5, training loss: 0.36, validation loss: 0.36\n","Epoch 6, training loss: 0.34, validation loss: 0.35\n","Epoch 7, training loss: 0.33, validation loss: 0.35\n","Epoch 8, training loss: 0.32, validation loss: 0.33\n","Epoch 9, training loss: 0.30, validation loss: 0.33\n","Epoch 10, training loss: 0.29, validation loss: 0.32\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"ipZq-4HjI0C6","colab_type":"code","colab":{}},"source":[""],"execution_count":0,"outputs":[]}]} -------------------------------------------------------------------------------- /code/2.Word_embedding.ipynb: -------------------------------------------------------------------------------- 1 | {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"3.Word_embedding.ipynb","version":"0.3.2","provenance":[],"collapsed_sections":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"_gCH1Ue-JnwM","colab_type":"text"},"source":["패키지 설치"]},{"cell_type":"code","metadata":{"id":"Cl1fGvONHpnt","colab_type":"code","outputId":"892ab623-22d8-49a1-c3c2-763d94213895","executionInfo":{"status":"ok","timestamp":1565323574101,"user_tz":-540,"elapsed":47628,"user":{"displayName":"seung hwan Jung","photoUrl":"","userId":"16375924554386741873"}},"colab":{"base_uri":"https://localhost:8080/","height":870}},"source":["!pip install mxnet-cu100\n","!pip install gluonnlp\n","!pip install gluoncv"],"execution_count":0,"outputs":[{"output_type":"stream","text":["Collecting mxnet-cu100\n","\u001b[?25l Downloading https://files.pythonhosted.org/packages/56/d3/e939814957c2f09ecdd22daa166898889d54e5981e356832425d514edfb6/mxnet_cu100-1.5.0-py2.py3-none-manylinux1_x86_64.whl (540.1MB)\n","\u001b[K |████████████████████████████████| 540.1MB 46kB/s \n","\u001b[?25hRequirement already satisfied: numpy<2.0.0,>1.16.0 in /usr/local/lib/python3.6/dist-packages (from mxnet-cu100) (1.16.4)\n","Collecting graphviz<0.9.0,>=0.8.1 (from mxnet-cu100)\n"," Downloading https://files.pythonhosted.org/packages/53/39/4ab213673844e0c004bed8a0781a0721a3f6bb23eb8854ee75c236428892/graphviz-0.8.4-py2.py3-none-any.whl\n","Requirement already satisfied: requests<3,>=2.20.0 in /usr/local/lib/python3.6/dist-packages (from mxnet-cu100) (2.21.0)\n","Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.20.0->mxnet-cu100) (2019.6.16)\n","Requirement already satisfied: idna<2.9,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.20.0->mxnet-cu100) (2.8)\n","Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.20.0->mxnet-cu100) (3.0.4)\n","Requirement already satisfied: urllib3<1.25,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.20.0->mxnet-cu100) (1.24.3)\n","Installing collected packages: graphviz, mxnet-cu100\n"," Found existing installation: graphviz 0.10.1\n"," Uninstalling graphviz-0.10.1:\n"," Successfully uninstalled graphviz-0.10.1\n","Successfully installed graphviz-0.8.4 mxnet-cu100-1.5.0\n","Collecting gluonnlp\n","\u001b[?25l Downloading https://files.pythonhosted.org/packages/c1/c8/e180cd98ab190e7ac3c6a767a909918e719be33f967bca13d0d4cd7c5468/gluonnlp-0.8.0.tar.gz (235kB)\n","\u001b[K |████████████████████████████████| 245kB 1.4MB/s \n","\u001b[?25hRequirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from gluonnlp) (1.16.4)\n","Building wheels for collected packages: gluonnlp\n"," Building wheel for gluonnlp (setup.py) ... \u001b[?25l\u001b[?25hdone\n"," Created wheel for gluonnlp: filename=gluonnlp-0.8.0-cp36-none-any.whl size=292704 sha256=92ad06adc4590f1df92c1c48804404d71da556919e8a16fa5aeeaf26f27681cb\n"," Stored in directory: /root/.cache/pip/wheels/28/ff/33/d73801f242fb93c02f2076f81232fcb9a29305480cc42c5454\n","Successfully built gluonnlp\n","Installing collected packages: gluonnlp\n","Successfully installed gluonnlp-0.8.0\n","Collecting gluoncv\n","\u001b[?25l Downloading https://files.pythonhosted.org/packages/3d/31/9c02604787d852bd0356ff0b5d727f7c94c9ff524cc8267245ece8d5c4a5/gluoncv-0.4.0.post0-py2.py3-none-any.whl (342kB)\n","\u001b[K |████████████████████████████████| 348kB 1.4MB/s \n","\u001b[?25hRequirement already satisfied: Pillow in /usr/local/lib/python3.6/dist-packages (from gluoncv) (4.3.0)\n","Requirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from gluoncv) (2.21.0)\n","Requirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (from gluoncv) (4.28.1)\n","Requirement already satisfied: scipy in /usr/local/lib/python3.6/dist-packages (from gluoncv) (1.3.0)\n","Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from gluoncv) (1.16.4)\n","Requirement already satisfied: matplotlib in /usr/local/lib/python3.6/dist-packages (from gluoncv) (3.0.3)\n","Requirement already satisfied: olefile in /usr/local/lib/python3.6/dist-packages (from Pillow->gluoncv) (0.46)\n","Requirement already satisfied: urllib3<1.25,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->gluoncv) (1.24.3)\n","Requirement already satisfied: idna<2.9,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->gluoncv) (2.8)\n","Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->gluoncv) (2019.6.16)\n","Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->gluoncv) (3.0.4)\n","Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->gluoncv) (1.1.0)\n","Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->gluoncv) (2.5.3)\n","Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.6/dist-packages (from matplotlib->gluoncv) (0.10.0)\n","Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->gluoncv) (2.4.2)\n","Requirement already satisfied: setuptools in /usr/local/lib/python3.6/dist-packages (from kiwisolver>=1.0.1->matplotlib->gluoncv) (41.0.1)\n","Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.6/dist-packages (from python-dateutil>=2.1->matplotlib->gluoncv) (1.12.0)\n","Installing collected packages: gluoncv\n","Successfully installed gluoncv-0.4.0.post0\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"FDYUzwOSKi79","colab_type":"code","outputId":"0a6d516d-873c-424f-907e-baf5ca638572","executionInfo":{"status":"ok","timestamp":1565323601193,"user_tz":-540,"elapsed":1494,"user":{"displayName":"seung hwan Jung","photoUrl":"","userId":"16375924554386741873"}},"colab":{"base_uri":"https://localhost:8080/","height":306}},"source":["!nvidia-smi"],"execution_count":0,"outputs":[{"output_type":"stream","text":["Fri Aug 9 04:06:40 2019 \n","+-----------------------------------------------------------------------------+\n","| NVIDIA-SMI 418.67 Driver Version: 410.79 CUDA Version: 10.0 |\n","|-------------------------------+----------------------+----------------------+\n","| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n","| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n","|===============================+======================+======================|\n","| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 |\n","| N/A 51C P8 16W / 70W | 0MiB / 15079MiB | 0% Default |\n","+-------------------------------+----------------------+----------------------+\n"," \n","+-----------------------------------------------------------------------------+\n","| Processes: GPU Memory |\n","| GPU PID Type Process name Usage |\n","|=============================================================================|\n","| No running processes found |\n","+-----------------------------------------------------------------------------+\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"aTagkRTjKkAH","colab_type":"code","colab":{}},"source":["from mxnet import gluon\n","\n","import warnings\n","warnings.simplefilter('ignore')\n","\n","import mxnet as mx\n","from mxnet import nd\n","import gluonnlp as nlp\n","import re"],"execution_count":0,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"ka17IvgPPZcQ","colab_type":"text"},"source":["Vocab 생성"]},{"cell_type":"code","metadata":{"id":"HA8svWEkLDGT","colab_type":"code","colab":{}},"source":["text = \"\"\"\n","지난달 일본 수출 규제 조치로 촉발된 일본제품 불매운동 여파로 유니클로 무인양품 등 일본 브랜드의 모바일 앱 사용자가 급격히 감소했다는 조사 결과가 나왔다\n","\"\"\""],"execution_count":0,"outputs":[]},{"cell_type":"code","metadata":{"id":"b-gGb5mSLNxd","colab_type":"code","colab":{}},"source":["def simple_tokenize(source_str, token_delim=' ', seq_delim='\\n'):\n"," return filter(None, re.split(token_delim + '|' + seq_delim, source_str))\n","\n","counter = nlp.data.count_tokens(simple_tokenize(text))"],"execution_count":0,"outputs":[]},{"cell_type":"code","metadata":{"id":"cd_HsAA-LQYE","colab_type":"code","colab":{}},"source":["vocab = nlp.Vocab(counter)"],"execution_count":0,"outputs":[]},{"cell_type":"code","metadata":{"id":"TphCGc6RLQ8u","colab_type":"code","outputId":"7b7e27fe-1bdb-427b-f261-bb32cf186094","executionInfo":{"status":"ok","timestamp":1565323632696,"user_tz":-540,"elapsed":568,"user":{"displayName":"seung hwan Jung","photoUrl":"","userId":"16375924554386741873"}},"colab":{"base_uri":"https://localhost:8080/","height":442}},"source":["for word in vocab.idx_to_token:\n"," print(word)"],"execution_count":0,"outputs":[{"output_type":"stream","text":["\n","\n","\n","\n","일본\n","감소했다는\n","결과가\n","규제\n","급격히\n","나왔다\n","등\n","모바일\n","무인양품\n","불매운동\n","브랜드의\n","사용자가\n","수출\n","앱\n","여파로\n","유니클로\n","일본제품\n","조사\n","조치로\n","지난달\n","촉발된\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"cdjiSKodPdoN","colab_type":"text"},"source":["이전 학습 데이터 기준 Word Embedding 생성"]},{"cell_type":"code","metadata":{"id":"sZ1jXr0Rj4q-","colab_type":"code","colab":{}},"source":["source_list = nlp.embedding.list_sources('fasttext')"],"execution_count":0,"outputs":[]},{"cell_type":"code","metadata":{"id":"gBZin34xj8x9","colab_type":"code","outputId":"6368dfc7-76a5-40a0-a276-7b18f46d7554","executionInfo":{"status":"ok","timestamp":1565323695673,"user_tz":-540,"elapsed":551,"user":{"displayName":"seung hwan Jung","photoUrl":"","userId":"16375924554386741873"}},"colab":{"base_uri":"https://localhost:8080/","height":68}},"source":["for s in source_list:\n"," if 'ko' in s:\n"," print(s)"],"execution_count":0,"outputs":[{"output_type":"stream","text":["wiki.koi\n","wiki.ko\n","cc.ko.300\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"xdfP4MPNLXJe","colab_type":"code","colab":{}},"source":["fasttext_simple = nlp.embedding.create('fasttext', source='wiki.ko')"],"execution_count":0,"outputs":[]},{"cell_type":"code","metadata":{"id":"sdqtWu9QM3bR","colab_type":"code","colab":{}},"source":["vocab.set_embedding(fasttext_simple)"],"execution_count":0,"outputs":[]},{"cell_type":"code","metadata":{"id":"gL2aW6qDM5yf","colab_type":"code","outputId":"38c78d70-080e-4f9d-c021-39f4de321c41","executionInfo":{"status":"ok","timestamp":1565324284643,"user_tz":-540,"elapsed":334,"user":{"displayName":"seung hwan Jung","photoUrl":"","userId":"16375924554386741873"}},"colab":{"base_uri":"https://localhost:8080/","height":71}},"source":["print('Vocabulary has length', len(vocab))\n","print(vocab.idx_to_token)"],"execution_count":0,"outputs":[{"output_type":"stream","text":["Vocabulary has length 25\n","['', '', '', '', '일본', '감소했다는', '결과가', '규제', '급격히', '나왔다', '등', '모바일', '무인양품', '불매운동', '브랜드의', '사용자가', '수출', '앱', '여파로', '유니클로', '일본제품', '조사', '조치로', '지난달', '촉발된']\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"CAHGdevKM9ne","colab_type":"code","outputId":"94bb97c2-a641-4a84-fd54-1159268d654b","executionInfo":{"status":"ok","timestamp":1565324285293,"user_tz":-540,"elapsed":525,"user":{"displayName":"seung hwan Jung","photoUrl":"","userId":"16375924554386741873"}},"colab":{"base_uri":"https://localhost:8080/","height":901}},"source":["vocab.embedding['수출']"],"execution_count":0,"outputs":[{"output_type":"execute_result","data":{"text/plain":["\n","[ 4.2638e-01 -3.4225e-01 6.3850e-02 1.8274e-01 -6.9525e-02 -8.3702e-01\n"," 1.9475e-01 -1.5824e-02 -1.7478e-04 4.9751e-02 -3.3179e-03 2.0388e-01\n"," -5.5108e-02 4.6147e-01 -4.2385e-01 2.5457e-01 -1.4857e-01 1.8358e-01\n"," -7.3676e-02 1.9369e-01 2.7336e-01 -3.5977e-01 -3.1542e-02 -3.2799e-01\n"," -1.0815e-01 8.5354e-02 -1.4569e-01 -4.0252e-01 9.3765e-03 5.4110e-02\n"," 2.6510e-01 1.2582e-01 -3.9834e-01 -2.7832e-01 1.6225e-01 4.5138e-02\n"," -5.4792e-02 -6.0437e-02 -3.2641e-01 -2.1744e-01 -8.7771e-02 1.3443e-01\n"," 1.3209e-01 -1.0430e-01 1.8851e-01 1.5267e-01 2.5451e-01 -3.8108e-01\n"," 5.0737e-01 1.2683e-01 3.6272e-02 -2.0192e-01 -7.0588e-02 -8.8823e-02\n"," 1.4248e-02 -1.3948e-01 -5.3001e-02 3.8400e-01 -7.1702e-02 -2.8495e-01\n"," 3.9161e-01 6.6414e-01 -1.2302e-01 -3.1577e-01 8.6273e-02 1.6665e-02\n"," -1.0768e-01 -1.5478e-01 -5.7639e-02 -3.9881e-01 3.4402e-01 -1.6982e-01\n"," 1.4718e-01 -6.1504e-01 -7.8528e-02 -5.3038e-02 -1.4703e-01 1.4217e-01\n"," 5.5486e-01 3.8348e-01 1.3249e-01 -2.2455e-01 2.8402e-01 5.3096e-02\n"," -9.9325e-02 -3.8445e-01 -1.2884e-01 -6.0842e-01 1.2491e-01 -2.3915e-01\n"," -4.1258e-01 -1.3533e-01 3.8791e-01 -3.1587e-01 -1.6278e-02 -9.0981e-02\n"," -5.2956e-02 2.6161e-01 -5.4387e-01 7.3484e-02 -1.5726e-01 -1.1190e-01\n"," -2.1202e-01 -4.6383e-01 -1.8112e-01 1.1111e-01 -1.4531e-01 -2.4862e-02\n"," -4.4514e-01 1.2022e-01 -1.7806e-01 3.1296e-01 1.0636e-01 -2.4084e-01\n"," 1.3785e-01 -5.8384e-01 2.3597e-01 5.5893e-02 -9.8666e-02 2.4723e-01\n"," -1.3603e-01 8.7914e-02 -2.4429e-01 -8.8579e-02 -1.8117e-02 2.1893e-01\n"," 3.3154e-01 1.0325e-01 5.5140e-02 3.3320e-03 -1.2757e-01 -1.8881e-02\n"," 8.7895e-02 -4.8405e-01 5.4673e-02 -6.9837e-01 1.8022e-02 6.0194e-01\n"," -9.6470e-03 -2.2723e-01 2.7933e-01 2.1040e-02 -3.8659e-01 3.3674e-01\n"," -2.1744e-01 3.2504e-01 4.9919e-02 -4.5032e-01 -9.0482e-01 -2.0204e-01\n"," -2.8953e-01 -1.9171e-01 1.2623e-01 -4.8595e-01 -8.6951e-02 -1.4328e-01\n"," 3.5679e-02 5.3210e-01 -1.9692e-01 7.5380e-02 2.2338e-01 -9.5107e-02\n"," 4.0847e-01 1.4180e-01 1.9727e-01 4.5828e-01 5.5788e-01 2.3105e-01\n"," -1.7546e-01 -2.3614e-01 3.9640e-02 -3.2772e-01 4.0468e-02 1.3873e-01\n"," 1.0368e-01 6.0927e-01 -2.2480e-01 2.9803e-01 -6.9286e-02 -1.3855e-01\n"," -5.9916e-02 -4.2182e-01 -3.0607e-01 -2.1136e-01 4.1947e-01 -6.2482e-01\n"," 2.6623e-01 1.6588e-01 9.0851e-01 2.5510e-01 4.7693e-01 5.9349e-01\n"," -3.0342e-01 6.8446e-01 1.1693e-01 -9.1779e-02 -1.6568e-02 -1.8064e-01\n"," -1.0110e-01 8.4037e-02 -3.0629e-01 2.7792e-01 4.2220e-01 3.4675e-01\n"," -4.4871e-02 -7.4020e-02 4.7964e-01 -2.4249e-01 2.6915e-01 2.4986e-01\n"," 6.1426e-02 -1.2248e-01 -2.4956e-01 -2.0275e-01 4.9726e-02 5.1225e-01\n"," -5.9698e-02 -4.4546e-01 -2.8758e-01 3.8029e-01 -1.1247e-01 -2.9148e-01\n"," 4.0575e-01 -3.6503e-01 5.8112e-02 3.4802e-01 6.3945e-01 -2.6490e-01\n"," -3.3999e-01 4.4482e-01 -4.5096e-01 3.3705e-01 -5.5393e-01 1.0783e-02\n"," -4.6373e-01 2.1076e-01 3.3907e-01 7.5862e-01 5.7707e-01 9.5326e-02\n"," 8.5128e-02 -1.3907e-01 -1.3865e-01 5.1447e-02 -6.5677e-01 6.2436e-03\n"," -1.9172e-01 -6.5026e-01 4.7809e-01 -3.0694e-01 -2.3818e-01 -2.0427e-01\n"," 2.9395e-01 2.5033e-01 -3.2696e-01 -1.8797e-01 -3.1706e-01 -1.7542e-01\n"," 2.0871e-01 -5.5197e-01 -2.0134e-01 -4.7705e-01 1.9890e-01 -1.1906e-01\n"," 5.3541e-01 3.3463e-02 1.5817e-01 1.1216e-01 -1.5229e-02 3.9583e-02\n"," 1.6018e-01 -4.5875e-01 -1.0683e-01 -5.3861e-01 -1.3970e-01 1.9986e-01\n"," -3.9659e-01 -1.8464e-01 -6.4784e-03 5.9818e-01 -2.2969e-01 -1.3996e-01\n"," -5.4446e-01 2.0123e-01 3.8428e-01 -6.4578e-02 -1.3425e-01 -2.9552e-01\n"," -2.7100e-02 6.4368e-01 -4.3834e-01 -3.4919e-01 3.4768e-01 1.1412e-01\n"," -1.0259e-01 -1.3130e-01 1.6735e-01 2.7717e-01 -7.9090e-02 -4.1903e-01]\n",""]},"metadata":{"tags":[]},"execution_count":50}]},{"cell_type":"markdown","metadata":{"id":"BY4QhBsgkVWg","colab_type":"text"},"source":["Emb Layer에서 Pre-trained Word Embedding 사용"]},{"cell_type":"code","metadata":{"id":"2zS1XeVGPBBB","colab_type":"code","outputId":"5860215b-2cbc-436b-de02-b0c36a59313e","executionInfo":{"status":"ok","timestamp":1565323799977,"user_tz":-540,"elapsed":520,"user":{"displayName":"seung hwan Jung","photoUrl":"","userId":"16375924554386741873"}},"colab":{"base_uri":"https://localhost:8080/","height":34}},"source":["vocab['일본', '수출']"],"execution_count":0,"outputs":[{"output_type":"execute_result","data":{"text/plain":["[4, 16]"]},"metadata":{"tags":[]},"execution_count":25}]},{"cell_type":"code","metadata":{"id":"ekmXg76_Lc0F","colab_type":"code","colab":{}},"source":["input_dim, output_dim = vocab.embedding.idx_to_vec.shape"],"execution_count":0,"outputs":[]},{"cell_type":"code","metadata":{"id":"ukqH2njCMkUD","colab_type":"code","colab":{}},"source":["layer = gluon.nn.Embedding(input_dim, output_dim)\n","layer.initialize(ctx=mx.gpu())\n","layer.weight.set_data(vocab.embedding.idx_to_vec)"],"execution_count":0,"outputs":[]},{"cell_type":"code","metadata":{"id":"Wi71FkWLMuNN","colab_type":"code","colab":{}},"source":["emb_out = layer(nd.array([4, 16], ctx = mx.gpu()))"],"execution_count":0,"outputs":[]},{"cell_type":"code","metadata":{"id":"OmLDR0PjOwU7","colab_type":"code","outputId":"c6c660bf-d162-4717-e5ba-fcfb1e9edd6f","executionInfo":{"status":"ok","timestamp":1565323842821,"user_tz":-540,"elapsed":519,"user":{"displayName":"seung hwan Jung","photoUrl":"","userId":"16375924554386741873"}},"colab":{"base_uri":"https://localhost:8080/","height":1000}},"source":["emb_out"],"execution_count":0,"outputs":[{"output_type":"execute_result","data":{"text/plain":["\n","[[ 1.8810e-01 1.9917e-01 -3.8932e-01 8.2046e-02 4.6638e-02 4.5055e-02\n"," -3.1671e-01 -3.6490e-02 5.0224e-02 -1.3456e-01 6.7149e-02 3.5040e-01\n"," -3.7936e-02 1.1197e-01 -4.0459e-01 3.6445e-01 3.8595e-01 1.4581e-01\n"," -1.4990e-01 -2.9435e-01 2.5718e-01 -6.7448e-02 2.2115e-01 -4.3913e-01\n"," 7.7741e-02 -1.7991e-01 -3.2989e-01 1.1673e-01 -3.1833e-01 -1.6808e-02\n"," -2.8415e-02 3.6290e-02 2.6130e-01 -4.2483e-01 -2.7343e-01 4.2707e-01\n"," -7.9224e-02 4.2217e-02 -2.5600e-01 -4.6470e-04 -3.3134e-02 4.3725e-02\n"," 1.7770e-01 -1.4922e-01 3.9644e-01 -1.2606e-01 -2.5324e-02 -9.2616e-02\n"," 1.5082e-02 1.5065e-01 2.9300e-01 3.6282e-01 -1.0867e-01 8.3840e-02\n"," 4.4520e-02 -6.1860e-02 -3.8661e-01 -7.9247e-02 -7.2375e-02 -5.9540e-02\n"," 2.4200e-01 -1.1036e-01 2.0190e-03 1.2865e-01 -6.6143e-02 3.5190e-02\n"," 2.2175e-01 -1.3672e-01 1.7555e-01 -2.4323e-01 2.6840e-01 1.9171e-01\n"," -1.1305e-01 -1.6910e-01 -3.9212e-03 -2.0387e-02 -3.6294e-01 1.1597e-01\n"," 4.2091e-02 -5.8388e-02 -9.0254e-02 -6.9849e-02 -1.6469e-01 2.2502e-02\n"," -8.5158e-02 1.3084e-02 4.9627e-01 9.4673e-02 -2.1923e-01 2.2622e-01\n"," 2.5301e-01 -2.0715e-02 1.0021e-01 -5.6142e-02 -1.8559e-01 -7.4258e-02\n"," -1.0302e-01 2.7017e-01 -1.2915e-01 -1.2952e-01 1.8582e-01 -6.3601e-02\n"," -7.3645e-02 2.3037e-02 -3.5136e-01 1.1340e-01 -3.7866e-03 -7.5002e-02\n"," -5.7064e-01 7.1485e-02 -5.6294e-02 -3.7366e-01 -1.8191e-01 -2.8404e-01\n"," 7.3318e-02 -9.2956e-02 1.3273e-01 2.6637e-01 2.6431e-02 1.4460e-01\n"," -2.3493e-01 2.7698e-01 -2.6268e-01 -4.3460e-03 -1.9252e-01 9.6839e-02\n"," 9.0349e-02 -6.2047e-02 2.6696e-01 -3.2123e-01 4.4355e-04 -2.6696e-01\n"," -1.3026e-01 -8.9901e-02 -1.0029e-01 -2.5110e-01 1.4319e-01 1.7760e-01\n"," -2.6844e-01 1.2893e-01 -4.7380e-02 9.6227e-03 -8.1744e-02 6.1455e-02\n"," 2.8503e-01 3.4132e-01 1.2467e-01 2.3180e-01 -3.5404e-02 2.1141e-01\n"," 1.7600e-01 3.6461e-01 3.0553e-01 -3.9010e-01 3.0512e-01 -1.5444e-01\n"," -2.5756e-01 4.7070e-01 -1.2910e-01 1.6549e-01 5.6344e-01 2.0581e-01\n"," 4.0742e-02 1.7386e-01 2.5355e-01 1.0239e-01 1.9662e-01 1.0232e-01\n"," -3.0957e-01 -4.9555e-01 -7.9367e-03 1.9833e-01 5.0840e-01 -1.2383e-01\n"," 2.5617e-01 1.2360e-01 -3.6718e-01 -2.1787e-01 -1.0816e-01 -2.5369e-01\n"," 7.4009e-02 -4.1643e-02 -4.8805e-02 -6.7081e-02 4.4067e-01 -4.7110e-01\n"," 2.2228e-01 -2.2026e-01 4.0538e-01 -4.2518e-01 1.0732e-01 3.1531e-01\n"," 9.5380e-02 4.6025e-01 3.6392e-01 -5.6403e-01 -3.2210e-03 2.9005e-01\n"," -4.8466e-01 7.9554e-02 1.6156e-02 2.5028e-01 -1.0411e-01 1.5570e-01\n"," -6.5607e-02 -4.2981e-01 -1.1829e-01 -3.0163e-01 -3.3469e-01 -2.5432e-01\n"," -1.9175e-01 6.1773e-01 1.5270e-01 -3.7654e-01 3.6843e-02 -3.7617e-01\n"," -1.2876e-01 -1.0712e-01 3.7194e-02 4.6964e-01 -5.0154e-01 -1.4163e-01\n"," 1.6431e-01 -5.6524e-01 4.4852e-01 3.6020e-01 4.6376e-01 -1.0570e-03\n"," -1.7882e-01 1.1948e-01 -4.8075e-01 1.4392e-01 1.3072e-01 -1.7797e-03\n"," -2.1375e-01 2.2058e-01 -1.5685e-01 -1.7553e-01 4.3184e-01 -1.5801e-01\n"," 9.7793e-02 2.3302e-02 3.4538e-02 3.8361e-01 -6.9856e-02 -5.0388e-01\n"," 1.1685e-01 -4.2123e-01 2.5137e-01 -4.1061e-01 -7.9262e-02 -5.0755e-01\n"," -5.6375e-02 2.3734e-03 1.0352e-02 2.9952e-01 -3.6774e-01 2.2305e-01\n"," 1.2627e-01 1.5429e-01 4.1420e-02 -3.8243e-01 4.1338e-01 -8.4710e-02\n"," -7.0449e-02 4.1873e-02 4.3457e-01 4.6773e-01 2.3689e-02 -3.9291e-03\n"," 1.2443e-01 1.5383e-01 4.3505e-04 7.3725e-02 -1.8785e-02 -2.6812e-01\n"," -1.1089e-01 2.8354e-02 -1.6654e-01 1.9243e-01 1.8665e-01 1.0634e-02\n"," -2.4506e-01 -3.1665e-01 -4.5618e-01 -1.4756e-01 3.0725e-01 -3.6764e-02\n"," -2.7107e-01 3.4827e-01 -2.5737e-01 9.3833e-02 3.7829e-02 4.9497e-02\n"," 4.3981e-01 3.7246e-02 2.2015e-01 8.5563e-02 -1.6931e-02 4.7679e-01]\n"," [ 4.2638e-01 -3.4225e-01 6.3850e-02 1.8274e-01 -6.9525e-02 -8.3702e-01\n"," 1.9475e-01 -1.5824e-02 -1.7478e-04 4.9751e-02 -3.3179e-03 2.0388e-01\n"," -5.5108e-02 4.6147e-01 -4.2385e-01 2.5457e-01 -1.4857e-01 1.8358e-01\n"," -7.3676e-02 1.9369e-01 2.7336e-01 -3.5977e-01 -3.1542e-02 -3.2799e-01\n"," -1.0815e-01 8.5354e-02 -1.4569e-01 -4.0252e-01 9.3765e-03 5.4110e-02\n"," 2.6510e-01 1.2582e-01 -3.9834e-01 -2.7832e-01 1.6225e-01 4.5138e-02\n"," -5.4792e-02 -6.0437e-02 -3.2641e-01 -2.1744e-01 -8.7771e-02 1.3443e-01\n"," 1.3209e-01 -1.0430e-01 1.8851e-01 1.5267e-01 2.5451e-01 -3.8108e-01\n"," 5.0737e-01 1.2683e-01 3.6272e-02 -2.0192e-01 -7.0588e-02 -8.8823e-02\n"," 1.4248e-02 -1.3948e-01 -5.3001e-02 3.8400e-01 -7.1702e-02 -2.8495e-01\n"," 3.9161e-01 6.6414e-01 -1.2302e-01 -3.1577e-01 8.6273e-02 1.6665e-02\n"," -1.0768e-01 -1.5478e-01 -5.7639e-02 -3.9881e-01 3.4402e-01 -1.6982e-01\n"," 1.4718e-01 -6.1504e-01 -7.8528e-02 -5.3038e-02 -1.4703e-01 1.4217e-01\n"," 5.5486e-01 3.8348e-01 1.3249e-01 -2.2455e-01 2.8402e-01 5.3096e-02\n"," -9.9325e-02 -3.8445e-01 -1.2884e-01 -6.0842e-01 1.2491e-01 -2.3915e-01\n"," -4.1258e-01 -1.3533e-01 3.8791e-01 -3.1587e-01 -1.6278e-02 -9.0981e-02\n"," -5.2956e-02 2.6161e-01 -5.4387e-01 7.3484e-02 -1.5726e-01 -1.1190e-01\n"," -2.1202e-01 -4.6383e-01 -1.8112e-01 1.1111e-01 -1.4531e-01 -2.4862e-02\n"," -4.4514e-01 1.2022e-01 -1.7806e-01 3.1296e-01 1.0636e-01 -2.4084e-01\n"," 1.3785e-01 -5.8384e-01 2.3597e-01 5.5893e-02 -9.8666e-02 2.4723e-01\n"," -1.3603e-01 8.7914e-02 -2.4429e-01 -8.8579e-02 -1.8117e-02 2.1893e-01\n"," 3.3154e-01 1.0325e-01 5.5140e-02 3.3320e-03 -1.2757e-01 -1.8881e-02\n"," 8.7895e-02 -4.8405e-01 5.4673e-02 -6.9837e-01 1.8022e-02 6.0194e-01\n"," -9.6470e-03 -2.2723e-01 2.7933e-01 2.1040e-02 -3.8659e-01 3.3674e-01\n"," -2.1744e-01 3.2504e-01 4.9919e-02 -4.5032e-01 -9.0482e-01 -2.0204e-01\n"," -2.8953e-01 -1.9171e-01 1.2623e-01 -4.8595e-01 -8.6951e-02 -1.4328e-01\n"," 3.5679e-02 5.3210e-01 -1.9692e-01 7.5380e-02 2.2338e-01 -9.5107e-02\n"," 4.0847e-01 1.4180e-01 1.9727e-01 4.5828e-01 5.5788e-01 2.3105e-01\n"," -1.7546e-01 -2.3614e-01 3.9640e-02 -3.2772e-01 4.0468e-02 1.3873e-01\n"," 1.0368e-01 6.0927e-01 -2.2480e-01 2.9803e-01 -6.9286e-02 -1.3855e-01\n"," -5.9916e-02 -4.2182e-01 -3.0607e-01 -2.1136e-01 4.1947e-01 -6.2482e-01\n"," 2.6623e-01 1.6588e-01 9.0851e-01 2.5510e-01 4.7693e-01 5.9349e-01\n"," -3.0342e-01 6.8446e-01 1.1693e-01 -9.1779e-02 -1.6568e-02 -1.8064e-01\n"," -1.0110e-01 8.4037e-02 -3.0629e-01 2.7792e-01 4.2220e-01 3.4675e-01\n"," -4.4871e-02 -7.4020e-02 4.7964e-01 -2.4249e-01 2.6915e-01 2.4986e-01\n"," 6.1426e-02 -1.2248e-01 -2.4956e-01 -2.0275e-01 4.9726e-02 5.1225e-01\n"," -5.9698e-02 -4.4546e-01 -2.8758e-01 3.8029e-01 -1.1247e-01 -2.9148e-01\n"," 4.0575e-01 -3.6503e-01 5.8112e-02 3.4802e-01 6.3945e-01 -2.6490e-01\n"," -3.3999e-01 4.4482e-01 -4.5096e-01 3.3705e-01 -5.5393e-01 1.0783e-02\n"," -4.6373e-01 2.1076e-01 3.3907e-01 7.5862e-01 5.7707e-01 9.5326e-02\n"," 8.5128e-02 -1.3907e-01 -1.3865e-01 5.1447e-02 -6.5677e-01 6.2436e-03\n"," -1.9172e-01 -6.5026e-01 4.7809e-01 -3.0694e-01 -2.3818e-01 -2.0427e-01\n"," 2.9395e-01 2.5033e-01 -3.2696e-01 -1.8797e-01 -3.1706e-01 -1.7542e-01\n"," 2.0871e-01 -5.5197e-01 -2.0134e-01 -4.7705e-01 1.9890e-01 -1.1906e-01\n"," 5.3541e-01 3.3463e-02 1.5817e-01 1.1216e-01 -1.5229e-02 3.9583e-02\n"," 1.6018e-01 -4.5875e-01 -1.0683e-01 -5.3861e-01 -1.3970e-01 1.9986e-01\n"," -3.9659e-01 -1.8464e-01 -6.4784e-03 5.9818e-01 -2.2969e-01 -1.3996e-01\n"," -5.4446e-01 2.0123e-01 3.8428e-01 -6.4578e-02 -1.3425e-01 -2.9552e-01\n"," -2.7100e-02 6.4368e-01 -4.3834e-01 -3.4919e-01 3.4768e-01 1.1412e-01\n"," -1.0259e-01 -1.3130e-01 1.6735e-01 2.7717e-01 -7.9090e-02 -4.1903e-01]]\n",""]},"metadata":{"tags":[]},"execution_count":31}]},{"cell_type":"markdown","metadata":{"id":"csw9vNmAlok4","colab_type":"text"},"source":[" 기존에 만들어진 Vocab 활용"]},{"cell_type":"code","metadata":{"id":"qZ9NIm_zklK4","colab_type":"code","colab":{}},"source":["vocab_pretrained = nlp.Vocab(nlp.data.Counter(fasttext_simple.idx_to_token))"],"execution_count":0,"outputs":[]},{"cell_type":"code","metadata":{"id":"c3Gn74-ul7IZ","colab_type":"code","outputId":"47595287-b190-4540-8a7a-6913ec1cd677","executionInfo":{"status":"ok","timestamp":1565324230118,"user_tz":-540,"elapsed":547,"user":{"displayName":"seung hwan Jung","photoUrl":"","userId":"16375924554386741873"}},"colab":{"base_uri":"https://localhost:8080/","height":34}},"source":["len(vocab_pretrained)"],"execution_count":0,"outputs":[{"output_type":"execute_result","data":{"text/plain":["879133"]},"metadata":{"tags":[]},"execution_count":43}]},{"cell_type":"code","metadata":{"id":"tF0I3dNFmFb7","colab_type":"code","outputId":"f9d5a7c6-04c2-4deb-8a21-af25f75316b6","executionInfo":{"status":"ok","timestamp":1565324345808,"user_tz":-540,"elapsed":554,"user":{"displayName":"seung hwan Jung","photoUrl":"","userId":"16375924554386741873"}},"colab":{"base_uri":"https://localhost:8080/","height":34}},"source":["vocab_pretrained.token_to_idx['수출']"],"execution_count":0,"outputs":[{"output_type":"execute_result","data":{"text/plain":["472039"]},"metadata":{"tags":[]},"execution_count":56}]},{"cell_type":"code","metadata":{"id":"fTUlNvaombPt","colab_type":"code","outputId":"335c1221-4870-4da3-d7dc-0d774c320873","executionInfo":{"status":"ok","timestamp":1565324339403,"user_tz":-540,"elapsed":543,"user":{"displayName":"seung hwan Jung","photoUrl":"","userId":"16375924554386741873"}},"colab":{"base_uri":"https://localhost:8080/","height":34}},"source":["vocab_pretrained.idx_to_token[591821]"],"execution_count":0,"outputs":[{"output_type":"execute_result","data":{"text/plain":["'위치였기'"]},"metadata":{"tags":[]},"execution_count":55}]},{"cell_type":"markdown","metadata":{"id":"ngdsPRcJN5Kn","colab_type":"text"},"source":["![대체 텍스트](https://)\n"]},{"cell_type":"markdown","metadata":{"id":"EuTIiewZmijs","colab_type":"text"},"source":["Word Similarity"]},{"cell_type":"code","metadata":{"id":"KVbCnF7LkuM6","colab_type":"code","colab":{}},"source":["def norm_vecs_by_row(x):\n"," return x / nd.sqrt(nd.sum(x * x, axis=1) + 1E-10).reshape((-1,1))\n"],"execution_count":0,"outputs":[]},{"cell_type":"code","metadata":{"id":"Pujzyq0bkxFJ","colab_type":"code","colab":{}},"source":["def get_knn(vocab, k, word):\n"," word_vec = vocab.embedding[word].reshape((-1, 1))\n"," vocab_vecs = norm_vecs_by_row(vocab.embedding.idx_to_vec)\n"," dot_prod = nd.dot(vocab_vecs, word_vec)\n"," indices = nd.topk(dot_prod.reshape((len(vocab), )), k=k+1, ret_typ='indices')\n"," indices = [int(i.asscalar()) for i in indices]\n"," # Remove input tokens.\n"," return vocab.to_tokens(indices[1:])"],"execution_count":0,"outputs":[]},{"cell_type":"code","metadata":{"id":"r7zoFi6AoaFC","colab_type":"code","colab":{}},"source":["def cos_sim(x, y):\n"," return nd.dot(x, y) / (nd.norm(x) * nd.norm(y))"],"execution_count":0,"outputs":[]},{"cell_type":"code","metadata":{"id":"APoFddGEkyFr","colab_type":"code","outputId":"dc871c05-f6b2-4df5-eb06-6c5e9075f64b","executionInfo":{"status":"ok","timestamp":1565324426426,"user_tz":-540,"elapsed":534,"user":{"displayName":"seung hwan Jung","photoUrl":"","userId":"16375924554386741873"}},"colab":{"base_uri":"https://localhost:8080/","height":34}},"source":["get_knn(vocab, 5, '수출')"],"execution_count":0,"outputs":[{"output_type":"execute_result","data":{"text/plain":["['무인양품', '규제', '브랜드의', '불매운동', '앱']"]},"metadata":{"tags":[]},"execution_count":60}]},{"cell_type":"code","metadata":{"id":"tthKdI3WmugP","colab_type":"code","colab":{}},"source":["vocab_pretrained.set_embedding(fasttext_simple)"],"execution_count":0,"outputs":[]},{"cell_type":"code","metadata":{"id":"TId-n1nVnChf","colab_type":"code","outputId":"4931ed98-4f34-443d-e1ad-0fdc7f3d125c","executionInfo":{"status":"ok","timestamp":1565324499706,"user_tz":-540,"elapsed":3738,"user":{"displayName":"seung hwan Jung","photoUrl":"","userId":"16375924554386741873"}},"colab":{"base_uri":"https://localhost:8080/","height":34}},"source":["get_knn(vocab_pretrained, 5, '한국')"],"execution_count":0,"outputs":[{"output_type":"execute_result","data":{"text/plain":["['한국뿐', '한국와', '한국뿐만', '대한민국뿐만', '대한민국와']"]},"metadata":{"tags":[]},"execution_count":62}]},{"cell_type":"code","metadata":{"id":"3lT9H1fKnEwT","colab_type":"code","outputId":"2e6b8cdc-75eb-423b-9e4a-54e96418b432","executionInfo":{"status":"ok","timestamp":1565324879099,"user_tz":-540,"elapsed":494,"user":{"displayName":"seung hwan Jung","photoUrl":"","userId":"16375924554386741873"}},"colab":{"base_uri":"https://localhost:8080/","height":68}},"source":["cos_sim(vocab_pretrained.embedding['한국'], vocab_pretrained.embedding['일본'])"],"execution_count":0,"outputs":[{"output_type":"execute_result","data":{"text/plain":["\n","[0.58051175]\n",""]},"metadata":{"tags":[]},"execution_count":64}]},{"cell_type":"code","metadata":{"id":"3XXIKjxTokU7","colab_type":"code","colab":{}},"source":[""],"execution_count":0,"outputs":[]}]} -------------------------------------------------------------------------------- /code/3_1_intent_classification_pycon2019.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "kernelspec": { 6 | "name": "python3", 7 | "display_name": "Python 3" 8 | }, 9 | "language_info": { 10 | "codemirror_mode": { 11 | "name": "ipython", 12 | "version": 3 13 | }, 14 | "file_extension": ".py", 15 | "mimetype": "text/x-python", 16 | "name": "python", 17 | "nbconvert_exporter": "python", 18 | "pygments_lexer": "ipython3", 19 | "version": "3.5.2" 20 | }, 21 | "colab": { 22 | "name": "intent_classification.ipynb", 23 | "version": "0.3.2", 24 | "provenance": [], 25 | "collapsed_sections": [] 26 | }, 27 | "accelerator": "GPU" 28 | }, 29 | "cells": [ 30 | { 31 | "cell_type": "code", 32 | "metadata": { 33 | "id": "Z94sO9eQBEFe", 34 | "colab_type": "code", 35 | "colab": { 36 | "base_uri": "https://localhost:8080/", 37 | "height": 320 38 | }, 39 | "outputId": "40ebffb7-652a-408c-9039-6a796a760182" 40 | }, 41 | "source": [ 42 | "!pip install mxnet-cu100\n", 43 | "!pip install gluonnlp pandas tqdm" 44 | ], 45 | "execution_count": 3, 46 | "outputs": [ 47 | { 48 | "output_type": "stream", 49 | "text": [ 50 | "Collecting gluonnlp\n", 51 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/c1/c8/e180cd98ab190e7ac3c6a767a909918e719be33f967bca13d0d4cd7c5468/gluonnlp-0.8.0.tar.gz (235kB)\n", 52 | "\u001b[K |████████████████████████████████| 245kB 1.4MB/s \n", 53 | "\u001b[?25hRequirement already satisfied: pandas in /usr/local/lib/python3.6/dist-packages (0.24.2)\n", 54 | "Requirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (4.28.1)\n", 55 | "Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from gluonnlp) (1.16.4)\n", 56 | "Requirement already satisfied: pytz>=2011k in /usr/local/lib/python3.6/dist-packages (from pandas) (2018.9)\n", 57 | "Requirement already satisfied: python-dateutil>=2.5.0 in /usr/local/lib/python3.6/dist-packages (from pandas) (2.5.3)\n", 58 | "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.6/dist-packages (from python-dateutil>=2.5.0->pandas) (1.12.0)\n", 59 | "Building wheels for collected packages: gluonnlp\n", 60 | " Building wheel for gluonnlp (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 61 | " Created wheel for gluonnlp: filename=gluonnlp-0.8.0-cp36-none-any.whl size=292704 sha256=4fed194e608655d54c907bdce8c379817cb6c1e42c09936efccaee6111328e97\n", 62 | " Stored in directory: /root/.cache/pip/wheels/28/ff/33/d73801f242fb93c02f2076f81232fcb9a29305480cc42c5454\n", 63 | "Successfully built gluonnlp\n", 64 | "Installing collected packages: gluonnlp\n", 65 | "Successfully installed gluonnlp-0.8.0\n" 66 | ], 67 | "name": "stdout" 68 | } 69 | ] 70 | }, 71 | { 72 | "cell_type": "markdown", 73 | "metadata": { 74 | "id": "EGi63M1KBBYd", 75 | "colab_type": "text" 76 | }, 77 | "source": [ 78 | "## Intent Classification" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "metadata": { 84 | "id": "5FkNc8PsBBYk", 85 | "colab_type": "code", 86 | "colab": {} 87 | }, 88 | "source": [ 89 | "import pandas as pd\n", 90 | "import numpy as np\n", 91 | "from mxnet.gluon import nn, rnn\n", 92 | "from mxnet import gluon, autograd\n", 93 | "import gluonnlp as nlp\n", 94 | "from mxnet import nd \n", 95 | "import mxnet as mx\n", 96 | "import time\n", 97 | "import itertools\n", 98 | "from tqdm import tqdm\n", 99 | "import multiprocessing as mp" 100 | ], 101 | "execution_count": 0, 102 | "outputs": [] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "metadata": { 107 | "id": "r6fFvaPPBBYx", 108 | "colab_type": "code", 109 | "colab": {} 110 | }, 111 | "source": [ 112 | "train_raw = pd.read_csv(\"https://www.dropbox.com/s/83n0uoy20rd2vq4/trainset.txt?dl=1\",names=['intent', 'entity', 'sentence'], sep='\\t')\n", 113 | "#validation_raw = pd.read_csv(\"https://www.dropbox.com/s/kbl7kw54jdo2550/test_hidden.txt?dl=1\",names=['intent', 'entity', 'sentence'], sep='\\t')\n", 114 | "validation_raw = pd.read_csv(\"https://www.dropbox.com/s/enxp9yt9cstcal2/validation.txt?dl=1\",names=['intent', 'entity', 'sentence'], sep='\\t')" 115 | ], 116 | "execution_count": 0, 117 | "outputs": [] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "metadata": { 122 | "id": "xlIx66ZgBBZB", 123 | "colab_type": "code", 124 | "colab": { 125 | "base_uri": "https://localhost:8080/", 126 | "height": 990 127 | }, 128 | "outputId": "f79eb6ae-1907-4b21-a87a-c6dbcf0325e8" 129 | }, 130 | "source": [ 131 | "train_raw.head(30)" 132 | ], 133 | "execution_count": 6, 134 | "outputs": [ 135 | { 136 | "output_type": "execute_result", 137 | "data": { 138 | "text/html": [ 139 | "
\n", 140 | "\n", 153 | "\n", 154 | " \n", 155 | " \n", 156 | " \n", 157 | " \n", 158 | " \n", 159 | " \n", 160 | " \n", 161 | " \n", 162 | " \n", 163 | " \n", 164 | " \n", 165 | " \n", 166 | " \n", 167 | " \n", 168 | " \n", 169 | " \n", 170 | " \n", 171 | " \n", 172 | " \n", 173 | " \n", 174 | " \n", 175 | " \n", 176 | " \n", 177 | " \n", 178 | " \n", 179 | " \n", 180 | " \n", 181 | " \n", 182 | " \n", 183 | " \n", 184 | " \n", 185 | " \n", 186 | " \n", 187 | " \n", 188 | " \n", 189 | " \n", 190 | " \n", 191 | " \n", 192 | " \n", 193 | " \n", 194 | " \n", 195 | " \n", 196 | " \n", 197 | " \n", 198 | " \n", 199 | " \n", 200 | " \n", 201 | " \n", 202 | " \n", 203 | " \n", 204 | " \n", 205 | " \n", 206 | " \n", 207 | " \n", 208 | " \n", 209 | " \n", 210 | " \n", 211 | " \n", 212 | " \n", 213 | " \n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | " \n", 223 | " \n", 224 | " \n", 225 | " \n", 226 | " \n", 227 | " \n", 228 | " \n", 229 | " \n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | " \n", 274 | " \n", 275 | " \n", 276 | " \n", 277 | " \n", 278 | " \n", 279 | " \n", 280 | " \n", 281 | " \n", 282 | " \n", 283 | " \n", 284 | " \n", 285 | " \n", 286 | " \n", 287 | " \n", 288 | " \n", 289 | " \n", 290 | " \n", 291 | " \n", 292 | " \n", 293 | " \n", 294 | " \n", 295 | " \n", 296 | " \n", 297 | " \n", 298 | " \n", 299 | " \n", 300 | " \n", 301 | " \n", 302 | " \n", 303 | " \n", 304 | " \n", 305 | " \n", 306 | " \n", 307 | " \n", 308 | " \n", 309 | " \n", 310 | " \n", 311 | " \n", 312 | " \n", 313 | " \n", 314 | " \n", 315 | " \n", 316 | " \n", 317 | " \n", 318 | " \n", 319 | " \n", 320 | " \n", 321 | " \n", 322 | " \n", 323 | " \n", 324 | " \n", 325 | " \n", 326 | " \n", 327 | " \n", 328 | " \n", 329 | " \n", 330 | " \n", 331 | " \n", 332 | " \n", 333 | " \n", 334 | " \n", 335 | " \n", 336 | " \n", 337 | " \n", 338 | " \n", 339 | " \n", 340 | " \n", 341 | " \n", 342 | " \n", 343 | " \n", 344 | "
intententitysentence
0areaEECCCCCCCCCCCCCCCCCCC자강의 면적은 얼마 정도되는지 알려줄래
1birth_dateCCCCCCCCCCCCEEECCCCCCCCCCCCWIKI PEDIA로 변재일 생년월일을 알고 싶어
2ageEEEEEEEEEEECCCCCCCCCCCCCCCCC남쪽 물고기자리 알파 나이가 위키백과사전으로 얼마야
3lengthEEEECCCCCCCCCCCCCCCCCC삼양터널의 총 길이 위키백과사전에서 뭐야
4birth_placeEEEEEECCCCCCCCCCC코니 윌리스의 태어난 곳은 뭐지
5weightCCCCCCCCCCCCEEEECCCCCCCCCCCCCWIKI백과사전 검색 AA12의 무게가 얼만지 찾아봐
6definitionCCCCCCCCCCCCCEEECCCCCCCCWIKIPEDIA백과로 라이프 찾아서 말해줘
7heightEEEEEEEECCCCCCCCCCCCCCCCCCC송파 헬리오시티 구조물 높이 위키 피디아에서 뭐야
8birth_dateCCCEEEEEECCCCCCCCCCCCCCC검색 HLKVAM 언제 출생했는지를 검색해라
9heightCCCCCCCCEEEEEECCCCCCCC위키 피디아에 푸조 508 전고가 몇이야
10lengthCCCEEEEECCCCCCC검색 호몬혼 섬 길이를 찾아
11definitionEEEEECCCCCCCCCCCCC영산중학교 좀 위키피디아사전 검색
12ageCCCCCCEEEEEECCCCCCC위키백과로 침보라조 산 나이 어떤지
13birth_dateEEEEEEECCCCCCCC마무드 아스라의 출생 찾아줘
14birth_placeCCCCCCEEEEEEECCCCCCCCC위키 백과 조제 카리오카의 출생지를 찾아
15birth_dateCCCEEEEEECCCCCCCCC검색 제이 개츠비 생년월일은 뭐지
16lengthEEEECCCCCCCCCCCCCCCCC증약터널의 길이가 얼마쯤인지 혹시 알아
17belong_toEEEEEEEEEEEEEEEECCCCCCCCCCCCCC리히텐슈타인의 한스 아담 2세 소속사는 어딘지 검색해봐
18heightCCCCCCCCCCCCEEEEECCCCCCCCCWIKI사전백과 검색 벨록스여우의 높이는 얼만지
19ageEEEEEECCCCCCCCC파블롭스키구의 나이를 찾아줘
20widthEEEEEEECCCCCCCCCCCCCC사카피솔라 섬의 너비는 WIKI에서 뭐
21birth_placeEECCCCCCCCCCCCCCCCC나미는 태어난 곳이 WIKI로 뭔지
22weightCCCCCEEEEECCCCCCC위키에서 피니스테르의 무게 찾기
23birth_placeCCCEEEEEEECCCCCCCCCCC검색 카를 야스퍼스 출신지역이 어디라고
24widthEEEEEEEEEEECCCCCCC63식 병력수송장갑차의 폭 얼만지
25birth_placeCCCCCEEECCCCCCCCCCCC검색으로 강마에가 출생 장소를 찾아줘
26birth_dateEEEEEECCCCCCCCCCCCCC쿠죠 히카리의 언제 출생했는지 탐색해
27lengthEEECCCCCCCCCCC사하라의 길이가 얼마쯤이지
28areaEEEECCCCCCCCC송대산성의 면적은 얼만지
29areaCCCCCCCCCCEEEEEECCCCCCCWIKI 피디아에 신자경선생묘의 넓이 뭔지
\n", 345 | "
" 346 | ], 347 | "text/plain": [ 348 | " intent entity sentence\n", 349 | "0 area EECCCCCCCCCCCCCCCCCCC 자강의 면적은 얼마 정도되는지 알려줄래\n", 350 | "1 birth_date CCCCCCCCCCCCEEECCCCCCCCCCCC WIKI PEDIA로 변재일 생년월일을 알고 싶어\n", 351 | "2 age EEEEEEEEEEECCCCCCCCCCCCCCCCC 남쪽 물고기자리 알파 나이가 위키백과사전으로 얼마야\n", 352 | "3 length EEEECCCCCCCCCCCCCCCCCC 삼양터널의 총 길이 위키백과사전에서 뭐야\n", 353 | "4 birth_place EEEEEECCCCCCCCCCC 코니 윌리스의 태어난 곳은 뭐지\n", 354 | "5 weight CCCCCCCCCCCCEEEECCCCCCCCCCCCC WIKI백과사전 검색 AA12의 무게가 얼만지 찾아봐\n", 355 | "6 definition CCCCCCCCCCCCCEEECCCCCCCC WIKIPEDIA백과로 라이프 찾아서 말해줘\n", 356 | "7 height EEEEEEEECCCCCCCCCCCCCCCCCCC 송파 헬리오시티 구조물 높이 위키 피디아에서 뭐야\n", 357 | "8 birth_date CCCEEEEEECCCCCCCCCCCCCCC 검색 HLKVAM 언제 출생했는지를 검색해라\n", 358 | "9 height CCCCCCCCEEEEEECCCCCCCC 위키 피디아에 푸조 508 전고가 몇이야\n", 359 | "10 length CCCEEEEECCCCCCC 검색 호몬혼 섬 길이를 찾아\n", 360 | "11 definition EEEEECCCCCCCCCCCCC 영산중학교 좀 위키피디아사전 검색\n", 361 | "12 age CCCCCCEEEEEECCCCCCC 위키백과로 침보라조 산 나이 어떤지\n", 362 | "13 birth_date EEEEEEECCCCCCCC 마무드 아스라의 출생 찾아줘\n", 363 | "14 birth_place CCCCCCEEEEEEECCCCCCCCC 위키 백과 조제 카리오카의 출생지를 찾아\n", 364 | "15 birth_date CCCEEEEEECCCCCCCCC 검색 제이 개츠비 생년월일은 뭐지\n", 365 | "16 length EEEECCCCCCCCCCCCCCCCC 증약터널의 길이가 얼마쯤인지 혹시 알아\n", 366 | "17 belong_to EEEEEEEEEEEEEEEECCCCCCCCCCCCCC 리히텐슈타인의 한스 아담 2세 소속사는 어딘지 검색해봐\n", 367 | "18 height CCCCCCCCCCCCEEEEECCCCCCCCC WIKI사전백과 검색 벨록스여우의 높이는 얼만지\n", 368 | "19 age EEEEEECCCCCCCCC 파블롭스키구의 나이를 찾아줘\n", 369 | "20 width EEEEEEECCCCCCCCCCCCCC 사카피솔라 섬의 너비는 WIKI에서 뭐\n", 370 | "21 birth_place EECCCCCCCCCCCCCCCCC 나미는 태어난 곳이 WIKI로 뭔지\n", 371 | "22 weight CCCCCEEEEECCCCCCC 위키에서 피니스테르의 무게 찾기\n", 372 | "23 birth_place CCCEEEEEEECCCCCCCCCCC 검색 카를 야스퍼스 출신지역이 어디라고\n", 373 | "24 width EEEEEEEEEEECCCCCCC 63식 병력수송장갑차의 폭 얼만지\n", 374 | "25 birth_place CCCCCEEECCCCCCCCCCCC 검색으로 강마에가 출생 장소를 찾아줘\n", 375 | "26 birth_date EEEEEECCCCCCCCCCCCCC 쿠죠 히카리의 언제 출생했는지 탐색해\n", 376 | "27 length EEECCCCCCCCCCC 사하라의 길이가 얼마쯤이지\n", 377 | "28 area EEEECCCCCCCCC 송대산성의 면적은 얼만지\n", 378 | "29 area CCCCCCCCCCEEEEEECCCCCCC WIKI 피디아에 신자경선생묘의 넓이 뭔지" 379 | ] 380 | }, 381 | "metadata": { 382 | "tags": [] 383 | }, 384 | "execution_count": 6 385 | } 386 | ] 387 | }, 388 | { 389 | "cell_type": "markdown", 390 | "metadata": { 391 | "id": "Rd8SGLptBBZO", 392 | "colab_type": "text" 393 | }, 394 | "source": [ 395 | "### Intent Classification" 396 | ] 397 | }, 398 | { 399 | "cell_type": "markdown", 400 | "metadata": { 401 | "id": "qrQjYp7iBBZR", 402 | "colab_type": "text" 403 | }, 404 | "source": [ 405 | "#### 데이터 전처리" 406 | ] 407 | }, 408 | { 409 | "cell_type": "code", 410 | "metadata": { 411 | "id": "KuuZ73g6BBZT", 412 | "colab_type": "code", 413 | "colab": {} 414 | }, 415 | "source": [ 416 | "train_dataset = [(l, d) for d,l in zip(train_raw['intent'], train_raw['sentence'])]\n", 417 | "valid_dataset = [(l, d) for d,l in zip(validation_raw['intent'], validation_raw['sentence'])]" 418 | ], 419 | "execution_count": 0, 420 | "outputs": [] 421 | }, 422 | { 423 | "cell_type": "code", 424 | "metadata": { 425 | "id": "1nAgM9q_BBZc", 426 | "colab_type": "code", 427 | "colab": {} 428 | }, 429 | "source": [ 430 | "seq_len = 32\n", 431 | "\n", 432 | "length_clip = nlp.data.PadSequence(seq_len, pad_val=\"\")\n", 433 | "\n", 434 | "def preprocess(data):\n", 435 | " sent, entity = data\n", 436 | " char_sent = list(str(sent))\n", 437 | " char_entity = str(entity)\n", 438 | " return(length_clip(char_sent), len(sent),char_entity)\n", 439 | "\n", 440 | "def preprocess_dataset(dataset):\n", 441 | " start = time.time()\n", 442 | " with mp.Pool() as pool:\n", 443 | " dataset = gluon.data.SimpleDataset(pool.map(preprocess, dataset))\n", 444 | " end = time.time()\n", 445 | " print('Done! Tokenizing Time={:.2f}s, #Sentences={}'\n", 446 | " .format(end - start, len(dataset)))\n", 447 | " return dataset\n" 448 | ], 449 | "execution_count": 0, 450 | "outputs": [] 451 | }, 452 | { 453 | "cell_type": "code", 454 | "metadata": { 455 | "id": "yZZpGLE5BBZj", 456 | "colab_type": "code", 457 | "colab": { 458 | "base_uri": "https://localhost:8080/", 459 | "height": 55 460 | }, 461 | "outputId": "7ed38251-0b55-4600-cabe-6b71bf087069" 462 | }, 463 | "source": [ 464 | "train_preprocessed = preprocess_dataset(train_dataset)\n", 465 | "valid_preprocessed = preprocess_dataset(valid_dataset)" 466 | ], 467 | "execution_count": 9, 468 | "outputs": [ 469 | { 470 | "output_type": "stream", 471 | "text": [ 472 | "Done! Tokenizing Time=0.13s, #Sentences=9000\n", 473 | "Done! Tokenizing Time=0.38s, #Sentences=1000\n" 474 | ], 475 | "name": "stdout" 476 | } 477 | ] 478 | }, 479 | { 480 | "cell_type": "code", 481 | "metadata": { 482 | "id": "8aYVByIKBBZt", 483 | "colab_type": "code", 484 | "colab": {} 485 | }, 486 | "source": [ 487 | "counter_sent = nlp.data.count_tokens(itertools.chain.from_iterable([c for c, _, _ in train_preprocessed]))\n", 488 | "counter_intent = nlp.data.count_tokens([c for _,_, c in train_preprocessed])" 489 | ], 490 | "execution_count": 0, 491 | "outputs": [] 492 | }, 493 | { 494 | "cell_type": "code", 495 | "metadata": { 496 | "id": "tNHQEllPBBZ5", 497 | "colab_type": "code", 498 | "colab": { 499 | "base_uri": "https://localhost:8080/", 500 | "height": 207 501 | }, 502 | "outputId": "6f9f6ed9-09f0-44af-8cdb-8ac2c742c848" 503 | }, 504 | "source": [ 505 | "counter_intent" 506 | ], 507 | "execution_count": 11, 508 | "outputs": [ 509 | { 510 | "output_type": "execute_result", 511 | "data": { 512 | "text/plain": [ 513 | "Counter({'age': 900,\n", 514 | " 'area': 900,\n", 515 | " 'belong_to': 900,\n", 516 | " 'birth_date': 900,\n", 517 | " 'birth_place': 900,\n", 518 | " 'definition': 900,\n", 519 | " 'height': 900,\n", 520 | " 'length': 900,\n", 521 | " 'weight': 900,\n", 522 | " 'width': 900})" 523 | ] 524 | }, 525 | "metadata": { 526 | "tags": [] 527 | }, 528 | "execution_count": 11 529 | } 530 | ] 531 | }, 532 | { 533 | "cell_type": "code", 534 | "metadata": { 535 | "id": "VE3NxndmBBaF", 536 | "colab_type": "code", 537 | "colab": {} 538 | }, 539 | "source": [ 540 | "vocab_sent = nlp.Vocab(counter_sent, bos_token=None, eos_token=None, min_freq=15)\n", 541 | "vocab_intent = nlp.Vocab(counter_intent, bos_token=None, eos_token=None, unknown_token=None, padding_token=None)" 542 | ], 543 | "execution_count": 0, 544 | "outputs": [] 545 | }, 546 | { 547 | "cell_type": "code", 548 | "metadata": { 549 | "id": "OaIDZ4b_BBaP", 550 | "colab_type": "code", 551 | "colab": { 552 | "base_uri": "https://localhost:8080/", 553 | "height": 226 554 | }, 555 | "outputId": "9fec1ab5-bcde-4b68-9fcc-4e8ba9b85c14" 556 | }, 557 | "source": [ 558 | "vocab_sent.idx_to_token[:10], vocab_intent.idx_to_token[:10], " 559 | ], 560 | "execution_count": 13, 561 | "outputs": [ 562 | { 563 | "output_type": "execute_result", 564 | "data": { 565 | "text/plain": [ 566 | "(['', '', ' ', 'I', '이', '색', '검', '의', '지', '아'],\n", 567 | " ['age',\n", 568 | " 'area',\n", 569 | " 'belong_to',\n", 570 | " 'birth_date',\n", 571 | " 'birth_place',\n", 572 | " 'definition',\n", 573 | " 'height',\n", 574 | " 'length',\n", 575 | " 'weight',\n", 576 | " 'width'])" 577 | ] 578 | }, 579 | "metadata": { 580 | "tags": [] 581 | }, 582 | "execution_count": 13 583 | } 584 | ] 585 | }, 586 | { 587 | "cell_type": "code", 588 | "metadata": { 589 | "id": "lxMs-jRVBBaY", 590 | "colab_type": "code", 591 | "colab": {} 592 | }, 593 | "source": [ 594 | "train_preprocessed_encoded = [(vocab_sent[sent], length ,vocab_intent[entity]) for sent, length ,entity in train_preprocessed ]\n", 595 | "valid = [(vocab_sent[sent], length ,vocab_intent[entity]) for sent, length ,entity in valid_preprocessed ]" 596 | ], 597 | "execution_count": 0, 598 | "outputs": [] 599 | }, 600 | { 601 | "cell_type": "code", 602 | "metadata": { 603 | "id": "huQE5dXMBBaf", 604 | "colab_type": "code", 605 | "colab": {} 606 | }, 607 | "source": [ 608 | "train, test = nlp.data.train_valid_split(train_preprocessed_encoded, valid_ratio=0.1)" 609 | ], 610 | "execution_count": 0, 611 | "outputs": [] 612 | }, 613 | { 614 | "cell_type": "code", 615 | "metadata": { 616 | "id": "GrvFaVwCBBal", 617 | "colab_type": "code", 618 | "colab": {} 619 | }, 620 | "source": [ 621 | "nbatch = 30\n", 622 | "batchify_fn = nlp.data.batchify.Tuple(nlp.data.batchify.Stack(),\n", 623 | " nlp.data.batchify.Stack('float32'),\n", 624 | " nlp.data.batchify.Stack())\n", 625 | "\n", 626 | "train_dataloader = gluon.data.DataLoader(train, batch_size=nbatch, batchify_fn=batchify_fn, shuffle=True)\n", 627 | "test_dataloader = gluon.data.DataLoader(test, batch_size=nbatch, batchify_fn=batchify_fn, shuffle=True)\n", 628 | "valid_dataloader = gluon.data.DataLoader(valid, batch_size=nbatch, batchify_fn=batchify_fn, shuffle=True)" 629 | ], 630 | "execution_count": 0, 631 | "outputs": [] 632 | }, 633 | { 634 | "cell_type": "markdown", 635 | "metadata": { 636 | "id": "DGsOoAObBBar", 637 | "colab_type": "text" 638 | }, 639 | "source": [ 640 | "#### 모델링 " 641 | ] 642 | }, 643 | { 644 | "cell_type": "code", 645 | "metadata": { 646 | "id": "kfyEnrV7BBau", 647 | "colab_type": "code", 648 | "colab": {} 649 | }, 650 | "source": [ 651 | "class IntentClassification(gluon.HybridBlock):\n", 652 | " def __init__(self, vocab_size, vocab_out_size, num_embed, seq_len, hidden_size, **kwargs):\n", 653 | " super(IntentClassification, self).__init__(**kwargs)\n", 654 | " self.seq_len = seq_len\n", 655 | " self.hidden_size = hidden_size \n", 656 | " self.vocab_out_size = vocab_out_size\n", 657 | " with self.name_scope():\n", 658 | " self.embed = nn.Embedding(input_dim=vocab_size, output_dim=num_embed)\n", 659 | " self.bigru = rnn.GRU(self.hidden_size, dropout=0.2, bidirectional=True)\n", 660 | " self.dense_prev = nn.Dense(10, flatten=False)\n", 661 | " self.dense = nn.Dense(self.vocab_out_size) \n", 662 | " \n", 663 | " def hybrid_forward(self, F ,inputs, length):\n", 664 | " em_out = self.embed(inputs)\n", 665 | " bigruout = self.bigru(em_out)\n", 666 | " masked_encoded = F.SequenceMask(bigruout,\n", 667 | " sequence_length=length,\n", 668 | " use_sequence_length=True).transpose((1,0,2))\n", 669 | " dense_out = self.dense_prev(masked_encoded)\n", 670 | " outs = self.dense(dense_out) \n", 671 | " return(outs)" 672 | ], 673 | "execution_count": 0, 674 | "outputs": [] 675 | }, 676 | { 677 | "cell_type": "code", 678 | "metadata": { 679 | "id": "3vndtQsyBBaz", 680 | "colab_type": "code", 681 | "colab": {} 682 | }, 683 | "source": [ 684 | "ctx = mx.gpu()\n", 685 | "\n", 686 | "model = IntentClassification(vocab_size = len(vocab_sent.idx_to_token), \n", 687 | " vocab_out_size=len(vocab_intent.idx_to_token), num_embed=50, seq_len=seq_len, hidden_size=30)" 688 | ], 689 | "execution_count": 0, 690 | "outputs": [] 691 | }, 692 | { 693 | "cell_type": "code", 694 | "metadata": { 695 | "id": "_-KdKk5ABBa5", 696 | "colab_type": "code", 697 | "colab": {} 698 | }, 699 | "source": [ 700 | "model.initialize(mx.initializer.Xavier(), ctx=ctx)" 701 | ], 702 | "execution_count": 0, 703 | "outputs": [] 704 | }, 705 | { 706 | "cell_type": "code", 707 | "metadata": { 708 | "id": "MizE5hD0BBbB", 709 | "colab_type": "code", 710 | "colab": {} 711 | }, 712 | "source": [ 713 | "trainer = gluon.Trainer(model.collect_params(),\"Adam\")\n", 714 | "loss = gluon.loss.SoftmaxCELoss() " 715 | ], 716 | "execution_count": 0, 717 | "outputs": [] 718 | }, 719 | { 720 | "cell_type": "code", 721 | "metadata": { 722 | "id": "0_7dByznBBbH", 723 | "colab_type": "code", 724 | "colab": {} 725 | }, 726 | "source": [ 727 | "model.hybridize()" 728 | ], 729 | "execution_count": 0, 730 | "outputs": [] 731 | }, 732 | { 733 | "cell_type": "code", 734 | "metadata": { 735 | "id": "UuPyjvu7BBbN", 736 | "colab_type": "code", 737 | "colab": {} 738 | }, 739 | "source": [ 740 | "def evaluate_accuracy(model, data_iter, ctx=ctx):\n", 741 | " acc = mx.metric.Accuracy()\n", 742 | " for i, (data, length, label) in enumerate(data_iter):\n", 743 | " data = data.as_in_context(ctx)\n", 744 | " label = label.as_in_context(ctx)\n", 745 | " length = length.as_in_context(ctx)\n", 746 | " output = model(data.T, length)\n", 747 | " predictions = nd.argmax(output, axis=1)\n", 748 | " acc.update(preds=predictions, labels=label)\n", 749 | " return(acc.get()[1])" 750 | ], 751 | "execution_count": 0, 752 | "outputs": [] 753 | }, 754 | { 755 | "cell_type": "code", 756 | "metadata": { 757 | "id": "eohIThCRBBbU", 758 | "colab_type": "code", 759 | "colab": {} 760 | }, 761 | "source": [ 762 | "def calculate_loss(model, data_iter, loss_obj, ctx=ctx):\n", 763 | " test_loss = []\n", 764 | " for i, (te_data, te_length, te_label) in enumerate(data_iter):\n", 765 | " te_data = te_data.as_in_context(ctx)\n", 766 | " te_label = te_label.as_in_context(ctx)\n", 767 | " te_length = te_length.as_in_context(ctx)\n", 768 | " te_output = model(te_data.T, te_length)\n", 769 | " loss_te = loss_obj(te_output, te_label)\n", 770 | " curr_loss = nd.mean(loss_te).asscalar()\n", 771 | " test_loss.append(curr_loss)\n", 772 | " return(np.mean(test_loss))" 773 | ], 774 | "execution_count": 0, 775 | "outputs": [] 776 | }, 777 | { 778 | "cell_type": "code", 779 | "metadata": { 780 | "id": "EdBPiGpYBBbc", 781 | "colab_type": "code", 782 | "colab": { 783 | "base_uri": "https://localhost:8080/", 784 | "height": 1000 785 | }, 786 | "outputId": "f42081fb-a7a7-428f-85d0-acd244b677cb" 787 | }, 788 | "source": [ 789 | "epochs = 100\n", 790 | "\n", 791 | "\n", 792 | "tot_test_loss = []\n", 793 | "tot_test_accu = []\n", 794 | "tot_train_loss = []\n", 795 | "tot_train_accu = []\n", 796 | "tot_valid_accu = [] \n", 797 | "for e in range(epochs):\n", 798 | " #batch training \n", 799 | " for i, (data, length, label) in enumerate(tqdm(train_dataloader)):\n", 800 | " data = data.as_in_context(ctx)\n", 801 | " label = label.as_in_context(ctx)\n", 802 | " length = length.as_in_context(ctx)\n", 803 | " with autograd.record():\n", 804 | " output = model(data.T, length)\n", 805 | " loss_ = loss(output, label)\n", 806 | " loss_.backward()\n", 807 | " trainer.step(data.shape[0])\n", 808 | "\n", 809 | " #caculate test loss\n", 810 | " if e % 10 == 0: \n", 811 | " test_loss = calculate_loss(model, test_dataloader, loss_obj = loss, ctx=ctx) \n", 812 | " train_loss = calculate_loss(model, train_dataloader, loss_obj = loss, ctx=ctx) \n", 813 | " test_accu = evaluate_accuracy(model, test_dataloader, ctx=ctx)\n", 814 | " train_accu = evaluate_accuracy(model, train_dataloader, ctx=ctx)\n", 815 | " valid_accu = evaluate_accuracy(model, valid_dataloader, ctx=ctx)\n", 816 | "\n", 817 | " print(\"Epoch %s. Train Loss: %s, Test Loss : %s,\" \\\n", 818 | " \" Test Accuracy : %s,\" \\\n", 819 | " \" Train Accuracy : %s : Valid Accuracy : %s\" % (e, train_loss, test_loss, test_accu, train_accu, valid_accu)) \n", 820 | " tot_test_loss.append(test_loss)\n", 821 | " tot_train_loss.append(train_loss)\n", 822 | " tot_test_accu.append(test_accu)\n", 823 | " tot_train_accu.append(train_accu)\n", 824 | " tot_valid_accu.append(valid_accu)" 825 | ], 826 | "execution_count": 24, 827 | "outputs": [ 828 | { 829 | "output_type": "stream", 830 | "text": [ 831 | "100%|██████████| 270/270 [00:01<00:00, 222.23it/s]\n", 832 | " 9%|▉ | 24/270 [00:00<00:01, 227.36it/s]" 833 | ], 834 | "name": "stderr" 835 | }, 836 | { 837 | "output_type": "stream", 838 | "text": [ 839 | "Epoch 0. Train Loss: 0.18470986, Test Loss : 0.18911381, Test Accuracy : 0.9455555555555556, Train Accuracy : 0.9514814814814815 : Valid Accuracy : 0.939\n" 840 | ], 841 | "name": "stdout" 842 | }, 843 | { 844 | "output_type": "stream", 845 | "text": [ 846 | "100%|██████████| 270/270 [00:01<00:00, 234.59it/s]\n", 847 | "100%|██████████| 270/270 [00:01<00:00, 245.90it/s]\n", 848 | "100%|██████████| 270/270 [00:01<00:00, 209.04it/s]\n", 849 | "100%|██████████| 270/270 [00:01<00:00, 236.16it/s]\n", 850 | "100%|██████████| 270/270 [00:01<00:00, 235.68it/s]\n", 851 | "100%|██████████| 270/270 [00:01<00:00, 230.73it/s]\n", 852 | "100%|██████████| 270/270 [00:01<00:00, 226.29it/s]\n", 853 | "100%|██████████| 270/270 [00:01<00:00, 243.93it/s]\n", 854 | "100%|██████████| 270/270 [00:01<00:00, 206.43it/s]\n", 855 | "100%|██████████| 270/270 [00:01<00:00, 208.71it/s]\n", 856 | " 9%|▊ | 23/270 [00:00<00:01, 228.15it/s]" 857 | ], 858 | "name": "stderr" 859 | }, 860 | { 861 | "output_type": "stream", 862 | "text": [ 863 | "Epoch 10. Train Loss: 0.0107786665, Test Loss : 0.03377559, Test Accuracy : 0.99, Train Accuracy : 0.9971604938271605 : Valid Accuracy : 0.989\n" 864 | ], 865 | "name": "stdout" 866 | }, 867 | { 868 | "output_type": "stream", 869 | "text": [ 870 | "100%|██████████| 270/270 [00:01<00:00, 233.27it/s]\n", 871 | "100%|██████████| 270/270 [00:01<00:00, 232.03it/s]\n", 872 | "100%|██████████| 270/270 [00:01<00:00, 241.32it/s]\n", 873 | "100%|██████████| 270/270 [00:01<00:00, 227.35it/s]\n", 874 | "100%|██████████| 270/270 [00:01<00:00, 226.52it/s]\n", 875 | "100%|██████████| 270/270 [00:01<00:00, 230.68it/s]\n", 876 | "100%|██████████| 270/270 [00:01<00:00, 227.91it/s]\n", 877 | "100%|██████████| 270/270 [00:01<00:00, 242.82it/s]\n", 878 | "100%|██████████| 270/270 [00:01<00:00, 227.59it/s]\n", 879 | "100%|██████████| 270/270 [00:01<00:00, 226.63it/s]\n", 880 | " 9%|▉ | 25/270 [00:00<00:01, 241.38it/s]" 881 | ], 882 | "name": "stderr" 883 | }, 884 | { 885 | "output_type": "stream", 886 | "text": [ 887 | "Epoch 20. Train Loss: 6.623845e-05, Test Loss : 0.022001153, Test Accuracy : 0.9933333333333333, Train Accuracy : 1.0 : Valid Accuracy : 0.995\n" 888 | ], 889 | "name": "stdout" 890 | }, 891 | { 892 | "output_type": "stream", 893 | "text": [ 894 | "100%|██████████| 270/270 [00:01<00:00, 213.77it/s]\n", 895 | "100%|██████████| 270/270 [00:01<00:00, 237.85it/s]\n", 896 | "100%|██████████| 270/270 [00:01<00:00, 230.91it/s]\n", 897 | "100%|██████████| 270/270 [00:01<00:00, 235.98it/s]\n", 898 | "100%|██████████| 270/270 [00:01<00:00, 235.30it/s]\n", 899 | "100%|██████████| 270/270 [00:01<00:00, 243.00it/s]\n", 900 | "100%|██████████| 270/270 [00:01<00:00, 237.30it/s]\n", 901 | "100%|██████████| 270/270 [00:01<00:00, 226.72it/s]\n", 902 | "100%|██████████| 270/270 [00:01<00:00, 240.38it/s]\n", 903 | "100%|██████████| 270/270 [00:01<00:00, 231.97it/s]\n", 904 | " 9%|▉ | 25/270 [00:00<00:01, 242.76it/s]" 905 | ], 906 | "name": "stderr" 907 | }, 908 | { 909 | "output_type": "stream", 910 | "text": [ 911 | "Epoch 30. Train Loss: 1.3015229e-05, Test Loss : 0.029010795, Test Accuracy : 0.9933333333333333, Train Accuracy : 1.0 : Valid Accuracy : 0.994\n" 912 | ], 913 | "name": "stdout" 914 | }, 915 | { 916 | "output_type": "stream", 917 | "text": [ 918 | "100%|██████████| 270/270 [00:01<00:00, 226.23it/s]\n", 919 | "100%|██████████| 270/270 [00:01<00:00, 228.21it/s]\n", 920 | "100%|██████████| 270/270 [00:01<00:00, 225.57it/s]\n", 921 | "100%|██████████| 270/270 [00:01<00:00, 230.60it/s]\n", 922 | "100%|██████████| 270/270 [00:01<00:00, 245.43it/s]\n", 923 | "100%|██████████| 270/270 [00:01<00:00, 241.21it/s]\n", 924 | "100%|██████████| 270/270 [00:01<00:00, 222.03it/s]\n", 925 | "100%|██████████| 270/270 [00:01<00:00, 231.73it/s]\n", 926 | "100%|██████████| 270/270 [00:01<00:00, 245.55it/s]\n", 927 | "100%|██████████| 270/270 [00:01<00:00, 238.74it/s]\n", 928 | " 10%|▉ | 26/270 [00:00<00:00, 254.18it/s]" 929 | ], 930 | "name": "stderr" 931 | }, 932 | { 933 | "output_type": "stream", 934 | "text": [ 935 | "Epoch 40. Train Loss: 2.8318507e-06, Test Loss : 0.036456905, Test Accuracy : 0.9911111111111112, Train Accuracy : 1.0 : Valid Accuracy : 0.992\n" 936 | ], 937 | "name": "stdout" 938 | }, 939 | { 940 | "output_type": "stream", 941 | "text": [ 942 | "100%|██████████| 270/270 [00:01<00:00, 222.40it/s]\n", 943 | "100%|██████████| 270/270 [00:01<00:00, 242.70it/s]\n", 944 | "100%|██████████| 270/270 [00:01<00:00, 234.89it/s]\n", 945 | "100%|██████████| 270/270 [00:01<00:00, 234.75it/s]\n", 946 | "100%|██████████| 270/270 [00:01<00:00, 240.13it/s]\n", 947 | "100%|██████████| 270/270 [00:01<00:00, 247.74it/s]\n", 948 | "100%|██████████| 270/270 [00:01<00:00, 241.33it/s]\n", 949 | "100%|██████████| 270/270 [00:01<00:00, 232.07it/s]\n", 950 | "100%|██████████| 270/270 [00:01<00:00, 225.22it/s]\n", 951 | "100%|██████████| 270/270 [00:01<00:00, 232.54it/s]\n", 952 | " 8%|▊ | 22/270 [00:00<00:01, 218.93it/s]" 953 | ], 954 | "name": "stderr" 955 | }, 956 | { 957 | "output_type": "stream", 958 | "text": [ 959 | "Epoch 50. Train Loss: 6.434988e-07, Test Loss : 0.041608997, Test Accuracy : 0.9911111111111112, Train Accuracy : 1.0 : Valid Accuracy : 0.992\n" 960 | ], 961 | "name": "stdout" 962 | }, 963 | { 964 | "output_type": "stream", 965 | "text": [ 966 | "100%|██████████| 270/270 [00:01<00:00, 223.49it/s]\n", 967 | "100%|██████████| 270/270 [00:01<00:00, 230.21it/s]\n", 968 | "100%|██████████| 270/270 [00:01<00:00, 245.84it/s]\n", 969 | "100%|██████████| 270/270 [00:01<00:00, 224.68it/s]\n", 970 | "100%|██████████| 270/270 [00:01<00:00, 219.22it/s]\n", 971 | "100%|██████████| 270/270 [00:01<00:00, 204.98it/s]\n", 972 | "100%|██████████| 270/270 [00:01<00:00, 248.60it/s]\n", 973 | "100%|██████████| 270/270 [00:01<00:00, 229.96it/s]\n", 974 | "100%|██████████| 270/270 [00:01<00:00, 223.54it/s]\n", 975 | "100%|██████████| 270/270 [00:01<00:00, 232.01it/s]\n", 976 | " 9%|▉ | 24/270 [00:00<00:01, 237.49it/s]" 977 | ], 978 | "name": "stderr" 979 | }, 980 | { 981 | "output_type": "stream", 982 | "text": [ 983 | "Epoch 60. Train Loss: 1.4549369e-07, Test Loss : 0.051583666, Test Accuracy : 0.9911111111111112, Train Accuracy : 1.0 : Valid Accuracy : 0.992\n" 984 | ], 985 | "name": "stdout" 986 | }, 987 | { 988 | "output_type": "stream", 989 | "text": [ 990 | "100%|██████████| 270/270 [00:01<00:00, 237.42it/s]\n", 991 | "100%|██████████| 270/270 [00:01<00:00, 219.72it/s]\n", 992 | "100%|██████████| 270/270 [00:01<00:00, 221.40it/s]\n", 993 | "100%|██████████| 270/270 [00:01<00:00, 229.97it/s]\n", 994 | "100%|██████████| 270/270 [00:01<00:00, 200.17it/s]\n", 995 | "100%|██████████| 270/270 [00:01<00:00, 206.36it/s]\n", 996 | "100%|██████████| 270/270 [00:01<00:00, 224.62it/s]\n", 997 | "100%|██████████| 270/270 [00:01<00:00, 224.25it/s]\n", 998 | "100%|██████████| 270/270 [00:01<00:00, 238.12it/s]\n", 999 | "100%|██████████| 270/270 [00:01<00:00, 236.44it/s]\n", 1000 | " 10%|█ | 27/270 [00:00<00:00, 256.06it/s]" 1001 | ], 1002 | "name": "stderr" 1003 | }, 1004 | { 1005 | "output_type": "stream", 1006 | "text": [ 1007 | "Epoch 70. Train Loss: 9.1341244e-05, Test Loss : 0.021057624, Test Accuracy : 0.9944444444444445, Train Accuracy : 1.0 : Valid Accuracy : 0.995\n" 1008 | ], 1009 | "name": "stdout" 1010 | }, 1011 | { 1012 | "output_type": "stream", 1013 | "text": [ 1014 | "100%|██████████| 270/270 [00:01<00:00, 231.05it/s]\n", 1015 | "100%|██████████| 270/270 [00:01<00:00, 232.42it/s]\n", 1016 | "100%|██████████| 270/270 [00:01<00:00, 220.26it/s]\n", 1017 | "100%|██████████| 270/270 [00:01<00:00, 237.14it/s]\n", 1018 | "100%|██████████| 270/270 [00:01<00:00, 221.74it/s]\n", 1019 | "100%|██████████| 270/270 [00:01<00:00, 227.62it/s]\n", 1020 | "100%|██████████| 270/270 [00:01<00:00, 237.04it/s]\n", 1021 | "100%|██████████| 270/270 [00:01<00:00, 231.40it/s]\n", 1022 | "100%|██████████| 270/270 [00:01<00:00, 224.03it/s]\n", 1023 | "100%|██████████| 270/270 [00:01<00:00, 228.41it/s]\n", 1024 | " 8%|▊ | 22/270 [00:00<00:01, 213.39it/s]" 1025 | ], 1026 | "name": "stderr" 1027 | }, 1028 | { 1029 | "output_type": "stream", 1030 | "text": [ 1031 | "Epoch 80. Train Loss: 1.5107484e-05, Test Loss : 0.020093959, Test Accuracy : 0.9955555555555555, Train Accuracy : 1.0 : Valid Accuracy : 0.995\n" 1032 | ], 1033 | "name": "stdout" 1034 | }, 1035 | { 1036 | "output_type": "stream", 1037 | "text": [ 1038 | "100%|██████████| 270/270 [00:01<00:00, 246.41it/s]\n", 1039 | "100%|██████████| 270/270 [00:01<00:00, 239.37it/s]\n", 1040 | "100%|██████████| 270/270 [00:01<00:00, 235.47it/s]\n", 1041 | "100%|██████████| 270/270 [00:01<00:00, 236.65it/s]\n", 1042 | "100%|██████████| 270/270 [00:01<00:00, 243.75it/s]\n", 1043 | "100%|██████████| 270/270 [00:01<00:00, 227.64it/s]\n", 1044 | "100%|██████████| 270/270 [00:01<00:00, 245.69it/s]\n", 1045 | "100%|██████████| 270/270 [00:01<00:00, 240.43it/s]\n", 1046 | "100%|██████████| 270/270 [00:01<00:00, 235.14it/s]\n", 1047 | "100%|██████████| 270/270 [00:01<00:00, 251.99it/s]\n", 1048 | " 10%|█ | 28/270 [00:00<00:00, 270.58it/s]" 1049 | ], 1050 | "name": "stderr" 1051 | }, 1052 | { 1053 | "output_type": "stream", 1054 | "text": [ 1055 | "Epoch 90. Train Loss: 4.6649916e-06, Test Loss : 0.019216476, Test Accuracy : 0.9944444444444445, Train Accuracy : 1.0 : Valid Accuracy : 0.996\n" 1056 | ], 1057 | "name": "stdout" 1058 | }, 1059 | { 1060 | "output_type": "stream", 1061 | "text": [ 1062 | "100%|██████████| 270/270 [00:01<00:00, 231.63it/s]\n", 1063 | "100%|██████████| 270/270 [00:01<00:00, 235.63it/s]\n", 1064 | "100%|██████████| 270/270 [00:01<00:00, 246.76it/s]\n", 1065 | "100%|██████████| 270/270 [00:01<00:00, 234.70it/s]\n", 1066 | "100%|██████████| 270/270 [00:01<00:00, 228.10it/s]\n", 1067 | "100%|██████████| 270/270 [00:01<00:00, 233.21it/s]\n", 1068 | "100%|██████████| 270/270 [00:01<00:00, 241.09it/s]\n", 1069 | "100%|██████████| 270/270 [00:01<00:00, 238.56it/s]\n", 1070 | "100%|██████████| 270/270 [00:01<00:00, 239.69it/s]\n" 1071 | ], 1072 | "name": "stderr" 1073 | } 1074 | ] 1075 | }, 1076 | { 1077 | "cell_type": "code", 1078 | "metadata": { 1079 | "id": "9nx-k3BRHvPo", 1080 | "colab_type": "code", 1081 | "colab": {} 1082 | }, 1083 | "source": [ 1084 | "model.collect_params().reset_ctx(mx.cpu())" 1085 | ], 1086 | "execution_count": 0, 1087 | "outputs": [] 1088 | }, 1089 | { 1090 | "cell_type": "code", 1091 | "metadata": { 1092 | "id": "iIyNj4FFBBb_", 1093 | "colab_type": "code", 1094 | "colab": {} 1095 | }, 1096 | "source": [ 1097 | "def get_intent(sent):\n", 1098 | " sent_len = len(sent)\n", 1099 | " coded_sent = vocab_sent[length_clip(list(sent))]\n", 1100 | " co = nd.array(coded_sent).expand_dims(axis=1)\n", 1101 | " ret_code = model(co, nd.array([sent_len,]))\n", 1102 | " ret_seq = vocab_intent.to_tokens(ret_code.argmax(axis=1).asnumpy().astype('int').tolist())\n", 1103 | " return(''.join(ret_seq))" 1104 | ], 1105 | "execution_count": 0, 1106 | "outputs": [] 1107 | }, 1108 | { 1109 | "cell_type": "code", 1110 | "metadata": { 1111 | "id": "chPfI_HNHHva", 1112 | "colab_type": "code", 1113 | "colab": { 1114 | "base_uri": "https://localhost:8080/", 1115 | "height": 36 1116 | }, 1117 | "outputId": "3b1e8269-d47f-4ff7-92b5-340e2f96ba2e" 1118 | }, 1119 | "source": [ 1120 | "get_intent(\"파이콘이 뭔지 알려줘?\")" 1121 | ], 1122 | "execution_count": 31, 1123 | "outputs": [ 1124 | { 1125 | "output_type": "execute_result", 1126 | "data": { 1127 | "text/plain": [ 1128 | "'definition'" 1129 | ] 1130 | }, 1131 | "metadata": { 1132 | "tags": [] 1133 | }, 1134 | "execution_count": 31 1135 | } 1136 | ] 1137 | }, 1138 | { 1139 | "cell_type": "markdown", 1140 | "metadata": { 1141 | "id": "jwhz3EetBBcE", 1142 | "colab_type": "text" 1143 | }, 1144 | "source": [ 1145 | "### TODO\n", 1146 | "- 개별 Intent와 Entity 모형을 하나의 모형으로 구축해본다. (Multi-Task Learning) \n", 1147 | " - 분류 성능이 좋아지는가? 학습 수렴 속도는 어떠한가?" 1148 | ] 1149 | } 1150 | ] 1151 | } -------------------------------------------------------------------------------- /code/3_2_entity_tagging_pycon2019.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "kernelspec": { 6 | "name": "python3", 7 | "display_name": "Python 3" 8 | }, 9 | "language_info": { 10 | "codemirror_mode": { 11 | "name": "ipython", 12 | "version": 3 13 | }, 14 | "file_extension": ".py", 15 | "mimetype": "text/x-python", 16 | "name": "python", 17 | "nbconvert_exporter": "python", 18 | "pygments_lexer": "ipython3", 19 | "version": "3.5.2" 20 | }, 21 | "colab": { 22 | "name": "entity_tagging.ipynb", 23 | "version": "0.3.2", 24 | "provenance": [], 25 | "collapsed_sections": [] 26 | }, 27 | "accelerator": "GPU" 28 | }, 29 | "cells": [ 30 | { 31 | "cell_type": "code", 32 | "metadata": { 33 | "id": "xex7CIsaIZms", 34 | "colab_type": "code", 35 | "colab": { 36 | "base_uri": "https://localhost:8080/", 37 | "height": 302 38 | }, 39 | "outputId": "a2728546-990f-407b-8b17-8956e47a0335" 40 | }, 41 | "source": [ 42 | "!pip install mxnet-cu100\n", 43 | "!pip install gluonnlp pandas tqdm" 44 | ], 45 | "execution_count": 2, 46 | "outputs": [ 47 | { 48 | "output_type": "stream", 49 | "text": [ 50 | "Requirement already satisfied: mxnet-cu100 in /usr/local/lib/python3.6/dist-packages (1.5.0)\n", 51 | "Requirement already satisfied: numpy<2.0.0,>1.16.0 in /usr/local/lib/python3.6/dist-packages (from mxnet-cu100) (1.16.4)\n", 52 | "Requirement already satisfied: requests<3,>=2.20.0 in /usr/local/lib/python3.6/dist-packages (from mxnet-cu100) (2.21.0)\n", 53 | "Requirement already satisfied: graphviz<0.9.0,>=0.8.1 in /usr/local/lib/python3.6/dist-packages (from mxnet-cu100) (0.8.4)\n", 54 | "Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.20.0->mxnet-cu100) (3.0.4)\n", 55 | "Requirement already satisfied: idna<2.9,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.20.0->mxnet-cu100) (2.8)\n", 56 | "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.20.0->mxnet-cu100) (2019.6.16)\n", 57 | "Requirement already satisfied: urllib3<1.25,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.20.0->mxnet-cu100) (1.24.3)\n", 58 | "Requirement already satisfied: gluonnlp in /usr/local/lib/python3.6/dist-packages (0.8.0)\n", 59 | "Requirement already satisfied: pandas in /usr/local/lib/python3.6/dist-packages (0.24.2)\n", 60 | "Requirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (4.28.1)\n", 61 | "Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from gluonnlp) (1.16.4)\n", 62 | "Requirement already satisfied: pytz>=2011k in /usr/local/lib/python3.6/dist-packages (from pandas) (2018.9)\n", 63 | "Requirement already satisfied: python-dateutil>=2.5.0 in /usr/local/lib/python3.6/dist-packages (from pandas) (2.5.3)\n", 64 | "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.6/dist-packages (from python-dateutil>=2.5.0->pandas) (1.12.0)\n" 65 | ], 66 | "name": "stdout" 67 | } 68 | ] 69 | }, 70 | { 71 | "cell_type": "markdown", 72 | "metadata": { 73 | "id": "_NZPM9avIXjH", 74 | "colab_type": "text" 75 | }, 76 | "source": [ 77 | "## Entity Taggging" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "metadata": { 83 | "id": "fS9E3KCwIXjM", 84 | "colab_type": "code", 85 | "colab": {} 86 | }, 87 | "source": [ 88 | "import pandas as pd\n", 89 | "import numpy as np\n", 90 | "from mxnet.gluon import nn, rnn\n", 91 | "from mxnet import gluon, autograd\n", 92 | "import gluonnlp as nlp\n", 93 | "from mxnet import nd \n", 94 | "import mxnet as mx\n", 95 | "import time\n", 96 | "import itertools\n", 97 | "from tqdm import tqdm\n", 98 | "import multiprocessing as mp" 99 | ], 100 | "execution_count": 0, 101 | "outputs": [] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "metadata": { 106 | "id": "7rei93JgIXja", 107 | "colab_type": "code", 108 | "colab": {} 109 | }, 110 | "source": [ 111 | "train_raw = pd.read_csv(\"https://www.dropbox.com/s/83n0uoy20rd2vq4/trainset.txt?dl=1\",names=['intent', 'entity', 'sentence'], sep='\\t')\n", 112 | "#validation_raw = pd.read_csv(\"https://www.dropbox.com/s/kbl7kw54jdo2550/test_hidden.txt?dl=1\",names=['intent', 'entity', 'sentence'], sep='\\t')\n", 113 | "validation_raw = pd.read_csv(\"https://www.dropbox.com/s/enxp9yt9cstcal2/validation.txt?dl=1\",names=['intent', 'entity', 'sentence'], sep='\\t')" 114 | ], 115 | "execution_count": 0, 116 | "outputs": [] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "metadata": { 121 | "id": "up8TL-t9IXjo", 122 | "colab_type": "code", 123 | "colab": { 124 | "base_uri": "https://localhost:8080/", 125 | "height": 990 126 | }, 127 | "outputId": "0c6416b8-25ec-4652-f4e3-7a288567609d" 128 | }, 129 | "source": [ 130 | "train_raw.head(30)" 131 | ], 132 | "execution_count": 5, 133 | "outputs": [ 134 | { 135 | "output_type": "execute_result", 136 | "data": { 137 | "text/html": [ 138 | "
\n", 139 | "\n", 152 | "\n", 153 | " \n", 154 | " \n", 155 | " \n", 156 | " \n", 157 | " \n", 158 | " \n", 159 | " \n", 160 | " \n", 161 | " \n", 162 | " \n", 163 | " \n", 164 | " \n", 165 | " \n", 166 | " \n", 167 | " \n", 168 | " \n", 169 | " \n", 170 | " \n", 171 | " \n", 172 | " \n", 173 | " \n", 174 | " \n", 175 | " \n", 176 | " \n", 177 | " \n", 178 | " \n", 179 | " \n", 180 | " \n", 181 | " \n", 182 | " \n", 183 | " \n", 184 | " \n", 185 | " \n", 186 | " \n", 187 | " \n", 188 | " \n", 189 | " \n", 190 | " \n", 191 | " \n", 192 | " \n", 193 | " \n", 194 | " \n", 195 | " \n", 196 | " \n", 197 | " \n", 198 | " \n", 199 | " \n", 200 | " \n", 201 | " \n", 202 | " \n", 203 | " \n", 204 | " \n", 205 | " \n", 206 | " \n", 207 | " \n", 208 | " \n", 209 | " \n", 210 | " \n", 211 | " \n", 212 | " \n", 213 | " \n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | " \n", 223 | " \n", 224 | " \n", 225 | " \n", 226 | " \n", 227 | " \n", 228 | " \n", 229 | " \n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | " \n", 274 | " \n", 275 | " \n", 276 | " \n", 277 | " \n", 278 | " \n", 279 | " \n", 280 | " \n", 281 | " \n", 282 | " \n", 283 | " \n", 284 | " \n", 285 | " \n", 286 | " \n", 287 | " \n", 288 | " \n", 289 | " \n", 290 | " \n", 291 | " \n", 292 | " \n", 293 | " \n", 294 | " \n", 295 | " \n", 296 | " \n", 297 | " \n", 298 | " \n", 299 | " \n", 300 | " \n", 301 | " \n", 302 | " \n", 303 | " \n", 304 | " \n", 305 | " \n", 306 | " \n", 307 | " \n", 308 | " \n", 309 | " \n", 310 | " \n", 311 | " \n", 312 | " \n", 313 | " \n", 314 | " \n", 315 | " \n", 316 | " \n", 317 | " \n", 318 | " \n", 319 | " \n", 320 | " \n", 321 | " \n", 322 | " \n", 323 | " \n", 324 | " \n", 325 | " \n", 326 | " \n", 327 | " \n", 328 | " \n", 329 | " \n", 330 | " \n", 331 | " \n", 332 | " \n", 333 | " \n", 334 | " \n", 335 | " \n", 336 | " \n", 337 | " \n", 338 | " \n", 339 | " \n", 340 | " \n", 341 | " \n", 342 | " \n", 343 | "
intententitysentence
0areaEECCCCCCCCCCCCCCCCCCC자강의 면적은 얼마 정도되는지 알려줄래
1birth_dateCCCCCCCCCCCCEEECCCCCCCCCCCCWIKI PEDIA로 변재일 생년월일을 알고 싶어
2ageEEEEEEEEEEECCCCCCCCCCCCCCCCC남쪽 물고기자리 알파 나이가 위키백과사전으로 얼마야
3lengthEEEECCCCCCCCCCCCCCCCCC삼양터널의 총 길이 위키백과사전에서 뭐야
4birth_placeEEEEEECCCCCCCCCCC코니 윌리스의 태어난 곳은 뭐지
5weightCCCCCCCCCCCCEEEECCCCCCCCCCCCCWIKI백과사전 검색 AA12의 무게가 얼만지 찾아봐
6definitionCCCCCCCCCCCCCEEECCCCCCCCWIKIPEDIA백과로 라이프 찾아서 말해줘
7heightEEEEEEEECCCCCCCCCCCCCCCCCCC송파 헬리오시티 구조물 높이 위키 피디아에서 뭐야
8birth_dateCCCEEEEEECCCCCCCCCCCCCCC검색 HLKVAM 언제 출생했는지를 검색해라
9heightCCCCCCCCEEEEEECCCCCCCC위키 피디아에 푸조 508 전고가 몇이야
10lengthCCCEEEEECCCCCCC검색 호몬혼 섬 길이를 찾아
11definitionEEEEECCCCCCCCCCCCC영산중학교 좀 위키피디아사전 검색
12ageCCCCCCEEEEEECCCCCCC위키백과로 침보라조 산 나이 어떤지
13birth_dateEEEEEEECCCCCCCC마무드 아스라의 출생 찾아줘
14birth_placeCCCCCCEEEEEEECCCCCCCCC위키 백과 조제 카리오카의 출생지를 찾아
15birth_dateCCCEEEEEECCCCCCCCC검색 제이 개츠비 생년월일은 뭐지
16lengthEEEECCCCCCCCCCCCCCCCC증약터널의 길이가 얼마쯤인지 혹시 알아
17belong_toEEEEEEEEEEEEEEEECCCCCCCCCCCCCC리히텐슈타인의 한스 아담 2세 소속사는 어딘지 검색해봐
18heightCCCCCCCCCCCCEEEEECCCCCCCCCWIKI사전백과 검색 벨록스여우의 높이는 얼만지
19ageEEEEEECCCCCCCCC파블롭스키구의 나이를 찾아줘
20widthEEEEEEECCCCCCCCCCCCCC사카피솔라 섬의 너비는 WIKI에서 뭐
21birth_placeEECCCCCCCCCCCCCCCCC나미는 태어난 곳이 WIKI로 뭔지
22weightCCCCCEEEEECCCCCCC위키에서 피니스테르의 무게 찾기
23birth_placeCCCEEEEEEECCCCCCCCCCC검색 카를 야스퍼스 출신지역이 어디라고
24widthEEEEEEEEEEECCCCCCC63식 병력수송장갑차의 폭 얼만지
25birth_placeCCCCCEEECCCCCCCCCCCC검색으로 강마에가 출생 장소를 찾아줘
26birth_dateEEEEEECCCCCCCCCCCCCC쿠죠 히카리의 언제 출생했는지 탐색해
27lengthEEECCCCCCCCCCC사하라의 길이가 얼마쯤이지
28areaEEEECCCCCCCCC송대산성의 면적은 얼만지
29areaCCCCCCCCCCEEEEEECCCCCCCWIKI 피디아에 신자경선생묘의 넓이 뭔지
\n", 344 | "
" 345 | ], 346 | "text/plain": [ 347 | " intent entity sentence\n", 348 | "0 area EECCCCCCCCCCCCCCCCCCC 자강의 면적은 얼마 정도되는지 알려줄래\n", 349 | "1 birth_date CCCCCCCCCCCCEEECCCCCCCCCCCC WIKI PEDIA로 변재일 생년월일을 알고 싶어\n", 350 | "2 age EEEEEEEEEEECCCCCCCCCCCCCCCCC 남쪽 물고기자리 알파 나이가 위키백과사전으로 얼마야\n", 351 | "3 length EEEECCCCCCCCCCCCCCCCCC 삼양터널의 총 길이 위키백과사전에서 뭐야\n", 352 | "4 birth_place EEEEEECCCCCCCCCCC 코니 윌리스의 태어난 곳은 뭐지\n", 353 | "5 weight CCCCCCCCCCCCEEEECCCCCCCCCCCCC WIKI백과사전 검색 AA12의 무게가 얼만지 찾아봐\n", 354 | "6 definition CCCCCCCCCCCCCEEECCCCCCCC WIKIPEDIA백과로 라이프 찾아서 말해줘\n", 355 | "7 height EEEEEEEECCCCCCCCCCCCCCCCCCC 송파 헬리오시티 구조물 높이 위키 피디아에서 뭐야\n", 356 | "8 birth_date CCCEEEEEECCCCCCCCCCCCCCC 검색 HLKVAM 언제 출생했는지를 검색해라\n", 357 | "9 height CCCCCCCCEEEEEECCCCCCCC 위키 피디아에 푸조 508 전고가 몇이야\n", 358 | "10 length CCCEEEEECCCCCCC 검색 호몬혼 섬 길이를 찾아\n", 359 | "11 definition EEEEECCCCCCCCCCCCC 영산중학교 좀 위키피디아사전 검색\n", 360 | "12 age CCCCCCEEEEEECCCCCCC 위키백과로 침보라조 산 나이 어떤지\n", 361 | "13 birth_date EEEEEEECCCCCCCC 마무드 아스라의 출생 찾아줘\n", 362 | "14 birth_place CCCCCCEEEEEEECCCCCCCCC 위키 백과 조제 카리오카의 출생지를 찾아\n", 363 | "15 birth_date CCCEEEEEECCCCCCCCC 검색 제이 개츠비 생년월일은 뭐지\n", 364 | "16 length EEEECCCCCCCCCCCCCCCCC 증약터널의 길이가 얼마쯤인지 혹시 알아\n", 365 | "17 belong_to EEEEEEEEEEEEEEEECCCCCCCCCCCCCC 리히텐슈타인의 한스 아담 2세 소속사는 어딘지 검색해봐\n", 366 | "18 height CCCCCCCCCCCCEEEEECCCCCCCCC WIKI사전백과 검색 벨록스여우의 높이는 얼만지\n", 367 | "19 age EEEEEECCCCCCCCC 파블롭스키구의 나이를 찾아줘\n", 368 | "20 width EEEEEEECCCCCCCCCCCCCC 사카피솔라 섬의 너비는 WIKI에서 뭐\n", 369 | "21 birth_place EECCCCCCCCCCCCCCCCC 나미는 태어난 곳이 WIKI로 뭔지\n", 370 | "22 weight CCCCCEEEEECCCCCCC 위키에서 피니스테르의 무게 찾기\n", 371 | "23 birth_place CCCEEEEEEECCCCCCCCCCC 검색 카를 야스퍼스 출신지역이 어디라고\n", 372 | "24 width EEEEEEEEEEECCCCCCC 63식 병력수송장갑차의 폭 얼만지\n", 373 | "25 birth_place CCCCCEEECCCCCCCCCCCC 검색으로 강마에가 출생 장소를 찾아줘\n", 374 | "26 birth_date EEEEEECCCCCCCCCCCCCC 쿠죠 히카리의 언제 출생했는지 탐색해\n", 375 | "27 length EEECCCCCCCCCCC 사하라의 길이가 얼마쯤이지\n", 376 | "28 area EEEECCCCCCCCC 송대산성의 면적은 얼만지\n", 377 | "29 area CCCCCCCCCCEEEEEECCCCCCC WIKI 피디아에 신자경선생묘의 넓이 뭔지" 378 | ] 379 | }, 380 | "metadata": { 381 | "tags": [] 382 | }, 383 | "execution_count": 5 384 | } 385 | ] 386 | }, 387 | { 388 | "cell_type": "markdown", 389 | "metadata": { 390 | "id": "IRSAAujAIXj0", 391 | "colab_type": "text" 392 | }, 393 | "source": [ 394 | "#### 데이터 전처리" 395 | ] 396 | }, 397 | { 398 | "cell_type": "code", 399 | "metadata": { 400 | "id": "mLJgprxPIXj3", 401 | "colab_type": "code", 402 | "colab": {} 403 | }, 404 | "source": [ 405 | "train_dataset = [(l, d) for d,l in zip(train_raw['entity'], train_raw['sentence'])]\n", 406 | "valid_dataset = [(l, d) for d,l in zip(validation_raw['entity'], validation_raw['sentence'])]" 407 | ], 408 | "execution_count": 0, 409 | "outputs": [] 410 | }, 411 | { 412 | "cell_type": "code", 413 | "metadata": { 414 | "id": "OOByF7V-IXj_", 415 | "colab_type": "code", 416 | "colab": {} 417 | }, 418 | "source": [ 419 | "seq_len = 32\n", 420 | "\n", 421 | "length_clip = nlp.data.PadSequence(seq_len, pad_val=\"\")\n", 422 | "\n", 423 | "def preprocess(data):\n", 424 | " sent, entity = data\n", 425 | " char_sent = list(str(sent))\n", 426 | " char_entity = list(str(entity))\n", 427 | " return(length_clip(char_sent), len(sent),length_clip(char_entity))\n", 428 | "\n", 429 | "def preprocess_dataset(dataset):\n", 430 | " start = time.time()\n", 431 | " with mp.Pool() as pool:\n", 432 | " dataset = gluon.data.SimpleDataset(pool.map(preprocess, dataset))\n", 433 | " end = time.time()\n", 434 | " print('Done! Tokenizing Time={:.2f}s, #Sentences={}'\n", 435 | " .format(end - start, len(dataset)))\n", 436 | " return dataset\n" 437 | ], 438 | "execution_count": 0, 439 | "outputs": [] 440 | }, 441 | { 442 | "cell_type": "code", 443 | "metadata": { 444 | "id": "0XZMULFiIXkJ", 445 | "colab_type": "code", 446 | "colab": { 447 | "base_uri": "https://localhost:8080/", 448 | "height": 55 449 | }, 450 | "outputId": "355ba816-6bfc-4c70-8503-08fab70f2f66" 451 | }, 452 | "source": [ 453 | "train_preprocessed = preprocess_dataset(train_dataset)\n", 454 | "valid_preprocessed = preprocess_dataset(valid_dataset)" 455 | ], 456 | "execution_count": 8, 457 | "outputs": [ 458 | { 459 | "output_type": "stream", 460 | "text": [ 461 | "Done! Tokenizing Time=0.34s, #Sentences=9000\n", 462 | "Done! Tokenizing Time=0.13s, #Sentences=1000\n" 463 | ], 464 | "name": "stdout" 465 | } 466 | ] 467 | }, 468 | { 469 | "cell_type": "code", 470 | "metadata": { 471 | "id": "5Y5xFJOKIXkS", 472 | "colab_type": "code", 473 | "colab": {} 474 | }, 475 | "source": [ 476 | "counter_sent = nlp.data.count_tokens(itertools.chain.from_iterable([c for c, _, _ in train_preprocessed]))\n", 477 | "counter_entity = nlp.data.count_tokens(itertools.chain.from_iterable([c for _,_, c in train_preprocessed]))" 478 | ], 479 | "execution_count": 0, 480 | "outputs": [] 481 | }, 482 | { 483 | "cell_type": "code", 484 | "metadata": { 485 | "id": "d_54U3R_IXka", 486 | "colab_type": "code", 487 | "colab": {} 488 | }, 489 | "source": [ 490 | "vocab_sent = nlp.Vocab(counter_sent, bos_token=None, eos_token=None, min_freq=15)\n", 491 | "vocab_entity = nlp.Vocab(counter_entity, bos_token=None, eos_token=None, unknown_token=None ,min_freq=15)" 492 | ], 493 | "execution_count": 0, 494 | "outputs": [] 495 | }, 496 | { 497 | "cell_type": "code", 498 | "metadata": { 499 | "id": "zJ8pwPwXIXkq", 500 | "colab_type": "code", 501 | "colab": { 502 | "base_uri": "https://localhost:8080/", 503 | "height": 55 504 | }, 505 | "outputId": "a122c09c-9300-4bee-fb76-ab7292141336" 506 | }, 507 | "source": [ 508 | "vocab_sent.idx_to_token[:10], vocab_entity.idx_to_token[:10], " 509 | ], 510 | "execution_count": 11, 511 | "outputs": [ 512 | { 513 | "output_type": "execute_result", 514 | "data": { 515 | "text/plain": [ 516 | "(['', '', ' ', 'I', '이', '색', '검', '의', '지', '아'],\n", 517 | " ['', 'C', 'E'])" 518 | ] 519 | }, 520 | "metadata": { 521 | "tags": [] 522 | }, 523 | "execution_count": 11 524 | } 525 | ] 526 | }, 527 | { 528 | "cell_type": "code", 529 | "metadata": { 530 | "id": "hhkhqOGNIXkz", 531 | "colab_type": "code", 532 | "colab": {} 533 | }, 534 | "source": [ 535 | "train_preprocessed_encoded = [(vocab_sent[sent], length ,vocab_entity[entity]) for sent, length ,entity in train_preprocessed ]\n", 536 | "valid = [(vocab_sent[sent], length ,vocab_entity[entity]) for sent, length ,entity in valid_preprocessed ]" 537 | ], 538 | "execution_count": 0, 539 | "outputs": [] 540 | }, 541 | { 542 | "cell_type": "code", 543 | "metadata": { 544 | "id": "Me3AVOOeIXk9", 545 | "colab_type": "code", 546 | "colab": {} 547 | }, 548 | "source": [ 549 | "train, test = nlp.data.train_valid_split(train_preprocessed_encoded, valid_ratio=0.1)" 550 | ], 551 | "execution_count": 0, 552 | "outputs": [] 553 | }, 554 | { 555 | "cell_type": "code", 556 | "metadata": { 557 | "id": "pHhp6acqIXlD", 558 | "colab_type": "code", 559 | "colab": {} 560 | }, 561 | "source": [ 562 | "nbatch = 30\n", 563 | "batchify_fn = nlp.data.batchify.Tuple(nlp.data.batchify.Stack(),\n", 564 | " nlp.data.batchify.Stack('float32'),\n", 565 | " nlp.data.batchify.Stack())\n", 566 | "\n", 567 | "train_dataloader = gluon.data.DataLoader(train, batch_size=nbatch, batchify_fn=batchify_fn, shuffle=True)\n", 568 | "test_dataloader = gluon.data.DataLoader(test, batch_size=nbatch, batchify_fn=batchify_fn, shuffle=True)\n", 569 | "valid_dataloader = gluon.data.DataLoader(valid, batch_size=nbatch, batchify_fn=batchify_fn, shuffle=True)" 570 | ], 571 | "execution_count": 0, 572 | "outputs": [] 573 | }, 574 | { 575 | "cell_type": "markdown", 576 | "metadata": { 577 | "id": "rt5gw0_dIXlP", 578 | "colab_type": "text" 579 | }, 580 | "source": [ 581 | "#### 모델링 " 582 | ] 583 | }, 584 | { 585 | "cell_type": "code", 586 | "metadata": { 587 | "id": "NfxurevYIXlS", 588 | "colab_type": "code", 589 | "colab": {} 590 | }, 591 | "source": [ 592 | "class EntityTagger(gluon.HybridBlock):\n", 593 | " def __init__(self, vocab_size, vocab_out_size, num_embed, hidden_size, use_attention=False, **kwargs):\n", 594 | " super(EntityTagger, self).__init__(**kwargs)\n", 595 | " self.hidden_size = hidden_size \n", 596 | " self.vocab_out_size = vocab_out_size\n", 597 | " self.use_attention = use_attention\n", 598 | " with self.name_scope():\n", 599 | " self.embed = nn.Embedding(input_dim=vocab_size, output_dim=num_embed)\n", 600 | " self.bigru = rnn.GRU(self.hidden_size, dropout=0.2, bidirectional=True)\n", 601 | " self.dense_prev = nn.Dense(10, flatten=False)\n", 602 | " self.dense = nn.Dense(self.vocab_out_size, flatten=False)\n", 603 | " if self.use_attention:\n", 604 | " self.attention = nlp.model.MLPAttentionCell(30, dropout=0.2)\n", 605 | " \n", 606 | " def hybrid_forward(self, F ,inputs, length):\n", 607 | " em_out = self.embed(inputs)\n", 608 | " bigruout = self.bigru(em_out)\n", 609 | " masked_encoded = F.SequenceMask(bigruout,\n", 610 | " sequence_length=length,\n", 611 | " use_sequence_length=True).transpose((1,0,2))\n", 612 | " if self.use_attention:\n", 613 | " masked_encoded,_ = self.attention(masked_encoded, masked_encoded)\n", 614 | " dense_out = self.dense_prev(masked_encoded)\n", 615 | " outs = self.dense(dense_out) \n", 616 | " return(outs)" 617 | ], 618 | "execution_count": 0, 619 | "outputs": [] 620 | }, 621 | { 622 | "cell_type": "code", 623 | "metadata": { 624 | "id": "MEWoD-4nIXla", 625 | "colab_type": "code", 626 | "colab": {} 627 | }, 628 | "source": [ 629 | "ctx = mx.gpu()\n", 630 | "\n", 631 | "model = EntityTagger(vocab_size = len(vocab_sent.idx_to_token), vocab_out_size=len(vocab_entity.idx_to_token), \n", 632 | " num_embed=50, hidden_size=30, use_attention=True)" 633 | ], 634 | "execution_count": 0, 635 | "outputs": [] 636 | }, 637 | { 638 | "cell_type": "code", 639 | "metadata": { 640 | "id": "LBs0t8kcIXlf", 641 | "colab_type": "code", 642 | "colab": {} 643 | }, 644 | "source": [ 645 | "model.initialize(mx.initializer.Xavier(), ctx=ctx)" 646 | ], 647 | "execution_count": 0, 648 | "outputs": [] 649 | }, 650 | { 651 | "cell_type": "code", 652 | "metadata": { 653 | "id": "xhMH6ZBMIXlk", 654 | "colab_type": "code", 655 | "colab": {} 656 | }, 657 | "source": [ 658 | "trainer = gluon.Trainer(model.collect_params(),\"Adam\")\n", 659 | "loss = gluon.loss.SoftmaxCELoss() " 660 | ], 661 | "execution_count": 0, 662 | "outputs": [] 663 | }, 664 | { 665 | "cell_type": "code", 666 | "metadata": { 667 | "id": "Fo8OAOnWIXls", 668 | "colab_type": "code", 669 | "colab": {} 670 | }, 671 | "source": [ 672 | "model.hybridize()" 673 | ], 674 | "execution_count": 0, 675 | "outputs": [] 676 | }, 677 | { 678 | "cell_type": "code", 679 | "metadata": { 680 | "id": "f0T8Pbg0IXlx", 681 | "colab_type": "code", 682 | "colab": { 683 | "base_uri": "https://localhost:8080/", 684 | "height": 264 685 | }, 686 | "outputId": "e51bff23-80aa-448f-eb5d-ebe4c8213f52" 687 | }, 688 | "source": [ 689 | "model" 690 | ], 691 | "execution_count": 54, 692 | "outputs": [ 693 | { 694 | "output_type": "execute_result", 695 | "data": { 696 | "text/plain": [ 697 | "EntityTagger(\n", 698 | " (embed): Embedding(481 -> 50, float32)\n", 699 | " (bigru): GRU(None -> 30, TNC, dropout=0.2, bidirectional)\n", 700 | " (dense_prev): Dense(None -> 10, linear)\n", 701 | " (dense): Dense(None -> 3, linear)\n", 702 | " (attention): MLPAttentionCell(\n", 703 | " (_act): Activation(tanh)\n", 704 | " (_dropout_layer): Dropout(p = 0.2, axes=())\n", 705 | " (_query_mid_layer): Dense(None -> 30, linear)\n", 706 | " (_key_mid_layer): Dense(None -> 30, linear)\n", 707 | " (_attention_score): Dense(30 -> 1, linear)\n", 708 | " )\n", 709 | ")" 710 | ] 711 | }, 712 | "metadata": { 713 | "tags": [] 714 | }, 715 | "execution_count": 54 716 | } 717 | ] 718 | }, 719 | { 720 | "cell_type": "code", 721 | "metadata": { 722 | "id": "o1NSWMLEIXl5", 723 | "colab_type": "code", 724 | "colab": {} 725 | }, 726 | "source": [ 727 | "def evaluate_accuracy(model, data_iter, ctx=ctx):\n", 728 | " corrected = 0\n", 729 | " n = 0\n", 730 | " for i, (data, length, label) in enumerate(data_iter):\n", 731 | " data = data.as_in_context(ctx)\n", 732 | " label = label.as_in_context(ctx)\n", 733 | " length = length.as_in_context(ctx)\n", 734 | " output = model(data.T, length)\n", 735 | " predictions = nd.argmax(output, axis=2)\n", 736 | " tf = predictions.astype('int64') == label\n", 737 | " for i in range(length.shape[0]):\n", 738 | " l = int(length[i].asscalar())\n", 739 | " corrected += nd.sum(tf[i][:l]).asscalar() == l\n", 740 | " n += 1\n", 741 | " #acc.update(preds=predictions, labels=label)\n", 742 | " return(corrected/n)" 743 | ], 744 | "execution_count": 0, 745 | "outputs": [] 746 | }, 747 | { 748 | "cell_type": "code", 749 | "metadata": { 750 | "id": "kYE2pbAKIXl_", 751 | "colab_type": "code", 752 | "colab": {} 753 | }, 754 | "source": [ 755 | "def calculate_loss(model, data_iter, loss_obj, ctx=ctx):\n", 756 | " test_loss = []\n", 757 | " for i, (te_data, te_length, te_label) in enumerate(data_iter):\n", 758 | " te_data = te_data.as_in_context(ctx)\n", 759 | " te_label = te_label.as_in_context(ctx)\n", 760 | " te_length = te_length.as_in_context(ctx)\n", 761 | " te_output = model(te_data.T, te_length)\n", 762 | " loss_te = loss_obj(te_output, te_label)\n", 763 | " curr_loss = nd.mean(loss_te).asscalar()\n", 764 | " test_loss.append(curr_loss)\n", 765 | " return(np.mean(test_loss))" 766 | ], 767 | "execution_count": 0, 768 | "outputs": [] 769 | }, 770 | { 771 | "cell_type": "code", 772 | "metadata": { 773 | "id": "KjuksSCWIXmE", 774 | "colab_type": "code", 775 | "colab": { 776 | "base_uri": "https://localhost:8080/", 777 | "height": 1000 778 | }, 779 | "outputId": "d3f63be7-941c-4e67-f674-06fad143dc71" 780 | }, 781 | "source": [ 782 | "epochs = 100\n", 783 | "\n", 784 | "\n", 785 | "tot_test_loss = []\n", 786 | "tot_test_accu = []\n", 787 | "tot_train_loss = []\n", 788 | "tot_train_accu = []\n", 789 | "tot_valid_accu = [] \n", 790 | "for e in range(epochs):\n", 791 | " #batch training \n", 792 | " for i, (data, length, label) in enumerate(tqdm(train_dataloader)):\n", 793 | " data = data.as_in_context(ctx)\n", 794 | " label = label.as_in_context(ctx)\n", 795 | " length = length.as_in_context(ctx)\n", 796 | " with autograd.record():\n", 797 | " output = model(data.T, length)\n", 798 | " loss_ = loss(output, label)\n", 799 | " loss_.backward()\n", 800 | " trainer.step(data.shape[0])\n", 801 | "\n", 802 | " #caculate test loss\n", 803 | " if e % 10 == 0: \n", 804 | " test_loss = calculate_loss(model, test_dataloader, loss_obj = loss, ctx=ctx) \n", 805 | " train_loss = calculate_loss(model, train_dataloader, loss_obj = loss, ctx=ctx) \n", 806 | " test_accu = evaluate_accuracy(model, test_dataloader, ctx=ctx)\n", 807 | " train_accu = evaluate_accuracy(model, train_dataloader, ctx=ctx)\n", 808 | " valid_accu = evaluate_accuracy(model, valid_dataloader, ctx=ctx)\n", 809 | "\n", 810 | " print(\"Epoch %s. Train Loss: %s, Test Loss : %s,\" \\\n", 811 | " \" Test Accuracy : %s,\" \\\n", 812 | " \" Train Accuracy : %s : Valid Accuracy : %s\" % (e, train_loss, test_loss, test_accu, train_accu, valid_accu)) \n", 813 | " tot_test_loss.append(test_loss)\n", 814 | " tot_train_loss.append(train_loss)\n", 815 | " tot_test_accu.append(test_accu)\n", 816 | " tot_train_accu.append(train_accu)\n", 817 | " tot_valid_accu.append(valid_accu)" 818 | ], 819 | "execution_count": 57, 820 | "outputs": [ 821 | { 822 | "output_type": "stream", 823 | "text": [ 824 | "100%|██████████| 270/270 [00:01<00:00, 240.68it/s]\n", 825 | " 9%|▊ | 23/270 [00:00<00:01, 226.74it/s]" 826 | ], 827 | "name": "stderr" 828 | }, 829 | { 830 | "output_type": "stream", 831 | "text": [ 832 | "Epoch 0. Train Loss: 0.04581105, Test Loss : 0.04551146, Test Accuracy : 0.7133333333333334, Train Accuracy : 0.7050617283950618 : Valid Accuracy : 0.704\n" 833 | ], 834 | "name": "stdout" 835 | }, 836 | { 837 | "output_type": "stream", 838 | "text": [ 839 | "100%|██████████| 270/270 [00:01<00:00, 175.59it/s]\n", 840 | "100%|██████████| 270/270 [00:01<00:00, 191.02it/s]\n", 841 | "100%|██████████| 270/270 [00:01<00:00, 167.02it/s]\n", 842 | "100%|██████████| 270/270 [00:01<00:00, 184.27it/s]\n", 843 | "100%|██████████| 270/270 [00:01<00:00, 202.12it/s]\n", 844 | "100%|██████████| 270/270 [00:01<00:00, 181.01it/s]\n", 845 | "100%|██████████| 270/270 [00:01<00:00, 176.99it/s]\n", 846 | "100%|██████████| 270/270 [00:01<00:00, 182.01it/s]\n", 847 | "100%|██████████| 270/270 [00:01<00:00, 197.47it/s]\n", 848 | "100%|██████████| 270/270 [00:01<00:00, 196.37it/s]\n", 849 | " 7%|▋ | 19/270 [00:00<00:01, 188.03it/s]" 850 | ], 851 | "name": "stderr" 852 | }, 853 | { 854 | "output_type": "stream", 855 | "text": [ 856 | "Epoch 10. Train Loss: 0.0020585118, Test Loss : 0.00977939, Test Accuracy : 0.9533333333333334, Train Accuracy : 0.9849382716049383 : Valid Accuracy : 0.963\n" 857 | ], 858 | "name": "stdout" 859 | }, 860 | { 861 | "output_type": "stream", 862 | "text": [ 863 | "100%|██████████| 270/270 [00:01<00:00, 178.87it/s]\n", 864 | "100%|██████████| 270/270 [00:01<00:00, 193.61it/s]\n", 865 | "100%|██████████| 270/270 [00:01<00:00, 180.71it/s]\n", 866 | "100%|██████████| 270/270 [00:01<00:00, 193.57it/s]\n", 867 | "100%|██████████| 270/270 [00:01<00:00, 186.82it/s]\n", 868 | "100%|██████████| 270/270 [00:01<00:00, 192.00it/s]\n", 869 | "100%|██████████| 270/270 [00:01<00:00, 179.13it/s]\n", 870 | "100%|██████████| 270/270 [00:01<00:00, 188.45it/s]\n", 871 | "100%|██████████| 270/270 [00:01<00:00, 206.38it/s]\n", 872 | "100%|██████████| 270/270 [00:01<00:00, 192.83it/s]\n", 873 | " 9%|▉ | 24/270 [00:00<00:01, 238.61it/s]" 874 | ], 875 | "name": "stderr" 876 | }, 877 | { 878 | "output_type": "stream", 879 | "text": [ 880 | "Epoch 20. Train Loss: 0.0019672026, Test Loss : 0.010476824, Test Accuracy : 0.9588888888888889, Train Accuracy : 0.9837037037037037 : Valid Accuracy : 0.961\n" 881 | ], 882 | "name": "stdout" 883 | }, 884 | { 885 | "output_type": "stream", 886 | "text": [ 887 | "100%|██████████| 270/270 [00:01<00:00, 200.94it/s]\n", 888 | "100%|██████████| 270/270 [00:01<00:00, 175.08it/s]\n", 889 | "100%|██████████| 270/270 [00:01<00:00, 185.28it/s]\n", 890 | "100%|██████████| 270/270 [00:01<00:00, 176.77it/s]\n", 891 | "100%|██████████| 270/270 [00:01<00:00, 175.91it/s]\n", 892 | "100%|██████████| 270/270 [00:01<00:00, 186.16it/s]\n", 893 | "100%|██████████| 270/270 [00:01<00:00, 184.73it/s]\n", 894 | "100%|██████████| 270/270 [00:01<00:00, 183.85it/s]\n", 895 | "100%|██████████| 270/270 [00:01<00:00, 195.25it/s]\n", 896 | "100%|██████████| 270/270 [00:01<00:00, 178.42it/s]\n", 897 | " 8%|▊ | 22/270 [00:00<00:01, 219.21it/s]" 898 | ], 899 | "name": "stderr" 900 | }, 901 | { 902 | "output_type": "stream", 903 | "text": [ 904 | "Epoch 30. Train Loss: 0.0002923226, Test Loss : 0.0132401, Test Accuracy : 0.9611111111111111, Train Accuracy : 0.9969135802469136 : Valid Accuracy : 0.972\n" 905 | ], 906 | "name": "stdout" 907 | }, 908 | { 909 | "output_type": "stream", 910 | "text": [ 911 | "100%|██████████| 270/270 [00:01<00:00, 178.83it/s]\n", 912 | "100%|██████████| 270/270 [00:01<00:00, 180.42it/s]\n", 913 | "100%|██████████| 270/270 [00:01<00:00, 185.86it/s]\n", 914 | "100%|██████████| 270/270 [00:01<00:00, 204.22it/s]\n", 915 | "100%|██████████| 270/270 [00:01<00:00, 179.28it/s]\n", 916 | "100%|██████████| 270/270 [00:01<00:00, 191.57it/s]\n", 917 | "100%|██████████| 270/270 [00:01<00:00, 177.67it/s]\n", 918 | "100%|██████████| 270/270 [00:01<00:00, 191.65it/s]\n", 919 | "100%|██████████| 270/270 [00:01<00:00, 198.73it/s]\n", 920 | "100%|██████████| 270/270 [00:01<00:00, 198.85it/s]\n", 921 | " 9%|▊ | 23/270 [00:00<00:01, 229.68it/s]" 922 | ], 923 | "name": "stderr" 924 | }, 925 | { 926 | "output_type": "stream", 927 | "text": [ 928 | "Epoch 40. Train Loss: 3.341425e-05, Test Loss : 0.012503282, Test Accuracy : 0.9688888888888889, Train Accuracy : 0.9997530864197531 : Valid Accuracy : 0.98\n" 929 | ], 930 | "name": "stdout" 931 | }, 932 | { 933 | "output_type": "stream", 934 | "text": [ 935 | "100%|██████████| 270/270 [00:01<00:00, 200.50it/s]\n", 936 | "100%|██████████| 270/270 [00:01<00:00, 195.91it/s]\n", 937 | "100%|██████████| 270/270 [00:01<00:00, 175.03it/s]\n", 938 | "100%|██████████| 270/270 [00:01<00:00, 200.46it/s]\n", 939 | "100%|██████████| 270/270 [00:01<00:00, 195.08it/s]\n", 940 | "100%|██████████| 270/270 [00:01<00:00, 185.90it/s]\n", 941 | "100%|██████████| 270/270 [00:01<00:00, 207.01it/s]\n", 942 | "100%|██████████| 270/270 [00:01<00:00, 197.68it/s]\n", 943 | "100%|██████████| 270/270 [00:01<00:00, 174.51it/s]\n", 944 | "100%|██████████| 270/270 [00:01<00:00, 189.14it/s]\n", 945 | " 6%|▌ | 16/270 [00:00<00:01, 154.69it/s]" 946 | ], 947 | "name": "stderr" 948 | }, 949 | { 950 | "output_type": "stream", 951 | "text": [ 952 | "Epoch 50. Train Loss: 7.684146e-06, Test Loss : 0.013977752, Test Accuracy : 0.9711111111111111, Train Accuracy : 1.0 : Valid Accuracy : 0.982\n" 953 | ], 954 | "name": "stdout" 955 | }, 956 | { 957 | "output_type": "stream", 958 | "text": [ 959 | "100%|██████████| 270/270 [00:01<00:00, 173.55it/s]\n", 960 | "100%|██████████| 270/270 [00:01<00:00, 180.13it/s]\n", 961 | "100%|██████████| 270/270 [00:01<00:00, 196.53it/s]\n", 962 | "100%|██████████| 270/270 [00:01<00:00, 178.13it/s]\n", 963 | "100%|██████████| 270/270 [00:01<00:00, 188.66it/s]\n", 964 | "100%|██████████| 270/270 [00:01<00:00, 176.26it/s]\n", 965 | "100%|██████████| 270/270 [00:01<00:00, 191.22it/s]\n", 966 | "100%|██████████| 270/270 [00:01<00:00, 173.48it/s]\n", 967 | "100%|██████████| 270/270 [00:01<00:00, 179.11it/s]\n", 968 | "100%|██████████| 270/270 [00:01<00:00, 189.78it/s]\n", 969 | " 8%|▊ | 22/270 [00:00<00:01, 217.96it/s]" 970 | ], 971 | "name": "stderr" 972 | }, 973 | { 974 | "output_type": "stream", 975 | "text": [ 976 | "Epoch 60. Train Loss: 0.0005088213, Test Loss : 0.014564858, Test Accuracy : 0.9644444444444444, Train Accuracy : 0.9950617283950617 : Valid Accuracy : 0.972\n" 977 | ], 978 | "name": "stdout" 979 | }, 980 | { 981 | "output_type": "stream", 982 | "text": [ 983 | "100%|██████████| 270/270 [00:01<00:00, 170.08it/s]\n", 984 | "100%|██████████| 270/270 [00:01<00:00, 190.07it/s]\n", 985 | "100%|██████████| 270/270 [00:01<00:00, 174.73it/s]\n", 986 | "100%|██████████| 270/270 [00:01<00:00, 170.88it/s]\n", 987 | "100%|██████████| 270/270 [00:01<00:00, 189.08it/s]\n", 988 | "100%|██████████| 270/270 [00:01<00:00, 201.39it/s]\n", 989 | "100%|██████████| 270/270 [00:01<00:00, 176.72it/s]\n", 990 | "100%|██████████| 270/270 [00:01<00:00, 183.48it/s]\n", 991 | "100%|██████████| 270/270 [00:01<00:00, 167.09it/s]\n", 992 | "100%|██████████| 270/270 [00:01<00:00, 187.38it/s]\n", 993 | " 7%|▋ | 20/270 [00:00<00:01, 198.38it/s]" 994 | ], 995 | "name": "stderr" 996 | }, 997 | { 998 | "output_type": "stream", 999 | "text": [ 1000 | "Epoch 70. Train Loss: 2.6472359e-05, Test Loss : 0.012694985, Test Accuracy : 0.9744444444444444, Train Accuracy : 0.9998765432098765 : Valid Accuracy : 0.983\n" 1001 | ], 1002 | "name": "stdout" 1003 | }, 1004 | { 1005 | "output_type": "stream", 1006 | "text": [ 1007 | "100%|██████████| 270/270 [00:01<00:00, 171.07it/s]\n", 1008 | "100%|██████████| 270/270 [00:01<00:00, 173.73it/s]\n", 1009 | "100%|██████████| 270/270 [00:01<00:00, 178.87it/s]\n", 1010 | "100%|██████████| 270/270 [00:01<00:00, 194.86it/s]\n", 1011 | "100%|██████████| 270/270 [00:01<00:00, 190.54it/s]\n", 1012 | "100%|██████████| 270/270 [00:01<00:00, 174.13it/s]\n", 1013 | "100%|██████████| 270/270 [00:01<00:00, 183.71it/s]\n", 1014 | "100%|██████████| 270/270 [00:01<00:00, 183.54it/s]\n", 1015 | "100%|██████████| 270/270 [00:01<00:00, 195.52it/s]\n", 1016 | "100%|██████████| 270/270 [00:01<00:00, 201.67it/s]\n", 1017 | " 8%|▊ | 21/270 [00:00<00:01, 209.25it/s]" 1018 | ], 1019 | "name": "stderr" 1020 | }, 1021 | { 1022 | "output_type": "stream", 1023 | "text": [ 1024 | "Epoch 80. Train Loss: 3.950536e-06, Test Loss : 0.013301165, Test Accuracy : 0.98, Train Accuracy : 1.0 : Valid Accuracy : 0.983\n" 1025 | ], 1026 | "name": "stdout" 1027 | }, 1028 | { 1029 | "output_type": "stream", 1030 | "text": [ 1031 | "100%|██████████| 270/270 [00:01<00:00, 182.13it/s]\n", 1032 | "100%|██████████| 270/270 [00:01<00:00, 194.95it/s]\n", 1033 | "100%|██████████| 270/270 [00:01<00:00, 180.37it/s]\n", 1034 | "100%|██████████| 270/270 [00:01<00:00, 193.45it/s]\n", 1035 | "100%|██████████| 270/270 [00:01<00:00, 182.30it/s]\n", 1036 | "100%|██████████| 270/270 [00:01<00:00, 192.52it/s]\n", 1037 | "100%|██████████| 270/270 [00:01<00:00, 178.73it/s]\n", 1038 | "100%|██████████| 270/270 [00:01<00:00, 182.85it/s]\n", 1039 | "100%|██████████| 270/270 [00:01<00:00, 184.27it/s]\n", 1040 | "100%|██████████| 270/270 [00:01<00:00, 175.67it/s]\n", 1041 | " 9%|▉ | 25/270 [00:00<00:01, 240.53it/s]" 1042 | ], 1043 | "name": "stderr" 1044 | }, 1045 | { 1046 | "output_type": "stream", 1047 | "text": [ 1048 | "Epoch 90. Train Loss: 1.37819925e-05, Test Loss : 0.009812501, Test Accuracy : 0.9822222222222222, Train Accuracy : 0.9997530864197531 : Valid Accuracy : 0.985\n" 1049 | ], 1050 | "name": "stdout" 1051 | }, 1052 | { 1053 | "output_type": "stream", 1054 | "text": [ 1055 | "100%|██████████| 270/270 [00:01<00:00, 192.52it/s]\n", 1056 | "100%|██████████| 270/270 [00:01<00:00, 188.30it/s]\n", 1057 | "100%|██████████| 270/270 [00:01<00:00, 184.50it/s]\n", 1058 | "100%|██████████| 270/270 [00:01<00:00, 187.15it/s]\n", 1059 | "100%|██████████| 270/270 [00:01<00:00, 178.01it/s]\n", 1060 | "100%|██████████| 270/270 [00:01<00:00, 186.42it/s]\n", 1061 | "100%|██████████| 270/270 [00:01<00:00, 198.87it/s]\n", 1062 | "100%|██████████| 270/270 [00:01<00:00, 197.53it/s]\n", 1063 | "100%|██████████| 270/270 [00:01<00:00, 182.71it/s]\n" 1064 | ], 1065 | "name": "stderr" 1066 | } 1067 | ] 1068 | }, 1069 | { 1070 | "cell_type": "code", 1071 | "metadata": { 1072 | "id": "Ztk3C4GKLTgw", 1073 | "colab_type": "code", 1074 | "colab": {} 1075 | }, 1076 | "source": [ 1077 | "model.collect_params().reset_ctx(mx.cpu())" 1078 | ], 1079 | "execution_count": 0, 1080 | "outputs": [] 1081 | }, 1082 | { 1083 | "cell_type": "code", 1084 | "metadata": { 1085 | "id": "HPLjG5IiIXmb", 1086 | "colab_type": "code", 1087 | "colab": {} 1088 | }, 1089 | "source": [ 1090 | "def get_entitytag(sent):\n", 1091 | " sent_len = len(sent)\n", 1092 | " coded_sent = vocab_sent[length_clip(list(sent))]\n", 1093 | " co = nd.array(coded_sent).expand_dims(axis=1)\n", 1094 | " ret_code = model(co, nd.array([sent_len,]))\n", 1095 | " ret_seq = vocab_entity.to_tokens(ret_code.argmax(axis=2)[0].asnumpy().astype('int').tolist())\n", 1096 | " return(''.join(ret_seq)[:sent_len])" 1097 | ], 1098 | "execution_count": 0, 1099 | "outputs": [] 1100 | }, 1101 | { 1102 | "cell_type": "code", 1103 | "metadata": { 1104 | "id": "YZKEW5WNIXmj", 1105 | "colab_type": "code", 1106 | "colab": { 1107 | "base_uri": "https://localhost:8080/", 1108 | "height": 36 1109 | }, 1110 | "outputId": "35dd56e0-4352-42a0-e3bb-f009b28bfed8" 1111 | }, 1112 | "source": [ 1113 | "get_entitytag(\"파이콘이 뭔지 알려줘\")" 1114 | ], 1115 | "execution_count": 60, 1116 | "outputs": [ 1117 | { 1118 | "output_type": "execute_result", 1119 | "data": { 1120 | "text/plain": [ 1121 | "'EEEECCCCCCC'" 1122 | ] 1123 | }, 1124 | "metadata": { 1125 | "tags": [] 1126 | }, 1127 | "execution_count": 60 1128 | } 1129 | ] 1130 | }, 1131 | { 1132 | "cell_type": "markdown", 1133 | "metadata": { 1134 | "id": "TD5giWUqIXmp", 1135 | "colab_type": "text" 1136 | }, 1137 | "source": [ 1138 | "### TODO\n", 1139 | "- Test Accuracy 95% 이상 올리기\n", 1140 | "- test_hidden 셋의 성능 90% 이상 올리기 \n", 1141 | "- Entity Tagging과 Intent Classification을 MultiTask Learning으로 통합해보기(성능이 좋아지나? 나빠지나?)" 1142 | ] 1143 | } 1144 | ] 1145 | } -------------------------------------------------------------------------------- /code/3_3_naver_review_classifications_gluon_bert.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "kernelspec": { 6 | "name": "python3", 7 | "display_name": "Python 3" 8 | }, 9 | "language_info": { 10 | "codemirror_mode": { 11 | "name": "ipython", 12 | "version": 3 13 | }, 14 | "file_extension": ".py", 15 | "mimetype": "text/x-python", 16 | "name": "python", 17 | "nbconvert_exporter": "python", 18 | "pygments_lexer": "ipython3", 19 | "version": "3.6.7" 20 | }, 21 | "toc": { 22 | "nav_menu": {}, 23 | "number_sections": true, 24 | "sideBar": true, 25 | "skip_h1_title": false, 26 | "toc_cell": false, 27 | "toc_position": {}, 28 | "toc_section_display": "block", 29 | "toc_window_display": false 30 | }, 31 | "colab": { 32 | "name": "naver_review_classifications_gluon_bert.ipynb", 33 | "version": "0.3.2", 34 | "provenance": [], 35 | "collapsed_sections": [] 36 | }, 37 | "accelerator": "GPU" 38 | }, 39 | "cells": [ 40 | { 41 | "cell_type": "code", 42 | "metadata": { 43 | "id": "GvjqOdDgT5zb", 44 | "colab_type": "code", 45 | "colab": { 46 | "base_uri": "https://localhost:8080/", 47 | "height": 302 48 | }, 49 | "outputId": "39c3f64c-5e35-45cf-ce44-57032b635e9a" 50 | }, 51 | "source": [ 52 | "!pip install mxnet-cu100\n", 53 | "!pip install gluonnlp pandas tqdm" 54 | ], 55 | "execution_count": 1, 56 | "outputs": [ 57 | { 58 | "output_type": "stream", 59 | "text": [ 60 | "Requirement already satisfied: mxnet-cu100 in /usr/local/lib/python3.6/dist-packages (1.5.0)\n", 61 | "Requirement already satisfied: numpy<2.0.0,>1.16.0 in /usr/local/lib/python3.6/dist-packages (from mxnet-cu100) (1.16.4)\n", 62 | "Requirement already satisfied: graphviz<0.9.0,>=0.8.1 in /usr/local/lib/python3.6/dist-packages (from mxnet-cu100) (0.8.4)\n", 63 | "Requirement already satisfied: requests<3,>=2.20.0 in /usr/local/lib/python3.6/dist-packages (from mxnet-cu100) (2.21.0)\n", 64 | "Requirement already satisfied: urllib3<1.25,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.20.0->mxnet-cu100) (1.24.3)\n", 65 | "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.20.0->mxnet-cu100) (2019.6.16)\n", 66 | "Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.20.0->mxnet-cu100) (3.0.4)\n", 67 | "Requirement already satisfied: idna<2.9,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.20.0->mxnet-cu100) (2.8)\n", 68 | "Requirement already satisfied: gluonnlp in /usr/local/lib/python3.6/dist-packages (0.8.0)\n", 69 | "Requirement already satisfied: pandas in /usr/local/lib/python3.6/dist-packages (0.24.2)\n", 70 | "Requirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (4.28.1)\n", 71 | "Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from gluonnlp) (1.16.4)\n", 72 | "Requirement already satisfied: pytz>=2011k in /usr/local/lib/python3.6/dist-packages (from pandas) (2018.9)\n", 73 | "Requirement already satisfied: python-dateutil>=2.5.0 in /usr/local/lib/python3.6/dist-packages (from pandas) (2.5.3)\n", 74 | "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.6/dist-packages (from python-dateutil>=2.5.0->pandas) (1.12.0)\n" 75 | ], 76 | "name": "stdout" 77 | } 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "metadata": { 83 | "id": "5mTNl7BKT2Fx", 84 | "colab_type": "code", 85 | "colab": {} 86 | }, 87 | "source": [ 88 | "import pandas as pd\n", 89 | "import numpy as np\n", 90 | "from mxnet.gluon import nn, rnn\n", 91 | "from mxnet import gluon, autograd\n", 92 | "import gluonnlp as nlp\n", 93 | "from mxnet import nd \n", 94 | "import mxnet as mx\n", 95 | "import time\n", 96 | "import itertools\n", 97 | "import random" 98 | ], 99 | "execution_count": 0, 100 | "outputs": [] 101 | }, 102 | { 103 | "cell_type": "markdown", 104 | "metadata": { 105 | "id": "Cc-zco-ST2F_", 106 | "colab_type": "text" 107 | }, 108 | "source": [ 109 | "### 버트 로딩 " 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "metadata": { 115 | "id": "89lsydguT2GG", 116 | "colab_type": "code", 117 | "colab": {} 118 | }, 119 | "source": [ 120 | "ctx = mx.gpu()" 121 | ], 122 | "execution_count": 0, 123 | "outputs": [] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "metadata": { 128 | "id": "SGwX9REiT2Gb", 129 | "colab_type": "code", 130 | "colab": { 131 | "base_uri": "https://localhost:8080/", 132 | "height": 94 133 | }, 134 | "outputId": "466d7fad-023f-40f3-8297-b04bdaad1f22" 135 | }, 136 | "source": [ 137 | "bert_base, vocabulary = nlp.model.get_model('bert_12_768_12',\n", 138 | " dataset_name='wiki_multilingual_cased',\n", 139 | " pretrained=True, ctx=ctx, use_pooler=True,\n", 140 | " use_decoder=False, use_classifier=False)\n", 141 | "#print(bert_base)" 142 | ], 143 | "execution_count": 5, 144 | "outputs": [ 145 | { 146 | "output_type": "stream", 147 | "text": [ 148 | "Vocab file is not found. Downloading.\n", 149 | "Downloading /root/.mxnet/models/1565856577.1304765wiki_multilingual_cased-0247cb44.zip from https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/dataset/vocab/wiki_multilingual_cased-0247cb44.zip...\n", 150 | "Downloading /root/.mxnet/models/bert_12_768_12_wiki_multilingual_cased-b0f57a20.zip from https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/models/bert_12_768_12_wiki_multilingual_cased-b0f57a20.zip...\n" 151 | ], 152 | "name": "stdout" 153 | } 154 | ] 155 | }, 156 | { 157 | "cell_type": "code", 158 | "metadata": { 159 | "id": "i69AUj9gT2Gk", 160 | "colab_type": "code", 161 | "colab": { 162 | "base_uri": "https://localhost:8080/", 163 | "height": 93 164 | }, 165 | "outputId": "f2338ed4-86f1-4118-aeea-cd0d07be15c0" 166 | }, 167 | "source": [ 168 | "ds = gluon.data.SimpleDataset([['나 보기가 역겨워', '김소월']])\n", 169 | "\n", 170 | "tok = nlp.data.BERTTokenizer(vocab=vocabulary, lower=False)\n", 171 | "\n", 172 | "trans = nlp.data.BERTSentenceTransform(tok, max_seq_length=10)\n", 173 | "\n", 174 | "list(ds.transform(trans))" 175 | ], 176 | "execution_count": 6, 177 | "outputs": [ 178 | { 179 | "output_type": "execute_result", 180 | "data": { 181 | "text/plain": [ 182 | "[(array([ 2, 8982, 9356, 47869, 9566, 3, 8935, 22333, 38851,\n", 183 | " 3], dtype=int32),\n", 184 | " array(10, dtype=int32),\n", 185 | " array([0, 0, 0, 0, 0, 0, 1, 1, 1, 1], dtype=int32))]" 186 | ] 187 | }, 188 | "metadata": { 189 | "tags": [] 190 | }, 191 | "execution_count": 6 192 | } 193 | ] 194 | }, 195 | { 196 | "cell_type": "code", 197 | "metadata": { 198 | "id": "4qy9g_UMVtdj", 199 | "colab_type": "code", 200 | "colab": { 201 | "base_uri": "https://localhost:8080/", 202 | "height": 796 203 | }, 204 | "outputId": "1c02e06b-94bc-4b9c-c9d2-0a04772f8469" 205 | }, 206 | "source": [ 207 | "!wget https://www.dropbox.com/s/374ftkec978br3d/ratings_train.txt?dl=1\n", 208 | "!wget https://www.dropbox.com/s/977gbwh542gdy94/ratings_test.txt?dl=1" 209 | ], 210 | "execution_count": 8, 211 | "outputs": [ 212 | { 213 | "output_type": "stream", 214 | "text": [ 215 | "--2019-08-15 08:13:44-- https://www.dropbox.com/s/374ftkec978br3d/ratings_train.txt?dl=1\n", 216 | "Resolving www.dropbox.com (www.dropbox.com)... 162.125.8.1, 2620:100:601b:1::a27d:801\n", 217 | "Connecting to www.dropbox.com (www.dropbox.com)|162.125.8.1|:443... connected.\n", 218 | "HTTP request sent, awaiting response... 301 Moved Permanently\n", 219 | "Location: /s/dl/374ftkec978br3d/ratings_train.txt [following]\n", 220 | "--2019-08-15 08:13:44-- https://www.dropbox.com/s/dl/374ftkec978br3d/ratings_train.txt\n", 221 | "Reusing existing connection to www.dropbox.com:443.\n", 222 | "HTTP request sent, awaiting response... 302 Found\n", 223 | "Location: https://ucef6f352d774589b42c518ef3f0.dl.dropboxusercontent.com/cd/0/get/Amq3zXJvW28FXK4jT8dPnncZI9ibrr6FYx7ZR_SOt-Z5Jt2lsXU9Y7bLGmO0LkPMnZ2eufdoF14xTEuRd-jV11A02AOXHYKmXj_MJPGzEGROAYWCe02sMg0a5Dnj0MkvXMo/file?dl=1# [following]\n", 224 | "--2019-08-15 08:13:45-- https://ucef6f352d774589b42c518ef3f0.dl.dropboxusercontent.com/cd/0/get/Amq3zXJvW28FXK4jT8dPnncZI9ibrr6FYx7ZR_SOt-Z5Jt2lsXU9Y7bLGmO0LkPMnZ2eufdoF14xTEuRd-jV11A02AOXHYKmXj_MJPGzEGROAYWCe02sMg0a5Dnj0MkvXMo/file?dl=1\n", 225 | "Resolving ucef6f352d774589b42c518ef3f0.dl.dropboxusercontent.com (ucef6f352d774589b42c518ef3f0.dl.dropboxusercontent.com)... 162.125.8.6, 2620:100:601b:6::a27d:806\n", 226 | "Connecting to ucef6f352d774589b42c518ef3f0.dl.dropboxusercontent.com (ucef6f352d774589b42c518ef3f0.dl.dropboxusercontent.com)|162.125.8.6|:443... connected.\n", 227 | "HTTP request sent, awaiting response... 200 OK\n", 228 | "Length: 14628807 (14M) [application/binary]\n", 229 | "Saving to: ‘ratings_train.txt?dl=1’\n", 230 | "\n", 231 | "ratings_train.txt?d 100%[===================>] 13.95M 84.2MB/s in 0.2s \n", 232 | "\n", 233 | "2019-08-15 08:13:45 (84.2 MB/s) - ‘ratings_train.txt?dl=1’ saved [14628807/14628807]\n", 234 | "\n", 235 | "--2019-08-15 08:13:47-- https://www.dropbox.com/s/977gbwh542gdy94/ratings_test.txt?dl=1\n", 236 | "Resolving www.dropbox.com (www.dropbox.com)... 162.125.8.1, 2620:100:601b:1::a27d:801\n", 237 | "Connecting to www.dropbox.com (www.dropbox.com)|162.125.8.1|:443... connected.\n", 238 | "HTTP request sent, awaiting response... 301 Moved Permanently\n", 239 | "Location: /s/dl/977gbwh542gdy94/ratings_test.txt [following]\n", 240 | "--2019-08-15 08:13:47-- https://www.dropbox.com/s/dl/977gbwh542gdy94/ratings_test.txt\n", 241 | "Reusing existing connection to www.dropbox.com:443.\n", 242 | "HTTP request sent, awaiting response... 302 Found\n", 243 | "Location: https://ucdbdf9608e9f1730a558af5fdfd.dl.dropboxusercontent.com/cd/0/get/Amo8KQEYEkexXaQpzbWDgzoSDijM32HbshuUSzGxyQwKytWsKsv3sS0036wIJ3t8bmVAvElO2q25futhDQAoZWhUZ2IwdfwPJ1SaQUsmsLjC4b6nbEQdT07FBx6woV--b3U/file?dl=1# [following]\n", 244 | "--2019-08-15 08:13:47-- https://ucdbdf9608e9f1730a558af5fdfd.dl.dropboxusercontent.com/cd/0/get/Amo8KQEYEkexXaQpzbWDgzoSDijM32HbshuUSzGxyQwKytWsKsv3sS0036wIJ3t8bmVAvElO2q25futhDQAoZWhUZ2IwdfwPJ1SaQUsmsLjC4b6nbEQdT07FBx6woV--b3U/file?dl=1\n", 245 | "Resolving ucdbdf9608e9f1730a558af5fdfd.dl.dropboxusercontent.com (ucdbdf9608e9f1730a558af5fdfd.dl.dropboxusercontent.com)... 162.125.8.6, 2620:100:601b:6::a27d:806\n", 246 | "Connecting to ucdbdf9608e9f1730a558af5fdfd.dl.dropboxusercontent.com (ucdbdf9608e9f1730a558af5fdfd.dl.dropboxusercontent.com)|162.125.8.6|:443... connected.\n", 247 | "HTTP request sent, awaiting response... 200 OK\n", 248 | "Length: 4893335 (4.7M) [application/binary]\n", 249 | "Saving to: ‘ratings_test.txt?dl=1’\n", 250 | "\n", 251 | "ratings_test.txt?dl 100%[===================>] 4.67M --.-KB/s in 0.1s \n", 252 | "\n", 253 | "2019-08-15 08:13:48 (37.9 MB/s) - ‘ratings_test.txt?dl=1’ saved [4893335/4893335]\n", 254 | "\n" 255 | ], 256 | "name": "stdout" 257 | } 258 | ] 259 | }, 260 | { 261 | "cell_type": "code", 262 | "metadata": { 263 | "id": "4LfCTweqT2Gt", 264 | "colab_type": "code", 265 | "colab": {} 266 | }, 267 | "source": [ 268 | "dataset_train = nlp.data.TSVDataset(\"ratings_train.txt?dl=1\", field_indices=[1,2], num_discard_samples=1)\n", 269 | "dataset_test = nlp.data.TSVDataset(\"ratings_test.txt?dl=1\", field_indices=[1,2], num_discard_samples=1)" 270 | ], 271 | "execution_count": 0, 272 | "outputs": [] 273 | }, 274 | { 275 | "cell_type": "code", 276 | "metadata": { 277 | "id": "pt0raV8uT2G2", 278 | "colab_type": "code", 279 | "colab": {} 280 | }, 281 | "source": [ 282 | "class BERTDataset(mx.gluon.data.Dataset):\n", 283 | " def __init__(self, dataset, sent_idx, label_idx, bert_tokenizer, max_len,\n", 284 | " pad, pair):\n", 285 | " transform = nlp.data.BERTSentenceTransform(\n", 286 | " bert_tokenizer, max_seq_length=max_len, pad=pad, pair=pair)\n", 287 | " sent_dataset = gluon.data.SimpleDataset([[\n", 288 | " i[sent_idx],\n", 289 | " ] for i in dataset])\n", 290 | " self.sentences = sent_dataset.transform(transform)\n", 291 | " self.labels = gluon.data.SimpleDataset(\n", 292 | " [np.array(np.int32(i[label_idx])) for i in dataset])\n", 293 | "\n", 294 | " def __getitem__(self, i):\n", 295 | " return (self.sentences[i] + (self.labels[i], ))\n", 296 | "\n", 297 | " def __len__(self):\n", 298 | " return (len(self.labels))\n" 299 | ], 300 | "execution_count": 0, 301 | "outputs": [] 302 | }, 303 | { 304 | "cell_type": "code", 305 | "metadata": { 306 | "id": "vtk-8pQST2G9", 307 | "colab_type": "code", 308 | "colab": {} 309 | }, 310 | "source": [ 311 | "bert_tokenizer = nlp.data.BERTTokenizer(vocabulary, lower=False)\n", 312 | "max_len = 64" 313 | ], 314 | "execution_count": 0, 315 | "outputs": [] 316 | }, 317 | { 318 | "cell_type": "code", 319 | "metadata": { 320 | "id": "_K_BLZP_T2HF", 321 | "colab_type": "code", 322 | "colab": {} 323 | }, 324 | "source": [ 325 | "data_train = BERTDataset(dataset_train, 0, 1, bert_tokenizer, max_len, True, False)\n", 326 | "data_test = BERTDataset(dataset_test, 0, 1, bert_tokenizer, max_len, True, False)" 327 | ], 328 | "execution_count": 0, 329 | "outputs": [] 330 | }, 331 | { 332 | "cell_type": "code", 333 | "metadata": { 334 | "id": "rhaw0H4ST2HM", 335 | "colab_type": "code", 336 | "colab": {} 337 | }, 338 | "source": [ 339 | "class BERTClassifier(nn.Block):\n", 340 | " def __init__(self,\n", 341 | " bert,\n", 342 | " num_classes=2,\n", 343 | " dropout=None,\n", 344 | " prefix=None,\n", 345 | " params=None):\n", 346 | " super(BERTClassifier, self).__init__(prefix=prefix, params=params)\n", 347 | " self.bert = bert\n", 348 | " with self.name_scope():\n", 349 | " self.classifier = nn.HybridSequential(prefix=prefix)\n", 350 | " if dropout:\n", 351 | " self.classifier.add(nn.Dropout(rate=dropout))\n", 352 | " self.classifier.add(nn.Dense(units=num_classes))\n", 353 | "\n", 354 | " def forward(self, inputs, token_types, valid_length=None):\n", 355 | " _, pooler = self.bert(inputs, token_types, valid_length)\n", 356 | " return self.classifier(pooler)\n", 357 | " " 358 | ], 359 | "execution_count": 0, 360 | "outputs": [] 361 | }, 362 | { 363 | "cell_type": "code", 364 | "metadata": { 365 | "id": "Y00BOPwST2HX", 366 | "colab_type": "code", 367 | "colab": {} 368 | }, 369 | "source": [ 370 | "model = BERTClassifier(bert_base, num_classes=2, dropout=0.3)\n", 371 | "# 분류 레이어만 초기화 한다. \n", 372 | "model.classifier.initialize(ctx=ctx)\n", 373 | "model.hybridize()\n", 374 | "\n", 375 | "# softmax cross entropy loss for classification\n", 376 | "loss_function = gluon.loss.SoftmaxCELoss()\n", 377 | "\n", 378 | "metric = mx.metric.Accuracy()" 379 | ], 380 | "execution_count": 0, 381 | "outputs": [] 382 | }, 383 | { 384 | "cell_type": "code", 385 | "metadata": { 386 | "id": "A2dLhnHkT2Hf", 387 | "colab_type": "code", 388 | "colab": {} 389 | }, 390 | "source": [ 391 | "batch_size = 16\n", 392 | "lr = 5e-5\n", 393 | "\n", 394 | "train_dataloader = mx.gluon.data.DataLoader(data_train, batch_size=batch_size, num_workers=5)\n", 395 | "test_dataloader = mx.gluon.data.DataLoader(data_test, batch_size=int(batch_size/2), num_workers=5)" 396 | ], 397 | "execution_count": 0, 398 | "outputs": [] 399 | }, 400 | { 401 | "cell_type": "code", 402 | "metadata": { 403 | "id": "ESo76UH-T2Hr", 404 | "colab_type": "code", 405 | "colab": {} 406 | }, 407 | "source": [ 408 | "trainer = gluon.Trainer(model.collect_params(), 'bertadam',\n", 409 | " {'learning_rate': lr, 'epsilon': 1e-9, 'wd':0.01})\n", 410 | "\n", 411 | "log_interval = 4\n", 412 | "num_epochs = 4" 413 | ], 414 | "execution_count": 0, 415 | "outputs": [] 416 | }, 417 | { 418 | "cell_type": "code", 419 | "metadata": { 420 | "id": "wspMBDOAT2H0", 421 | "colab_type": "code", 422 | "colab": {} 423 | }, 424 | "source": [ 425 | "# LayerNorm과 Bias에는 Weight Decay를 적용하지 않는다. \n", 426 | "for _, v in model.collect_params('.*beta|.*gamma|.*bias').items():\n", 427 | " v.wd_mult = 0.0\n", 428 | "params = [\n", 429 | " p for p in model.collect_params().values() if p.grad_req != 'null'\n", 430 | "]\n" 431 | ], 432 | "execution_count": 0, 433 | "outputs": [] 434 | }, 435 | { 436 | "cell_type": "code", 437 | "metadata": { 438 | "id": "NCR6AMKHT2H6", 439 | "colab_type": "code", 440 | "colab": {} 441 | }, 442 | "source": [ 443 | "def evaluate_accuracy(model, data_iter, ctx=ctx):\n", 444 | " acc = mx.metric.Accuracy()\n", 445 | " i = 0\n", 446 | " for i, (t,v,s, label) in enumerate(data_iter):\n", 447 | " token_ids = t.as_in_context(ctx)\n", 448 | " valid_length = v.as_in_context(ctx)\n", 449 | " segment_ids = s.as_in_context(ctx)\n", 450 | " label = label.as_in_context(ctx)\n", 451 | " output = model(token_ids, segment_ids, valid_length.astype('float32'))\n", 452 | " acc.update(preds=output, labels=label)\n", 453 | " if i > 1000:\n", 454 | " break\n", 455 | " i += 1\n", 456 | " return(acc.get()[1])" 457 | ], 458 | "execution_count": 0, 459 | "outputs": [] 460 | }, 461 | { 462 | "cell_type": "code", 463 | "metadata": { 464 | "id": "SkcW6GyeT2IA", 465 | "colab_type": "code", 466 | "colab": {} 467 | }, 468 | "source": [ 469 | "#learning rate warmup을 위한 준비 \n", 470 | "step_size = batch_size \n", 471 | "num_train_examples = len(data_train)\n", 472 | "num_train_steps = int(num_train_examples / step_size * num_epochs)\n", 473 | "warmup_ratio = 0.1\n", 474 | "num_warmup_steps = int(num_train_steps * warmup_ratio)\n", 475 | "step_num = 0" 476 | ], 477 | "execution_count": 0, 478 | "outputs": [] 479 | }, 480 | { 481 | "cell_type": "code", 482 | "metadata": { 483 | "id": "0mJ3Pw_VT2IH", 484 | "colab_type": "code", 485 | "colab": { 486 | "base_uri": "https://localhost:8080/", 487 | "height": 188 488 | }, 489 | "outputId": "9bd064ae-94aa-484e-bc85-657144f5aa8d" 490 | }, 491 | "source": [ 492 | "for epoch_id in range(num_epochs):\n", 493 | " metric.reset()\n", 494 | " step_loss = 0\n", 495 | " for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(train_dataloader):\n", 496 | " step_num += 1\n", 497 | " if step_num < num_warmup_steps:\n", 498 | " new_lr = lr * step_num / num_warmup_steps\n", 499 | " else:\n", 500 | " offset = (step_num - num_warmup_steps) * lr / (\n", 501 | " num_train_steps - num_warmup_steps)\n", 502 | " new_lr = lr - offset\n", 503 | " trainer.set_learning_rate(new_lr)\n", 504 | " with mx.autograd.record():\n", 505 | " # load data to GPU\n", 506 | " token_ids = token_ids.as_in_context(ctx)\n", 507 | " valid_length = valid_length.as_in_context(ctx)\n", 508 | " segment_ids = segment_ids.as_in_context(ctx)\n", 509 | " label = label.as_in_context(ctx)\n", 510 | "\n", 511 | " # forward computation\n", 512 | " out = model(token_ids, segment_ids, valid_length.astype('float32'))\n", 513 | " ls = loss_function(out, label).mean()\n", 514 | "\n", 515 | " # backward computation\n", 516 | " ls.backward()\n", 517 | " trainer.allreduce_grads()\n", 518 | " nlp.utils.clip_grad_global_norm(params, 1)\n", 519 | " trainer.update(token_ids.shape[0])\n", 520 | "\n", 521 | " step_loss += ls.asscalar()\n", 522 | " metric.update([label], [out])\n", 523 | " if (batch_id + 1) % (50) == 0:\n", 524 | " print('[Epoch {} Batch {}/{}] loss={:.4f}, lr={:.10f}, acc={:.3f}'\n", 525 | " .format(epoch_id + 1, batch_id + 1, len(train_dataloader),\n", 526 | " step_loss / log_interval,\n", 527 | " trainer.learning_rate, metric.get()[1]))\n", 528 | " step_loss = 0\n", 529 | " test_acc = evaluate_accuracy(model, test_dataloader, ctx)\n", 530 | " print('Test Acc : {}'.format(test_acc))" 531 | ], 532 | "execution_count": 0, 533 | "outputs": [ 534 | { 535 | "output_type": "stream", 536 | "text": [ 537 | "[Epoch 1 Batch 50/9375] loss=8.7709, lr=0.0000006667, acc=0.505\n", 538 | "[Epoch 1 Batch 100/9375] loss=8.5914, lr=0.0000013333, acc=0.526\n", 539 | "[Epoch 1 Batch 150/9375] loss=8.3027, lr=0.0000020000, acc=0.555\n", 540 | "[Epoch 1 Batch 200/9375] loss=7.6643, lr=0.0000026667, acc=0.579\n", 541 | "[Epoch 1 Batch 250/9375] loss=7.4951, lr=0.0000033333, acc=0.603\n", 542 | "[Epoch 1 Batch 300/9375] loss=7.2966, lr=0.0000040000, acc=0.620\n", 543 | "[Epoch 1 Batch 350/9375] loss=7.2736, lr=0.0000046667, acc=0.632\n", 544 | "[Epoch 1 Batch 400/9375] loss=7.0332, lr=0.0000053333, acc=0.641\n", 545 | "[Epoch 1 Batch 450/9375] loss=7.3415, lr=0.0000060000, acc=0.647\n" 546 | ], 547 | "name": "stdout" 548 | } 549 | ] 550 | } 551 | ] 552 | } -------------------------------------------------------------------------------- /slide/1.MXNet_Basic.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seujung/gluonnlp_tutorial/e3ddf2f834537e62b91e7ccbae5b3d15e226139f/slide/1.MXNet_Basic.pdf -------------------------------------------------------------------------------- /slide/2.Word_Embedding.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seujung/gluonnlp_tutorial/e3ddf2f834537e62b91e7ccbae5b3d15e226139f/slide/2.Word_Embedding.pdf -------------------------------------------------------------------------------- /slide/3_bert.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seujung/gluonnlp_tutorial/e3ddf2f834537e62b91e7ccbae5b3d15e226139f/slide/3_bert.pdf --------------------------------------------------------------------------------