├── .gitignore ├── README.md ├── kcm ├── README.md ├── ex_kcm │ ├── __init__.py │ └── kcm.py └── test_kcm.py ├── requirements.txt ├── template ├── README.md ├── example_homework │ ├── __init__.py │ ├── cubing.py │ └── squaring.py └── test_example_homework.py ├── w2v ├── README.md ├── ex_w2v │ ├── __init__.py │ └── w2v.py └── test_w2v.py ├── w2v_classifier ├── README.md ├── ex_classifier │ ├── __init__.py │ ├── a_sent2vec.py │ ├── b_dataloader.py │ ├── c_model.py │ └── d_predict.py ├── test_classifier.py └── udicstm_for_dataloader.csv └── w2v_gru_classifier ├── README.md ├── Taipei_FAQ_for_dataloader.csv ├── ex_classifier ├── __init__.py ├── a_sent2vec.py ├── b_dataloader.py ├── c_model.py └── d_predict.py └── test_classifier.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by .ignore support plugin (hsz.mobi) 2 | ### macOS template 3 | # General 4 | .DS_Store 5 | .AppleDouble 6 | .LSOverride 7 | 8 | # Icon must end with two \r 9 | Icon 10 | 11 | # Thumbnails 12 | ._* 13 | 14 | # Files that might appear in the root of a volume 15 | .DocumentRevisions-V100 16 | .fseventsd 17 | .Spotlight-V100 18 | .TemporaryItems 19 | .Trashes 20 | .VolumeIcon.icns 21 | .com.apple.timemachine.donotpresent 22 | 23 | # Directories potentially created on remote AFP share 24 | .AppleDB 25 | .AppleDesktop 26 | Network Trash Folder 27 | Temporary Items 28 | .apdisk 29 | 30 | ### JetBrains template 31 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider 32 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 33 | 34 | # User-specific stuff 35 | .idea/**/workspace.xml 36 | .idea/**/tasks.xml 37 | .idea/**/usage.statistics.xml 38 | .idea/**/dictionaries 39 | .idea/**/shelf 40 | 41 | # Generated files 42 | .idea/**/contentModel.xml 43 | 44 | # Sensitive or high-churn files 45 | .idea/**/dataSources/ 46 | .idea/**/dataSources.ids 47 | .idea/**/dataSources.local.xml 48 | .idea/**/sqlDataSources.xml 49 | .idea/**/dynamic.xml 50 | .idea/**/uiDesigner.xml 51 | .idea/**/dbnavigator.xml 52 | 53 | # Gradle 54 | .idea/**/gradle.xml 55 | .idea/**/libraries 56 | 57 | # Gradle and Maven with auto-import 58 | # When using Gradle or Maven with auto-import, you should exclude module files, 59 | # since they will be recreated, and may cause churn. Uncomment if using 60 | # auto-import. 61 | # .idea/artifacts 62 | # .idea/compiler.xml 63 | # .idea/jarRepositories.xml 64 | # .idea/modules.xml 65 | # .idea/*.iml 66 | # .idea/modules 67 | # *.iml 68 | # *.ipr 69 | 70 | # CMake 71 | cmake-build-*/ 72 | 73 | # Mongo Explorer plugin 74 | .idea/**/mongoSettings.xml 75 | 76 | # File-based project format 77 | *.iws 78 | 79 | # IntelliJ 80 | out/ 81 | 82 | # mpeltonen/sbt-idea plugin 83 | .idea_modules/ 84 | 85 | # JIRA plugin 86 | atlassian-ide-plugin.xml 87 | 88 | # Cursive Clojure plugin 89 | .idea/replstate.xml 90 | 91 | # Crashlytics plugin (for Android Studio and IntelliJ) 92 | com_crashlytics_export_strings.xml 93 | crashlytics.properties 94 | crashlytics-build.properties 95 | fabric.properties 96 | 97 | # Editor-based Rest Client 98 | .idea/httpRequests 99 | 100 | # Android studio 3.1+ serialized cache file 101 | .idea/caches/build_file_checksums.ser 102 | 103 | ### Python template 104 | # Byte-compiled / optimized / DLL files 105 | __pycache__/ 106 | *.py[cod] 107 | *$py.class 108 | 109 | # C extensions 110 | *.so 111 | 112 | # Distribution / packaging 113 | .Python 114 | build/ 115 | develop-eggs/ 116 | dist/ 117 | downloads/ 118 | eggs/ 119 | .eggs/ 120 | lib/ 121 | lib64/ 122 | parts/ 123 | sdist/ 124 | var/ 125 | wheels/ 126 | pip-wheel-metadata/ 127 | share/python-wheels/ 128 | *.egg-info/ 129 | .installed.cfg 130 | *.egg 131 | MANIFEST 132 | 133 | # PyInstaller 134 | # Usually these files are written by a python script from a template 135 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 136 | *.manifest 137 | *.spec 138 | 139 | # Installer logs 140 | pip-log.txt 141 | pip-delete-this-directory.txt 142 | 143 | # Unit test / coverage reports 144 | htmlcov/ 145 | .tox/ 146 | .nox/ 147 | .coverage 148 | .coverage.* 149 | .cache 150 | nosetests.xml 151 | coverage.xml 152 | *.cover 153 | *.py,cover 154 | .hypothesis/ 155 | .pytest_cache/ 156 | cover/ 157 | 158 | # Translations 159 | *.mo 160 | *.pot 161 | 162 | # Django stuff: 163 | *.log 164 | local_settings.py 165 | db.sqlite3 166 | db.sqlite3-journal 167 | 168 | # Flask stuff: 169 | instance/ 170 | .webassets-cache 171 | 172 | # Scrapy stuff: 173 | .scrapy 174 | 175 | # Sphinx documentation 176 | docs/_build/ 177 | 178 | # PyBuilder 179 | .pybuilder/ 180 | target/ 181 | 182 | # Jupyter Notebook 183 | .ipynb_checkpoints 184 | 185 | # IPython 186 | profile_default/ 187 | ipython_config.py 188 | 189 | # pyenv 190 | # For a library or package, you might want to ignore these files since the code is 191 | # intended to run in multiple environments; otherwise, check them in: 192 | # .python-version 193 | 194 | # pipenv 195 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 196 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 197 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 198 | # install all needed dependencies. 199 | #Pipfile.lock 200 | 201 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 202 | __pypackages__/ 203 | 204 | # Celery stuff 205 | celerybeat-schedule 206 | celerybeat.pid 207 | 208 | # SageMath parsed files 209 | *.sage.py 210 | 211 | # Environments 212 | .env 213 | .venv 214 | env/ 215 | venv/ 216 | ENV/ 217 | env.bak/ 218 | venv.bak/ 219 | 220 | # Spyder project settings 221 | .spyderproject 222 | .spyproject 223 | 224 | # Rope project settings 225 | .ropeproject 226 | 227 | # mkdocs documentation 228 | /site 229 | 230 | # mypy 231 | .mypy_cache/ 232 | .dmypy.json 233 | dmypy.json 234 | 235 | # Pyre type checker 236 | .pyre/ 237 | 238 | # pytype static type analyzer 239 | .pytype/ 240 | 241 | # Cython debug symbols 242 | cython_debug/ 243 | 244 | ### Example user template template 245 | ### Example user template 246 | 247 | # IntelliJ project files 248 | .idea 249 | *.iml 250 | out 251 | gen 252 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Install Requirement 2 | pip install -r requirements.txt 3 | 4 | ## Goal 5 | ```bash 6 | # check if there are Python syntax errors or undefined names 7 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 8 | # run testcase 9 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 10 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 11 | pytest 12 | ``` -------------------------------------------------------------------------------- /kcm/README.md: -------------------------------------------------------------------------------- 1 | # KCM Exercise 2 | Keyword Correlation Models from Open Corpus 3 | 4 | English Wiki Corpus 5 | ```text 6 | Raw Dump Size: 12 GB 7 | 5,174,418 articles 8 | 2 billons words 9 | Simple Correlation Model Size: 39 GB 10 | # of pairs in Correlation Model: 2,428,774,782 11 | ``` 12 | Chinese Wiki Corpus 13 | ```text 14 | Raw Dump Size: 1GB 15 | 890,973 articles 16 | 10 millon words 17 | Simple Correlation Model Size: 2 GB 18 | # of pairs in Correlation Model: 2,428,774,782 19 | ``` 20 | - Chinese Wiki Corpus Raw Text Preprocess (資料集的下載與前處理) 21 | https://github.com/NCHU-NLU-Lab/Wiki_Extractor 22 | - 使用 500,000 篇文章進行 KCM 練習 23 | 24 | ## Objective 25 | 26 | - Objective 1: 利用Docker佈建Ubuntu 27 | - Objective 2: 熟悉Linux環境, Command LINE Interface 28 | - Objective 3: 了解從大量文字資料中找出知識(Knowledge)的可行性 (在這個例子中,是找到關鍵字共同出現的關係) 29 | 30 | ## Review 31 | - 想想看結果是否與預期相同? 32 | - 如何把品質做得好? 33 | - 雜訊該怎麼處理? 34 | - 時間跑多久? 35 | - 如何加快模型的建立速度? 36 | - 如何加快查詢速度? 37 | -------------------------------------------------------------------------------- /kcm/ex_kcm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NCHU-NLP-Lab/nlp-tutorials/515871e2045c9f2c4459a10238d76d2ca3fd52a6/kcm/ex_kcm/__init__.py -------------------------------------------------------------------------------- /kcm/ex_kcm/kcm.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | def sim(input): 5 | """ 6 | Return 10 most related keywords from kcm result. 7 | Parameters 8 | ---------- 9 | input : string 10 | only keywords on kcm dict will return result, otherwise return empty list 11 | 12 | output : two-dimensional list 13 | The output is a two-dimensional list. 14 | The output contains 10 elements. 15 | Each element is a list with two items: 16 | The first item is related keyword and the second item is frequency. 17 | 18 | Example: 19 | [ [“高雄市”, 39], [“國民黨”, 39], [“候選人”, 24],……. ] 20 | 21 | """ 22 | pass 23 | -------------------------------------------------------------------------------- /kcm/test_kcm.py: -------------------------------------------------------------------------------- 1 | """Example unit tests for `kcm` 2 | 3 | Important 4 | ========= 5 | 6 | Do not modify the way in which the functions `sim` is imported. 7 | 8 | """ 9 | 10 | import unittest 11 | 12 | # code should be importable 13 | from ex_kcm.kcm import sim 14 | 15 | 16 | class TestKCM(unittest.TestCase): 17 | """Test the `kcm` function defined in `.hw.kcm`. 18 | 19 | """ 20 | 21 | test_word_list = ['臺灣', '蔡英文', '復仇者聯盟', '中興大學', '肺炎'] 22 | 23 | def test_empty(self): 24 | self.assertEqual(sim(''), []) 25 | 26 | def test_notinlist(self): 27 | self.assertEqual(sim('KNFER'), []) 28 | 29 | def test_existwords(self): 30 | for word in self.test_word_list: 31 | self.assertEqual(len(sim(word)), 10) 32 | 33 | def test_sort(self): 34 | # result should be sort by frequency in descending order 35 | for word in self.test_word_list: 36 | self.assertTrue(all(sim(word)[i][1] >= sim(word)[i + 1][1] for i in range(len(sim(word)) - 1))) 37 | 38 | def test_length(self): 39 | # one word result should be removed 40 | for word in self.test_word_list: 41 | self.assertTrue(len(sim(word)[0]) > 1) 42 | 43 | def test_example(self): 44 | example_list = ['配音', '香港', '大陸', '日本', '聲演', '日治', '中國大陸', '名稱', '傳統', '地域'] 45 | overlap = set([result[0] for result in sim('臺灣')]) & set(example_list) 46 | self.assertTrue(len(overlap) / 10 > 0.7) 47 | 48 | example_list = ['總統', '中華民國總統', '民進黨', '主席', '臺灣', '民主進步黨', '時任', '競選', '馬英九', '總統府'] 49 | overlap = set([result[0] for result in sim('蔡英文')]) & set(example_list) 50 | self.assertTrue(len(overlap) / 10 > 0.7) 51 | 52 | example_list = ['無限', '電影', '奧創', '紀元', '終局', '漫威', '英雄', '內戰', '美國隊長', '飾演'] 53 | overlap = set([result[0] for result in sim('復仇者聯盟')]) & set(example_list) 54 | self.assertTrue(len(overlap) / 10 > 0.7) 55 | 56 | example_list = ['大學', '教授', '臺灣', '畢業', '臺灣省立', '農學院', '法商學院', '研究所', '合併', '師範大學'] 57 | overlap = set([result[0] for result in sim('中興大學')]) & set(example_list) 58 | self.assertTrue(len(overlap) / 10 > 0.7) 59 | 60 | example_list = ['病例', '新冠', '冠狀病毒', '疫情', '傳染性', '報告', '感染', '武漢', '人數', '患者'] 61 | overlap = set([result[0] for result in sim('肺炎')]) & set(example_list) 62 | self.assertTrue(len(overlap) / 10 > 0.7) 63 | 64 | 65 | if __name__ == '__main__': 66 | unittest.main(verbosity=2) 67 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | flake8 2 | pytest -------------------------------------------------------------------------------- /template/README.md: -------------------------------------------------------------------------------- 1 | # Example Python Homework 2 | 3 | From https://github.com/uwhpsc-2016/example-python-homework 4 | 5 | *An example homework assignment in Python for demonstrating how homework is done 6 | in the course.* 7 | 8 | In this (example) assignment we will learn some basic Python. In particular, we 9 | will learn about basic Python arithmetic, list handling, and how to create 10 | Python modules as well as an introduction to writing tests using Python's 11 | `unittest` module. 12 | 13 | ## Objective 14 | 15 | Provide definitions to the functions `square` and `cube` defined in the Python 16 | submodules `example_homework.squaring` and `example_homework.cubing`, 17 | respectively. 18 | 19 | ```python 20 | def square(x): 21 | # return the square of x 22 | 23 | def cube(x): 24 | # return the cube of x 25 | ``` 26 | 27 | The provided example test suite can be executed using 28 | 29 | ``` 30 | $ python ./test_example_homework.py 31 | ``` 32 | 33 | As always, you are welcome (and encouraged) to add your own tests to the test 34 | suite to ensure that your code is robust. Most of all, make sure that the 35 | supplied tests pass so that the grading software can import and use your 36 | function as expected. 37 | 38 | ## Grading 39 | 40 | When the homework deadline is reached your implementation of `square` and `cube` 41 | will be run against the following tests: 42 | 43 | * (1/5) the square and cube of zero is zero (already in the provided test suite) 44 | * (1/5) the square and cube of two is four and eight, respectively (already in 45 | the provided test suite) 46 | * (1/5) `square` and `cube` should behave appropriately with complex numbers as 47 | input (e.g. `1.0 - 2.0j`) 48 | * (2/5) these functions should also be able to act on Python `list`s of numbers 49 | without the use of the Python function `map`. that is, if the input is of type 50 | `list` then the output should also be of type `list` containing appropriate 51 | corresponding values 52 | 53 | It is important that the function names and locations DO NOT CHANGE. Otherwise, 54 | the test suite used for grading may not be able to import your code. If it's 55 | importable in the test suite provided to you it should be importable in the 56 | grading test suite. 57 | -------------------------------------------------------------------------------- /template/example_homework/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NCHU-NLP-Lab/nlp-tutorials/515871e2045c9f2c4459a10238d76d2ca3fd52a6/template/example_homework/__init__.py -------------------------------------------------------------------------------- /template/example_homework/cubing.py: -------------------------------------------------------------------------------- 1 | r""" 2 | cubing.py 3 | =========== 4 | Provides a definition for the "cube" function. The function prototype should 5 | be 6 | cube(x : type) -> type 7 | Solution 8 | -------- 9 | Do not modify the interface to `cube()`. This function can, however, call 10 | other functions. 11 | """ 12 | 13 | 14 | def cube(x): 15 | """ 16 | Return the cube of a `x`. 17 | Parameters 18 | ---------- 19 | x : number 20 | Any numerical type that can be cubed. 21 | Returns 22 | ------- 23 | value : number 24 | The square of `x`. 25 | """ 26 | if type(x) != list: 27 | return x * x * x 28 | else: 29 | return [cube(a) for a in x] 30 | -------------------------------------------------------------------------------- /template/example_homework/squaring.py: -------------------------------------------------------------------------------- 1 | r""" 2 | squaring.py 3 | =========== 4 | Provides a definition for the "square" function. The function prototype should 5 | be 6 | square(x : type) -> type 7 | Solution 8 | -------- 9 | Do not modify the interface to `square()`. This function can, however, call 10 | other functions. 11 | """ 12 | 13 | 14 | def square(x): 15 | """ 16 | Return the square of a `x`. 17 | Parameters 18 | ---------- 19 | x : number 20 | Any numerical type that can be squared. 21 | Returns 22 | ------- 23 | value : number 24 | The square of `x`. 25 | """ 26 | if type(x) != list: 27 | return x * x 28 | else: 29 | return [square(a) for a in x] 30 | -------------------------------------------------------------------------------- /template/test_example_homework.py: -------------------------------------------------------------------------------- 1 | """Example unit tests for `example_homework` 2 | 3 | Important 4 | ========= 5 | 6 | Do not modify the way in which the functions `square` and `cube` are imported. 7 | This will be exactly the way we import your code for use in grading. You are 8 | encouraged to add as many additional tests as you like. 9 | 10 | """ 11 | 12 | import unittest 13 | 14 | # code should be importable 15 | from example_homework.squaring import square 16 | from example_homework.cubing import cube 17 | 18 | 19 | class TestSquare(unittest.TestCase): 20 | """Test the `square` function defined in `example_homework.square`. 21 | 22 | Test the basic properties of the `square` function. In particular, how it 23 | behaves on the numbers zero and one as well as some example numbers, both 24 | positive and negative. 25 | 26 | """ 27 | 28 | def test_zero(self): 29 | self.assertEqual(square(0), 0) 30 | 31 | def test_one(self): 32 | self.assertEqual(square(1), 1) 33 | 34 | def test_two_three_four(self): 35 | self.assertEqual(square(2), 4) 36 | self.assertEqual(square(3), 9) 37 | self.assertEqual(square(4), 16) 38 | 39 | def test_negative(self): 40 | self.assertEqual(square(-1), 1) 41 | 42 | 43 | class TestCube(unittest.TestCase): 44 | """Test the `cube` function defined in `example_homework.cube`. 45 | 46 | Test the basic properties of the `cube` function. In particular, how it 47 | behaves on the numbers zero and one as well as some example numbers, both 48 | positive and negative. 49 | 50 | """ 51 | 52 | def test_zero(self): 53 | self.assertEqual(cube(0), 0) 54 | 55 | def test_one(self): 56 | self.assertEqual(cube(1), 1) 57 | 58 | def test_two_three_four(self): 59 | self.assertEqual(cube(2), 8) 60 | self.assertEqual(cube(3), 27) 61 | self.assertEqual(cube(4), 64) 62 | 63 | def test_negative(self): 64 | self.assertEqual(cube(-1), -1) 65 | 66 | 67 | if __name__ == '__main__': 68 | unittest.main(verbosity=2) 69 | -------------------------------------------------------------------------------- /w2v/README.md: -------------------------------------------------------------------------------- 1 | # W2V Exercise 2 | Keyword Embedding Model 3 | 4 | Keyword Embedding Technique takes as its input a large corpus of text and produces a high-dimensional space, with each unique word in the corpus being assigned a corresponding vector in the space. 5 | Word vectors are positioned in the vector space such that words that share common contexts in the corpus are located in close proximity to one another in the space 6 | Different View Points for the Same Corpus 7 | Example: 8 | ```text 9 | D1: 教練 評論 王建民 今日 投球 內容 10 | D2: 教練 評論 郭泓志 今日 投球 表現 11 | D3: 教練 談及 林書豪 今日 精彩 表現 12 | ``` 13 | 14 | ## Objective 15 | 16 | - Objective 1: 熟悉Linux環境, Command LINE Interface 17 | - Objective 2: 了解從大量文字資料中找出知識(Knowledge)的可行性 (在這個例子中,是找到關鍵字同位詞的關係) 18 | -------------------------------------------------------------------------------- /w2v/ex_w2v/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NCHU-NLP-Lab/nlp-tutorials/515871e2045c9f2c4459a10238d76d2ca3fd52a6/w2v/ex_w2v/__init__.py -------------------------------------------------------------------------------- /w2v/ex_w2v/w2v.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | def sim(input): 5 | """ 6 | Return 10 most related keywords from word2vec result. 7 | Parameters 8 | ---------- 9 | input : string 10 | only keywords on word2vec dict will return result, otherwise return empty list 11 | Returns 12 | ------- 13 | value : list of dict, result keywords as keys, and its frequency will be value. 14 | """ 15 | pass 16 | -------------------------------------------------------------------------------- /w2v/test_w2v.py: -------------------------------------------------------------------------------- 1 | """Example unit tests for `w2v` 2 | 3 | Important 4 | ========= 5 | 6 | Do not modify the way in which the functions `sim` is imported. 7 | 8 | """ 9 | 10 | import unittest 11 | 12 | # code should be importable 13 | from ex_w2v.w2v import sim 14 | 15 | 16 | class TestKCM(unittest.TestCase): 17 | """Test the `word2vec` function defined in `.hw.word2vec`. 18 | """ 19 | 20 | test_word_list = ['臺灣', '蔡英文', '復仇者聯盟', '中興大學', '肺炎'] 21 | 22 | def test_empty(self): 23 | self.assertEqual(sim(''), []) 24 | 25 | def test_notinlist(self): 26 | self.assertEqual(sim('KNFER'), []) 27 | 28 | def test_existwords(self): 29 | for word in self.test_word_list: 30 | self.assertEqual(len(sim(word)), 10) 31 | 32 | def test_isprob(self): 33 | for word in self.test_word_list: 34 | [self.assertTrue(0 <= float(sim(word)[i][1]) <= 1) for i in range(len(sim(word)))] 35 | 36 | def test_sort(self): 37 | # result should be sort by frequency in descending order 38 | for word in self.test_word_list: 39 | self.assertTrue(all(sim(word)[i][1] >= sim(word)[i + 1][1] for i in range(len(sim(word)) - 1))) 40 | 41 | def test_length(self): 42 | # one word result should be removed 43 | for word in self.test_word_list: 44 | self.assertTrue(len(sim(word)[0]) > 1) 45 | 46 | def test_example(self): 47 | example_list = ['臺灣地區', '臺北', '南臺灣', '臺灣人', '高雄', '宜蘭', '臺南', '臺中', '新竹', '全臺'] 48 | overlap = set([result[0] for result in sim('臺灣')]) & set(example_list) 49 | self.assertTrue(len(overlap) / 10 >= 0.5) 50 | 51 | example_list = ['馬英九', '陳水扁', '李登輝', '蘇貞昌', '韓國瑜', '柯文哲', '吳敦義', '宋楚瑜', '賴清德', '謝長廷'] 52 | overlap = set([result[0] for result in sim('蔡英文')]) & set(example_list) 53 | self.assertTrue(len(overlap) / 10 >= 0.5) 54 | 55 | example_list = ['鋼鐵人', 'X戰警', '正義聯盟', '奧創', '美國隊長', '驚奇隊長', '戰警', '蜘蛛人', '水行俠', '奇異博士'] 56 | overlap = set([result[0] for result in sim('復仇者聯盟')]) & set(example_list) 57 | self.assertTrue(len(overlap) / 10 >= 0.5) 58 | 59 | example_list = ['國立中興大學', '國立成功大學', '逢甲大學', '國立陽明大學', '國立中正大學', '國立屏東科技大學', '東海大學', '國立清華大學', '國立宜蘭大學', '靜宜大學'] 60 | overlap = set([result[0] for result in sim('中興大學')]) & set(example_list) 61 | self.assertTrue(len(overlap) / 10 >= 0.5) 62 | 63 | example_list = ['COVID', '流感', '疫情', '冠狀病毒', '傳染性', '非典型肺炎', '新冠', '流行性感冒', '吸入性', 'Covid'] 64 | overlap = set([result[0] for result in sim('肺炎')]) & set(example_list) 65 | self.assertTrue(len(overlap) / 10 >= 0.5) 66 | 67 | if __name__ == '__main__': 68 | unittest.main(verbosity=2) 69 | -------------------------------------------------------------------------------- /w2v_classifier/README.md: -------------------------------------------------------------------------------- 1 | # W2V Classifier Exercise 2 | Use word2vec to train a sentiment classifier. 3 | 4 | ## Objective 5 | 6 | - Objective 1: 對神經網路可以有基本認識和了解 7 | - Objective 2: 知道怎麼用Pytorch搭建一個的模型 8 | - Objective 2: 了解怎麼樣用word2vec做情緒分類器 9 | -------------------------------------------------------------------------------- /w2v_classifier/ex_classifier/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NCHU-NLP-Lab/nlp-tutorials/515871e2045c9f2c4459a10238d76d2ca3fd52a6/w2v_classifier/ex_classifier/__init__.py -------------------------------------------------------------------------------- /w2v_classifier/ex_classifier/a_sent2vec.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | def sent2vec(sent=''): 6 | """ 7 | Parameters 8 | ---------- 9 | sent : string 10 | sentence 11 | Returns 12 | ------- 13 | sentence_vector : torch.FloatTensor 14 | sentence vector from word vector, formatting in torch.tensor 15 | """ 16 | return torch.tensor([1.0, 2.0, 3.0]) 17 | -------------------------------------------------------------------------------- /w2v_classifier/ex_classifier/b_dataloader.py: -------------------------------------------------------------------------------- 1 | from torch.utils import data 2 | 3 | from .a_sent2vec import sent2vec 4 | 5 | 6 | class SentDataloader(data.Dataset): 7 | def __init__(self, fpath): 8 | 9 | # Initialize path and transform dataset 10 | sample = [] 11 | input = sent2vec('') 12 | target = 1 13 | sample.append([input, target]) 14 | sample.append([input, target]) 15 | sample.append([input, target]) 16 | self.sample = sample 17 | 18 | def __getitem__(self, idx): 19 | 20 | # Return the data (e.g. sentence_vec and label) 21 | return self.sample[idx] 22 | 23 | def __len__(self): 24 | 25 | # Indicate the total size of the dataset 26 | return len(self.sample) -------------------------------------------------------------------------------- /w2v_classifier/ex_classifier/c_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class SentClassifier(nn.Module): 6 | def __init__(self, input_dim, n_classes): 7 | super(SentClassifier, self).__init__() 8 | 9 | # dimensionalities 10 | self.input_dim = input_dim 11 | self.n_classes = n_classes 12 | self.hidden_dim = 500 13 | 14 | # creates a MLP 15 | self.classifier = nn.Sequential( 16 | nn.Linear(self.input_dim, self.hidden_dim), 17 | nn.Tanh(), 18 | nn.Linear(self.hidden_dim, self.n_classes)) 19 | 20 | def forward(self, sentence_vec, target=None): 21 | 22 | predicted = torch.tensor([[0.9, 0.1]], requires_grad=True) 23 | predicted_value, predicted_class = torch.max(predicted, 1) 24 | 25 | if target is not None: 26 | criterion = nn.CrossEntropyLoss() 27 | loss = criterion(predicted, torch.tensor([0])) 28 | return predicted_class, loss 29 | else: 30 | return predicted_class 31 | -------------------------------------------------------------------------------- /w2v_classifier/ex_classifier/d_predict.py: -------------------------------------------------------------------------------- 1 | def sent_predictor(input=''): 2 | """ 3 | Parameters 4 | ---------- 5 | input : string 6 | sentence 7 | Returns 8 | ------- 9 | Classifier prediction : string 10 | The positive or negative predicted by the classifier 11 | """ 12 | if "input is positive": 13 | return 'positive' 14 | else: 15 | return 'negative' 16 | -------------------------------------------------------------------------------- /w2v_classifier/test_classifier.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from torch import optim 3 | from torch.utils import data 4 | 5 | from ex_classifier.a_sent2vec import * 6 | from ex_classifier.b_dataloader import SentDataloader 7 | from ex_classifier.c_model import SentClassifier 8 | from ex_classifier.d_predict import sent_predictor 9 | 10 | 11 | class TestSent2Vec(unittest.TestCase): 12 | 13 | def test_zero(self): 14 | self.assertEqual(sent2vec(''), []) 15 | 16 | def test_sim(self): 17 | self.assertEqual(sent2vec('台灣').shape, sent2vec('中國').shape) 18 | 19 | def test_shape(self): 20 | self.assertEqual(sent2vec('台灣').shape[0], 250) 21 | 22 | 23 | class TestDataloader(unittest.TestCase): 24 | 25 | def test_workable(self): 26 | batch = 2 27 | dataloader = SentDataloader('udicstm_for_dataloader.csv') 28 | dl = data.DataLoader(dataset=dataloader, 29 | batch_size=batch, 30 | shuffle=True) 31 | for d in dl: 32 | input_vec, target = d 33 | print(input_vec, target) 34 | 35 | 36 | class TestModel(unittest.TestCase): 37 | 38 | def test_workable(self): 39 | batch = 2 40 | dataloader = SentDataloader('udicstm_for_dataloader.csv') 41 | dl = data.DataLoader(dataset=dataloader, 42 | batch_size=batch, 43 | shuffle=True) 44 | classifier = SentClassifier(250, 2) 45 | for d in dl: 46 | input_vec, target = d 47 | preducted, loss = classifier.forward(input_vec, target) 48 | print(preducted) 49 | print(loss) 50 | 51 | 52 | class Overall(unittest.TestCase): 53 | 54 | def test_workable(self): 55 | batch = 20 56 | epoch = 10 57 | dataloader = SentDataloader('udicstm_for_dataloader.csv') 58 | classifier = SentClassifier(250, 2) 59 | optimizer = optim.SGD(classifier.parameters(), lr=0.01, momentum=0.9) 60 | for ep in range(epoch): 61 | dl = data.DataLoader(dataset=dataloader, 62 | batch_size=batch, 63 | shuffle=True) 64 | total_loss = 0 65 | for d in dl: 66 | input_vec, target = d 67 | optimizer.zero_grad() 68 | preducted, loss = classifier.forward(input_vec, target) 69 | loss.backward() 70 | optimizer.step() 71 | total_loss += loss 72 | print(total_loss / len(dl)) 73 | 74 | 75 | class Predict(unittest.TestCase): 76 | def test_function(self): 77 | self.assertTrue(isinstance(sent_predictor('這餐廳的送餐速度很快,服務也很好'), str)) 78 | 79 | def test_example(self): 80 | self.assertTrue(sent_predictor('這餐廳的送餐速度很快,服務也很好') == 'positive') 81 | self.assertTrue(sent_predictor('這飯店服務很糟糕') == 'negative') 82 | 83 | if __name__ == '__main__': 84 | unittest.main(verbosity=2) 85 | -------------------------------------------------------------------------------- /w2v_classifier/udicstm_for_dataloader.csv: -------------------------------------------------------------------------------- 1 | 價 錢 還 行 用 着 不 錯 自 提 的 MM 好 漂 亮 服 務 態 度 好 機 器 沒 問 題,positive 2 | 2 房 間 惡 差 我 住 的 房 間 補 補 貼 中 國 八 十 年 代 的 建 築 我 不 介 意 可 你 用 白 漆 塗 塗 吧 我 房 間 牆 是 二 個 顏 色 頂 是 三 個 顏 色 當 然 不 是 配 色 而 是 補 補 貼 貼 的 東 西,negative 3 | 房 間 和 走 廊 裏 味 道 很 重 樓 外 下 面 馬 路 上 的 聲 音 很 吵 早 上 隔 壁 的 電 視 機 的 聲 音 聽 的 很 清 楚 可 以 當 鬧 鐘 使 用 但 時 間 是 隔 壁 定 的 沒 有 給 鬧 鐘 定 時 結 帳 的 時 候 發 票 沒 有 給 我 蓋 章 總 的 感 覺 是 讓 人 失 望 今 後 不 會 再 住 了,negative 4 | 補 充 點 評 2007 年 11 月 23 日 不 好 意 思 評 錯 酒 店 了 正 確 的 應 該 是 房 間 很 大 很 好 缺 點 是 地 點 比 較 偏,positive 5 | 交 通 很 便 利 從 酒 店 步 行 5 分 鐘 就 到 了 最 繁 華 的 開 封 的 小 吃 一 條 街,positive 6 | 優 點 我 不 說 了 說 說 發 現 的 問 題 1 即 使 硬 件 是 DDR 3 由 於 CPU 總 線 只 有 800 所 以 內 存 降 頻 至 800 使 用 2 請 教 各 位 了 調 節 集 成 顯 卡 顯 存 的 選 項 在 哪 裏,positive 7 | 交 通 不 方 便 沒 什 麼 車 隔 音 很 差 雖 然 是 新 開 的 但 是 同 價 位 還 是 選 擇 漢 庭 和 莫 泰 吧 房 間 感 覺 就 像 是 紙 糊 的 浴 室 太 差 了,negative 8 | 總 之 下 回 不 會 在 選 擇 了 還 不 如 花 少 錢 住 三 亞 灣 的 小 型 度 假 村 或 花 同 樣 的 錢 住 環 境 好 的 亞 龍 灣,negative 9 | 掛 的 是 三 星 的 牌 牌 但 室 內 用 具 環 境 及 性 價 比 還 是 差 太 多 現 在 可 能 是 哈 市 的 旅 遊 旺 季 因 爲 是 冰 雪 的 季 節 嗎 房 間 內 有 股 汗 味 讓 人 實 在 受 不 了,negative 10 | 電 池 時 間 有 點 短 們 也 就 兩 小 時 不 到 面 板 易 髒 前 幾 天 我 裝 了 個 酷 狗 音 樂 沒 把 酷 狗 的 什 麼 快 捷 鍵 關 掉 結 果 運 行 的 時 候 不 小 心 按 了 個 右 鍵 導 致 死 機 了,negative 11 | 不 是 鏡 面 屏 風 扇 從 開 機 就 沒 停 聲 音 比 較 大 稍 微 重 了 點,negative 12 | 散 熱 很 好 風 扇 吹 出 來 的 風 始 終 涼 涼 的,positive 13 | 如 果 題 目 改 成 高 盛 歷 史 更 爲 貼 切 文 本 內 容 因 爲 我 按 照 這 個 題 目 一 直 想 找 出 高 盛 究 竟 是 用 什 麼 技 巧 用 什 麼 深 遠 戰 略 贏 得 了 先 今 的 主 導 權 而 本 文 只 是 在 最 後 篇 幅 蜻 蜓 點 水 般 一 帶 而 過 沒 有 預 期 效 果 但 本 書 描 述 了 美 國 投 資 業 及 高 盛 的 發 展 歷 史 和 其 不 同 的 輝 煌 落 敗 還 是 值 得 一 讀 的 只 是 想 勸 誡 當 今 文 壇 作 者 起 書 名 時 一 定 要 對 得 起 觀 衆 而 不 能 用 名 字 先 把 人 給 吸 引 過 來 名 字 要 先 入 爲 主 但 內 容 更 應 名 副 其 實,positive 14 | 許 多 年 前 同 事 爲 養 育 孩 子 就 看 了 這 本 書 多 少 瞭 解 這 本 書 的 一 些 觀 點 這 次 爲 了 養 育 自 己 的 孩 子 買 來 細 細 讀 了 感 覺 還 是 很 有 收 穫 一 個 成 功 的 教 育 個 案 往 往 比 一 些 空 洞 的 教 育 理 論 對 家 長 來 說 更 有 意 義 書 中 的 教 育 觀 點 落 於 實 處 非 常 具 體 很 有 參 考 價 值 基 本 上 同 意 老 卡 爾 威 特 的 教 育 理 念 除 了 個 別 的 一 些 觀 念 一 些 具 體 的 應 對 方 式 和 教 育 方 法 對 於 家 長 來 說 值 得 借 鑑 推 薦 一 下,positive 15 | 硬 盤 到 手 就 發 現 一 個 壞 塊 因 爲 是 完 美 屏 沒 回 京 東 換 新 花 了 兩 天 在 本 地 換 新 硬 盤 發 票 都 不 需 要 電 池 銜 接 很 鬆 可 有 1 mm 間 隙 出 廠 時 A B 面 貼 的 保 護 膜 太 敷 衍 太 多 氣 泡 雖 然 反 正 要 撕 掉 但 說 明 廠 家 態 度 不 嚴 謹,negative 16 | 房 間 裝 修 陳 舊 下 水 管 堵 塞 晚 上 折 騰 了 2 個 多 小 時 還 是 沒 有 修 好,negative 17 | 總 的 來 說 是 不 錯 的 酒 店 但 還 有 一 些 小 細 節 要 注 意 電 腦 寬 帶 接 線 處 的 板 已 壞 早 餐 送 房 後 在 蛋 餅 裏 發 現 鍋 底 灰 經 交 涉 後 店 方 馬 上 來 電 致 歉 這 點 非 常 滿 意 但 希 望 在 下 次 入 住 不 要 再 發 生 此 類 事 件 下 次 去 成 都 還 會 繼 續 入 住 喜 來 登,positive 18 | 沒 我 想 像 的 那 麼 好 還 行 吧 我 兒 子 看 着 還 行,negative 19 | 幼 兒 園 的 老 師 更 應 該 看 孩 子 大 多 時 間 還 是 在 幼 兒 園 在 幼 兒 園 裏 經 歷 的 一 些 事 可 能 會 影 響 孩 子 的 性 格 甚 至 一 生 如 果 幼 教 都 能 領 悟 蒙 特 梭 利 的 育 兒 法 真 正 讓 孩 子 做 主 真 正 愛 孩 子 對 孩 子 有 極 大 的 耐 心 我 相 信 我 們 的 小 苗 苗 都 能 長 成 參 天 大 樹 的 我 們 明 天 的 棟 樑 會 讓 我 們 中 國 更 加 強 大,positive 20 | 現 在 在 用 頂 多 是 一 般 如 果 買 的 話 推 薦 聯 想 那 款,positive 21 | -------------------------------------------------------------------------------- /w2v_gru_classifier/README.md: -------------------------------------------------------------------------------- 1 | # W2V GRU Classifier Exercise 2 | Use word2vec and GRU to train a Taipei_FAQ classifier. 3 | 4 | ## Objective 5 | 6 | - Objective 1: 對RNN類的神經網路可以有基本認識和了解 7 | - Objective 2: 思考改善分類器的方式 8 | - Objective 2: 了解怎麼樣用lstm和word2vec做情緒分類器 9 | -------------------------------------------------------------------------------- /w2v_gru_classifier/Taipei_FAQ_for_dataloader.csv: -------------------------------------------------------------------------------- 1 | 臺北市專為高齡者開辦的課程或活動有哪些?可否有網站資料直接查詢?,臺北市政府教育局終身教育科 2 | 如何查詢藝文推廣處城市舞台檔期?,臺北市藝文推廣處 3 | 個案若安置於機構,是否能使用失能身心障礙日間照顧中心服務?,臺北市政府社會局身心障礙者福利科 4 | 申請低收入戶18歲以上就學生活補助洽辦單位、應備文件、補助資格及補助內容?,臺北市政府社會局社會救助科 5 | 如何查詢訴願案件辦理進度?,臺北市政府法務局 6 | 騎車時,若遇到車輛爆胎等故障情形,該怎麼辦?,臺北市政府交通局 7 | 有關私立高中學雜費補助,孩子與爺爺同住,父母和孩子不同戶籍,請問這樣就不能申請私立高中學雜費補助?,臺北市政府教育局中等教育科 8 | 什麼是茲卡病毒感染症?,臺北市政府衛生局疾病管制科 9 | 請問如何成為衛生保健志工?,臺北市政府衛生局健康管理科 10 | 公務人員能否擔任公司股東或公司之董事或監察人,臺北市商業處 11 | 重型機車(紅牌、黃牌)可以停放汽車停車格嗎?,臺北市停車管理工程處 12 | 106年度臺北市各有線電視業者收視費用是多少?,臺北市政府觀光傳播局 13 | 是否有暑期工讀機會?每年何時辦理?,臺北市就業服務處 14 | 本處稽查人員執法不公、選擇性開單?,臺北市停車管理工程處 15 | 本人或家人持有身心障礙手冊或證明,如何辦理使用牌照稅免稅優惠?,臺北市稅捐稽徵處機會稅科 16 | 契稅如何計算?,臺北市稅捐稽徵處財產稅科 17 | ●軍人權益:「收到教育召集令如何請假?」,臺北市政府兵役局 18 | 什麼是新一代學生悠遊卡?,臺北市公共運輸處綜合規劃科 19 | 是否可以自行製作環保兩用袋?,臺北市政府環境保護局資源循環管理科 20 | 小型車後座繫安全帶法令介紹,臺北市政府交通局 -------------------------------------------------------------------------------- /w2v_gru_classifier/ex_classifier/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NCHU-NLP-Lab/nlp-tutorials/515871e2045c9f2c4459a10238d76d2ca3fd52a6/w2v_gru_classifier/ex_classifier/__init__.py -------------------------------------------------------------------------------- /w2v_gru_classifier/ex_classifier/a_sent2vec.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | def sent2vec(sent=''): 6 | """ 7 | Parameters 8 | ---------- 9 | sent : string 10 | sentence 11 | Returns 12 | ------- 13 | sentence_vector : torch.FloatTensor 14 | sentence vector from word vector, formatting in torch.tensor 15 | example 16 | ------- 17 | I love you. 18 | => 19 | tensor([ 20 | [1.0765e-01, -3.6939e+00, 1.2139e+00, -1.0561e+00, -2.0084e+00, # "I" vector 21 | -1.4055e+00, -9.0298e-01, -2.3618e-01, 1.5151e+00, -1.2158e-01, 22 | 2.3321e+00, -5.7944e-01, -2.2252e-01, ...], 23 | [-6.3879e-01, -1.7294e+00, 1.1637e-01, -1.0025e+00, -6.6298e-01, # "love" vector 24 | -1.6146e+00, -1.1563e+00, -1.4284e+00, 1.1772e+00, -1.4051e+00, 25 | -5.2077e-01, -4.0171e-01, -1.9743e-01, ...], 26 | [4.7850e-01, -1.4013e+00, -7.7003e-01, -9.6428e-01, -6.0314e-01, # "you" vector 27 | 1.7834e-01, 6.1909e-02, -2.0041e-01, 4.4003e-01, 5.2138e-01, 28 | -2.2191e-01, -2.6324e-02, -1.1932e+00, ...] 29 | ]) 30 | => 31 | torch.Size([3,250]) #[keywords num, word2vec dim] 32 | """ 33 | inputs = torch.tensor([[1.0, 2.0, 3.0]]) 34 | inputs = torch.cat(inputs) # torch.cat 合併向量 35 | inputs = torch.view(len(inputs), 250) #torch.view 依指定數字做組合 36 | return inputs 37 | -------------------------------------------------------------------------------- /w2v_gru_classifier/ex_classifier/b_dataloader.py: -------------------------------------------------------------------------------- 1 | from torch.utils import data 2 | from .a_sent2vec import sent2vec 3 | from sklearn.preprocessing import LabelEncoder 4 | import pickle 5 | import csv 6 | 7 | class SentDataloader(data.Dataset): 8 | def __init__(self, fpath): 9 | sample = [] 10 | maxlen = 0 11 | 12 | # 使用 LabelEncoder 對 label 做 encoder 13 | le = LabelEncoder() 14 | le_target = le.fit_transform('answers list') 15 | 16 | 17 | 18 | #將 question 與 target 合併成一筆資料 19 | for q, target in zip('questions list', 'le_target'): 20 | input = sent2vec('') 21 | sample.append([input, target]) 22 | 23 | 24 | 25 | # 將資料向量補齊至最大長度,使資料維度都一樣 26 | zero_vetor = torch.zeros(1,250) 27 | for ele in sample: 28 | while ele[0].size(0) < maxlen: 29 | ele[0] = torch.cat((ele[0], zero_vetor), 0) 30 | 31 | self.sample = sample 32 | self.maxlen = maxlen 33 | 34 | def __len__(self): 35 | return len(self.sample) 36 | 37 | def __getitem__(self, idx): 38 | return self.sample[idx] 39 | -------------------------------------------------------------------------------- /w2v_gru_classifier/ex_classifier/c_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class SentClassifier(nn.Module): 6 | def __init__(self, input_dim, n_classes): 7 | super(SentClassifier, self).__init__() 8 | 9 | # dimensionalities 10 | self.input_dim = input_dim 11 | self.n_classes = n_classes 12 | self.hidden_dim = 500 13 | 14 | # creates a MLP 15 | self.gru_s2v = nn.GRU(self.input_dim, self.hidden_dim, batch_first=True) 16 | self.acf = nn.ReLU() 17 | self.classifier = nn.Linear(self.hidden_dim, self.n_classes) 18 | 19 | def forward(self, sentence_vec, target=None): 20 | 21 | h0 = torch.zeros(1, 100, 100) # 初始化 hidden state (num_layers (GRU 層數), batch_size, hidden_dim) 22 | 23 | gru_output, hidden = self.gru_s2v(input, h0) # 使用 gru 產生句向量 24 | 25 | gru_output = self.acf() # 將 gru_output 最後一層向量經過 acf 26 | 27 | predicted = self.classifier() # 丟入 linear layer 做分類 28 | 29 | predicted_value, predicted_class = torch.max(predicted, 1) 30 | 31 | if target is not None: 32 | criterion = nn.CrossEntropyLoss() 33 | loss = criterion(predicted, torch.tensor([0])) # predicted 與 target 做 loss 計算 34 | return predicted_class, loss 35 | else: 36 | return predicted_class -------------------------------------------------------------------------------- /w2v_gru_classifier/ex_classifier/d_predict.py: -------------------------------------------------------------------------------- 1 | from .a_sent2vec import sent2vec 2 | from .c_model import SentClassifier 3 | import torch 4 | from sklearn.preprocessing import LabelEncoder 5 | import pickle 6 | 7 | # load 模型 8 | # load LabelEncoder 規則 9 | 10 | def sent_predictor(input=''): 11 | 12 | # 將輸入轉成 input 格式丟入模型,並使用 LabelEncoder 規則將預測結果轉乘類別名稱 13 | 14 | return 'predicted_class' 15 | -------------------------------------------------------------------------------- /w2v_gru_classifier/test_classifier.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from torch import optim 3 | from torch.utils import data 4 | 5 | from ex_classifier.a_sent2vec import * 6 | from ex_classifier.b_dataloader import SentDataloader 7 | from ex_classifier.c_model import SentClassifier 8 | 9 | from ex_classifier.d_predict import sent_predictor 10 | 11 | 12 | class TestSent2Vec(unittest.TestCase): 13 | 14 | def test_zero(self): 15 | self.assertEqual(sent2vec(''), []) 16 | 17 | def test_sim(self): 18 | self.assertEqual(sent2vec('台灣').shape, sent2vec('中國').shape) 19 | 20 | def test_shape(self): 21 | self.assertEqual(sent2vec('台灣').shape[1], 250) 22 | 23 | 24 | class TestDataloader(unittest.TestCase): 25 | 26 | def test_workable(self): 27 | batch = 2 28 | dataloader = SentDataloader('Taipei_FAQ_for_dataloader.csv') 29 | dl = data.DataLoader(dataset=dataloader, 30 | batch_size=batch, 31 | shuffle=True) 32 | for d in dl: 33 | input_vec, target = d 34 | print(input_vec, target) 35 | 36 | 37 | class TestModel(unittest.TestCase): 38 | 39 | def test_workable(self): 40 | batch = 2 41 | dataloader = SentDataloader('Taipei_FAQ_for_dataloader.csv') 42 | dl = data.DataLoader(dataset=dataloader, 43 | batch_size=batch, 44 | shuffle=True) 45 | classifier = SentClassifier(250, 78) 46 | for d in dl: 47 | input_vec, target = d 48 | preducted, loss = classifier.forward(input_vec, target) 49 | print(preducted) 50 | print(loss) 51 | 52 | 53 | class Overall(unittest.TestCase): 54 | 55 | def test_workable(self): 56 | batch = 20 57 | epoch = 10 58 | dataloader = SentDataloader('Taipei_FAQ_for_dataloader.csv') 59 | classifier = SentClassifier(250, 78) 60 | # optimizer = optim.SGD(classifier.parameters(), lr=0.01, momentum=0.9) 61 | optimizer = optim.Adam(classifier.parameters(), lr=0.001) 62 | for ep in range(epoch): 63 | dl = data.DataLoader(dataset=dataloader, 64 | batch_size=batch, 65 | shuffle=True) 66 | total_loss = 0 67 | for d in dl: 68 | input_vec, target = d 69 | optimizer.zero_grad() 70 | preducted, loss = classifier.forward(input_vec, target) 71 | loss.backward() 72 | optimizer.step() 73 | total_loss += loss 74 | print(total_loss / len(dl)) 75 | 76 | 77 | class Predict(unittest.TestCase): 78 | def test_function(self): 79 | self.assertTrue(isinstance(sent_predictor('如何查詢藝文推廣處城市舞台檔期?'), str)) 80 | 81 | def test_example(self): 82 | self.assertTrue(sent_predictor('如何查詢藝文推廣處城市舞台檔期?') == '臺北市藝文推廣處') 83 | self.assertTrue(sent_predictor('如果感染了登革熱該怎麼辦?') == '臺北市政府衛生局疾病管制科') 84 | 85 | 86 | if __name__ == '__main__': 87 | unittest.main(verbosity=2) 88 | --------------------------------------------------------------------------------