├── .gitignore ├── LICENSE ├── Pipfile ├── Pipfile.lock ├── README.md ├── data └── natsume.txt ├── deepdialog ├── __init__.py ├── datasource │ ├── __init__.py │ └── aozora_converter.ipynb ├── rnnlm │ ├── __init__.py │ └── rnnlm.ipynb └── transformer │ ├── README.md │ ├── __init__.py │ ├── attention.py │ ├── common_layer.py │ ├── embedding.py │ ├── metrics.py │ ├── preprocess │ ├── __init__.py │ ├── batch_generator.py │ ├── create_tokenizer.sh │ ├── spm_natsume.model │ └── spm_natsume.vocab │ ├── training.ipynb │ └── transformer.py ├── test ├── __init__.py ├── deepdialog │ ├── __init__.py │ └── transformer │ │ ├── __init__.py │ │ ├── attention.py │ │ ├── common_layer.py │ │ ├── embedding.py │ │ ├── metrics.py │ │ ├── preprocess │ │ ├── __init__.py │ │ └── batch_generator.py │ │ └── transformer.py ├── run └── run.py ├── tmp └── .gitkeep └── tox.ini /.gitignore: -------------------------------------------------------------------------------- 1 | tmp/* 2 | !.gitkeep 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | env/ 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | 52 | # Translations 53 | *.mo 54 | *.pot 55 | 56 | # Django stuff: 57 | *.log 58 | local_settings.py 59 | 60 | # Flask stuff: 61 | instance/ 62 | .webassets-cache 63 | 64 | # Scrapy stuff: 65 | .scrapy 66 | 67 | # Sphinx documentation 68 | docs/_build/ 69 | 70 | # PyBuilder 71 | target/ 72 | 73 | # Jupyter Notebook 74 | .ipynb_checkpoints 75 | 76 | # pyenv 77 | .python-version 78 | 79 | # celery beat schedule file 80 | celerybeat-schedule 81 | 82 | # SageMath parsed files 83 | *.sage.py 84 | 85 | # dotenv 86 | .env 87 | 88 | # virtualenv 89 | .venv 90 | venv/ 91 | ENV/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Harumitsu Nobuta 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Pipfile: -------------------------------------------------------------------------------- 1 | [[source]] 2 | url = "https://pypi.org/simple" 3 | verify_ssl = true 4 | name = "pypi" 5 | 6 | [scripts] 7 | test = "python -m test.run" 8 | lint = "flake8" 9 | mypy = "mypy --ignore-missing" 10 | 11 | [packages] 12 | jupyterlab = "*" 13 | tensorflow-gpu = "==1.12.0" 14 | tqdm = "*" 15 | ipython = "*" 16 | "flake8" = "*" 17 | mypy = "*" 18 | sentencepiece = "*" 19 | 20 | [dev-packages] 21 | 22 | [requires] 23 | python_version = "3.6" 24 | -------------------------------------------------------------------------------- /Pipfile.lock: -------------------------------------------------------------------------------- 1 | { 2 | "_meta": { 3 | "hash": { 4 | "sha256": "c32de2ac2b1539be81d6d4f82c242dcb56136632db4d7a45bf5ba48da327dfb9" 5 | }, 6 | "pipfile-spec": 6, 7 | "requires": { 8 | "python_version": "3.6" 9 | }, 10 | "sources": [ 11 | { 12 | "name": "pypi", 13 | "url": "https://pypi.org/simple", 14 | "verify_ssl": true 15 | } 16 | ] 17 | }, 18 | "default": { 19 | "absl-py": { 20 | "hashes": [ 21 | "sha256:87519e3b91a3d573664c6e2ee33df582bb68dca6642ae3cf3a4361b1c0a4e9d6" 22 | ], 23 | "version": "==0.6.1" 24 | }, 25 | "astor": { 26 | "hashes": [ 27 | "sha256:95c30d87a6c2cf89aa628b87398466840f0ad8652f88eb173125a6df8533fb8d", 28 | "sha256:fb503b9e2fdd05609fbf557b916b4a7824171203701660f0c55bbf5a7a68713e" 29 | ], 30 | "version": "==0.7.1" 31 | }, 32 | "backcall": { 33 | "hashes": [ 34 | "sha256:38ecd85be2c1e78f77fd91700c76e14667dc21e2713b63876c0eb901196e01e4", 35 | "sha256:bbbf4b1e5cd2bdb08f915895b51081c041bac22394fdfcfdfbe9f14b77c08bf2" 36 | ], 37 | "version": "==0.1.0" 38 | }, 39 | "bleach": { 40 | "hashes": [ 41 | "sha256:48d39675b80a75f6d1c3bdbffec791cf0bbbab665cf01e20da701c77de278718", 42 | "sha256:73d26f018af5d5adcdabf5c1c974add4361a9c76af215fe32fdec8a6fc5fb9b9" 43 | ], 44 | "version": "==3.0.2" 45 | }, 46 | "decorator": { 47 | "hashes": [ 48 | "sha256:2c51dff8ef3c447388fe5e4453d24a2bf128d3a4c32af3fabef1f01c6851ab82", 49 | "sha256:c39efa13fbdeb4506c476c9b3babf6a718da943dab7811c206005a4a956c080c" 50 | ], 51 | "version": "==4.3.0" 52 | }, 53 | "defusedxml": { 54 | "hashes": [ 55 | "sha256:24d7f2f94f7f3cb6061acb215685e5125fbcdc40a857eff9de22518820b0a4f4", 56 | "sha256:702a91ade2968a82beb0db1e0766a6a273f33d4616a6ce8cde475d8e09853b20" 57 | ], 58 | "version": "==0.5.0" 59 | }, 60 | "entrypoints": { 61 | "hashes": [ 62 | "sha256:10ad569bb245e7e2ba425285b9fa3e8178a0dc92fc53b1e1c553805e15a8825b", 63 | "sha256:d2d587dde06f99545fb13a383d2cd336a8ff1f359c5839ce3a64c917d10c029f" 64 | ], 65 | "version": "==0.2.3" 66 | }, 67 | "flake8": { 68 | "hashes": [ 69 | "sha256:6a35f5b8761f45c5513e3405f110a86bea57982c3b75b766ce7b65217abe1670", 70 | "sha256:c01f8a3963b3571a8e6bd7a4063359aff90749e160778e03817cd9b71c9e07d2" 71 | ], 72 | "index": "pypi", 73 | "version": "==3.6.0" 74 | }, 75 | "gast": { 76 | "hashes": [ 77 | "sha256:7068908321ecd2774f145193c4b34a11305bd104b4551b09273dfd1d6a374930" 78 | ], 79 | "version": "==0.2.0" 80 | }, 81 | "grpcio": { 82 | "hashes": [ 83 | "sha256:09a37a0121215740fb39c5e20afa2d40a4731799d8b7dc61d854a75c9299b6a2", 84 | "sha256:2d06536bcfb6c83e9074b86ce88ae7b2faedfaac5e8cb261b8edb47b3b14dc59", 85 | "sha256:32d967ed8e09c1ef71126fbbe0d2f3bc8d1f7500ae6065b4dd271e0822b83650", 86 | "sha256:41bf456665d666d0a07cb1b363626745dc1a98b02395acd666b36c929133a72a", 87 | "sha256:446b40a37250acb2555962df900e0ce7d5c7e8b9c5353bd786b75a389d2b67a7", 88 | "sha256:5315dfc229223c6e18134e1ad1bec1306372ab74ef4ccee4eebdf4f6e8e149ff", 89 | "sha256:5764615ae4c2bb3de6b8498714c64507ea6bca9461d4fc4421af632491e55139", 90 | "sha256:58b09826e26f7842aca2f81fb4c7169fdff3e58c4b1b635a8a2cd0b3a216ea74", 91 | "sha256:6a6f5bfd3e23ac43dcb11e7ae8ca499c5a283df138d0f70f05cc2a46dd0599d2", 92 | "sha256:703e431ca770b924f85d14bb02a1ebcebf70541a068b4cbc4a1f2fe88cc4e997", 93 | "sha256:78b99a433b8ee41d878facd76c6877a688b5f1e07634968bf3f86fdf0557510e", 94 | "sha256:7907cc0119988532e30594bb3c97d78d3e0998d15a0623d7223bbd4db2fcace3", 95 | "sha256:7b4467b7ba0db4c506c8c689461d75c4d4e626edde2e1b666fba35f473b9c671", 96 | "sha256:7e78e05df820136eff85056253411bc2598ef3c3395508ade7373800825511fb", 97 | "sha256:7e8c15407b5d85cb7ea580e2e272af4063dcdbb1b0ee93f0131ba3b345679bd2", 98 | "sha256:7eae9240a6ad2097f835f5f93050e0ad9440ff50799215b70c9950e743b7c685", 99 | "sha256:817b6c479ff3edd05bc89bbff5ab1ba89392af81894cc27ae6a47d741ca375c6", 100 | "sha256:836e3ccac59c4b3222915d2b6440b1ab13191be15d004cd7ab9fcac5946249fe", 101 | "sha256:8b9b2c5084b883b52c705838b132ddbd5138f64bf21c1fdbeaf854598f9131f3", 102 | "sha256:9a6eaa71d328347fb13f6a3fb4d1564cc393dc37b6d07f37e84c78d8f605b548", 103 | "sha256:9afe4584a7c9928588be3b6340eea887f241e3b470a6cad9827e8f2cd3a90273", 104 | "sha256:9bff46dd43773329fbca3f19b2b07c0be9ec43c5a57a98ef77b7faa810d452e3", 105 | "sha256:a73f989e45b34d211719a62d565ea13db32c7ae741fff5746126b2aacb31a0be", 106 | "sha256:a7a0fa9df943ba46fde64083cf18579c34ae73a56e765e8b3dcf36eed0ad1bdb", 107 | "sha256:a7e6c986b0d12e7fa70faba37fec4cf7366cdba603a6548a79c6e2ed1db906a5", 108 | "sha256:bdad37e6dfcd70524b712e45e7bac7cc05caa2eca563b0c072b5fcdc9dc34468", 109 | "sha256:c0c624efc1fc1433588efb38011a570d1939b23001ef1dfec06ef1734cf00e7e", 110 | "sha256:d2c17d4a1fee746e7d122c84ca9733347beb449bfc0afdba36ad292871d62f4f", 111 | "sha256:d6c798506312648758ee774281f64469109b834f19e5de1a800451ef1d4e276b", 112 | "sha256:dd2dfc067acea55c89f6b2b63a4c96b84534a3073509277ff980c44bfcf3314f", 113 | "sha256:df316ce5b353d8ecb9fdff4c5bedb86964d4f46cf979825a444cc3e03d5ce2d5", 114 | "sha256:e6dc1ed826107f782f300774dd933eadfe54784a5225a0a5af4a31821a440136" 115 | ], 116 | "version": "==1.16.1" 117 | }, 118 | "h5py": { 119 | "hashes": [ 120 | "sha256:0f8cd2acbacf3177b4427ed42639c911667b1f24d923388ab1f8ad466a12be5e", 121 | "sha256:11277e3879098f921ee9e29105b20591e1dfdd44963357399f2abaa1a280c560", 122 | "sha256:1241dec0c94ac32f3285cac1d6f44beabf80423e422ab03bd2686d731a8a9294", 123 | "sha256:17b8187de0b3a945d8e8d031e7eb6ece2fce90791f9c5fde36f4396bf38fdde1", 124 | "sha256:2f30007d0796788a454c1293262f19f25e6428317d3d386f78138fba2a44e37d", 125 | "sha256:308e0758587ee16d4e73e7f2f8aae8351091e343bf0a43d2f697f9535465c816", 126 | "sha256:37cacddf0e8209905f52537a8cf71da0dd9a4de62bd79247274c97b24a408997", 127 | "sha256:38a23bb599748adf23d77f74885c0de6f4a7d9baa42f74e476bbf90fba2b47dd", 128 | "sha256:47ab18b7b7bbc36fd2b606289b703b6f0ee915b923d6ad94dd17ac80ebffc280", 129 | "sha256:486c78330af0bf33f5077b51d1888c0739c3cd1a03d5aade0d48572b3b5690ca", 130 | "sha256:4e2183458d6ef1ae87dfb5d6acd0786359336cd9ac0ece6396c09b59fdaa3bd6", 131 | "sha256:51d0595c3e58814c831f6cd2b664a5bf9590e26262c1d541b380d041e4fcb3c0", 132 | "sha256:56d259d56822b70881760b243957f04a0cf133f0ec65eae6a33f562826aee899", 133 | "sha256:5e6e777653169a3cc24ea56bb3d8c845ea391f8914c35bb6f350b0753a52891c", 134 | "sha256:62bfb0ebb0f59e5dccc0b0dbbc0fc40dd1d1e09d04c0dc71f89790231531d4a2", 135 | "sha256:67d89b64debfa021b54aa6f24bbf008403bd144748a0148596b518bce80d2fc4", 136 | "sha256:6bf38571f555fa214493ec6349d29024cc5f313bf1715b09f236c553fd22ae4d", 137 | "sha256:9214ca445c18a37bfe9c165982c0e317e2f21f035c8d635d1c6d9fcbaf35b7a8", 138 | "sha256:ab0c52850428d2e86029935389379c2c97f752e76b616da851deec8a4484f8ec", 139 | "sha256:b2eff336697d8dfd712c5d93fef9f4e4d3e97d9d8c258801836b8664a239e07a", 140 | "sha256:bb33fabc0b8f3fe3bb0f8d6821b2fad5b2a64c27a0808e8d1c5c1e3362062064", 141 | "sha256:bd5353ab342bae1262b04745934cc1565df4cbc8d6a979a0c98f42209bd5c265", 142 | "sha256:bd73444efd1ac06dac27b8405bbe8791a02fd1bc8a2fa0e575257f90b7b57467", 143 | "sha256:bd932236a2ef91a75fee5d7f4ace80ab494c5a59cd092a67c9785ddb7fdc218c", 144 | "sha256:c45650de228ace7731e4280e14fb687f6d5c29cd666c5b22b42492b035e994d6", 145 | "sha256:d5c0c01da45f901a3d429e7ef9e7e22baa869e1affb8715f1bf94e6a30020740", 146 | "sha256:d75035db5bde802a29f4f29f18bb7548863d29ac90ccbf2c04c11799bbbba2c3", 147 | "sha256:dda88206dc9464923f27f601000bc5b152ac0bd6d0122f098d4f239150a70076", 148 | "sha256:e1c2ac5d0aa232c0f60fecc6bd1122346885086a176f939b91058c4c980cc226", 149 | "sha256:e626c65a8587921ebc7fb8d31a49addfdd0b9a9aa96315ea484c09803337b955" 150 | ], 151 | "version": "==2.8.0" 152 | }, 153 | "ipykernel": { 154 | "hashes": [ 155 | "sha256:0aeb7ec277ac42cc2b59ae3d08b10909b2ec161dc6908096210527162b53675d", 156 | "sha256:0fc0bf97920d454102168ec2008620066878848fcfca06c22b669696212e292f" 157 | ], 158 | "version": "==5.1.0" 159 | }, 160 | "ipython": { 161 | "hashes": [ 162 | "sha256:6a9496209b76463f1dec126ab928919aaf1f55b38beb9219af3fe202f6bbdd12", 163 | "sha256:f69932b1e806b38a7818d9a1e918e5821b685715040b48e59c657b3c7961b742" 164 | ], 165 | "index": "pypi", 166 | "version": "==7.2.0" 167 | }, 168 | "ipython-genutils": { 169 | "hashes": [ 170 | "sha256:72dd37233799e619666c9f639a9da83c34013a73e8bbc79a7a6348d93c61fab8", 171 | "sha256:eb2e116e75ecef9d4d228fdc66af54269afa26ab4463042e33785b887c628ba8" 172 | ], 173 | "version": "==0.2.0" 174 | }, 175 | "jedi": { 176 | "hashes": [ 177 | "sha256:0191c447165f798e6a730285f2eee783fff81b0d3df261945ecb80983b5c3ca7", 178 | "sha256:b7493f73a2febe0dc33d51c99b474547f7f6c0b2c8fb2b21f453eef204c12148" 179 | ], 180 | "version": "==0.13.1" 181 | }, 182 | "jinja2": { 183 | "hashes": [ 184 | "sha256:74c935a1b8bb9a3947c50a54766a969d4846290e1e788ea44c1392163723c3bd", 185 | "sha256:f84be1bb0040caca4cea721fcbbbbd61f9be9464ca236387158b0feea01914a4" 186 | ], 187 | "version": "==2.10" 188 | }, 189 | "jsonschema": { 190 | "hashes": [ 191 | "sha256:000e68abd33c972a5248544925a0cae7d1125f9bf6c58280d37546b946769a08", 192 | "sha256:6ff5f3180870836cae40f06fa10419f557208175f13ad7bc26caa77beb1f6e02" 193 | ], 194 | "version": "==2.6.0" 195 | }, 196 | "jupyter-client": { 197 | "hashes": [ 198 | "sha256:27befcf0446b01e29853014d6a902dd101ad7d7f94e2252b1adca17c3466b761", 199 | "sha256:59e6d791e22a8002ad0e80b78c6fd6deecab4f9e1b1aa1a22f4213de271b29ea" 200 | ], 201 | "version": "==5.2.3" 202 | }, 203 | "jupyter-core": { 204 | "hashes": [ 205 | "sha256:927d713ffa616ea11972534411544589976b2493fc7e09ad946e010aa7eb9970", 206 | "sha256:ba70754aa680300306c699790128f6fbd8c306ee5927976cbe48adacf240c0b7" 207 | ], 208 | "version": "==4.4.0" 209 | }, 210 | "jupyterlab": { 211 | "hashes": [ 212 | "sha256:c48f092526f6d5f12b039118bd92401ab605f49d17050ac71c0d809e86b15036", 213 | "sha256:deba0b2803640fcad72c61366bff11d5945173015961586d5e3b2f629ffeb455" 214 | ], 215 | "index": "pypi", 216 | "version": "==0.35.4" 217 | }, 218 | "jupyterlab-server": { 219 | "hashes": [ 220 | "sha256:65eaf85b27a37380329fbdd8ebd095a0bd65fe9261d73ef6a1abee1dbaeaac1f", 221 | "sha256:72d916a73957a880cdb885def6d8664a6d1b2760ef5dca5ad665aa1e8d1bb783" 222 | ], 223 | "version": "==0.2.0" 224 | }, 225 | "keras-applications": { 226 | "hashes": [ 227 | "sha256:721dda4fa4e043e5bbd6f52a2996885c4639a7130ae478059b3798d0706f5ae7", 228 | "sha256:a03af60ddc9c5afdae4d5c9a8dd4ca857550e0b793733a5072e0725829b87017" 229 | ], 230 | "version": "==1.0.6" 231 | }, 232 | "keras-preprocessing": { 233 | "hashes": [ 234 | "sha256:90d04c1750bccceef88ac09475c291b4b5f6aa1eaf0603167061b1aa8b043c61", 235 | "sha256:ef2e482c4336fcf7180244d06f4374939099daa3183816e82aee7755af35b754" 236 | ], 237 | "version": "==1.0.5" 238 | }, 239 | "markdown": { 240 | "hashes": [ 241 | "sha256:c00429bd503a47ec88d5e30a751e147dcb4c6889663cd3e2ba0afe858e009baa", 242 | "sha256:d02e0f9b04c500cde6637c11ad7c72671f359b87b9fe924b2383649d8841db7c" 243 | ], 244 | "version": "==3.0.1" 245 | }, 246 | "markupsafe": { 247 | "hashes": [ 248 | "sha256:048ef924c1623740e70204aa7143ec592504045ae4429b59c30054cb31e3c432", 249 | "sha256:130f844e7f5bdd8e9f3f42e7102ef1d49b2e6fdf0d7526df3f87281a532d8c8b", 250 | "sha256:19f637c2ac5ae9da8bfd98cef74d64b7e1bb8a63038a3505cd182c3fac5eb4d9", 251 | "sha256:1b8a7a87ad1b92bd887568ce54b23565f3fd7018c4180136e1cf412b405a47af", 252 | "sha256:1c25694ca680b6919de53a4bb3bdd0602beafc63ff001fea2f2fc16ec3a11834", 253 | "sha256:1f19ef5d3908110e1e891deefb5586aae1b49a7440db952454b4e281b41620cd", 254 | "sha256:1fa6058938190ebe8290e5cae6c351e14e7bb44505c4a7624555ce57fbbeba0d", 255 | "sha256:31cbb1359e8c25f9f48e156e59e2eaad51cd5242c05ed18a8de6dbe85184e4b7", 256 | "sha256:3e835d8841ae7863f64e40e19477f7eb398674da6a47f09871673742531e6f4b", 257 | "sha256:4e97332c9ce444b0c2c38dd22ddc61c743eb208d916e4265a2a3b575bdccb1d3", 258 | "sha256:525396ee324ee2da82919f2ee9c9e73b012f23e7640131dd1b53a90206a0f09c", 259 | "sha256:52b07fbc32032c21ad4ab060fec137b76eb804c4b9a1c7c7dc562549306afad2", 260 | "sha256:52ccb45e77a1085ec5461cde794e1aa037df79f473cbc69b974e73940655c8d7", 261 | "sha256:5c3fbebd7de20ce93103cb3183b47671f2885307df4a17a0ad56a1dd51273d36", 262 | "sha256:5e5851969aea17660e55f6a3be00037a25b96a9b44d2083651812c99d53b14d1", 263 | "sha256:5edfa27b2d3eefa2210fb2f5d539fbed81722b49f083b2c6566455eb7422fd7e", 264 | "sha256:7d263e5770efddf465a9e31b78362d84d015cc894ca2c131901a4445eaa61ee1", 265 | "sha256:83381342bfc22b3c8c06f2dd93a505413888694302de25add756254beee8449c", 266 | "sha256:857eebb2c1dc60e4219ec8e98dfa19553dae33608237e107db9c6078b1167856", 267 | "sha256:98e439297f78fca3a6169fd330fbe88d78b3bb72f967ad9961bcac0d7fdd1550", 268 | "sha256:bf54103892a83c64db58125b3f2a43df6d2cb2d28889f14c78519394feb41492", 269 | "sha256:d9ac82be533394d341b41d78aca7ed0e0f4ba5a2231602e2f05aa87f25c51672", 270 | "sha256:e982fe07ede9fada6ff6705af70514a52beb1b2c3d25d4e873e82114cf3c5401", 271 | "sha256:edce2ea7f3dfc981c4ddc97add8a61381d9642dc3273737e756517cc03e84dd6", 272 | "sha256:efdc45ef1afc238db84cb4963aa689c0408912a0239b0721cb172b4016eb31d6", 273 | "sha256:f137c02498f8b935892d5c0172560d7ab54bc45039de8805075e19079c639a9c", 274 | "sha256:f82e347a72f955b7017a39708a3667f106e6ad4d10b25f237396a7115d8ed5fd", 275 | "sha256:fb7c206e01ad85ce57feeaaa0bf784b97fa3cad0d4a5737bc5295785f5c613a1" 276 | ], 277 | "version": "==1.1.0" 278 | }, 279 | "mccabe": { 280 | "hashes": [ 281 | "sha256:ab8a6258860da4b6677da4bd2fe5dc2c659cff31b3ee4f7f5d64e79735b80d42", 282 | "sha256:dd8d182285a0fe56bace7f45b5e7d1a6ebcbf524e8f3bd87eb0f125271b8831f" 283 | ], 284 | "version": "==0.6.1" 285 | }, 286 | "mistune": { 287 | "hashes": [ 288 | "sha256:59a3429db53c50b5c6bcc8a07f8848cb00d7dc8bdb431a4ab41920d201d4756e", 289 | "sha256:88a1051873018da288eee8538d476dffe1262495144b33ecb586c4ab266bb8d4" 290 | ], 291 | "version": "==0.8.4" 292 | }, 293 | "mypy": { 294 | "hashes": [ 295 | "sha256:8e071ec32cc226e948a34bbb3d196eb0fd96f3ac69b6843a5aff9bd4efa14455", 296 | "sha256:fb90c804b84cfd8133d3ddfbd630252694d11ccc1eb0166a1b2efb5da37ecab2" 297 | ], 298 | "index": "pypi", 299 | "version": "==0.641" 300 | }, 301 | "mypy-extensions": { 302 | "hashes": [ 303 | "sha256:37e0e956f41369209a3d5f34580150bcacfabaa57b33a15c0b25f4b5725e0812", 304 | "sha256:b16cabe759f55e3409a7d231ebd2841378fb0c27a5d1994719e340e4f429ac3e" 305 | ], 306 | "version": "==0.4.1" 307 | }, 308 | "nbconvert": { 309 | "hashes": [ 310 | "sha256:08d21cf4203fabafd0d09bbd63f06131b411db8ebeede34b0fd4be4548351779", 311 | "sha256:a8a2749f972592aa9250db975304af6b7337f32337e523a2c995cc9e12c07807" 312 | ], 313 | "version": "==5.4.0" 314 | }, 315 | "nbformat": { 316 | "hashes": [ 317 | "sha256:b9a0dbdbd45bb034f4f8893cafd6f652ea08c8c1674ba83f2dc55d3955743b0b", 318 | "sha256:f7494ef0df60766b7cabe0a3651556345a963b74dbc16bc7c18479041170d402" 319 | ], 320 | "version": "==4.4.0" 321 | }, 322 | "notebook": { 323 | "hashes": [ 324 | "sha256:661341909008d1e7bfa1541904006f9789fa3de1cbec8379d2879819454cc04b", 325 | "sha256:91705b109fc785198faed892489cddb233265564d5e2dad5e4f7974af05ee8dd" 326 | ], 327 | "version": "==5.7.2" 328 | }, 329 | "numpy": { 330 | "hashes": [ 331 | "sha256:0df89ca13c25eaa1621a3f09af4c8ba20da849692dcae184cb55e80952c453fb", 332 | "sha256:154c35f195fd3e1fad2569930ca51907057ae35e03938f89a8aedae91dd1b7c7", 333 | "sha256:18e84323cdb8de3325e741a7a8dd4a82db74fde363dce32b625324c7b32aa6d7", 334 | "sha256:1e8956c37fc138d65ded2d96ab3949bd49038cc6e8a4494b1515b0ba88c91565", 335 | "sha256:23557bdbca3ccbde3abaa12a6e82299bc92d2b9139011f8c16ca1bb8c75d1e95", 336 | "sha256:24fd645a5e5d224aa6e39d93e4a722fafa9160154f296fd5ef9580191c755053", 337 | "sha256:36e36b6868e4440760d4b9b44587ea1dc1f06532858d10abba98e851e154ca70", 338 | "sha256:3d734559db35aa3697dadcea492a423118c5c55d176da2f3be9c98d4803fc2a7", 339 | "sha256:416a2070acf3a2b5d586f9a6507bb97e33574df5bd7508ea970bbf4fc563fa52", 340 | "sha256:4a22dc3f5221a644dfe4a63bf990052cc674ef12a157b1056969079985c92816", 341 | "sha256:4d8d3e5aa6087490912c14a3c10fbdd380b40b421c13920ff468163bc50e016f", 342 | "sha256:4f41fd159fba1245e1958a99d349df49c616b133636e0cf668f169bce2aeac2d", 343 | "sha256:561ef098c50f91fbac2cc9305b68c915e9eb915a74d9038ecf8af274d748f76f", 344 | "sha256:56994e14b386b5c0a9b875a76d22d707b315fa037affc7819cda08b6d0489756", 345 | "sha256:73a1f2a529604c50c262179fcca59c87a05ff4614fe8a15c186934d84d09d9a5", 346 | "sha256:7da99445fd890206bfcc7419f79871ba8e73d9d9e6b82fe09980bc5bb4efc35f", 347 | "sha256:99d59e0bcadac4aa3280616591fb7bcd560e2218f5e31d5223a2e12a1425d495", 348 | "sha256:a4cc09489843c70b22e8373ca3dfa52b3fab778b57cf81462f1203b0852e95e3", 349 | "sha256:a61dc29cfca9831a03442a21d4b5fd77e3067beca4b5f81f1a89a04a71cf93fa", 350 | "sha256:b1853df739b32fa913cc59ad9137caa9cc3d97ff871e2bbd89c2a2a1d4a69451", 351 | "sha256:b1f44c335532c0581b77491b7715a871d0dd72e97487ac0f57337ccf3ab3469b", 352 | "sha256:b261e0cb0d6faa8fd6863af26d30351fd2ffdb15b82e51e81e96b9e9e2e7ba16", 353 | "sha256:c857ae5dba375ea26a6228f98c195fec0898a0fd91bcf0e8a0cae6d9faf3eca7", 354 | "sha256:cf5bb4a7d53a71bb6a0144d31df784a973b36d8687d615ef6a7e9b1809917a9b", 355 | "sha256:db9814ff0457b46f2e1d494c1efa4111ca089e08c8b983635ebffb9c1573361f", 356 | "sha256:df04f4bad8a359daa2ff74f8108ea051670cafbca533bb2636c58b16e962989e", 357 | "sha256:ecf81720934a0e18526177e645cbd6a8a21bb0ddc887ff9738de07a1df5c6b61", 358 | "sha256:edfa6fba9157e0e3be0f40168eb142511012683ac3dc82420bee4a3f3981b30e" 359 | ], 360 | "version": "==1.15.4" 361 | }, 362 | "pandocfilters": { 363 | "hashes": [ 364 | "sha256:b3dd70e169bb5449e6bc6ff96aea89c5eea8c5f6ab5e207fc2f521a2cf4a0da9" 365 | ], 366 | "version": "==1.4.2" 367 | }, 368 | "parso": { 369 | "hashes": [ 370 | "sha256:35704a43a3c113cce4de228ddb39aab374b8004f4f2407d070b6a2ca784ce8a2", 371 | "sha256:895c63e93b94ac1e1690f5fdd40b65f07c8171e3e53cbd7793b5b96c0e0a7f24" 372 | ], 373 | "version": "==0.3.1" 374 | }, 375 | "pexpect": { 376 | "hashes": [ 377 | "sha256:2a8e88259839571d1251d278476f3eec5db26deb73a70be5ed5dc5435e418aba", 378 | "sha256:3fbd41d4caf27fa4a377bfd16fef87271099463e6fa73e92a52f92dfee5d425b" 379 | ], 380 | "markers": "sys_platform != 'win32'", 381 | "version": "==4.6.0" 382 | }, 383 | "pickleshare": { 384 | "hashes": [ 385 | "sha256:87683d47965c1da65cdacaf31c8441d12b8044cdec9aca500cd78fc2c683afca", 386 | "sha256:9649af414d74d4df115d5d718f82acb59c9d418196b7b4290ed47a12ce62df56" 387 | ], 388 | "version": "==0.7.5" 389 | }, 390 | "prometheus-client": { 391 | "hashes": [ 392 | "sha256:046cb4fffe75e55ff0e6dfd18e2ea16e54d86cc330f369bebcc683475c8b68a9" 393 | ], 394 | "version": "==0.4.2" 395 | }, 396 | "prompt-toolkit": { 397 | "hashes": [ 398 | "sha256:c1d6aff5252ab2ef391c2fe498ed8c088066f66bc64a8d5c095bbf795d9fec34", 399 | "sha256:d4c47f79b635a0e70b84fdb97ebd9a274203706b1ee5ed44c10da62755cf3ec9", 400 | "sha256:fd17048d8335c1e6d5ee403c3569953ba3eb8555d710bfc548faf0712666ea39" 401 | ], 402 | "version": "==2.0.7" 403 | }, 404 | "protobuf": { 405 | "hashes": [ 406 | "sha256:10394a4d03af7060fa8a6e1cbf38cea44be1467053b0aea5bbfcb4b13c4b88c4", 407 | "sha256:1489b376b0f364bcc6f89519718c057eb191d7ad6f1b395ffd93d1aa45587811", 408 | "sha256:1931d8efce896981fe410c802fd66df14f9f429c32a72dd9cfeeac9815ec6444", 409 | "sha256:196d3a80f93c537f27d2a19a4fafb826fb4c331b0b99110f985119391d170f96", 410 | "sha256:46e34fdcc2b1f2620172d3a4885128705a4e658b9b62355ae5e98f9ea19f42c2", 411 | "sha256:4b92e235a3afd42e7493b281c8b80c0c65cbef45de30f43d571d1ee40a1f77ef", 412 | "sha256:574085a33ca0d2c67433e5f3e9a0965c487410d6cb3406c83bdaf549bfc2992e", 413 | "sha256:59cd75ded98094d3cf2d79e84cdb38a46e33e7441b2826f3838dcc7c07f82995", 414 | "sha256:5ee0522eed6680bb5bac5b6d738f7b0923b3cafce8c4b1a039a6107f0841d7ed", 415 | "sha256:65917cfd5da9dfc993d5684643063318a2e875f798047911a9dd71ca066641c9", 416 | "sha256:685bc4ec61a50f7360c9fd18e277b65db90105adbf9c79938bd315435e526b90", 417 | "sha256:92e8418976e52201364a3174e40dc31f5fd8c147186d72380cbda54e0464ee19", 418 | "sha256:9335f79d1940dfb9bcaf8ec881fb8ab47d7a2c721fb8b02949aab8bbf8b68625", 419 | "sha256:a7ee3bb6de78185e5411487bef8bc1c59ebd97e47713cba3c460ef44e99b3db9", 420 | "sha256:ceec283da2323e2431c49de58f80e1718986b79be59c266bb0509cbf90ca5b9e", 421 | "sha256:fcfc907746ec22716f05ea96b7f41597dfe1a1c088f861efb8a0d4f4196a6f10" 422 | ], 423 | "version": "==3.6.1" 424 | }, 425 | "ptyprocess": { 426 | "hashes": [ 427 | "sha256:923f299cc5ad920c68f2bc0bc98b75b9f838b93b599941a6b63ddbc2476394c0", 428 | "sha256:d7cc528d76e76342423ca640335bd3633420dc1366f258cb31d05e865ef5ca1f" 429 | ], 430 | "markers": "os_name != 'nt'", 431 | "version": "==0.6.0" 432 | }, 433 | "pycodestyle": { 434 | "hashes": [ 435 | "sha256:cbc619d09254895b0d12c2c691e237b2e91e9b2ecf5e84c26b35400f93dcfb83", 436 | "sha256:cbfca99bd594a10f674d0cd97a3d802a1fdef635d4361e1a2658de47ed261e3a" 437 | ], 438 | "version": "==2.4.0" 439 | }, 440 | "pyflakes": { 441 | "hashes": [ 442 | "sha256:9a7662ec724d0120012f6e29d6248ae3727d821bba522a0e6b356eff19126a49", 443 | "sha256:f661252913bc1dbe7fcfcbf0af0db3f42ab65aabd1a6ca68fe5d466bace94dae" 444 | ], 445 | "version": "==2.0.0" 446 | }, 447 | "pygments": { 448 | "hashes": [ 449 | "sha256:6301ecb0997a52d2d31385e62d0a4a4cf18d2f2da7054a5ddad5c366cd39cee7", 450 | "sha256:82666aac15622bd7bb685a4ee7f6625dd716da3ef7473620c192c0168aae64fc" 451 | ], 452 | "version": "==2.3.0" 453 | }, 454 | "python-dateutil": { 455 | "hashes": [ 456 | "sha256:063df5763652e21de43de7d9e00ccf239f953a832941e37be541614732cdfc93", 457 | "sha256:88f9287c0174266bb0d8cedd395cfba9c58e87e5ad86b2ce58859bc11be3cf02" 458 | ], 459 | "version": "==2.7.5" 460 | }, 461 | "pyzmq": { 462 | "hashes": [ 463 | "sha256:25a0715c8f69cf72f67cfe5a68a3f3ed391c67c063d2257bec0fe7fc2c7f08f8", 464 | "sha256:2bab63759632c6b9e0d5bf19cc63c3b01df267d660e0abcf230cf0afaa966349", 465 | "sha256:30ab49d99b24bf0908ebe1cdfa421720bfab6f93174e4883075b7ff38cc555ba", 466 | "sha256:32c7ca9fc547a91e3c26fc6080b6982e46e79819e706eb414dd78f635a65d946", 467 | "sha256:41219ae72b3cc86d97557fe5b1ef5d1adc1057292ec597b50050874a970a39cf", 468 | "sha256:4b8c48a9a13cea8f1f16622f9bd46127108af14cd26150461e3eab71e0de3e46", 469 | "sha256:55724997b4a929c0d01b43c95051318e26ddbae23565018e138ae2dc60187e59", 470 | "sha256:65f0a4afae59d4fc0aad54a917ab599162613a761b760ba167d66cc646ac3786", 471 | "sha256:6f88591a8b246f5c285ee6ce5c1bf4f6bd8464b7f090b1333a446b6240a68d40", 472 | "sha256:75022a4c60dcd8765bb9ca32f6de75a0ec83b0d96e0309dc479f4c7b21f26cb7", 473 | "sha256:76ea493bfab18dcb090d825f3662b5612e2def73dffc196d51a5194b0294a81d", 474 | "sha256:7b60c045b80709e4e3c085bab9b691e71761b44c2b42dbb047b8b498e7bc16b3", 475 | "sha256:8e6af2f736734aef8ed6f278f9f552ec7f37b1a6b98e59b887484a840757f67d", 476 | "sha256:9ac2298e486524331e26390eac14e4627effd3f8e001d4266ed9d8f1d2d31cce", 477 | "sha256:9ba650f493a9bc1f24feca1d90fce0e5dd41088a252ac9840131dfbdbf3815ca", 478 | "sha256:a02a4a385e394e46012dc83d2e8fd6523f039bb52997c1c34a2e0dd49ed839c1", 479 | "sha256:a3ceee84114d9f5711fa0f4db9c652af0e4636c89eabc9b7f03a3882569dd1ed", 480 | "sha256:a72b82ac1910f2cf61a49139f4974f994984475f771b0faa730839607eeedddf", 481 | "sha256:ab136ac51027e7c484c53138a0fab4a8a51e80d05162eb7b1585583bcfdbad27", 482 | "sha256:c095b224300bcac61e6c445e27f9046981b1ac20d891b2f1714da89d34c637c8", 483 | "sha256:c5cc52d16c06dc2521340d69adda78a8e1031705924e103c0eb8fc8af861d810", 484 | "sha256:d612e9833a89e8177f8c1dc68d7b4ff98d3186cd331acd616b01bbdab67d3a7b", 485 | "sha256:e828376a23c66c6fe90dcea24b4b72cd774f555a6ee94081670872918df87a19", 486 | "sha256:e9767c7ab2eb552796440168d5c6e23a99ecaade08dda16266d43ad461730192", 487 | "sha256:ebf8b800d42d217e4710d1582b0c8bff20cdcb4faad7c7213e52644034300924" 488 | ], 489 | "version": "==17.1.2" 490 | }, 491 | "send2trash": { 492 | "hashes": [ 493 | "sha256:60001cc07d707fe247c94f74ca6ac0d3255aabcb930529690897ca2a39db28b2", 494 | "sha256:f1691922577b6fa12821234aeb57599d887c4900b9ca537948d2dac34aea888b" 495 | ], 496 | "version": "==1.5.0" 497 | }, 498 | "sentencepiece": { 499 | "hashes": [ 500 | "sha256:0e57932107f6c1c6e28d9c159a18f88cdcef1830382f53e6030ea613cd4ca2a9", 501 | "sha256:25227735000796a5c8e4f1ae8e440a84d9a42b7a82027782622b3e48922d02e6", 502 | "sha256:2587288fcb9b89e74550911ef484a849192518ae2e0676d83eea67512b99a0ed", 503 | "sha256:3945dd4a72de68bf139d1f0cbed28dc0fada2c615a5a30dccd8e5fa929b75918", 504 | "sha256:40b786265ab0f19aaeb1d5c9528a20c8c3fcf96d0db2b968ac4b883f07e9ff83", 505 | "sha256:4423e4471820467b9d569c0967e52c2cf7e2a841a494023210ab363430f7acb9", 506 | "sha256:4460cf088f815a6a56cb3048bb0b8254f6463fa50ea35ba5ee304dc153c46289", 507 | "sha256:4537054788ddad5abc02c506a1097666c5dc39a98d55f3990a2112463adda6da", 508 | "sha256:523426fc87b38cbdb6cfd137572a24d8eee8ae6382beb10b9799c513015ff184", 509 | "sha256:6a4f03da810707591fca43b3cbbd5db3ee1e4b3bfccfeeb3f9ef562f918839f7", 510 | "sha256:7121cd814aa91c858de5dfed241540a00592b5df436d59bfc3348462ebdc0a0d", 511 | "sha256:7464b29bdfa59c86a4c084e51255c0c7443a1b824bb08e5199be15c6c911b5f1", 512 | "sha256:7b32da16dd6b0eed6741705d76205de54b8a501c884633c83f74b3f339732c5a", 513 | "sha256:7b68c5f2ddbeec16e2f82e2b2c13f3a2e8fb319b5d30ce814dc6ab0df11bea87", 514 | "sha256:8ada819bdc982c70c8b5f2e9df8dba1dd7599a6ee6fef8d8fd5613f97e176280", 515 | "sha256:91d8b1b8b22767272504d49c1749cda746d3e0287a9872c3628bd883b810b4e6", 516 | "sha256:99c3d4a263a062282b10cadb8c1f909ffb6e78106514a52cbaf5b8c1524ba569", 517 | "sha256:9deeaa940b380615d017b016721281d7ee9f9d512a8f7919a4c9c958f6d6b5cf", 518 | "sha256:a78fa55a683b45369698cd00350e7ee3d8dd1e6ea4eaa939311282ac79433798", 519 | "sha256:c8001d6cd43d80780d26b5b3c0033bd00afe849ec6f9577781ab6eab14c92c74", 520 | "sha256:d3a510283181b11b65af86ddbd444d51bcaaf6ff2b0c02f04171e7f247c82dc0", 521 | "sha256:df156c6bc13fa8210abb84d124c0c6beb6a4da537de2dd37eda639a871cb5006", 522 | "sha256:e05d10d4e0a8834a9d0f09a40ba93659b2c926f16aecad06c91266247318c18f" 523 | ], 524 | "index": "pypi", 525 | "version": "==0.1.6" 526 | }, 527 | "six": { 528 | "hashes": [ 529 | "sha256:70e8a77beed4562e7f14fe23a786b54f6296e34344c23bc42f07b15018ff98e9", 530 | "sha256:832dc0e10feb1aa2c68dcc57dbb658f1c7e65b9b61af69048abc87a2db00a0eb" 531 | ], 532 | "version": "==1.11.0" 533 | }, 534 | "tensorboard": { 535 | "hashes": [ 536 | "sha256:537603db949e10d2f5f201d88b073f3f8fb4e4c311d5541e1d4518aa59aa8daa", 537 | "sha256:ca275a7e39797946930d7d4460999369b73968e8191f2256e23bfb7924004d59" 538 | ], 539 | "version": "==1.12.0" 540 | }, 541 | "tensorflow-gpu": { 542 | "hashes": [ 543 | "sha256:12902549817d2f093f3045f7861df84a5936e8f14469d11c5a5622c85455b96c", 544 | "sha256:435a9a4a37c1a92f9bc80f577f0328775539c593b9bc9e943712a204ada11db5", 545 | "sha256:6e9e6b73cc6dc6b82a8e09f9688a8806f44dbe02c4e92cb9c36efea30a7cd47e", 546 | "sha256:bf2c1e660c533102db2a81fad21a26213f4e4ff5ce6b841c0d9adc4ac3c5c6bc", 547 | "sha256:ce47aaa4ddf8446c9c9a83d968c2beba93feefaf796f1255ec6e361e4dd0e13a", 548 | "sha256:d02f018e46ee0d45a86bd27c5635b936330ab7e180c43029d1b3c4cebc7c2c45", 549 | "sha256:da799ad89780c21380fdbb99f3ecf73488dbfdca0715493c6931c2710c710e62" 550 | ], 551 | "index": "pypi", 552 | "version": "==1.12.0" 553 | }, 554 | "termcolor": { 555 | "hashes": [ 556 | "sha256:1d6d69ce66211143803fbc56652b41d73b4a400a2891d7bf7a1cdf4c02de613b" 557 | ], 558 | "version": "==1.1.0" 559 | }, 560 | "terminado": { 561 | "hashes": [ 562 | "sha256:55abf9ade563b8f9be1f34e4233c7b7bde726059947a593322e8a553cc4c067a", 563 | "sha256:65011551baff97f5414c67018e908110693143cfbaeb16831b743fe7cad8b927" 564 | ], 565 | "version": "==0.8.1" 566 | }, 567 | "testpath": { 568 | "hashes": [ 569 | "sha256:46c89ebb683f473ffe2aab0ed9f12581d4d078308a3cb3765d79c6b2317b0109", 570 | "sha256:b694b3d9288dbd81685c5d2e7140b81365d46c29f5db4bc659de5aa6b98780f8" 571 | ], 572 | "version": "==0.4.2" 573 | }, 574 | "tornado": { 575 | "hashes": [ 576 | "sha256:0662d28b1ca9f67108c7e3b77afabfb9c7e87bde174fbda78186ecedc2499a9d", 577 | "sha256:4e5158d97583502a7e2739951553cbd88a72076f152b4b11b64b9a10c4c49409", 578 | "sha256:732e836008c708de2e89a31cb2fa6c0e5a70cb60492bee6f1ea1047500feaf7f", 579 | "sha256:8154ec22c450df4e06b35f131adc4f2f3a12ec85981a203301d310abf580500f", 580 | "sha256:8e9d728c4579682e837c92fdd98036bd5cdefa1da2aaf6acf26947e6dd0c01c5", 581 | "sha256:d4b3e5329f572f055b587efc57d29bd051589fb5a43ec8898c77a47ec2fa2bbb", 582 | "sha256:e5f2585afccbff22390cddac29849df463b252b711aa2ce7c5f3f342a5b3b444" 583 | ], 584 | "version": "==5.1.1" 585 | }, 586 | "tqdm": { 587 | "hashes": [ 588 | "sha256:3c4d4a5a41ef162dd61f1edb86b0e1c7859054ab656b2e7c7b77e7fbf6d9f392", 589 | "sha256:5b4d5549984503050883bc126280b386f5f4ca87e6c023c5d015655ad75bdebb" 590 | ], 591 | "index": "pypi", 592 | "version": "==4.28.1" 593 | }, 594 | "traitlets": { 595 | "hashes": [ 596 | "sha256:9c4bd2d267b7153df9152698efb1050a5d84982d3384a37b2c1f7723ba3e7835", 597 | "sha256:c6cb5e6f57c5a9bdaa40fa71ce7b4af30298fbab9ece9815b5d995ab6217c7d9" 598 | ], 599 | "version": "==4.3.2" 600 | }, 601 | "typed-ast": { 602 | "hashes": [ 603 | "sha256:0948004fa228ae071054f5208840a1e88747a357ec1101c17217bfe99b299d58", 604 | "sha256:10703d3cec8dcd9eef5a630a04056bbc898abc19bac5691612acba7d1325b66d", 605 | "sha256:1f6c4bd0bdc0f14246fd41262df7dfc018d65bb05f6e16390b7ea26ca454a291", 606 | "sha256:25d8feefe27eb0303b73545416b13d108c6067b846b543738a25ff304824ed9a", 607 | "sha256:29464a177d56e4e055b5f7b629935af7f49c196be47528cc94e0a7bf83fbc2b9", 608 | "sha256:2e214b72168ea0275efd6c884b114ab42e316de3ffa125b267e732ed2abda892", 609 | "sha256:3e0d5e48e3a23e9a4d1a9f698e32a542a4a288c871d33ed8df1b092a40f3a0f9", 610 | "sha256:519425deca5c2b2bdac49f77b2c5625781abbaf9a809d727d3a5596b30bb4ded", 611 | "sha256:57fe287f0cdd9ceaf69e7b71a2e94a24b5d268b35df251a88fef5cc241bf73aa", 612 | "sha256:668d0cec391d9aed1c6a388b0d5b97cd22e6073eaa5fbaa6d2946603b4871efe", 613 | "sha256:68ba70684990f59497680ff90d18e756a47bf4863c604098f10de9716b2c0bdd", 614 | "sha256:6de012d2b166fe7a4cdf505eee3aaa12192f7ba365beeefaca4ec10e31241a85", 615 | "sha256:79b91ebe5a28d349b6d0d323023350133e927b4de5b651a8aa2db69c761420c6", 616 | "sha256:8550177fa5d4c1f09b5e5f524411c44633c80ec69b24e0e98906dd761941ca46", 617 | "sha256:898f818399cafcdb93cbbe15fc83a33d05f18e29fb498ddc09b0214cdfc7cd51", 618 | "sha256:94b091dc0f19291adcb279a108f5d38de2430411068b219f41b343c03b28fb1f", 619 | "sha256:a26863198902cda15ab4503991e8cf1ca874219e0118cbf07c126bce7c4db129", 620 | "sha256:a8034021801bc0440f2e027c354b4eafd95891b573e12ff0418dec385c76785c", 621 | "sha256:bc978ac17468fe868ee589c795d06777f75496b1ed576d308002c8a5756fb9ea", 622 | "sha256:c05b41bc1deade9f90ddc5d988fe506208019ebba9f2578c622516fd201f5863", 623 | "sha256:c9b060bd1e5a26ab6e8267fd46fc9e02b54eb15fffb16d112d4c7b1c12987559", 624 | "sha256:edb04bdd45bfd76c8292c4d9654568efaedf76fe78eb246dde69bdb13b2dad87", 625 | "sha256:f19f2a4f547505fe9072e15f6f4ae714af51b5a681a97f187971f50c283193b6" 626 | ], 627 | "version": "==1.1.0" 628 | }, 629 | "wcwidth": { 630 | "hashes": [ 631 | "sha256:3df37372226d6e63e1b1e1eda15c594bca98a22d33a23832a90998faa96bc65e", 632 | "sha256:f4ebe71925af7b40a864553f761ed559b43544f8f71746c2d756c7fe788ade7c" 633 | ], 634 | "version": "==0.1.7" 635 | }, 636 | "webencodings": { 637 | "hashes": [ 638 | "sha256:a0af1213f3c2226497a97e2b3aa01a7e4bee4f403f95be16fc9acd2947514a78", 639 | "sha256:b36a1c245f2d304965eb4e0a82848379241dc04b865afcc4aab16748587e1923" 640 | ], 641 | "version": "==0.5.1" 642 | }, 643 | "werkzeug": { 644 | "hashes": [ 645 | "sha256:c3fd7a7d41976d9f44db327260e263132466836cef6f91512889ed60ad26557c", 646 | "sha256:d5da73735293558eb1651ee2fddc4d0dedcfa06538b8813a2e20011583c9e49b" 647 | ], 648 | "version": "==0.14.1" 649 | }, 650 | "wheel": { 651 | "hashes": [ 652 | "sha256:029703bf514e16c8271c3821806a1c171220cc5bdd325cbf4e7da1e056a01db6", 653 | "sha256:1e53cdb3f808d5ccd0df57f964263752aa74ea7359526d3da6c02114ec1e1d44" 654 | ], 655 | "markers": "python_version >= '3'", 656 | "version": "==0.32.3" 657 | } 658 | }, 659 | "develop": {} 660 | } 661 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep Dialog Tutorial 2 | 会話モデルネタでいろいろ追加していくリポジトリ 3 | 4 | - [Transformer](https://github.com/halhorn/deep_dialog_tutorial/tree/master/deepdialog/transformer) 5 | - [RNNLM](https://github.com/halhorn/deep_dialog_tutorial/tree/master/deepdialog/rnnlm) 6 | 7 | # Install 8 | python は python3 を想定してます。 9 | 10 | ```zsh 11 | git clone git@github.com:halhorn/deep_dialog_tutorial.git 12 | cd deep_dialog_tutorial 13 | pip install pipenv 14 | pipenv install 15 | 16 | # 起動 17 | pipenv run jupyter lab 18 | ``` 19 | 20 | # Transformer 21 | [コード](https://github.com/halhorn/deep_dialog_tutorial/tree/master/deepdialog/transformer) 22 | [作って理解する Transformer / Attention](https://qiita.com/halhorn/private/c91497522be27bde17ce) 23 | 24 | # RNNLM 25 | rnnlm.ipynb 26 | 27 | RNN の言語モデル。 28 | たくさんの文章集合から、それっぽい文章を生成するモデルです。 29 | 30 | - 学習時:上から順に Train のセクションまで実行してください 31 | - 生成時:Train 以外のそれより上と、 Restore, Generate を実行してください。 32 | - Restore 時のモデルのパスは適宜変えてください。 33 | 34 | # Testing 35 | ```py3 36 | ./test/run 37 | # or 38 | ./test/run deepdialog/transformer/transformer.py 39 | ``` -------------------------------------------------------------------------------- /deepdialog/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/halhorn/deep_dialog_tutorial/134d68db94d2a4cb6267a2fa64d2db0a5b70449e/deepdialog/__init__.py -------------------------------------------------------------------------------- /deepdialog/datasource/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/halhorn/deep_dialog_tutorial/134d68db94d2a4cb6267a2fa64d2db0a5b70449e/deepdialog/datasource/__init__.py -------------------------------------------------------------------------------- /deepdialog/datasource/aozora_converter.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# 青空文庫データコンバーター\n", 8 | "\n", 9 | "青空文庫の .txt 形式のファイルを適当なディレクトリに複数置いておけば、それらをつなげて本文の文(。で区切られる)毎に一行となるフォーマットに変換してくれます。\n", 10 | "雑にやってるので章や物語の切れ目などが考えられていなかったり、発話の途中の。で切られてたりします。" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "# カレントディレクトリをリポジトリ直下にするおまじない\n", 20 | "import os\n", 21 | "while os.getcwd().split('/')[-1] != 'deep_dialog_tutorial': os.chdir('..')\n", 22 | "print('current dir:', os.getcwd())" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": null, 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "import os\n", 32 | "import re\n", 33 | "import codecs\n", 34 | "from tqdm import tqdm" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": null, 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "# ここに青空文庫からダウンロードしたファイルを配置してください\n", 44 | "data_dir = 'tmp/natsume/'\n", 45 | "out_path = 'data/natsume.txt'\n", 46 | "\n", 47 | "os.makedirs(data_dir, exist_ok=True)" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": null, 53 | "metadata": {}, 54 | "outputs": [], 55 | "source": [ 56 | "def convert_line(line):\n", 57 | " line = line.replace('\\r\\n', '').replace('○', '').replace('\\u3000', '').replace('※', '').replace('|', '')\n", 58 | " line = re.sub(r'《.*?》', '', line)\n", 59 | " line = re.sub(r'\\[.*?\\]', '', line)\n", 60 | " if not line:\n", 61 | " return []\n", 62 | " sentence_list = line.split('。')\n", 63 | " return [sentence + '。' for sentence in sentence_list if sentence]\n", 64 | "\n", 65 | "def get_text_list(file_path):\n", 66 | " with codecs.open(file_path, mode='r', encoding='shift_jis') as f:\n", 67 | " header_sep_count = 0\n", 68 | " text_list = []\n", 69 | " for line in f.readlines():\n", 70 | " if line.startswith('-----------------'):\n", 71 | " header_sep_count += 1\n", 72 | " continue\n", 73 | " if header_sep_count < 2:\n", 74 | " continue\n", 75 | " if line.startswith('底本:'):\n", 76 | " break\n", 77 | " text_list += convert_line(line)\n", 78 | " return text_list" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": null, 84 | "metadata": {}, 85 | "outputs": [], 86 | "source": [ 87 | "text_list = []\n", 88 | "\n", 89 | "file_name_list = os.listdir(data_dir)\n", 90 | "for file_name in tqdm(file_name_list):\n", 91 | " file_path = os.path.join(data_dir, file_name)\n", 92 | " text_list += get_text_list(file_path)\n", 93 | "print('{} lines'.format(len(text_list)))" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": null, 99 | "metadata": {}, 100 | "outputs": [], 101 | "source": [ 102 | "with codecs.open(out_path, mode='w', encoding='utf-8') as f:\n", 103 | " f.write('\\n'.join(text_list) + '\\n')" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": null, 109 | "metadata": {}, 110 | "outputs": [], 111 | "source": [] 112 | } 113 | ], 114 | "metadata": { 115 | "kernelspec": { 116 | "display_name": "Python 3", 117 | "language": "python", 118 | "name": "python3" 119 | }, 120 | "language_info": { 121 | "codemirror_mode": { 122 | "name": "ipython", 123 | "version": 3 124 | }, 125 | "file_extension": ".py", 126 | "mimetype": "text/x-python", 127 | "name": "python", 128 | "nbconvert_exporter": "python", 129 | "pygments_lexer": "ipython3", 130 | "version": "3.6.5" 131 | } 132 | }, 133 | "nbformat": 4, 134 | "nbformat_minor": 2 135 | } 136 | -------------------------------------------------------------------------------- /deepdialog/rnnlm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/halhorn/deep_dialog_tutorial/134d68db94d2a4cb6267a2fa64d2db0a5b70449e/deepdialog/rnnlm/__init__.py -------------------------------------------------------------------------------- /deepdialog/rnnlm/rnnlm.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# RNNLM\n", 8 | "Recurrent Neural Network Language Model\n", 9 | "RNN による言語モデルです。\n", 10 | "文章の集団を学習させることで、それっぽい文章を生成できます。\n", 11 | "\n", 12 | "これが発展して Seq2Seq のデコーダー部分になっていきます。" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 1, 18 | "metadata": {}, 19 | "outputs": [ 20 | { 21 | "name": "stdout", 22 | "output_type": "stream", 23 | "text": [ 24 | "current dir: /home/harumitsu.nobuta/git/deep_dialog_tutorial\n" 25 | ] 26 | } 27 | ], 28 | "source": [ 29 | "# カレントディレクトリをリポジトリ直下にするおまじない\n", 30 | "import os\n", 31 | "while os.getcwd().split('/')[-1] != 'deep_dialog_tutorial': os.chdir('..')\n", 32 | "print('current dir:', os.getcwd())" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 2, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "import tensorflow as tf\n", 42 | "from tensorflow.python.layers import core as layers_core\n", 43 | "import numpy as np\n", 44 | "import os\n", 45 | "import random\n", 46 | "import collections\n", 47 | "from tqdm import tqdm" 48 | ] 49 | }, 50 | { 51 | "cell_type": "markdown", 52 | "metadata": {}, 53 | "source": [ 54 | "# Create Model" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 3, 60 | "metadata": {}, 61 | "outputs": [], 62 | "source": [ 63 | "hidden_dim = 1024\n", 64 | "embedding_dim = 256\n", 65 | "vocab_size = 1000" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": 4, 71 | "metadata": {}, 72 | "outputs": [], 73 | "source": [ 74 | "# 入出力部分\n", 75 | "in_ph = tf.placeholder(tf.int32, shape=[None, None], name='in_ph')\n", 76 | "out_ph = tf.placeholder(tf.int32, shape=[None, None], name='out_ph')\n", 77 | "len_ph = tf.placeholder(tf.int32, shape=[None], name='len_ph')\n", 78 | "gen_start_token_ph = tf.placeholder(tf.int32, shape=[], name='gen_start_token_ph')" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": 5, 84 | "metadata": {}, 85 | "outputs": [], 86 | "source": [ 87 | "def debug(ops):\n", 88 | " '''与えられた計算ノードの値を表示します。'''\n", 89 | " with tf.Session() as sess:\n", 90 | " sess.run(tf.global_variables_initializer())\n", 91 | " result = sess.run(ops, {\n", 92 | " in_ph: [[30, 40, 50], [160, 170, 180]],\n", 93 | " out_ph:[[40, 50, 60], [170, 180, 190]],\n", 94 | " len_ph:[3, 3]\n", 95 | " })\n", 96 | " print('## {}\\nshape: {}'.format(ops.name, ops.shape))\n", 97 | " print(result)" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": 6, 103 | "metadata": {}, 104 | "outputs": [ 105 | { 106 | "name": "stdout", 107 | "output_type": "stream", 108 | "text": [ 109 | "## embedding_lookup/Identity:0\n", 110 | "shape: (?, ?, 256)\n", 111 | "[[[-0.20361237 -0.11308474 -0.41880432 ... 2.7252874 -1.0681399\n", 112 | " 1.3086659 ]\n", 113 | " [-0.3352617 -1.0744342 0.9656708 ... 1.1087787 1.8505251\n", 114 | " 0.02086403]\n", 115 | " [-0.48931482 1.2667885 0.58199185 ... 0.02560114 0.500132\n", 116 | " -3.2564793 ]]\n", 117 | "\n", 118 | " [[-0.12822877 -0.22769526 1.352034 ... 1.9360523 -0.34742078\n", 119 | " -0.40487522]\n", 120 | " [-2.3179045 0.57485205 -0.754861 ... -1.3065025 1.2339923\n", 121 | " 0.01515338]\n", 122 | " [ 0.6592995 -0.6290501 -0.36402264 ... 0.42282453 0.3705973\n", 123 | " 0.5638999 ]]]\n" 124 | ] 125 | } 126 | ], 127 | "source": [ 128 | "# embeddings - 文字の ID から分散表現のベクトルに変換します。\n", 129 | "# データは [batch_size, sentence_len, embedding_dim] の形になります。\n", 130 | "embeddings = tf.Variable(tf.random_normal([vocab_size, embedding_dim], stddev=1), name='embeddings', dtype=tf.float32)\n", 131 | "in_embedded = tf.nn.embedding_lookup(embeddings, in_ph)\n", 132 | "debug(in_embedded)" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": 7, 138 | "metadata": {}, 139 | "outputs": [ 140 | { 141 | "name": "stdout", 142 | "output_type": "stream", 143 | "text": [ 144 | "## output_layer/Tensordot:0\n", 145 | "shape: (?, ?, 1000)\n", 146 | "[[[ 0.02734021 0.11152687 -0.12108772 ... -0.02212429 -0.09032315\n", 147 | " 0.03886116]\n", 148 | " [ 0.15023798 0.12774038 0.01580057 ... 0.13109946 -0.19433439\n", 149 | " 0.08362214]\n", 150 | " [ 0.11481191 0.05790965 -0.00183003 ... 0.05068074 -0.42197746\n", 151 | " 0.042301 ]]\n", 152 | "\n", 153 | " [[ 0.17572726 0.10455676 -0.05352221 ... 0.10981397 -0.15552926\n", 154 | " -0.07315858]\n", 155 | " [ 0.07044679 -0.0157798 0.10679971 ... 0.0757293 -0.04247508\n", 156 | " 0.08267353]\n", 157 | " [-0.11602538 -0.07365637 -0.05860609 ... 0.08297199 -0.10835387\n", 158 | " 0.24690573]]]\n" 159 | ] 160 | } 161 | ], 162 | "source": [ 163 | "# RNN 部分\n", 164 | "cell = tf.nn.rnn_cell.GRUCell(hidden_dim, kernel_initializer=tf.orthogonal_initializer)\n", 165 | "rnn_out, final_state = tf.nn.dynamic_rnn(\n", 166 | " cell=cell,\n", 167 | " inputs=in_embedded,\n", 168 | " sequence_length=len_ph,\n", 169 | " dtype=tf.float32,\n", 170 | " scope='rnn',\n", 171 | ")\n", 172 | "# 隠れ層から全結合をかませて、各単語の生成確率っぽい値にする。\n", 173 | "# (i番目のニューロンの出力が id: i の単語の生成確率っぽいものになる)\n", 174 | "output_layer = layers_core.Dense(vocab_size, use_bias=False, name='output_layer')\n", 175 | "onehot_logits = output_layer.apply(rnn_out)\n", 176 | "debug(onehot_logits)\n", 177 | "output_ids_op = tf.argmax(onehot_logits, -1)" 178 | ] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "execution_count": 8, 183 | "metadata": {}, 184 | "outputs": [ 185 | { 186 | "name": "stdout", 187 | "output_type": "stream", 188 | "text": [ 189 | "## loss:0\n", 190 | "shape: ()\n", 191 | "6.907025\n" 192 | ] 193 | } 194 | ], 195 | "source": [ 196 | "cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(\n", 197 | " labels=out_ph,\n", 198 | " logits=onehot_logits,\n", 199 | ")\n", 200 | "loss_op = tf.reduce_mean(cross_entropy, name='loss')\n", 201 | "debug(loss_op)" 202 | ] 203 | }, 204 | { 205 | "cell_type": "code", 206 | "execution_count": 9, 207 | "metadata": {}, 208 | "outputs": [], 209 | "source": [ 210 | "# 生成時用の RNN\n", 211 | "beam_width = 20\n", 212 | "gen_max_len = 500\n", 213 | "start_tokens = tf.ones([1], tf.int32) * gen_start_token_ph # 生成時の batch_size は1\n", 214 | "\n", 215 | "decoder = tf.contrib.seq2seq.BeamSearchDecoder(\n", 216 | " cell=cell,\n", 217 | " embedding=embeddings,\n", 218 | " start_tokens=start_tokens, \n", 219 | " end_token=0, # dummy\n", 220 | " initial_state=cell.zero_state(beam_width, tf.float32),\n", 221 | " beam_width=beam_width,\n", 222 | " output_layer=output_layer,\n", 223 | ")\n", 224 | "\n", 225 | "beam_decoder_output = tf.contrib.seq2seq.dynamic_decode(\n", 226 | " decoder=decoder,\n", 227 | " maximum_iterations=500,\n", 228 | " scope='generator_decode'\n", 229 | ")[0]\n", 230 | "generate_op = beam_decoder_output.predicted_ids" 231 | ] 232 | }, 233 | { 234 | "cell_type": "markdown", 235 | "metadata": {}, 236 | "source": [ 237 | "# Load and Convert Data" 238 | ] 239 | }, 240 | { 241 | "cell_type": "code", 242 | "execution_count": 10, 243 | "metadata": {}, 244 | "outputs": [], 245 | "source": [ 246 | "sentence_len = 50\n", 247 | "batch_size = 512\n", 248 | "data_path = 'data/natsume.txt'" 249 | ] 250 | }, 251 | { 252 | "cell_type": "code", 253 | "execution_count": 11, 254 | "metadata": {}, 255 | "outputs": [], 256 | "source": [ 257 | "class Tokenizer:\n", 258 | " def __init__(self, vocab):\n", 259 | " self.vocab = vocab\n", 260 | " self.rev_dict = {c: i for i, c in enumerate(vocab)}\n", 261 | " self.pad = 0\n", 262 | " self.bos = 1\n", 263 | " self.eos = 2\n", 264 | " self.unk = 3\n", 265 | " \n", 266 | " @classmethod\n", 267 | " def from_text(cls, text):\n", 268 | " char_freq_tuples = collections.Counter(text).most_common(vocab_size - 4)\n", 269 | " vocab, _ = zip(*char_freq_tuples)\n", 270 | " vocab = ['', '', '', ''] + list(vocab)\n", 271 | " return cls(vocab)\n", 272 | "\n", 273 | " @property\n", 274 | " def vocab_size(self):\n", 275 | " return len(self.vocab_size)\n", 276 | " \n", 277 | " def text2id(self, text):\n", 278 | " return [self.rev_dict[c] if c in self.rev_dict else self.unk for c in text]\n", 279 | "\n", 280 | " def id2text(self, ids):\n", 281 | " return ''.join(self.vocab[i] for i in ids)" 282 | ] 283 | }, 284 | { 285 | "cell_type": "code", 286 | "execution_count": 12, 287 | "metadata": {}, 288 | "outputs": [], 289 | "source": [ 290 | "with open(data_path) as f:\n", 291 | " text = f.read().replace('\\n', '')\n", 292 | "\n", 293 | "tokenizer = Tokenizer.from_text(text)\n", 294 | "ids = tokenizer.text2id(text)\n", 295 | "\n", 296 | "def split_ndlist(ndlist, size):\n", 297 | " return [np.array(ndlist[i - size:i]) for i in range(size, len(ndlist) + 1, size)]\n", 298 | "\n", 299 | "# (1文字目, 2文字目), (2文字目, 3文字目), ... というペアを作る\n", 300 | "# ある時刻の入力に対しその次時刻の出力を学習させるため\n", 301 | "in_sequence_list = split_ndlist(ids[:-1], size=sentence_len)\n", 302 | "out_sequence_list = split_ndlist(ids[1:], size=sentence_len)\n", 303 | "\n", 304 | "in_batch_list = split_ndlist(in_sequence_list, batch_size)\n", 305 | "out_batch_list = split_ndlist(out_sequence_list, batch_size)\n", 306 | "\n", 307 | "# batch_size 個ごとに切り分け\n", 308 | "batch_list = [\n", 309 | " {\n", 310 | " 'in': in_batch,\n", 311 | " 'out': out_batch,\n", 312 | " 'len': np.array([len(seq) for seq in in_batch]),\n", 313 | " }\n", 314 | " for in_batch, out_batch\n", 315 | " in zip(in_batch_list, out_batch_list)\n", 316 | "]" 317 | ] 318 | }, 319 | { 320 | "cell_type": "code", 321 | "execution_count": 13, 322 | "metadata": { 323 | "scrolled": true 324 | }, 325 | "outputs": [ 326 | { 327 | "name": "stdout", 328 | "output_type": "stream", 329 | "text": [ 330 | "[33, 27, 8, 51, 14, 3]\n", 331 | "こんにちは\n", 332 | "batch list num: 129\n", 333 | "{'in': array([[ 3, 77, 8, ..., 17, 224, 38],\n", 334 | " [ 12, 16, 55, ..., 4, 317, 14],\n", 335 | " [491, 3, 120, ..., 27, 25, 18],\n", 336 | " ...,\n", 337 | " [ 19, 25, 12, ..., 190, 255, 165],\n", 338 | " [ 11, 23, 4, ..., 10, 49, 266],\n", 339 | " [ 30, 12, 15, ..., 4, 14, 55]]), 'out': array([[ 77, 8, 3, ..., 224, 38, 12],\n", 340 | " [ 16, 55, 46, ..., 317, 14, 491],\n", 341 | " [ 3, 120, 3, ..., 25, 18, 7],\n", 342 | " ...,\n", 343 | " [ 25, 12, 10, ..., 255, 165, 11],\n", 344 | " [ 23, 4, 19, ..., 49, 266, 30],\n", 345 | " [ 12, 15, 13, ..., 14, 55, 109]]), 'len': array([50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50,\n", 346 | " 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50,\n", 347 | " 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50,\n", 348 | " 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50,\n", 349 | " 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50,\n", 350 | " 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50,\n", 351 | " 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50,\n", 352 | " 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50,\n", 353 | " 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50,\n", 354 | " 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50,\n", 355 | " 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50,\n", 356 | " 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50,\n", 357 | " 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50,\n", 358 | " 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50,\n", 359 | " 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50,\n", 360 | " 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50,\n", 361 | " 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50,\n", 362 | " 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50,\n", 363 | " 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50,\n", 364 | " 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50,\n", 365 | " 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50,\n", 366 | " 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50,\n", 367 | " 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50,\n", 368 | " 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50,\n", 369 | " 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50,\n", 370 | " 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50,\n", 371 | " 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50,\n", 372 | " 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50,\n", 373 | " 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50,\n", 374 | " 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50,\n", 375 | " 50, 50])}\n" 376 | ] 377 | } 378 | ], 379 | "source": [ 380 | "print(tokenizer.text2id('こんにちは😁'))\n", 381 | "print(tokenizer.id2text([33, 27, 8, 51, 14, 3]))\n", 382 | "print('batch list num: {}'.format(len(batch_list)))\n", 383 | "print(batch_list[0])" 384 | ] 385 | }, 386 | { 387 | "cell_type": "markdown", 388 | "metadata": {}, 389 | "source": [ 390 | "# Training" 391 | ] 392 | }, 393 | { 394 | "cell_type": "code", 395 | "execution_count": 14, 396 | "metadata": {}, 397 | "outputs": [], 398 | "source": [ 399 | "max_epoch = 50\n", 400 | "save_path = 'tmp/rnnlm/model.ckpt'\n", 401 | "log_dir = 'tmp/rnnlm/log/'\n", 402 | "learning_rate = 0.001" 403 | ] 404 | }, 405 | { 406 | "cell_type": "code", 407 | "execution_count": 15, 408 | "metadata": {}, 409 | "outputs": [], 410 | "source": [ 411 | "if not os.path.isdir(os.path.dirname(save_path)):\n", 412 | " os.makedirs(os.path.dirname(save_path))\n", 413 | "if not os.path.isdir(log_dir):\n", 414 | " os.makedirs(log_dir)\n", 415 | "\n", 416 | "global_step = tf.Variable(0, name='global_step', trainable=False)\n", 417 | "optimizer = tf.train.AdamOptimizer(learning_rate)\n", 418 | "train_op = optimizer.minimize(loss_op, global_step=global_step)\n", 419 | "tf.summary.scalar('loss', loss_op)\n", 420 | "summary_op = tf.summary.merge_all()" 421 | ] 422 | }, 423 | { 424 | "cell_type": "code", 425 | "execution_count": null, 426 | "metadata": {}, 427 | "outputs": [], 428 | "source": [ 429 | "min_loss = 100000.0\n", 430 | "sess = tf.Session()\n", 431 | "summary_writer = tf.summary.FileWriter(log_dir, sess.graph)\n", 432 | "saver = tf.train.Saver()\n", 433 | "\n", 434 | "sess.run(tf.global_variables_initializer())\n", 435 | "for epoch in range(max_epoch):\n", 436 | " random.shuffle(batch_list)\n", 437 | " for batch in tqdm(batch_list):\n", 438 | " feed_dict = {\n", 439 | " in_ph: batch['in'],\n", 440 | " out_ph: batch['out'],\n", 441 | " len_ph: batch['len'],\n", 442 | " }\n", 443 | " _, loss, summary, step = sess.run([train_op, loss_op, summary_op, global_step], feed_dict)\n", 444 | " summary_writer.add_summary(summary, step)\n", 445 | " if loss < min_loss:\n", 446 | " saver.save(sess, save_path)\n", 447 | " min_loss = loss\n", 448 | " print('epoch {}/{} - loss: {}'.format(epoch, max_epoch, loss))\n" 449 | ] 450 | }, 451 | { 452 | "cell_type": "markdown", 453 | "metadata": {}, 454 | "source": [ 455 | "# Restore" 456 | ] 457 | }, 458 | { 459 | "cell_type": "code", 460 | "execution_count": null, 461 | "metadata": {}, 462 | "outputs": [], 463 | "source": [ 464 | "load_path = 'learned_model/rnnlm/model.ckpt'\n", 465 | "sess = tf.Session()\n", 466 | "sess.run(tf.global_variables_initializer())\n", 467 | "saver = tf.train.Saver()\n", 468 | "saver.restore(sess, load_path)" 469 | ] 470 | }, 471 | { 472 | "cell_type": "markdown", 473 | "metadata": {}, 474 | "source": [ 475 | "# Generate" 476 | ] 477 | }, 478 | { 479 | "cell_type": "code", 480 | "execution_count": null, 481 | "metadata": {}, 482 | "outputs": [], 483 | "source": [ 484 | "start_char = '私'\n", 485 | "generated_ids = sess.run(generate_op, {\n", 486 | " gen_start_token_ph: tokenizer.text2id(start_char)[0]\n", 487 | "})[0, :, 0]\n", 488 | "generated_text = start_char + tokenizer.id2text(generated_ids)\n", 489 | "print(generated_text)" 490 | ] 491 | } 492 | ], 493 | "metadata": { 494 | "kernelspec": { 495 | "display_name": "Python 3", 496 | "language": "python", 497 | "name": "python3" 498 | }, 499 | "language_info": { 500 | "codemirror_mode": { 501 | "name": "ipython", 502 | "version": 3 503 | }, 504 | "file_extension": ".py", 505 | "mimetype": "text/x-python", 506 | "name": "python", 507 | "nbconvert_exporter": "python", 508 | "pygments_lexer": "ipython3", 509 | "version": "3.6.5" 510 | } 511 | }, 512 | "nbformat": 4, 513 | "nbformat_minor": 2 514 | } 515 | -------------------------------------------------------------------------------- /deepdialog/transformer/README.md: -------------------------------------------------------------------------------- 1 | # Transformer 2 | この記事では2018年現在 DeepLearning における自然言語処理のデファクトスタンダードとなりつつある Transformer の tf.keras 実装です。 3 | eager mode / graph mode のどちらでも動きます。 4 | 5 | ## Motivation 6 | [公式の Transformer](https://github.com/tensorflow/models/tree/master/official/transformer) が deprecated な tf.layers ベースで書かれており悲しいので、 tensorflow 2.0 で標準になってくる tf.keras.(layers|models) ベースでの実装を行いました。 7 | 私の理解の範囲での、より今後の tensorflow コードとして推奨される形を目指しています。 8 | 9 | また、この実装は[作って理解する Transformer / Attention](https://qiita.com/halhorn/private/c91497522be27bde17ce)の教材にもなっています。 10 | 11 | ## Install 12 | ```sh 13 | git clone git@github.com:halhorn/deep_dialog_tutorial.git 14 | cd deep_dialog_tutorial 15 | pip install pipenv 16 | pipenv install 17 | ``` 18 | 19 | ## Training 20 | ```sh 21 | pipenv run jupyter lab 22 | ``` 23 | jupyter 上で deepdialog/transformer/training.ipynb を開いてください。 24 | -------------------------------------------------------------------------------- /deepdialog/transformer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/halhorn/deep_dialog_tutorial/134d68db94d2a4cb6267a2fa64d2db0a5b70449e/deepdialog/transformer/__init__.py -------------------------------------------------------------------------------- /deepdialog/transformer/attention.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | ''' 4 | Google 公式の Transformer の Attention を tf.keras ベースとして実装しなおしたモデルです。 5 | c.f. https://github.com/tensorflow/models/blob/master/official/transformer/model/attention_layer.py 6 | ''' 7 | 8 | 9 | class MultiheadAttention(tf.keras.models.Model): 10 | ''' 11 | Multi-head Attention のモデルです。 12 | 13 | model = MultiheadAttention( 14 | hidden_dim=512, 15 | head_num=8, 16 | dropout_rate=0.1, 17 | ) 18 | model(query, memory, mask, training=True) 19 | ''' 20 | 21 | def __init__(self, hidden_dim: int, head_num: int, dropout_rate: float, *args, **kwargs): 22 | ''' 23 | コンストラクタです。 24 | :param hidden_dim: 隠れ層及び出力の次元 25 | head_num の倍数である必要があります。 26 | :param head_num: ヘッドの数 27 | :param dropout_rate: ドロップアウトする確率 28 | ''' 29 | super().__init__(*args, **kwargs) 30 | self.hidden_dim = hidden_dim 31 | self.head_num = head_num 32 | self.dropout_rate = dropout_rate 33 | 34 | self.q_dense_layer = tf.keras.layers.Dense(hidden_dim, use_bias=False, name='q_dense_layer') 35 | self.k_dense_layer = tf.keras.layers.Dense(hidden_dim, use_bias=False, name='k_dense_layer') 36 | self.v_dense_layer = tf.keras.layers.Dense(hidden_dim, use_bias=False, name='v_dense_layer') 37 | self.output_dense_layer = tf.keras.layers.Dense(hidden_dim, use_bias=False, name='output_dense_layer') 38 | self.attention_dropout_layer = tf.keras.layers.Dropout(dropout_rate) 39 | 40 | def call( 41 | self, 42 | input: tf.Tensor, 43 | memory: tf.Tensor, 44 | attention_mask: tf.Tensor, 45 | training: bool, 46 | ) -> tf.Tensor: 47 | ''' 48 | モデルの実行を行います。 49 | :param input: query のテンソル 50 | :param memory: query に情報を与える memory のテンソル 51 | :param attention_mask: attention weight に適用される mask 52 | shape = [batch_size, 1, q_length, k_length] のものです。 53 | pad 等無視する部分が True となるようなものを指定してください。 54 | :param training: 学習時か推論時かのフラグ 55 | ''' 56 | q = self.q_dense_layer(input) # [batch_size, q_length, hidden_dim] 57 | k = self.k_dense_layer(memory) # [batch_size, m_length, hidden_dim] 58 | v = self.v_dense_layer(memory) 59 | 60 | q = self._split_head(q) # [batch_size, head_num, q_length, hidden_dim/head_num] 61 | k = self._split_head(k) # [batch_size, head_num, m_length, hidden_dim/head_num] 62 | v = self._split_head(v) # [batch_size, head_num, m_length, hidden_dim/head_num] 63 | 64 | depth = self.hidden_dim // self.head_num 65 | q *= depth ** -0.5 # for scaled dot production 66 | 67 | # ここで q と k の内積を取ることで、query と key の関連度のようなものを計算します。 68 | logit = tf.matmul(q, k, transpose_b=True) # [batch_size, head_num, q_length, k_length] 69 | logit += tf.to_float(attention_mask) * input.dtype.min # mask は pad 部分などが1, 他は0 70 | 71 | # softmax を取ることで正規化します 72 | attention_weight = tf.nn.softmax(logit, name='attention_weight') 73 | attention_weight = self.attention_dropout_layer(attention_weight, training=training) 74 | 75 | # 重みに従って value から情報を引いてきます 76 | attention_output = tf.matmul(attention_weight, v) # [batch_size, head_num, q_length, hidden_dim/head_num] 77 | attention_output = self._combine_head(attention_output) # [batch_size, q_length, hidden_dim] 78 | return self.output_dense_layer(attention_output) 79 | 80 | def _split_head(self, x: tf.Tensor) -> tf.Tensor: 81 | ''' 82 | 入力の tensor の hidden_dim の次元をいくつかのヘッドに分割します。 83 | 84 | 入力 shape: [batch_size, length, hidden_dim] の時 85 | 出力 shape: [batch_size, head_num, length, hidden_dim//head_num] 86 | となります。 87 | ''' 88 | with tf.name_scope('split_head'): 89 | batch_size, length, hidden_dim = tf.unstack(tf.shape(x)) 90 | x = tf.reshape(x, [batch_size, length, self.head_num, self.hidden_dim // self.head_num]) 91 | return tf.transpose(x, [0, 2, 1, 3]) 92 | 93 | def _combine_head(self, x: tf.Tensor) -> tf.Tensor: 94 | ''' 95 | 入力の tensor の各ヘッドを結合します。 _split_head の逆変換です。 96 | 97 | 入力 shape: [batch_size, head_num, length, hidden_dim//head_num] の時 98 | 出力 shape: [batch_size, length, hidden_dim] 99 | となります。 100 | ''' 101 | with tf.name_scope('combine_head'): 102 | batch_size, _, length, _ = tf.unstack(tf.shape(x)) 103 | x = tf.transpose(x, [0, 2, 1, 3]) 104 | return tf.reshape(x, [batch_size, length, self.hidden_dim]) 105 | 106 | 107 | class SelfAttention(MultiheadAttention): 108 | ''' 109 | Multi-head Attention による自己注意です。 110 | ''' 111 | def call( # type: ignore 112 | self, 113 | input: tf.Tensor, 114 | attention_mask: tf.Tensor, 115 | training: bool, 116 | ) -> tf.Tensor: 117 | return super().call( 118 | input=input, 119 | memory=input, 120 | attention_mask=attention_mask, 121 | training=training, 122 | ) 123 | 124 | 125 | class SimpleAttention(tf.keras.models.Model): 126 | ''' 127 | Attention の説明をするための、 Multi-head ではない単純な Attention です。 128 | ''' 129 | 130 | def __init__(self, depth: int, *args, **kwargs): 131 | ''' 132 | コンストラクタです。 133 | :param depth: 隠れ層及び出力の次元 134 | ''' 135 | super().__init__(*args, **kwargs) 136 | self.depth = depth 137 | 138 | self.q_dense_layer = tf.keras.layers.Dense(depth, use_bias=False, name='q_dense_layer') 139 | self.k_dense_layer = tf.keras.layers.Dense(depth, use_bias=False, name='k_dense_layer') 140 | self.v_dense_layer = tf.keras.layers.Dense(depth, use_bias=False, name='v_dense_layer') 141 | self.output_dense_layer = tf.keras.layers.Dense(depth, use_bias=False, name='output_dense_layer') 142 | 143 | def call(self, input: tf.Tensor, memory: tf.Tensor) -> tf.Tensor: 144 | ''' 145 | モデルの実行を行います。 146 | :param input: query のテンソル 147 | :param memory: query に情報を与える memory のテンソル 148 | ''' 149 | q = self.q_dense_layer(input) # [batch_size, q_length, depth] 150 | k = self.k_dense_layer(memory) # [batch_size, m_length, depth] 151 | v = self.v_dense_layer(memory) 152 | 153 | # ここで q と k の内積を取ることで、query と key の関連度のようなものを計算します。 154 | logit = tf.matmul(q, k, transpose_b=True) # [batch_size, q_length, k_length] 155 | 156 | # softmax を取ることで正規化します 157 | attention_weight = tf.nn.softmax(logit, name='attention_weight') 158 | 159 | # 重みに従って value から情報を引いてきます 160 | attention_output = tf.matmul(attention_weight, v) # [batch_size, q_length, depth] 161 | return self.output_dense_layer(attention_output) 162 | -------------------------------------------------------------------------------- /deepdialog/transformer/common_layer.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | class FeedForwardNetwork(tf.keras.models.Model): 5 | ''' 6 | Transformer 用の Position-wise Feedforward Neural Network です。 7 | ''' 8 | def __init__(self, hidden_dim: int, dropout_rate: float, *args, **kwargs) -> None: 9 | super().__init__(*args, **kwargs) 10 | self.hidden_dim = hidden_dim 11 | self.dropout_rate = dropout_rate 12 | 13 | self.filter_dense_layer = tf.keras.layers.Dense(hidden_dim * 4, use_bias=True, 14 | activation=tf.nn.relu, name='filter_layer') 15 | self.output_dense_layer = tf.keras.layers.Dense(hidden_dim, use_bias=True, name='output_layer') 16 | self.dropout_layer = tf.keras.layers.Dropout(dropout_rate) 17 | 18 | def call(self, input: tf.Tensor, training: bool) -> tf.Tensor: 19 | ''' 20 | FeedForwardNetwork を適用します。 21 | :param input: shape = [batch_size, length, hidden_dim] 22 | :return: shape = [batch_size, length, hidden_dim] 23 | ''' 24 | tensor = self.filter_dense_layer(input) 25 | tensor = self.dropout_layer(tensor, training=training) 26 | return self.output_dense_layer(tensor) 27 | 28 | 29 | class ResidualNormalizationWrapper(tf.keras.models.Model): 30 | ''' 31 | 与えられたレイヤー(もしくはモデル)に対して、下記のノーマライゼーションを行います。 32 | - Layer Normalization 33 | - Dropout 34 | - Residual Connection 35 | ''' 36 | def __init__(self, layer: tf.keras.layers.Layer, dropout_rate: float, *args, **kwargs) -> None: 37 | super().__init__(*args, **kwargs) 38 | self.layer = layer 39 | self.layer_normalization = LayerNormalization() 40 | self.dropout_layer = tf.keras.layers.Dropout(dropout_rate) 41 | 42 | def call(self, input: tf.Tensor, training: bool, *args, **kwargs) -> tf.Tensor: 43 | tensor = self.layer_normalization(input) 44 | tensor = self.layer(tensor, training=training, *args, **kwargs) 45 | tensor = self.dropout_layer(tensor, training=training) 46 | return input + tensor 47 | 48 | 49 | class LayerNormalization(tf.keras.layers.Layer): 50 | ''' 51 | レイヤーノーマライゼーションです。 52 | レイヤーの出力が平均 bias, 標準偏差 scale になるように調整します。 53 | ''' 54 | def build(self, input_shape: tf.TensorShape) -> None: 55 | hidden_dim = input_shape[-1] 56 | self.scale = self.add_weight('layer_norm_scale', shape=[hidden_dim], 57 | initializer=tf.ones_initializer()) 58 | self.bias = self.add_weight('layer_norm_bias', [hidden_dim], 59 | initializer=tf.zeros_initializer()) 60 | super().build(input_shape) 61 | 62 | def call(self, x: tf.Tensor, epsilon: float = 1e-6) -> tf.Tensor: 63 | mean = tf.reduce_mean(x, axis=[-1], keepdims=True) 64 | variance = tf.reduce_mean(tf.square(x - mean), axis=[-1], keepdims=True) 65 | norm_x = (x - mean) * tf.rsqrt(variance + epsilon) 66 | 67 | return norm_x * self.scale + self.bias 68 | -------------------------------------------------------------------------------- /deepdialog/transformer/embedding.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import math 3 | 4 | PAD_ID = 0 5 | 6 | 7 | class TokenEmbedding(tf.keras.layers.Layer): 8 | ''' 9 | トークン列を Embedded Vector 列に変換します。 10 | ''' 11 | def __init__(self, vocab_size: int, embedding_dim: int, dtype=tf.float32, *args, **kwargs): 12 | super().__init__(*args, **kwargs) 13 | self.vocab_size = vocab_size 14 | self.embedding_dim = embedding_dim 15 | self.dtype_ = dtype 16 | 17 | def build(self, input_shape: tf.TensorShape) -> None: 18 | self.lookup_table = self.add_variable( 19 | name='token_embedding', 20 | shape=[self.vocab_size, self.embedding_dim], 21 | dtype=self.dtype_, 22 | initializer=tf.random_normal_initializer(0., self.embedding_dim ** -0.5), 23 | ) 24 | super().build(input_shape) 25 | 26 | def call(self, input: tf.Tensor) -> tf.Tensor: 27 | mask = tf.to_float(tf.not_equal(input, PAD_ID)) 28 | embedding = tf.nn.embedding_lookup(self.lookup_table, input) 29 | embedding *= tf.expand_dims(mask, -1) # 元々 PAD だった部分を0にする 30 | return embedding * self.embedding_dim ** 0.5 31 | 32 | 33 | class AddPositionalEncoding(tf.keras.layers.Layer): 34 | ''' 35 | 入力テンソルに対し、位置の情報を付与して返すレイヤーです。 36 | see: https://arxiv.org/pdf/1706.03762.pdf 37 | 38 | PE_{pos, 2i} = sin(pos / 10000^{2i / d_model}) 39 | PE_{pos, 2i+1} = cos(pos / 10000^{2i / d_model}) 40 | ''' 41 | def call(self, inputs: tf.Tensor) -> tf.Tensor: 42 | fl_type = inputs.dtype 43 | batch_size, max_length, depth = tf.unstack(tf.shape(inputs)) 44 | 45 | depth_counter = tf.range(depth) // 2 * 2 # 0, 0, 2, 2, 4, ... 46 | depth_matrix = tf.tile(tf.expand_dims(depth_counter, 0), [max_length, 1]) # [max_length, depth] 47 | depth_matrix = tf.pow(10000.0, tf.cast(depth_matrix / depth, fl_type)) # [max_length, depth] 48 | 49 | # cos(x) == sin(x + π/2) 50 | phase = tf.cast(tf.range(depth) % 2, fl_type) * math.pi / 2 # 0, π/2, 0, π/2, ... 51 | phase_matrix = tf.tile(tf.expand_dims(phase, 0), [max_length, 1]) # [max_length, depth] 52 | 53 | pos_counter = tf.range(max_length) 54 | pos_matrix = tf.cast(tf.tile(tf.expand_dims(pos_counter, 1), [1, depth]), fl_type) # [max_length, depth] 55 | 56 | positional_encoding = tf.sin(pos_matrix / depth_matrix + phase_matrix) 57 | # [batch_size, max_length, depth] 58 | positional_encoding = tf.tile(tf.expand_dims(positional_encoding, 0), [batch_size, 1, 1]) 59 | 60 | return inputs + positional_encoding 61 | -------------------------------------------------------------------------------- /deepdialog/transformer/metrics.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | # copy from https://github.com/tensorflow/models/blob/master/official/transformer/utils/metrics.py 4 | 5 | 6 | def padded_cross_entropy_loss(logits, labels, smoothing, vocab_size): 7 | """Calculate cross entropy loss while ignoring padding. 8 | Args: 9 | logits: Tensor of size [batch_size, length_logits, vocab_size] 10 | labels: Tensor of size [batch_size, length_labels] 11 | smoothing: Label smoothing constant, used to determine the on and off values 12 | vocab_size: int size of the vocabulary 13 | Returns: 14 | Returns the cross entropy loss and weight tensors: float32 tensors with 15 | shape [batch_size, max(length_logits, length_labels)] 16 | """ 17 | with tf.name_scope("loss", values=[logits, labels]): 18 | logits, labels = _pad_tensors_to_same_length(logits, labels) 19 | 20 | # Calculate smoothing cross entropy 21 | with tf.name_scope("smoothing_cross_entropy", values=[logits, labels]): 22 | confidence = 1.0 - smoothing 23 | low_confidence = (1.0 - confidence) / tf.to_float(vocab_size - 1) 24 | soft_targets = tf.one_hot( 25 | tf.cast(labels, tf.int32), 26 | depth=vocab_size, 27 | on_value=confidence, 28 | off_value=low_confidence) 29 | xentropy = tf.nn.softmax_cross_entropy_with_logits_v2( 30 | logits=logits, labels=soft_targets) 31 | 32 | # Calculate the best (lowest) possible value of cross entropy, and 33 | # subtract from the cross entropy loss. 34 | normalizing_constant = -( 35 | confidence * tf.log(confidence) + tf.to_float(vocab_size - 1) * 36 | low_confidence * tf.log(low_confidence + 1e-20)) 37 | xentropy -= normalizing_constant 38 | 39 | weights = tf.to_float(tf.not_equal(labels, 0)) 40 | return xentropy * weights, weights 41 | 42 | 43 | def padded_accuracy(logits, labels): 44 | """Percentage of times that predictions matches labels on non-0s.""" 45 | with tf.variable_scope("padded_accuracy", values=[logits, labels]): 46 | logits, labels = _pad_tensors_to_same_length(logits, labels) 47 | weights = tf.to_float(tf.not_equal(labels, 0)) 48 | outputs = tf.to_int32(tf.argmax(logits, axis=-1)) 49 | padded_labels = tf.to_int32(labels) 50 | return tf.to_float(tf.equal(outputs, padded_labels)), weights 51 | 52 | 53 | def _pad_tensors_to_same_length(x, y): 54 | """Pad x and y so that the results have the same length (second dimension).""" 55 | with tf.name_scope("pad_to_same_length"): 56 | x_length = tf.shape(x)[1] 57 | y_length = tf.shape(y)[1] 58 | 59 | max_length = tf.maximum(x_length, y_length) 60 | 61 | x = tf.pad(x, [[0, 0], [0, max_length - x_length], [0, 0]]) 62 | y = tf.pad(y, [[0, 0], [0, max_length - y_length]]) 63 | return x, y 64 | -------------------------------------------------------------------------------- /deepdialog/transformer/preprocess/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/halhorn/deep_dialog_tutorial/134d68db94d2a4cb6267a2fa64d2db0a5b70449e/deepdialog/transformer/preprocess/__init__.py -------------------------------------------------------------------------------- /deepdialog/transformer/preprocess/batch_generator.py: -------------------------------------------------------------------------------- 1 | import sentencepiece as spm 2 | import numpy as np 3 | import random 4 | from typing import List, Sequence, Tuple 5 | 6 | ENCODER_INPUT_NODE = 'transformer/encoder_input:0' 7 | DECODER_INPUT_NODE = 'transformer/decoder_input:0' 8 | IS_TRAINING_NODE = 'transformer/is_training:0' 9 | 10 | 11 | class BatchGenerator: 12 | def __init__( 13 | self, 14 | max_length=50, 15 | spm_model_path: str = 'deepdialog/transformer/preprocess/spm_natsume.model' 16 | ) -> None: 17 | self.max_length = max_length 18 | self.sp = spm.SentencePieceProcessor() 19 | self.sp.load(spm_model_path) 20 | self.bos = self.sp.piece_to_id('') 21 | self.eos = self.sp.piece_to_id('') 22 | self.pad = 0 23 | 24 | @property 25 | def vocab_size(self) -> int: 26 | return self.sp.get_piece_size() 27 | 28 | def load(self, file_path: str) -> None: 29 | with open(file_path) as f: 30 | lines = [line.strip() for line in f.readlines()] 31 | self.data = self._create_data(lines) 32 | 33 | def get_batch(self, batch_size: int = 128, shuffle=True): 34 | while True: 35 | if shuffle: 36 | random.shuffle(self.data) 37 | raw_batch_list = self._split(self.data, batch_size) 38 | for raw_batch in raw_batch_list: 39 | questions, answers = zip(*raw_batch) 40 | yield { 41 | ENCODER_INPUT_NODE: self._convert_to_array(questions), 42 | DECODER_INPUT_NODE: self._convert_to_array(answers), 43 | IS_TRAINING_NODE: True, 44 | } 45 | 46 | def _create_data(self, lines: Sequence[str]) -> List[Tuple[List[int], List[int]]]: 47 | questions = [self._create_question(line) for line in lines[:-1]] 48 | answers = [self._create_answer(line) for line in lines[1:]] 49 | return list(zip(questions, answers)) 50 | 51 | def _create_question(self, sentence) -> List[int]: 52 | ids = self.sp.encode_as_ids(sentence) 53 | return ids[:self.max_length] 54 | 55 | def _create_answer(self, sentence: str) -> List[int]: 56 | ids = self.sp.encode_as_ids(sentence) 57 | return [self.bos] + ids[:self.max_length - 2] + [self.eos] 58 | 59 | def _split(self, nd_list: Sequence, batch_size: int) -> List[List]: 60 | return [list(nd_list[i - batch_size:i]) for i in range(batch_size, len(nd_list) + 1, batch_size)] 61 | 62 | def _convert_to_array(self, id_list_list: Sequence[Sequence[int]]) -> np.ndarray: 63 | max_len = max([len(id_list) for id_list in id_list_list]) 64 | 65 | return np.array( 66 | [list(id_list) + [self.pad] * (max_len - len(id_list)) for id_list in id_list_list], 67 | dtype=np.int32, 68 | ) 69 | -------------------------------------------------------------------------------- /deepdialog/transformer/preprocess/create_tokenizer.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | pipenv run spm_train --pad_id=0 --bos_id=1 --eos_id=2 --unk_id=3 --input=data/natsume.txt --model_prefix=deepdialog/transformer/preprocess/spm_natsume --vocab_size=8000 3 | -------------------------------------------------------------------------------- /deepdialog/transformer/preprocess/spm_natsume.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/halhorn/deep_dialog_tutorial/134d68db94d2a4cb6267a2fa64d2db0a5b70449e/deepdialog/transformer/preprocess/spm_natsume.model -------------------------------------------------------------------------------- /deepdialog/transformer/training.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stdout", 10 | "output_type": "stream", 11 | "text": [ 12 | "current dir: /home/harumitsu.nobuta/git/deep_dialog_tutorial\n" 13 | ] 14 | } 15 | ], 16 | "source": [ 17 | "# カレントディレクトリをリポジトリ直下にするおまじない\n", 18 | "import os\n", 19 | "while os.getcwd().split('/')[-1] != 'deep_dialog_tutorial': os.chdir('..')\n", 20 | "print('current dir:', os.getcwd())" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 2, 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "import tensorflow as tf\n", 30 | "from deepdialog.transformer.transformer import Transformer\n", 31 | "from deepdialog.transformer.preprocess.batch_generator import BatchGenerator" 32 | ] 33 | }, 34 | { 35 | "cell_type": "markdown", 36 | "metadata": {}, 37 | "source": [ 38 | "# Create Data" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": 3, 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "data_path = 'data/natsume.txt'" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": 4, 53 | "metadata": {}, 54 | "outputs": [], 55 | "source": [ 56 | "batch_generator = BatchGenerator()\n", 57 | "batch_generator.load(data_path)" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": 5, 63 | "metadata": {}, 64 | "outputs": [], 65 | "source": [ 66 | "vocab_size = batch_generator.vocab_size" 67 | ] 68 | }, 69 | { 70 | "cell_type": "markdown", 71 | "metadata": {}, 72 | "source": [ 73 | "# Create Model" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": null, 79 | "metadata": {}, 80 | "outputs": [], 81 | "source": [ 82 | "graph = tf.Graph()\n", 83 | "with graph.as_default():\n", 84 | " transformer = Transformer(\n", 85 | " vocab_size=vocab_size,\n", 86 | " hopping_num=4,\n", 87 | " head_num=8,\n", 88 | " hidden_dim=512,\n", 89 | " dropout_rate=0.1,\n", 90 | " max_length=50,\n", 91 | " )\n", 92 | " transformer.build_graph()" 93 | ] 94 | }, 95 | { 96 | "cell_type": "markdown", 97 | "metadata": {}, 98 | "source": [ 99 | "# Create Training Graph" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": null, 105 | "metadata": {}, 106 | "outputs": [], 107 | "source": [ 108 | "save_dir = 'tmp/learning/transformer/'\n", 109 | "log_dir = os.path.join(save_dir, 'log')\n", 110 | "ckpt_path = os.path.join(save_dir, 'checkpoints/model.ckpt')\n", 111 | "\n", 112 | "os.makedirs(log_dir, exist_ok=True)" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": null, 118 | "metadata": {}, 119 | "outputs": [], 120 | "source": [ 121 | "with graph.as_default():\n", 122 | " global_step = tf.train.get_or_create_global_step()\n", 123 | " \n", 124 | " learning_rate = tf.placeholder(dtype=tf.float32, name='learning_rate')\n", 125 | " optimizer = tf.train.AdamOptimizer(\n", 126 | " learning_rate=learning_rate,\n", 127 | " beta2=0.98,\n", 128 | " )\n", 129 | " optimize_op = optimizer.minimize(transformer.loss, global_step=global_step)\n", 130 | "\n", 131 | " summary_op = tf.summary.merge([\n", 132 | " tf.summary.scalar('train/loss', transformer.loss),\n", 133 | " tf.summary.scalar('train/acc', transformer.acc),\n", 134 | " tf.summary.scalar('train/learning_rate', learning_rate),\n", 135 | " ], name='train_summary')\n", 136 | " summary_writer = tf.summary.FileWriter(log_dir, graph)\n", 137 | " saver = tf.train.Saver()" 138 | ] 139 | }, 140 | { 141 | "cell_type": "markdown", 142 | "metadata": {}, 143 | "source": [ 144 | "# Train" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": null, 150 | "metadata": {}, 151 | "outputs": [], 152 | "source": [ 153 | "max_step = 100000\n", 154 | "batch_size = 128\n", 155 | "max_learning_rate = 0.0001\n", 156 | "warmup_step = 4000" 157 | ] 158 | }, 159 | { 160 | "cell_type": "code", 161 | "execution_count": null, 162 | "metadata": {}, 163 | "outputs": [], 164 | "source": [ 165 | "def get_learning_rate(step: int) -> float:\n", 166 | " rate = min(step ** -0.5, step * warmup_step ** -1.5) / warmup_step ** -0.5\n", 167 | " return max_learning_rate * rate" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": null, 173 | "metadata": {}, 174 | "outputs": [], 175 | "source": [ 176 | "with graph.as_default():\n", 177 | " sess = tf.Session()\n", 178 | " sess.run(tf.global_variables_initializer())\n", 179 | " step = 0" 180 | ] 181 | }, 182 | { 183 | "cell_type": "code", 184 | "execution_count": null, 185 | "metadata": {}, 186 | "outputs": [ 187 | { 188 | "name": "stdout", 189 | "output_type": "stream", 190 | "text": [ 191 | "0: loss: 8.456110000610352,\t acc: 0.00042753314482979476\n", 192 | "100: loss: 8.063234329223633,\t acc: 0.061625875532627106\n", 193 | "200: loss: 7.624130725860596,\t acc: 0.08877591043710709\n", 194 | "300: loss: 7.2388014793396,\t acc: 0.15279187262058258\n", 195 | "400: loss: 6.831792831420898,\t acc: 0.15193502604961395\n", 196 | "800: loss: 6.131741523742676,\t acc: 0.16190476715564728\n", 197 | "900: loss: 6.099096298217773,\t acc: 0.16284014284610748\n", 198 | "1000: loss: 6.00535774230957,\t acc: 0.17789646983146667\n", 199 | "1100: loss: 5.965171813964844,\t acc: 0.175105482339859\n", 200 | "1200: loss: 6.056082248687744,\t acc: 0.16189385950565338\n", 201 | "1300: loss: 5.734684944152832,\t acc: 0.19673576951026917\n", 202 | "1400: loss: 5.750892162322998,\t acc: 0.19291338324546814\n", 203 | "1500: loss: 5.762808322906494,\t acc: 0.19190600514411926\n", 204 | "1600: loss: 5.654571056365967,\t acc: 0.20242369174957275\n", 205 | "1700: loss: 5.622186660766602,\t acc: 0.20016610622406006\n", 206 | "1800: loss: 5.621791362762451,\t acc: 0.19756199419498444\n", 207 | "1900: loss: 5.568434238433838,\t acc: 0.20691144466400146\n", 208 | "2000: loss: 5.44687557220459,\t acc: 0.21381579339504242\n", 209 | "2100: loss: 5.414419174194336,\t acc: 0.2064831256866455\n", 210 | "2200: loss: 5.508380889892578,\t acc: 0.19874311983585358\n", 211 | "2300: loss: 5.487494945526123,\t acc: 0.20331446826457977\n", 212 | "2400: loss: 5.319317817687988,\t acc: 0.21157850325107574\n", 213 | "2500: loss: 5.343756675720215,\t acc: 0.20568837225437164\n", 214 | "2600: loss: 5.286867141723633,\t acc: 0.20811519026756287\n", 215 | "2700: loss: 5.276029586791992,\t acc: 0.2098710536956787\n", 216 | "2800: loss: 5.206023216247559,\t acc: 0.22122378647327423\n", 217 | "2900: loss: 5.038154602050781,\t acc: 0.23236124217510223\n", 218 | "3000: loss: 5.104355812072754,\t acc: 0.22743521630764008\n", 219 | "3100: loss: 5.126926422119141,\t acc: 0.2107037454843521\n", 220 | "3200: loss: 5.014754772186279,\t acc: 0.22401656210422516\n", 221 | "3300: loss: 5.025139808654785,\t acc: 0.23066666722297668\n", 222 | "3400: loss: 4.941408634185791,\t acc: 0.23014168441295624\n", 223 | "3500: loss: 4.912785053253174,\t acc: 0.24003392457962036\n", 224 | "3600: loss: 5.00833797454834,\t acc: 0.22020934522151947\n", 225 | "3700: loss: 4.822469711303711,\t acc: 0.24779582023620605\n", 226 | "3800: loss: 4.809948444366455,\t acc: 0.24245116114616394\n", 227 | "3900: loss: 4.824687480926514,\t acc: 0.2477102428674698\n", 228 | "4000: loss: 4.854027271270752,\t acc: 0.23831196129322052\n", 229 | "4100: loss: 4.789676666259766,\t acc: 0.23141223192214966\n", 230 | "4200: loss: 4.783527374267578,\t acc: 0.2415730357170105\n", 231 | "4300: loss: 4.59663724899292,\t acc: 0.24965955317020416\n", 232 | "4400: loss: 4.752450466156006,\t acc: 0.23962344229221344\n", 233 | "4500: loss: 4.8018035888671875,\t acc: 0.23905034363269806\n", 234 | "4600: loss: 4.587310791015625,\t acc: 0.2567099630832672\n", 235 | "4700: loss: 4.544100761413574,\t acc: 0.2522633671760559\n", 236 | "4800: loss: 4.494418621063232,\t acc: 0.2575145661830902\n", 237 | "4900: loss: 4.54245662689209,\t acc: 0.24596431851387024\n", 238 | "5000: loss: 4.445837497711182,\t acc: 0.2686775028705597\n", 239 | "5100: loss: 4.514632701873779,\t acc: 0.2603395879268646\n", 240 | "5200: loss: 4.388525009155273,\t acc: 0.2585214674472809\n", 241 | "5300: loss: 4.37705659866333,\t acc: 0.27642637491226196\n", 242 | "5400: loss: 4.34771728515625,\t acc: 0.2738979756832123\n", 243 | "5500: loss: 4.450374603271484,\t acc: 0.25080257654190063\n", 244 | "5600: loss: 4.400920867919922,\t acc: 0.27194955945014954\n", 245 | "5700: loss: 4.484810829162598,\t acc: 0.24290083348751068\n", 246 | "5800: loss: 4.316364288330078,\t acc: 0.2655201256275177\n", 247 | "5900: loss: 4.273972511291504,\t acc: 0.2746448516845703\n", 248 | "6000: loss: 4.320553302764893,\t acc: 0.2808062434196472\n", 249 | "6100: loss: 4.388421058654785,\t acc: 0.2707231044769287\n", 250 | "6200: loss: 4.133350372314453,\t acc: 0.2922716736793518\n", 251 | "6300: loss: 4.329468727111816,\t acc: 0.2611134946346283\n", 252 | "6400: loss: 4.322432041168213,\t acc: 0.26906222105026245\n", 253 | "6500: loss: 4.262185573577881,\t acc: 0.2718527913093567\n", 254 | "6600: loss: 4.347004413604736,\t acc: 0.26810672879219055\n", 255 | "6700: loss: 4.2577223777771,\t acc: 0.26779359579086304\n", 256 | "6800: loss: 4.383606910705566,\t acc: 0.25628742575645447\n", 257 | "6900: loss: 4.17828893661499,\t acc: 0.29600733518600464\n", 258 | "7000: loss: 4.274170875549316,\t acc: 0.2644462287425995\n", 259 | "7100: loss: 4.192056655883789,\t acc: 0.2744651734828949\n", 260 | "7200: loss: 4.157262325286865,\t acc: 0.2859618663787842\n", 261 | "7300: loss: 4.143908977508545,\t acc: 0.27880510687828064\n", 262 | "7400: loss: 4.1946563720703125,\t acc: 0.27994734048843384\n", 263 | "7500: loss: 3.973661422729492,\t acc: 0.29752808809280396\n", 264 | "7600: loss: 4.075747966766357,\t acc: 0.2974516749382019\n", 265 | "7700: loss: 4.045965671539307,\t acc: 0.2951042652130127\n", 266 | "7800: loss: 4.085124492645264,\t acc: 0.28605425357818604\n", 267 | "7900: loss: 4.138719081878662,\t acc: 0.2756514549255371\n", 268 | "8000: loss: 4.048675060272217,\t acc: 0.29752808809280396\n", 269 | "8100: loss: 4.119937419891357,\t acc: 0.27346569299697876\n", 270 | "8200: loss: 4.129490852355957,\t acc: 0.27090984582901\n", 271 | "8300: loss: 4.024595260620117,\t acc: 0.29987505078315735\n", 272 | "8400: loss: 4.097468376159668,\t acc: 0.2881355881690979\n", 273 | "8500: loss: 4.125740051269531,\t acc: 0.28338098526000977\n", 274 | "8600: loss: 4.064797878265381,\t acc: 0.28390368819236755\n", 275 | "8700: loss: 4.081852436065674,\t acc: 0.2783898413181305\n", 276 | "8800: loss: 4.1344313621521,\t acc: 0.2897196114063263\n", 277 | "8900: loss: 4.147453308105469,\t acc: 0.2754112184047699\n", 278 | "9000: loss: 4.041755676269531,\t acc: 0.2966066300868988\n", 279 | "9100: loss: 4.057901859283447,\t acc: 0.2802101671695709\n", 280 | "9200: loss: 3.9369938373565674,\t acc: 0.2987436056137085\n", 281 | "9300: loss: 4.0047502517700195,\t acc: 0.2983802258968353\n", 282 | "9400: loss: 4.050186634063721,\t acc: 0.2909336984157562\n", 283 | "9500: loss: 4.042887210845947,\t acc: 0.2975391447544098\n", 284 | "9600: loss: 3.9739620685577393,\t acc: 0.2875226140022278\n", 285 | "9700: loss: 4.015842437744141,\t acc: 0.28633594512939453\n", 286 | "9800: loss: 4.048672199249268,\t acc: 0.27763205766677856\n", 287 | "9900: loss: 4.000374794006348,\t acc: 0.2982300817966461\n", 288 | "10000: loss: 3.9310991764068604,\t acc: 0.2945859730243683\n", 289 | "10100: loss: 3.913878917694092,\t acc: 0.2960662543773651\n", 290 | "10200: loss: 3.9307632446289062,\t acc: 0.2983333468437195\n", 291 | "10300: loss: 3.889249563217163,\t acc: 0.30311354994773865\n", 292 | "10400: loss: 3.831475019454956,\t acc: 0.3099730312824249\n", 293 | "10500: loss: 4.028707027435303,\t acc: 0.2801155149936676\n", 294 | "10600: loss: 3.9097371101379395,\t acc: 0.3073878586292267\n", 295 | "10700: loss: 3.912473678588867,\t acc: 0.3038083016872406\n", 296 | "10800: loss: 3.845147132873535,\t acc: 0.30451127886772156\n", 297 | "10900: loss: 3.8536312580108643,\t acc: 0.29784536361694336\n", 298 | "11000: loss: 3.7893378734588623,\t acc: 0.31664469838142395\n", 299 | "11100: loss: 3.8203961849212646,\t acc: 0.31642410159111023\n", 300 | "11200: loss: 3.7602591514587402,\t acc: 0.32076290249824524\n", 301 | "11300: loss: 3.8646557331085205,\t acc: 0.3108544945716858\n", 302 | "11400: loss: 3.8545830249786377,\t acc: 0.30572083592414856\n", 303 | "11500: loss: 3.8321175575256348,\t acc: 0.302325576543808\n", 304 | "11600: loss: 3.719156265258789,\t acc: 0.3173781931400299\n", 305 | "11700: loss: 3.8117899894714355,\t acc: 0.3128444254398346\n", 306 | "11800: loss: 3.886993408203125,\t acc: 0.30528542399406433\n", 307 | "11900: loss: 3.775373935699463,\t acc: 0.31563544273376465\n", 308 | "12000: loss: 3.7622268199920654,\t acc: 0.3165532946586609\n", 309 | "12100: loss: 3.7508909702301025,\t acc: 0.31331828236579895\n", 310 | "12200: loss: 3.8010976314544678,\t acc: 0.30648064613342285\n", 311 | "12300: loss: 3.8352155685424805,\t acc: 0.3165552616119385\n", 312 | "12400: loss: 3.7904624938964844,\t acc: 0.311710923910141\n", 313 | "12500: loss: 3.7200119495391846,\t acc: 0.30683282017707825\n", 314 | "12600: loss: 3.667607069015503,\t acc: 0.31440070271492004\n", 315 | "12700: loss: 3.7935903072357178,\t acc: 0.31179186701774597\n", 316 | "12800: loss: 3.629826545715332,\t acc: 0.32864511013031006\n", 317 | "12900: loss: 3.8675429821014404,\t acc: 0.30778688192367554\n", 318 | "13000: loss: 3.7820820808410645,\t acc: 0.3183133602142334\n", 319 | "13100: loss: 3.780679702758789,\t acc: 0.3036717176437378\n", 320 | "13200: loss: 3.7960171699523926,\t acc: 0.31540894508361816\n", 321 | "13300: loss: 3.7389650344848633,\t acc: 0.3232235610485077\n", 322 | "13400: loss: 3.7431070804595947,\t acc: 0.3018943965435028\n", 323 | "13500: loss: 3.7474818229675293,\t acc: 0.32277923822402954\n", 324 | "13600: loss: 3.6472673416137695,\t acc: 0.33138489723205566\n", 325 | "13700: loss: 3.659785270690918,\t acc: 0.33216631412506104\n", 326 | "18500: loss: 3.3232548236846924,\t acc: 0.3600183129310608\n", 327 | "18600: loss: 3.391157388687134,\t acc: 0.3583032488822937\n", 328 | "18700: loss: 3.426806926727295,\t acc: 0.3510917127132416\n", 329 | "18800: loss: 3.3631210327148438,\t acc: 0.3566824197769165\n", 330 | "18900: loss: 3.453577995300293,\t acc: 0.3440541625022888\n", 331 | "19000: loss: 3.4107024669647217,\t acc: 0.3549356162548065\n", 332 | "19100: loss: 3.2855074405670166,\t acc: 0.3664220869541168\n", 333 | "19200: loss: 3.357651948928833,\t acc: 0.3394950330257416\n", 334 | "19300: loss: 3.3485193252563477,\t acc: 0.3612521290779114\n", 335 | "19400: loss: 3.316305637359619,\t acc: 0.36931312084198\n", 336 | "19500: loss: 3.320740222930908,\t acc: 0.36003559827804565\n", 337 | "19600: loss: 3.3496923446655273,\t acc: 0.3532053828239441\n", 338 | "19700: loss: 3.364785671234131,\t acc: 0.35251179337501526\n", 339 | "19800: loss: 3.3690717220306396,\t acc: 0.3566029369831085\n", 340 | "19900: loss: 3.2894585132598877,\t acc: 0.36442893743515015\n", 341 | "20000: loss: 3.3753769397735596,\t acc: 0.3489666283130646\n", 342 | "20100: loss: 3.261539936065674,\t acc: 0.357371062040329\n", 343 | "20200: loss: 3.282179117202759,\t acc: 0.3590339124202728\n", 344 | "20300: loss: 3.3219830989837646,\t acc: 0.3520814776420593\n", 345 | "20400: loss: 3.375070095062256,\t acc: 0.35904356837272644\n", 346 | "20500: loss: 3.3187315464019775,\t acc: 0.3668763041496277\n", 347 | "20600: loss: 3.3546907901763916,\t acc: 0.3472447395324707\n", 348 | "20700: loss: 3.2267580032348633,\t acc: 0.3754822015762329\n", 349 | "20800: loss: 3.3306498527526855,\t acc: 0.3590814173221588\n", 350 | "20900: loss: 3.3414225578308105,\t acc: 0.3506008982658386\n", 351 | "21000: loss: 3.3138365745544434,\t acc: 0.36109796166419983\n", 352 | "21100: loss: 3.304713726043701,\t acc: 0.3556230962276459\n", 353 | "21200: loss: 3.2739338874816895,\t acc: 0.3778371214866638\n", 354 | "21300: loss: 3.33601450920105,\t acc: 0.35638731718063354\n", 355 | "21400: loss: 3.2664527893066406,\t acc: 0.3623490333557129\n", 356 | "21500: loss: 3.1983115673065186,\t acc: 0.38477054238319397\n", 357 | "21600: loss: 3.1049516201019287,\t acc: 0.3921971321105957\n", 358 | "21700: loss: 3.235291004180908,\t acc: 0.3630598485469818\n", 359 | "21800: loss: 3.2732791900634766,\t acc: 0.36550888419151306\n", 360 | "21900: loss: 3.322505235671997,\t acc: 0.3609052002429962\n", 361 | "22000: loss: 3.2371673583984375,\t acc: 0.3626425862312317\n", 362 | "22100: loss: 3.2942678928375244,\t acc: 0.35512155294418335\n", 363 | "22200: loss: 3.2588412761688232,\t acc: 0.356609582901001\n", 364 | "22300: loss: 3.2395517826080322,\t acc: 0.3725900948047638\n" 365 | ] 366 | } 367 | ], 368 | "source": [ 369 | "with graph.as_default():\n", 370 | " for batch in batch_generator.get_batch(batch_size=batch_size):\n", 371 | " feed = {\n", 372 | " **batch,\n", 373 | " learning_rate: get_learning_rate(step + 1),\n", 374 | " }\n", 375 | " _, loss, acc, step, summary = sess.run([optimize_op, transformer.loss, transformer.acc, global_step, summary_op], feed_dict=feed)\n", 376 | " summary_writer.add_summary(summary, step)\n", 377 | " \n", 378 | " if step % 100 == 0:\n", 379 | " print(f'{step}: loss: {loss},\\t acc: {acc}')\n", 380 | " saver.save(sess, ckpt_path, global_step=step)" 381 | ] 382 | }, 383 | { 384 | "cell_type": "code", 385 | "execution_count": null, 386 | "metadata": {}, 387 | "outputs": [], 388 | "source": [] 389 | } 390 | ], 391 | "metadata": { 392 | "kernelspec": { 393 | "display_name": "Python 3", 394 | "language": "python", 395 | "name": "python3" 396 | }, 397 | "language_info": { 398 | "codemirror_mode": { 399 | "name": "ipython", 400 | "version": 3 401 | }, 402 | "file_extension": ".py", 403 | "mimetype": "text/x-python", 404 | "name": "python", 405 | "nbconvert_exporter": "python", 406 | "pygments_lexer": "ipython3", 407 | "version": "3.6.5" 408 | } 409 | }, 410 | "nbformat": 4, 411 | "nbformat_minor": 2 412 | } 413 | -------------------------------------------------------------------------------- /deepdialog/transformer/transformer.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from typing import List 3 | from .common_layer import FeedForwardNetwork, ResidualNormalizationWrapper, LayerNormalization 4 | from .embedding import TokenEmbedding, AddPositionalEncoding 5 | from .attention import MultiheadAttention, SelfAttention 6 | from .metrics import padded_cross_entropy_loss, padded_accuracy 7 | 8 | PAD_ID = 0 9 | 10 | 11 | class Transformer(tf.keras.models.Model): 12 | ''' 13 | Transformer モデルです。 14 | ''' 15 | def __init__( 16 | self, 17 | vocab_size: int, 18 | hopping_num: int = 4, 19 | head_num: int = 8, 20 | hidden_dim: int = 512, 21 | dropout_rate: float = 0.1, 22 | max_length: int = 50, 23 | *args, 24 | **kwargs, 25 | ) -> None: 26 | super().__init__(*args, **kwargs) 27 | self.vocab_size = vocab_size 28 | self.hopping_num = hopping_num 29 | self.head_num = head_num 30 | self.hidden_dim = hidden_dim 31 | self.dropout_rate = dropout_rate 32 | self.max_length = max_length 33 | 34 | self.encoder = Encoder( 35 | vocab_size=vocab_size, 36 | hopping_num=hopping_num, 37 | head_num=head_num, 38 | hidden_dim=hidden_dim, 39 | dropout_rate=dropout_rate, 40 | max_length=max_length, 41 | ) 42 | self.decoder = Decoder( 43 | vocab_size=vocab_size, 44 | hopping_num=hopping_num, 45 | head_num=head_num, 46 | hidden_dim=hidden_dim, 47 | dropout_rate=dropout_rate, 48 | max_length=max_length, 49 | ) 50 | 51 | def build_graph(self, name='transformer') -> None: 52 | ''' 53 | 学習/推論のためのグラフを構築します。 54 | ''' 55 | with tf.name_scope(name): 56 | self.is_training = tf.placeholder(dtype=tf.bool, name='is_training') 57 | # [batch_size, max_length] 58 | self.encoder_input = tf.placeholder(dtype=tf.int32, shape=[None, None], name='encoder_input') 59 | # [batch_size] 60 | self.decoder_input = tf.placeholder(dtype=tf.int32, shape=[None, None], name='decoder_input') 61 | 62 | logit = self.call( 63 | encoder_input=self.encoder_input, 64 | decoder_input=self.decoder_input[:, :-1], # 入力は EOS を含めない 65 | training=self.is_training, 66 | ) 67 | decoder_target = self.decoder_input[:, 1:] # 出力は BOS を含めない 68 | 69 | self.prediction = tf.nn.softmax(logit, name='prediction') 70 | 71 | with tf.name_scope('metrics'): 72 | xentropy, weights = padded_cross_entropy_loss( 73 | logit, decoder_target, smoothing=0.05, vocab_size=self.vocab_size) 74 | self.loss = tf.identity(tf.reduce_sum(xentropy) / tf.reduce_sum(weights), name='loss') 75 | 76 | accuracies, weights = padded_accuracy(logit, decoder_target) 77 | self.acc = tf.identity(tf.reduce_sum(accuracies) / tf.reduce_sum(weights), name='acc') 78 | 79 | def call(self, encoder_input: tf.Tensor, decoder_input: tf.Tensor, training: bool) -> tf.Tensor: 80 | enc_attention_mask = self._create_enc_attention_mask(encoder_input) 81 | dec_self_attention_mask = self._create_dec_self_attention_mask(decoder_input) 82 | 83 | encoder_output = self.encoder( 84 | encoder_input, 85 | self_attention_mask=enc_attention_mask, 86 | training=training, 87 | ) 88 | decoder_output = self.decoder( 89 | decoder_input, 90 | encoder_output, 91 | self_attention_mask=dec_self_attention_mask, 92 | enc_dec_attention_mask=enc_attention_mask, 93 | training=training, 94 | ) 95 | return decoder_output 96 | 97 | def _create_enc_attention_mask(self, encoder_input: tf.Tensor): 98 | with tf.name_scope('enc_attention_mask'): 99 | batch_size, length = tf.unstack(tf.shape(encoder_input)) 100 | pad_array = tf.equal(encoder_input, PAD_ID) # [batch_size, m_length] 101 | # shape broadcasting で [batch_size, head_num, (m|q)_length, m_length] になる 102 | return tf.reshape(pad_array, [batch_size, 1, 1, length]) 103 | 104 | def _create_dec_self_attention_mask(self, decoder_input: tf.Tensor): 105 | with tf.name_scope('dec_self_attention_mask'): 106 | batch_size, length = tf.unstack(tf.shape(decoder_input)) 107 | pad_array = tf.equal(decoder_input, PAD_ID) # [batch_size, m_length] 108 | pad_array = tf.reshape(pad_array, [batch_size, 1, 1, length]) 109 | 110 | autoregression_array = tf.logical_not( 111 | tf.matrix_band_part(tf.ones([length, length], dtype=tf.bool), -1, 0)) # 下三角が False 112 | autoregression_array = tf.reshape(autoregression_array, [1, 1, length, length]) 113 | 114 | return tf.logical_or(pad_array, autoregression_array) 115 | 116 | 117 | class Encoder(tf.keras.models.Model): 118 | ''' 119 | トークン列をベクトル列にエンコードする Encoder です。 120 | ''' 121 | def __init__( 122 | self, 123 | vocab_size: int, 124 | hopping_num: int, 125 | head_num: int, 126 | hidden_dim: int, 127 | dropout_rate: float, 128 | max_length: int, 129 | *args, 130 | **kwargs, 131 | ) -> None: 132 | super().__init__(*args, **kwargs) 133 | self.hopping_num = hopping_num 134 | self.head_num = head_num 135 | self.hidden_dim = hidden_dim 136 | self.dropout_rate = dropout_rate 137 | 138 | self.token_embedding = TokenEmbedding(vocab_size, hidden_dim) 139 | self.add_position_embedding = AddPositionalEncoding() 140 | self.input_dropout_layer = tf.keras.layers.Dropout(dropout_rate) 141 | 142 | self.attention_block_list: List[List[tf.keras.models.Model]] = [] 143 | for _ in range(hopping_num): 144 | attention_layer = SelfAttention(hidden_dim, head_num, dropout_rate, name='self_attention') 145 | ffn_layer = FeedForwardNetwork(hidden_dim, dropout_rate, name='ffn') 146 | self.attention_block_list.append([ 147 | ResidualNormalizationWrapper(attention_layer, dropout_rate, name='self_attention_wrapper'), 148 | ResidualNormalizationWrapper(ffn_layer, dropout_rate, name='ffn_wrapper'), 149 | ]) 150 | self.output_normalization = LayerNormalization() 151 | 152 | def call( 153 | self, 154 | input: tf.Tensor, 155 | self_attention_mask: tf.Tensor, 156 | training: bool, 157 | ) -> tf.Tensor: 158 | ''' 159 | モデルを実行します 160 | 161 | :param input: shape = [batch_size, length] 162 | :param training: 学習時は True 163 | :return: shape = [batch_size, length, hidden_dim] 164 | ''' 165 | # [batch_size, length, hidden_dim] 166 | embedded_input = self.token_embedding(input) 167 | embedded_input = self.add_position_embedding(embedded_input) 168 | query = self.input_dropout_layer(embedded_input, training=training) 169 | 170 | for i, layers in enumerate(self.attention_block_list): 171 | attention_layer, ffn_layer = tuple(layers) 172 | with tf.name_scope(f'hopping_{i}'): 173 | query = attention_layer(query, attention_mask=self_attention_mask, training=training) 174 | query = ffn_layer(query, training=training) 175 | # [batch_size, length, hidden_dim] 176 | return self.output_normalization(query) 177 | 178 | 179 | class Decoder(tf.keras.models.Model): 180 | ''' 181 | エンコードされたベクトル列からトークン列を生成する Decoder です。 182 | ''' 183 | def __init__( 184 | self, 185 | vocab_size: int, 186 | hopping_num: int, 187 | head_num: int, 188 | hidden_dim: int, 189 | dropout_rate: float, 190 | max_length: int, 191 | *args, 192 | **kwargs, 193 | ) -> None: 194 | super().__init__(*args, **kwargs) 195 | self.hopping_num = hopping_num 196 | self.head_num = head_num 197 | self.hidden_dim = hidden_dim 198 | self.dropout_rate = dropout_rate 199 | 200 | self.token_embedding = TokenEmbedding(vocab_size, hidden_dim) 201 | self.add_position_embedding = AddPositionalEncoding() 202 | self.input_dropout_layer = tf.keras.layers.Dropout(dropout_rate) 203 | 204 | self.attention_block_list: List[List[tf.keras.models.Model]] = [] 205 | for _ in range(hopping_num): 206 | self_attention_layer = SelfAttention(hidden_dim, head_num, dropout_rate, name='self_attention') 207 | enc_dec_attention_layer = MultiheadAttention(hidden_dim, head_num, dropout_rate, name='enc_dec_attention') 208 | ffn_layer = FeedForwardNetwork(hidden_dim, dropout_rate, name='ffn') 209 | self.attention_block_list.append([ 210 | ResidualNormalizationWrapper(self_attention_layer, dropout_rate, name='self_attention_wrapper'), 211 | ResidualNormalizationWrapper(enc_dec_attention_layer, dropout_rate, name='enc_dec_attention_wrapper'), 212 | ResidualNormalizationWrapper(ffn_layer, dropout_rate, name='ffn_wrapper'), 213 | ]) 214 | self.output_normalization = LayerNormalization() 215 | # 注:本家ではここは TokenEmbedding の重みを転地したものを使っている 216 | self.output_dense_layer = tf.keras.layers.Dense(vocab_size, use_bias=False) 217 | 218 | def call( 219 | self, 220 | input: tf.Tensor, 221 | encoder_output: tf.Tensor, 222 | self_attention_mask: tf.Tensor, 223 | enc_dec_attention_mask: tf.Tensor, 224 | training: bool, 225 | ) -> tf.Tensor: 226 | ''' 227 | モデルを実行します 228 | 229 | :param input: shape = [batch_size, length] 230 | :param training: 学習時は True 231 | :return: shape = [batch_size, length, hidden_dim] 232 | ''' 233 | # [batch_size, length, hidden_dim] 234 | embedded_input = self.token_embedding(input) 235 | embedded_input = self.add_position_embedding(embedded_input) 236 | query = self.input_dropout_layer(embedded_input, training=training) 237 | 238 | for i, layers in enumerate(self.attention_block_list): 239 | self_attention_layer, enc_dec_attention_layer, ffn_layer = tuple(layers) 240 | with tf.name_scope(f'hopping_{i}'): 241 | query = self_attention_layer(query, attention_mask=self_attention_mask, training=training) 242 | query = enc_dec_attention_layer(query, memory=encoder_output, 243 | attention_mask=enc_dec_attention_mask, training=training) 244 | query = ffn_layer(query, training=training) 245 | 246 | query = self.output_normalization(query) # [batch_size, length, hidden_dim] 247 | return self.output_dense_layer(query) # [batch_size, length, vocab_size] 248 | -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/halhorn/deep_dialog_tutorial/134d68db94d2a4cb6267a2fa64d2db0a5b70449e/test/__init__.py -------------------------------------------------------------------------------- /test/deepdialog/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/halhorn/deep_dialog_tutorial/134d68db94d2a4cb6267a2fa64d2db0a5b70449e/test/deepdialog/__init__.py -------------------------------------------------------------------------------- /test/deepdialog/transformer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/halhorn/deep_dialog_tutorial/134d68db94d2a4cb6267a2fa64d2db0a5b70449e/test/deepdialog/transformer/__init__.py -------------------------------------------------------------------------------- /test/deepdialog/transformer/attention.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import tensorflow as tf 3 | import numpy as np 4 | from deepdialog.transformer.attention import MultiheadAttention, SelfAttention, SimpleAttention 5 | 6 | tf.enable_eager_execution() 7 | 8 | 9 | class TestMultiheadAttention(unittest.TestCase): 10 | def test_call(self): 11 | batch_size = 3 12 | q_length = 5 13 | m_length = 7 14 | hidden_dim = 32 15 | head_num = 4 16 | with tf.Graph().as_default(), tf.Session() as sess: 17 | q = tf.placeholder(dtype=tf.float32, shape=[None, None, hidden_dim]) 18 | k = tf.placeholder(dtype=tf.float32, shape=[None, None, hidden_dim]) 19 | 20 | mask_numpy = np.zeros(shape=[batch_size, 1, q_length, m_length]) 21 | mask_numpy[0, 0, :, -1] = 1 22 | mask = tf.constant(mask_numpy, dtype=tf.bool) 23 | 24 | model = MultiheadAttention(hidden_dim=hidden_dim, head_num=head_num, dropout_rate=0.1) 25 | result_op = model(q, k, mask, training=True) 26 | sess.run(tf.global_variables_initializer()) 27 | result, attention_weight = sess.run([result_op, 'multihead_attention/attention_weight:0'], feed_dict={ 28 | q: np.ones(shape=[batch_size, q_length, hidden_dim]), 29 | k: np.ones(shape=[batch_size, m_length, hidden_dim]), 30 | }) 31 | self.assertEqual(result.shape, (batch_size, q_length, hidden_dim)) 32 | self.assertEqual(attention_weight[0, 0, :, -1].tolist(), [0.0] * q_length) 33 | 34 | def test_split_head(self): 35 | batch_size = 3 36 | length = 5 37 | hidden_dim = 32 38 | head_num = 4 39 | model = MultiheadAttention(hidden_dim=hidden_dim, head_num=head_num, dropout_rate=0.1) 40 | x = tf.ones(shape=[batch_size, length, hidden_dim]) 41 | y = model._split_head(x) 42 | self.assertEqual(y.shape, [batch_size, head_num, length, hidden_dim // head_num]) 43 | 44 | def test_split_head_graph(self): 45 | batch_size = 3 46 | length = 5 47 | hidden_dim = 32 48 | head_num = 4 49 | with tf.Graph().as_default(), tf.Session() as sess: 50 | model = MultiheadAttention(hidden_dim=hidden_dim, head_num=head_num, dropout_rate=0.1) 51 | x = tf.placeholder(dtype=tf.float32, shape=[batch_size, None, None]) 52 | y = model._split_head(x) 53 | sess.run(y, feed_dict={x: np.ones(shape=[batch_size, length, hidden_dim])}) 54 | 55 | def test_combine_head(self): 56 | batch_size = 3 57 | length = 5 58 | hidden_dim = 32 59 | head_num = 4 60 | model = MultiheadAttention(hidden_dim=hidden_dim, head_num=head_num, dropout_rate=0.1) 61 | x = tf.ones(shape=[batch_size, head_num, length, hidden_dim // head_num]) 62 | y = model._combine_head(x) 63 | self.assertEqual(y.shape, [batch_size, length, hidden_dim]) 64 | 65 | x = tf.reshape(tf.range(batch_size * length * hidden_dim), [batch_size, length, hidden_dim]) 66 | reconstructed = model._combine_head(model._split_head(x)) 67 | self.assertEqual(reconstructed.numpy().tolist(), x.numpy().tolist()) 68 | 69 | 70 | class TestSelfAttention(unittest.TestCase): 71 | def test_call(self): 72 | batch_size = 3 73 | q_length = 5 74 | hidden_dim = 32 75 | head_num = 4 76 | q = tf.ones(shape=[batch_size, q_length, hidden_dim]) 77 | mask = tf.zeros(shape=[batch_size, 1, 1, q_length]) 78 | model = SelfAttention(hidden_dim=hidden_dim, head_num=head_num, dropout_rate=0.1) 79 | result = model(q, mask, training=True) 80 | self.assertEqual(result.shape, [batch_size, q_length, hidden_dim]) 81 | 82 | 83 | class TestSimpleAttention(unittest.TestCase): 84 | def test_call(self): 85 | batch_size = 3 86 | q_length = 5 87 | m_length = 7 88 | depth = 32 89 | 90 | model = SimpleAttention(depth=depth) 91 | query = tf.ones(shape=[batch_size, q_length, depth]) 92 | memory = tf.ones(shape=[batch_size, m_length, depth]) 93 | result = model(query, memory) 94 | self.assertEqual(result.shape, [batch_size, q_length, depth]) 95 | -------------------------------------------------------------------------------- /test/deepdialog/transformer/common_layer.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import tensorflow as tf 3 | import numpy as np 4 | from deepdialog.transformer.common_layer import ( 5 | FeedForwardNetwork, LayerNormalization, ResidualNormalizationWrapper 6 | ) 7 | 8 | tf.enable_eager_execution() 9 | 10 | 11 | class TestFeedForwardNetwork(unittest.TestCase): 12 | def test_call(self): 13 | batch_size = 3 14 | length = 5 15 | hidden_dim = 32 16 | input_dim = 16 17 | model = FeedForwardNetwork(hidden_dim, dropout_rate=0.1) 18 | x = tf.ones(shape=[batch_size, length, input_dim]) 19 | result = model(x, training=True) 20 | self.assertEqual(result.shape, [batch_size, length, hidden_dim]) 21 | 22 | 23 | class TestResidualNormalization(unittest.TestCase): 24 | def test_call(self): 25 | batch_size = 3 26 | length = 5 27 | hidden_dim = 32 28 | layer = FeedForwardNetwork(hidden_dim, dropout_rate=0.1) 29 | wrapped_layer = ResidualNormalizationWrapper(layer, dropout_rate=0.1) 30 | 31 | x = tf.ones(shape=[batch_size, length, hidden_dim]) 32 | y = wrapped_layer(x, training=True) 33 | self.assertEqual(y.shape, [batch_size, length, hidden_dim]) 34 | 35 | 36 | class TestLayerNormalization(unittest.TestCase): 37 | def test_call(self): 38 | x = tf.constant([[0, 4], [-2, 2]], dtype=tf.float32) 39 | layer = LayerNormalization() 40 | y = layer(x) 41 | expect = [[-1, 1], [-1, 1]] 42 | for y1, e1 in zip(y.numpy(), expect): 43 | for y2, e2 in zip(y1, e1): 44 | self.assertAlmostEqual(y2, e2, places=5) 45 | 46 | def test_call_graph(self): 47 | batch_size = 2 48 | length = 3 49 | hidden_dim = 5 50 | 51 | with tf.Graph().as_default(), tf.Session() as sess: 52 | layer = LayerNormalization() 53 | x = tf.placeholder(dtype=tf.float32, shape=[None, None, hidden_dim]) 54 | y = layer(x) 55 | sess.run(tf.global_variables_initializer()) 56 | sess.run(y, feed_dict={x: np.ones(shape=[batch_size, length, hidden_dim])}) 57 | -------------------------------------------------------------------------------- /test/deepdialog/transformer/embedding.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import tensorflow as tf 3 | import numpy as np 4 | import itertools 5 | from deepdialog.transformer.embedding import TokenEmbedding, AddPositionalEncoding 6 | 7 | tf.enable_eager_execution() 8 | 9 | 10 | class TestTokenEmbedding(unittest.TestCase): 11 | def test_call(self): 12 | vocab_size = 3 13 | embedding_dim = 4 14 | layer = TokenEmbedding(vocab_size=vocab_size, embedding_dim=embedding_dim) 15 | embedded = layer(tf.constant([[0, 1, 2]])) 16 | embedded_tokens = embedded[0] 17 | self.assertEqual(embedded_tokens[0].numpy().tolist(), [0] * embedding_dim) 18 | self.assertNotEqual(embedded_tokens[1].numpy().tolist(), [0] * embedding_dim) 19 | 20 | 21 | class TestAddPositionalEncoding(unittest.TestCase): 22 | def test_call(self): 23 | max_length = 2 24 | batch_size = 3 25 | depth = 7 26 | 27 | layer = AddPositionalEncoding() 28 | input = tf.ones(shape=[batch_size, max_length, depth]) 29 | result = layer(input) 30 | self.assertEqual(result.shape, [batch_size, max_length, depth]) 31 | positional_encoding = (result - input).numpy() 32 | 33 | # PE_{pos, 2i} = sin(pos / 10000^{2i / d_model}) 34 | # PE_{pos, 2i+1} = cos(pos / 10000^{2i / d_model}) 35 | for batch, i, pos in itertools.product(range(batch_size), range(depth // 2), range(max_length)): 36 | self.assertAlmostEqual( 37 | positional_encoding[batch, pos, i * 2], 38 | np.sin(pos / 10000 ** (i * 2 / depth)), 39 | places=6, 40 | ) 41 | self.assertAlmostEqual( 42 | positional_encoding[batch, pos, i * 2 + 1], 43 | np.cos(pos / 10000 ** (i * 2 / depth)), 44 | places=6, 45 | ) 46 | 47 | def test_call_graph(self): 48 | batch_size = 3 49 | max_length = 5 50 | depth = 7 51 | data = np.ones(shape=[batch_size, max_length, depth]) 52 | 53 | with tf.Graph().as_default(): 54 | with tf.Session() as sess: 55 | layer = AddPositionalEncoding() 56 | input = tf.placeholder(shape=[None, None, None], dtype=tf.float32) 57 | result_op = layer(input) 58 | result = sess.run(result_op, feed_dict={ 59 | input: data, 60 | }) 61 | self.assertEqual(result.shape, (batch_size, max_length, depth)) 62 | 63 | positional_encoding = result - data 64 | 65 | # PE_{pos, 2i} = sin(pos / 10000^{2i / d_model}) 66 | # PE_{pos, 2i+1} = cos(pos / 10000^{2i / d_model}) 67 | for batch, i, pos in itertools.product(range(batch_size), range(depth // 2), range(max_length)): 68 | self.assertAlmostEqual( 69 | positional_encoding[batch, pos, i * 2], 70 | np.sin(pos / 10000 ** (i * 2 / depth)), 71 | places=6, 72 | ) 73 | self.assertAlmostEqual( 74 | positional_encoding[batch, pos, i * 2 + 1], 75 | np.cos(pos / 10000 ** (i * 2 / depth)), 76 | places=6, 77 | ) 78 | -------------------------------------------------------------------------------- /test/deepdialog/transformer/metrics.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import tensorflow as tf 3 | from deepdialog.transformer import metrics 4 | 5 | tf.enable_eager_execution() 6 | 7 | 8 | class TestMetrics(unittest.TestCase): 9 | def test_padded_cross_entropy_loss(self): 10 | logits = tf.constant([[ 11 | [0.1, -0.1, 1., 0., 0.], 12 | [0.1, -0.1, 1., 0., 0.], 13 | ]]) 14 | labels = tf.constant([[2, 2]]) 15 | metrics.padded_cross_entropy_loss(logits, labels, smoothing=0.05, vocab_size=5) 16 | 17 | def test_padded_accuracy(self): 18 | logits = tf.constant([[ 19 | [0.1, -0.1, 1., 0., 0.], 20 | [0.1, -0.1, 1., 0., 0.], 21 | [0.1, -0.1, 1., 0., 0.], 22 | ]]) 23 | labels = tf.constant([[2, 3, 0]]) # 0 == PAD 24 | result, weight = metrics.padded_accuracy(logits, labels) 25 | self.assertEqual(result.numpy().tolist(), [[1., 0., 0.]]) 26 | self.assertEqual(weight.numpy().tolist(), [[1., 1., 0.]]) 27 | -------------------------------------------------------------------------------- /test/deepdialog/transformer/preprocess/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/halhorn/deep_dialog_tutorial/134d68db94d2a4cb6267a2fa64d2db0a5b70449e/test/deepdialog/transformer/preprocess/__init__.py -------------------------------------------------------------------------------- /test/deepdialog/transformer/preprocess/batch_generator.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from deepdialog.transformer.preprocess.batch_generator import BatchGenerator, ENCODER_INPUT_NODE, DECODER_INPUT_NODE 3 | 4 | 5 | class TestBatchGenerator(unittest.TestCase): 6 | def test_vocab_size(self): 7 | batch_generator = BatchGenerator() 8 | self.assertEqual(batch_generator.vocab_size, 8000) 9 | 10 | def test_get_batch(self): 11 | batch_generator = BatchGenerator() 12 | batch_generator.data = [ 13 | ([1, 2, 3], [1, 4, 5, 2]), 14 | ([4, 5], [1, 6, 7, 8, 2]), 15 | ] 16 | gen = batch_generator.get_batch(batch_size=2, shuffle=False) 17 | result = gen.__next__() 18 | self.assertEqual(result[ENCODER_INPUT_NODE].tolist(), [ 19 | [1, 2, 3], 20 | [4, 5, 0], 21 | ]) 22 | self.assertEqual(result[DECODER_INPUT_NODE].tolist(), [ 23 | [1, 4, 5, 2, 0], 24 | [1, 6, 7, 8, 2], 25 | ]) 26 | 27 | def test_create_data(self): 28 | batch_generator = BatchGenerator() 29 | lines = ['こんにちは', 'やあ', 'いい天気だ'] 30 | result = batch_generator._create_data(lines) 31 | self.assertEqual(len(result), 2) 32 | self.assertEqual(batch_generator.sp.decode_ids(result[0][0]), 'こんにちは') 33 | 34 | def test_create_question(self): 35 | batch_generator = BatchGenerator() 36 | ids = batch_generator._create_question('こんにちは') 37 | self.assertEqual(batch_generator.sp.decode_ids(ids), 'こんにちは') 38 | 39 | def test_create_answer(self): 40 | batch_generator = BatchGenerator() 41 | ids = batch_generator._create_answer('こんにちは') 42 | self.assertEqual(batch_generator.sp.id_to_piece(ids[0]), '') 43 | self.assertEqual(batch_generator.sp.id_to_piece(ids[-1]), '') 44 | self.assertEqual(batch_generator.sp.decode_ids(ids), 'こんにちは') 45 | 46 | def test_split(self): 47 | batch_generator = BatchGenerator() 48 | 49 | for data in ( 50 | [0, 1, 2, 3, 4, 5, 6, 7, 8], 51 | [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], 52 | [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10], 53 | ): 54 | splited = batch_generator._split(data, 3) 55 | self.assertEqual(splited, [ 56 | [0, 1, 2], 57 | [3, 4, 5], 58 | [6, 7, 8], 59 | ], 'test with {0}'.format(data)) 60 | 61 | def test_convert_to_array(self): 62 | batch_generator = BatchGenerator() 63 | id_list_list = [ 64 | [1, 2], 65 | [3, 4, 5, 6], 66 | [7], 67 | ] 68 | self.assertEqual(batch_generator._convert_to_array(id_list_list).tolist(), [ 69 | [1, 2, 0, 0], 70 | [3, 4, 5, 6], 71 | [7, 0, 0, 0], 72 | ]) 73 | -------------------------------------------------------------------------------- /test/deepdialog/transformer/transformer.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import tensorflow as tf 3 | import numpy as np 4 | from deepdialog.transformer.transformer import ( 5 | Transformer, Encoder, Decoder, PAD_ID 6 | ) 7 | 8 | tf.enable_eager_execution() 9 | 10 | 11 | class TestTransformer(unittest.TestCase): 12 | def test_build_graph(self): 13 | vocab_size = 17 14 | max_length = 13 15 | hidden_dim = 32 16 | with tf.Graph().as_default(), tf.Session() as sess: 17 | model = Transformer(vocab_size, hopping_num=3, head_num=4, hidden_dim=hidden_dim, 18 | dropout_rate=0.1, max_length=max_length) 19 | model.build_graph() 20 | sess.run(tf.global_variables_initializer()) 21 | loss, acc, prediction = sess.run([model.loss, model.acc, model.prediction], feed_dict={ 22 | model.is_training: True, 23 | model.encoder_input: np.array([[10, 11, 12], [13, 14, 15]]), 24 | model.decoder_input: np.array([[1, 20, 21, 2], [1, 22, 23, 2]]), 25 | }) 26 | self.assertIsInstance(loss, np.float32) 27 | self.assertIsInstance(acc, np.float32) 28 | self.assertEqual(prediction.shape, (2, 3, vocab_size)) # 3 == decoder_len - 1 29 | 30 | # Graph がグローバルに持っている重みと、モデルがプロパティとして持っている重みが一致することのテスト 31 | graph_weight_set = set(tf.trainable_variables()) 32 | model_weight_set = set(model.weights) 33 | self.assertEqual(model_weight_set, graph_weight_set) 34 | 35 | def test_call(self): 36 | vocab_size = 17 37 | batch_size = 7 38 | max_length = 13 39 | enc_length = 11 40 | dec_length = 10 41 | hidden_dim = 32 42 | model = Transformer(vocab_size, hopping_num=3, head_num=4, hidden_dim=hidden_dim, 43 | dropout_rate=0.1, max_length=max_length) 44 | encoder_input = tf.ones(shape=[batch_size, enc_length], dtype=tf.int32) 45 | decoder_input = tf.ones(shape=[batch_size, dec_length], dtype=tf.int32) 46 | y = model( 47 | encoder_input, 48 | decoder_input, 49 | training=True, 50 | ) 51 | self.assertEqual(y.shape, [batch_size, dec_length, vocab_size]) 52 | 53 | def test_create_enc_attention_mask(self): 54 | P = PAD_ID 55 | x = tf.constant([ 56 | [1, 2, 3, P], 57 | [1, 2, P, P], 58 | ]) 59 | model = Transformer(vocab_size=17) 60 | self.assertEqual(model._create_enc_attention_mask(x).numpy().tolist(), [ 61 | [[[False, False, False, True]]], 62 | [[[False, False, True, True]]], 63 | ]) 64 | 65 | def test_create_dec_self_attention_mask(self): 66 | P = PAD_ID 67 | x = tf.constant([ 68 | [1, 2, 3, P], 69 | [1, 2, P, P], 70 | ]) 71 | model = Transformer(vocab_size=17) 72 | self.assertEqual(model._create_dec_self_attention_mask(x).numpy().tolist(), [ 73 | [[ 74 | [False, True, True, True], 75 | [False, False, True, True], 76 | [False, False, False, True], 77 | [False, False, False, True], 78 | ]], 79 | [[ 80 | [False, True, True, True], 81 | [False, False, True, True], 82 | [False, False, True, True], 83 | [False, False, True, True], 84 | ]], 85 | ]) 86 | 87 | 88 | class TestEncoder(unittest.TestCase): 89 | def test_call(self): 90 | vocab_size = 17 91 | batch_size = 7 92 | max_length = 13 93 | length = 11 94 | hidden_dim = 32 95 | model = Encoder(vocab_size, hopping_num=3, head_num=4, hidden_dim=hidden_dim, 96 | dropout_rate=0.1, max_length=max_length) 97 | x = tf.ones(shape=[batch_size, length], dtype=tf.int32) 98 | mask = tf.cast(tf.zeros(shape=[batch_size, 1, 1, length]), tf.bool) 99 | y = model(x, self_attention_mask=mask, training=True) 100 | self.assertEqual(y.shape, [batch_size, length, hidden_dim]) 101 | 102 | 103 | class TestDecoder(unittest.TestCase): 104 | def test_call(self): 105 | vocab_size = 17 106 | batch_size = 7 107 | max_length = 13 108 | enc_length = 11 109 | dec_length = 10 110 | hidden_dim = 32 111 | model = Decoder(vocab_size, hopping_num=3, head_num=4, hidden_dim=hidden_dim, 112 | dropout_rate=0.1, max_length=max_length) 113 | decoder_input = tf.ones(shape=[batch_size, dec_length], dtype=tf.int32) 114 | encoder_output = tf.ones(shape=[batch_size, enc_length, hidden_dim]) 115 | dec_self_attention_mask = tf.cast(tf.zeros(shape=[batch_size, 1, dec_length, dec_length]), tf.bool) 116 | enc_dec_attention_mask = tf.cast(tf.zeros(shape=[batch_size, 1, 1, enc_length]), tf.bool) 117 | y = model( 118 | decoder_input, 119 | encoder_output, 120 | self_attention_mask=dec_self_attention_mask, 121 | enc_dec_attention_mask=enc_dec_attention_mask, 122 | training=True, 123 | ) 124 | self.assertEqual(y.shape, [batch_size, dec_length, vocab_size]) 125 | -------------------------------------------------------------------------------- /test/run: -------------------------------------------------------------------------------- 1 | #/bin/sh 2 | 3 | if [ $# -eq 1 ];then 4 | echo '# flake8 ################################' 5 | pipenv run flake8 $1 6 | echo '' 7 | echo '# mypy ################################' 8 | pipenv run mypy $1 9 | echo '' 10 | echo '# test ##################################' 11 | pipenv run test $1 12 | else 13 | echo '# flake8 ################################' 14 | pipenv run flake8 ./ 15 | echo '' 16 | echo '# mypy ################################' 17 | pipenv run mypy ./ 18 | echo '' 19 | echo '# test ##################################' 20 | pipenv run test 21 | fi 22 | -------------------------------------------------------------------------------- /test/run.py: -------------------------------------------------------------------------------- 1 | from unittest import TestLoader 2 | from unittest import TextTestRunner 3 | import sys 4 | import os 5 | 6 | 7 | def run(path=None): 8 | ''' 9 | テストを実行します。 10 | 11 | :param str path: 指定された場合、そのパスに対応するテストを実行します。 12 | ''' 13 | 14 | project_dir = './' 15 | 16 | if path: 17 | tests = _get_tests_from_file_path(path, project_dir) 18 | else: 19 | tests = TestLoader().discover( 20 | os.path.join(project_dir, 'test/'), 21 | pattern='*.py', 22 | top_level_dir=project_dir 23 | ) 24 | 25 | return_code = not TextTestRunner().run(tests).wasSuccessful() 26 | sys.exit(return_code) 27 | 28 | 29 | def _get_tests_from_file_path(path, project_dir): 30 | if not path.endswith('.py'): 31 | raise Exception('test file path should not dir path') 32 | 33 | # path は test/hoge/fuga.py などで与えられる 34 | path = os.path.relpath(path, project_dir) 35 | if not path.startswith('test/'): 36 | path = 'test/' + path 37 | 38 | # test.hoge.fuga に変換 39 | module_name = path.replace('.py', '').replace('/', '.') 40 | return TestLoader().loadTestsFromName(module_name) 41 | 42 | 43 | if __name__ == '__main__': 44 | run(*sys.argv[1:]) 45 | -------------------------------------------------------------------------------- /tmp/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/halhorn/deep_dialog_tutorial/134d68db94d2a4cb6267a2fa64d2db0a5b70449e/tmp/.gitkeep -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | # setting for flake8 (python) 2 | 3 | [flake8] 4 | max-line-length = 120 5 | exclude = venv/* 6 | --------------------------------------------------------------------------------