├── Readme.md └── baseline_solution.ipynb /Readme.md: -------------------------------------------------------------------------------- 1 | # Онлайн-хакатон Райффайзенбанка в области Data Science 2 | 3 | 4 | Ключевая информация: 5 | 6 | * Скор на паблике: 1.4241 7 | * Примерное место на паблике: 43 8 | * Обучение только на `price_type == 1` 9 | * Рассказ про adversarial validation 10 | * Рассказ про схему валидации 11 | * Визуализация геоданных с помощью библиотеки keplergl 12 | -------------------------------------------------------------------------------- /baseline_solution.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Онлайн-хакатон Райффайзенбанка в области Data Science\n", 8 | "\n", 9 | "\n", 10 | "Ключевая информация:\n", 11 | "\n", 12 | "* Скор на паблике: 1.4241\n", 13 | "* Примерное место на паблике: 43\n", 14 | "* Обучение только на `price_type == 1`\n", 15 | "* Рассказ про adversarial validation\n", 16 | "* Рассказ про схему валидации\n", 17 | "* Визуализация геоданных с помощью библиотеки keplergl\n", 18 | "\n", 19 | "\n", 20 | "----" 21 | ] 22 | }, 23 | { 24 | "cell_type": "markdown", 25 | "metadata": {}, 26 | "source": [ 27 | "Установим необходимые библиотеки" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": 1, 33 | "metadata": { 34 | "ExecuteTime": { 35 | "end_time": "2021-09-25T18:31:26.101470Z", 36 | "start_time": "2021-09-25T18:31:24.514293Z" 37 | } 38 | }, 39 | "outputs": [ 40 | { 41 | "name": "stdout", 42 | "output_type": "stream", 43 | "text": [ 44 | "Requirement already satisfied: tqdm in /usr/local/lib/python3.8/site-packages (4.50.2)\n", 45 | "Requirement already satisfied: matplotlib in /usr/local/lib/python3.8/site-packages (3.3.2)\n", 46 | "Requirement already satisfied: catboost in /usr/local/lib/python3.8/site-packages (0.26)\n", 47 | "Requirement already satisfied: keplergl in /usr/local/lib/python3.8/site-packages (0.2.1)\n", 48 | "Requirement already satisfied: six in /usr/local/lib/python3.8/site-packages (from catboost) (1.15.0)\n", 49 | "Requirement already satisfied: numpy>=1.16.0 in /usr/local/lib/python3.8/site-packages (from catboost) (1.19.2)\n", 50 | "Requirement already satisfied: graphviz in /usr/local/lib/python3.8/site-packages (from catboost) (0.16)\n", 51 | "Requirement already satisfied: scipy in /usr/local/lib/python3.8/site-packages (from catboost) (1.5.2)\n", 52 | "Requirement already satisfied: plotly in /usr/local/lib/python3.8/site-packages (from catboost) (5.1.0)\n", 53 | "Requirement already satisfied: pandas>=0.24.0 in /usr/local/lib/python3.8/site-packages (from catboost) (1.1.2)\n", 54 | "Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.8/site-packages (from pandas>=0.24.0->catboost) (2020.1)\n", 55 | "Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.8/site-packages (from pandas>=0.24.0->catboost) (2.8.1)\n", 56 | "Requirement already satisfied: Shapely>=1.6.4.post2 in /usr/local/lib/python3.8/site-packages (from keplergl) (1.7.1)\n", 57 | "Requirement already satisfied: ipywidgets<8,>=7.0.0 in /usr/local/lib/python3.8/site-packages (from keplergl) (7.5.1)\n", 58 | "Requirement already satisfied: geopandas>=0.5.0 in /usr/local/lib/python3.8/site-packages (from keplergl) (0.8.1)\n", 59 | "Requirement already satisfied: traittypes>=0.2.1 in /usr/local/lib/python3.8/site-packages (from keplergl) (0.2.1)\n", 60 | "Requirement already satisfied: fiona in /usr/local/lib/python3.8/site-packages (from geopandas>=0.5.0->keplergl) (1.8.17)\n", 61 | "Requirement already satisfied: pyproj>=2.2.0 in /usr/local/lib/python3.8/site-packages (from geopandas>=0.5.0->keplergl) (2.6.1.post1)\n", 62 | "Requirement already satisfied: nbformat>=4.2.0 in /usr/local/lib/python3.8/site-packages (from ipywidgets<8,>=7.0.0->keplergl) (5.0.7)\n", 63 | "Requirement already satisfied: traitlets>=4.3.1 in /usr/local/lib/python3.8/site-packages (from ipywidgets<8,>=7.0.0->keplergl) (5.0.4)\n", 64 | "Requirement already satisfied: ipykernel>=4.5.1 in /usr/local/lib/python3.8/site-packages (from ipywidgets<8,>=7.0.0->keplergl) (5.3.4)\n", 65 | "Requirement already satisfied: widgetsnbextension~=3.5.0 in /usr/local/lib/python3.8/site-packages (from ipywidgets<8,>=7.0.0->keplergl) (3.5.1)\n", 66 | "Requirement already satisfied: ipython>=4.0.0 in /usr/local/lib/python3.8/site-packages (from ipywidgets<8,>=7.0.0->keplergl) (7.18.1)\n", 67 | "Requirement already satisfied: appnope in /usr/local/lib/python3.8/site-packages (from ipykernel>=4.5.1->ipywidgets<8,>=7.0.0->keplergl) (0.1.0)\n", 68 | "Requirement already satisfied: jupyter-client in /usr/local/lib/python3.8/site-packages (from ipykernel>=4.5.1->ipywidgets<8,>=7.0.0->keplergl) (6.1.7)\n", 69 | "Requirement already satisfied: tornado>=4.2 in /usr/local/lib/python3.8/site-packages (from ipykernel>=4.5.1->ipywidgets<8,>=7.0.0->keplergl) (6.0.4)\n", 70 | "Requirement already satisfied: setuptools>=18.5 in /usr/local/lib/python3.8/site-packages (from ipython>=4.0.0->ipywidgets<8,>=7.0.0->keplergl) (53.0.0)\n", 71 | "Requirement already satisfied: jedi>=0.10 in /usr/local/lib/python3.8/site-packages (from ipython>=4.0.0->ipywidgets<8,>=7.0.0->keplergl) (0.17.2)\n", 72 | "Requirement already satisfied: pickleshare in /usr/local/lib/python3.8/site-packages (from ipython>=4.0.0->ipywidgets<8,>=7.0.0->keplergl) (0.7.5)\n", 73 | "Requirement already satisfied: pexpect>4.3 in /usr/local/lib/python3.8/site-packages (from ipython>=4.0.0->ipywidgets<8,>=7.0.0->keplergl) (4.8.0)\n", 74 | "Requirement already satisfied: prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0 in /usr/local/lib/python3.8/site-packages (from ipython>=4.0.0->ipywidgets<8,>=7.0.0->keplergl) (3.0.7)\n", 75 | "Requirement already satisfied: pygments in /usr/local/lib/python3.8/site-packages (from ipython>=4.0.0->ipywidgets<8,>=7.0.0->keplergl) (2.6.1)\n", 76 | "Requirement already satisfied: backcall in /usr/local/lib/python3.8/site-packages (from ipython>=4.0.0->ipywidgets<8,>=7.0.0->keplergl) (0.2.0)\n", 77 | "Requirement already satisfied: decorator in /usr/local/lib/python3.8/site-packages (from ipython>=4.0.0->ipywidgets<8,>=7.0.0->keplergl) (4.4.2)\n", 78 | "Requirement already satisfied: parso<0.8.0,>=0.7.0 in /usr/local/lib/python3.8/site-packages (from jedi>=0.10->ipython>=4.0.0->ipywidgets<8,>=7.0.0->keplergl) (0.7.1)\n", 79 | "Requirement already satisfied: ipython-genutils in /usr/local/lib/python3.8/site-packages (from nbformat>=4.2.0->ipywidgets<8,>=7.0.0->keplergl) (0.2.0)\n", 80 | "Requirement already satisfied: jupyter-core in /usr/local/lib/python3.8/site-packages (from nbformat>=4.2.0->ipywidgets<8,>=7.0.0->keplergl) (4.6.3)\n", 81 | "Requirement already satisfied: jsonschema!=2.5.0,>=2.4 in /usr/local/lib/python3.8/site-packages (from nbformat>=4.2.0->ipywidgets<8,>=7.0.0->keplergl) (3.2.0)\n", 82 | "Requirement already satisfied: attrs>=17.4.0 in /usr/local/lib/python3.8/site-packages (from jsonschema!=2.5.0,>=2.4->nbformat>=4.2.0->ipywidgets<8,>=7.0.0->keplergl) (20.2.0)\n", 83 | "Requirement already satisfied: pyrsistent>=0.14.0 in /usr/local/lib/python3.8/site-packages (from jsonschema!=2.5.0,>=2.4->nbformat>=4.2.0->ipywidgets<8,>=7.0.0->keplergl) (0.17.3)\n", 84 | "Requirement already satisfied: ptyprocess>=0.5 in /usr/local/lib/python3.8/site-packages (from pexpect>4.3->ipython>=4.0.0->ipywidgets<8,>=7.0.0->keplergl) (0.6.0)\n", 85 | "Requirement already satisfied: wcwidth in /usr/local/lib/python3.8/site-packages (from prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0->ipython>=4.0.0->ipywidgets<8,>=7.0.0->keplergl) (0.2.5)\n", 86 | "Requirement already satisfied: notebook>=4.4.1 in /usr/local/lib/python3.8/site-packages (from widgetsnbextension~=3.5.0->ipywidgets<8,>=7.0.0->keplergl) (6.1.4)\n", 87 | "Requirement already satisfied: terminado>=0.8.3 in /usr/local/lib/python3.8/site-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets<8,>=7.0.0->keplergl) (0.9.0)\n", 88 | "Requirement already satisfied: argon2-cffi in /usr/local/lib/python3.8/site-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets<8,>=7.0.0->keplergl) (20.1.0)\n", 89 | "Requirement already satisfied: pyzmq>=17 in /usr/local/lib/python3.8/site-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets<8,>=7.0.0->keplergl) (19.0.2)\n", 90 | "Requirement already satisfied: Send2Trash in /usr/local/lib/python3.8/site-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets<8,>=7.0.0->keplergl) (1.5.0)\n", 91 | "Requirement already satisfied: jinja2 in /usr/local/lib/python3.8/site-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets<8,>=7.0.0->keplergl) (2.11.2)\n", 92 | "Requirement already satisfied: prometheus-client in /usr/local/lib/python3.8/site-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets<8,>=7.0.0->keplergl) (0.8.0)\n", 93 | "Requirement already satisfied: nbconvert in /usr/local/lib/python3.8/site-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets<8,>=7.0.0->keplergl) (6.0.4)\n", 94 | "Requirement already satisfied: pillow>=6.2.0 in /usr/local/lib/python3.8/site-packages (from matplotlib) (7.2.0)\n", 95 | "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.8/site-packages (from matplotlib) (0.10.0)\n", 96 | "Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.3 in /usr/local/lib/python3.8/site-packages (from matplotlib) (2.4.7)\n", 97 | "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.8/site-packages (from matplotlib) (1.2.0)\n", 98 | "Requirement already satisfied: certifi>=2020.06.20 in /usr/local/lib/python3.8/site-packages (from matplotlib) (2020.6.20)\n", 99 | "Requirement already satisfied: cffi>=1.0.0 in /usr/local/lib/python3.8/site-packages (from argon2-cffi->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets<8,>=7.0.0->keplergl) (1.14.2)\n", 100 | "Requirement already satisfied: pycparser in /usr/local/lib/python3.8/site-packages (from cffi>=1.0.0->argon2-cffi->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets<8,>=7.0.0->keplergl) (2.20)\n", 101 | "Requirement already satisfied: click<8,>=4.0 in /usr/local/lib/python3.8/site-packages (from fiona->geopandas>=0.5.0->keplergl) (7.1.2)\n", 102 | "Requirement already satisfied: munch in /usr/local/lib/python3.8/site-packages (from fiona->geopandas>=0.5.0->keplergl) (2.5.0)\n", 103 | "Requirement already satisfied: click-plugins>=1.0 in /usr/local/lib/python3.8/site-packages (from fiona->geopandas>=0.5.0->keplergl) (1.1.1)\n", 104 | "Requirement already satisfied: cligj>=0.5 in /usr/local/lib/python3.8/site-packages (from fiona->geopandas>=0.5.0->keplergl) (0.5.0)\n" 105 | ] 106 | }, 107 | { 108 | "name": "stdout", 109 | "output_type": "stream", 110 | "text": [ 111 | "Requirement already satisfied: MarkupSafe>=0.23 in /usr/local/lib/python3.8/site-packages (from jinja2->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets<8,>=7.0.0->keplergl) (1.1.1)\n", 112 | "Requirement already satisfied: jupyterlab-pygments in /usr/local/lib/python3.8/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets<8,>=7.0.0->keplergl) (0.1.1)\n", 113 | "Requirement already satisfied: mistune<2,>=0.8.1 in /usr/local/lib/python3.8/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets<8,>=7.0.0->keplergl) (0.8.4)\n", 114 | "Requirement already satisfied: bleach in /usr/local/lib/python3.8/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets<8,>=7.0.0->keplergl) (3.2.1)\n", 115 | "Requirement already satisfied: testpath in /usr/local/lib/python3.8/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets<8,>=7.0.0->keplergl) (0.4.4)\n", 116 | "Requirement already satisfied: defusedxml in /usr/local/lib/python3.8/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets<8,>=7.0.0->keplergl) (0.6.0)\n", 117 | "Requirement already satisfied: nbclient<0.6.0,>=0.5.0 in /usr/local/lib/python3.8/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets<8,>=7.0.0->keplergl) (0.5.0)\n", 118 | "Requirement already satisfied: entrypoints>=0.2.2 in /usr/local/lib/python3.8/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets<8,>=7.0.0->keplergl) (0.3)\n", 119 | "Requirement already satisfied: pandocfilters>=1.4.1 in /usr/local/lib/python3.8/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets<8,>=7.0.0->keplergl) (1.4.2)\n", 120 | "Requirement already satisfied: async-generator in /usr/local/lib/python3.8/site-packages (from nbclient<0.6.0,>=0.5.0->nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets<8,>=7.0.0->keplergl) (1.10)\n", 121 | "Requirement already satisfied: nest-asyncio in /usr/local/lib/python3.8/site-packages (from nbclient<0.6.0,>=0.5.0->nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets<8,>=7.0.0->keplergl) (1.4.0)\n", 122 | "Requirement already satisfied: webencodings in /usr/local/lib/python3.8/site-packages (from bleach->nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets<8,>=7.0.0->keplergl) (0.5.1)\n", 123 | "Requirement already satisfied: packaging in /usr/local/lib/python3.8/site-packages (from bleach->nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets<8,>=7.0.0->keplergl) (20.4)\n", 124 | "Requirement already satisfied: tenacity>=6.2.0 in /usr/local/lib/python3.8/site-packages (from plotly->catboost) (7.0.0)\n", 125 | "Note: you may need to restart the kernel to use updated packages.\n" 126 | ] 127 | } 128 | ], 129 | "source": [ 130 | "pip install tqdm matplotlib catboost keplergl" 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": 2, 136 | "metadata": { 137 | "ExecuteTime": { 138 | "end_time": "2021-09-25T18:31:27.230045Z", 139 | "start_time": "2021-09-25T18:31:26.103332Z" 140 | } 141 | }, 142 | "outputs": [], 143 | "source": [ 144 | "import pandas as pd\n", 145 | "from matplotlib import pyplot as plt\n", 146 | "import numpy as np\n", 147 | "from catboost import CatBoostClassifier, CatBoostRegressor\n", 148 | "\n", 149 | "from sklearn.metrics import roc_auc_score\n", 150 | "from sklearn.base import clone\n", 151 | "from tqdm.auto import tqdm\n", 152 | "\n", 153 | "from keplergl import KeplerGl\n", 154 | "\n", 155 | "np.random.seed(42)" 156 | ] 157 | }, 158 | { 159 | "cell_type": "markdown", 160 | "metadata": {}, 161 | "source": [ 162 | "Метрика, которую предоставили организаторы" 163 | ] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": 3, 168 | "metadata": { 169 | "ExecuteTime": { 170 | "end_time": "2021-09-25T18:31:27.236022Z", 171 | "start_time": "2021-09-25T18:31:27.231896Z" 172 | } 173 | }, 174 | "outputs": [], 175 | "source": [ 176 | "def deviation_metric_one_sample(y_true, y_pred):\n", 177 | " deviation = (y_pred - y_true) / np.maximum(1e-8, y_true)\n", 178 | " if np.abs(deviation) <= 0.15: return 0\n", 179 | " elif deviation <= -0.6: return 9.9\n", 180 | " elif deviation < -0.15: return 1.1 * (deviation / 0.15 + 1) ** 2\n", 181 | " elif deviation < 0.6: return (deviation / 0.15 - 1) ** 2\n", 182 | " return 9\n", 183 | "\n", 184 | "def deviation_metric(y_true, y_pred):\n", 185 | " return np.array([deviation_metric_one_sample(y_true[n], y_pred[n]) for n in range(len(y_true))]).mean()" 186 | ] 187 | }, 188 | { 189 | "cell_type": "code", 190 | "execution_count": 4, 191 | "metadata": { 192 | "ExecuteTime": { 193 | "end_time": "2021-09-25T18:31:29.658149Z", 194 | "start_time": "2021-09-25T18:31:27.238094Z" 195 | } 196 | }, 197 | "outputs": [ 198 | { 199 | "name": "stderr", 200 | "output_type": "stream", 201 | "text": [ 202 | "/usr/local/lib/python3.8/site-packages/IPython/core/interactiveshell.py:3145: DtypeWarning: Columns (1) have mixed types.Specify dtype option on import or set low_memory=False.\n", 203 | " has_raised = await self.run_ast_nodes(code_ast.body, cell_name,\n" 204 | ] 205 | } 206 | ], 207 | "source": [ 208 | "test = pd.read_csv('./data/test.csv').rename({'per_square_meter_price': 'target'}, axis=1)\n", 209 | "test['train'] = 0\n", 210 | "test['target'] = 0\n", 211 | "\n", 212 | "train = pd.read_csv('./data/train.csv').rename({'per_square_meter_price': 'target'}, axis=1)\n", 213 | "train['train'] = 1\n", 214 | "\n", 215 | "dataset = pd.concat([train, test])" 216 | ] 217 | }, 218 | { 219 | "cell_type": "markdown", 220 | "metadata": {}, 221 | "source": [ 222 | "Выделим важные колонки" 223 | ] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "execution_count": 5, 228 | "metadata": { 229 | "ExecuteTime": { 230 | "end_time": "2021-09-25T18:31:29.662569Z", 231 | "start_time": "2021-09-25T18:31:29.660009Z" 232 | } 233 | }, 234 | "outputs": [], 235 | "source": [ 236 | "key_cols = ['id', 'date', 'price_type', 'train', 'month', 'target']\n", 237 | "cat_cols = ['city', 'osm_city_nearest_name', 'region', 'realty_type', 'street', 'floor']" 238 | ] 239 | }, 240 | { 241 | "cell_type": "markdown", 242 | "metadata": {}, 243 | "source": [ 244 | "Для удобства объединим выборки и добавим некоторые важные колонки.\n", 245 | "\n", 246 | "Также отфильтруем датасет :)" 247 | ] 248 | }, 249 | { 250 | "cell_type": "code", 251 | "execution_count": 6, 252 | "metadata": { 253 | "ExecuteTime": { 254 | "end_time": "2021-09-25T18:31:30.558167Z", 255 | "start_time": "2021-09-25T18:31:29.664648Z" 256 | } 257 | }, 258 | "outputs": [], 259 | "source": [ 260 | "dataset[cat_cols] = dataset[cat_cols].astype(str).fillna('NAN')\n", 261 | "dataset['date'] = pd.to_datetime(dataset['date'])\n", 262 | "dataset['month'] = (\n", 263 | " dataset['date'].dt.floor('d') + pd.offsets.MonthEnd(0) - pd.offsets.MonthBegin(1)\n", 264 | ")\n", 265 | "\n", 266 | "dataset = dataset[dataset['price_type']==1]" 267 | ] 268 | }, 269 | { 270 | "cell_type": "markdown", 271 | "metadata": {}, 272 | "source": [ 273 | "# Небольшая визуализация" 274 | ] 275 | }, 276 | { 277 | "cell_type": "code", 278 | "execution_count": 7, 279 | "metadata": { 280 | "ExecuteTime": { 281 | "end_time": "2021-09-25T18:31:30.750349Z", 282 | "start_time": "2021-09-25T18:31:30.559827Z" 283 | } 284 | }, 285 | "outputs": [ 286 | { 287 | "data": { 288 | "image/png": "\n", 289 | "text/plain": [ 290 | "
" 291 | ] 292 | }, 293 | "metadata": { 294 | "needs_background": "light" 295 | }, 296 | "output_type": "display_data" 297 | } 298 | ], 299 | "source": [ 300 | "dataset['month'].value_counts().sort_index().plot(title='Кол-во объектов по месяцам');" 301 | ] 302 | }, 303 | { 304 | "cell_type": "markdown", 305 | "metadata": {}, 306 | "source": [ 307 | "Интересно, что в последние два месяца, объект гораздо больше, чем в другие месяцы...\n", 308 | "\n", 309 | "Посмотрим, как географически распределены train и test:" 310 | ] 311 | }, 312 | { 313 | "cell_type": "code", 314 | "execution_count": 8, 315 | "metadata": { 316 | "ExecuteTime": { 317 | "end_time": "2021-09-25T18:31:30.864935Z", 318 | "start_time": "2021-09-25T18:31:30.753719Z" 319 | } 320 | }, 321 | "outputs": [ 322 | { 323 | "name": "stdout", 324 | "output_type": "stream", 325 | "text": [ 326 | "User Guide: https://docs.kepler.gl/docs/keplergl-jupyter\n" 327 | ] 328 | }, 329 | { 330 | "data": { 331 | "application/vnd.jupyter.widget-view+json": { 332 | "model_id": "e0a8fa94f39249308ae9041d235222f0", 333 | "version_major": 2, 334 | "version_minor": 0 335 | }, 336 | "text/plain": [ 337 | "KeplerGl(data={'test': lat lng\n", 338 | "0 51.709255 36.147908\n", 339 | "1 61.233240 73.462509\n", 340 | "2 …" 341 | ] 342 | }, 343 | "metadata": {}, 344 | "output_type": "display_data" 345 | } 346 | ], 347 | "source": [ 348 | "kepler_data = {\n", 349 | " \"test\": dataset[dataset['train']==0][['lat', 'lng']],\n", 350 | " \"train\": dataset[dataset['train']==1][['lat', 'lng']],\n", 351 | "}\n", 352 | "\n", 353 | "map1 = KeplerGl(height=400, data=kepler_data)\n", 354 | "map1" 355 | ] 356 | }, 357 | { 358 | "cell_type": "markdown", 359 | "metadata": {}, 360 | "source": [ 361 | "Выглядит достаточно равномерно" 362 | ] 363 | }, 364 | { 365 | "cell_type": "markdown", 366 | "metadata": {}, 367 | "source": [ 368 | "# Adversarial validation\n", 369 | "\n", 370 | "В данном разделе проверим, насколько сильно наша обучающая выборка отличается от тестовой.\n", 371 | "\n", 372 | "\n", 373 | "Для этого назовем `единичками` объекты из обучающей выбокри, `ноликами` - объекты из тестовой выборки.\n", 374 | "И на исходных признаках построим модель машинного обучения, которая будет пытаться отличить обучающую выборку от тестовой на основе признаков.\n", 375 | "\n", 376 | "Так как данные распределены по времени, то чтобы оценить качество нашей модели оставим по последнему месяцу из train и test.\n", 377 | "\n", 378 | "Если выборки не отличаются, то качество такой модели должно быть случайной.\n", 379 | "\n", 380 | "Для оценки качества модели будем использовать `ROC AUC`. Как мы знаем, `ROC AUC` случайного алгоритма должен быть в районе `0.5`." 381 | ] 382 | }, 383 | { 384 | "cell_type": "code", 385 | "execution_count": 9, 386 | "metadata": { 387 | "ExecuteTime": { 388 | "end_time": "2021-09-25T18:31:39.888229Z", 389 | "start_time": "2021-09-25T18:31:30.866752Z" 390 | } 391 | }, 392 | "outputs": [ 393 | { 394 | "name": "stdout", 395 | "output_type": "stream", 396 | "text": [ 397 | "Learning rate set to 0.022342\n", 398 | "0:\tlearn: 0.6909881\ttotal: 58.8ms\tremaining: 58.7s\n", 399 | "100:\tlearn: 0.6161894\ttotal: 1.11s\tremaining: 9.9s\n", 400 | "200:\tlearn: 0.5937461\ttotal: 1.97s\tremaining: 7.82s\n", 401 | "300:\tlearn: 0.5767069\ttotal: 2.77s\tremaining: 6.43s\n", 402 | "400:\tlearn: 0.5588018\ttotal: 3.56s\tremaining: 5.32s\n", 403 | "500:\tlearn: 0.5380880\ttotal: 4.49s\tremaining: 4.47s\n", 404 | "600:\tlearn: 0.5179787\ttotal: 5.39s\tremaining: 3.58s\n", 405 | "700:\tlearn: 0.4999933\ttotal: 6.25s\tremaining: 2.67s\n", 406 | "800:\tlearn: 0.4835562\ttotal: 7.14s\tremaining: 1.77s\n", 407 | "900:\tlearn: 0.4679176\ttotal: 7.97s\tremaining: 876ms\n", 408 | "999:\tlearn: 0.4541192\ttotal: 8.81s\tremaining: 0us\n", 409 | "0.48765809873521015\n" 410 | ] 411 | } 412 | ], 413 | "source": [ 414 | "oot_train_month_adv = '2020-08-01'\n", 415 | "oot_val_month_adv = '2020-12-01'\n", 416 | "\n", 417 | "Xy_train_adv = dataset[~dataset['month'].isin([oot_train_month_adv, oot_val_month_adv])].reset_index(drop=True)\n", 418 | "Xy_test_adv = dataset[dataset['month'].isin([oot_train_month_adv, oot_val_month_adv])].reset_index(drop=True)\n", 419 | "\n", 420 | "\n", 421 | "adv_model = CatBoostClassifier(verbose=100)\n", 422 | "adv_model.fit(\n", 423 | " Xy_train_adv.drop(key_cols, axis=1), \n", 424 | " Xy_train_adv['train'],\n", 425 | " cat_features=cat_cols\n", 426 | ")\n", 427 | "\n", 428 | "\n", 429 | "predict_adv = adv_model.predict(Xy_test_adv.drop(key_cols, axis=1))\n", 430 | "\n", 431 | "print(roc_auc_score(Xy_test_adv['train'], predict_adv))" 432 | ] 433 | }, 434 | { 435 | "cell_type": "markdown", 436 | "metadata": {}, 437 | "source": [ 438 | "Качество нашей модели получилось околослучайное, а значит все в порядке" 439 | ] 440 | }, 441 | { 442 | "cell_type": "markdown", 443 | "metadata": {}, 444 | "source": [ 445 | "# Модель" 446 | ] 447 | }, 448 | { 449 | "cell_type": "markdown", 450 | "metadata": {}, 451 | "source": [ 452 | "Прежде, чем строить модель, давайте разберемся, как будем валидироваться.\n", 453 | "\n", 454 | "Как говорилось выше, данные распределены по времени, поэтому валидироваться не обходимо также по времени:\n", 455 | "\n", 456 | "![alt text](https://miro.medium.com/max/558/1*AXRu72CV1hdjLfODFGbMWQ.png \"Title\")\n", 457 | "\n" 458 | ] 459 | }, 460 | { 461 | "cell_type": "code", 462 | "execution_count": 10, 463 | "metadata": { 464 | "ExecuteTime": { 465 | "end_time": "2021-09-25T18:31:39.896342Z", 466 | "start_time": "2021-09-25T18:31:39.889996Z" 467 | } 468 | }, 469 | "outputs": [], 470 | "source": [ 471 | "# месяц с которого начнем валидироваться\n", 472 | "start_month = '2020-04-01'\n", 473 | "\n", 474 | "# месяцы для валидации\n", 475 | "val_months = (\n", 476 | " dataset[(dataset['train']==1) & (dataset['month'] > start_month)]['month']\n", 477 | " .drop_duplicates()\n", 478 | " .sort_values()\n", 479 | " .tolist()\n", 480 | ")" 481 | ] 482 | }, 483 | { 484 | "cell_type": "code", 485 | "execution_count": 11, 486 | "metadata": { 487 | "ExecuteTime": { 488 | "end_time": "2021-09-25T18:32:06.507988Z", 489 | "start_time": "2021-09-25T18:31:39.898647Z" 490 | } 491 | }, 492 | "outputs": [ 493 | { 494 | "data": { 495 | "application/vnd.jupyter.widget-view+json": { 496 | "model_id": "f91d64541e1548d1b3575e83280da1b2", 497 | "version_major": 2, 498 | "version_minor": 0 499 | }, 500 | "text/plain": [ 501 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=4.0), HTML(value='')))" 502 | ] 503 | }, 504 | "metadata": {}, 505 | "output_type": "display_data" 506 | }, 507 | { 508 | "name": "stdout", 509 | "output_type": "stream", 510 | "text": [ 511 | "\n", 512 | "Средняя метрика по бинам: 1.594\n", 513 | "Отклонение метрики по бинам: 0.160\n" 514 | ] 515 | } 516 | ], 517 | "source": [ 518 | "main_model = CatBoostRegressor(loss_function='MAE', verbose=0)\n", 519 | "\n", 520 | "scores = []\n", 521 | "\n", 522 | "for month in tqdm(val_months):\n", 523 | " Xy_train = dataset[(dataset['month'] < month)].reset_index(drop=True)\n", 524 | " Xy_val = dataset[(dataset['month'] == month)].reset_index(drop=True)\n", 525 | " \n", 526 | " model = clone(main_model)\n", 527 | " model = CatBoostRegressor(loss_function='MAE', verbose=0)\n", 528 | " model = model.fit(Xy_train.drop(key_cols, axis=1), Xy_train['target'], cat_features=cat_cols)\n", 529 | "\n", 530 | " metric = deviation_metric(Xy_val['target'], model.predict(Xy_val.drop(key_cols, axis=1)))\n", 531 | " scores.append(metric)\n", 532 | " \n", 533 | "print(f'Средняя метрика по бинам: {np.mean(scores):.3f}')\n", 534 | "print(f'Отклонение метрики по бинам: {np.std(scores):.3f}')" 535 | ] 536 | }, 537 | { 538 | "cell_type": "code", 539 | "execution_count": 12, 540 | "metadata": { 541 | "ExecuteTime": { 542 | "end_time": "2021-09-25T18:32:06.536279Z", 543 | "start_time": "2021-09-25T18:32:06.509814Z" 544 | } 545 | }, 546 | "outputs": [ 547 | { 548 | "data": { 549 | "text/html": [ 550 | "
\n", 551 | "\n", 564 | "\n", 565 | " \n", 566 | " \n", 567 | " \n", 568 | " \n", 569 | " \n", 570 | " \n", 571 | " \n", 572 | " \n", 573 | " \n", 574 | " \n", 575 | " \n", 576 | " \n", 577 | " \n", 578 | " \n", 579 | " \n", 580 | " \n", 581 | " \n", 582 | " \n", 583 | " \n", 584 | " \n", 585 | " \n", 586 | " \n", 587 | " \n", 588 | " \n", 589 | " \n", 590 | " \n", 591 | " \n", 592 | " \n", 593 | " \n", 594 | " \n", 595 | " \n", 596 | " \n", 597 | " \n", 598 | " \n", 599 | " \n", 600 | " \n", 601 | " \n", 602 | " \n", 603 | " \n", 604 | " \n", 605 | " \n", 606 | " \n", 607 | " \n", 608 | " \n", 609 | " \n", 610 | " \n", 611 | " \n", 612 | " \n", 613 | " \n", 614 | " \n", 615 | " \n", 616 | " \n", 617 | " \n", 618 | " \n", 619 | " \n", 620 | " \n", 621 | " \n", 622 | " \n", 623 | " \n", 624 | " \n", 625 | " \n", 626 | " \n", 627 | " \n", 628 | " \n", 629 | "
Feature IdImportances
0city8.071964
1total_square7.342278
2floor4.817922
3realty_type4.656558
4osm_train_stop_closest_dist3.711538
.........
68osm_building_points_in_0.010.200341
69osm_healthcare_points_in_0.0050.177565
70osm_leisure_points_in_0.00750.161400
71osm_building_points_in_0.0010.131856
72osm_train_stop_points_in_0.00750.000000
\n", 630 | "

73 rows × 2 columns

\n", 631 | "
" 632 | ], 633 | "text/plain": [ 634 | " Feature Id Importances\n", 635 | "0 city 8.071964\n", 636 | "1 total_square 7.342278\n", 637 | "2 floor 4.817922\n", 638 | "3 realty_type 4.656558\n", 639 | "4 osm_train_stop_closest_dist 3.711538\n", 640 | ".. ... ...\n", 641 | "68 osm_building_points_in_0.01 0.200341\n", 642 | "69 osm_healthcare_points_in_0.005 0.177565\n", 643 | "70 osm_leisure_points_in_0.0075 0.161400\n", 644 | "71 osm_building_points_in_0.001 0.131856\n", 645 | "72 osm_train_stop_points_in_0.0075 0.000000\n", 646 | "\n", 647 | "[73 rows x 2 columns]" 648 | ] 649 | }, 650 | "execution_count": 12, 651 | "metadata": {}, 652 | "output_type": "execute_result" 653 | } 654 | ], 655 | "source": [ 656 | "model.get_feature_importance(prettified=True)" 657 | ] 658 | }, 659 | { 660 | "cell_type": "markdown", 661 | "metadata": {}, 662 | "source": [ 663 | "# Финальная модель\n", 664 | "\n", 665 | "Обучим финальную модель и сделаем предикт" 666 | ] 667 | }, 668 | { 669 | "cell_type": "code", 670 | "execution_count": 13, 671 | "metadata": { 672 | "ExecuteTime": { 673 | "end_time": "2021-09-25T18:32:14.105074Z", 674 | "start_time": "2021-09-25T18:32:06.538334Z" 675 | } 676 | }, 677 | "outputs": [], 678 | "source": [ 679 | "Xy_train = dataset[(dataset['train'] == 1)].reset_index(drop=True)\n", 680 | "Xy_test = dataset[(dataset['train'] == 0)].reset_index(drop=True)\n", 681 | "\n", 682 | "model = clone(main_model)\n", 683 | "model = CatBoostRegressor(loss_function='MAE', verbose=0)\n", 684 | "model = model.fit(Xy_train.drop(key_cols, axis=1), Xy_train['target'], cat_features=cat_cols)" 685 | ] 686 | }, 687 | { 688 | "cell_type": "code", 689 | "execution_count": 14, 690 | "metadata": { 691 | "ExecuteTime": { 692 | "end_time": "2021-09-25T18:32:14.133472Z", 693 | "start_time": "2021-09-25T18:32:14.106470Z" 694 | } 695 | }, 696 | "outputs": [], 697 | "source": [ 698 | "Xy_test['per_square_meter_price'] = model.predict(Xy_test.drop(key_cols, axis=1))\n", 699 | "Xy_test[['id', 'per_square_meter_price']].to_csv('sub1.csv', index=False)" 700 | ] 701 | }, 702 | { 703 | "cell_type": "markdown", 704 | "metadata": {}, 705 | "source": [ 706 | "# Дальнейшие идеи\n", 707 | "\n", 708 | "* Покрутить признаки\n", 709 | "* потюнить модель\n", 710 | "* Подумать, что можно сделать с отфильтрованной частью датасета\n", 711 | "* Поиспользовать внешние данные :)" 712 | ] 713 | } 714 | ], 715 | "metadata": { 716 | "kernelspec": { 717 | "display_name": "Python 3", 718 | "language": "python", 719 | "name": "python3" 720 | }, 721 | "language_info": { 722 | "codemirror_mode": { 723 | "name": "ipython", 724 | "version": 3 725 | }, 726 | "file_extension": ".py", 727 | "mimetype": "text/x-python", 728 | "name": "python", 729 | "nbconvert_exporter": "python", 730 | "pygments_lexer": "ipython3", 731 | "version": "3.8.8" 732 | }, 733 | "toc": { 734 | "base_numbering": 1, 735 | "nav_menu": {}, 736 | "number_sections": true, 737 | "sideBar": true, 738 | "skip_h1_title": false, 739 | "title_cell": "Table of Contents", 740 | "title_sidebar": "Contents", 741 | "toc_cell": false, 742 | "toc_position": {}, 743 | "toc_section_display": true, 744 | "toc_window_display": false 745 | } 746 | }, 747 | "nbformat": 4, 748 | "nbformat_minor": 4 749 | } 750 | --------------------------------------------------------------------------------