├── .gitignore ├── .travis.yml ├── AUTHORS.rst ├── LICENSE ├── Pipfile ├── Pipfile.lock ├── README.rst ├── quantgov ├── __init__.py ├── __main__.py ├── corpus.py ├── ml │ ├── __init__.py │ ├── candidate_sets.py │ ├── estimation.py │ ├── evaluation.py │ ├── structures.py │ ├── training.py │ └── utils.py ├── nlp.py └── utils.py ├── setup.cfg ├── setup.py └── tests ├── pseudo_corpus ├── data │ └── clean │ │ ├── cfr.txt │ │ └── moby.txt └── driver.py ├── pseudo_estimator ├── .gitignore └── data │ ├── binary.qge │ └── multiclass.qge ├── test_downloads.py ├── test_ml.py └── test_nlp.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | .pytest_cache 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | env/ 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *,cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # IPython Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # dotenv 80 | .env 81 | 82 | # virtualenv 83 | venv/ 84 | ENV/ 85 | 86 | # Spyder project settings 87 | .spyderproject 88 | 89 | # Rope project settings 90 | .ropeproject 91 | 92 | .DS_Store 93 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | python: 3 | - '3.5' 4 | - '3.6' 5 | install: 6 | - pip install ".[testing]" 7 | - pip install ".[nlp]" 8 | - pip install ".[s3driver]" 9 | - python -m nltk.downloader punkt stopwords wordnet 10 | script: pytest 11 | deploy: 12 | provider: pypi 13 | user: QuantGov 14 | password: 15 | secure: klIbNSfnAAFP0EA1/KVrov+3mJNxgEtEoGFIV+7meDsKruEKcKAg6EHFmFozo0hpgPW2jPw0d+wFZ0EcrDqlKA7QKX7wYUPHKQ/QLEbJPWT5l70ZjMPzBEy4Vl71KX0SMCH4khB05ZliD4rbMHoviSByA7LuPi53NdH09qhwcll3NS4cm9ModnJBz0gX1k4b/2YkPHzBSBIMXpNsf/AGl76T7YYxaNDGNFvmIvg7ZVzFScRUVzf8UiEQ5M2njlbanbPbySL8rBrmBRGa3RIm1PNtl/nNiEMY0pt8kc/dGVAJcznqsoPdSYjIlxzHVok++ZjrVlbQqN5JqTub9ycUN894z/jdWDEzE81PjJ0FPB/2c2vtZdXR2IFpdM2Mp09GZmqTOpa8ec48hLJUstI0hyW2Rp/mLOVIt4sKwkE3ULCKTI01TDASKAhQJzxyv+UlW17CdkXd/dyuUchjJm6ZJtQC9hiEmaV/Yh+EPrgoE6nVZzDSk+3vJ/cYiRSlMTuHs0rXKVpxkGLtGEYJgrzu5dQzSz1oQA0Hfic2SZSOFvu9R6MYnLm4maxmf2KgI9sr213JWqBpkfuGOkrCTRjdLJOmVoGUPofuivP1IdoBEYzsS9shaiAe5gJ+ivK937OeaiN6c71+lFRL0rVuWddRFq8qw5fa++Z9PEjW9Ki2bcI= 16 | on: 17 | tags: true 18 | distributions: sdist bdist_wheel 19 | -------------------------------------------------------------------------------- /AUTHORS.rst: -------------------------------------------------------------------------------- 1 | Quantgov is written and maintained by the Mercatus Center at George Mason University and 2 | various contributors: 3 | 4 | Primary Developers 5 | `````````````````` 6 | 7 | - Oliver Sherouse `@osherouse `_ Primary Author 8 | - Daniel Francis `@dfrancis `_ 9 | - Michael Gasvoda `@mgasvoda `_ 10 | - Stephen Strosko `@sstrosko `_ 11 | - Jonathan Nelson `@jnelson16 `_ 12 | 13 | 14 | Patches and Suggestions 15 | ``````````````````````` 16 | 17 | - TBD -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 QuantGov 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Pipfile: -------------------------------------------------------------------------------- 1 | [[source]] 2 | url = "https://pypi.python.org/simple" 3 | verify_ssl = true 4 | name = "pypi" 5 | 6 | [packages] 7 | e1839a8 = {path = ".",extras = ["nlp", "s3driver"],editable = true} 8 | textstat = "https://github.com/jnelson16/textstat.git" 9 | 10 | [dev-packages] 11 | "pytest-flake8" = "*" 12 | ipython = "*" 13 | -------------------------------------------------------------------------------- /Pipfile.lock: -------------------------------------------------------------------------------- 1 | { 2 | "_meta": { 3 | "hash": { 4 | "sha256": "fb05682cf6faa7e15048ad5a8688ff5675dc87a47e496516bc7395c4c66a68f8" 5 | }, 6 | "pipfile-spec": 6, 7 | "requires": {}, 8 | "sources": [ 9 | { 10 | "name": "pypi", 11 | "url": "https://pypi.python.org/simple", 12 | "verify_ssl": true 13 | } 14 | ] 15 | }, 16 | "default": { 17 | "boto3": { 18 | "hashes": [ 19 | "sha256:366a1f3ec37b9434f25247cbe876f9ca1b53d35e35af18f74c735445100b4bc4", 20 | "sha256:e7718b48cd073ad59a99a33d14252319dfaf550be3682b0c6a58da052fb05fcc" 21 | ], 22 | "version": "==1.9.217" 23 | }, 24 | "botocore": { 25 | "hashes": [ 26 | "sha256:68a0a22ca4e0e7e7ab482f63e21debfe402841fc49b8503dec0a7307b565d774", 27 | "sha256:7a213b876e58b1b5380cf30faa05ba45073692ad4a3cc803ba763082a36436bb" 28 | ], 29 | "version": "==1.12.217" 30 | }, 31 | "certifi": { 32 | "hashes": [ 33 | "sha256:046832c04d4e752f37383b628bc601a7ea7211496b4638f6514d0e5b9acc4939", 34 | "sha256:945e3ba63a0b9f577b1395204e13c3a231f9bc0223888be653286534e5873695" 35 | ], 36 | "version": "==2019.6.16" 37 | }, 38 | "chardet": { 39 | "hashes": [ 40 | "sha256:84ab92ed1c4d4f16916e05906b6b75a6c0fb5db821cc65e70cbd64a3e2a5eaae", 41 | "sha256:fc323ffcaeaed0e0a02bf4d117757b98aed530d9ed4531e3e15460124c106691" 42 | ], 43 | "version": "==3.0.4" 44 | }, 45 | "decorator": { 46 | "hashes": [ 47 | "sha256:86156361c50488b84a3f148056ea716ca587df2f0de1d34750d35c21312725de", 48 | "sha256:f069f3a01830ca754ba5258fde2278454a0b5b79e0d7f5c13b3b97e57d4acff6" 49 | ], 50 | "version": "==4.4.0" 51 | }, 52 | "docutils": { 53 | "hashes": [ 54 | "sha256:6c4f696463b79f1fb8ba0c594b63840ebd41f059e92b31957c46b74a4599b6d0", 55 | "sha256:9e4d7ecfc600058e07ba661411a2b7de2fd0fafa17d1a7f7361cd47b1175c827", 56 | "sha256:a2aeea129088da402665e92e0b25b04b073c04b2dce4ab65caaa38b7ce2e1a99" 57 | ], 58 | "version": "==0.15.2" 59 | }, 60 | "e1839a8": { 61 | "editable": true, 62 | "extras": [ 63 | "nlp", 64 | "s3driver" 65 | ], 66 | "path": "." 67 | }, 68 | "idna": { 69 | "hashes": [ 70 | "sha256:c357b3f628cf53ae2c4c05627ecc484553142ca23264e593d327bcde5e9c3407", 71 | "sha256:ea8b7f6188e6fa117537c3df7da9fc686d485087abf6ac197f9c46432f7e4a3c" 72 | ], 73 | "version": "==2.8" 74 | }, 75 | "jmespath": { 76 | "hashes": [ 77 | "sha256:3720a4b1bd659dd2eecad0666459b9788813e032b83e7ba58578e48254e0a0e6", 78 | "sha256:bde2aef6f44302dfb30320115b17d030798de8c4110e28d5cf6cf91a7a31074c" 79 | ], 80 | "version": "==0.9.4" 81 | }, 82 | "joblib": { 83 | "hashes": [ 84 | "sha256:21e0c34a69ad7fde4f2b1f3402290e9ec46f545f15f1541c582edfe05d87b63a", 85 | "sha256:315d6b19643ec4afd4c41c671f9f2d65ea9d787da093487a81ead7b0bac94524" 86 | ], 87 | "version": "==0.13.2" 88 | }, 89 | "nltk": { 90 | "hashes": [ 91 | "sha256:bed45551259aa2101381bbdd5df37d44ca2669c5c3dad72439fa459b29137d94" 92 | ], 93 | "version": "==3.4.5" 94 | }, 95 | "numpy": { 96 | "hashes": [ 97 | "sha256:03f2ebcbffcce2dec8860633b89a93e80c6a239d21a77ae8b241450dc21e8c35", 98 | "sha256:078c8025da5ab9e8657edc9c2a1e9642e06e953bc7baa2e65c1aa9d9dfb7e98b", 99 | "sha256:0fbfa98c5d5c3c6489cc1e852ec94395d51f35d9ebe70c6850e47f465038cdf4", 100 | "sha256:1c841033f4fe6801648180c3033c45b3235a8bbd09bc7249010f99ea27bb6790", 101 | "sha256:2c0984a01ddd0aeec89f0ce46ef21d64761048cd76c0074d0658c91f9131f154", 102 | "sha256:4c166dcb0fff7cb3c0bbc682dfb5061852a2547efb6222e043a7932828c08fb5", 103 | "sha256:8c2d98d0623bd63fb883b65256c00454d5f53127a5a7bcdaa8bdc582814e8cb4", 104 | "sha256:8cb4b6ae45aad6d26712a1ce0a3f2556c5e1484867f9649e03496e45d6a5eba4", 105 | "sha256:93050e73c446c82065b7410221b07682e475ac51887cd9368227a5d944afae80", 106 | "sha256:a3f6b3024f8826d8b1490e6e2a9b99e841cd2c375791b1df62991bd8f4c00b89", 107 | "sha256:bede70fd8699695363f39e86c1e869b2c8b74fb5ef135a67b9e1eeebff50322a", 108 | "sha256:c304b2221f33489cd15a915237a84cdfe9420d7e4d4828c78a0820f9d990395c", 109 | "sha256:f11331530f0eff69a758d62c2461cd98cdc2eae0147279d8fc86e0464eb7e8ca", 110 | "sha256:fa5f2a8ef1e07ba258dc07d4dd246de23ef4ab920ae0f3fa2a1cc5e90f0f1888", 111 | "sha256:fb6178b0488b0ce6a54bc4accbdf5225e937383586555604155d64773f6beb2b", 112 | "sha256:fd5e830d4dc31658d61a6452cd3e842213594d8c15578cdae6829e36ad9c0930" 113 | ], 114 | "version": "==1.17.1" 115 | }, 116 | "pandas": { 117 | "hashes": [ 118 | "sha256:18d91a9199d1dfaa01ad645f7540370ba630bdcef09daaf9edf45b4b1bca0232", 119 | "sha256:3f26e5da310a0c0b83ea50da1fd397de2640b02b424aa69be7e0784228f656c9", 120 | "sha256:4182e32f4456d2c64619e97c58571fa5ca0993d1e8c2d9ca44916185e1726e15", 121 | "sha256:426e590e2eb0e60f765271d668a30cf38b582eaae5ec9b31229c8c3c10c5bc21", 122 | "sha256:5eb934a8f0dc358f0e0cdf314072286bbac74e4c124b64371395e94644d5d919", 123 | "sha256:717928808043d3ea55b9bcde636d4a52d2236c246f6df464163a66ff59980ad8", 124 | "sha256:8145f97c5ed71827a6ec98ceaef35afed1377e2d19c4078f324d209ff253ecb5", 125 | "sha256:8744c84c914dcc59cbbb2943b32b7664df1039d99e834e1034a3372acb89ea4d", 126 | "sha256:c1ac1d9590d0c9314ebf01591bd40d4c03d710bfc84a3889e5263c97d7891dee", 127 | "sha256:cb2e197b7b0687becb026b84d3c242482f20cbb29a9981e43604eb67576da9f6", 128 | "sha256:d4001b71ad2c9b84ff18b182cea22b7b6cbf624216da3ea06fb7af28d1f93165", 129 | "sha256:d8930772adccb2882989ab1493fa74bd87d47c8ac7417f5dd3dd834ba8c24dc9", 130 | "sha256:dfbb0173ee2399bc4ed3caf2d236e5c0092f948aafd0a15fbe4a0e77ee61a958", 131 | "sha256:eebfbba048f4fa8ac711b22c78516e16ff8117d05a580e7eeef6b0c2be554c18", 132 | "sha256:f1b21bc5cf3dbea53d33615d1ead892dfdae9d7052fa8898083bec88be20dcd2" 133 | ], 134 | "version": "==0.25.1" 135 | }, 136 | "pyphen": { 137 | "hashes": [ 138 | "sha256:3b633a50873156d777e1f1075ba4d8e96a6ad0a3ca42aa3ea9a6259f93f18921", 139 | "sha256:e172faf10992c8c9d369bdc83e36dbcf1121f4ed0d881f1a0b521935aee583b5" 140 | ], 141 | "version": "==0.9.5" 142 | }, 143 | "python-dateutil": { 144 | "hashes": [ 145 | "sha256:7e6584c74aeed623791615e26efd690f29817a27c73085b78e4bad02493df2fb", 146 | "sha256:c89805f6f4d64db21ed966fda138f8a5ed7a4fdbc1a8ee329ce1b74e3c74da9e" 147 | ], 148 | "markers": "python_version >= '2.7'", 149 | "version": "==2.8.0" 150 | }, 151 | "pytz": { 152 | "hashes": [ 153 | "sha256:26c0b32e437e54a18161324a2fca3c4b9846b74a8dccddd843113109e1116b32", 154 | "sha256:c894d57500a4cd2d5c71114aaab77dbab5eabd9022308ce5ac9bb93a60a6f0c7" 155 | ], 156 | "version": "==2019.2" 157 | }, 158 | "repoze.lru": { 159 | "hashes": [ 160 | "sha256:0429a75e19380e4ed50c0694e26ac8819b4ea7851ee1fc7583c8572db80aff77", 161 | "sha256:f77bf0e1096ea445beadd35f3479c5cff2aa1efe604a133e67150bc8630a62ea" 162 | ], 163 | "version": "==0.7" 164 | }, 165 | "requests": { 166 | "hashes": [ 167 | "sha256:11e007a8a2aa0323f5a921e9e6a2d7e4e67d9877e85773fba9ba6419025cbeb4", 168 | "sha256:9cf5292fcd0f598c671cfc1e0d7d1a7f13bb8085e9a590f48c010551dc6c4b31" 169 | ], 170 | "version": "==2.22.0" 171 | }, 172 | "s3transfer": { 173 | "hashes": [ 174 | "sha256:6efc926738a3cd576c2a79725fed9afde92378aa5c6a957e3af010cb019fac9d", 175 | "sha256:b780f2411b824cb541dbcd2c713d0cb61c7d1bcadae204cdddda2b35cef493ba" 176 | ], 177 | "version": "==0.2.1" 178 | }, 179 | "scikit-learn": { 180 | "hashes": [ 181 | "sha256:1ac81293d261747c25ea5a0ee8cd2bb1f3b5ba9ec05421a7f9f0feb4eb7c4116", 182 | "sha256:289361cf003d90b007f5066b27fcddc2d71324c82f1c88e316fedacb0dfdd516", 183 | "sha256:3a14d0abd4281fc3fd2149c486c3ec7cedad848b8d5f7b6f61522029d65a29f8", 184 | "sha256:5083a5e50d9d54548e4ada829598ae63a05651dd2bb319f821ffd9e8388384a6", 185 | "sha256:777cdd5c077b7ca9cb381396c81990cf41d2fa8350760d3cad3b4c460a7db644", 186 | "sha256:8bf2ff63da820d09b96b18e88f9625228457bff8df4618f6b087e12442ef9e15", 187 | "sha256:8d319b71c449627d178f21c57614e21747e54bb3fc9602b6f42906c3931aa320", 188 | "sha256:928050b65781fea9542dfe9bfe02d8c4f5530baa8472ec60782ea77347d2c836", 189 | "sha256:92c903613ff50e22aa95d589f9fff5deb6f34e79f7f21f609680087f137bb524", 190 | "sha256:ae322235def5ce8fae645b439e332e6f25d34bb90d6a6c8e261f17eb476457b7", 191 | "sha256:c1cd6b29eb1fd1cc672ac5e4a8be5f6ea936d094a3dc659ada0746d6fac750b1", 192 | "sha256:c41a6e2685d06bcdb0d26533af2540f54884d40db7e48baed6a5bcbf1a7cc642", 193 | "sha256:d07fcb0c0acbc043faa0e7cf4d2037f71193de3fb04fb8ed5c259b089af1cf5c", 194 | "sha256:d146d5443cda0a41f74276e42faf8c7f283fef49e8a853b832885239ef544e05", 195 | "sha256:eb2b7bed0a26ba5ce3700e15938b28a4f4513578d3e54a2156c29df19ac5fd01", 196 | "sha256:eb9b8ebf59eddd8b96366428238ab27d05a19e89c5516ce294abc35cea75d003" 197 | ], 198 | "version": "==0.21.3" 199 | }, 200 | "scipy": { 201 | "hashes": [ 202 | "sha256:0baa64bf42592032f6f6445a07144e355ca876b177f47ad8d0612901c9375bef", 203 | "sha256:243b04730d7223d2b844bda9500310eecc9eda0cba9ceaf0cde1839f8287dfa8", 204 | "sha256:2643cfb46d97b7797d1dbdb6f3c23fe3402904e3c90e6facfe6a9b98d808c1b5", 205 | "sha256:396eb4cdad421f846a1498299474f0a3752921229388f91f60dc3eda55a00488", 206 | "sha256:3ae3692616975d3c10aca6d574d6b4ff95568768d4525f76222fb60f142075b9", 207 | "sha256:435d19f80b4dcf67dc090cc04fde2c5c8a70b3372e64f6a9c58c5b806abfa5a8", 208 | "sha256:46a5e55850cfe02332998b3aef481d33f1efee1960fe6cfee0202c7dd6fc21ab", 209 | "sha256:75b513c462e58eeca82b22fc00f0d1875a37b12913eee9d979233349fce5c8b2", 210 | "sha256:7ccfa44a08226825126c4ef0027aa46a38c928a10f0a8a8483c80dd9f9a0ad44", 211 | "sha256:89dd6a6d329e3f693d1204d5562dd63af0fd7a17854ced17f9cbc37d5b853c8d", 212 | "sha256:a81da2fe32f4eab8b60d56ad43e44d93d392da228a77e229e59b51508a00299c", 213 | "sha256:a9d606d11eb2eec7ef893eb825017fbb6eef1e1d0b98a5b7fc11446ebeb2b9b1", 214 | "sha256:ac37eb652248e2d7cbbfd89619dce5ecfd27d657e714ed049d82f19b162e8d45", 215 | "sha256:cbc0611699e420774e945f6a4e2830f7ca2b3ee3483fca1aa659100049487dd5", 216 | "sha256:d02d813ec9958ed63b390ded463163685af6025cb2e9a226ec2c477df90c6957", 217 | "sha256:dd3b52e00f93fd1c86f2d78243dfb0d02743c94dd1d34ffea10055438e63b99d" 218 | ], 219 | "version": "==1.3.1" 220 | }, 221 | "six": { 222 | "hashes": [ 223 | "sha256:3350809f0555b11f552448330d0b52d5f24c91a322ea4a15ef22629740f3761c", 224 | "sha256:d16a0141ec1a18405cd4ce8b4613101da75da0e9a7aec5bdd4fa804d0e0eba73" 225 | ], 226 | "version": "==1.12.0" 227 | }, 228 | "sqlalchemy": { 229 | "hashes": [ 230 | "sha256:2f8ff566a4d3a92246d367f2e9cd6ed3edeef670dcd6dda6dfdc9efed88bcd80" 231 | ], 232 | "version": "==1.3.8" 233 | }, 234 | "textblob": { 235 | "hashes": [ 236 | "sha256:7ff3c00cb5a85a30132ee6768b8c68cb2b9d76432fec18cd1b3ffe2f8594ec8c", 237 | "sha256:b0eafd8b129c9b196c8128056caed891d64b7fa20ba570e1fcde438f4f7dd312" 238 | ], 239 | "version": "==0.15.3" 240 | }, 241 | "textstat": { 242 | "hashes": [ 243 | "sha256:5e1342bf87b4660f5437a36ce0a12cc987885187527c97c6b1f19557811df4d6", 244 | "sha256:c50ad2691763c74508e35e554da2ad8aee748537c999388f93e218fcab9ab12f", 245 | "sha256:fd225f95cb558fa2923b2bea4991f77e8dcd66eb9b544824a297b04a6a0d4425" 246 | ], 247 | "index": "pypi", 248 | "version": "==0.5.6" 249 | }, 250 | "urllib3": { 251 | "hashes": [ 252 | "sha256:b246607a25ac80bedac05c6f282e3cdaf3afb65420fd024ac94435cabe6e18d1", 253 | "sha256:dbe59173209418ae49d485b87d1681aefa36252ee85884c31346debd19463232" 254 | ], 255 | "markers": "python_version >= '3.4'", 256 | "version": "==1.25.3" 257 | } 258 | }, 259 | "develop": { 260 | "appnope": { 261 | "hashes": [ 262 | "sha256:5b26757dc6f79a3b7dc9fab95359328d5747fcb2409d331ea66d0272b90ab2a0", 263 | "sha256:8b995ffe925347a2138d7ac0fe77155e4311a0ea6d6da4f5128fe4b3cbe5ed71" 264 | ], 265 | "markers": "sys_platform == 'darwin'", 266 | "version": "==0.1.0" 267 | }, 268 | "atomicwrites": { 269 | "hashes": [ 270 | "sha256:03472c30eb2c5d1ba9227e4c2ca66ab8287fbfbbda3888aa93dc2e28fc6811b4", 271 | "sha256:75a9445bac02d8d058d5e1fe689654ba5a6556a1dfd8ce6ec55a0ed79866cfa6" 272 | ], 273 | "version": "==1.3.0" 274 | }, 275 | "attrs": { 276 | "hashes": [ 277 | "sha256:69c0dbf2ed392de1cb5ec704444b08a5ef81680a61cb899dc08127123af36a79", 278 | "sha256:f0b870f674851ecbfbbbd364d6b5cbdff9dcedbc7f3f5e18a6891057f21fe399" 279 | ], 280 | "version": "==19.1.0" 281 | }, 282 | "backcall": { 283 | "hashes": [ 284 | "sha256:38ecd85be2c1e78f77fd91700c76e14667dc21e2713b63876c0eb901196e01e4", 285 | "sha256:bbbf4b1e5cd2bdb08f915895b51081c041bac22394fdfcfdfbe9f14b77c08bf2" 286 | ], 287 | "version": "==0.1.0" 288 | }, 289 | "decorator": { 290 | "hashes": [ 291 | "sha256:86156361c50488b84a3f148056ea716ca587df2f0de1d34750d35c21312725de", 292 | "sha256:f069f3a01830ca754ba5258fde2278454a0b5b79e0d7f5c13b3b97e57d4acff6" 293 | ], 294 | "version": "==4.4.0" 295 | }, 296 | "entrypoints": { 297 | "hashes": [ 298 | "sha256:589f874b313739ad35be6e0cd7efde2a4e9b6fea91edcc34e58ecbb8dbe56d19", 299 | "sha256:c70dd71abe5a8c85e55e12c19bd91ccfeec11a6e99044204511f9ed547d48451" 300 | ], 301 | "version": "==0.3" 302 | }, 303 | "flake8": { 304 | "hashes": [ 305 | "sha256:19241c1cbc971b9962473e4438a2ca19749a7dd002dd1a946eaba171b4114548", 306 | "sha256:8e9dfa3cecb2400b3738a42c54c3043e821682b9c840b0448c0503f781130696" 307 | ], 308 | "version": "==3.7.8" 309 | }, 310 | "importlib-metadata": { 311 | "hashes": [ 312 | "sha256:23d3d873e008a513952355379d93cbcab874c58f4f034ff657c7a87422fa64e8", 313 | "sha256:80d2de76188eabfbfcf27e6a37342c2827801e59c4cc14b0371c56fed43820e3" 314 | ], 315 | "markers": "python_version < '3.8'", 316 | "version": "==0.19" 317 | }, 318 | "ipython": { 319 | "hashes": [ 320 | "sha256:1d3a1692921e932751bc1a1f7bb96dc38671eeefdc66ed33ee4cbc57e92a410e", 321 | "sha256:537cd0176ff6abd06ef3e23f2d0c4c2c8a4d9277b7451544c6cbf56d1c79a83d" 322 | ], 323 | "index": "pypi", 324 | "version": "==7.7.0" 325 | }, 326 | "ipython-genutils": { 327 | "hashes": [ 328 | "sha256:72dd37233799e619666c9f639a9da83c34013a73e8bbc79a7a6348d93c61fab8", 329 | "sha256:eb2e116e75ecef9d4d228fdc66af54269afa26ab4463042e33785b887c628ba8" 330 | ], 331 | "version": "==0.2.0" 332 | }, 333 | "jedi": { 334 | "hashes": [ 335 | "sha256:786b6c3d80e2f06fd77162a07fed81b8baa22dde5d62896a790a331d6ac21a27", 336 | "sha256:ba859c74fa3c966a22f2aeebe1b74ee27e2a462f56d3f5f7ca4a59af61bfe42e" 337 | ], 338 | "version": "==0.15.1" 339 | }, 340 | "mccabe": { 341 | "hashes": [ 342 | "sha256:ab8a6258860da4b6677da4bd2fe5dc2c659cff31b3ee4f7f5d64e79735b80d42", 343 | "sha256:dd8d182285a0fe56bace7f45b5e7d1a6ebcbf524e8f3bd87eb0f125271b8831f" 344 | ], 345 | "version": "==0.6.1" 346 | }, 347 | "more-itertools": { 348 | "hashes": [ 349 | "sha256:409cd48d4db7052af495b09dec721011634af3753ae1ef92d2b32f73a745f832", 350 | "sha256:92b8c4b06dac4f0611c0729b2f2ede52b2e1bac1ab48f089c7ddc12e26bb60c4" 351 | ], 352 | "version": "==7.2.0" 353 | }, 354 | "packaging": { 355 | "hashes": [ 356 | "sha256:a7ac867b97fdc07ee80a8058fe4435ccd274ecc3b0ed61d852d7d53055528cf9", 357 | "sha256:c491ca87294da7cc01902edbe30a5bc6c4c28172b5138ab4e4aa1b9d7bfaeafe" 358 | ], 359 | "version": "==19.1" 360 | }, 361 | "parso": { 362 | "hashes": [ 363 | "sha256:63854233e1fadb5da97f2744b6b24346d2750b85965e7e399bec1620232797dc", 364 | "sha256:666b0ee4a7a1220f65d367617f2cd3ffddff3e205f3f16a0284df30e774c2a9c" 365 | ], 366 | "version": "==0.5.1" 367 | }, 368 | "pexpect": { 369 | "hashes": [ 370 | "sha256:2094eefdfcf37a1fdbfb9aa090862c1a4878e5c7e0e7e7088bdb511c558e5cd1", 371 | "sha256:9e2c1fd0e6ee3a49b28f95d4b33bc389c89b20af6a1255906e90ff1262ce62eb" 372 | ], 373 | "markers": "sys_platform != 'win32'", 374 | "version": "==4.7.0" 375 | }, 376 | "pickleshare": { 377 | "hashes": [ 378 | "sha256:87683d47965c1da65cdacaf31c8441d12b8044cdec9aca500cd78fc2c683afca", 379 | "sha256:9649af414d74d4df115d5d718f82acb59c9d418196b7b4290ed47a12ce62df56" 380 | ], 381 | "version": "==0.7.5" 382 | }, 383 | "pluggy": { 384 | "hashes": [ 385 | "sha256:0825a152ac059776623854c1543d65a4ad408eb3d33ee114dff91e57ec6ae6fc", 386 | "sha256:b9817417e95936bf75d85d3f8767f7df6cdde751fc40aed3bb3074cbcb77757c" 387 | ], 388 | "version": "==0.12.0" 389 | }, 390 | "prompt-toolkit": { 391 | "hashes": [ 392 | "sha256:11adf3389a996a6d45cc277580d0d53e8a5afd281d0c9ec71b28e6f121463780", 393 | "sha256:2519ad1d8038fd5fc8e770362237ad0364d16a7650fb5724af6997ed5515e3c1", 394 | "sha256:977c6583ae813a37dc1c2e1b715892461fcbdaa57f6fc62f33a528c4886c8f55" 395 | ], 396 | "version": "==2.0.9" 397 | }, 398 | "ptyprocess": { 399 | "hashes": [ 400 | "sha256:923f299cc5ad920c68f2bc0bc98b75b9f838b93b599941a6b63ddbc2476394c0", 401 | "sha256:d7cc528d76e76342423ca640335bd3633420dc1366f258cb31d05e865ef5ca1f" 402 | ], 403 | "version": "==0.6.0" 404 | }, 405 | "py": { 406 | "hashes": [ 407 | "sha256:64f65755aee5b381cea27766a3a147c3f15b9b6b9ac88676de66ba2ae36793fa", 408 | "sha256:dc639b046a6e2cff5bbe40194ad65936d6ba360b52b3c3fe1d08a82dd50b5e53" 409 | ], 410 | "version": "==1.8.0" 411 | }, 412 | "pycodestyle": { 413 | "hashes": [ 414 | "sha256:95a2219d12372f05704562a14ec30bc76b05a5b297b21a5dfe3f6fac3491ae56", 415 | "sha256:e40a936c9a450ad81df37f549d676d127b1b66000a6c500caa2b085bc0ca976c" 416 | ], 417 | "version": "==2.5.0" 418 | }, 419 | "pyflakes": { 420 | "hashes": [ 421 | "sha256:17dbeb2e3f4d772725c777fabc446d5634d1038f234e77343108ce445ea69ce0", 422 | "sha256:d976835886f8c5b31d47970ed689944a0262b5f3afa00a5a7b4dc81e5449f8a2" 423 | ], 424 | "version": "==2.1.1" 425 | }, 426 | "pygments": { 427 | "hashes": [ 428 | "sha256:71e430bc85c88a430f000ac1d9b331d2407f681d6f6aec95e8bcfbc3df5b0127", 429 | "sha256:881c4c157e45f30af185c1ffe8d549d48ac9127433f2c380c24b84572ad66297" 430 | ], 431 | "version": "==2.4.2" 432 | }, 433 | "pyparsing": { 434 | "hashes": [ 435 | "sha256:6f98a7b9397e206d78cc01df10131398f1c8b8510a2f4d97d9abd82e1aacdd80", 436 | "sha256:d9338df12903bbf5d65a0e4e87c2161968b10d2e489652bb47001d82a9b028b4" 437 | ], 438 | "version": "==2.4.2" 439 | }, 440 | "pytest": { 441 | "hashes": [ 442 | "sha256:95b1f6db806e5b1b5b443efeb58984c24945508f93a866c1719e1a507a957d7c", 443 | "sha256:c3d5020755f70c82eceda3feaf556af9a341334414a8eca521a18f463bcead88" 444 | ], 445 | "version": "==5.1.1" 446 | }, 447 | "pytest-flake8": { 448 | "hashes": [ 449 | "sha256:4d225c13e787471502ff94409dcf6f7927049b2ec251c63b764a4b17447b60c0", 450 | "sha256:d7e2b6b274a255b7ae35e9224c85294b471a83b76ecb6bd53c337ae977a499af" 451 | ], 452 | "index": "pypi", 453 | "version": "==1.0.4" 454 | }, 455 | "six": { 456 | "hashes": [ 457 | "sha256:3350809f0555b11f552448330d0b52d5f24c91a322ea4a15ef22629740f3761c", 458 | "sha256:d16a0141ec1a18405cd4ce8b4613101da75da0e9a7aec5bdd4fa804d0e0eba73" 459 | ], 460 | "version": "==1.12.0" 461 | }, 462 | "traitlets": { 463 | "hashes": [ 464 | "sha256:9c4bd2d267b7153df9152698efb1050a5d84982d3384a37b2c1f7723ba3e7835", 465 | "sha256:c6cb5e6f57c5a9bdaa40fa71ce7b4af30298fbab9ece9815b5d995ab6217c7d9" 466 | ], 467 | "version": "==4.3.2" 468 | }, 469 | "wcwidth": { 470 | "hashes": [ 471 | "sha256:3df37372226d6e63e1b1e1eda15c594bca98a22d33a23832a90998faa96bc65e", 472 | "sha256:f4ebe71925af7b40a864553f761ed559b43544f8f71746c2d756c7fe788ade7c" 473 | ], 474 | "version": "==0.1.7" 475 | }, 476 | "zipp": { 477 | "hashes": [ 478 | "sha256:3718b1cbcd963c7d4c5511a8240812904164b7f381b647143a89d3b98f9bcd8e", 479 | "sha256:f06903e9f1f43b12d371004b4ac7b06ab39a44adc747266928ae6debfa7b3335" 480 | ], 481 | "version": "==0.6.0" 482 | } 483 | } 484 | } 485 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | QuantGov 2 | ======== 3 | 4 | ========== ==================== 5 | Branch Build Status 6 | ========== ==================== 7 | **Master** |Master Test Status| 8 | **Dev** |Dev Test Status| 9 | ========== ==================== 10 | 11 | The QuantGov library is a companion to the `QuantGov Platform 12 | `_. It provides an easy way to start a new project 13 | using the ``quantgov start`` set of commands, and also provides a set of 14 | classes and functions often used in the QuantGov framework. 15 | 16 | To install the library, use ``pip install quantgov``. 17 | 18 | Documentation is available at http://docs.quantgov.org. 19 | 20 | .. |Master Test Status| image:: https://travis-ci.org/QuantGov/quantgov.svg?branch=master 21 | :target: https://travis-ci.org/QuantGov/quantgov 22 | 23 | .. |Dev Test Status| image:: https://travis-ci.org/QuantGov/quantgov.svg?branch=dev 24 | :target: https://travis-ci.org/QuantGov/quantgov 25 | 26 | 27 | How to Contribute 28 | ----------------- 29 | 30 | #. Check for open issues or open a fresh issue to start a discussion around a feature idea or a bug. 31 | #. Fork `the repository`_ on GitHub to start making your changes to the **dev** branch (or branch off of it). 32 | #. Write a test which shows that the bug was fixed or that the feature works as expected. 33 | #. Send a pull request and bug the maintainer until it gets merged and published. Make sure to add yourself to AUTHORS_. 34 | 35 | .. _`the repository`: http://github.com/quantgov/quantov 36 | .. _AUTHORS: https://github.com/quantgov/quantgov/blob/master/AUTHORS.rst 37 | -------------------------------------------------------------------------------- /quantgov/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import (absolute_import, division, print_function, 2 | unicode_literals) 3 | 4 | from . import corpus, nlp, ml, utils 5 | from .utils import load_driver 6 | 7 | __version__ = '0.6.4' 8 | -------------------------------------------------------------------------------- /quantgov/__main__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Quantgov: a policy analytics framework 3 | """ 4 | 5 | import argparse 6 | import csv 7 | import io 8 | import functools 9 | import logging 10 | import shutil 11 | import sys 12 | import zipfile 13 | 14 | import requests 15 | 16 | import joblib as jl 17 | import quantgov 18 | 19 | from pathlib import Path 20 | 21 | log = logging.getLogger(__name__) 22 | 23 | 24 | _URL = 'https://github.com/QuantGov/{component}/archive/{parent}.zip' 25 | ENCODE_OUT = 'utf-8' 26 | 27 | 28 | def parse_args(): 29 | parser = argparse.ArgumentParser(description=__doc__) 30 | subparsers = parser.add_subparsers(dest='command') 31 | 32 | # Create command 33 | create = subparsers.add_parser('start') 34 | create.add_argument( 35 | 'component', choices=['corpus', 'estimator', 'project']) 36 | create.add_argument('path', type=Path) 37 | create.add_argument('--parent', default='master') 38 | 39 | # NLP command 40 | nlp_subparser = subparsers.add_parser('nlp') 41 | nlp_subcommands = nlp_subparser.add_subparsers(dest='subcommand') 42 | for command, builtin in quantgov.nlp.commands.items(): 43 | subcommand = nlp_subcommands.add_parser( 44 | command, help=builtin.cli.help) 45 | subcommand.add_argument( 46 | 'corpus', help='Path to a QuantGov Corpus directory') 47 | for argument in builtin.cli.arguments: 48 | flags = ((argument.flags,) if isinstance(argument.flags, str) 49 | else argument.flags) 50 | kwargs = {} if argument.kwargs is None else argument.kwargs 51 | subcommand.add_argument(*flags, **kwargs) 52 | subcommand.add_argument( 53 | '-o', '--outfile', 54 | type=lambda x: open(x, 'w', newline='', encoding=ENCODE_OUT), 55 | default=sys.stdout 56 | ) 57 | 58 | # ML Command 59 | ml_parser = subparsers.add_parser('ml') 60 | ml_subcommands = ml_parser.add_subparsers(dest='subcommand') 61 | 62 | # ML Evaluate 63 | evaluate = ml_subcommands.add_parser( 64 | 'evaluate', help='Evaluate candidate models') 65 | evaluate.add_argument( 66 | 'modeldefs', type=Path, 67 | help='python module containing candidate models' 68 | ) 69 | evaluate.add_argument( 70 | 'trainers', 71 | type=quantgov.ml.Trainers.load, 72 | help='saved Trainers object' 73 | ) 74 | evaluate.add_argument( 75 | 'labels', type=quantgov.ml.Labels.load, help='saved Labels object') 76 | evaluate.add_argument( 77 | 'output_results', 78 | type=lambda x: open(x, 'w', encoding=ENCODE_OUT), 79 | help='Output file for evaluation results' 80 | ) 81 | evaluate.add_argument( 82 | 'output_suggestion', 83 | type=lambda x: open(x, 'w', encoding=ENCODE_OUT), 84 | help='Output file for model suggestion' 85 | ) 86 | evaluate.add_argument( 87 | '--folds', type=int, default=5, 88 | help='Number of folds for cross-validation') 89 | evaluate.add_argument('--scoring', default='f1', help='scoring method') 90 | 91 | # ML Train 92 | train = ml_subcommands.add_parser('train', help='Train a model') 93 | train.add_argument( 94 | 'modeldefs', type=Path, 95 | help='Python module containing candidate models' 96 | ) 97 | train.add_argument('configfile', help='Model configuration file') 98 | train.add_argument( 99 | 'vectorizer', 100 | type=jl.load, 101 | help='saved Vectorizer object' 102 | ) 103 | train.add_argument( 104 | 'trainers', 105 | type=quantgov.ml.Trainers.load, 106 | help='saved Trainers object' 107 | ) 108 | train.add_argument( 109 | 'labels', type=quantgov.ml.Labels.load, help='saved Labels object') 110 | train.add_argument( 111 | '-o', '--outfile', help='location to save the trained Estimator' 112 | ) 113 | 114 | # ML Estimate 115 | estimate = ml_subcommands.add_parser( 116 | 'estimate', help='Estimate label values for a target corpus') 117 | estimate.add_argument( 118 | 'estimator', 119 | type=quantgov.ml.Estimator.load, 120 | help='saved Estimator object' 121 | ) 122 | estimate.add_argument( 123 | 'corpus', type=quantgov.load_driver, 124 | help='Path to a QuantGov corpus') 125 | estimate.add_argument( 126 | '--probability', action='store_true', 127 | help='output probabilities instead of predictions') 128 | estimate.add_argument( 129 | '--precision', default=4, type=int, 130 | help='number of decimal places to round the probabilities') 131 | estimate.add_argument( 132 | '--oneclass', action='store_true', 133 | help='only return predicted class for multiclass probabilty estimates') 134 | estimate.add_argument( 135 | '-o', '--outfile', 136 | type=lambda x: open(x, 'w', newline='', encoding='utf-8'), 137 | default=sys.stdout, 138 | help='location to save estimation results' 139 | ) 140 | 141 | return parser.parse_args() 142 | 143 | 144 | def download(component, parent, outdir): 145 | response = requests.get( 146 | _URL.format(component=component, parent=parent), 147 | ) 148 | archive = zipfile.ZipFile(io.BytesIO(response.content)) 149 | for name in archive.namelist(): 150 | if name.split('/', 1)[-1] == '': 151 | continue 152 | outfile = outdir.joinpath(name.split('/', 1)[-1]) 153 | if not outfile.parent.exists(): 154 | outfile.parent.mkdir(parents=True) 155 | if name.endswith('/'): 156 | outfile.mkdir() 157 | continue 158 | with outfile.open('wb') as outf, archive.open(name) as inf: 159 | outf.write(inf.read()) 160 | 161 | 162 | def start_component(args): 163 | if args.path.exists(): 164 | log.error("A file or folder with that name already exists") 165 | exit(1) 166 | args.path.mkdir() 167 | try: 168 | download(args.component, args.parent, args.path) 169 | except Exception: 170 | shutil.rmtree(str(args.path)) 171 | raise 172 | 173 | 174 | def run_corpus_builtin(args): 175 | driver = quantgov.load_driver(args.corpus) 176 | writer = csv.writer(args.outfile) 177 | builtin = quantgov.nlp.commands[args.subcommand] 178 | func_args = {i: j for i, j in vars(args).items() 179 | if i not in {'command', 'subcommand', 'outfile', 'corpus'}} 180 | writer.writerow(driver.index_labels + builtin.get_columns(func_args)) 181 | partial = functools.partial( 182 | builtin.process_document, 183 | **func_args 184 | ) 185 | for result in quantgov.utils.lazy_parallel(partial, driver.stream()): 186 | if result: 187 | writer.writerow(result) 188 | args.outfile.flush() 189 | 190 | 191 | def run_estimator(args): 192 | if args.subcommand == "evaluate": 193 | quantgov.ml.evaluate( 194 | args.modeldefs, args.trainers, args.labels, args.folds, 195 | args.scoring, args.output_results, args.output_suggestion 196 | ) 197 | elif args.subcommand == "train": 198 | quantgov.ml.train_and_save_model( 199 | args.modeldefs, args.configfile, args.vectorizer, args.trainers, 200 | args.labels, args.outfile) 201 | elif args.subcommand == "estimate": 202 | writer = csv.writer(args.outfile) 203 | labels = args.corpus.index_labels 204 | if args.probability: 205 | if args.estimator.multilabel: 206 | if args.estimator.multiclass: 207 | writer.writerow(labels + ('label', 'class', 'probability')) 208 | else: 209 | writer.writerow(labels + ('label', 'probability')) 210 | elif args.estimator.multiclass: 211 | writer.writerow(labels + ('class', 'probability')) 212 | else: 213 | writer.writerow( 214 | labels + ('{}_prob'.format(args.estimator.label_names[0]),) 215 | ) 216 | else: 217 | if args.estimator.multilabel: 218 | writer.writerow(labels + ('label', 'prediction')) 219 | else: 220 | writer.writerow( 221 | labels + ('{}'.format(args.estimator.label_names[0]),) 222 | ) 223 | writer.writerows( 224 | docidx + result for docidx, 225 | result in quantgov.ml.estimate( 226 | args.estimator, 227 | args.corpus, 228 | args.probability, 229 | args.precision, 230 | args.oneclass) 231 | ) 232 | 233 | 234 | def main(): 235 | args = parse_args() 236 | { 237 | 'start': start_component, 238 | 'nlp': run_corpus_builtin, 239 | 'ml': run_estimator, 240 | }[args.command](args) 241 | 242 | 243 | if __name__ == '__main__': 244 | main() 245 | -------------------------------------------------------------------------------- /quantgov/corpus.py: -------------------------------------------------------------------------------- 1 | """ 2 | quantgov.corpus 3 | 4 | Classes for Writing QuantGov Corpora 5 | """ 6 | 7 | import re 8 | import collections 9 | import csv 10 | import logging 11 | 12 | from decorator import decorator 13 | from collections import namedtuple 14 | from pathlib import Path 15 | 16 | from . import utils as qgutils 17 | 18 | try: 19 | import boto3 20 | except ImportError: 21 | boto3 = None 22 | try: 23 | import sqlalchemy 24 | except ImportError: 25 | sqlalchemy = None 26 | 27 | log = logging.getLogger(__name__) 28 | 29 | Document = namedtuple('Document', ['index', 'text']) 30 | 31 | 32 | @decorator 33 | def check_boto(func, *args, **kwargs): 34 | if boto3 is None: 35 | raise RuntimeError('Must install boto3 to use {}'.format(func)) 36 | return func(*args, **kwargs) 37 | 38 | 39 | @decorator 40 | def check_sqlalchemy(func, *args, **kwargs): 41 | if sqlalchemy is None: 42 | raise RuntimeError('Must install sqlalchemy to use {}'.format(func)) 43 | return func(*args, **kwargs) 44 | 45 | 46 | class CorpusStreamer(object): 47 | """ 48 | A knowledgable wrapper for a CorpusDriver stream 49 | """ 50 | 51 | def __init__(self, iterable): 52 | self.iterable = iterable 53 | self.finished = False 54 | self.index = [] 55 | 56 | @property 57 | def documents_streamed(self): 58 | return len(self.index) 59 | 60 | def __iter__(self): 61 | for document in self.iterable: 62 | self.index.append(document.index) 63 | yield document 64 | self.finished = True 65 | 66 | 67 | class CorpusDriver(object): 68 | """ 69 | A base class for Corpus Drivers 70 | 71 | This class defines the Corpus Driver interface 72 | """ 73 | 74 | def __init__(self, index_labels): 75 | if isinstance(index_labels, str): 76 | index_labels = (index_labels,) 77 | else: 78 | try: 79 | index_labels = tuple(index_labels) 80 | assert(all(isinstance(i, str) for i in index_labels)) 81 | except (ValueError, AssertionError): 82 | raise ValueError( 83 | "Index Labels must be a string or sequence of strings") 84 | self.index_labels = index_labels 85 | 86 | def get_streamer(self, *args, **kwargs): 87 | """ 88 | Return a CorpusStreamer object that wraps this corpus's stream method. 89 | """ 90 | return CorpusStreamer(self.stream(*args, **kwargs)) 91 | 92 | def stream(self): 93 | """ 94 | Iterate over the corpus 95 | 96 | Return a generator of Document Objects. 97 | """ 98 | raise NotImplementedError 99 | 100 | def validate_key(self, key): 101 | if not isinstance(key, collections.Sequence): 102 | key = tuple(key,) 103 | if not len(key) == len(self.index_labels): 104 | raise ValueError("Expected index value of length {}, got length {}" 105 | .format(len(self.index_labels), len(key))) 106 | 107 | def __getitem__(self, key): 108 | self.validate_key(key) 109 | for idx, text in self.stream(): 110 | if idx == key: 111 | return text 112 | raise KeyError("Index value not found: {}".format(key)) 113 | 114 | 115 | class FlatFileCorpusDriver(CorpusDriver): 116 | """ 117 | Superclass for drivers that keep each document in a separate file. 118 | """ 119 | 120 | def __init__(self, index_labels, encoding="utf-8", cache=True): 121 | super(FlatFileCorpusDriver, self).__init__(index_labels) 122 | self.encoding = encoding 123 | self.cache = cache 124 | self._mapping = None 125 | 126 | @property 127 | def mapping(self): 128 | if self._mapping is None: 129 | self._mapping = { 130 | idx: path for idx, path in self.gen_indices_and_paths() 131 | } 132 | return self._mapping 133 | 134 | def gen_indices_and_paths(self): 135 | """ 136 | Return an iterator over the indices and paths of the corpus 137 | """ 138 | raise NotImplementedError 139 | 140 | def read(self, docinfo): 141 | """ 142 | Given an index and a path, return a Document 143 | """ 144 | idx, path = docinfo 145 | log.debug("Reading {}".format(path)) 146 | with path.open(encoding=self.encoding) as inf: 147 | return Document(idx, inf.read()) 148 | 149 | def __getitem__(self, key): 150 | self.validate_key(key) 151 | if self.cache: 152 | path = self.mapping[key] 153 | else: 154 | for idx, path in self.gen_indices_and_paths(): 155 | if idx == key: 156 | break 157 | else: 158 | raise KeyError() 159 | return self.read((key, path)) 160 | 161 | def stream(self): 162 | return qgutils.lazy_parallel(self.read, self.gen_indices_and_paths()) 163 | 164 | 165 | class RecursiveDirectoryCorpusDriver(FlatFileCorpusDriver): 166 | """ 167 | """ 168 | 169 | def __init__(self, directory, index_labels, encoding='utf-8', cache=True): 170 | super(RecursiveDirectoryCorpusDriver, self).__init__( 171 | index_labels, encoding, cache=cache) 172 | self.directory = Path(directory).resolve() 173 | self.encoding = encoding 174 | 175 | def _gen_docinfo(self, directory=None, level=0, restraint=None): 176 | """ 177 | Recursively generates indices and paths. 178 | """ 179 | 180 | if restraint is None: 181 | restraint = {} 182 | 183 | if directory is None: 184 | directory = self.directory 185 | 186 | subpaths = sorted(i for i in directory.iterdir() 187 | if not i.name.startswith('.')) 188 | 189 | for subpath in subpaths: 190 | if subpath.is_dir(): 191 | if self.index_labels[level] in restraint.keys(): 192 | if subpath.name in restraint[self.index_labels[level]]: 193 | for idx, path in self._gen_docinfo( 194 | subpath, level=level + 1, restraint=restraint 195 | ): 196 | yield (subpath.name,) + idx, path 197 | else: 198 | for idx, path in self._gen_docinfo( 199 | subpath, level=level + 1, restraint=restraint): 200 | yield (subpath.name,) + idx, path 201 | else: 202 | yield (subpath.stem,), subpath 203 | 204 | def gen_indices_and_paths(self): 205 | return self._gen_docinfo() 206 | 207 | def gen_indices_and_paths_restrained(self, restraint): 208 | return self._gen_docinfo(restraint=restraint) 209 | 210 | def extract(self, restraint): 211 | """ 212 | Allows specification of index values to restrict corpus. 'restraint' 213 | must be a dictionary of index names and tuples of allowable index 214 | values, i.e. {'index_name':('restraint_value',)}. 215 | """ 216 | return qgutils.lazy_parallel( 217 | self.read, 218 | self.gen_indices_and_paths_restrained(restraint=restraint) 219 | ) 220 | 221 | 222 | class NamePatternCorpusDriver(FlatFileCorpusDriver): 223 | """ 224 | Serve a corpus with all files in a single directory and filenames defined 225 | by a regular expression. 226 | 227 | The index labels are, the group names contained in the regular expression 228 | in the order that they appear 229 | """ 230 | 231 | def __init__(self, pattern, directory, encoding='utf-8', cache=True): 232 | self.pattern = re.compile(pattern) 233 | index_labels = ( 234 | i[0] for i in 235 | sorted(self.pattern.groupindex.items(), key=lambda x: x[1]) 236 | ) 237 | super(NamePatternCorpusDriver, self).__init__( 238 | index_labels=index_labels, encoding=encoding, cache=cache) 239 | self.directory = Path(directory) 240 | 241 | def gen_indices_and_paths(self): 242 | subpaths = sorted(i for i in self.directory.iterdir() 243 | if not i.name.startswith('.')) 244 | for subpath in subpaths: 245 | match = self.pattern.search(subpath.stem) 246 | index = tuple(match.groupdict()[i] for i in self.index_labels) 247 | yield index, subpath 248 | 249 | 250 | class IndexDriver(FlatFileCorpusDriver): 251 | """ 252 | Serve a corpus using an index csv where the final column is the path to the 253 | file and the other columns form the index. Index label names are taken from 254 | the csv header. 255 | """ 256 | 257 | def __init__(self, index, encoding='utf-8', cache=True): 258 | self.index = Path(index) 259 | with self.index.open(encoding=encoding) as inf: 260 | index_labels = next(csv.reader(inf))[:-1] 261 | super(IndexDriver, self).__init__( 262 | index_labels=index_labels, encoding=encoding, cache=cache) 263 | 264 | def gen_indices_and_paths(self): 265 | with self.index.open() as inf: 266 | reader = csv.reader(inf) 267 | next(reader) 268 | for row in reader: 269 | yield tuple(row[:-1]), Path(row[-1]) 270 | 271 | 272 | class S3Driver(IndexDriver): 273 | """ 274 | Serve a whole or partial corpus from a remote file location in s3. 275 | Filtering can be done using the values provided in the index file. 276 | """ 277 | 278 | @check_boto 279 | def __init__(self, index, bucket, encoding='utf-8', cache=True): 280 | self.index = Path(index) 281 | self.bucket = bucket 282 | self.client = boto3.client('s3') 283 | self.encoding = encoding 284 | with self.index.open(encoding=encoding) as inf: 285 | index_labels = next(csv.reader(inf))[:-1] 286 | super(IndexDriver, self).__init__( 287 | index_labels=index_labels, encoding=encoding, cache=cache) 288 | 289 | def read(self, docinfo): 290 | idx, path = docinfo 291 | body = self.client.get_object(Bucket=self.bucket, 292 | Key=str(path).replace('\\', '/'))['Body'] 293 | return Document(idx, body.read().decode(self.encoding)) 294 | 295 | def filter(self, pattern): 296 | """ Filter paths based on index values. """ 297 | raise NotImplementedError 298 | 299 | def stream(self): 300 | """Yield text from an object stored in s3. """ 301 | return qgutils.lazy_parallel(self.read, self.gen_indices_and_paths()) 302 | 303 | 304 | class S3DatabaseDriver(S3Driver): 305 | """ 306 | Retrieves an index table from a database with an arbitrary, user-provided 307 | query and serves documents like a normal S3Driver. 308 | """ 309 | 310 | @check_boto 311 | @check_sqlalchemy 312 | def __init__(self, protocol, user, password, host, db, port, query, 313 | bucket, cache=True, encoding='utf-8'): 314 | self.bucket = bucket 315 | self.client = boto3.client('s3') 316 | self.index = [] 317 | engine = sqlalchemy.create_engine('{}://{}:{}@{}:{}/{}' 318 | .format(protocol, user, password, 319 | host, port, db)) 320 | conn = engine.connect() 321 | result = conn.execute(query) 322 | for doc in result: 323 | self.index.append(doc) 324 | index_labels = doc.keys() 325 | super(IndexDriver, self).__init__( 326 | index_labels=index_labels, encoding=encoding, cache=cache) 327 | 328 | def gen_indices_and_paths(self): 329 | for row in self.index: 330 | yield tuple(row[:-1]), row[-1] 331 | -------------------------------------------------------------------------------- /quantgov/ml/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'candidate_sets', 3 | 'estimation', 4 | 'evaluate', 5 | 'structures', 6 | 'training', 7 | ] 8 | 9 | from .structures import ( 10 | Labels, 11 | Trainers, 12 | Estimator, 13 | CandidateModel 14 | ) 15 | 16 | from .evaluation import evaluate 17 | from .training import train_and_save_model 18 | from .estimation import estimate 19 | -------------------------------------------------------------------------------- /quantgov/ml/candidate_sets.py: -------------------------------------------------------------------------------- 1 | """ 2 | quantgov.ml.candidate_sets: Starter model candidate sets 3 | 4 | 5 | This module provides a few sample sets of models for common problems. These are 6 | mostly helpful for initial analysis; in general, you will want to customize 7 | these. 8 | 9 | The currently included candidates sets are: 10 | * `classificaiton`: Random Forests and Logit with TF-IDF preprocessor 11 | * `multilabel_classificaiton`: same as classification, with the Logit 12 | classifier wrapped in a MultiOutputClassifier 13 | """ 14 | import numpy as np 15 | import sklearn.ensemble 16 | import sklearn.linear_model 17 | import sklearn.multioutput 18 | import sklearn.pipeline 19 | import sklearn.feature_extraction 20 | 21 | import quantgov.ml 22 | 23 | classification = [ 24 | quantgov.ml.CandidateModel( 25 | name="Random Forests", 26 | model=sklearn.pipeline.Pipeline(steps=( 27 | ('tfidf', sklearn.feature_extraction.text.TfidfTransformer()), 28 | ('rf', sklearn.ensemble.RandomForestClassifier(n_jobs=-1)), 29 | )), 30 | parameters={ 31 | 'rf__n_estimators': [5, 10, 25, 50, 100], 32 | } 33 | ), 34 | quantgov.ml.CandidateModel( 35 | name="Logistic Regression", 36 | model=sklearn.pipeline.Pipeline(steps=( 37 | ('tfidf', sklearn.feature_extraction.text.TfidfTransformer()), 38 | ('logit', sklearn.linear_model.LogisticRegression()), 39 | )), 40 | parameters={ 41 | 'logit__C': np.logspace(-2, 2, 5) 42 | } 43 | ), 44 | ] 45 | 46 | 47 | multilabel_classification = [ 48 | quantgov.ml.CandidateModel( 49 | name="Random Forests", 50 | model=sklearn.pipeline.Pipeline(steps=( 51 | ('tfidf', sklearn.feature_extraction.text.TfidfTransformer()), 52 | ('rf', sklearn.ensemble.RandomForestClassifier(n_jobs=-1)), 53 | )), 54 | parameters={ 55 | 'rf__n_estimators': [5, 10, 25, 50, 100], 56 | } 57 | ), 58 | quantgov.ml.CandidateModel( 59 | name="Logistic Regression", 60 | model=sklearn.pipeline.Pipeline(steps=( 61 | ('tfidf', sklearn.feature_extraction.text.TfidfTransformer()), 62 | ('logit', sklearn.multioutput.MultiOutputClassifier( 63 | sklearn.linear_model.LogisticRegression(), 64 | n_jobs=-1 65 | )), 66 | )), 67 | parameters={ 68 | 'logit__estimator__C': np.logspace(-2, 2, 5) 69 | } 70 | ), 71 | ] 72 | -------------------------------------------------------------------------------- /quantgov/ml/estimation.py: -------------------------------------------------------------------------------- 1 | """ 2 | quantgov.ml.estimation 3 | 4 | Functionality for making predictions with an estimator 5 | """ 6 | import logging 7 | import numpy as np 8 | 9 | log = logging.getLogger(__name__) 10 | 11 | 12 | def estimate_simple(estimator, streamer): 13 | """ 14 | Generate predictions for a one-label estimator 15 | 16 | Arguments: 17 | * estimator: a quantgov.ml.Estimator 18 | * streamer: a quantgov.corpora.CorpusStreamer 19 | 20 | Yields: 21 | 2-tuples of docindex, (prediction,) 22 | 23 | """ 24 | texts = (doc.text for doc in streamer) 25 | predicted = estimator.pipeline.predict(texts) 26 | for docidx, prediction in zip(streamer.index, predicted): 27 | yield docidx, (prediction,) 28 | 29 | 30 | def estimate_multilabel(estimator, streamer): 31 | """ 32 | Generate predictions for a multi-label estimator 33 | 34 | Arguments: 35 | * estimator: a quantgov.ml.Estimator 36 | * streamer: a quantgov.corpora.CorpusStreamer 37 | 38 | Yields: 39 | 2-tuples of docindex, (label, prediction,) 40 | 41 | """ 42 | for docidx, (prediction,) in estimate_simple(estimator, streamer): 43 | for label, label_prediction in zip(estimator.label_names, prediction): 44 | yield docidx, (label, label_prediction) 45 | 46 | 47 | def estimate_probability(estimator, streamer, precision): 48 | """ 49 | Generate probabilities for a one-label estimator 50 | 51 | Arguments: 52 | * estimator: a quantgov.ml.Estimator 53 | * streamer: a quantgov.corpora.CorpusStreamer 54 | 55 | Yields: 56 | 2-tuples of docindex, (probability,) 57 | 58 | """ 59 | texts = (doc.text for doc in streamer) 60 | truecol = list(int(i) for i in estimator.pipeline.classes_).index(1) 61 | predicted = ( 62 | estimator.pipeline.predict_proba(texts)[:, truecol].round(precision)) 63 | yield from zip(streamer.index, ((prob,) for prob in predicted)) 64 | 65 | 66 | def estimate_probability_multilabel(estimator, streamer, precision): 67 | """ 68 | Generate probabilities for a multilabel binary estimator 69 | 70 | Arguments: 71 | * estimator: a quantgov.ml.Estimator 72 | * streamer: a quantgov.corpora.CorpusStreamer 73 | 74 | Yields: 75 | 2-tuples of docindex, (label, probability) 76 | 77 | """ 78 | texts = (doc.text for doc in streamer) 79 | model = estimator.pipeline.steps[-1][1] 80 | try: 81 | truecols = tuple( 82 | list(int(i) for i in label_classes).index(1) 83 | for label_classes in model.classes_ 84 | ) 85 | except (AttributeError, TypeError): 86 | truecols = tuple( 87 | list(int(i) for i in label_classes).index(1) 88 | for label_classes in ( 89 | est.classes_ for est in model.steps[-1][1].estimators_ 90 | ) 91 | ) 92 | predicted = estimator.pipeline.predict_proba(texts).round(int(precision)) 93 | 94 | try: 95 | yield from ( 96 | (docidx, (label, label_prediction[truecol])) 97 | for docidx, doc_predictions in zip(streamer.index, predicted) 98 | for label, label_prediction, truecol 99 | in zip(estimator.label_names, doc_predictions, truecols) 100 | ) 101 | except IndexError: 102 | yield from ( 103 | (docidx, (label, label_prediction)) 104 | for docidx, doc_predictions in zip(streamer.index, predicted) 105 | for (label, label_prediction) 106 | in zip(estimator.label_names, doc_predictions) 107 | ) 108 | 109 | 110 | def estimate_probability_multiclass(estimator, streamer, precision, oneclass): 111 | """ 112 | Generate probabilities for a one-label, multiclass estimator 113 | 114 | Arguments: 115 | * estimator: a quantgov.ml.Estimator 116 | * streamer: a quantgov.corpora.CorpusStreamer 117 | 118 | Yields: 119 | 2-tuples of docindex, (class, probability) 120 | 121 | """ 122 | texts = (doc.text for doc in streamer) 123 | probs = estimator.pipeline.predict_proba(texts) 124 | # If oneclass flag is true, only returns the predicted class 125 | if oneclass: 126 | class_indices = list(i[-1] for i in np.argsort(probs, axis=1)) 127 | yield from ( 128 | (docidx, (estimator.pipeline.classes_[class_index], 129 | doc_probs[class_index].round(precision))) 130 | for docidx, doc_probs, class_index in zip( 131 | streamer.index, probs, class_indices) 132 | ) 133 | # Else returns probabilty values for all classes 134 | else: 135 | yield from ( 136 | (docidx, (class_, probability.round(precision))) 137 | for docidx, doc_probs in zip(streamer.index, probs) 138 | for class_, probability in zip( 139 | estimator.pipeline.classes_, doc_probs) 140 | ) 141 | 142 | 143 | def estimate_probability_multilabel_multiclass(estimator, streamer, precision): 144 | """ 145 | Generate probabilities for a multilabel, multiclass estimator 146 | 147 | Arguments: 148 | * estimator: a quantgov.ml.Estimator 149 | * streamer: a quantgov.corpora.CorpusStreamer 150 | 151 | Yields: 152 | 2-tuples of docindex, (label, class, probability 153 | 154 | """ 155 | texts = (doc.text for doc in streamer) 156 | probs = estimator.pipeline.predict_proba(texts).round(precision) 157 | yield from ( 158 | (docidx, (label_name, class_, prob)) 159 | for label_name, label_probs in zip(estimator.label_names, probs) 160 | for docidx, doc_probs in zip(streamer.index, label_probs) 161 | for class_, prob in zip(estimator.pipeline.classes_, doc_probs) 162 | ) 163 | 164 | 165 | def estimate(estimator, corpus, probability, precision=4, oneclass=False, 166 | *args, **kwargs): 167 | """ 168 | Estimate label values for documents in corpus 169 | 170 | Arguments: 171 | 172 | * **estimator**: path to a saved `quantgov.ml.Estimator` object 173 | * **corpus**: path to a quantgov corpus 174 | * **probability**: if True, predict probability 175 | * **precision**: precision for probability prediction 176 | """ 177 | streamer = corpus.get_streamer(*args, **kwargs) 178 | if probability: 179 | if estimator.multilabel: 180 | if estimator.multiclass: # Multilabel-multiclass probability 181 | yield from estimate_probability_multilabel_multiclass( 182 | estimator, streamer, precision) 183 | else: # Multilabel probability 184 | yield from estimate_probability_multilabel( 185 | estimator, streamer, precision) 186 | elif estimator.multiclass: # Multiclass probability 187 | yield from estimate_probability_multiclass( 188 | estimator, streamer, precision, oneclass) 189 | else: # Simple probability 190 | yield from estimate_probability( 191 | estimator, streamer, precision) 192 | elif estimator.multilabel: # Multilabel Prediction 193 | yield from estimate_multilabel(estimator, streamer) 194 | else: # Binary and Multiclass 195 | yield from estimate_simple(estimator, streamer) 196 | -------------------------------------------------------------------------------- /quantgov/ml/evaluation.py: -------------------------------------------------------------------------------- 1 | import configparser 2 | import logging 3 | 4 | import pandas as pd 5 | 6 | try: 7 | from sklearn.model_selection import KFold, GridSearchCV 8 | except ImportError: # sklearn 0.17 9 | from sklearn.cross_validation import KFold 10 | from sklearn.grid_search import GridSearchCV 11 | 12 | from . import utils as eutils 13 | 14 | log = logging.getLogger(name=__name__) 15 | 16 | 17 | def evaluate_model(model, X, y, folds, scoring): 18 | """ 19 | Evaluate a single model 20 | 21 | Arguments: 22 | * model: a quantgov.ml.CandidateModel 23 | * X: array-like of document vectors with shape [n_samples x n_features] 24 | * y: array-like of labels with shape [n_samples X n_labels] 25 | * folds: folds to use in cross-validation 26 | * scoring: scoring method 27 | 28 | 29 | Returns: pandas DataFrame with model evaluation results 30 | """ 31 | log.info('Evaluating {}'.format(model.name)) 32 | if hasattr(y[0], '__getitem__'): 33 | cv = KFold(folds, shuffle=True) 34 | if '_' not in scoring: 35 | log.warning("No averaging method specified, assuming macro") 36 | scoring += '_macro' 37 | else: 38 | cv = KFold(folds, shuffle=True) 39 | gs = GridSearchCV( 40 | estimator=model.model, 41 | param_grid=model.parameters, 42 | cv=cv, 43 | scoring=scoring, 44 | verbose=100, 45 | refit=False 46 | ) 47 | gs.fit(X, y) 48 | return pd.DataFrame(gs.cv_results_).assign(model=model.name) 49 | 50 | 51 | def evaluate_all_models(models, X, y, folds, scoring): 52 | """ 53 | Evaluate a number of models 54 | 55 | Arguments: 56 | * models: a sequence of quantgov.ml.CandidateModel objects 57 | * X: array-like of document vectors with shape [n_samples x n_features] 58 | * y: array-like of labels with shape [n_samples X n_labels] 59 | * folds: folds to use in cross-validation 60 | * scoring: scoring method 61 | 62 | Returns: pandas DataFrame with model evaluation results 63 | """ 64 | results = pd.concat( 65 | [evaluate_model(model, X, y, folds, scoring) for model in models], 66 | ignore_index=True 67 | ) 68 | results = results[ 69 | ['model', 'mean_test_score', 'std_test_score', 70 | 'mean_fit_time', 'std_fit_time', 71 | 'mean_score_time', 'std_score_time'] 72 | + sorted(i for i in results if i.startswith('param_')) 73 | + sorted(i for i in results 74 | if i.startswith('split') 75 | and '_train_' not in i 76 | ) 77 | + ['params'] 78 | ] 79 | return results 80 | 81 | 82 | def write_suggestion(results, file): 83 | """ 84 | Given results, write the best performer to a config file. 85 | 86 | Arguments: 87 | 88 | * **Results**: a A DataFrame as returned by `evaluate_all_models` 89 | * **file**: an open file-like object 90 | """ 91 | best_model = results.loc[results['mean_test_score'].idxmax()] 92 | config = configparser.ConfigParser() 93 | config.optionxform = str 94 | config['Model'] = {'name': best_model['model']} 95 | config['Parameters'] = {i: j for i, j in best_model['params'].items()} 96 | config.write(file) 97 | 98 | 99 | def evaluate(modeldefs, trainers, labels, folds, scoring, results_file, 100 | suggestion_file): 101 | """ 102 | Evaluate Candidate Models and write out a suggestion 103 | 104 | Arguments: 105 | 106 | * **modeldefs**: Path to a python module containing a list of 107 | `quantgov.ml.CandidateModel` objects in a module-level 108 | variable named `models'. 109 | * **trainers**: a `quantgov.ml.Trainers` object 110 | * **labels**: a `quantgov.ml.Labels` object 111 | * **folds**: folds to use in cross-validation 112 | * **scoring**: scoring method to use 113 | * **results_file**: open file object to which results should be written 114 | * **suggestion_file**: open file object to which the model suggestion 115 | should be written 116 | """ 117 | assert labels.index == trainers.index 118 | models = eutils.load_models(modeldefs) 119 | results = evaluate_all_models( 120 | models, trainers.vectors, labels.labels, folds, scoring) 121 | results.to_csv(results_file, index=False) 122 | write_suggestion(results, suggestion_file) 123 | -------------------------------------------------------------------------------- /quantgov/ml/structures.py: -------------------------------------------------------------------------------- 1 | """ 2 | quantgov.ml.structures 3 | 4 | Useful structures for evaluating and training estimators 5 | """ 6 | import collections 7 | import joblib as jl 8 | 9 | 10 | class _PersistanceMixin(object): 11 | """ 12 | A Mixin to add a `.save` method to any class that uses joblib to pickle the 13 | object 14 | """ 15 | 16 | @classmethod 17 | def load(cls, path): 18 | """ 19 | Load a saved object at path `path` 20 | """ 21 | loaded = jl.load(path) 22 | if not isinstance(loaded, cls): 23 | raise ValueError( 24 | 'Expected saved type {}, path {} contained saved type {}' 25 | .format(cls, path, type(loaded)) 26 | ) 27 | return loaded 28 | 29 | def save(self, path): 30 | """ 31 | Use joblib to save the object. 32 | 33 | Arguments: 34 | path: an open file object or string holding the path to where the 35 | object should be saved 36 | """ 37 | jl.dump(self, path, compress=True) 38 | 39 | 40 | class Labels( 41 | collections.namedtuple('Labels', ['index', 'label_names', 'labels']), 42 | _PersistanceMixin 43 | ): 44 | """ 45 | A set of labels for training a model. 46 | 47 | Arguments: 48 | * index: a sequence holding the index values for each document being 49 | labeled 50 | * label_names: a sequence holding one name for each label 51 | * labels: an array-like of label values with 52 | shape [n_samples x n_labels] 53 | """ 54 | pass 55 | 56 | 57 | class Trainers( 58 | collections.namedtuple('Trainers', ['index', 'vectors']), 59 | _PersistanceMixin 60 | ): 61 | """ 62 | A set of vectorized documents for training a model 63 | 64 | Arguments: 65 | * index: a sequence holding the index values for each document 66 | represented 67 | * vectors: array-like of document vectors [n_samples x n_features] 68 | """ 69 | pass 70 | 71 | 72 | def is_multiclass(classes): 73 | """ 74 | Returns True if values in classes are anything but 1, 0, True, or False, 75 | otherwise returns False. 76 | """ 77 | try: 78 | return len(set(int(i) for i in classes) - {0, 1}) != 0 79 | except ValueError: 80 | return True 81 | 82 | 83 | class Estimator( 84 | collections.namedtuple('Estimator', ['label_names', 'pipeline']), 85 | _PersistanceMixin 86 | ): 87 | """ 88 | A Trained estimator 89 | 90 | Arguments: 91 | * label_names: sequence of names for each label the model estimates 92 | * pipeline: a trained sklearn-like pipeline, implementing `.fit`, 93 | `.fit_transform`, and `.predict` methods, where the X inputs are a 94 | sequence of strings. 95 | """ 96 | 97 | def __init__(self, *args, **kwargs): 98 | super().__init__() 99 | self.multilabel = len(self.label_names) > 1 100 | model = self.pipeline.steps[-1][1] 101 | if self.multilabel: 102 | try: 103 | self.multiclass = any(is_multiclass(i) for i in model.classes_) 104 | except (AttributeError, TypeError): 105 | self.multiclass = any( 106 | is_multiclass(i.classes_) 107 | for i in model.steps[-1][-1].estimators_ 108 | ) 109 | else: 110 | self.multiclass = is_multiclass(model.classes_) 111 | 112 | 113 | class CandidateModel( 114 | collections.namedtuple('CandidateModel', ['name', 'model', 'parameters']) 115 | ): 116 | """ 117 | A Candidate Model for testing 118 | 119 | Arguments: 120 | * name: an identifier for this model, unique among candidates 121 | * model: a trained sklearn-like model, implementing `.fit`, 122 | `.fit_transform`, and `.predict` methods 123 | * parameters: a dictionary with parameters names as keys and possible 124 | parameter values to test as values 125 | """ 126 | pass 127 | -------------------------------------------------------------------------------- /quantgov/ml/training.py: -------------------------------------------------------------------------------- 1 | import configparser 2 | 3 | import sklearn.pipeline 4 | 5 | import quantgov.ml 6 | 7 | 8 | def _autoconvert(value): 9 | """Convert to int or float if possible, otherwise return string""" 10 | try: 11 | return int(value) 12 | except ValueError: 13 | pass 14 | try: 15 | return float(value) 16 | except ValueError: 17 | return value 18 | 19 | 20 | def get_model(modeldefs, configfile): 21 | """ 22 | Parse config file and configure relevant model 23 | """ 24 | config = configparser.ConfigParser() 25 | config.optionxform = str 26 | config.read(configfile) 27 | models = {i.name: i for i in 28 | quantgov.ml.utils.load_models(modeldefs)} 29 | model = models[config['Model']['name']].model 30 | model.set_params( 31 | **{i: _autoconvert(j) for i, j in config['Parameters'].items()}) 32 | return model 33 | 34 | 35 | def train_and_save_model( 36 | modeldefs, 37 | configfile, 38 | vectorizer, 39 | trainers, 40 | labels, 41 | outfile): 42 | """ 43 | Train and save model described in config file 44 | 45 | Arguments: 46 | 47 | * **modeldefs**: Path to a python module containing a list of 48 | `quantgov.ml.CandidateModel` objects in a module-level 49 | variable named `models'. 50 | * **configfile**: config file as produced by 51 | `quantgov ml evaluate` 52 | * **vectorizer**: an sklearn-compatible Vectorizer object 53 | * **trainers**: a `quantgov.ml.Trainers` object 54 | * **labels**: a `quantgov.ml.Labels` object 55 | * **outfile**: file to which model should be saved 56 | """ 57 | model = get_model(modeldefs, configfile) 58 | pipeline = sklearn.pipeline.Pipeline(( 59 | ('vectorizer', vectorizer), 60 | ('model', model.fit(trainers.vectors, labels.labels)), 61 | )) 62 | quantgov.ml.Estimator(labels.label_names, pipeline).save(outfile) 63 | -------------------------------------------------------------------------------- /quantgov/ml/utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from pathlib import Path 4 | 5 | 6 | def load_models(path): 7 | """ 8 | Load models list from path 9 | 10 | Arguments: 11 | 12 | * **path**: Path to a python module containing a list of 13 | `quantgov.ml.CandidateModel` objects in a module-level 14 | """ 15 | path = Path(path).resolve() 16 | try: 17 | assert ' ' not in path.stem 18 | except AssertionError: 19 | raise ValueError("models file name must contain no spaces") 20 | sys.path.insert(0, str(path.parent)) 21 | exec('import {}'.format(path.stem)) 22 | models = eval('{}.models'.format(path.stem)) 23 | exec('del({})'.format(path.stem)) 24 | sys.path.pop(0) 25 | return models 26 | -------------------------------------------------------------------------------- /quantgov/nlp.py: -------------------------------------------------------------------------------- 1 | """ 2 | quantgov.nlp: Text-based analysis of documents 3 | """ 4 | import collections 5 | import math 6 | import re 7 | 8 | from decorator import decorator 9 | 10 | from . import utils 11 | 12 | try: 13 | import nltk.corpus 14 | NLTK = True 15 | except ImportError: 16 | NLTK = None 17 | 18 | try: 19 | import textblob 20 | except ImportError: 21 | textblob = None 22 | 23 | try: 24 | import textstat 25 | except ImportError: 26 | textstat = None 27 | 28 | if NLTK: 29 | try: 30 | nltk.corpus.wordnet.ensure_loaded() 31 | except LookupError: 32 | nltk.download('wordnet') 33 | nltk.corpus.wordnet.ensure_loaded() 34 | 35 | commands = {} 36 | 37 | 38 | @decorator 39 | def check_nltk(func, *args, **kwargs): 40 | if NLTK is None: 41 | raise RuntimeError('Must install NLTK to use {}'.format(func)) 42 | return func(*args, **kwargs) 43 | 44 | 45 | @decorator 46 | def check_textblob(func, *args, **kwargs): 47 | if textblob is None: 48 | raise RuntimeError('Must install textblob to use {}'.format(func)) 49 | return func(*args, **kwargs) 50 | 51 | 52 | @decorator 53 | def check_textstat(func, *args, **kwargs): 54 | if textstat is None: 55 | raise RuntimeError('Must install teststat to use {}'.format(func)) 56 | return func(*args, **kwargs) 57 | 58 | 59 | class WordCounter(): 60 | 61 | cli = utils.CLISpec( 62 | help='Word Counter', 63 | arguments=[ 64 | utils.CLIArg( 65 | flags=('--word_pattern', '-wp'), 66 | kwargs={ 67 | 'help': 'regular expression defining a "word"', 68 | 'type': re.compile, 69 | 'default': re.compile(r'\b\w+\b') 70 | } 71 | ) 72 | ] 73 | ) 74 | 75 | @staticmethod 76 | def get_columns(args): 77 | return ('words',) 78 | 79 | @staticmethod 80 | def process_document(doc, word_pattern): 81 | return doc.index + (len(word_pattern.findall(doc.text)),) 82 | 83 | 84 | commands['count_words'] = WordCounter 85 | 86 | 87 | class OccurrenceCounter(): 88 | 89 | cli = utils.CLISpec( 90 | help="Term Counter for Specific Words", 91 | arguments=[ 92 | utils.CLIArg( 93 | flags=('terms'), 94 | kwargs={ 95 | 'help': 'list of terms to be counted', 96 | 'nargs': '+' 97 | } 98 | ), 99 | utils.CLIArg( 100 | flags=('--total_label'), 101 | kwargs={ 102 | 'metavar': 'LABEL', 103 | 'help': ( 104 | 'output a column with sum of occurrences of all terms' 105 | ' with column name LABEL' 106 | ), 107 | } 108 | ), 109 | utils.CLIArg( 110 | flags=('--pattern'), 111 | kwargs={ 112 | 'help': 'pattern to use in identifying words', 113 | 'default': r'\b(?P{})\b' 114 | } 115 | ) 116 | ] 117 | ) 118 | 119 | @staticmethod 120 | def get_columns(args): 121 | if args['total_label'] is not None: 122 | return tuple(args['terms']) + (args['total_label'],) 123 | return tuple(args['terms']) 124 | 125 | @staticmethod 126 | def process_document(doc, terms, pattern, total_label): 127 | text = ' '.join(doc.text.split()).lower() 128 | terms_sorted = sorted(terms, key=len, reverse=True) 129 | combined_pattern = re.compile(pattern.format('|'.join(terms_sorted))) 130 | term_counts = collections.Counter( 131 | i.groupdict()['match'] for i in combined_pattern.finditer(text) 132 | ) 133 | if total_label is not None: 134 | return ( 135 | doc.index 136 | + tuple(term_counts[i] for i in terms) 137 | + (sum(term_counts.values()),) 138 | ) 139 | return (doc.index + tuple(term_counts[i] for i in terms)) 140 | 141 | 142 | commands['count_occurrences'] = OccurrenceCounter 143 | 144 | 145 | class ShannonEntropy(): 146 | lemmas = {} 147 | cli = utils.CLISpec( 148 | help='Shannon Entropy', 149 | arguments=[ 150 | utils.CLIArg( 151 | flags=('--word_pattern', '-wp'), 152 | kwargs={ 153 | 'help': 'regular expression defining a "word"', 154 | 'type': re.compile, 155 | 'default': re.compile(r'\b\w+\b') 156 | } 157 | ), 158 | utils.CLIArg( 159 | flags=('--stopwords', '-sw'), 160 | kwargs={ 161 | 'help': 'stopwords to ignore', 162 | 'default': ( 163 | None if not NLTK else 164 | nltk.corpus.stopwords.words('english') 165 | ) 166 | } 167 | ), 168 | utils.CLIArg( 169 | flags=('--precision'), 170 | kwargs={ 171 | 'help': 'decimal places to round', 172 | 'default': 2 173 | } 174 | ) 175 | ] 176 | ) 177 | 178 | @staticmethod 179 | def get_columns(args): 180 | return ('shannon_entropy',) 181 | 182 | @staticmethod 183 | @check_nltk 184 | @check_textblob 185 | def process_document(doc, word_pattern, precision, stopwords, 186 | textblob=textblob, nltk=NLTK): 187 | words = word_pattern.findall(doc.text) 188 | lemmas = [ 189 | lemma for lemma in ( 190 | ShannonEntropy.lemmatize(word) for word in words 191 | ) 192 | if lemma not in stopwords 193 | ] 194 | counts = collections.Counter(lemmas) 195 | return doc.index + (round(sum( 196 | -(count / len(lemmas) * math.log(count / len(lemmas), 2)) 197 | for count in counts.values() 198 | ), int(precision)),) 199 | 200 | def lemmatize(word): 201 | if word in ShannonEntropy.lemmas: 202 | lemma = ShannonEntropy.lemmas[word] 203 | else: 204 | lemma = textblob.Word(word).lemmatize() 205 | ShannonEntropy.lemmas[word] = lemma 206 | return lemma 207 | 208 | 209 | commands['shannon_entropy'] = ShannonEntropy 210 | 211 | 212 | class ConditionalCounter(): 213 | cli = utils.CLISpec( 214 | help=('Count conditional words and phrases. Included terms are: ' 215 | ' "if", "but", "except", "provided", "when", "where", ' 216 | '"whenever", "unless", "notwithstanding", "in the event", ' 217 | 'and "in no event"'), 218 | arguments=[] 219 | ) 220 | pattern = re.compile( 221 | r'\b(if|but|except|provided|when|where' 222 | r'|whenever|unless|notwithstanding' 223 | r'|in\s+the\s+event|in\s+no\s+event)\b' 224 | ) 225 | 226 | @staticmethod 227 | def get_columns(args): 228 | return ('conditionals',) 229 | 230 | @staticmethod 231 | def process_document(doc): 232 | return doc.index + (len(ConditionalCounter.pattern.findall( 233 | ' '.join((doc.text).splitlines()))),) 234 | 235 | 236 | commands['count_conditionals'] = ConditionalCounter 237 | 238 | 239 | class SentenceLength(): 240 | 241 | cli = utils.CLISpec( 242 | help='Sentence Length', 243 | arguments=[ 244 | utils.CLIArg( 245 | flags=('--precision'), 246 | kwargs={ 247 | 'help': 'decimal places to round', 248 | 'default': 2 249 | } 250 | ), 251 | utils.CLIArg( 252 | flags=('--threshold'), 253 | kwargs={ 254 | 'help': ('maximum average sentence length to allow ' 255 | '(set to 0 for no filtering)'), 256 | 'type': int, 257 | 'default': 100 258 | } 259 | ) 260 | ] 261 | ) 262 | 263 | @staticmethod 264 | def get_columns(args): 265 | return ('sentence_length',) 266 | 267 | @staticmethod 268 | @check_nltk 269 | @check_textblob 270 | def process_document(doc, precision, threshold): 271 | sentences = textblob.TextBlob(doc.text).sentences 272 | if not len(sentences): 273 | return doc.index + (None,) 274 | # Allows for rounding to a specified number of decimals 275 | elif precision: 276 | sentence_length = round(sum(len( 277 | sentence.words) for sentence in sentences) / len(sentences), 278 | int(precision)) 279 | else: 280 | sentence_length = sum(len( 281 | sentence.words) for sentence in sentences) / len(sentences) 282 | # Filters values based on threshold 283 | if not threshold or sentence_length < threshold: 284 | return doc.index + (sentence_length,) 285 | else: 286 | return doc.index + (None,) 287 | 288 | 289 | commands['sentence_length'] = SentenceLength 290 | 291 | 292 | class SentimentAnalysis(): 293 | 294 | cli = utils.CLISpec( 295 | help='Performs sentiment analysis on the text', 296 | arguments=[ 297 | utils.CLIArg( 298 | flags=('--backend'), 299 | kwargs={ 300 | 'help': 'which program to use for the analysis', 301 | 'default': 'textblob' 302 | } 303 | ), 304 | utils.CLIArg( 305 | flags=('--precision'), 306 | kwargs={ 307 | 'help': 'decimal places to round', 308 | 'default': 2 309 | } 310 | ) 311 | ] 312 | ) 313 | 314 | @staticmethod 315 | def get_columns(args): 316 | if args['backend'] == 'textblob': 317 | return ('sentiment_polarity', 'sentiment_subjectivity',) 318 | else: 319 | raise NotImplementedError 320 | 321 | @staticmethod 322 | @check_nltk 323 | @check_textblob 324 | def process_document(doc, backend, precision): 325 | if backend == 'textblob': 326 | sentiment = textblob.TextBlob(doc.text) 327 | # Allows for rounding to a specified number of decimals 328 | if precision: 329 | return (doc.index + (round( 330 | sentiment.polarity, int(precision)), 331 | round(sentiment.subjectivity, int(precision)),)) 332 | else: 333 | return (doc.index + (sentiment.polarity, 334 | sentiment.subjectivity,)) 335 | 336 | 337 | commands['sentiment_analysis'] = SentimentAnalysis 338 | 339 | 340 | class FleschReadingEase(): 341 | 342 | cli = utils.CLISpec( 343 | help='Flesch Reading Ease metric', 344 | arguments=[ 345 | utils.CLIArg( 346 | flags=('--threshold'), 347 | kwargs={ 348 | 'help': ('minimum score to allow ' 349 | '(set to 0 for no filtering)'), 350 | 'type': int, 351 | 'default': -100 352 | } 353 | ) 354 | ] 355 | ) 356 | 357 | @staticmethod 358 | def get_columns(args): 359 | return ('flesch_reading_ease',) 360 | 361 | @staticmethod 362 | @check_textstat 363 | def process_document(doc, threshold): 364 | score = textstat.flesch_reading_ease(doc.text) 365 | # Filters values based on threshold 366 | if not threshold or score > threshold: 367 | return doc.index + (int(score),) 368 | else: 369 | return doc.index + (None,) 370 | 371 | 372 | commands['flesch_reading_ease'] = FleschReadingEase 373 | 374 | 375 | class TextStandard(): 376 | 377 | cli = utils.CLISpec( 378 | help='combines all of the readability metrics in textstats', 379 | arguments=[] 380 | ) 381 | 382 | @staticmethod 383 | def get_columns(args): 384 | return ('text_standard',) 385 | 386 | @staticmethod 387 | @check_textstat 388 | def process_document(doc): 389 | score = textstat.text_standard(doc.text) 390 | # Allows for rounding to a specified number of decimals 391 | return doc.index + (score,) 392 | 393 | 394 | commands['text_standard'] = TextStandard 395 | -------------------------------------------------------------------------------- /quantgov/utils.py: -------------------------------------------------------------------------------- 1 | # TODO: Docstrings 2 | 3 | import collections 4 | import concurrent.futures 5 | import os 6 | import sys 7 | 8 | from pathlib import Path 9 | 10 | 11 | def load_driver(corpus): 12 | corpus = Path(corpus) 13 | if corpus.name == 'driver.py' or corpus.name == 'timestamp': 14 | corpus = corpus.parent 15 | sys.path.insert(0, str(corpus)) 16 | from driver import driver 17 | sys.path.pop(0) 18 | return driver 19 | 20 | 21 | _POOLS = { 22 | 'thread': concurrent.futures.ThreadPoolExecutor, 23 | 'process': concurrent.futures.ProcessPoolExecutor 24 | } 25 | 26 | 27 | def lazy_parallel(func, *iterables, **kwargs): 28 | """ 29 | Parallel execution without fully loading iterables 30 | 31 | 32 | Arguments: 33 | * func: function to call 34 | * iterables: any number of iterables, which will be passed to func as 35 | arguments 36 | 37 | Keyword Arugments: 38 | * max_workers: max number of threads or processes. Defaults to None. 39 | * worker: 'thread' (default) or 'process' 40 | """ 41 | worker = kwargs.get('worker', 'thread') 42 | max_workers = kwargs.get('max_workers') 43 | if max_workers is None: # Not in back-port 44 | max_workers = (os.cpu_count() or 1) 45 | if worker == 'thread': 46 | max_workers *= 5 47 | try: 48 | pooltype = _POOLS[worker] 49 | except KeyError: 50 | raise ValueError("Valid choices for worker are: {}" 51 | .format(', '.join(_POOLS.keys()))) 52 | jobs = [] 53 | argsets = zip(*iterables) 54 | with pooltype(max_workers) as pool: 55 | for argset in argsets: 56 | jobs.append(pool.submit(func, *argset)) 57 | if len(jobs) == pool._max_workers: 58 | yield jobs.pop(0).result() 59 | for job in jobs: 60 | yield job.result() 61 | 62 | 63 | CLISpec = collections.namedtuple('CLISpec', ['help', 'arguments']) 64 | CLIArg = collections.namedtuple('CLIArg', ['flags', 'kwargs']) 65 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [bdist_wheel] 2 | universal = 1 3 | [tool:pytest] 4 | addopts = --flake8 5 | flake8-ignore = 6 | *.py W391 W503 7 | */__init__.py F401 8 | tests/* F401 E402 9 | 10 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """ 2 | A setuptools-based setup module. 3 | """ 4 | 5 | import os 6 | import re 7 | 8 | from setuptools import setup, find_packages 9 | from codecs import open 10 | 11 | 12 | def read(*names, **kwargs): 13 | with open( 14 | os.path.join(os.path.dirname(__file__), *names), 15 | encoding=kwargs.get("encoding", "utf8") 16 | ) as fp: 17 | return fp.read() 18 | 19 | 20 | def find_version(*file_paths): 21 | version_file = read(*file_paths) 22 | version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", 23 | version_file, re.M) 24 | if version_match: 25 | return version_match.group(1) 26 | raise RuntimeError("Unable to find version string.") 27 | 28 | 29 | long_description = read("README.rst") 30 | version = find_version("quantgov", "__init__.py") 31 | 32 | setup( 33 | name='quantgov', 34 | version=version, 35 | 36 | description='A Policy Analytics Framework', 37 | long_description=long_description, 38 | url='https://www.quantgov.org', 39 | author='Oliver Sherouse', 40 | author_email='quantgov.info@gmail.com', 41 | license='MIT', 42 | classifiers=[ 43 | 'Development Status :: 3 - Alpha', 44 | 'Intended Audience :: Science/Research', 45 | 'Topic :: Scientific/Engineering :: Information Analysis', 46 | 'License :: OSI Approved :: MIT License', 47 | 'Programming Language :: Python :: 3', 48 | 'Programming Language :: Python :: 3.5', 49 | 'Programming Language :: Python :: 3.6', 50 | ], 51 | keywords='quantgov economics policy government machine learning', 52 | packages=find_packages( 53 | exclude=["*.tests", "*.tests.*", "tests.*", "tests"]), 54 | install_requires=[ 55 | 'decorator', 56 | 'joblib', 57 | 'pandas', 58 | 'requests', 59 | 'scikit-learn', 60 | 'scipy', 61 | 'textstat' 62 | ], 63 | extras_require={ 64 | 'testing': ['pytest-flake8'], 65 | 'nlp': [ 66 | 'textblob', 67 | 'nltk', 68 | ], 69 | 's3driver': [ 70 | 'sqlalchemy', 71 | 'boto3' 72 | ] 73 | }, 74 | entry_points={ 75 | 'console_scripts': [ 76 | 'quantgov=quantgov.__main__:main', 77 | ], 78 | }, 79 | ) 80 | -------------------------------------------------------------------------------- /tests/pseudo_corpus/driver.py: -------------------------------------------------------------------------------- 1 | import quantgov 2 | 3 | from pathlib import Path 4 | 5 | driver = quantgov.corpus.RecursiveDirectoryCorpusDriver( 6 | directory=Path(__file__).parent.joinpath('data', 'clean'), 7 | index_labels=('file',) 8 | ) 9 | -------------------------------------------------------------------------------- /tests/pseudo_estimator/.gitignore: -------------------------------------------------------------------------------- 1 | .snakemake 2 | notebooks/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | env/ 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *,cover 49 | .hypothesis/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | local_settings.py 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # IPython Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # dotenv 82 | .env 83 | 84 | # virtualenv 85 | venv/ 86 | ENV/ 87 | 88 | # Spyder project settings 89 | .spyderproject 90 | 91 | # Rope project settings 92 | .ropeproject 93 | -------------------------------------------------------------------------------- /tests/pseudo_estimator/data/binary.qge: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QuantGov/quantgov/f0e702339c0e1eedff28f9879f89346236b12efd/tests/pseudo_estimator/data/binary.qge -------------------------------------------------------------------------------- /tests/pseudo_estimator/data/multiclass.qge: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QuantGov/quantgov/f0e702339c0e1eedff28f9879f89346236b12efd/tests/pseudo_estimator/data/multiclass.qge -------------------------------------------------------------------------------- /tests/test_downloads.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | 3 | import pytest 4 | 5 | 6 | @pytest.mark.parametrize('component', ['corpus', 'estimator', 'project']) 7 | def test_download(tmpdir, component): 8 | comp_dir = tmpdir.join(component) 9 | subprocess.check_call(['quantgov', 'start', component, str(comp_dir)]) 10 | assert comp_dir.join('Snakefile').check() 11 | 12 | 13 | def test_noclobber(tmpdir): 14 | comp_dir = tmpdir.join('corpus') 15 | comp_dir.mkdir() 16 | with pytest.raises(subprocess.CalledProcessError): 17 | subprocess.check_call(['quantgov', 'start', 'corpus', str(comp_dir)]) 18 | -------------------------------------------------------------------------------- /tests/test_ml.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import quantgov.ml 3 | import subprocess 4 | 5 | from pathlib import Path 6 | 7 | 8 | PSEUDO_CORPUS_PATH = Path(__file__).resolve().parent.joinpath('pseudo_corpus') 9 | PSEUDO_ESTIMATOR_PATH = ( 10 | Path(__file__).resolve().parent 11 | .joinpath('pseudo_estimator') 12 | ) 13 | 14 | 15 | def check_output(cmd): 16 | return ( 17 | subprocess.check_output(cmd, universal_newlines=True) 18 | .replace('\n\n', '\n') 19 | ) 20 | 21 | 22 | def test_simple_estimator(): 23 | output = check_output( 24 | ['quantgov', 'ml', 'estimate', 25 | str(PSEUDO_ESTIMATOR_PATH.joinpath('data', 'binary.qge')), 26 | str(PSEUDO_CORPUS_PATH)] 27 | ) 28 | assert output == 'file,is_world\ncfr,False\nmoby,False\n' 29 | 30 | 31 | def test_probability_estimator(): 32 | output = check_output( 33 | ['quantgov', 'ml', 'estimate', 34 | str(PSEUDO_ESTIMATOR_PATH.joinpath('data', 'binary.qge')), 35 | str(PSEUDO_CORPUS_PATH), '--probability'] 36 | ) 37 | assert output == ('file,is_world_prob\ncfr,0.0899\nmoby,0.0216\n') 38 | 39 | 40 | def test_probability_estimator_6decimals(): 41 | output = check_output( 42 | ['quantgov', 'ml', 'estimate', 43 | str(PSEUDO_ESTIMATOR_PATH.joinpath('data', 'binary.qge')), 44 | str(PSEUDO_CORPUS_PATH), '--probability', '--precision', '6'] 45 | ) 46 | assert output == ('file,is_world_prob\ncfr,0.089898\nmoby,0.02162\n') 47 | 48 | 49 | def test_multiclass_probability_estimator(): 50 | output = check_output( 51 | ['quantgov', 'ml', 'estimate', 52 | str(PSEUDO_ESTIMATOR_PATH.joinpath('data', 'multiclass.qge')), 53 | str(PSEUDO_CORPUS_PATH), '--probability'] 54 | ) 55 | assert output == ('file,class,probability\n' 56 | 'cfr,business-and-industry,0.1765\n' 57 | 'cfr,environment,0.1294\n' 58 | 'cfr,health-and-public-welfare,0.1785\n' 59 | 'cfr,money,0.169\n' 60 | 'cfr,science-and-technology,0.147\n' 61 | 'cfr,world,0.1997\n' 62 | 'moby,business-and-industry,0.1804\n' 63 | 'moby,environment,0.1529\n' 64 | 'moby,health-and-public-welfare,0.205\n' 65 | 'moby,money,0.1536\n' 66 | 'moby,science-and-technology,0.1671\n' 67 | 'moby,world,0.141\n') 68 | 69 | 70 | def test_multiclass_probability_oneclass_estimator(): 71 | output = check_output( 72 | ['quantgov', 'ml', 'estimate', 73 | str(PSEUDO_ESTIMATOR_PATH.joinpath('data', 'multiclass.qge')), 74 | str(PSEUDO_CORPUS_PATH), '--probability', '--oneclass'] 75 | ) 76 | assert output == ('file,class,probability\n' 77 | 'cfr,world,0.1997\n' 78 | 'moby,health-and-public-welfare,0.205\n') 79 | -------------------------------------------------------------------------------- /tests/test_nlp.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import quantgov.corpus 3 | import subprocess 4 | 5 | from pathlib import Path 6 | 7 | 8 | def build_recursive_directory_corpus(directory): 9 | for path, text in (('a/1.txt', 'foo'), ('b/2.txt', 'bar')): 10 | directory.join(path).write_text(text, encoding='utf-8', ensure=True) 11 | return quantgov.corpus.RecursiveDirectoryCorpusDriver( 12 | directory=str(directory), index_labels=('letter', 'number')) 13 | 14 | 15 | def build_name_pattern_corpus(directory): 16 | for path, text in (('a_1.txt', 'foo'), ('b_2.txt', 'bar')): 17 | path = directory.join(path).write_text( 18 | text, encoding='utf-8', ensure=True) 19 | return quantgov.corpus.NamePatternCorpusDriver( 20 | pattern=r'(?P[a-z])_(?P\d)', 21 | directory=str(directory) 22 | ) 23 | 24 | 25 | def build_index_corpus(directory): 26 | rows = [] 27 | for letter, number, path, text in ( 28 | ('a', '1', 'first.txt', 'foo'), 29 | ('b', '2', 'second.txt', 'bar') 30 | ): 31 | outpath = directory.join(path, abs=1) 32 | outpath.write_text(text, encoding='utf-8') 33 | rows.append((letter, number, str(outpath))) 34 | index_path = directory.join('index.csv') 35 | with index_path.open('w', encoding='utf-8') as outf: 36 | outf.write('letter,number,path\n') 37 | outf.write('\n'.join(','.join(row) for row in rows)) 38 | return quantgov.corpus.IndexDriver(str(index_path)) 39 | 40 | 41 | def build_s3_corpus(directory): 42 | rows = [] 43 | for letter, number, path in ( 44 | ('a', '1', 'quantgov_tests/first.txt'), 45 | ('b', '2', 'quantgov_tests/second.txt') 46 | ): 47 | rows.append((letter, number, path)) 48 | index_path = directory.join('index.csv') 49 | with index_path.open('w', encoding='utf-8') as outf: 50 | outf.write('letter,number,path\n') 51 | outf.write('\n'.join(','.join(row) for row in rows)) 52 | return quantgov.corpus.S3Driver(str(index_path), 53 | bucket='quantgov-databanks') 54 | 55 | 56 | BUILDERS = { 57 | 'RecursiveDirectoryCorpusDriver': build_recursive_directory_corpus, 58 | 'NamePatternCorpusDriver': build_name_pattern_corpus, 59 | 'IndexDriver': build_index_corpus, 60 | 'S3Driver': build_s3_corpus 61 | } 62 | 63 | 64 | @pytest.fixture(scope='module', params=list(BUILDERS.keys())) 65 | def corpus(request, tmpdir_factory): 66 | tmpdir = tmpdir_factory.mktemp(request.param, numbered=True) 67 | return BUILDERS[request.param](tmpdir) 68 | 69 | 70 | def test_index_labels(corpus): 71 | assert corpus.index_labels == ('letter', 'number') 72 | 73 | 74 | def test_simple_stream(corpus): 75 | served = tuple(corpus.stream()) 76 | assert served == ( 77 | (('a', '1'), 'foo'), 78 | (('b', '2'), 'bar') 79 | ) 80 | 81 | 82 | def test_corpus_streamer(corpus): 83 | streamer = corpus.get_streamer() 84 | served = [] 85 | for i in streamer: 86 | served.append(i) 87 | assert streamer.documents_streamed == len(served) 88 | assert not streamer.finished 89 | assert streamer.documents_streamed == len(served) 90 | assert streamer.finished 91 | assert tuple(served) == ( 92 | (('a', '1'), 'foo'), 93 | (('b', '2'), 'bar') 94 | ) 95 | assert streamer.index == [('a', '1'), ('b', '2')] 96 | 97 | 98 | PSEUDO_CORPUS_PATH = Path(__file__).resolve().parent.joinpath('pseudo_corpus') 99 | 100 | 101 | def check_output(cmd): 102 | return ( 103 | subprocess.check_output(cmd, universal_newlines=True) 104 | .replace('\n\n', '\n') 105 | ) 106 | 107 | 108 | def test_wordcount(): 109 | output = check_output( 110 | ['quantgov', 'nlp', 'count_words', str(PSEUDO_CORPUS_PATH)], 111 | ) 112 | assert output == 'file,words\ncfr,349153\nmoby,216645\n' 113 | 114 | 115 | def test_wordcount_pattern(): 116 | output = check_output( 117 | ['quantgov', 'nlp', 'count_words', str(PSEUDO_CORPUS_PATH), 118 | '--word_pattern', r'\S+'] 119 | ) 120 | assert output == 'file,words\ncfr,333237\nmoby,210130\n' 121 | 122 | 123 | def test_termcount(): 124 | output = check_output( 125 | ['quantgov', 'nlp', 'count_occurrences', str(PSEUDO_CORPUS_PATH), 126 | 'shall'], 127 | ) 128 | assert output == 'file,shall\ncfr,1946\nmoby,94\n' 129 | 130 | 131 | def test_termcount_multiple(): 132 | output = check_output( 133 | ['quantgov', 'nlp', 'count_occurrences', str(PSEUDO_CORPUS_PATH), 134 | 'shall', 'must', 'may not'], 135 | ) 136 | assert output == ('file,shall,must,may not\n' 137 | 'cfr,1946,744,122\nmoby,94,285,5\n') 138 | 139 | 140 | def test_termcount_multiple_with_label(): 141 | output = check_output( 142 | ['quantgov', 'nlp', 'count_occurrences', str(PSEUDO_CORPUS_PATH), 143 | 'shall', 'must', 'may not', '--total_label', 'allofthem'], 144 | ) 145 | assert output == ('file,shall,must,may not,allofthem\n' 146 | 'cfr,1946,744,122,2812\nmoby,94,285,5,384\n') 147 | 148 | 149 | def test_shannon_entropy(): 150 | output = check_output( 151 | ['quantgov', 'nlp', 'shannon_entropy', str(PSEUDO_CORPUS_PATH)], 152 | ) 153 | assert output == 'file,shannon_entropy\ncfr,10.71\nmoby,11.81\n' 154 | 155 | 156 | def test_shannon_entropy_no_stopwords(): 157 | output = check_output( 158 | ['quantgov', 'nlp', 'shannon_entropy', str(PSEUDO_CORPUS_PATH), 159 | '--stopwords', 'None'], 160 | ) 161 | assert output == 'file,shannon_entropy\ncfr,9.52\nmoby,10.03\n' 162 | 163 | 164 | def test_shannon_entropy_4decimals(): 165 | output = check_output( 166 | ['quantgov', 'nlp', 'shannon_entropy', str(PSEUDO_CORPUS_PATH), 167 | '--precision', '4'], 168 | ) 169 | assert output == 'file,shannon_entropy\ncfr,10.7127\nmoby,11.813\n' 170 | 171 | 172 | def test_conditionalcount(): 173 | output = check_output( 174 | ['quantgov', 'nlp', 'count_conditionals', str(PSEUDO_CORPUS_PATH)], 175 | ) 176 | assert output == 'file,conditionals\ncfr,2132\nmoby,2374\n' 177 | 178 | 179 | def test_sentencelength(): 180 | output = check_output( 181 | ['quantgov', 'nlp', 'sentence_length', str(PSEUDO_CORPUS_PATH)], 182 | ) 183 | assert output == 'file,sentence_length\ncfr,18.68\nmoby,25.09\n' 184 | 185 | 186 | def test_sentiment_analysis(): 187 | output = check_output( 188 | ['quantgov', 'nlp', 'sentiment_analysis', str(PSEUDO_CORPUS_PATH)], 189 | ) 190 | assert output == ('file,sentiment_polarity,sentiment_subjectivity' 191 | '\ncfr,0.01,0.42\nmoby,0.08,0.48\n') 192 | 193 | 194 | def test_sentiment_analysis_4decimals(): 195 | output = check_output( 196 | ['quantgov', 'nlp', 'sentiment_analysis', str(PSEUDO_CORPUS_PATH), 197 | '--precision', '4'], 198 | ) 199 | assert output == ('file,sentiment_polarity,sentiment_subjectivity' 200 | '\ncfr,0.0114,0.421\nmoby,0.0816,0.4777\n') 201 | --------------------------------------------------------------------------------