├── output ├── log ├── result_json └── runs │ └── result_weight ├── model ├── submission_model_weight └── preprocessing │ └── preprocessing_model_weight ├── DATA └── Final_DATA │ ├── task02_test │ └── test_path │ └── task02_train │ └── train_path ├── .gitattributes ├── Image ├── task.png ├── score.png └── summary.png ├── solution.pdf ├── train.py ├── modules ├── utils.py ├── dataset.py ├── scheduler.py ├── transform.py ├── models.py ├── preprocessing.py └── solver.py ├── preprocess.py ├── README.md ├── config └── config.yaml └── predict.py /output/log: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /output/result_json: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /output/runs/result_weight: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /model/submission_model_weight: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /DATA/Final_DATA/task02_test/test_path: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /DATA/Final_DATA/task02_train/train_path: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /model/preprocessing/preprocessing_model_weight: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /Image/task.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/choco9966/Korean-Hair-Segmentation/HEAD/Image/task.png -------------------------------------------------------------------------------- /solution.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/choco9966/Korean-Hair-Segmentation/HEAD/solution.pdf -------------------------------------------------------------------------------- /Image/score.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/choco9966/Korean-Hair-Segmentation/HEAD/Image/score.png -------------------------------------------------------------------------------- /Image/summary.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/choco9966/Korean-Hair-Segmentation/HEAD/Image/summary.png -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | from logging import getLogger, INFO, FileHandler, Formatter, StreamHandler 5 | from modules.utils import * 6 | from modules.preprocessing import * 7 | from modules.solver import train_loop, train_pre_loop 8 | 9 | import warnings 10 | warnings.filterwarnings('ignore') 11 | 12 | CONFIG_PATH = './config/config.yaml' 13 | config = load_yaml(CONFIG_PATH) 14 | 15 | # TRAIN 16 | TRAIN = config['TRAIN']['train'] 17 | N_FOLD = config['TRAIN']['n_fold'] 18 | TRN_FOLD = config['TRAIN']['trn_fold'] 19 | 20 | dataset_path = './DATA/Final_DATA/task02_train' 21 | test_dataset_path = './DATA/Final_DATA/task02_test' 22 | 23 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 24 | 25 | def main(LOGGER): 26 | if TRAIN: 27 | data = iou_preprocessing(dataset_path) 28 | for fold in range(N_FOLD): 29 | if fold in TRN_FOLD: 30 | train_loop(data, fold, LOGGER) 31 | 32 | 33 | if __name__ == '__main__': 34 | seed_torch() 35 | LOGGER = init_logger() 36 | main(LOGGER) 37 | -------------------------------------------------------------------------------- /modules/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import yaml 4 | import numpy as np 5 | 6 | import torch 7 | from logging import getLogger, INFO, FileHandler, Formatter, StreamHandler 8 | 9 | 10 | OUTPUT_DIR = './output' 11 | if not os.path.exists(OUTPUT_DIR): 12 | os.makedirs(OUTPUT_DIR) 13 | 14 | def load_yaml(path): 15 | with open(path, 'r') as f: 16 | return yaml.load(f, Loader=yaml.FullLoader) 17 | 18 | 19 | def init_logger(log_file=OUTPUT_DIR+'/train.log'): 20 | logger = getLogger(__name__) 21 | logger.setLevel(INFO) 22 | handler1 = StreamHandler() 23 | handler1.setFormatter(Formatter('%(message)s')) 24 | handler2 = FileHandler(filename=log_file) 25 | handler2.setFormatter(Formatter('%(message)s')) 26 | logger.addHandler(handler1) 27 | logger.addHandler(handler2) 28 | return logger 29 | 30 | 31 | def seed_torch(seed=42): 32 | random.seed(seed) 33 | os.environ['PYTHONHASHSEED'] = str(seed) 34 | np.random.seed(seed) 35 | torch.manual_seed(seed) 36 | torch.cuda.manual_seed(seed) 37 | torch.cuda.manual_seed_all(seed) 38 | torch.backends.cudnn.deterministic = True 39 | torch.backends.cudnn.benchmark = False 40 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | from logging import getLogger, INFO, FileHandler, Formatter, StreamHandler 5 | from modules.utils import * 6 | from modules.preprocessing import * 7 | from modules.solver import train_loop, train_pre_loop 8 | 9 | import warnings 10 | warnings.filterwarnings('ignore') 11 | 12 | CONFIG_PATH = './config/config.yaml' 13 | config = load_yaml(CONFIG_PATH) 14 | 15 | # PREPROCESSING 16 | PREPROCESSING = config['PREPROCESSING']['preprocessing'] 17 | PRE_N_FOLD = config['PREPROCESSING']['n_fold'] 18 | PRE_TRN_FOLD = config['PREPROCESSING']['trn_fold'] 19 | 20 | dataset_path = './DATA/Final_DATA/task02_train' 21 | 22 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 23 | 24 | 25 | def main(LOGGER): 26 | if PREPROCESSING: 27 | make_masks(dataset_path) 28 | make_npixels_data(dataset_path) 29 | data_pre = make_pre_data() 30 | print('data loaded') 31 | print('Complete preprocessing') 32 | 33 | # train for preprocessing 34 | for fold in range(PRE_N_FOLD): 35 | if fold in PRE_TRN_FOLD: 36 | train_pre_loop(data_pre, fold, LOGGER) 37 | 38 | calculate_iou(data_pre, device) 39 | print('Complete calculate iou') 40 | 41 | 42 | if __name__ == '__main__': 43 | seed_torch() 44 | LOGGER = init_logger() 45 | main(LOGGER) 46 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # :haircut: 한국인 헤어 세그멘테이션 경진대회 2등 솔루션 2 | 3 | # Contents 4 | 5 | #### Solution Description 6 | 7 | #### Command Line Interface 8 | 9 | - **Preprocess** 10 | - **Train** 11 | - **Predict** 12 | 13 | 14 | 15 | # Solution Description 16 | 17 | ## Subject 18 | 19 | 본 대회의 주제는 한국인 헤어에 대해서 세그멘테이션을 수행하는 AI 모델 개발입니다. 주어진 이미지는 아래의 그림과 같이 다양한 헤어스타일에 대해서 세그멘테이션을 수행해야 할 필요가 있었습니다. 20 | 21 | 22 | 23 | ![task](./Image/task.png) 24 | 25 | ## Data 26 | 27 | - 학습 데이터: 약 20만장의 512x512 헤어스타일 이미지 28 | - 경진대회의 데이터의 경우 저작권의 문제때매, 예시 이미지의 경우 AI Hub에서 제공하는 한국인 헤어스타일 데이터셋의 예시를 사용했습니다. 29 | - 이미지 데이터 : https://aihub.or.kr/aidata/30752 30 | 31 | 32 | 33 | ## Summary 34 | 35 | ![summary](./Image/summary.png) 36 | 37 | ## Score 38 | 39 | ![image-20210914131303674](./Image/score.png) 40 | 41 | 42 | 43 | ## Code Structure 44 | 45 | ```bash 46 | ├── DATA 47 | │ ├── Final_DATA 48 | │ │ ├── task02_test 49 | │ │ └── task02_train 50 | │ ├── polygon_iou.json 51 | │ └── data.csv 52 | ├── config 53 | │ └── config.yaml 54 | ├── model 55 | │ ├── submission_model_weight 56 | │ └── preprocessing 57 | │ └── preprocessing_model_weight 58 | ├── modules 59 | │ ├── dataset.py 60 | │ ├── models.py 61 | │ ├── preprocessing.py 62 | │ ├── scheduler.py 63 | │ ├── solver.py 64 | │ ├── transform.py 65 | │ └── utils.py 66 | ├── output 67 | │ ├── log 68 | │ ├── runs 69 | │ │ └── output_model_weight 70 | │ └── result.json 71 | ├── README.md 72 | ├── predict.py 73 | └── train.py 74 | ``` 75 | 76 | # Command Line Interface 77 | 78 | 79 | 80 | ## Preprocess 81 | 82 | ```console 83 | python preprocess.py 84 | ``` 85 | 86 | 87 | 88 | ## Train 89 | 90 | ```console 91 | python train.py 92 | ``` 93 | 94 | 95 | 96 | ## Predict 97 | 98 | ```console 99 | python predict.py 100 | ``` 101 | 102 | -------------------------------------------------------------------------------- /config/config.yaml: -------------------------------------------------------------------------------- 1 | SEED: 2 | random_seed: 42 3 | 4 | DATALOADER: 5 | num_workers: 0 6 | 7 | DEBUG: 8 | debug: False 9 | 10 | PREPROCESSING: 11 | img_size: 512 12 | max_len: 275 13 | print_freq: 1000 14 | encoder_type: 'efficientnet-b0' 15 | decoder_type: 'Unet' 16 | size: 512 17 | freeze_epo: 0 18 | warmup_epo: 1 19 | cosine_epo: 19 20 | warmup_factor: 10 21 | scheduler: 'GradualWarmupSchedulerV2' 22 | factor: 0.2 23 | patience: 4 24 | eps: 1e-6 25 | T_max: 4 26 | T_0: 4 27 | encoder_lr: 0.00003 28 | min_lr: 0.000001 29 | batch_size: 16 30 | weight_decay: 0.000001 31 | gradient_accumulation_steps: 1 32 | max_grad_norm: 5 33 | dropout: 0.5 34 | n_fold: 2 35 | trn_fold: [0, 1] 36 | preprocessing: True 37 | apex: False 38 | load_state: False 39 | npixel_threshold: 20000 40 | npixel_for_iou: 2000 41 | iou_threshold : 0.80 42 | 43 | 44 | TRAIN: 45 | img_size: 512 46 | max_len: 275 47 | print_freq: 100 48 | encoder_type: 'timm-efficientnet-b2' 49 | decoder_type: 'UnetPlusPlus' 50 | size: 512 51 | freeze_epo: 0 52 | warmup_epo: 1 53 | cosine_epo: 19 54 | warmup_factor: 10 55 | scheduler: 'GradualWarmupSchedulerV2' 56 | factor: 0.2 57 | patience: 4 58 | eps: 1e-6 59 | T_max: 4 60 | T_0: 4 61 | encoder_lr: 0.00003 62 | min_lr: 0.000001 63 | batch_size: 32 64 | weight_decay: 0.000001 65 | gradient_accumulation_steps: 1 66 | max_grad_norm: 5 67 | dropout: 0.5 68 | n_fold: 1 69 | trn_fold: [0] 70 | train: True 71 | apex: True 72 | load_state: False 73 | self_cutmix: True 74 | cutmix_threshold: 0.1 75 | coloring: True 76 | coloring_threshold: 0.1 77 | loss_smooth_factor: 0.01 78 | pretrained : True 79 | prospective_filtering : False 80 | augmentation : True 81 | break_epoch: 14 82 | 83 | INFERENCE: 84 | epoch_list: [13, 11, 12] 85 | mask_threshold: 14000 86 | epsilon: 0.555 87 | polygon_approx_threshold: 50 88 | 89 | LOG: 90 | log_day: '0630' 91 | version: 'v1-1' 92 | light: 'light' 93 | data_type: 'hair' 94 | comment: 'final' 95 | -------------------------------------------------------------------------------- /modules/dataset.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | from torch.utils.data import Dataset 4 | 5 | class TrainDataset(Dataset): 6 | def __init__(self, data, transform=None): 7 | self.data = data 8 | self.transform = transform 9 | 10 | 11 | def __getitem__(self, idx): 12 | images = cv2.imread(self.data.loc[idx]['images']) 13 | images = cv2.cvtColor(images, cv2.COLOR_BGR2RGB) 14 | masks = cv2.imread(self.data.loc[idx]['masks'])[:,:,0] 15 | masks = masks.astype(float) 16 | masks = np.expand_dims(masks, axis=2) 17 | 18 | if self.transform is not None: 19 | transformed = self.transform(image=images, mask=masks) 20 | images = transformed['image'] 21 | masks = transformed['mask'] 22 | 23 | return images, masks 24 | 25 | def __len__(self): 26 | return len(self.data) 27 | 28 | 29 | class ValidDataset(Dataset): 30 | def __init__(self, data, transform=None): 31 | self.data = data 32 | self.transform = transform 33 | 34 | def __getitem__(self, idx): 35 | ids = self.data.loc[idx]['images'] 36 | images = cv2.imread(self.data.loc[idx]['images']) 37 | images = cv2.cvtColor(images, cv2.COLOR_BGR2RGB) 38 | masks = cv2.imread(self.data.loc[idx]['masks'])[:,:,0] 39 | masks = masks.astype(float) 40 | masks = np.expand_dims(masks, axis=2) 41 | 42 | if self.transform is not None: 43 | transformed = self.transform(image=images, mask=masks) 44 | images = transformed["image"] 45 | masks = transformed["mask"] 46 | 47 | return ids, images, masks 48 | 49 | def __len__(self): 50 | return len(self.data) 51 | 52 | 53 | class TestDataset(Dataset): 54 | def __init__(self, data, transform=None): 55 | self.data = data 56 | self.transform = transform 57 | 58 | 59 | def __getitem__(self, idx): 60 | ids = self.data.loc[idx]['ids'] 61 | images = cv2.imread(self.data.loc[idx]['images']) 62 | images = cv2.cvtColor(images, cv2.COLOR_BGR2RGB) 63 | 64 | if self.transform is not None: 65 | transformed = self.transform(image=images) 66 | images = transformed["image"] 67 | 68 | return ids, images 69 | 70 | def __len__(self): 71 | return len(self.data) 72 | -------------------------------------------------------------------------------- /modules/scheduler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from warmup_scheduler import GradualWarmupScheduler 5 | 6 | class GradualWarmupSchedulerV2(GradualWarmupScheduler): 7 | 8 | def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None): 9 | super(GradualWarmupSchedulerV2, self).__init__(optimizer, multiplier, total_epoch, after_scheduler) 10 | 11 | def get_lr(self): 12 | if self.last_epoch > self.total_epoch: 13 | if self.after_scheduler: 14 | if not self.finished: 15 | self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs] 16 | self.finished = True 17 | return self.after_scheduler.get_lr() 18 | return [base_lr * self.multiplier for base_lr in self.base_lrs] 19 | 20 | if self.multiplier == 1.0: 21 | return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs] 22 | else: 23 | return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] 24 | 25 | 26 | class DiceLoss(nn.Module): 27 | def __init__(self, weight=None, size_average=True): 28 | super(DiceLoss, self).__init__() 29 | 30 | def forward(self, inputs, targets, smooth=1): 31 | 32 | #comment out if your model contains a sigmoid or equivalent activation layer 33 | inputs = F.sigmoid(inputs) 34 | 35 | #flatten label and prediction tensors 36 | inputs = inputs.view(-1) 37 | targets = targets.view(-1) 38 | 39 | intersection = (inputs * targets).sum() 40 | dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth) 41 | 42 | return 1 - dice 43 | 44 | 45 | class DiceBCELoss(nn.Module): 46 | # Formula Given above. 47 | def __init__(self, weight=None, size_average=True): 48 | super(DiceBCELoss, self).__init__() 49 | 50 | def forward(self, inputs, targets, smooth=1): 51 | 52 | #comment out if your model contains a sigmoid or equivalent activation layer 53 | BCE = F.binary_cross_entropy_with_logits(inputs, targets, reduction='mean') 54 | inputs = F.sigmoid(inputs) 55 | 56 | #flatten label and prediction tensors 57 | inputs = inputs.view(-1) 58 | targets = targets.view(-1) 59 | 60 | intersection = (inputs * targets).mean() 61 | dice_loss = 1 - (2.*intersection + smooth)/(inputs.mean() + targets.mean() + smooth) 62 | Dice_BCE = 0.9*BCE + 0.1*dice_loss 63 | 64 | return Dice_BCE.mean() 65 | 66 | 67 | def get_dice_coeff_ori(pred, targs, eps = 1e-9): 68 | ''' 69 | Calculates the dice coeff of a single or batch of predicted mask and true masks. 70 | 71 | Args: 72 | pred : Batch of Predicted masks (b, w, h) or single predicted mask (w, h) 73 | targs : Batch of true masks (b, w, h) or single true mask (w, h) 74 | 75 | Returns: Dice coeff over a batch or over a single pair. 76 | ''' 77 | p = (pred.view(-1) > 0).float() 78 | t = (targs.view(-1) > 0.5).float() 79 | dice = (2.0 * (p * t).sum() + eps)/ (p.sum() + t.sum() + eps) 80 | return dice 81 | 82 | 83 | def get_dice_coeff(pred, targs, eps = 1e-9): 84 | ''' 85 | Calculates the dice coeff of a single or batch of predicted mask and true masks. 86 | 87 | Args: 88 | pred : Batch of Predicted masks (b, w, h) or single predicted mask (w, h) 89 | targs : Batch of true masks (b, w, h) or single true mask (w, h) 90 | 91 | Returns: Dice coeff over a batch or over a single pair. 92 | ''' 93 | 94 | 95 | pred = (pred>0).float() 96 | return 2.0 * (pred*targs).sum() / ((pred+targs).sum() + eps) 97 | 98 | 99 | def reduce(values): 100 | ''' 101 | Returns the average of the values. 102 | Args: 103 | values : list of any value which is calulated on each core 104 | ''' 105 | return sum(values) / len(values) 106 | 107 | 108 | def symmetric_lovasz(outputs, targets): 109 | return 0.5*(lovasz_hinge(outputs, targets) + lovasz_hinge(-outputs, 1.0 - targets)) 110 | 111 | 112 | -------------------------------------------------------------------------------- /modules/transform.py: -------------------------------------------------------------------------------- 1 | import albumentations as A 2 | from albumentations.pytorch import ToTensorV2 3 | 4 | from modules.utils import load_yaml 5 | 6 | CONFIG_PATH = './config/config.yaml' 7 | config = load_yaml(CONFIG_PATH) 8 | 9 | # TRAIN 10 | AUGMENTATION = config['TRAIN']['augmentation'] 11 | 12 | def get_transforms_preprocessing(*, data): 13 | if data == 'train': 14 | return A.Compose([ 15 | A.HorizontalFlip(p=0.5), 16 | A.OneOf([ 17 | A.RandomContrast(), 18 | A.RandomGamma(), 19 | A.RandomBrightness(), 20 | ], p=0.3), 21 | A.OneOf([ 22 | A.GridDistortion(), 23 | A.OpticalDistortion(distort_limit=2, shift_limit=0.5), 24 | ], p=0.3), 25 | A.GridDropout(p=0.1), 26 | A.Normalize( 27 | mean=(0.485, 0.456, 0.406), 28 | std=(0.229, 0.224, 0.225) 29 | ), 30 | ToTensorV2(transpose_mask=True) 31 | ],p=1.) 32 | elif data == 'valid': 33 | return A.Compose([ 34 | A.Normalize( 35 | mean=(0.485, 0.456, 0.406), 36 | std=(0.229, 0.224, 0.225) 37 | ), 38 | ToTensorV2(transpose_mask=True) 39 | ],p=1.) 40 | 41 | def get_transforms_train(*, data): 42 | if data == 'train': 43 | if AUGMENTATION == True: 44 | return A.Compose([ 45 | A.OneOf([ 46 | A.HueSaturationValue(15, 25, 0), 47 | ],p=0.1), 48 | A.OneOf([ 49 | A.RandomContrast(), 50 | A.RandomGamma(), 51 | A.RandomBrightness(), 52 | ], p=0.1), 53 | A.OneOf([ 54 | A.GridDistortion(), 55 | A.OpticalDistortion(), 56 | A.GaussNoise(), 57 | ], p=0.1), 58 | A.HorizontalFlip(p=0.1), 59 | A.Cutout(), 60 | A.Normalize( 61 | mean=(0.485, 0.456, 0.406), 62 | std=(0.229, 0.224, 0.225) 63 | ), 64 | ToTensorV2(transpose_mask=True) 65 | ],p=1.) 66 | else: 67 | return A.Compose([ 68 | A.Normalize( 69 | mean=(0.485, 0.456, 0.406), 70 | std=(0.229, 0.224, 0.225) 71 | ), 72 | ToTensorV2(transpose_mask=True) 73 | ],p=1.) 74 | elif data == 'valid': 75 | return A.Compose([ 76 | A.Normalize( 77 | mean=(0.485, 0.456, 0.406), 78 | std=(0.229, 0.224, 0.225) 79 | ), 80 | ToTensorV2(transpose_mask=True) 81 | ],p=1.) 82 | 83 | def get_transforms_inference(*, data): 84 | if data == 'train': 85 | return A.Compose([ 86 | A.Resize(512, 512,always_apply=True), 87 | A.OneOf([ 88 | A.RandomContrast(), 89 | A.RandomGamma(), 90 | A.RandomBrightness(), 91 | ], p=0.3), 92 | A.OneOf([ 93 | A.ElasticTransform(alpha=120, sigma=120 * 0.05, alpha_affine=120 * 0.03), 94 | A.GridDistortion(), 95 | A.OpticalDistortion(distort_limit=2, shift_limit=0.5), 96 | ], p=0.3), 97 | A.ShiftScaleRotate(p=0.2), 98 | A.GridDropout(p=0.1), 99 | A.Resize(512,512,always_apply=True), 100 | A.Normalize( 101 | mean=(0.485, 0.456, 0.406), 102 | std=(0.229, 0.224, 0.225) 103 | ), 104 | ToTensorV2(transpose_mask=True) 105 | ],p=1.) 106 | elif data == 'valid': 107 | return A.Compose([ 108 | A.Normalize( 109 | mean=(0.485, 0.456, 0.406), 110 | std=(0.229, 0.224, 0.225) 111 | ), 112 | ToTensorV2(transpose_mask=True) 113 | ],p=1.) 114 | -------------------------------------------------------------------------------- /modules/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import segmentation_models_pytorch as smp 4 | 5 | class Encoder(nn.Module): 6 | def __init__(self, encoder_name='timm-efficientnet-b3', decoder_name='Unet' , pretrained=False): 7 | super().__init__() 8 | if encoder_name in ['se_resnext50_32x4d', 'se_resnext101_32x4d']: 9 | encoder_weights = 'imagenet' 10 | else: 11 | encoder_weights = 'noisy-student' 12 | 13 | if pretrained == False: 14 | encoder_weights = None 15 | 16 | if decoder_name == 'Unet': 17 | self.encoder = smp.Unet(encoder_name=encoder_name, encoder_weights=encoder_weights, classes=1) # [imagenet, noisy-student] 18 | elif decoder_name == 'UnetPlusPlus': 19 | self.encoder = smp.UnetPlusPlus(encoder_name=encoder_name, encoder_weights=encoder_weights, classes=1) # [imagenet, noisy-student] 20 | elif decoder_name == 'MAnet': 21 | self.encoder = smp.MAnet(encoder_name=encoder_name, encoder_weights=encoder_weights, classes=1) # [imagenet, noisy-student] 22 | elif decoder_name == 'Linknet': 23 | self.encoder = smp.Linknet(encoder_name=encoder_name, encoder_weights=encoder_weights, classes=1) # [imagenet, noisy-student] 24 | elif decoder_name == 'FPN': 25 | self.encoder = smp.FPN(encoder_name=encoder_name, encoder_weights=encoder_weights, classes=1) # [imagenet, noisy-student] 26 | elif decoder_name == 'PSPNet': 27 | self.encoder = smp.PSPNet(encoder_name=encoder_name, encoder_weights=encoder_weights, classes=1) # [imagenet, noisy-student] 28 | elif decoder_name == 'PAN': 29 | self.encoder = smp.PAN(encoder_name=encoder_name, encoder_weights=encoder_weights, classes=1) # [imagenet, noisy-student] 30 | elif decoder_name == 'DeepLabV3': 31 | self.encoder = smp.DeepLabV3(encoder_name=encoder_name, encoder_weights=encoder_weights, classes=1) # [imagenet, noisy-student] 32 | elif decoder_name == 'DeepLabV3Plus': 33 | self.encoder = smp.DeepLabV3Plus(encoder_name=encoder_name, encoder_weights=encoder_weights, classes=1) # [imagenet, noisy-student] 34 | else: 35 | raise ValueError(f"decoder_type : {decoder_name} is not exist") 36 | 37 | 38 | #@autocast() 39 | def forward(self, x): 40 | x = self.encoder(x) 41 | return x 42 | 43 | 44 | class InferenceEncoder(nn.Module): 45 | def __init__(self, encoder_name='timm-efficientnet-b3', decoder_name='Unet' , pretrained=False): 46 | super().__init__() 47 | if encoder_name in ['se_resnext50_32x4d', 'se_resnext101_32x4d']: 48 | encoder_weights = 'imagenet' 49 | else: 50 | encoder_weights = 'noisy-student' 51 | 52 | if decoder_name == 'Unet': 53 | self.encoder = smp.Unet(encoder_name=encoder_name, encoder_weights=encoder_weights, classes=1) # [imagenet, noisy-student] 54 | elif decoder_name == 'UnetPlusPlus': 55 | self.encoder = smp.UnetPlusPlus(encoder_name=encoder_name, encoder_weights=encoder_weights, classes=1) # [imagenet, noisy-student] 56 | elif decoder_name == 'MAnet': 57 | self.encoder = smp.MAnet(encoder_name=encoder_name, encoder_weights=encoder_weights, classes=1) # [imagenet, noisy-student] 58 | elif decoder_name == 'Linknet': 59 | self.encoder = smp.Linknet(encoder_name=encoder_name, encoder_weights=encoder_weights, classes=1) # [imagenet, noisy-student] 60 | elif decoder_name == 'FPN': 61 | self.encoder = smp.FPN(encoder_name=encoder_name, encoder_weights=encoder_weights, classes=1) # [imagenet, noisy-student] 62 | elif decoder_name == 'PSPNet': 63 | self.encoder = smp.PSPNet(encoder_name=encoder_name, encoder_weights=encoder_weights, classes=1) # [imagenet, noisy-student] 64 | elif decoder_name == 'PAN': 65 | self.encoder = smp.PAN(encoder_name=encoder_name, encoder_weights=encoder_weights, classes=1) # [imagenet, noisy-student] 66 | elif decoder_name == 'DeepLabV3': 67 | self.encoder = smp.DeepLabV3(encoder_name=encoder_name, encoder_weights=encoder_weights, classes=1) # [imagenet, noisy-student] 68 | elif decoder_name == 'DeepLabV3Plus': 69 | self.encoder = smp.DeepLabV3Plus(encoder_name=encoder_name, encoder_weights=encoder_weights, classes=1) # [imagenet, noisy-student] 70 | else: 71 | raise ValueError(f"decoder_type : {decoder_name} is not exist") 72 | 73 | 74 | #@autocast() 75 | def forward(self, x): 76 | x = self.encoder(x) 77 | return x 78 | -------------------------------------------------------------------------------- /modules/preprocessing.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | import cv2 5 | import numpy as np 6 | import pandas as pd 7 | import torch 8 | from shapely.geometry import Polygon 9 | from scipy import stats 10 | from tqdm import tqdm 11 | from PIL import Image, ImageDraw 12 | from torch.utils.data import DataLoader 13 | from sklearn.model_selection import StratifiedKFold, GroupKFold, KFold 14 | 15 | from modules.utils import load_yaml 16 | from modules.dataset import ValidDataset 17 | from modules.transform import get_transforms_preprocessing 18 | from modules.scheduler import get_dice_coeff 19 | from modules.models import Encoder 20 | 21 | CONFIG_PATH = './config/config.yaml' 22 | config = load_yaml(CONFIG_PATH) 23 | 24 | # PREPROCESSING 25 | ENCODER_TYPE = config['PREPROCESSING']['encoder_type'] 26 | DECODER_TYPE = config['PREPROCESSING']['decoder_type'] 27 | PRE_N_FOLD = config['PREPROCESSING']['n_fold'] 28 | PRE_TRN_FOLD = config['PREPROCESSING']['trn_fold'] 29 | NPIXEL_THRESHOLD = config['PREPROCESSING']['npixel_threshold'] 30 | NPIXEL_FOR_IOU = config['PREPROCESSING']['npixel_for_iou'] 31 | IOU_THRESHOLD = config['PREPROCESSING']['iou_threshold'] 32 | 33 | # SEED 34 | RANDOM_SEED = config['SEED']['random_seed'] 35 | 36 | # DATALOADER 37 | NUM_WORKERS = config['DATALOADER']['num_workers'] 38 | 39 | # LOG 40 | VERSION = config['LOG']['version'] 41 | 42 | def make_masks(dataset_path): 43 | if not os.path.exists(dataset_path+'/masks'): 44 | os.makedirs(dataset_path+'/masks') 45 | 46 | with open(dataset_path+"/labels.json", "r") as l: 47 | data = json.load(l) 48 | 49 | entries = {} 50 | for idx, files in enumerate(data['annotations']): 51 | entries[files['file_name']]=[] 52 | for polygon in files['polygon1']: 53 | entries[files['file_name']].append(tuple(polygon.values())) 54 | 55 | width = height = 512 56 | for name, data in tqdm(entries.items()): 57 | img = Image.new('L', (width, height), 0) 58 | ImageDraw.Draw(img).polygon(data, outline=1, fill=1) 59 | img.save(os.path.join(dataset_path+'/masks',os.path.splitext(name)[0])+'.png') 60 | 61 | 62 | def make_pre_data(): 63 | data_pre = pd.read_csv('./DATA/data.csv') 64 | data_pre = data_pre.sort_values(by='npixels', ascending=True) 65 | data_pre = data_pre.reset_index(drop=True) 66 | data_pre = data_pre[data_pre['npixels'] > NPIXEL_FOR_IOU].reset_index(drop=True) 67 | 68 | data_pre['npixels_bins'] = pd.qcut(data_pre['npixels'], q=10, retbins=True, labels=False)[0].values 69 | skf = StratifiedKFold(PRE_N_FOLD, random_state = RANDOM_SEED, shuffle=True) 70 | 71 | data_pre['fold'] = 0 72 | for fold, (tr_idx, val_idx) in enumerate(skf.split(data_pre, y=data_pre['npixels_bins'])): 73 | data_pre.loc[val_idx, 'fold'] = fold 74 | 75 | return data_pre 76 | 77 | 78 | def make_npixels_data(dataset_path): 79 | data_npixels = pd.DataFrame() 80 | data_npixels['images'] = [dataset_path + '/images/' + c.split('.')[0] + '.jpg' for c in sorted(os.listdir(dataset_path + '/masks'))] 81 | data_npixels['masks'] = [dataset_path + '/masks/' + c.split('.')[0] + '.png' for c in sorted(os.listdir(dataset_path + '/masks'))] 82 | data_npixels['npixels'] = data_npixels['masks'].apply(lambda x: np.count_nonzero(cv2.imread(x))/3) 83 | # data_npixels = data_npixels.sort_values(by='npixels', ascending=True) 84 | data_npixels.to_csv('./DATA/data.csv', index=False) 85 | 86 | 87 | def make_data(): 88 | data = pd.read_csv('./DATA/data.csv') 89 | data = data.sort_values(by='npixels',ascending=True) 90 | data = data.loc[data.npixels >= NPIXEL_THRESHOLD] 91 | #data = data.iloc[20000:-13000] 92 | data = data.reset_index(drop=True) 93 | #data['images'] = data['images'].str.replace('data/train','DATA/Final_DATA/task02_train') 94 | #data['masks'] = data['masks'].str.replace('data/train','DATA/Final_DATA/task02_train') 95 | 96 | return data 97 | 98 | def calculate_iou(data, device): 99 | target_dir = './DATA/iou_result/preds' 100 | if not os.path.exists(target_dir): 101 | os.makedirs(target_dir) 102 | 103 | image_ids = [] 104 | dice_score = [] 105 | for fold in PRE_TRN_FOLD: 106 | valids = data[data['fold'] == fold].reset_index(drop=True) 107 | valid_dataset = ValidDataset(valids, transform=get_transforms_preprocessing(data='valid')) 108 | valid_loader = DataLoader(valid_dataset, 109 | batch_size=128, 110 | shuffle=False, 111 | num_workers=NUM_WORKERS, 112 | pin_memory=True, 113 | drop_last=False) 114 | 115 | model_path1 = f'./model/preprocessing/{ENCODER_TYPE}_{DECODER_TYPE}_fold{fold}_{VERSION}_1.pth' 116 | encoder1 = Encoder(ENCODER_TYPE, DECODER_TYPE, pretrained=False) 117 | 118 | checkpoint = torch.load(model_path1, map_location=device) 119 | encoder1.load_state_dict(checkpoint['encoder']) 120 | encoder1.to(device) 121 | 122 | model_path2 = f'./model/preprocessing/{ENCODER_TYPE}_{DECODER_TYPE}_fold{fold}_{VERSION}_2.pth' 123 | encoder2 = Encoder(ENCODER_TYPE, DECODER_TYPE, pretrained=False) 124 | 125 | checkpoint = torch.load(model_path2, map_location=device) 126 | encoder2.load_state_dict(checkpoint['encoder']) 127 | encoder2.to(device) 128 | 129 | encoder1.eval() 130 | encoder2.eval() 131 | 132 | for step, (file_name, images, targets) in tqdm(enumerate(valid_loader)): 133 | images = images.to(device) 134 | targets = targets.to(device) 135 | with torch.no_grad(): 136 | y_preds1 = encoder1(images) 137 | y_preds2 = encoder2(images) 138 | y_preds = (y_preds1 + y_preds2)/2 139 | 140 | # prepare mask 141 | mask = y_preds 142 | mask[mask >= 0] = 255 143 | mask[mask<0] = 0 144 | mask = mask.cpu() 145 | 146 | 147 | for j, m in enumerate(mask): 148 | path = os.path.join(target_dir, file_name[j].split('/')[-1].split('.')[0]+'.png') 149 | cv2.imwrite(path,np.array(m[0],dtype=np.uint8)) 150 | 151 | ref_dir = './DATA/Final_DATA/task02_train/masks' 152 | file_list = os.listdir(target_dir) 153 | iou = [0]*len(file_list) 154 | for i, file_name in enumerate(tqdm(file_list)): 155 | target_mask = cv2.imread(os.path.join(target_dir,file_name), cv2.IMREAD_GRAYSCALE) 156 | target_mask[target_mask>0] = 1 157 | ref_mask = cv2.imread(os.path.join(ref_dir,file_name), cv2.IMREAD_GRAYSCALE) 158 | ref_mask[ref_mask>0] = 1 159 | 160 | intersection = np.sum(target_mask*ref_mask) 161 | union = np.sum(target_mask) + np.sum(ref_mask) - intersection 162 | iou[i] = intersection/union 163 | 164 | sorted_index = list(range(len(iou))) 165 | sorted_index = sorted(sorted_index, key=lambda x: iou[x]) 166 | 167 | polygon_iou = dict() 168 | for i in sorted_index: 169 | polygon_iou[file_list[i]] = iou[i] 170 | 171 | with open('./DATA/polygon_iou.json','w') as json_file: 172 | json.dump(polygon_iou, json_file) 173 | 174 | def iou_preprocessing(dataset_path): 175 | 176 | iou_threshold = IOU_THRESHOLD 177 | 178 | polygon_iou = json.load(open('./DATA/polygon_iou.json','rb')) 179 | iou_dic ={} 180 | for i, k in enumerate(polygon_iou): 181 | each ={} 182 | each['file_id'] = k.split('.')[0] 183 | each['iou'] = polygon_iou[k] 184 | iou_dic[i] = each 185 | iou_pd = pd.DataFrame(iou_dic.values()) 186 | iou_pd = iou_pd[iou_pd.iou>iou_threshold] 187 | 188 | train_label = json.load(open(dataset_path + '/labels.json','rb')) 189 | annotations = train_label['annotations'] 190 | 191 | lst = [] 192 | new_dic ={} 193 | for i,v in enumerate(annotations): 194 | x,y = [x['x'] for x in v['polygon1']],[x['y'] for x in v['polygon1']] 195 | each ={} 196 | pgon = Polygon(zip(x, y)) 197 | each['counts'] = len(v['polygon1']) 198 | each['area'] = pgon.area 199 | each['file_id'] = v['file_name'].split('.')[0] 200 | new_dic[i] = each 201 | lst.append(len(v['polygon1'])) 202 | 203 | area_pd = pd.DataFrame(new_dic.values()) 204 | area_pd['zscore'] = stats.zscore(area_pd.area) 205 | area_pd = area_pd[np.abs(area_pd.zscore)<1.2] 206 | target = pd.merge(iou_pd,area_pd,on='file_id',how='outer').reset_index(drop=True) 207 | 208 | data = pd.DataFrame() 209 | data['images'] = sorted(['./DATA/Final_DATA/task02_train/images/'+x+'.jpg' for x in target.file_id.values]) 210 | data['masks'] = sorted(['./DATA/Final_DATA/task02_train/masks/'+x+'.png' for x in target.file_id.values]) 211 | 212 | FOLDS = 3 # use only 20% data 213 | kf = KFold(FOLDS, random_state=RANDOM_SEED, shuffle=True) 214 | 215 | data['fold'] = 0 216 | for fold, (tr_idx, val_idx) in enumerate(kf.split(data)): 217 | data.loc[val_idx, 'fold'] = fold 218 | 219 | return data 220 | 221 | def PolyArea(x,y): 222 | return 0.5*np.abs(np.dot(x,np.roll(y,1))-np.dot(y,np.roll(x,1))) -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pandas as pd 3 | import sys 4 | from torch.utils.data import DataLoader, Dataset 5 | import torch 6 | import torch.nn as nn 7 | import os 8 | from albumentations.pytorch import ToTensorV2 9 | import segmentation_models_pytorch as smp 10 | import albumentations as A 11 | import argparse 12 | from tqdm import tqdm 13 | import torch.nn.functional as F 14 | import datetime 15 | import cv2 16 | from PIL import Image 17 | import numpy as np 18 | def get_args(): 19 | parser = argparse.ArgumentParser(description='Hair Segmentation') 20 | parser.add_argument('--testset_path', type=str, default='testset_path') # korean, celeb 21 | parser.add_argument('--encoder', type=str, default='encoder') 22 | parser.add_argument('--decoder', type=str, default='decoder') 23 | args = parser.parse_args() 24 | return args 25 | 26 | # args = get_args() 27 | 28 | testset_path = './DATA/Final_DATA/task02_test' 29 | encoder = 'timm-efficientnet-b2' 30 | decoder = 'UnetPlusPlus' 31 | 32 | 33 | 34 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 35 | #device = torch.device('cpu') 36 | def get_transforms(*, data): 37 | if data == 'train': 38 | return A.Compose([ 39 | A.Resize(512, 512,always_apply=True), 40 | A.OneOf([ 41 | A.RandomContrast(), 42 | A.RandomGamma(), 43 | A.RandomBrightness(), 44 | ], p=0.3), 45 | A.OneOf([ 46 | A.ElasticTransform(alpha=120, sigma=120 * 0.05, alpha_affine=120 * 0.03), 47 | A.GridDistortion(), 48 | A.OpticalDistortion(distort_limit=2, shift_limit=0.5), 49 | ], p=0.3), 50 | A.ShiftScaleRotate(p=0.2), 51 | A.GridDropout(p=0.1), 52 | A.Resize(512,512,always_apply=True), 53 | A.Normalize( 54 | mean=(0.485, 0.456, 0.406), 55 | std=(0.229, 0.224, 0.225) 56 | ), 57 | ToTensorV2(transpose_mask=True) 58 | ],p=1.) 59 | elif data == 'valid': 60 | return A.Compose([ 61 | A.Normalize( 62 | mean=(0.485, 0.456, 0.406), 63 | std=(0.229, 0.224, 0.225) 64 | ), 65 | ToTensorV2(transpose_mask=True) 66 | ],p=1.) 67 | 68 | 69 | 70 | class CFG: 71 | encoder_type= encoder 72 | decoder_type= decoder 73 | num_workers = 0 74 | class Encoder(nn.Module): 75 | def __init__(self, encoder_name='timm-efficientnet-b3', decoder_name='Unet' , pretrained=False): 76 | super().__init__() 77 | if CFG.encoder_type in ['se_resnext50_32x4d', 'se_resnext101_32x4d']: 78 | encoder_weights = 'imagenet' 79 | else: 80 | encoder_weights = 'noisy-student' 81 | 82 | if CFG.decoder_type == 'Unet': 83 | self.encoder = smp.Unet(encoder_name=CFG.encoder_type, encoder_weights=encoder_weights, classes=1) # [imagenet, noisy-student] 84 | elif CFG.decoder_type == 'UnetPlusPlus': 85 | self.encoder = smp.UnetPlusPlus(encoder_name=CFG.encoder_type, encoder_weights=encoder_weights, classes=1) # [imagenet, noisy-student] 86 | elif CFG.decoder_type == 'MAnet': 87 | self.encoder = smp.MAnet(encoder_name=CFG.encoder_type, encoder_weights=encoder_weights, classes=1) # [imagenet, noisy-student] 88 | elif CFG.decoder_type == 'Linknet': 89 | self.encoder = smp.Linknet(encoder_name=CFG.encoder_type, encoder_weights=encoder_weights, classes=1) # [imagenet, noisy-student] 90 | elif CFG.decoder_type == 'FPN': 91 | self.encoder = smp.FPN(encoder_name=CFG.encoder_type, encoder_weights=encoder_weights, classes=1) # [imagenet, noisy-student] 92 | elif CFG.decoder_type == 'PSPNet': 93 | self.encoder = smp.PSPNet(encoder_name=CFG.encoder_type, encoder_weights=encoder_weights, classes=1) # [imagenet, noisy-student] 94 | elif CFG.decoder_type == 'PAN': 95 | self.encoder = smp.PAN(encoder_name=CFG.encoder_type, encoder_weights=encoder_weights, classes=1) # [imagenet, noisy-student] 96 | elif CFG.decoder_type == 'DeepLabV3': 97 | self.encoder = smp.DeepLabV3(encoder_name=CFG.encoder_type, encoder_weights=encoder_weights, classes=1) # [imagenet, noisy-student] 98 | elif CFG.decoder_type == 'DeepLabV3Plus': 99 | self.encoder = smp.DeepLabV3Plus(encoder_name=CFG.encoder_type, encoder_weights=encoder_weights, classes=1) # [imagenet, noisy-student] 100 | else: 101 | raise ValueError(f"decoder_type : {CFG.decoder_type} is not exist") 102 | 103 | 104 | #@autocast() 105 | def forward(self, x): 106 | x = self.encoder(x) 107 | return x 108 | 109 | 110 | 111 | 112 | 113 | test_df = pd.DataFrame() 114 | test_df['images'] = [testset_path + '/images/' + c.split('.')[0] + '.jpg' for c in sorted(os.listdir(testset_path + '/images'))] 115 | test_df['ids'] = test_df['images'].apply(lambda x: x.split('/')[-1]) 116 | 117 | class TestDataset(Dataset): 118 | def __init__(self, data, transform=None): 119 | self.data = data 120 | self.transform = transform 121 | 122 | 123 | def __getitem__(self, idx): 124 | ids = self.data.loc[idx]['ids'] 125 | images = cv2.imread(self.data.loc[idx]['images']) 126 | images = cv2.cvtColor(images, cv2.COLOR_BGR2RGB) 127 | 128 | if self.transform is not None: 129 | transformed = self.transform(image=images) 130 | images = transformed["image"] 131 | 132 | return ids, images 133 | 134 | def __len__(self): 135 | return len(self.data) 136 | 137 | # ==================================================== 138 | # loader 139 | # ==================================================== 140 | test_dataset = TestDataset(test_df, transform=get_transforms(data='valid')) 141 | 142 | 143 | test_loader = DataLoader(test_dataset, 144 | batch_size=32, 145 | shuffle=False, 146 | num_workers=CFG.num_workers, 147 | pin_memory=True, 148 | drop_last=False) 149 | 150 | collect = {"annotations":[]} 151 | 152 | def mergeContours(cnt1, cnt2): 153 | minDist1 = np.inf 154 | minInd1 = None 155 | for ind, point in enumerate(cnt1): 156 | dist = cv2.pointPolygonTest(cnt2, (int(point[0][0]), int(point[0][1])), True) 157 | dist = abs(dist) 158 | if dist < minDist1: 159 | minDist1 = dist 160 | minInd1 = ind 161 | cnt1 = np.roll(cnt1, -minInd1, axis=0) 162 | 163 | minDist2 = np.inf 164 | minInd2 = None 165 | for ind, point in enumerate(cnt2): 166 | dist = cv2.pointPolygonTest(cnt1, (int(point[0][0]), int(point[0][1])), True) 167 | dist = abs(dist) 168 | if dist < minDist2: 169 | minDist2 = dist 170 | minInd2 = ind 171 | cnt2 = np.roll(cnt2, -minInd2, axis=0) 172 | mreged_cnt = np.concatenate([cnt1,cnt2],axis=0) 173 | return mreged_cnt 174 | 175 | import ttach as tta 176 | transforms = tta.Compose( 177 | [ 178 | tta.HorizontalFlip() 179 | ] 180 | ) 181 | 182 | model_path = f'./output/runs/timm-efficientnet-b2_UnetPlusPlus_fold0___hair_66%_image512_iou80_pixel12_DiceBCE_cutmixColor01_epoch13.pth' 183 | encoder1 = Encoder(CFG.encoder_type, CFG.decoder_type, pretrained=False) 184 | checkpoint = torch.load(model_path, map_location=device) 185 | encoder1.load_state_dict(checkpoint['encoder']) 186 | encoder1.to(device) 187 | model_path = f'./output/runs/timm-efficientnet-b2_UnetPlusPlus_fold0___hair_66%_image512_iou80_pixel12_DiceBCE_cutmixColor01_epoch11.pth' 188 | encoder2 = Encoder(CFG.encoder_type, CFG.decoder_type, pretrained=False) 189 | checkpoint = torch.load(model_path, map_location=device) 190 | encoder2.load_state_dict(checkpoint['encoder']) 191 | encoder2.to(device) 192 | model_path = f'./output/runs/timm-efficientnet-b2_UnetPlusPlus_fold0___hair_66%_image512_iou80_pixel12_DiceBCE_cutmixColor01_epoch12.pth' 193 | encoder3 = Encoder(CFG.encoder_type, CFG.decoder_type, pretrained=False) 194 | checkpoint = torch.load(model_path, map_location=device) 195 | encoder3.load_state_dict(checkpoint['encoder']) 196 | encoder3.to(device) 197 | tta_model1 = tta.SegmentationTTAWrapper(encoder1, transforms) 198 | tta_model2 = tta.SegmentationTTAWrapper(encoder2, transforms) 199 | tta_model3 = tta.SegmentationTTAWrapper(encoder3, transforms) 200 | tta_model1.eval() 201 | encoder1.eval() 202 | tta_model2.eval() 203 | encoder2.eval() 204 | tta_model3.eval() 205 | encoder3.eval() 206 | for step, (ids, images) in enumerate(tqdm(test_loader)): 207 | images = images.to(device) 208 | with torch.no_grad(): 209 | y_preds = (1/6)*(encoder1(images)+tta_model1(images)+encoder2(images)+tta_model2(images) 210 | + encoder3(images)+tta_model3(images)) 211 | 212 | y_preds = F.sigmoid(y_preds) 213 | y_preds = y_preds > 0.5 214 | 215 | # n_mask = torch.sum(y_preds > 0.5, dim=(1,2,3)) 216 | # mask_threshold = 14000 217 | # small_mask_ind = torch.where(n_mask < mask_threshold) 218 | # y_preds[small_mask_ind] = torch.where(y_preds[small_mask_ind] > 0.55, 1.0, 0.0) 219 | # y_preds = torch.where(y_preds > 0.49, 1.0, 0.0) 220 | 221 | y_preds = y_preds.detach().cpu().numpy() 222 | y_preds = np.uint8(y_preds*255) 223 | 224 | for idx, y_pred in enumerate(y_preds): 225 | y_pred = y_pred.transpose(1, 2, 0) 226 | contours, hierarchy = cv2.findContours(y_pred, cv2.RETR_LIST ,cv2.CHAIN_APPROX_NONE) 227 | contours = sorted(contours, key= lambda x: cv2.contourArea(x), reverse=True) 228 | if len(contours) == 0: 229 | cnt = np.array([[[0,0]],[[0,0]],[[0,0]]]) 230 | else: 231 | cnt = contours[0] 232 | for i in range(1,len(contours)): 233 | if cv2.contourArea(contours[i]) > 75: 234 | cnt = mergeContours(cnt, contours[i]) 235 | # 적용하는 숫자가 커질 수록 Point의 갯수는 감소 236 | epsilon = 0.555 237 | approx = cv2.approxPolyDP(cnt, epsilon, True) 238 | if len(approx) < 50: 239 | polygons = cnt 240 | else: 241 | polygons = approx 242 | x_and_y = [{"x":int(a[0][0]), "y": int(a[0][1])} for a in polygons] 243 | collect["annotations"].append({"file_name" : ids[idx], 244 | "polygon1" : x_and_y}) 245 | with open('result.json','w') as json_file: 246 | json.dump(collect, json_file) -------------------------------------------------------------------------------- /modules/solver.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import time 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import numpy as np 10 | from torch.cuda.amp import autocast 11 | from torch.optim import Adam 12 | from torch import nn, Tensor 13 | from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, CosineAnnealingLR, ReduceLROnPlateau 14 | from typing import Optional 15 | from torch.utils.data import Dataset, DataLoader 16 | from warmup_scheduler import GradualWarmupScheduler 17 | from logging import getLogger, INFO, FileHandler, Formatter, StreamHandler 18 | 19 | from modules.utils import load_yaml, init_logger 20 | from modules.scheduler import * 21 | from modules.models import Encoder 22 | from modules.dataset import * 23 | from modules.transform import get_transforms_train, get_transforms_preprocessing 24 | 25 | CONFIG_PATH = './config/config.yaml' 26 | config = load_yaml(CONFIG_PATH) 27 | 28 | # LOG 29 | VERSION = config['LOG']['version'] 30 | 31 | # DATALOADER 32 | NUM_WORKERS = config['DATALOADER']['num_workers'] 33 | 34 | # PREPROCESSING 35 | PRE_GRADIENT_ACCUMULATION_STEPS = config['PREPROCESSING']['gradient_accumulation_steps'] 36 | PRE_MAX_GRAD_NORM = config['PREPROCESSING']['max_grad_norm'] 37 | PRE_APEX = config['PREPROCESSING']['apex'] 38 | PRE_PRINT_FREQ = config['PREPROCESSING']['print_freq'] 39 | PRE_SCHEDULER = config['PREPROCESSING']['scheduler'] 40 | PRE_PATIENCE = config['PREPROCESSING']['patience'] 41 | PRE_EPS = config['PREPROCESSING']['eps'] 42 | PRE_T_MAX = config['PREPROCESSING']['T_max'] 43 | PRE_MIN_LR = config['PREPROCESSING']['min_lr'] 44 | PRE_T_0 = config['PREPROCESSING']['T_0'] 45 | PRE_COSINE_EPO = config['PREPROCESSING']['cosine_epo'] 46 | PRE_WARMUP_EPO = config['PREPROCESSING']['warmup_epo'] 47 | PRE_FREEZE_EPO = config['PREPROCESSING']['freeze_epo'] 48 | PRE_WARMUP_FACTOR = config['PREPROCESSING']['warmup_factor'] 49 | PRE_ENCODER_TYPE = config['PREPROCESSING']['encoder_type'] 50 | PRE_DECODER_TYPE = config['PREPROCESSING']['decoder_type'] 51 | PRE_WEIGHT_DECAY = config['PREPROCESSING']['weight_decay'] 52 | PRE_ENCODER_LR = config['PREPROCESSING']['encoder_lr'] 53 | PRE_BATCH_SIZE = config['PREPROCESSING']['batch_size'] 54 | PRE_EPOCHS = PRE_COSINE_EPO + PRE_WARMUP_EPO + PRE_FREEZE_EPO 55 | 56 | # TRAIN 57 | GRADIENT_ACCUMULATION_STEPS = config['TRAIN']['gradient_accumulation_steps'] 58 | MAX_GRAD_NORM = config['TRAIN']['max_grad_norm'] 59 | APEX = config['TRAIN']['apex'] 60 | PRINT_FREQ = config['TRAIN']['print_freq'] 61 | SCHEDULER = config['TRAIN']['scheduler'] 62 | PATIENCE = config['TRAIN']['patience'] 63 | EPS = config['TRAIN']['eps'] 64 | T_MAX = config['TRAIN']['T_max'] 65 | MIN_LR = config['TRAIN']['min_lr'] 66 | T_0 = config['TRAIN']['T_0'] 67 | COSINE_EPO = config['TRAIN']['cosine_epo'] 68 | WARMUP_EPO = config['TRAIN']['warmup_epo'] 69 | FREEZE_EPO = config['TRAIN']['freeze_epo'] 70 | WARMUP_FACTOR = config['TRAIN']['warmup_factor'] 71 | ENCODER_TYPE = config['TRAIN']['encoder_type'] 72 | DECODER_TYPE = config['TRAIN']['decoder_type'] 73 | WEIGHT_DECAY = config['TRAIN']['weight_decay'] 74 | ENCODER_LR = config['TRAIN']['encoder_lr'] 75 | BATCH_SIZE = config['TRAIN']['batch_size'] 76 | SELF_CUTMIX = config['TRAIN']['self_cutmix'] 77 | CUTMIX_THRESHOLD = config['TRAIN']['cutmix_threshold'] 78 | COLORING = config['TRAIN']['coloring'] 79 | COLORING_THRESHOLD = config['TRAIN']['coloring_threshold'] 80 | LOSS_SMOOTH_FACTOR = config['TRAIN']['loss_smooth_factor'] 81 | PROSPECTIVE_FILTERING = config['TRAIN']['prospective_filtering'] 82 | PRETRAINED = config['TRAIN']['pretrained'] 83 | BREAK_EPOCH = config['TRAIN']['break_epoch'] 84 | 85 | EPOCHS = COSINE_EPO + WARMUP_EPO + FREEZE_EPO 86 | 87 | 88 | 89 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 90 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 91 | OUTPUT_DIR = 'model/preprocessing' 92 | OUTPUT_TRAIN_DIR = 'output/runs' 93 | 94 | 95 | class AverageMeter(object): 96 | """Computes and stores the average and current value""" 97 | def __init__(self): 98 | self.reset() 99 | 100 | def reset(self): 101 | self.val = 0 102 | self.avg = 0 103 | self.sum = 0 104 | self.count = 0 105 | 106 | def update(self, val, n=1): 107 | self.val = val 108 | self.sum += val * n 109 | self.count += n 110 | self.avg = self.sum / self.count 111 | 112 | 113 | def asMinutes(s): 114 | m = math.floor(s / 60) 115 | s -= m * 60 116 | return '%dm %ds' % (m, s) 117 | 118 | 119 | def timeSince(since, percent): 120 | now = time.time() 121 | s = now - since 122 | es = s / (percent) 123 | rs = es - s 124 | return '%s (remain %s)' % (asMinutes(s), asMinutes(rs)) 125 | 126 | def rand_bbox(size, rat): 127 | W = size[2] 128 | H = size[3] 129 | min_rat, max_rat = rat 130 | cut_w = np.random.randint(W*min_rat, W*max_rat) 131 | cut_h = np.random.randint(H*min_rat, H*max_rat) 132 | 133 | bbox_list = [] 134 | for i in range(2): 135 | cx = np.random.randint(0.1*W,0.9*W) 136 | cy = np.random.randint(0.1*H,0.9*H) 137 | bbx1 = cx - cut_w // 2 138 | bby1 = cy - cut_h // 2 139 | bbx2 = cx + cut_w // 2 140 | bby2 = cy + cut_h // 2 141 | bbox_list.append([bbx1, bby1, bbx2, bby2]) 142 | 143 | if bbx1 < 0 or bbx2 >= W: 144 | return [] 145 | if bby1 < 0 or bby2 >= H: 146 | return [] 147 | 148 | return bbox_list 149 | 150 | def get_iou(preds, targets): 151 | preds_mask = preds>0 152 | targets_mask = targets>0 153 | 154 | intersection = torch.sum(preds_mask*targets_mask, axis=(1,2,3)) 155 | union = torch.sum(preds_mask, axis=(1,2,3)) + torch.sum(targets_mask, axis=(1,2,3)) - intersection 156 | return intersection/union 157 | 158 | 159 | 160 | 161 | __all__ = ["SoftBCEWithLogitsLoss"] 162 | 163 | 164 | class SoftBCEWithLogitsLoss(nn.Module): 165 | 166 | __constants__ = ["weight", "pos_weight", "reduction", "ignore_index", "smooth_factor"] 167 | 168 | def __init__( 169 | self, 170 | weight: Optional[torch.Tensor] = None, 171 | ignore_index: Optional[int] = -100, 172 | reduction: str = "mean", 173 | smooth_factor: Optional[float] = None, 174 | pos_weight: Optional[torch.Tensor] = None, 175 | ): 176 | super().__init__() 177 | self.ignore_index = ignore_index 178 | self.reduction = reduction 179 | self.smooth_factor = smooth_factor 180 | self.register_buffer("weight", weight) 181 | self.register_buffer("pos_weight", pos_weight) 182 | 183 | def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: 184 | 185 | if self.smooth_factor is not None: 186 | soft_targets = (1 - y_true) * self.smooth_factor + y_true * (1 - self.smooth_factor) 187 | else: 188 | soft_targets = y_true 189 | 190 | loss = F.binary_cross_entropy_with_logits( 191 | y_pred, soft_targets, self.weight, pos_weight=self.pos_weight, reduction="none" 192 | ) 193 | 194 | if self.ignore_index is not None: 195 | not_ignored_mask = y_true != self.ignore_index 196 | loss *= not_ignored_mask.type_as(loss) 197 | 198 | if self.reduction == "mean": 199 | loss = loss.mean() 200 | 201 | if self.reduction == "sum": 202 | loss = loss.sum() 203 | 204 | return loss 205 | 206 | def train_pre_fn(train_loader, encoder1, encoder2, criterion1, criterion2, 207 | optimizer1, optimizer2, epoch, 208 | scheduler1, scheduler2, device): 209 | batch_time = AverageMeter() 210 | data_time = AverageMeter() 211 | losses = AverageMeter() 212 | dice_coeffs = AverageMeter() 213 | # switch to train mode 214 | encoder1.train() 215 | encoder2.train() 216 | 217 | scaler = torch.cuda.amp.GradScaler() 218 | 219 | start = end = time.time() 220 | global_step = 0 221 | for step, (images, targets) in enumerate(train_loader): 222 | # measure data loading time 223 | data_time.update(time.time() - end) 224 | images = images.to(device) 225 | targets = targets.float().to(device) 226 | batch_size = images.size(0) 227 | 228 | # ========================= 229 | # zero_grad() 230 | # ========================= 231 | optimizer1.zero_grad() 232 | optimizer2.zero_grad() 233 | if PRE_APEX: 234 | with autocast(): 235 | y_preds = encoder(images) 236 | loss = criterion(y_preds, targets) 237 | scaler.scale(loss).backward() 238 | scaler.unscale_(optimizer) 239 | else: 240 | y_preds1 = encoder1(images) 241 | y_preds2 = encoder2(images) 242 | 243 | loss1 = criterion1(y_preds1, targets) 244 | loss2 = criterion1(y_preds2, targets) 245 | 246 | ind_1_sorted = np.argsort(np.sum(loss1.data.reshape(len(y_preds1), -1).detach().cpu().numpy(), axis=1)) 247 | loss_1_sorted = loss1[ind_1_sorted] 248 | 249 | ind_2_sorted = np.argsort(np.sum(loss2.data.reshape(len(y_preds1), -1).detach().cpu().numpy(), axis=1)) 250 | loss_2_sorted = loss2[ind_2_sorted] 251 | 252 | forget_rate = 0.1 253 | remember_rate = 1 - forget_rate 254 | num_remember = int(remember_rate * len(loss_1_sorted)) 255 | 256 | ind_1_update=ind_1_sorted[:num_remember] 257 | ind_2_update=ind_2_sorted[:num_remember] 258 | 259 | # exchange 260 | loss_1_update = criterion2(y_preds1[ind_2_update], targets[ind_2_update]) 261 | loss_2_update = criterion2(y_preds2[ind_1_update], targets[ind_1_update]) 262 | 263 | loss_1_update.backward() 264 | loss_2_update.backward() 265 | 266 | # record loss 267 | losses.update((loss_1_update.item() + loss_2_update.item())/2, batch_size) 268 | if PRE_GRADIENT_ACCUMULATION_STEPS > 1: 269 | loss = loss / PRE_GRADIENT_ACCUMULATION_STEPS 270 | 271 | #loss.backward() 272 | encoder_grad_norm1 = torch.nn.utils.clip_grad_norm_(encoder1.parameters(), PRE_MAX_GRAD_NORM) 273 | encoder_grad_norm2 = torch.nn.utils.clip_grad_norm_(encoder2.parameters(), PRE_MAX_GRAD_NORM) 274 | if (step + 1) % PRE_GRADIENT_ACCUMULATION_STEPS == 0: 275 | if PRE_APEX: 276 | scaler.step(optimizer) 277 | scaler.update() 278 | else: 279 | optimizer1.step() 280 | optimizer2.step() 281 | global_step += 1 282 | 283 | # record dice_coeff 284 | dice_coeff = get_dice_coeff((y_preds1+y_preds2)/2, targets) 285 | dice_coeffs.update(dice_coeff, batch_size) 286 | 287 | # measure elapsed time 288 | batch_time.update(time.time() - end) 289 | end = time.time() 290 | if step % PRE_PRINT_FREQ == 0 or step == (len(train_loader)-1): 291 | print('Epoch: [{0}][{1}/{2}] ' 292 | 'Data {data_time.val:.3f} ({data_time.avg:.3f}) ' 293 | 'Elapsed {remain:s} ' 294 | 'Loss: {loss.val:.4f}({loss.avg:.4f}) ' 295 | 'Dice_coeff: {dice_coeff.val:.4f}({dice_coeff.avg:.4f}) ' 296 | 'Encoder Grad1: {encoder_grad_norm1:.4f} ' 297 | 'Encoder Grad2: {encoder_grad_norm2:.4f} ' 298 | 'Encoder LR: {encoder_lr:.6f} ' 299 | .format( 300 | epoch+1, step, len(train_loader), batch_time=batch_time, 301 | data_time=data_time, loss=losses, dice_coeff=dice_coeffs, 302 | remain=timeSince(start, float(step+1)/len(train_loader)), 303 | encoder_grad_norm1=encoder_grad_norm1, 304 | encoder_grad_norm2=encoder_grad_norm2, 305 | encoder_lr=scheduler1.get_lr()[0], 306 | )) 307 | return losses.avg, dice_coeffs.avg 308 | 309 | 310 | def train_pre_loop(folds, fold, LOGGER): 311 | LOGGER.info(f"========== fold: {fold} training ==========") 312 | 313 | # ==================================================== 314 | # loader 315 | # ==================================================== 316 | trn_idx = folds[folds['fold'] != fold].index 317 | train_folds = folds.loc[trn_idx].reset_index(drop=True) 318 | 319 | # ==================================================== 320 | # scheduler 321 | # ==================================================== 322 | def get_scheduler(optimizer): 323 | if PRE_SCHEDULER=='ReduceLROnPlateau': 324 | scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=PRE_FACTOR, patience=PRE_PATIENCE, verbose=True, eps=PRE_EPS) 325 | elif PRE_SCHEDULER=='CosineAnnealingLR': 326 | scheduler = CosineAnnealingLR(optimizer, T_max=PRE_T_MAX, eta_min=PRE_MIN_LR, last_epoch=-1) 327 | elif PRE_SCHEDULER=='CosineAnnealingWarmRestarts': 328 | scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=PRE_T_0, T_mult=1, eta_min=PRE_MIN_LR, last_epoch=-1) 329 | elif PRE_SCHEDULER=='GradualWarmupSchedulerV2': 330 | scheduler_cosine=torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, PRE_COSINE_EPO) 331 | scheduler_warmup=GradualWarmupSchedulerV2(optimizer, multiplier=PRE_WARMUP_FACTOR, total_epoch=PRE_WARMUP_EPO, after_scheduler=scheduler_cosine) 332 | scheduler=scheduler_warmup 333 | return scheduler 334 | 335 | # ==================================================== 336 | # model & optimizer 337 | # ==================================================== 338 | encoder1 = Encoder(PRE_ENCODER_TYPE, PRE_DECODER_TYPE, pretrained=False) 339 | encoder1.to(device) 340 | 341 | encoder2 = Encoder(PRE_ENCODER_TYPE, PRE_DECODER_TYPE, pretrained=False) 342 | encoder2.to(device) 343 | 344 | if len(os.environ['CUDA_VISIBLE_DEVICES'].split(',')) > 1: 345 | encoder1 = nn.DataParallel(encoder1) 346 | encoder2 = nn.DataParallel(encoder2) 347 | 348 | 349 | optimizer1 = Adam(encoder1.parameters(), lr=PRE_ENCODER_LR, weight_decay=PRE_WEIGHT_DECAY, amsgrad=False) 350 | scheduler1 = get_scheduler(optimizer1) 351 | 352 | optimizer2 = Adam(encoder2.parameters(), lr=PRE_ENCODER_LR, weight_decay=PRE_WEIGHT_DECAY, amsgrad=False) 353 | scheduler2 = get_scheduler(optimizer2) 354 | # Log the network weight histograms (optional) 355 | 356 | # ==================================================== 357 | # loop 358 | # ==================================================== 359 | criterion1 = SoftBCEWithLogitsLoss(smooth_factor=0.05, reduction="none") 360 | criterion2 = SoftBCEWithLogitsLoss(smooth_factor=0.05) 361 | 362 | best_score = 0 363 | best_loss = np.inf 364 | 365 | for epoch in range(PRE_EPOCHS): 366 | if epoch >= 1: 367 | break 368 | start_time = time.time() 369 | train_folds_sample = train_folds.sample(frac=0.2, random_state=epoch).reset_index(drop=True) 370 | train_dataset = TrainDataset(train_folds_sample, transform=get_transforms_preprocessing(data='train')) 371 | 372 | train_loader = DataLoader(train_dataset, 373 | batch_size=PRE_BATCH_SIZE, 374 | shuffle=True, 375 | num_workers=NUM_WORKERS, 376 | pin_memory=True, 377 | drop_last=True) 378 | 379 | # train 380 | avg_loss, avg_tr_dice_coeff = train_pre_fn(train_loader, encoder1, encoder2, criterion1, criterion2, optimizer1, optimizer2, epoch, scheduler1, scheduler2, device) 381 | 382 | # scoring 383 | #score = get_score(valid_labels, text_preds) 384 | score = avg_tr_dice_coeff 385 | 386 | if isinstance(scheduler1, ReduceLROnPlateau): 387 | scheduler1.step(score) 388 | scheduler2.step(score) 389 | elif isinstance(scheduler1, CosineAnnealingLR): 390 | scheduler1.step() 391 | scheduler2.step() 392 | elif isinstance(scheduler1, CosineAnnealingWarmRestarts): 393 | scheduler1.step() 394 | scheduler2.step() 395 | elif isinstance(scheduler1, GradualWarmupSchedulerV2): 396 | scheduler1.step(epoch) 397 | scheduler2.step(epoch) 398 | 399 | elapsed = time.time() - start_time 400 | 401 | LOGGER.info(f'Epoch {epoch+1} - avg_train_loss: {avg_loss:.4f} time: {elapsed:.0f}s') 402 | LOGGER.info(f'Epoch {epoch+1} - Score: {avg_tr_dice_coeff:.4f}') 403 | 404 | model_to_save1 = encoder1.module if hasattr(encoder1, 'module') else encoder1 405 | model_to_save2 = encoder2.module if hasattr(encoder2, 'module') else encoder2 406 | 407 | if score > best_score: 408 | best_score = score 409 | LOGGER.info(f'Epoch {epoch+1} - Save Best Score: {best_score:.4f} Model') 410 | torch.save({'encoder': model_to_save1.state_dict(), 411 | 'optimizer': optimizer1.state_dict(), 412 | 'scheduler': scheduler1.state_dict(), 413 | }, 414 | OUTPUT_DIR+f'/{PRE_ENCODER_TYPE}_{PRE_DECODER_TYPE}_fold{fold}_{VERSION}_1.pth') 415 | #f'./{PRE_ENCODER_TYPE}_{PRE_DECODER_TYPE}_fold{fold}_{VERSION}_1.pth') 416 | torch.save({'encoder': model_to_save2.state_dict(), 417 | 'optimizer': optimizer2.state_dict(), 418 | 'scheduler': scheduler2.state_dict(), 419 | }, 420 | OUTPUT_DIR+f'/{PRE_ENCODER_TYPE}_{PRE_DECODER_TYPE}_fold{fold}_{VERSION}_2.pth') 421 | 422 | 423 | def train_fn(train_loader, encoder, criterion, 424 | optimizer, epoch, 425 | scheduler, device): 426 | batch_time = AverageMeter() 427 | data_time = AverageMeter() 428 | losses = AverageMeter() 429 | dice_coeffs = AverageMeter() 430 | real_batch = AverageMeter() 431 | cutmix_counter = AverageMeter() 432 | coloring_counter = AverageMeter() 433 | # switch to train mode 434 | encoder.train() 435 | 436 | scaler = torch.cuda.amp.GradScaler() 437 | 438 | start = end = time.time() 439 | global_step = 0 440 | for step, (images, targets) in enumerate(train_loader): 441 | # measure data loading time 442 | data_time.update(time.time() - end) 443 | images = images.to(device) 444 | targets = targets.float().to(device) 445 | batch_size = images.size(0) 446 | 447 | if COLORING and np.random.random() < COLORING_THRESHOLD: 448 | coloring_counter.update(1) 449 | rgb_gain = np.random.random((batch_size,3)) 450 | rgb_gain[:,0] = 0 451 | masks = targets 452 | 453 | images = images * torch.tensor([ 0.229, 0.224, 0.225 ]).reshape(3,1,1).to(device) 454 | images = images + torch.tensor([ 0.485, 0.456, 0.406 ]).reshape(3,1,1).to(device) 455 | images = images + images * (torch.tensor(rgb_gain).reshape(batch_size,3,1,1).to(device) * masks.repeat(1,3,1,1)) 456 | images = images - torch.tensor([ 0.485, 0.456, 0.406 ]).reshape(3,1,1).to(device) 457 | images = images / torch.tensor([ 0.229, 0.224, 0.225 ]).reshape(3,1,1).to(device) 458 | images = images.float() 459 | else: 460 | coloring_counter.update(0) 461 | 462 | if SELF_CUTMIX and np.random.random() < CUTMIX_THRESHOLD: 463 | bbox_list = rand_bbox(images.size(), (0.1, 0.4)) 464 | if len(bbox_list) == 2: 465 | bbx1_1, bby1_1, bbx2_1, bby2_1 = bbox_list[0] 466 | bbx1_2, bby1_2, bbx2_2, bby2_2 = bbox_list[1] 467 | images[:, :, bbx1_1:bbx2_1, bby1_1:bby2_1] = images[:, :, bbx1_2:bbx2_2, bby1_2:bby2_2] 468 | targets[:, :, bbx1_1:bbx2_1, bby1_1:bby2_1] = targets[:, :, bbx1_2:bbx2_2, bby1_2:bby2_2] 469 | cutmix_counter.update(1) 470 | else: 471 | cutmix_counter.update(0) 472 | else: 473 | cutmix_counter.update(0) 474 | 475 | # ========================= 476 | # zero_grad() 477 | # ========================= 478 | optimizer.zero_grad() 479 | if APEX: 480 | with autocast(): 481 | y_preds = encoder(images) 482 | loss = criterion(y_preds, targets) 483 | scaler.scale(loss).backward() 484 | scaler.unscale_(optimizer) 485 | else: 486 | y_preds = encoder(images) 487 | if PROSPECTIVE_FILTERING and epoch >= 1: 488 | threshold = 0.5 489 | iou = get_iou(y_preds, targets) 490 | train_filter = iou>threshold 491 | batch_size = torch.sum(train_filter) 492 | y_preds = y_preds[train_filter] 493 | targets = targets[train_filter] 494 | loss = criterion(y_preds, targets) 495 | loss.backward() 496 | real_batch.update(batch_size) 497 | # record loss 498 | losses.update(loss.item(), batch_size) 499 | if GRADIENT_ACCUMULATION_STEPS > 1: 500 | loss = loss / GRADIENT_ACCUMULATION_STEPS 501 | encoder_grad_norm = torch.nn.utils.clip_grad_norm_(encoder.parameters(), MAX_GRAD_NORM) 502 | if (step + 1) % GRADIENT_ACCUMULATION_STEPS == 0: 503 | if APEX: 504 | scaler.step(optimizer) 505 | scaler.update() 506 | else: 507 | optimizer.step() 508 | global_step += 1 509 | 510 | # record dice_coeff 511 | dice_coeff = get_dice_coeff(y_preds, 512 | targets) 513 | dice_coeffs.update(dice_coeff, batch_size) 514 | 515 | # measure elapsed time 516 | batch_time.update(time.time() - end) 517 | end = time.time() 518 | if step % PRINT_FREQ == 0 or step == (len(train_loader)-1): 519 | print('Epoch: [{0}][{1}/{2}] ' 520 | 'Data {data_time.val:.3f} ({data_time.avg:.3f}) ' 521 | 'Elapsed {remain:s} ' 522 | 'Loss: {loss.val:.4f}({loss.avg:.4f}) ' 523 | 'Dice_coeff: {dice_coeff.val:.4f}({dice_coeff.avg:.4f}) ' 524 | 'Encoder Grad: {encoder_grad_norm:.4f} ' 525 | 'Encoder LR: {encoder_lr:.6f} ' 526 | .format( 527 | epoch+1, step, len(train_loader), batch_time=batch_time, 528 | data_time=data_time, loss=losses, dice_coeff=dice_coeffs, 529 | remain=timeSince(start, float(step+1)/len(train_loader)), 530 | encoder_grad_norm=encoder_grad_norm, 531 | encoder_lr=scheduler.get_lr()[0], 532 | )) 533 | return losses.avg, dice_coeffs.avg 534 | 535 | 536 | def train_loop(folds, fold, LOGGER): 537 | LOGGER.info(f"========== All dataset training ==========") 538 | 539 | # ==================================================== 540 | # loader 541 | # ==================================================== 542 | trn_idx = folds[folds['fold'] != fold].index 543 | train_folds = folds.loc[trn_idx].reset_index(drop=True) 544 | 545 | train_dataset = TrainDataset(train_folds, transform=get_transforms_train(data='train')) 546 | train_loader = DataLoader(train_dataset, 547 | batch_size=BATCH_SIZE, 548 | shuffle=True, 549 | num_workers=NUM_WORKERS, 550 | pin_memory=True, 551 | drop_last=True) 552 | 553 | 554 | # ==================================================== 555 | # scheduler 556 | # ==================================================== 557 | def get_scheduler(optimizer): 558 | if SCHEDULER=='ReduceLROnPlateau': 559 | scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=FACTOR, patience=PATIENCE, verbose=True, eps=EPS) 560 | elif SCHEDULER=='CosineAnnealingLR': 561 | scheduler = CosineAnnealingLR(optimizer, T_max=T_MAX, eta_min=MIN_LR, last_epoch=-1) 562 | elif SCHEDULER=='CosineAnnealingWarmRestarts': 563 | scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=T_0, T_mult=1, eta_min=MIN_LR, last_epoch=-1) 564 | elif SCHEDULER=='GradualWarmupSchedulerV2': 565 | scheduler_cosine=torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, COSINE_EPO) 566 | scheduler_warmup=GradualWarmupSchedulerV2(optimizer, multiplier=WARMUP_FACTOR, total_epoch=WARMUP_EPO, after_scheduler=scheduler_cosine) 567 | scheduler=scheduler_warmup 568 | return scheduler 569 | 570 | # ==================================================== 571 | # model & optimizer 572 | # ==================================================== 573 | encoder = Encoder(ENCODER_TYPE, DECODER_TYPE, pretrained=PRETRAINED) 574 | encoder.to(device) 575 | 576 | if len(os.environ['CUDA_VISIBLE_DEVICES'].split(',')) > 1: 577 | encoder = nn.DataParallel(encoder) 578 | 579 | optimizer = Adam(encoder.parameters(), lr=ENCODER_LR, weight_decay=WEIGHT_DECAY, amsgrad=False) 580 | scheduler = get_scheduler(optimizer) 581 | 582 | # ==================================================== 583 | # loop 584 | # ==================================================== 585 | criterion = DiceBCELoss() 586 | 587 | best_score = 0 588 | best_loss = np.inf 589 | 590 | for epoch in range(EPOCHS): 591 | if epoch >= BREAK_EPOCH: 592 | break 593 | start_time = time.time() 594 | 595 | # train 596 | avg_loss, avg_tr_dice_coeff = train_fn(train_loader, encoder, criterion, optimizer, epoch, scheduler, device) 597 | 598 | if isinstance(scheduler, ReduceLROnPlateau): 599 | scheduler.step(score) 600 | elif isinstance(scheduler, CosineAnnealingLR): 601 | scheduler.step() 602 | elif isinstance(scheduler, CosineAnnealingWarmRestarts): 603 | scheduler.step() 604 | elif isinstance(scheduler, GradualWarmupSchedulerV2): 605 | scheduler.step(epoch) 606 | 607 | elapsed = time.time() - start_time 608 | 609 | LOGGER.info(f'Epoch {epoch+1} - avg_train_loss: {avg_loss:.4f} time: {elapsed:.0f}s') 610 | LOGGER.info(f'Epoch {epoch+1} - Score: {avg_tr_dice_coeff:.4f}') 611 | 612 | model_to_save = encoder.module if hasattr(encoder, 'module') else encoder 613 | LOGGER.info(f'Epoch {epoch+1} - Save Model') 614 | torch.save({'encoder': model_to_save.state_dict(), 615 | 'optimizer': optimizer.state_dict(), 616 | 'scheduler': scheduler.state_dict(), 617 | }, 618 | OUTPUT_TRAIN_DIR+f'/{ENCODER_TYPE}_{DECODER_TYPE}_{VERSION}_epoch{epoch}.pth') 619 | 620 | --------------------------------------------------------------------------------