├── .gitignore ├── LICENSE ├── README.md ├── data ├── __init__.py ├── figaro.py ├── figaro.sh ├── lfw.py └── lfw.sh ├── demo.py ├── docker └── DockerFile ├── evaluate.py ├── main.py ├── markdowns ├── README.md ├── about_utils.md ├── deeplabv3plus.md ├── figaro.md ├── mobile_net.md ├── pspnet.md ├── semantic_segmentation.md └── visdom.md ├── networks ├── __init__.py ├── deeplab_v3_plus.py ├── mobile_hair.py └── pspnet.py ├── notebooks └── TrainingExample.ipynb ├── requirements.txt └── utils ├── __init__.py ├── joint_transforms.py ├── metrics.py └── trainer_verbose.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .nox/ 42 | .coverage 43 | .coverage.* 44 | .cache 45 | nosetests.xml 46 | coverage.xml 47 | *.cover 48 | .hypothesis/ 49 | .pytest_cache/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | local_settings.py 58 | db.sqlite3 59 | 60 | # Flask stuff: 61 | instance/ 62 | .webassets-cache 63 | 64 | # Scrapy stuff: 65 | .scrapy 66 | 67 | # Sphinx documentation 68 | docs/_build/ 69 | 70 | # PyBuilder 71 | target/ 72 | 73 | # Jupyter Notebook 74 | .ipynb_checkpoints 75 | 76 | # IPython 77 | profile_default/ 78 | ipython_config.py 79 | 80 | # pyenv 81 | .python-version 82 | 83 | # celery beat schedule file 84 | celerybeat-schedule 85 | 86 | # SageMath parsed files 87 | *.sage.py 88 | 89 | # Environments 90 | .env 91 | .venv 92 | env/ 93 | venv/ 94 | ENV/ 95 | env.bak/ 96 | venv.bak/ 97 | 98 | # Spyder project settings 99 | .spyderproject 100 | .spyproject 101 | 102 | # Rope project settings 103 | .ropeproject 104 | 105 | # mkdocs documentation 106 | /site 107 | 108 | # mypy 109 | .mypy_cache/ 110 | .dmypy.json 111 | dmypy.json 112 | 113 | # Pyre type checker 114 | .pyre 115 | 116 | # MAC OS 117 | .DS_Store 118 | .AppleDouble 119 | .LSOverride 120 | 121 | # Pycharm 122 | # User-specific stuff 123 | .idea/**/workspace.xml 124 | .idea/**/tasks.xml 125 | .idea/**/usage.statistics.xml 126 | .idea/**/dictionaries 127 | .idea/**/shelf 128 | 129 | # Generated files 130 | .idea/**/contentModel.xml 131 | 132 | # Sensitive or high-churn files 133 | .idea/**/dataSources/ 134 | .idea/**/dataSources.ids 135 | .idea/**/dataSources.local.xml 136 | .idea/**/sqlDataSources.xml 137 | .idea/**/dynamic.xml 138 | .idea/**/uiDesigner.xml 139 | .idea/**/dbnavigator.xml 140 | 141 | # Gradle 142 | .idea/**/gradle.xml 143 | .idea/**/libraries 144 | 145 | # Dataset 146 | Figaro1k/ 147 | Lfw/ 148 | 149 | # log and ckpt 150 | logs/ 151 | ckpt/ 152 | 153 | .vscode/ 154 | .vscode/* 155 | !.vscode/settings.json 156 | !.vscode/tasks.json 157 | !.vscode/launch.json 158 | !.vscode/extensions.json 159 | 160 | *.zip -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 YBIGTA 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-hair-segmentation 2 | Implementation of pytorch semantic segmentation with [figaro-1k](http://projects.i-ctm.eu/it/progetto/figaro-1k). 3 | 4 | - tutorial document : https://pytorchhair.gitbook.io/project/ (kor) 5 | 6 | ### Prerequisites 7 | ``` 8 | opencv-contrib-python 3.4.4 9 | pytorch 0.4.1 10 | torchvision 0.2.1 11 | numpy 1.14.5 12 | ``` 13 | 14 | 15 | ### Downloading dataset 16 | ```bash 17 | # specify a directory for dataset to be downloaded into, else default is ./data/ 18 | sh data/figaro.sh # 19 | ``` 20 | ### Running trainer 21 | 22 | ```bash 23 | # sample execution 24 | 25 | python3 main.py \ 26 | --networks mobilenet \ 27 | --dataset figaro \ 28 | --data_dir ./data/Figaro1k \ 29 | --scheduler ReduceLROnPlateau \ 30 | --batch_size 4 \ 31 | --epochs 5 \ 32 | --lr 1e-3 \ 33 | --num_workers 2 \ 34 | --optimizer adam \ 35 | --img_size 256 \ 36 | --momentum 0.5 \ 37 | --ignite True 38 | ``` 39 | 40 | * You should add your own model script in `networks` and make it avaliable in `get_network` in `./networks/__init__.py` 41 | 42 | ### Running docker & train 43 | 44 | > with ignite 45 | 46 | `docker run davinnovation/pytorch-hairsegment:cpu python main.py` 47 | 48 | > with no-ignite 49 | 50 | `docker run -p davinnovation/pytorch-hairsegment:cpu python main.py --ignite False` 51 | 52 | ### Evaluating model 53 | 54 | ```bash 55 | # sample execution 56 | 57 | python3 evaluate.py \ 58 | --networks pspnet_resnet101 \ 59 | --ckpt_dir [path to checkpoint] \ 60 | --dataset figaro \ 61 | --data_dir ./data/Figaro1k \ 62 | --save_dir ./overlay/ \ 63 | --use_gpu True 64 | ``` 65 | 66 | ### Evaluation result on figaro testset 67 | 68 | | Model | IoU | F1-score | Checkpoint | 69 | | --- | --- | --- | --- | 70 | | pspnet_resnet101 | 0.92| 0.96 | [link](https://drive.google.com/file/d/1w7oMuxckqEClImjLFTH7xBCpm1wg7Eg4/view?usp=sharing) 71 | | pspnet_squeezenet| 0.88| 0.91 | [link](https://drive.google.com/file/d/1ieKvsK3uoDZN0vA5MenQphca4AZZuaPk/view?usp=sharing) | 72 | | deeplabv3plus | 0.80| 0.89 | - | 73 | 74 | 75 | ### Sample visualization 76 | * Red: GT / Blue: Segmentation Map 77 | 78 | ![sample_0](https://user-images.githubusercontent.com/19547969/227229779-28b42d02-efad-4b7b-be65-3cf3a1a7bfef.png) 79 | ![sample_1](https://user-images.githubusercontent.com/19547969/227229796-5de39ea1-73fe-4be8-9ef7-2857df54b94c.png) 80 | ![sample_2](https://user-images.githubusercontent.com/19547969/227229856-e224b91c-6fb2-4aa8-a93f-6ab1edfe568b.png) 81 | ![sample_3](https://user-images.githubusercontent.com/19547969/227229883-ff4b05e7-ba23-42c9-9dec-a431bf0715f1.png) 82 | ![sample_4](https://user-images.githubusercontent.com/19547969/227229909-68b6cdf1-6f89-4cf9-a2f8-01be251dd140.png) 83 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from .figaro import FigaroDataset 2 | from .lfw import LfwDataset 3 | from torch.utils.data import DataLoader 4 | 5 | 6 | def get_loader(dataset, data_dir='./data/Figaro1k', train=True, batch_size=64, shuffle=True, 7 | joint_transforms=None, image_transforms=None, mask_transforms=None, num_workers=0, gray_image=False): 8 | """ 9 | Args: 10 | dataset (string): name of dataset to use 11 | data_dir (string): directory to dataset 12 | train (bool): whether training or not 13 | batch_size (int): batch size 14 | joint_transforms (Compose): list of joint transforms both on images and masks 15 | image_transforms (Compose): list of transforms only on images 16 | mask_transforms (Compose): list of transforms only on targets (masks) 17 | """ 18 | 19 | if dataset.lower() == 'figaro': 20 | dset = FigaroDataset(root_dir=data_dir, 21 | train=train, 22 | joint_transforms=joint_transforms, 23 | image_transforms=image_transforms, 24 | mask_transforms=mask_transforms, 25 | gray_image=gray_image) 26 | 27 | elif dataset.lower() == 'lfw': 28 | dset = LfwDataset(root_dir=data_dir, 29 | train=train, 30 | joint_transforms=joint_transforms, 31 | image_transforms=image_transforms, 32 | mask_transforms=mask_transforms) 33 | else: 34 | raise ValueError 35 | 36 | loader = DataLoader(dset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers) 37 | 38 | return loader 39 | -------------------------------------------------------------------------------- /data/figaro.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from PIL import Image 4 | from torch.utils.data import Dataset 5 | 6 | 7 | class FigaroDataset(Dataset): 8 | def __init__(self, root_dir, train=True, joint_transforms=None, 9 | image_transforms=None, mask_transforms=None, gray_image=False): 10 | """ 11 | Args: 12 | root_dir (str): root directory of dataset 13 | joint_transforms (torchvision.transforms.Compose): tranformation on both data and target 14 | image_transforms (torchvision.transforms.Compose): tranformation only on data 15 | mask_transforms (torchvision.transforms.Compose): tranformation only on target 16 | gray_image (bool): whether to return gray image image or not. 17 | If True, returns img, mask, gray. 18 | """ 19 | mode = 'Training' if train else 'Testing' 20 | img_dir = os.path.join(root_dir, 'Original', mode) 21 | mask_dir = os.path.join(root_dir, 'GT', mode) 22 | 23 | self.img_path_list = [os.path.join(img_dir, img) for img in sorted(os.listdir(img_dir))] 24 | self.mask_path_list = [os.path.join(mask_dir, mask) for mask in sorted(os.listdir(mask_dir))] 25 | self.joint_transforms = joint_transforms 26 | self.image_transforms = image_transforms 27 | self.mask_transforms = mask_transforms 28 | self.gray_image = gray_image 29 | 30 | def __getitem__(self,idx): 31 | img_path = self.img_path_list[idx] 32 | img = Image.open(img_path) 33 | 34 | mask_path = self.mask_path_list[idx] 35 | mask = Image.open(mask_path) 36 | 37 | if self.joint_transforms is not None: 38 | img, mask = self.joint_transforms(img, mask) 39 | 40 | if self.gray_image: 41 | gray = img.convert('L') 42 | gray = np.array(gray,dtype=np.float32)[np.newaxis,]/255 43 | 44 | if self.image_transforms is not None: 45 | img = self.image_transforms(img) 46 | 47 | if self.mask_transforms is not None: 48 | mask = self.mask_transforms(mask) 49 | 50 | if self.gray_image: 51 | return img, mask, gray 52 | else: 53 | return img, mask 54 | 55 | def __len__(self): 56 | return len(self.mask_path_list) 57 | 58 | def get_class_label(self, filename): 59 | """ 60 | 0: straight: frame00001-00150 61 | 1: wavy: frame00151-00300 62 | 2: curly: frame00301-00450 63 | 3: kinky: frame00451-00600 64 | 4: braids: frame00601-00750 65 | 5: dreadlocks: frame00751-00900 66 | 6: short-men: frame00901-01050 67 | """ 68 | idx = int(filename.strip('Frame').strip('-gt.pbm')) 69 | 70 | if 0 < idx <= 150: 71 | return 0 72 | elif 150 < idx <= 300: 73 | return 1 74 | elif 300 < idx <= 450: 75 | return 2 76 | elif 450 < idx <= 600: 77 | return 3 78 | elif 600 < idx <= 750: 79 | return 4 80 | elif 750 < idx <= 900: 81 | return 5 82 | elif 900 < idx <= 1050: 83 | return 6 84 | raise ValueError 85 | -------------------------------------------------------------------------------- /data/figaro.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | if [ -z "$1" ] 3 | then 4 | # navigate to ~/data 5 | echo "navigating to ./data/ ..." 6 | cd ./data/ 7 | else 8 | # check if is valid directory 9 | if [ ! -d $1 ]; then 10 | echo $1 "is not a valid directory" 11 | exit 0 12 | fi 13 | echo "navigating to" $1 "..." 14 | cd $1 15 | fi 16 | 17 | echo "Now downloading Figaro1k.zip ..." 18 | 19 | # wget http://projects.i-ctm.eu/sites/default/files/AltroMateriale/207_Michele%20Svanera/Figaro1k.zip 20 | # The official link is not working for some reason, so temporarily use dropbox instead. 21 | 22 | wget https://www.dropbox.com/s/35momrh68zuhkei/Figaro1k.zip 23 | 24 | echo "Unzip Figaro1k.zip ..." 25 | 26 | unzip Figaro1k.zip 27 | 28 | echo "Removing unnecessary files ..." 29 | 30 | rm -f Figaro1k.zip 31 | rm -f Figaro1k/GT/Training/*'(1).pbm' 32 | rm -f Figaro1k/.DS_Store 33 | rm -rf __MACOSX 34 | 35 | echo "Finished!" 36 | -------------------------------------------------------------------------------- /data/lfw.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from PIL import Image 4 | from torch.utils.data import Dataset 5 | 6 | 7 | class LfwDataset(Dataset): 8 | def __init__(self, root_dir, train=True, joint_transforms=None, 9 | image_transforms=None, mask_transforms=None, gray_image=False): 10 | """ 11 | Args: 12 | root_dir (str): root directory of dataset 13 | joint_transforms (torchvision.transforms.Compose): tranformation on both data and target 14 | image_transforms (torchvision.transforms.Compose): tranformation only on data 15 | mask_transforms (torchvision.transforms.Compose): tranformation only on target 16 | gray_image (bool): True if to add gray images 17 | """ 18 | 19 | txt_file = 'parts_train_val.txt' if train else 'parts_test.txt' 20 | txt_dir = os.path.join(root_dir, txt_file) 21 | name_list = LfwDataset.parse_name_list(txt_dir) 22 | img_dir = os.path.join(root_dir, 'lfw_funneled') 23 | mask_dir = os.path.join(root_dir, 'parts_lfw_funneled_gt_images') 24 | 25 | self.img_path_list = [os.path.join(img_dir, elem[0], elem[1]+'.jpg') for elem in name_list] 26 | self.mask_path_list = [os.path.join(mask_dir, elem[1]+'.ppm') for elem in name_list] 27 | self.joint_transforms = joint_transforms 28 | self.image_transforms = image_transforms 29 | self.mask_transforms = mask_transforms 30 | self.gray_image = gray_image 31 | 32 | def __getitem__(self, idx): 33 | img_path = self.img_path_list[idx] 34 | img = Image.open(img_path) 35 | 36 | mask_path = self.mask_path_list[idx] 37 | mask = Image.open(mask_path) 38 | mask = LfwDataset.rgb2binary(mask) 39 | 40 | if self.joint_transforms is not None: 41 | img, mask = self.joint_transforms(img, mask) 42 | 43 | if self.image_transforms is not None: 44 | img = self.image_transforms(img) 45 | 46 | if self.mask_transforms is not None: 47 | mask = self.mask_transforms(mask) 48 | 49 | if self.gray_image: 50 | gray = img.convert('L') 51 | gray = np.array(gray,dtype=np.float32)[np.newaxis,]/255 52 | return img, mask, gray 53 | 54 | return img, mask 55 | 56 | def __len__(self): 57 | return len(self.mask_path_list) 58 | 59 | @staticmethod 60 | def rgb2binary(mask): 61 | """transforms RGB mask image to binary hair mask image. 62 | """ 63 | mask_arr = np.array(mask) 64 | mask_map = mask_arr == np.array([255, 0, 0]) 65 | mask_map = np.all(mask_map, axis=2).astype(np.float32) 66 | return Image.fromarray(mask_map) 67 | 68 | @staticmethod 69 | def parse_name_list(fp): 70 | with open(fp, 'r') as fin: 71 | lines = fin.readlines() 72 | parsed = list() 73 | for line in lines: 74 | name, num = line.strip().split(' ') 75 | num = format(num, '0>4') 76 | filename = '{}_{}'.format(name, num) 77 | parsed.append((name, filename)) 78 | return parsed 79 | -------------------------------------------------------------------------------- /data/lfw.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | if [ -z "$1" ] 3 | then 4 | # navigate to ~/data 5 | echo "navigating to ./data/ ..." 6 | cd ./data/ 7 | else 8 | # check if is valid directory 9 | if [ ! -d $1 ]; then 10 | echo $1 "is not a valid directory" 11 | exit 0 12 | fi 13 | echo "navigating to" $1 "..." 14 | cd $1 15 | fi 16 | 17 | mkdir Lfw 18 | cd Lfw 19 | 20 | echo "Now downloading Figaro1k.zip ..." 21 | 22 | wget http://vis-www.cs.umass.edu/lfw/lfw-funneled.tgz 23 | 24 | echo "Unzip lfw-funneled.tgz ..." 25 | 26 | tar -xvf lfw-funneled.tgz 27 | 28 | echo "Now downloading GT Images ... " 29 | wget wget http://vis-www.cs.umass.edu/lfw/part_labels/parts_lfw_funneled_gt_images.tgz 30 | 31 | echo "Unzip parts_lfw_funneled_gt_images.tgz ..." 32 | tar -xvf parts_lfw_funneled_gt_images.tgz 33 | 34 | echo "Now downloading txt files" 35 | wget http://vis-www.cs.umass.edu/lfw/part_labels/parts_train.txt 36 | wget http://vis-www.cs.umass.edu/lfw/part_labels/parts_validation.txt 37 | wget http://vis-www.cs.umass.edu/lfw/part_labels/parts_test.txt 38 | 39 | echo "Making parts_train_val.txt ..." 40 | cat parts_train.txt parts_validation.txt > parts_train_val.txt 41 | 42 | echo "Finished!" 43 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | import time 5 | import os 6 | import sys 7 | import argparse 8 | from PIL import Image 9 | sys.path.append(os.path.dirname(os.path.dirname(__file__))) 10 | 11 | from networks import get_network 12 | from data import get_loader 13 | import torchvision.transforms as std_trnsf 14 | from utils import joint_transforms as jnt_trnsf 15 | from utils.metrics import MultiThresholdMeasures 16 | 17 | def str2bool(s): 18 | return s.lower() in ('t', 'true', 1) 19 | 20 | def has_img_ext(fname): 21 | ext = os.path.splitext(fname)[1] 22 | return ext in ('.jpg', '.jpeg', '.png') 23 | 24 | if __name__ == '__main__': 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument('--ckpt_dir', help='path to ckpt file',type=str, 27 | default='./models/pspnet_resnet101_sgd_lr_0.002_epoch_100_test_iou_0.918.pth') 28 | parser.add_argument('--img_dir', help='path to image files', type=str, default='./data/Figaro1k') 29 | parser.add_argument('--networks', help='name of neural network', type=str, default='pspnet_resnet101') 30 | parser.add_argument('--save_dir', default='./overlay', 31 | help='path to save overlay images') 32 | parser.add_argument('--use_gpu', type=str2bool, default=True, 33 | help='True if using gpu during inference') 34 | 35 | args = parser.parse_args() 36 | 37 | ckpt_dir = args.ckpt_dir 38 | img_dir = args.img_dir 39 | network = args.networks.lower() 40 | save_dir = args.save_dir 41 | device = 'cuda' if args.use_gpu else 'cpu' 42 | 43 | assert os.path.exists(ckpt_dir) 44 | assert os.path.exists(img_dir) 45 | assert os.path.exists(os.path.split(save_dir)[0]) 46 | 47 | os.makedirs(save_dir, exist_ok=True) 48 | 49 | # prepare network with trained parameters 50 | net = get_network(network).to(device) 51 | state = torch.load(ckpt_dir) 52 | net.load_state_dict(state['weight']) 53 | 54 | 55 | test_image_transforms = std_trnsf.Compose([ 56 | std_trnsf.ToTensor(), 57 | std_trnsf.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 58 | ]) 59 | 60 | 61 | durations = list() 62 | 63 | # prepare images 64 | img_paths = [os.path.join(img_dir, k) for k in sorted(os.listdir(img_dir)) if has_img_ext(k)] 65 | with torch.no_grad(): 66 | for i, img_path in enumerate(img_paths, 1): 67 | print('[{:3d}/{:3d}] processing image... '.format(i, len(img_paths))) 68 | img = Image.open(img_path) 69 | data = test_image_transforms(img) 70 | data = torch.unsqueeze(data, dim=0) 71 | net.eval() 72 | data = data.to(device) 73 | 74 | # inference 75 | start = time.time() 76 | logit = net(data) 77 | duration = time.time() - start 78 | 79 | # prepare mask 80 | pred = torch.sigmoid(logit.cpu())[0][0].data.numpy() 81 | mh, mw = data.size(2), data.size(3) 82 | mask = pred >= 0.5 83 | 84 | mask_n = np.zeros((mh, mw, 3)) 85 | mask_n[:,:,0] = 255 86 | mask_n[:,:,0] *= mask 87 | 88 | path = os.path.join(save_dir, os.path.basename(img_path)+'.png') 89 | image_n = np.array(img) 90 | image_n = cv2.cvtColor(image_n, cv2.COLOR_RGB2BGR) 91 | # discard padded area 92 | ih, iw, _ = image_n.shape 93 | 94 | delta_h = mh - ih 95 | delta_w = mw - iw 96 | 97 | top = delta_h // 2 98 | bottom = mh - (delta_h - top) 99 | left = delta_w // 2 100 | right = mw - (delta_w - left) 101 | 102 | mask_n = mask_n[top:bottom, left:right, :] 103 | 104 | # addWeighted 105 | image_n = image_n * 0.5 + mask_n * 0.5 106 | 107 | # log measurements 108 | durations.append(duration) 109 | 110 | # write overlay image 111 | cv2.imwrite(path,image_n) 112 | 113 | 114 | avg_fps = sum(durations)/len(durations) 115 | print('Avg-FPS:', avg_fps) 116 | -------------------------------------------------------------------------------- /docker/DockerFile: -------------------------------------------------------------------------------- 1 | FROM nsml/default_ml:tf-gpu-1.11.0torch-0.4.1opencv-3.4.3 2 | 3 | EXPOSE 8097 4 | 5 | RUN apt-get update && \ 6 | apt-get install -y git 7 | 8 | RUN pip install visdom 9 | 10 | RUN pip install https://download.pytorch.org/whl/cpu/torch-1.0.0-cp36-cp36m-linux_x86_64.whl 11 | RUN pip install torchvision 12 | RUN pip install opencv-python 13 | 14 | RUN git clone https://github.com/YBIGTA/pytorch-hair-segmentation.git 15 | 16 | WORKDIR pytorch-hair-segmentation 17 | 18 | RUN pip install -r requirements.txt 19 | 20 | RUN sh data/figaro.sh -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | import time 5 | import os 6 | import sys 7 | import argparse 8 | from PIL import Image 9 | sys.path.append(os.path.dirname(os.path.dirname(__file__))) 10 | 11 | from networks import get_network 12 | from data import get_loader 13 | import torchvision.transforms as std_trnsf 14 | from utils import joint_transforms as jnt_trnsf 15 | from utils.metrics import MultiThresholdMeasures 16 | 17 | def str2bool(s): 18 | return s.lower() in ('t', 'true', 1) 19 | 20 | if __name__ == '__main__': 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('--ckpt_dir', help='path to ckpt file',type=str, 23 | default='./models/pspnet_resnet101_sgd_lr_0.002_epoch_100_test_iou_0.918.pth') 24 | parser.add_argument('--dataset', type=str, default='figaro', 25 | help='Name of dataset you want to use default is "figaro"') 26 | parser.add_argument('--data_dir', help='path to Figaro1k folder', type=str, default='./data/Figaro1k') 27 | parser.add_argument('--networks', help='name of neural network', type=str, default='pspnet_resnet101') 28 | parser.add_argument('--save_dir', default='./overlay', 29 | help='path to save overlay images, default=None and do not save images in this case') 30 | parser.add_argument('--use_gpu', type=str2bool, default=True, 31 | help='True if using gpu during inference') 32 | 33 | args = parser.parse_args() 34 | 35 | ckpt_dir = args.ckpt_dir 36 | data_dir = args.data_dir 37 | img_dir = os.path.join(data_dir, 'Original', 'Testing') 38 | network = args.networks.lower() 39 | save_dir = args.save_dir 40 | device = 'cuda' if args.use_gpu else 'cpu' 41 | 42 | assert os.path.exists(ckpt_dir) 43 | assert os.path.exists(data_dir) 44 | assert os.path.exists(os.path.split(save_dir)[0]) 45 | 46 | if not os.path.exists(save_dir): 47 | os.mkdir(save_dir) 48 | 49 | # prepare network with trained parameters 50 | net = get_network(network).to(device) 51 | state = torch.load(ckpt_dir) 52 | net.load_state_dict(state['weight']) 53 | 54 | # this is the default setting for train_verbose.py 55 | test_joint_transforms = jnt_trnsf.Compose([ 56 | jnt_trnsf.Safe32Padding() 57 | ]) 58 | 59 | test_image_transforms = std_trnsf.Compose([ 60 | std_trnsf.ToTensor(), 61 | std_trnsf.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 62 | ]) 63 | 64 | # transforms only on mask 65 | mask_transforms = std_trnsf.Compose([ 66 | std_trnsf.ToTensor() 67 | ]) 68 | 69 | test_loader = get_loader(dataset=args.dataset, 70 | data_dir=data_dir, 71 | train=False, 72 | joint_transforms=test_joint_transforms, 73 | image_transforms=test_image_transforms, 74 | mask_transforms=mask_transforms, 75 | batch_size=1, 76 | shuffle=False, 77 | num_workers=4) 78 | 79 | # prepare measurements 80 | metric = MultiThresholdMeasures() 81 | metric.reset() 82 | durations = list() 83 | 84 | # prepare images 85 | imgs = [os.path.join(img_dir, k) for k in sorted(os.listdir(img_dir)) if k.endswith('.jpg')] 86 | with torch.no_grad(): 87 | for i, (data, label) in enumerate(test_loader): 88 | print('[{:3d}/{:3d}] processing image... '.format(i+1, len(test_loader))) 89 | net.eval() 90 | data, label = data.to(device), label.to(device) 91 | 92 | # inference 93 | start = time.time() 94 | logit = net(data) 95 | duration = time.time() - start 96 | 97 | # prepare mask 98 | pred = torch.sigmoid(logit.cpu())[0][0].data.numpy() 99 | mh, mw = data.size(2), data.size(3) 100 | mask = pred >= 0.5 101 | 102 | mask_n = np.zeros((mh, mw, 3)) 103 | mask_n[:,:,0] = 255 104 | mask_n[:,:,0] *= mask 105 | 106 | path = os.path.join(save_dir, "figaro_img_%04d.png" % i) 107 | image_n = cv2.imread(imgs[i]) 108 | 109 | # discard padded area 110 | ih, iw, _ = image_n.shape 111 | 112 | delta_h = mh - ih 113 | delta_w = mw - iw 114 | 115 | top = delta_h // 2 116 | bottom = mh - (delta_h - top) 117 | left = delta_w // 2 118 | right = mw - (delta_w - left) 119 | 120 | mask_n = mask_n[top:bottom, left:right, :] 121 | 122 | # addWeighted 123 | image_n = image_n * 0.5 + mask_n * 0.5 124 | 125 | # log measurements 126 | metric.update((logit, label)) 127 | durations.append(duration) 128 | 129 | # write overlay image 130 | cv2.imwrite(path,image_n) 131 | 132 | 133 | # compute measurements 134 | iou = metric.compute_iou() 135 | f = metric.compute_f1() 136 | acc = metric.compute_accuracy() 137 | avg_fps = sum(durations)/len(durations) 138 | 139 | print('Avg-FPS:', avg_fps) 140 | print('Pixel-acc:', acc) 141 | print('F1-score:', f) 142 | print('IOU:', iou) 143 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import gc 4 | import sys 5 | import copy 6 | import pickle 7 | import logging 8 | import argparse 9 | 10 | from utils.trainer_verbose import train_with_ignite, train_without_ignite, get_optimizer 11 | from utils import check_mkdir 12 | 13 | import torch 14 | 15 | from networks import mobile_hair 16 | 17 | logger = logging.getLogger('hair segmentation project') 18 | 19 | def str2bool(s): 20 | return s.lower() in ('t', 'true', '1') 21 | 22 | 23 | def get_args(): 24 | parser = argparse.ArgumentParser(description='Hair Segmentation') 25 | parser.add_argument('--networks', default='mobilenet') 26 | parser.add_argument('--scheduler', default='ReduceLROnPlateau') 27 | parser.add_argument('--dataset', default='figaro') 28 | parser.add_argument('--data_dir', default='./data/Figaro1k') 29 | parser.add_argument('--batch_size', type=int, default=4) 30 | parser.add_argument('--epochs', default=5, type=int) 31 | parser.add_argument('--lr', default=0.0001, type=float) 32 | parser.add_argument('--num_workers', type=int, default=2) 33 | parser.add_argument('--img_size',type=int, default=256) 34 | parser.add_argument('--use_pretrained', type=str, default='ImageNet') 35 | parser.add_argument('--ignite', type=str2bool, default=True) 36 | parser.add_argument('--visdom', type=str2bool, default=False) 37 | parser.add_argument('--optimizer', type=str, default='adam') 38 | parser.add_argument('--momentum', type=float, default=0.9) 39 | 40 | args = parser.parse_args() 41 | 42 | return args 43 | 44 | 45 | def main(): 46 | args = get_args() 47 | 48 | check_mkdir('./logs') 49 | 50 | logging_name = './logs/{}_{}_lr_{}.txt'.format(args.networks, 51 | args.optimizer, 52 | args.lr) 53 | 54 | logger.setLevel(logging.DEBUG) 55 | formatter = logging.Formatter( 56 | '[%(asctime)10s][%(levelname)s] %(message)s', 57 | datefmt='%Y/%m/%d %H:%M:%S' 58 | ) 59 | 60 | stream_handler = logging.StreamHandler() 61 | stream_handler.setFormatter(formatter) 62 | 63 | file_handler = logging.FileHandler(logging_name) 64 | file_handler.setFormatter(formatter) 65 | 66 | logger.addHandler(stream_handler) 67 | logger.addHandler(file_handler) 68 | logger.info('arguments:{}'.format(" ".join(sys.argv))) 69 | if args.ignite is False: 70 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 71 | 72 | model = mobile_hair.MobileMattingFCN() 73 | 74 | if torch.cuda.is_available(): 75 | if torch.cuda.device_count() > 1: 76 | print('multi gpu') 77 | model = torch.nn.DataParallel(model) 78 | 79 | model.to(device) 80 | 81 | loss = mobile_hair.HairMattingLoss() 82 | 83 | optimizer = get_optimizer(args.optimizer, model, args.lr, args.momentum) 84 | # torch.optim.Adam(filter(lambda p: p.requires_grad,model.parameters()), lr=0.0001, betas=(0.9, 0.999)) 85 | 86 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min') 87 | 88 | train_without_ignite(model, 89 | loss, 90 | batch_size=args.batch_size, 91 | img_size=args.img_size, 92 | epochs=args.epochs, 93 | lr=args.lr, 94 | num_workers=args.num_workers, 95 | optimizer=optimizer, 96 | logger=logger, 97 | gray_image=True, 98 | scheduler=scheduler, 99 | viz=args.visdom) 100 | 101 | else: train_with_ignite(networks=args.networks, 102 | dataset=args.dataset, 103 | data_dir=args.data_dir, 104 | batch_size=args.batch_size, 105 | epochs=args.epochs, 106 | lr=args.lr, 107 | num_workers=args.num_workers, 108 | optimizer=args.optimizer, 109 | momentum=args.momentum, 110 | img_size=args.img_size, 111 | logger=logger) 112 | 113 | 114 | if __name__ == '__main__': 115 | main() 116 | -------------------------------------------------------------------------------- /markdowns/README.md: -------------------------------------------------------------------------------- 1 | # pytorch-hair-segmentation 2 | Implementation of pytorch semantic segmentation with [figaro-1k](http://projects.i-ctm.eu/it/progetto/figaro-1k). 3 | 4 | ### Downloading dataset 5 | ```bash 6 | # specify a directory for dataset to be downloaded into, else default is ./data/ 7 | sh data/figaro.sh # 8 | ``` 9 | ### Running trainer 10 | 11 | ```bash 12 | # sample execution 13 | 14 | python3 main.py \ 15 | --networks segnet \ 16 | --dataset figaro \ 17 | --data_dir ./data/Figaro1k \ 18 | --scheduler ReduceLROnPlateau \ 19 | --batch_size 4 \ 20 | --epochs 100 \ 21 | --lr 1e-3 \ 22 | --num_workers 4 \ 23 | --optimizer adam \ 24 | --img_size 256 \ 25 | --momentum 0.5 26 | ``` 27 | 28 | * You should add your own model script in `networks` and make it avaliable in `get_network` in `./networks/__init__.py` 29 | 30 | ### Current Project Tree 31 | 32 | ```python 33 | pytorch-hair-segmentation/ 34 | data/ # includes data script 35 | docker/ # includes dockerfile 36 | logs/ # log file for train-test 37 | markdowns/ # documentation 38 | models/ # saved model 39 | networks/ # pytorch model 40 | deeplab_v3_plus 41 | mobile_hair 42 | pspnet 43 | segnet 44 | ternausnet 45 | unet 46 | notebooks/ # notebook example for using network code 47 | utils/ # util function for training 48 | ``` 49 | 50 | ### RUN with Docker 51 | -------------------------------------------------------------------------------- /markdowns/about_utils.md: -------------------------------------------------------------------------------- 1 | ## Ignite 2 | 학습 시 [pytorch-ignite](https://pytorch.org/ignite/)를 활용해보았습니다. training 시 에폭과 배치 루프를 돌면서 여러 반복작업을 수행하는데요. 예를 들어 각각의 배치마다 loss를 역전파시키고, 각각의 에폭마다 validation set에 대해서 여러 metric을 뽑아내고 일정 주기마다 모델을 저장합니다. ignite는 이와 관련된 메소드를 제공하여 보다 깔끔한 코드를 작성할 수 있도록 도와주는 라이브러리입니다. ignite 페이지에 업로드된 아래 이미지는 같은 작업에 대해서 ignite를 사용한 경우와 사용하지 않은 경우를 비교합니다. 3 | ![](https://raw.githubusercontent.com/pytorch/ignite/master/assets/ignite_vs_bare_pytorch.png) 4 | 5 | 저희는 아래와 같은 형태로 코드를 작성했습니다. 편의를 위해 중간중간 생략된 부분이 있습니다. 6 | 1. trainer / evaluater 생성 7 | - ignite의 큰 뼈대는 `ignite.engine.Engine`으로 이루어져있습니다. Engine은 입력 받은 연산을 반복 수행하는 역할을 합니다. 8 | - `ignite.engine.create_supervised_trainer` 메소드는 loss를 계산하고 이를 역전파하는 Engine을 리턴합니다. 9 | - `ignite.engine.create_supervised_trainer` 메소드는 모델의 output을 산출하고 입력받은 metric들을 연산하는 Engine을 리턴합니다. 10 | 11 | 2. logging 함수 구현 12 | - Engine.on() 메소드는 decorator를 이용하여 특정 event가 발생했을 때 작성된 함수를 실행시키도록 합니다. 이를 이용해 아래와 같은 함수를 작성했습니다. 13 | - `log_training_loss`: 각각의 iteration마다 (Events.ITERATION_COMPLETED) 발생한 loss를 로깅하는 함수입니다. 14 | - `log_training_results`: 각각의 iteration마다 트레이닝 셋 전체에 대해 여러 metric 값을 로깅하는 함수입니다. 15 | - `log_validation_results`: 각각의 iteration마다 테스트(혹은 validation) 셋 전체에 대해 여러 metric 값을 로깅하는 함수입니다. 16 | 17 | 18 | 3. train.run() 19 | - Engine.run()은 입력받은 max_epochs만큼 위의 작업들을 수행합니다. 20 | 21 | 22 | ```python 23 | from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator 24 | from ignite.metrics import Loss 25 | 26 | 27 | trainer = create_supervised_trainer(model, model_optimizer, loss, device=device) 28 | evaluator = create_supervised_evaluator(model, 29 | metrics={ 30 | 'pix-acc': Accuracy(), 31 | 'iou': IoU(0.5), 32 | 'loss': Loss(loss), 33 | 'f1': F1score() 34 | }, 35 | device=device) 36 | 37 | # execution after every training iteration 38 | @trainer.on(Events.ITERATION_COMPLETED) 39 | def log_training_loss(trainer): 40 | num_iter = (trainer.state.iteration - 1) % len(train_loader) + 1 41 | if num_iter % 20 == 0: 42 | logger.info("Epoch[{}] Iter[{:03d}] Loss: {:.2f}".format( 43 | trainer.state.epoch, num_iter, trainer.state.output)) 44 | 45 | # execution after every training epoch 46 | @trainer.on(Events.EPOCH_COMPLETED) 47 | def log_training_results(trainer): 48 | # evaluate on training set 49 | evaluator.run(train_loader) 50 | metrics = evaluator.state.metrics 51 | logger.info("Training Results - Epoch: {} Avg-loss: {:.3f} Pix-acc: {:.3f} IoU: {:.3f} F1: {}".format( 52 | trainer.state.epoch, metrics['loss'], metrics['pix-acc'], metrics['iou'], str(metrics['f1']))) 53 | 54 | # execution after every epoch 55 | @trainer.on(Events.EPOCH_COMPLETED) 56 | def log_validation_results(trainer): 57 | # evaluate test(validation) set 58 | evaluator.run(test_loader) 59 | metrics = evaluator.state.metrics 60 | logger.info("Validation Results - Epoch: {} Avg-loss: {:.2f} Pix-acc: {:.2f} IoU: {:.3f} F1: {}".format( 61 | trainer.state.epoch, metrics['loss'], metrics['pix-acc'], metrics['iou'], str(metrics['f1']))) 62 | 63 | trainer.run(train_loader, max_epochs=epochs) 64 | ``` 65 | 66 | 67 | 현재 `ignite.metrics` 에서는 Precision, Recall, Accuracy 등의 제한된 metric만을 제공하고 있습니다. 새로운 metric은 `ignite.metrics.MetricsLambda`를 활용하여 구현 가능하다고 합니다. 저희는 `ignite.metrics`의 다른 구현체들을 참고하여 클래스를 새로 작성했습니다. `ignite.metrics.Metric`을 상속받은 후 __ init __, reset, update, compute 등의 메소드를 오버라이딩했습니다. (Pixel-level) IoU를 예로 들어보겠습니다. 68 | 69 | 1. __ init __ : 다른 메소드 실행에 필요한 멤버 변수를 저장합니다. 확률값에 대한 threshold 값을 저장해두었습니다. 70 | 2. reset : 연산을 시작하기 전 멤버변수를 초기화하는 함수입니다. 71 | 3. update : metric 계산을 위해 각각의 iteration에서 연산에 필요한 값을 업데이트하는 함수입니다. IoU의 경우 모델 결과와 GT 사이의 union과 intersection 값을 더해줍니다. 72 | 4. compute : 정해진 주기가 끝나면 (데이터 셋 전체를 다 훑어본 후) metric을 계산하는 함수입니다. 73 | 74 | ```python 75 | class IoU(Metric): 76 | """ 77 | Calculates intersection over union for only foreground (hair) 78 | """ 79 | def __init__(self, thrs=0.5): 80 | super(IoU, self).__init__() 81 | self._thrs = thrs 82 | self.reset() 83 | 84 | def reset(self): 85 | self._num_intersect = 0 86 | self._num_union = 0 87 | 88 | def update(self, output): 89 | logit, y = output 90 | 91 | y_pred = torch.sigmoid(logit) >= self._thrs 92 | y = y.byte() 93 | 94 | intersect = y_pred * y == 1 95 | union = y_pred + y > 0 96 | 97 | self._num_intersect += torch.sum(intersect).item() 98 | self._num_union += torch.sum(union).item() 99 | 100 | def compute(self): 101 | if self._num_union == 0: 102 | raise ValueError('IoU must have at least one example before it can be computed') 103 | return self._num_intersect / self._num_union 104 | ``` -------------------------------------------------------------------------------- /markdowns/deeplabv3plus.md: -------------------------------------------------------------------------------- 1 |

2 | Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation, DeepLab V3+ (2018.08.22) 3 | 4 | **DeepLab**은 v1부터 가장 최신 버전인 v3+까지 총 4개의 버전이 있습니다. 5 | 6 | > 1\. [DeepLabv1](https://arxiv.org/abs/1412.7062) (2015) : Semantic Image Segmentation with Deep Convolutional Nets and Fully Connected CRFs  7 | >
2\. [DeepLabv2](https://arxiv.org/abs/1606.00915) (2017) : DeepLab : Semantic Image Segmentation with Deep Convolutional Nets, Atrous Convolution, and Fully Connected CRFs 8 | >
3\. [DeepLabv3](https://arxiv.org/abs/1706.05587) (2017) : Rethinking Atrous Convolution for Semantic Image Segmentation 9 | >
4\. [DeepLabv3+](https://arxiv.org/abs/1802.02611) (2018) : Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation 10 | 11 | 이 중 가장 최신의 논문이며 뛰어난 성능을 보이고 있는 DeepLabv3+에 대해서 다루도록 하겠습니다. 12 | 13 |
14 | 15 | 16 | # Abstract 17 | 18 | 1\. **Spatial pyramid pooling module** **Encode-decoder structure** 각각의 장점을 합쳐 더욱 더 좋은 성능을 보이는 DeepLabv3+를 제시합니다. 19 |
- Spatial pyramid pooling module : Rate가 다른 여러 개의 atrous convolution을 통해 다양한 크기(multi-scale)의 물체 정보를 encode할 수 있습니다. 20 |
- Encode-decoder structure : 공간 정보(spatial information)를 점진적으로 회복함으로써 더욱 정확하게 물체의 바운더리를 잡아낼 수 있습니다. 21 | 22 | 2\. **Xception model**을 backbone network로 사용합니다. 23 | 24 | 3. **The depthwise separable convolution**을 통해 encoder-decoder를 더욱 빠르고 정확하게 합니다. 25 | 26 | 4\. 후처리(post-processing) 작업 없이 PASCAL VOC 2012와 Cityscapes의 test data에 대해 각각 89.0%와 82.1%의 성능을 보입니다. 27 |
- v2까지는 모델의 output에 CRF를 거치는 post-processing을 통해 더욱 정확한 바운더리를 잡아내었습니다. 28 | 29 |
30 |
31 | 32 | ## 1\. Introduction 33 | 34 | - Atrous convolution을 통해 extracted encoder features의 resolution을 임의로 할 수 있습니다. 이는 기존의 Encoder-decoder model에서는 불가능했습니다. 35 | 나머지 내용은 모두 Abstract에서 언급한 것과 동일하기 때문에 생략하도록 하겠습니다. 36 | 37 |
38 |
39 | 40 | ## 2\. Related Work 41 | 42 | ### Atrous Convolution 43 |

44 |
※ 사진 출처 : https://www.mdpi.com/2072-4292/9/5/498/htm) 45 |

46 | Atrous에서 trous는 구멍(hole)을 의미합니다. 즉, convolution에서 중간중간을 비워두고 특정 간격의 픽셀에 대해서 합성곱을 한다고 생각할 수 있습니다. Atrous convolution에는 rate라는 parameter가 존재하는데 사진에서 r이 rate에 해당합니다. r이 커질수록 필터는 더 넓은 영역을 담을 수 있게 됩니다. 47 | 48 |
장점 : 동일한 수의 파라미터를 이용함에도 불구하고 필터의 receptive field를 키울 수 있습니다. 따라서 semantic segmentation에서는 디테일한 정보가 중요하므로 여러 convolution과 pooling 과정에서 디테일한 정보가 줄어들고 특성이 점점 추상화되는 것을 어느정도 방지할 수 있습니다. 49 | 50 | 51 |

52 | 53 | > (a) Spatial Pyramid Pooling 54 | >
- 다른 rates의 atrous convolution을 parallel하게 적용하여 다양한 크기(multi-scale)의 정보를 담을 수 있습니다. 55 | >
56 | >
(b) Encoder-Decoder 57 | >
- Encoder-decoder는 human pose estimation, object detection, 그리고 semantic segmentation을 포함해 다양한 컴퓨터 비전 영역에서 좋은 성능을 보여왔습니다. Encoder 부분은 점점 feature maps을 줄여 higher semantic information을 잡아낼 수 있습니다. 또한 Decoder 부분은 점진적으로 공간 정보(spatial information)를 회복합니다. 58 | 59 | 이 논문에서는 encoder module로 이전 버전인 DeepLabv3을 사용할 것을 제시하고 있습니다. 60 | 61 |
62 |
63 |
64 | ### Depthwise separable convolution 65 | 66 | Depthwise separable convolution은 [Mobilenet](https://arxiv.org/abs/1704.04861) 논문에 담긴 사진과 [링크](http://machinethink.net/blog/googles-mobile-net-architecture-on-iphone/)에서 사용한 사진을 통해 정리하도록 하겠습니다. 67 | 68 |

69 | 위의 그림은 일반적으로 사용되는 convolution을 나타낸 사진입니다. 70 | 71 | 하나의 3x3 filter가 input channel과 동일한 channel을 가지고 합성곱을 하면 channel이 1인 output이 나오게 됩니다. 72 | 73 | 이러한 filter를 원하는 output channel의 개수만큼 사용합니다. 74 | 75 | 76 |

77 | 이를 일반화하여 나타내면 위와 같습니다. 78 | 79 | Dk는 filter의 크기이고 M은 input channel, 그리고 N은 output channel (= filter의 개수)를 의미합니다. 80 | 81 |

82 | 83 |

84 | 85 | Depth separable convolution은 두 가지 과정을 거칩니다. 86 | >1\. Depth Convolution : filter의 channel이 input channel과 같은 것이 아니라 1로 두어서 convolution 결과 channel의 변화가 생기지 않도록 한다. 87 | >
88 | >2\. Pointwise Convolution : 1x1 filter를 원하는 output channel의 개수만큼 사용하여 channel의 개수를 맞춰준다. 89 | 90 |

91 |

92 | 93 | 94 | Depthwise convolution 과정에서 사용되는 filter는 channel이 1이고 filter의 개수가 input channel인 M과 같습니다. 95 | 그리고 Pointwise convolution에서 사용되는 filter는 channel이 input channel인 M이고 96 | filter의 개수가 원하는 output channel과 같습니다. 97 | 98 | Depth Separable Convolution의 장점으로는 99 | 일반적인 Convolution과 비슷한 과정을 거치지만 사용되는 parameter의 수를 획기적으로 줄여 100 | 모델의 속도를 향상시킬 수 있습니다. 101 | 102 | ※ 잘 이해가 되지 않으시는 분은 [PR-044: MobileNet](https://www.youtube.com/watch?v=7UoOFKcyIvM&feature=youtu.be)을 참고하시면 좋을 것 같습니다! 103 | 104 | ## 3. Methods 105 | 106 |

107 | 108 | ### 3.1 Encoder-Decoder with Atrous Convolution 109 | 110 | DeepLabv3+에서는 Encoder로 DeepLabv3을 사용하고 Decoder로 bilinear upsampling 대신에 U-net과 유사하게 concat해주는 방법을 사용합니다. 111 | 112 | 1) **Encoder (DeepLabv3)** : DCNN(Deep Convolutional Neural Network)에서 Atrous convolution을 통해 임의의 resolution으로 특징을 뽑아낼 수 있도록 합니다. 여기서 output stride의 개념이 쓰이는데 'input image의 resolution과 최종 output의 resolution의 비'로 생각하시면 됩니다. (The ratio of input image spatial resolution to the final output resolution) 즉, 최종 feature maps이 input image에 비해 32배 줄어들었다면 output stride는 32가 됩니다. Semantic segmentation에서는 더욱 디테일한 정보를 얻어내기 위해서 마지막 부분의 block을 1개 혹은 2개를 삭제하고 atrous convolution을 해줌으로써 output stride를 16혹은 8로 줄입니다.  113 | 114 |  그리고 다양한 크기의 물체 정보를 잡아내기 위해 다양한 rates의 atrous convolution을 사용하는 ASPP(Atrous Spatial Pyramid Pooling)를 사용합니다. 115 | 116 |
2) **Decoder** : 이전의 DeepLabv3에서는 decoder 부분을 단순히 bilinear upsampling해주었으나 v3+에서는 encoder의 최종 output에 1x1 convolution을 하여 channel을 줄이고 bilinear upsampling을 해준 후 concat하는 과정이 추가되었습니다. 이를 통해 decoder 과정에서 효과적으로 object segmentation details을 유지할 수 있게 되었습니다. 117 | 118 | 119 |
120 | ### 3.2 Modified Aligned Xception 121 | 122 |

123 |  DeepLabv3+에서는 Xception을 backbone으로 사용하지만 MSRA의 Aligned Xception과 다른 3가지 변화를 주었습니다. 124 | 125 | >1) 빠른 연산과 메모리의 효율을 위해 entry flow network structure를 수정하지 않았습니다. 126 | >2) Atrous separable convolution을 적용하기 위해 모든 pooling operation을 depthwise separable convolution으로 대체했습니다. 127 | >3) 각각의 3 x 3 depthwise convolution 이후에 추가적으로 bath normalization과 ReLU 활성화 함수를 추가해주었습니다. 128 | 129 | 130 | ## 4\. Experimental Evalution 131 | 132 |

133 | 134 | DeepLabv3에서 사용했던 ResNet대신에 Xception을 backbone으로 사용하였을 때 135 | Error rate가 더 낮음을 확인할 수 있습니다. 136 | 137 |

138 |

139 | 140 | 141 | 표를 통해 Decoder를 추가한 것이 훨씬 높은 mIOU를 보인다는 것을 확인할 수 있습니다. 142 | 143 | 또한 예시를 통해 살펴보면 중간 사진은 DeepLabv3과 동일하게 단순히 BU(Bilinear Upsampling)만 사용한 것이고 오른쪽 사진은 U-net과 유사하게 concat 과정이 있는 decoder를 사용한 것입니다. 144 | 이를 통해 Decoder 과정에서 boundary information을 잘 잡아내었음을 확인할 수 있습니다. 145 | 146 | ## 5\. Conclusion 147 | 148 |  DeepLab은 segmentation task의 문제점에 대한 해결책을 찾으며 v1부터 꾸준하게 발전해온 모델입니다. 또한 Atrous convolution과 decoder를 통해 segmentation을 더욱 더 정확하게 할 수 있음을 보여주었습니다. 149 | 150 | ## 6\. Reference 151 | 152 |  1. DeepLabv3+ 리뷰 : [https://blog.lunit.io/2018/07/02/deeplab-v3-encoder-decoder-with-atrous-separable-convolution-for-semantic-image-segmentation/](https://blog.lunit.io/2018/07/02/deeplab-v3-encoder-decoder-with-atrous-separable-convolution-for-semantic-image-segmentation/) 153 | 154 |  2\. Atrous Convolution : [https://www.mdpi.com/2072-4292/9/5/498/html](https://www.mdpi.com/2072-4292/9/5/498/html) 155 | 156 |  3\. MobileNet : [https://arxiv.org/abs/1704.04861](https://arxiv.org/abs/1704.04861) (Paper) 157 |
[http://machinethink.net/blog/googles-mobile-net-architecture-on-iphone/](http://machinethink.net/blog/googles-mobile-net-architecture-on-iphone/) 158 | 159 | [https://www.youtube.com/watch?v=7UoOFKcyIvM&feature=youtu.be](https://www.youtube.com/watch?v=7UoOFKcyIvM&feature=youtu.be) (PR-044) -------------------------------------------------------------------------------- /markdowns/figaro.md: -------------------------------------------------------------------------------- 1 | ## Figaro 데이터셋 소개 2 | 3 | 저희가 semantic segmentation을 적용하고자 하는 분야는 hair segmentation입니다. 이미지에서 사람의 머리카락 부분을 탐지하는 작업인데요. 이와 관련된 데이터셋 중 하나가 [Figaro-1k](http://projects.i-ctm.eu/it/progetto/figaro-1k) 입니다. strait, wavy, curly 등 7가지 헤어스타일의 클래스가 있습니다. 각각 클래스당 150개로 총 1050개의 데이터셋입니다. 4:1 비율로 Training 840, Testing 210 개로 나뉘어 있습니다. 샘플 이미지는 아래와 같습니다. 4 | 5 | ![](http://projects.i-ctm.eu/sites/default/files/Images/207_Michele%20Svanera/database.jpg) 6 | 7 | Patch1k도 제공하고 있습니다. 머리카락 부분만 줌인하여 편집한 패치 1050장과 non-hair 패치 1050장이 있습니다. [논문](http://www.eecs.qmul.ac.uk/~urm30/Doc/Publication/2018/IVC2018.pdf)에 따르면 hair에 대한 feature를 학습하기 위한 보조 데이터셋으로 사용 가능하다고 합니다. 저희는 보조 데이터셋을 사용하지 않고 Figaro-1k 데이터만 사용하여 학습했습니다. 또한 헤어스타일 클래스에 상관 없이 머리카락인지 아닌지에 대한 binary classification을 진행하였습니다. 프로젝트의 root에서 아래와 같은 명령어로 데이터셋을 다운 받을 수 있습니다. 8 | 9 | ```bash 10 | # 특정 디렉토리에 다운로드를 원하는 경우 argument로 명시. 그렇지 않을 시 ./data/ 에 다운로드 11 | sh data/figaro.sh # 12 | ``` 13 | 14 | ## Pytorch FigaroDataSet 구현 15 | 파이토치에서는 torch.utils.data.Dataset을 상속받아 손쉽게 데이터셋을 구현할 수 있으며, `torch.utils.data.dataloader`를 통해 쉽게 불러올 수 있습니다. 데이터셋 구현 시 __init__, __getitem__, __len__ 세 가지 메소드를 오버라이딩하면 됩니다. 16 | 17 | 18 | #### __ init __ 19 | 다른 메소드 실해엥 필요한 멤버변수를 설정합니다. Figaro1k/ 폴더의 경로를 root_dir로 넣어주면 이미지와 마스크(GT) 파일들의 경로를 저장합니다. 또한 이미지 혹은 GT 마스크에 적용할 transforms들을 인자로 받아 멤버변수로 저장합니다. transforms에 대해서는 아래서 설명하겠습니다. 20 | 21 | ```python 22 | class FigaroDataset(Dataset): 23 | def __init__(self, root_dir, train=True, joint_transforms=None, 24 | image_transforms=None, mask_transforms=None, gray_image=False): 25 | """ 26 | Args: 27 | root_dir (str): root directory of dataset 28 | joint_transforms (torchvision.transforms.Compose): tranformation on both data and target 29 | image_transforms (torchvision.transforms.Compose): tranformation only on data 30 | mask_transforms (torchvision.transforms.Compose): tranformation only on target 31 | gray_image (bool): whether to return gray image image or not. 32 | If True, returns img, mask, gray. 33 | """ 34 | mode = 'Training' if train else 'Testing' 35 | img_dir = os.path.join(root_dir, 'Original', mode) 36 | mask_dir = os.path.join(root_dir, 'GT', mode) 37 | 38 | self.img_path_list = [os.path.join(img_dir, img) for img in sorted(os.listdir(img_dir))] 39 | self.mask_path_list = [os.path.join(mask_dir, mask) for mask in sorted(os.listdir(mask_dir))] 40 | self.joint_transforms = joint_transforms 41 | self.image_transforms = image_transforms 42 | self.mask_transforms = mask_transforms 43 | self.gray_image = gray_image 44 | ``` 45 | 46 | 47 | #### __ getitem __ 48 | 인덱스를 통해 이미지 / GT 파일에 접근한 후 pytorch tensor 형태로 반환하는 함수입니다. 이미지 / GT 파일을 PIL Image로 읽어온 후 세 가지 transforming 과정을 거칩니다. 49 | 50 | 1. joint_transforms: 좌우변환 등 geometric한 변경이 필요한 경우 이미지(데이터)와 마스트(타겟)에 모두 적용시킵니다. 51 | 2. image_transforms: 색상 변환 등 이미지(데이터)에만 적용되는 변환입니다. 52 | 3. mask_transoforms: 마스크(타겟)에만 적용되는 변환입니다. PIL.Image를 Tensor로 변환하는 데만 사용했습니다. 53 | 54 | FigaroDataSet 인스턴스 생성 시 image / mask tranforms의 경우 `torchvision.transforms`의 클래스들을 argument로 넘겨주었습니다. joint transforms의 경우, [pytorch-semantic-segmentation](https://github.com/zijundeng/pytorch-semantic-segmentation/blob/master/utils/joint_transforms.py)의 구현체를 이용했습니다. 또한 grayscale의 이미지가 필요한 경우 함께 리턴하도록 구현했습니다. 55 | 56 | ``` python 57 | def __getitem__(self,idx): 58 | img_path = self.img_path_list[idx] 59 | img = Image.open(img_path) 60 | 61 | if self.gray_image: 62 | gray = img.convert('LA') 63 | 64 | mask_path = self.mask_path_list[idx] 65 | mask = Image.open(mask_path) 66 | 67 | if self.joint_transforms is not None: 68 | img, mask = self.joint_transforms(img, mask) 69 | 70 | if self.image_transforms is not None: 71 | img = self.image_transforms(img) 72 | 73 | if self.mask_transforms is not None: 74 | mask = self.mask_transforms(mask) 75 | 76 | if self.gray_image: 77 | return img, mask, gray 78 | else: 79 | return img, mask 80 | ``` 81 | 82 | #### __ len __ 83 | 데이터셋의 크기를 반환하는 함수로, 멤버변수에 저장된 마스크 파일 갯수를 반환시켰습니다. 84 | ```python 85 | def __len__(self): 86 | return len(self.mask_path_list) 87 | ``` 88 | 89 | -------------------------------------------------------------------------------- /markdowns/mobile_net.md: -------------------------------------------------------------------------------- 1 | ## MobileNet 2 | 3 | [mobilenet](https://arxiv.org/abs/1704.04861)은 적은 파라미터로 효율적인 모델을 만들기 위해 노력한 논문입니다. 4 | 5 | -------------------------------------------------------------------------------- /markdowns/pspnet.md: -------------------------------------------------------------------------------- 1 | ## [Pyramid Scene Parsing Network](https://arxiv.org/pdf/1612.01105.pdf) 2 | 3 | 4 | ![](https://img-blog.csdnimg.cn/2018110821502579.jpg?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3UwMTQzODAxNjU=,size_16,color_FFFFFF,t_70) 5 | 6 | 앞서 FCN에 대해서 살펴보았습니다. FCN 계열의 모델은 괜찮은 성능을 보여왔지만, 다양한 풍경에서의 일반화 능력이 떨어진다는 단점이 있었습니다. 이미지 전체에서 컨텍스트 정보를 뽑아내는 능력이 부족했기 때문인데요. 예를 들어 위 이미지 가장 윗부분에서 기존의 FCN은 보트를 자동차로 탐지합니다. 이미지 전반에 걸쳐서 정보를 뽑아낼 수 있다면, '물 위에 있기 때문에 이 물체는 보트다.'는 추론이 가능하겠죠. 이를 극복하기 위해 여러 노력이 있었는데요, 대표적으로 dilated convolution을 사용하는 방법이 있습니다. 기존의 convolution에서 커널 사이에 일정한 간격을 두어 필터가 보다 넓은 영역의 정보를 담을 수 있게 합니다. 본 논문 Pyramid Scene Parsing Network 역시 dilated convolution을 사용합니다. 이에 덧붙여 보다 직접적으로 이미지 전역에 걸친 feature를 뽑아내는 모듈을 추가했는데요, 이에 대해서는 아래에서 보다 자세히 설명하겠습니다. 7 | 8 | 9 | ![Dilated convolution](https://www.researchgate.net/profile/Yizhou_Yu/publication/316899215/figure/fig7/AS:668532287238147@1536401926298/Dilated-convolution-Top-The-spatial-resolution-of-feature-maps-is-reduced-by-half-after.png) 10 | 11 | 12 | 저자는 논문의 기여를 아래와 같이 서술합니다. 13 | * FCN 계열의 픽셀 단위 예측 모델에 비해 보다 어려운 풍경 정보를 담을 수 있는 pyramid scene parsing network 제안. 14 | * deeply supervised loss를 활용한 효과적인 최적화 전략 개발. 15 | * 최고 성능을 보이는 semantic segmentation 시스템 개발 및 구현 코드 공개 16 | 17 | 18 | ### Pyramid Pooling Module 19 | 기존의 FCN과 가장 차별화되는 부분입니다. 일반적인 FCN은 convolutional layer를 통해 인코딩한 정보를 점진적으로 보간(interpolation)해나갑니다. 보간된 정보는 다시 convolutional layer를 거치며 보다 풍부(dense)해집니다. 이런 과정을 거쳐 입력 이미지에 상응하는 사이즈의 score map을 출력합니다. 따라서 각 픽셀별로 어떤 class에 속하는 지 score를 나타내게 됩니다. 본 논문은 이 중간 과정에 Pyramid Pooling Layer를 추가합니다. 그 구조는 아래와 같습니다. 20 | ![](https://hszhao.github.io/projects/pspnet/figures/pspnet.png) 21 | 22 | 1. 서로 다른 사이즈가 되도록 여러 차례 pooling을 합니다. 논문에서는 1x1, 2x2, 3x3, 6x6 사이즈로 만들었는데요. 1x1 사이즈의 feature map은 가장 조악한 정보를, 하지만 가장 넓은 범위의 정보를 담습니다. 각각 다른 사이즈의 feature map은 서로 다른 부분 영역들의 정보를 담게 됩니다. 23 | 2. 이후 1x1 convolution을 통해 channel 수를 조정합니다. pooling layer의 개수를 N이라고 할 때, `출력 channel 수 = 입력 채널 수 / N` 이 됩니다. 24 | 3. 이후 이 모듈의 input size에 맞게끔 feature map을 upsample합니다. 이 과정에서 bilinear interpolation이 사용됩니다. 25 | 4. 원래의 feature map과 위 과정을 통해 생성한 새로운 feature map들을 이어붙여(concatenate)줍니다. 26 | 27 | 본 논문에서는 위와 같은 4개의 사이즈를 사용했지만, 구현에 따라서 다르게 설정할 수 있다고 합니다. pooling의 경우 max pooling과 average pooling을 모두 사용해본 결과, average pooling이 일반적으로 좋은 성능을 보였다고 합니다. 28 | 29 | ### deeply supervised loss 30 | ![](https://tangzhenyu.github.io/assets/paper_notes/pspnet/image3.jpg) 31 | 본 논문에서는 dilated convolution을 사용한 resnet 50, resnet101 등 깊은 모델을 사용했습니다. 깊은 모델이 잘 학습될 경우 더 좋은 정확도를 보이지만, 최적화가 어려운데요. 본 논문에서는 보조적인 loss를 사용하여 이 문제를 해결합니다. 모델을 끝단에서 loss1을 산출하기에 앞서, 앞단의 레이어 (resnet4b22)에 보조적인 classifier를 달아 보조 loss2를 산출합니다. 트레이닝 과정에서 이 두 가지 loss를 가중합계 내어 최종적인 loss를 산출합니다. Inference단에서는 loss2를 계산하는 보조 classifier는 사용하지 않는다고 합니다. 본 논문에선는 여러 실험을 통해 이 보조 loss의 존재가 학습에 도움이 된다고 주장합니다. 32 | 33 | 34 | ### 구현 35 | 설명에 앞서 기존의 [pytorch pspnet 구현체](https://github.com/Lextal/pspnet-pytorch)를 참고했음을 밝힙니다. 36 | 37 | 1. Base Network 38 | - 저희가 사용하는 Figaro-1k은 데이터셋 크기가 작기 때문에 최대한 얕은 모델을 사용하고자 했습니다. 39 | - 때문에 SqueezeNet을 사용했습니다. `torchvision.models`에서 pretrained model을 불러와 classifier 전단계까지의 레이어를 사용했습니다. 40 | 41 | ```python 42 | class SqueezeNetExtractor(nn.Module): 43 | def __init__(self): 44 | super(SqueezeNetExtractor, self).__init__() 45 | model = squeezenet1_1(pretrained=True) 46 | features = model.features 47 | self.feature1 = features[:2] 48 | self.feature2 = features[2:5] 49 | self.feature3 = features[5:8] 50 | self.feature4 = features[8:] 51 | 52 | def forward(self, x): 53 | f1 = self.feature1(x) 54 | f2 = self.feature2(f1) 55 | f3 = self.feature3(f2) 56 | f4 = self.feature4(f3) 57 | return f4 58 | ``` 59 | 2. Pyramid Pooling Layer 60 | - 논문에 나온 4 가지 사이즈 (1x1, 2x2, 3x3, 6x6)를 사용했습니다. 61 | - 이 과정에서는 `torch.nn.AdaptiveAvgPool2d`를 사용했습니다. 62 | ```python 63 | class PyramidPoolingModule(nn.Module): 64 | def __init__(self, in_channels, sizes=(1, 2, 3, 6)): 65 | super(PyramidPoolingModule, self).__init__() 66 | pyramid_levels = len(sizes) 67 | out_channels = in_channels // pyramid_levels 68 | 69 | pooling_layers = nn.ModuleList() 70 | for size in sizes: 71 | layers = [nn.AdaptiveAvgPool2d(size), nn.Conv2d(in_channels, out_channels, kernel_size=1)] 72 | pyramid_layer = nn.Sequential(*layers) 73 | pooling_layers.append(pyramid_layer) 74 | 75 | self.pooling_layers = pooling_layers 76 | 77 | def forward(self, x): 78 | h, w = x.size(2), x.size(3) 79 | features = [x] 80 | for pooling_layer in self.pooling_layers: 81 | # pool with different sizes 82 | pooled = pooling_layer(x) 83 | 84 | # upsample to original size 85 | upsampled = F.upsample(pooled, size=(h, w), mode='bilinear') 86 | 87 | features.append(upsampled) 88 | 89 | return torch.cat(features, dim=1) 90 | ``` 91 | 3. deeply supervised loss 92 | - SqueezeNet을 사용했기 때문에 보조적인 loss가 필요 없다고 판단하여 구현하지 않았습니다. 93 | 94 | 95 | ### 성능 평가 96 | 97 | pyramid pooling 모듈의 사용 여부가 정확도에 미치는 영향을 파악하기 위해 pyramid pooling 모듈을 제외한 FCN 형태의 네트워크도 학습을 진행해보았습니다. 98 | 99 | 1. 양적 평가 100 | 101 | | pyramid pooling | threshold | Validation-loss | Pixcel Accuracy | IoU | F1-score| 102 | | --- | --- | --- | --- | --- | --- | 103 | | o | 0.50 | 0.11 | 0.96 | 0.861 | 0.925 | 104 | | x | 0.50 | 0.12 | 0.95 | 0.843 | 0.915 | 105 | 106 | * Pixel Accuracy = TP + TN / (TP + FP + FN + TN) 107 | * IoU = TP / (TP + FP + FN) 108 | * F1-score = 2 * precision * recall / () precision + recall ) 109 | * 위 결과는 각각의 네트워크에서 가장 IoU가 가장 높은 pyramid pooling module을 사용하는 경우 대부분의 measure에서 보다 좋은 정확도를 보였습니다. 110 | 111 | 112 | 2. 질적 평가 113 | 114 | * Best Cases 115 | 116 | Commercial Photography 117 | Commercial Photography 118 | Commercial Photography 119 | SAMPLE-5 120 | 121 | * Failure Cases 122 | 123 | Commercial Photography 124 | Commercial Photography 125 | * 흑백 이미지에서 머리카락이 아닌 다른 부분이 탐지되는 경우가 있습니다. 126 | * 수염 등 머리카락과 헷갈릴만한 부분에서 잘못 탐지되는 경우가 있습니다. 127 | * 보다 많은 샘플을 확보한다면 일반화 능력이 더 좋아질 것으로 예상됩니다. 128 | 129 | 130 | 131 | 132 | -------------------------------------------------------------------------------- /markdowns/semantic_segmentation.md: -------------------------------------------------------------------------------- 1 | # Semantic Segmentation 2 | **Semantic Segmentation**이란, 이미지를 픽셀별(pixel-wise)로 분류(Classification)하는 것입니다. 3 |
아래의 그림은 Semantic Segmentation의 예시 중 하나입니다. 4 | 5 | ![FCN](https://i.ibb.co/YhdKdd5/ss.png) 6 |
※ 사진 출처 : https://www.jeremyjordan.me/semantic-segmentation/ 7 | 8 | 오른쪽 사진을 보면, 모든 픽셀이 3가지 클래스(사람, 자전거, 배경) 중 하나로 분류된 것을 확인할 수 있습니다. 9 | 10 | ## 활용 11 | 12 | Semantic Segmentation은 매우 다양한 분야에서 사용되고 있습니다. 대표적인 예로 **자율 주행 자동차**에서 Semantic Segmentation은 핵심적인 역할을 합니다. 13 | 14 | ![Drive](https://i.ibb.co/TkZYvrD/figure9.png) 15 |
※ 사진 출처 : https://devblogs.nvidia.com/image-segmentation-using-digits-5/ 16 | 17 | 자율 주행 자동차가 정면에 위치한 대상이 사람인지, 자동차인지, 횡단보도인지 신호등인지 정확하게 구분하지 못하면 상황에 따른 적절한 판단을 내릴 수 없습니다. 따라서 Semantic Segmentation의 정확도와 속도를 모두 높이기 위해 많은 연구가 이루어지고 있습니다. 18 | 19 | ## Semantic Segmentation VS Instance Segmentation 20 | 21 | ![comparison](https://i.ibb.co/0yL6Yjf/is.png) 22 |
※ 사진 출처 : http://slazebni.cs.illinois.edu/spring18/lec25_deep_segmentation.pdf 23 | 24 | Semantic Segmentation과 Instance Segmentation의 차이를 잘 보여주고 있는 예시입니다. 위의 그림에서 중간에 위치한 **Semantic Segmentation**의 경우, 각 픽셀을 사람(핑크색)과 배경(검은색) 중에 어떤 클래스에 속하는지 분류하고 있습니다. 이와 달리, 오른쪽에 위치한 **Instance Segmentation**은 사람과 배경을 구분해줄 뿐만 아니라 사람끼리도 구분해주고 있는 것을 확인할 수 있습니다. 25 | 즉, Semantic Segmentation은 단순히 각각의 픽셀이 어떤 클래스에 속하는지 분류하는 것에 그치는 반면에 Instance Segmentation은 동일한 클래스에 속하더라도 각각의 사물을 개별적으로 구분해줍니다. 26 | 27 | ## 대표적인 논문 28 | 29 | **Semantic Segmentation** 분야의 대표적은 논문들은 아래와 같습니다. 30 | 31 | 1. FCN (2014) : https://arxiv.org/abs/1411.4038 32 | 2. U-Net (2015) : https://arxiv.org/abs/1505.04597 33 | 3. SegNet (2015) : https://arxiv.org/abs/1511.00561 34 | 4. PSPNet (2016) : https://arxiv.org/abs/1612.01105 35 | 5. DeepLab V3+ (2018) : https://arxiv.org/abs/1802.02611 36 | -------------------------------------------------------------------------------- /markdowns/visdom.md: -------------------------------------------------------------------------------- 1 | ## Visdom 2 | 3 | [Visdom](https://github.com/facebookresearch/visdom)은 데이터를 시각화하기 위한 툴입니다. torch, numpy를 지원하고 있습니다. 4 | 5 | ### 쉬운 시작 6 | 7 | 1. Visdom server 켜기 8 | 9 | ``` 10 | pip install visdom 11 | python -m visdom.server 12 | # localhost:8097에 접속합니다. 13 | ``` 14 | 15 | 2. Visdom에 로그 남기기 16 | ```python 17 | DEFAULT_PORT = 8097 18 | DEFAULT_HOSTNAME = "http://localhost" 19 | 20 | vis = visdom.Visdom(port=DEFAULT_PORT, server=DEFAULT_HOSTNAME) 21 | 22 | vis.text('텍스트 써보기') 23 | vis.images(img:numpy_type) # numpy type의 이미지를 변수에 할당 24 | vis.matplot(plt) # matplotlib의 plot type의 변수에 할당 25 | ``` 26 | 27 | 자세한 사항은 [Visdom 문서](https://github.com/facebookresearch/visdom)를 참조해주세요 -------------------------------------------------------------------------------- /networks/__init__.py: -------------------------------------------------------------------------------- 1 | from .deeplab_v3_plus import DeepLab 2 | from .pspnet import PSPNet 3 | from .mobile_hair import MobileMattingFCN 4 | 5 | 6 | 7 | def get_network(name): 8 | name = name.lower() 9 | if name == 'deeplabv3plus': 10 | return DeepLab(return_with_logits = True) 11 | elif name == 'pspnet_squeezenet': 12 | return PSPNet(num_class=1, base_network='squeezenet') 13 | elif name == 'pspnet_resnet101': 14 | return PSPNet(num_class=1, base_network='resnet101') 15 | elif name == 'mobilenet': 16 | return MobileMattingFCN() 17 | raise ValueError 18 | -------------------------------------------------------------------------------- /networks/deeplab_v3_plus.py: -------------------------------------------------------------------------------- 1 | # Reference : https://github.com/jfzhang95/pytorch-deeplab-xception/blob/master/modeling/backbone/xception.py 2 | 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.utils.model_zoo as model_zoo 8 | 9 | ''' 10 | @ inplanes = the number of input channels 11 | @ planes = the number of output channels 12 | @ start_with_relu = except for Block1 of Entry flow, every block starts with ReLU. 13 | @ output_stride = the ratio of input image spatial resolution to the final output resolution (before global pooling or fc layer) 14 | ''' 15 | 16 | def fixed_padding(inputs, kernel_size, dilation): 17 | kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1) 18 | pad_total = kernel_size_effective - 1 19 | pad_beg = pad_total // 2 20 | pad_end = pad_total - pad_beg 21 | padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end)) 22 | return padded_inputs 23 | 24 | class SeparableConv2d(nn.Module): 25 | def __init__(self, inplanes, planes, kernel_size=3, stride=1, dilation=1, bias=False): 26 | super(SeparableConv2d, self).__init__() 27 | 28 | self.conv1 = nn.Conv2d(inplanes, inplanes, kernel_size, stride, 0, dilation, 29 | groups=inplanes, bias=bias) 30 | self.bn = nn.BatchNorm2d(inplanes) 31 | self.pointwise = nn.Conv2d(inplanes, planes, 1, 1, 0, 1, 1, bias=bias) 32 | 33 | def forward(self, x): 34 | x = fixed_padding(x, self.conv1.kernel_size[0], dilation=self.conv1.dilation[0]) 35 | x = self.conv1(x) 36 | x = self.bn(x) 37 | x = self.pointwise(x) 38 | return x 39 | 40 | class Block(nn.Module): 41 | def __init__(self, inplanes, planes, reps, stride=1, dilation=1, 42 | start_with_relu=True, grow_first=True, is_last=False): 43 | super(Block, self).__init__() 44 | 45 | if planes != inplanes or stride != 1: 46 | self.skip = nn.Conv2d(inplanes, planes, 1, stride=stride, bias=False) 47 | self.skipbn = nn.BatchNorm2d(planes) 48 | else: 49 | self.skip = None 50 | 51 | self.relu = nn.ReLU(inplace=True) 52 | rep = [] 53 | 54 | filters = inplanes 55 | if grow_first: 56 | rep.append(self.relu) 57 | rep.append(SeparableConv2d(inplanes, planes, 3, 1, dilation)) 58 | rep.append(nn.BatchNorm2d(planes)) 59 | filters = planes 60 | 61 | for i in range(reps - 1): 62 | rep.append(self.relu) 63 | rep.append(SeparableConv2d(filters, filters, 3, 1, dilation)) 64 | rep.append(nn.BatchNorm2d(filters)) 65 | 66 | if not grow_first: 67 | rep.append(self.relu) 68 | rep.append(SeparableConv2d(inplanes, planes, 3, 1, dilation)) 69 | rep.append(nn.BatchNorm2d(planes)) 70 | 71 | if stride != 1: 72 | rep.append(self.relu) 73 | rep.append(SeparableConv2d(planes, planes, 3, 2)) 74 | rep.append(nn.BatchNorm2d(planes)) 75 | 76 | if stride == 1 and is_last: 77 | rep.append(self.relu) 78 | rep.append(SeparableConv2d(planes, planes, 3, 1)) 79 | rep.append(nn.BatchNorm2d(planes)) 80 | 81 | if not start_with_relu: 82 | rep = rep[1:] 83 | 84 | self.rep = nn.Sequential(*rep) 85 | 86 | 87 | def forward(self, inp): 88 | x = self.rep(inp) 89 | 90 | if self.skip is not None: 91 | skip = self.skip(inp) 92 | skip = self.skipbn(skip) 93 | else: 94 | skip = inp 95 | 96 | x = x + skip 97 | 98 | return x 99 | 100 | class ModifiedAlignedXception(nn.Module): 101 | """ 102 | Modified Alighed Xception 103 | """ 104 | def __init__(self, output_stride, 105 | pretrained=True): 106 | super(ModifiedAlignedXception, self).__init__() 107 | 108 | if output_stride == 16: 109 | entry_block3_stride = 2 110 | middle_block_dilation = 1 111 | exit_block_dilations = (1, 2) 112 | elif output_stride == 8: 113 | entry_block3_stride = 1 114 | middle_block_dilation = 2 115 | exit_block_dilations = (2, 4) 116 | else: 117 | raise NotImplementedError 118 | 119 | 120 | # Entry flow 121 | self.conv1 = nn.Conv2d(3, 32, 3, stride=2, padding=1, bias=False) 122 | self.bn1 = nn.BatchNorm2d(32) 123 | self.relu = nn.ReLU(inplace=True) 124 | 125 | self.conv2 = nn.Conv2d(32, 64, 3, stride=1, padding=1, bias=False) 126 | self.bn2 = nn.BatchNorm2d(64) 127 | 128 | self.block1 = Block(64, 128, reps=2, stride=2, start_with_relu=False) 129 | self.block2 = Block(128, 256, reps=2, stride=2, start_with_relu=False, 130 | grow_first=True) 131 | self.block3 = Block(256, 728, reps=2, stride=entry_block3_stride, 132 | start_with_relu=True, grow_first=True, is_last=True) 133 | 134 | # Middle flow 135 | self.block4 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 136 | start_with_relu=True, grow_first=True) 137 | self.block5 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 138 | start_with_relu=True, grow_first=True) 139 | self.block6 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 140 | start_with_relu=True, grow_first=True) 141 | self.block7 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 142 | start_with_relu=True, grow_first=True) 143 | self.block8 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 144 | start_with_relu=True, grow_first=True) 145 | self.block9 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 146 | start_with_relu=True, grow_first=True) 147 | self.block10 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 148 | start_with_relu=True, grow_first=True) 149 | self.block11 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 150 | start_with_relu=True, grow_first=True) 151 | self.block12 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 152 | start_with_relu=True, grow_first=True) 153 | self.block13 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 154 | start_with_relu=True, grow_first=True) 155 | self.block14 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 156 | start_with_relu=True, grow_first=True) 157 | self.block15 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 158 | start_with_relu=True, grow_first=True) 159 | self.block16 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 160 | start_with_relu=True, grow_first=True) 161 | self.block17 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 162 | start_with_relu=True, grow_first=True) 163 | self.block18 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 164 | start_with_relu=True, grow_first=True) 165 | self.block19 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 166 | start_with_relu=True, grow_first=True) 167 | 168 | # Exit flow 169 | self.block20 = Block(728, 1024, reps=2, stride=1, dilation=exit_block_dilations[0], 170 | start_with_relu=True, grow_first=False, is_last=True) 171 | 172 | self.conv3 = SeparableConv2d(1024, 1536, 3, stride=1, dilation=exit_block_dilations[1]) 173 | self.bn3 = nn.BatchNorm2d(1536) 174 | 175 | self.conv4 = SeparableConv2d(1536, 1536, 3, stride=1, dilation=exit_block_dilations[1]) 176 | self.bn4 = nn.BatchNorm2d(1536) 177 | 178 | self.conv5 = SeparableConv2d(1536, 2048, 3, stride=1, dilation=exit_block_dilations[1]) 179 | self.bn5 = nn.BatchNorm2d(2048) 180 | 181 | # Load pretrained model 182 | if pretrained: 183 | self._load_pretrained_model() 184 | 185 | def forward(self, x): 186 | # Entry flow 187 | x = self.conv1(x) 188 | x = self.bn1(x) 189 | x = self.relu(x) 190 | 191 | x = self.conv2(x) 192 | x = self.bn2(x) 193 | x = self.relu(x) 194 | 195 | x = self.block1(x) 196 | # add relu here 197 | x = self.relu(x) 198 | low_level_feat = x 199 | x = self.block2(x) 200 | x = self.block3(x) 201 | 202 | # Middle flow 203 | x = self.block4(x) 204 | x = self.block5(x) 205 | x = self.block6(x) 206 | x = self.block7(x) 207 | x = self.block8(x) 208 | x = self.block9(x) 209 | x = self.block10(x) 210 | x = self.block11(x) 211 | x = self.block12(x) 212 | x = self.block13(x) 213 | x = self.block14(x) 214 | x = self.block15(x) 215 | x = self.block16(x) 216 | x = self.block17(x) 217 | x = self.block18(x) 218 | x = self.block19(x) 219 | 220 | # Exit flow 221 | x = self.block20(x) 222 | x = self.relu(x) 223 | x = self.conv3(x) 224 | x = self.bn3(x) 225 | x = self.relu(x) 226 | 227 | x = self.conv4(x) 228 | x = self.bn4(x) 229 | x = self.relu(x) 230 | 231 | x = self.conv5(x) 232 | x = self.bn5(x) 233 | x = self.relu(x) 234 | 235 | return x, low_level_feat 236 | 237 | def _load_pretrained_model(self): 238 | pretrain_dict = model_zoo.load_url('http://data.lip6.fr/cadene/pretrainedmodels/xception-b5690688.pth') 239 | model_dict = {} 240 | state_dict = self.state_dict() 241 | 242 | for k, v in pretrain_dict.items(): 243 | if k in model_dict: 244 | if 'pointwise' in k: 245 | v = v.unsqueeze(-1).unsqueeze(-1) 246 | if k.startswith('block11'): 247 | model_dict[k] = v 248 | model_dict[k.replace('block11', 'block12')] = v 249 | model_dict[k.replace('block11', 'block13')] = v 250 | model_dict[k.replace('block11', 'block14')] = v 251 | model_dict[k.replace('block11', 'block15')] = v 252 | model_dict[k.replace('block11', 'block16')] = v 253 | model_dict[k.replace('block11', 'block17')] = v 254 | model_dict[k.replace('block11', 'block18')] = v 255 | model_dict[k.replace('block11', 'block19')] = v 256 | elif k.startswith('block12'): 257 | model_dict[k.replace('block12', 'block20')] = v 258 | elif k.startswith('bn3'): 259 | model_dict[k] = v 260 | model_dict[k.replace('bn3', 'bn4')] = v 261 | elif k.startswith('conv4'): 262 | model_dict[k.replace('conv4', 'conv5')] = v 263 | elif k.startswith('bn4'): 264 | model_dict[k.replace('bn4', 'bn5')] = v 265 | else: 266 | model_dict[k] = v 267 | state_dict.update(model_dict) 268 | self.load_state_dict(state_dict) 269 | 270 | class _ASPPModule(nn.Module): 271 | def __init__(self, inplanes, planes, kernel_size, padding, dilation): 272 | super(_ASPPModule, self).__init__() 273 | 274 | self.atrous_conv = nn.Conv2d(inplanes, planes, kernel_size = kernel_size, 275 | stride = 1, padding = padding, dilation = dilation, bias = False) 276 | self.bn = nn.BatchNorm2d(planes) 277 | self.relu = nn.ReLU() 278 | 279 | def forward(self, x): 280 | x = self.atrous_conv(x) 281 | x = self.bn(x) 282 | 283 | return self.relu(x) 284 | 285 | class ASPP(nn.Module): 286 | def __init__(self, output_stride): 287 | super(ASPP, self).__init__() 288 | 289 | inplanes = 2048 290 | 291 | if output_stride == 16: 292 | dilations = [1, 6, 12, 18] 293 | elif output_stride == 8: 294 | dilations = [1, 12, 24, 36] 295 | else: 296 | raise NotImplementedError 297 | 298 | self.aspp1 = _ASPPModule(inplanes, 256, 1, padding=0, dilation=dilations[0]) 299 | self.aspp2 = _ASPPModule(inplanes, 256, 3, padding=dilations[1], dilation=dilations[1]) 300 | self.aspp3 = _ASPPModule(inplanes, 256, 3, padding=dilations[2], dilation=dilations[2]) 301 | self.aspp4 = _ASPPModule(inplanes, 256, 3, padding=dilations[3], dilation=dilations[3]) 302 | 303 | self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), 304 | nn.Conv2d(inplanes, 256, 1, stride=1, bias=False), 305 | nn.BatchNorm2d(256), 306 | nn.ReLU()) 307 | self.conv1 = nn.Conv2d(1280, 256, 1, bias=False) 308 | self.bn1 = nn.BatchNorm2d(256) 309 | self.relu = nn.ReLU() 310 | self.dropout = nn.Dropout(0.5) 311 | 312 | def forward(self, x): 313 | x1 = self.aspp1(x) 314 | x2 = self.aspp2(x) 315 | x3 = self.aspp3(x) 316 | x4 = self.aspp4(x) 317 | x5 = self.global_avg_pool(x) 318 | x5 = F.interpolate(x5, size=x4.size()[2:], mode='bilinear', align_corners=True) 319 | x = torch.cat((x1, x2, x3, x4, x5), dim=1) 320 | 321 | x = self.conv1(x) 322 | x = self.bn1(x) 323 | x = self.relu(x) 324 | 325 | return self.dropout(x) 326 | 327 | class Decoder(nn.Module): 328 | def __init__(self, return_with_logits, low_level_inplanes = 128): 329 | super(Decoder, self).__init__() 330 | 331 | self.conv1 = nn.Conv2d(low_level_inplanes, 48, 1, bias = False) 332 | self.bn1 = nn.BatchNorm2d(48) 333 | self.relu = nn.ReLU() 334 | 335 | layers = [nn.Conv2d(304, 256, kernel_size = 3, stride = 1, padding = 1, bias = False), 336 | nn.BatchNorm2d(256), 337 | nn.ReLU(), 338 | nn.Dropout(0.5), 339 | nn.Conv2d(256, 256, kernel_size = 3, stride = 1, padding = 1, bias = False), 340 | nn.BatchNorm2d(256), 341 | nn.ReLU(), 342 | nn.Dropout(0.1), 343 | nn.Conv2d(256, 1, kernel_size=1, stride=1)] 344 | 345 | if not return_with_logits: 346 | layers.append(nn.Sigmoid()) 347 | 348 | self.last_conv = nn.Sequential(*layers) 349 | 350 | def forward(self, x, low_level_feat): 351 | low_level_feat = self.conv1(low_level_feat) 352 | low_level_feat = self.bn1(low_level_feat) 353 | low_level_feat = self.relu(low_level_feat) 354 | 355 | x = F.interpolate(x, size = low_level_feat.size()[2:], mode = 'bilinear', align_corners = True) 356 | x = torch.cat((x, low_level_feat), dim = 1) 357 | x = self.last_conv(x) 358 | 359 | return x 360 | 361 | class DeepLab(nn.Module): 362 | def __init__(self, return_with_logits = True, output_stride=16): 363 | super(DeepLab, self).__init__() 364 | 365 | self.backbone = ModifiedAlignedXception(output_stride) 366 | self.aspp = ASPP(output_stride) 367 | self.decoder = Decoder(return_with_logits) 368 | 369 | 370 | def forward(self, input): 371 | x, low_level_feat = self.backbone(input) 372 | x = self.aspp(x) 373 | x = self.decoder(x, low_level_feat) 374 | x = F.interpolate(x, size=input.size()[2:], mode='bilinear', align_corners=True) 375 | 376 | return x 377 | -------------------------------------------------------------------------------- /networks/mobile_hair.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of "Real-time deep hair matting on mobile devices(2018)" 3 | """ 4 | 5 | import torch 6 | from torch import nn 7 | from torch.nn import functional as F 8 | 9 | import numpy as np 10 | import cv2 11 | 12 | def fixed_padding(inputs, kernel_size, rate): 13 | kernel_size_effective = kernel_size + (kernel_size - 1) * (rate - 1) 14 | pad_total = kernel_size_effective - 1 15 | pad_beg = pad_total // 2 16 | pad_end = pad_total - pad_beg 17 | padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end)) 18 | return padded_inputs 19 | 20 | class SeparableConv2d(nn.Module): 21 | def __init__(self,in_channels,out_channels,kernel_size=1,stride=1,padding=0,dilation=1,bias=False): 22 | super(SeparableConv2d,self).__init__() 23 | 24 | self.conv1 = nn.Conv2d(in_channels,in_channels,kernel_size,stride,padding,dilation,groups=in_channels,bias=bias) 25 | self.pointwise = nn.Conv2d(in_channels,out_channels,1,1,0,1,1,bias=bias) 26 | 27 | def forward(self,x): 28 | x = fixed_padding(x, self.conv1.kernel_size[0], rate=self.conv1.dilation[0]) 29 | x = self.conv1(x) 30 | x = self.pointwise(x) 31 | return x 32 | 33 | class GreenBlock(nn.Module): 34 | def __init__(self, in_channel, out_channel): 35 | super(GreenBlock,self).__init__() 36 | 37 | self.dconv = nn.Sequential( 38 | SeparableConv2d(in_channel), 39 | nn.BatchNorm2d(in_channel), 40 | nn.ReLU() 41 | ) 42 | 43 | self.conv = nn.Sequential( 44 | nn.Conv2d(in_channel, out_channel, kernel_size=1), 45 | nn.BatchNorm2d(out_channel), 46 | nn.ReLU() 47 | ) 48 | 49 | def forward(self, input): 50 | x = self.dconv(input) 51 | x = self.conv(x) 52 | 53 | return x 54 | 55 | class YellowBlock(nn.Module): 56 | def __init__(self): 57 | super(YellowBlock,self).__init__() 58 | 59 | def forward(self, input): 60 | return F.interpolate(input, scale_factor=2) 61 | 62 | class OrangeBlock(nn.Module): 63 | def __init__(self, in_channels, out_channels,kernel_size=3,stride=1,padding=0,dilation=1, bias=False): 64 | super(OrangeBlock,self).__init__() 65 | self.conv = nn.Sequential( 66 | SeparableConv2d(in_channels, out_channels, kernel_size), 67 | nn.ReLU() 68 | ) 69 | 70 | def forward(self, input): 71 | return self.conv(input) 72 | 73 | 74 | class MobileMattingFCN(nn.Module): 75 | # https://github.com/marvis/pytorch-mobilenet modified 76 | def __init__(self): 77 | super(MobileMattingFCN, self).__init__() 78 | 79 | def conv_bn(inp, oup, stride): 80 | return nn.Sequential( 81 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 82 | nn.BatchNorm2d(oup), 83 | nn.ReLU(inplace=True) 84 | ) 85 | 86 | def conv_dw(inp, oup, stride): 87 | return nn.Sequential( 88 | nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), 89 | nn.BatchNorm2d(inp), 90 | nn.ReLU(inplace=True), 91 | 92 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 93 | nn.BatchNorm2d(oup), 94 | nn.ReLU(inplace=True), 95 | ) 96 | 97 | self.model = nn.Sequential( 98 | conv_bn( 3, 32, 2), 99 | conv_dw( 32, 64, 1), # skip1 1 100 | conv_dw( 64, 128, 2), 101 | conv_dw(128, 128, 1), # skip2 3 102 | conv_dw(128, 256, 2), 103 | conv_dw(256, 256, 1), # skip3 5 104 | conv_dw(256, 512, 2), 105 | conv_dw(512, 512, 1), 106 | conv_dw(512, 512, 1), 107 | conv_dw(512, 512, 1), 108 | conv_dw(512, 512, 1), 109 | conv_dw(512, 512, 1), # skip4 11 110 | conv_dw(512, 1024, 2), 111 | conv_dw(1024, 1024, 1), 112 | 113 | conv_dw(1024, 1024, 1), # added 114 | #nn.AvgPool2d(7), 115 | ) 116 | 117 | self.upsample0 = YellowBlock() 118 | self.o0 = OrangeBlock(1024+512, 64) 119 | 120 | self.upsample1 = YellowBlock() 121 | self.o1 = OrangeBlock(64+256, 64) 122 | self.upsample2 = YellowBlock() 123 | self.o2 = OrangeBlock(64+128, 64) 124 | self.upsample3 = YellowBlock() 125 | self.o3 = OrangeBlock(64+64, 64) 126 | self.upsample4 = YellowBlock() 127 | self.o4 = OrangeBlock(64, 64) 128 | 129 | self.red = nn.Sequential( 130 | nn.Conv2d(64, 1, 1) 131 | ) 132 | 133 | #self.fc = nn.Linear(1024, 1000) 134 | 135 | def forward(self, x): 136 | skips = [] 137 | #x = self.model(x) 138 | 139 | for i, model in enumerate(self.model): 140 | x = model(x) 141 | if i in {1,3,5,11}: 142 | skips.append(x) 143 | 144 | x = self.upsample0(x) 145 | x = torch.cat((x, skips[-1]), dim=1) 146 | x = self.o0(x) 147 | 148 | x = self.upsample1(x) 149 | x = torch.cat((x, skips[-2]), dim=1) 150 | x = self.o1(x) 151 | 152 | x = self.upsample2(x) 153 | x = torch.cat((x, skips[-3]), dim=1) 154 | x = self.o2(x) 155 | 156 | x = self.upsample3(x) 157 | x = torch.cat((x, skips[-4]), dim=1) 158 | x = self.o3(x) 159 | x = self.upsample4(x) 160 | x = self.o4(x) 161 | 162 | #x = self.fc(x) 163 | return self.red(x) 164 | 165 | def load_pretrained_model(self): 166 | pass 167 | # hell baidu - https://github.com/marvis/pytorch-mobilenet 168 | 169 | class HairMattingLoss(nn.modules.loss._Loss): 170 | def __init__(self, ratio_of_Gradient=0.0, add_gradient=False): 171 | super(HairMattingLoss, self).__init__() 172 | self.ratio_of_gradient = ratio_of_Gradient 173 | self.add_gradient = add_gradient 174 | self.bce_loss = nn.BCEWithLogitsLoss() 175 | 176 | def forward(self, pred, true, image): 177 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 178 | loss2 = None 179 | if self.ratio_of_gradient > 0: 180 | sobel_kernel_x = torch.Tensor( 181 | [[1.0, 0.0, -1.0], 182 | [2.0, 0.0, -2.0], 183 | [1.0, 0.0, -1.0]]).to(device) 184 | sobel_kernel_x = sobel_kernel_x.view((1,1,3,3)) 185 | 186 | I_x = F.conv2d(image, sobel_kernel_x) 187 | G_x = F.conv2d(pred, sobel_kernel_x) 188 | 189 | sobel_kernel_y = torch.Tensor( 190 | [[1.0, 2.0, 1.0], 191 | [0.0, 0.0, 0.0], 192 | [-1.0, -2.0, -1.0]]).to(device) 193 | sobel_kernel_y = sobel_kernel_y.view((1,1,3,3)) 194 | 195 | I_y = F.conv2d(image, sobel_kernel_y) 196 | G_y = F.conv2d(pred, sobel_kernel_y) 197 | 198 | G = torch.sqrt(torch.pow(G_x,2)+ torch.pow(G_y,2)) 199 | 200 | rang_grad = 1 - torch.pow(I_x*G_x + I_y*G_y,2) 201 | rang_grad = range_grad if rang_grad > 0 else 0 202 | 203 | loss2 = torch.sum(torch.mul(G, rang_grad))/torch.sum(G) + 1e-6 204 | 205 | if self.add_gradient: 206 | loss = (1-self.ratio_of_gradient)*self.bce_loss(pred, true) + loss2*self.ratio_of_gradient 207 | else: 208 | loss = self.bce_loss(pred, true) 209 | 210 | return loss 211 | -------------------------------------------------------------------------------- /networks/pspnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | from torchvision.models import squeezenet1_1, resnet101 5 | from torch.nn.init import xavier_normal_ 6 | 7 | """ 8 | Referenced from https://github.com/Lextal/pspnet-pytorch/blob/master/pspnet.py 9 | """ 10 | 11 | 12 | class ResNet101Extractor(nn.Module): 13 | def __init__(self): 14 | super(ResNet101Extractor, self).__init__() 15 | model = resnet101(pretrained=True) 16 | self.features = nn.Sequential(*list(model.children())[:7]) 17 | def forward(self, x): 18 | return self.features(x) 19 | 20 | class SqueezeNetExtractor(nn.Module): 21 | def __init__(self): 22 | super(SqueezeNetExtractor, self).__init__() 23 | model = squeezenet1_1(pretrained=True) 24 | features = model.features 25 | self.feature1 = features[:2] 26 | self.feature2 = features[2:5] 27 | self.feature3 = features[5:8] 28 | self.feature4 = features[8:] 29 | 30 | def forward(self, x): 31 | f1 = self.feature1(x) 32 | f2 = self.feature2(f1) 33 | f3 = self.feature3(f2) 34 | f4 = self.feature4(f3) 35 | return f4 36 | 37 | 38 | class PyramidPoolingModule(nn.Module): 39 | def __init__(self, in_channels, sizes=(1, 2, 3, 6)): 40 | super(PyramidPoolingModule, self).__init__() 41 | pyramid_levels = len(sizes) 42 | out_channels = in_channels // pyramid_levels 43 | 44 | pooling_layers = nn.ModuleList() 45 | for size in sizes: 46 | layers = [nn.AdaptiveAvgPool2d(size), nn.Conv2d(in_channels, out_channels, kernel_size=1)] 47 | pyramid_layer = nn.Sequential(*layers) 48 | pooling_layers.append(pyramid_layer) 49 | 50 | self.pooling_layers = pooling_layers 51 | 52 | def forward(self, x): 53 | h, w = x.size(2), x.size(3) 54 | features = [x] 55 | for pooling_layer in self.pooling_layers: 56 | # pool with different sizes 57 | pooled = pooling_layer(x) 58 | 59 | # upsample to original size 60 | upsampled = F.upsample(pooled, size=(h, w), mode='bilinear') 61 | 62 | features.append(upsampled) 63 | 64 | return torch.cat(features, dim=1) 65 | 66 | 67 | class UpsampleLayer(nn.Module): 68 | def __init__(self, in_channels, out_channels, upsample_size=None): 69 | super().__init__() 70 | self.upsample_size = upsample_size 71 | 72 | self.conv = nn.Sequential( 73 | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False), 74 | nn.BatchNorm2d(out_channels), 75 | nn.ReLU() 76 | ) 77 | 78 | def forward(self, x): 79 | size = 2 * x.size(2), 2 * x.size(3) 80 | f = F.upsample(x, size=size, mode='bilinear') 81 | return self.conv(f) 82 | 83 | 84 | class PSPNet(nn.Module): 85 | def __init__(self, num_class=1, sizes=(1, 2, 3, 6), base_network='resnet101'): 86 | super(PSPNet, self).__init__() 87 | base_network = base_network.lower() 88 | if base_network == 'resnet101': 89 | self.base_network = ResNet101Extractor() 90 | feature_dim = 1024 91 | elif base_network == 'squeezenet': 92 | self.base_network = SqueezeNetExtractor() 93 | feature_dim = 512 94 | else: 95 | raise ValueError 96 | self.psp = PyramidPoolingModule(in_channels=feature_dim, sizes=sizes) 97 | self.drop_1 = nn.Dropout2d(p=0.3) 98 | 99 | self.up_1 = UpsampleLayer(2*feature_dim, 256) 100 | self.up_2 = UpsampleLayer(256, 64) 101 | self.up_3 = UpsampleLayer(64, 64) 102 | 103 | self.drop_2 = nn.Dropout2d(p=0.15) 104 | self.final = nn.Sequential( 105 | nn.Conv2d(64, num_class, kernel_size=1) 106 | ) 107 | 108 | self._init_weight() 109 | 110 | def forward(self, x): 111 | h, w = x.size(2), x.size(3) 112 | f = self.base_network(x) 113 | p = self.psp(f) 114 | p = self.drop_1(p) 115 | p = self.up_1(p) 116 | p = self.drop_2(p) 117 | 118 | p = self.up_2(p) 119 | p = self.drop_2(p) 120 | 121 | p = self.up_3(p) 122 | 123 | if (p.size(2) != h) or (p.size(3) != w): 124 | p = F.interpolate(p, size=(h, w), mode='bilinear') 125 | 126 | p = self.drop_2(p) 127 | 128 | return self.final(p) 129 | 130 | def _init_weight(self): 131 | layers = [self.up_1, self.up_2, self.up_3, self.final] 132 | for layer in layers: 133 | if isinstance(layer, nn.Conv2d): 134 | xavier_normal_(layer.weight.data) 135 | 136 | elif isinstance(layer, nn.BatchNorm2d): 137 | layer.weight.data.normal_(1.0, 0.02) 138 | layer.bias.data.fill_(0) 139 | 140 | -------------------------------------------------------------------------------- /notebooks/TrainingExample.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Example training notebook file" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 2, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "# add work directory\n", 17 | "import os\n", 18 | "import sys\n", 19 | "import torch\n", 20 | "\n", 21 | "# you should add root directory\n", 22 | "sys.path.append(os.path.dirname(\"../\"))" 23 | ] 24 | }, 25 | { 26 | "cell_type": "markdown", 27 | "metadata": {}, 28 | "source": [ 29 | "## Loading Figaro dataset using get_loader" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 3, 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "# importing dataloader\n", 39 | "\n", 40 | "from data import get_loader\n", 41 | "\n", 42 | "# you have to predefine transforms to load dataset\n", 43 | "# this transforms images and masks while loading\n", 44 | "# example transforms\n", 45 | "\n", 46 | "from utils import joint_transforms as jnt_trnsf\n", 47 | "import torchvision.transforms as std_trnsf\n", 48 | "\n", 49 | "\n", 50 | "# transforms that are applied to both images and masks\n", 51 | "# includes geometrical changes like flip\n", 52 | "# implemented in ./utils/joint_transforms.py\n", 53 | "joint_transforms = jnt_trnsf.Compose([\n", 54 | " jnt_trnsf.Resize(256),\n", 55 | " jnt_trnsf.RandomRotate(5),\n", 56 | " jnt_trnsf.CenterCrop(224),\n", 57 | " jnt_trnsf.RandomHorizontallyFlip()\n", 58 | "])\n", 59 | "\n", 60 | "\n", 61 | "# transforms that are applied to only images\n", 62 | "# this includes color jittering, normalizing, blurring, etc\n", 63 | "# use torchvision.transforms, or implement additional transforms in 'utils'\n", 64 | "train_image_transforms = std_trnsf.Compose([\n", 65 | " std_trnsf.ColorJitter(0.05, 0.05, 0.05, 0.05),\n", 66 | " std_trnsf.ToTensor(),\n", 67 | " std_trnsf.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n", 68 | " ])\n", 69 | "\n", 70 | "\n", 71 | "test_image_transforms = std_trnsf.Compose([\n", 72 | " std_trnsf.ToTensor(),\n", 73 | " std_trnsf.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n", 74 | " ])\n", 75 | "\n", 76 | "# transforms that are applied to only masks\n", 77 | "mask_transforms = std_trnsf.Compose([\n", 78 | " std_trnsf.ToTensor()\n", 79 | " ])\n", 80 | "\n", 81 | "# predifine other needed arguments\n", 82 | "batch_size = 4\n", 83 | "num_workers = 1\n", 84 | "data_dir = '../data/Figaro1k/'" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": 4, 90 | "metadata": {}, 91 | "outputs": [], 92 | "source": [ 93 | "train_loader = get_loader(dataset='figaro',\n", 94 | " data_dir=data_dir,\n", 95 | " train=True,\n", 96 | " joint_transforms=joint_transforms,\n", 97 | " image_transforms=train_image_transforms,\n", 98 | " mask_transforms=mask_transforms,\n", 99 | " batch_size=batch_size,\n", 100 | " shuffle=False,\n", 101 | " num_workers=num_workers)\n", 102 | "\n", 103 | "test_loader = get_loader(dataset='figaro',\n", 104 | " data_dir=data_dir,\n", 105 | " train=False,\n", 106 | " joint_transforms=joint_transforms,\n", 107 | " image_transforms=test_image_transforms,\n", 108 | " mask_transforms=mask_transforms,\n", 109 | " batch_size=1,\n", 110 | " shuffle=False,\n", 111 | " num_workers=num_workers)" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": 5, 117 | "metadata": {}, 118 | "outputs": [ 119 | { 120 | "data": { 121 | "text/plain": [ 122 | "(0, torch.Size([4, 3, 224, 224]), torch.Size([4, 1, 224, 224]))" 123 | ] 124 | }, 125 | "execution_count": 5, 126 | "metadata": {}, 127 | "output_type": "execute_result" 128 | } 129 | ], 130 | "source": [ 131 | "# two ways of iterating dataloader\n", 132 | "\n", 133 | "# 1. using for loop\n", 134 | "\n", 135 | "for step, (data, target) in enumerate(train_loader):\n", 136 | " break\n", 137 | "step, data.size(), target.size() \n" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": 6, 143 | "metadata": {}, 144 | "outputs": [ 145 | { 146 | "data": { 147 | "text/plain": [ 148 | "(torch.Size([4, 3, 224, 224]), torch.Size([4, 1, 224, 224]))" 149 | ] 150 | }, 151 | "execution_count": 6, 152 | "metadata": {}, 153 | "output_type": "execute_result" 154 | }, 155 | { 156 | "name": "stderr", 157 | "output_type": "stream", 158 | "text": [ 159 | "Process Process-2:\n", 160 | "Traceback (most recent call last):\n", 161 | " File \"/opt/conda/lib/python3.6/multiprocessing/process.py\", line 258, in _bootstrap\n", 162 | " self.run()\n", 163 | " File \"/opt/conda/lib/python3.6/multiprocessing/process.py\", line 93, in run\n", 164 | " self._target(*self._args, **self._kwargs)\n", 165 | " File \"/opt/conda/lib/python3.6/site-packages/torch/utils/data/dataloader.py\", line 96, in _worker_loop\n", 166 | " r = index_queue.get(timeout=MANAGER_STATUS_CHECK_INTERVAL)\n", 167 | " File \"/opt/conda/lib/python3.6/multiprocessing/queues.py\", line 104, in get\n", 168 | " if not self._poll(timeout):\n", 169 | " File \"/opt/conda/lib/python3.6/multiprocessing/connection.py\", line 257, in poll\n", 170 | " return self._poll(timeout)\n", 171 | " File \"/opt/conda/lib/python3.6/multiprocessing/connection.py\", line 414, in _poll\n", 172 | " r = wait([self], timeout)\n", 173 | " File \"/opt/conda/lib/python3.6/multiprocessing/connection.py\", line 911, in wait\n", 174 | " ready = selector.select(timeout)\n", 175 | " File \"/opt/conda/lib/python3.6/selectors.py\", line 376, in select\n", 176 | " fd_event_list = self._poll.poll(timeout)\n", 177 | "KeyboardInterrupt\n" 178 | ] 179 | } 180 | ], 181 | "source": [ 182 | "# 2. using iterator\n", 183 | "batch_iterator = iter(train_loader)\n", 184 | "\n", 185 | "for _ in range(10):\n", 186 | " data, target = batch_iterator.next()\n", 187 | "data.size(), target.size()" 188 | ] 189 | }, 190 | { 191 | "cell_type": "markdown", 192 | "metadata": {}, 193 | "source": [ 194 | "## Importing model" 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": 7, 200 | "metadata": {}, 201 | "outputs": [], 202 | "source": [ 203 | "from networks import get_network\n", 204 | "\n", 205 | "# you can add your own model in get_network fuction in ./networks/__init__.py \n", 206 | "model = get_network(name='SegNet', num_class = 1)\n", 207 | "\n", 208 | "# or just import directly\n", 209 | "from networks.segnet import SegNet\n", 210 | "model = SegNet(num_class = 1)" 211 | ] 212 | }, 213 | { 214 | "cell_type": "markdown", 215 | "metadata": {}, 216 | "source": [ 217 | "## Defining Optimizer & Scheduler & loss & device" 218 | ] 219 | }, 220 | { 221 | "cell_type": "code", 222 | "execution_count": 8, 223 | "metadata": {}, 224 | "outputs": [], 225 | "source": [ 226 | "# torch.optim\n", 227 | "optimizer = torch.optim.Adam(model.parameters(), \n", 228 | " lr = 0.001, \n", 229 | " betas=(0.5, 0.999), # beta1 acts like 'momentum' in SGD\n", 230 | " )\n", 231 | "\n", 232 | "# torch.\n", 233 | "scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)\n", 234 | "\n", 235 | "# torch.nn\n", 236 | "loss = torch.nn.BCEWithLogitsLoss()\n", 237 | "\n", 238 | "# flag to use gpu or not\n", 239 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'" 240 | ] 241 | }, 242 | { 243 | "cell_type": "markdown", 244 | "metadata": {}, 245 | "source": [ 246 | "## Using Pytorch Ignite" 247 | ] 248 | }, 249 | { 250 | "cell_type": "code", 251 | "execution_count": 9, 252 | "metadata": {}, 253 | "outputs": [], 254 | "source": [ 255 | "# ignite moduels\n", 256 | "from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator\n", 257 | "from ignite.metrics import Loss\n", 258 | "\n", 259 | "# custom modules\n", 260 | "from utils.metrics import Accuracy, MeanIU" 261 | ] 262 | }, 263 | { 264 | "cell_type": "code", 265 | "execution_count": 10, 266 | "metadata": {}, 267 | "outputs": [], 268 | "source": [ 269 | "# trainer and evaluator\n", 270 | "trainer = create_supervised_trainer(model, optimizer, loss, device=device)\n", 271 | "evaluator = create_supervised_evaluator(model,\n", 272 | " metrics={\n", 273 | " 'pix-acc': Accuracy(),\n", 274 | " 'mean-iu': MeanIU(0.5),\n", 275 | " 'loss': Loss(loss)\n", 276 | " },\n", 277 | " device=device)" 278 | ] 279 | }, 280 | { 281 | "cell_type": "code", 282 | "execution_count": 11, 283 | "metadata": {}, 284 | "outputs": [], 285 | "source": [ 286 | "# saving training state if you want\n", 287 | "from utils import update_state\n", 288 | "state = update_state(model.state_dict(), 0, 0, 0, 0)" 289 | ] 290 | }, 291 | { 292 | "cell_type": "code", 293 | "execution_count": null, 294 | "metadata": {}, 295 | "outputs": [ 296 | { 297 | "name": "stderr", 298 | "output_type": "stream", 299 | "text": [ 300 | "Process Process-3:\n", 301 | "Traceback (most recent call last):\n", 302 | " File \"/opt/conda/lib/python3.6/multiprocessing/process.py\", line 258, in _bootstrap\n", 303 | " self.run()\n", 304 | " File \"/opt/conda/lib/python3.6/multiprocessing/process.py\", line 93, in run\n", 305 | " self._target(*self._args, **self._kwargs)\n", 306 | " File \"/opt/conda/lib/python3.6/site-packages/torch/utils/data/dataloader.py\", line 96, in _worker_loop\n", 307 | " r = index_queue.get(timeout=MANAGER_STATUS_CHECK_INTERVAL)\n", 308 | " File \"/opt/conda/lib/python3.6/multiprocessing/queues.py\", line 104, in get\n", 309 | " if not self._poll(timeout):\n", 310 | " File \"/opt/conda/lib/python3.6/multiprocessing/connection.py\", line 257, in poll\n", 311 | " return self._poll(timeout)\n", 312 | " File \"/opt/conda/lib/python3.6/multiprocessing/connection.py\", line 414, in _poll\n", 313 | " r = wait([self], timeout)\n", 314 | " File \"/opt/conda/lib/python3.6/multiprocessing/connection.py\", line 911, in wait\n", 315 | " ready = selector.select(timeout)\n", 316 | " File \"/opt/conda/lib/python3.6/selectors.py\", line 376, in select\n", 317 | " fd_event_list = self._poll.poll(timeout)\n", 318 | "KeyboardInterrupt\n" 319 | ] 320 | } 321 | ], 322 | "source": [ 323 | "@trainer.on(Events.ITERATION_COMPLETED)\n", 324 | "def log_training_loss(trainer):\n", 325 | " num_iter = (trainer.state.iteration - 1) % len(train_loader) + 1\n", 326 | " if num_iter % 20 == 0:\n", 327 | " print(\"Epoch[{}] Iter[{:03d}] Loss: {:.2f}\".format(\n", 328 | " trainer.state.epoch, num_iter, trainer.state.output))\n", 329 | "\n", 330 | "@trainer.on(Events.EPOCH_COMPLETED)\n", 331 | "def log_training_results(trainer):\n", 332 | " # evaluate training set\n", 333 | " evaluator.run(train_loader)\n", 334 | " metrics = evaluator.state.metrics\n", 335 | " print(\"Training Results - Epoch: {} Pix-acc: {:.3f} MeanIU: {:.3f} Avg-loss: {:.3f}\".format(\n", 336 | " trainer.state.epoch, metrics['pix-acc'], metrics['mean-iu'], metrics['loss']))\n", 337 | "\n", 338 | " # update state\n", 339 | " update_state(model.state_dict(), metrics['loss'], state['val_loss'], state['val_pix_acc'], state['val_miu'])\n", 340 | "\n", 341 | "@trainer.on(Events.EPOCH_COMPLETED)\n", 342 | "def log_validation_results(trainer):\n", 343 | " # evaluate test(validation) set\n", 344 | " evaluator.run(test_loader)\n", 345 | " metrics = evaluator.state.metrics\n", 346 | " print(\"Validation Results - Epoch: {} Pix-acc: {:.2f} MeanIU: {:.3f} Avg-loss: {:.2f}\".format(\n", 347 | " trainer.state.epoch, metrics['pix-acc'], metrics['mean-iu'], metrics['loss']))\n", 348 | "\n", 349 | " # update scheduler\n", 350 | " scheduler.step(metrics['loss'])\n", 351 | "\n", 352 | " # update and save state\n", 353 | " update_state(model.state_dict(), state['train_loss'], metrics['loss'], metrics['pix-acc'], metrics['mean-iu'])\n", 354 | " save_ckpt_file(\n", 355 | " ckpt_path.format(network=networks, epoch=trainer.state.epoch),\n", 356 | " state)\n", 357 | "\n", 358 | "trainer.run(train_loader, max_epochs=100)" 359 | ] 360 | }, 361 | { 362 | "cell_type": "markdown", 363 | "metadata": {}, 364 | "source": [ 365 | "## To do this in one-queue" 366 | ] 367 | }, 368 | { 369 | "cell_type": "markdown", 370 | "metadata": {}, 371 | "source": [ 372 | "```bash\n", 373 | "# run this in root\n", 374 | "\n", 375 | "python3 main.py \\\n", 376 | " --networks segnet \\\n", 377 | " --scheduler ReduceLROnPlateau \\\n", 378 | " --batch_size 4 \\\n", 379 | " --epochs 100 \\\n", 380 | " --lr 1e-3 \\\n", 381 | " --num_workers 4 \\\n", 382 | " --optimizer adam \\\n", 383 | " --momentum 0.5\n", 384 | "```" 385 | ] 386 | } 387 | ], 388 | "metadata": { 389 | "kernelspec": { 390 | "display_name": "Python 3", 391 | "language": "python", 392 | "name": "python3" 393 | }, 394 | "language_info": { 395 | "codemirror_mode": { 396 | "name": "ipython", 397 | "version": 3 398 | }, 399 | "file_extension": ".py", 400 | "mimetype": "text/x-python", 401 | "name": "python", 402 | "nbconvert_exporter": "python", 403 | "pygments_lexer": "ipython3", 404 | "version": "3.6.5" 405 | } 406 | }, 407 | "nbformat": 4, 408 | "nbformat_minor": 2 409 | } 410 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm 2 | pytorch-ignite==0.1.0 3 | torchsummary 4 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | 5 | from collections import OrderedDict 6 | import numpy as np 7 | 8 | 9 | def check_mkdir(path): 10 | if not os.path.exists(path): 11 | os.mkdir(path) 12 | 13 | 14 | def update_state(weight, train_loss, val_pix_acc, val_loss, val_iou, val_f1): 15 | state = { 16 | 'weight': weight, 17 | 'train_loss': train_loss, 18 | 'val_loss': val_loss, 19 | 'val_pix_acc': val_pix_acc, 20 | 'val_iou': val_iou, 21 | 'val_f1': val_f1 22 | } 23 | return state 24 | 25 | 26 | def save_ckpt_file(ckpt_path, state): 27 | 28 | check_mkdir(os.path.split(ckpt_path)[0]) 29 | with open(ckpt_path, 'wb') as fout: 30 | torch.save(state, fout) 31 | 32 | 33 | def summarize_model(model, input_size, logger, batch_size=-1, device="cuda"): 34 | """ 35 | hard copied from https://github.com/sksq96/pytorch-summary/blob/master/torchsummary/torchsummary.py#L9 36 | """ 37 | def register_hook(module): 38 | 39 | def hook(module, input, output): 40 | class_name = str(module.__class__).split(".")[-1].split("'")[0] 41 | module_idx = len(summary) 42 | 43 | m_key = "%s-%i" % (class_name, module_idx + 1) 44 | summary[m_key] = OrderedDict() 45 | summary[m_key]["input_shape"] = list(input[0].size()) 46 | summary[m_key]["input_shape"][0] = batch_size 47 | if isinstance(output, (list, tuple)): 48 | summary[m_key]["output_shape"] = [ 49 | [-1] + list(o.size())[1:] for o in output 50 | ] 51 | else: 52 | summary[m_key]["output_shape"] = list(output.size()) 53 | summary[m_key]["output_shape"][0] = batch_size 54 | 55 | params = 0 56 | if hasattr(module, "weight") and hasattr(module.weight, "size"): 57 | params += torch.prod(torch.LongTensor(list(module.weight.size()))) 58 | summary[m_key]["trainable"] = module.weight.requires_grad 59 | if hasattr(module, "bias") and hasattr(module.bias, "size"): 60 | params += torch.prod(torch.LongTensor(list(module.bias.size()))) 61 | summary[m_key]["nb_params"] = params 62 | 63 | if ( 64 | not isinstance(module, nn.Sequential) 65 | and not isinstance(module, nn.ModuleList) 66 | and not (module == model) 67 | ): 68 | hooks.append(module.register_forward_hook(hook)) 69 | 70 | device = device.lower() 71 | assert device in [ 72 | "cuda", 73 | "cpu", 74 | ], "Input device is not valid, please specify 'cuda' or 'cpu'" 75 | 76 | if device == "cuda" and torch.cuda.is_available(): 77 | dtype = torch.cuda.FloatTensor 78 | else: 79 | dtype = torch.FloatTensor 80 | 81 | # multiple inputs to the network 82 | if isinstance(input_size, tuple): 83 | input_size = [input_size] 84 | 85 | # batch_size of 2 for batchnorm 86 | x = [torch.rand(2, *in_size).type(dtype) for in_size in input_size] 87 | # print(type(x[0])) 88 | 89 | # create properties 90 | summary = OrderedDict() 91 | hooks = [] 92 | 93 | # register hook 94 | model.apply(register_hook) 95 | 96 | # make a forward pass 97 | # print(x.shape) 98 | model(*x) 99 | 100 | # remove these hooks 101 | for h in hooks: 102 | h.remove() 103 | 104 | logger.info("----------------------------------------------------------------") 105 | line_new = "{:>20} {:>25} {:>15}".format("Layer (type)", "Output Shape", "Param #") 106 | logger.info(line_new) 107 | logger.info("================================================================") 108 | total_params = 0 109 | total_output = 0 110 | trainable_params = 0 111 | for layer in summary: 112 | # input_shape, output_shape, trainable, nb_params 113 | line_new = "{:>20} {:>25} {:>15}".format( 114 | layer, 115 | str(summary[layer]["output_shape"]), 116 | "{0:,}".format(summary[layer]["nb_params"]), 117 | ) 118 | total_params += summary[layer]["nb_params"] 119 | total_output += np.prod(summary[layer]["output_shape"]) 120 | if "trainable" in summary[layer]: 121 | if summary[layer]["trainable"] == True: 122 | trainable_params += summary[layer]["nb_params"] 123 | logger.info(line_new) 124 | 125 | # assume 4 bytes/number (float on cuda). 126 | total_input_size = abs(np.prod(input_size) * batch_size * 4. / (1024 ** 2.)) 127 | total_output_size = abs(2. * total_output * 4. / (1024 ** 2.)) # x2 for gradients 128 | total_params_size = abs(total_params.numpy() * 4. / (1024 ** 2.)) 129 | total_size = total_params_size + total_output_size + total_input_size 130 | 131 | logger.info("================================================================") 132 | logger.info("Total params: {0:,}".format(total_params)) 133 | logger.info("Trainable params: {0:,}".format(trainable_params)) 134 | logger.info("Non-trainable params: {0:,}".format(total_params - trainable_params)) 135 | logger.info("----------------------------------------------------------------") 136 | logger.info("Input size (MB): %0.2f" % total_input_size) 137 | logger.info("Forward/backward pass size (MB): %0.2f" % total_output_size) 138 | logger.info("Params size (MB): %0.2f" % total_params_size) 139 | logger.info("Estimated Total Size (MB): %0.2f" % total_size) 140 | logger.info("----------------------------------------------------------------") 141 | -------------------------------------------------------------------------------- /utils/joint_transforms.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numbers 3 | import random 4 | 5 | from PIL import Image, ImageOps 6 | import numpy as np 7 | 8 | 9 | """ 10 | Most of codes here are from 11 | https://github.com/zijundeng/pytorch-semantic-segmentation/blob/master/utils/joint_transforms.py 12 | """ 13 | 14 | def pad_to_target(img, target_height, target_width, label=0): 15 | # Pad image with zeros to the specified height and width if needed 16 | # This op does nothing if the image already has size bigger than target_height and target_width. 17 | w, h = img.size 18 | left = top = right = bottom = 0 19 | doit = False 20 | if target_width > w: 21 | delta = target_width - w 22 | left = delta // 2 23 | right = delta - left 24 | doit = True 25 | if target_height > h: 26 | delta = target_height - h 27 | top = delta // 2 28 | bottom = delta - top 29 | doit = True 30 | if doit: 31 | img = ImageOps.expand(img, border=(left, top, right, bottom), fill=label) 32 | assert img.size[0] >= target_width 33 | assert img.size[1] >= target_height 34 | return img 35 | 36 | class Compose(object): 37 | def __init__(self, transforms): 38 | self.transforms = transforms 39 | 40 | def __call__(self, img, mask): 41 | assert img.size == mask.size 42 | for t in self.transforms: 43 | img, mask = t(img, mask) 44 | return img, mask 45 | 46 | class Safe32Padding(object): 47 | def __call__(self, img, mask=None): 48 | width, height = img.size 49 | 50 | if (height % 32) != 0: height += 32 - (height % 32) 51 | if (width % 32) != 0: width += 32 - (width % 32) 52 | 53 | if mask: 54 | return pad_to_target(img, height, width), pad_to_target(mask, height, width) 55 | else: 56 | return pad_to_target(img, height, width) 57 | 58 | class Resize(object): 59 | def __init__(self, size): 60 | self.w = 0 61 | self.h = 0 62 | if isinstance(size, int): 63 | self.w = size 64 | self.h = size 65 | elif isinstance(size, tuple) and len(size) == 2: 66 | if isinstance(size[0], int) and isinstance(size[1], int): 67 | self.w = size[0] 68 | self.h = size[1] 69 | else: 70 | raise ValueError 71 | else: 72 | raise ValueError 73 | 74 | def __call__(self, img, mask): 75 | return (img.resize((self.w, self.h), Image.NEAREST), 76 | mask.resize((self.w, self.h), Image.BILINEAR)) 77 | 78 | 79 | class RandomCrop(object): 80 | def __init__(self, size, padding=0): 81 | if isinstance(size, numbers.Number): 82 | self.size = (int(size), int(size)) 83 | else: 84 | self.size = size 85 | self.padding = padding 86 | 87 | def __call__(self, img, mask): 88 | if self.padding > 0: 89 | img = ImageOps.expand(img, border=self.padding, fill=0) 90 | mask = ImageOps.expand(mask, border=self.padding, fill=0) 91 | 92 | assert img.size == mask.size 93 | w, h = img.size 94 | th, tw = self.size 95 | if w == tw and h == th: 96 | return img, mask 97 | if w < tw or h < th: 98 | return img.resize((tw, th), Image.BILINEAR), mask.resize((tw, th), Image.NEAREST) 99 | 100 | x1 = random.randint(0, w - tw) 101 | y1 = random.randint(0, h - th) 102 | return img.crop((x1, y1, x1 + tw, y1 + th)), mask.crop((x1, y1, x1 + tw, y1 + th)) 103 | 104 | 105 | class CenterCrop(object): 106 | def __init__(self, size): 107 | if isinstance(size, numbers.Number): 108 | self.size = (int(size), int(size)) 109 | else: 110 | self.size = size 111 | 112 | def __call__(self, img, mask): 113 | assert img.size == mask.size 114 | w, h = img.size 115 | th, tw = self.size 116 | x1 = int(round((w - tw) / 2.)) 117 | y1 = int(round((h - th) / 2.)) 118 | return img.crop((x1, y1, x1 + tw, y1 + th)), mask.crop((x1, y1, x1 + tw, y1 + th)) 119 | 120 | 121 | class RandomHorizontallyFlip(object): 122 | def __call__(self, img, mask): 123 | if random.random() < 0.5: 124 | return img.transpose(Image.FLIP_LEFT_RIGHT), mask.transpose(Image.FLIP_LEFT_RIGHT) 125 | return img, mask 126 | 127 | 128 | class FreeScale(object): 129 | def __init__(self, size): 130 | self.size = tuple(reversed(size)) # size: (h, w) 131 | 132 | def __call__(self, img, mask): 133 | assert img.size == mask.size 134 | return img.resize(self.size, Image.BILINEAR), mask.resize(self.size, Image.NEAREST) 135 | 136 | 137 | class RandomSizedCrop(object): 138 | def __init__(self, size): 139 | self.size = size 140 | 141 | def __call__(self, img, mask): 142 | assert img.size == mask.size 143 | for attempt in range(10): 144 | area = img.size[0] * img.size[1] 145 | target_area = random.uniform(0.45, 1.0) * area 146 | aspect_ratio = random.uniform(0.5, 2) 147 | 148 | w = int(round(math.sqrt(target_area * aspect_ratio))) 149 | h = int(round(math.sqrt(target_area / aspect_ratio))) 150 | 151 | if random.random() < 0.5: 152 | w, h = h, w 153 | 154 | if w <= img.size[0] and h <= img.size[1]: 155 | x1 = random.randint(0, img.size[0] - w) 156 | y1 = random.randint(0, img.size[1] - h) 157 | 158 | img = img.crop((x1, y1, x1 + w, y1 + h)) 159 | mask = mask.crop((x1, y1, x1 + w, y1 + h)) 160 | assert (img.size == (w, h)) 161 | 162 | return img.resize((self.size, self.size), Image.BILINEAR), mask.resize((self.size, self.size), 163 | Image.NEAREST) 164 | 165 | # Fallback 166 | resize = Resize(self.size) 167 | crop = CenterCrop(self.size) 168 | return crop(*resize(img, mask)) 169 | 170 | 171 | class RandomRotate(object): 172 | def __init__(self, degree): 173 | self.degree = degree 174 | 175 | def __call__(self, img, mask): 176 | rotate_degree = random.random() * 2 * self.degree - self.degree 177 | return img.rotate(rotate_degree, Image.BILINEAR), mask.rotate(rotate_degree, Image.NEAREST) 178 | 179 | 180 | class RandomSized(object): 181 | def __init__(self, size): 182 | self.size = size 183 | self.scale = Scale(self.size) 184 | self.crop = RandomCrop(self.size) 185 | 186 | def __call__(self, img, mask): 187 | assert img.size == mask.size 188 | 189 | w = int(random.uniform(0.5, 2) * img.size[0]) 190 | h = int(random.uniform(0.5, 2) * img.size[1]) 191 | 192 | img, mask = img.resize((w, h), Image.BILINEAR), mask.resize((w, h), Image.NEAREST) 193 | 194 | return self.crop(*self.scale(img, mask)) 195 | 196 | 197 | class SlidingCropOld(object): 198 | def __init__(self, crop_size, stride_rate, ignore_label): 199 | self.crop_size = crop_size 200 | self.stride_rate = stride_rate 201 | self.ignore_label = ignore_label 202 | 203 | def _pad(self, img, mask): 204 | h, w = img.shape[: 2] 205 | pad_h = max(self.crop_size - h, 0) 206 | pad_w = max(self.crop_size - w, 0) 207 | img = np.pad(img, ((0, pad_h), (0, pad_w), (0, 0)), 'constant') 208 | mask = np.pad(mask, ((0, pad_h), (0, pad_w)), 'constant', constant_values=self.ignore_label) 209 | return img, mask 210 | 211 | def __call__(self, img, mask): 212 | assert img.size == mask.size 213 | 214 | w, h = img.size 215 | long_size = max(h, w) 216 | 217 | img = np.array(img) 218 | mask = np.array(mask) 219 | 220 | if long_size > self.crop_size: 221 | stride = int(math.ceil(self.crop_size * self.stride_rate)) 222 | h_step_num = int(math.ceil((h - self.crop_size) / float(stride))) + 1 223 | w_step_num = int(math.ceil((w - self.crop_size) / float(stride))) + 1 224 | img_sublist, mask_sublist = [], [] 225 | for yy in xrange(h_step_num): 226 | for xx in xrange(w_step_num): 227 | sy, sx = yy * stride, xx * stride 228 | ey, ex = sy + self.crop_size, sx + self.crop_size 229 | img_sub = img[sy: ey, sx: ex, :] 230 | mask_sub = mask[sy: ey, sx: ex] 231 | img_sub, mask_sub = self._pad(img_sub, mask_sub) 232 | img_sublist.append(Image.fromarray(img_sub.astype(np.uint8)).convert('RGB')) 233 | mask_sublist.append(Image.fromarray(mask_sub.astype(np.uint8)).convert('P')) 234 | return img_sublist, mask_sublist 235 | else: 236 | img, mask = self._pad(img, mask) 237 | img = Image.fromarray(img.astype(np.uint8)).convert('RGB') 238 | mask = Image.fromarray(mask.astype(np.uint8)).convert('P') 239 | return img, mask 240 | 241 | 242 | class SlidingCrop(object): 243 | def __init__(self, crop_size, stride_rate, ignore_label): 244 | self.crop_size = crop_size 245 | self.stride_rate = stride_rate 246 | self.ignore_label = ignore_label 247 | 248 | def _pad(self, img, mask): 249 | h, w = img.shape[: 2] 250 | pad_h = max(self.crop_size - h, 0) 251 | pad_w = max(self.crop_size - w, 0) 252 | img = np.pad(img, ((0, pad_h), (0, pad_w), (0, 0)), 'constant') 253 | mask = np.pad(mask, ((0, pad_h), (0, pad_w)), 'constant', constant_values=self.ignore_label) 254 | return img, mask, h, w 255 | 256 | def __call__(self, img, mask): 257 | assert img.size == mask.size 258 | 259 | w, h = img.size 260 | long_size = max(h, w) 261 | 262 | img = np.array(img) 263 | mask = np.array(mask) 264 | 265 | if long_size > self.crop_size: 266 | stride = int(math.ceil(self.crop_size * self.stride_rate)) 267 | h_step_num = int(math.ceil((h - self.crop_size) / float(stride))) + 1 268 | w_step_num = int(math.ceil((w - self.crop_size) / float(stride))) + 1 269 | img_slices, mask_slices, slices_info = [], [], [] 270 | for yy in xrange(h_step_num): 271 | for xx in xrange(w_step_num): 272 | sy, sx = yy * stride, xx * stride 273 | ey, ex = sy + self.crop_size, sx + self.crop_size 274 | img_sub = img[sy: ey, sx: ex, :] 275 | mask_sub = mask[sy: ey, sx: ex] 276 | img_sub, mask_sub, sub_h, sub_w = self._pad(img_sub, mask_sub) 277 | img_slices.append(Image.fromarray(img_sub.astype(np.uint8)).convert('RGB')) 278 | mask_slices.append(Image.fromarray(mask_sub.astype(np.uint8)).convert('P')) 279 | slices_info.append([sy, ey, sx, ex, sub_h, sub_w]) 280 | return img_slices, mask_slices, slices_info 281 | else: 282 | img, mask, sub_h, sub_w = self._pad(img, mask) 283 | img = Image.fromarray(img.astype(np.uint8)).convert('RGB') 284 | mask = Image.fromarray(mask.astype(np.uint8)).convert('P') 285 | return [img], [mask], [[0, sub_h, 0, sub_w, sub_h, sub_w]] 286 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from ignite.metrics.metric import Metric 3 | 4 | class MultiThresholdMeasures(Metric): 5 | """ 6 | Calculates Accuracy, IoU, F1-score (Dice Coefficient) within thresholds [0.0, 0.1, ..., 1.0] 7 | """ 8 | def __init__(self): 9 | super(MultiThresholdMeasures, self).__init__() 10 | self.reset() 11 | self._thrs = torch.FloatTensor([i/10 for i in range(11)]).to(self._device) 12 | 13 | def reset(self): 14 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 15 | self._tp = torch.zeros(11).to(device) 16 | self._fp = torch.zeros(11).to(device) 17 | self._fn = torch.zeros(11).to(device) 18 | self._tn = torch.zeros(11).to(device) 19 | self._device = device 20 | 21 | def update(self, output): 22 | logit, y = output 23 | n = y.size(0) 24 | 25 | y_pred = torch.sigmoid(logit) 26 | y_pred = y_pred.view(n, -1, 1).repeat(1, 1, 11) > self._thrs 27 | y = y.byte().view(n, -1, 1).repeat(1, 1, 11) 28 | 29 | tp = y_pred * y == 1 30 | tn = y_pred + y == 0 31 | fp = y_pred - y == 1 32 | fn = y - y_pred == 1 33 | 34 | self._tp += torch.sum(tp, dim=[0,1]).float() 35 | self._tn += torch.sum(tn, dim=[0,1]).float() 36 | self._fp += torch.sum(fp, dim=[0,1]).float() 37 | self._fn += torch.sum(fn, dim=[0,1]).float() 38 | 39 | def compute(self): 40 | return 41 | 42 | def compute_iou(self): 43 | intersect = self._tp 44 | union = self._tp + self._fp + self._fn 45 | iou = intersect / union 46 | return [round(i.item(), 3) for i in iou] 47 | 48 | def compute_f1(self): 49 | pr = self._tp / (self._tp + self._fp) 50 | re = self._tp / (self._tp + self._fn) 51 | f1 = 2 * pr * re / (pr + re) 52 | return [round(f.item(), 3) for f in f1] 53 | 54 | def compute_accuracy(self): 55 | acc = (self._tp + self._tn) / (self._tp + self._tn + self._fp + self._fn) 56 | return [round(a.item(), 3) for a in acc] 57 | 58 | 59 | class Accuracy(Metric): 60 | def __init__(self, multi_thrs_measure): 61 | super(Accuracy, self).__init__() 62 | self.multi_thrs_measure = multi_thrs_measure 63 | 64 | def update(self, output): 65 | return 66 | 67 | def compute(self): 68 | return self.multi_thrs_measure.compute_accuracy() 69 | 70 | 71 | class IoU(Metric): 72 | def __init__(self, multi_thrs_measure): 73 | super(IoU, self).__init__() 74 | self.multi_thrs_measure = multi_thrs_measure 75 | 76 | def update(self, output): 77 | return 78 | 79 | def compute(self): 80 | return self.multi_thrs_measure.compute_iou() 81 | 82 | 83 | class F1score(Metric): 84 | def __init__(self, multi_thrs_measure): 85 | super(F1score, self).__init__() 86 | self.multi_thrs_measure = multi_thrs_measure 87 | 88 | def update(self, output): 89 | return 90 | 91 | def compute(self): 92 | return self.multi_thrs_measure.compute_f1() 93 | 94 | -------------------------------------------------------------------------------- /utils/trainer_verbose.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append(os.path.dirname(os.path.dirname(__file__))) 4 | 5 | import numpy as np 6 | 7 | from data import get_loader 8 | from utils import update_state, save_ckpt_file 9 | from utils import joint_transforms as jnt_trnsf 10 | from utils import summarize_model 11 | from networks import get_network 12 | 13 | import torch 14 | import torchvision.transforms as std_trnsf 15 | 16 | from tqdm import tqdm 17 | 18 | def get_optimizer(string, model, lr, momentum): 19 | string = string.lower() 20 | if string == 'adam': 21 | return torch.optim.Adam(model.parameters(), lr=lr, betas=(momentum, 0.999)) 22 | elif string == 'sgd': 23 | return torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum) 24 | raise ValueError 25 | 26 | 27 | def train_with_ignite(networks, dataset, data_dir, batch_size, img_size, 28 | epochs, lr, momentum, num_workers, optimizer, logger): 29 | 30 | from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator 31 | from ignite.metrics import Loss 32 | from utils.metrics import MultiThresholdMeasures, Accuracy, IoU, F1score 33 | 34 | # device 35 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 36 | 37 | # build model 38 | model = get_network(networks) 39 | 40 | # log model summary 41 | input_size = (3, img_size, img_size) 42 | summarize_model(model.to(device), input_size, logger, batch_size, device) 43 | 44 | # build loss 45 | loss = torch.nn.BCEWithLogitsLoss() 46 | 47 | # build optimizer and scheduler 48 | model_optimizer = get_optimizer(optimizer, model, lr, momentum) 49 | lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(model_optimizer) 50 | 51 | # transforms on both image and mask 52 | train_joint_transforms = jnt_trnsf.Compose([ 53 | jnt_trnsf.RandomCrop(img_size), 54 | jnt_trnsf.RandomRotate(5), 55 | jnt_trnsf.RandomHorizontallyFlip() 56 | ]) 57 | 58 | # transforms only on images 59 | train_image_transforms = std_trnsf.Compose([ 60 | std_trnsf.ColorJitter(0.05, 0.05, 0.05, 0.05), 61 | std_trnsf.ToTensor(), 62 | std_trnsf.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 63 | ]) 64 | 65 | test_joint_transforms = jnt_trnsf.Compose([ 66 | jnt_trnsf.Safe32Padding() 67 | ]) 68 | 69 | test_image_transforms = std_trnsf.Compose([ 70 | std_trnsf.ToTensor(), 71 | std_trnsf.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 72 | ]) 73 | 74 | # transforms only on mask 75 | mask_transforms = std_trnsf.Compose([ 76 | std_trnsf.ToTensor() 77 | ]) 78 | 79 | # build train / test loader 80 | train_loader = get_loader(dataset=dataset, 81 | data_dir=data_dir, 82 | train=True, 83 | joint_transforms=train_joint_transforms, 84 | image_transforms=train_image_transforms, 85 | mask_transforms=mask_transforms, 86 | batch_size=batch_size, 87 | shuffle=False, 88 | num_workers=num_workers) 89 | 90 | test_loader = get_loader(dataset=dataset, 91 | data_dir=data_dir, 92 | train=False, 93 | joint_transforms=test_joint_transforms, 94 | image_transforms=test_image_transforms, 95 | mask_transforms=mask_transforms, 96 | batch_size=1, 97 | shuffle=False, 98 | num_workers=num_workers) 99 | 100 | # build trainer / evaluator with ignite 101 | trainer = create_supervised_trainer(model, model_optimizer, loss, device=device) 102 | measure = MultiThresholdMeasures() 103 | evaluator = create_supervised_evaluator(model, 104 | metrics={ 105 | '': measure, 106 | 'pix-acc': Accuracy(measure), 107 | 'iou': IoU(measure), 108 | 'loss': Loss(loss), 109 | 'f1': F1score(measure), 110 | }, 111 | device=device) 112 | 113 | # initialize state variable for checkpoint 114 | state = update_state(model.state_dict(), 0, 0, 0, 0, 0) 115 | 116 | # make ckpt path 117 | ckpt_root = './ckpt/' 118 | filename = '{network}_{optimizer}_lr_{lr}_epoch_{epoch}.pth' 119 | ckpt_path = os.path.join(ckpt_root, filename) 120 | 121 | # execution after every training iteration 122 | @trainer.on(Events.ITERATION_COMPLETED) 123 | def log_training_loss(trainer): 124 | num_iter = (trainer.state.iteration - 1) % len(train_loader) + 1 125 | if num_iter % 20 == 0: 126 | logger.info("Epoch[{}] Iter[{:03d}] Loss: {:.2f}".format( 127 | trainer.state.epoch, num_iter, trainer.state.output)) 128 | 129 | # execution after every training epoch 130 | @trainer.on(Events.EPOCH_COMPLETED) 131 | def log_training_results(trainer): 132 | # evaluate on training set 133 | evaluator.run(train_loader) 134 | metrics = evaluator.state.metrics 135 | logger.info("Training Results - Epoch: {} Avg-loss: {:.3f}\n Pix-acc: {}\n IoU: {}\n F1: {}\n".format( 136 | trainer.state.epoch, metrics['loss'], str(metrics['pix-acc']), str(metrics['iou']), str(metrics['f1']))) 137 | 138 | # update state 139 | update_state(weight=model.state_dict(), 140 | train_loss=metrics['loss'], 141 | val_loss=state['val_loss'], 142 | val_pix_acc=state['val_pix_acc'], 143 | val_iou=state['val_iou'], 144 | val_f1=state['val_f1']) 145 | 146 | # execution after every epoch 147 | @trainer.on(Events.EPOCH_COMPLETED) 148 | def log_validation_results(trainer): 149 | # evaluate test(validation) set 150 | evaluator.run(test_loader) 151 | metrics = evaluator.state.metrics 152 | logger.info("Validation Results - Epoch: {} Avg-loss: {:.3f}\n Pix-acc: {}\n IoU: {}\n F1: {}\n".format( 153 | trainer.state.epoch, metrics['loss'], str(metrics['pix-acc']), str(metrics['iou']), str(metrics['f1']))) 154 | 155 | # update scheduler 156 | lr_scheduler.step(metrics['loss']) 157 | 158 | # update and save state 159 | update_state(weight=model.state_dict(), 160 | train_loss=state['train_loss'], 161 | val_loss=metrics['loss'], 162 | val_pix_acc=metrics['pix-acc'], 163 | val_iou=metrics['iou'], 164 | val_f1=metrics['f1']) 165 | 166 | path = ckpt_path.format(network=networks, 167 | optimizer=optimizer, 168 | lr=lr, 169 | epoch=trainer.state.epoch) 170 | save_ckpt_file(path, state) 171 | 172 | trainer.run(train_loader, max_epochs=epochs) 173 | 174 | def train_without_ignite(model, loss, batch_size, img_size, 175 | epochs, lr, num_workers, optimizer, logger, gray_image=False, scheduler=None, viz=True): 176 | import visdom 177 | from utils.metrics import Accuracy, IoU 178 | 179 | DEFAULT_PORT = 8097 180 | DEFAULT_HOSTNAME = "http://localhost" 181 | 182 | if viz: 183 | vis = visdom.Visdom(port=DEFAULT_PORT, server=DEFAULT_HOSTNAME) 184 | 185 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 186 | 187 | data_loader = {} 188 | 189 | joint_transforms = jnt_trnsf.Compose([ 190 | jnt_trnsf.RandomCrop(img_size), 191 | jnt_trnsf.RandomRotate(5), 192 | jnt_trnsf.RandomHorizontallyFlip() 193 | ]) 194 | 195 | train_image_transforms = std_trnsf.Compose([ 196 | std_trnsf.ColorJitter(0.05, 0.05, 0.05, 0.05), 197 | std_trnsf.ToTensor(), 198 | std_trnsf.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 199 | ]) 200 | 201 | test_joint_transforms = jnt_trnsf.Compose([ 202 | jnt_trnsf.Safe32Padding() 203 | ]) 204 | 205 | test_image_transforms = std_trnsf.Compose([ 206 | std_trnsf.ToTensor(), 207 | std_trnsf.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 208 | ]) 209 | 210 | mask_transforms = std_trnsf.Compose([ 211 | std_trnsf.ToTensor() 212 | ]) 213 | 214 | data_loader['train'] = get_loader(dataset='figaro', 215 | train=True, 216 | joint_transforms=joint_transforms, 217 | image_transforms=train_image_transforms, 218 | mask_transforms=mask_transforms, 219 | batch_size=batch_size, 220 | shuffle=True, 221 | num_workers=num_workers, 222 | gray_image=gray_image) 223 | 224 | data_loader['test'] = get_loader(dataset='figaro', 225 | train=False, 226 | joint_transforms=test_joint_transforms, 227 | image_transforms=test_image_transforms, 228 | mask_transforms=mask_transforms, 229 | batch_size=1, 230 | shuffle=True, 231 | num_workers=num_workers, 232 | gray_image=gray_image) 233 | 234 | for epoch in range(epochs): 235 | for phase in ['train', 'test']: 236 | if phase == 'train': 237 | model.train(True) 238 | else: 239 | prev_grad_state = torch.is_grad_enabled() 240 | torch.set_grad_enabled(False) 241 | model.train(False) 242 | 243 | running_loss = 0.0 244 | 245 | for i, data in enumerate(tqdm(data_loader[phase], file=sys.stdout)): 246 | if i == len(data_loader[phase]) - 1: break 247 | data_ = [t.to(device) if isinstance(t, torch.Tensor) else t for t in data] 248 | 249 | if gray_image: 250 | img, mask, gray = data_ 251 | else: 252 | img, mask = data_ 253 | 254 | model.zero_grad() 255 | 256 | pred_mask = model(img) 257 | 258 | if gray_image: 259 | l = loss(pred_mask, mask, gray) 260 | else: 261 | l = loss(pred_mask, mask) 262 | 263 | if phase == 'train': 264 | l.backward() 265 | optimizer.step() 266 | 267 | running_loss += l.item() 268 | 269 | epoch_loss = running_loss / len(data_loader[phase]) 270 | 271 | if phase == 'train': 272 | logger.info(f"Training Results - Epoch: {epoch} Avg-loss: {epoch_loss:.3f}") 273 | if viz: 274 | vis.images([ 275 | np.clip(pred_mask.detach().cpu().numpy()[0],0,1), 276 | mask.detach().cpu().numpy()[0] 277 | ], opts=dict(title=f'pred img for {epoch}-th iter')) 278 | 279 | if phase == 'test': 280 | if viz: 281 | vis.images([ 282 | np.clip(pred_mask.detach().cpu().numpy()[0],0,1), 283 | mask.detach().cpu().numpy()[0] 284 | ], opts=dict(title=f'pred img for {epoch}-th iter')) 285 | logger.info(f"Test Results - Epoch: {epoch} Avg-loss: {epoch_loss:.3f}") 286 | 287 | if scheduler: scheduler.step(epoch_loss) 288 | 289 | torch.set_grad_enabled(prev_grad_state) 290 | 291 | --------------------------------------------------------------------------------