├── README.md ├── dtc.png ├── .gitignore └── oof_model.ipynb /README.md: -------------------------------------------------------------------------------- 1 | # pumpitup 2 | Pump it Up: Data Mining the Water Table 3 | -------------------------------------------------------------------------------- /dtc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sagol/pumpitup/HEAD/dtc.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /oof_model.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "ExecuteTime": { 8 | "end_time": "2021-02-14T07:31:53.296831Z", 9 | "start_time": "2021-02-14T07:31:52.347275Z" 10 | } 11 | }, 12 | "outputs": [], 13 | "source": [ 14 | "import numpy as np\n", 15 | "import pandas as pd\n", 16 | "import catboost\n", 17 | "import time\n", 18 | "\n", 19 | "from catboost import Pool, sum_models\n", 20 | "from sklearn.model_selection import train_test_split\n", 21 | "from sklearn.metrics import balanced_accuracy_score\n", 22 | "from catboost import CatBoostClassifier\n", 23 | "from sklearn.model_selection import StratifiedKFold" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": null, 29 | "metadata": { 30 | "ExecuteTime": { 31 | "end_time": "2021-02-14T07:31:53.302889Z", 32 | "start_time": "2021-02-14T07:31:53.299419Z" 33 | } 34 | }, 35 | "outputs": [], 36 | "source": [ 37 | "FOLDS = 10\n", 38 | "SEEDS = [0, 42, 888, 1042, 8888]\n", 39 | "VERSION = round(time.time())" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": null, 45 | "metadata": { 46 | "ExecuteTime": { 47 | "end_time": "2021-02-14T07:31:53.967930Z", 48 | "start_time": "2021-02-14T07:31:53.305408Z" 49 | } 50 | }, 51 | "outputs": [], 52 | "source": [ 53 | "df_train_set = pd.read_csv('4910797b-ee55-40a7-8668-10efd5c1b960.csv', index_col='id')\n", 54 | "df_train_labels = pd.read_csv('0bf8bc6e-30d0-4c50-956a-603fc693d966.csv', index_col='id')\n", 55 | "df_test = pd.read_csv('702ddfc5-68cd-4d1d-a0de-f5f566f76d91.csv', index_col='id')\n" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": null, 61 | "metadata": { 62 | "ExecuteTime": { 63 | "end_time": "2021-02-14T07:31:53.999664Z", 64 | "start_time": "2021-02-14T07:31:53.970432Z" 65 | }, 66 | "code_folding": [ 67 | 0, 68 | 9, 69 | 20 70 | ] 71 | }, 72 | "outputs": [], 73 | "source": [ 74 | "def clean_installer(df):\n", 75 | "\n", 76 | " df['installer'] = df['installer'].astype(str).str.lower()\n", 77 | " df['installer'].replace(\n", 78 | " to_replace=(\n", 79 | " 'fini water', 'fin water', 'finn water', 'finwater', 'finwate'),\n", 80 | " value='finw', inplace=True)\n", 81 | " df['installer'].replace(to_replace=('jaica co'), value='jaica', inplace=True)\n", 82 | " df['installer'].replace(\n", 83 | " to_replace=(\n", 84 | " 'district water department', 'district water depar', 'district council',\n", 85 | " 'district counci', 'village council orpha','kibaha town council',\n", 86 | " 'village council', 'coun', 'village counil', 'council',\n", 87 | " 'mbulu district council', 'counc', 'village council .oda',\n", 88 | " 'sangea district coun', 'songea district coun', 'villege council',\n", 89 | " 'district council', 'quick win project /council', 'mbozi district council',\n", 90 | " 'village council', 'municipal council', 'tabora municipal council',\n", 91 | " 'wb / district council'),\n", 92 | " value='council', inplace=True)\n", 93 | " df['installer'].replace(\n", 94 | " to_replace=(\n", 95 | " 'rc church', 'rc churc', 'rcchurch/cefa', 'irc', 'rc', 'rc ch', 'hw/rc',\n", 96 | " 'rc church/central gover', 'kkkt church', 'pentecost church', 'roman church',\n", 97 | " 'rc/mission', 'rc church/cefa', 'lutheran church', 'tag church',\n", 98 | " 'free pentecoste church of tanz', 'rc c', 'church', 'rc cathoric',\n", 99 | " 'morovian church', 'cefa/rc church', 'rc mission', 'anglican church',\n", 100 | " 'church of disciples', 'anglikana church', 'cetral government /rc',\n", 101 | " 'pentecostal church', 'cg/rc', 'rc missionary', 'sda church', 'methodist church', 'trc',\n", 102 | " 'rc msufi', 'haidomu lutheran church', 'baptist church', 'rc church brother',\n", 103 | " 'st magreth church', 'anglica church', 'global resource co', 'rc mi',\n", 104 | " 'baptist church of tanzania', 'fpct church', 'rc njoro', 'rc .church',\n", 105 | " 'rc mis', 'batist church', 'churc', 'dwe/anglican church','missi', 'mission',\n", 106 | " 'ndanda missions', 'rc/mission', 'cvs miss', 'missionaries', 'hydom luthelani',\n", 107 | " 'luthe', 'haydom lutheran hospital', 'lutheran', 'missio', 'germany missionary',\n", 108 | " 'grail mission kiseki bar', 'missionary', 'heri mission', 'german missionsry',\n", 109 | " 'wamissionari wa kikatoriki', 'neemia mission', 'wamisionari wa kikatoriki'),\n", 110 | " value='church', inplace=True)\n", 111 | " df['installer'].replace(\n", 112 | " to_replace=(\n", 113 | " 'central government', 'gove', 'central govt', 'gover', 'cipro/government',\n", 114 | " 'governme', 'adra /government', 'isf/government', 'adra/government',\n", 115 | " 'government /tcrs', 'village govt', 'government', 'government /community',\n", 116 | " 'concern /government', 'goverm', 'village government', 'cental government',\n", 117 | " 'govern', 'cebtral government', 'government /sda', 'tcrs /government',\n", 118 | " 'tanzania government', 'centra govt', 'colonial government', 'misri government',\n", 119 | " 'government and community', 'cetral government /rc', 'concern/government',\n", 120 | " 'government of misri', 'lwi ¢ral government', 'governmen', 'government/tcrs', 'government /world vision',\n", 121 | " 'centra government'),\n", 122 | " value='tanzanian government', inplace=True)\n", 123 | " df['installer'].replace(\n", 124 | " to_replace=(\n", 125 | " 'world vission', 'world division', 'word divisio','world visiin'),\n", 126 | " value='world vision', inplace=True)\n", 127 | " df['installer'].replace(to_replace=('unicrf'), value='unicef', inplace=True)\n", 128 | " df['installer'].replace(\n", 129 | " to_replace=(\n", 130 | " 'commu', 'olgilai village community', 'adra /community', 'adra/community',\n", 131 | " 'rwe/ community', 'killflora /community', 'communit', 'taboma/community',\n", 132 | " 'arab community', 'adra/ community', 'sekei village community', 'rwe/community',\n", 133 | " 'arabs community', 'village community', 'government /community',\n", 134 | " 'dads/village community', 'killflora/ community', 'mtuwasa and community',\n", 135 | " 'rwe /community', 'ilwilo community', 'summit for water/community',\n", 136 | " 'igolola community', 'ngiresi village community', 'rwe community',\n", 137 | " 'african realief committe of ku', 'twesa /community', 'shelisheli commission',\n", 138 | " 'twesa/ community', 'marumbo community', 'government and community',\n", 139 | " 'community bank', 'kitiangare village community', 'oldadai village community',\n", 140 | " 'twesa/community', 'tlc/community', 'maseka community', 'islamic community',\n", 141 | " 'district community j', 'village water commission', 'village community members',\n", 142 | " 'tcrs/village community', 'village water committee', 'comunity'),\n", 143 | " value='community', inplace=True)\n", 144 | " df['installer'].replace(\n", 145 | " to_replace=(\n", 146 | " 'danid', 'danda','danida co', 'danny', 'daniad', 'dannida', 'danids'),\n", 147 | " value='danida', inplace=True)\n", 148 | " df['installer'].replace(\n", 149 | " to_replace=(\n", 150 | " 'hesaws', 'huches', 'hesaw', 'hesawz', 'hesawq', 'hesewa'),\n", 151 | " value='hesawa', inplace=True)\n", 152 | " df['installer'].replace(\n", 153 | " to_replace=(\n", 154 | " 'dwsp', 'kkkt _ konde and dwe', 'rwe/dwe', 'rwedwe', 'dwe/', 'dw', 'dwr',\n", 155 | " 'dwe}', 'dwt', 'dwe /tassaf', 'dwe/ubalozi wa marekani', 'consultant and dwe',\n", 156 | " 'dwe & lwi', 'ubalozi wa marekani /dwe', 'dwe&', 'dwe/tassaf', 'dw$',\n", 157 | " 'dw e', 'tcrs/dwe', 'dw#', 'dweb', 'tcrs /dwe', 'water aid/dwe', 'dww'),\n", 158 | " value='dwe', inplace=True)\n", 159 | " df['installer'].replace(\n", 160 | " to_replace=(\n", 161 | " 'africa muslim', 'muslimu society(shia)', 'africa muslim agenc',\n", 162 | " 'african muslims age', 'muslimehefen international','islamic',\n", 163 | " 'the isla', 'islamic agency tanzania', 'islam', 'nyabibuye islamic center'),\n", 164 | " value='muslims', inplace=True)\n", 165 | " df['installer'].replace(\n", 166 | " to_replace=(\n", 167 | " 'british colonial government', 'british government', 'britain'),\n", 168 | " value='british', inplace=True)\n", 169 | " df['installer'].replace(\n", 170 | " to_replace=(\n", 171 | " 'tcrs/tlc', 'tcrs /care', 'cipro/care/tcrs', 'tcrs kibondo', 'tcrs.tlc',\n", 172 | " 'tcrs /twesa', 'tassaf /tcrs', 'tcrs/care', 'tcrs twesa', 'rwe/tcrs',\n", 173 | " 'tcrs/twesa', 'tassaf/ tcrs', 'tcrs/ tassaf', 'tcrs/ twesa', 'tcrs a',\n", 174 | " 'tassaf/tcrs'),\n", 175 | " value='tcrs', inplace=True)\n", 176 | " df['installer'].replace(\n", 177 | " to_replace=(\n", 178 | " 'kkkt-dioces ya pare', 'kkkt leguruki', 'kkkt ndrumangeni', 'kkkt dme',\n", 179 | " 'kkkt kilinga', 'kkkt canal', 'kkkt katiti juu', 'kkkt mareu'),\n", 180 | " value='kkkt', inplace=True)\n", 181 | " df['installer'].replace(to_replace=('norad/'), value='norad', inplace=True)\n", 182 | " df['installer'].replace( to_replace=('tasaf/dmdd', 'dmdd/solider'),\n", 183 | " value='dmdd', inplace=True)\n", 184 | " df['installer'].replace(\n", 185 | " to_replace=('cjejow construction', 'cjej0'), value='cjejow', inplace=True)\n", 186 | " df['installer'].replace(\n", 187 | " to_replace=(\n", 188 | " 'china henan constuction', 'china henan contractor', 'china co.', 'chinese'),\n", 189 | " value='china', inplace=True)\n", 190 | " df['installer'].replace(\n", 191 | " to_replace=(\n", 192 | " 'local contract', 'local technician', 'local', 'local technician',\n", 193 | " 'locall technician', 'local te', 'local technitian', 'local technical tec',\n", 194 | " 'local fundi', 'local technical', 'localtechnician', 'village local contractor',\n", 195 | " 'local l technician'),\n", 196 | " value='local', inplace=True)\n", 197 | " df['installer'].replace(\n", 198 | " to_replace=(\n", 199 | " 'oikos e .africa', 'oikos e.africa', 'africa amini alama',\n", 200 | " 'africa islamic agency tanzania', 'africare', 'african development foundation',\n", 201 | " 'oikos e. africa', 'oikos e.afrika', 'afroz ismail', 'africa', 'farm-africa',\n", 202 | " 'oikos e africa', 'farm africa', 'africaone', 'tina/africare', 'africaone ltd',\n", 203 | " 'african reflections foundation', 'africa m'),\n", 204 | " value='africa', inplace=True)\n", 205 | " df['installer'].replace(to_replace=('0', 'nan', '-'), value='other', inplace=True)\n", 206 | " df_installer_cnt = df.groupby('installer')['installer'].count()\n", 207 | " other_list = df_installer_cnt[df_installer_cnt<71].index.tolist()\n", 208 | " df['installer'].replace(to_replace=other_list, value='other', inplace=True)" 209 | ] 210 | }, 211 | { 212 | "cell_type": "code", 213 | "execution_count": null, 214 | "metadata": { 215 | "ExecuteTime": { 216 | "end_time": "2021-02-14T07:31:54.155285Z", 217 | "start_time": "2021-02-14T07:31:54.001638Z" 218 | }, 219 | "code_folding": [ 220 | 0 221 | ] 222 | }, 223 | "outputs": [], 224 | "source": [ 225 | "def clean_funder(df):\n", 226 | " \n", 227 | " df['funder'] = df['funder'].astype(str).str.lower()\n", 228 | " df['funder'].replace(\n", 229 | " to_replace=(\n", 230 | " 'kkkt_makwale', 'kkkt-dioces ya pare', 'world vision/ kkkt', 'kkkt church',\n", 231 | " 'kkkt leguruki', 'kkkt ndrumangeni', 'kkkt dme', 'kkkt canal', 'kkkt usa',\n", 232 | " 'kkkt mareu'),\n", 233 | " value='kkkt', inplace=True)\n", 234 | " df['funder'].replace(\n", 235 | " to_replace=(\n", 236 | " 'government of tanzania', 'norad /government', 'government/ community',\n", 237 | " 'cipro/government', 'isf/government', 'finidagermantanzania govt',\n", 238 | " 'government /tassaf', 'finida german tanzania govt', 'village government',\n", 239 | " 'tcrs /government', 'village govt', 'government/ world bank',\n", 240 | " 'danida /government', 'dhv/gove', 'concern /govern', 'vgovernment',\n", 241 | " 'lwi & central government', 'government /sda', 'koica and tanzania government',\n", 242 | " 'world bank/government', 'colonial government', 'misri government',\n", 243 | " 'government and community', 'concern/governm', 'government of misri',\n", 244 | " 'government/tassaf', 'government/school', 'government/tcrs', 'unhcr/government',\n", 245 | " 'government /world vision', 'norad/government'),\n", 246 | " value='government', inplace=True)\n", 247 | " df['funder'].replace(\n", 248 | " to_replace=(\n", 249 | " 'british colonial government', 'japan government', 'china government',\n", 250 | " 'finland government', 'belgian government', 'italy government',\n", 251 | " 'irish government', 'egypt government', 'iran gover', 'swedish', 'finland'),\n", 252 | " value='foreign government', inplace=True)\n", 253 | " df['funder'].replace(\n", 254 | " to_replace=(\n", 255 | " 'rc church', 'anglican church', 'rc churc', 'rc ch', 'rcchurch/cefa',\n", 256 | " 'irc', 'rc', 'churc', 'hw/rc', 'rc church/centr', 'pentecosta church',\n", 257 | " 'roman church', 'rc/mission', \"ju-sarang church' and bugango\",\n", 258 | " 'lutheran church', 'roman cathoric church', 'tag church ub', 'aic church',\n", 259 | " 'free pentecoste church of tanz', 'tag church', 'fpct church', 'rc cathoric',\n", 260 | " 'baptist church', 'morovian church', 'cefa/rcchurch', 'rc mission',\n", 261 | " 'bukwang church saints', 'agt church', 'church of disciples', 'rc mofu',\n", 262 | " \"gil cafe'church'\", 'pentecostal church', 'bukwang church saint',\n", 263 | " 'eung am methodist church', 'rc/dwe', 'cg/rc', 'eung-am methodist church',\n", 264 | " 'rc missionary', 'sda church', 'methodist church', 'rc msufi',\n", 265 | " 'haidomu lutheran church', 'nazareth church', 'st magreth church',\n", 266 | " 'agape churc', 'rc missi', 'rc mi', 'rc njoro', 'world vision/rc church',\n", 267 | " 'pag church', 'batist church', 'full gospel church', 'nazalet church',\n", 268 | " 'dwe/anglican church', 'missi', 'mission', 'missionaries', 'cpps mission',\n", 269 | " 'cvs miss', 'grail mission kiseki bar', 'shelisheli commission', 'missionary',\n", 270 | " 'heri mission', 'german missionary', 'wamissionari wa kikatoriki',\n", 271 | " 'rc missionary', 'germany missionary', 'missio', 'neemia mission', 'rc missi',\n", 272 | " 'hydom luthelani', 'luthe', 'lutheran church', 'haydom lutheran hospital',\n", 273 | " 'village council/ haydom luther', 'lutheran', 'haidomu lutheran church',\n", 274 | " 'resolute golden pride project', 'resolute mininggolden pride',\n", 275 | " 'germany cristians'),\n", 276 | " value='church', inplace=True)\n", 277 | " df['funder'].replace(\n", 278 | " to_replace=(\n", 279 | " 'olgilai village community', 'commu', 'community', 'arab community',\n", 280 | " 'sekei village community', 'arabs community', 'village community',\n", 281 | " 'mtuwasa and community', 'ilwilo community', 'igolola community',\n", 282 | " 'ngiresi village community', 'marumbo community', 'village communi',\n", 283 | " 'comune di roma', 'comunity construction fund', 'community bank',\n", 284 | " \"oak'zion' and bugango b' commu\", 'kitiangare village community',\n", 285 | " 'oldadai village community', 'tlc/community', 'maseka community',\n", 286 | " 'islamic community', 'tcrs/village community', 'buluga subvillage community',\n", 287 | " 'okutu village community'),\n", 288 | " value='community', inplace=True)\n", 289 | " df['funder'].replace(\n", 290 | " to_replace=(\n", 291 | " 'council', 'wb / district council', 'cdtfdistrict council',\n", 292 | " 'sangea district council', 'mheza distric counc', 'kyela council',\n", 293 | " 'kibaha town council', 'swidish', 'mbozi district council', \n", 294 | " 'village council/ rose kawala', 'songea municipal counci',\n", 295 | " 'quick win project /council', 'village council', 'villege council',\n", 296 | " 'tabora municipal council', 'kilindi district co', 'kigoma municipal council',\n", 297 | " 'district council', 'municipal council', 'district medical',\n", 298 | " 'sengerema district council', 'town council', 'mkinga distric cou',\n", 299 | " 'songea district council', 'district rural project', 'mkinga distric coun',\n", 300 | " 'dadis'),\n", 301 | " value='district', inplace=True)\n", 302 | " df['funder'].replace(\n", 303 | " to_replace=(\n", 304 | " 'tcrs.tlc', 'tcrs /care', 'tcrst', 'cipro/care/tcrs', 'tcrs/care', 'tcrs kibondo'),\n", 305 | " value='tcrs', inplace=True)\n", 306 | " df['funder'].replace(\n", 307 | " to_replace=(\n", 308 | " 'fini water', 'finw', 'fin water', 'finn water', 'finwater'),\n", 309 | " value='fini', inplace=True)\n", 310 | " df['funder'].replace(\n", 311 | " to_replace=(\n", 312 | " 'islamic', 'the isla', 'islamic found', 'islamic agency tanzania',\n", 313 | " 'islam', 'muislam', 'the islamic', 'nyabibuye islamic center', 'islamic society', 'african muslim agency',\n", 314 | " 'muslims', 'answeer muslim grou', 'muslimu society(shia)',\n", 315 | " 'unicef/african muslim agency', 'muslim world', 'muslimehefen international',\n", 316 | " 'shear muslim', 'muslim society'),\n", 317 | " value='islam', inplace=True)\n", 318 | " df['funder'].replace(\n", 319 | " to_replace=('danida', 'ms-danish', 'unhcr/danida', 'tassaf/ danida'),\n", 320 | " value='danida', inplace=True)\n", 321 | " df['funder'].replace(\n", 322 | " to_replace=(\n", 323 | " 'hesawa', 'hesawz', 'hesaw', 'hhesawa', 'hesawwa', 'hesawza', 'hesswa',\n", 324 | " 'hesawa and concern world wide'),\n", 325 | " value='hesawa', inplace=True)\n", 326 | " df['funder'].replace(\n", 327 | " to_replace=('world vision/adra', 'game division', 'worldvision'),\n", 328 | " value='world vision', inplace=True)\n", 329 | " df['funder'].replace(\n", 330 | " to_replace=(\n", 331 | " 'germany republi', 'a/co germany', 'aco/germany', 'bingo foundation germany',\n", 332 | " 'africa project ev germany', 'tree ways german'),\n", 333 | " value='germany', inplace=True)\n", 334 | " df['funder'].replace(to_replace=('0', 'nan', '-'), value='other', inplace=True)\n", 335 | " df_funder_cnt = df.groupby('funder')['funder'].count()\n", 336 | " other_list = df_funder_cnt[df_funder_cnt<98].index.tolist()\n", 337 | " df['funder'].replace(to_replace=other_list, value='other', inplace=True) " 338 | ] 339 | }, 340 | { 341 | "cell_type": "code", 342 | "execution_count": null, 343 | "metadata": { 344 | "ExecuteTime": { 345 | "end_time": "2021-02-14T07:31:54.306179Z", 346 | "start_time": "2021-02-14T07:31:54.158249Z" 347 | } 348 | }, 349 | "outputs": [], 350 | "source": [ 351 | "def get_medians_df(df):\n", 352 | " \n", 353 | " df_geo = df.groupby(['region_code'])[['latitude', 'longitude']].median()\n", 354 | " df_subvillage = df.groupby(['region_code'])['subvillage'].agg(pd.Series.mode)\n", 355 | " df_scheme = df.groupby(['region'])['scheme_name'].agg(pd.Series.mode)\n", 356 | " return df_geo, df_subvillage, df_scheme\n", 357 | " \n", 358 | " \n", 359 | "def geo_restore(df, df_geo):\n", 360 | " \n", 361 | " def geo_update(row, df_geo):\n", 362 | " row['longitude'] = df_geo.loc[row['region_code']]['longitude']\n", 363 | " row['latitude'] = df_geo.loc[row['region_code']]['latitude']\n", 364 | " return row\n", 365 | "\n", 366 | " df.loc[df['longitude']==0, ['longitude', 'latitude']] = \\\n", 367 | " df[df['longitude']==0].apply(\n", 368 | " geo_update, df_geo=df_geo, axis=1)[['longitude', 'latitude']]\n", 369 | " \n", 370 | " \n", 371 | "def scheme_restore(df, df_scheme):\n", 372 | " \n", 373 | " def scheme_update(row, df_scheme):\n", 374 | " row['scheme_name'] = df_scheme[row['region']]\n", 375 | " return row\n", 376 | "\n", 377 | " df.loc[df['scheme_name'].isnull(), ['scheme_name']] = \\\n", 378 | " df[df['scheme_name'].isnull()].apply(\n", 379 | " scheme_update, df_scheme=df_scheme, axis=1)[['scheme_name']]\n", 380 | " \n", 381 | " \n", 382 | "def subvillage_restore(df, df_subvillage):\n", 383 | "\n", 384 | " def subvillage_update(row, df_subvillage):\n", 385 | " row['subvillage'] = df_subvillage[row['region_code']]\n", 386 | " return row\n", 387 | "\n", 388 | " df.loc[df['subvillage'].isnull(), ['subvillage']] = \\\n", 389 | " df[df['subvillage'].isnull()].apply(\n", 390 | " subvillage_update, df_subvillage=df_subvillage, axis=1)[['subvillage']]\n", 391 | " \n", 392 | " \n", 393 | "def get_medians(df):\n", 394 | " \n", 395 | " df_pm_median = df['public_meeting'].median()\n", 396 | " df_permit_median = df['permit'].median()\n", 397 | " return df_pm_median, df_permit_median\n", 398 | " \n", 399 | " \n", 400 | "def fill_na(df, df_pm_median, df_permit_median):\n", 401 | " \n", 402 | " df.loc[df['public_meeting'].isnull(), 'public_meeting'] = df_pm_median\n", 403 | " df.loc[df['permit'].isnull(), 'permit'] = df_permit_median\n", 404 | " \n", 405 | " \n", 406 | "def create_na_features(df): \n", 407 | " \n", 408 | " na_cols = ['subvillage', 'public_meeting', 'scheme_name', 'permit']\n", 409 | " for c in na_cols:\n", 410 | " df[f'c_na'] = df[c].isnull() \n" 411 | ] 412 | }, 413 | { 414 | "cell_type": "code", 415 | "execution_count": null, 416 | "metadata": { 417 | "ExecuteTime": { 418 | "end_time": "2021-02-14T07:31:54.430426Z", 419 | "start_time": "2021-02-14T07:31:54.309251Z" 420 | } 421 | }, 422 | "outputs": [], 423 | "source": [ 424 | "def drop(df):\n", 425 | " \n", 426 | " df.drop([\n", 427 | " 'scheme_management', 'quantity_group', 'water_quality', 'region_code', 'payment_type',\n", 428 | " 'extraction_type', 'waterpoint_type_group', 'date_recorded', 'recorded_by'],\n", 429 | " axis=1, inplace=True)\n" 430 | ] 431 | }, 432 | { 433 | "cell_type": "code", 434 | "execution_count": null, 435 | "metadata": { 436 | "ExecuteTime": { 437 | "end_time": "2021-02-14T07:32:10.977974Z", 438 | "start_time": "2021-02-14T07:31:54.436118Z" 439 | } 440 | }, 441 | "outputs": [], 442 | "source": [ 443 | "df_train = df_train_set.join(df_train_labels)\n", 444 | "clean_installer(df_train)\n", 445 | "clean_funder(df_train)\n", 446 | "df_geo, df_subvillage, df_scheme = get_medians_df(df_train)\n", 447 | "geo_restore(df_train, df_geo)\n", 448 | "subvillage_restore(df_train, df_subvillage)\n", 449 | "scheme_restore(df_train, df_scheme)\n", 450 | "create_na_features(df_train)\n", 451 | "df_pm_median, df_permit_median = get_medians(df_train)\n", 452 | "fill_na(df_train, df_pm_median, df_permit_median)\n", 453 | "drop(df_train)\n", 454 | "\n", 455 | "clean_installer(df_test)\n", 456 | "clean_funder(df_test)\n", 457 | "geo_restore(df_test, df_geo)\n", 458 | "subvillage_restore(df_test, df_subvillage)\n", 459 | "scheme_restore(df_test, df_scheme)\n", 460 | "create_na_features(df_test)\n", 461 | "fill_na(df_test, df_pm_median, df_permit_median)\n", 462 | "drop(df_test)" 463 | ] 464 | }, 465 | { 466 | "cell_type": "code", 467 | "execution_count": null, 468 | "metadata": { 469 | "ExecuteTime": { 470 | "end_time": "2021-02-14T07:32:10.985356Z", 471 | "start_time": "2021-02-14T07:32:10.980829Z" 472 | } 473 | }, 474 | "outputs": [], 475 | "source": [ 476 | "def fit_model(train_pool, test_pool, **kwargs):\n", 477 | " model = CatBoostClassifier(\n", 478 | " max_ctr_complexity=5,\n", 479 | " task_type='GPU',\n", 480 | " iterations=10000,\n", 481 | " eval_metric='AUC',\n", 482 | " od_type='Iter',\n", 483 | " od_wait=500,\n", 484 | " **kwargs\n", 485 | " )\n", 486 | "\n", 487 | " return model.fit(\n", 488 | " train_pool,\n", 489 | " eval_set=test_pool,\n", 490 | " verbose=1000,\n", 491 | " plot=False,\n", 492 | " use_best_model=True)" 493 | ] 494 | }, 495 | { 496 | "cell_type": "code", 497 | "execution_count": null, 498 | "metadata": { 499 | "ExecuteTime": { 500 | "end_time": "2021-02-14T07:32:11.555284Z", 501 | "start_time": "2021-02-14T07:32:10.987291Z" 502 | } 503 | }, 504 | "outputs": [], 505 | "source": [ 506 | "num_cols = [\n", 507 | " 'amount_tsh', 'gps_height', 'longitude', 'latitude', 'num_private',\n", 508 | " 'district_code', 'population'\n", 509 | "]\n", 510 | "cat_features = [x for x in df_train.columns if x not in num_cols and x!='status_group']\n", 511 | "df_train[cat_features] = df_train[cat_features].astype('category')\n", 512 | "df_test[cat_features] = df_test[cat_features].astype('category')" 513 | ] 514 | }, 515 | { 516 | "cell_type": "code", 517 | "execution_count": null, 518 | "metadata": { 519 | "ExecuteTime": { 520 | "end_time": "2021-02-14T07:32:11.560948Z", 521 | "start_time": "2021-02-14T07:32:11.557503Z" 522 | } 523 | }, 524 | "outputs": [], 525 | "source": [ 526 | "def classification_rate(y, y_pred):\n", 527 | " return np.sum(y==y_pred)/len(y)" 528 | ] 529 | }, 530 | { 531 | "cell_type": "code", 532 | "execution_count": null, 533 | "metadata": { 534 | "ExecuteTime": { 535 | "end_time": "2021-02-14T07:32:11.679175Z", 536 | "start_time": "2021-02-14T07:32:11.562948Z" 537 | } 538 | }, 539 | "outputs": [], 540 | "source": [ 541 | "def get_oof(n_folds, x_train, y, x_test, cat_features, seeds):\n", 542 | " \n", 543 | " ntrain = x_train.shape[0]\n", 544 | " ntest = x_test.shape[0] \n", 545 | " \n", 546 | " oof_train = np.zeros((len(seeds), ntrain, 3))\n", 547 | " oof_test = np.zeros((ntest, 3))\n", 548 | " oof_test_skf = np.empty((len(seeds), n_folds, ntest, 3))\n", 549 | " \n", 550 | " test_pool = Pool(data=x_test, cat_features=cat_features) \n", 551 | " models = {}\n", 552 | " \n", 553 | " for iseed, seed in enumerate(seeds):\n", 554 | " kf = StratifiedKFold(\n", 555 | " n_splits=n_folds,\n", 556 | " shuffle=True,\n", 557 | " random_state=seed) \n", 558 | " for i, (train_index, test_index) in enumerate(kf.split(x_train, y)):\n", 559 | " print(f'\\nSeed {seed}, Fold {i}')\n", 560 | " x_tr = x_train.iloc[train_index, :]\n", 561 | " y_tr = y[train_index]\n", 562 | " x_te = x_train.iloc[test_index, :]\n", 563 | " y_te = y[test_index]\n", 564 | " train_pool = Pool(data=x_tr, label=y_tr, cat_features=cat_features)\n", 565 | " valid_pool = Pool(data=x_te, label=y_te, cat_features=cat_features) \n", 566 | "\n", 567 | " model = fit_model(\n", 568 | " train_pool, valid_pool,\n", 569 | " loss_function='MultiClass',\n", 570 | " random_seed=seed\n", 571 | " )\n", 572 | " oof_train[iseed, test_index, :] = model.predict_proba(x_te)\n", 573 | " oof_test_skf[iseed, i, :, :] = model.predict_proba(x_test)\n", 574 | " models[(seed, i)] = model\n", 575 | " model.save_model(\n", 576 | " f\"cb_{seed}_{i}_{VERSION}.cbm\",\n", 577 | " format=\"cbm\", export_parameters=None, pool=None)\n", 578 | "\n", 579 | " oof_test[:, :] = oof_test_skf.mean(axis=1).mean(axis=0)\n", 580 | " oof_train = oof_train.mean(axis=0)\n", 581 | " return oof_train, oof_test, models" 582 | ] 583 | }, 584 | { 585 | "cell_type": "code", 586 | "execution_count": null, 587 | "metadata": { 588 | "ExecuteTime": { 589 | "end_time": "2021-02-14T08:21:03.211020Z", 590 | "start_time": "2021-02-14T07:32:11.682129Z" 591 | } 592 | }, 593 | "outputs": [], 594 | "source": [ 595 | "y_train, X_train = df_train['status_group'], df_train.drop(['status_group'], axis=1)\n", 596 | "\n", 597 | "oof_train, oof_test, models = get_oof(\n", 598 | " n_folds=FOLDS,\n", 599 | " x_train=X_train,\n", 600 | " y=y_train.values,\n", 601 | " x_test=df_test,\n", 602 | " cat_features=cat_features,\n", 603 | " seeds=SEEDS)" 604 | ] 605 | }, 606 | { 607 | "cell_type": "code", 608 | "execution_count": null, 609 | "metadata": { 610 | "ExecuteTime": { 611 | "end_time": "2021-02-14T08:33:43.662078Z", 612 | "start_time": "2021-02-14T08:33:43.059956Z" 613 | } 614 | }, 615 | "outputs": [], 616 | "source": [ 617 | "m = models[(SEEDS[0],0)]\n", 618 | "fea_imp = pd.DataFrame({'importance': m.feature_importances_,\n", 619 | " 'col': m.feature_names_})\n", 620 | "fea_imp = fea_imp.sort_values(['importance', 'col'],\n", 621 | " ascending=[True, False]).iloc[-40:]\n", 622 | "fea_imp.plot(kind='barh', x='col', y='importance', figsize=(20, 20))" 623 | ] 624 | }, 625 | { 626 | "cell_type": "code", 627 | "execution_count": null, 628 | "metadata": { 629 | "ExecuteTime": { 630 | "end_time": "2021-02-14T08:21:04.221945Z", 631 | "start_time": "2021-02-14T08:21:03.213161Z" 632 | } 633 | }, 634 | "outputs": [], 635 | "source": [ 636 | "y_pred_train = np.array([models[(SEEDS[0],0)].classes_[x] for x in oof_train.argmax(axis=1)])\n", 637 | "print(f\"balanced accuracy: {balanced_accuracy_score(y_train, y_pred_train)}\")\n", 638 | "class_rate = classification_rate(y_train, y_pred_train)\n", 639 | "print(f\"classification rate: {class_rate}\")" 640 | ] 641 | }, 642 | { 643 | "cell_type": "code", 644 | "execution_count": null, 645 | "metadata": { 646 | "ExecuteTime": { 647 | "end_time": "2021-02-14T08:21:04.563161Z", 648 | "start_time": "2021-02-14T08:21:04.228071Z" 649 | } 650 | }, 651 | "outputs": [], 652 | "source": [ 653 | "y_pred = np.array([models[(SEEDS[0],0)].classes_[x] for x in oof_test.argmax(axis=1)])\n", 654 | "sub = pd.read_csv(\"SubmissionFormat.csv\", index_col='id')\n", 655 | "sub['status_group'] = y_pred\n", 656 | "sub.to_csv(f\"sub_{round(class_rate, 4)}.csv\", index=True)" 657 | ] 658 | } 659 | ], 660 | "metadata": { 661 | "kernelspec": { 662 | "display_name": "Python 3", 663 | "language": "python", 664 | "name": "python3" 665 | }, 666 | "language_info": { 667 | "codemirror_mode": { 668 | "name": "ipython", 669 | "version": 3 670 | }, 671 | "file_extension": ".py", 672 | "mimetype": "text/x-python", 673 | "name": "python", 674 | "nbconvert_exporter": "python", 675 | "pygments_lexer": "ipython3", 676 | "version": "3.8.6" 677 | } 678 | }, 679 | "nbformat": 4, 680 | "nbformat_minor": 4 681 | } 682 | --------------------------------------------------------------------------------