├── README.md ├── data ├── 1.txt ├── others │ ├── models │ │ ├── .gitkeep │ │ └── no_reg_save_0917.pth │ ├── no_reg_0917 │ ├── test_df.pkl.zip │ └── train.csv.zip ├── test_clean └── train_clean.zip ├── exp └── nb_.py ├── interpret_tabular.ipynb ├── p_fastai.ipynb └── test_tablr_mixup_quick_n_dirty_messy_code.ipynb /README.md: -------------------------------------------------------------------------------- 1 | # fastai-shared-notebooks 2 | 3 | Some useful functions for fastai tabular model analysing 4 | Hope the [notebook](https://github.com/Pak911/fastai-shared-notebooks/blob/master/interpret_tabular.ipynb) is self explanatory :) 5 | 6 | Contains functions and examples of: 7 | - How to make **prediction on a new dataset** with trained fastai model (learner) 8 | - How to **use trained embeddings** in other process (to train **Random Forest** in this case) 9 | - How to calculate **feature importance** in fastai 10 | - How to calculate **partial dependence** for categorical features 11 | - How to plot **dendrograms** for the data 12 | - How to plot **embeddings** 13 | 14 | Thanks to: 15 | - [fastai](https://github.com/fastai/fastai) framework itself 16 | - fastai Machine Learning course [lesson 3](https://youtu.be/YSFG_W8JxBo?t=4048) and [lesson 4](https://www.youtube.com/watch?v=YSFG_W8JxBo). Or if you prefer - the [notebook](https://github.com/fastai/fastai/blob/master/courses/ml1/lesson2-rf_interpretation.ipynb) of these tutorials 17 | My notebook is pretty much an implementation of these ones for fastai tabular model case (with some of my thoughts and experiments' results) 18 | -------------------------------------------------------------------------------- /data/1.txt: -------------------------------------------------------------------------------- 1 | Test 2 | -------------------------------------------------------------------------------- /data/others/models/.gitkeep: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /data/others/models/no_reg_save_0917.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Pak911/fastai-shared-notebooks/ae5e01216e6a31ccb55b23bc2ae73bc23350b987/data/others/models/no_reg_save_0917.pth -------------------------------------------------------------------------------- /data/others/no_reg_0917: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Pak911/fastai-shared-notebooks/ae5e01216e6a31ccb55b23bc2ae73bc23350b987/data/others/no_reg_0917 -------------------------------------------------------------------------------- /data/others/test_df.pkl.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Pak911/fastai-shared-notebooks/ae5e01216e6a31ccb55b23bc2ae73bc23350b987/data/others/test_df.pkl.zip -------------------------------------------------------------------------------- /data/others/train.csv.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Pak911/fastai-shared-notebooks/ae5e01216e6a31ccb55b23bc2ae73bc23350b987/data/others/train.csv.zip -------------------------------------------------------------------------------- /data/test_clean: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Pak911/fastai-shared-notebooks/ae5e01216e6a31ccb55b23bc2ae73bc23350b987/data/test_clean -------------------------------------------------------------------------------- /data/train_clean.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Pak911/fastai-shared-notebooks/ae5e01216e6a31ccb55b23bc2ae73bc23350b987/data/train_clean.zip -------------------------------------------------------------------------------- /exp/nb_.py: -------------------------------------------------------------------------------- 1 | 2 | ################################################# 3 | ### THIS FILE WAS AUTOGENERATED! DO NOT EDIT! ### 4 | ################################################# 5 | # files to edit: 01-main_train.ipynb 02-main_train-experiments.ipynb 03-main_simple-FI.ipynb 04-main_retrain-FI.ipynb 05-main_part_dep.ipynb 06-main_dendrogram-and-dem-red.ipynb _functions.ipynb check_data.ipynb contract_till_interpret_importance_clean.ipynb interpret_tabular.ipynb test_mixup.ipynb 6 | from fastai.layers import FlattenedLoss 7 | 8 | from fastai.tabular import * 9 | 10 | from fastai.basic_train import _loss_func2activ 11 | 12 | from fastai.callbacks import CSVLogger 13 | 14 | from scipy.cluster import hierarchy as hc 15 | 16 | from sklearn import manifold 17 | 18 | import pickle 19 | 20 | def _list_diff(list_1, list_2): 21 | diff = set(list_1) - set(list_2) 22 | return [item for item in list_1 if item in diff] 23 | 24 | def list_diff(list1, list2, *args): 25 | diff = _list_diff(list1, list2) 26 | for arg in args: 27 | diff = _list_diff(diff, arg) 28 | return diff 29 | 30 | def exp_mmape(pred:Tensor, targ:Tensor)->Rank0Tensor: 31 | "Exp median absolute percentage error between `pred` and `targ`." 32 | pred,targ = flatten_check(pred,targ) 33 | pred, targ = torch.exp(pred), torch.exp(targ) 34 | pct_var = (targ - pred)/targ 35 | return torch.abs(pct_var).median() 36 | 37 | def MAELossFlat(*args, axis:int=-1, floatify:bool=True, **kwargs): 38 | "Same as `nn.MAELoss`, but flattens input and target." 39 | return FlattenedLoss(nn.L1Loss, *args, axis=axis, floatify=floatify, is_2d=False, **kwargs) 40 | 41 | def _list_diff(list_1, list_2): 42 | diff = set(list_1) - set(list_2) 43 | return [item for item in list_1 if item in diff] 44 | 45 | def list_diff(list1, list2, *args): 46 | diff = _list_diff(list1, list2) 47 | for arg in args: 48 | diff = _list_diff(diff, arg) 49 | return diff 50 | 51 | def which_elms(values, in_list): 52 | ''' 53 | Just outputs elements from values that are in list in_list 54 | ''' 55 | return [x for x in values if (x in in_list)] 56 | 57 | def is_in_list(values, in_list): 58 | ''' 59 | Just outputs is one of the elements from values is in list in_list 60 | ''' 61 | if (len(which_elms(values, in_list)) > 0): 62 | return True 63 | else: 64 | return False 65 | 66 | def apply_fill_n_catf(df:DataFrame, learn:Learner)->DataFrame: 67 | ''' 68 | Reapplies FillMissing and Categorify to given dataframe. 69 | ''' 70 | 71 | df_copy = df.copy() 72 | fill, catf = None, None 73 | is_alone = True if (len(df) == 1) else False 74 | 75 | proc = learn.data.processor[0] 76 | if (is_alone): 77 | df_copy = df_copy.append(df_copy.iloc[0]) 78 | 79 | for prc in proc.procs: 80 | if (type(prc) == FillMissing): 81 | fill = prc 82 | elif (type(prc) == Categorify): 83 | catf = prc 84 | if (fill is not None): 85 | fill.apply_test(df_copy) 86 | 87 | if (catf is not None): 88 | catf.apply_test(df_copy) 89 | for c in catf.cat_names: 90 | df_copy[c] = (df_copy[c].cat.codes).astype(np.int64) + 1 91 | cats = df_copy[catf.cat_names].to_numpy() 92 | 93 | # ugly workaround as apperently catf.apply_test doesn't work with lone row 94 | if (is_alone): 95 | df_copy = df_copy[:1] 96 | 97 | return df_copy 98 | 99 | 100 | def apply_fill(df:DataFrame, learn:Learner)->DataFrame: 101 | ''' 102 | Reapplies FillMissing to given dataframe. 103 | ''' 104 | 105 | df_copy = df.copy() 106 | fill = None 107 | is_alone = True if (len(df) == 1) else False 108 | 109 | proc = learn.data.processor[0] 110 | if (is_alone): 111 | df_copy = df_copy.append(df_copy.iloc[0]) 112 | 113 | for prc in proc.procs: 114 | if (type(prc) == FillMissing): 115 | fill = prc 116 | if (fill is not None): 117 | fill.apply_test(df_copy) 118 | 119 | # ugly workaround as apperently catf.apply_test doesn't work with lone row 120 | if (is_alone): 121 | df_copy = df_copy[:1] 122 | 123 | return df_copy 124 | 125 | 126 | def get_model_real_input(df:DataFrame, learn:Learner, bs:int=None)->Tensor: 127 | 128 | df_copy = df.copy() 129 | fill, catf, norm = None, None, None 130 | cats, conts = None, None 131 | is_alone = True if (len(df) == 1) else False 132 | 133 | 134 | proc = learn.data.processor[0] 135 | if (is_alone): 136 | df_copy = df_copy.append(df_copy.iloc[0]) 137 | 138 | for prc in proc.procs: 139 | if (type(prc) == FillMissing): 140 | fill = prc 141 | elif (type(prc) == Categorify): 142 | catf = prc 143 | elif (type(prc) == Normalize): 144 | norm = prc 145 | if (fill is not None): 146 | fill.apply_test(df_copy) 147 | if (catf is not None): 148 | catf.apply_test(df_copy) 149 | for c in catf.cat_names: 150 | df_copy[c] = (df_copy[c].cat.codes).astype(np.int64) + 1 151 | cats = df_copy[catf.cat_names].to_numpy() 152 | 153 | if (norm is not None): 154 | norm.apply_test(df_copy) 155 | conts = df_copy[norm.cont_names].to_numpy().astype('float32') 156 | 157 | # ugly workaround as apperently catf.apply_test doesn't work with lone row 158 | if (is_alone): 159 | xs = [torch.tensor([cats[0]], device=learn.data.device), torch.tensor([conts[0]], device=learn.data.device)] 160 | else: 161 | if (bs is None): 162 | xs = [torch.tensor(cats, device=learn.data.device), torch.tensor(conts, device=learn.data.device)] 163 | elif (bs > 0): 164 | xs = [list(chunks(l=torch.tensor(cats, device=learn.data.device), n=bs)), 165 | list(chunks(l=torch.tensor(conts, device=learn.data.device), n=bs))] 166 | 167 | return xs 168 | 169 | 170 | def get_cust_preds(df:DataFrame, learn:Learner, bs:int=None, parent=None)->Tensor: 171 | ''' 172 | Using existing model to predict output (learn.model) on a new dataframe at once (learn.predict does it for 173 | one row which is pretty slow). 174 | ''' 175 | def turn_to_activ(learn, acts): 176 | activ = _loss_func2activ(learn.loss_func) 177 | if activ is not None: 178 | return to_np(activ(acts)) 179 | else: 180 | return to_np(acts) 181 | 182 | xs = get_model_real_input(df=df, learn=learn, bs=bs) 183 | learn.model.eval(); 184 | if (bs is None): 185 | outp = learn.model(x_cat=xs[0], x_cont=xs[1]) 186 | 187 | elif (bs > 0): 188 | res = [] 189 | for ca, co in zip(xs[0], xs[1]): 190 | res.append(to_np(learn.model(x_cat=ca, x_cont=co))) 191 | #double translation to save gpu memory 192 | outp = tensor(np.concatenate(res, axis=0)) 193 | return turn_to_activ(learn=learn, acts=outp) 194 | 195 | 196 | def convert_dep_col(df:DataFrame, dep_col:AnyStr, learn:Learner)->Tensor: 197 | ''' 198 | Converts dataframe column, named "depended column", into tensor, that can later be used to compare with predictions. 199 | Log will be applied if it was done in a training dataset 200 | ''' 201 | actls = df[dep_col].T.to_numpy()[np.newaxis].T.astype('float32') 202 | actls = np.log(actls) if (hasattr(learn.data, 'log') and learn.data.log) else actls 203 | return torch.tensor(actls, device=learn.data.device) 204 | 205 | 206 | def calc_loss(func:Callable, pred:Tensor, targ:Tensor, device=None)->Rank0Tensor: 207 | ''' 208 | Calculates error from predictions and actuals with a given metrics function 209 | ''' 210 | if (device is None): 211 | return func(pred, targ) 212 | else: 213 | return func(torch.tensor(pred, device=device), targ) 214 | 215 | 216 | def calc_error(df:DataFrame, learn:Learner, dep_col:AnyStr, 217 | func:Callable, bs:int=None)->float: 218 | ''' 219 | Wrapping function to calculate error for new dataframe on existing learner (learn.model) 220 | See following functions' docstrings for details 221 | ''' 222 | preds = get_cust_preds(df=df, learn=learn, bs=bs) 223 | actls = convert_dep_col(df, dep_col, learn) 224 | error = calc_loss(func, pred=preds, targ=actls, device=learn.data.device) 225 | return float(error) 226 | 227 | def emb_fwrd_sim(model, x_cat:Tensor, x_cont:Tensor)->Tensor: 228 | ''' 229 | Part that was completely taking from fastai Tabular model source :) 230 | Gets inner representation of input dataframe (Catigorified, Filled and Normalized) 231 | and process it with embeddings 'prelayer'. Also continuous variables are processed with BatchNorm if needed. 232 | As a result output is model gets on it's layers as input (embedding in fact are not layers, but before them) 233 | ''' 234 | if model.n_emb != 0: 235 | x = [e(x_cat[:,i]) for i,e in enumerate(model.embeds)] 236 | x = torch.cat(x, 1) 237 | x = model.emb_drop(x) 238 | if model.n_cont != 0: 239 | x_cont = model.bn_cont(x_cont) 240 | x = torch.cat([x, x_cont], 1) if model.n_emb != 0 else x_cont 241 | return x 242 | 243 | 244 | def get_inner_repr(df:DataFrame, learn:Learner)->Tensor: 245 | ''' 246 | Gets new dataframe that has categorical and continuous columns the learner war learnt with 247 | (are being taken from learner automatically) 248 | And outputs inner representation of these data -- what model gets after embeddings 249 | Is useful for ex. to use learnt embeddings in random forest 250 | This output can be directly feed to RF learner (after turning it to numpy if needed) 251 | ''' 252 | xs = get_model_real_input(df=df, learn=learn) 253 | return emb_fwrd_sim(model=learn.model, x_cat=xs[0], x_cont=xs[1]) 254 | 255 | def calc_error_mixed_col(df:DataFrame, 256 | learn:Learner, 257 | dep_col:AnyStr, 258 | sampl_col:AnyStr, 259 | func:Callable, 260 | bs:int=None, 261 | rounds=5)->float: 262 | df_temp = pd.concat([df]*rounds, ignore_index=True).copy() 263 | df_temp[sampl_col] = np.random.permutation(df_temp[sampl_col].values) 264 | return calc_error(df=df_temp, learn=learn, dep_col=dep_col, func=func, bs=bs) 265 | 266 | 267 | def get_columns(learn:Learner)->tuple: 268 | cats, cats_temp, conts, conts_temp = [], [], [], [] 269 | proc = learn.data.processor[0] 270 | for prc in proc.procs: 271 | if (type(prc) == Categorify): 272 | cats_temp = prc.cat_names 273 | elif (type(prc) == Normalize): 274 | conts = prc.cont_names 275 | 276 | #delete _na columns 277 | conts_temp = [cont+'_na' for cont in conts] 278 | for cat in cats_temp: 279 | if (cat not in conts_temp): 280 | cats.append(cat) 281 | 282 | return cats, conts 283 | 284 | 285 | def calc_feat_importance(df:DataFrame, 286 | learn:Learner, 287 | dep_col:AnyStr, 288 | func:Callable, 289 | bs:int=None, 290 | rounds=5)->OrderedDict: 291 | 292 | base_error = calc_error(df=df, learn=learn, dep_col=dep_col, func=func, bs=bs) 293 | cats, conts = get_columns(learn=learn) 294 | importance = {} 295 | pbar = master_bar(cats+conts, total=len(cats+conts)) 296 | for col in pbar: 297 | importance[col] = calc_error_mixed_col(df=df, learn=learn, dep_col=dep_col, 298 | sampl_col=col, func=func, bs=bs, rounds=rounds) 299 | _ = progress_bar(range(1), display=False, parent=pbar) #looks like fastprogress doesn't work without 2nd bar :( 300 | for key, value in importance.items(): 301 | importance[key] = (value - base_error)/base_error 302 | return collections.OrderedDict(sorted(importance.items(), key=lambda kv: kv[1], reverse=True)) 303 | 304 | 305 | def calc_fi_custom(df:DataFrame, 306 | learn:Learner, 307 | dep_col:AnyStr, 308 | fields:List, 309 | func:Callable, 310 | bs:int=None, 311 | rounds=5)->OrderedDict: 312 | 313 | base_error = calc_error(df=df, learn=learn, dep_col=dep_col, func=func, bs=bs) 314 | importance = {} 315 | pbar = master_bar(fields, total=len(fields)) 316 | for field in pbar: 317 | key = field if isinstance(field, str) else ', '.join(str(e) for e in field) 318 | importance[key] = calc_error_mixed_col(df=df, learn=learn, dep_col=dep_col, 319 | sampl_col=field, func=func, bs=bs, rounds=rounds) 320 | _ = progress_bar(range(1), display=False, parent=pbar) #looks like fastprogress doesn't work without 2nd bar :( 321 | for key, value in importance.items(): 322 | importance[key] = (value - base_error)/base_error 323 | return collections.OrderedDict(sorted(importance.items(), key=lambda kv: kv[1], reverse=True)) 324 | 325 | def ord_dic_to_df(ord_dict:OrderedDict)->DataFrame: 326 | return pd.DataFrame([[k, v] for k, v in ord_dict.items()], columns=['feature', 'importance']) 327 | 328 | def plot_importance(df:DataFrame, limit=20, asc=False): 329 | df_copy = df.copy() 330 | df_copy['feature'] = df_copy['feature'].str.slice(0,25) 331 | ax = df_copy.sort_values(by='importance', ascending=asc)[:limit].sort_values(by='importance', ascending=not(asc)).plot.barh(x="feature", y="importance", sort_columns=True, figsize=(10, 10)) 332 | for p in ax.patches: 333 | ax.annotate(f'{p.get_width():.4f}', ((p.get_width() * 1.005), p.get_y() * 1.005)) 334 | 335 | 336 | # implement function that returns learner object in your notebook 337 | # 338 | # For ex. 339 | # def build_learner_cur(df:DataFrame, 340 | # bs:int, 341 | # acc_func:Callable, 342 | # dep_var:str, 343 | # to_drop_cat:tuple=(), 344 | # to_drop_cont:tuple=()): 345 | # cat_vars_mod = list_diff(cat_vars, to_drop_cat) 346 | # cont_vars_mod = list_diff(cont_vars, to_drop_cont) 347 | # data = (TabularList.from_df(df, path=path, cat_names=cat_vars_mod, cont_names=cont_vars_mod, procs=procs) 348 | # .split_by_idx(valid_idx) 349 | # .label_from_df(cols=dep_var, label_cls=FloatList, log=True) 350 | # .databunch(bs=bs)) 351 | # np.random.seed(1001) 352 | # learn = tabular_learner(data, 353 | # layers=p['layers'], 354 | # ps=p['layers_drop'], 355 | # emb_drop=p['emb_drop'], 356 | # y_range=y_range, 357 | # metrics=acc_func, 358 | # loss_func=MAELossFlat(), 359 | # callback_fns=[CSVLogger]) 360 | # return learn 361 | # 362 | 363 | # implement function does 1 training loop in your notebook 364 | # 365 | # For ex. 366 | # def do_train_loop_cur(learn:Learner, cycles): 367 | # learn.fit_one_cycle(cyc_len=cycles, max_lr=p['max_lr'], wd=p['w_decay']) 368 | 369 | def clear_pbar(): 370 | # Just to clear the output. Yes, I know, I agree It's awfull should be refactored 371 | for _ in progress_bar(range(1), parent=None, leave=False): 372 | 1==1 373 | 374 | def extract_metrics_median(metrics_df:DataFrame, acc_func:Callable, bottom_X:float=0.2)->float: 375 | func_name = acc_func.__name__ 376 | metr = metrics_df[func_name].to_numpy() 377 | subset = metr[np.argsort(metr)][-math.ceil(len(metr)*bottom_X):] if (func_name =='accuracy') else metr[np.argsort(metr)][:math.ceil(len(metr)*bottom_X)] 378 | metrics = np.median(subset) 379 | return float(metrics) 380 | 381 | def calc_valid_acc(learn:Learner, func:Callable)->float: 382 | metr = learn.csv_logger.read_logged_file() 383 | acc = extract_metrics_median(metrics_df=metr, acc_func=func) 384 | return float(acc) 385 | 386 | def calc_acc(df:DataFrame, 387 | bs:int, 388 | acc_func:Callable, 389 | dep_var:str, 390 | to_drop_cat:tuple=(), 391 | to_drop_cont:tuple=(), 392 | load_learn:str=None, 393 | trains:int=1, 394 | cycles:int=80, 395 | is_overall_mode:bool=None)->float: 396 | learn = build_learner(df=df, 397 | bs=bs, 398 | acc_func=acc_func, 399 | dep_var=dep_var, 400 | to_drop_cat=to_drop_cat, 401 | to_drop_cont=to_drop_cont) 402 | if (load_learn is not None): 403 | learn = learn.load(load_learn) 404 | else: 405 | for i in range(trains): 406 | print(f"Train {i+1} of {trains}") 407 | do_train_loop(learn, cycles) 408 | clear_pbar() 409 | if (is_overall_mode is None) or (is_overall_mode == False): 410 | acc = calc_valid_acc(learn=learn, func=acc_func) 411 | else: 412 | acc = calc_error(df=df, learn=learn, dep_col=dep_var, func=acc_func, bs=bs) 413 | return acc 414 | 415 | def calc_1_imp_relearn(base_error:float, 416 | df:DataFrame, 417 | bs:int, 418 | acc_func:Callable, 419 | dep_var:str, 420 | to_drop_cat:tuple=(), 421 | to_drop_cont:tuple=(), 422 | load_learn:str=None, 423 | trains:int=1, 424 | cycles:int=80, 425 | is_overall_mode:bool=None)->float: 426 | error = calc_acc(df, bs, acc_func, dep_var, to_drop_cat, to_drop_cont, load_learn, 427 | trains=trains, cycles=cycles, is_overall_mode=is_overall_mode) 428 | if (acc_func.__name__ == 'accuracy'): 429 | base_acc, accuracy = base_error, error # Just rename for better understanding 430 | importance = (base_acc - accuracy)/base_acc 431 | else: 432 | importance = (error - base_error)/base_error 433 | return (list(to_drop_cat)+list(to_drop_cont), importance) 434 | 435 | def print_importance_res(dropped:List, importance:float): 436 | print('Features '+', '.join(dropped)+' have accumulated importance of') 437 | print(importance) 438 | 439 | def calc_many_imps_relearn(base_error:float, 440 | df:DataFrame, 441 | bs:int, 442 | acc_func:Callable, 443 | dep_var:str, 444 | to_drop_cats:tuple=(), 445 | to_drop_conts:tuple=(), 446 | load_learn:str=None, 447 | trains:int=1, 448 | cycles:int=80, 449 | is_overall_mode:bool=None)->float: 450 | 451 | to_drop_cats = listify(to_drop_cats) 452 | to_drop_conts = listify(to_drop_conts) 453 | importances = {} 454 | 455 | overall = len(list(to_drop_cats)+list(to_drop_conts)) 456 | for i, var in enumerate(to_drop_cats): 457 | var = listify(var) 458 | print(f"Categorical feature {i+1} of {len(to_drop_cats)}") 459 | imp = calc_1_imp_relearn(base_error, df, bs, acc_func, 460 | dep_var=dep_var, to_drop_cat=var, trains=trains, 461 | cycles=cycles, is_overall_mode=is_overall_mode) 462 | key = imp[0] if isinstance(imp[0], str) else ', '.join(str(e) for e in imp[0]) 463 | importances[key] = imp 464 | 465 | for i, var in enumerate(to_drop_conts): 466 | var = listify(var) 467 | print(f"Continuous feature {i+1} of {len(to_drop_conts)}") 468 | imp = calc_1_imp_relearn(base_error, df, bs, acc_func, 469 | dep_var=dep_var, to_drop_cont=var, trains=trains, 470 | cycles=cycles, is_overall_mode=is_overall_mode) 471 | key = imp[0] if isinstance(imp[0], str) else ', '.join(str(e) for e in imp[0]) 472 | importances[key] = imp 473 | 474 | return importances 475 | 476 | def calc_mean_dict(lst): 477 | mean_dict = {} 478 | ln = len(lst) 479 | for key, value in lst[0].items(): 480 | mean_dict[key] = np.zeros(ln) 481 | for i, row in enumerate(lst): 482 | for key, value in row.items(): 483 | mean_dict[key][i] = value[1] 484 | for key, value in mean_dict.items(): 485 | mean_dict[key] = np.median(value) 486 | 487 | return mean_dict 488 | 489 | 490 | def calc_many_imps_relearn_steps(base_error:float, 491 | df:DataFrame, 492 | bs:int, 493 | acc_func:Callable, 494 | dep_var:str, 495 | to_drop_cats:tuple=(), 496 | to_drop_conts:tuple=(), 497 | load_learn:str=None, 498 | trains=1, 499 | cycles=80, 500 | rounds=5, 501 | is_overall_mode:bool=None)->dict: 502 | ''' 503 | to_drop_cats and to_drop_conts:tuple can be tupple of tuples (lists of lists) 504 | this means we measure every item in the first list and retrain without every item in the second one in one turn 505 | (treat it as one entity) 506 | ''' 507 | acc = [] 508 | for i in range(rounds): 509 | print(f"Round {i+1} of {rounds}") 510 | acc_ = calc_many_imps_relearn(base_error=base_error, 511 | df=df, 512 | bs=bs, 513 | acc_func=acc_func, 514 | dep_var=dep_var, 515 | to_drop_cats=to_drop_cats, 516 | to_drop_conts=to_drop_conts, 517 | trains=trains, 518 | cycles=cycles, 519 | is_overall_mode=is_overall_mode) 520 | acc.append(acc_) 521 | imp = calc_mean_dict(acc) 522 | return collections.OrderedDict(sorted(imp.items(), key=lambda kv: kv[1], reverse=True)) 523 | 524 | 525 | def calc_base_acc_steps(df:DataFrame, 526 | bs:int, 527 | acc_func:Callable, 528 | dep_var:str, 529 | trains=1, 530 | cycles=80, 531 | rounds=5, 532 | is_overall_mode:bool=None)->float: 533 | base_acc=np.empty((rounds)) 534 | for i in range(rounds): 535 | print(f"Round {i+1} of {rounds}") 536 | base_acc[i] = calc_acc(df=df, bs=bs, 537 | acc_func=acc_func, dep_var=dep_var, trains=trains, 538 | cycles=cycles, is_overall_mode=is_overall_mode) 539 | return np.median(base_acc) 540 | 541 | def get_field_uniq_x_coef(df:DataFrame, field:str, coef:float)->list: 542 | ''' 543 | This function outputs threshold to number of occurrences different variants of list of columns (fields) 544 | In short if coef for ex. is 0.9, then function outputs number of occurrences for all but least 10% 545 | of the least used 546 | If coef is more 1.0, then 'coef' itself is used as threshold 547 | ''' 548 | if (coef > 1): 549 | return math.ceil(coef) 550 | coef = 0. if (coef < 0) else coef 551 | occs = df.groupby(field).size().reset_index(name="Times").sort_values(['Times'], ascending=False) 552 | num = math.ceil(coef*len(occs)) 553 | if (num <= 0): 554 | # number of occurances is now = max_occs+1 (so it will be no items with this filter) 555 | return occs.iloc[0]['Times'] + 1 556 | else: 557 | return occs.iloc[num-1]['Times'] 558 | 559 | 560 | def get_part_dep_one_list(df:DataFrame, 561 | learn:Learner, bs:int=None, fields:list=(), coef:float=1.0, to_int:bool=False, 562 | dep_name:str=None, is_sorted:bool=True)->DataFrame: 563 | ''' 564 | Function calculate partial dependency for column in fields. 565 | Fields is a list of lists of what columns we want to test. The inner items are treated as connected fields. 566 | For ex. fields = [['Store','StoreType']] mean that Store and StoreType is treated as one entity 567 | (it's values are substitute as a pair, not as separate values) 568 | coef is useful when we don't want to deal with all the variants, but only with most common 569 | ''' 570 | NAN_SUBST = '###na###' 571 | CONT_COLS = get_cont_cols(learn) 572 | if (dep_name is None): 573 | dep_name = 'dep_var' 574 | 575 | fields = listify(fields) 576 | df = apply_fill(df=df, learn=learn) 577 | 578 | #divide cont variables into groups 579 | if is_in_list(values=fields, in_list=CONT_COLS): 580 | for col in which_elms(values=fields, in_list=CONT_COLS): 581 | edges = np.histogram_bin_edges(a=df[col].dropna(), bins='auto') 582 | for x,y in zip(edges[::],edges[1::]): 583 | df.loc[(df[col] > x) & (df[col] < y), col] = (x+y)/2 584 | 585 | field_min_occ = get_field_uniq_x_coef(df=df, field=fields, coef=coef) 586 | df[fields] = df[fields].fillna(NAN_SUBST) #to treat None as a separate field 587 | occs = df.groupby(fields).size().reset_index(name="Times").sort_values(['Times'], ascending=False) 588 | occs[fields] = occs[fields].replace(to_replace=NAN_SUBST, value=np.nan) #get back Nones from NAN_SUBST 589 | df[fields] = df[fields].replace(to_replace=NAN_SUBST, value=np.nan) #get back Nones from NAN_SUBST 590 | occs = occs[occs['Times'] >= field_min_occ] 591 | df_copy = df.merge(occs[fields]).copy() 592 | 593 | frame = [] 594 | ln = len(occs) 595 | if (ln > 0): 596 | pbar = master_bar(occs.iterrows(), total=ln) 597 | for _, row in pbar: 598 | # We don't need to do df_copy = df.merge(occs[field]).copy() every time 599 | # as every time we change the same column (set of columns) 600 | record = [] 601 | pb = progress_bar(fields, display=False, parent=pbar) 602 | for fld in pb: 603 | df_copy[fld] = row[fld] 604 | preds = get_cust_preds(df=df_copy, learn=learn, bs=bs) 605 | preds = np.exp(np.median(preds)) if (hasattr(learn.data, 'log') and learn.data.log) else np.median(preds) 606 | pred = int(preds) if to_int else preds 607 | for fld in fields: 608 | record.append(row[fld]) 609 | record.append(pred) 610 | record.append(row['Times']) 611 | frame.append(record) 612 | out = pd.DataFrame(frame, columns=fields+[dep_name, 'times']) 613 | median = out[dep_name].median() 614 | out[dep_name] /= median 615 | if (is_sorted == True): 616 | out = out.sort_values(by=dep_name, ascending=False) 617 | return out 618 | 619 | def get_cat_cols(learn:Learner, is_wo_na=True)->List: 620 | ''' 621 | Just outputs category fields from LabelLists object 622 | ''' 623 | catf = None 624 | result = [] 625 | proc = learn.data.processor[0] 626 | for prc in proc.procs: 627 | if (type(prc) == Categorify): 628 | catf = prc 629 | if (catf is not None): 630 | result = [c for c in catf.cat_names if ((is_wo_na is not None) and (is_wo_na == True) and (c[-3:] != "_na"))] 631 | return result 632 | 633 | 634 | def get_cont_cols(learn:Learner)->List: 635 | ''' 636 | Just outputs continuous fields from LabelLists object 637 | ''' 638 | norm = None 639 | result = [] 640 | proc = learn.data.processor[0] 641 | 642 | for prc in proc.procs: 643 | if (type(prc) == Normalize): 644 | norm = prc 645 | 646 | if (norm is not None): 647 | result = norm.cont_names 648 | 649 | return result 650 | 651 | 652 | def get_part_dep(df:DataFrame, learn:Learner, bs:int=None, 653 | fields:tuple=None, coef:float=1.0, to_int:bool=False, 654 | dep_name:str=None, is_sorted:bool=True)->List: 655 | ''' 656 | Makes a datafreme with partial dependencies for every categorical variable in df 657 | ''' 658 | result = [] 659 | if (fields is None): 660 | fields = get_cat_cols(learn=learn) + get_cont_cols(learn=learn) 661 | 662 | for field in fields: 663 | new_df = get_part_dep_one_list(df=df, learn=learn, bs=bs, fields=field, to_int=to_int, 664 | dep_name=dep_name, coef=coef, is_sorted=is_sorted) 665 | new_df['feature'] = str(field) 666 | if is_listy(field): 667 | new_df['value'] = new_df[field].values.tolist() 668 | new_df.drop(columns=field, inplace=True) 669 | else: 670 | new_df = new_df.rename(index=str, columns={str(field): "value"}) 671 | result.append(new_df) 672 | clear_pbar() 673 | result = pd.concat(result, ignore_index=True, sort=True) 674 | result = result[['feature', 'value', dep_name, 'times']] 675 | 676 | return result 677 | 678 | def build_correlation_matr(df:DataFrame): 679 | ''' 680 | Build Spearman rank-order correlation matrix 681 | NA in df should be fixed before pass here 682 | ''' 683 | corr = np.round(scipy.stats.spearmanr(df).correlation, 4) 684 | corr[np.isnan(corr)] = 0.0 685 | np.fill_diagonal(corr, 1.0) 686 | return corr 687 | 688 | def plot_dendrogram_corr(corr_matr, columns, figsize=None, leaf_font_size=16): 689 | ''' 690 | Plots dendrogram for a given correlation matrix 691 | ''' 692 | if (figsize is None): 693 | figsize = (15, 0.02*leaf_font_size*len(columns)) 694 | corr_condensed = hc.distance.squareform(1-corr_matr) 695 | z = hc.linkage(corr_condensed, method='average') 696 | fig = plt.figure(figsize=figsize) 697 | dendrogram = hc.dendrogram(z, labels=columns, orientation='left', leaf_font_size=leaf_font_size) 698 | plt.show() 699 | 700 | def plot_dendrogram(df:DataFrame, figsize=None, leaf_font_size=16): 701 | corr = build_correlation_matr(df) 702 | plot_dendrogram_corr(corr_matr=corr, columns=df.columns, figsize=figsize, leaf_font_size=leaf_font_size) 703 | 704 | def cramers_corrected_stat(confusion_matrix): 705 | """ calculate Cramers V statistic for categorial-categorial association. 706 | uses correction from Bergsma and Wicher, 707 | Journal of the Korean Statistical Society 42 (2013): 323-328 708 | """ 709 | chi2 = scipy.stats.chi2_contingency(confusion_matrix)[0] 710 | if (chi2 == 0): 711 | return 0.0 712 | n = confusion_matrix.sum().sum() 713 | phi2 = chi2/n 714 | r,k = confusion_matrix.shape 715 | phi2corr = max(0, phi2 - ((k-1)*(r-1))/(n-1)) 716 | rcorr = r - ((r-1)**2)/(n-1) 717 | kcorr = k - ((k-1)**2)/(n-1) 718 | return np.sqrt(phi2corr / min( (kcorr-1), (rcorr-1))) 719 | 720 | def get_cramer_v_matr(df:DataFrame)->np.ndarray: 721 | ''' 722 | Calculate Cramers V statistic for every pair in df's columns 723 | ''' 724 | cols = list(df.columns) 725 | corrM = np.zeros((len(cols), len(cols))) 726 | pbar = master_bar(list(itertools.combinations(cols, 2))) 727 | for col1, col2 in pbar: 728 | _ = progress_bar(range(1), parent=pbar) #looks like fastprogress doesn't work without 2nd bar :( 729 | idx1, idx2 = cols.index(col1), cols.index(col2) 730 | corrM[idx1, idx2] = cramers_corrected_stat(pd.crosstab(df[col1], df[col2])) 731 | corrM[idx2, idx1] = corrM[idx1, idx2] 732 | np.fill_diagonal(corrM, 1.0) 733 | return corrM 734 | 735 | def get_top_corr_df(df:DataFrame, corr_thr:float=0.8, corr_matr:array=None)->DataFrame: 736 | if (corr_matr is not None): 737 | corr = corr_matr 738 | else: 739 | corr = build_correlation_matr(df=df) 740 | corr = np.where(abs(corr) 2): 744 | idxs.append(i) 745 | cols = df.columns[idxs] 746 | return pd.DataFrame(corr[np.ix_(idxs, idxs)], columns=cols, index=cols) 747 | 748 | def get_top_corr_dict_corrs(top_corrs:DataFrame)->OrderedDict: 749 | cols = top_corrs.columns 750 | top_corrs_np = top_corrs.to_numpy() 751 | corr_dict = {} 752 | for i in range(top_corrs_np.shape[0]): 753 | for j in range(i+1, top_corrs_np.shape[0]): 754 | if (top_corrs_np[i, j] > 0): 755 | corr_dict[cols[i]+' vs '+cols[j]] = np.round(top_corrs_np[i, j], 3) 756 | return collections.OrderedDict(sorted(corr_dict.items(), key=lambda kv: abs(kv[1]), reverse=True)) 757 | 758 | def get_top_corr_dict(df:DataFrame, corr_thr:float=0.8, corr_matr:array=None)->OrderedDict: 759 | ''' 760 | Outputs top pairs of correlation in a given dataframe with a given correlation matrix 761 | Filters output mith minimal correlation of corr_thr 762 | ''' 763 | top_corrs = get_top_corr_df(df, corr_thr, corr_matr) 764 | return get_top_corr_dict_corrs(top_corrs) 765 | 766 | def get_classes_o_list(learn:Learner): 767 | procs = learn.data.processor[0] 768 | return procs.classes 769 | 770 | def get_rev_emb_idxs(learn:Learner)->dict: 771 | classes_dict = get_classes_o_list(learn=learn) 772 | return {c:i for i, (c, _) in enumerate(classes_dict.items()) if (c[-3:] != "_na")} 773 | 774 | 775 | def get_emb_outp(learn:Learner, field:str, inp:str, rev_emb_idxs:dict, classes, embs): 776 | emb = embs[rev_emb_idxs[field]] 777 | idx, = np.where(classes[field] == inp) 778 | if (len(idx) == 1): 779 | cat_idx = idx[0] 780 | else: 781 | cat_idx = 0 782 | return emb(torch.tensor(cat_idx, device=learn.data.device)) 783 | 784 | 785 | def get_embs_map(learn:Learner)->OrderedDict: 786 | ''' 787 | Output embedding vector for every item of every cafegirical column as a dictionary of dicts 788 | 789 | ''' 790 | cat_cols = get_cat_cols(learn=learn, is_wo_na=True) 791 | rev_emb_idxs = get_rev_emb_idxs(learn=learn) 792 | classes = get_classes_o_list(learn=learn) 793 | embs = learn.model.embeds 794 | learn.model.eval(); 795 | result = OrderedDict() 796 | 797 | for cat in cat_cols: 798 | cat_res = OrderedDict() 799 | for val in classes[cat]: 800 | cat_res[val] = get_emb_outp(learn=learn, 801 | field=cat, inp=str(val), 802 | rev_emb_idxs=rev_emb_idxs, 803 | classes=classes, embs=embs) 804 | result[cat] = cat_res 805 | 806 | return result 807 | 808 | 809 | def emb_map_reduce_dim(embs_map:OrderedDict, outp_dim:int=3, to_df:bool=True, method:str='pytorch', exclude:list=None): 810 | ''' 811 | Reduces dimention of embedding map upto outp_dim 812 | Can use 'pytorch' approach (pca) 813 | or 'scilearn' for manifold.TSNE (longer, but not sure that it is better) 814 | ''' 815 | exclude = listify(exclude) 816 | result = OrderedDict() 817 | for feat, val in embs_map.items(): 818 | reformat = [] 819 | names = [] 820 | for k,v in val.items(): 821 | reformat.append(v) 822 | names.append(k) 823 | reformat = torch.stack(reformat) 824 | if (exclude is not None) and (feat in exclude): 825 | continue 826 | if (method == 'scilearn'): 827 | tsne = manifold.TSNE(n_components=outp_dim, init='pca') 828 | reduced = tsne.fit_transform(to_np(reformat)) 829 | else: 830 | reduced = reformat.pca(outp_dim) 831 | record = OrderedDict({k:v for k, v in zip(names, reduced)}) 832 | result[feat] = record 833 | 834 | if (to_df == True): 835 | data = [] 836 | for feat, val in result.items(): 837 | for k,v in val.items(): 838 | dt = list(v) if (method == 'scilearn') else list(to_np(v)) 839 | data.append([feat] + [k] + dt) 840 | names = ['feature', 'value'] + ['axis_' + str(i) for i in range(outp_dim)] 841 | result = pd.DataFrame(data, columns=names) 842 | 843 | return result 844 | 845 | 846 | def add_times_col(embs_map:DataFrame, df:DataFrame)->DataFrame: 847 | ''' 848 | Adds to embeddings map dataframe new column with times of value's number of occurrences 849 | Usefull for estimation of how accurate the value is (more time means more sure you can be) 850 | ''' 851 | times = np.zeros(len(embs_map)) 852 | last_feat = '' 853 | vc = None 854 | for i, (f, v) in enumerate(zip(embs_map['feature'], embs_map['value'])): 855 | if (f != last_feat): 856 | vc = df[f].value_counts(dropna=False) 857 | vc.index = vc.index.map(str) 858 | last_feat = f 859 | if (v != '#na#'): 860 | times[i] = vc[v] 861 | else: 862 | times[i] = vc['nan'] if ('nan' in vc.index) else 0 863 | result = embs_map.copy() 864 | result['times'] = times 865 | return result 866 | 867 | # Little helpers for saving/loading variables with pickle 868 | def sv_var(var, name, path): 869 | f = open(path/f"{name}.pkl","wb") 870 | pickle.dump(var, f) 871 | f.close() 872 | 873 | def ld_var(name, path): 874 | f = open(path/f"{name}.pkl","rb") 875 | var = pickle.load(f) 876 | f.close() 877 | return var 878 | 879 | def plot_2d_emb(emb_map:DataFrame, feature:str, top_x:int=10): 880 | sub_df = emb_map.query(f"feature == '{feature}'").sort_values('times', ascending=False).head(top_x) 881 | X = sub_df['axis_0'] 882 | Y = sub_df['axis_1'] 883 | plt.figure(figsize=(15, 8)) 884 | plt.scatter(X, Y) 885 | for name, x, y in zip(sub_df['value'], X, Y): 886 | plt.text(x, y, name, color=np.random.rand(3)*0.7, fontsize=11) 887 | plt.show() -------------------------------------------------------------------------------- /p_fastai.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "ExecuteTime": { 8 | "end_time": "2020-02-04T07:51:59.955651Z", 9 | "start_time": "2020-02-04T07:51:59.944675Z" 10 | } 11 | }, 12 | "outputs": [], 13 | "source": [ 14 | "%reload_ext autoreload\n", 15 | "%autoreload 2" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": 2, 21 | "metadata": { 22 | "ExecuteTime": { 23 | "end_time": "2020-02-04T07:52:00.775667Z", 24 | "start_time": "2020-02-04T07:51:59.956726Z" 25 | } 26 | }, 27 | "outputs": [], 28 | "source": [ 29 | "from fastai.tabular import *" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 3, 35 | "metadata": { 36 | "ExecuteTime": { 37 | "end_time": "2020-02-04T07:52:00.789494Z", 38 | "start_time": "2020-02-04T07:52:00.776873Z" 39 | } 40 | }, 41 | "outputs": [], 42 | "source": [ 43 | "from fastai import tabular" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 4, 49 | "metadata": { 50 | "ExecuteTime": { 51 | "end_time": "2020-02-04T07:52:00.803458Z", 52 | "start_time": "2020-02-04T07:52:00.790694Z" 53 | } 54 | }, 55 | "outputs": [], 56 | "source": [ 57 | "# import sys\n", 58 | "# sys.path.append(\"../common\")" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": 5, 64 | "metadata": { 65 | "ExecuteTime": { 66 | "end_time": "2020-02-04T07:52:00.860672Z", 67 | "start_time": "2020-02-04T07:52:00.804450Z" 68 | } 69 | }, 70 | "outputs": [], 71 | "source": [ 72 | "from exp.nb_ import *" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": 6, 78 | "metadata": { 79 | "ExecuteTime": { 80 | "end_time": "2020-02-04T07:52:00.875788Z", 81 | "start_time": "2020-02-04T07:52:00.861602Z" 82 | } 83 | }, 84 | "outputs": [], 85 | "source": [ 86 | "# path=Path('../data/other/')" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": 7, 92 | "metadata": { 93 | "ExecuteTime": { 94 | "end_time": "2020-02-04T07:52:00.904266Z", 95 | "start_time": "2020-02-04T07:52:00.876759Z" 96 | } 97 | }, 98 | "outputs": [], 99 | "source": [ 100 | "df = pd.read_csv(path/'train.csv.zip', compression='zip')" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": 8, 106 | "metadata": { 107 | "ExecuteTime": { 108 | "end_time": "2020-02-04T07:52:00.926246Z", 109 | "start_time": "2020-02-04T07:52:00.905799Z" 110 | } 111 | }, 112 | "outputs": [ 113 | { 114 | "data": { 115 | "text/html": [ 116 | "
\n", 117 | "\n", 130 | "\n", 131 | " \n", 132 | " \n", 133 | " \n", 134 | " \n", 135 | " \n", 136 | " \n", 137 | " \n", 138 | " \n", 139 | " \n", 140 | " \n", 141 | " \n", 142 | " \n", 143 | " \n", 144 | " \n", 145 | " \n", 146 | " \n", 147 | " \n", 148 | " \n", 149 | " \n", 150 | " \n", 151 | " \n", 152 | " \n", 153 | " \n", 154 | " \n", 155 | " \n", 156 | " \n", 157 | " \n", 158 | " \n", 159 | " \n", 160 | " \n", 161 | " \n", 162 | " \n", 163 | " \n", 164 | " \n", 165 | " \n", 166 | " \n", 167 | " \n", 168 | " \n", 169 | " \n", 170 | " \n", 171 | " \n", 172 | " \n", 173 | " \n", 174 | " \n", 175 | " \n", 176 | " \n", 177 | " \n", 178 | " \n", 179 | " \n", 180 | " \n", 181 | " \n", 182 | " \n", 183 | " \n", 184 | " \n", 185 | " \n", 186 | " \n", 187 | " \n", 188 | " \n", 189 | " \n", 190 | " \n", 191 | " \n", 192 | " \n", 193 | " \n", 194 | " \n", 195 | " \n", 196 | " \n", 197 | " \n", 198 | " \n", 199 | " \n", 200 | " \n", 201 | " \n", 202 | " \n", 203 | " \n", 204 | " \n", 205 | " \n", 206 | " \n", 207 | " \n", 208 | " \n", 209 | " \n", 210 | " \n", 211 | " \n", 212 | " \n", 213 | " \n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | "
S1C1S2C2S3C3S4C4S5C5hand
579371632311280
61863124911317210
642026432134104120
19868411493612450
108202101114923360
\n", 220 | "
" 221 | ], 222 | "text/plain": [ 223 | " S1 C1 S2 C2 S3 C3 S4 C4 S5 C5 hand\n", 224 | "579 3 7 1 6 3 2 3 11 2 8 0\n", 225 | "6186 3 12 4 9 1 13 1 7 2 1 0\n", 226 | "6420 2 6 4 3 2 13 4 10 4 12 0\n", 227 | "19868 4 11 4 9 3 6 1 2 4 5 0\n", 228 | "10820 2 10 1 11 4 9 2 3 3 6 0" 229 | ] 230 | }, 231 | "execution_count": 8, 232 | "metadata": {}, 233 | "output_type": "execute_result" 234 | } 235 | ], 236 | "source": [ 237 | "df.sample(5)" 238 | ] 239 | }, 240 | { 241 | "cell_type": "code", 242 | "execution_count": 9, 243 | "metadata": { 244 | "ExecuteTime": { 245 | "end_time": "2020-02-04T07:52:00.942073Z", 246 | "start_time": "2020-02-04T07:52:00.927846Z" 247 | } 248 | }, 249 | "outputs": [], 250 | "source": [ 251 | "suits = {1:'Hearts', 2:'Spades', 3:'Diamonds', 4:'Clubs'}" 252 | ] 253 | }, 254 | { 255 | "cell_type": "code", 256 | "execution_count": 10, 257 | "metadata": { 258 | "ExecuteTime": { 259 | "end_time": "2020-02-04T07:52:00.957886Z", 260 | "start_time": "2020-02-04T07:52:00.943066Z" 261 | } 262 | }, 263 | "outputs": [], 264 | "source": [ 265 | "cards = {1:'Ace', 2:'2', 3:'3', 4:'4', 5:'5', 6:'6', 7:'7', 8:'8', 9:'9', 10:'10', 11:'Jack', 12:'Queen', 13:'King'}" 266 | ] 267 | }, 268 | { 269 | "cell_type": "code", 270 | "execution_count": 11, 271 | "metadata": { 272 | "ExecuteTime": { 273 | "end_time": "2020-02-04T07:52:00.973366Z", 274 | "start_time": "2020-02-04T07:52:00.958856Z" 275 | } 276 | }, 277 | "outputs": [], 278 | "source": [ 279 | "hands = {0: 'Nothing', 1: 'Pair', 2: 'Two pairs', 3: 'Three of a kind',\n", 280 | " 4: 'Straight', 5: 'Flush', 6: 'Full house', 7: 'Four of a kind',\n", 281 | " 8: 'Straight flush', 9: 'Royal flush'}" 282 | ] 283 | }, 284 | { 285 | "cell_type": "code", 286 | "execution_count": 12, 287 | "metadata": { 288 | "ExecuteTime": { 289 | "end_time": "2020-02-04T07:52:01.008323Z", 290 | "start_time": "2020-02-04T07:52:00.974322Z" 291 | } 292 | }, 293 | "outputs": [], 294 | "source": [ 295 | "df = df.replace({'C1':cards, 'C2':cards, 'C3':cards, 'C4':cards, 'C5':cards})" 296 | ] 297 | }, 298 | { 299 | "cell_type": "code", 300 | "execution_count": 13, 301 | "metadata": { 302 | "ExecuteTime": { 303 | "end_time": "2020-02-04T07:52:01.035289Z", 304 | "start_time": "2020-02-04T07:52:01.009265Z" 305 | } 306 | }, 307 | "outputs": [], 308 | "source": [ 309 | "df = df.replace({'S1':suits, 'S2':suits, 'S3':suits, 'S4':suits, 'S5':suits})" 310 | ] 311 | }, 312 | { 313 | "cell_type": "code", 314 | "execution_count": 14, 315 | "metadata": { 316 | "ExecuteTime": { 317 | "end_time": "2020-02-04T07:52:01.057186Z", 318 | "start_time": "2020-02-04T07:52:01.036211Z" 319 | } 320 | }, 321 | "outputs": [], 322 | "source": [ 323 | "df = df.replace({'hand':hands})" 324 | ] 325 | }, 326 | { 327 | "cell_type": "code", 328 | "execution_count": 15, 329 | "metadata": { 330 | "ExecuteTime": { 331 | "end_time": "2020-02-04T07:52:01.080732Z", 332 | "start_time": "2020-02-04T07:52:01.058131Z" 333 | } 334 | }, 335 | "outputs": [ 336 | { 337 | "data": { 338 | "text/html": [ 339 | "
\n", 340 | "\n", 353 | "\n", 354 | " \n", 355 | " \n", 356 | " \n", 357 | " \n", 358 | " \n", 359 | " \n", 360 | " \n", 361 | " \n", 362 | " \n", 363 | " \n", 364 | " \n", 365 | " \n", 366 | " \n", 367 | " \n", 368 | " \n", 369 | " \n", 370 | " \n", 371 | " \n", 372 | " \n", 373 | " \n", 374 | " \n", 375 | " \n", 376 | " \n", 377 | " \n", 378 | " \n", 379 | " \n", 380 | " \n", 381 | " \n", 382 | " \n", 383 | " \n", 384 | " \n", 385 | " \n", 386 | " \n", 387 | " \n", 388 | " \n", 389 | " \n", 390 | " \n", 391 | " \n", 392 | " \n", 393 | " \n", 394 | " \n", 395 | " \n", 396 | " \n", 397 | " \n", 398 | " \n", 399 | " \n", 400 | " \n", 401 | " \n", 402 | " \n", 403 | " \n", 404 | " \n", 405 | " \n", 406 | " \n", 407 | " \n", 408 | " \n", 409 | " \n", 410 | " \n", 411 | " \n", 412 | " \n", 413 | " \n", 414 | " \n", 415 | " \n", 416 | " \n", 417 | " \n", 418 | " \n", 419 | " \n", 420 | " \n", 421 | " \n", 422 | " \n", 423 | " \n", 424 | " \n", 425 | " \n", 426 | " \n", 427 | " \n", 428 | " \n", 429 | " \n", 430 | " \n", 431 | " \n", 432 | " \n", 433 | " \n", 434 | " \n", 435 | " \n", 436 | " \n", 437 | " \n", 438 | " \n", 439 | " \n", 440 | " \n", 441 | " \n", 442 | " \n", 443 | " \n", 444 | " \n", 445 | " \n", 446 | " \n", 447 | " \n", 448 | " \n", 449 | " \n", 450 | " \n", 451 | " \n", 452 | " \n", 453 | " \n", 454 | " \n", 455 | " \n", 456 | " \n", 457 | " \n", 458 | " \n", 459 | " \n", 460 | " \n", 461 | " \n", 462 | " \n", 463 | " \n", 464 | " \n", 465 | " \n", 466 | " \n", 467 | " \n", 468 | " \n", 469 | " \n", 470 | " \n", 471 | " \n", 472 | " \n", 473 | " \n", 474 | " \n", 475 | " \n", 476 | " \n", 477 | " \n", 478 | " \n", 479 | " \n", 480 | " \n", 481 | " \n", 482 | " \n", 483 | " \n", 484 | " \n", 485 | " \n", 486 | " \n", 487 | " \n", 488 | " \n", 489 | " \n", 490 | " \n", 491 | " \n", 492 | " \n", 493 | " \n", 494 | " \n", 495 | " \n", 496 | " \n", 497 | " \n", 498 | " \n", 499 | " \n", 500 | " \n", 501 | " \n", 502 | " \n", 503 | " \n", 504 | " \n", 505 | " \n", 506 | " \n", 507 | " \n", 508 | " \n", 509 | " \n", 510 | " \n", 511 | " \n", 512 | " \n", 513 | " \n", 514 | " \n", 515 | " \n", 516 | " \n", 517 | " \n", 518 | " \n", 519 | " \n", 520 | " \n", 521 | " \n", 522 | " \n", 523 | " \n", 524 | " \n", 525 | " \n", 526 | "
S1C1S2C2S3C3S4C4S5C5hand
0Clubs9SpadesAceSpades2Clubs7Spades8Nothing
1Hearts4Diamonds6HeartsQueenDiamondsJackSpades7Nothing
2HeartsJackClubsAceDiamonds7ClubsJackSpadesAceTwo pairs
3Spades9Spades4Diamonds6Hearts9Clubs9Three of a kind
4Hearts8Spades4SpadesJackSpades2SpadesAceNothing
....................................
25005Clubs9Clubs6Diamonds6ClubsQueenClubs5Pair
25006Diamonds8Diamonds5ClubsJackSpades2HeartsKingNothing
25007Hearts8Clubs5DiamondsJackDiamonds2SpadesKingNothing
25008ClubsQueenDiamonds5SpadesAceSpades7Clubs6Nothing
25009HeartsAceHearts3Hearts7Hearts2Clubs2Pair
\n", 527 | "

25010 rows × 11 columns

\n", 528 | "
" 529 | ], 530 | "text/plain": [ 531 | " S1 C1 S2 C2 S3 C3 S4 C4 \\\n", 532 | "0 Clubs 9 Spades Ace Spades 2 Clubs 7 \n", 533 | "1 Hearts 4 Diamonds 6 Hearts Queen Diamonds Jack \n", 534 | "2 Hearts Jack Clubs Ace Diamonds 7 Clubs Jack \n", 535 | "3 Spades 9 Spades 4 Diamonds 6 Hearts 9 \n", 536 | "4 Hearts 8 Spades 4 Spades Jack Spades 2 \n", 537 | "... ... ... ... ... ... ... ... ... \n", 538 | "25005 Clubs 9 Clubs 6 Diamonds 6 Clubs Queen \n", 539 | "25006 Diamonds 8 Diamonds 5 Clubs Jack Spades 2 \n", 540 | "25007 Hearts 8 Clubs 5 Diamonds Jack Diamonds 2 \n", 541 | "25008 Clubs Queen Diamonds 5 Spades Ace Spades 7 \n", 542 | "25009 Hearts Ace Hearts 3 Hearts 7 Hearts 2 \n", 543 | "\n", 544 | " S5 C5 hand \n", 545 | "0 Spades 8 Nothing \n", 546 | "1 Spades 7 Nothing \n", 547 | "2 Spades Ace Two pairs \n", 548 | "3 Clubs 9 Three of a kind \n", 549 | "4 Spades Ace Nothing \n", 550 | "... ... ... ... \n", 551 | "25005 Clubs 5 Pair \n", 552 | "25006 Hearts King Nothing \n", 553 | "25007 Spades King Nothing \n", 554 | "25008 Clubs 6 Nothing \n", 555 | "25009 Clubs 2 Pair \n", 556 | "\n", 557 | "[25010 rows x 11 columns]" 558 | ] 559 | }, 560 | "execution_count": 15, 561 | "metadata": {}, 562 | "output_type": "execute_result" 563 | } 564 | ], 565 | "source": [ 566 | "df" 567 | ] 568 | }, 569 | { 570 | "cell_type": "code", 571 | "execution_count": 16, 572 | "metadata": { 573 | "ExecuteTime": { 574 | "end_time": "2020-02-04T07:52:01.096132Z", 575 | "start_time": "2020-02-04T07:52:01.081656Z" 576 | } 577 | }, 578 | "outputs": [], 579 | "source": [ 580 | "cat_vars_tpl = ('S1', 'C1', 'S2', 'C2', 'S3', 'C3', 'S4', 'C4', 'S5', 'C5')" 581 | ] 582 | }, 583 | { 584 | "cell_type": "code", 585 | "execution_count": 17, 586 | "metadata": { 587 | "ExecuteTime": { 588 | "end_time": "2020-02-04T07:52:01.110767Z", 589 | "start_time": "2020-02-04T07:52:01.096947Z" 590 | } 591 | }, 592 | "outputs": [], 593 | "source": [ 594 | "cont_vars_tpl = ()" 595 | ] 596 | }, 597 | { 598 | "cell_type": "code", 599 | "execution_count": 18, 600 | "metadata": { 601 | "ExecuteTime": { 602 | "end_time": "2020-02-04T07:52:01.125733Z", 603 | "start_time": "2020-02-04T07:52:01.111710Z" 604 | } 605 | }, 606 | "outputs": [], 607 | "source": [ 608 | "cat_vars = list(cat_vars_tpl)\n", 609 | "cont_vars = list(cont_vars_tpl)\n", 610 | "all_vars = cat_vars + cont_vars" 611 | ] 612 | }, 613 | { 614 | "cell_type": "code", 615 | "execution_count": 19, 616 | "metadata": { 617 | "ExecuteTime": { 618 | "end_time": "2020-02-04T07:52:01.141399Z", 619 | "start_time": "2020-02-04T07:52:01.126635Z" 620 | } 621 | }, 622 | "outputs": [ 623 | { 624 | "data": { 625 | "text/plain": [ 626 | "['hand']" 627 | ] 628 | }, 629 | "execution_count": 19, 630 | "metadata": {}, 631 | "output_type": "execute_result" 632 | } 633 | ], 634 | "source": [ 635 | "list_diff(df.columns, cat_vars, cont_vars)" 636 | ] 637 | }, 638 | { 639 | "cell_type": "code", 640 | "execution_count": 20, 641 | "metadata": { 642 | "ExecuteTime": { 643 | "end_time": "2020-02-04T07:52:01.156417Z", 644 | "start_time": "2020-02-04T07:52:01.142273Z" 645 | } 646 | }, 647 | "outputs": [], 648 | "source": [ 649 | "dep_var = 'hand'" 650 | ] 651 | }, 652 | { 653 | "cell_type": "code", 654 | "execution_count": 21, 655 | "metadata": { 656 | "ExecuteTime": { 657 | "end_time": "2020-02-04T07:52:01.172226Z", 658 | "start_time": "2020-02-04T07:52:01.157283Z" 659 | } 660 | }, 661 | "outputs": [], 662 | "source": [ 663 | "np.random.seed(1001)\n", 664 | "ln = len(df)\n", 665 | "valid_idx = np.random.choice(ln, int(ln*0.2), replace=False)" 666 | ] 667 | }, 668 | { 669 | "cell_type": "code", 670 | "execution_count": 22, 671 | "metadata": { 672 | "ExecuteTime": { 673 | "end_time": "2020-02-04T07:52:01.187319Z", 674 | "start_time": "2020-02-04T07:52:01.173050Z" 675 | } 676 | }, 677 | "outputs": [ 678 | { 679 | "data": { 680 | "text/plain": [ 681 | "5002" 682 | ] 683 | }, 684 | "execution_count": 22, 685 | "metadata": {}, 686 | "output_type": "execute_result" 687 | } 688 | ], 689 | "source": [ 690 | "len(valid_idx)" 691 | ] 692 | }, 693 | { 694 | "cell_type": "code", 695 | "execution_count": 23, 696 | "metadata": { 697 | "ExecuteTime": { 698 | "end_time": "2020-02-04T07:52:01.202673Z", 699 | "start_time": "2020-02-04T07:52:01.188745Z" 700 | } 701 | }, 702 | "outputs": [], 703 | "source": [ 704 | "procs=[FillMissing, Categorify, Normalize]" 705 | ] 706 | }, 707 | { 708 | "cell_type": "code", 709 | "execution_count": 24, 710 | "metadata": { 711 | "ExecuteTime": { 712 | "end_time": "2020-02-04T07:52:01.217812Z", 713 | "start_time": "2020-02-04T07:52:01.203632Z" 714 | } 715 | }, 716 | "outputs": [], 717 | "source": [ 718 | "def emb_sz_rule_reduced(n_cat:int)->int: return min(10, round(1.6 * n_cat**0.56))" 719 | ] 720 | }, 721 | { 722 | "cell_type": "code", 723 | "execution_count": 25, 724 | "metadata": { 725 | "ExecuteTime": { 726 | "end_time": "2020-02-04T07:52:01.232510Z", 727 | "start_time": "2020-02-04T07:52:01.218743Z" 728 | } 729 | }, 730 | "outputs": [], 731 | "source": [ 732 | "#monkey pacth embenning rule as 600 floats is too much for our case\n", 733 | "tabular.data.emb_sz_rule = emb_sz_rule_reduced" 734 | ] 735 | }, 736 | { 737 | "cell_type": "code", 738 | "execution_count": 26, 739 | "metadata": { 740 | "ExecuteTime": { 741 | "end_time": "2020-02-04T07:52:04.803907Z", 742 | "start_time": "2020-02-04T07:52:04.784128Z" 743 | } 744 | }, 745 | "outputs": [], 746 | "source": [ 747 | "BS = 128" 748 | ] 749 | }, 750 | { 751 | "cell_type": "markdown", 752 | "metadata": {}, 753 | "source": [ 754 | "### 85%" 755 | ] 756 | }, 757 | { 758 | "cell_type": "code", 759 | "execution_count": 27, 760 | "metadata": { 761 | "ExecuteTime": { 762 | "end_time": "2020-02-04T07:52:06.561774Z", 763 | "start_time": "2020-02-04T07:52:06.544310Z" 764 | } 765 | }, 766 | "outputs": [], 767 | "source": [ 768 | "layers = [1000, 500, 200]\n", 769 | "layers_drop = [0.001, 0.005, 0.01]\n", 770 | "emb_drop = 0.01\n", 771 | "cycles = 60\n", 772 | "w_decay = 0.01\n", 773 | "max_lr = 1e-3" 774 | ] 775 | }, 776 | { 777 | "cell_type": "code", 778 | "execution_count": 29, 779 | "metadata": { 780 | "ExecuteTime": { 781 | "end_time": "2020-02-04T07:52:35.978771Z", 782 | "start_time": "2020-02-04T07:52:35.865662Z" 783 | } 784 | }, 785 | "outputs": [], 786 | "source": [ 787 | "data = (TabularList.from_df(df, path=path, cat_names=cat_vars, cont_names=cont_vars, procs=procs)\n", 788 | " .split_by_idx(valid_idx)\n", 789 | " .label_from_df(cols=dep_var, label_cls=CategoryList)\n", 790 | " .databunch(bs=BS))" 791 | ] 792 | }, 793 | { 794 | "cell_type": "code", 795 | "execution_count": 30, 796 | "metadata": { 797 | "ExecuteTime": { 798 | "end_time": "2020-02-04T07:52:38.774091Z", 799 | "start_time": "2020-02-04T07:52:37.769101Z" 800 | } 801 | }, 802 | "outputs": [], 803 | "source": [ 804 | "learn = None\n", 805 | "np.random.seed(1001)\n", 806 | "learn = tabular_learner(data, \n", 807 | " layers=layers, \n", 808 | " ps=layers_drop, \n", 809 | " emb_drop=emb_drop, \n", 810 | " metrics=accuracy,\n", 811 | " callback_fns=[CSVLogger])" 812 | ] 813 | }, 814 | { 815 | "cell_type": "code", 816 | "execution_count": 31, 817 | "metadata": { 818 | "ExecuteTime": { 819 | "end_time": "2020-02-04T07:52:40.957424Z", 820 | "start_time": "2020-02-04T07:52:40.937148Z" 821 | } 822 | }, 823 | "outputs": [], 824 | "source": [ 825 | "max_lr = 3e-3" 826 | ] 827 | }, 828 | { 829 | "cell_type": "code", 830 | "execution_count": 32, 831 | "metadata": { 832 | "ExecuteTime": { 833 | "end_time": "2020-02-04T07:53:47.679547Z", 834 | "start_time": "2020-02-04T07:52:41.994384Z" 835 | }, 836 | "scrolled": true 837 | }, 838 | "outputs": [ 839 | { 840 | "data": { 841 | "text/html": [ 842 | "\n", 843 | " \n", 844 | " \n", 845 | " \n", 846 | " \n", 847 | " \n", 848 | " \n", 849 | " \n", 850 | " \n", 851 | " \n", 852 | " \n", 853 | " \n", 854 | " \n", 855 | " \n", 856 | " \n", 857 | " \n", 858 | " \n", 859 | " \n", 860 | " \n", 861 | " \n", 862 | " \n", 863 | " \n", 864 | " \n", 865 | " \n", 866 | " \n", 867 | " \n", 868 | " \n", 869 | " \n", 870 | " \n", 871 | " \n", 872 | " \n", 873 | " \n", 874 | " \n", 875 | " \n", 876 | " \n", 877 | " \n", 878 | " \n", 879 | " \n", 880 | " \n", 881 | " \n", 882 | " \n", 883 | " \n", 884 | " \n", 885 | " \n", 886 | " \n", 887 | " \n", 888 | " \n", 889 | " \n", 890 | " \n", 891 | " \n", 892 | " \n", 893 | " \n", 894 | " \n", 895 | " \n", 896 | " \n", 897 | " \n", 898 | " \n", 899 | " \n", 900 | " \n", 901 | " \n", 902 | " \n", 903 | " \n", 904 | " \n", 905 | " \n", 906 | " \n", 907 | " \n", 908 | " \n", 909 | " \n", 910 | " \n", 911 | " \n", 912 | " \n", 913 | " \n", 914 | " \n", 915 | " \n", 916 | " \n", 917 | " \n", 918 | " \n", 919 | " \n", 920 | " \n", 921 | " \n", 922 | " \n", 923 | " \n", 924 | " \n", 925 | " \n", 926 | " \n", 927 | " \n", 928 | " \n", 929 | " \n", 930 | " \n", 931 | " \n", 932 | " \n", 933 | " \n", 934 | " \n", 935 | " \n", 936 | " \n", 937 | " \n", 938 | " \n", 939 | " \n", 940 | " \n", 941 | " \n", 942 | " \n", 943 | " \n", 944 | " \n", 945 | " \n", 946 | " \n", 947 | " \n", 948 | " \n", 949 | " \n", 950 | " \n", 951 | " \n", 952 | " \n", 953 | " \n", 954 | " \n", 955 | " \n", 956 | " \n", 957 | " \n", 958 | " \n", 959 | " \n", 960 | " \n", 961 | " \n", 962 | " \n", 963 | " \n", 964 | " \n", 965 | " \n", 966 | " \n", 967 | " \n", 968 | " \n", 969 | " \n", 970 | " \n", 971 | " \n", 972 | " \n", 973 | " \n", 974 | " \n", 975 | " \n", 976 | " \n", 977 | " \n", 978 | " \n", 979 | " \n", 980 | " \n", 981 | " \n", 982 | " \n", 983 | " \n", 984 | " \n", 985 | " \n", 986 | " \n", 987 | " \n", 988 | " \n", 989 | " \n", 990 | " \n", 991 | " \n", 992 | " \n", 993 | " \n", 994 | " \n", 995 | " \n", 996 | " \n", 997 | " \n", 998 | " \n", 999 | " \n", 1000 | " \n", 1001 | " \n", 1002 | " \n", 1003 | " \n", 1004 | " \n", 1005 | " \n", 1006 | " \n", 1007 | " \n", 1008 | " \n", 1009 | " \n", 1010 | " \n", 1011 | " \n", 1012 | " \n", 1013 | " \n", 1014 | " \n", 1015 | " \n", 1016 | " \n", 1017 | " \n", 1018 | " \n", 1019 | " \n", 1020 | " \n", 1021 | " \n", 1022 | " \n", 1023 | " \n", 1024 | " \n", 1025 | " \n", 1026 | " \n", 1027 | " \n", 1028 | " \n", 1029 | " \n", 1030 | " \n", 1031 | " \n", 1032 | " \n", 1033 | " \n", 1034 | " \n", 1035 | " \n", 1036 | " \n", 1037 | " \n", 1038 | " \n", 1039 | " \n", 1040 | " \n", 1041 | " \n", 1042 | " \n", 1043 | " \n", 1044 | " \n", 1045 | " \n", 1046 | " \n", 1047 | " \n", 1048 | " \n", 1049 | " \n", 1050 | " \n", 1051 | " \n", 1052 | " \n", 1053 | " \n", 1054 | " \n", 1055 | " \n", 1056 | " \n", 1057 | " \n", 1058 | " \n", 1059 | " \n", 1060 | " \n", 1061 | " \n", 1062 | " \n", 1063 | " \n", 1064 | " \n", 1065 | " \n", 1066 | " \n", 1067 | " \n", 1068 | " \n", 1069 | " \n", 1070 | " \n", 1071 | " \n", 1072 | " \n", 1073 | " \n", 1074 | " \n", 1075 | " \n", 1076 | " \n", 1077 | " \n", 1078 | " \n", 1079 | " \n", 1080 | " \n", 1081 | " \n", 1082 | " \n", 1083 | " \n", 1084 | " \n", 1085 | " \n", 1086 | " \n", 1087 | " \n", 1088 | " \n", 1089 | " \n", 1090 | " \n", 1091 | " \n", 1092 | " \n", 1093 | " \n", 1094 | " \n", 1095 | " \n", 1096 | " \n", 1097 | " \n", 1098 | " \n", 1099 | " \n", 1100 | " \n", 1101 | " \n", 1102 | " \n", 1103 | " \n", 1104 | " \n", 1105 | " \n", 1106 | " \n", 1107 | " \n", 1108 | " \n", 1109 | " \n", 1110 | " \n", 1111 | " \n", 1112 | " \n", 1113 | " \n", 1114 | " \n", 1115 | " \n", 1116 | " \n", 1117 | " \n", 1118 | " \n", 1119 | " \n", 1120 | " \n", 1121 | " \n", 1122 | " \n", 1123 | " \n", 1124 | " \n", 1125 | " \n", 1126 | " \n", 1127 | " \n", 1128 | " \n", 1129 | " \n", 1130 | " \n", 1131 | " \n", 1132 | " \n", 1133 | " \n", 1134 | " \n", 1135 | " \n", 1136 | " \n", 1137 | " \n", 1138 | " \n", 1139 | " \n", 1140 | " \n", 1141 | " \n", 1142 | " \n", 1143 | " \n", 1144 | " \n", 1145 | " \n", 1146 | " \n", 1147 | " \n", 1148 | " \n", 1149 | " \n", 1150 | " \n", 1151 | " \n", 1152 | " \n", 1153 | " \n", 1154 | " \n", 1155 | " \n", 1156 | " \n", 1157 | " \n", 1158 | " \n", 1159 | " \n", 1160 | " \n", 1161 | " \n", 1162 | " \n", 1163 | " \n", 1164 | " \n", 1165 | " \n", 1166 | " \n", 1167 | " \n", 1168 | " \n", 1169 | " \n", 1170 | " \n", 1171 | " \n", 1172 | " \n", 1173 | " \n", 1174 | " \n", 1175 | " \n", 1176 | " \n", 1177 | " \n", 1178 | " \n", 1179 | " \n", 1180 | " \n", 1181 | " \n", 1182 | " \n", 1183 | " \n", 1184 | " \n", 1185 | " \n", 1186 | " \n", 1187 | " \n", 1188 | " \n", 1189 | " \n", 1190 | " \n", 1191 | " \n", 1192 | " \n", 1193 | " \n", 1194 | " \n", 1195 | " \n", 1196 | " \n", 1197 | " \n", 1198 | " \n", 1199 | " \n", 1200 | " \n", 1201 | " \n", 1202 | " \n", 1203 | " \n", 1204 | " \n", 1205 | " \n", 1206 | " \n", 1207 | " \n", 1208 | " \n", 1209 | " \n", 1210 | " \n", 1211 | " \n", 1212 | " \n", 1213 | " \n", 1214 | " \n", 1215 | " \n", 1216 | " \n", 1217 | " \n", 1218 | " \n", 1219 | " \n", 1220 | " \n", 1221 | " \n", 1222 | " \n", 1223 | " \n", 1224 | " \n", 1225 | " \n", 1226 | " \n", 1227 | " \n", 1228 | " \n", 1229 | " \n", 1230 | " \n", 1231 | " \n", 1232 | " \n", 1233 | " \n", 1234 | " \n", 1235 | " \n", 1236 | " \n", 1237 | " \n", 1238 | " \n", 1239 | " \n", 1240 | " \n", 1241 | " \n", 1242 | " \n", 1243 | " \n", 1244 | " \n", 1245 | " \n", 1246 | " \n", 1247 | " \n", 1248 | " \n", 1249 | " \n", 1250 | " \n", 1251 | " \n", 1252 | " \n", 1253 | " \n", 1254 | " \n", 1255 | " \n", 1256 | " \n", 1257 | " \n", 1258 | " \n", 1259 | " \n", 1260 | " \n", 1261 | " \n", 1262 | " \n", 1263 | " \n", 1264 | " \n", 1265 | " \n", 1266 | " \n", 1267 | " \n", 1268 | " \n", 1269 | " \n", 1270 | " \n", 1271 | " \n", 1272 | " \n", 1273 | " \n", 1274 | "
epochtrain_lossvalid_lossaccuracytime
02.2059182.1125710.40143900:01
11.8389521.7324370.56357500:01
21.3095391.1302220.66693300:01
30.8136350.7505170.73030800:01
40.6789730.6681790.73650500:01
50.6052850.6241360.76489400:01
60.5838430.5859790.78428600:01
70.5785870.5746820.78388600:01
80.5559720.5895460.76509400:01
90.5550050.5506060.79508200:01
100.5496250.5683220.77409000:01
110.5369850.5349800.79948000:01
120.5153230.5566850.77888800:01
130.5278720.5263690.78828500:01
140.4919410.5093460.79828100:01
150.4919770.5091730.80467800:01
160.4854780.5074280.80767700:01
170.4414690.4978240.81267500:01
180.4453930.4988610.80707700:01
190.4264250.5003760.80307900:01
200.4044930.4861020.81407400:01
210.3729390.5057040.81487400:01
220.3719520.5288410.81167500:01
230.3632070.5350340.81547400:01
240.3294140.5183360.81647300:01
250.3115020.5521920.81127500:01
260.2856780.5705820.80267900:01
270.2644990.6044540.81427400:01
280.2408770.6076140.80347900:01
290.2294440.6592520.79908000:01
300.1901710.6846410.80567800:01
310.1837580.6883260.80967600:01
320.1604000.7217590.79988000:01
330.1394230.7604390.80247900:01
340.1371230.7614620.80407800:01
350.1113740.8515560.81347500:01
360.1035500.8416990.80627800:01
370.0931930.8623190.81347500:01
380.0909560.8601480.81327500:01
390.0745390.8934420.80587800:01
400.0727760.8974820.80547800:01
410.0554410.9275290.80727700:01
420.0546990.9528900.80627800:01
430.0497600.9610540.80827700:01
440.0393950.9780220.81467400:01
450.0375571.0291760.80467800:01
460.0411091.0146670.81307500:01
470.0294321.0195570.81467400:01
480.0339821.0158390.81167500:01
490.0259871.0164070.81727300:01
500.0238341.0516630.81087600:01
510.0237751.0717840.81387400:01
520.0203471.0629830.81567400:01
530.0207891.0857310.81647300:01
540.0153531.0440570.81427400:01
550.0187531.0611730.81227500:01
560.0177491.0693160.81627400:01
570.0190031.0477130.81707300:01
580.0177091.0670660.81487400:01
590.0167461.0748620.81647300:01
" 1275 | ], 1276 | "text/plain": [ 1277 | "" 1278 | ] 1279 | }, 1280 | "metadata": {}, 1281 | "output_type": "display_data" 1282 | } 1283 | ], 1284 | "source": [ 1285 | "learn.fit_one_cycle(cyc_len=cycles, max_lr=max_lr, wd=w_decay)" 1286 | ] 1287 | }, 1288 | { 1289 | "cell_type": "code", 1290 | "execution_count": 33, 1291 | "metadata": { 1292 | "ExecuteTime": { 1293 | "end_time": "2020-02-04T07:53:56.552998Z", 1294 | "start_time": "2020-02-04T07:53:56.522725Z" 1295 | } 1296 | }, 1297 | "outputs": [], 1298 | "source": [ 1299 | "learn = None\n", 1300 | "np.random.seed(1001)\n", 1301 | "learn = tabular_learner(data, \n", 1302 | " layers=layers, \n", 1303 | " ps=layers_drop, \n", 1304 | " emb_drop=emb_drop, \n", 1305 | " metrics=accuracy,\n", 1306 | " callback_fns=[CSVLogger])" 1307 | ] 1308 | }, 1309 | { 1310 | "cell_type": "code", 1311 | "execution_count": 34, 1312 | "metadata": { 1313 | "ExecuteTime": { 1314 | "end_time": "2020-02-04T07:54:02.286166Z", 1315 | "start_time": "2020-02-04T07:54:02.265641Z" 1316 | } 1317 | }, 1318 | "outputs": [], 1319 | "source": [ 1320 | "max_lr = 1e-3" 1321 | ] 1322 | }, 1323 | { 1324 | "cell_type": "code", 1325 | "execution_count": 35, 1326 | "metadata": { 1327 | "ExecuteTime": { 1328 | "end_time": "2020-02-04T07:55:13.309107Z", 1329 | "start_time": "2020-02-04T07:54:07.155296Z" 1330 | } 1331 | }, 1332 | "outputs": [ 1333 | { 1334 | "data": { 1335 | "text/html": [ 1336 | "\n", 1337 | " \n", 1338 | " \n", 1339 | " \n", 1340 | " \n", 1341 | " \n", 1342 | " \n", 1343 | " \n", 1344 | " \n", 1345 | " \n", 1346 | " \n", 1347 | " \n", 1348 | " \n", 1349 | " \n", 1350 | " \n", 1351 | " \n", 1352 | " \n", 1353 | " \n", 1354 | " \n", 1355 | " \n", 1356 | " \n", 1357 | " \n", 1358 | " \n", 1359 | " \n", 1360 | " \n", 1361 | " \n", 1362 | " \n", 1363 | " \n", 1364 | " \n", 1365 | " \n", 1366 | " \n", 1367 | " \n", 1368 | " \n", 1369 | " \n", 1370 | " \n", 1371 | " \n", 1372 | " \n", 1373 | " \n", 1374 | " \n", 1375 | " \n", 1376 | " \n", 1377 | " \n", 1378 | " \n", 1379 | " \n", 1380 | " \n", 1381 | " \n", 1382 | " \n", 1383 | " \n", 1384 | " \n", 1385 | " \n", 1386 | " \n", 1387 | " \n", 1388 | " \n", 1389 | " \n", 1390 | " \n", 1391 | " \n", 1392 | " \n", 1393 | " \n", 1394 | " \n", 1395 | " \n", 1396 | " \n", 1397 | " \n", 1398 | " \n", 1399 | " \n", 1400 | " \n", 1401 | " \n", 1402 | " \n", 1403 | " \n", 1404 | " \n", 1405 | " \n", 1406 | " \n", 1407 | " \n", 1408 | " \n", 1409 | " \n", 1410 | " \n", 1411 | " \n", 1412 | " \n", 1413 | " \n", 1414 | " \n", 1415 | " \n", 1416 | " \n", 1417 | " \n", 1418 | " \n", 1419 | " \n", 1420 | " \n", 1421 | " \n", 1422 | " \n", 1423 | " \n", 1424 | " \n", 1425 | " \n", 1426 | " \n", 1427 | " \n", 1428 | " \n", 1429 | " \n", 1430 | " \n", 1431 | " \n", 1432 | " \n", 1433 | " \n", 1434 | " \n", 1435 | " \n", 1436 | " \n", 1437 | " \n", 1438 | " \n", 1439 | " \n", 1440 | " \n", 1441 | " \n", 1442 | " \n", 1443 | " \n", 1444 | " \n", 1445 | " \n", 1446 | " \n", 1447 | " \n", 1448 | " \n", 1449 | " \n", 1450 | " \n", 1451 | " \n", 1452 | " \n", 1453 | " \n", 1454 | " \n", 1455 | " \n", 1456 | " \n", 1457 | " \n", 1458 | " \n", 1459 | " \n", 1460 | " \n", 1461 | " \n", 1462 | " \n", 1463 | " \n", 1464 | " \n", 1465 | " \n", 1466 | " \n", 1467 | " \n", 1468 | " \n", 1469 | " \n", 1470 | " \n", 1471 | " \n", 1472 | " \n", 1473 | " \n", 1474 | " \n", 1475 | " \n", 1476 | " \n", 1477 | " \n", 1478 | " \n", 1479 | " \n", 1480 | " \n", 1481 | " \n", 1482 | " \n", 1483 | " \n", 1484 | " \n", 1485 | " \n", 1486 | " \n", 1487 | " \n", 1488 | " \n", 1489 | " \n", 1490 | " \n", 1491 | " \n", 1492 | " \n", 1493 | " \n", 1494 | " \n", 1495 | " \n", 1496 | " \n", 1497 | " \n", 1498 | " \n", 1499 | " \n", 1500 | " \n", 1501 | " \n", 1502 | " \n", 1503 | " \n", 1504 | " \n", 1505 | " \n", 1506 | " \n", 1507 | " \n", 1508 | " \n", 1509 | " \n", 1510 | " \n", 1511 | " \n", 1512 | " \n", 1513 | " \n", 1514 | " \n", 1515 | " \n", 1516 | " \n", 1517 | " \n", 1518 | " \n", 1519 | " \n", 1520 | " \n", 1521 | " \n", 1522 | " \n", 1523 | " \n", 1524 | " \n", 1525 | " \n", 1526 | " \n", 1527 | " \n", 1528 | " \n", 1529 | " \n", 1530 | " \n", 1531 | " \n", 1532 | " \n", 1533 | " \n", 1534 | " \n", 1535 | " \n", 1536 | " \n", 1537 | " \n", 1538 | " \n", 1539 | " \n", 1540 | " \n", 1541 | " \n", 1542 | " \n", 1543 | " \n", 1544 | " \n", 1545 | " \n", 1546 | " \n", 1547 | " \n", 1548 | " \n", 1549 | " \n", 1550 | " \n", 1551 | " \n", 1552 | " \n", 1553 | " \n", 1554 | " \n", 1555 | " \n", 1556 | " \n", 1557 | " \n", 1558 | " \n", 1559 | " \n", 1560 | " \n", 1561 | " \n", 1562 | " \n", 1563 | " \n", 1564 | " \n", 1565 | " \n", 1566 | " \n", 1567 | " \n", 1568 | " \n", 1569 | " \n", 1570 | " \n", 1571 | " \n", 1572 | " \n", 1573 | " \n", 1574 | " \n", 1575 | " \n", 1576 | " \n", 1577 | " \n", 1578 | " \n", 1579 | " \n", 1580 | " \n", 1581 | " \n", 1582 | " \n", 1583 | " \n", 1584 | " \n", 1585 | " \n", 1586 | " \n", 1587 | " \n", 1588 | " \n", 1589 | " \n", 1590 | " \n", 1591 | " \n", 1592 | " \n", 1593 | " \n", 1594 | " \n", 1595 | " \n", 1596 | " \n", 1597 | " \n", 1598 | " \n", 1599 | " \n", 1600 | " \n", 1601 | " \n", 1602 | " \n", 1603 | " \n", 1604 | " \n", 1605 | " \n", 1606 | " \n", 1607 | " \n", 1608 | " \n", 1609 | " \n", 1610 | " \n", 1611 | " \n", 1612 | " \n", 1613 | " \n", 1614 | " \n", 1615 | " \n", 1616 | " \n", 1617 | " \n", 1618 | " \n", 1619 | " \n", 1620 | " \n", 1621 | " \n", 1622 | " \n", 1623 | " \n", 1624 | " \n", 1625 | " \n", 1626 | " \n", 1627 | " \n", 1628 | " \n", 1629 | " \n", 1630 | " \n", 1631 | " \n", 1632 | " \n", 1633 | " \n", 1634 | " \n", 1635 | " \n", 1636 | " \n", 1637 | " \n", 1638 | " \n", 1639 | " \n", 1640 | " \n", 1641 | " \n", 1642 | " \n", 1643 | " \n", 1644 | " \n", 1645 | " \n", 1646 | " \n", 1647 | " \n", 1648 | " \n", 1649 | " \n", 1650 | " \n", 1651 | " \n", 1652 | " \n", 1653 | " \n", 1654 | " \n", 1655 | " \n", 1656 | " \n", 1657 | " \n", 1658 | " \n", 1659 | " \n", 1660 | " \n", 1661 | " \n", 1662 | " \n", 1663 | " \n", 1664 | " \n", 1665 | " \n", 1666 | " \n", 1667 | " \n", 1668 | " \n", 1669 | " \n", 1670 | " \n", 1671 | " \n", 1672 | " \n", 1673 | " \n", 1674 | " \n", 1675 | " \n", 1676 | " \n", 1677 | " \n", 1678 | " \n", 1679 | " \n", 1680 | " \n", 1681 | " \n", 1682 | " \n", 1683 | " \n", 1684 | " \n", 1685 | " \n", 1686 | " \n", 1687 | " \n", 1688 | " \n", 1689 | " \n", 1690 | " \n", 1691 | " \n", 1692 | " \n", 1693 | " \n", 1694 | " \n", 1695 | " \n", 1696 | " \n", 1697 | " \n", 1698 | " \n", 1699 | " \n", 1700 | " \n", 1701 | " \n", 1702 | " \n", 1703 | " \n", 1704 | " \n", 1705 | " \n", 1706 | " \n", 1707 | " \n", 1708 | " \n", 1709 | " \n", 1710 | " \n", 1711 | " \n", 1712 | " \n", 1713 | " \n", 1714 | " \n", 1715 | " \n", 1716 | " \n", 1717 | " \n", 1718 | " \n", 1719 | " \n", 1720 | " \n", 1721 | " \n", 1722 | " \n", 1723 | " \n", 1724 | " \n", 1725 | " \n", 1726 | " \n", 1727 | " \n", 1728 | " \n", 1729 | " \n", 1730 | " \n", 1731 | " \n", 1732 | " \n", 1733 | " \n", 1734 | " \n", 1735 | " \n", 1736 | " \n", 1737 | " \n", 1738 | " \n", 1739 | " \n", 1740 | " \n", 1741 | " \n", 1742 | " \n", 1743 | " \n", 1744 | " \n", 1745 | " \n", 1746 | " \n", 1747 | " \n", 1748 | " \n", 1749 | " \n", 1750 | " \n", 1751 | " \n", 1752 | " \n", 1753 | " \n", 1754 | " \n", 1755 | " \n", 1756 | " \n", 1757 | " \n", 1758 | " \n", 1759 | " \n", 1760 | " \n", 1761 | " \n", 1762 | " \n", 1763 | " \n", 1764 | " \n", 1765 | " \n", 1766 | " \n", 1767 | " \n", 1768 | "
epochtrain_lossvalid_lossaccuracytime
02.3319942.2984190.19432200:01
12.1945432.1700360.33806500:01
22.0212341.9743620.47800900:01
31.7215861.6348950.58956400:01
41.3348061.2436460.65193900:01
50.9490420.9329300.69372200:01
60.7531700.7580990.71331500:01
70.6484770.7017690.72930800:01
80.5840610.6225910.75809700:01
90.5287250.5725850.78188700:01
100.4946630.5385720.79928000:01
110.4610230.5144710.81307500:01
120.4140330.4588770.83306700:01
130.4033420.4665450.83286700:01
140.3714480.4603640.83466600:01
150.3490660.4564470.83166700:01
160.3286230.4312010.84666100:01
170.2982620.4406470.84206300:01
180.2873410.4512920.84186300:01
190.2762020.4371120.85185900:01
200.2389420.4408750.85245900:01
210.2474440.4571170.84226300:01
220.2106030.4666230.84046400:01
230.1993700.4671430.84666100:01
240.1889430.4808170.84266300:01
250.1701040.5153870.85145900:01
260.1645230.5012440.84966000:01
270.1427040.5193140.84966000:01
280.1292320.5706130.84386200:01
290.1178490.5300750.84926000:01
300.1196870.5617510.84106400:01
310.1011570.5723760.84946000:01
320.0994570.5895220.84846100:01
330.0892800.6006240.84326300:01
340.0817080.6050030.84326300:01
350.0780500.6048640.84346300:01
360.0660390.6052190.84286300:01
370.0642910.6333930.84606200:01
380.0565190.6310110.84846100:01
390.0519510.6505490.84926000:01
400.0443990.6851670.84426200:01
410.0444200.6688950.84686100:01
420.0410400.6708170.84686100:01
430.0397310.6718130.85185900:01
440.0317310.6872650.84646100:01
450.0321730.6654320.85026000:01
460.0237290.6817000.85245900:01
470.0275600.6743410.85385800:01
480.0231270.6789070.85605800:01
490.0264710.6839940.85465800:01
500.0220800.6964830.85545800:01
510.0189410.6777580.85705700:01
520.0237970.6957700.85525800:01
530.0205670.6814170.85505800:01
540.0171710.6925620.85585800:01
550.0194570.6946070.85605800:01
560.0158360.7078120.85705700:01
570.0176510.7047570.85765700:01
580.0180220.6906860.85685700:01
590.0183670.6922750.85725700:01
" 1769 | ], 1770 | "text/plain": [ 1771 | "" 1772 | ] 1773 | }, 1774 | "metadata": {}, 1775 | "output_type": "display_data" 1776 | } 1777 | ], 1778 | "source": [ 1779 | "learn.fit_one_cycle(cyc_len=cycles, max_lr=max_lr, wd=w_decay)" 1780 | ] 1781 | }, 1782 | { 1783 | "cell_type": "code", 1784 | "execution_count": 36, 1785 | "metadata": { 1786 | "ExecuteTime": { 1787 | "end_time": "2020-02-04T07:55:39.855236Z", 1788 | "start_time": "2020-02-04T07:55:39.814219Z" 1789 | } 1790 | }, 1791 | "outputs": [], 1792 | "source": [ 1793 | "learn.save('no_reg_save_0857')" 1794 | ] 1795 | }, 1796 | { 1797 | "cell_type": "code", 1798 | "execution_count": 37, 1799 | "metadata": { 1800 | "ExecuteTime": { 1801 | "end_time": "2020-02-04T07:55:43.203481Z", 1802 | "start_time": "2020-02-04T07:55:43.162057Z" 1803 | } 1804 | }, 1805 | "outputs": [], 1806 | "source": [ 1807 | "learn.export('no_reg_0857')" 1808 | ] 1809 | }, 1810 | { 1811 | "cell_type": "markdown", 1812 | "metadata": {}, 1813 | "source": [ 1814 | "### 91%" 1815 | ] 1816 | }, 1817 | { 1818 | "cell_type": "code", 1819 | "execution_count": 38, 1820 | "metadata": { 1821 | "ExecuteTime": { 1822 | "end_time": "2020-02-04T07:56:24.556232Z", 1823 | "start_time": "2020-02-04T07:56:24.538453Z" 1824 | } 1825 | }, 1826 | "outputs": [], 1827 | "source": [ 1828 | "layers = [1000, 500, 200]\n", 1829 | "layers_drop = [0, 0, 0]\n", 1830 | "emb_drop = 0\n", 1831 | "cycles = 60\n", 1832 | "w_decay = 0\n", 1833 | "max_lr = 1e-3" 1834 | ] 1835 | }, 1836 | { 1837 | "cell_type": "code", 1838 | "execution_count": 39, 1839 | "metadata": { 1840 | "ExecuteTime": { 1841 | "end_time": "2020-02-04T07:56:27.686208Z", 1842 | "start_time": "2020-02-04T07:56:27.602630Z" 1843 | } 1844 | }, 1845 | "outputs": [], 1846 | "source": [ 1847 | "data = (TabularList.from_df(df, path=path, cat_names=cat_vars, cont_names=cont_vars, procs=procs)\n", 1848 | " .split_by_idx(valid_idx)\n", 1849 | " .label_from_df(cols=dep_var, label_cls=CategoryList)\n", 1850 | " .databunch(bs=BS))" 1851 | ] 1852 | }, 1853 | { 1854 | "cell_type": "code", 1855 | "execution_count": 40, 1856 | "metadata": { 1857 | "ExecuteTime": { 1858 | "end_time": "2020-02-04T07:56:29.035207Z", 1859 | "start_time": "2020-02-04T07:56:28.995806Z" 1860 | } 1861 | }, 1862 | "outputs": [], 1863 | "source": [ 1864 | "np.random.seed(1001)\n", 1865 | "learn = tabular_learner(data, \n", 1866 | " layers=layers, \n", 1867 | " ps=layers_drop, \n", 1868 | " emb_drop=emb_drop, \n", 1869 | " metrics=accuracy,\n", 1870 | " callback_fns=[CSVLogger])" 1871 | ] 1872 | }, 1873 | { 1874 | "cell_type": "code", 1875 | "execution_count": 41, 1876 | "metadata": { 1877 | "ExecuteTime": { 1878 | "end_time": "2020-02-04T07:57:38.378645Z", 1879 | "start_time": "2020-02-04T07:56:31.075407Z" 1880 | } 1881 | }, 1882 | "outputs": [ 1883 | { 1884 | "data": { 1885 | "text/html": [ 1886 | "\n", 1887 | " \n", 1888 | " \n", 1889 | " \n", 1890 | " \n", 1891 | " \n", 1892 | " \n", 1893 | " \n", 1894 | " \n", 1895 | " \n", 1896 | " \n", 1897 | " \n", 1898 | " \n", 1899 | " \n", 1900 | " \n", 1901 | " \n", 1902 | " \n", 1903 | " \n", 1904 | " \n", 1905 | " \n", 1906 | " \n", 1907 | " \n", 1908 | " \n", 1909 | " \n", 1910 | " \n", 1911 | " \n", 1912 | " \n", 1913 | " \n", 1914 | " \n", 1915 | " \n", 1916 | " \n", 1917 | " \n", 1918 | " \n", 1919 | " \n", 1920 | " \n", 1921 | " \n", 1922 | " \n", 1923 | " \n", 1924 | " \n", 1925 | " \n", 1926 | " \n", 1927 | " \n", 1928 | " \n", 1929 | " \n", 1930 | " \n", 1931 | " \n", 1932 | " \n", 1933 | " \n", 1934 | " \n", 1935 | " \n", 1936 | " \n", 1937 | " \n", 1938 | " \n", 1939 | " \n", 1940 | " \n", 1941 | " \n", 1942 | " \n", 1943 | " \n", 1944 | " \n", 1945 | " \n", 1946 | " \n", 1947 | " \n", 1948 | " \n", 1949 | " \n", 1950 | " \n", 1951 | " \n", 1952 | " \n", 1953 | " \n", 1954 | " \n", 1955 | " \n", 1956 | " \n", 1957 | " \n", 1958 | " \n", 1959 | " \n", 1960 | " \n", 1961 | " \n", 1962 | " \n", 1963 | " \n", 1964 | " \n", 1965 | " \n", 1966 | " \n", 1967 | " \n", 1968 | " \n", 1969 | " \n", 1970 | " \n", 1971 | " \n", 1972 | " \n", 1973 | " \n", 1974 | " \n", 1975 | " \n", 1976 | " \n", 1977 | " \n", 1978 | " \n", 1979 | " \n", 1980 | " \n", 1981 | " \n", 1982 | " \n", 1983 | " \n", 1984 | " \n", 1985 | " \n", 1986 | " \n", 1987 | " \n", 1988 | " \n", 1989 | " \n", 1990 | " \n", 1991 | " \n", 1992 | " \n", 1993 | " \n", 1994 | " \n", 1995 | " \n", 1996 | " \n", 1997 | " \n", 1998 | " \n", 1999 | " \n", 2000 | " \n", 2001 | " \n", 2002 | " \n", 2003 | " \n", 2004 | " \n", 2005 | " \n", 2006 | " \n", 2007 | " \n", 2008 | " \n", 2009 | " \n", 2010 | " \n", 2011 | " \n", 2012 | " \n", 2013 | " \n", 2014 | " \n", 2015 | " \n", 2016 | " \n", 2017 | " \n", 2018 | " \n", 2019 | " \n", 2020 | " \n", 2021 | " \n", 2022 | " \n", 2023 | " \n", 2024 | " \n", 2025 | " \n", 2026 | " \n", 2027 | " \n", 2028 | " \n", 2029 | " \n", 2030 | " \n", 2031 | " \n", 2032 | " \n", 2033 | " \n", 2034 | " \n", 2035 | " \n", 2036 | " \n", 2037 | " \n", 2038 | " \n", 2039 | " \n", 2040 | " \n", 2041 | " \n", 2042 | " \n", 2043 | " \n", 2044 | " \n", 2045 | " \n", 2046 | " \n", 2047 | " \n", 2048 | " \n", 2049 | " \n", 2050 | " \n", 2051 | " \n", 2052 | " \n", 2053 | " \n", 2054 | " \n", 2055 | " \n", 2056 | " \n", 2057 | " \n", 2058 | " \n", 2059 | " \n", 2060 | " \n", 2061 | " \n", 2062 | " \n", 2063 | " \n", 2064 | " \n", 2065 | " \n", 2066 | " \n", 2067 | " \n", 2068 | " \n", 2069 | " \n", 2070 | " \n", 2071 | " \n", 2072 | " \n", 2073 | " \n", 2074 | " \n", 2075 | " \n", 2076 | " \n", 2077 | " \n", 2078 | " \n", 2079 | " \n", 2080 | " \n", 2081 | " \n", 2082 | " \n", 2083 | " \n", 2084 | " \n", 2085 | " \n", 2086 | " \n", 2087 | " \n", 2088 | " \n", 2089 | " \n", 2090 | " \n", 2091 | " \n", 2092 | " \n", 2093 | " \n", 2094 | " \n", 2095 | " \n", 2096 | " \n", 2097 | " \n", 2098 | " \n", 2099 | " \n", 2100 | " \n", 2101 | " \n", 2102 | " \n", 2103 | " \n", 2104 | " \n", 2105 | " \n", 2106 | " \n", 2107 | " \n", 2108 | " \n", 2109 | " \n", 2110 | " \n", 2111 | " \n", 2112 | " \n", 2113 | " \n", 2114 | " \n", 2115 | " \n", 2116 | " \n", 2117 | " \n", 2118 | " \n", 2119 | " \n", 2120 | " \n", 2121 | " \n", 2122 | " \n", 2123 | " \n", 2124 | " \n", 2125 | " \n", 2126 | " \n", 2127 | " \n", 2128 | " \n", 2129 | " \n", 2130 | " \n", 2131 | " \n", 2132 | " \n", 2133 | " \n", 2134 | " \n", 2135 | " \n", 2136 | " \n", 2137 | " \n", 2138 | " \n", 2139 | " \n", 2140 | " \n", 2141 | " \n", 2142 | " \n", 2143 | " \n", 2144 | " \n", 2145 | " \n", 2146 | " \n", 2147 | " \n", 2148 | " \n", 2149 | " \n", 2150 | " \n", 2151 | " \n", 2152 | " \n", 2153 | " \n", 2154 | " \n", 2155 | " \n", 2156 | " \n", 2157 | " \n", 2158 | " \n", 2159 | " \n", 2160 | " \n", 2161 | " \n", 2162 | " \n", 2163 | " \n", 2164 | " \n", 2165 | " \n", 2166 | " \n", 2167 | " \n", 2168 | " \n", 2169 | " \n", 2170 | " \n", 2171 | " \n", 2172 | " \n", 2173 | " \n", 2174 | " \n", 2175 | " \n", 2176 | " \n", 2177 | " \n", 2178 | " \n", 2179 | " \n", 2180 | " \n", 2181 | " \n", 2182 | " \n", 2183 | " \n", 2184 | " \n", 2185 | " \n", 2186 | " \n", 2187 | " \n", 2188 | " \n", 2189 | " \n", 2190 | " \n", 2191 | " \n", 2192 | " \n", 2193 | " \n", 2194 | " \n", 2195 | " \n", 2196 | " \n", 2197 | " \n", 2198 | " \n", 2199 | " \n", 2200 | " \n", 2201 | " \n", 2202 | " \n", 2203 | " \n", 2204 | " \n", 2205 | " \n", 2206 | " \n", 2207 | " \n", 2208 | " \n", 2209 | " \n", 2210 | " \n", 2211 | " \n", 2212 | " \n", 2213 | " \n", 2214 | " \n", 2215 | " \n", 2216 | " \n", 2217 | " \n", 2218 | " \n", 2219 | " \n", 2220 | " \n", 2221 | " \n", 2222 | " \n", 2223 | " \n", 2224 | " \n", 2225 | " \n", 2226 | " \n", 2227 | " \n", 2228 | " \n", 2229 | " \n", 2230 | " \n", 2231 | " \n", 2232 | " \n", 2233 | " \n", 2234 | " \n", 2235 | " \n", 2236 | " \n", 2237 | " \n", 2238 | " \n", 2239 | " \n", 2240 | " \n", 2241 | " \n", 2242 | " \n", 2243 | " \n", 2244 | " \n", 2245 | " \n", 2246 | " \n", 2247 | " \n", 2248 | " \n", 2249 | " \n", 2250 | " \n", 2251 | " \n", 2252 | " \n", 2253 | " \n", 2254 | " \n", 2255 | " \n", 2256 | " \n", 2257 | " \n", 2258 | " \n", 2259 | " \n", 2260 | " \n", 2261 | " \n", 2262 | " \n", 2263 | " \n", 2264 | " \n", 2265 | " \n", 2266 | " \n", 2267 | " \n", 2268 | " \n", 2269 | " \n", 2270 | " \n", 2271 | " \n", 2272 | " \n", 2273 | " \n", 2274 | " \n", 2275 | " \n", 2276 | " \n", 2277 | " \n", 2278 | " \n", 2279 | " \n", 2280 | " \n", 2281 | " \n", 2282 | " \n", 2283 | " \n", 2284 | " \n", 2285 | " \n", 2286 | " \n", 2287 | " \n", 2288 | " \n", 2289 | " \n", 2290 | " \n", 2291 | " \n", 2292 | " \n", 2293 | " \n", 2294 | " \n", 2295 | " \n", 2296 | " \n", 2297 | " \n", 2298 | " \n", 2299 | " \n", 2300 | " \n", 2301 | " \n", 2302 | " \n", 2303 | " \n", 2304 | " \n", 2305 | " \n", 2306 | " \n", 2307 | " \n", 2308 | " \n", 2309 | " \n", 2310 | " \n", 2311 | " \n", 2312 | " \n", 2313 | " \n", 2314 | " \n", 2315 | " \n", 2316 | " \n", 2317 | " \n", 2318 | "
epochtrain_lossvalid_lossaccuracytime
02.3139082.2735490.19692100:01
12.1721572.1790610.34806100:01
21.9735042.0072260.48180700:01
31.6890381.6606260.58076800:01
41.3074211.2370870.64214300:01
50.9353130.9068350.68932400:01
60.7200140.7369570.72910800:01
70.6154890.6687740.73290700:01
80.5578200.6158240.76129500:01
90.5274870.6160450.75929600:01
100.4900040.5925570.78088800:01
110.4664180.5940110.78708500:01
120.4064370.5516910.79568200:01
130.3777570.5076830.81787300:01
140.3446300.4948620.82926800:01
150.3068550.4737360.83326700:01
160.2672850.4574510.84226300:01
170.2396220.4339770.84806100:01
180.2120510.4415240.85685700:01
190.1758360.4270190.86005600:01
200.1792920.4131490.86805300:01
210.1457760.4352020.87185100:01
220.1357230.4267250.87245100:01
230.1109330.3996590.87964800:01
240.1124470.4054410.88564600:01
250.1095050.4458750.87624900:01
260.0853300.4148860.89244300:01
270.0937030.4243970.88384600:01
280.0690070.4172130.89204300:01
290.0716600.4313380.88604600:01
300.0574730.4321870.88924400:01
310.0591030.4491400.88684500:01
320.0521040.4474110.89324300:01
330.0386680.4391080.89504200:01
340.0416380.4584630.88964400:01
350.0351980.4273330.89884000:01
360.0319950.4553370.89464200:01
370.0274990.4353170.90143900:01
380.0220930.4291250.90263900:01
390.0192650.4489100.90323900:01
400.0171390.4434200.89884000:01
410.0163070.4509930.90024000:01
420.0111150.4464120.90163900:01
430.0083310.4672980.90463800:01
440.0101570.4710970.89924000:01
450.0135440.4917570.90303900:01
460.0074860.4851410.90163900:01
470.0077910.4984240.90243900:01
480.0039880.4919840.90423800:01
490.0050470.4820240.90743700:01
500.0040360.4754490.91023600:01
510.0054840.4844450.90563800:01
520.0025720.4933490.90663700:01
530.0026690.4879430.90903600:01
540.0024480.4864310.90683700:01
550.0027410.4724130.90903600:01
560.0027610.4822040.91003600:01
570.0020240.4888530.90683700:01
580.0029980.4761790.90783700:01
590.0038970.4824910.90823700:01
" 2319 | ], 2320 | "text/plain": [ 2321 | "" 2322 | ] 2323 | }, 2324 | "metadata": {}, 2325 | "output_type": "display_data" 2326 | } 2327 | ], 2328 | "source": [ 2329 | "learn.fit_one_cycle(cyc_len=cycles, max_lr=max_lr, wd=w_decay)" 2330 | ] 2331 | }, 2332 | { 2333 | "cell_type": "code", 2334 | "execution_count": 42, 2335 | "metadata": { 2336 | "ExecuteTime": { 2337 | "end_time": "2020-02-04T07:58:08.683290Z", 2338 | "start_time": "2020-02-04T07:58:08.622617Z" 2339 | } 2340 | }, 2341 | "outputs": [], 2342 | "source": [ 2343 | "learn.save('no_reg_save_0908')\n", 2344 | "learn.export('no_reg_0908')" 2345 | ] 2346 | }, 2347 | { 2348 | "cell_type": "code", 2349 | "execution_count": 43, 2350 | "metadata": { 2351 | "ExecuteTime": { 2352 | "end_time": "2020-02-04T07:58:31.888308Z", 2353 | "start_time": "2020-02-04T07:58:31.867844Z" 2354 | } 2355 | }, 2356 | "outputs": [], 2357 | "source": [ 2358 | "max_lr = 1e-4" 2359 | ] 2360 | }, 2361 | { 2362 | "cell_type": "code", 2363 | "execution_count": 44, 2364 | "metadata": { 2365 | "ExecuteTime": { 2366 | "end_time": "2020-02-04T07:59:54.280873Z", 2367 | "start_time": "2020-02-04T07:58:49.076527Z" 2368 | } 2369 | }, 2370 | "outputs": [ 2371 | { 2372 | "data": { 2373 | "text/html": [ 2374 | "\n", 2375 | " \n", 2376 | " \n", 2377 | " \n", 2378 | " \n", 2379 | " \n", 2380 | " \n", 2381 | " \n", 2382 | " \n", 2383 | " \n", 2384 | " \n", 2385 | " \n", 2386 | " \n", 2387 | " \n", 2388 | " \n", 2389 | " \n", 2390 | " \n", 2391 | " \n", 2392 | " \n", 2393 | " \n", 2394 | " \n", 2395 | " \n", 2396 | " \n", 2397 | " \n", 2398 | " \n", 2399 | " \n", 2400 | " \n", 2401 | " \n", 2402 | " \n", 2403 | " \n", 2404 | " \n", 2405 | " \n", 2406 | " \n", 2407 | " \n", 2408 | " \n", 2409 | " \n", 2410 | " \n", 2411 | " \n", 2412 | " \n", 2413 | " \n", 2414 | " \n", 2415 | " \n", 2416 | " \n", 2417 | " \n", 2418 | " \n", 2419 | " \n", 2420 | " \n", 2421 | " \n", 2422 | " \n", 2423 | " \n", 2424 | " \n", 2425 | " \n", 2426 | " \n", 2427 | " \n", 2428 | " \n", 2429 | " \n", 2430 | " \n", 2431 | " \n", 2432 | " \n", 2433 | " \n", 2434 | " \n", 2435 | " \n", 2436 | " \n", 2437 | " \n", 2438 | " \n", 2439 | " \n", 2440 | " \n", 2441 | " \n", 2442 | " \n", 2443 | " \n", 2444 | " \n", 2445 | " \n", 2446 | " \n", 2447 | " \n", 2448 | " \n", 2449 | " \n", 2450 | " \n", 2451 | " \n", 2452 | " \n", 2453 | " \n", 2454 | " \n", 2455 | " \n", 2456 | " \n", 2457 | " \n", 2458 | " \n", 2459 | " \n", 2460 | " \n", 2461 | " \n", 2462 | " \n", 2463 | " \n", 2464 | " \n", 2465 | " \n", 2466 | " \n", 2467 | " \n", 2468 | " \n", 2469 | " \n", 2470 | " \n", 2471 | " \n", 2472 | " \n", 2473 | " \n", 2474 | " \n", 2475 | " \n", 2476 | " \n", 2477 | " \n", 2478 | " \n", 2479 | " \n", 2480 | " \n", 2481 | " \n", 2482 | " \n", 2483 | " \n", 2484 | " \n", 2485 | " \n", 2486 | " \n", 2487 | " \n", 2488 | " \n", 2489 | " \n", 2490 | " \n", 2491 | " \n", 2492 | " \n", 2493 | " \n", 2494 | " \n", 2495 | " \n", 2496 | " \n", 2497 | " \n", 2498 | " \n", 2499 | " \n", 2500 | " \n", 2501 | " \n", 2502 | " \n", 2503 | " \n", 2504 | " \n", 2505 | " \n", 2506 | " \n", 2507 | " \n", 2508 | " \n", 2509 | " \n", 2510 | " \n", 2511 | " \n", 2512 | " \n", 2513 | " \n", 2514 | " \n", 2515 | " \n", 2516 | " \n", 2517 | " \n", 2518 | " \n", 2519 | " \n", 2520 | " \n", 2521 | " \n", 2522 | " \n", 2523 | " \n", 2524 | " \n", 2525 | " \n", 2526 | " \n", 2527 | " \n", 2528 | " \n", 2529 | " \n", 2530 | " \n", 2531 | " \n", 2532 | " \n", 2533 | " \n", 2534 | " \n", 2535 | " \n", 2536 | " \n", 2537 | " \n", 2538 | " \n", 2539 | " \n", 2540 | " \n", 2541 | " \n", 2542 | " \n", 2543 | " \n", 2544 | " \n", 2545 | " \n", 2546 | " \n", 2547 | " \n", 2548 | " \n", 2549 | " \n", 2550 | " \n", 2551 | " \n", 2552 | " \n", 2553 | " \n", 2554 | " \n", 2555 | " \n", 2556 | " \n", 2557 | " \n", 2558 | " \n", 2559 | " \n", 2560 | " \n", 2561 | " \n", 2562 | " \n", 2563 | " \n", 2564 | " \n", 2565 | " \n", 2566 | " \n", 2567 | " \n", 2568 | " \n", 2569 | " \n", 2570 | " \n", 2571 | " \n", 2572 | " \n", 2573 | " \n", 2574 | " \n", 2575 | " \n", 2576 | " \n", 2577 | " \n", 2578 | " \n", 2579 | " \n", 2580 | " \n", 2581 | " \n", 2582 | " \n", 2583 | " \n", 2584 | " \n", 2585 | " \n", 2586 | " \n", 2587 | " \n", 2588 | " \n", 2589 | " \n", 2590 | " \n", 2591 | " \n", 2592 | " \n", 2593 | " \n", 2594 | " \n", 2595 | " \n", 2596 | " \n", 2597 | " \n", 2598 | " \n", 2599 | " \n", 2600 | " \n", 2601 | " \n", 2602 | " \n", 2603 | " \n", 2604 | " \n", 2605 | " \n", 2606 | " \n", 2607 | " \n", 2608 | " \n", 2609 | " \n", 2610 | " \n", 2611 | " \n", 2612 | " \n", 2613 | " \n", 2614 | " \n", 2615 | " \n", 2616 | " \n", 2617 | " \n", 2618 | " \n", 2619 | " \n", 2620 | " \n", 2621 | " \n", 2622 | " \n", 2623 | " \n", 2624 | " \n", 2625 | " \n", 2626 | " \n", 2627 | " \n", 2628 | " \n", 2629 | " \n", 2630 | " \n", 2631 | " \n", 2632 | " \n", 2633 | " \n", 2634 | " \n", 2635 | " \n", 2636 | " \n", 2637 | " \n", 2638 | " \n", 2639 | " \n", 2640 | " \n", 2641 | " \n", 2642 | " \n", 2643 | " \n", 2644 | " \n", 2645 | " \n", 2646 | " \n", 2647 | " \n", 2648 | " \n", 2649 | " \n", 2650 | " \n", 2651 | " \n", 2652 | " \n", 2653 | " \n", 2654 | " \n", 2655 | " \n", 2656 | " \n", 2657 | " \n", 2658 | " \n", 2659 | " \n", 2660 | " \n", 2661 | " \n", 2662 | " \n", 2663 | " \n", 2664 | " \n", 2665 | " \n", 2666 | " \n", 2667 | " \n", 2668 | " \n", 2669 | " \n", 2670 | " \n", 2671 | " \n", 2672 | " \n", 2673 | " \n", 2674 | " \n", 2675 | " \n", 2676 | " \n", 2677 | " \n", 2678 | " \n", 2679 | " \n", 2680 | " \n", 2681 | " \n", 2682 | " \n", 2683 | " \n", 2684 | " \n", 2685 | " \n", 2686 | " \n", 2687 | " \n", 2688 | " \n", 2689 | " \n", 2690 | " \n", 2691 | " \n", 2692 | " \n", 2693 | " \n", 2694 | " \n", 2695 | " \n", 2696 | " \n", 2697 | " \n", 2698 | " \n", 2699 | " \n", 2700 | " \n", 2701 | " \n", 2702 | " \n", 2703 | " \n", 2704 | " \n", 2705 | " \n", 2706 | " \n", 2707 | " \n", 2708 | " \n", 2709 | " \n", 2710 | " \n", 2711 | " \n", 2712 | " \n", 2713 | " \n", 2714 | " \n", 2715 | " \n", 2716 | " \n", 2717 | " \n", 2718 | " \n", 2719 | " \n", 2720 | " \n", 2721 | " \n", 2722 | " \n", 2723 | " \n", 2724 | " \n", 2725 | " \n", 2726 | " \n", 2727 | " \n", 2728 | " \n", 2729 | " \n", 2730 | " \n", 2731 | " \n", 2732 | " \n", 2733 | " \n", 2734 | " \n", 2735 | " \n", 2736 | " \n", 2737 | " \n", 2738 | " \n", 2739 | " \n", 2740 | " \n", 2741 | " \n", 2742 | " \n", 2743 | " \n", 2744 | " \n", 2745 | " \n", 2746 | " \n", 2747 | " \n", 2748 | " \n", 2749 | " \n", 2750 | " \n", 2751 | " \n", 2752 | " \n", 2753 | " \n", 2754 | " \n", 2755 | " \n", 2756 | " \n", 2757 | " \n", 2758 | " \n", 2759 | " \n", 2760 | " \n", 2761 | " \n", 2762 | " \n", 2763 | " \n", 2764 | " \n", 2765 | " \n", 2766 | " \n", 2767 | " \n", 2768 | " \n", 2769 | " \n", 2770 | " \n", 2771 | " \n", 2772 | " \n", 2773 | " \n", 2774 | " \n", 2775 | " \n", 2776 | " \n", 2777 | " \n", 2778 | " \n", 2779 | " \n", 2780 | " \n", 2781 | " \n", 2782 | " \n", 2783 | " \n", 2784 | " \n", 2785 | " \n", 2786 | " \n", 2787 | " \n", 2788 | " \n", 2789 | " \n", 2790 | " \n", 2791 | " \n", 2792 | " \n", 2793 | " \n", 2794 | " \n", 2795 | " \n", 2796 | " \n", 2797 | " \n", 2798 | " \n", 2799 | " \n", 2800 | " \n", 2801 | " \n", 2802 | " \n", 2803 | " \n", 2804 | " \n", 2805 | " \n", 2806 | "
epochtrain_lossvalid_lossaccuracytime
00.0023460.5019660.90623700:01
10.0018000.4849590.90723700:01
20.0027400.5030990.90683700:01
30.0027390.5165480.90423800:01
40.0028840.4907640.90723700:01
50.0063040.5466470.90163900:01
60.0082270.5320480.90423800:01
70.0112130.5399050.89924000:01
80.0149060.5525670.89524200:01
90.0229290.6131380.89024400:01
100.0372770.5971360.88744500:01
110.0337550.5451980.89264300:01
120.0431620.5791430.88464600:01
130.0441360.5819040.88224700:01
140.0549390.5392700.89384200:01
150.0591450.5189040.89184300:01
160.0591680.5168590.88984400:01
170.0555480.5288000.88784500:01
180.0438890.5002830.89444200:01
190.0549160.5130640.89184300:01
200.0498620.4899530.88804500:01
210.0604760.4847380.89524200:01
220.0430350.4717260.90104000:01
230.0447360.5066840.89164300:01
240.0435470.4476430.89984000:01
250.0349260.4882990.89724100:01
260.0362340.4380570.90423800:01
270.0347460.4571430.90143900:01
280.0327960.4644430.90483800:01
290.0390690.4538790.90443800:01
300.0189690.4717430.90383800:01
310.0212710.4612860.90403800:01
320.0289990.4570900.90503800:01
330.0207660.4760780.90064000:01
340.0204450.4856030.90383800:01
350.0135570.4889220.90463800:01
360.0191660.4804850.90723700:01
370.0106790.4545220.91143500:01
380.0123570.4444170.91023600:01
390.0081350.4717800.90663700:01
400.0059190.4927160.90963600:01
410.0055580.4848250.91043600:01
420.0072300.5070190.90983600:01
430.0118920.5010900.91443400:01
440.0057600.4849830.91283500:01
450.0033290.4926220.91203500:01
460.0057620.4945830.91383400:01
470.0040020.4892600.91543400:01
480.0036670.4594720.91663300:01
490.0030550.4911180.91783300:01
500.0022620.4849790.91963200:01
510.0023460.4562070.91823300:01
520.0014850.4595860.91923200:01
530.0015360.4702430.91883200:01
540.0012430.4695940.91763300:01
550.0008900.4742300.91703300:01
560.0008830.4802280.91983200:01
570.0009400.4785310.91743300:01
580.0008580.4788180.91543400:01
590.0032460.4719250.91723300:01
" 2807 | ], 2808 | "text/plain": [ 2809 | "" 2810 | ] 2811 | }, 2812 | "metadata": {}, 2813 | "output_type": "display_data" 2814 | } 2815 | ], 2816 | "source": [ 2817 | "learn.fit_one_cycle(cyc_len=cycles, max_lr=max_lr, wd=w_decay)" 2818 | ] 2819 | }, 2820 | { 2821 | "cell_type": "code", 2822 | "execution_count": 45, 2823 | "metadata": { 2824 | "ExecuteTime": { 2825 | "end_time": "2020-02-04T08:00:15.883521Z", 2826 | "start_time": "2020-02-04T08:00:15.822953Z" 2827 | } 2828 | }, 2829 | "outputs": [], 2830 | "source": [ 2831 | "learn.save('no_reg_save_0917')\n", 2832 | "learn.export('no_reg_0917')" 2833 | ] 2834 | }, 2835 | { 2836 | "cell_type": "markdown", 2837 | "metadata": {}, 2838 | "source": [ 2839 | "### Test" 2840 | ] 2841 | }, 2842 | { 2843 | "cell_type": "code", 2844 | "execution_count": 46, 2845 | "metadata": { 2846 | "ExecuteTime": { 2847 | "end_time": "2020-02-04T08:00:29.345047Z", 2848 | "start_time": "2020-02-04T08:00:28.132602Z" 2849 | } 2850 | }, 2851 | "outputs": [], 2852 | "source": [ 2853 | "test_df = ld_var(name='test_df', path=path)" 2854 | ] 2855 | }, 2856 | { 2857 | "cell_type": "code", 2858 | "execution_count": 47, 2859 | "metadata": { 2860 | "ExecuteTime": { 2861 | "end_time": "2020-02-04T08:00:30.571916Z", 2862 | "start_time": "2020-02-04T08:00:30.340314Z" 2863 | } 2864 | }, 2865 | "outputs": [ 2866 | { 2867 | "data": { 2868 | "text/html": [ 2869 | "
\n", 2870 | "\n", 2883 | "\n", 2884 | " \n", 2885 | " \n", 2886 | " \n", 2887 | " \n", 2888 | " \n", 2889 | " \n", 2890 | " \n", 2891 | " \n", 2892 | " \n", 2893 | " \n", 2894 | " \n", 2895 | " \n", 2896 | " \n", 2897 | " \n", 2898 | " \n", 2899 | " \n", 2900 | " \n", 2901 | " \n", 2902 | " \n", 2903 | " \n", 2904 | " \n", 2905 | " \n", 2906 | " \n", 2907 | " \n", 2908 | " \n", 2909 | " \n", 2910 | " \n", 2911 | " \n", 2912 | " \n", 2913 | " \n", 2914 | " \n", 2915 | " \n", 2916 | " \n", 2917 | " \n", 2918 | " \n", 2919 | " \n", 2920 | " \n", 2921 | " \n", 2922 | " \n", 2923 | " \n", 2924 | " \n", 2925 | " \n", 2926 | " \n", 2927 | " \n", 2928 | " \n", 2929 | " \n", 2930 | " \n", 2931 | " \n", 2932 | " \n", 2933 | " \n", 2934 | " \n", 2935 | " \n", 2936 | " \n", 2937 | " \n", 2938 | " \n", 2939 | " \n", 2940 | " \n", 2941 | " \n", 2942 | " \n", 2943 | " \n", 2944 | " \n", 2945 | " \n", 2946 | " \n", 2947 | " \n", 2948 | " \n", 2949 | " \n", 2950 | " \n", 2951 | " \n", 2952 | " \n", 2953 | " \n", 2954 | " \n", 2955 | " \n", 2956 | " \n", 2957 | " \n", 2958 | " \n", 2959 | " \n", 2960 | " \n", 2961 | " \n", 2962 | " \n", 2963 | " \n", 2964 | " \n", 2965 | " \n", 2966 | " \n", 2967 | " \n", 2968 | " \n", 2969 | " \n", 2970 | " \n", 2971 | " \n", 2972 | " \n", 2973 | " \n", 2974 | " \n", 2975 | " \n", 2976 | " \n", 2977 | " \n", 2978 | " \n", 2979 | " \n", 2980 | " \n", 2981 | " \n", 2982 | " \n", 2983 | " \n", 2984 | "
idS1C1S2C2S3C3S4C4S5C5cardsnew_hand
679436679437Diamonds10Spades2HeartsKingClubsJackClubs3[2, 3, 10, 11, 13]Nothing
607096607097Diamonds3SpadesAceClubs8Hearts6Diamonds8[1, 3, 6, 8, 8, 14]Pair
690751690752SpadesQueenDiamondsJackHeartsQueenSpades8HeartsJack[8, 11, 11, 12, 12]Two pairs
887951887952Hearts8Hearts2Hearts6Clubs7Hearts3[2, 3, 6, 7, 8]Nothing
667297667298ClubsAceHearts9Hearts3SpadesJackSpadesQueen[1, 3, 9, 11, 12, 14]Nothing
\n", 2985 | "
" 2986 | ], 2987 | "text/plain": [ 2988 | " id S1 C1 S2 C2 S3 C3 S4 C4 \\\n", 2989 | "679436 679437 Diamonds 10 Spades 2 Hearts King Clubs Jack \n", 2990 | "607096 607097 Diamonds 3 Spades Ace Clubs 8 Hearts 6 \n", 2991 | "690751 690752 Spades Queen Diamonds Jack Hearts Queen Spades 8 \n", 2992 | "887951 887952 Hearts 8 Hearts 2 Hearts 6 Clubs 7 \n", 2993 | "667297 667298 Clubs Ace Hearts 9 Hearts 3 Spades Jack \n", 2994 | "\n", 2995 | " S5 C5 cards new_hand \n", 2996 | "679436 Clubs 3 [2, 3, 10, 11, 13] Nothing \n", 2997 | "607096 Diamonds 8 [1, 3, 6, 8, 8, 14] Pair \n", 2998 | "690751 Hearts Jack [8, 11, 11, 12, 12] Two pairs \n", 2999 | "887951 Hearts 3 [2, 3, 6, 7, 8] Nothing \n", 3000 | "667297 Spades Queen [1, 3, 9, 11, 12, 14] Nothing " 3001 | ] 3002 | }, 3003 | "execution_count": 47, 3004 | "metadata": {}, 3005 | "output_type": "execute_result" 3006 | } 3007 | ], 3008 | "source": [ 3009 | "test_df.sample(5)" 3010 | ] 3011 | }, 3012 | { 3013 | "cell_type": "code", 3014 | "execution_count": 48, 3015 | "metadata": { 3016 | "ExecuteTime": { 3017 | "end_time": "2020-02-04T08:00:38.664736Z", 3018 | "start_time": "2020-02-04T08:00:38.644077Z" 3019 | } 3020 | }, 3021 | "outputs": [], 3022 | "source": [ 3023 | "learn_test = copy(learn)" 3024 | ] 3025 | }, 3026 | { 3027 | "cell_type": "code", 3028 | "execution_count": 49, 3029 | "metadata": { 3030 | "ExecuteTime": { 3031 | "end_time": "2020-02-04T08:00:42.753519Z", 3032 | "start_time": "2020-02-04T08:00:42.733048Z" 3033 | } 3034 | }, 3035 | "outputs": [], 3036 | "source": [ 3037 | "n = 100000" 3038 | ] 3039 | }, 3040 | { 3041 | "cell_type": "code", 3042 | "execution_count": 50, 3043 | "metadata": { 3044 | "ExecuteTime": { 3045 | "end_time": "2020-02-04T08:00:44.121349Z", 3046 | "start_time": "2020-02-04T08:00:43.572327Z" 3047 | } 3048 | }, 3049 | "outputs": [], 3050 | "source": [ 3051 | "np.random.seed(1001)\n", 3052 | "preds = get_cust_preds(df=test_df.iloc[:n], learn=learn_test, bs=BS)" 3053 | ] 3054 | }, 3055 | { 3056 | "cell_type": "code", 3057 | "execution_count": 51, 3058 | "metadata": { 3059 | "ExecuteTime": { 3060 | "end_time": "2020-02-04T08:00:45.393955Z", 3061 | "start_time": "2020-02-04T08:00:45.371880Z" 3062 | } 3063 | }, 3064 | "outputs": [], 3065 | "source": [ 3066 | "y_hat = np.argmax(preds, axis = 1)" 3067 | ] 3068 | }, 3069 | { 3070 | "cell_type": "code", 3071 | "execution_count": 52, 3072 | "metadata": { 3073 | "ExecuteTime": { 3074 | "end_time": "2020-02-04T08:00:46.145068Z", 3075 | "start_time": "2020-02-04T08:00:46.124356Z" 3076 | } 3077 | }, 3078 | "outputs": [ 3079 | { 3080 | "data": { 3081 | "text/plain": [ 3082 | "100000" 3083 | ] 3084 | }, 3085 | "execution_count": 52, 3086 | "metadata": {}, 3087 | "output_type": "execute_result" 3088 | } 3089 | ], 3090 | "source": [ 3091 | "len(y_hat)" 3092 | ] 3093 | }, 3094 | { 3095 | "cell_type": "code", 3096 | "execution_count": 53, 3097 | "metadata": { 3098 | "ExecuteTime": { 3099 | "end_time": "2020-02-04T08:00:46.921427Z", 3100 | "start_time": "2020-02-04T08:00:46.900461Z" 3101 | } 3102 | }, 3103 | "outputs": [ 3104 | { 3105 | "data": { 3106 | "text/plain": [ 3107 | "{'Flush': 0,\n", 3108 | " 'Four of a kind': 1,\n", 3109 | " 'Full house': 2,\n", 3110 | " 'Nothing': 3,\n", 3111 | " 'Pair': 4,\n", 3112 | " 'Royal flush': 5,\n", 3113 | " 'Straight': 6,\n", 3114 | " 'Straight flush': 7,\n", 3115 | " 'Three of a kind': 8,\n", 3116 | " 'Two pairs': 9}" 3117 | ] 3118 | }, 3119 | "execution_count": 53, 3120 | "metadata": {}, 3121 | "output_type": "execute_result" 3122 | } 3123 | ], 3124 | "source": [ 3125 | "c2i = learn_test.data.train_ds.c2i\n", 3126 | "c2i" 3127 | ] 3128 | }, 3129 | { 3130 | "cell_type": "code", 3131 | "execution_count": 54, 3132 | "metadata": { 3133 | "ExecuteTime": { 3134 | "end_time": "2020-02-04T08:00:49.227012Z", 3135 | "start_time": "2020-02-04T08:00:49.180522Z" 3136 | } 3137 | }, 3138 | "outputs": [], 3139 | "source": [ 3140 | "y = test_df.iloc[:n]['new_hand']\n", 3141 | "y = y.replace(c2i)" 3142 | ] 3143 | }, 3144 | { 3145 | "cell_type": "code", 3146 | "execution_count": 55, 3147 | "metadata": { 3148 | "ExecuteTime": { 3149 | "end_time": "2020-02-04T08:00:50.305792Z", 3150 | "start_time": "2020-02-04T08:00:50.284309Z" 3151 | } 3152 | }, 3153 | "outputs": [], 3154 | "source": [ 3155 | "tr = np.count_nonzero(y==y_hat)" 3156 | ] 3157 | }, 3158 | { 3159 | "cell_type": "code", 3160 | "execution_count": 56, 3161 | "metadata": { 3162 | "ExecuteTime": { 3163 | "end_time": "2020-02-04T08:00:51.104537Z", 3164 | "start_time": "2020-02-04T08:00:51.084501Z" 3165 | } 3166 | }, 3167 | "outputs": [], 3168 | "source": [ 3169 | "accuracy = tr/len(y)" 3170 | ] 3171 | }, 3172 | { 3173 | "cell_type": "code", 3174 | "execution_count": 57, 3175 | "metadata": { 3176 | "ExecuteTime": { 3177 | "end_time": "2020-02-04T08:00:51.951914Z", 3178 | "start_time": "2020-02-04T08:00:51.931661Z" 3179 | } 3180 | }, 3181 | "outputs": [ 3182 | { 3183 | "data": { 3184 | "text/plain": [ 3185 | "0.91744" 3186 | ] 3187 | }, 3188 | "execution_count": 57, 3189 | "metadata": {}, 3190 | "output_type": "execute_result" 3191 | } 3192 | ], 3193 | "source": [ 3194 | "accuracy" 3195 | ] 3196 | }, 3197 | { 3198 | "cell_type": "markdown", 3199 | "metadata": {}, 3200 | "source": [ 3201 | "### 91%" 3202 | ] 3203 | } 3204 | ], 3205 | "metadata": { 3206 | "kernelspec": { 3207 | "display_name": "Python 3", 3208 | "language": "python", 3209 | "name": "python3" 3210 | }, 3211 | "language_info": { 3212 | "codemirror_mode": { 3213 | "name": "ipython", 3214 | "version": 3 3215 | }, 3216 | "file_extension": ".py", 3217 | "mimetype": "text/x-python", 3218 | "name": "python", 3219 | "nbconvert_exporter": "python", 3220 | "pygments_lexer": "ipython3", 3221 | "version": "3.7.3" 3222 | } 3223 | }, 3224 | "nbformat": 4, 3225 | "nbformat_minor": 2 3226 | } 3227 | -------------------------------------------------------------------------------- /test_tablr_mixup_quick_n_dirty_messy_code.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "### Mixup Test - proof of concept (doesn't work)" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "If we split our model into 2 parts embedding + rest_of_the_model, then we can use the second part as a model and just shift the input of it (in dataloader or callback). We just pass out initial data though embeddings layer and then blend the result. That will be our inputs. I think it’s fair to call in this case the second part (rest_of_the_model) as ‘the model’ as only this part can be trained (I cannot think of the way how to train embeddings as well in a mixup). And feedforward though embeddings is now just a part of preprocessing step. Definetly, first of all we have to train your model in a normal way, as we want to produce our embeddings. Then we can use the_rest_of_the_model and retrain it or throw it away and use only embeddings (and new the_rest_of_the_model) for a mixup training." 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 1, 20 | "metadata": { 21 | "ExecuteTime": { 22 | "end_time": "2019-08-08T22:40:50.726941Z", 23 | "start_time": "2019-08-08T22:40:50.712470Z" 24 | } 25 | }, 26 | "outputs": [], 27 | "source": [ 28 | "%reload_ext autoreload\n", 29 | "%autoreload 2" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 2, 35 | "metadata": { 36 | "ExecuteTime": { 37 | "end_time": "2019-08-08T22:40:52.966895Z", 38 | "start_time": "2019-08-08T22:40:51.645281Z" 39 | } 40 | }, 41 | "outputs": [], 42 | "source": [ 43 | "from fastai.tabular import *" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 3, 49 | "metadata": { 50 | "ExecuteTime": { 51 | "end_time": "2019-08-08T22:40:53.127382Z", 52 | "start_time": "2019-08-08T22:40:52.969777Z" 53 | } 54 | }, 55 | "outputs": [], 56 | "source": [ 57 | "from exp.nb_ import *" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": 4, 63 | "metadata": { 64 | "ExecuteTime": { 65 | "end_time": "2019-08-08T22:40:55.252254Z", 66 | "start_time": "2019-08-08T22:40:53.129235Z" 67 | } 68 | }, 69 | "outputs": [], 70 | "source": [ 71 | "path=Path('data/')\n", 72 | "train_df = pd.read_pickle(path/'train_clean.zip', compression='zip')\n", 73 | "test_df = pd.read_pickle(path/'test_clean')" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": 5, 79 | "metadata": { 80 | "ExecuteTime": { 81 | "end_time": "2019-08-08T22:40:55.281407Z", 82 | "start_time": "2019-08-08T22:40:55.254653Z" 83 | } 84 | }, 85 | "outputs": [], 86 | "source": [ 87 | "procs=[FillMissing, Categorify, Normalize]" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": 6, 93 | "metadata": { 94 | "ExecuteTime": { 95 | "end_time": "2019-08-08T22:40:55.316131Z", 96 | "start_time": "2019-08-08T22:40:55.283465Z" 97 | } 98 | }, 99 | "outputs": [], 100 | "source": [ 101 | "cat_vars_tpl = ('Store', 'DayOfWeek', 'Year', 'Month', 'Day', 'StateHoliday', 'CompetitionMonthsOpen',\n", 102 | " 'Promo2Weeks', 'StoreType', 'Assortment', 'PromoInterval', 'CompetitionOpenSinceYear', 'Promo2SinceYear',\n", 103 | " 'State', 'Week', 'Events', 'Promo_fw', 'Promo_bw', 'StateHoliday_fw', 'StateHoliday_bw',\n", 104 | " 'SchoolHoliday_fw', 'SchoolHoliday_bw')\n", 105 | "\n", 106 | "cont_vars_tpl = ('CompetitionDistance', 'Max_TemperatureC', 'Mean_TemperatureC', 'Min_TemperatureC',\n", 107 | " 'Max_Humidity', 'Mean_Humidity', 'Min_Humidity', 'Max_Wind_SpeedKm_h', \n", 108 | " 'Mean_Wind_SpeedKm_h', 'CloudCover', 'trend', 'trend_DE',\n", 109 | " 'AfterStateHoliday', 'BeforeStateHoliday', 'Promo', 'SchoolHoliday')\n", 110 | "cat_vars = list(cat_vars_tpl)\n", 111 | "cont_vars = list(cont_vars_tpl)\n", 112 | "all_vars = cat_vars + cont_vars" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": 7, 118 | "metadata": { 119 | "ExecuteTime": { 120 | "end_time": "2019-08-08T22:40:57.494560Z", 121 | "start_time": "2019-08-08T22:40:56.983581Z" 122 | } 123 | }, 124 | "outputs": [], 125 | "source": [ 126 | "dep_var = 'Sales'\n", 127 | "df = train_df[cat_vars + cont_vars + [dep_var,'Date']].copy()" 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": 8, 133 | "metadata": { 134 | "ExecuteTime": { 135 | "end_time": "2019-08-08T22:40:58.835333Z", 136 | "start_time": "2019-08-08T22:40:58.354100Z" 137 | } 138 | }, 139 | "outputs": [], 140 | "source": [ 141 | "#this step reduces the data as whole dataset doesn't fit into my memory after preprocessing\n", 142 | "np.random.seed(1001)\n", 143 | "coef = 0.3\n", 144 | "ln = len(df)\n", 145 | "part_idx = np.random.choice(ln, int(ln*coef), replace=False)\n", 146 | "df = df.iloc[part_idx]\n", 147 | "df.sort_values(by='Date', ascending=False, inplace=True)\n", 148 | "df = df.reset_index()" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": 11, 154 | "metadata": { 155 | "ExecuteTime": { 156 | "end_time": "2019-08-08T22:42:30.120648Z", 157 | "start_time": "2019-08-08T22:42:30.080530Z" 158 | } 159 | }, 160 | "outputs": [ 161 | { 162 | "data": { 163 | "text/plain": [ 164 | "12443" 165 | ] 166 | }, 167 | "execution_count": 11, 168 | "metadata": {}, 169 | "output_type": "execute_result" 170 | } 171 | ], 172 | "source": [ 173 | "cut = df['Date'][(df['Date'] == df['Date'][int(len(test_df)*coef)])].index.max()\n", 174 | "cut" 175 | ] 176 | }, 177 | { 178 | "cell_type": "code", 179 | "execution_count": 12, 180 | "metadata": { 181 | "ExecuteTime": { 182 | "end_time": "2019-08-08T22:42:48.932050Z", 183 | "start_time": "2019-08-08T22:42:48.902394Z" 184 | } 185 | }, 186 | "outputs": [ 187 | { 188 | "data": { 189 | "text/plain": [ 190 | "range(0, 12443)" 191 | ] 192 | }, 193 | "execution_count": 12, 194 | "metadata": {}, 195 | "output_type": "execute_result" 196 | } 197 | ], 198 | "source": [ 199 | "valid_idx = range(cut)\n", 200 | "valid_idx" 201 | ] 202 | }, 203 | { 204 | "cell_type": "code", 205 | "execution_count": 13, 206 | "metadata": { 207 | "ExecuteTime": { 208 | "end_time": "2019-08-08T22:42:51.887831Z", 209 | "start_time": "2019-08-08T22:42:51.862088Z" 210 | } 211 | }, 212 | "outputs": [], 213 | "source": [ 214 | "BS = 1024" 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "execution_count": 14, 220 | "metadata": { 221 | "ExecuteTime": { 222 | "end_time": "2019-08-08T22:42:54.362351Z", 223 | "start_time": "2019-08-08T22:42:52.853880Z" 224 | } 225 | }, 226 | "outputs": [], 227 | "source": [ 228 | "data = (TabularList.from_df(df, path=path, cat_names=cat_vars, cont_names=cont_vars, procs=procs)\n", 229 | " .split_by_idx(valid_idx)\n", 230 | " .label_from_df(cols=dep_var, label_cls=FloatList, log=True)\n", 231 | " .databunch(bs=BS))" 232 | ] 233 | }, 234 | { 235 | "cell_type": "code", 236 | "execution_count": 15, 237 | "metadata": { 238 | "ExecuteTime": { 239 | "end_time": "2019-08-08T22:42:57.955043Z", 240 | "start_time": "2019-08-08T22:42:55.174795Z" 241 | } 242 | }, 243 | "outputs": [], 244 | "source": [ 245 | "max_log_y = np.log(np.max(train_df['Sales'])*1.2)\n", 246 | "y_range = torch.tensor([0, max_log_y], device=defaults.device)" 247 | ] 248 | }, 249 | { 250 | "cell_type": "code", 251 | "execution_count": 16, 252 | "metadata": { 253 | "ExecuteTime": { 254 | "end_time": "2019-08-08T22:42:59.064136Z", 255 | "start_time": "2019-08-08T22:42:58.926096Z" 256 | } 257 | }, 258 | "outputs": [], 259 | "source": [ 260 | "np.random.seed(1001)\n", 261 | "learn = tabular_learner(data, layers=[1000,500], ps=[0.001,0.01], emb_drop=0.04, \n", 262 | " y_range=y_range, metrics=exp_rmspe)" 263 | ] 264 | }, 265 | { 266 | "cell_type": "code", 267 | "execution_count": 17, 268 | "metadata": { 269 | "ExecuteTime": { 270 | "end_time": "2019-08-08T22:43:40.416017Z", 271 | "start_time": "2019-08-08T22:43:00.904031Z" 272 | } 273 | }, 274 | "outputs": [ 275 | { 276 | "data": { 277 | "text/html": [ 278 | "Total time: 00:39

\n", 279 | " \n", 280 | " \n", 281 | " \n", 282 | " \n", 283 | " \n", 284 | " \n", 285 | " \n", 286 | " \n", 287 | " \n", 288 | " \n", 289 | " \n", 290 | " \n", 291 | " \n", 292 | " \n", 293 | " \n", 294 | " \n", 295 | " \n", 296 | " \n", 297 | " \n", 298 | " \n", 299 | " \n", 300 | " \n", 301 | " \n", 302 | " \n", 303 | " \n", 304 | " \n", 305 | " \n", 306 | " \n", 307 | " \n", 308 | " \n", 309 | " \n", 310 | " \n", 311 | " \n", 312 | " \n", 313 | " \n", 314 | " \n", 315 | " \n", 316 | " \n", 317 | " \n", 318 | " \n", 319 | " \n", 320 | " \n", 321 | " \n", 322 | " \n", 323 | " \n", 324 | " \n", 325 | " \n", 326 | " \n", 327 | " \n", 328 | "
epochtrain_lossvalid_lossexp_rmspetime
00.2314960.0916580.37554700:07
10.0260470.0185670.13976900:06
20.0183470.0201240.15387000:06
30.0153850.0220240.16202600:06
40.0111400.0124280.11332000:06
50.0086640.0122980.10853700:06
\n" 329 | ], 330 | "text/plain": [ 331 | "" 332 | ] 333 | }, 334 | "metadata": {}, 335 | "output_type": "display_data" 336 | } 337 | ], 338 | "source": [ 339 | "learn.fit_one_cycle(6, 1e-2, wd=0.2)" 340 | ] 341 | }, 342 | { 343 | "cell_type": "code", 344 | "execution_count": 18, 345 | "metadata": { 346 | "ExecuteTime": { 347 | "end_time": "2019-08-08T22:44:23.400878Z", 348 | "start_time": "2019-08-08T22:43:44.887710Z" 349 | } 350 | }, 351 | "outputs": [ 352 | { 353 | "data": { 354 | "text/html": [ 355 | "Total time: 00:38

\n", 356 | " \n", 357 | " \n", 358 | " \n", 359 | " \n", 360 | " \n", 361 | " \n", 362 | " \n", 363 | " \n", 364 | " \n", 365 | " \n", 366 | " \n", 367 | " \n", 368 | " \n", 369 | " \n", 370 | " \n", 371 | " \n", 372 | " \n", 373 | " \n", 374 | " \n", 375 | " \n", 376 | " \n", 377 | " \n", 378 | " \n", 379 | " \n", 380 | " \n", 381 | " \n", 382 | " \n", 383 | " \n", 384 | " \n", 385 | " \n", 386 | " \n", 387 | " \n", 388 | " \n", 389 | " \n", 390 | " \n", 391 | " \n", 392 | " \n", 393 | " \n", 394 | " \n", 395 | " \n", 396 | " \n", 397 | " \n", 398 | " \n", 399 | " \n", 400 | " \n", 401 | " \n", 402 | " \n", 403 | " \n", 404 | " \n", 405 | "
epochtrain_lossvalid_lossexp_rmspetime
00.0080300.0123630.10842900:06
10.0079760.0122010.10811900:06
20.0078390.0122100.10768600:06
30.0078390.0120340.10756800:06
40.0076040.0121830.10837000:06
50.0075200.0122030.10859900:06
\n" 406 | ], 407 | "text/plain": [ 408 | "" 409 | ] 410 | }, 411 | "metadata": {}, 412 | "output_type": "display_data" 413 | } 414 | ], 415 | "source": [ 416 | "learn.fit_one_cycle(6, 1e-4, wd=0.2)" 417 | ] 418 | }, 419 | { 420 | "cell_type": "markdown", 421 | "metadata": {}, 422 | "source": [ 423 | "### Santity test" 424 | ] 425 | }, 426 | { 427 | "cell_type": "markdown", 428 | "metadata": {}, 429 | "source": [ 430 | "Here we try to separate embeddings from the model. First of all we preprocess the data (normalize, categorize and fill missing if needed) Then we take a model and feed it all our data (train and valid in separate dataframes) to embedding layers only. Then we concat these values with cont values. So now we have a bunch of floats for each row of data. This is what our NN (apart from embeddings) really gets as input. Our last step is to pretend that these floats are just a bunch of cont values and try to teach NN in a normal way (without preprocessing (!)) If we will get the similar results as in normal training, then our methon does work and we can think of a mixup." 431 | ] 432 | }, 433 | { 434 | "cell_type": "markdown", 435 | "metadata": {}, 436 | "source": [ 437 | "By the way we already have a function that makes all the preprocess and outputs the 'real model input'. I've made it for Random Forrest with embedding case (RF vs NN) in https://github.com/Pak911/fastai-shared-notebooks/blob/master/interpret_tabular.ipynb" 438 | ] 439 | }, 440 | { 441 | "cell_type": "code", 442 | "execution_count": 57, 443 | "metadata": { 444 | "ExecuteTime": { 445 | "end_time": "2019-08-08T22:57:43.262470Z", 446 | "start_time": "2019-08-08T22:57:43.183749Z" 447 | } 448 | }, 449 | "outputs": [], 450 | "source": [ 451 | "ln = len(df)\n", 452 | "train_idx = list_diff(list1=range(ln), list2=valid_idx)" 453 | ] 454 | }, 455 | { 456 | "cell_type": "code", 457 | "execution_count": 58, 458 | "metadata": { 459 | "ExecuteTime": { 460 | "end_time": "2019-08-08T22:57:44.027688Z", 461 | "start_time": "2019-08-08T22:57:43.924000Z" 462 | } 463 | }, 464 | "outputs": [], 465 | "source": [ 466 | "tr_df = df.iloc[train_idx]\n", 467 | "val_df = df.iloc[valid_idx]" 468 | ] 469 | }, 470 | { 471 | "cell_type": "code", 472 | "execution_count": 59, 473 | "metadata": { 474 | "ExecuteTime": { 475 | "end_time": "2019-08-08T22:57:58.953878Z", 476 | "start_time": "2019-08-08T22:57:57.487145Z" 477 | } 478 | }, 479 | "outputs": [], 480 | "source": [ 481 | "tr_data_inner = to_np(get_inner_repr(df=tr_df[all_vars], learn=learn))" 482 | ] 483 | }, 484 | { 485 | "cell_type": "code", 486 | "execution_count": 60, 487 | "metadata": { 488 | "ExecuteTime": { 489 | "end_time": "2019-08-08T22:57:59.118763Z", 490 | "start_time": "2019-08-08T22:57:58.956303Z" 491 | } 492 | }, 493 | "outputs": [], 494 | "source": [ 495 | "val_data_inner = to_np(get_inner_repr(df=val_df[all_vars], learn=learn))" 496 | ] 497 | }, 498 | { 499 | "cell_type": "code", 500 | "execution_count": 61, 501 | "metadata": { 502 | "ExecuteTime": { 503 | "end_time": "2019-08-08T22:57:59.886874Z", 504 | "start_time": "2019-08-08T22:57:59.845032Z" 505 | } 506 | }, 507 | "outputs": [ 508 | { 509 | "data": { 510 | "text/plain": [ 511 | "(240858, 12443)" 512 | ] 513 | }, 514 | "execution_count": 61, 515 | "metadata": {}, 516 | "output_type": "execute_result" 517 | } 518 | ], 519 | "source": [ 520 | "len(tr_data_inner), len(val_data_inner)" 521 | ] 522 | }, 523 | { 524 | "cell_type": "code", 525 | "execution_count": 62, 526 | "metadata": { 527 | "ExecuteTime": { 528 | "end_time": "2019-08-08T22:58:04.878227Z", 529 | "start_time": "2019-08-08T22:58:04.839808Z" 530 | } 531 | }, 532 | "outputs": [ 533 | { 534 | "data": { 535 | "text/plain": [ 536 | "array([-0.041985, -0.018466, -0.071389, -0.03445 , ..., -0.12327 , 0.193699, 0.361726, -0.061271], dtype=float32)" 537 | ] 538 | }, 539 | "execution_count": 62, 540 | "metadata": {}, 541 | "output_type": "execute_result" 542 | } 543 | ], 544 | "source": [ 545 | "tr_data_inner[0]" 546 | ] 547 | }, 548 | { 549 | "cell_type": "code", 550 | "execution_count": 63, 551 | "metadata": { 552 | "ExecuteTime": { 553 | "end_time": "2019-08-08T22:58:05.967112Z", 554 | "start_time": "2019-08-08T22:58:05.937636Z" 555 | } 556 | }, 557 | "outputs": [ 558 | { 559 | "data": { 560 | "text/plain": [ 561 | "(233, 233)" 562 | ] 563 | }, 564 | "execution_count": 63, 565 | "metadata": {}, 566 | "output_type": "execute_result" 567 | } 568 | ], 569 | "source": [ 570 | "len(tr_data_inner[0]), len(val_data_inner[0])" 571 | ] 572 | }, 573 | { 574 | "cell_type": "code", 575 | "execution_count": 64, 576 | "metadata": { 577 | "ExecuteTime": { 578 | "end_time": "2019-08-08T22:58:07.301877Z", 579 | "start_time": "2019-08-08T22:58:07.273359Z" 580 | } 581 | }, 582 | "outputs": [], 583 | "source": [ 584 | "tr_inner_df = pd.DataFrame(tr_data_inner)\n", 585 | "val_inner_df = pd.DataFrame(val_data_inner)" 586 | ] 587 | }, 588 | { 589 | "cell_type": "code", 590 | "execution_count": 65, 591 | "metadata": { 592 | "ExecuteTime": { 593 | "end_time": "2019-08-08T22:58:08.046808Z", 594 | "start_time": "2019-08-08T22:58:07.964978Z" 595 | } 596 | }, 597 | "outputs": [], 598 | "source": [ 599 | "tr_inner_df[dep_var] = tr_df.reset_index()[dep_var]\n", 600 | "val_inner_df[dep_var] = val_df.reset_index()[dep_var]" 601 | ] 602 | }, 603 | { 604 | "cell_type": "code", 605 | "execution_count": 66, 606 | "metadata": { 607 | "ExecuteTime": { 608 | "end_time": "2019-08-08T22:58:09.359024Z", 609 | "start_time": "2019-08-08T22:58:09.223079Z" 610 | } 611 | }, 612 | "outputs": [], 613 | "source": [ 614 | "merge_inner_df = pd.concat([tr_inner_df, val_inner_df])" 615 | ] 616 | }, 617 | { 618 | "cell_type": "code", 619 | "execution_count": 67, 620 | "metadata": { 621 | "ExecuteTime": { 622 | "end_time": "2019-08-08T22:58:09.933757Z", 623 | "start_time": "2019-08-08T22:58:09.905949Z" 624 | } 625 | }, 626 | "outputs": [], 627 | "source": [ 628 | "inner_val_idx = range(len(tr_inner_df), len(merge_inner_df))" 629 | ] 630 | }, 631 | { 632 | "cell_type": "code", 633 | "execution_count": 68, 634 | "metadata": { 635 | "ExecuteTime": { 636 | "end_time": "2019-08-08T22:58:10.698564Z", 637 | "start_time": "2019-08-08T22:58:10.670942Z" 638 | } 639 | }, 640 | "outputs": [], 641 | "source": [ 642 | "inner_cont_vars = list_diff(merge_inner_df.columns, [dep_var])" 643 | ] 644 | }, 645 | { 646 | "cell_type": "code", 647 | "execution_count": 69, 648 | "metadata": { 649 | "ExecuteTime": { 650 | "end_time": "2019-08-08T22:58:11.914443Z", 651 | "start_time": "2019-08-08T22:58:11.887842Z" 652 | } 653 | }, 654 | "outputs": [], 655 | "source": [ 656 | "inner_procs=[]" 657 | ] 658 | }, 659 | { 660 | "cell_type": "code", 661 | "execution_count": 70, 662 | "metadata": { 663 | "ExecuteTime": { 664 | "end_time": "2019-08-08T22:58:14.550660Z", 665 | "start_time": "2019-08-08T22:58:12.453987Z" 666 | } 667 | }, 668 | "outputs": [], 669 | "source": [ 670 | "inner_data = (TabularList.from_df(merge_inner_df, path=path, cat_names=[], cont_names=inner_cont_vars, procs=inner_procs)\n", 671 | " .split_by_idx(inner_val_idx)\n", 672 | " .label_from_df(cols=dep_var, label_cls=FloatList, log=True)\n", 673 | " .databunch(bs=BS))" 674 | ] 675 | }, 676 | { 677 | "cell_type": "code", 678 | "execution_count": 71, 679 | "metadata": { 680 | "ExecuteTime": { 681 | "end_time": "2019-08-08T22:58:14.666037Z", 682 | "start_time": "2019-08-08T22:58:14.552706Z" 683 | } 684 | }, 685 | "outputs": [], 686 | "source": [ 687 | "np.random.seed(1001)\n", 688 | "inner_learn = tabular_learner(inner_data, layers=[1000,500], ps=[0.001,0.01], emb_drop=0.04, \n", 689 | " y_range=y_range, metrics=exp_rmspe)" 690 | ] 691 | }, 692 | { 693 | "cell_type": "code", 694 | "execution_count": 72, 695 | "metadata": { 696 | "ExecuteTime": { 697 | "end_time": "2019-08-08T22:58:56.539386Z", 698 | "start_time": "2019-08-08T22:58:14.667901Z" 699 | } 700 | }, 701 | "outputs": [ 702 | { 703 | "data": { 704 | "text/html": [ 705 | "Total time: 00:41

\n", 706 | " \n", 707 | " \n", 708 | " \n", 709 | " \n", 710 | " \n", 711 | " \n", 712 | " \n", 713 | " \n", 714 | " \n", 715 | " \n", 716 | " \n", 717 | " \n", 718 | " \n", 719 | " \n", 720 | " \n", 721 | " \n", 722 | " \n", 723 | " \n", 724 | " \n", 725 | " \n", 726 | " \n", 727 | " \n", 728 | " \n", 729 | " \n", 730 | " \n", 731 | " \n", 732 | " \n", 733 | " \n", 734 | " \n", 735 | " \n", 736 | " \n", 737 | " \n", 738 | " \n", 739 | " \n", 740 | " \n", 741 | " \n", 742 | " \n", 743 | " \n", 744 | " \n", 745 | " \n", 746 | " \n", 747 | " \n", 748 | " \n", 749 | " \n", 750 | " \n", 751 | " \n", 752 | " \n", 753 | " \n", 754 | " \n", 755 | "
epochtrain_lossvalid_lossexp_rmspetime
00.2245780.0236410.15864200:07
10.0198420.0316790.18690100:06
20.0148210.0217960.16209000:06
30.0114550.0138800.11380900:07
40.0093950.0124890.10861300:07
50.0076420.0119040.10722300:07
\n" 756 | ], 757 | "text/plain": [ 758 | "" 759 | ] 760 | }, 761 | "metadata": {}, 762 | "output_type": "display_data" 763 | } 764 | ], 765 | "source": [ 766 | "inner_learn.fit_one_cycle(6, 1e-2, wd=0.2)" 767 | ] 768 | }, 769 | { 770 | "cell_type": "code", 771 | "execution_count": 73, 772 | "metadata": { 773 | "ExecuteTime": { 774 | "end_time": "2019-08-08T22:59:36.306310Z", 775 | "start_time": "2019-08-08T22:58:56.542125Z" 776 | } 777 | }, 778 | "outputs": [ 779 | { 780 | "data": { 781 | "text/html": [ 782 | "Total time: 00:39

\n", 783 | " \n", 784 | " \n", 785 | " \n", 786 | " \n", 787 | " \n", 788 | " \n", 789 | " \n", 790 | " \n", 791 | " \n", 792 | " \n", 793 | " \n", 794 | " \n", 795 | " \n", 796 | " \n", 797 | " \n", 798 | " \n", 799 | " \n", 800 | " \n", 801 | " \n", 802 | " \n", 803 | " \n", 804 | " \n", 805 | " \n", 806 | " \n", 807 | " \n", 808 | " \n", 809 | " \n", 810 | " \n", 811 | " \n", 812 | " \n", 813 | " \n", 814 | " \n", 815 | " \n", 816 | " \n", 817 | " \n", 818 | " \n", 819 | " \n", 820 | " \n", 821 | " \n", 822 | " \n", 823 | " \n", 824 | " \n", 825 | " \n", 826 | " \n", 827 | " \n", 828 | " \n", 829 | " \n", 830 | " \n", 831 | " \n", 832 | "
epochtrain_lossvalid_lossexp_rmspetime
00.0072680.0120540.10692300:06
10.0074650.0123560.10767200:06
20.0072910.0118560.10659400:06
30.0071550.0121970.10702900:06
40.0069920.0120170.10682000:07
50.0069670.0121350.10694300:06
\n" 833 | ], 834 | "text/plain": [ 835 | "" 836 | ] 837 | }, 838 | "metadata": {}, 839 | "output_type": "display_data" 840 | } 841 | ], 842 | "source": [ 843 | "inner_learn.fit_one_cycle(6, 1e-4, wd=0.2)" 844 | ] 845 | }, 846 | { 847 | "cell_type": "markdown", 848 | "metadata": {}, 849 | "source": [ 850 | "So now we see that this method produce pretty the same result.\n", 851 | "\n", 852 | "Let's try to incorporate mixup here" 853 | ] 854 | }, 855 | { 856 | "cell_type": "markdown", 857 | "metadata": {}, 858 | "source": [ 859 | "### Mixup" 860 | ] 861 | }, 862 | { 863 | "cell_type": "markdown", 864 | "metadata": {}, 865 | "source": [ 866 | "And here we just add some interpolation in data (and depended valiable as we have regression here)" 867 | ] 868 | }, 869 | { 870 | "cell_type": "code", 871 | "execution_count": 74, 872 | "metadata": { 873 | "ExecuteTime": { 874 | "end_time": "2019-08-08T22:59:36.340253Z", 875 | "start_time": "2019-08-08T22:59:36.309182Z" 876 | } 877 | }, 878 | "outputs": [], 879 | "source": [ 880 | "alpha = 0.2" 881 | ] 882 | }, 883 | { 884 | "cell_type": "code", 885 | "execution_count": 75, 886 | "metadata": { 887 | "ExecuteTime": { 888 | "end_time": "2019-08-08T22:59:36.373120Z", 889 | "start_time": "2019-08-08T22:59:36.343066Z" 890 | } 891 | }, 892 | "outputs": [], 893 | "source": [ 894 | "def interp(var1, var2, alpha):\n", 895 | " lam = np.random.beta(alpha, alpha)\n", 896 | " return lam*var1 + (1.-lam)*var2" 897 | ] 898 | }, 899 | { 900 | "cell_type": "code", 901 | "execution_count": 76, 902 | "metadata": { 903 | "ExecuteTime": { 904 | "end_time": "2019-08-08T22:59:36.400959Z", 905 | "start_time": "2019-08-08T22:59:36.374779Z" 906 | } 907 | }, 908 | "outputs": [ 909 | { 910 | "data": { 911 | "text/plain": [ 912 | "range(240858, 253301)" 913 | ] 914 | }, 915 | "execution_count": 76, 916 | "metadata": {}, 917 | "output_type": "execute_result" 918 | } 919 | ], 920 | "source": [ 921 | "inner_val_idx" 922 | ] 923 | }, 924 | { 925 | "cell_type": "code", 926 | "execution_count": 77, 927 | "metadata": { 928 | "ExecuteTime": { 929 | "end_time": "2019-08-08T22:59:36.488045Z", 930 | "start_time": "2019-08-08T22:59:36.402398Z" 931 | } 932 | }, 933 | "outputs": [], 934 | "source": [ 935 | "inner_tr_idx = list_diff(list1=range(len(merge_inner_df)), list2=inner_val_idx)" 936 | ] 937 | }, 938 | { 939 | "cell_type": "code", 940 | "execution_count": 78, 941 | "metadata": { 942 | "ExecuteTime": { 943 | "end_time": "2019-08-08T22:59:37.252769Z", 944 | "start_time": "2019-08-08T22:59:36.489737Z" 945 | } 946 | }, 947 | "outputs": [], 948 | "source": [ 949 | "np_merge_df = merge_inner_df.iloc[inner_tr_idx].to_numpy()" 950 | ] 951 | }, 952 | { 953 | "cell_type": "code", 954 | "execution_count": 79, 955 | "metadata": { 956 | "ExecuteTime": { 957 | "end_time": "2019-08-08T22:59:37.279645Z", 958 | "start_time": "2019-08-08T22:59:37.254797Z" 959 | } 960 | }, 961 | "outputs": [], 962 | "source": [ 963 | "augmented = []" 964 | ] 965 | }, 966 | { 967 | "cell_type": "code", 968 | "execution_count": 80, 969 | "metadata": { 970 | "ExecuteTime": { 971 | "end_time": "2019-08-08T22:59:43.325030Z", 972 | "start_time": "2019-08-08T22:59:37.281214Z" 973 | } 974 | }, 975 | "outputs": [], 976 | "source": [ 977 | "for _ in range(6):\n", 978 | " shfld = np_merge_df.copy()\n", 979 | " np.random.shuffle(shfld)\n", 980 | " augmented.append(pd.DataFrame(interp(shfld, np_merge_df, alpha)))" 981 | ] 982 | }, 983 | { 984 | "cell_type": "code", 985 | "execution_count": 83, 986 | "metadata": { 987 | "ExecuteTime": { 988 | "end_time": "2019-08-08T23:00:40.489800Z", 989 | "start_time": "2019-08-08T23:00:40.460802Z" 990 | } 991 | }, 992 | "outputs": [], 993 | "source": [ 994 | "del np_merge_df; del shfld" 995 | ] 996 | }, 997 | { 998 | "cell_type": "code", 999 | "execution_count": 84, 1000 | "metadata": { 1001 | "ExecuteTime": { 1002 | "end_time": "2019-08-08T23:00:48.173024Z", 1003 | "start_time": "2019-08-08T23:00:47.478772Z" 1004 | } 1005 | }, 1006 | "outputs": [], 1007 | "source": [ 1008 | "augmented = pd.concat(augmented)" 1009 | ] 1010 | }, 1011 | { 1012 | "cell_type": "code", 1013 | "execution_count": 85, 1014 | "metadata": { 1015 | "ExecuteTime": { 1016 | "end_time": "2019-08-08T23:00:54.078941Z", 1017 | "start_time": "2019-08-08T23:00:49.094286Z" 1018 | } 1019 | }, 1020 | "outputs": [], 1021 | "source": [ 1022 | "augmented.rename(columns={augmented.columns[-1]:dep_var}, inplace = True)" 1023 | ] 1024 | }, 1025 | { 1026 | "cell_type": "code", 1027 | "execution_count": 86, 1028 | "metadata": { 1029 | "ExecuteTime": { 1030 | "end_time": "2019-08-08T23:00:54.109139Z", 1031 | "start_time": "2019-08-08T23:00:54.080985Z" 1032 | } 1033 | }, 1034 | "outputs": [ 1035 | { 1036 | "data": { 1037 | "text/plain": [ 1038 | "1445148" 1039 | ] 1040 | }, 1041 | "execution_count": 86, 1042 | "metadata": {}, 1043 | "output_type": "execute_result" 1044 | } 1045 | ], 1046 | "source": [ 1047 | "len(augmented)" 1048 | ] 1049 | }, 1050 | { 1051 | "cell_type": "code", 1052 | "execution_count": 87, 1053 | "metadata": { 1054 | "ExecuteTime": { 1055 | "end_time": "2019-08-08T23:00:56.035243Z", 1056 | "start_time": "2019-08-08T23:00:55.310475Z" 1057 | } 1058 | }, 1059 | "outputs": [], 1060 | "source": [ 1061 | "merge_inner_df = pd.concat([augmented, val_inner_df])" 1062 | ] 1063 | }, 1064 | { 1065 | "cell_type": "code", 1066 | "execution_count": 88, 1067 | "metadata": { 1068 | "ExecuteTime": { 1069 | "end_time": "2019-08-08T23:00:57.658692Z", 1070 | "start_time": "2019-08-08T23:00:57.629383Z" 1071 | } 1072 | }, 1073 | "outputs": [], 1074 | "source": [ 1075 | "inner_val_idx = range(len(augmented), len(merge_inner_df))" 1076 | ] 1077 | }, 1078 | { 1079 | "cell_type": "code", 1080 | "execution_count": 89, 1081 | "metadata": { 1082 | "ExecuteTime": { 1083 | "end_time": "2019-08-08T23:01:00.166661Z", 1084 | "start_time": "2019-08-08T23:01:00.052413Z" 1085 | } 1086 | }, 1087 | "outputs": [], 1088 | "source": [ 1089 | "del augmented" 1090 | ] 1091 | }, 1092 | { 1093 | "cell_type": "code", 1094 | "execution_count": 90, 1095 | "metadata": { 1096 | "ExecuteTime": { 1097 | "end_time": "2019-08-08T23:01:03.510495Z", 1098 | "start_time": "2019-08-08T23:01:03.478078Z" 1099 | } 1100 | }, 1101 | "outputs": [], 1102 | "source": [ 1103 | "inner_cont_vars = list_diff(merge_inner_df.columns, [dep_var])" 1104 | ] 1105 | }, 1106 | { 1107 | "cell_type": "code", 1108 | "execution_count": 91, 1109 | "metadata": { 1110 | "ExecuteTime": { 1111 | "end_time": "2019-08-08T23:01:26.734526Z", 1112 | "start_time": "2019-08-08T23:01:06.150737Z" 1113 | } 1114 | }, 1115 | "outputs": [], 1116 | "source": [ 1117 | "inner_data = (TabularList.from_df(merge_inner_df, path=path, cat_names=[], cont_names=inner_cont_vars, procs=inner_procs)\n", 1118 | " .split_by_idx(inner_val_idx)\n", 1119 | " .label_from_df(cols=dep_var, label_cls=FloatList, log=True)\n", 1120 | " .databunch(bs=BS))" 1121 | ] 1122 | }, 1123 | { 1124 | "cell_type": "code", 1125 | "execution_count": 92, 1126 | "metadata": { 1127 | "ExecuteTime": { 1128 | "end_time": "2019-08-08T23:01:32.282150Z", 1129 | "start_time": "2019-08-08T23:01:31.918102Z" 1130 | } 1131 | }, 1132 | "outputs": [], 1133 | "source": [ 1134 | "np.random.seed(1001)\n", 1135 | "inner_learn = tabular_learner(inner_data, layers=[1000,500], ps=[0.001,0.01], emb_drop=0.04, \n", 1136 | " y_range=y_range, metrics=exp_rmspe)" 1137 | ] 1138 | }, 1139 | { 1140 | "cell_type": "code", 1141 | "execution_count": 93, 1142 | "metadata": { 1143 | "ExecuteTime": { 1144 | "end_time": "2019-08-08T23:05:08.028289Z", 1145 | "start_time": "2019-08-08T23:01:40.010934Z" 1146 | } 1147 | }, 1148 | "outputs": [ 1149 | { 1150 | "data": { 1151 | "text/html": [ 1152 | "Total time: 03:27

\n", 1153 | " \n", 1154 | " \n", 1155 | " \n", 1156 | " \n", 1157 | " \n", 1158 | " \n", 1159 | " \n", 1160 | " \n", 1161 | " \n", 1162 | " \n", 1163 | " \n", 1164 | " \n", 1165 | " \n", 1166 | " \n", 1167 | " \n", 1168 | " \n", 1169 | " \n", 1170 | " \n", 1171 | " \n", 1172 | " \n", 1173 | " \n", 1174 | " \n", 1175 | " \n", 1176 | " \n", 1177 | " \n", 1178 | " \n", 1179 | " \n", 1180 | " \n", 1181 | " \n", 1182 | " \n", 1183 | " \n", 1184 | " \n", 1185 | " \n", 1186 | " \n", 1187 | " \n", 1188 | " \n", 1189 | " \n", 1190 | " \n", 1191 | " \n", 1192 | " \n", 1193 | " \n", 1194 | " \n", 1195 | " \n", 1196 | " \n", 1197 | " \n", 1198 | " \n", 1199 | " \n", 1200 | " \n", 1201 | " \n", 1202 | "
epochtrain_lossvalid_lossexp_rmspetime
00.0138960.0179680.14116900:34
10.0122990.0159370.12086200:35
20.0107560.0158440.12760100:35
30.0095290.0137520.11490300:34
40.0068790.0116830.10953000:34
50.0055380.0116510.10722100:34
\n" 1203 | ], 1204 | "text/plain": [ 1205 | "" 1206 | ] 1207 | }, 1208 | "metadata": {}, 1209 | "output_type": "display_data" 1210 | } 1211 | ], 1212 | "source": [ 1213 | "inner_learn.fit_one_cycle(6, 1e-2, wd=0.2)" 1214 | ] 1215 | }, 1216 | { 1217 | "cell_type": "code", 1218 | "execution_count": 94, 1219 | "metadata": { 1220 | "ExecuteTime": { 1221 | "end_time": "2019-08-08T23:08:40.335039Z", 1222 | "start_time": "2019-08-08T23:05:08.030842Z" 1223 | } 1224 | }, 1225 | "outputs": [ 1226 | { 1227 | "data": { 1228 | "text/html": [ 1229 | "Total time: 03:32

\n", 1230 | " \n", 1231 | " \n", 1232 | " \n", 1233 | " \n", 1234 | " \n", 1235 | " \n", 1236 | " \n", 1237 | " \n", 1238 | " \n", 1239 | " \n", 1240 | " \n", 1241 | " \n", 1242 | " \n", 1243 | " \n", 1244 | " \n", 1245 | " \n", 1246 | " \n", 1247 | " \n", 1248 | " \n", 1249 | " \n", 1250 | " \n", 1251 | " \n", 1252 | " \n", 1253 | " \n", 1254 | " \n", 1255 | " \n", 1256 | " \n", 1257 | " \n", 1258 | " \n", 1259 | " \n", 1260 | " \n", 1261 | " \n", 1262 | " \n", 1263 | " \n", 1264 | " \n", 1265 | " \n", 1266 | " \n", 1267 | " \n", 1268 | " \n", 1269 | " \n", 1270 | " \n", 1271 | " \n", 1272 | " \n", 1273 | " \n", 1274 | " \n", 1275 | " \n", 1276 | " \n", 1277 | " \n", 1278 | " \n", 1279 | "
epochtrain_lossvalid_lossexp_rmspetime
00.0056240.0117750.10827600:33
10.0055470.0118590.10850500:38
20.0053250.0120780.10892400:34
30.0050320.0121310.11033100:34
40.0050080.0122840.11062300:35
50.0049990.0122650.11036100:34
\n" 1280 | ], 1281 | "text/plain": [ 1282 | "" 1283 | ] 1284 | }, 1285 | "metadata": {}, 1286 | "output_type": "display_data" 1287 | } 1288 | ], 1289 | "source": [ 1290 | "inner_learn.fit_one_cycle(6, 1e-4, wd=0.2)" 1291 | ] 1292 | }, 1293 | { 1294 | "cell_type": "markdown", 1295 | "metadata": {}, 1296 | "source": [ 1297 | "So here we don't see any improvments in terms of validation error :(" 1298 | ] 1299 | } 1300 | ], 1301 | "metadata": { 1302 | "kernelspec": { 1303 | "display_name": "Python [conda env:fastai] *", 1304 | "language": "python", 1305 | "name": "conda-env-fastai-py" 1306 | }, 1307 | "language_info": { 1308 | "codemirror_mode": { 1309 | "name": "ipython", 1310 | "version": 3 1311 | }, 1312 | "file_extension": ".py", 1313 | "mimetype": "text/x-python", 1314 | "name": "python", 1315 | "nbconvert_exporter": "python", 1316 | "pygments_lexer": "ipython3", 1317 | "version": "3.6.8" 1318 | } 1319 | }, 1320 | "nbformat": 4, 1321 | "nbformat_minor": 2 1322 | } 1323 | --------------------------------------------------------------------------------