├── README.md ├── docs └── present.pdf ├── ensemble.py ├── include.py ├── leaks ├── exp_to_group.csv └── plate_groups.csv ├── loss ├── __init__.py ├── cyclic_lr.py ├── loss.py └── rate.py ├── main.py ├── net ├── __init__.py ├── archead.py ├── imagenet_pretrain_model │ ├── __init__.py │ ├── include.py │ └── xception.py ├── model_xception_6channel.py └── model_xception_large_6channel.py ├── process ├── __init__.py ├── augmentation.py ├── data.py ├── folds │ ├── train_fold_0.csv │ ├── train_fold_1.csv │ ├── train_fold_2.csv │ ├── train_fold_3.csv │ ├── train_fold_4.csv │ ├── valid_fold_0.csv │ ├── valid_fold_1.csv │ ├── valid_fold_2.csv │ ├── valid_fold_3.csv │ └── valid_fold_4.csv ├── mean_std_batch.csv ├── mean_std_plate.csv ├── sample_submission.csv └── utils.py ├── settings.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | ## DeepNoise: Signal and Noise Disentanglement based on Classifying Fluorescent Microscopy Images via Deep Learningg(Genomics, Proteomics and Bioinformatics) 2 | 3 | [ArXiv](https://arxiv.org/abs/2209.05772) | [Cite](#reference) 4 | 5 | ### kaggle NIPS2019 Recursion Cellular Image Classification Challenge 2nd place [code](https://www.kaggle.com/c/recursion-cellular-image-classification) 6 | 7 | #### Dependencies 8 | - imgaug == 0.2.8 9 | - opencv-python==3.4.2 10 | - scikit-image==0.14.0 11 | - scikit-learn==0.19.1 12 | - scipy==1.1.0 13 | - torch==1.0.1. 14 | - torchvision==0.2.2 15 | 16 | ## Solution Development 17 | 18 | #### Input and preprocessing: 19 | - 6 channel input, image size 512*512 20 | - per image standardization: normalization of 6 channels in images per plate, with small randomization 21 | - augmentation: random flip, random rotation multiple of 90 degrees. 22 | 23 | #### Backbone CNN: 24 | - Xception and Xception_large; 25 | - Xception is trained from imagenet pretrained weights. 26 | - Xception_large is trained from srcatch which need more trainning epochs. 27 | 28 | #### Loss fuction: 29 | - Arcface Loss s=30, m=0.1: 30 | 31 | 32 | #### Path Setup 33 | Set the following path to your own in ./setting.py 34 | ``` 35 | data_dir = r'your-own-path/recursion-cellular-image-classification'#data path 36 | ``` 37 | 38 | #### Single Model Training 39 | pretrain xception_large 512x512 fold 0 with all cell types: 40 | ``` 41 | CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py --mode=pretrain --model=xception_large --image_size=512 --fold_index=0 --batch_size=64 42 | ``` 43 | 44 | finetune xception_large 512x512 fold 0 on each cell type: 45 | ``` 46 | CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py --mode=semi_finetune --model=xception_large --image_size=512 --fold_index=0 --batch_size=64 47 | ``` 48 | 49 | predict xception_large 512x512 fold 0 model: 50 | ``` 51 | CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py --mode=infer --model=xception_large --image_size=512 --fold_index=0 --batch_size=64 52 | ``` 53 | 54 | #### generate prediction 55 | 56 | ``` 57 | python ensemble.py 58 | ``` 59 | 60 | ## Reference 61 | If you find our work useful in your research or if you use parts of this code please consider citing our paper: 62 | 63 | ```@article{WANG2021102785, 64 | title = {DeepNoise: Signal and Noise Disentanglement based on Classifying Fluorescent Microscopy Images via Deep Learning}, 65 | author={Sen,Yang and Tao,Shen and Yuqi,Fang and Xiyue,Wang and Jun,Zhang and Wei,Yang and Junzhou,Huang and Xiao,Han}, 66 | journal = {Genomics, Proteomics and Bioinformatics}, 67 | year = {2023} 68 | ``` 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | -------------------------------------------------------------------------------- /docs/present.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Scu-sen/Recursion-Cellular-Image-Classification-Challenge/a45e48905b673ea89763f0648623c89fbddb85f1/docs/present.pdf -------------------------------------------------------------------------------- /ensemble.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import tqdm 3 | import numpy as np 4 | import pandas as pd 5 | from collections import defaultdict, Counter 6 | SIGFIGS = 6 7 | import pandas as pd 8 | from settings import * 9 | from utils import * 10 | import random 11 | # ### Hungarian 12 | import scipy 13 | import scipy.special 14 | import scipy.optimize 15 | 16 | ############################################################# 17 | plate = pd.read_csv('./leaks/plate_groups.csv') 18 | assign = pd.read_csv('./leaks/exp_to_group.csv') 19 | plate_dict = {} 20 | plate_dict[0] = plate['p0'] 21 | plate_dict[1] = plate['p1'] 22 | plate_dict[2] = plate['p2'] 23 | plate_dict[3] = plate['p3'] 24 | exp_to_group = assign.assignment.values 25 | all_test_exp = assign.experiment.values 26 | a_dict = {} 27 | 28 | for a_tmp, test_exp in zip(exp_to_group, all_test_exp): 29 | a_dict[test_exp] = a_tmp 30 | ############################################################# 31 | 32 | def _get_predicts(predicts, coefficients): 33 | return torch.einsum("ij,j->ij", (predicts, coefficients)) 34 | 35 | def _get_labels_distribution(predicts, coefficients): 36 | predicts = _get_predicts(predicts, coefficients) 37 | labels = predicts.argmax(dim=-1) 38 | counter = torch.bincount(labels, minlength=predicts.shape[1]) 39 | return counter 40 | 41 | def _compute_score_with_coefficients(predicts, coefficients): 42 | counter = _get_labels_distribution(predicts, coefficients).float() 43 | counter = counter * 100 / len(predicts) 44 | max_scores = torch.ones(len(coefficients)).cuda().float() * 100 / len(coefficients) 45 | result, _ = torch.min(torch.cat([counter.unsqueeze(0), max_scores.unsqueeze(0)], dim=0), dim=0) 46 | 47 | return float(result.sum().cpu()) 48 | 49 | def _find_best_coefficients(predicts, coefficients, alpha=0.001, iterations=100): 50 | best_coefficients = coefficients.clone() 51 | best_score = _compute_score_with_coefficients(predicts, coefficients) 52 | 53 | for _ in range(iterations): 54 | counter = _get_labels_distribution(predicts, coefficients) 55 | label = int(torch.argmax(counter).cpu()) 56 | coefficients[label] -= alpha 57 | score = _compute_score_with_coefficients(predicts, coefficients) 58 | 59 | if score > best_score: 60 | best_score = score 61 | best_coefficients = coefficients.clone() 62 | 63 | return best_coefficients 64 | 65 | def pre_balance(pre, iter = 3000, is_show = True): 66 | ############ input: pre like this [n_sample,n_class] cuda tensor 67 | ##########output the same as pre ,but it is numpt tensor 68 | y = pre 69 | start_alpha=0.01 70 | alpha = start_alpha 71 | min_alpha=0.0001 72 | 73 | coefs = torch.ones(y.shape[1]).cuda().float() 74 | last_score = _compute_score_with_coefficients(y, coefs) 75 | 76 | if is_show: 77 | print("Start score", last_score) 78 | 79 | while alpha >= min_alpha: 80 | coefs = _find_best_coefficients(y, coefs, iterations=iter, alpha=alpha) 81 | new_score = _compute_score_with_coefficients(y, coefs) 82 | 83 | if new_score <= last_score: 84 | alpha *= 0.5 85 | 86 | last_score = new_score 87 | 88 | if last_score == 100: 89 | break 90 | 91 | if is_show: 92 | print("Score: {}, alpha: {}".format(last_score, alpha)) 93 | 94 | predicts = _get_predicts(y, coefs) 95 | return predicts.cpu().numpy() 96 | 97 | def from_numpy(predict): 98 | 99 | y = torch.FloatTensor(predict).cuda() 100 | start_alpha=0.01 101 | alpha = start_alpha 102 | min_alpha=0.0001 103 | 104 | coefs = torch.ones(y.shape[1]).cuda().float() 105 | last_score = _compute_score_with_coefficients(y, coefs) 106 | print("Start score", last_score) 107 | 108 | while alpha >= min_alpha: 109 | coefs = _find_best_coefficients(y, coefs, iterations=3000, alpha=alpha) 110 | new_score = _compute_score_with_coefficients(y, coefs) 111 | 112 | if new_score <= last_score: 113 | alpha *= 0.5 114 | 115 | last_score = new_score 116 | print("Score: {}, alpha: {}".format(last_score, alpha)) 117 | 118 | predicts = _get_predicts(y, coefs) 119 | return predicts.cpu().numpy() 120 | 121 | def pre_balance_batch_probability(preds1,ids): 122 | ########### input : preds1 is like thie [n_sample,n_class],ids is predictived ids 123 | ########### output :pre is prediction of class,ids is predictived id 124 | pixel_mean={'id_code':ids} 125 | submission = pd.DataFrame(pixel_mean,columns=['id_code']) 126 | submission ['experiment'] = submission.id_code.str.split("_").apply(lambda a: a[0]) 127 | exper=submission['experiment'] 128 | ee=exper.drop_duplicates() 129 | pre_id=[] 130 | pre=[] 131 | 132 | for index1 in tqdm.tqdm(ee): 133 | 134 | tmp=submission[submission['experiment']==index1].index.tolist() 135 | st=tmp[0] 136 | end=tmp[-1] 137 | 138 | tmp_pre=preds1[st:end+1] 139 | id_tm=ids[st:end+1] 140 | 141 | tmp_pre_torch= torch.from_numpy(np.array(tmp_pre)).cuda().float() 142 | tmp=pre_balance(tmp_pre_torch) 143 | 144 | pre_id.extend(id_tm) 145 | pre.extend(tmp.tolist()) 146 | 147 | return pre,pre_id 148 | 149 | def pre_balance_plate_probability(preds1, ids, a_dict, a_index_dict, only_public = True, is_softmax = True): 150 | ########### input : preds1 is like thie [n_sample,n_class],ids is predictived ids 151 | ########### output :pre is prediction of class,ids is predictived id 152 | 153 | pixel_mean={'id_code': ids} 154 | submission = pd.DataFrame(pixel_mean, columns=['id_code']) 155 | submission ['experiment'] = submission.id_code.str.split("_").apply(lambda a: a[0]+'_'+a[1]) 156 | exper=submission['experiment'] 157 | ee=exper.drop_duplicates() 158 | pre_id=[] 159 | pre=[] 160 | 161 | public_set = {'HUVEC-17','HEPG2-08','RPE-08','U2OS-04'} 162 | 163 | for index1 in tqdm.tqdm(ee): 164 | a_index = a_index_dict[index1.split('_')[0]] 165 | assignment = a_dict[int(a_index)] 166 | print('assignment',a_index) 167 | 168 | plate = int(index1.split('_')[1]) 169 | print(index1) 170 | print(plate) 171 | plate_label_index = list(assignment == plate) 172 | 173 | tmp=submission[submission['experiment']==index1].index.tolist() 174 | st=tmp[0] 175 | end=tmp[-1] 176 | pre_tmp=np.array(preds1[st:end+1]) 177 | id_tm=ids[st:end+1] 178 | pre_tmp = np.array(pre_tmp[:,plate_label_index]) 179 | 180 | if is_softmax: 181 | pre_tmp = softmax_np(pre_tmp) 182 | 183 | tmp_pre_torch= torch.from_numpy(pre_tmp).cuda().float() 184 | 185 | if only_public: 186 | if index1.split('_')[0] in public_set : 187 | print('IN THE PUBLIC!!!!!!!!!!!') 188 | tmp=pre_balance(tmp_pre_torch, iter = 3000) 189 | else: 190 | tmp=pre_balance(tmp_pre_torch, iter = 3000, is_show = False) 191 | else: 192 | tmp = pre_balance(tmp_pre_torch, iter=3000) 193 | 194 | pre_id.extend(id_tm) 195 | final_predict = tmp.reshape([len(id_tm), 277]) 196 | predict_tmp = np.zeros([len(id_tm), 1108]) 197 | predict_tmp[:,plate_label_index] = final_predict 198 | pre.extend(predict_tmp.tolist()) 199 | 200 | return pre, pre_id 201 | 202 | def softmax_np(x): 203 | re = np.exp(x) / np.sum(np.exp(x), axis=0) 204 | return re 205 | 206 | def balance_plate_probability_training(preds1, ids, a_dict, a_index_dict, iters = 3000, is_show = True): 207 | ########### input : preds1 is like thie [n_sample,n_class],ids is predictived ids 208 | ########### output :pre is prediction of class,ids is predictived id 209 | 210 | pixel_mean={'id_code': ids} 211 | submission = pd.DataFrame(pixel_mean, columns=['id_code']) 212 | submission ['experiment'] = submission.id_code.str.split("_").apply(lambda a: a[0]+'_'+a[1]) 213 | exper=submission['experiment'] 214 | ee=exper.drop_duplicates() 215 | pre_id=[] 216 | pre=[] 217 | 218 | for index1 in tqdm.tqdm(ee): 219 | # print(index1) 220 | a_index = a_index_dict[index1.split('_')[0]] 221 | assignment = a_dict[int(a_index)] 222 | # print('assignment',a_index) 223 | plate = int(index1.split('_')[1]) 224 | # print(plate) 225 | plate_label_index = list(assignment == plate) 226 | 227 | tmp=submission[submission['experiment']==index1].index.tolist() 228 | st=tmp[0] 229 | end=tmp[-1] 230 | pre_tmp=np.array(preds1[st:end+1]) 231 | id_tm=ids[st:end+1] 232 | pre_tmp = np.array(pre_tmp[:,plate_label_index]) 233 | pre_id.extend(id_tm) 234 | 235 | pre_tmp = softmax_np(pre_tmp) 236 | 237 | if iters > 0: 238 | tmp_pre_torch= torch.from_numpy(pre_tmp).cuda().float() 239 | tmp=pre_balance(tmp_pre_torch, iter = iters, is_show = is_show) 240 | else: 241 | tmp = pre_tmp 242 | 243 | final_predict = tmp.reshape([len(id_tm), 277]) 244 | predict_tmp = np.zeros([len(id_tm), 1108]) 245 | predict_tmp[:,plate_label_index] = final_predict 246 | pre.extend(predict_tmp.tolist()) 247 | 248 | return pre, pre_id 249 | 250 | def split_277_acc(preds1, ids, val_label, a_dict, a_index_dict, is_softmax = True, remove_list = {}): 251 | ########### input : preds1 is like thie [n_sample,n_class],ids is predictived ids 252 | ########### output :pre is prediction of class,ids is predictived id 253 | 254 | pixel_mean={'id_code': ids} 255 | submission = pd.DataFrame(pixel_mean, columns=['id_code']) 256 | submission ['experiment'] = submission.id_code.str.split("_").apply(lambda a: a[0]+'_'+a[1]) 257 | exper=submission['experiment'] 258 | ee=exper.drop_duplicates() 259 | pre_id=[] 260 | pre=[] 261 | label = [] 262 | 263 | # public_set = {'HUVEC-17','HEPG2-08','RPE-08','U2OS-04'} 264 | 265 | for index1 in ee: 266 | a_index = a_index_dict[index1.split('_')[0]] 267 | assignment = a_dict[int(a_index)] 268 | 269 | # print(index1.split('_')[0]) 270 | plate = int(index1.split('_')[1]) 271 | plate_label_index = list(assignment == plate) 272 | 273 | tmp=submission[submission['experiment']==index1].index.tolist() 274 | st=tmp[0] 275 | end=tmp[-1] 276 | pre_tmp=np.array(preds1[st:end+1]) 277 | id_tm=ids[st:end+1] 278 | lb_tm=val_label[st:end+1] 279 | 280 | pre_tmp = np.array(pre_tmp[:,plate_label_index]) 281 | 282 | if is_softmax: 283 | pre_tmp = softmax_np(pre_tmp) 284 | 285 | tmp = np.asarray(pre_tmp) 286 | final_predict = tmp.reshape([len(id_tm), 277]) 287 | predict_tmp = np.zeros([len(id_tm), 1108]) 288 | 289 | if index1.split('_')[0] not in remove_list: 290 | predict_tmp[:, plate_label_index] = final_predict 291 | # continue 292 | 293 | pre_id.extend(id_tm) 294 | pre.extend(predict_tmp.tolist()) 295 | label.extend(lb_tm) 296 | 297 | return np.asarray(pre), pre_id, label 298 | 299 | def assign_plate(plate): 300 | probabilities = np.array(plate) 301 | cost = probabilities * -1 302 | rows, cols = scipy.optimize.linear_sum_assignment(cost) 303 | chosen_elements = set(zip(rows.tolist(), cols.tolist())) 304 | 305 | for sample in range(cost.shape[0]): 306 | for sirna in range(cost.shape[1]): 307 | if (sample, sirna) not in chosen_elements: 308 | probabilities[sample, sirna] = 0 309 | 310 | return probabilities.argmax(axis=1).tolist() 311 | 312 | def final_pre_balance_plate_probability(preds1, ids, a_dict, a_index_dict, is_softmax = True): 313 | ########### input : preds1 is like thie [n_sample,n_class],ids is predictived ids 314 | ########### output :pre is prediction of class,ids is predictived id 315 | pixel_mean={'id_code': ids} 316 | submission = pd.DataFrame(pixel_mean, columns=['id_code']) 317 | submission ['experiment'] = submission.id_code.str.split("_").apply(lambda a: a[0]+'_'+a[1]) 318 | exper=submission['experiment'] 319 | ee=exper.drop_duplicates() 320 | pre_id=[] 321 | pre=[] 322 | 323 | for index1 in ee: 324 | a_index = a_index_dict[index1.split('_')[0]] 325 | assignment = a_dict[int(a_index)] 326 | print('assignment',a_index) 327 | 328 | plate = int(index1.split('_')[1]) 329 | print(index1) 330 | print(plate) 331 | plate_label_index = list(assignment == plate) 332 | 333 | tmp=submission[submission['experiment']==index1].index.tolist() 334 | st=tmp[0] 335 | end=tmp[-1] 336 | pre_tmp=np.array(preds1[st:end+1]) 337 | id_tm=ids[st:end+1] 338 | pre_tmp = np.array(pre_tmp[:,plate_label_index]) 339 | 340 | if is_softmax: 341 | pre_tmp = softmax_np(pre_tmp) 342 | 343 | tmp_pre_torch= torch.from_numpy(pre_tmp).cuda().float() 344 | tmp = pre_balance(tmp_pre_torch, iter=3000, is_show = True) 345 | 346 | pre_id.extend(id_tm) 347 | final_predict = tmp.reshape([len(id_tm), 277]) 348 | predict_tmp = np.zeros([len(id_tm), 1108]) 349 | predict_tmp[:,plate_label_index] = final_predict 350 | 351 | predict_tmp = assign_plate(predict_tmp) 352 | pre.extend(predict_tmp) 353 | 354 | return pre, pre_id 355 | 356 | def ave_balance(): 357 | import os 358 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 359 | name = r'se__finetune_cell_fold0-4' 360 | avg = read_models(model_pred) 361 | csv_name = write_models(avg, name, is_top1=True) 362 | 363 | n_samples = len(avg) 364 | predict = np.zeros([n_samples, 1108]) 365 | ids = [] 366 | i = 0 367 | for id in avg: 368 | sum = 0.0 369 | for index in avg[id]: 370 | predict[i, index] = avg[id][index] 371 | sum += avg[id][index] 372 | 373 | predict[i] = predict[i] / sum 374 | ids.append(id) 375 | i += 1 376 | 377 | predict = predict.tolist() 378 | prob_to_csv_top5(predict, ids, name+'_top5.csv') 379 | predict, ids = pre_balance_batch_probability(predict, ids) 380 | prob_to_csv_top5(predict, ids, name+'_batch_balance_top5.csv') 381 | 382 | prob = np.asarray(predict) 383 | print(prob.shape) 384 | top = np.argsort(-prob,1)[:,:5] 385 | index = 0 386 | rs = [] 387 | 388 | for (t0,t1,t2,t3,t4) in top: 389 | top_k_label_name = r'' 390 | top_k_label_name += str(t0) 391 | index += 1 392 | rs.append(top_k_label_name) 393 | 394 | pd.DataFrame({'id_code':ids, 'sirna':rs}).to_csv( '{}.csv'.format(name+'_batch_balance'), index=None) 395 | 396 | def balance_ave(): 397 | import os 398 | os.environ["CUDA_VISIBLE_DEVICES"] = "2" 399 | 400 | for item in model_pred: 401 | weight = model_pred[item] 402 | tmp_dict = {} 403 | tmp_dict[item] = weight 404 | avg = read_models(tmp_dict) 405 | 406 | n_samples = len(avg) 407 | predict = np.zeros([n_samples, 1108]) 408 | ids = [] 409 | i = 0 410 | 411 | for id in avg: 412 | sum = 0.0 413 | for index in avg[id]: 414 | predict[i, index] = avg[id][index] 415 | sum += avg[id][index] 416 | 417 | predict[i] = predict[i] / sum 418 | ids.append(id) 419 | i += 1 420 | 421 | predict = predict.tolist() 422 | predict, ids = pre_balance_batch_probability(predict, ids) 423 | prob_to_csv_top5(predict, ids, './ave_predict/'+os.path.split(item)[1]) 424 | 425 | # prob = np.asarray(predict) 426 | # print(prob.shape) 427 | # top = np.argsort(-prob,1)[:,:5] 428 | # index = 0 429 | # rs = [] 430 | # 431 | # for (t0,t1,t2,t3,t4) in top: 432 | # top_k_label_name = r'' 433 | # top_k_label_name += str(t0) 434 | # index += 1 435 | # rs.append(top_k_label_name) 436 | # 437 | # pd.DataFrame({'id_code':ids, 'sirna':rs}).to_csv( '{}.csv'.format(name+'_batch_balance'), index=None) 438 | 439 | def ave_plate_balance(): 440 | import os 441 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 442 | 443 | name = r'tmp' 444 | avg = read_models(model_pred) 445 | 446 | n_samples = len(avg) 447 | predict = np.zeros([n_samples, 1108]) 448 | ids = [] 449 | i = 0 450 | for id in avg: 451 | sum = 0.0 452 | for index in avg[id]: 453 | predict[i, index] = avg[id][index] 454 | sum += avg[id][index] 455 | 456 | predict[i] = predict[i] / sum 457 | ids.append(id) 458 | i += 1 459 | 460 | predict = predict.tolist() 461 | predict, ids = pre_balance_plate_probability(predict, ids) 462 | 463 | prob = np.asarray(predict) 464 | print(prob.shape) 465 | top = np.argsort(-prob,1)[:,:5] 466 | index = 0 467 | rs = [] 468 | 469 | for (t0,t1,t2,t3,t4) in top: 470 | top_k_label_name = r'' 471 | top_k_label_name += str(t0) 472 | index += 1 473 | rs.append(top_k_label_name) 474 | 475 | pd.DataFrame({'id_code':ids, 'sirna':rs}).to_csv( '{}.csv'.format(name+'_batch_balance'), index=None) 476 | 477 | def filter_csv(csv, cell): 478 | df = pd.read_csv(csv) 479 | 480 | id_code = df['id_code'] 481 | siRNA = df['sirna'] 482 | 483 | siRNA_tmp = [] 484 | 485 | for id_tmp, rna_tmp in zip(id_code, siRNA): 486 | if cell not in id_tmp: 487 | print(id_tmp) 488 | print(rna_tmp) 489 | siRNA_tmp.append(1138) 490 | else: 491 | siRNA_tmp.append(rna_tmp) 492 | 493 | pd.DataFrame({'id_code': id_code, 'sirna': siRNA_tmp}).to_csv(csv.replace('.csv','_with_all.csv'),index=None) 494 | 495 | def replace(cell, origin, new, csv ): 496 | df = pd.read_csv(origin) 497 | id_code = df['id_code'] 498 | siRNA = df['sirna'] 499 | 500 | id_code_tmp = [] 501 | siRNA_tmp = [] 502 | 503 | for id_tmp, rna_tmp in zip(id_code, siRNA): 504 | if cell in id_tmp: 505 | continue 506 | else: 507 | siRNA_tmp.append(rna_tmp) 508 | id_code_tmp.append(id_tmp) 509 | 510 | df = pd.read_csv(new) 511 | id_code = df['id_code'] 512 | siRNA = df['sirna'] 513 | 514 | for id_tmp, rna_tmp in zip(id_code, siRNA): 515 | if cell in id_tmp: 516 | siRNA_tmp.append(rna_tmp) 517 | id_code_tmp.append(id_tmp) 518 | print(id_tmp) 519 | print(rna_tmp) 520 | 521 | pd.DataFrame({'id_code': id_code_tmp, 'sirna': siRNA_tmp}).to_csv(csv,index=None) 522 | return csv 523 | 524 | def get_valid_label(fold = 0): 525 | df = pd.read_csv(path_data + '/valid_fold_'+str(fold)+'.csv') 526 | cell = 'U2OS' 527 | if cell is not None: 528 | print('cell type: ' + cell) 529 | df_all = [] 530 | for i in range(100): 531 | df_ = pd.DataFrame(df.loc[df['experiment'] == (cell + '-' + str(100 + i)[-2:])]) 532 | df_all.append(df_) 533 | df = pd.concat(df_all) 534 | 535 | val_label = np.asarray(list(df[r'sirna'])).reshape([-1]) 536 | 537 | return val_label 538 | 539 | def get_test_id(cell): 540 | df = pd.read_csv(data_dir + '/test.csv') 541 | 542 | if cell is not None: 543 | df_all = [] 544 | for i in range(100): 545 | df_ = pd.DataFrame(df.loc[df['experiment'] == (cell + '-' + str(100 + i)[-2:])]) 546 | df_all.append(df_) 547 | df = pd.concat(df_all) 548 | 549 | return list(df[r'id_code']) 550 | 551 | def get_npy( dir, cell='HUVEC', mode='test.npy'): 552 | 553 | ids = get_test_id(cell=cell) 554 | print(len(ids)) 555 | 556 | all_list = [] 557 | dir_list = os.listdir(dir) 558 | 559 | for dir_tmp in dir_list: 560 | npy_tmp = os.path.join(dir, dir_tmp) 561 | if cell in npy_tmp and mode in npy_tmp: 562 | all_list.append(npy_tmp) 563 | 564 | tmp_list = [] 565 | for npy in all_list: 566 | tmp = np.load(npy) 567 | tmp_list.append(np.asarray(tmp)) 568 | 569 | tmp = np.mean(tmp_list, axis=0) 570 | return tmp 571 | 572 | if __name__ == '__main__': 573 | import os 574 | 575 | only_public = False 576 | 577 | for cell in ['U2OS', 'RPE', 'HEPG2', 'HUVEC']: 578 | ids = get_test_id(cell=cell) 579 | 580 | test = get_npy(dir=model_save_path+r'/xception_large_6channel_512_fold0_final/checkpoint/max_valid_model_' 581 | +cell+r'_semi_snapshot_all_npy', 582 | cell=cell, mode='test.npy') 583 | 584 | predict, ids = pre_balance_plate_probability(test, ids, plate_dict, a_dict, only_public=False) 585 | prob = np.asarray(predict) 586 | 587 | print(prob.shape) 588 | top = np.argsort(-prob, 1)[:, :5] 589 | index = 0 590 | rs = [] 591 | 592 | for (t0, t1, t2, t3, t4) in top: 593 | top_k_label_name = r'' 594 | top_k_label_name += str(t0) 595 | rs.append(top_k_label_name) 596 | 597 | name = cell 598 | pd.DataFrame({'id_code': ids, 'sirna': rs}).to_csv('{}.csv'.format(name), index=None) 599 | print(name) 600 | 601 | 602 | csv = replace(cell = 'RPE', 603 | origin = r'./process/sample_submission.csv', 604 | new = 'RPE.csv', 605 | csv = 'submission.csv') 606 | 607 | csv = replace(cell = 'HEPG2', 608 | origin = csv, 609 | new = 'HEPG2.csv', 610 | csv = csv) 611 | 612 | csv = replace(cell = 'HUVEC', 613 | origin = csv, 614 | new ='HUVEC.csv', 615 | csv = csv) 616 | 617 | csv = replace(cell = 'U2OS', 618 | origin = csv, 619 | new = 'U2OS.csv', 620 | csv = csv) 621 | 622 | os.remove('./U2OS.csv') 623 | os.remove('./RPE.csv') 624 | os.remove('./HEPG2.csv') 625 | os.remove('./HUVEC.csv') 626 | print(csv) -------------------------------------------------------------------------------- /include.py: -------------------------------------------------------------------------------- 1 | import os 2 | from datetime import datetime 3 | PROJECT_PATH = os.path.dirname(os.path.realpath(__file__)) 4 | IDENTIFIER = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') 5 | 6 | #numerical libs 7 | import math 8 | import numpy as np 9 | import random 10 | import PIL 11 | import cv2 12 | import matplotlib 13 | matplotlib.use('TkAgg') 14 | #matplotlib.use('WXAgg') 15 | #matplotlib.use('Qt4Agg') 16 | #matplotlib.use('Qt5Agg') #Qt4Agg 17 | print(matplotlib.get_backend()) 18 | #print(matplotlib.__version__) 19 | 20 | # torch libs 21 | import torch 22 | from torch.utils.data.dataset import Dataset 23 | from torch.utils.data import DataLoader 24 | from torch.utils.data.sampler import * 25 | 26 | import torch.nn as nn 27 | import torch.nn.functional as F 28 | import torch.optim as optim 29 | from torch.autograd import Variable 30 | from torch.nn.parallel.data_parallel import data_parallel 31 | 32 | # std libs 33 | import collections 34 | import copy 35 | import numbers 36 | import inspect 37 | import shutil 38 | from timeit import default_timer as timer 39 | import itertools 40 | from collections import OrderedDict 41 | 42 | import csv 43 | import pandas as pd 44 | import pickle 45 | import glob 46 | import sys 47 | from distutils.dir_util import copy_tree 48 | import time 49 | 50 | import matplotlib.pyplot as plt 51 | from mpl_toolkits.mplot3d import Axes3D 52 | 53 | from skimage.transform import resize as skimage_resize 54 | 55 | # constant # 56 | PI = np.pi 57 | INF = np.inf 58 | EPS = 1e-12 59 | 60 | #--------------------------------------------------------------------------------- 61 | class Struct(object): 62 | def __init__(self, **kwargs): 63 | self.__dict__.update(kwargs) 64 | 65 | #--------------------------------------------------------------------------------- 66 | def remove_comments(lines, token='#'): 67 | """ Generator. Strips comments and whitespace from input lines. 68 | """ 69 | 70 | l = [] 71 | for line in lines: 72 | s = line.split(token, 1)[0].strip() 73 | if s != '': 74 | l.append(s) 75 | return l 76 | 77 | def remove(file): 78 | if os.path.exists(file): os.remove(file) 79 | 80 | 81 | def empty(dir): 82 | if os.path.isdir(dir): 83 | shutil.rmtree(dir, ignore_errors=True) 84 | else: 85 | os.makedirs(dir) 86 | 87 | 88 | # http://stackoverflow.com/questions/34950201/pycharm-print-end-r-statement-not-working 89 | class Logger(object): 90 | def __init__(self): 91 | self.terminal = sys.stdout #stdout 92 | self.file = None 93 | 94 | def open(self, file, mode=None): 95 | if mode is None: mode ='w' 96 | self.file = open(file, mode) 97 | 98 | def write(self, message, is_terminal=1, is_file=1 ): 99 | if '\r' in message: is_file=0 100 | 101 | if is_terminal == 1: 102 | self.terminal.write(message) 103 | self.terminal.flush() 104 | #time.sleep(1) 105 | 106 | if is_file == 1: 107 | self.file.write(message) 108 | self.file.flush() 109 | 110 | def flush(self): 111 | # this flush method is needed for python 3 compatibility. 112 | # this handles the flush command by doing nothing. 113 | # you might want to specify some extra behavior here. 114 | pass 115 | 116 | # io ------------------------------------ 117 | def write_list_to_file(strings, list_file): 118 | with open(list_file, 'w') as f: 119 | for s in strings: 120 | f.write('%s\n'%str(s)) 121 | pass 122 | 123 | 124 | def read_list_from_file(list_file, comment='#', func=None): 125 | with open(list_file) as f: 126 | lines = f.readlines() 127 | 128 | strings=[] 129 | for line in lines: 130 | s = line.split(comment, 1)[0].strip() 131 | if s != '': 132 | strings.append(s) 133 | if func is not None: 134 | strings=[func(s) for s in strings] 135 | 136 | return strings 137 | 138 | def load_pickle_file(pickle_file): 139 | with open(pickle_file,'rb') as f: 140 | x = pickle.load(f) 141 | return x 142 | 143 | def save_pickle_file(pickle_file, x): 144 | with open(pickle_file, 'wb') as f: 145 | pickle.dump(x, f, pickle.HIGHEST_PROTOCOL) 146 | 147 | 148 | def backup_project_as_zip(project_dir, zip_file): 149 | assert(os.path.isdir(project_dir)) 150 | assert(os.path.isdir(os.path.dirname(zip_file))) 151 | shutil.make_archive(zip_file.replace('.zip',''), 'zip', project_dir) 152 | pass 153 | 154 | # etc ------------------------------------ 155 | def time_to_str(t, mode='min'): 156 | if mode=='min': 157 | t = int(t)/60 158 | hr = t//60 159 | min = t%60 160 | return '%2d hr %02d min'%(hr,min) 161 | elif mode=='sec': 162 | t = int(t) 163 | min = t//60 164 | sec = t%60 165 | return '%2d min %02d sec'%(min,sec) 166 | else: 167 | raise NotImplementedError 168 | 169 | def np_float32_to_uint8(x, scale=255): 170 | return (x*scale).astype(np.uint8) 171 | 172 | def np_uint8_to_float32(x, scale=255): 173 | return (x/scale).astype(np.float32) 174 | 175 | 176 | 177 | -------------------------------------------------------------------------------- /leaks/exp_to_group.csv: -------------------------------------------------------------------------------- 1 | assignment,experiment 2 | 0,HEPG2-01 3 | 0,HEPG2-02 4 | 0,HEPG2-03 5 | 0,HEPG2-04 6 | 1,HEPG2-05 7 | 0,HEPG2-06 8 | 0,HEPG2-07 9 | 3,HEPG2-08 10 | 1,HEPG2-09 11 | 0,HEPG2-10 12 | 0,HEPG2-11 13 | 0,RPE-01 14 | 0,RPE-02 15 | 2,RPE-03 16 | 0,RPE-04 17 | 0,RPE-05 18 | 1,RPE-06 19 | 0,RPE-07 20 | 1,RPE-08 21 | 0,RPE-09 22 | 0,RPE-10 23 | 0,RPE-11 24 | 1,U2OS-01 25 | 0,U2OS-02 26 | 0,U2OS-03 27 | 2,U2OS-04 28 | 3,U2OS-05 29 | 0,HUVEC-01 30 | 0,HUVEC-02 31 | 0,HUVEC-03 32 | 0,HUVEC-04 33 | 0,HUVEC-05 34 | 0,HUVEC-06 35 | 1,HUVEC-07 36 | 0,HUVEC-08 37 | 2,HUVEC-09 38 | 2,HUVEC-10 39 | 1,HUVEC-11 40 | 0,HUVEC-12 41 | 1,HUVEC-13 42 | 1,HUVEC-14 43 | 2,HUVEC-15 44 | 0,HUVEC-16 45 | 0,HUVEC-17 46 | 0,HUVEC-18 47 | 2,HUVEC-19 48 | 2,HUVEC-20 49 | 3,HUVEC-21 50 | 0,HUVEC-22 51 | 0,HUVEC-23 52 | 3,HUVEC-24 53 | 54 | 55 | 56 | 57 | 58 | 59 | -------------------------------------------------------------------------------- /leaks/plate_groups.csv: -------------------------------------------------------------------------------- 1 | p0,p1,p2,p3,class 2 | 4,2,3,1,0 3 | 1,3,4,2,1 4 | 2,4,1,3,2 5 | 1,3,4,2,3 6 | 3,1,2,4,4 7 | 1,3,4,2,5 8 | 1,3,4,2,6 9 | 2,4,1,3,7 10 | 1,3,4,2,8 11 | 4,2,3,1,9 12 | 1,3,4,2,10 13 | 2,4,1,3,11 14 | 4,2,3,1,12 15 | 4,2,3,1,13 16 | 4,2,3,1,14 17 | 4,2,3,1,15 18 | 4,2,3,1,16 19 | 4,2,3,1,17 20 | 4,2,3,1,18 21 | 2,4,1,3,19 22 | 2,4,1,3,20 23 | 3,1,2,4,21 24 | 2,4,1,3,22 25 | 4,2,3,1,23 26 | 3,1,2,4,24 27 | 2,4,1,3,25 28 | 1,3,4,2,26 29 | 2,4,1,3,27 30 | 2,4,1,3,28 31 | 3,1,2,4,29 32 | 4,2,3,1,30 33 | 4,2,3,1,31 34 | 3,1,2,4,32 35 | 1,3,4,2,33 36 | 3,1,2,4,34 37 | 3,1,2,4,35 38 | 4,2,3,1,36 39 | 4,2,3,1,37 40 | 3,1,2,4,38 41 | 3,1,2,4,39 42 | 1,3,4,2,40 43 | 2,4,1,3,41 44 | 2,4,1,3,42 45 | 2,4,1,3,43 46 | 2,4,1,3,44 47 | 4,2,3,1,45 48 | 1,3,4,2,46 49 | 2,4,1,3,47 50 | 1,3,4,2,48 51 | 3,1,2,4,49 52 | 4,2,3,1,50 53 | 3,1,2,4,51 54 | 2,4,1,3,52 55 | 3,1,2,4,53 56 | 4,2,3,1,54 57 | 1,3,4,2,55 58 | 1,3,4,2,56 59 | 2,4,1,3,57 60 | 4,2,3,1,58 61 | 1,3,4,2,59 62 | 2,4,1,3,60 63 | 2,4,1,3,61 64 | 3,1,2,4,62 65 | 2,4,1,3,63 66 | 2,4,1,3,64 67 | 3,1,2,4,65 68 | 2,4,1,3,66 69 | 4,2,3,1,67 70 | 4,2,3,1,68 71 | 4,2,3,1,69 72 | 1,3,4,2,70 73 | 4,2,3,1,71 74 | 2,4,1,3,72 75 | 4,2,3,1,73 76 | 1,3,4,2,74 77 | 2,4,1,3,75 78 | 3,1,2,4,76 79 | 1,3,4,2,77 80 | 2,4,1,3,78 81 | 4,2,3,1,79 82 | 1,3,4,2,80 83 | 2,4,1,3,81 84 | 4,2,3,1,82 85 | 2,4,1,3,83 86 | 3,1,2,4,84 87 | 4,2,3,1,85 88 | 3,1,2,4,86 89 | 2,4,1,3,87 90 | 1,3,4,2,88 91 | 1,3,4,2,89 92 | 3,1,2,4,90 93 | 1,3,4,2,91 94 | 2,4,1,3,92 95 | 2,4,1,3,93 96 | 1,3,4,2,94 97 | 2,4,1,3,95 98 | 3,1,2,4,96 99 | 4,2,3,1,97 100 | 3,1,2,4,98 101 | 3,1,2,4,99 102 | 2,4,1,3,100 103 | 2,4,1,3,101 104 | 1,3,4,2,102 105 | 4,2,3,1,103 106 | 4,2,3,1,104 107 | 1,3,4,2,105 108 | 4,2,3,1,106 109 | 4,2,3,1,107 110 | 2,4,1,3,108 111 | 1,3,4,2,109 112 | 4,2,3,1,110 113 | 2,4,1,3,111 114 | 1,3,4,2,112 115 | 1,3,4,2,113 116 | 3,1,2,4,114 117 | 3,1,2,4,115 118 | 4,2,3,1,116 119 | 2,4,1,3,117 120 | 2,4,1,3,118 121 | 3,1,2,4,119 122 | 3,1,2,4,120 123 | 3,1,2,4,121 124 | 3,1,2,4,122 125 | 3,1,2,4,123 126 | 1,3,4,2,124 127 | 3,1,2,4,125 128 | 1,3,4,2,126 129 | 3,1,2,4,127 130 | 4,2,3,1,128 131 | 3,1,2,4,129 132 | 2,4,1,3,130 133 | 2,4,1,3,131 134 | 2,4,1,3,132 135 | 1,3,4,2,133 136 | 3,1,2,4,134 137 | 3,1,2,4,135 138 | 1,3,4,2,136 139 | 4,2,3,1,137 140 | 2,4,1,3,138 141 | 4,2,3,1,139 142 | 4,2,3,1,140 143 | 4,2,3,1,141 144 | 3,1,2,4,142 145 | 4,2,3,1,143 146 | 1,3,4,2,144 147 | 3,1,2,4,145 148 | 2,4,1,3,146 149 | 1,3,4,2,147 150 | 3,1,2,4,148 151 | 2,4,1,3,149 152 | 4,2,3,1,150 153 | 3,1,2,4,151 154 | 4,2,3,1,152 155 | 2,4,1,3,153 156 | 4,2,3,1,154 157 | 1,3,4,2,155 158 | 4,2,3,1,156 159 | 1,3,4,2,157 160 | 1,3,4,2,158 161 | 3,1,2,4,159 162 | 2,4,1,3,160 163 | 4,2,3,1,161 164 | 1,3,4,2,162 165 | 3,1,2,4,163 166 | 2,4,1,3,164 167 | 2,4,1,3,165 168 | 2,4,1,3,166 169 | 3,1,2,4,167 170 | 3,1,2,4,168 171 | 2,4,1,3,169 172 | 3,1,2,4,170 173 | 3,1,2,4,171 174 | 1,3,4,2,172 175 | 3,1,2,4,173 176 | 3,1,2,4,174 177 | 2,4,1,3,175 178 | 4,2,3,1,176 179 | 1,3,4,2,177 180 | 2,4,1,3,178 181 | 4,2,3,1,179 182 | 3,1,2,4,180 183 | 1,3,4,2,181 184 | 2,4,1,3,182 185 | 3,1,2,4,183 186 | 3,1,2,4,184 187 | 3,1,2,4,185 188 | 3,1,2,4,186 189 | 4,2,3,1,187 190 | 1,3,4,2,188 191 | 2,4,1,3,189 192 | 3,1,2,4,190 193 | 3,1,2,4,191 194 | 1,3,4,2,192 195 | 4,2,3,1,193 196 | 2,4,1,3,194 197 | 1,3,4,2,195 198 | 2,4,1,3,196 199 | 3,1,2,4,197 200 | 3,1,2,4,198 201 | 4,2,3,1,199 202 | 4,2,3,1,200 203 | 4,2,3,1,201 204 | 4,2,3,1,202 205 | 1,3,4,2,203 206 | 4,2,3,1,204 207 | 4,2,3,1,205 208 | 2,4,1,3,206 209 | 2,4,1,3,207 210 | 4,2,3,1,208 211 | 4,2,3,1,209 212 | 3,1,2,4,210 213 | 1,3,4,2,211 214 | 1,3,4,2,212 215 | 4,2,3,1,213 216 | 3,1,2,4,214 217 | 3,1,2,4,215 218 | 2,4,1,3,216 219 | 1,3,4,2,217 220 | 2,4,1,3,218 221 | 2,4,1,3,219 222 | 4,2,3,1,220 223 | 1,3,4,2,221 224 | 1,3,4,2,222 225 | 4,2,3,1,223 226 | 1,3,4,2,224 227 | 2,4,1,3,225 228 | 4,2,3,1,226 229 | 4,2,3,1,227 230 | 2,4,1,3,228 231 | 3,1,2,4,229 232 | 4,2,3,1,230 233 | 4,2,3,1,231 234 | 2,4,1,3,232 235 | 2,4,1,3,233 236 | 3,1,2,4,234 237 | 4,2,3,1,235 238 | 2,4,1,3,236 239 | 3,1,2,4,237 240 | 1,3,4,2,238 241 | 1,3,4,2,239 242 | 4,2,3,1,240 243 | 4,2,3,1,241 244 | 1,3,4,2,242 245 | 4,2,3,1,243 246 | 3,1,2,4,244 247 | 2,4,1,3,245 248 | 3,1,2,4,246 249 | 4,2,3,1,247 250 | 3,1,2,4,248 251 | 4,2,3,1,249 252 | 3,1,2,4,250 253 | 4,2,3,1,251 254 | 3,1,2,4,252 255 | 1,3,4,2,253 256 | 1,3,4,2,254 257 | 3,1,2,4,255 258 | 1,3,4,2,256 259 | 2,4,1,3,257 260 | 1,3,4,2,258 261 | 3,1,2,4,259 262 | 4,2,3,1,260 263 | 4,2,3,1,261 264 | 2,4,1,3,262 265 | 2,4,1,3,263 266 | 4,2,3,1,264 267 | 3,1,2,4,265 268 | 1,3,4,2,266 269 | 3,1,2,4,267 270 | 1,3,4,2,268 271 | 2,4,1,3,269 272 | 4,2,3,1,270 273 | 3,1,2,4,271 274 | 1,3,4,2,272 275 | 4,2,3,1,273 276 | 1,3,4,2,274 277 | 4,2,3,1,275 278 | 1,3,4,2,276 279 | 1,3,4,2,277 280 | 3,1,2,4,278 281 | 3,1,2,4,279 282 | 4,2,3,1,280 283 | 1,3,4,2,281 284 | 1,3,4,2,282 285 | 3,1,2,4,283 286 | 2,4,1,3,284 287 | 2,4,1,3,285 288 | 4,2,3,1,286 289 | 1,3,4,2,287 290 | 4,2,3,1,288 291 | 3,1,2,4,289 292 | 1,3,4,2,290 293 | 3,1,2,4,291 294 | 4,2,3,1,292 295 | 1,3,4,2,293 296 | 4,2,3,1,294 297 | 1,3,4,2,295 298 | 1,3,4,2,296 299 | 4,2,3,1,297 300 | 4,2,3,1,298 301 | 2,4,1,3,299 302 | 3,1,2,4,300 303 | 1,3,4,2,301 304 | 1,3,4,2,302 305 | 2,4,1,3,303 306 | 4,2,3,1,304 307 | 2,4,1,3,305 308 | 4,2,3,1,306 309 | 3,1,2,4,307 310 | 1,3,4,2,308 311 | 1,3,4,2,309 312 | 3,1,2,4,310 313 | 3,1,2,4,311 314 | 1,3,4,2,312 315 | 1,3,4,2,313 316 | 3,1,2,4,314 317 | 2,4,1,3,315 318 | 4,2,3,1,316 319 | 2,4,1,3,317 320 | 4,2,3,1,318 321 | 3,1,2,4,319 322 | 1,3,4,2,320 323 | 4,2,3,1,321 324 | 3,1,2,4,322 325 | 4,2,3,1,323 326 | 1,3,4,2,324 327 | 4,2,3,1,325 328 | 3,1,2,4,326 329 | 2,4,1,3,327 330 | 4,2,3,1,328 331 | 2,4,1,3,329 332 | 2,4,1,3,330 333 | 4,2,3,1,331 334 | 4,2,3,1,332 335 | 2,4,1,3,333 336 | 2,4,1,3,334 337 | 3,1,2,4,335 338 | 2,4,1,3,336 339 | 3,1,2,4,337 340 | 3,1,2,4,338 341 | 1,3,4,2,339 342 | 1,3,4,2,340 343 | 1,3,4,2,341 344 | 2,4,1,3,342 345 | 4,2,3,1,343 346 | 4,2,3,1,344 347 | 1,3,4,2,345 348 | 1,3,4,2,346 349 | 3,1,2,4,347 350 | 1,3,4,2,348 351 | 2,4,1,3,349 352 | 1,3,4,2,350 353 | 2,4,1,3,351 354 | 1,3,4,2,352 355 | 2,4,1,3,353 356 | 4,2,3,1,354 357 | 1,3,4,2,355 358 | 4,2,3,1,356 359 | 2,4,1,3,357 360 | 4,2,3,1,358 361 | 3,1,2,4,359 362 | 1,3,4,2,360 363 | 1,3,4,2,361 364 | 3,1,2,4,362 365 | 3,1,2,4,363 366 | 2,4,1,3,364 367 | 3,1,2,4,365 368 | 1,3,4,2,366 369 | 4,2,3,1,367 370 | 3,1,2,4,368 371 | 3,1,2,4,369 372 | 4,2,3,1,370 373 | 1,3,4,2,371 374 | 1,3,4,2,372 375 | 1,3,4,2,373 376 | 4,2,3,1,374 377 | 3,1,2,4,375 378 | 2,4,1,3,376 379 | 3,1,2,4,377 380 | 4,2,3,1,378 381 | 1,3,4,2,379 382 | 3,1,2,4,380 383 | 2,4,1,3,381 384 | 3,1,2,4,382 385 | 2,4,1,3,383 386 | 2,4,1,3,384 387 | 1,3,4,2,385 388 | 4,2,3,1,386 389 | 3,1,2,4,387 390 | 1,3,4,2,388 391 | 1,3,4,2,389 392 | 2,4,1,3,390 393 | 1,3,4,2,391 394 | 1,3,4,2,392 395 | 2,4,1,3,393 396 | 3,1,2,4,394 397 | 1,3,4,2,395 398 | 3,1,2,4,396 399 | 1,3,4,2,397 400 | 1,3,4,2,398 401 | 1,3,4,2,399 402 | 1,3,4,2,400 403 | 3,1,2,4,401 404 | 3,1,2,4,402 405 | 4,2,3,1,403 406 | 2,4,1,3,404 407 | 3,1,2,4,405 408 | 4,2,3,1,406 409 | 2,4,1,3,407 410 | 2,4,1,3,408 411 | 2,4,1,3,409 412 | 3,1,2,4,410 413 | 1,3,4,2,411 414 | 2,4,1,3,412 415 | 1,3,4,2,413 416 | 2,4,1,3,414 417 | 1,3,4,2,415 418 | 4,2,3,1,416 419 | 1,3,4,2,417 420 | 4,2,3,1,418 421 | 4,2,3,1,419 422 | 4,2,3,1,420 423 | 2,4,1,3,421 424 | 3,1,2,4,422 425 | 2,4,1,3,423 426 | 2,4,1,3,424 427 | 2,4,1,3,425 428 | 4,2,3,1,426 429 | 4,2,3,1,427 430 | 2,4,1,3,428 431 | 2,4,1,3,429 432 | 4,2,3,1,430 433 | 1,3,4,2,431 434 | 3,1,2,4,432 435 | 4,2,3,1,433 436 | 2,4,1,3,434 437 | 3,1,2,4,435 438 | 1,3,4,2,436 439 | 3,1,2,4,437 440 | 3,1,2,4,438 441 | 1,3,4,2,439 442 | 4,2,3,1,440 443 | 4,2,3,1,441 444 | 1,3,4,2,442 445 | 3,1,2,4,443 446 | 3,1,2,4,444 447 | 1,3,4,2,445 448 | 3,1,2,4,446 449 | 2,4,1,3,447 450 | 3,1,2,4,448 451 | 3,1,2,4,449 452 | 1,3,4,2,450 453 | 4,2,3,1,451 454 | 2,4,1,3,452 455 | 3,1,2,4,453 456 | 2,4,1,3,454 457 | 4,2,3,1,455 458 | 1,3,4,2,456 459 | 2,4,1,3,457 460 | 4,2,3,1,458 461 | 1,3,4,2,459 462 | 1,3,4,2,460 463 | 2,4,1,3,461 464 | 2,4,1,3,462 465 | 3,1,2,4,463 466 | 3,1,2,4,464 467 | 3,1,2,4,465 468 | 2,4,1,3,466 469 | 2,4,1,3,467 470 | 4,2,3,1,468 471 | 3,1,2,4,469 472 | 4,2,3,1,470 473 | 1,3,4,2,471 474 | 4,2,3,1,472 475 | 2,4,1,3,473 476 | 4,2,3,1,474 477 | 2,4,1,3,475 478 | 4,2,3,1,476 479 | 3,1,2,4,477 480 | 1,3,4,2,478 481 | 2,4,1,3,479 482 | 3,1,2,4,480 483 | 1,3,4,2,481 484 | 1,3,4,2,482 485 | 2,4,1,3,483 486 | 3,1,2,4,484 487 | 3,1,2,4,485 488 | 1,3,4,2,486 489 | 3,1,2,4,487 490 | 4,2,3,1,488 491 | 1,3,4,2,489 492 | 3,1,2,4,490 493 | 3,1,2,4,491 494 | 2,4,1,3,492 495 | 4,2,3,1,493 496 | 3,1,2,4,494 497 | 2,4,1,3,495 498 | 4,2,3,1,496 499 | 1,3,4,2,497 500 | 4,2,3,1,498 501 | 2,4,1,3,499 502 | 1,3,4,2,500 503 | 3,1,2,4,501 504 | 2,4,1,3,502 505 | 1,3,4,2,503 506 | 3,1,2,4,504 507 | 1,3,4,2,505 508 | 1,3,4,2,506 509 | 2,4,1,3,507 510 | 1,3,4,2,508 511 | 4,2,3,1,509 512 | 3,1,2,4,510 513 | 2,4,1,3,511 514 | 3,1,2,4,512 515 | 1,3,4,2,513 516 | 1,3,4,2,514 517 | 3,1,2,4,515 518 | 3,1,2,4,516 519 | 2,4,1,3,517 520 | 4,2,3,1,518 521 | 4,2,3,1,519 522 | 3,1,2,4,520 523 | 2,4,1,3,521 524 | 4,2,3,1,522 525 | 1,3,4,2,523 526 | 4,2,3,1,524 527 | 3,1,2,4,525 528 | 4,2,3,1,526 529 | 3,1,2,4,527 530 | 4,2,3,1,528 531 | 2,4,1,3,529 532 | 3,1,2,4,530 533 | 4,2,3,1,531 534 | 1,3,4,2,532 535 | 3,1,2,4,533 536 | 4,2,3,1,534 537 | 2,4,1,3,535 538 | 2,4,1,3,536 539 | 1,3,4,2,537 540 | 1,3,4,2,538 541 | 2,4,1,3,539 542 | 4,2,3,1,540 543 | 1,3,4,2,541 544 | 2,4,1,3,542 545 | 1,3,4,2,543 546 | 2,4,1,3,544 547 | 3,1,2,4,545 548 | 3,1,2,4,546 549 | 1,3,4,2,547 550 | 3,1,2,4,548 551 | 1,3,4,2,549 552 | 2,4,1,3,550 553 | 3,1,2,4,551 554 | 1,3,4,2,552 555 | 2,4,1,3,553 556 | 3,1,2,4,554 557 | 1,3,4,2,555 558 | 1,3,4,2,556 559 | 1,3,4,2,557 560 | 2,4,1,3,558 561 | 1,3,4,2,559 562 | 4,2,3,1,560 563 | 4,2,3,1,561 564 | 2,4,1,3,562 565 | 2,4,1,3,563 566 | 3,1,2,4,564 567 | 1,3,4,2,565 568 | 1,3,4,2,566 569 | 1,3,4,2,567 570 | 4,2,3,1,568 571 | 2,4,1,3,569 572 | 3,1,2,4,570 573 | 2,4,1,3,571 574 | 1,3,4,2,572 575 | 4,2,3,1,573 576 | 4,2,3,1,574 577 | 4,2,3,1,575 578 | 1,3,4,2,576 579 | 1,3,4,2,577 580 | 3,1,2,4,578 581 | 2,4,1,3,579 582 | 2,4,1,3,580 583 | 1,3,4,2,581 584 | 1,3,4,2,582 585 | 3,1,2,4,583 586 | 2,4,1,3,584 587 | 1,3,4,2,585 588 | 3,1,2,4,586 589 | 4,2,3,1,587 590 | 1,3,4,2,588 591 | 4,2,3,1,589 592 | 3,1,2,4,590 593 | 4,2,3,1,591 594 | 4,2,3,1,592 595 | 3,1,2,4,593 596 | 4,2,3,1,594 597 | 3,1,2,4,595 598 | 4,2,3,1,596 599 | 2,4,1,3,597 600 | 1,3,4,2,598 601 | 3,1,2,4,599 602 | 2,4,1,3,600 603 | 2,4,1,3,601 604 | 4,2,3,1,602 605 | 3,1,2,4,603 606 | 1,3,4,2,604 607 | 1,3,4,2,605 608 | 1,3,4,2,606 609 | 1,3,4,2,607 610 | 3,1,2,4,608 611 | 4,2,3,1,609 612 | 4,2,3,1,610 613 | 1,3,4,2,611 614 | 4,2,3,1,612 615 | 1,3,4,2,613 616 | 3,1,2,4,614 617 | 2,4,1,3,615 618 | 3,1,2,4,616 619 | 3,1,2,4,617 620 | 1,3,4,2,618 621 | 4,2,3,1,619 622 | 1,3,4,2,620 623 | 4,2,3,1,621 624 | 1,3,4,2,622 625 | 2,4,1,3,623 626 | 1,3,4,2,624 627 | 3,1,2,4,625 628 | 1,3,4,2,626 629 | 3,1,2,4,627 630 | 4,2,3,1,628 631 | 3,1,2,4,629 632 | 4,2,3,1,630 633 | 4,2,3,1,631 634 | 3,1,2,4,632 635 | 3,1,2,4,633 636 | 4,2,3,1,634 637 | 4,2,3,1,635 638 | 3,1,2,4,636 639 | 3,1,2,4,637 640 | 4,2,3,1,638 641 | 2,4,1,3,639 642 | 2,4,1,3,640 643 | 4,2,3,1,641 644 | 2,4,1,3,642 645 | 3,1,2,4,643 646 | 3,1,2,4,644 647 | 4,2,3,1,645 648 | 3,1,2,4,646 649 | 1,3,4,2,647 650 | 3,1,2,4,648 651 | 3,1,2,4,649 652 | 2,4,1,3,650 653 | 2,4,1,3,651 654 | 3,1,2,4,652 655 | 1,3,4,2,653 656 | 2,4,1,3,654 657 | 1,3,4,2,655 658 | 4,2,3,1,656 659 | 1,3,4,2,657 660 | 3,1,2,4,658 661 | 2,4,1,3,659 662 | 1,3,4,2,660 663 | 2,4,1,3,661 664 | 2,4,1,3,662 665 | 4,2,3,1,663 666 | 3,1,2,4,664 667 | 3,1,2,4,665 668 | 3,1,2,4,666 669 | 2,4,1,3,667 670 | 4,2,3,1,668 671 | 2,4,1,3,669 672 | 1,3,4,2,670 673 | 3,1,2,4,671 674 | 3,1,2,4,672 675 | 1,3,4,2,673 676 | 3,1,2,4,674 677 | 4,2,3,1,675 678 | 1,3,4,2,676 679 | 3,1,2,4,677 680 | 3,1,2,4,678 681 | 3,1,2,4,679 682 | 4,2,3,1,680 683 | 1,3,4,2,681 684 | 2,4,1,3,682 685 | 2,4,1,3,683 686 | 4,2,3,1,684 687 | 2,4,1,3,685 688 | 3,1,2,4,686 689 | 1,3,4,2,687 690 | 4,2,3,1,688 691 | 1,3,4,2,689 692 | 3,1,2,4,690 693 | 3,1,2,4,691 694 | 1,3,4,2,692 695 | 4,2,3,1,693 696 | 2,4,1,3,694 697 | 1,3,4,2,695 698 | 4,2,3,1,696 699 | 2,4,1,3,697 700 | 4,2,3,1,698 701 | 1,3,4,2,699 702 | 1,3,4,2,700 703 | 2,4,1,3,701 704 | 4,2,3,1,702 705 | 1,3,4,2,703 706 | 1,3,4,2,704 707 | 1,3,4,2,705 708 | 2,4,1,3,706 709 | 3,1,2,4,707 710 | 1,3,4,2,708 711 | 2,4,1,3,709 712 | 4,2,3,1,710 713 | 4,2,3,1,711 714 | 4,2,3,1,712 715 | 3,1,2,4,713 716 | 2,4,1,3,714 717 | 1,3,4,2,715 718 | 3,1,2,4,716 719 | 2,4,1,3,717 720 | 3,1,2,4,718 721 | 3,1,2,4,719 722 | 4,2,3,1,720 723 | 2,4,1,3,721 724 | 2,4,1,3,722 725 | 2,4,1,3,723 726 | 2,4,1,3,724 727 | 4,2,3,1,725 728 | 2,4,1,3,726 729 | 3,1,2,4,727 730 | 4,2,3,1,728 731 | 2,4,1,3,729 732 | 2,4,1,3,730 733 | 2,4,1,3,731 734 | 2,4,1,3,732 735 | 2,4,1,3,733 736 | 1,3,4,2,734 737 | 2,4,1,3,735 738 | 2,4,1,3,736 739 | 4,2,3,1,737 740 | 3,1,2,4,738 741 | 2,4,1,3,739 742 | 3,1,2,4,740 743 | 4,2,3,1,741 744 | 1,3,4,2,742 745 | 1,3,4,2,743 746 | 1,3,4,2,744 747 | 1,3,4,2,745 748 | 4,2,3,1,746 749 | 2,4,1,3,747 750 | 3,1,2,4,748 751 | 3,1,2,4,749 752 | 1,3,4,2,750 753 | 2,4,1,3,751 754 | 2,4,1,3,752 755 | 2,4,1,3,753 756 | 2,4,1,3,754 757 | 3,1,2,4,755 758 | 4,2,3,1,756 759 | 1,3,4,2,757 760 | 4,2,3,1,758 761 | 3,1,2,4,759 762 | 1,3,4,2,760 763 | 2,4,1,3,761 764 | 1,3,4,2,762 765 | 3,1,2,4,763 766 | 1,3,4,2,764 767 | 1,3,4,2,765 768 | 2,4,1,3,766 769 | 4,2,3,1,767 770 | 3,1,2,4,768 771 | 3,1,2,4,769 772 | 1,3,4,2,770 773 | 4,2,3,1,771 774 | 4,2,3,1,772 775 | 4,2,3,1,773 776 | 3,1,2,4,774 777 | 1,3,4,2,775 778 | 2,4,1,3,776 779 | 3,1,2,4,777 780 | 4,2,3,1,778 781 | 2,4,1,3,779 782 | 4,2,3,1,780 783 | 2,4,1,3,781 784 | 3,1,2,4,782 785 | 2,4,1,3,783 786 | 1,3,4,2,784 787 | 4,2,3,1,785 788 | 1,3,4,2,786 789 | 2,4,1,3,787 790 | 2,4,1,3,788 791 | 2,4,1,3,789 792 | 2,4,1,3,790 793 | 1,3,4,2,791 794 | 4,2,3,1,792 795 | 1,3,4,2,793 796 | 4,2,3,1,794 797 | 3,1,2,4,795 798 | 3,1,2,4,796 799 | 1,3,4,2,797 800 | 1,3,4,2,798 801 | 4,2,3,1,799 802 | 3,1,2,4,800 803 | 2,4,1,3,801 804 | 3,1,2,4,802 805 | 4,2,3,1,803 806 | 2,4,1,3,804 807 | 3,1,2,4,805 808 | 1,3,4,2,806 809 | 3,1,2,4,807 810 | 3,1,2,4,808 811 | 3,1,2,4,809 812 | 1,3,4,2,810 813 | 1,3,4,2,811 814 | 4,2,3,1,812 815 | 2,4,1,3,813 816 | 4,2,3,1,814 817 | 3,1,2,4,815 818 | 3,1,2,4,816 819 | 4,2,3,1,817 820 | 3,1,2,4,818 821 | 3,1,2,4,819 822 | 3,1,2,4,820 823 | 3,1,2,4,821 824 | 4,2,3,1,822 825 | 4,2,3,1,823 826 | 4,2,3,1,824 827 | 1,3,4,2,825 828 | 2,4,1,3,826 829 | 2,4,1,3,827 830 | 1,3,4,2,828 831 | 3,1,2,4,829 832 | 2,4,1,3,830 833 | 1,3,4,2,831 834 | 4,2,3,1,832 835 | 4,2,3,1,833 836 | 3,1,2,4,834 837 | 4,2,3,1,835 838 | 4,2,3,1,836 839 | 2,4,1,3,837 840 | 1,3,4,2,838 841 | 4,2,3,1,839 842 | 1,3,4,2,840 843 | 4,2,3,1,841 844 | 3,1,2,4,842 845 | 4,2,3,1,843 846 | 2,4,1,3,844 847 | 1,3,4,2,845 848 | 4,2,3,1,846 849 | 1,3,4,2,847 850 | 3,1,2,4,848 851 | 2,4,1,3,849 852 | 4,2,3,1,850 853 | 2,4,1,3,851 854 | 2,4,1,3,852 855 | 1,3,4,2,853 856 | 4,2,3,1,854 857 | 4,2,3,1,855 858 | 2,4,1,3,856 859 | 2,4,1,3,857 860 | 3,1,2,4,858 861 | 2,4,1,3,859 862 | 3,1,2,4,860 863 | 3,1,2,4,861 864 | 4,2,3,1,862 865 | 3,1,2,4,863 866 | 3,1,2,4,864 867 | 2,4,1,3,865 868 | 3,1,2,4,866 869 | 2,4,1,3,867 870 | 1,3,4,2,868 871 | 3,1,2,4,869 872 | 1,3,4,2,870 873 | 2,4,1,3,871 874 | 4,2,3,1,872 875 | 3,1,2,4,873 876 | 2,4,1,3,874 877 | 1,3,4,2,875 878 | 2,4,1,3,876 879 | 1,3,4,2,877 880 | 4,2,3,1,878 881 | 2,4,1,3,879 882 | 4,2,3,1,880 883 | 1,3,4,2,881 884 | 4,2,3,1,882 885 | 1,3,4,2,883 886 | 1,3,4,2,884 887 | 4,2,3,1,885 888 | 3,1,2,4,886 889 | 2,4,1,3,887 890 | 3,1,2,4,888 891 | 1,3,4,2,889 892 | 1,3,4,2,890 893 | 1,3,4,2,891 894 | 4,2,3,1,892 895 | 2,4,1,3,893 896 | 1,3,4,2,894 897 | 3,1,2,4,895 898 | 2,4,1,3,896 899 | 1,3,4,2,897 900 | 2,4,1,3,898 901 | 4,2,3,1,899 902 | 1,3,4,2,900 903 | 3,1,2,4,901 904 | 2,4,1,3,902 905 | 2,4,1,3,903 906 | 4,2,3,1,904 907 | 2,4,1,3,905 908 | 4,2,3,1,906 909 | 1,3,4,2,907 910 | 3,1,2,4,908 911 | 2,4,1,3,909 912 | 1,3,4,2,910 913 | 4,2,3,1,911 914 | 4,2,3,1,912 915 | 2,4,1,3,913 916 | 1,3,4,2,914 917 | 2,4,1,3,915 918 | 3,1,2,4,916 919 | 2,4,1,3,917 920 | 2,4,1,3,918 921 | 3,1,2,4,919 922 | 2,4,1,3,920 923 | 1,3,4,2,921 924 | 3,1,2,4,922 925 | 2,4,1,3,923 926 | 4,2,3,1,924 927 | 4,2,3,1,925 928 | 3,1,2,4,926 929 | 2,4,1,3,927 930 | 4,2,3,1,928 931 | 4,2,3,1,929 932 | 4,2,3,1,930 933 | 4,2,3,1,931 934 | 2,4,1,3,932 935 | 3,1,2,4,933 936 | 2,4,1,3,934 937 | 4,2,3,1,935 938 | 3,1,2,4,936 939 | 3,1,2,4,937 940 | 1,3,4,2,938 941 | 1,3,4,2,939 942 | 3,1,2,4,940 943 | 2,4,1,3,941 944 | 3,1,2,4,942 945 | 3,1,2,4,943 946 | 3,1,2,4,944 947 | 1,3,4,2,945 948 | 2,4,1,3,946 949 | 3,1,2,4,947 950 | 4,2,3,1,948 951 | 4,2,3,1,949 952 | 1,3,4,2,950 953 | 2,4,1,3,951 954 | 2,4,1,3,952 955 | 4,2,3,1,953 956 | 4,2,3,1,954 957 | 3,1,2,4,955 958 | 2,4,1,3,956 959 | 2,4,1,3,957 960 | 4,2,3,1,958 961 | 1,3,4,2,959 962 | 1,3,4,2,960 963 | 3,1,2,4,961 964 | 1,3,4,2,962 965 | 3,1,2,4,963 966 | 4,2,3,1,964 967 | 2,4,1,3,965 968 | 2,4,1,3,966 969 | 2,4,1,3,967 970 | 2,4,1,3,968 971 | 3,1,2,4,969 972 | 3,1,2,4,970 973 | 4,2,3,1,971 974 | 1,3,4,2,972 975 | 4,2,3,1,973 976 | 1,3,4,2,974 977 | 2,4,1,3,975 978 | 4,2,3,1,976 979 | 3,1,2,4,977 980 | 1,3,4,2,978 981 | 4,2,3,1,979 982 | 4,2,3,1,980 983 | 1,3,4,2,981 984 | 1,3,4,2,982 985 | 1,3,4,2,983 986 | 1,3,4,2,984 987 | 4,2,3,1,985 988 | 1,3,4,2,986 989 | 1,3,4,2,987 990 | 2,4,1,3,988 991 | 3,1,2,4,989 992 | 3,1,2,4,990 993 | 2,4,1,3,991 994 | 3,1,2,4,992 995 | 4,2,3,1,993 996 | 4,2,3,1,994 997 | 3,1,2,4,995 998 | 2,4,1,3,996 999 | 4,2,3,1,997 1000 | 4,2,3,1,998 1001 | 3,1,2,4,999 1002 | 3,1,2,4,1000 1003 | 4,2,3,1,1001 1004 | 2,4,1,3,1002 1005 | 2,4,1,3,1003 1006 | 4,2,3,1,1004 1007 | 4,2,3,1,1005 1008 | 1,3,4,2,1006 1009 | 3,1,2,4,1007 1010 | 1,3,4,2,1008 1011 | 4,2,3,1,1009 1012 | 4,2,3,1,1010 1013 | 4,2,3,1,1011 1014 | 3,1,2,4,1012 1015 | 1,3,4,2,1013 1016 | 3,1,2,4,1014 1017 | 1,3,4,2,1015 1018 | 2,4,1,3,1016 1019 | 2,4,1,3,1017 1020 | 3,1,2,4,1018 1021 | 2,4,1,3,1019 1022 | 1,3,4,2,1020 1023 | 4,2,3,1,1021 1024 | 4,2,3,1,1022 1025 | 1,3,4,2,1023 1026 | 2,4,1,3,1024 1027 | 1,3,4,2,1025 1028 | 4,2,3,1,1026 1029 | 1,3,4,2,1027 1030 | 1,3,4,2,1028 1031 | 3,1,2,4,1029 1032 | 2,4,1,3,1030 1033 | 1,3,4,2,1031 1034 | 2,4,1,3,1032 1035 | 2,4,1,3,1033 1036 | 3,1,2,4,1034 1037 | 2,4,1,3,1035 1038 | 4,2,3,1,1036 1039 | 4,2,3,1,1037 1040 | 4,2,3,1,1038 1041 | 1,3,4,2,1039 1042 | 2,4,1,3,1040 1043 | 1,3,4,2,1041 1044 | 1,3,4,2,1042 1045 | 2,4,1,3,1043 1046 | 2,4,1,3,1044 1047 | 4,2,3,1,1045 1048 | 2,4,1,3,1046 1049 | 2,4,1,3,1047 1050 | 4,2,3,1,1048 1051 | 2,4,1,3,1049 1052 | 4,2,3,1,1050 1053 | 3,1,2,4,1051 1054 | 2,4,1,3,1052 1055 | 3,1,2,4,1053 1056 | 4,2,3,1,1054 1057 | 4,2,3,1,1055 1058 | 1,3,4,2,1056 1059 | 3,1,2,4,1057 1060 | 2,4,1,3,1058 1061 | 3,1,2,4,1059 1062 | 1,3,4,2,1060 1063 | 4,2,3,1,1061 1064 | 1,3,4,2,1062 1065 | 2,4,1,3,1063 1066 | 2,4,1,3,1064 1067 | 1,3,4,2,1065 1068 | 1,3,4,2,1066 1069 | 1,3,4,2,1067 1070 | 3,1,2,4,1068 1071 | 4,2,3,1,1069 1072 | 2,4,1,3,1070 1073 | 3,1,2,4,1071 1074 | 4,2,3,1,1072 1075 | 4,2,3,1,1073 1076 | 3,1,2,4,1074 1077 | 4,2,3,1,1075 1078 | 1,3,4,2,1076 1079 | 2,4,1,3,1077 1080 | 2,4,1,3,1078 1081 | 3,1,2,4,1079 1082 | 2,4,1,3,1080 1083 | 4,2,3,1,1081 1084 | 4,2,3,1,1082 1085 | 2,4,1,3,1083 1086 | 1,3,4,2,1084 1087 | 3,1,2,4,1085 1088 | 3,1,2,4,1086 1089 | 3,1,2,4,1087 1090 | 2,4,1,3,1088 1091 | 1,3,4,2,1089 1092 | 2,4,1,3,1090 1093 | 3,1,2,4,1091 1094 | 2,4,1,3,1092 1095 | 3,1,2,4,1093 1096 | 4,2,3,1,1094 1097 | 4,2,3,1,1095 1098 | 1,3,4,2,1096 1099 | 4,2,3,1,1097 1100 | 2,4,1,3,1098 1101 | 3,1,2,4,1099 1102 | 1,3,4,2,1100 1103 | 1,3,4,2,1101 1104 | 1,3,4,2,1102 1105 | 4,2,3,1,1103 1106 | 1,3,4,2,1104 1107 | 3,1,2,4,1105 1108 | 1,3,4,2,1106 1109 | 4,2,3,1,1107 1110 | -------------------------------------------------------------------------------- /loss/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Scu-sen/Recursion-Cellular-Image-Classification-Challenge/a45e48905b673ea89763f0648623c89fbddb85f1/loss/__init__.py -------------------------------------------------------------------------------- /loss/cyclic_lr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | from torch.optim.lr_scheduler import _LRScheduler 4 | 5 | class CosineAnnealingLR_with_Restart(_LRScheduler): 6 | """Set the learning rate of each parameter group using a cosine annealing 7 | schedule, where :math:`\eta_{max}` is set to the initial lr and 8 | :math:`T_{cur}` is the number of epochs since the last restart in SGDR: 9 | 10 | .. math:: 11 | 12 | \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})(1 + 13 | \cos(\frac{T_{cur}}{T_{max}}\pi)) 14 | 15 | When last_epoch=-1, sets initial lr as lr. 16 | 17 | It has been proposed in 18 | `SGDR: Stochastic Gradient Descent with Warm Restarts`_. The original pytorch 19 | implementation only implements the cosine annealing part of SGDR, 20 | I added my own implementation of the restarts part. 21 | 22 | Args: 23 | optimizer (Optimizer): Wrapped optimizer. 24 | T_max (int): Maximum number of iterations. 25 | T_mult (float): Increase T_max by a factor of T_mult 26 | eta_min (float): Minimum learning rate. Default: 0. 27 | last_epoch (int): The index of last epoch. Default: -1. 28 | model (pytorch model): The model to save. 29 | out_dir (str): Directory to save snapshots 30 | take_snapshot (bool): Whether to save snapshots at every restart 31 | 32 | .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: 33 | https://arxiv.org/abs/1608.03983 34 | """ 35 | 36 | def __init__(self, optimizer, T_max, T_mult, model, out_dir, take_snapshot, eta_min=0, last_epoch=-1): 37 | self.T_max = T_max 38 | self.T_mult = T_mult 39 | self.Te = self.T_max 40 | self.eta_min = eta_min 41 | self.current_epoch = last_epoch 42 | 43 | self.model = model 44 | self.out_dir = out_dir 45 | self.take_snapshot = take_snapshot 46 | 47 | self.lr_history = [] 48 | 49 | super(CosineAnnealingLR_with_Restart, self).__init__(optimizer, last_epoch) 50 | 51 | def get_lr(self): 52 | new_lrs = [self.eta_min + (base_lr - self.eta_min) * 53 | (1 + math.cos(math.pi * self.current_epoch / self.Te)) / 2 54 | for base_lr in self.base_lrs] 55 | 56 | self.lr_history.append(new_lrs) 57 | return new_lrs 58 | 59 | def step(self, epoch=None): 60 | if epoch is None: 61 | epoch = self.last_epoch + 1 62 | self.last_epoch = epoch 63 | self.current_epoch += 1 64 | 65 | for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): 66 | param_group['lr'] = lr 67 | 68 | ## restart 69 | if self.current_epoch == self.Te: 70 | print("restart at epoch {:03d}".format(self.last_epoch + 1)) 71 | 72 | if self.take_snapshot: 73 | torch.save({ 74 | 'epoch': self.T_max, 75 | 'state_dict': self.model.state_dict() 76 | }, self.out_dir + "Weight/" + 'snapshot_e_{:03d}.pth.tar'.format(self.T_max)) 77 | 78 | ## reset epochs since the last reset 79 | self.current_epoch = 0 80 | 81 | ## reset the next goal 82 | self.Te = int(self.Te * self.T_mult) 83 | self.T_max = self.T_max + self.Te -------------------------------------------------------------------------------- /loss/loss.py: -------------------------------------------------------------------------------- 1 | from include import * 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | def l2_norm(input, axis=1): 7 | norm = torch.norm(input,2, axis, True) 8 | output = torch.div(input, norm) 9 | return output 10 | 11 | def euclidean_dist(x, y): 12 | 13 | m, n = x.size(0), y.size(0) 14 | xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n) 15 | yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t() 16 | dist = xx + yy 17 | dist.addmm_(1, -2, x, y.t()) 18 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 19 | return dist 20 | 21 | def hard_example_mining(dist_mat, labels, return_inds=False): 22 | 23 | assert len(dist_mat.size()) == 2 24 | assert dist_mat.size(0) == dist_mat.size(1) 25 | N = dist_mat.size(0) 26 | 27 | # shape [N, N] 28 | is_pos = labels.expand(N, N).eq(labels.expand(N, N).t()) 29 | is_neg = labels.expand(N, N).ne(labels.expand(N, N).t()) 30 | 31 | # `dist_ap` means distance(anchor, positive) 32 | # both `dist_ap` and `relative_p_inds` with shape [N, 1] 33 | dist_ap, relative_p_inds = torch.max((dist_mat*is_pos.float()).contiguous().view(N, -1), 1, keepdim=True) 34 | 35 | # `dist_an` means distance(anchor, negative) 36 | # both `dist_an` and `relative_n_inds` with shape [N, 1] 37 | temp = dist_mat*is_neg.float() 38 | temp[temp==0] = 10e5 39 | 40 | dist_an, relative_n_inds = torch.min(temp.contiguous().view(N, -1), 1, keepdim=True) 41 | # shape [N] 42 | dist_ap = dist_ap.squeeze(1) 43 | dist_an = dist_an.squeeze(1) 44 | 45 | if return_inds: 46 | # shape [N, N] 47 | ind = (labels.new().resize_as_(labels) 48 | .copy_(torch.arange(0, N).long()) 49 | .unsqueeze(0).expand(N, N)) 50 | # shape [N, 1] 51 | p_inds = torch.gather( 52 | ind[is_pos].contiguous().view(N, -1), 1, relative_p_inds.data) 53 | n_inds = torch.gather( 54 | ind[is_neg].contiguous().view(N, -1), 1, relative_n_inds.data) 55 | # shape [N] 56 | p_inds = p_inds.squeeze(1) 57 | n_inds = n_inds.squeeze(1) 58 | return dist_ap, dist_an, p_inds, n_inds 59 | 60 | return dist_ap, dist_an 61 | 62 | class TripletLoss(object): 63 | """Modified from Tong Xiao's open-reid (https://github.com/Cysu/open-reid). 64 | Related Triplet Loss theory can be found in paper 'In Defense of the Triplet 65 | Loss for Person Re-Identification'.""" 66 | 67 | def __init__(self, margin=None): 68 | self.margin = margin 69 | 70 | if margin is not None: 71 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 72 | else: 73 | self.ranking_loss = nn.SoftMarginLoss() 74 | 75 | def __call__(self, global_feat, labels, normalize_feature=False): 76 | 77 | if normalize_feature: 78 | global_feat = l2_norm(global_feat) 79 | 80 | dist_mat = euclidean_dist(global_feat, global_feat) 81 | dist_ap, dist_an = hard_example_mining(dist_mat, labels) 82 | 83 | y = dist_an.new().resize_as_(dist_an).fill_(1) 84 | 85 | if self.margin is not None: 86 | loss = self.ranking_loss(dist_an, dist_ap, y) 87 | else: 88 | loss = self.ranking_loss(dist_an - dist_ap, y) 89 | 90 | return loss 91 | 92 | def softmax_loss(results, labels, ignore_index = -1): 93 | labels = labels.view(-1) 94 | loss = F.cross_entropy(results, labels, reduction='mean', ignore_index=ignore_index) 95 | return loss 96 | 97 | def focal_loss(input, target, OHEM_percent=None): 98 | gamma = 2 99 | assert target.size() == input.size() 100 | 101 | max_val = (-input).clamp(min=0) 102 | loss = input - input * target + max_val + ((-max_val).exp() + (-input - max_val).exp()).log() 103 | invprobs = F.logsigmoid(-input * (target * 2 - 1)) 104 | loss = (invprobs * gamma).exp() * loss 105 | 106 | if OHEM_percent is None: 107 | return loss.mean() 108 | else: 109 | OHEM, _ = loss.topk(k=int(10008 * OHEM_percent), dim=1, largest=True, sorted=True) 110 | return OHEM.mean() 111 | 112 | def bce_loss(input, target, OHEM_percent=None): 113 | if OHEM_percent is None: 114 | loss = F.binary_cross_entropy_with_logits(input, target, reduce=True) 115 | return loss 116 | else: 117 | loss = F.binary_cross_entropy_with_logits(input, target, reduce=False) 118 | value, index= loss.topk(int(10008 * OHEM_percent), dim=1, largest=True, sorted=True) 119 | return value.mean() 120 | 121 | def focal_OHEM(results, labels, labels_onehot, OHEM_percent=100): 122 | batch_size, class_num = results.shape 123 | labels = labels.view(-1) 124 | loss0 = bce_loss(results, labels_onehot, OHEM_percent) 125 | loss1 = focal_loss(results, labels_onehot, OHEM_percent) 126 | indexs_ = (labels != class_num).nonzero().view(-1) 127 | if len(indexs_) == 0: 128 | return loss0 + loss1 129 | results_ = results[torch.arange(0,len(results))[indexs_],labels[indexs_]].contiguous() 130 | loss2 = focal_loss(results_, torch.ones_like(results_).float().cuda()) 131 | return loss0 + loss1 + loss2 -------------------------------------------------------------------------------- /loss/rate.py: -------------------------------------------------------------------------------- 1 | # learning rate schduler 2 | from whale.net.include import * 3 | 4 | # http://elgoacademy.org/anatomy-matplotlib-part-1/ 5 | def plot_rates(fig, lrs, title=''): 6 | N = len(lrs) 7 | epoches = np.arange(0,N) 8 | 9 | #get limits 10 | max_lr = np.max(lrs) 11 | xmin=0 12 | xmax=N 13 | dx=2 14 | 15 | ymin=0 16 | ymax=max_lr*1.2 17 | dy=(ymax-ymin)/10 18 | dy=10**math.ceil(math.log10(dy)) 19 | 20 | ax = fig.add_subplot(111) 21 | #ax = fig.gca() 22 | ax.set_axisbelow(True) 23 | ax.minorticks_on() 24 | ax.set_xticks(np.arange(xmin,xmax+0.0001, dx)) 25 | ax.set_yticks(np.arange(ymin,ymax+0.0001, dy)) 26 | ax.set_xlim(xmin,xmax+0.0001) 27 | ax.set_ylim(ymin,ymax+0.0001) 28 | ax.grid(b=True, which='minor', color='black', alpha=0.1, linestyle='dashed') 29 | ax.grid(b=True, which='major', color='black', alpha=0.4, linestyle='dashed') 30 | 31 | ax.set_xlabel('iter') 32 | ax.set_ylabel('learning rate') 33 | ax.set_title(title) 34 | ax.plot(epoches, lrs) 35 | 36 | 37 | 38 | ## simple stepping rates 39 | class StepScheduler(): 40 | def __init__(self, pairs): 41 | super(StepScheduler, self).__init__() 42 | 43 | N=len(pairs) 44 | rates=[] 45 | steps=[] 46 | for n in range(N): 47 | steps.append(pairs[n][0]) 48 | rates.append(pairs[n][1]) 49 | 50 | self.rates = rates 51 | self.steps = steps 52 | 53 | def __call__(self, epoch): 54 | 55 | N = len(self.steps) 56 | lr = -1 57 | for n in range(N): 58 | if epoch >= self.steps[n]: 59 | lr = self.rates[n] 60 | return lr 61 | 62 | def __str__(self): 63 | string = 'Step Learning Rates\n' \ 64 | + 'rates=' + str(['%7.4f' % i for i in self.rates]) + '\n' \ 65 | + 'steps=' + str(['%7.0f' % i for i in self.steps]) + '' 66 | return string 67 | 68 | 69 | ## https://github.com/pytorch/tutorials/blob/master/beginner_source/transfer_learning_tutorial.py 70 | class DecayScheduler(): 71 | def __init__(self, base_lr, decay, step): 72 | # super(DecayScheduler, self).__init__() 73 | self.step = step 74 | self.decay = decay 75 | self.base_lr = base_lr 76 | 77 | def get_rate(self, epoch): 78 | lr = self.base_lr * (self.decay**(epoch // self.step)) 79 | return lr 80 | 81 | def __str__(self): 82 | string = '(Exp) Decay Learning Rates\n' \ 83 | + 'base_lr=%0.3f, decay=%0.3f, step=%0.3f'%(self.base_lr, self.decay, self.step) 84 | return string 85 | 86 | 87 | # 'Cyclical Learning Rates for Training Neural Networks'- Leslie N. Smith, arxiv 2017 88 | # https://arxiv.org/abs/1506.01186 89 | # https://github.com/bckenstler/CLR 90 | 91 | class CyclicScheduler0(): 92 | 93 | def __init__(self, min_lr=0.001, max_lr=0.01, period=10 ): 94 | super(CyclicScheduler, self).__init__() 95 | 96 | self.min_lr = min_lr 97 | self.max_lr = max_lr 98 | self.period = period 99 | 100 | def __call__(self, time): 101 | 102 | #sawtooth 103 | #r = (1-(time%self.period)/self.period) 104 | 105 | #cosine 106 | time= time%self.period 107 | r = (np.cos(time/self.period *PI)+1)/2 108 | 109 | lr = self.min_lr + r*(self.max_lr-self.min_lr) 110 | return lr 111 | 112 | def __str__(self): 113 | string = 'CyclicScheduler\n' \ 114 | + 'min_lr=%0.3f, max_lr=%0.3f, period=%8.1f'%(self.min_lr, self.max_lr, self.period) 115 | return string 116 | 117 | 118 | class CyclicScheduler1(): 119 | 120 | def __init__(self, min_lr=0.001, max_lr=0.01, period=10, max_decay=0.99, warm_start=0 ): 121 | super(CyclicScheduler, self).__init__() 122 | 123 | self.min_lr = min_lr 124 | self.max_lr = max_lr 125 | self.period = period 126 | self.max_decay = max_decay 127 | self.warm_start = warm_start 128 | self.cycle = -1 129 | 130 | def __call__(self, time): 131 | if time\n') 181 | log.write('\t ... xxx baseline ... \n') 182 | log.write('\n') 183 | 184 | ## dataset ---------------------------------------- 185 | log.write('** dataset setting **\n') 186 | assert(len(train_dataset)>=batch_size) 187 | log.write('batch_size = %d\n'%(batch_size)) 188 | log.write('train_dataset : \n%s\n'%(train_dataset)) 189 | log.write('valid_dataset : \n%s\n'%(valid_dataset)) 190 | log.write('\n') 191 | 192 | ## net ---------------------------------------- 193 | log.write('** net setting **\n') 194 | if initial_checkpoint is not None: 195 | log.write('\tinitial_checkpoint = %s\n' % initial_checkpoint) 196 | net.load_state_dict(torch.load(initial_checkpoint, map_location=lambda storage, loc: storage)) 197 | print('\tinitial_checkpoint = %s\n' % initial_checkpoint) 198 | 199 | log.write('%s\n'%(type(net))) 200 | log.write('\n') 201 | 202 | if config.bias_no_decay: 203 | print('bias no decay !!!!!!!!!!!!!!!!!!') 204 | train_params = split_weights(net) 205 | else: 206 | train_params = filter(lambda p: p.requires_grad, net.parameters()) 207 | 208 | optimizer = torch.optim.Adam(train_params, lr=base_lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.0002) 209 | 210 | iter_smooth = 20 211 | start_iter = 0 212 | 213 | log.write('\n') 214 | ## start training here! ############################################## 215 | log.write('** top_n step 100,60,60,60 **\n') 216 | log.write('** start training here! **\n') 217 | log.write(' |------------ VALID -------------|-------- TRAIN/BATCH ----------| \n') 218 | log.write('rate iter epoch | loss acc-1 acc-5 lb | loss acc-1 acc-5 lb | time \n') 219 | log.write('----------------------------------------------------------------------------------------------------\n') 220 | 221 | print('** start training here! **\n') 222 | print(' |------------ VALID -------------|-------- TRAIN/BATCH ----------| \n') 223 | print('rate iter epoch | loss acc-1 acc-5 lb | loss acc-1 acc-5 lb | time \n') 224 | print('----------------------------------------------------------------------------------------------------\n') 225 | 226 | valid_loss = np.zeros(6,np.float32) 227 | batch_loss = np.zeros(6,np.float32) 228 | 229 | i = 0 230 | start = timer() 231 | max_valid = 0 232 | 233 | for epoch in range(config.train_epoch): 234 | sum_train_loss = np.zeros(6,np.float32) 235 | sum = 0 236 | optimizer.zero_grad() 237 | 238 | rate, hard_ratio = adjust_lr_and_hard_ratio(optimizer, epoch + 1) 239 | print('change lr: '+str(rate)) 240 | 241 | for input, truth_, _ in train_loader: 242 | iter = i + start_iter 243 | # one iteration update ------------- 244 | net.train() 245 | input = input.cuda() 246 | truth_ = truth_.cuda() 247 | 248 | input = to_var(input) 249 | truth_ = to_var(truth_) 250 | 251 | if config.is_arc: 252 | logit = net(input, truth_) 253 | else: 254 | logit = net(input) 255 | 256 | loss = softmax_loss(logit, truth_) * config.softmax_w 257 | _, precision = metric(logit, truth_) 258 | 259 | loss.backward() 260 | optimizer.step() 261 | optimizer.zero_grad() 262 | 263 | batch_loss[:4] = np.array((loss.data.cpu().numpy(), 264 | precision[0].data.cpu().numpy(), 265 | precision[2].data.cpu().numpy(), 266 | loss.data.cpu().numpy())).reshape([4]) 267 | 268 | sum_train_loss += batch_loss 269 | sum += 1 270 | 271 | if iter%iter_smooth == 0: 272 | sum_train_loss = np.zeros(6,np.float32) 273 | sum = 0 274 | 275 | if i % 10 == 0: 276 | print(model_name + ' %0.7f %5.1f %6.1f | %0.3f %0.3f %0.3f (%0.3f)%s | %0.3f %0.3f %0.3f (%0.3f) | %s' % (\ 277 | rate, iter, epoch, 278 | valid_loss[0], valid_loss[1], valid_loss[2], valid_loss[3],' ', 279 | batch_loss[0], batch_loss[1], batch_loss[2], batch_loss[3], 280 | time_to_str((timer() - start),'min'))) 281 | 282 | if i % 100 == 0: 283 | log.write('%0.7f %5.1f %6.1f | %0.3f %0.3f %0.3f (%0.3f)%s | %0.3f %0.3f %0.3f (%0.3f) | %s' % (\ 284 | rate, iter, epoch, 285 | valid_loss[0], valid_loss[1], valid_loss[2], valid_loss[3],' ', 286 | batch_loss[0], batch_loss[1], batch_loss[2], batch_loss[3], 287 | time_to_str((timer() - start),'min'))) 288 | 289 | log.write('\n') 290 | 291 | i=i+1 292 | if i%1000==0: 293 | net.eval() 294 | valid_loss = do_valid(net, valid_loader) 295 | net.train() 296 | 297 | if max_valid < valid_loss[2]: 298 | max_valid = valid_loss[2] 299 | print('save max valid!!!!!! : ' + str(max_valid)) 300 | log.write('save max valid!!!!!! : ' + str(max_valid)) 301 | log.write('\n') 302 | torch.save(net.state_dict(), out_dir + '/checkpoint/max_valid_model.pth') 303 | 304 | if (epoch+1) % config.iter_save_interval ==0 and epoch>0: 305 | torch.save(net.state_dict(), out_dir + '/checkpoint/%08d_model.pth' % (epoch)) 306 | 307 | net.eval() 308 | valid_loss = do_valid(net, valid_loader) 309 | net.train() 310 | 311 | if max_valid < valid_loss[2]: 312 | max_valid = valid_loss[2] 313 | print('save max valid!!!!!! : ' + str(max_valid)) 314 | log.write('save max valid!!!!!! : ' + str(max_valid)) 315 | log.write('\n') 316 | torch.save(net.state_dict(), out_dir + '/checkpoint/max_valid_model.pth') 317 | 318 | def run_finetune_CELL_cyclic_semi_final(config, 319 | cell, 320 | tag = None, 321 | fold_index = 0, 322 | epoch_ratio = 1.0, 323 | online_pseudo = True): 324 | 325 | if config.rgb: 326 | model = config.model + '_rgb' 327 | else: 328 | model = config.model + '_6channel' 329 | 330 | model_name = model + '_' + str(config.image_size)+'_fold'+str(config.train_fold_index)+'_'+config.tag 331 | 332 | batch_size = config.batch_size 333 | ## setup ----------------------------------------------------------------------------- 334 | out_dir = os.path.join(model_save_path, model_name) 335 | print(out_dir) 336 | 337 | if config.pretrained_model is not None: 338 | initial_checkpoint = os.path.join(out_dir, 'checkpoint', config.pretrained_model) 339 | else: 340 | initial_checkpoint = None 341 | 342 | train_dataset = Dataset_(path_data + r'/train_fold_'+str(fold_index)+'.csv', 343 | img_dir=data_dir, 344 | mode='train', 345 | image_size=config.image_size, 346 | rgb=config.rgb, 347 | cell = cell, 348 | augment=[0,0,0], 349 | is_TransTwice=True) 350 | 351 | semi_valid_dataset = Dataset_(path_data + r'/valid_fold_'+str(fold_index)+'.csv', 352 | img_dir=data_dir, 353 | mode='semi_valid', 354 | image_size=config.image_size, 355 | rgb=config.rgb, 356 | cell=cell, 357 | is_TransTwice=True) 358 | 359 | infer_dataset = Dataset_(data_dir + '/test.csv', 360 | data_dir, 361 | image_size=config.image_size, 362 | mode='semi_test', 363 | rgb=config.rgb, 364 | cell=cell, 365 | is_TransTwice=True) 366 | 367 | train_loader = DataLoader(train_dataset + semi_valid_dataset + infer_dataset, 368 | shuffle = True, 369 | batch_size = batch_size, 370 | drop_last = True, 371 | num_workers = 8, 372 | pin_memory = True) 373 | 374 | pseudo_loader = DataLoader(semi_valid_dataset + infer_dataset, 375 | shuffle = False, 376 | batch_size = batch_size, 377 | drop_last = False, 378 | num_workers = 8, 379 | pin_memory = True) 380 | 381 | valid_dataset = Dataset_(path_data + r'/valid_fold_'+str(fold_index)+'.csv', 382 | img_dir=data_dir, 383 | mode='valid', 384 | image_size=config.image_size, 385 | rgb=config.rgb, 386 | cell=cell, 387 | augment=[0,0,0]) 388 | 389 | valid_loader = DataLoader(valid_dataset, 390 | shuffle = False, 391 | batch_size = batch_size, 392 | drop_last = False, 393 | num_workers = 8, 394 | pin_memory = True) 395 | 396 | out_dir = os.path.join(out_dir, 'checkpoint') 397 | if not os.path.exists(out_dir): 398 | os.makedirs(out_dir) 399 | 400 | if tag is not None: 401 | cell += '_' + tag 402 | 403 | log = open(out_dir+'/log.'+cell+'_finetune_train.txt', mode='a') 404 | log.write('\t__file__ = %s\n' % __file__) 405 | log.write('\tout_dir = %s\n' % out_dir) 406 | log.write('\n') 407 | log.write('\t\n') 408 | log.write('\t ... xxx baseline ... \n') 409 | log.write('\n') 410 | 411 | ## dataset ---------------------------------------- 412 | log.write('** dataset setting **\n') 413 | assert(len(train_dataset)>=batch_size) 414 | log.write('batch_size = %d\n'%(batch_size)) 415 | log.write('train_dataset : \n%s\n'%(train_dataset)) 416 | log.write('valid_dataset : \n%s\n'%(valid_dataset)) 417 | log.write('\n') 418 | 419 | net = get_model(model) 420 | ema_net = get_model(model) 421 | 422 | for param in ema_net.parameters(): 423 | param.detach_() 424 | 425 | ## net ---------------------------------------- 426 | log.write('** net setting **\n') 427 | if initial_checkpoint is not None: 428 | log.write('\tinitial_checkpoint = %s\n' % initial_checkpoint) 429 | net.load_pretrain(initial_checkpoint, is_skip_fc=True) 430 | ema_net.load_pretrain(initial_checkpoint, is_skip_fc=False) 431 | print('\tinitial_checkpoint = %s\n' % initial_checkpoint) 432 | 433 | ## optimiser ---------------------------------- 434 | net = torch.nn.DataParallel(net) 435 | net = net.cuda() 436 | ema_net = torch.nn.DataParallel(ema_net) 437 | ema_net = ema_net.cuda() 438 | 439 | log.write('%s\n'%(type(net))) 440 | log.write('\n') 441 | 442 | if config.bias_no_decay: 443 | print('bias no decay') 444 | train_params = split_weights(net) 445 | else: 446 | train_params = filter(lambda p: p.requires_grad, net.parameters()) 447 | 448 | iter_smooth = 20 449 | start_iter = 0 450 | 451 | log.write('\n') 452 | ## start training here! ############################################## 453 | log.write('** top_n step 100,60,60,60 **\n') 454 | log.write('** start training here! **\n') 455 | log.write(' |------------ VALID -------------|-------- TRAIN/BATCH ----------| \n') 456 | log.write('rate iter epoch | loss acc-1 acc-5 lb | loss acc-1 acc-5 lb | time \n') 457 | log.write('----------------------------------------------------------------------------------------------------\n') 458 | 459 | print('** start training here! **\n') 460 | print(' |------------ VALID -------------|-------- TRAIN/BATCH ----------| \n') 461 | print('rate iter epoch | loss acc-1 acc-5 lb | loss acc-1 acc-5 lb | time \n') 462 | print('----------------------------------------------------------------------------------------------------\n') 463 | 464 | ##### pay attention to this 465 | all_semi_ids = list(semi_valid_dataset.records['id_code']) + list(infer_dataset.records['id_code']) 466 | pseudo_logit = np.zeros([len(infer_dataset) + len(semi_valid_dataset), 1108]) 467 | pseudo_labels = np.zeros([len(infer_dataset) + len(semi_valid_dataset), 1]) 468 | 469 | def get_pseudo_labels(pseudo_labels, ids, all_semi_ids): 470 | indexs = [all_semi_ids.index(tmp) for tmp in ids] 471 | labels = np.argmax(pseudo_labels, axis=1) 472 | batch_labels = torch.LongTensor(np.asarray(labels[indexs])) 473 | return batch_labels 474 | 475 | def update_pseudo_labels(pseudo_logit, ids, ratio = 0.9, tta_num = 1, iters = 3000): 476 | probs_all = np.zeros([len(infer_dataset) + len(valid_dataset), 1108]) 477 | 478 | for i in range(tta_num): 479 | pseudo_num = 0 480 | probs = [] 481 | 482 | with torch.no_grad(): 483 | for _, input, _, _, _ in pseudo_loader: 484 | input = input.cuda() 485 | input = to_var(input) 486 | logit = ema_net.forward(input) 487 | logit = logit[:, 0: 1108] 488 | probs.append(logit) 489 | pseudo_num += len(input) 490 | 491 | assert (pseudo_num == len(pseudo_loader.sampler)) 492 | probs_ = torch.cat(probs).cpu().numpy().reshape([pseudo_num, 1108]) 493 | probs_all += probs_ / tta_num 494 | 495 | pseudo_logit = pseudo_logit * (1-ratio) + probs_all * ratio 496 | pseudo_labels, ids = balance_plate_probability_training(pseudo_logit, ids, plate_dict, a_dict, iters) 497 | return pseudo_logit, pseudo_labels 498 | 499 | i = 0 500 | start = timer() 501 | 502 | base_lr = 1e-1 503 | optimizer = torch.optim.SGD(train_params, lr=base_lr, weight_decay=0.0002) 504 | 505 | cycle_inter = int(40 * epoch_ratio) 506 | cycle_num = 4 507 | 508 | sgdr = CosineAnnealingLR_with_Restart(optimizer, 509 | T_max=cycle_inter, 510 | T_mult=1, 511 | model=net, 512 | out_dir='../input/', 513 | take_snapshot=False, 514 | eta_min=1e-3) 515 | 516 | last_lr = 1e-4 517 | last_epoch = 5 518 | 519 | epoch_all = 0 520 | epoch_stop = cycle_inter * cycle_num + last_epoch 521 | 522 | max_valid_all = 0 523 | max_valid_ema = 0 524 | valid_loss_ema = np.zeros(6, np.float32) 525 | 526 | for cycle_index in range(cycle_num+1): 527 | print('cycle index: ' + str(cycle_index)) 528 | 529 | max_valid = 0 530 | valid_loss = np.zeros(6, np.float32) 531 | batch_loss = np.zeros(6, np.float32) 532 | 533 | for epoch in range(cycle_inter): 534 | epoch_all += 1 535 | 536 | if epoch_all > epoch_stop: 537 | return 538 | 539 | elif epoch_all > cycle_inter * cycle_num: 540 | for p in optimizer.param_groups: 541 | p['lr'] = last_lr 542 | rate = last_lr 543 | else: 544 | sgdr.step() 545 | rate = optimizer.param_groups[0]['lr'] 546 | 547 | print('change lr: ' + str(rate)) 548 | 549 | net.eval() 550 | valid_loss = do_valid(net, valid_loader) 551 | net.train() 552 | 553 | ema_net.eval() 554 | valid_loss_ema = do_valid(ema_net, valid_loader) 555 | ema_net.train() 556 | 557 | if max_valid < valid_loss[2] and epoch > 0 : 558 | max_valid = valid_loss[2] 559 | print('save max valid!!!!!! : ' + str(max_valid)) 560 | log.write('save max valid!!!!!! : ' + str(max_valid)) 561 | log.write('\n') 562 | 563 | if max_valid_all < valid_loss[2] and epoch > 0 : 564 | max_valid_all = valid_loss[2] 565 | print('save max valid all!!!!!! : ' + str(max_valid_all)) 566 | log.write('save max valid all!!!!!! : ' + str(max_valid_all)) 567 | log.write('\n') 568 | torch.save(net.state_dict(), out_dir + '/max_valid_model_' + cell + 569 | '_snapshot_all.pth') 570 | 571 | if epoch_all == (10 * 2 * epoch_ratio - 1): 572 | iters_balance = 500 573 | pseudo_logit, pseudo_labels = update_pseudo_labels(pseudo_logit, 574 | all_semi_ids, 575 | ratio=1.0, 576 | tta_num=4, 577 | iters=iters_balance) 578 | 579 | 580 | if max_valid_ema < valid_loss_ema[2] and epoch > 0 : 581 | max_valid_ema = valid_loss_ema[2] 582 | print('save max valid ema!!!!!! : ' + str(max_valid_ema)) 583 | log.write('save max valid ema!!!!!! : ' + str(max_valid_ema)) 584 | log.write('\n') 585 | torch.save(net.state_dict(), out_dir + '/max_valid_model_' + cell + 586 | '_snapshot_ema.pth') 587 | 588 | if online_pseudo: 589 | print('update_pseudo_labels') 590 | if epoch_all > (10*2*epoch_ratio-1): 591 | iters_balance = 500 592 | pseudo_logit, pseudo_labels = update_pseudo_labels(pseudo_logit, 593 | all_semi_ids, 594 | ratio=0.9, 595 | tta_num=4, 596 | iters=iters_balance) 597 | 598 | elif epoch_all > (20*2*epoch_ratio-1): 599 | iters_balance = 1000 600 | pseudo_logit, pseudo_labels = update_pseudo_labels(pseudo_logit, 601 | all_semi_ids, 602 | ratio=0.9, 603 | tta_num=4, 604 | iters=iters_balance) 605 | 606 | elif epoch_all > (40*2*epoch_ratio-1): 607 | iters_balance = 3000 608 | pseudo_logit, pseudo_labels = update_pseudo_labels(pseudo_logit, 609 | all_semi_ids, 610 | ratio=0.9, 611 | tta_num=8, 612 | iters=iters_balance) 613 | 614 | 615 | sum_train_loss = np.zeros(6,np.float32) 616 | sum = 0 617 | optimizer.zero_grad() 618 | 619 | for input, input_easy, input_hard, truth_, id in train_loader: 620 | iter = i + start_iter 621 | 622 | # one iteration update ------------- 623 | net.train() 624 | input = input.cuda() 625 | input_easy = input_easy.cuda() 626 | input_hard = input_hard.cuda() 627 | truth_ = truth_.cuda() 628 | 629 | input = to_var(input) 630 | input_easy = to_var(input_easy) 631 | input_hard = to_var(input_hard) 632 | truth_ = to_var(truth_) 633 | 634 | indexs_supervised = (truth_ != -1).nonzero().view(-1) 635 | indexs_semi = (truth_ == -1).nonzero().view(-1) 636 | 637 | if len(indexs_semi) == 0 or len(indexs_supervised) == 0 : 638 | continue 639 | 640 | if config.is_arc: 641 | logit_arc = net(input[indexs_supervised], truth_[indexs_supervised]) 642 | logit = net(input_easy) 643 | ema_logit = ema_net(input_hard) 644 | 645 | ##################### consistency loss ######################################################################## 646 | consistency = 100.0 647 | consistency_rp = 5 648 | 649 | if consistency: 650 | consistency_weight = get_current_consistency_weight(epoch_all, consistency, consistency_rp) 651 | ema_logit = Variable(ema_logit.detach().data, requires_grad=False) 652 | consistency_loss = consistency_weight * softmax_mse_loss(logit, ema_logit) / batch_size 653 | else: 654 | consistency_loss = 0 655 | 656 | ##################### online_pseudo label ##################################################################### 657 | if online_pseudo: 658 | if epoch_all < (10*2*epoch_ratio): 659 | weight = 0.0 660 | else: 661 | weight = 0.05 662 | 663 | id_smi = [id[index] for index in list(indexs_semi.cpu().numpy().reshape([-1]))] 664 | id_smi_pseudo = get_pseudo_labels(pseudo_labels, id_smi, all_semi_ids).cuda() 665 | id_smi_pseudo = to_var(id_smi_pseudo) 666 | 667 | pseudo_loss = softmax_loss(logit[indexs_semi], id_smi_pseudo) * weight 668 | pseudo_loss_log = pseudo_loss.data.cpu().numpy() 669 | supervised_loss = softmax_loss(logit_arc, truth_[indexs_supervised]) * (1.0 - weight) 670 | 671 | else: 672 | pseudo_loss = 0.0 673 | pseudo_loss_log = 0.0 674 | supervised_loss = softmax_loss(logit_arc, truth_[indexs_supervised]) 675 | 676 | ##################### loss #################################################################################### 677 | loss = supervised_loss + consistency_loss + pseudo_loss 678 | 679 | _, precision = metric(logit_arc, truth_[indexs_supervised]) 680 | 681 | loss.backward() 682 | optimizer.step() 683 | optimizer.zero_grad() 684 | 685 | batch_loss[:4] = np.array((precision[0].data.cpu().numpy(), 686 | supervised_loss.data.cpu().numpy(), 687 | pseudo_loss_log, 688 | consistency_loss.data.cpu().numpy())).reshape([4]) 689 | 690 | sum_train_loss += batch_loss 691 | sum += 1 692 | 693 | if epoch_all > 20: 694 | alpha_ema = 0.999 695 | elif epoch_all > 10: 696 | alpha_ema = 0.99 697 | elif epoch_all > 1: 698 | alpha_ema = 0.9 699 | else: 700 | alpha_ema = 0.5 701 | 702 | update_ema_variables(net, ema_net, alpha_ema, i) 703 | 704 | if iter%iter_smooth == 0: 705 | sum_train_loss = np.zeros(6,np.float32) 706 | sum = 0 707 | 708 | if i % 10 == 0: 709 | print(model_name +' finetune '+ cell +'%6.1f %0.7f %5.1f %6.1f | %0.3f %0.3f %0.3f (%0.3f)%s | %0.3f %0.3f %0.3f (%0.3f)%s | %0.3f %0.3f %0.3f (%0.3f) | %s' % (\ 710 | cycle_index, rate, iter, epoch, 711 | valid_loss[0], valid_loss[1], valid_loss[2], valid_loss[3],' ', 712 | valid_loss_ema[0], valid_loss_ema[1], valid_loss_ema[2], valid_loss_ema[3], ' ', 713 | batch_loss[0], batch_loss[1], batch_loss[2], batch_loss[3], 714 | time_to_str((timer() - start),'min'))) 715 | 716 | 717 | 718 | if i % 100 == 0: 719 | log.write('%6.1f %0.7f %5.1f %6.1f | %0.3f %0.3f %0.3f (%0.3f)%s |%0.3f %0.3f %0.3f (%0.3f)%s | %0.3f %0.3f %0.3f (%0.3f) | %s' % (\ 720 | cycle_index, rate, iter, epoch, 721 | valid_loss[0], valid_loss[1], valid_loss[2], valid_loss[3],' ', 722 | valid_loss_ema[0], valid_loss_ema[1], valid_loss_ema[2], valid_loss_ema[3], ' ', 723 | batch_loss[0], batch_loss[1], batch_loss[2], batch_loss[3], 724 | time_to_str((timer() - start),'min'))) 725 | 726 | log.write('\n') 727 | 728 | i=i+1 729 | 730 | def run_infer_oof_finetuned(config, cell, initial_checkpoint): 731 | 732 | if config.rgb: 733 | model = config.model + '_rgb' 734 | else: 735 | model = config.model + '_6channel' 736 | 737 | batch_size = config.batch_size 738 | ## setup ----------------------------------------------------------------------------- 739 | net = get_model(model) 740 | net = torch.nn.DataParallel(net) 741 | 742 | if initial_checkpoint is not None: 743 | print(initial_checkpoint) 744 | net.load_state_dict(torch.load(initial_checkpoint, map_location=lambda storage, loc: storage)) 745 | 746 | net = net.cuda() 747 | net.eval() 748 | 749 | augments = [] 750 | # 8TTA 751 | augments.append([0, 0, 0]) 752 | augments.append([0, 0, 1]) 753 | augments.append([0, 1, 0]) 754 | augments.append([0, 1, 1]) 755 | augments.append([1, 0, 0]) 756 | augments.append([1, 0, 1]) 757 | augments.append([1, 1, 0]) 758 | augments.append([1, 1, 1]) 759 | 760 | predict_num = 1108 761 | 762 | def get_valid_label(fold=0): 763 | df = pd.read_csv(path_data + '/valid_fold_' + str(fold) + '.csv') 764 | if cell is not None: 765 | df_all = [] 766 | for i in range(100): 767 | df_ = pd.DataFrame(df.loc[df['experiment'] == (cell + '-' + str(100 + i)[-2:])]) 768 | df_all.append(df_) 769 | df = pd.concat(df_all) 770 | 771 | val_label = np.asarray(list(df[r'sirna'])).reshape([-1]) 772 | return val_label, list(df[r'id_code']) 773 | 774 | val_label, ids = get_valid_label(fold=config.train_fold_index) 775 | 776 | #####################################infer valid############################################# 777 | probs_all = [] 778 | for index in range(len(augments)): 779 | for site in range(2): 780 | 781 | valid_dataset = Dataset_(path_data + '/valid_fold_' + str(config.train_fold_index) + '.csv', 782 | data_dir, 783 | mode='valid', 784 | image_size=config.image_size, 785 | site=site+1, 786 | rgb=config.rgb, 787 | cell=cell, 788 | augment=augments[index]) 789 | 790 | valid_loader = DataLoaderX(valid_dataset, 791 | shuffle=False, 792 | batch_size=batch_size, 793 | drop_last=False, 794 | num_workers=4, 795 | pin_memory=True) 796 | 797 | # infer test 798 | test_ids = [] 799 | probs = [] 800 | from tqdm import tqdm 801 | for i,(input, label, id) in enumerate(tqdm(valid_loader)): 802 | test_ids += id 803 | input = input.cuda() 804 | input = to_var(input) 805 | logit = net.forward(input) 806 | logit = logit[:, 0:predict_num] 807 | 808 | prob = logit 809 | probs += prob.data.cpu().numpy().tolist() 810 | 811 | probs = np.asarray(probs) 812 | 813 | val_ = np.argmax(probs, axis=1).reshape([-1]) 814 | print(np.mean(val_ == val_label)) 815 | 816 | probs_all.append(probs) 817 | 818 | probs_all = np.mean(probs_all,axis=0) 819 | print(probs_all.shape) 820 | 821 | val_ = np.argmax(probs_all, axis=1).reshape([-1]) 822 | print(np.mean(val_ == val_label)) 823 | 824 | save_path = initial_checkpoint.replace('.pth', '_npy') 825 | if not os.path.exists(save_path): 826 | os.makedirs(save_path) 827 | 828 | save_path = os.path.join(save_path, cell+'_val.npy') 829 | print(save_path) 830 | np.save(save_path, probs_all) 831 | 832 | # return 833 | #####################################infer test############################################# 834 | probs_all = [] 835 | for index in range(len(augments)): 836 | for site in range(2): 837 | valid_dataset = Dataset_(data_dir+'/test.csv', 838 | data_dir, 839 | mode='test', 840 | image_size=config.image_size, 841 | site=site+1, 842 | rgb=config.rgb, 843 | cell=cell, 844 | augment=augments[index]) 845 | 846 | valid_loader = DataLoaderX(valid_dataset, 847 | shuffle=False, 848 | batch_size=batch_size, 849 | drop_last=False, 850 | num_workers=4, 851 | pin_memory=True) 852 | 853 | # infer test 854 | test_ids = [] 855 | probs = [] 856 | from tqdm import tqdm 857 | for i,(id, input) in enumerate(tqdm(valid_loader)): 858 | test_ids += id 859 | input = input.cuda() 860 | input = to_var(input) 861 | logit = net.forward(input) 862 | logit = logit[:, 0:predict_num] 863 | 864 | prob = logit 865 | probs += prob.data.cpu().numpy().tolist() 866 | 867 | probs = np.asarray(probs) 868 | probs_all.append(probs) 869 | 870 | probs_all = np.mean(probs_all,axis=0) 871 | print(probs_all.shape) 872 | 873 | save_path = initial_checkpoint.replace('.pth', '_npy') 874 | if not os.path.exists(save_path): 875 | os.makedirs(save_path) 876 | 877 | save_path = os.path.join(save_path, cell + '_test.npy') 878 | print(save_path) 879 | np.save(save_path, probs_all) 880 | 881 | def main(config): 882 | 883 | if config.mode == 'pretrain': 884 | # pretrain on all cell types 885 | run_pretrain(config) 886 | 887 | 888 | elif config.mode == 'semi_finetune': 889 | # semi-supervised learning finetune on each cell type 890 | cell_types = ['U2OS','RPE','HEPG2','HUVEC'] 891 | epoch_ratios = [ 1.0, 0.5, 0.5, 0.25] 892 | 893 | for cell, epoch_ratio in zip(cell_types, epoch_ratios): 894 | config.pretrained_model = r'max_valid_model.pth' 895 | config.finetune_tag = r'semi' 896 | run_finetune_CELL_cyclic_semi_final(config, 897 | cell = cell, 898 | fold_index = config.train_fold_index, 899 | tag = config.finetune_tag, 900 | epoch_ratio = epoch_ratio, 901 | online_pseudo = True) 902 | 903 | 904 | elif config.mode == 'infer': 905 | 906 | if config.rgb: 907 | model = config.model + '_rgb' 908 | else: 909 | model = config.model + '_6channel' 910 | 911 | cell_types = ['U2OS','RPE','HEPG2','HUVEC'] 912 | for cell in cell_types: 913 | model_name = model + '_' + str(config.image_size) + '_fold' + str(config.train_fold_index) + '_' + config.tag 914 | out_dir = os.path.join(model_save_path, model_name) 915 | initial_checkpoint = os.path.join(out_dir, 'checkpoint', 'max_valid_model_' + cell + '_semi_snapshot_all.pth') 916 | 917 | with torch.no_grad(): 918 | run_infer_oof_finetuned(config, cell=cell, initial_checkpoint=initial_checkpoint) 919 | 920 | if __name__ == '__main__': 921 | 922 | parser = argparse.ArgumentParser() 923 | parser.add_argument('--tag', type=str, default='final') 924 | parser.add_argument('--train_fold_index', type=int, default = 0) 925 | parser.add_argument('--rgb', type=bool, default=False) 926 | parser.add_argument('--is_arc', type=bool, default=True) 927 | parser.add_argument('--arc_s', type=float, default=30.0) 928 | parser.add_argument('--arc_m', type=float, default=0.1) 929 | parser.add_argument('--bias_no_decay', type=bool, default=False) 930 | 931 | parser.add_argument('--model', type=str, default='xception_large') 932 | parser.add_argument('--batch_size', type=int, default=1) 933 | parser.add_argument('--image_size', type=int, default=512) 934 | parser.add_argument('--softmax_w', type=float, default=1.0) 935 | 936 | parser.add_argument('--mode', type=str, default='pretrain', choices=['pretrain', 'semi_finetune', 'infer']) 937 | parser.add_argument('--pretrained_model', type=str, default=None) 938 | 939 | parser.add_argument('--iter_save_interval', type=int, default=10) 940 | parser.add_argument('--train_epoch', type=int, default=80) 941 | 942 | config = parser.parse_args() 943 | print(config) 944 | main(config) 945 | 946 | 947 | 948 | -------------------------------------------------------------------------------- /net/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Scu-sen/Recursion-Cellular-Image-Classification-Challenge/a45e48905b673ea89763f0648623c89fbddb85f1/net/__init__.py -------------------------------------------------------------------------------- /net/archead.py: -------------------------------------------------------------------------------- 1 | import types 2 | import math 3 | #-*-coding:utf-8-*- 4 | import torch 5 | import torch.nn as nn 6 | import math 7 | from torch.nn import Parameter 8 | from torch.autograd import Variable 9 | from torch.autograd import Function 10 | import torch.nn.functional as F 11 | import numpy as np 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | 16 | 17 | class ArcModule(nn.Module): 18 | def __init__(self, in_features, out_features, s=65, m=0.5): 19 | super().__init__() 20 | self.in_features = in_features 21 | self.out_features = out_features 22 | self.s = s 23 | self.m = m 24 | self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features)) 25 | nn.init.xavier_normal_(self.weight) 26 | 27 | self.cos_m = math.cos(m) 28 | self.sin_m = math.sin(m) 29 | self.th = math.cos(math.pi - m) 30 | self.mm = math.sin(math.pi - m) * m 31 | 32 | def forward(self, inputs, labels = None): 33 | cos_th = F.linear(inputs, F.normalize(self.weight)) 34 | cos_th = cos_th.clamp(-1, 1) 35 | sin_th = torch.sqrt(1.0 - torch.pow(cos_th, 2)) 36 | cos_th_m = cos_th * self.cos_m - sin_th * self.sin_m 37 | cos_th_m = torch.where(cos_th > self.th, cos_th_m, cos_th - self.mm) 38 | 39 | cond_v = cos_th - self.th 40 | cond = cond_v <= 0 41 | cos_th_m[cond] = (cos_th - self.mm)[cond] 42 | 43 | onehot = torch.zeros(cos_th.size()).cuda() 44 | 45 | if labels is not None: 46 | if labels.dim() == 1: 47 | labels = labels.unsqueeze(-1) 48 | onehot.scatter_(1, labels, 1) 49 | 50 | outputs = onehot * cos_th_m + (1.0 - onehot) * cos_th 51 | outputs = outputs * self.s 52 | return outputs 53 | 54 | def reset_parameters(self): 55 | stdv = 1. / math.sqrt(self.weight.size(1)) 56 | self.weight.data.uniform_(-stdv, stdv) 57 | 58 | 59 | class CombineMarginModule(nn.Module): 60 | def __init__(self, in_features, out_features, s=65, m1 = 1.0, m2 = 0.1, m3 = 0.05): 61 | super().__init__() 62 | self.in_features = in_features 63 | self.out_features = out_features 64 | self.s = s 65 | self.m1 = m1 66 | self.m2 = m2 67 | self.m3 = m3 68 | 69 | self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features)) 70 | nn.init.xavier_normal_(self.weight) 71 | 72 | self.cos_m2 = math.cos(m2) 73 | self.sin_m2 = math.sin(m2) 74 | self.th2 = math.cos(math.pi - m2) 75 | self.mm2 = math.sin(math.pi - m2) * m2 76 | 77 | def forward(self, inputs, labels = None): 78 | cos_th = F.linear(inputs, F.normalize(self.weight)) 79 | cos_th = cos_th.clamp(-1, 1) 80 | sin_th = torch.sqrt(1.0 - torch.pow(cos_th, 2)) 81 | 82 | onehot = torch.zeros(cos_th.size()).cuda() 83 | if labels is not None: 84 | if labels.dim() == 1: 85 | labels = labels.unsqueeze(-1) 86 | onehot.scatter_(1, labels, 1) 87 | 88 | # m2 89 | cos_th_m2 = cos_th * self.cos_m2 - sin_th * self.sin_m2 90 | cos_th_m2 = torch.where(cos_th > self.th2, cos_th_m2, cos_th - self.mm2) 91 | cond_v = cos_th - self.th2 92 | cond = cond_v <= 0 93 | cos_th_m2[cond] = (cos_th - self.mm2)[cond] 94 | 95 | # m3 96 | cos_th_m2_m3 = cos_th_m2 - self.m3 97 | 98 | outputs = onehot * cos_th_m2_m3 + (1.0 - onehot) * cos_th 99 | outputs = outputs * self.s 100 | return outputs 101 | 102 | 103 | -------------------------------------------------------------------------------- /net/imagenet_pretrain_model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Scu-sen/Recursion-Cellular-Image-Classification-Challenge/a45e48905b673ea89763f0648623c89fbddb85f1/net/imagenet_pretrain_model/__init__.py -------------------------------------------------------------------------------- /net/imagenet_pretrain_model/include.py: -------------------------------------------------------------------------------- 1 | import os 2 | from datetime import datetime 3 | PROJECT_PATH = os.path.dirname(os.path.realpath(__file__)) 4 | IDENTIFIER = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') 5 | 6 | 7 | #numerical libs 8 | import math 9 | import numpy as np 10 | import random 11 | import PIL 12 | import cv2 13 | import matplotlib 14 | matplotlib.use('TkAgg') 15 | #matplotlib.use('WXAgg') 16 | #matplotlib.use('Qt4Agg') 17 | #matplotlib.use('Qt5Agg') #Qt4Agg 18 | print(matplotlib.get_backend()) 19 | #print(matplotlib.__version__) 20 | 21 | 22 | # torch libs 23 | import torch 24 | from torch.utils.data.dataset import Dataset 25 | from torch.utils.data import DataLoader 26 | from torch.utils.data.sampler import * 27 | 28 | import torch.nn as nn 29 | import torch.nn.functional as F 30 | import torch.optim as optim 31 | from torch.nn.parallel.data_parallel import data_parallel 32 | 33 | 34 | # std libs 35 | import collections 36 | import copy 37 | import numbers 38 | import inspect 39 | import shutil 40 | from timeit import default_timer as timer 41 | import itertools 42 | from collections import OrderedDict 43 | 44 | import csv 45 | import pandas as pd 46 | import pickle 47 | import glob 48 | import sys 49 | from distutils.dir_util import copy_tree 50 | import time 51 | 52 | import matplotlib.pyplot as plt 53 | from mpl_toolkits.mplot3d import Axes3D 54 | 55 | from skimage.transform import resize as skimage_resize 56 | 57 | 58 | 59 | 60 | # constant # 61 | PI = np.pi 62 | INF = np.inf 63 | EPS = 1e-12 64 | 65 | 66 | 67 | -------------------------------------------------------------------------------- /net/imagenet_pretrain_model/xception.py: -------------------------------------------------------------------------------- 1 | """ 2 | Ported to pytorch thanks to [tstandley](https://github.com/tstandley/Xception-PyTorch) 3 | 4 | @author: tstandley 5 | Adapted by cadene 6 | 7 | Creates an Xception Model as defined in: 8 | 9 | Francois Chollet 10 | Xception: Deep Learning with Depthwise Separable Convolutions 11 | https://arxiv.org/pdf/1610.02357.pdf 12 | 13 | This weights ported from the Keras implementation. Achieves the following performance on the validation set: 14 | 15 | Loss:0.9173 Prec@1:78.892 Prec@5:94.292 16 | 17 | REMEMBER to set your image size to 3x299x299 for both test and validation 18 | 19 | normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], 20 | std=[0.5, 0.5, 0.5]) 21 | 22 | The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299 23 | """ 24 | from __future__ import print_function, division, absolute_import 25 | import math 26 | import torch 27 | import torch.nn as nn 28 | import torch.nn.functional as F 29 | import torch.utils.model_zoo as model_zoo 30 | from torch.nn import init 31 | 32 | __all__ = ['xception', 'xception_large'] 33 | 34 | pretrained_settings = { 35 | 'xception': { 36 | 'imagenet': { 37 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/xception-43020ad28.pth', 38 | 'input_space': 'RGB', 39 | 'input_size': [3, 299, 299], 40 | 'input_range': [0, 1], 41 | 'mean': [0.5, 0.5, 0.5], 42 | 'std': [0.5, 0.5, 0.5], 43 | 'num_classes': 1000, 44 | 'scale': 0.8975 # The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299 45 | } 46 | } 47 | } 48 | 49 | 50 | class SeparableConv2d(nn.Module): 51 | def __init__(self,in_channels,out_channels,kernel_size=1,stride=1,padding=0,dilation=1,bias=False): 52 | super(SeparableConv2d,self).__init__() 53 | 54 | self.conv1 = nn.Conv2d(in_channels,in_channels,kernel_size,stride,padding,dilation,groups=in_channels,bias=bias) 55 | self.pointwise = nn.Conv2d(in_channels,out_channels,1,1,0,1,1,bias=bias) 56 | 57 | def forward(self,x): 58 | x = self.conv1(x) 59 | x = self.pointwise(x) 60 | return x 61 | 62 | 63 | class Block(nn.Module): 64 | def __init__(self,in_filters,out_filters,reps,strides=1,start_with_relu=True,grow_first=True): 65 | super(Block, self).__init__() 66 | 67 | if out_filters != in_filters or strides!=1: 68 | self.skip = nn.Conv2d(in_filters,out_filters,1,stride=strides, bias=False) 69 | self.skipbn = nn.BatchNorm2d(out_filters) 70 | else: 71 | self.skip=None 72 | 73 | self.relu = nn.ReLU(inplace=True) 74 | rep=[] 75 | 76 | filters=in_filters 77 | if grow_first: 78 | rep.append(self.relu) 79 | rep.append(SeparableConv2d(in_filters,out_filters,3,stride=1,padding=1,bias=False)) 80 | rep.append(nn.BatchNorm2d(out_filters)) 81 | filters = out_filters 82 | 83 | for i in range(reps-1): 84 | rep.append(self.relu) 85 | rep.append(SeparableConv2d(filters,filters,3,stride=1,padding=1,bias=False)) 86 | rep.append(nn.BatchNorm2d(filters)) 87 | 88 | if not grow_first: 89 | rep.append(self.relu) 90 | rep.append(SeparableConv2d(in_filters,out_filters,3,stride=1,padding=1,bias=False)) 91 | rep.append(nn.BatchNorm2d(out_filters)) 92 | 93 | if not start_with_relu: 94 | rep = rep[1:] 95 | else: 96 | rep[0] = nn.ReLU(inplace=False) 97 | 98 | if strides != 1: 99 | rep.append(nn.MaxPool2d(3,strides,1)) 100 | self.rep = nn.Sequential(*rep) 101 | 102 | def forward(self,inp): 103 | x = self.rep(inp) 104 | 105 | if self.skip is not None: 106 | skip = self.skip(inp) 107 | skip = self.skipbn(skip) 108 | else: 109 | skip = inp 110 | 111 | x+=skip 112 | return x 113 | 114 | 115 | class Xception(nn.Module): 116 | """ 117 | Xception optimized for the ImageNet dataset, as specified in 118 | https://arxiv.org/pdf/1610.02357.pdf 119 | """ 120 | def __init__(self, num_classes=1000): 121 | """ Constructor 122 | Args: 123 | num_classes: number of classes 124 | """ 125 | super(Xception, self).__init__() 126 | self.num_classes = num_classes 127 | 128 | self.conv1 = nn.Conv2d(3, 32, 3,2, 0, bias=False) 129 | self.bn1 = nn.BatchNorm2d(32) 130 | self.relu = nn.ReLU(inplace=True) 131 | 132 | self.conv2 = nn.Conv2d(32,64,3,bias=False) 133 | self.bn2 = nn.BatchNorm2d(64) 134 | #do relu here 135 | 136 | self.block1=Block(64,128,2,2,start_with_relu=False,grow_first=True) 137 | self.block2=Block(128,256,2,2,start_with_relu=True,grow_first=True) 138 | self.block3=Block(256,728,2,2,start_with_relu=True,grow_first=True) 139 | 140 | self.block4=Block(728,728,3,1,start_with_relu=True,grow_first=True) 141 | self.block5=Block(728,728,3,1,start_with_relu=True,grow_first=True) 142 | self.block6=Block(728,728,3,1,start_with_relu=True,grow_first=True) 143 | self.block7=Block(728,728,3,1,start_with_relu=True,grow_first=True) 144 | 145 | self.block8=Block(728,728,3,1,start_with_relu=True,grow_first=True) 146 | self.block9=Block(728,728,3,1,start_with_relu=True,grow_first=True) 147 | self.block10=Block(728,728,3,1,start_with_relu=True,grow_first=True) 148 | self.block11=Block(728,728,3,1,start_with_relu=True,grow_first=True) 149 | 150 | self.block12=Block(728,1024,2,2,start_with_relu=True,grow_first=False) 151 | 152 | self.conv3 = SeparableConv2d(1024,1536,3,1,1) 153 | self.bn3 = nn.BatchNorm2d(1536) 154 | 155 | #do relu here 156 | self.conv4 = SeparableConv2d(1536,2048,3,1,1) 157 | self.bn4 = nn.BatchNorm2d(2048) 158 | 159 | self.fc = nn.Linear(2048, num_classes) 160 | 161 | # #------- init weights -------- 162 | # for m in self.modules(): 163 | # if isinstance(m, nn.Conv2d): 164 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 165 | # m.weight.data.normal_(0, math.sqrt(2. / n)) 166 | # elif isinstance(m, nn.BatchNorm2d): 167 | # m.weight.data.fill_(1) 168 | # m.bias.data.zero_() 169 | # #----------------------------- 170 | 171 | def features(self, input): 172 | x = self.conv1(input) 173 | x = self.bn1(x) 174 | x = self.relu(x) 175 | 176 | x = self.conv2(x) 177 | x = self.bn2(x) 178 | x = self.relu(x) 179 | 180 | x = self.block1(x) 181 | x = self.block2(x) 182 | x = self.block3(x) 183 | x = self.block4(x) 184 | x = self.block5(x) 185 | x = self.block6(x) 186 | x = self.block7(x) 187 | x = self.block8(x) 188 | x = self.block9(x) 189 | x = self.block10(x) 190 | x = self.block11(x) 191 | x = self.block12(x) 192 | 193 | x = self.conv3(x) 194 | x = self.bn3(x) 195 | x = self.relu(x) 196 | 197 | x = self.conv4(x) 198 | x = self.bn4(x) 199 | return x 200 | 201 | def logits(self, features): 202 | x = self.relu(features) 203 | 204 | x = F.adaptive_avg_pool2d(x, (1, 1)) 205 | x = x.view(x.size(0), -1) 206 | x = self.last_linear(x) 207 | return x 208 | 209 | def forward(self, input): 210 | x = self.features(input) 211 | x = self.logits(x) 212 | return x 213 | 214 | class Xception_Large(nn.Module): 215 | """ 216 | Xception optimized for the ImageNet dataset, as specified in 217 | https://arxiv.org/pdf/1610.02357.pdf 218 | """ 219 | def __init__(self, num_classes=1000): 220 | """ Constructor 221 | Args: 222 | num_classes: number of classes 223 | """ 224 | super(Xception_Large, self).__init__() 225 | self.num_classes = num_classes 226 | 227 | self.conv1 = nn.Conv2d(3, 64, 3,2, 0, bias=False) 228 | self.bn1 = nn.BatchNorm2d(64) 229 | self.relu = nn.ReLU(inplace=True) 230 | 231 | self.conv2 = nn.Conv2d(64,128,3,bias=False) 232 | self.bn2 = nn.BatchNorm2d(128) 233 | #do relu here 234 | 235 | self.block1=Block(128,256,2,2,start_with_relu=False,grow_first=True) 236 | self.block2=Block(256,384,2,2,start_with_relu=True,grow_first=True) 237 | self.block3=Block(384,1024,2,2,start_with_relu=True,grow_first=True) 238 | 239 | self.block4=Block(1024,1024,3,1,start_with_relu=True,grow_first=True) 240 | self.block5=Block(1024,1024,3,1,start_with_relu=True,grow_first=True) 241 | self.block6=Block(1024,1024,3,1,start_with_relu=True,grow_first=True) 242 | self.block7=Block(1024,1024,3,1,start_with_relu=True,grow_first=True) 243 | 244 | self.block8=Block(1024,1024,3,1,start_with_relu=True,grow_first=True) 245 | self.block9=Block(1024,1024,3,1,start_with_relu=True,grow_first=True) 246 | self.block10=Block(1024,1024,3,1,start_with_relu=True,grow_first=True) 247 | self.block11=Block(1024,1024,3,1,start_with_relu=True,grow_first=True) 248 | 249 | self.block12=Block(1024,1024,2,2,start_with_relu=True,grow_first=False) 250 | 251 | self.conv3 = SeparableConv2d(1024,1536,3,1,1) 252 | self.bn3 = nn.BatchNorm2d(1536) 253 | 254 | #do relu here 255 | self.conv4 = SeparableConv2d(1536,2048,3,1,1) 256 | self.bn4 = nn.BatchNorm2d(2048) 257 | 258 | self.fc = nn.Linear(2048, num_classes) 259 | 260 | # #------- init weights -------- 261 | # for m in self.modules(): 262 | # if isinstance(m, nn.Conv2d): 263 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 264 | # m.weight.data.normal_(0, math.sqrt(2. / n)) 265 | # elif isinstance(m, nn.BatchNorm2d): 266 | # m.weight.data.fill_(1) 267 | # m.bias.data.zero_() 268 | # #----------------------------- 269 | 270 | def features(self, input): 271 | x = self.conv1(input) 272 | x = self.bn1(x) 273 | x = self.relu(x) 274 | 275 | x = self.conv2(x) 276 | x = self.bn2(x) 277 | x = self.relu(x) 278 | 279 | x = self.block1(x) 280 | x = self.block2(x) 281 | x = self.block3(x) 282 | x = self.block4(x) 283 | x = self.block5(x) 284 | x = self.block6(x) 285 | x = self.block7(x) 286 | x = self.block8(x) 287 | x = self.block9(x) 288 | x = self.block10(x) 289 | x = self.block11(x) 290 | x = self.block12(x) 291 | 292 | x = self.conv3(x) 293 | x = self.bn3(x) 294 | x = self.relu(x) 295 | 296 | x = self.conv4(x) 297 | x = self.bn4(x) 298 | return x 299 | 300 | def logits(self, features): 301 | x = self.relu(features) 302 | 303 | x = F.adaptive_avg_pool2d(x, (1, 1)) 304 | x = x.view(x.size(0), -1) 305 | x = self.last_linear(x) 306 | return x 307 | 308 | def forward(self, input): 309 | x = self.features(input) 310 | x = self.logits(x) 311 | return x 312 | 313 | 314 | def xception_large(num_classes=1000, pretrained='imagenet'): 315 | model = Xception_Large(num_classes=num_classes) 316 | return model 317 | 318 | 319 | def xception(num_classes=1000, pretrained='imagenet'): 320 | model = Xception(num_classes=num_classes) 321 | if pretrained: 322 | settings = pretrained_settings['xception'][pretrained] 323 | assert num_classes == settings['num_classes'], \ 324 | "num_classes should be {}, but is {}".format(settings['num_classes'], num_classes) 325 | 326 | model = Xception(num_classes=num_classes) 327 | model.load_state_dict(model_zoo.load_url(settings['url'])) 328 | 329 | model.input_space = settings['input_space'] 330 | model.input_size = settings['input_size'] 331 | model.input_range = settings['input_range'] 332 | model.mean = settings['mean'] 333 | model.std = settings['std'] 334 | 335 | # TODO: ugly 336 | model.last_linear = model.fc 337 | del model.fc 338 | return model 339 | -------------------------------------------------------------------------------- /net/model_xception_6channel.py: -------------------------------------------------------------------------------- 1 | from net.imagenet_pretrain_model.xception import * 2 | from net.archead import * 3 | BatchNorm2d = nn.BatchNorm2d 4 | 5 | ###########################################################################################3 6 | class Net(nn.Module): 7 | 8 | def load_pretrain(self, pretrain_file, is_skip_fc = False): 9 | pretrain_state_dict = torch.load(pretrain_file) 10 | state_dict = self.state_dict() 11 | keys = list(state_dict.keys()) 12 | for key in keys: 13 | if is_skip_fc and 'fc' in key: 14 | print('skip: ' + key) 15 | continue 16 | state_dict[key] = pretrain_state_dict[r'module.'+key] 17 | 18 | self.load_state_dict(state_dict) 19 | print('') 20 | 21 | def __init__(self, num_class=340, is_arc = False, arc_s = None, arc_m = None): 22 | super(Net,self).__init__() 23 | self.basemodel = xception() 24 | self.basemodel.conv1 = nn.Conv2d(6, 32, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) 25 | self.basemodel.last_linear = nn.Sequential() 26 | 27 | self.is_arc = is_arc 28 | if self.is_arc: 29 | self.fc = ArcModule(2048, num_class, s=arc_s, m=arc_m) 30 | else: 31 | self.fc = nn.Sequential(nn.Dropout(), 32 | nn.Linear(2048, num_class)) 33 | 34 | 35 | def forward(self, x, label = None): 36 | x = self.basemodel.features(x) 37 | x = F.adaptive_avg_pool2d(x, (1, 1)) 38 | fea = x.view(x.size(0), -1) 39 | 40 | if self.is_arc: 41 | x = F.normalize(fea) 42 | x = self.fc(x, label) 43 | else: 44 | x = self.fc(fea) 45 | return x, fea 46 | 47 | def set_mode(self, mode, is_freeze_bn=False ): 48 | self.mode = mode 49 | if mode in ['eval', 'valid', 'test']: 50 | self.eval() 51 | elif mode in ['train']: 52 | self.train() 53 | if is_freeze_bn==True: ##freeze 54 | for m in self.modules(): 55 | if isinstance(m, BatchNorm2d): 56 | m.eval() 57 | m.weight.requires_grad = False 58 | m.bias.requires_grad = False 59 | 60 | ### run ############################################################################## 61 | def run_check_net(): 62 | net = Net(num_class=1000).cuda() 63 | net.set_mode('train') 64 | print(net) 65 | 66 | 67 | ######################################################################################## 68 | if __name__ == '__main__': 69 | # print( '%s: calling main function ... ' % os.path.basename(__file__)) 70 | run_check_net() 71 | print( 'sucessful!') -------------------------------------------------------------------------------- /net/model_xception_large_6channel.py: -------------------------------------------------------------------------------- 1 | from net.imagenet_pretrain_model.xception import * 2 | from net.archead import * 3 | BatchNorm2d = nn.BatchNorm2d 4 | 5 | ###########################################################################################3 6 | class Net(nn.Module): 7 | 8 | def load_pretrain(self, pretrain_file, is_skip_fc = False): 9 | pretrain_state_dict = torch.load(pretrain_file) 10 | state_dict = self.state_dict() 11 | keys = list(state_dict.keys()) 12 | for key in keys: 13 | if is_skip_fc and 'fc' in key: 14 | print('skip: ' + key) 15 | continue 16 | state_dict[key] = pretrain_state_dict[r'module.'+key] 17 | 18 | self.load_state_dict(state_dict) 19 | print('') 20 | 21 | def __init__(self, num_class=340, is_arc = False, arc_s = None, arc_m = None): 22 | super(Net,self).__init__() 23 | self.basemodel = xception_large() 24 | self.basemodel.conv1 = nn.Conv2d(6, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) 25 | self.basemodel.last_linear = nn.Sequential() 26 | 27 | self.is_arc = is_arc 28 | if self.is_arc: 29 | self.fc = ArcModule(2048, num_class, s=arc_s, m=arc_m) 30 | else: 31 | self.fc = nn.Sequential(nn.Dropout(), 32 | nn.Linear(2048, num_class)) 33 | 34 | 35 | def forward(self, x, label = None): 36 | x = self.basemodel.features(x) 37 | x = F.adaptive_avg_pool2d(x, (1, 1)) 38 | x = x.view(x.size(0), -1) 39 | 40 | if self.is_arc: 41 | x = F.normalize(x) 42 | x = self.fc(x, label) 43 | else: 44 | x = self.fc(x) 45 | return x 46 | 47 | def set_mode(self, mode, is_freeze_bn=False ): 48 | self.mode = mode 49 | if mode in ['eval', 'valid', 'test']: 50 | self.eval() 51 | elif mode in ['train']: 52 | self.train() 53 | if is_freeze_bn==True: ##freeze 54 | for m in self.modules(): 55 | if isinstance(m, BatchNorm2d): 56 | m.eval() 57 | m.weight.requires_grad = False 58 | m.bias.requires_grad = False 59 | 60 | ### run ############################################################################## 61 | def run_check_net(): 62 | net = Net(num_class=1000).cuda() 63 | net.set_mode('train') 64 | print(net) 65 | 66 | 67 | ######################################################################################## 68 | if __name__ == '__main__': 69 | # print( '%s: calling main function ... ' % os.path.basename(__file__)) 70 | run_check_net() 71 | print( 'sucessful!') -------------------------------------------------------------------------------- /process/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Scu-sen/Recursion-Cellular-Image-Classification-Challenge/a45e48905b673ea89763f0648623c89fbddb85f1/process/__init__.py -------------------------------------------------------------------------------- /process/augmentation.py: -------------------------------------------------------------------------------- 1 | from imgaug import augmenters as iaa 2 | import cv2 3 | import numpy as np 4 | import random 5 | import math 6 | #===================================================paug=============================================================== 7 | def order_points(pts): 8 | # initialzie a list of coordinates that will be ordered 9 | # such that the first entry in the list is the top-left, 10 | # the second entry is the top-right, the third is the 11 | # bottom-right, and the fourth is the bottom-left 12 | rect = np.zeros((4, 2), dtype="float32") 13 | # the top-left point will have the smallest sum, whereas 14 | # the bottom-right point will have the largest sum 15 | s = pts.sum(axis=1) 16 | rect[0] = pts[np.argmin(s)] 17 | rect[2] = pts[np.argmax(s)] 18 | # now, compute the difference between the points, the 19 | # top-right point will have the smallest difference, 20 | # whereas the bottom-left will have the largest difference 21 | diff = np.diff(pts, axis=1) 22 | rect[1] = pts[np.argmin(diff)] 23 | rect[3] = pts[np.argmax(diff)] 24 | # return the ordered coordinates 25 | return rect 26 | 27 | def four_point_transform(image, pts): 28 | # obtain a consistent order of the points and unpack them 29 | # individually 30 | rect = order_points(pts) 31 | original = np.array([[0, 0], 32 | [image.shape[1] - 1, 0], 33 | [image.shape[1] - 1, image.shape[0] - 1], 34 | [0, image.shape[0] - 1]], dtype="float32") 35 | 36 | M = cv2.getPerspectiveTransform(original, rect) 37 | warped = cv2.warpPerspective(image, M, (image.shape[1], image.shape[0])) 38 | return warped 39 | 40 | def Perspective_aug(img, threshold1 = 0.25, threshold2 = 0.75): 41 | # img = cv2.imread(img_name) 42 | rows, cols, ch = img.shape 43 | 44 | x0,y0 = random.randint(0, int(cols * threshold1)), random.randint(0, int(rows * threshold1)) 45 | x1,y1 = random.randint(int(cols * threshold2), cols - 1), random.randint(0, int(rows * threshold1)) 46 | x2,y2 = random.randint(int(cols * threshold2), cols - 1), random.randint(int(rows * threshold2), rows - 1) 47 | x3,y3 = random.randint(0, int(cols * threshold1)), random.randint(int(rows * threshold2), rows - 1) 48 | pts = np.float32([(x0,y0), 49 | (x1,y1), 50 | (x2,y2), 51 | (x3,y3)]) 52 | 53 | warped = four_point_transform(img, pts) 54 | 55 | x_ = np.asarray([x0, x1, x2, x3]) 56 | y_ = np.asarray([y0, y1, y2, y3]) 57 | 58 | min_x = np.min(x_) 59 | max_x = np.max(x_) 60 | min_y = np.min(y_) 61 | max_y = np.max(y_) 62 | 63 | warped = warped[min_y:max_y,min_x:max_x,:] 64 | return warped 65 | 66 | def randomHorizontalFlip(image, u=0.5): 67 | if np.random.random() < u: 68 | image = cv2.flip(image, 1) 69 | return image 70 | 71 | def randomVerticleFlip(image, u=0.5): 72 | if np.random.random() < u: 73 | image = cv2.flip(image, 0) 74 | return image 75 | 76 | def randomRotate90(image, u=0.5): 77 | if np.random.random() < u: 78 | image[:,:,0:3] = np.rot90(image[:,:,0:3]) 79 | image[:,:,3:6] = np.rot90(image[:,:,3:6]) 80 | 81 | return image 82 | #===================================================origin============================================================= 83 | def random_cropping(image, ratio=0.8, is_random = True): 84 | height, width, _ = image.shape 85 | target_h = int(height*ratio) 86 | target_w = int(width*ratio) 87 | 88 | if is_random: 89 | start_x = random.randint(0, width - target_w) 90 | start_y = random.randint(0, height - target_h) 91 | else: 92 | start_x = ( width - target_w ) // 2 93 | start_y = ( height - target_h ) // 2 94 | 95 | zeros = image[start_y:start_y+target_h,start_x:start_x+target_w,:] 96 | zeros = cv2.resize(zeros ,(width,height)) 97 | return zeros 98 | 99 | def cropping(image, ratio=0.8, code = 0): 100 | height, width, _ = image.shape 101 | target_h = int(height*ratio) 102 | target_w = int(width*ratio) 103 | 104 | if code==0: 105 | start_x = ( width - target_w ) // 2 106 | start_y = ( height - target_h ) // 2 107 | 108 | elif code == 1: 109 | start_x = 0 110 | start_y = 0 111 | 112 | elif code == 2: 113 | start_x = width - target_w 114 | start_y = 0 115 | 116 | elif code == 3: 117 | start_x = 0 118 | start_y = height - target_h 119 | 120 | elif code == 4: 121 | start_x = width - target_w 122 | start_y = height - target_h 123 | 124 | elif code == -1: 125 | return image 126 | 127 | zeros = image[start_y:start_y+target_h,start_x:start_x+target_w,:] 128 | zeros = cv2.resize(zeros ,(width,height)) 129 | return zeros 130 | 131 | def random_erasing(img, probability=0.5, sl=0.02, sh=0.4, r1=0.3): 132 | if random.uniform(0, 1) > probability: 133 | return img 134 | 135 | for attempt in range(100): 136 | area = img.shape[0] * img.shape[1] 137 | 138 | target_area = random.uniform(sl, sh) * area 139 | aspect_ratio = random.uniform(r1, 1 / r1) 140 | 141 | h = int(round(math.sqrt(target_area * aspect_ratio))) 142 | w = int(round(math.sqrt(target_area / aspect_ratio))) 143 | 144 | if w < img.shape[1] and h < img.shape[0]: 145 | x1 = random.randint(0, img.shape[0] - h) 146 | y1 = random.randint(0, img.shape[1] - w) 147 | if img.shape[2] == 6: 148 | img[x1:x1 + h, y1:y1 + w, :] = 0.0 149 | else: 150 | print('!!!!!!!! random_erasing dim wrong!!!!!!!!!!!') 151 | return 152 | 153 | return img 154 | 155 | return img 156 | 157 | def randomShiftScaleRotate(image, 158 | shift_limit=(-0.0, 0.0), 159 | scale_limit=(-0.0, 0.0), 160 | rotate_limit=(-0.0, 0.0), 161 | aspect_limit=(-0.0, 0.0), 162 | borderMode=cv2.BORDER_CONSTANT, u=0.5): 163 | if np.random.random() < u: 164 | height, width, channel = image.shape 165 | 166 | angle = np.random.uniform(rotate_limit[0], rotate_limit[1]) 167 | scale = np.random.uniform(1 + scale_limit[0], 1 + scale_limit[1]) 168 | aspect = np.random.uniform(1 + aspect_limit[0], 1 + aspect_limit[1]) 169 | sx = scale * aspect / (aspect ** 0.5) 170 | sy = scale / (aspect ** 0.5) 171 | dx = round(np.random.uniform(shift_limit[0], shift_limit[1]) * width) 172 | dy = round(np.random.uniform(shift_limit[0], shift_limit[1]) * height) 173 | 174 | cc = np.math.cos(angle / 180 * np.math.pi) * sx 175 | ss = np.math.sin(angle / 180 * np.math.pi) * sy 176 | rotate_matrix = np.array([[cc, -ss], [ss, cc]]) 177 | 178 | box0 = np.array([[0, 0], [width, 0], [width, height], [0, height], ]) 179 | box1 = box0 - np.array([width / 2, height / 2]) 180 | box1 = np.dot(box1, rotate_matrix.T) + np.array([width / 2 + dx, height / 2 + dy]) 181 | 182 | box0 = box0.astype(np.float32) 183 | box1 = box1.astype(np.float32) 184 | mat = cv2.getPerspectiveTransform(box0, box1) 185 | image = cv2.warpPerspective(image, mat, (width, height), flags=cv2.INTER_LINEAR, borderMode=borderMode, 186 | borderValue=( 187 | 0, 0, 188 | 0,)) 189 | 190 | return image 191 | 192 | 193 | def aug_image(image, is_infer=False, augment = None, type = None): 194 | 195 | if is_infer: 196 | # return image 197 | image = randomHorizontalFlip(image, augment[0]) 198 | image = randomVerticleFlip(image, augment[1]) 199 | image = randomRotate90(image, augment[2]) 200 | 201 | return image 202 | 203 | else: 204 | 205 | image = randomHorizontalFlip(image) 206 | image = randomVerticleFlip(image) 207 | image = randomRotate90(image) 208 | 209 | height, width, _ = image.shape 210 | image = randomShiftScaleRotate(image, 211 | shift_limit=(-0.1, 0.1), 212 | scale_limit=(-0.1, 0.1), 213 | aspect_limit=(-0.1, 0.1), 214 | rotate_limit=(-0, 0)) 215 | image = cv2.resize(image, (width, height)) 216 | 217 | if type == 'easy': 218 | return image 219 | 220 | image = random_erasing(image, probability=0.5, sl=0.02, sh=0.4, r1=0.3) 221 | 222 | if type == 'normal': 223 | return image 224 | 225 | ratio = random.uniform(0.8,0.99) 226 | image = random_cropping(image, ratio=ratio, is_random=True) 227 | 228 | if type == 'hard': 229 | return image 230 | 231 | print('!!!!!!!!!!!WRONG!!!!!!!!!!!') 232 | 233 | 234 | -------------------------------------------------------------------------------- /process/data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import pandas as pd 4 | from PIL import Image 5 | import torch 6 | import torch.nn as nn 7 | import torch.utils.data as D 8 | import torch.nn.functional as F 9 | import torchvision 10 | from sklearn.model_selection import KFold 11 | from tqdm import tqdm 12 | import cv2 13 | from torchvision import transforms as T 14 | from process.augmentation import * 15 | from settings import * 16 | 17 | 18 | id_index_dict = {'HEPG':0,'HUVEC':1,'RPE':2,'U2OS':3,} 19 | 20 | class Dataset_(D.Dataset): 21 | def __init__(self, csv_file, img_dir, mode='train', cell= None, image_size=512, site=1, 22 | augment = [], 23 | rgb=False, 24 | pseudo_csv = None, 25 | # Pseudo_Batch=r'U2OS-05', 26 | augment_param_dict = None, 27 | is_TransTwice=False): 28 | 29 | print(csv_file) 30 | 31 | self.channels = [1, 2, 3, 4, 5, 6] 32 | df = pd.read_csv(csv_file) 33 | self.image_size = image_size 34 | self.is_TransTwice = is_TransTwice 35 | 36 | self.site = site 37 | self.mode = mode 38 | self.img_dir = img_dir 39 | 40 | self.augment = augment 41 | self.rgb = rgb 42 | self.augment_param_dict = augment_param_dict 43 | 44 | if cell is not None: 45 | print('cell type: '+ cell) 46 | df_all = [] 47 | for i in range(100): 48 | df_ = pd.DataFrame(df.loc[df['experiment'] == (cell+'-'+str(100+i)[-2:])]) 49 | df_all.append(df_) 50 | df = pd.concat(df_all) 51 | 52 | self.records = df.to_records(index=False) 53 | self.len = self.records.shape[0] 54 | self.mean_std = pd.read_csv(os.path.join(mean_std_path, r'mean_std_plate.csv')) 55 | 56 | print(self.len) 57 | 58 | @staticmethod 59 | def _load_img_as_tensor(file_name): 60 | with Image.open(file_name) as img: 61 | return T.ToTensor()(img) 62 | 63 | def _get_img_path(self, index, channel): 64 | 65 | experiment, well, plate = self.records[index].experiment, self.records[index].well, self.records[index].plate 66 | path = '/'.join([self.img_dir, 'train', experiment, f'Plate{plate}', f'{well}_s{self.site}_w{channel}.png']) 67 | 68 | if not os.path.exists(path): 69 | experiment, well, plate = self.records[index].experiment, self.records[index].well, self.records[ 70 | index].plate 71 | path = '/'.join([self.img_dir, 'test', experiment, f'Plate{plate}', f'{well}_s{self.site}_w{channel}.png']) 72 | 73 | return path 74 | 75 | 76 | def __getitem__(self, index): 77 | 78 | if self.mode == 'train' or 'semi' in self.mode: 79 | if np.random.random() < 0.5: 80 | self.site = 2 81 | else: 82 | self.site = 1 83 | 84 | experiment, plate = self.records[index].experiment, self.records[index].plate 85 | 86 | def get_img(paths): 87 | mean = [] 88 | std = [] 89 | for ch in self.channels: 90 | tmp = self.mean_std[(self.mean_std['experiment'] == experiment) & (self.mean_std['plate'] == plate) & ( 91 | self.mean_std['channel'] == ch)]['mean'].values 92 | tmp1 = self.mean_std[(self.mean_std['experiment'] == experiment) & (self.mean_std['plate'] == plate) & ( 93 | self.mean_std['channel'] == ch)]['std'].values 94 | mean.append(tmp) 95 | std.append(tmp1) 96 | 97 | img = [] 98 | for j, img_path in enumerate(paths): 99 | tmp_img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE) 100 | tmp_img = cv2.resize(tmp_img, (self.image_size, self.image_size)) 101 | 102 | if self.mode == 'train' or 'semi' in self.mode: 103 | m_r = 0.1 * random.uniform(-1,1) 104 | s_r = 0.1 * random.uniform(-1,1) 105 | 106 | mean[j] = mean[j]*(1+m_r) 107 | std[j] = std[j]*(1+s_r) 108 | 109 | tmp_img = (tmp_img - mean[j]) / std[j] 110 | img.append(tmp_img) 111 | 112 | img = np.array(img) 113 | image = np.transpose(img, (1, 2, 0)) 114 | return image 115 | 116 | paths = [self._get_img_path(index, ch) for ch in self.channels] 117 | image = get_img(paths) 118 | 119 | if self.is_TransTwice: 120 | image_hard = image.copy() 121 | image_hard = aug_image(image_hard, type = 'hard') 122 | img_hard= np.transpose(image_hard, (2, 0, 1)) 123 | img_hard = img_hard.reshape([6, self.image_size, self.image_size]) 124 | img_hard = torch.FloatTensor(img_hard) 125 | 126 | image_easy = image.copy() 127 | image_easy = aug_image(image_easy, type = 'easy') 128 | img_easy = np.transpose(image_easy, (2, 0, 1)) 129 | img_easy= img_easy.reshape([6, self.image_size, self.image_size]) 130 | img_easy = torch.FloatTensor(img_easy) 131 | 132 | image = aug_image(image, type = 'normal') 133 | else: 134 | if self.mode == 'train': 135 | image = aug_image(image, type = 'normal') 136 | else: 137 | image = aug_image(image, is_infer=True, augment=self.augment) 138 | 139 | img = np.transpose(image, (2, 0, 1)) 140 | img = img.reshape([6, self.image_size, self.image_size]) 141 | img = torch.FloatTensor(img) 142 | 143 | if self.mode == 'train' or self.mode == 'valid' or self.mode == 'semi_valid': 144 | label = torch.from_numpy(np.array(self.records[index].sirna)) 145 | id_code = self.records[index].id_code 146 | 147 | if self.mode == 'semi_valid': 148 | label = torch.from_numpy(np.array(-1)) 149 | 150 | if self.is_TransTwice: 151 | return img, img_easy, img_hard, label, id_code 152 | 153 | return img, label, id_code 154 | 155 | elif self.mode == 'semi_test' and self.is_TransTwice: 156 | label = torch.from_numpy(np.array(-1)) 157 | id_code = self.records[index].id_code 158 | return img, img_easy, img_hard, label, id_code 159 | 160 | else: 161 | return self.records[index].id_code, img 162 | 163 | def __len__(self): 164 | """ 165 | Total number of samples in the dataset 166 | """ 167 | return self.len 168 | 169 | def run_check_train_data(): 170 | dataset = Dataset_(path_data + r'/train_fold_0.csv', 171 | data_dir, 172 | cell='U2OS', 173 | image_size=256, 174 | rgb=False) 175 | 176 | print(dataset) 177 | num = len(dataset) 178 | for m in range(num): 179 | i = np.random.choice(num) 180 | image, label, id_index = dataset[i] 181 | print(image.shape) 182 | print(label) 183 | print(id_index) 184 | 185 | def run_check_test_data(): 186 | dataset = Dataset_(data_dir+'/test.csv', data_dir, mode='semi_test',is_TransTwice=True) 187 | print(dataset) 188 | 189 | num = len(dataset) 190 | for m in range(num): 191 | i = np.random.choice(num) 192 | image,image_pwc,id,_ = dataset[i] 193 | 194 | print(image.size()) 195 | print(image_pwc.size()) 196 | print(id) 197 | # print(id) 198 | 199 | if __name__ == '__main__': 200 | run_check_train_data() 201 | -------------------------------------------------------------------------------- /process/mean_std_batch.csv: -------------------------------------------------------------------------------- 1 | ,experiment,channel,mean,std 2 | 0,HEPG2-01,1,37.947925205354565,33.33256088525673 3 | 1,HEPG2-01,2,26.399948734741706,13.752735954749486 4 | 2,HEPG2-01,3,34.72383564478391,17.484090425217307 5 | 3,HEPG2-01,4,40.01408919575927,19.907502303064415 6 | 4,HEPG2-01,5,43.96559181461087,26.302431225116717 7 | 5,HEPG2-01,6,30.80435258072692,15.014549661764766 8 | 6,HEPG2-02,1,33.40507252340193,33.035450522673024 9 | 7,HEPG2-02,2,27.128853740630213,16.185302620744363 10 | 8,HEPG2-02,3,29.58229715328712,15.628693276987999 11 | 9,HEPG2-02,4,46.01093650793101,26.953835862608113 12 | 10,HEPG2-02,5,32.501941801665666,20.137803742173144 13 | 11,HEPG2-02,6,31.14411666950622,16.697220067725826 14 | 12,HEPG2-03,1,12.908154589789254,17.6308216659134 15 | 13,HEPG2-03,2,12.969320964503599,11.322577154575745 16 | 14,HEPG2-03,3,10.257114345377142,6.209581089594636 17 | 15,HEPG2-03,4,18.05504085955682,14.777835623039344 18 | 16,HEPG2-03,5,4.275886320448541,3.4066356000947593 19 | 17,HEPG2-03,6,11.605919741964959,7.908519677347332 20 | 18,HEPG2-04,1,6.008256912231445,8.9205260307892 21 | 19,HEPG2-04,2,10.877087134819526,10.212378835059424 22 | 20,HEPG2-04,3,9.151439789053681,5.634087783476116 23 | 21,HEPG2-04,4,6.801836224345418,5.782148847029047 24 | 22,HEPG2-04,5,3.885868547798751,2.992832923477055 25 | 23,HEPG2-04,6,10.282782769822456,7.285510938695546 26 | 24,HEPG2-05,1,3.7524226362055,6.27931404928419 27 | 25,HEPG2-05,2,9.528092981933,9.575428212397807 28 | 26,HEPG2-05,3,6.468095378442244,5.048884256358043 29 | 27,HEPG2-05,4,6.044398538478009,5.977280910274096 30 | 28,HEPG2-05,5,2.43571556543375,2.8533324327259217 31 | 29,HEPG2-05,6,5.110681262883273,3.977204392054807 32 | 30,HEPG2-06,1,6.73843179739915,9.007139900169207 33 | 31,HEPG2-06,2,15.640715828189602,12.600068178733265 34 | 32,HEPG2-06,3,10.31616521655739,6.583289883332775 35 | 33,HEPG2-06,4,10.91490338375042,8.432110306930019 36 | 34,HEPG2-06,5,4.401273939516638,3.968203148074765 37 | 35,HEPG2-06,6,8.346202419949817,5.2121956442105 38 | 36,HEPG2-07,1,4.123609423443554,5.69532848880153 39 | 37,HEPG2-07,2,13.233933416227016,10.828468581031968 40 | 38,HEPG2-07,3,6.198933475773509,3.588383613138075 41 | 39,HEPG2-07,4,7.382063738505045,5.655692651761542 42 | 40,HEPG2-07,5,4.086066574778983,3.622635224546427 43 | 41,HEPG2-07,6,5.482998941002823,3.15007765113072 44 | 42,HEPG2-08,1,4.031396865844727,4.979457139812356 45 | 43,HEPG2-08,2,17.484409415944622,13.635926921055773 46 | 44,HEPG2-08,3,9.890447173440098,5.875292734938908 47 | 45,HEPG2-08,4,9.318601953799313,6.703630534403889 48 | 46,HEPG2-08,5,11.368786411300894,7.967187650483456 49 | 47,HEPG2-08,6,7.621498163689645,4.296077049199903 50 | 48,HEPG2-09,1,2.357357628933795,4.091330364732308 51 | 49,HEPG2-09,2,7.790863263142573,8.414867928285013 52 | 50,HEPG2-09,3,4.936864837423547,3.585973624879819 53 | 51,HEPG2-09,4,3.575661749034733,3.5423538547808207 54 | 52,HEPG2-09,5,3.0554958628369615,3.1835327582584796 55 | 53,HEPG2-09,6,3.5790727184964464,2.4331839084622784 56 | 54,HEPG2-10,1,5.526568466966802,7.02068163959014 57 | 55,HEPG2-10,2,16.88304150878609,12.524164857167689 58 | 56,HEPG2-10,3,10.278806615185427,6.052535981255654 59 | 57,HEPG2-10,4,10.200859483186301,7.192107275998087 60 | 58,HEPG2-10,5,8.432076283863612,6.602570540015926 61 | 59,HEPG2-10,6,8.729963407888041,5.006957386183275 62 | 60,HEPG2-11,1,4.518259136851241,6.198137149425403 63 | 61,HEPG2-11,2,15.709275299940652,12.29539700635501 64 | 62,HEPG2-11,3,9.901140841042123,5.752858902134519 65 | 63,HEPG2-11,4,8.253368946788758,5.979289479900843 66 | 64,HEPG2-11,5,8.26331015951265,6.5117157701660044 67 | 65,HEPG2-11,6,7.532963196048891,4.20487212514502 68 | 66,HUVEC-01,1,4.123237998454602,6.183822374295504 69 | 67,HUVEC-01,2,19.956227019235687,17.410373249709515 70 | 68,HUVEC-01,3,10.233502017987238,5.398731573324602 71 | 69,HUVEC-01,4,10.657919504425742,9.064901321697583 72 | 70,HUVEC-01,5,2.3267373713580044,2.482055290223624 73 | 71,HUVEC-01,6,12.136189259491958,6.825840134701069 74 | 72,HUVEC-02,1,4.448384289617662,7.142490812848778 75 | 73,HUVEC-02,2,20.802759995708218,17.162895221906414 76 | 74,HUVEC-02,3,8.89785906711182,4.8149996835136974 77 | 75,HUVEC-02,4,10.869584569683322,8.893834941525412 78 | 76,HUVEC-02,5,2.4429691070085995,2.896274343197857 79 | 77,HUVEC-02,6,11.343739261874905,6.062351429151439 80 | 78,HUVEC-03,1,4.523923174127356,6.936634859009913 81 | 79,HUVEC-03,2,16.64078368923881,14.336938879135944 82 | 80,HUVEC-03,3,9.584943955594843,4.255130282673752 83 | 81,HUVEC-03,4,8.839367722536062,7.1524739369239185 84 | 82,HUVEC-03,5,2.6843516965965173,2.503303756338322 85 | 83,HUVEC-03,6,11.400608490039776,5.504176030245998 86 | 84,HUVEC-04,1,4.4525722590359775,7.101131335707065 87 | 85,HUVEC-04,2,18.80776567428143,16.30188774996097 88 | 86,HUVEC-04,3,8.794710694969474,4.387230120693119 89 | 87,HUVEC-04,4,9.666565463140413,8.087362310190969 90 | 88,HUVEC-04,5,2.32545758996691,2.3326163080798894 91 | 89,HUVEC-04,6,10.819912167338583,5.873843199043228 92 | 90,HUVEC-05,1,2.934234235193822,5.250856734336974 93 | 91,HUVEC-05,2,6.696135661818764,6.457119011881106 94 | 92,HUVEC-05,3,5.510527243861905,3.154173834068408 95 | 93,HUVEC-05,4,4.123352420794499,3.79445163740529 96 | 94,HUVEC-05,5,2.4013383791044163,2.562447671717749 97 | 95,HUVEC-05,6,6.3478354912299615,4.199810136895209 98 | 96,HUVEC-06,1,3.0381007159656863,4.834226661220529 99 | 97,HUVEC-06,2,11.76814136613587,11.188808127315676 100 | 98,HUVEC-06,3,7.971630344344386,4.7724321395826905 101 | 99,HUVEC-06,4,7.109078022254182,6.502477963742212 102 | 100,HUVEC-06,5,2.036674448380521,2.2585344052791285 103 | 101,HUVEC-06,6,9.376877147138458,5.6348230442018465 104 | 102,HUVEC-07,1,3.147763698132007,4.948217047766337 105 | 103,HUVEC-07,2,18.878245271645582,14.331536042005979 106 | 104,HUVEC-07,3,6.402360728808811,3.5296017086301745 107 | 105,HUVEC-07,4,10.288628733003295,8.138533111916402 108 | 106,HUVEC-07,5,2.2331130009192925,2.3168058642258327 109 | 107,HUVEC-07,6,4.913339720143901,2.5223188309630338 110 | 108,HUVEC-08,1,3.551637954526133,5.490524710485044 111 | 109,HUVEC-08,2,15.416924391474042,13.207647096389017 112 | 110,HUVEC-08,3,8.318035749645976,4.831448935103155 113 | 111,HUVEC-08,4,8.743074609087659,7.615154318232122 114 | 112,HUVEC-08,5,1.9909261325737098,2.0952244603336445 115 | 113,HUVEC-08,6,5.89808654630339,3.1845551853728438 116 | 114,HUVEC-09,1,3.5162445647375926,5.533914518890725 117 | 115,HUVEC-09,2,14.176696026480043,12.823176885266896 118 | 116,HUVEC-09,3,6.050853154876015,3.312548520259971 119 | 117,HUVEC-09,4,7.628807792415866,6.744788943963276 120 | 118,HUVEC-09,5,2.914426989369578,3.374631916191481 121 | 119,HUVEC-09,6,4.356150560564809,2.207772372080506 122 | 120,HUVEC-10,1,2.8759766135896956,4.85476472001889 123 | 121,HUVEC-10,2,14.706772414120762,12.962872346166757 124 | 122,HUVEC-10,3,6.529636311840702,3.41942202082507 125 | 123,HUVEC-10,4,7.984274610296472,6.984703508086695 126 | 124,HUVEC-10,5,2.6392515402335626,2.6206051364172454 127 | 125,HUVEC-10,6,4.874001066406052,2.3180348182618427 128 | 126,HUVEC-11,1,3.4272671365118645,5.211219382682406 129 | 127,HUVEC-11,2,17.993010565832062,14.212818096753484 130 | 128,HUVEC-11,3,7.901747076542346,4.323031143266195 131 | 129,HUVEC-11,4,9.84448110747647,8.008731844849814 132 | 130,HUVEC-11,5,2.08433321389285,2.582982946767482 133 | 131,HUVEC-11,6,6.036166847526253,3.072909684947843 134 | 132,HUVEC-12,1,7.005183219909668,5.683179980268542 135 | 133,HUVEC-12,2,15.376991566125449,12.914875044758428 136 | 134,HUVEC-12,3,5.774422323548949,2.931309480846939 137 | 135,HUVEC-12,4,7.90397580877527,6.3971604366049 138 | 136,HUVEC-12,5,3.5876278706959317,3.8606157025280052 139 | 137,HUVEC-12,6,4.16717335156032,1.9612488549373648 140 | 138,HUVEC-13,1,6.848372023090039,5.39756600301248 141 | 139,HUVEC-13,2,15.660099144854577,11.963631587741473 142 | 140,HUVEC-13,3,5.808692977319356,3.0512440512406815 143 | 141,HUVEC-13,4,8.72532674533869,6.617421898316723 144 | 142,HUVEC-13,5,3.12055616596945,3.3025084908564604 145 | 143,HUVEC-13,6,4.524055680418326,2.2755258684752078 146 | 144,HUVEC-14,1,4.921122430186522,3.6758997370594213 147 | 145,HUVEC-14,2,15.985091818006415,13.261861126689737 148 | 146,HUVEC-14,3,5.354569533937855,2.8705251893767887 149 | 147,HUVEC-14,4,7.090410823884763,5.600686193866689 150 | 148,HUVEC-14,5,4.30511569192535,5.262959467888591 151 | 149,HUVEC-14,6,3.9811147984705473,2.1278019847845124 152 | 150,HUVEC-15,1,4.549013319768403,3.0972630348790005 153 | 151,HUVEC-15,2,17.446095463476684,14.10846391112002 154 | 152,HUVEC-15,3,5.6876649087981175,3.0841760742995805 155 | 153,HUVEC-15,4,8.2799316079993,6.573590838540447 156 | 154,HUVEC-15,5,3.947403697591079,5.0355317214428075 157 | 155,HUVEC-15,6,4.70804415094225,2.7240523003632235 158 | 156,HUVEC-16,1,6.900458394707023,5.963448904493482 159 | 157,HUVEC-16,2,13.046737879901738,10.77576123341358 160 | 158,HUVEC-16,3,6.834947643341956,3.3807141058593957 161 | 159,HUVEC-16,4,7.354048518391399,5.9341500794221185 162 | 160,HUVEC-16,5,2.6300981478257612,2.8816127963229228 163 | 161,HUVEC-16,6,5.127939541618545,2.398072316438165 164 | 162,HUVEC-17,1,3.0777270221090935,4.7347117435889565 165 | 163,HUVEC-17,2,7.903858717385825,7.939172050968376 166 | 164,HUVEC-17,3,7.8793937788381205,3.772278859341094 167 | 165,HUVEC-17,4,5.4769744222814385,4.953986058152201 168 | 166,HUVEC-17,5,2.2676011231038475,2.1094071000518095 169 | 167,HUVEC-17,6,9.365418180242761,4.722853025652434 170 | 168,HUVEC-18,1,3.641753905580646,5.537115317780045 171 | 169,HUVEC-18,2,19.893317832296017,15.173063803987269 172 | 170,HUVEC-18,3,8.445262569610323,4.456788399186 173 | 171,HUVEC-18,4,11.644640167199142,9.116957616438642 174 | 172,HUVEC-18,5,2.4914967361065936,2.701815373168584 175 | 173,HUVEC-18,6,6.540758349854805,3.1845737548454296 176 | 174,HUVEC-19,1,3.4957298798994585,5.284973487356911 177 | 175,HUVEC-19,2,19.38256341915626,14.374612522162147 178 | 176,HUVEC-19,3,6.491130763834173,3.238432956465161 179 | 177,HUVEC-19,4,12.12870635769584,9.57733679780621 180 | 178,HUVEC-19,5,1.158154447357376,1.1874609869842132 181 | 179,HUVEC-19,6,6.3886334958014555,3.3462037397029376 182 | 180,HUVEC-20,1,3.4611970201715248,5.466144107294636 183 | 181,HUVEC-20,2,17.499450229979182,13.847036464693208 184 | 182,HUVEC-20,3,7.495805565412942,4.133217843528006 185 | 183,HUVEC-20,4,9.473571285024866,7.784712957324313 186 | 184,HUVEC-20,5,1.9756623692326731,2.0812880238496554 187 | 185,HUVEC-20,6,5.754493968827384,2.9674065770627234 188 | 186,HUVEC-21,1,5.282639717126822,3.9124161382595584 189 | 187,HUVEC-21,2,14.940510666215575,11.90895446643926 190 | 188,HUVEC-21,3,6.552785952369888,3.040851276341721 191 | 189,HUVEC-21,4,6.896931847968659,5.299137620203697 192 | 190,HUVEC-21,5,3.864989975830177,3.90540540622969 193 | 191,HUVEC-21,6,4.978786070625503,2.1558607878377325 194 | 192,HUVEC-22,1,4.766385728662664,4.0786842641117635 195 | 193,HUVEC-22,2,15.602398400182848,13.53302042932937 196 | 194,HUVEC-22,3,5.02526576488049,2.336921172077658 197 | 195,HUVEC-22,4,6.705845619176889,5.618456361562256 198 | 196,HUVEC-22,5,2.9821719674320963,2.902539018709608 199 | 197,HUVEC-22,6,3.950447500526131,1.719001985273377 200 | 198,HUVEC-23,1,4.89968963331754,3.7855653397231683 201 | 199,HUVEC-23,2,12.424805679135357,10.546565494487181 202 | 200,HUVEC-23,3,7.32482897605865,3.6161458095872394 203 | 201,HUVEC-23,4,5.83766254267007,4.666346575202206 204 | 202,HUVEC-23,5,4.419775754431222,4.809825772851262 205 | 203,HUVEC-23,6,5.054036886492558,2.313404792566226 206 | 204,HUVEC-24,1,5.144148798732014,3.882020600506075 207 | 205,HUVEC-24,2,18.252287008545615,13.667799638162546 208 | 206,HUVEC-24,3,6.849889515282272,3.0282464374139217 209 | 207,HUVEC-24,4,8.151514894002444,5.8256170211706335 210 | 208,HUVEC-24,5,4.920666720960047,4.522007574576416 211 | 209,HUVEC-24,6,5.4202206150277865,2.2471104098872496 212 | 210,RPE-01,1,3.618763621751364,4.192666875458602 213 | 211,RPE-01,2,10.956120983346716,9.001264884009728 214 | 212,RPE-01,3,16.646254347516344,9.515929079978783 215 | 213,RPE-01,4,7.8296966769478535,5.754399411414793 216 | 214,RPE-01,5,3.690766427424047,3.5254472138315758 217 | 215,RPE-01,6,13.503023982048035,6.8188086397576955 218 | 216,RPE-02,1,2.978969154419837,4.008428452386732 219 | 217,RPE-02,2,8.812340422110124,8.543697223340121 220 | 218,RPE-02,3,14.565686269239945,9.09850304361312 221 | 219,RPE-02,4,6.347397101389897,5.473768006232086 222 | 220,RPE-02,5,3.7991883568949514,3.6151026051546182 223 | 221,RPE-02,6,12.52660627334149,7.268195615217497 224 | 222,RPE-03,1,3.4160109185553216,4.266135877338148 225 | 223,RPE-03,2,8.254199461503463,7.39236393042369 226 | 224,RPE-03,3,16.256040862628392,9.814869868404422 227 | 225,RPE-03,4,6.239201920373099,4.823814026681506 228 | 226,RPE-03,5,4.509512128768029,4.026306541241899 229 | 227,RPE-03,6,12.78915332664143,6.843365937493819 230 | 228,RPE-04,1,2.2836156778467274,2.930678134369584 231 | 229,RPE-04,2,13.32587173021876,13.348558043170117 232 | 230,RPE-04,3,10.601482098515106,7.098331663036705 233 | 231,RPE-04,4,6.574460301527059,6.138121967099072 234 | 232,RPE-04,5,5.650054214254421,6.076721981938296 235 | 233,RPE-04,6,7.266003784563946,4.329859798104113 236 | 234,RPE-05,1,2.9065533226186577,4.060372213010841 237 | 235,RPE-05,2,13.485283075989067,12.269472544047769 238 | 236,RPE-05,3,10.648747136066486,6.140987144882112 239 | 237,RPE-05,4,7.449366996814678,6.407338636527132 240 | 238,RPE-05,5,5.123378236572464,5.026366380213089 241 | 239,RPE-05,6,6.332968919308154,3.2830630270162096 242 | 240,RPE-06,1,2.188530039477658,3.0609088790437333 243 | 241,RPE-06,2,13.898793576599715,13.077941216927924 244 | 242,RPE-06,3,10.336899046774034,6.103305922461448 245 | 243,RPE-06,4,6.9152455298931566,6.004647246325296 246 | 244,RPE-06,5,4.3894192599630975,4.1641680927578735 247 | 245,RPE-06,6,6.901995021027404,3.673627897509623 248 | 246,RPE-07,1,2.2592327718610887,3.070365939754178 249 | 247,RPE-07,2,13.241409530887356,13.852669399070603 250 | 248,RPE-07,3,11.424235316066,7.60091397254857 251 | 249,RPE-07,4,6.587675913587793,6.350641250210778 252 | 250,RPE-07,5,5.645078669894826,6.214571469890471 253 | 251,RPE-07,6,7.627715812100993,4.607619576815733 254 | 252,RPE-08,1,2.5991130169335896,3.5723584481857316 255 | 253,RPE-08,2,14.882595944714236,13.41075350539315 256 | 254,RPE-08,3,11.71162163437187,8.456188316417437 257 | 255,RPE-08,4,8.982621449928779,7.79834068002911 258 | 256,RPE-08,5,1.2401140401889752,1.3682175789849884 259 | 257,RPE-08,6,7.530617105496394,4.7924648885095005 260 | 258,RPE-09,1,2.6513966747070112,3.5012979941239237 261 | 259,RPE-09,2,14.728037704790054,13.15006618112863 262 | 260,RPE-09,3,13.677852137896416,7.430553927489134 263 | 261,RPE-09,4,7.141990971507143,5.737139058590704 264 | 262,RPE-09,5,5.234366510671682,4.6118545325318765 265 | 263,RPE-09,6,8.75836055826889,4.241113320352657 266 | 264,RPE-10,1,2.6262122253318885,3.0182835547231077 267 | 265,RPE-10,2,14.410689332268454,12.647798339176033 268 | 266,RPE-10,3,11.1440017966481,7.017884623806237 269 | 267,RPE-10,4,7.211873741893025,5.819814148850524 270 | 268,RPE-10,5,4.089006951877049,4.218589471999256 271 | 269,RPE-10,6,7.566016957357332,4.354793941567297 272 | 270,RPE-11,1,2.1429493495843994,2.9858613095615216 273 | 271,RPE-11,2,13.5351067843159,12.875010643428732 274 | 272,RPE-11,3,11.350918044981254,7.570879061454019 275 | 273,RPE-11,4,6.537267311929971,5.847637281700208 276 | 274,RPE-11,5,4.419598450108248,4.716024596556777 277 | 275,RPE-11,6,7.588953272083693,4.556357167060993 278 | 276,U2OS-01,1,7.53757672960108,10.91069202608281 279 | 277,U2OS-01,2,10.093597218587801,7.884009805214971 280 | 278,U2OS-01,3,10.044440435124683,5.319648123227273 281 | 279,U2OS-01,4,7.385856851354822,5.481308144410095 282 | 280,U2OS-01,5,3.0146624995516493,3.2446417264903373 283 | 281,U2OS-01,6,13.574783184311606,8.392859214846434 284 | 282,U2OS-02,1,11.160868369139633,12.695885051692606 285 | 283,U2OS-02,2,12.521636337428898,8.454749862284237 286 | 284,U2OS-02,3,12.274935527281327,5.065677728392378 287 | 285,U2OS-02,4,10.068334166105691,5.925579158020393 288 | 286,U2OS-02,5,3.2909658744737698,3.6543719347325925 289 | 287,U2OS-02,6,16.674154216592964,7.8475601620377935 290 | 288,U2OS-03,1,9.0928936438127,12.093452228697604 291 | 289,U2OS-03,2,12.956782345647936,10.024108975816217 292 | 290,U2OS-03,3,11.595745063447334,5.889229961396887 293 | 291,U2OS-03,4,9.150984962265213,6.52467881576424 294 | 292,U2OS-03,5,3.0964064071704813,3.9391766404997677 295 | 293,U2OS-03,6,13.740227476342932,7.927517904465902 296 | 294,U2OS-04,1,10.661908519732489,10.850662613634036 297 | 295,U2OS-04,2,55.51065658903741,30.280542933624375 298 | 296,U2OS-04,3,21.704327575572126,7.680837688350505 299 | 297,U2OS-04,4,29.250711693392173,14.84692353010392 300 | 298,U2OS-04,5,28.89898979199397,15.446277984435925 301 | 299,U2OS-04,6,25.74298348364892,11.274324744277402 302 | 300,U2OS-05,1,4.558761799921755,6.626578351767876 303 | 301,U2OS-05,2,10.398600445419062,8.407355144020054 304 | 302,U2OS-05,3,7.812904267232926,4.363187538739127 305 | 303,U2OS-05,4,6.35408998864596,4.809022019706418 306 | 304,U2OS-05,5,3.8135468514239204,4.120718689925231 307 | 305,U2OS-05,6,6.413873881981021,3.389138995490683 308 | -------------------------------------------------------------------------------- /process/utils.py: -------------------------------------------------------------------------------- 1 | from include import * 2 | from torch.autograd import Variable 3 | 4 | def save(list_or_dict,name): 5 | f = open(name, 'w') 6 | f.write(str(list_or_dict)) 7 | f.close() 8 | 9 | def load(name): 10 | f = open(name, 'r') 11 | a = f.read() 12 | tmp = eval(a) 13 | f.close() 14 | return tmp 15 | 16 | def dot_numpy(vector1 , vector2,emb_size = 512): 17 | vector1 = vector1.reshape([-1, emb_size]) 18 | vector2 = vector2.reshape([-1, emb_size]) 19 | vector2 = vector2.transpose(1,0) 20 | 21 | cosV12 = np.dot(vector1, vector2) 22 | return cosV12 23 | 24 | def to_var(x, volatile=False): 25 | if torch.cuda.is_available(): 26 | x = x.cuda() 27 | return Variable(x, volatile=volatile) 28 | 29 | def softmax_cross_entropy_criterion(logit, truth, is_average=True): 30 | loss = F.cross_entropy(logit, truth, reduce=is_average) 31 | return loss 32 | 33 | def softmax_add_newwhale(logit, truth): 34 | indexs_NoNew = (truth != 5004).nonzero().view(-1) 35 | indexs_New = (truth == 5004).nonzero().view(-1) 36 | 37 | logits_NoNew = logit[indexs_NoNew] 38 | truth_NoNew = truth[indexs_NoNew] 39 | 40 | logits_New = logit[indexs_New] 41 | print(logits_New.size()) 42 | 43 | if logits_NoNew.size()[0]>0: 44 | loss = nn.CrossEntropyLoss(reduce=True)(logits_NoNew, truth_NoNew) 45 | else: 46 | loss = 0 47 | 48 | if logits_New.size()[0]>0: 49 | logits_New = torch.softmax(logits_New,1) 50 | logits_New = logits_New.topk(1,1,True,True)[0] 51 | target_New = torch.zeros_like(logits_New).float().cuda() 52 | loss += nn.L1Loss()(logits_New, target_New) 53 | 54 | return loss 55 | 56 | def metric(logit, truth, is_average=True, is_prob = False): 57 | if is_prob: 58 | prob = logit 59 | else: 60 | prob = F.softmax(logit, 1) 61 | 62 | value, top = prob.topk(5, dim=1, largest=True, sorted=True) 63 | correct = top.eq(truth.view(-1, 1).expand_as(top)) 64 | 65 | if is_average==True: 66 | # top-3 accuracy 67 | correct = correct.float().sum(0, keepdim=False) 68 | correct = correct/len(truth) 69 | 70 | top = [correct[0], 71 | correct[0] + correct[1], 72 | correct[0] + correct[1] + correct[2], 73 | correct[0] + correct[1] + correct[2] + correct[3], 74 | correct[0] + correct[1] + correct[2] + correct[3] + correct[4]] 75 | 76 | precision = correct[0] / 1 + correct[1] / 2 + correct[2] / 3 + correct[3] / 4 + correct[4] / 5 77 | 78 | return precision, top 79 | else: 80 | return correct 81 | 82 | def metric_top1(logit, truth, is_average=True, is_prob = False): 83 | if is_prob: 84 | prob = logit 85 | else: 86 | prob = F.softmax(logit, 1) 87 | 88 | value, top = prob.topk(1, dim=1, largest=True, sorted=True) 89 | correct = top.eq(truth.view(-1, 1).expand_as(top)) 90 | 91 | if is_average==True: 92 | # top-3 accuracy 93 | correct = correct.float().sum(0, keepdim=False) 94 | correct = correct/len(truth) 95 | top = [correct[0]] 96 | precision = correct[0] 97 | 98 | return precision, top 99 | else: 100 | return correct 101 | 102 | 103 | def top_n_np(preds, labels): 104 | n = 5 105 | predicted = np.fliplr(preds.argsort(axis=1)[:, -n:]) 106 | top5 = [] 107 | 108 | re = 0 109 | for i in range(len(preds)): 110 | predicted_tmp = predicted[i] 111 | labels_tmp = labels[i] 112 | for n_ in range(5): 113 | re += np.sum(labels_tmp == predicted_tmp[n_]) / (n_ + 1.0) 114 | 115 | re = re / len(preds) 116 | 117 | for i in range(n): 118 | top5.append(np.sum(labels == predicted[:, i])/ (1.0*len(labels))) 119 | 120 | return re, top5 121 | 122 | def metric_binary(logit, truth): 123 | prob = F.softmax(logit, 1) 124 | value, top = prob.topk(2, dim=1, largest=True, sorted=True) 125 | correct = top.eq(truth.view(-1, 1).expand_as(top)) 126 | correct = correct.float().sum(0, keepdim=False) 127 | correct = correct / len(truth) 128 | return correct[0] 129 | 130 | def metric_bce(logit, truth): 131 | prob = F.sigmoid(logit) 132 | prob[prob > 0.5] = 1 133 | prob[prob < 0.5] = 0 134 | correct = prob.eq(truth.view(-1, 1).expand_as(prob)) 135 | correct = correct.float().sum(0, keepdim=False) 136 | correct = correct/len(truth) 137 | return correct 138 | -------------------------------------------------------------------------------- /settings.py: -------------------------------------------------------------------------------- 1 | path_data = './process/folds/' 2 | mean_std_path = r'./process/' 3 | model_save_path = r'./models' 4 | data_dir = r'./NIPS2019/recursion-cellular-image-classification' 5 | 6 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from torch.autograd import Variable 2 | import torch 3 | import numpy as np 4 | import torch.nn.functional as F 5 | import torch.nn as nn 6 | import os 7 | import math 8 | import argparse 9 | import pprint 10 | import tqdm 11 | from collections import defaultdict 12 | import numpy as np 13 | import pandas as pd 14 | import torch 15 | import numpy as np 16 | 17 | from torch.utils.data import DataLoader 18 | from prefetch_generator import BackgroundGenerator 19 | 20 | class DataLoaderX(DataLoader): 21 | 22 | def __iter__(self): 23 | return BackgroundGenerator(super().__iter__()) 24 | 25 | class CrossEntropyLabelSmooth(nn.Module): 26 | """Cross entropy loss with label smoothing regularizer. 27 | 28 | Reference: 29 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016. 30 | Equation: y = (1 - epsilon) * y + epsilon / K. 31 | 32 | Args: 33 | num_classes (int): number of classes. 34 | epsilon (float): weight. 35 | """ 36 | def __init__(self, num_classes, epsilon=0.1, use_gpu=True): 37 | super(CrossEntropyLabelSmooth, self).__init__() 38 | self.num_classes = num_classes 39 | self.epsilon = epsilon 40 | self.use_gpu = use_gpu 41 | self.logsoftmax = nn.LogSoftmax(dim=1) 42 | 43 | def forward(self, inputs, targets): 44 | """ 45 | Args: 46 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 47 | targets: ground truth labels with shape (num_classes) 48 | """ 49 | log_probs = self.logsoftmax(inputs) 50 | targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1) 51 | if self.use_gpu: targets = targets.cuda() 52 | targets = (1 - self.epsilon) * targets + self.epsilon / (self.num_classes*1.0) 53 | loss = (- targets * log_probs).mean(0).sum() 54 | return loss 55 | 56 | def split_weights(net): 57 | """split network weights into to categlories, 58 | one are weights in conv layer and linear layer, 59 | others are other learnable paramters(conv bias, 60 | bn weights, bn bias, linear bias) 61 | Args: 62 | net: network architecture 63 | 64 | Returns: 65 | a dictionary of params splite into to categlories 66 | """ 67 | 68 | decay = [] 69 | no_decay = [] 70 | 71 | for m in net.modules(): 72 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): 73 | decay.append(m.weight) 74 | 75 | if m.bias is not None: 76 | no_decay.append(m.bias) 77 | 78 | else: 79 | if hasattr(m, 'weight'): 80 | no_decay.append(m.weight) 81 | if hasattr(m, 'bias'): 82 | no_decay.append(m.bias) 83 | 84 | assert len(list(net.parameters())) == len(decay) + len(no_decay) 85 | 86 | return [dict(params=decay), dict(params=no_decay, weight_decay=0)] 87 | 88 | 89 | def mixup_data(x, y, alpha=1.0, use_cuda=True): 90 | 91 | '''Compute the mixup data. Return mixed inputs, pairs of targets, and lambda''' 92 | if alpha > 0.: 93 | lam = np.random.beta(alpha, alpha) 94 | else: 95 | lam = 1. 96 | 97 | batch_size = x.size()[0] 98 | 99 | if use_cuda: 100 | index = torch.randperm(batch_size).cuda() 101 | else: 102 | index = torch.randperm(batch_size) 103 | 104 | mixed_x = lam * x + (1 - lam) * x[index,:] 105 | y_a, y_b = y, y[index] 106 | return mixed_x, y_a, y_b, lam 107 | 108 | def mixup_criterion(y_a, y_b, lam): 109 | return lambda criterion, pred: lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b) 110 | 111 | def save(list_or_dict,name): 112 | f = open(name, 'w') 113 | f.write(str(list_or_dict)) 114 | f.close() 115 | 116 | def load(name): 117 | f = open(name, 'r') 118 | a = f.read() 119 | tmp = eval(a) 120 | f.close() 121 | return tmp 122 | 123 | def dot_numpy(vector1 , vector2,emb_size = 512): 124 | vector1 = vector1.reshape([-1, emb_size]) 125 | vector2 = vector2.reshape([-1, emb_size]) 126 | vector2 = vector2.transpose(1,0) 127 | 128 | cosV12 = np.dot(vector1, vector2) 129 | return cosV12 130 | 131 | def to_var(x, volatile=False): 132 | if torch.cuda.is_available(): 133 | x = x.cuda() 134 | return Variable(x, volatile=volatile) 135 | 136 | def metric(logit, truth, is_average=True, is_prob = False, topn =5): 137 | 138 | if is_prob: 139 | prob = logit 140 | else: 141 | prob = F.softmax(logit, 1) 142 | 143 | value, top = prob.topk(topn, dim=1, largest=True, sorted=True) 144 | correct = top.eq(truth.view(-1, 1).expand_as(top)) 145 | 146 | if is_average==True: 147 | # top-3 accuracy 148 | correct = correct.float().sum(0, keepdim=False) 149 | correct = correct/len(truth) 150 | 151 | top = [correct[0], 152 | correct[0] + correct[1], 153 | correct[0] + correct[1] + correct[2], 154 | correct[0] + correct[1] + correct[2] + correct[3], 155 | correct[0] + correct[1] + correct[2] + correct[3] + correct[4]] 156 | 157 | precision = correct[0] / 1 + correct[1] / 2 + correct[2] / 3 + correct[3] / 4 + correct[4] / 5 158 | 159 | return precision, top 160 | else: 161 | correct = correct.float().sum(0, keepdim=False) 162 | correct = correct/len(truth) 163 | return correct 164 | 165 | def top_n_np(preds, labels): 166 | n = 5 167 | predicted = np.fliplr(preds.argsort(axis=1)[:, -n:]) 168 | top5 = [] 169 | 170 | re = 0 171 | for i in range(len(preds)): 172 | predicted_tmp = predicted[i] 173 | labels_tmp = labels[i] 174 | for n_ in range(5): 175 | re += np.sum(labels_tmp == predicted_tmp[n_]) / (n_ + 1.0) 176 | 177 | re = re / len(preds) 178 | 179 | for i in range(n): 180 | top5.append(np.sum(labels == predicted[:, i])/ (1.0*len(labels))) 181 | 182 | return re, top5 183 | 184 | def get_center(vectors): 185 | avg = np.mean(vectors, axis=0) 186 | if avg.ndim == 1: 187 | avg = avg / np.linalg.norm(avg) 188 | elif avg.ndim == 2: 189 | assert avg.shape[1] == 512 190 | avg = avg / np.linalg.norm(avg, axis=1, keepdims=True) 191 | else: 192 | assert False, avg.shape 193 | return avg 194 | 195 | def average_features(features, id_list): 196 | averaged_features = [] 197 | averaged_id_list = [] 198 | 199 | unique_ids = set(id_list) 200 | unique_ids = list(sorted(list(unique_ids))) 201 | 202 | for unique_id in unique_ids: 203 | assert unique_id != 'new_whale' 204 | cur_features = [feature for feature, Id 205 | in zip(features, id_list) if Id == unique_id] 206 | 207 | cur_features = np.stack(cur_features, axis=0) 208 | 209 | if len(cur_features) == 1: 210 | averaged_features.append(cur_features[0]) 211 | averaged_id_list.append(unique_id) 212 | else: 213 | averaged_feature = get_center(cur_features) 214 | averaged_features.append(averaged_feature) 215 | averaged_id_list.append(unique_id) 216 | 217 | averaged_features = np.stack(averaged_features, axis=0) 218 | assert averaged_features.shape[0] == len(averaged_id_list) 219 | 220 | return averaged_features, averaged_id_list 221 | 222 | def get_nearest_k(center, features, k, threshold): 223 | feature_with_dis = [(feature, np.dot(center, feature)) for feature in features] 224 | if len(feature_with_dis) > 10: 225 | distances = np.array([dis for _, dis in feature_with_dis]) 226 | 227 | filtered = [feature for feature, dis in feature_with_dis if dis > 0.5] 228 | if len(filtered) < len(feature_with_dis): 229 | distances = np.array([feature for feature, dis in feature_with_dis if dis <= 0.5]) 230 | if len(filtered) > k: 231 | return filtered 232 | feature_with_dis = [feature for feature, dis in sorted(feature_with_dis, key=lambda v: v[1], reverse=True)] 233 | return feature_with_dis[:k] 234 | 235 | def get_image_center(features): 236 | if len(features) < 4: 237 | return get_center(features) 238 | 239 | for _ in range(2): 240 | center = get_center(features) 241 | features = get_nearest_k(center, features, int(len(features) * 3 / 4), 0.5) 242 | 243 | if len(features) < 4: 244 | break 245 | 246 | return get_center(features) 247 | 248 | 249 | def sigmoid_rampup(current, rampup_length): 250 | """Exponential rampup from https://arxiv.org/abs/1610.02242""" 251 | if rampup_length == 0: 252 | return 1.0 253 | else: 254 | current = np.clip(current, 0.0, rampup_length) 255 | phase = 1.0 - current / rampup_length 256 | return float(np.exp(-5.0 * phase * phase)) 257 | 258 | 259 | def linear_rampup(current, rampup_length): 260 | """Linear rampup""" 261 | assert current >= 0 and rampup_length >= 0 262 | if current >= rampup_length: 263 | return 1.0 264 | else: 265 | return current / rampup_length 266 | 267 | 268 | def cosine_rampdown(current, rampdown_length): 269 | """Cosine rampdown from https://arxiv.org/abs/1608.03983""" 270 | assert 0 <= current <= rampdown_length 271 | return float(.5 * (np.cos(np.pi * current / rampdown_length) + 1)) 272 | --------------------------------------------------------------------------------