├── .dockerignore ├── .floydignore ├── .gitignore ├── Dockerfile ├── LICENSE ├── Pipfile ├── Pipfile.lock ├── README.md ├── RUNNING.md ├── macgraph ├── __init__.py ├── activations.py ├── args.py ├── attention.py ├── attention_test.py ├── cell │ ├── __init__.py │ ├── decode.py │ ├── mac_cell.py │ ├── messaging_cell.py │ ├── messaging_cell_helpers.py │ ├── output_cell.py │ ├── query.py │ └── types.py ├── component.py ├── const.py ├── estimator.py ├── evaluate.py ├── global_args.py ├── hooks.py ├── input │ ├── __init__.py │ ├── args.py │ ├── balancer.py │ ├── build.py │ ├── build_test.py │ ├── graph_util.py │ ├── input.py │ ├── kb.py │ ├── partitioner.py │ ├── print_gqa.py │ ├── print_tfr.py │ ├── text_util.py │ └── util.py ├── layers.py ├── minception.py ├── model.py ├── optimizer.py ├── predict.py ├── print_util.py ├── train.py ├── unit_test.py └── util.py ├── train.sh └── util ├── __init__.py └── file.py /.dockerignore: -------------------------------------------------------------------------------- 1 | 2 | input_data 3 | output 4 | kubernetes 5 | *.DS_Store 6 | .git 7 | -------------------------------------------------------------------------------- /.floydignore: -------------------------------------------------------------------------------- 1 | 2 | # Directories and files to ignore when uploading code to floyd 3 | ./input_data/ 4 | ./output/ 5 | ./output_test/ 6 | 7 | 8 | __pycache__ 9 | 10 | .git 11 | .eggs 12 | eggs 13 | lib 14 | lib64 15 | parts 16 | sdist 17 | var 18 | *.pyc 19 | *.swp 20 | .DS_Store 21 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | nohup*.out 2 | /input*/ 3 | /output*/ 4 | 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | *~ 11 | .DS_Store 12 | 13 | 14 | 15 | # C extensions 16 | *.so 17 | 18 | # Distribution / packaging 19 | .Python 20 | env/ 21 | build/ 22 | develop-eggs/ 23 | dist/ 24 | downloads/ 25 | eggs/ 26 | .eggs/ 27 | lib/ 28 | lib64/ 29 | parts/ 30 | sdist/ 31 | var/ 32 | wheels/ 33 | *.egg-info/ 34 | .installed.cfg 35 | *.egg 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | .hypothesis/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # pyenv 83 | .python-version 84 | 85 | # celery beat schedule file 86 | celerybeat-schedule 87 | 88 | # SageMath parsed files 89 | *.sage.py 90 | 91 | # dotenv 92 | .env 93 | 94 | # virtualenv 95 | .venv 96 | venv/ 97 | ENV/ 98 | 99 | # Spyder project settings 100 | .spyderproject 101 | .spyproject 102 | 103 | # Rope project settings 104 | .ropeproject 105 | 106 | # mkdocs documentation 107 | /site 108 | 109 | # mypy 110 | .mypy_cache/ 111 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # FROM gcr.io/tensorflow/tensorflow:1.7.0-rc0-py3 2 | FROM gcr.io/google-appengine/python 3 | RUN pip install --upgrade pip 4 | RUN pip install pipenv 5 | 6 | WORKDIR /source 7 | 8 | # Only do costly pipenv install when needed 9 | COPY Pipfile . 10 | RUN pipenv install --verbose --skip-lock 11 | 12 | COPY . . 13 | 14 | CMD "./run-k8.sh" -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | This is free and unencumbered software released into the public domain. 2 | 3 | Anyone is free to copy, modify, publish, use, compile, sell, or 4 | distribute this software, either in source code form or as a compiled 5 | binary, for any purpose, commercial or non-commercial, and by any 6 | means. 7 | 8 | In jurisdictions that recognize copyright laws, the author or authors 9 | of this software dedicate any and all copyright interest in the 10 | software to the public domain. We make this dedication for the benefit 11 | of the public at large and to the detriment of our heirs and 12 | successors. We intend this dedication to be an overt act of 13 | relinquishment in perpetuity of all present and future rights to this 14 | software under copyright law. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 17 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 18 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 19 | IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR 20 | OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, 21 | ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR 22 | OTHER DEALINGS IN THE SOFTWARE. 23 | 24 | For more information, please refer to 25 | -------------------------------------------------------------------------------- /Pipfile: -------------------------------------------------------------------------------- 1 | [[source]] 2 | 3 | url = "https://pypi.python.org/simple" 4 | verify_ssl = true 5 | name = "pypi" 6 | 7 | 8 | [dev-packages] 9 | 10 | 11 | 12 | [packages] 13 | 14 | tensorflow = "*" 15 | numpy = "*" 16 | pyyaml = ">=4.2b1" 17 | tqdm = "*" 18 | colored = "*" 19 | coloredlogs = "*" 20 | 21 | 22 | [requires] 23 | 24 | python_version = "3.6" 25 | -------------------------------------------------------------------------------- /Pipfile.lock: -------------------------------------------------------------------------------- 1 | { 2 | "_meta": { 3 | "hash": { 4 | "sha256": "4608f6add1cc754c1dd032278d974efe37a041b5a173fec1d92da53032587026" 5 | }, 6 | "host-environment-markers": { 7 | "implementation_name": "cpython", 8 | "implementation_version": "3.6.2", 9 | "os_name": "posix", 10 | "platform_machine": "x86_64", 11 | "platform_python_implementation": "CPython", 12 | "platform_release": "18.0.0", 13 | "platform_system": "Darwin", 14 | "platform_version": "Darwin Kernel Version 18.0.0: Wed Aug 22 20:13:40 PDT 2018; root:xnu-4903.201.2~1/RELEASE_X86_64", 15 | "python_full_version": "3.6.2", 16 | "python_version": "3.6", 17 | "sys_platform": "darwin" 18 | }, 19 | "pipfile-spec": 6, 20 | "requires": { 21 | "python_version": "3.6" 22 | }, 23 | "sources": [ 24 | { 25 | "name": "pypi", 26 | "url": "https://pypi.python.org/simple", 27 | "verify_ssl": true 28 | } 29 | ] 30 | }, 31 | "default": { 32 | "absl-py": { 33 | "hashes": [ 34 | "sha256:87519e3b91a3d573664c6e2ee33df582bb68dca6642ae3cf3a4361b1c0a4e9d6" 35 | ], 36 | "version": "==0.6.1" 37 | }, 38 | "astor": { 39 | "hashes": [ 40 | "sha256:fb503b9e2fdd05609fbf557b916b4a7824171203701660f0c55bbf5a7a68713e", 41 | "sha256:95c30d87a6c2cf89aa628b87398466840f0ad8652f88eb173125a6df8533fb8d" 42 | ], 43 | "version": "==0.7.1" 44 | }, 45 | "backports.weakref": { 46 | "hashes": [ 47 | "sha256:81bc9b51c0abc58edc76aefbbc68c62a787918ffe943a37947e162c3f8e19e82", 48 | "sha256:bc4170a29915f8b22c9e7c4939701859650f2eb84184aee80da329ac0b9825c2" 49 | ], 50 | "version": "==1.0.post1" 51 | }, 52 | "colored": { 53 | "hashes": [ 54 | "sha256:8296ea990e3f6b7822f44eec21408b126dfb9c1c031306b859e3f7d46cc27075" 55 | ], 56 | "version": "==1.3.93" 57 | }, 58 | "coloredlogs": { 59 | "hashes": [ 60 | "sha256:34fad2e342d5a559c31b6c889e8d14f97cb62c47d9a2ae7b5ed14ea10a79eff8", 61 | "sha256:b869a2dda3fa88154b9dd850e27828d8755bfab5a838a1c97fbc850c6e377c36" 62 | ], 63 | "version": "==10.0" 64 | }, 65 | "enum34": { 66 | "hashes": [ 67 | "sha256:6bd0f6ad48ec2aa117d3d141940d484deccda84d4fcd884f5c3d93c23ecd8c79", 68 | "sha256:644837f692e5f550741432dd3f223bbb9852018674981b1664e5dc339387588a", 69 | "sha256:8ad8c4783bf61ded74527bffb48ed9b54166685e4230386a9ed9b1279e2df5b1", 70 | "sha256:2d81cbbe0e73112bdfe6ef8576f2238f2ba27dd0d55752a776c41d38b7da2850" 71 | ], 72 | "version": "==1.1.6" 73 | }, 74 | "funcsigs": { 75 | "hashes": [ 76 | "sha256:330cc27ccbf7f1e992e69fef78261dc7c6569012cf397db8d3de0234e6c937ca", 77 | "sha256:a7bb0f2cf3a3fd1ab2732cb49eba4252c2af4240442415b4abce3b87022a8f50" 78 | ], 79 | "markers": "python_version < '3.3'", 80 | "version": "==1.0.2" 81 | }, 82 | "futures": { 83 | "hashes": [ 84 | "sha256:c4884a65654a7c45435063e14ae85280eb1f111d94e542396717ba9828c4337f", 85 | "sha256:51ecb45f0add83c806c68e4b06106f90db260585b25ef2abfcda0bd95c0132fd" 86 | ], 87 | "markers": "python_version < '3'", 88 | "version": "==3.1.1" 89 | }, 90 | "gast": { 91 | "hashes": [ 92 | "sha256:7068908321ecd2774f145193c4b34a11305bd104b4551b09273dfd1d6a374930" 93 | ], 94 | "version": "==0.2.0" 95 | }, 96 | "grpcio": { 97 | "hashes": [ 98 | "sha256:0e8ff79b12b8b07198dd847974fc32a4ed8c0d52d5224fabb9d28bf4c2e3f4a9", 99 | "sha256:a7f21a7b48fcd9f51029419b22a9bfea097973cca5d1529b8578f1d2919e6b23", 100 | "sha256:57705e31f76db45b51f3a98bcfd362c89d58e99f846337a25fed957b4d43ae4f", 101 | "sha256:e68e6afbbae2cbfadaabd33ee40314963cd83500feff733c07edb172674a7f8b", 102 | "sha256:62c777f801aee22100d8ea5fa057020e37b65541a8000091879a8560b089da9d", 103 | "sha256:cdea5595b30f027e6603887b71f343ca5b209da74b910fe04fc25e1dfe6df263", 104 | "sha256:145e82aec0a643d7569499b1aa0d5167c99d9d26a2b8c4e4b3f5cd51b99a8cdc", 105 | "sha256:8317d351ab1e80cf20676ef3d4929d3e760df10e6e5c289283c36c4c92ca61f7", 106 | "sha256:1a820ebf0c924cbfa299cb59e4bc9582a24abfec89d9a36c281d78fa941115ae", 107 | "sha256:612e742c748df51c921a7eefd76195d76467e3cc00e084e089af5b111d8210b7", 108 | "sha256:b51d49d89758ea45841130c5c7be79c68612d8834bd600994b8a2672c59dc9b9", 109 | "sha256:2a8b6b569fd23f4d9f2c8201fd8995519dfbddc60ceeffa8bf5bea2a8e9cb72c", 110 | "sha256:fa6e14bce7ad5de2363abb644191489ddfffcdb2751337251f7ef962ab7e3293", 111 | "sha256:d64350156dc4b21914409e0c93ffeeb4ceba193716fb1ae570df699383c4cd63", 112 | "sha256:9a7ed6160e6c14058b4676aac68a8bf268f171f4c371ff0a0c0ab81b90803f70", 113 | "sha256:38b93080df498656aea1dbab632e32013c580c2d00bd8c30d0f1d2c9513b0469", 114 | "sha256:4837ad8fdcf99df0e89214ba42001469cab807851f30481db41fd84fc9358ce7", 115 | "sha256:11c8026a3d35e8b9ad6cda7bf4f5e51b9b82e7f29a590ad194f63957657fa808", 116 | "sha256:8b72721e64becd4a3e9580f12dbdf618d41e80d3ae7585dc8a921dbf76c979bb", 117 | "sha256:8bb7dbe20fe883ee22a6cb2c1317ea228b75a3ef60f3749584ee2634192e3452", 118 | "sha256:f7bb6617bae5e7333e66ec1e7aac1fe419b59e0e34a8717f97e1ce2791ab9d3a", 119 | "sha256:a46c34768f292fa0d97e929591e51ec20dc857321d83b198de1dad9c8183e8cb", 120 | "sha256:f0c0e48c255a63fec78be2f240ff5a3bd4291b1f83976895f6ee0085362568d0", 121 | "sha256:b3bbeadc6b99e4a42bf23803f5e9b292f23f3e37cc7f75a9f5efbfa9b812abc1", 122 | "sha256:284bee4657c4dd7d48835128b31975e8b0ea3a2eeb084c5d46de215b31d1f8f5", 123 | "sha256:e10bbef59706a90672b295c0f82dcb6329d829643b8dd7c3bd120f89a093d740", 124 | "sha256:082bc981d6aabfdb26bfdeab63f5626df3d2c5ac3a9ae8533dfa5ce73432f4fe", 125 | "sha256:cbb95a586fdf3e795eba28b4acc75fdfdb59a14df62e747fe8bc4572ef37b647", 126 | "sha256:adfee9c9099cae92c2a4948bc95cc2cc3185cdf59b371e056b8dd19ed434247e", 127 | "sha256:5447336edd6fea8ab35eca34ff5289e369e22c375bc2ac8156a419fa467949ac", 128 | "sha256:8703efaf03396123426fdea08b369712df1248fa5fdfdbee3f87a410f52e9bac", 129 | "sha256:fd6774bbb6c717f725b39394757445ead4f69c471118364933aadb81a4f16961" 130 | ], 131 | "version": "==1.17.1" 132 | }, 133 | "h5py": { 134 | "hashes": [ 135 | "sha256:f3b49107fbfc77333fc2b1ef4d5de2abcd57e7ea3a1482455229494cf2da56ce", 136 | "sha256:0f94de7a10562b991967a66bbe6dda9808e18088676834c0a4dcec3fdd3bcc6f", 137 | "sha256:713ac19307e11de4d9833af0c4bd6778bde0a3d967cafd2f0f347223711c1e31", 138 | "sha256:30e365e8408759db3778c361f1e4e0fe8e98a875185ae46c795a85e9bafb9cdf", 139 | "sha256:3206bac900e16eda81687d787086f4ffd4f3854980d798e191a9868a6510c3ae", 140 | "sha256:4162953714a9212d373ac953c10e3329f1e830d3c7473f2a2e4f25dd6241eef0", 141 | "sha256:407b5f911a83daa285bbf1ef78a9909ee5957f257d3524b8606be37e8643c5f0", 142 | "sha256:106e42e2e01e486a3d32eeb9ba0e3a7f65c12fa8998d63625fa41fb8bdc44cdb", 143 | "sha256:1606c66015f04719c41a9863c156fc0e6b992150de21c067444bcb82e7d75579", 144 | "sha256:e58a25764472af07b7e1c4b10b0179c8ea726446c7141076286e41891bf3a563", 145 | "sha256:0dd2adeb2e9de5081eb8dcec88874e7fd35dae9a21557be3a55a3c7d491842a4", 146 | "sha256:b9e4b8dfd587365bdd719ae178fa1b6c1231f81280b1375eef8626dfd8761bf3", 147 | "sha256:082a27208aa3a2286e7272e998e7e225b2a7d4b7821bd840aebf96d50977abbb", 148 | "sha256:c5dd4ec75985b99166c045909e10f0534704d102848b1d9f0992720e908928e7", 149 | "sha256:8cc4aed71e20d87e0a6f02094d718a95252f11f8ed143bc112d22167f08d4040", 150 | "sha256:2cca17e80ddb151894333377675db90cd0279fa454776e0a4f74308376afd050", 151 | "sha256:71b946d80ef3c3f12db157d7778b1fe74a517ca85e94809358b15580983c2ce2", 152 | "sha256:5fc7aba72a51b2c80605eba1c50dbf84224dcd206279d30a75c154e5652e1fe4", 153 | "sha256:d2b82f23cd862a9d05108fe99967e9edfa95c136f532a71cb3d28dc252771f50", 154 | "sha256:1e9fb6f1746500ea91a00193ce2361803c70c6b13f10aae9a33ad7b5bd28e800", 155 | "sha256:3c23d72058647cee19b30452acc7895621e2de0a0bd5b8a1e34204b9ea9ed43c", 156 | "sha256:a744e13b000f234cd5a5b2a1f95816b819027c57f385da54ad2b7da1adace2f3", 157 | "sha256:1854c4beff9961e477e133143c5e5e355dac0b3ebf19c52cf7cc1b1ef757703c", 158 | "sha256:08e2e8297195f9e813e894b6c63f79372582787795bba2014a2db6a2de95f713", 159 | "sha256:05750b91640273c69989c657eaac34b091abdd75efc8c4824c82aaf898a2da0a", 160 | "sha256:b087ee01396c4b34e9dc41e3a6a0442158206d383c19c7d0396d52067b17c1cb", 161 | "sha256:b0f03af381d33306ce67d18275b61acb4ca111ced645381387a02c8a5ee1b796", 162 | "sha256:9d41ca62daf36d6b6515ab8765e4c8c4388ee18e2a665701fef2b41563821002" 163 | ], 164 | "version": "==2.9.0" 165 | }, 166 | "humanfriendly": { 167 | "hashes": [ 168 | "sha256:42d0aa829f59c710db20ec42eed24a8b7a27688d477da61b5aebd604d0bb2402", 169 | "sha256:1d3a1c157602801c62dfdb321760229df2e0d4f14412a0f41b13ad3f930a936a" 170 | ], 171 | "version": "==4.17" 172 | }, 173 | "keras-applications": { 174 | "hashes": [ 175 | "sha256:721dda4fa4e043e5bbd6f52a2996885c4639a7130ae478059b3798d0706f5ae7", 176 | "sha256:a03af60ddc9c5afdae4d5c9a8dd4ca857550e0b793733a5072e0725829b87017" 177 | ], 178 | "version": "==1.0.6" 179 | }, 180 | "keras-preprocessing": { 181 | "hashes": [ 182 | "sha256:90d04c1750bccceef88ac09475c291b4b5f6aa1eaf0603167061b1aa8b043c61", 183 | "sha256:ef2e482c4336fcf7180244d06f4374939099daa3183816e82aee7755af35b754" 184 | ], 185 | "version": "==1.0.5" 186 | }, 187 | "markdown": { 188 | "hashes": [ 189 | "sha256:c00429bd503a47ec88d5e30a751e147dcb4c6889663cd3e2ba0afe858e009baa", 190 | "sha256:d02e0f9b04c500cde6637c11ad7c72671f359b87b9fe924b2383649d8841db7c" 191 | ], 192 | "version": "==3.0.1" 193 | }, 194 | "mock": { 195 | "hashes": [ 196 | "sha256:5ce3c71c5545b472da17b72268978914d0252980348636840bd34a00b5cc96c1", 197 | "sha256:b158b6df76edd239b8208d481dc46b6afd45a846b7812ff0ce58971cf5bc8bba" 198 | ], 199 | "version": "==2.0.0" 200 | }, 201 | "monotonic": { 202 | "hashes": [ 203 | "sha256:552a91f381532e33cbd07c6a2655a21908088962bb8fa7239ecbcc6ad1140cc7", 204 | "sha256:23953d55076df038541e648a53676fb24980f7a1be290cdda21300b3bc21dfb0" 205 | ], 206 | "markers": "python_version == '2.6' or python_version == '2.7' or python_version == '3.0' or python_version == '3.1' or python_version == '3.2'", 207 | "version": "==1.5" 208 | }, 209 | "numpy": { 210 | "hashes": [ 211 | "sha256:18e84323cdb8de3325e741a7a8dd4a82db74fde363dce32b625324c7b32aa6d7", 212 | "sha256:154c35f195fd3e1fad2569930ca51907057ae35e03938f89a8aedae91dd1b7c7", 213 | "sha256:4d8d3e5aa6087490912c14a3c10fbdd380b40b421c13920ff468163bc50e016f", 214 | "sha256:c857ae5dba375ea26a6228f98c195fec0898a0fd91bcf0e8a0cae6d9faf3eca7", 215 | "sha256:0df89ca13c25eaa1621a3f09af4c8ba20da849692dcae184cb55e80952c453fb", 216 | "sha256:36e36b6868e4440760d4b9b44587ea1dc1f06532858d10abba98e851e154ca70", 217 | "sha256:99d59e0bcadac4aa3280616591fb7bcd560e2218f5e31d5223a2e12a1425d495", 218 | "sha256:edfa6fba9157e0e3be0f40168eb142511012683ac3dc82420bee4a3f3981b30e", 219 | "sha256:b261e0cb0d6faa8fd6863af26d30351fd2ffdb15b82e51e81e96b9e9e2e7ba16", 220 | "sha256:db9814ff0457b46f2e1d494c1efa4111ca089e08c8b983635ebffb9c1573361f", 221 | "sha256:df04f4bad8a359daa2ff74f8108ea051670cafbca533bb2636c58b16e962989e", 222 | "sha256:7da99445fd890206bfcc7419f79871ba8e73d9d9e6b82fe09980bc5bb4efc35f", 223 | "sha256:56994e14b386b5c0a9b875a76d22d707b315fa037affc7819cda08b6d0489756", 224 | "sha256:ecf81720934a0e18526177e645cbd6a8a21bb0ddc887ff9738de07a1df5c6b61", 225 | "sha256:cf5bb4a7d53a71bb6a0144d31df784a973b36d8687d615ef6a7e9b1809917a9b", 226 | "sha256:561ef098c50f91fbac2cc9305b68c915e9eb915a74d9038ecf8af274d748f76f", 227 | "sha256:4f41fd159fba1245e1958a99d349df49c616b133636e0cf668f169bce2aeac2d", 228 | "sha256:416a2070acf3a2b5d586f9a6507bb97e33574df5bd7508ea970bbf4fc563fa52", 229 | "sha256:24fd645a5e5d224aa6e39d93e4a722fafa9160154f296fd5ef9580191c755053", 230 | "sha256:23557bdbca3ccbde3abaa12a6e82299bc92d2b9139011f8c16ca1bb8c75d1e95", 231 | "sha256:b1853df739b32fa913cc59ad9137caa9cc3d97ff871e2bbd89c2a2a1d4a69451", 232 | "sha256:73a1f2a529604c50c262179fcca59c87a05ff4614fe8a15c186934d84d09d9a5", 233 | "sha256:1e8956c37fc138d65ded2d96ab3949bd49038cc6e8a4494b1515b0ba88c91565", 234 | "sha256:a4cc09489843c70b22e8373ca3dfa52b3fab778b57cf81462f1203b0852e95e3", 235 | "sha256:4a22dc3f5221a644dfe4a63bf990052cc674ef12a157b1056969079985c92816", 236 | "sha256:b1f44c335532c0581b77491b7715a871d0dd72e97487ac0f57337ccf3ab3469b", 237 | "sha256:a61dc29cfca9831a03442a21d4b5fd77e3067beca4b5f81f1a89a04a71cf93fa", 238 | "sha256:3d734559db35aa3697dadcea492a423118c5c55d176da2f3be9c98d4803fc2a7" 239 | ], 240 | "version": "==1.15.4" 241 | }, 242 | "pbr": { 243 | "hashes": [ 244 | "sha256:f6d5b23f226a2ba58e14e49aa3b1bfaf814d0199144b95d78458212444de1387", 245 | "sha256:f59d71442f9ece3dffc17bc36575768e1ee9967756e6b6535f0ee1f0054c3d68" 246 | ], 247 | "version": "==5.1.1" 248 | }, 249 | "protobuf": { 250 | "hashes": [ 251 | "sha256:10394a4d03af7060fa8a6e1cbf38cea44be1467053b0aea5bbfcb4b13c4b88c4", 252 | "sha256:59cd75ded98094d3cf2d79e84cdb38a46e33e7441b2826f3838dcc7c07f82995", 253 | "sha256:1931d8efce896981fe410c802fd66df14f9f429c32a72dd9cfeeac9815ec6444", 254 | "sha256:92e8418976e52201364a3174e40dc31f5fd8c147186d72380cbda54e0464ee19", 255 | "sha256:a7ee3bb6de78185e5411487bef8bc1c59ebd97e47713cba3c460ef44e99b3db9", 256 | "sha256:5ee0522eed6680bb5bac5b6d738f7b0923b3cafce8c4b1a039a6107f0841d7ed", 257 | "sha256:fcfc907746ec22716f05ea96b7f41597dfe1a1c088f861efb8a0d4f4196a6f10", 258 | "sha256:ceec283da2323e2431c49de58f80e1718986b79be59c266bb0509cbf90ca5b9e", 259 | "sha256:65917cfd5da9dfc993d5684643063318a2e875f798047911a9dd71ca066641c9", 260 | "sha256:46e34fdcc2b1f2620172d3a4885128705a4e658b9b62355ae5e98f9ea19f42c2", 261 | "sha256:9335f79d1940dfb9bcaf8ec881fb8ab47d7a2c721fb8b02949aab8bbf8b68625", 262 | "sha256:685bc4ec61a50f7360c9fd18e277b65db90105adbf9c79938bd315435e526b90", 263 | "sha256:574085a33ca0d2c67433e5f3e9a0965c487410d6cb3406c83bdaf549bfc2992e", 264 | "sha256:4b92e235a3afd42e7493b281c8b80c0c65cbef45de30f43d571d1ee40a1f77ef", 265 | "sha256:e7a5ccf56444211d79e3204b05087c1460c212a2c7d62f948b996660d0165d68", 266 | "sha256:196d3a80f93c537f27d2a19a4fafb826fb4c331b0b99110f985119391d170f96", 267 | "sha256:1489b376b0f364bcc6f89519718c057eb191d7ad6f1b395ffd93d1aa45587811" 268 | ], 269 | "version": "==3.6.1" 270 | }, 271 | "pyyaml": { 272 | "hashes": [ 273 | "sha256:d5eef459e30b09f5a098b9cea68bebfeb268697f78d647bd255a085371ac7f3f", 274 | "sha256:e01d3203230e1786cd91ccfdc8f8454c8069c91bee3962ad93b87a4b2860f537", 275 | "sha256:558dd60b890ba8fd982e05941927a3911dc409a63dcb8b634feaa0cda69330d3", 276 | "sha256:d46d7982b62e0729ad0175a9bc7e10a566fc07b224d2c79fafb5e032727eaa04", 277 | "sha256:a7c28b45d9f99102fa092bb213aa12e0aaf9a6a1f5e395d36166639c1f96c3a1", 278 | "sha256:bc558586e6045763782014934bfaf39d48b8ae85a2713117d16c39864085c613", 279 | "sha256:40c71b8e076d0550b2e6380bada1f1cd1017b882f7e16f09a65be98e017f211a", 280 | "sha256:3d7da3009c0f3e783b2c873687652d83b1bbfd5c88e9813fb7e5b03c0dd3108b", 281 | "sha256:e170a9e6fcfd19021dd29845af83bb79236068bf5fd4df3327c1be18182b2531", 282 | "sha256:aa7dd4a6a427aed7df6fb7f08a580d68d9b118d90310374716ae90b710280af1", 283 | "sha256:3ef3092145e9b70e3ddd2c7ad59bdd0252a94dfe3949721633e41344de00a6bf" 284 | ], 285 | "version": "==3.13" 286 | }, 287 | "six": { 288 | "hashes": [ 289 | "sha256:3350809f0555b11f552448330d0b52d5f24c91a322ea4a15ef22629740f3761c", 290 | "sha256:d16a0141ec1a18405cd4ce8b4613101da75da0e9a7aec5bdd4fa804d0e0eba73" 291 | ], 292 | "version": "==1.12.0" 293 | }, 294 | "tensorboard": { 295 | "hashes": [ 296 | "sha256:698285207c6749d8d1c562e77d35ba6a07ee747b61362da49fa4bd3feaf28f6f", 297 | "sha256:8f5c158db581a70be4ea432875be8bea20c55e51b92973c8c330643a2bec4f77" 298 | ], 299 | "version": "==1.12.1" 300 | }, 301 | "tensorflow": { 302 | "hashes": [ 303 | "sha256:5cee35f8a6a12e83560f30246811643efdc551c364bc981d27f21fbd0926403d", 304 | "sha256:2681b55d3e434e20fe98e3a3b1bde3588af62d7864b62feee4141a71e29ef594", 305 | "sha256:16fb8a59e724afd37a276d33b7e2ed070e5c84899a8d4cfc3fe1bb446a859da7", 306 | "sha256:6ad6ed495f1a3d445c43d90cb2ce251ff5532fd6436e25f52977ee59ffa583df", 307 | "sha256:42fc8398ce9f9895b488f516ea0143cf6cf2a3a5fc804da4a190b063304bc173", 308 | "sha256:e4f479e6aca595acc98347364288cbdfd3c025ca85389380174ea75a43c327b7", 309 | "sha256:1ae50e44c0b29df5fb5b460118be5a257b4eb3e561008f64d2c4c715651259b7", 310 | "sha256:1b7d09cc26ef727d628dcb74841b89374a38ed81af25bd589a21659ef67443da", 311 | "sha256:d3f3d7cd9bd4cdc7ebf25fd6c2dfc103dcf4b2834ae9276cc4cf897eb1515f6d", 312 | "sha256:f587dc03b5f0d1e50cca39b7159c9f21ffdec96273dbf5f7619d48c622cb21f2", 313 | "sha256:531619ad1c17b4084d09f442a9171318af813e81aae748e5de8274d561461749", 314 | "sha256:cd8c1a899e3befe1ccb774ea1aae077a4b1286f855c956210b23766f4ac85c30" 315 | ], 316 | "version": "==1.12.0" 317 | }, 318 | "termcolor": { 319 | "hashes": [ 320 | "sha256:1d6d69ce66211143803fbc56652b41d73b4a400a2891d7bf7a1cdf4c02de613b" 321 | ], 322 | "version": "==1.1.0" 323 | }, 324 | "tqdm": { 325 | "hashes": [ 326 | "sha256:3c4d4a5a41ef162dd61f1edb86b0e1c7859054ab656b2e7c7b77e7fbf6d9f392", 327 | "sha256:5b4d5549984503050883bc126280b386f5f4ca87e6c023c5d015655ad75bdebb" 328 | ], 329 | "version": "==4.28.1" 330 | }, 331 | "werkzeug": { 332 | "hashes": [ 333 | "sha256:d5da73735293558eb1651ee2fddc4d0dedcfa06538b8813a2e20011583c9e49b", 334 | "sha256:c3fd7a7d41976d9f44db327260e263132466836cef6f91512889ed60ad26557c" 335 | ], 336 | "version": "==0.14.1" 337 | }, 338 | "wheel": { 339 | "hashes": [ 340 | "sha256:1e53cdb3f808d5ccd0df57f964263752aa74ea7359526d3da6c02114ec1e1d44", 341 | "sha256:029703bf514e16c8271c3821806a1c171220cc5bdd325cbf4e7da1e056a01db6" 342 | ], 343 | "markers": "python_version < '3'", 344 | "version": "==0.32.3" 345 | } 346 | }, 347 | "develop": {} 348 | } 349 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Finding shortest paths with Graph Networks 2 | 3 | > In this article we show how a Graph Network with attention read and write can perform shortest path calculations. This network performs this task with 99.91% accuracy after minimal training. 4 | 5 | Here at Octavian we believe that graphs are a powerful medium for representing diverse knowledge (for example BenevolentAI uses them to represent pharmaceutical research and knowledge). 6 | 7 | Neural networks are a way to create functions that no human could write. They do this by harnessing the power of large datasets. 8 | 9 | On problems for which we have capable neural models, we can use example inputs and outputs to train the network to learn a function that transforms those inputs into those outputs, and hopefully generalizes to other unseen inputs. 10 | 11 | We need to be able to build neural networks that can learn functions on graphs. Those neural networks need the right inductive biases so that they can reliably learn useful graph functions. With that foundation, we can build powerful neural graph systems. 12 | 13 | Here we present a "Graph network with attention read and write", a simple network that can effectively compute shortest path. It is an example of how to combine different neural network components to make a system that readily learns a classical graph algorithm. 14 | 15 | We present this network both as a novel system in of itself, but more importantly as the basis for further investigation into effective neural graph computation. 16 | 17 | Read our [extensive article about this architecture](https://medium.com/octavian-ai/finding-shortest-paths-with-graph-networks-807c5bbfc9c8). 18 | 19 | Download the [pre-compiled YAML dataset](https://storage.googleapis.com/octavian-static/download/clevr-graph/StationShortestCount.zip) or the [fully-compiled TFRecords dataset](https://storage.googleapis.com/octavian-static/download/mac-graph/StationShortestCount.zip). Data is expected to live in input_data/processed/StationShortestCount/. 20 | 21 | ## Running 22 | 23 | ```shell 24 | pipenv install 25 | pipenv run ./train.sh 26 | ``` 27 | 28 | ## Visualising the attention 29 | 30 | ```shell 31 | pipenv run python -m macgraph.predict --model-dir ./output/StationShortestPath/ 32 | ``` 33 | -------------------------------------------------------------------------------- /RUNNING.md: -------------------------------------------------------------------------------- 1 | 2 | # Running this code 3 | 4 | ## Working with the network locally 5 | 6 | ### Prerequisites 7 | 8 | We use the pipenv dependency/virtualenv framework: 9 | ```shell 10 | $ pipenv install 11 | $ pipenv shell 12 | (mac-graph-sjOzWQ6Y) $ 13 | ``` 14 | 15 | ### Prediction 16 | 17 | You can watch the model predict values from the hold-back data: 18 | ```shell 19 | $ python -m macgraph.predict --name my_dataset --model-version 0ds9f0s 20 | 21 | predicted_label: shabby 22 | actual_label: derilict 23 | src: How clean is 3 ? 24 | ------- 25 | predicted_label: small 26 | actual_label: medium-sized 27 | src: How big is 4 ? 28 | ------- 29 | predicted_label: medium-sized 30 | actual_label: tiny 31 | src: How big is 7 ? 32 | ------- 33 | predicted_label: True 34 | actual_label: True 35 | src: Does 1 have rail connections ? 36 | ------- 37 | predicted_label: True 38 | actual_label: False 39 | src: Does 0 have rail connections ? 40 | ------- 41 | predicted_label: victorian 42 | actual_label: victorian 43 | src: What architectural style is 1 ? 44 | ``` 45 | 46 | **TODO: Get it predicting from your typed input** 47 | 48 | ### Building the data 49 | 50 | To train the model, you need training data. 51 | 52 | If you want to skip this step, you can download the pre-built data from [our public dataset](https://www.floydhub.com/davidmack/datasets/mac-graph). This repo is a work in progress so the format is still in flux. 53 | 54 | The underlying data (a Graph-Question-Answer YAML from CLEVR-graph) must be pre-processed for training and evaluation. The YAML is transformed into TensorFlow records, and split into train-evaluate-predict tranches. 55 | 56 | First [generate](https://github.com/Octavian-ai/clevr-graph) a `gqa.yaml` with the command: 57 | ```shell 58 | clevr-graph$ python -m gqa.generate --count 50000 --int-names 59 | cp data/gqa-some-id.yaml ../mac-graph/input_data/raw/my_dataset.yaml 60 | ``` 61 | Then build (that is, pre-process into a vocab table and tfrecords) the data: 62 | 63 | ```shell 64 | mac-graph$ python -m macgraph.input.build --name my_dataset 65 | ``` 66 | 67 | #### Arguments to build 68 | - `--limit N` will only read N records from the YAML and only output a total of N tf-records (split across three tranches) 69 | - `--type-string-prefix StationProperty` will filter just questions with type string prefix "StationProperty" 70 | 71 | 72 | ### Training 73 | 74 | Let's build a model. (Note, this requires training data from the previous section). 75 | 76 | General advice is to have at least 40,000 training records (e.g. build from 50,000 GQA triples) 77 | 78 | ```shell 79 | python -m macgraph.train --name my_dataset 80 | ``` 81 | 82 | 83 | -------------------------------------------------------------------------------- /macgraph/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Octavian-ai/shortest-path/7baef8d4cad13297fa2d08b5ac0f19f06bb708e3/macgraph/__init__.py -------------------------------------------------------------------------------- /macgraph/activations.py: -------------------------------------------------------------------------------- 1 | 2 | import tensorflow as tf 3 | 4 | from .minception import mi_activation 5 | 6 | def absu(x): 7 | return tf.nn.relu(x) + tf.nn.relu(-x) 8 | 9 | # Expand activation args to callables 10 | ACTIVATION_FNS = { 11 | "tanh": tf.tanh, 12 | "relu": tf.nn.relu, 13 | "sigmoid": tf.nn.sigmoid, 14 | "mi": mi_activation, 15 | "abs": absu, 16 | "tanh_abs": lambda x: tf.concat([tf.tanh(x), absu(x)], axis=-1), 17 | "linear": tf.identity, 18 | "id": tf.identity, 19 | "selu": tf.nn.selu, 20 | } -------------------------------------------------------------------------------- /macgraph/args.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | import os.path 4 | import yaml 5 | import subprocess 6 | import pathlib 7 | import tensorflow as tf 8 | import glob 9 | import logging 10 | import coloredlogs 11 | 12 | from .global_args import global_args 13 | 14 | from .activations import ACTIVATION_FNS 15 | from .input import Vocab 16 | 17 | 18 | 19 | def generate_args_derivatives(args): 20 | 21 | r = {} 22 | r["modes"] = ["eval", "train", "predict"] 23 | 24 | if "gqa_paths" in args: 25 | if args["gqa_paths"] == [] or args["gqa_paths"] == None: 26 | r["gqa_paths"] = [os.path.join(args["gqa_dir"], args["dataset"]) + ".yaml"] 27 | else: 28 | gp = [] 29 | for i in args["gqa_paths"]: 30 | if "*" in i: 31 | gp += glob.glob(i) 32 | else: 33 | gp.append(i) 34 | 35 | r["gqa_paths"] = gp 36 | 37 | if args["input_dir"] is None: 38 | r["input_dir"] = os.path.join(args["input_dir_prefix"], args["dataset"]) 39 | else: 40 | r["input_dir"] = args["input_dir"] 41 | 42 | if args["model_dir"] is None: 43 | r["model_dir"] = os.path.join(args["model_dir_prefix"], args["dataset"], *args["tag"], args["model_version"]) 44 | else: 45 | r["model_dir"] = args["model_dir"] 46 | 47 | 48 | r["profile_path"] = os.path.join(r["model_dir"], "profile") 49 | 50 | # Expand input dirs 51 | for i in [*r["modes"], "all"]: 52 | r[i+"_input_path"] = os.path.join(r["input_dir"], i+"_input.tfrecords") 53 | 54 | if args["vocab_path"] is None: 55 | r["vocab_path"] = os.path.join(r["input_dir"], "vocab.txt") 56 | else: 57 | r["vocab_path"] = args["vocab_path"] 58 | 59 | r["config_path"] = os.path.join(r["model_dir"], "config.yaml") 60 | r["question_types_path"] = os.path.join(r["input_dir"], "types.yaml") 61 | r["answer_classes_path"] = os.path.join(r["input_dir"], "answer_classes.yaml") 62 | r["answer_classes_types_path"] = os.path.join(r["input_dir"], "answer_classes_types.yaml") 63 | 64 | r["mp_head_list"] = ["mp_write", "mp_read0"] 65 | 66 | r["query_sources"] = [ "token_index"] 67 | r["query_taps"] = ["switch_attn", "token_index_attn"] 68 | 69 | 70 | if args["use_fast"]: 71 | r["use_summary_scalar"] = False 72 | r["use_summary_image"] = False 73 | r["use_assert"] = False 74 | 75 | 76 | 77 | try: 78 | r["vocab"] = Vocab.load(r["vocab_path"], args["vocab_size"]) 79 | except tf.errors.NotFoundError: 80 | pass 81 | 82 | return r 83 | 84 | def get_git_hash(): 85 | try: 86 | result = subprocess.run( 87 | ['git', '--no-pager', 'log', "--pretty=format:%h", '-n', '1'], 88 | stdout=subprocess.PIPE, 89 | check=True, 90 | universal_newlines=True 91 | ) 92 | return result.stdout 93 | except subprocess.CalledProcessError: 94 | # Git was angry, oh well 95 | return "unknown" 96 | 97 | def get_args(extend=lambda parser:None, argv=None): 98 | 99 | parser = argparse.ArgumentParser() 100 | extend(parser) 101 | 102 | # -------------------------------------------------------------------------- 103 | # General 104 | # -------------------------------------------------------------------------- 105 | 106 | parser.add_argument('--log-level', type=str, default='INFO') 107 | parser.add_argument('--output-dir', type=str, default="./output") 108 | parser.add_argument('--dataset', type=str, default="default", help="Name of dataset") 109 | parser.add_argument('--input-dir', type=str, default=None) 110 | parser.add_argument('--input-dir-prefix', type=str, default="./input_data/processed") 111 | parser.add_argument('--tag', action="append") 112 | 113 | parser.add_argument('--model-dir', type=str, default=None) 114 | parser.add_argument('--model-version', type=str, default=get_git_hash(), help="Model will be saved to a directory with this name, to assist with repeatable experiments") 115 | parser.add_argument('--model-dir-prefix', type=str, default="./output") 116 | 117 | 118 | # Used in train / predict / build 119 | parser.add_argument('--limit', type=int, default=None, help="How many rows of input data to read") 120 | parser.add_argument('--filter-type-prefix', type=str, default=None, help="Filter input data rows to only have this type string prefix") 121 | parser.add_argument('--filter-output-class', action="append", help="Filter input data rows to only have this output class") 122 | 123 | # -------------------------------------------------------------------------- 124 | # Data build 125 | # -------------------------------------------------------------------------- 126 | 127 | parser.add_argument('--eval-holdback', type=float, default=0.1) 128 | parser.add_argument('--predict-holdback', type=float, default=0.005) 129 | 130 | 131 | # -------------------------------------------------------------------------- 132 | # Training 133 | # -------------------------------------------------------------------------- 134 | 135 | parser.add_argument('--warm-start-dir', type=str, default=None, help="Load model initial weights from previous checkpoints") 136 | 137 | parser.add_argument('--batch-size', type=int, default=32, help="Number of items in a full batch") 138 | parser.add_argument('--train-max-steps', type=float, default=None, help="In thousands") 139 | parser.add_argument('--results-path', type=str, default="./results.yaml") 140 | 141 | 142 | parser.add_argument('--max-gradient-norm', type=float, default=0.4) 143 | parser.add_argument('--learning-rate', type=float, default=0.01) 144 | parser.add_argument('--enable-regularization', action='store_true', dest='use_regularization') 145 | parser.add_argument('--regularization-factor', type=float, default=0.0001) 146 | parser.add_argument('--random-seed', type=int, default=3) 147 | parser.add_argument('--enable-gradient-clipping', action='store_true', dest='use_gradient_clipping') 148 | parser.add_argument('--eval-every', type=int, default=2*60, help="Evaluate every X seconds") 149 | 150 | parser.add_argument('--fast', action='store_true', dest='use_fast') 151 | 152 | # -------------------------------------------------------------------------- 153 | # Decode 154 | # -------------------------------------------------------------------------- 155 | 156 | parser.add_argument('--max-decode-iterations', type=int, default=1) 157 | parser.add_argument('--finished-steps-loss-factor', type=float, default= 0.001) 158 | parser.add_argument('--enable-dynamic-decode', action='store_true', dest="use_dynamic_decode") 159 | parser.add_argument('--enable-independent-iterations', action='store_true', dest="use_independent_iterations") 160 | 161 | # -------------------------------------------------------------------------- 162 | # Network topology 163 | # -------------------------------------------------------------------------- 164 | 165 | parser.add_argument('--vocab-size', type=int, default=128, help="How many different words are in vocab") 166 | parser.add_argument('--vocab-path', type=str, default=None, help="Custom vocab path") 167 | 168 | parser.add_argument('--max-seq-len', type=int, default=40, help="Maximum length of question token list") 169 | 170 | parser.add_argument('--embed-width', type=int, default=128, help="The width of token embeddings") 171 | parser.add_argument('--disable-embed-const-eye', action='store_false', dest='use_embed_const_eye') 172 | 173 | parser.add_argument('--kb-node-width', type=int, default=7, help="Width of node entry into graph table aka the knowledge base") 174 | parser.add_argument('--kb-node-max-len', type=int, default=40, help="Maximum number of nodes in kb") 175 | parser.add_argument('--kb-edge-width', type=int, default=3, help="Width of edge entry into graph table aka the knowledge base") 176 | parser.add_argument('--kb-edge-max-len', type=int, default=40, help="Maximum number of edges in kb") 177 | 178 | parser.add_argument('--mp-activation', type=str, default="selu", choices=ACTIVATION_FNS.keys()) 179 | parser.add_argument('--mp-state-width', type=int, default=4) 180 | parser.add_argument('--disable-mp-gru', action='store_false', dest='use_mp_gru') 181 | parser.add_argument('--mp-read-heads', type=int, default=1) 182 | 183 | parser.add_argument('--output-activation', type=str, default="selu", choices=ACTIVATION_FNS.keys()) 184 | parser.add_argument('--output-layers', type=int, default=1) 185 | parser.add_argument('--output-width', type=int, default=128, help="The number of different possible answers (e.g. answer classes). Currently tied to vocab size since we attempt to tokenise the output.") 186 | parser.add_argument('--disable-output-lookback', action='store_false', dest="use_output_lookback") 187 | 188 | parser.add_argument('--enable-lr-finder', action='store_true', dest="use_lr_finder") 189 | parser.add_argument('--enable-curriculum', action='store_true', dest="use_curriculum") 190 | 191 | parser.add_argument('--enable-tf-debug', action='store_true', dest="use_tf_debug") 192 | parser.add_argument('--enable-floyd', action='store_true', dest="use_floyd") 193 | parser.add_argument('--disable-assert', action='store_false', dest="use_assert") 194 | 195 | args = vars(parser.parse_args(argv)) 196 | 197 | args.update(generate_args_derivatives(args)) 198 | 199 | # Global singleton var for easy access deep in the codebase (e.g. utility functions) 200 | # Note that this wont play well with PBT!! 201 | # TODO: Remove 202 | global_args.clear() 203 | global_args.update(args) 204 | 205 | 206 | # Setup logging 207 | logging.basicConfig() 208 | tf.logging.set_verbosity(args["log_level"]) 209 | logging.getLogger("mac-graph").setLevel(args["log_level"]) 210 | 211 | loggers = [logging.getLogger(i) 212 | for i in ["__main__", "pbt", "experiment", "macgraph", "util", "tensorflow"]] 213 | 214 | for i in loggers: 215 | i.handlers = [] 216 | coloredlogs.install(logger=i, level=args["log_level"], fmt='%(levelname)s %(name)s %(message)s') 217 | 218 | return args 219 | 220 | 221 | def save_args(args): 222 | pathlib.Path(args["model_dir"]).mkdir(parents=True, exist_ok=True) 223 | with tf.gfile.GFile(os.path.join(args["config_path"]), "w") as file: 224 | yaml.dump(args, file) 225 | -------------------------------------------------------------------------------- /macgraph/attention.py: -------------------------------------------------------------------------------- 1 | 2 | import tensorflow as tf 3 | 4 | from .util import * 5 | 6 | from .args import global_args 7 | from .const import EPSILON 8 | from .print_util import * 9 | 10 | from .component import * 11 | 12 | class Attention(Component): 13 | def __init__(self, args, table:Component, query:Component, key_width:int=None, seq_len:int=None, keys_len:Tensor=None, table_representation=None, name:str=None): 14 | super().__init__(args, name) 15 | 16 | self.table = table 17 | self.query = query 18 | self.key_width = key_width 19 | self.seq_len = seq_len 20 | self.keys_len = keys_len 21 | self.table_representation = table_representation 22 | 23 | def forward(self, features): 24 | attn, self.focus, self._taps = attention( 25 | self.table.forward(features), 26 | self.query.forward(features), 27 | self.key_width, 28 | self.keys_len, 29 | name=self.name 30 | ) 31 | 32 | return attn 33 | 34 | def taps(self): 35 | return { 36 | "attn": self._taps['attn'], 37 | "attn_raw": self._taps['attn_raw'], 38 | } 39 | 40 | def tap_sizes(self): 41 | return { 42 | "attn": [self.seq_len], 43 | "attn_raw": [self.seq_len], 44 | } 45 | 46 | def print(self, taps, path, prefix, all_features): 47 | if self.table_representation is None: 48 | table_rep = list(range(len(taps["attn"]))) 49 | elif isinstance(self.table_representation, str): 50 | table_rep = all_features[self.table_representation] 51 | else: 52 | table_rep = self.table_representation 53 | 54 | l = ' '.join(color_text(table_rep, taps["attn"].flatten())) 55 | 56 | print(prefix, l) 57 | 58 | 59 | 60 | 61 | class AttentionByIndex(Component): 62 | def __init__(self, args, table:Component, control:Component, seq_len:int=None, table_representation=None, name:str=None): 63 | super().__init__(args, name) 64 | 65 | self.table = table 66 | self.control = control 67 | self.seq_len = seq_len 68 | self.table_representation = table_representation 69 | 70 | def forward(self, features): 71 | output, self.tap_attn = attention_by_index( 72 | self.table.forward(features), 73 | self.control.forward(features), 74 | name=self.name 75 | ) 76 | 77 | return output 78 | 79 | def taps(self): 80 | return { 81 | "attn": self.tap_attn, 82 | } 83 | 84 | def tap_sizes(self): 85 | return { 86 | "attn": [self.seq_len], 87 | } 88 | 89 | def print(self, taps, path, prefix, all_features): 90 | if self.table_representation is None: 91 | table_rep = list(range(len(taps["attn"]))) 92 | elif isinstance(self.table_representation, str): 93 | table_rep = all_features[self.table_representation] 94 | else: 95 | table_rep = self.table_representation 96 | 97 | l = ' '.join(color_text(table_rep, taps["attn"].flatten())) 98 | 99 | print(prefix, l) 100 | 101 | 102 | 103 | def softmax_with_masking(logits, mask, axis, name="", internal_dtype=tf.float64): 104 | with tf.name_scope(name+"_softmax_with_masking"): 105 | 106 | # -------------------------------------------------------------------------- 107 | # Validate inputs 108 | # -------------------------------------------------------------------------- 109 | 110 | logits_shape = tf.shape(logits) 111 | 112 | assert mask.dtype == tf.bool 113 | mask = dynamic_assert_shape(mask, logits_shape) 114 | assert axis < len(logits.shape) 115 | logits = tf.check_numerics(logits, "logits") 116 | mask = dynamic_assert_shape(mask, logits_shape, "mask") 117 | 118 | # -------------------------------------------------------------------------- 119 | # Mask those logits! 120 | # -------------------------------------------------------------------------- 121 | 122 | # masked_logits = tf.boolean_mask(logits, mask) 123 | # masked_logits = tf.reshape(masked_logits, tf.shape(logits)) 124 | 125 | f_mask = tf.cast(mask, internal_dtype) 126 | masked_logits = tf.cast(logits, internal_dtype) * f_mask 127 | masked_logits = dynamic_assert_shape(masked_logits, logits_shape, "masked_logits") 128 | 129 | # masked_logits = tf.Print(masked_logits, [f"{name}: masked_logits", tf.squeeze(masked_logits)], message="\n", summarize=9999) 130 | 131 | # For numerical stability shrink the values 132 | logits_max = tf.reduce_max(masked_logits, axis=axis, keepdims=True) 133 | logits_max = tf.check_numerics(logits_max, "logit_max") 134 | 135 | # Numerator 136 | l_delta = (tf.cast(logits, internal_dtype) - logits_max) * f_mask 137 | l_delta = tf.check_numerics(l_delta, "l_delta") 138 | l_delta = dynamic_assert_shape(l_delta, tf.shape(logits), "l_delta") 139 | 140 | 141 | # l_delta = tf.Print(l_delta, [f"{name}: logits", tf.squeeze(logits,-1)], message="\n", summarize=9999) 142 | # l_delta = tf.Print(l_delta, [f"{name}: logits_max", logits_max], message="\n", summarize=9999) 143 | # l_delta = tf.Print(l_delta, [f"{name}: l_delta", tf.squeeze(l_delta,-1)], message="\n", summarize=9999) 144 | 145 | # This assert fails, howwwww?? 146 | with tf.control_dependencies([tf.assert_less_equal(l_delta, tf.cast(0.0, l_delta.dtype), summarize=100000, data=[logits_max, mask, logits])]): 147 | 148 | l = tf.exp(l_delta) 149 | l = tf.check_numerics(l, "numerator pre mask") 150 | 151 | # l = tf.Print(l, [f"numerator pre mask {name}", tf.squeeze(l,-1)], message="\n", summarize=9999) 152 | 153 | l *= f_mask 154 | l = tf.check_numerics(l, "numerator") 155 | # l = tf.Print(l, [f"numerator post mask {name}", tf.squeeze(l,-1)], message="\n", summarize=9999) 156 | 157 | # Denominator 158 | d = tf.reduce_sum(l, axis) 159 | d = tf.expand_dims(d, axis) 160 | d = tf.check_numerics(d, "denominator") 161 | 162 | normalized = l / (d + EPSILON) 163 | normalized = tf.cast(normalized, logits.dtype) 164 | 165 | normalized = dynamic_assert_shape(normalized, logits_shape, "normalized_sm_scores") 166 | 167 | # Total, by batch 168 | scores_total = tf.reduce_sum(normalized, axis=axis) 169 | # keys_more_than_zero = tf.where( 170 | # tf.greater(keys_len, 0), 171 | # tf.ones(tf.shape(scores_total)), tf.zeros(tf.shape(scores_total))) 172 | 173 | sum_to_one = tf_assert_almost_equal(scores_total, 1.0, message=f"Checking scores sum to 1.0",summarize=999) 174 | 175 | with tf.control_dependencies([sum_to_one]): 176 | return normalized 177 | 178 | 179 | 180 | 181 | def attention(table:tf.Tensor, query:tf.Tensor, key_width:int=None, keys_len=None, name="attention"): 182 | """ 183 | Returns: 184 | - attention_output 185 | - focus 186 | - taps {"attn", "attn_raw"} 187 | """ 188 | 189 | return attention_key_value( 190 | keys=table, 191 | table=table, 192 | query=query, 193 | key_width=key_width, keys_len=keys_len, name=name) 194 | 195 | def attention_key_value(keys:tf.Tensor, table:tf.Tensor, query:tf.Tensor, key_width:int=None, keys_len=None, name="attention"): 196 | """ 197 | Apply attention 198 | 199 | Arguments: 200 | - `keys`, shape (batch_size, len, key_width) 201 | - `query`, shape (batch_size, key_width) 202 | - `table`, shape (batch_size, len, value_width) 203 | - key_width: The width of the key entries 204 | - `keys_len` A tensor of the lengths of the tables (in the batch) that is used to mask the scores before applying softmax (i.e. meaning that any table values after the length are ignored in the lookup) 205 | 206 | 207 | Returns: 208 | - attention_output 209 | - focus 210 | - taps {"attn", "attn_raw"} 211 | """ 212 | 213 | assert len(table.shape) == 3, f"table should be shape [batch, seq_len, value_width] but is len(shape) {len(table.shape)}" 214 | batch_size = tf.shape(table)[0] 215 | seq_len = tf.shape(table)[1] 216 | value_width = tf.shape(table)[2] 217 | 218 | keys = dynamic_assert_shape(keys, [batch_size, seq_len, tf.shape(keys)[2]], "keys") 219 | 220 | scores_sm, attn_focus, scores_raw = attention_compute_scores( 221 | keys=keys, 222 | query=query, 223 | key_width=key_width, 224 | keys_len=keys_len, 225 | name=name) 226 | 227 | scores_sm = dynamic_assert_shape(scores_sm, [batch_size, seq_len, 1], "scores_sm") 228 | 229 | with tf.name_scope(name): 230 | weighted_table = table * scores_sm 231 | 232 | output = tf.reduce_sum(weighted_table, 1) 233 | output = dynamic_assert_shape(output, [batch_size, value_width], "output") 234 | output = tf.check_numerics(output, "attention_output") 235 | 236 | return output, attn_focus, {"attn": scores_sm, "attn_raw": scores_raw} 237 | 238 | def attention_compute_scores(keys:tf.Tensor, query:tf.Tensor, key_width:int=None, keys_len=None, name:str="attention"): 239 | with tf.name_scope(name): 240 | 241 | # -------------------------------------------------------------------------- 242 | # Validate inputs 243 | # -------------------------------------------------------------------------- 244 | 245 | assert query is not None 246 | assert keys is not None 247 | assert len(keys.shape) == 3, "keys should be shape [batch, len, key_width]" 248 | 249 | batch_size = tf.shape(keys)[0] 250 | seq_len = tf.shape(keys)[1] 251 | 252 | if keys_len is not None: 253 | keys_len = dynamic_assert_shape(keys_len, [batch_size], "keys_len") 254 | 255 | if key_width is None: 256 | key_width = tf.shape(keys)[2] 257 | 258 | q_shape = [batch_size, key_width] 259 | scores_shape = [batch_size, seq_len, 1] 260 | keys_shape = [batch_size, seq_len, key_width] 261 | 262 | query = dynamic_assert_shape(query, q_shape, "query") 263 | keys = dynamic_assert_shape(keys, keys_shape, "keys") # Somewhat tautologious 264 | 265 | # -------------------------------------------------------------------------- 266 | # Run model 267 | # -------------------------------------------------------------------------- 268 | 269 | # mul = tf.get_variable("attn_mul", [1], dtype=query.dtype) 270 | # bias = tf.get_variable("attn_bias", [1], dtype=query.dtype) 271 | 272 | scores = tf.matmul(keys, tf.expand_dims(query, 2)) 273 | scores /= tf.sqrt(tf.cast(tf.shape(query)[-1], scores.dtype)) # As per Transformer model 274 | scores = dynamic_assert_shape(scores, scores_shape, "scores") 275 | 276 | if keys_len is not None: 277 | scores_mask = tf.sequence_mask(keys_len, seq_len) 278 | scores_mask = tf.expand_dims(scores_mask, -1) 279 | scores_mask = dynamic_assert_shape(scores_mask, scores_shape, "scores_mask") 280 | 281 | scores = tf.where(scores_mask, scores, tf.fill(scores_shape, -1e9)) 282 | scores_sm = tf.nn.softmax(scores + EPSILON, axis=1) 283 | 284 | # scores_sm = softmax_with_masking(scores, mask=scores_mask, axis=1, name=name) 285 | else: 286 | scores_sm = tf.nn.softmax(scores + EPSILON, axis=1) 287 | 288 | scores_sm = dynamic_assert_shape(scores_sm, scores_shape, "scores_sm") 289 | 290 | return scores_sm, tf.reduce_sum(scores, axis=1), scores 291 | 292 | 293 | def attention_write_by_key(keys, query, value, key_width=None, keys_len=None, name="attention"): 294 | """ 295 | Returns: 296 | - attention_output 297 | - softmax_scores 298 | - focus 299 | """ 300 | 301 | batch_size = tf.shape(keys)[0] 302 | seq_len = tf.shape(keys)[1] 303 | value_width = tf.shape(value)[-1] 304 | 305 | assert len(value.shape) == 2, f"Value must be two dimensional, not {len(value.shape)}" 306 | 307 | scores_sm, attn_focus, scores_raw = attention_compute_scores( 308 | keys=keys, query=query, key_width=key_width, keys_len=keys_len, name=name) 309 | 310 | with tf.name_scope(name): 311 | weighted_table = tf.expand_dims(value, 1) * scores_sm 312 | weighted_table = dynamic_assert_shape(weighted_table, [batch_size, seq_len, value_width]) 313 | return weighted_table, attn_focus, {"attn": scores_sm, "attn_raw": scores_raw} 314 | 315 | 316 | 317 | 318 | 319 | def attention_by_index(table, control, name:str="attention_by_index"): 320 | ''' 321 | Essentially a weighted sum over the second-last dimension of table, 322 | using a dense softmax of control for the weights 323 | 324 | Requires table to have fixed seq_len 325 | 326 | 327 | Shapes: 328 | * control [batch, word_size] 329 | * table [batch, seq_len, word_size] 330 | 331 | Returns [batch, word_size] 332 | 333 | ''' 334 | 335 | with tf.name_scope(name): 336 | with tf.variable_scope(name): 337 | 338 | word_size = tf.shape(table)[-1] 339 | seq_len = table.shape[-2] 340 | batch_size = tf.shape(table)[0] 341 | 342 | query_shape = [batch_size, seq_len] 343 | output_shape = [batch_size, word_size] 344 | 345 | assert seq_len is not None, "Seq len must be defined" 346 | 347 | if control is not None: 348 | query = tf.layers.dense(control, seq_len, activation=tf.nn.softmax) 349 | else: 350 | query = tf.get_variable("query", [1, seq_len], trainable=True, initializer=tf.initializers.random_normal) 351 | query = tf.tile(query, [batch_size, 1]) 352 | query = tf.nn.softmax(query) 353 | 354 | weighted_stack = table * tf.expand_dims(query, -1) 355 | weighted_sum = tf.reduce_sum(weighted_stack, -2) 356 | 357 | output = weighted_sum 358 | output = dynamic_assert_shape(output, output_shape) 359 | return output, query 360 | 361 | 362 | -------------------------------------------------------------------------------- /macgraph/attention_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import tensorflow as tf 4 | tf.enable_eager_execution() 5 | 6 | import numpy as np 7 | import math 8 | 9 | from .attention import * 10 | 11 | class TestAttention(unittest.TestCase): 12 | 13 | # def setUp(self): 14 | # tf.enable_eager_execution() 15 | 16 | def test_softmax_masking(self): 17 | 18 | max_len = 3 19 | axis = 1 20 | logits = tf.eye(max_len) 21 | seq_len = [1,2,2] 22 | mask = tf.sequence_mask(seq_len, max_len) 23 | 24 | r = softmax_with_masking(logits, mask, axis) 25 | r = np.array(r) 26 | 27 | d = math.exp(1) + math.exp(0) 28 | 29 | expected = np.array([ 30 | [1,0,0], 31 | [math.exp(0)/d, math.exp(1)/d,0], 32 | [0.5, 0.5, 0], 33 | ]) 34 | 35 | np.testing.assert_almost_equal(r, expected) 36 | 37 | def test_softmax_masking2(self): 38 | 39 | max_len = 3 40 | axis = 1 41 | logits = tf.zeros([max_len, max_len]) 42 | seq_len = [1,2,3] 43 | mask = tf.sequence_mask(seq_len, max_len) 44 | 45 | r = softmax_with_masking(logits, mask, axis) 46 | r = np.array(r) 47 | 48 | expected = np.array([ 49 | [1.0,0.0,0], 50 | [0.5,0.5,0], 51 | [1.0/3.0, 1.0/3.0, 1.0/3.0], 52 | ]) 53 | 54 | np.testing.assert_almost_equal(r, expected) 55 | 56 | def test_softmax_write(self): 57 | 58 | max_len = 6 59 | keys = tf.expand_dims(tf.eye(max_len), 0) 60 | target = 3 61 | batch_len = 1 62 | 63 | table, focus, taps = attention_write_by_key(keys, keys[:,target,:], tf.ones([batch_len, max_len])) 64 | 65 | d = math.exp(1) + (max_len-1) * math.exp(0) 66 | exp = np.full([batch_len, max_len, max_len], 1/d) 67 | exp[:,target,:] = (d-5)/d 68 | 69 | np.set_printoptions(threshold=np.inf) 70 | np.testing.assert_almost_equal(table.numpy(), exp) 71 | 72 | 73 | if __name__ == '__main__': 74 | unittest.main() -------------------------------------------------------------------------------- /macgraph/cell/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .decode import execute_reasoning 3 | from .mac_cell import MAC_Component -------------------------------------------------------------------------------- /macgraph/cell/decode.py: -------------------------------------------------------------------------------- 1 | 2 | import tensorflow as tf 3 | 4 | from .mac_cell import * 5 | from ..util import * 6 | 7 | 8 | 9 | 10 | def static_decode(args, features, inputs, labels, vocab_embedding): 11 | with tf.variable_scope("decoder", reuse=tf.AUTO_REUSE): 12 | 13 | d_cell = MAC_RNNCell(args, features, vocab_embedding) 14 | d_cell_initial = d_cell.zero_state(dtype=tf.float32, batch_size=features["d_batch_size"]) 15 | d_cell_empty_output = [tf.zeros([features["d_batch_size"], args["output_width"]])] 16 | 17 | # Hard-coded unroll of the reasoning network for simplicity 18 | states = [(d_cell_empty_output, d_cell_initial)] 19 | for i in range(args["max_decode_iterations"]): 20 | with tf.variable_scope("decoder_cell", reuse=tf.AUTO_REUSE): 21 | inputs_slice = [item[i] for item in inputs] 22 | prev_outputs = [item[0][0] for item in states] 23 | 24 | if len(prev_outputs) < args["max_decode_iterations"]: 25 | for i in range(args["max_decode_iterations"] - len(prev_outputs)): 26 | prev_outputs.append(d_cell_empty_output[0]) 27 | 28 | assert len(prev_outputs) == args["max_decode_iterations"] 29 | prev_outputs = tf.stack(prev_outputs, axis=1) 30 | 31 | inputs_for_iteration = [*inputs_slice, prev_outputs] 32 | prev_state = states[-1][1] 33 | 34 | states.append(d_cell(inputs_for_iteration, prev_state)) 35 | 36 | final_output = states[-1][0][0] 37 | 38 | def get_tap(idx, key): 39 | with tf.name_scope(f"get_tap_{key}"): 40 | tap = [i[0][idx] for i in states[1:] if i[0] is not None] 41 | 42 | for i in tap: 43 | if i is None: 44 | return None 45 | 46 | if len(tap) == 0: 47 | return None 48 | 49 | tap = tf.convert_to_tensor(tap) 50 | 51 | # Deal with batch vs iteration axis layout 52 | if len(tap.shape) == 3: 53 | tap = tf.transpose(tap, [1,0,2]) # => batch, iteration, data 54 | if len(tap.shape) == 4: 55 | tap = tf.transpose(tap, [1,0,2,3]) # => batch, iteration, control_head, data 56 | 57 | return tap 58 | 59 | out_taps = { 60 | key: get_tap(idx+1, key) 61 | for idx, key in enumerate(d_cell.tap_sizes().keys()) 62 | } 63 | 64 | return final_output, out_taps 65 | 66 | 67 | def execute_reasoning(args, features, **kwargs): 68 | 69 | d_eye = tf.eye(args["max_decode_iterations"]) 70 | 71 | iteration_id = [ 72 | tf.tile(tf.expand_dims(d_eye[i], 0), [features["d_batch_size"], 1]) 73 | for i in range(args["max_decode_iterations"]) 74 | ] 75 | 76 | inputs = [iteration_id] 77 | 78 | final_output, out_taps = static_decode(args, features, inputs, **kwargs) 79 | 80 | 81 | final_output = dynamic_assert_shape(final_output, [features["d_batch_size"], args["output_width"]]) 82 | 83 | 84 | return final_output, out_taps 85 | 86 | 87 | 88 | 89 | -------------------------------------------------------------------------------- /macgraph/cell/mac_cell.py: -------------------------------------------------------------------------------- 1 | 2 | import tensorflow as tf 3 | 4 | from ..component import Component 5 | 6 | from .output_cell import * 7 | from .messaging_cell import * 8 | from .types import * 9 | 10 | from ..util import * 11 | from ..minception import * 12 | from ..layers import * 13 | 14 | class MAC_RNNCell(tf.nn.rnn_cell.RNNCell): 15 | 16 | def __init__(self, args, features, vocab_embedding): 17 | 18 | self.args = args 19 | self.features = features 20 | self.vocab_embedding = vocab_embedding 21 | 22 | self.mac = MAC_Component(args) 23 | 24 | super().__init__(self) 25 | 26 | 27 | def __call__(self, inputs, in_state): 28 | '''Build this cell (part of implementing RNNCell) 29 | 30 | This is a wrapper that marshalls our named taps, to 31 | make sure they end up where we expect and are present. 32 | 33 | Args: 34 | inputs: `2-D` tensor with shape `[batch_size, input_size]`. 35 | state: if `self.state_size` is an integer, this should be a `2-D Tensor` 36 | with shape `[batch_size, self.state_size]`. Otherwise, if 37 | `self.state_size` is a tuple of integers, this should be a tuple 38 | with shapes `[batch_size, s] for s in self.state_size`. 39 | scope: VariableScope for the created subgraph; defaults to class name. 40 | Returns: 41 | A pair containing: 42 | - Output: A `2-D` tensor with shape `[batch_size, self.output_size]`. 43 | - New state: Either a single `2-D` tensor, or a tuple of tensors matching 44 | the arity and shapes of `state`. 45 | ''' 46 | 47 | output, out_state = self.mac.forward(self.features, inputs, in_state, self.vocab_embedding) 48 | 49 | taps = self.mac.all_taps() 50 | 51 | out_data = [output] 52 | 53 | for k,v in taps.items(): 54 | out_data.append(v) 55 | 56 | return out_data, out_state 57 | 58 | 59 | def tap_sizes(self): 60 | return self.mac.all_tap_sizes() 61 | 62 | 63 | 64 | @property 65 | def state_size(self): 66 | """ 67 | Returns a size tuple 68 | """ 69 | return ( 70 | tf.TensorShape([self.args["kb_node_max_len"], self.args["mp_state_width"]]), 71 | ) 72 | 73 | @property 74 | def output_size(self): 75 | 76 | tap_sizes = self.mac.all_tap_sizes() 77 | 78 | return [ 79 | self.args["output_width"], 80 | ] + tap_sizes 81 | 82 | 83 | 84 | 85 | class MAC_Component(Component): 86 | 87 | def __init__(self, args): 88 | super().__init__(args, name=None) # empty to preserve legacy naming 89 | 90 | self.output_cell = OutputCell(args) 91 | 92 | """ 93 | Special forward. Should return output, out_state 94 | """ 95 | def forward(self, features, inputs, in_state, vocab_embedding): 96 | # TODO: remove this transition scaffolding 97 | self.features = features 98 | self.vocab_embedding = vocab_embedding 99 | 100 | with tf.variable_scope("mac_cell", reuse=tf.AUTO_REUSE): 101 | 102 | in_node_state = in_state[0] 103 | 104 | in_iter_id = inputs[0] 105 | in_iter_id = dynamic_assert_shape(in_iter_id, [self.features["d_batch_size"], self.args["max_decode_iterations"]], "in_iter_id") 106 | 107 | in_prev_outputs = inputs[-1] 108 | 109 | embedded_question = tf.nn.embedding_lookup(vocab_embedding, features["src"]) 110 | embedded_question *= tf.sqrt(tf.cast(self.args["embed_width"], embedded_question.dtype)) # As per Transformer model 111 | embedded_question = dynamic_assert_shape(embedded_question, [features["d_batch_size"], features["src_len"], self.args["embed_width"]]) 112 | 113 | context = CellContext( 114 | features=self.features, 115 | args=self.args, 116 | vocab_embedding=self.vocab_embedding, 117 | in_prev_outputs=in_prev_outputs, 118 | in_iter_id=in_iter_id, 119 | in_node_state=in_node_state, 120 | embedded_question=embedded_question 121 | ) 122 | 123 | mp_reads, out_mp_state, mp_taps = messaging_cell(context) 124 | 125 | output = self.output_cell.forward(features, context, mp_reads) 126 | 127 | # TODO: tidy away later 128 | self.mp_taps = mp_taps 129 | 130 | self.mp_state = out_mp_state 131 | self.context = context 132 | 133 | out_state = (out_mp_state,) 134 | 135 | 136 | return output, out_state 137 | 138 | 139 | def taps(self): 140 | 141 | # TODO: Remove all of this and let it run in the subsystem 142 | 143 | mp_taps = self.mp_taps 144 | 145 | empty_attn = tf.fill([self.features["d_batch_size"], self.args["max_seq_len"], 1], 0.0) 146 | empty_query = tf.fill([self.features["d_batch_size"], self.args["max_seq_len"]], 0.0) 147 | 148 | 149 | # TODO: AST this all away 150 | out_taps = { 151 | "mp_node_state": self.mp_state, 152 | "iter_id": self.context.in_iter_id, 153 | } 154 | 155 | mp_reads = [f"mp_read{i}" for i in range(self.args["mp_read_heads"])] 156 | 157 | suffixes = ["_attn", "_attn_raw", "_query", "_signal"] 158 | for qt in ["token_index_attn"]: 159 | suffixes.append("_query_"+qt) 160 | 161 | for mp_head in ["mp_write", *mp_reads]: 162 | for suffix in suffixes: 163 | i = mp_head + suffix 164 | out_taps[i] = mp_taps.get(i, empty_query) 165 | 166 | 167 | 168 | return out_taps 169 | 170 | 171 | 172 | 173 | def tap_sizes(self): 174 | 175 | t = { 176 | "mp_node_state": tf.TensorShape([self.args["kb_node_max_len"], self.args["mp_state_width"]]), 177 | "iter_id": self.args["max_decode_iterations"], 178 | } 179 | 180 | mp_reads = [f"mp_read{i}" for i in range(self.args["mp_read_heads"])] 181 | 182 | for mp_head in ["mp_write", *mp_reads]: 183 | t[f"{mp_head}_attn"] = self.args["kb_node_max_len"] 184 | t[f"{mp_head}_attn_raw"] = self.args["kb_node_max_len"] 185 | t[f"{mp_head}_query"] = self.args["kb_node_width"] * self.args["embed_width"] 186 | t[f"{mp_head}_signal"] = self.args["mp_state_width"] 187 | t[f"{mp_head}_query_token_index_attn" ] = self.args["max_seq_len"] 188 | 189 | return t 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | -------------------------------------------------------------------------------- /macgraph/cell/messaging_cell.py: -------------------------------------------------------------------------------- 1 | 2 | from typing import NamedTuple 3 | import tensorflow as tf 4 | 5 | from .types import * 6 | from .query import * 7 | from .messaging_cell_helpers import * 8 | 9 | from ..args import ACTIVATION_FNS 10 | from ..attention import * 11 | from ..input import get_table_with_embedding 12 | from ..const import EPSILON 13 | from ..util import * 14 | from ..layers import * 15 | from ..activations import * 16 | 17 | 18 | def messaging_cell(context:CellContext): 19 | 20 | node_table, node_table_width, node_table_len = get_table_with_embedding(context.args, context.features, context.vocab_embedding, "kb_node") 21 | 22 | node_table_width = context.args["embed_width"] 23 | node_table = node_table[:,:,0:node_table_width] 24 | 25 | in_signal = context.in_iter_id 26 | 27 | taps = {} 28 | def add_taps(val, prefix): 29 | ret,tps = val 30 | for k,v in tps.items(): 31 | taps[prefix+"_"+k] = v 32 | return ret 33 | 34 | in_write_signal = layer_dense(in_signal, context.args["mp_state_width"], "sigmoid") 35 | in_write_query = add_taps(generate_token_index_query(context, "mp_write_query"), "mp_write_query") 36 | 37 | read_queries = [] 38 | for i in range(context.args["mp_read_heads"]): 39 | read_queries.append(add_taps(generate_token_index_query(context, f"mp_read{i}_query"), f"mp_read{i}_query")) 40 | 41 | out_read_signals, node_state, taps2 = do_messaging_cell(context, 42 | node_table, node_table_width, node_table_len, 43 | in_write_query, in_write_signal, read_queries) 44 | 45 | 46 | return out_read_signals, node_state, {**taps, **taps2} 47 | 48 | 49 | 50 | 51 | 52 | def calc_normalized_adjacency(context, node_state): 53 | # Aggregate via adjacency matrix with normalisation (that does not include self-edges) 54 | adj = tf.cast(context.features["kb_adjacency"], tf.float32) 55 | degree = tf.reduce_sum(adj, -1, keepdims=True) 56 | inv_degree = tf.reciprocal(degree) 57 | node_mask = tf.expand_dims(tf.sequence_mask(context.features["kb_nodes_len"], context.args["kb_node_max_len"]), -1) 58 | inv_degree = tf.where(node_mask, inv_degree, tf.zeros(tf.shape(inv_degree))) 59 | inv_degree = tf.where(tf.greater(degree, 0), inv_degree, tf.zeros(tf.shape(inv_degree))) 60 | inv_degree = tf.check_numerics(inv_degree, "inv_degree") 61 | adj_norm = inv_degree * adj 62 | adj_norm = tf.cast(adj_norm, node_state.dtype) 63 | adj_norm = tf.check_numerics(adj_norm, "adj_norm") 64 | node_incoming = tf.einsum('bnw,bnm->bmw', node_state, adj_norm) 65 | 66 | return node_incoming 67 | 68 | 69 | 70 | def do_messaging_cell(context:CellContext, 71 | node_table, node_table_width, node_table_len, 72 | in_write_query, in_write_signal, in_read_queries): 73 | 74 | ''' 75 | Operate a message passing cell 76 | Each iteration it'll do one round of message passing 77 | 78 | Returns: read_signal, node_state 79 | 80 | for to_node in nodes: 81 | to_node.state = combine_incoming_signals([ 82 | message_pass(from_node, to_node) for from_node in to_node.neighbors 83 | ] + [node_self_update(to_node)]) 84 | 85 | 86 | ''' 87 | 88 | with tf.name_scope("messaging_cell"): 89 | 90 | taps = {} 91 | taps["mp_write_query"] = in_write_query 92 | taps["mp_write_signal"] = in_write_signal 93 | 94 | node_state_shape = tf.shape(context.in_node_state) 95 | node_state = context.in_node_state 96 | assert len(node_state.shape) == 3, f"Node state should have three dimensions, has {len(node_state.shape)}" 97 | padded_node_table = pad_to_table_len(node_table, node_state, name="padded_node_table") 98 | 99 | # -------------------------------------------------------------------------- 100 | # Write to graph 101 | # -------------------------------------------------------------------------- 102 | 103 | write_signal, _, a_taps = attention_write_by_key( 104 | keys=node_table, 105 | key_width=node_table_width, 106 | keys_len=node_table_len, 107 | query=in_write_query, 108 | value=in_write_signal, 109 | name="mp_write_signal" 110 | ) 111 | 112 | for k,v in a_taps.items(): 113 | taps["mp_write_"+k] = v 114 | 115 | write_signal = pad_to_table_len(write_signal, node_state, name="write_signal") 116 | node_state += write_signal 117 | node_state = dynamic_assert_shape(node_state, node_state_shape, "node_state") 118 | 119 | # -------------------------------------------------------------------------- 120 | # Calculate adjacency 121 | # -------------------------------------------------------------------------- 122 | 123 | node_state = calc_normalized_adjacency(context, node_state) 124 | 125 | # -------------------------------------------------------------------------- 126 | # Read from graph 127 | # -------------------------------------------------------------------------- 128 | 129 | out_read_signals = [] 130 | 131 | for idx, qry in enumerate(in_read_queries): 132 | out_read_signal, _, a_taps = attention_key_value( 133 | keys=padded_node_table, 134 | keys_len=node_table_len, 135 | key_width=node_table_width, 136 | query=qry, 137 | table=node_state, 138 | name=f"mp_read{idx}" 139 | ) 140 | out_read_signals.append(out_read_signal) 141 | 142 | for k,v in a_taps.items(): 143 | taps[f"mp_read{idx}_{k}"] = v 144 | taps[f"mp_read{idx}_signal"] = out_read_signal 145 | taps[f"mp_read{idx}_query"] = qry 146 | 147 | 148 | taps["mp_node_state"] = node_state 149 | node_state = dynamic_assert_shape(node_state, node_state_shape, "node_state") 150 | assert node_state.shape[-1] == context.in_node_state.shape[-1], "Node state should not lose dimension" 151 | 152 | return out_read_signals, node_state, taps 153 | 154 | 155 | -------------------------------------------------------------------------------- /macgraph/cell/messaging_cell_helpers.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from typing import NamedTuple 4 | import tensorflow as tf 5 | 6 | from .types import * 7 | from .query import * 8 | 9 | from ..args import ACTIVATION_FNS 10 | from ..attention import * 11 | from ..input import get_table_with_embedding 12 | from ..const import EPSILON 13 | from ..util import * 14 | from ..layers import * 15 | from ..activations import * 16 | 17 | MP_State = tf.Tensor 18 | 19 | class MP_Node(NamedTuple): 20 | id: str 21 | properties: tf.Tensor 22 | state: MP_State 23 | 24 | use_message_passing_fn = False 25 | use_self_reference = False 26 | 27 | def layer_normalize(tensor): 28 | '''Apologies if I've abused this term''' 29 | 30 | in_shape = tf.shape(tensor) 31 | axes = list(range(1, len(tensor.shape))) 32 | 33 | # Keep batch axis 34 | t = tf.reduce_sum(tensor, axis=axes ) 35 | t += EPSILON 36 | t = tf.reciprocal(t) 37 | t = tf.check_numerics(t, "1/sum") 38 | 39 | tensor = tf.einsum('brc,b->brc', tensor, t) 40 | 41 | tensor = dynamic_assert_shape(tensor, in_shape, "layer_normalize_tensor") 42 | return tensor 43 | 44 | 45 | def mp_matmul(state, mat, name): 46 | return tf.nn.conv1d(state, mat, 1, 'VALID', name=name) 47 | 48 | 49 | 50 | 51 | 52 | def calc_right_shift(node_incoming): 53 | shape = tf.shape(node_incoming) 54 | node_incoming = tf.concat([node_incoming[:,:,1:],node_incoming[:,:,0:1]], axis=-1) 55 | node_incoming = dynamic_assert_shape(node_incoming, shape, "node_incoming") 56 | return node_incoming -------------------------------------------------------------------------------- /macgraph/cell/output_cell.py: -------------------------------------------------------------------------------- 1 | 2 | import tensorflow as tf 3 | 4 | from ..minception import * 5 | from ..args import ACTIVATION_FNS 6 | from ..util import * 7 | from ..layers import * 8 | from ..attention import * 9 | from ..component import * 10 | 11 | 12 | class OutputCell(Component): 13 | 14 | def __init__(self, args): 15 | super().__init__(args, "output_cell") 16 | 17 | # TODO: generate this from components 18 | 19 | if args["use_output_lookback"]: 20 | tr = [] 21 | for i in range(args["mp_read_heads"]): 22 | tr.append(f"mp{i}") 23 | 24 | for i in range(args["max_decode_iterations"]): 25 | tr.append(f"po{i}") 26 | 27 | 28 | self.output_table = Tensor("table") 29 | self.output_query = Tensor("focus_query") 30 | self.focus = AttentionByIndex(args, 31 | self.output_table, self.output_query, seq_len=6, 32 | table_representation=tr, 33 | name="focus") 34 | 35 | 36 | 37 | def forward(self, features, context, mp_reads): 38 | 39 | with tf.name_scope(self.name): 40 | 41 | if self.args["use_output_lookback"]: 42 | 43 | in_all = [] 44 | 45 | def add(t): 46 | in_all.append(pad_to_len_1d(t, self.args["embed_width"])) 47 | 48 | def add_all(t): 49 | for i in t: 50 | add(i) 51 | 52 | add_all(mp_reads) 53 | 54 | prev_outputs = tf.unstack(context.in_prev_outputs, axis=1) 55 | add_all(prev_outputs) 56 | 57 | in_stack = tf.stack(in_all, axis=1) 58 | in_stack = dynamic_assert_shape(in_stack, [features["d_batch_size"], len(in_all), self.args["embed_width"]]) 59 | 60 | self.output_table.bind(in_stack) 61 | self.output_query.bind(context.in_iter_id) 62 | v = self.focus.forward(features) 63 | v.set_shape([None, self.args["embed_width"]]) 64 | 65 | else: 66 | v = tf.concat(mp_reads, -1) 67 | 68 | for i in range(self.args["output_layers"]): 69 | v = layer_dense(v, self.args["output_width"], self.args["output_activation"], name=f"output{i}") 70 | 71 | return v 72 | 73 | 74 | 75 | -------------------------------------------------------------------------------- /macgraph/cell/query.py: -------------------------------------------------------------------------------- 1 | 2 | import tensorflow as tf 3 | 4 | from ..attention import * 5 | from ..util import * 6 | from .types import * 7 | 8 | def generate_token_index_query(context:CellContext, name:str): 9 | with tf.name_scope(name): 10 | with tf.variable_scope(name): 11 | 12 | taps = {} 13 | 14 | master_signal = context.in_iter_id 15 | 16 | tokens = pad_to_table_len(context.embedded_question, seq_len=context.args["max_seq_len"], name=name) 17 | token_index_signal, query = attention_by_index(tokens, None) 18 | 19 | output = token_index_signal 20 | taps["token_index_attn"] = tf.expand_dims(query, 2) 21 | 22 | return output, taps 23 | 24 | 25 | 26 | def generate_query(context:CellContext, name): 27 | with tf.name_scope(name): 28 | 29 | taps = {} 30 | sources = [] 31 | 32 | def add_taps(prefix, extra_taps): 33 | for k, v in extra_taps.items(): 34 | taps[prefix + "_" + k] = v 35 | 36 | # -------------------------------------------------------------------------- 37 | # Produce all the difference sources of addressing query 38 | # -------------------------------------------------------------------------- 39 | 40 | ms = [context.in_iter_id] 41 | 42 | master_signal = tf.concat(ms, -1) 43 | 44 | # Content address the question tokens 45 | token_query = tf.layers.dense(master_signal, context.args["embed_width"]) 46 | token_signal, _, x_taps = attention(context.in_question_tokens, token_query) 47 | sources.append(token_signal) 48 | add_taps("token_content", x_taps) 49 | 50 | # Index address the question tokens 51 | padding = [[0,0], [0, tf.maximum(0,context.args["max_seq_len"] - tf.shape(context.in_question_tokens)[1])], [0,0]] # batch, seq_len, token 52 | in_question_tokens_padded = tf.pad(context.in_question_tokens, padding) 53 | in_question_tokens_padded.set_shape([None, context.args["max_seq_len"], None]) 54 | 55 | token_index_signal, query = attention_by_index(in_question_tokens_padded, master_signal) 56 | sources.append(token_index_signal) 57 | taps["token_index_attn"] = tf.expand_dims(query, 2) 58 | 59 | if context.args["use_read_previous_outputs"]: 60 | # Use the previous output of the network 61 | prev_output_query = tf.layers.dense(master_signal, context.args["output_width"]) 62 | in_prev_outputs_padded = tf.pad(context.in_prev_outputs, [[0,0],[0, context.args["max_decode_iterations"] - tf.shape(context.in_prev_outputs)[1]],[0,0]]) 63 | prev_output_signal, _, x_taps = attention(in_prev_outputs_padded, prev_output_query) 64 | sources.append(prev_output_signal) 65 | add_taps("prev_output", x_taps) 66 | 67 | # -------------------------------------------------------------------------- 68 | # Choose a query source 69 | # -------------------------------------------------------------------------- 70 | 71 | query_signal, q_tap = attention_by_index(tf.stack(sources, 1), master_signal) 72 | taps["switch_attn"] = q_tap 73 | 74 | return query_signal, taps 75 | -------------------------------------------------------------------------------- /macgraph/cell/types.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from typing import * 4 | 5 | # Someday TensorFlow will have types! 6 | Tensor = Any 7 | 8 | class CellContext(NamedTuple): 9 | features: Dict 10 | args: Dict 11 | vocab_embedding: Tensor 12 | in_prev_outputs: Tensor 13 | in_iter_id: Tensor 14 | in_node_state: Tensor 15 | embedded_question: Tensor -------------------------------------------------------------------------------- /macgraph/component.py: -------------------------------------------------------------------------------- 1 | 2 | from typing import * 3 | 4 | from abc import * 5 | import numpy as np 6 | 7 | from .print_util import * 8 | 9 | RawTensor = Any 10 | 11 | 12 | class FixedSizeTensor(NamedTuple): 13 | tensor:RawTensor 14 | size:List[int] 15 | 16 | class Component(ABC): 17 | 18 | def __init__(self, args:Dict[str, Any]={}, name:str=None): 19 | ''' 20 | Components should instantiate their sub-components in init 21 | 22 | This is so that methods like print do not need to run 23 | `forward` (which may be impossible outside of a session) 24 | prior to being able to recursively print all sub-components. 25 | ''' 26 | self.args = args 27 | self.name = name 28 | 29 | 30 | @abstractmethod 31 | def forward(self, features:Dict[str, RawTensor]) -> RawTensor: 32 | ''' 33 | Wire the forward pass (e.g. take tensors and return transformed tensors) 34 | to ultimately build the whole network 35 | ''' 36 | pass 37 | 38 | def taps(self) -> Dict[str, RawTensor]: 39 | ''' 40 | Get the taps (tensors that provide insight into the workings 41 | of the network) 42 | 43 | Forward will always have been called before this method, so 44 | it's ok to stash tensors as instance members in forward 45 | then recall them here 46 | ''' 47 | return {} 48 | 49 | 50 | def tap_sizes(self) -> Dict[str, List[int]]: 51 | ''' 52 | Get the names and sizes of expected taps 53 | 54 | Will be called independently from forward and taps 55 | ''' 56 | return {} 57 | 58 | def print(self, tap_dict:Dict[str, np.array], path:List[str], prefix:str, all_features:Dict[str, np.array]): 59 | ''' 60 | Print predict output nicely 61 | ''' 62 | pass 63 | 64 | 65 | def _do_recursive_map(self, fn, path:List[str]=[]): 66 | new_path = [*path, self.name] 67 | new_path = [i for i in new_path if i is not None] 68 | 69 | r = fn(self, new_path) 70 | 71 | for k, v in vars(self).items(): 72 | if issubclass(type(v), Component): 73 | r = {**r, **v._do_recursive_map(fn, new_path)} 74 | 75 | return r 76 | 77 | 78 | def all_taps(self) -> Dict[str,RawTensor]: 79 | 80 | def fn(self, path): 81 | r = self.taps() 82 | r_prefixed = {'_'.join([*path, k]): v 83 | for k,v in r.items()} 84 | return r_prefixed 85 | 86 | sizes = self.all_tap_sizes() 87 | taps = self._do_recursive_map(fn) 88 | 89 | sk = set(sizes.keys()) 90 | tk = set(taps.keys()) 91 | 92 | assert sk == tk, f"Set mismatch, in sizes but not taps: {sk - tk}, in taps but not sizes: {tk - sk}. \nFull sets taps:{tk} \ntap_sizes:{sk}" 93 | 94 | return taps 95 | 96 | def all_tap_sizes(self) -> Dict[str, List[int]]: 97 | 98 | def fn(self, path): 99 | r = self.tap_sizes() 100 | r_prefixed = {'_'.join([*path, k]): v 101 | for k,v in r.items()} 102 | 103 | return r_prefixed 104 | 105 | return self._do_recursive_map(fn) 106 | 107 | 108 | # You must call recursive_taps before this 109 | def print_all(self, all_features:Dict[str, np.array]): 110 | 111 | def fn(self, path): 112 | t = self.tap_sizes() 113 | 114 | r = { 115 | k: all_features['_'.join([*path, k])] 116 | for k in t.keys() 117 | } 118 | self.print(r, path, '_'.join(path), all_features) 119 | return {} 120 | 121 | self._do_recursive_map(fn) 122 | 123 | 124 | 125 | 126 | class Tensor(Component): 127 | def __init__(self, name=None): 128 | super().__init__(name=name) 129 | 130 | def bind(self, tensor:RawTensor): 131 | self.tensor = tensor 132 | 133 | def forward(self, features): 134 | return self.tensor 135 | 136 | 137 | 138 | 139 | class PrintTensor(Tensor): 140 | 141 | def __init__(self, width, name): 142 | super().__init__(name=name) 143 | self.width = width 144 | 145 | def taps(self): 146 | return { 147 | "tensor": self.tensor 148 | } 149 | 150 | def tap_sizes(self): 151 | return { 152 | "tensor": [self.width] 153 | } 154 | 155 | def print(self, taps, path, prefix, all): 156 | print(prefix, color_vector(taps["tensor"])) 157 | 158 | 159 | -------------------------------------------------------------------------------- /macgraph/const.py: -------------------------------------------------------------------------------- 1 | 2 | # Constant to avoid circular dependencies 3 | EPSILON = 1E-7 -------------------------------------------------------------------------------- /macgraph/estimator.py: -------------------------------------------------------------------------------- 1 | 2 | import tensorflow as tf 3 | 4 | from .model import model_fn 5 | 6 | def get_estimator(args): 7 | 8 | run_config = tf.estimator.RunConfig( 9 | model_dir=args["model_dir"], 10 | tf_random_seed=args["random_seed"], 11 | save_checkpoints_steps=None, 12 | save_checkpoints_secs=args["eval_every"], 13 | ) 14 | 15 | return tf.estimator.Estimator( 16 | model_fn=model_fn, 17 | config=run_config, 18 | warm_start_from=args["warm_start_dir"], 19 | params=args) -------------------------------------------------------------------------------- /macgraph/evaluate.py: -------------------------------------------------------------------------------- 1 | 2 | from .args import get_args 3 | from .estimator import get_estimator 4 | from .input import gen_input_fn 5 | 6 | if __name__ == "__main__": 7 | 8 | args = get_args() 9 | estimator = get_estimator(args) 10 | 11 | estimator.evaluate(input_fn=gen_input_fn(args, "eval")) -------------------------------------------------------------------------------- /macgraph/global_args.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | global_args = {} -------------------------------------------------------------------------------- /macgraph/hooks.py: -------------------------------------------------------------------------------- 1 | 2 | import tensorflow as tf 3 | import numpy as np 4 | 5 | class FloydHubMetricHook(tf.train.SessionRunHook): 6 | """An easy way to output your metric_ops to FloydHub's training metric graphs 7 | 8 | This is designed to fit into TensorFlow's EstimatorSpec. Assuming you've 9 | already defined some metric_ops for monitoring your training/evaluation, 10 | this helper class will compute those operations then print them out in 11 | the format that FloydHub is expecting. For example: 12 | 13 | ``` 14 | def model_fn(features, labels, mode, params): 15 | 16 | # Set up your model 17 | loss = ... 18 | my_predictions = ... 19 | 20 | eval_metric_ops = { 21 | "accuracy": tf.metrics.accuracy(labels=labels, predictions=my_predictions) 22 | "loss": tf.metrics.mean(loss) 23 | } 24 | 25 | return EstimatorSpec(mode, 26 | eval_metric_ops = eval_metric_ops, 27 | 28 | # **Here it is! The magic!! ** 29 | eval_hooks = [FloydHubMetricHook(eval_metric_ops)] 30 | 31 | ) 32 | ``` 33 | 34 | FloydHubMetricHook has one optional parameter, *prefix* for using it multiple times 35 | (e.g. prefix="train_" for training metrics, prefix="eval_" for evaluation metrics). 36 | 37 | 38 | """ 39 | 40 | def __init__(self, metric_ops, prefix=""): 41 | self.metric_ops = metric_ops 42 | self.prefix = prefix 43 | self.readings = {} 44 | 45 | def before_run(self, run_context): 46 | return tf.train.SessionRunArgs(self.metric_ops) 47 | 48 | def after_run(self, run_context, run_values): 49 | if run_values.results is not None: 50 | for k,v in run_values.results.items(): 51 | try: 52 | self.readings[k].append(v[1]) 53 | except KeyError: 54 | self.readings[k] = [v[1]] 55 | 56 | def end(self, session): 57 | for k, v in self.readings.items(): 58 | a = np.average(v) 59 | print(f'{{"metric": "{self.prefix}{k}", "value": {a}}}') 60 | 61 | self.readings = {} 62 | -------------------------------------------------------------------------------- /macgraph/input/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .input import input_fn, gen_input_fn 3 | from .text_util import Vocab, bytes_to_string, pretokenize_json, UNK_ID 4 | from .kb import get_table_with_embedding -------------------------------------------------------------------------------- /macgraph/input/args.py: -------------------------------------------------------------------------------- 1 | 2 | from ..args import get_args as get_args_parent 3 | 4 | def get_args(extend=lambda x:None, argv=None): 5 | def inner_extend(parser): 6 | parser.add_argument('--skip-vocab', action='store_true') 7 | parser.add_argument('--only-build-vocab', action='store_true') 8 | 9 | parser.add_argument('--gqa-dir', type=str, default="./input_data/raw") 10 | parser.add_argument('--gqa-paths', type=str, nargs='+') 11 | parser.add_argument('--balance-batch', type=int, default=1000) 12 | parser.add_argument('--vocab-build-limit', type=int, default=2000, help="It's slow to read all records to build vocab, so just read this many. Could cause UNK to slip into dataset depending on distribution of tokens.") 13 | 14 | extend(parser) 15 | 16 | return get_args_parent(inner_extend, argv) -------------------------------------------------------------------------------- /macgraph/input/balancer.py: -------------------------------------------------------------------------------- 1 | 2 | import random 3 | import math 4 | from collections import Counter 5 | 6 | import logging 7 | logger = logging.getLogger(__name__) 8 | 9 | class Balancer(object): 10 | """ 11 | It's the caller's duty to close record_writer 12 | """ 13 | 14 | def __init__(self, record_writer, balance_freq, name="", parent=None): 15 | self.batch_i = 0 16 | self.record_writer = record_writer 17 | self.balance_freq = balance_freq 18 | self.name = name 19 | self.parent = parent 20 | self.running_total = None 21 | 22 | def oversampled_so_far(self): 23 | raise NotImplementedException() 24 | 25 | def oversample(self, n): 26 | '''Return over-sampling with n items total''' 27 | raise NotImplementedException() 28 | 29 | def write(self, doc, item): 30 | # Only do this for top level class 31 | if self.parent is None: 32 | self.batch_i += 1 33 | self.pipe_if_ready() 34 | 35 | def pipe(self): 36 | for i in self.oversample(self.batch_i): 37 | self.record_writer.write(*i) 38 | 39 | def pipe_if_ready(self): 40 | if self.batch_i > self.balance_freq: 41 | self.pipe() 42 | self.batch_i = 0 43 | 44 | def __enter__(self): 45 | return self 46 | 47 | def __exit__(self, *vargs): 48 | if self.parent is None: 49 | self.pipe() 50 | print(self.running_total) 51 | 52 | 53 | def resample_list(l, n): 54 | if n < 0: 55 | raise ArgumentError("Cannot sample list to negative size") 56 | elif n == 0: 57 | r = [] 58 | elif n == len(l): 59 | r = l 60 | elif n >= len(l): 61 | r = l + [random.choice(l) for i in range(n - len(l))] 62 | else: 63 | r = random.sample(l, n) 64 | 65 | assert len(r) == n 66 | return r 67 | 68 | 69 | class ListBalancer(Balancer): 70 | 71 | def __init__(self, record_writer, balance_freq, name="", parent=None): 72 | super().__init__(record_writer, balance_freq, name, parent) 73 | self.data = [] 74 | 75 | def write(self, doc, item): 76 | self.data.append((doc,item)) 77 | self.data = self.data[-self.balance_freq:] 78 | super().write(doc, item) 79 | 80 | def oversample(self, n): 81 | if len(self.data) == 0: 82 | raise ValueError("Cannot sample empty list") 83 | 84 | r = resample_list(self.data, n) 85 | return r 86 | 87 | 88 | class DictBalancer(Balancer): 89 | 90 | def __init__(self, key_pred, CtrClzz, record_writer, balance_freq, name="", parent=None): 91 | super().__init__(record_writer, balance_freq, name, parent) 92 | self.data = {} 93 | self.key_pred = key_pred 94 | self.CtrClzz = CtrClzz 95 | self.running_total = Counter() 96 | 97 | def write(self, doc, item): 98 | key = self.key_pred(doc) 99 | 100 | if key not in self.data: 101 | self.data[key] = self.CtrClzz(self.record_writer, self.balance_freq, key, self) 102 | 103 | if key not in self.running_total: 104 | self.running_total[key] = 0 105 | 106 | self.data[key].write(doc, item) 107 | super().write(doc, item) 108 | 109 | def oversampled_so_far(self): 110 | return sum(self.running_total.values()) 111 | 112 | def oversample(self, n): 113 | 114 | total_target = self.oversampled_so_far() + n 115 | 116 | if len(self.data) > 0: 117 | target_per_class = math.ceil(total_target / len(self.data)) 118 | else: 119 | logger.warning(f"Oversample called on {self.name} but no data added") 120 | return [] 121 | 122 | if n <= 0: 123 | return [] 124 | 125 | r = [] 126 | 127 | for k, v in self.running_total.items(): 128 | o = self.data[k].oversample(target_per_class - v) 129 | o = [(k, i) for i in o] 130 | r.extend(o) 131 | 132 | r = resample_list(r, n) 133 | 134 | classes, r = zip(*r) # aka unzip 135 | self.running_total.update(classes) 136 | 137 | assert len(r) == n, f"DictBalancer {self.name} tried to return {len(r)} not {n} items" 138 | assert sum(self.running_total.values()) == total_target 139 | return r 140 | 141 | 142 | class TwoLevelBalancer(DictBalancer): 143 | def __init__(self, key1, key2, record_writer, balance_freq, name="TwoLevelBalancer", parent=None): 144 | Inner = lambda record_writer, balance_freq, name, parent: DictBalancer(key2, ListBalancer, record_writer, balance_freq, name, parent) 145 | super().__init__(key1, Inner, record_writer, balance_freq, name, parent) 146 | 147 | -------------------------------------------------------------------------------- /macgraph/input/build.py: -------------------------------------------------------------------------------- 1 | 2 | import tensorflow as tf 3 | import pathlib 4 | from collections import Counter 5 | import yaml 6 | from tqdm import tqdm 7 | import contextlib 8 | 9 | from .graph_util import * 10 | from .text_util import * 11 | from .util import * 12 | from .args import * 13 | from .balancer import TwoLevelBalancer 14 | from .partitioner import * 15 | 16 | import logging 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | # -------------------------------------------------------------------------- 21 | # Helpers 22 | # -------------------------------------------------------------------------- 23 | 24 | 25 | def generate_record(args, vocab, doc): 26 | 27 | q = vocab.english_to_ids(doc["question"]["english"]) 28 | 29 | # May raise exception if unsupported type 30 | label = vocab.lookup(pretokenize_json(doc["answer"])) 31 | 32 | if label == UNK_ID: 33 | raise ValueError(f"We're only including questions that have in-vocab answers ({doc['answer']})") 34 | 35 | if label >= args["output_width"]: 36 | raise ValueError(f"Label {label} greater than answer classes {args['output_width']}") 37 | 38 | nodes, edges, adjacency = graph_to_table(args, vocab, doc["graph"]) 39 | 40 | logger.debug(f""" 41 | Answer={vocab.ids_to_string([label])} 42 | {vocab.ids_to_string(q)} 43 | {[vocab.ids_to_string(g) for g in nodes]} 44 | {[vocab.ids_to_string(g) for g in edges]}""") 45 | 46 | feature = { 47 | "src": write_int64_array_feature(q), 48 | "src_len": write_int64_feature(len(q)), 49 | "kb_edges": write_int64_array_feature(edges.flatten()), 50 | "kb_edges_len": write_int64_feature(edges.shape[0]), 51 | "kb_nodes": write_int64_array_feature(nodes.flatten()), 52 | "kb_nodes_len": write_int64_feature(nodes.shape[0]), 53 | "kb_adjacency": write_boolean_array_feature(adjacency.flatten()), 54 | "label": write_int64_feature(label), 55 | "type_string": write_string_feature(doc["question"]["type_string"]), 56 | } 57 | 58 | example = tf.train.Example(features=tf.train.Features(feature=feature)) 59 | return example.SerializeToString() 60 | 61 | 62 | def build(args): 63 | try: 64 | pathlib.Path(args["input_dir"]).mkdir(parents=True, exist_ok=True) 65 | except FileExistsError: 66 | pass 67 | 68 | logger.info(f"Building {args['dataset']} data from {args['gqa_paths']}") 69 | 70 | if not args["skip_vocab"]: 71 | logger.info(f"Build vocab {args['vocab_path']} ") 72 | vocab = Vocab.build(args, lambda i:gqa_to_tokens(args, i), limit=min(args["limit"], args["vocab_build_limit"])) 73 | logger.info(f"Wrote {len(vocab)} vocab entries") 74 | logger.debug(f"vocab: {vocab.table}") 75 | print() 76 | 77 | if args["only_build_vocab"]: 78 | return 79 | else: 80 | vocab = Vocab.load_from_args(args) 81 | 82 | 83 | question_types = Counter() 84 | output_classes = Counter() 85 | 86 | logger.info(f"Generate TFRecords {args['input_dir']}") 87 | 88 | # To close everything nicely later 89 | with ExitStack() as stack: 90 | 91 | balancers = {} 92 | 93 | k_answer = lambda d: d["answer"] 94 | k_type_string = lambda d: d["question"]["type_string"] 95 | 96 | for mode in args["modes"]: 97 | writer = stack.enter_context(RecordWriter(args, mode)) 98 | balancers[mode] = stack.enter_context(TwoLevelBalancer(k_answer, k_type_string, writer, min_none(args["balance_batch"], args["limit"]))) 99 | 100 | with Partitioner(args, balancers) as p: 101 | for doc in tqdm(read_gqa(args), total=args["limit"]): 102 | try: 103 | record = generate_record(args, vocab, doc) 104 | question_types[doc["question"]["type_string"]] += 1 105 | output_classes[doc["answer"]] += 1 106 | p.write(doc, record) 107 | 108 | except ValueError as ex: 109 | logger.debug(ex) 110 | pass 111 | 112 | 113 | with tf.gfile.GFile(args["answer_classes_path"], "w") as file: 114 | yaml.dump(dict(p.answer_classes), file) 115 | 116 | with tf.gfile.GFile(args["answer_classes_types_path"], "w") as file: 117 | yaml.dump(dict(p.answer_classes_types), file) 118 | 119 | logger.info(f"Class distribution: {p.answer_classes}") 120 | 121 | logger.info(f"Wrote {p.written} TFRecords") 122 | 123 | 124 | with tf.gfile.GFile(args["question_types_path"], "w") as file: 125 | yaml.dump(dict(question_types), file) 126 | 127 | 128 | # -------------------------------------------------------------------------- 129 | # Run the script 130 | # -------------------------------------------------------------------------- 131 | 132 | if __name__ == "__main__": 133 | 134 | args = get_args() 135 | 136 | logging.basicConfig() 137 | logger.setLevel(args["log_level"]) 138 | logging.getLogger("mac-graph.input.util").setLevel(args["log_level"]) 139 | 140 | build(args) 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | -------------------------------------------------------------------------------- /macgraph/input/build_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import tensorflow as tf 4 | import numpy as np 5 | import math 6 | import logging 7 | 8 | from .input import input_fn 9 | from .build import build 10 | from .args import get_args 11 | from .util import * 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | class TestBuild(unittest.TestCase): 16 | 17 | @classmethod 18 | def setUpClass(cls): 19 | tf.enable_eager_execution() 20 | 21 | argv = [ 22 | '--gqa-paths', 'input_data/raw/test.yaml', 23 | '--input-dir', 'input_data/processed/test', 24 | '--limit', '100', 25 | '--predict-holdback', '0.1', 26 | '--eval-holdback', '0.1', 27 | ] 28 | 29 | args = get_args(argv=argv) 30 | cls.args = args 31 | 32 | build(args) 33 | 34 | def assert_adjacency_valid(self, args, features, batch_index): 35 | 36 | lhs = features["kb_adjacency"][batch_index].numpy() 37 | 38 | # Test adjacency is symmetric 39 | np.testing.assert_array_equal(lhs, np.transpose(lhs)) 40 | 41 | # Test not reflexive 42 | reflexive = np.identity(args["kb_node_max_len"], np.bool) 43 | any_reflexive = np.logical_and(reflexive, lhs) 44 | np.testing.assert_array_equal(any_reflexive, False) 45 | 46 | 47 | # Next, reconstruct adj from edges and check it is the same 48 | adj = np.full([args["kb_node_max_len"], args["kb_node_max_len"]], False) 49 | 50 | def get_node_idx(node_id): 51 | for idx, data in enumerate(features["kb_nodes"][batch_index]): 52 | if idx < features["kb_nodes_len"][batch_index].numpy(): 53 | if data[0].numpy() == node_id.numpy(): 54 | return idx 55 | 56 | raise ValueError(f"Node id {node_id} not found in node list {features['kb_nodes'][batch_index]}") 57 | 58 | for idx, edge in enumerate(features["kb_edges"][batch_index]): 59 | if idx < features["kb_edges_len"][batch_index].numpy(): 60 | node_from = edge[0] 61 | node_to = edge[2] 62 | node_from_idx = get_node_idx(node_from) 63 | node_to_idx = get_node_idx(node_to) 64 | adj[node_from_idx][node_to_idx] = True 65 | adj[node_to_idx][node_from_idx] = True 66 | 67 | assert adj.shape == features["kb_adjacency"][batch_index].numpy().shape 68 | 69 | 70 | # So that it prints useful errors 71 | np.set_printoptions(threshold=np.inf) 72 | # for i, j, k in zip(lhs, adj, features["kb_nodes"][batch_index]): 73 | # print("testing node", k.numpy()) 74 | # np.testing.assert_array_equal(i,j) 75 | np.testing.assert_array_equal(adj, lhs) 76 | 77 | 78 | def test_build_adjacency(self): 79 | 80 | dataset = input_fn(TestBuild.args, "train", repeat=False) 81 | 82 | for features, label in dataset: 83 | for i in range(len(list(label))): 84 | self.assert_adjacency_valid(TestBuild.args, features, i) 85 | 86 | 87 | def test_build_basics(self): 88 | 89 | # Validate questions are a unique set 90 | gqa_questions = set() 91 | for i in read_gqa(TestBuild.args): 92 | digest = (i["question"]["english"], len(i["graph"]["edges"]), len(i["graph"]["nodes"])) 93 | self.assertNotIn(digest, gqa_questions) 94 | gqa_questions.add(digest) 95 | 96 | questions = {} 97 | 98 | for mode in TestBuild.args["modes"]: 99 | dataset = input_fn(TestBuild.args, mode, repeat=False) 100 | questions[mode] = set() 101 | 102 | for features, label in dataset: 103 | for batch_index in range(len(list(label))): 104 | questions[mode].add(( 105 | str(features["src"][batch_index]), int(features["kb_edges_len"][batch_index]), int(features["kb_nodes_len"][batch_index]) 106 | )) 107 | 108 | 109 | for mode in TestBuild.args["modes"]: 110 | for mode_b in TestBuild.args["modes"]: 111 | if mode != mode_b: 112 | self.assertTrue(questions[mode].isdisjoint(questions[mode_b]), f"Same question in mode {mode} and {mode_b}") 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | if __name__ == '__main__': 121 | logging.basicConfig() 122 | logger.setLevel('INFO') 123 | 124 | unittest.main() -------------------------------------------------------------------------------- /macgraph/input/graph_util.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | 4 | from .text_util import * 5 | 6 | NODE_PROPS = ["name", "cleanliness", "music", "architecture", "size", "has_rail", "disabled_access"] 7 | EDGE_PROPS = ["line_name"] 8 | 9 | 10 | def gqa_to_tokens(args, gqa): 11 | 12 | tokens = list() 13 | 14 | for edge in gqa["graph"]["edges"]: 15 | for key in EDGE_PROPS: 16 | tokens.append(pretokenize_json(edge[key])) 17 | 18 | for node in gqa["graph"]["nodes"]: 19 | for key in NODE_PROPS: 20 | tokens.append(pretokenize_json(node[key])) 21 | 22 | tokens += pretokenize_english(gqa["question"]["english"]).split(' ') 23 | 24 | try: 25 | tokens.append(pretokenize_json(gqa["answer"])) 26 | except ValueError: 27 | pass 28 | 29 | return tokens 30 | 31 | 32 | def graph_to_table(args, vocab, graph): 33 | 34 | def node_to_vec(node, props=NODE_PROPS): 35 | return np.array([ 36 | vocab.lookup(pretokenize_json(node[key])) for key in props 37 | ]) 38 | 39 | def edge_to_vec(edge, props=EDGE_PROPS): 40 | return np.array([ 41 | vocab.lookup(pretokenize_json(edge[key])) for key in props 42 | ]) 43 | 44 | def pack(row, width): 45 | if len(row) > width: 46 | r = row[0:width] 47 | elif len(row) < width: 48 | r = np.pad(row, (0, width - len(row)), 'constant', constant_values=UNK_ID) 49 | else: 50 | r = row 51 | 52 | assert len(r) == width, "Extraction functions didn't create the right length of knowledge table data" 53 | return r 54 | 55 | edges = [] 56 | 57 | node_lookup = {i["id"]: i for i in graph["nodes"]} 58 | 59 | nodes = [pack(node_to_vec(i), args["kb_node_width"]) for i in graph["nodes"]] 60 | 61 | assert len(graph["nodes"]) <= args["kb_node_max_len"] 62 | 63 | for edge in graph["edges"]: 64 | s1 = node_to_vec(node_lookup[edge["station1"]], ['name']) 65 | s2 = node_to_vec(node_lookup[edge["station2"]], ['name']) 66 | e = edge_to_vec(edge) 67 | 68 | row = np.concatenate((s1, e, s2), -1) 69 | row = pack(row, args["kb_edge_width"]) 70 | 71 | edges.append(row) 72 | 73 | 74 | # I'm treating edges as bidirectional for the adjacency matrix 75 | # Also, I'm discarding line information. That is still in the edges list 76 | def is_connected(idx_from, idx_to): 77 | 78 | # To produce stable tensor sizes, the adj matrix is padded out to kb_nodes_max_len 79 | if idx_from >= len(graph["nodes"]) or idx_to >= len(graph["nodes"]): 80 | return False 81 | 82 | id_from = graph["nodes"][idx_from]["id"] 83 | id_to = graph["nodes"][idx_to ]["id"] 84 | 85 | for edge in graph["edges"]: 86 | if edge["station1"] == id_from and edge["station2"] == id_to: 87 | return True 88 | if edge["station1"] == id_to and edge["station2"] == id_from: 89 | return True 90 | 91 | return False 92 | 93 | 94 | adjacency = [ 95 | [ 96 | is_connected(i, j) for j in range(args["kb_node_max_len"]) 97 | ] 98 | for i in range(args["kb_node_max_len"]) 99 | ] 100 | 101 | 102 | return np.array(nodes), np.array(edges), np.array(adjacency) 103 | 104 | -------------------------------------------------------------------------------- /macgraph/input/input.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import tensorflow as tf 4 | 5 | import logging 6 | logger = logging.getLogger(__name__) 7 | 8 | from .text_util import EOS_ID, UNK_ID 9 | from .graph_util import * 10 | from .util import * 11 | 12 | def parse_single_example(i): 13 | return tf.parse_single_example( 14 | i, 15 | features = { 16 | 'src': parse_feature_int_array(), 17 | 'src_len': parse_feature_int(), 18 | 19 | 'kb_edges': parse_feature_int_array(), 20 | 'kb_edges_len': parse_feature_int(), 21 | 'kb_nodes': parse_feature_int_array(), 22 | 'kb_nodes_len': parse_feature_int(), 23 | 'kb_adjacency': parse_feature_boolean_array(), 24 | 25 | 'label': parse_feature_int(), 26 | 'type_string': parse_feature_string(), 27 | }) 28 | 29 | def reshape_example(args, i): 30 | 31 | def reshape_adj(tensor): 32 | return tf.reshape(tensor, [args["kb_node_max_len"], args["kb_node_max_len"]]) 33 | 34 | return ({ 35 | # Text input 36 | "src": i["src"], 37 | "src_len": i["src_len"], 38 | 39 | # Knowledge base 40 | "kb_nodes": tf.reshape(i["kb_nodes"], [-1, args["kb_node_width"]]), 41 | "kb_nodes_len": i["kb_nodes_len"], 42 | "kb_edges": tf.reshape(i["kb_edges"], [-1, args["kb_edge_width"]]), 43 | "kb_edges_len": i["kb_edges_len"], 44 | "kb_adjacency": reshape_adj(i["kb_adjacency"]), 45 | 46 | # Prediction stats 47 | "label": i["label"], 48 | "type_string": i["type_string"], 49 | 50 | }, i["label"]) 51 | 52 | def switch_to_from(db): 53 | return tf.stack([db[:,2], db[:,1], db[:,0]], -1) 54 | 55 | def make_edges_bidirectional(features, labels): 56 | features["kb_edges"] = tf.concat([features["kb_edges"], switch_to_from(features["kb_edges"])], 0) 57 | features["kb_edges_len"] *= 2 58 | return features, labels 59 | 60 | 61 | 62 | def cast_adjacency_to_bool(features, labels): 63 | features["kb_adjacency"] = tf.cast(features["kb_adjacency"], tf.bool) 64 | return features, labels 65 | 66 | def input_fn(args, mode, question=None, repeat=True): 67 | 68 | # -------------------------------------------------------------------------- 69 | # Read TFRecords 70 | # -------------------------------------------------------------------------- 71 | 72 | d = tf.data.TFRecordDataset([args[f"{mode}_input_path"]]) 73 | d = d.map(parse_single_example) 74 | 75 | # -------------------------------------------------------------------------- 76 | # Layout input data 77 | # -------------------------------------------------------------------------- 78 | 79 | d = d.map(lambda i: reshape_example(args,i)) 80 | d = d.map(make_edges_bidirectional) 81 | 82 | 83 | if args["limit"] is not None: 84 | d = d.take(args["limit"]) 85 | 86 | if args["filter_type_prefix"] is not None: 87 | d = d.filter(lambda features, labels: 88 | tf_startswith(features["type_string"], args["filter_type_prefix"])) 89 | 90 | if args["filter_output_class"] is not None: 91 | classes_as_ints = [args["vocab"].lookup(i) for i in args["filter_output_class"]] 92 | d = d.filter(lambda features, labels: 93 | tf.reduce_any(tf.equal(features["label"], classes_as_ints)) 94 | ) 95 | 96 | d = d.shuffle(args["batch_size"]*1000) 97 | 98 | zero_64 = tf.cast(0, tf.int64) 99 | unk_64 = tf.cast(UNK_ID, tf.int64) 100 | 101 | kb_adjacency_shape = tf.TensorShape([args["kb_node_max_len"], args["kb_node_max_len"]]) 102 | 103 | 104 | d = d.padded_batch( 105 | args["batch_size"], 106 | # The first three entries are the source and target line rows; 107 | # these have unknown-length vectors. The last two entries are 108 | # the source and target row sizes; these are scalars. 109 | padded_shapes=( 110 | { 111 | "src": tf.TensorShape([None]), 112 | "src_len": tf.TensorShape([]), 113 | 114 | "kb_nodes": tf.TensorShape([None, args["kb_node_width"]]), 115 | "kb_nodes_len": tf.TensorShape([]), 116 | "kb_edges": tf.TensorShape([None, args["kb_edge_width"]]), 117 | "kb_edges_len": tf.TensorShape([]), 118 | "kb_adjacency": kb_adjacency_shape, 119 | 120 | "label": tf.TensorShape([]), 121 | "type_string": tf.TensorShape([None]), 122 | }, 123 | tf.TensorShape([]), # label 124 | ), 125 | 126 | # Pad the source and target sequences with eos tokens. 127 | # (Though notice we don't generally need to do this since 128 | # later on we will be masking out calculations past the true sequence. 129 | padding_values=( 130 | { 131 | "src": tf.cast(EOS_ID, tf.int64), 132 | "src_len": zero_64, # unused 133 | 134 | "kb_nodes": unk_64, 135 | "kb_nodes_len": zero_64, # unused 136 | "kb_edges": unk_64, 137 | "kb_edges_len": zero_64, # unused 138 | "kb_adjacency": zero_64, 139 | 140 | "label": zero_64, 141 | "type_string": tf.cast("", tf.string), 142 | }, 143 | zero_64 # label (unused) 144 | ), 145 | drop_remainder=(mode == "train") 146 | ) 147 | 148 | d = d.map(cast_adjacency_to_bool) 149 | 150 | # Add dynamic dimensions for convenience (e.g. to do shape assertions) 151 | d = d.map(lambda features, labels: ({ 152 | **features, 153 | "d_batch_size": tf.shape(features["src"])[0], 154 | "d_src_len": tf.shape(features["src"])[1], 155 | }, labels)) 156 | 157 | if repeat: 158 | d = d.repeat() 159 | 160 | return d 161 | 162 | 163 | 164 | def gen_input_fn(args, mode): 165 | return lambda: input_fn(args, mode, repeat=(mode == "train")) 166 | 167 | 168 | 169 | 170 | -------------------------------------------------------------------------------- /macgraph/input/kb.py: -------------------------------------------------------------------------------- 1 | 2 | import tensorflow as tf 3 | 4 | from ..util import dynamic_assert_shape 5 | from .text_util import UNK_ID 6 | 7 | 8 | def get_table_with_embedding(args, features, vocab_embedding, noun): 9 | 10 | # -------------------------------------------------------------------------- 11 | # Constants and validations 12 | # -------------------------------------------------------------------------- 13 | 14 | # TODO: remove these pesky ses 15 | table = features[f"{noun}s"] 16 | table_len = features[f"{noun}s_len"] 17 | width = args[f"{noun}_width"] 18 | full_width = width * args["embed_width"] 19 | 20 | d_len = tf.shape(table)[1] 21 | assert table.shape[-1] == width, f"Table shape {table.shape} did not have expected inner width dimensions of {width}" 22 | 23 | # -------------------------------------------------------------------------- 24 | # Embed graph tokens 25 | # -------------------------------------------------------------------------- 26 | 27 | table = dynamic_assert_shape(table, [features["d_batch_size"], d_len, width]) 28 | 29 | emb_kb = tf.nn.embedding_lookup(vocab_embedding, table) 30 | emb_kb *= tf.sqrt(tf.cast(args["embed_width"], emb_kb.dtype)) # As per Transformer model 31 | 32 | emb_kb = dynamic_assert_shape(emb_kb, 33 | [features["d_batch_size"], d_len, width, args["embed_width"]]) 34 | 35 | emb_kb = tf.reshape(emb_kb, [-1, d_len, full_width]) 36 | emb_kb = dynamic_assert_shape(emb_kb, 37 | [features["d_batch_size"], d_len, full_width]) 38 | 39 | return emb_kb, full_width, table_len 40 | 41 | -------------------------------------------------------------------------------- /macgraph/input/partitioner.py: -------------------------------------------------------------------------------- 1 | 2 | import tensorflow as tf 3 | from collections import Counter 4 | import random 5 | 6 | class RecordWriter(object): 7 | """ 8 | Very basic wrapper that just serves to marshall the args for us and 9 | plug into context manager 10 | """ 11 | 12 | def __init__(self, args, mode): 13 | self.args = args 14 | self.mode = mode 15 | 16 | def __enter__(self, *vargs): 17 | self.file = tf.python_io.TFRecordWriter(self.args[f"{self.mode}_input_path"]) 18 | return self 19 | 20 | def write(self, doc, record): 21 | self.file.write(record) 22 | 23 | def __exit__(self, *vargs): 24 | self.file.close() 25 | self.file = None 26 | 27 | 28 | class Partitioner(object): 29 | """ 30 | Write to this and it'll randomly write to the writer_dict 31 | writers. 32 | 33 | It is the callers responsibility to close the writer_dicts after 34 | this has been disposed of. 35 | 36 | Arguments: 37 | writer_dict: A dictionary of mode strings to something that accepts a write(doc, record) call 38 | """ 39 | 40 | def __init__(self, args, writer_dict): 41 | self.args = args 42 | self.written = 0 43 | self.answer_classes = Counter() 44 | self.answer_classes_types = Counter() 45 | self.writer_dict = writer_dict 46 | 47 | def __enter__(self, *vargs): 48 | """ 49 | Just here in case we've future buffered state to tidy up 50 | """ 51 | return self 52 | 53 | 54 | def write(self, doc, record): 55 | r = random.random() 56 | 57 | if r < self.args["eval_holdback"]: 58 | mode = "eval" 59 | elif r < self.args["eval_holdback"] + self.args["predict_holdback"]: 60 | mode = "predict" 61 | else: 62 | mode = "train" 63 | 64 | key = (str(doc["answer"]), doc["question"]["type_string"]) 65 | 66 | self.writer_dict[mode].write(doc, record) 67 | self.answer_classes[str(doc["answer"])] += 1 68 | self.answer_classes_types[key] += 1 69 | self.written += 1 70 | 71 | 72 | def __exit__(self, *vargs): 73 | pass 74 | 75 | -------------------------------------------------------------------------------- /macgraph/input/print_gqa.py: -------------------------------------------------------------------------------- 1 | 2 | import tableprint 3 | from collections import Counter 4 | from tqdm import tqdm 5 | 6 | from .args import * 7 | from .util import * 8 | 9 | 10 | if __name__ == "__main__": 11 | 12 | def extend(parser): 13 | parser.add_argument("--print-rows",action='store_true') 14 | 15 | args = get_args(extend) 16 | 17 | output_classes = Counter() 18 | question_types = Counter() 19 | questions = Counter() 20 | 21 | try: 22 | with tableprint.TableContext(headers=["Type", "Question", "Answer"], width=[40,50,15]) as t: 23 | for i in read_gqa(args): 24 | output_classes[i["answer"]] += 1 25 | question_types[i["question"]["type_string"]] += 1 26 | questions[i["question"]["english"]] += 1 27 | 28 | if args["print_rows"]: 29 | t([ 30 | i["question"]["type_string"], 31 | i["question"]["english"], 32 | i["answer"] 33 | ]) 34 | 35 | except KeyboardInterrupt: 36 | print() 37 | pass 38 | # we want to print the final results! 39 | 40 | def second(v): 41 | return v[1] 42 | 43 | tableprint.table(headers=["Answer", "Count"], width=[20,5], data=sorted(output_classes.items(), key=second)) 44 | tableprint.table(headers=["Question", "Count"], width=[50,5], data=sorted(questions.items(), key=second)) 45 | tableprint.table(headers=["Question type", "Count"], width=[20,5], data=sorted(question_types.items(), key=second)) 46 | 47 | 48 | 49 | 50 | -------------------------------------------------------------------------------- /macgraph/input/print_tfr.py: -------------------------------------------------------------------------------- 1 | 2 | import tensorflow as tf 3 | import numpy as np 4 | 5 | from collections import Counter 6 | from tqdm import tqdm 7 | import tableprint as tp 8 | import re 9 | 10 | from .args import * 11 | from .util import * 12 | from .text_util import * 13 | from .input import * 14 | 15 | def eager_to_str(v): 16 | return bytes_to_string(np.array(v)) 17 | 18 | def extend_args(parser): 19 | parser.add_argument('--print-records', action='store_true') 20 | 21 | if __name__ == "__main__": 22 | 23 | args = get_args(extend_args) 24 | 25 | vocab = Vocab.load_from_args(args) 26 | count = 0 27 | tf.enable_eager_execution() 28 | 29 | dist = Counter() 30 | types = set() 31 | labels = set() 32 | 33 | 34 | for i in tqdm(tf.python_io.tf_record_iterator(args["train_input_path"]), total=args["limit"]): 35 | 36 | if args["limit"] is not None and count > args["limit"]: 37 | break 38 | 39 | # Parse 40 | r = parse_single_example(i) 41 | r,label = reshape_example(args, r) 42 | r["type_string"] = eager_to_str(r["type_string"]) 43 | r["src"] = vocab.ids_to_english(np.array(r["src"])) 44 | r["label"] = vocab.inverse_lookup(int(r["label"])) 45 | r["kb_nodes"] = [vocab.ids_to_english(np.array(i)) for i in r["kb_nodes"] if np.array(i).size > 0] 46 | 47 | count += 1 48 | 49 | # Skip non matching prefixes 50 | if args["filter_type_prefix"] is not None: 51 | if not r["type_string"].startswith(args["filter_type_prefix"]): 52 | continue 53 | 54 | types.add(r["type_string"]) 55 | labels.add(r["label"]) 56 | 57 | dist[(r["label"], r["type_string"])] += 1 58 | 59 | if args["print_records"]: 60 | print(r["src"] + " = " + r["label"]) 61 | for j in r["kb_nodes"]: 62 | print("NODE: " + j) 63 | print() 64 | 65 | 66 | 67 | print(f"\nTotal records processed: {count}") 68 | 69 | def shorten(i): 70 | return re.sub('[^A-Z]', '', i) 71 | # return i.replace("Station", "S").replace("Property", "P").replace("Adjacent", "A") 72 | 73 | headers = ["Label"] + [shorten(i) for i in list(types)] + ["Total"] 74 | data = [ [label] + [dist[(label, tpe)] for tpe in types] + [sum([dist[(label, tpe)] for tpe in types])] for label in labels] 75 | data.append(["Total"] + [sum([dist[(label, tpe)] for label in labels]) for tpe in types] + [sum(dist.values())]) 76 | width = [20] + [7 for i in types] + [7] 77 | tp.table(data, headers, width=width) 78 | 79 | 80 | 81 | 82 | -------------------------------------------------------------------------------- /macgraph/input/text_util.py: -------------------------------------------------------------------------------- 1 | 2 | from collections import Counter 3 | import tensorflow as tf 4 | import numpy as np 5 | from typing import List, Set 6 | import re 7 | import string 8 | from tqdm import tqdm 9 | 10 | import logging 11 | logger = logging.getLogger(__name__) 12 | 13 | from .util import read_gqa 14 | 15 | 16 | 17 | # -------------------------------------------------------------------------- 18 | # Constants 19 | # -------------------------------------------------------------------------- 20 | 21 | 22 | UNK = "" 23 | SOS = "" 24 | EOS = "" 25 | SPACE = "" 26 | 27 | CHARS = ["<"+i+">" for i in string.ascii_lowercase] + ["<"+i+">" for i in string.ascii_uppercase] 28 | SPECIAL_TOKENS = [UNK, SOS, EOS, SPACE] #+ CHARS 29 | 30 | UNK_ID = SPECIAL_TOKENS.index(UNK) 31 | SOS_ID = SPECIAL_TOKENS.index(SOS) 32 | EOS_ID = SPECIAL_TOKENS.index(EOS) 33 | 34 | 35 | 36 | # -------------------------------------------------------------------------- 37 | # Pretokenize 38 | # -------------------------------------------------------------------------- 39 | 40 | 41 | ENGLISH_PUNCTUATION = '!"#$%&()*+,-./:;=?@[\\]^_`{|}~' 42 | 43 | # -------------------------------------------------------------------------- 44 | 45 | 46 | def pretokenize_general(text): 47 | text = text.replace("\n", "") 48 | text = re.sub(r'\s*$', '', text) 49 | text = text.replace(" ", f" {SPACE} ") 50 | return text 51 | 52 | def detokenize_general(text): 53 | text = text.replace(f" {SPACE} ", " ") 54 | return text 55 | 56 | 57 | def pretokenize_json(value): 58 | if isinstance(value, str) or isinstance(value, bool) or isinstance(value, int): 59 | return str(value) 60 | raise ValueError("Unsupported json value type") 61 | 62 | 63 | def pretokenize_english(text): 64 | text = pretokenize_general(text) 65 | 66 | for p in ENGLISH_PUNCTUATION: 67 | text = text.replace(p, f" {p} ") 68 | 69 | text = re.sub(r'\s*$', '', text) 70 | return text 71 | 72 | 73 | def detokenize_english(text): 74 | text = detokenize_general(text) 75 | 76 | for p in ENGLISH_PUNCTUATION: 77 | text = text.replace(f" {p} ", p) 78 | 79 | return text 80 | 81 | 82 | def bytes_to_string(p): 83 | if len(p) == 0: 84 | return "" 85 | 86 | decode_utf8 = np.vectorize(lambda v: v.decode("utf-8")) 87 | p = decode_utf8(p) 88 | s = ''.join(p) 89 | return s 90 | 91 | 92 | # -------------------------------------------------------------------------- 93 | # Vocab 94 | # -------------------------------------------------------------------------- 95 | 96 | 97 | class Vocab(object): 98 | 99 | def __init__(self, table:List[str]): 100 | self.table = table 101 | 102 | def __contains__(self, value): 103 | return value in self.table 104 | 105 | def __iter__(self): 106 | return iter(self.table) 107 | 108 | def __len__(self): 109 | return len(self.table) 110 | 111 | # -------------------------------------------------------------------------- # 112 | 113 | def lookup(self, value): 114 | try: 115 | return self.table.index(value) 116 | except ValueError: 117 | return UNK_ID 118 | 119 | def inverse_lookup(self, value): 120 | try: 121 | return self.table[value] 122 | except IndexError: 123 | return UNK 124 | 125 | def ids_to_string(self, line, output_as_array=False): 126 | d = [self.inverse_lookup(i) for i in line] 127 | if output_as_array: 128 | return d 129 | else: 130 | return ' '.join(d) 131 | 132 | def string_to_ids(self, line): 133 | return [self.lookup(i) for i in line.split(' ')] 134 | 135 | def expand_unknowns(self, line): 136 | unknowns = set(line.split(' ')) 137 | unknowns -= set(self.table) 138 | unknowns -= set(['']) 139 | 140 | for t in unknowns: 141 | spaced = ''.join([f"<{c}> " for c in t]) 142 | line = line.replace(t, spaced) 143 | 144 | return line 145 | 146 | 147 | def english_to_ids(self, line): 148 | # TODO: Make greedy w.r.t. tokens with spaces in them 149 | line = pretokenize_english(line) 150 | line = self.expand_unknowns(line) 151 | line = self.string_to_ids(line) 152 | return line 153 | 154 | def ids_to_english(self, line): 155 | line = self.ids_to_string(line) 156 | line = detokenize_english(line) 157 | return line 158 | 159 | 160 | def prediction_value_to_string(self, v, output_as_array=False): 161 | """Rough 'n' ready get me the hell outta here fn. 162 | Tries its best to deal with the mess of datatypes that end up coming out""" 163 | 164 | if isinstance(v, np.int64): 165 | s = self.inverse_lookup(v) 166 | elif isinstance(v, np.ndarray): 167 | if v.dtype == np.int64: 168 | s = self.ids_to_string(v, output_as_array) 169 | elif v.dtype == object: 170 | s = bytes_to_string(v) 171 | else: 172 | raise ValueError() 173 | else: 174 | raise ValueError() 175 | 176 | return s 177 | 178 | 179 | 180 | def save(self, args): 181 | with tf.gfile.GFile(args["vocab_path"], 'w') as out_file: 182 | for i in self.table: 183 | out_file.write(i + "\n") 184 | 185 | 186 | 187 | # -------------------------------------------------------------------------- 188 | # Make me a vocab! 189 | # -------------------------------------------------------------------------- 190 | 191 | 192 | @classmethod 193 | def load(cls, path, size): 194 | 195 | tokens = list() 196 | 197 | with tf.gfile.GFile(path) as file: 198 | for line in file.readlines(): 199 | tokens.append(line.replace("\n", "")) 200 | 201 | if len(tokens) == size: 202 | break 203 | 204 | assert len(tokens) == len(set(tokens)), f"Duplicate lines in {path}" 205 | 206 | return Vocab(tokens) 207 | 208 | 209 | @classmethod 210 | def load_from_args(cls, args): 211 | return Vocab.load(args["vocab_path"], args["vocab_size"]) 212 | 213 | 214 | 215 | @classmethod 216 | def build(cls, args, gqa_to_tokens, limit=None): 217 | hits = Counter() 218 | 219 | def add(tokens:List[str]): 220 | for token in tokens: 221 | if token not in ["", " ", "\n"]: 222 | hits[token] += 1 223 | 224 | for i in tqdm(read_gqa(args, limit=limit), total=limit): 225 | add(gqa_to_tokens(i)) 226 | 227 | tokens = list() 228 | tokens.extend(SPECIAL_TOKENS) 229 | 230 | for i, c in hits.most_common(args["vocab_size"]): 231 | if len(tokens) == args["vocab_size"]: 232 | break 233 | 234 | if i not in tokens: 235 | tokens.append(i) 236 | 237 | assert len(tokens) <= args["vocab_size"] 238 | 239 | v = Vocab(tokens) 240 | v.save(args) 241 | 242 | return v 243 | 244 | 245 | 246 | 247 | -------------------------------------------------------------------------------- /macgraph/input/util.py: -------------------------------------------------------------------------------- 1 | 2 | import yaml 3 | import tensorflow as tf 4 | import random 5 | from tqdm import tqdm 6 | from collections import Counter 7 | from contextlib import ExitStack 8 | 9 | import logging 10 | logger = logging.getLogger(__name__) 11 | 12 | # -------------------------------------------------------------------------- 13 | # Miscel 14 | # -------------------------------------------------------------------------- 15 | 16 | def min_none(a, b): 17 | if a is None: 18 | return b 19 | if b is None: 20 | return a 21 | return min(a,b) 22 | 23 | 24 | # -------------------------------------------------------------------------- 25 | # TFRecord functions 26 | # -------------------------------------------------------------------------- 27 | 28 | # Why it's so awkward to write a record I do not know 29 | 30 | def write_int64_feature(value): 31 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 32 | 33 | def write_int64_array_feature(value): 34 | return tf.train.Feature(int64_list=tf.train.Int64List(value=value)), 35 | 36 | def write_boolean_array_feature(value): 37 | return write_int64_array_feature(value) 38 | 39 | def write_string_feature(value): 40 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[tf.compat.as_bytes(value)])) 41 | 42 | 43 | # TODO: Better naming / structure 44 | 45 | def parse_feature_int_array(): 46 | return tf.FixedLenSequenceFeature([],tf.int64, allow_missing=True) 47 | 48 | def parse_feature_boolean_array(): 49 | return parse_feature_int_array() 50 | 51 | def parse_feature_string(): 52 | return tf.FixedLenSequenceFeature([],tf.string, allow_missing=True) 53 | 54 | def parse_feature_int(): 55 | return tf.FixedLenFeature([], tf.int64) 56 | 57 | 58 | # -------------------------------------------------------------------------- 59 | # TF helpers 60 | # -------------------------------------------------------------------------- 61 | 62 | def tf_startswith(tensor, prefix, axis=None): 63 | return tf.reduce_all(tf.equal(tf.substr(tensor, 0, len(prefix)), prefix), axis=axis) 64 | 65 | 66 | 67 | # -------------------------------------------------------------------------- 68 | # File readers and writers 69 | # -------------------------------------------------------------------------- 70 | 71 | def read_gqa(args, limit=None): 72 | 73 | if limit is None: 74 | limit = args["limit"] 75 | 76 | with ExitStack() as stack: 77 | files = [stack.enter_context(open(fname)) for fname in args["gqa_paths"]] 78 | 79 | in_files = [ 80 | stack.enter_context(tf.gfile.GFile(i, 'r')) 81 | for i in args["gqa_paths"] 82 | ] 83 | 84 | yamls = [ 85 | yaml.safe_load_all(i) 86 | for i in in_files 87 | ] 88 | 89 | ctr = 0 90 | 91 | for row in zip(*yamls): 92 | for i in row: 93 | if i is not None: 94 | if args["filter_type_prefix"] is None or i["question"]["type_string"].startswith(args["filter_type_prefix"]): 95 | yield i 96 | ctr += 1 97 | if limit is not None and ctr >= limit: 98 | logger.debug("Hit limit, stop") 99 | return 100 | else: 101 | logger.debug(f"{i['question']['type_string']} does not match prefix {args['filter_type_prefix']}") 102 | else: 103 | logger.debug("Skipping None yaml doc") 104 | 105 | 106 | 107 | 108 | # -------------------------------------------------------------------------- 109 | # Dataset helpers 110 | # -------------------------------------------------------------------------- 111 | 112 | def StringDataset(s): 113 | 114 | def generator(): 115 | yield s 116 | 117 | return tf.data.Dataset.from_generator(generator, tf.string, tf.TensorShape([]) ) 118 | 119 | 120 | 121 | 122 | 123 | -------------------------------------------------------------------------------- /macgraph/layers.py: -------------------------------------------------------------------------------- 1 | 2 | import tensorflow as tf 3 | 4 | from .activations import * 5 | 6 | def layer_selu(tensor, width, dropout=0.0, name=None): 7 | 8 | if name is None: 9 | name_dense = None 10 | name_drop = None 11 | else: 12 | name_dense = name + "_dense" 13 | name_drop = name + "_drop" 14 | 15 | r = tf.layers.dense(tensor, width, 16 | activation=tf.nn.selu, 17 | kernel_initializer=tf.contrib.layers.variance_scaling_initializer(factor=1.0), 18 | name=name_dense) 19 | 20 | if dropout > 0.0: 21 | r = tf.contrib.nn.alpha_dropout(r, dropout, name=name_drop) 22 | 23 | return r 24 | 25 | def layer_dense(tensor, width, activation_str="linear", dropout=0.0, name=None): 26 | 27 | if activation_str == "selu": 28 | return layer_selu(tensor, width, dropout, name) 29 | else: 30 | v = tf.layers.dense(tensor, width, activation=ACTIVATION_FNS[activation_str], name=name) 31 | 32 | if dropout > 0: 33 | v = tf.nn.dropout(v, 1.0-dropout) 34 | 35 | return v 36 | 37 | 38 | def deeep(tensor, width, depth=2, residual_depth=3, activation=tf.nn.tanh): 39 | """ 40 | Quick 'n' dirty "let's slap on some layers" function. 41 | 42 | Implements residual connections and applys them when it can. Uses this schematic: 43 | https://blog.waya.ai/deep-residual-learning-9610bb62c355 44 | """ 45 | with tf.name_scope("deeep"): 46 | 47 | if residual_depth is not None: 48 | for i in range(math.floor(depth/residual_depth)): 49 | tensor_in = tensor 50 | 51 | for j in range(residual_depth-1): 52 | tensor = tf.layers.dense(tensor, width, activation=activation) 53 | 54 | tensor = tf.layers.dense(tensor, width) 55 | 56 | if tensor_in.shape[-1] == width: 57 | tensor += tensor_in 58 | 59 | tensor = activation(tensor) 60 | 61 | remaining = depth % residual_depth 62 | else: 63 | remaining = depth 64 | 65 | for i in range(remaining): 66 | tensor = tf.layers.dense(tensor, width, activation=activation) 67 | 68 | return tensor -------------------------------------------------------------------------------- /macgraph/minception.py: -------------------------------------------------------------------------------- 1 | 2 | import tensorflow as tf 3 | 4 | from .util import * 5 | from .const import EPSILON 6 | 7 | ''' 8 | The mini-inception (mi) library 9 | 10 | This is inspired by Google's inception network 11 | and DARTS architecture search. I didn't get fancy 12 | on the bilevel optimization, so let's see how it goes!! 13 | 14 | ''' 15 | 16 | MI_ACTIVATIONS = 5 17 | 18 | 19 | def mi_activation(tensor, tap=False): 20 | with tf.name_scope("mi_activation"): 21 | activations = [ 22 | tf.nn.relu, 23 | lambda x: tf.nn.relu(x) + tf.nn.relu(-x), # Abslu... yeah.. 24 | tf.tanh, 25 | tf.nn.sigmoid, 26 | tf.identity, 27 | ] 28 | 29 | axis = 1 30 | 31 | choice = tf.get_variable("darts_choice", [len(activations)]) 32 | choice = tf.nn.softmax(choice) 33 | choice = tf.check_numerics(choice, "activation_choice") 34 | 35 | t = [activations[i](tensor) * choice[i] 36 | for i in range(len(activations))] 37 | 38 | t = sum(t) 39 | t = dynamic_assert_shape(t, tf.shape(tensor)) 40 | 41 | if tap: 42 | return t, choice 43 | else: 44 | return t 45 | 46 | 47 | def mi_activation_control(tensor, control=None, tap=False): 48 | with tf.name_scope("mi_activation"): 49 | activations = [ 50 | tf.nn.relu, 51 | lambda x: tf.nn.relu(-x), # Combining this with previous gives PRelu 52 | tf.tanh, 53 | tf.nn.sigmoid, 54 | tf.identity, 55 | ] 56 | 57 | axis = 1 58 | 59 | if control is None: 60 | v = tf.get_variable("mi_choice", [1, len(activations)]) 61 | choice = tf.tile(v, [tf.shape(tensor)[0], 1]) 62 | else: 63 | choice = tf.layers.dense(control, len(activations)) 64 | 65 | choice = tf.nn.softmax(choice, axis=axis) 66 | choice = tf.check_numerics(choice, "activation_choice") 67 | 68 | t = [activations[i](tensor) * tf.expand_dims(choice[:,i], axis) 69 | for i in range(len(activations))] 70 | 71 | t = sum(t) 72 | t = dynamic_assert_shape(t, tf.shape(tensor)) 73 | 74 | if tap: 75 | return t, choice 76 | else: 77 | return t 78 | 79 | 80 | def mi_residual(tensor, width): 81 | with tf.name_scope("mi_residual"): 82 | 83 | choice = tf.get_variable("choice", [2]) 84 | choice = tf.nn.sigmoid(choice) 85 | 86 | tensor = tf.layers.dense(tensor, width) 87 | 88 | left = choice[0] * tf.layers.dense( 89 | mi_activation( 90 | tf.layers.dense(tensor, width) 91 | ) 92 | , width) 93 | 94 | right = choice[1] * tensor 95 | 96 | join = left + right 97 | out = mi_activation(join) 98 | 99 | return join 100 | 101 | 102 | 103 | def mi_deep(tensor, width, depth): 104 | with tf.name_scope("mi_deep"): 105 | 106 | t = tensor 107 | 108 | for i in range(depth // 2): 109 | t = mi_residual(t, width) 110 | 111 | for i in range(depth % 2): 112 | t = tf.layers.dense(t, width) 113 | t = mi_activation(t) 114 | 115 | return t 116 | 117 | 118 | 119 | 120 | -------------------------------------------------------------------------------- /macgraph/model.py: -------------------------------------------------------------------------------- 1 | 2 | import tensorflow as tf 3 | import numpy as np 4 | import yaml 5 | 6 | from .cell import execute_reasoning 7 | from .cell import * 8 | from .util import * 9 | from .hooks import * 10 | from .input import * 11 | from .optimizer import * 12 | 13 | def model_fn(features, labels, mode, params): 14 | 15 | # -------------------------------------------------------------------------- 16 | # Setup input 17 | # -------------------------------------------------------------------------- 18 | 19 | args = params 20 | 21 | # EstimatorSpec slots 22 | loss = None 23 | train_op = None 24 | eval_metric_ops = None 25 | predictions = None 26 | eval_hooks = None 27 | 28 | vocab = Vocab.load_from_args(args) 29 | 30 | # -------------------------------------------------------------------------- 31 | # Vocabulary embedding 32 | # -------------------------------------------------------------------------- 33 | 34 | vocab_shape = [args["vocab_size"], args["embed_width"]] 35 | 36 | if args["use_embed_const_eye"]: 37 | vocab_embedding = tf.constant( 38 | np.eye(*vocab_shape)*2.0 - 1.0, 39 | dtype=tf.float32, 40 | shape=vocab_shape) 41 | 42 | else: 43 | vocab_embedding = tf.get_variable( 44 | "vocab_embedding", 45 | shape=vocab_shape, 46 | dtype=tf.float32) 47 | 48 | # -------------------------------------------------------------------------- 49 | # Model for realz 50 | # -------------------------------------------------------------------------- 51 | 52 | logits, taps = execute_reasoning(args, 53 | features=features, 54 | labels=labels, 55 | vocab_embedding=vocab_embedding) 56 | 57 | # -------------------------------------------------------------------------- 58 | # Calc loss 59 | # -------------------------------------------------------------------------- 60 | 61 | if mode in [tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL]: 62 | crossent = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=logits) 63 | loss_logit = tf.reduce_sum(crossent) / tf.to_float(features["d_batch_size"]) 64 | loss = loss_logit 65 | # -------------------------------------------------------------------------- 66 | # Optimize 67 | # -------------------------------------------------------------------------- 68 | 69 | global_step = tf.train.get_global_step() 70 | 71 | if mode == tf.estimator.ModeKeys.TRAIN: 72 | 73 | learning_rate = args["learning_rate"] 74 | 75 | if args["use_lr_finder"]: 76 | learning_rate = tf.train.exponential_decay( 77 | args["learning_rate"]/1e3, 78 | global_step, 79 | decay_steps=100, 80 | decay_rate=1.5) 81 | 82 | 83 | var = tf.trainable_variables() 84 | gradients = tf.gradients(loss, var) 85 | norms = [tf.norm(i, 2) for i in gradients if i is not None] 86 | 87 | for i in gradients: 88 | if i is not None: 89 | tf.summary.histogram(f"{i.name}", i, family="gradient_norm") 90 | 91 | tf.summary.scalar("learning_rate", learning_rate, family="hyperparam") 92 | tf.summary.scalar("current_step", global_step, family="hyperparam") 93 | tf.summary.scalar("grad_norm", tf.reduce_max(norms), family="hyperparam") 94 | 95 | optimizer = tf.train.AdamOptimizer(learning_rate) 96 | 97 | if args["use_gradient_clipping"]: 98 | train_op, gradients = minimize_clipped(optimizer, loss, args["max_gradient_norm"]) 99 | else: 100 | train_op = optimizer.minimize(loss, global_step=global_step) 101 | 102 | 103 | # -------------------------------------------------------------------------- 104 | # Predictions 105 | # -------------------------------------------------------------------------- 106 | 107 | if mode in [tf.estimator.ModeKeys.PREDICT, tf.estimator.ModeKeys.EVAL]: 108 | 109 | predicted_labels = tf.argmax(tf.nn.softmax(logits), axis=-1) 110 | 111 | predictions = { 112 | "predicted_label": predicted_labels, 113 | "actual_label": features["label"], 114 | } 115 | 116 | # For diagnostic visualisation 117 | predictions.update(features) 118 | predictions.update(taps) 119 | 120 | # Fake features do not have batch, must be removed 121 | del predictions["d_batch_size"] 122 | del predictions["d_src_len"] 123 | 124 | # -------------------------------------------------------------------------- 125 | # Eval metrics 126 | # -------------------------------------------------------------------------- 127 | 128 | if mode == tf.estimator.ModeKeys.EVAL: 129 | 130 | eval_metric_ops = { 131 | "accuracy": tf.metrics.accuracy(labels=labels, predictions=predicted_labels), 132 | "current_step": tf.metrics.mean(global_step), 133 | } 134 | 135 | try: 136 | with tf.gfile.GFile(args["question_types_path"]) as file: 137 | doc = yaml.load(file) 138 | for type_string in doc.keys(): 139 | if args["filter_type_prefix"] is None or type_string.startswith(args["filter_type_prefix"]): 140 | eval_metric_ops["type_accuracy_"+type_string] = tf.metrics.accuracy( 141 | labels=labels, 142 | predictions=predicted_labels, 143 | weights=tf.equal(features["type_string"], type_string)) 144 | 145 | 146 | with tf.gfile.GFile(args["answer_classes_path"]) as file: 147 | doc = yaml.load(file) 148 | for answer_class in doc.keys(): 149 | if args["filter_output_class"] is None or answer_class in args["filter_output_class"]: 150 | e = vocab.lookup(pretokenize_json(answer_class)) 151 | weights = tf.equal(labels, tf.cast(e, tf.int64)) 152 | eval_metric_ops["class_accuracy_"+str(answer_class)] = tf.metrics.accuracy( 153 | labels=labels, 154 | predictions=predicted_labels, 155 | weights=weights) 156 | 157 | except tf.errors.NotFoundError as err: 158 | print(err) 159 | pass 160 | except Exception as err: 161 | print(err) 162 | pass 163 | 164 | 165 | if args["use_floyd"]: 166 | eval_hooks = [FloydHubMetricHook(eval_metric_ops)] 167 | 168 | return tf.estimator.EstimatorSpec(mode, 169 | loss=loss, 170 | train_op=train_op, 171 | predictions=predictions, 172 | eval_metric_ops=eval_metric_ops, 173 | export_outputs=None, 174 | training_chief_hooks=None, 175 | training_hooks=None, 176 | scaffold=None, 177 | evaluation_hooks=eval_hooks, 178 | prediction_hooks=None 179 | ) 180 | -------------------------------------------------------------------------------- /macgraph/optimizer.py: -------------------------------------------------------------------------------- 1 | 2 | from tensorflow.python.training import optimizer 3 | from tensorflow.python.ops import control_flow_ops 4 | from tensorflow.python.ops import math_ops 5 | from tensorflow.python.ops import init_ops 6 | from tensorflow.python.ops import state_ops 7 | from tensorflow.python.framework import ops 8 | 9 | class PercentDeltaOptimizer(optimizer.Optimizer): 10 | 11 | def __init__(self, target=0.2, use_locking=False, name="PercentDelta"): 12 | super(PercentDeltaOptimizer, self).__init__(use_locking, name) 13 | self._target = target 14 | 15 | def _prepare(self): 16 | self._target_t = ops.convert_to_tensor(self._target, name="target") 17 | 18 | def _apply_dense(self, grad, var): 19 | return self._apply_pd(grad, var) 20 | 21 | def _apply_grad_descent(self, grad, var): 22 | lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) 23 | var_update = state_ops.assign_sub(var, grad * lr_t) 24 | return control_flow_ops.group(*[var_update]) 25 | 26 | # Thanks to https://github.com/google/asymproj_edge_dnn/blob/master/edge_nn.py 27 | def _apply_pd(self, grad, var): 28 | 29 | def PlusEpsilon(x, eps=1e-6): 30 | """Element-wise add `eps` to `x` without changing sign of `x`.""" 31 | return x + (tf.sign(x) * eps) 32 | 33 | target_t = math_ops.cast(self._target_t, var.dtype.base_dtype) 34 | mean_percent_grad = tf.reduce_mean(tf.abs(tf.div(grad, PlusEpsilon(var)))) 35 | lr_t = tf.div(target_t, (mean_percent_grad + 1e-5)) 36 | 37 | var_update = state_ops.assign_sub(var, grad * lr_t) 38 | return var_update 39 | 40 | def _apply_sparse(self, grad, var): 41 | raise NotImplementedError("Sparse gradient updates are not supported.") -------------------------------------------------------------------------------- /macgraph/predict.py: -------------------------------------------------------------------------------- 1 | 2 | import tensorflow as tf 3 | import numpy as np 4 | from collections import Counter 5 | from colored import fg, bg, stylize 6 | import math 7 | import argparse 8 | import yaml 9 | import os.path 10 | 11 | from .input.text_util import UNK_ID 12 | from .estimator import get_estimator 13 | from .input import * 14 | from .const import EPSILON 15 | from .args import get_git_hash 16 | from .global_args import global_args 17 | from .print_util import * 18 | 19 | from .cell import MAC_Component 20 | 21 | import logging 22 | logger = logging.getLogger(__name__) 23 | 24 | 25 | # Make TF be quiet 26 | import os 27 | os.environ["TF_CPP_MIN_LOG_LEVEL"]="2" 28 | 29 | 30 | 31 | def predict(args, cmd_args): 32 | estimator = get_estimator(args) 33 | 34 | # Info about the experiment, for the record 35 | tfr_size = sum(1 for _ in tf.python_io.tf_record_iterator(args["predict_input_path"])) 36 | logger.info(args) 37 | logger.info(f"Predicting on {tfr_size} input records") 38 | 39 | # Actually do some work 40 | predictions = estimator.predict(input_fn=gen_input_fn(args, "predict")) 41 | vocab = Vocab.load_from_args(args) 42 | 43 | 44 | # And build the component 45 | mac = MAC_Component(args) 46 | 47 | 48 | def print_query(i, prefix, row): 49 | switch_attn = row[f"{prefix}_switch_attn"][i] 50 | print(f"{i}: {prefix}_switch: ", 51 | ' '.join(color_text(args["query_sources"], row[f"{prefix}_switch_attn"][i]))) 52 | # print(np.squeeze(switch_attn), f"Σ={sum(switch_attn)}") 53 | 54 | for idx, part_noun in enumerate(args["query_sources"]): 55 | if row[f"{prefix}_switch_attn"][i][idx] > ATTN_THRESHOLD: 56 | 57 | if part_noun == "step_const": 58 | print(f"{i}: {prefix}_step_const_signal: {row[f'{prefix}_step_const_signal']}") 59 | db = None 60 | if part_noun.startswith("token"): 61 | db = row["src"] 62 | elif part_noun.startswith("prev_output"): 63 | db = list(range(i+1)) 64 | 65 | if db is not None: 66 | scores = row[f"{prefix}_{part_noun}_attn"][i] 67 | attn_sum = sum(scores) 68 | assert attn_sum > 0.99, f"Attention does not sum to 1.0 {prefix}_{part_noun}_attn" 69 | v = ' '.join(color_text(db, scores)) 70 | print(f"{i}: {prefix}_{part_noun}_attn: {v}") 71 | print(f"{i}: {prefix}_{part_noun}_attn: {color_vector(np.squeeze(scores))} Σ={attn_sum}") 72 | 73 | def print_row(row): 74 | if row["actual_label"] == row["predicted_label"]: 75 | emoji = "✅" 76 | answer_part = f"{stylize(row['predicted_label'], bg(22))}" 77 | else: 78 | emoji = "❌" 79 | answer_part = f"{stylize(row['predicted_label'], bg(1))}, expected {row['actual_label']}" 80 | 81 | 82 | print(emoji, " ", answer_part, " - ", ''.join(row['src']).replace('', ' ').replace('', '')) 83 | 84 | if cmd_args["hide_details"]: 85 | return 86 | 87 | for i in range(frozen_args["max_decode_iterations"]): 88 | 89 | hr_text(f"Iteration {i}") 90 | 91 | def get_slice_if_poss(v,i): 92 | try: 93 | return v[i] 94 | except: 95 | return v 96 | 97 | row_iter_slice = { 98 | k: get_slice_if_poss(v,i) for k, v in row.items() 99 | } 100 | 101 | mac.print_all(row_iter_slice) 102 | 103 | 104 | mp_reads = [f"mp_read{i}" for i in range(args["mp_read_heads"])] 105 | 106 | for mp_head in ["mp_write", *mp_reads]: 107 | 108 | # -- Print node query --- 109 | # print_query(i, mp_head+"_query", row) 110 | 111 | # --- Print node attn --- 112 | db = [vocab.prediction_value_to_string(kb_row[0:1]) for kb_row in row["kb_nodes"]] 113 | db = db[0:row["kb_nodes_len"]] 114 | 115 | tap = mp_head+"_attn" 116 | attn_sum = sum(row[mp_head+"_attn"][i]) 117 | print(f"{i}: {mp_head}_attn: ",', '.join(color_text(db, row[mp_head+"_attn"][i]))) 118 | # print(f"{i}: {tap}: ", list(zip(db, np.squeeze(row[tap][i]))), f"Σ={attn_sum}") 119 | 120 | # for tap in ["signal"]: 121 | # t_v = row[f'{mp_head}_{tap}'][i] 122 | # print(f"{i}: {mp_head}_{tap}: {color_vector(t_v)}") 123 | 124 | # mp_state = color_vector(row['mp_node_state'][i][0:row['kb_nodes_len']]) 125 | # node_ids = [' node ' + pad_str(vocab.prediction_value_to_string(row[0])) for row in row['kb_nodes']] 126 | # s = [': '.join(i) for i in zip(node_ids, mp_state)] 127 | # mp_state_str = '\n'.join(s) 128 | # print(f"{i}: mp_node_state:") 129 | # print(mp_state_str) 130 | 131 | 132 | hr() 133 | print("Adjacency:\n", 134 | adj_pretty(row["kb_adjacency"], row["kb_nodes_len"], row["kb_nodes"], vocab)) 135 | 136 | 137 | 138 | def decode_row(row): 139 | for i in ["type_string", "actual_label", "predicted_label", "src"]: 140 | row[i] = vocab.prediction_value_to_string(row[i], True) 141 | 142 | stats = Counter() 143 | output_classes = Counter() 144 | predicted_classes = Counter() 145 | confusion = Counter() 146 | 147 | for count, p in enumerate(predictions): 148 | if count >= cmd_args["n"]: 149 | break 150 | 151 | decode_row(p) 152 | if cmd_args["filter_type_prefix"] is None or p["type_string"].startswith(cmd_args["filter_type_prefix"]): 153 | if cmd_args["filter_output_class"] is None or p["predicted_label"] == cmd_args["filter_output_class"]: 154 | if cmd_args["filter_expected_class"] is None or p["actual_label"] == cmd_args["filter_expected_class"]: 155 | 156 | output_classes[p["actual_label"]] += 1 157 | predicted_classes[p["predicted_label"]] += 1 158 | 159 | correct = p["actual_label"] == p["predicted_label"] 160 | 161 | if cmd_args["failed_only"] and not correct: 162 | print_row(p) 163 | elif cmd_args["correct_only"] and correct: 164 | print_row(p) 165 | elif not cmd_args["failed_only"] and not cmd_args["correct_only"]: 166 | print_row(p) 167 | 168 | 169 | if __name__ == "__main__": 170 | 171 | # -------------------------------------------------------------------------- 172 | # Arguments 173 | # -------------------------------------------------------------------------- 174 | parser = argparse.ArgumentParser() 175 | parser.add_argument("--n",type=int,default=20) 176 | parser.add_argument("--filter-type-prefix",type=str,default=None) 177 | parser.add_argument("--filter-output-class",type=str,default=None) 178 | parser.add_argument("--filter-expected-class",type=str,default=None) 179 | parser.add_argument("--model-dir",type=str,default=None) 180 | parser.add_argument("--model-dir-prefix",type=str,default="output/model") 181 | parser.add_argument('--dataset',type=str, default="default", help="Name of dataset") 182 | parser.add_argument("--model-version",type=str,default=get_git_hash()) 183 | 184 | parser.add_argument("--correct-only",action='store_true') 185 | parser.add_argument("--failed-only",action='store_true') 186 | parser.add_argument("--hide-details",action='store_true') 187 | 188 | cmd_args = vars(parser.parse_args()) 189 | 190 | if cmd_args["model_dir"] is None: 191 | cmd_args["model_dir"] = os.path.join(cmd_args["model_dir_prefix"], cmd_args["dataset"], cmd_args["model_version"]) 192 | 193 | with tf.gfile.GFile(os.path.join(cmd_args["model_dir"], "config.yaml"), "r") as file: 194 | frozen_args = yaml.load(file) 195 | 196 | # If the directory got renamed, the model_dir might be out of sync, convenience hack 197 | frozen_args["model_dir"] = cmd_args["model_dir"] 198 | 199 | global_args.clear() 200 | global_args.update(frozen_args) 201 | 202 | 203 | 204 | # -------------------------------------------------------------------------- 205 | # Logging 206 | # -------------------------------------------------------------------------- 207 | 208 | logging.basicConfig() 209 | tf.logging.set_verbosity("WARN") 210 | logger.setLevel("WARN") 211 | logging.getLogger("mac-graph").setLevel("WARN") 212 | 213 | 214 | 215 | # -------------------------------------------------------------------------- 216 | # Lessssss do it! 217 | # -------------------------------------------------------------------------- 218 | 219 | predict(frozen_args, cmd_args) 220 | 221 | 222 | 223 | -------------------------------------------------------------------------------- /macgraph/print_util.py: -------------------------------------------------------------------------------- 1 | 2 | import tensorflow as tf 3 | import numpy as np 4 | 5 | from colored import fg, bg, stylize 6 | from .const import EPSILON 7 | import math 8 | 9 | 10 | TARGET_CHAR_WIDTH = 80 11 | 12 | def hr_text(text): 13 | t_len = len(text) + 2 14 | pad_len = (TARGET_CHAR_WIDTH - t_len) // 2 15 | padding = '-'.join(["" for i in range(pad_len)]) 16 | 17 | s = padding + " " + text + " " + padding 18 | 19 | print(stylize(s, fg("yellow"))) 20 | 21 | 22 | def hr(bold=False): 23 | if bold: 24 | print(stylize("--------------------------", fg("yellow"))) 25 | else: 26 | print(stylize("--------------------------", fg("blue"))) 27 | 28 | DARK_GREY = 235 29 | WHITE = 255 30 | 31 | BG_BLACK = 232 32 | BG_DARK_GREY = 237 33 | 34 | ATTN_THRESHOLD = 0.25 35 | 36 | np.set_printoptions(precision=3) 37 | 38 | 39 | def color_text(text_array, levels, color_fg=True): 40 | out = [] 41 | 42 | l_max = np.amax(levels) 43 | l_min = np.amin(levels) 44 | 45 | l_max = max(l_max, 1.0) 46 | 47 | for l, s in zip(levels, text_array): 48 | l_n = (l - l_min) / (l_max + EPSILON) 49 | l_n = max(0.0, min(1.0, l_n)) 50 | if color_fg: 51 | color = fg(int(math.floor(DARK_GREY + l_n * (WHITE-DARK_GREY)))) 52 | else: 53 | color = bg(int(math.floor(BG_BLACK + l_n * (BG_DARK_GREY-BG_BLACK)))) 54 | out.append(stylize(s, color)) 55 | return out 56 | 57 | def color_vector(vec, show_numbers=True): 58 | 59 | v_max = np.amax(vec) 60 | v_min = np.amin(vec) 61 | delta = np.abs(v_max - v_min) 62 | norm = (vec - v_min) / np.maximum(delta, 0.00001) 63 | 64 | def format_element(n): 65 | if show_numbers: 66 | return str(np.around(n, 4)) 67 | else: 68 | return "-" if n < -EPSILON else ("+" if n > EPSILON else "0") 69 | 70 | def to_color(row): 71 | return ' '.join(color_text([format_element(i) for i in row], (row-v_min) / np.maximum(delta, EPSILON))) 72 | 73 | if len(np.shape(vec)) == 1: 74 | return to_color(vec) 75 | else: 76 | return [to_color(row) for row in vec] 77 | 78 | def pad_str(s, target=3): 79 | if len(s) < target: 80 | for i in range(target - len(s)): 81 | s += " " 82 | return s 83 | 84 | def adj_pretty(mtx, kb_nodes_len, kb_nodes, vocab): 85 | output = "" 86 | 87 | for r_idx, row in enumerate(mtx): 88 | if r_idx < kb_nodes_len: 89 | 90 | r_id = kb_nodes[r_idx][0] 91 | r_name = vocab.inverse_lookup(r_id) 92 | output += pad_str(f"{r_name}: ",target=4) 93 | 94 | for c_idx, item in enumerate(row): 95 | if c_idx < kb_nodes_len: 96 | 97 | c_id = kb_nodes[c_idx][0] 98 | c_name = vocab.inverse_lookup(c_id) 99 | 100 | if item: 101 | output += pad_str(f"{c_name}") 102 | else: 103 | output += pad_str(" ") 104 | output += "\n" 105 | 106 | return output -------------------------------------------------------------------------------- /macgraph/train.py: -------------------------------------------------------------------------------- 1 | 2 | import tensorflow as tf 3 | from tensorflow.python import debug as tf_debug 4 | from collections import namedtuple 5 | 6 | from .estimator import get_estimator 7 | from .input import gen_input_fn 8 | from .args import * 9 | 10 | # Make TF be quiet 11 | import os 12 | os.environ["TF_CPP_MIN_LOG_LEVEL"]="2" 13 | 14 | import logging 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | 19 | def train(args): 20 | 21 | # So I don't frigging forget what caused working models 22 | save_args(args) 23 | 24 | 25 | if args["use_tf_debug"]: 26 | hooks = [tf_debug.LocalCLIDebugHook()] 27 | else: 28 | hooks = [] 29 | 30 | 31 | train_size = sum(1 for _ in tf.python_io.tf_record_iterator(args["train_input_path"])) 32 | tf.logging.info(f"Training on {train_size} records") 33 | 34 | # ---------------------------------------------------------------------------------- 35 | 36 | 37 | 38 | training_segments = [] 39 | TrainingSegment = namedtuple('TrainingSegment', ['args', 'max_steps']) 40 | 41 | if args["use_curriculum"]: 42 | assert args["train_max_steps"] is not None, "Curriculum training requires --train-max-steps" 43 | 44 | seg_steps = args["train_max_steps"] / float(args["max_decode_iterations"]) 45 | 46 | for i in range(1, args["max_decode_iterations"]+1): 47 | 48 | seg_args = {**args} 49 | seg_args["filter_output_class"] = [str(j) for j in list(range(i+1))] 50 | total_seg_steps = i*seg_steps*1000 51 | 52 | 53 | training_segments.append(TrainingSegment(seg_args, total_seg_steps)) 54 | 55 | else: 56 | training_segments.append(TrainingSegment(args, args["train_max_steps"]*1000 if args["train_max_steps"] is not None else None)) 57 | 58 | 59 | for i in training_segments: 60 | 61 | tf.logging.info(f"Begin training segment {i.max_steps} {i.args['filter_output_class']}") 62 | 63 | estimator = get_estimator(i.args) 64 | 65 | train_spec = tf.estimator.TrainSpec( 66 | input_fn=gen_input_fn(i.args, "train"), 67 | max_steps=int(i.max_steps), 68 | hooks=hooks) 69 | 70 | eval_spec = tf.estimator.EvalSpec( 71 | input_fn=gen_input_fn(i.args, "eval"), 72 | throttle_secs=i.args["eval_every"]) 73 | 74 | tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec) 75 | 76 | 77 | 78 | if __name__ == "__main__": 79 | args = get_args() 80 | 81 | # DO IT! 82 | train(args) 83 | 84 | 85 | 86 | -------------------------------------------------------------------------------- /macgraph/unit_test.py: -------------------------------------------------------------------------------- 1 | 2 | import tensorflow as tf 3 | 4 | 5 | # -------------------------------------------------------------------------- 6 | # The tests 7 | # -------------------------------------------------------------------------- 8 | 9 | from .cell.test_read_cell import ReadTest 10 | 11 | 12 | 13 | # -------------------------------------------------------------------------- 14 | 15 | if __name__ == '__main__': 16 | tf.test.main() -------------------------------------------------------------------------------- /macgraph/util.py: -------------------------------------------------------------------------------- 1 | 2 | import tensorflow as tf 3 | import math 4 | 5 | from .global_args import global_args 6 | 7 | def tf_assert_almost_equal(x, y, delta=0.001, **kwargs): 8 | return tf.assert_less(tf.abs(x-y), delta, **kwargs) 9 | 10 | def assert_shape(tensor, shape, batchless=False): 11 | 12 | read_from = 0 if batchless else 1 13 | 14 | lhs = tf.TensorShape(tensor.shape[read_from:]) 15 | rhs = tf.TensorShape(shape) 16 | 17 | lhs.assert_is_compatible_with(rhs) 18 | 19 | # assert lhs == shape, f"{tensor.name} is wrong shape, expected {shape} found {lhs}" 20 | 21 | def assert_rank(tensor, rank): 22 | assert len(tensor.shape) == rank, f"{tensor.name} is wrong rank, expected {rank} got {len(tensor.shape)}" 23 | 24 | 25 | def dynamic_assert_shape(tensor, shape, name=None): 26 | """ 27 | Check that a tensor has a shape given by a list of constants and tensor values. 28 | 29 | This function will place an operation into your graph that gets executed at runtime. 30 | This is helpful because often tensors have many dynamic sized dimensions that 31 | you cannot otherwise compare / assert are as you expect. 32 | 33 | For example, measure a dimension at run time: 34 | `batch_size = tf.shape(my_tensor)[0]` 35 | 36 | then assert another tensor does indeed have the right shape: 37 | `other_tensor = dynamic_assert_shape(other_tensor, [batch_size, 16])` 38 | 39 | You should use this as an inline identity function so that the operation it generates 40 | gets added and executed in the graph 41 | 42 | Returns: the argument `tensor` unchanged 43 | """ 44 | 45 | if global_args["use_assert"]: 46 | 47 | tensor_shape = tf.shape(tensor) 48 | tensor_shape = tf.cast(tensor_shape, tf.int64) 49 | 50 | expected_shape = tf.convert_to_tensor(shape) 51 | expected_shape = tf.cast(expected_shape, tf.int64) 52 | 53 | t_name = "tensor" if tf.executing_eagerly() else tensor.name 54 | 55 | if isinstance(shape, list): 56 | assert len(tensor.shape) == len(shape), f"Tensor shape {tensor_shape} and expected shape {expected_shape} have different lengths" 57 | 58 | assert_op = tf.assert_equal(tensor_shape, expected_shape, message=f"Asserting shape of {t_name}", summarize=10, name=name) 59 | 60 | with tf.control_dependencies([assert_op]): 61 | return tf.identity(tensor, name="dynamic_assert_shape") 62 | 63 | else: 64 | return tensor 65 | 66 | 67 | 68 | def minimize_clipped(optimizer, value, max_gradient_norm, var=None): 69 | global_step = tf.train.get_global_step() 70 | 71 | if var is None: 72 | var = tf.trainable_variables() 73 | 74 | gradients = tf.gradients(value, var) 75 | clipped_gradients, _ = tf.clip_by_global_norm(gradients, max_gradient_norm) 76 | grad_dict = dict(zip(var, clipped_gradients)) 77 | op = optimizer.apply_gradients(zip(clipped_gradients, var), global_step=global_step) 78 | return op, grad_dict 79 | 80 | 81 | 82 | """ 83 | Where length is the second dimension 84 | """ 85 | def pad_to_table_len(tensor, table_to_mimic=None, seq_len=None, name=None): 86 | if seq_len is not None: 87 | delta = seq_len - tf.shape(tensor)[1] 88 | elif table_to_mimic is not None: 89 | delta = tf.shape(table_to_mimic)[1] - tf.shape(tensor)[1] 90 | else: 91 | raise Exception("Must have argument table_to_mimic or seq_len") 92 | 93 | tensor = tf.pad(tensor, [ [0,0], [0,delta], [0,0] ], name=name) # zero pad out 94 | 95 | if seq_len is not None: 96 | tensor.set_shape([None, seq_len, None]) 97 | 98 | return tensor 99 | 100 | 101 | """ 102 | Where length is the second dimension 103 | """ 104 | def pad_to_len_1d(tensor, l:int, name=None): 105 | delta = l - tf.shape(tensor)[1] 106 | tensor = tf.pad(tensor, [ [0,0], [0,tf.maximum(delta, 0)] ], name=name) # zero pad out 107 | # tensor = dynamic_assert_shape(tensor, tf.shape(table_to_mimic)[0:1]+[tf.shape(tensor)[2]], name) 108 | return tensor 109 | 110 | 111 | 112 | 113 | 114 | def vector_to_barcode(tensor): 115 | width = tf.shape(tensor)[-1] 116 | barcode_height = tf.cast(tf.round(tf.div(tf.cast(width, tf.float32), 3.0)), tf.int32) 117 | barcode_image = tf.tile(tf.reshape(tensor, [-1, 1, width, 1]), [1, barcode_height, 1, 1]) 118 | return barcode_image 119 | 120 | 121 | 122 | 123 | def add_positional_encoding_1d(tensor, seq_axis=1, word_axis=2, dtype=tf.float32): 124 | ''' 125 | The function is based on https://github.com/stanfordnlp/mac-network 126 | 127 | Computes sin/cos positional encoding for h x w x (4*dim). 128 | If outDim positive, casts positions to that dimension. 129 | Based on positional encoding presented in "Attention is all you need" 130 | 131 | Currently hard-coded for one setup of seq_axis and word_axis 132 | ''' 133 | 134 | assert len(tensor.shape) == 3, "Expecting tensor of shape [batch, seq, word]" 135 | 136 | in_tensor_shape = tf.shape(tensor) 137 | 138 | batch_len = tf.shape(tensor)[0] 139 | seq_len = tf.shape(tensor)[seq_axis] 140 | word_len = tf.shape(tensor)[word_axis] 141 | 142 | halfdim = tf.cast(word_len / 2, dtype) 143 | 144 | x = tf.expand_dims(tf.to_float(tf.range(seq_len)), axis=1) 145 | i = tf.expand_dims(tf.to_float(tf.range(halfdim)), axis=0) 146 | 147 | peSinX = tf.sin(x / (tf.pow(10000.0, i / halfdim))) 148 | peCosX = tf.cos(x / (tf.pow(10000.0, i / halfdim))) 149 | 150 | pe = tf.concat([peSinX, peCosX], axis=-1) 151 | pe = tf.expand_dims(pe, 0) 152 | # pe = tf.tile(pe, [batch, 1, 1]) 153 | # pe = dynamic_assert_shape(pe, tf.shape(tensor)) 154 | 155 | # Original paper 156 | tensor = tensor + pe 157 | tensor = dynamic_assert_shape(tensor, in_tensor_shape) 158 | 159 | # Concat method 160 | # tensor = tf.concat([tensor,pe], axis=word_axis) 161 | 162 | 163 | return tensor 164 | 165 | 166 | 167 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | task=StationShortestCount 4 | iteration=10 5 | 6 | python -m macgraph.train \ 7 | --dataset $task \ 8 | --tag iter_$iteration \ 9 | --tag upto_9 \ 10 | --tag r$RANDOM \ 11 | --filter-output-class 0 \ 12 | --filter-output-class 1 \ 13 | --filter-output-class 2 \ 14 | --filter-output-class 3 \ 15 | --filter-output-class 4 \ 16 | --filter-output-class 5 \ 17 | --filter-output-class 6 \ 18 | --filter-output-class 7 \ 19 | --filter-output-class 8 \ 20 | --filter-output-class 9 \ 21 | --train-max-steps 10 \ 22 | --max-decode-iterations $iteration \ 23 | --fast \ 24 | $@ -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .ploty import Ploty 3 | from .file import FileWritey, FileReadie, path_exists -------------------------------------------------------------------------------- /util/file.py: -------------------------------------------------------------------------------- 1 | 2 | import os.path 3 | import sys 4 | import tensorflow as tf 5 | 6 | import logging 7 | logger = logging.getLogger(__name__) 8 | 9 | # Now that I know tf.gfile is a thing, none of this is needed 10 | 11 | 12 | class FileThingy(object): 13 | 14 | def __init__(self, args, filename): 15 | self.args = args 16 | self.filename = filename 17 | 18 | @property 19 | def file_dir(self): 20 | return os.path.join(self.args.output_dir, self.args.run) 21 | 22 | @property 23 | def file_path(self): 24 | return os.path.join(self.args.output_dir, self.args.run, self.filename) 25 | 26 | @property 27 | def gcs_path(self): 28 | return os.path.join(self.args.gcs_dir, self.args.run, self.filename) 29 | 30 | 31 | def path_exists(path): 32 | return tf.gfile.Exists(path) 33 | 34 | 35 | class FileReadie(FileThingy): 36 | """Tries to write on traditional filesystem and Google Cloud storage""" 37 | 38 | def __init__(self, args, filename, binary=False): 39 | super().__init__(args, filename) 40 | self.open_str = "rb" if binary else "r" 41 | 42 | def __enter__(self): 43 | self.file = tf.gfile.GFile(self.file_path, self.open_str) 44 | return self.file 45 | 46 | def __exit__(self, type, value, traceback): 47 | self.file.close() 48 | 49 | 50 | 51 | 52 | class FileWritey(FileThingy): 53 | """Tries to write on traditional filesystem and Google Cloud storage""" 54 | 55 | def __init__(self, args, filename, binary=False): 56 | super().__init__(args, filename) 57 | self.open_str = "wb" if binary else "w" 58 | 59 | def __enter__(self): 60 | try: 61 | os.makedirs(self.file_dir, exist_ok=True) 62 | except Exception: 63 | pass 64 | 65 | self.file = tf.gfile.GFile(self.file_path, self.open_str) 66 | 67 | return self.file 68 | 69 | def __exit__(self, type, value, traceback): 70 | self.file.close() 71 | 72 | 73 | 74 | --------------------------------------------------------------------------------