├── .gitignore ├── README.md ├── assets ├── 1.png ├── 2.png ├── 3.png ├── 4.png ├── 5.png ├── 6.png └── 7.png ├── config ├── __init__.py └── defaults.py ├── configs ├── test_fcn16s.yml ├── test_fcn32s.yml ├── test_fcn8s.yml ├── test_fcn8s_atonce.yml ├── train_fcn16s.yml ├── train_fcn32s.yml ├── train_fcn8s.yml └── train_fcn8s_atonce.yml ├── data ├── __init__.py ├── build.py ├── datasets │ ├── __init__.py │ └── voc.py └── transforms │ ├── __init__.py │ ├── build.py │ └── transforms.py ├── engine ├── inference.py └── trainer.py ├── get_data.sh ├── layers ├── bilinear_upsample.py ├── conv_layer.py └── cross_entropy2d.py ├── modeling ├── __init__.py ├── backbone │ ├── __init__.py │ └── vgg.py ├── fcn16s.py ├── fcn32s.py └── fcn8s.py ├── solver ├── __init__.py └── build.py ├── tests ├── __init__.py ├── test_dataset.py └── test_model.py ├── tools ├── __init__.py ├── test_fcn.py └── train_fcn.py └── utils ├── __init__.py ├── logger.py └── metric.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | __pycache__ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # fcn.pytorch 2 | 3 | PyTorch implementation of [Fully Convolutional Networks](https://github.com/shelhamer/fcn.berkeleyvision.org), main code modified from [pytorch-fcn](https://github.com/wkentaro/pytorch-fcn). 4 | 5 | ### Requirements 6 | - pytorch 7 | - torchvision 8 | - [ignite](https://github.com/pytorch/ignite) 9 | - [yacs](https://github.com/rbgirshick/yacs) 10 | - [tensorboardX](https://github.com/lanpa/tensorboardX) 11 | - tensorflow (for tensorboard) 12 | 13 | ### Get Started 14 | The designed architecture follows this guide [PyTorch-Project-Template](https://github.com/L1aoXingyu/PyTorch-Project-Template), you can check each folder's purpose by yourself. 15 | 16 | #### Prepare Dataset 17 | You can open the terminal and run the bash command to get VOC2012 dataset 18 | 19 | ```bash 20 | bash get_data.sh 21 | ``` 22 | 23 | or you can just copy this url download by yourself 24 | 25 | ```bash 26 | http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar 27 | ``` 28 | 29 | ### Training 30 | Most of the configuration files that we provide are in folder `configs`. You just need to modify `dataset root`, `vgg model weight` and `output directory`. There are a few possibilities: 31 | 32 | #### 1. Modify configuration file and run 33 | You can modify `train_fcn32s.yml` first and run following code 34 | 35 | ```bash 36 | python3 tools/train_fcn.py --config_file='configs/train_fcn32s.yml' 37 | ``` 38 | 39 | #### 2. Modify the cfg parameters 40 | You can change configuration parameter such as learning rate or max epochs in command line. 41 | 42 | ```bash 43 | python3 tools/train_fcn.py --config_file='configs/train_fcn32s.yml' SOLVER.BASE_LR 0.0025 SOLVER.MAX_EPOCHS 8 44 | ``` 45 | 46 | ### Results 47 | We are training these models on VOC2012 train.txt and testing on val.txt, and we also use torchvision pretrained vgg16 rather than caffe pretrained. So the results maybe are different from the origin paper. 48 | 49 | |Model| Epoch | Mean IU | 50 | |-|-|-| 51 | | FCN32s| 13 | 55.1| 52 | | FCN16s| 8 | 54.8| 53 | | FCN8s | 7 | 55.7 | 54 | | FCN8sAtOnce | 11 | 53.6 | 55 | 56 |
57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 |
65 | -------------------------------------------------------------------------------- /assets/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/L1aoXingyu/fcn.pytorch/7f592407f41325375baff5b3514567fb4c5f9a62/assets/1.png -------------------------------------------------------------------------------- /assets/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/L1aoXingyu/fcn.pytorch/7f592407f41325375baff5b3514567fb4c5f9a62/assets/2.png -------------------------------------------------------------------------------- /assets/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/L1aoXingyu/fcn.pytorch/7f592407f41325375baff5b3514567fb4c5f9a62/assets/3.png -------------------------------------------------------------------------------- /assets/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/L1aoXingyu/fcn.pytorch/7f592407f41325375baff5b3514567fb4c5f9a62/assets/4.png -------------------------------------------------------------------------------- /assets/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/L1aoXingyu/fcn.pytorch/7f592407f41325375baff5b3514567fb4c5f9a62/assets/5.png -------------------------------------------------------------------------------- /assets/6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/L1aoXingyu/fcn.pytorch/7f592407f41325375baff5b3514567fb4c5f9a62/assets/6.png -------------------------------------------------------------------------------- /assets/7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/L1aoXingyu/fcn.pytorch/7f592407f41325375baff5b3514567fb4c5f9a62/assets/7.png -------------------------------------------------------------------------------- /config/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from .defaults import _C as cfg 8 | -------------------------------------------------------------------------------- /config/defaults.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | 3 | # ----------------------------------------------------------------------------- 4 | # Convention about Training / Test specific parameters 5 | # ----------------------------------------------------------------------------- 6 | # Whenever an argument can be either used for training or for testing, the 7 | # corresponding name will be post-fixed by a _TRAIN for a training parameter, 8 | # or _TEST for a test-specific parameter. 9 | # For example, the number of images during training will be 10 | # IMAGES_PER_BATCH_TRAIN, while the number of images for testing will be 11 | # IMAGES_PER_BATCH_TEST 12 | 13 | # ----------------------------------------------------------------------------- 14 | # Config definition 15 | # ----------------------------------------------------------------------------- 16 | 17 | _C = CN() 18 | 19 | _C.MODEL = CN() 20 | _C.MODEL.DEVICE = "cuda" 21 | _C.MODEL.NUM_CLASSES = 21 22 | 23 | _C.MODEL.META_ARCHITECTURE = "fcn32s" 24 | 25 | _C.MODEL.BACKBONE = CN() 26 | _C.MODEL.BACKBONE.NAME = "vgg16" 27 | _C.MODEL.BACKBONE.PRETRAINED = False 28 | _C.MODEL.BACKBONE.WEIGHT = "" 29 | 30 | _C.MODEL.REFINEMENT = CN() 31 | _C.MODEL.REFINEMENT.NAME = '' 32 | _C.MODEL.REFINEMENT.WEIGHT = '' 33 | 34 | # ----------------------------------------------------------------------------- 35 | # INPUT 36 | # ----------------------------------------------------------------------------- 37 | _C.INPUT = CN() 38 | # Random probability for image horizontal flip 39 | _C.INPUT.PROB = 0.5 40 | # Values to be used for image normalization 41 | # _C.INPUT.PIXEL_MEAN = [104.00698793, 116.66876762, 122.67891434] 42 | _C.INPUT.PIXEL_MEAN = [0.485, 0.456, 0.406] 43 | # Values to be used for image normalization 44 | _C.INPUT.PIXEL_STD = [0.229, 0.224, 0.225] 45 | 46 | # ----------------------------------------------------------------------------- 47 | # Dataset 48 | # ----------------------------------------------------------------------------- 49 | _C.DATASETS = CN() 50 | # Dataset root path 51 | _C.DATASETS.ROOT = '' 52 | # ----------------------------------------------------------------------------- 53 | # DataLoader 54 | # ----------------------------------------------------------------------------- 55 | _C.DATALOADER = CN() 56 | # Number of data loading threads 57 | _C.DATALOADER.NUM_WORKERS = 8 58 | 59 | # ---------------------------------------------------------------------------- # 60 | # Solver 61 | # ---------------------------------------------------------------------------- # 62 | _C.SOLVER = CN() 63 | _C.SOLVER.OPTIMIZER_NAME = "SGD" 64 | 65 | _C.SOLVER.MAX_EPOCHS = 11 66 | 67 | _C.SOLVER.BASE_LR = 1.0e-4 68 | _C.SOLVER.BIAS_LR_FACTOR = 2 69 | 70 | _C.SOLVER.MOMENTUM = 0.99 71 | 72 | _C.SOLVER.WEIGHT_DECAY = 0.0005 73 | _C.SOLVER.WEIGHT_DECAY_BIAS = 0 74 | 75 | _C.SOLVER.CHECKPOINT_PERIOD = 10 76 | _C.SOLVER.LOG_PERIOD = 400 77 | 78 | # Number of images per batch 79 | # This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will 80 | # see 2 images per batch 81 | _C.SOLVER.IMS_PER_BATCH = 1 82 | 83 | # This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will 84 | # see 2 images per batch 85 | _C.TEST = CN() 86 | _C.TEST.IMS_PER_BATCH = 1 87 | _C.TEST.WEIGHT = "" 88 | 89 | # ---------------------------------------------------------------------------- # 90 | # Misc options 91 | # ---------------------------------------------------------------------------- # 92 | _C.OUTPUT_DIR = "" 93 | -------------------------------------------------------------------------------- /configs/test_fcn16s.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | META_ARCHITECTURE: "fcn16s" 3 | 4 | BACKBONE: 5 | PRETRAINED: False 6 | 7 | REFINEMENT: 8 | NAME: '' 9 | 10 | DATASETS: 11 | ROOT: '/mnt/truenas/scratch/xingyu.liao/DATA/VOCdevkit/VOC2012' 12 | 13 | 14 | TEST: 15 | WEIGHT: '/mnt/truenas/scratch/xingyu.liao/checkpoints/train_fcn16s/fcn_model_8.pth' 16 | 17 | OUTPUT_DIR: "/mnt/truenas/scratch/xingyu.liao/checkpoints/inference_fcn16s" 18 | 19 | -------------------------------------------------------------------------------- /configs/test_fcn32s.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | META_ARCHITECTURE: "fcn32s" 3 | 4 | BACKBONE: 5 | PRETRAINED: False 6 | 7 | REFINEMENT: 8 | NAME: '' 9 | 10 | DATASETS: 11 | ROOT: '/mnt/truenas/scratch/xingyu.liao/DATA/VOCdevkit/VOC2012' 12 | 13 | 14 | TEST: 15 | WEIGHT: '/mnt/truenas/scratch/xingyu.liao/checkpoints/train_fcn32s/fcn_model_13.pth' 16 | 17 | OUTPUT_DIR: "/mnt/truenas/scratch/xingyu.liao/checkpoints/inference_fcn32s" 18 | 19 | -------------------------------------------------------------------------------- /configs/test_fcn8s.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | META_ARCHITECTURE: "fcn8s" 3 | 4 | BACKBONE: 5 | PRETRAINED: False 6 | 7 | REFINEMENT: 8 | NAME: '' 9 | 10 | DATASETS: 11 | ROOT: '/mnt/truenas/scratch/xingyu.liao/DATA/VOCdevkit/VOC2012' 12 | 13 | 14 | TEST: 15 | WEIGHT: '/mnt/truenas/scratch/xingyu.liao/checkpoints/train_fcn8s/fcn_model_7.pth' 16 | 17 | OUTPUT_DIR: "/mnt/truenas/scratch/xingyu.liao/checkpoints/inference_fcn8s" 18 | 19 | -------------------------------------------------------------------------------- /configs/test_fcn8s_atonce.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | META_ARCHITECTURE: "fcn8s" 3 | 4 | BACKBONE: 5 | PRETRAINED: False 6 | 7 | REFINEMENT: 8 | NAME: '' 9 | 10 | DATASETS: 11 | ROOT: '/mnt/truenas/scratch/xingyu.liao/DATA/VOCdevkit/VOC2012' 12 | 13 | 14 | TEST: 15 | WEIGHT: '/mnt/truenas/scratch/xingyu.liao/checkpoints/train_fcn8s_atonce/fcn_model_13.pth' 16 | 17 | OUTPUT_DIR: "/mnt/truenas/scratch/xingyu.liao/checkpoints/inference_fcn8s_atonce" 18 | 19 | -------------------------------------------------------------------------------- /configs/train_fcn16s.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | META_ARCHITECTURE: "fcn16s" 3 | 4 | BACKBONE: 5 | PRETRAINED: False 6 | 7 | REFINEMENT: 8 | NAME: 'fcn32s' 9 | WEIGHT: "/mnt/truenas/scratch/xingyu.liao/checkpoints/train_fcn32s/fcn_model_13.pth" 10 | 11 | DATASETS: 12 | ROOT: '/mnt/truenas/scratch/xingyu.liao/DATA/VOCdevkit/VOC2012' 13 | 14 | SOLVER: 15 | MAX_EPOCHS: 8 16 | CHECKPOINT_PERIOD: 8 17 | 18 | OUTPUT_DIR: "/mnt/truenas/scratch/xingyu.liao/checkpoints/train_fcn16s" 19 | -------------------------------------------------------------------------------- /configs/train_fcn32s.yml: -------------------------------------------------------------------------------- 1 | 2 | MODEL: 3 | META_ARCHITECTURE: "fcn32s" 4 | 5 | BACKBONE: 6 | PRETRAINED: True 7 | WEIGHT: '/mnt/truenas/scratch/xingyu.liao/model_zoo/vgg16-397923af.pth' 8 | 9 | REFINEMENT: 10 | NAME: '' 11 | 12 | DATASETS: 13 | ROOT: '/mnt/truenas/scratch/xingyu.liao/DATA/VOCdevkit/VOC2012' 14 | 15 | SOLVER: 16 | MAX_EPOCHS: 13 17 | CHECKPOINT_PERIOD: 13 18 | 19 | OUTPUT_DIR: "/mnt/truenas/scratch/xingyu.liao/checkpoints/train_fcn32s" 20 | -------------------------------------------------------------------------------- /configs/train_fcn8s.yml: -------------------------------------------------------------------------------- 1 | 2 | MODEL: 3 | META_ARCHITECTURE: "fcn8s" 4 | 5 | BACKBONE: 6 | PRETRAINED: False 7 | 8 | REFINEMENT: 9 | NAME: 'fcn16s' 10 | WEIGHT: "/mnt/truenas/scratch/xingyu.liao/checkpoints/train_fcn16s/fcn_model_8.pth" 11 | 12 | DATASETS: 13 | ROOT: '/mnt/truenas/scratch/xingyu.liao/DATA/VOCdevkit/VOC2012' 14 | 15 | SOLVER: 16 | MAX_EPOCHS: 7 17 | CHECKPOINT_PERIOD: 7 18 | 19 | OUTPUT_DIR: "/mnt/truenas/scratch/xingyu.liao/checkpoints/train_fcn8s" 20 | -------------------------------------------------------------------------------- /configs/train_fcn8s_atonce.yml: -------------------------------------------------------------------------------- 1 | 2 | MODEL: 3 | META_ARCHITECTURE: "fcn8s" 4 | 5 | BACKBONE: 6 | PRETRAINED: True 7 | WEIGHT: '/mnt/truenas/scratch/xingyu.liao/model_zoo/vgg16-397923af.pth' 8 | 9 | REFINEMENT: 10 | NAME: '' 11 | 12 | DATASETS: 13 | ROOT: '/mnt/truenas/scratch/xingyu.liao/DATA/VOCdevkit/VOC2012' 14 | 15 | SOLVER: 16 | MAX_EPOCHS: 13 17 | CHECKPOINT_PERIOD: 13 18 | 19 | OUTPUT_DIR: "/mnt/truenas/scratch/xingyu.liao/checkpoints/train_fcn8s_atonce" 20 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from .build import make_data_loader 8 | -------------------------------------------------------------------------------- /data/build.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from torch.utils import data 8 | 9 | from .datasets.voc import VocSegDataset 10 | from .transforms import build_transforms 11 | 12 | 13 | def build_dataset(cfg, transforms, is_train=True): 14 | datasets = VocSegDataset(cfg, is_train, transforms) 15 | return datasets 16 | 17 | 18 | def make_data_loader(cfg, is_train=True): 19 | if is_train: 20 | batch_size = cfg.SOLVER.IMS_PER_BATCH 21 | shuffle = True 22 | else: 23 | batch_size = cfg.TEST.IMS_PER_BATCH 24 | shuffle = False 25 | 26 | transforms = build_transforms(cfg, is_train) 27 | datasets = build_dataset(cfg, transforms, is_train) 28 | 29 | num_workers = cfg.DATALOADER.NUM_WORKERS 30 | data_loader = data.DataLoader( 31 | datasets, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True 32 | ) 33 | 34 | return data_loader 35 | -------------------------------------------------------------------------------- /data/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | -------------------------------------------------------------------------------- /data/datasets/voc.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import os 8 | 9 | import numpy as np 10 | from PIL import Image 11 | from torch.utils import data 12 | 13 | 14 | def read_images(root, train): 15 | txt_fname = os.path.join(root, 'ImageSets/Segmentation/') + ('train.txt' if train else 'val.txt') 16 | with open(txt_fname, 'r') as f: 17 | images = f.read().split() 18 | data = [os.path.join(root, 'JPEGImages', i + '.jpg') for i in images] 19 | label = [os.path.join(root, 'SegmentationClass', i + '.png') for i in images] 20 | return data, label 21 | 22 | 23 | class VocSegDataset(data.Dataset): 24 | 25 | def __init__(self, cfg, train, transforms=None): 26 | self.cfg = cfg 27 | self.train = train 28 | self.transforms = transforms 29 | self.data_list, self.label_list = read_images(self.cfg.DATASETS.ROOT, train) 30 | 31 | def __getitem__(self, item): 32 | img = self.data_list[item] 33 | label = self.label_list[item] 34 | img = Image.open(img) 35 | # load label 36 | label = Image.open(label) 37 | img, label = self.transforms(img, label) 38 | return img, label 39 | 40 | def __len__(self): 41 | return len(self.data_list) 42 | -------------------------------------------------------------------------------- /data/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from .build import build_transforms, build_untransform 8 | -------------------------------------------------------------------------------- /data/transforms/build.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import numpy as np 8 | import torch 9 | import torchvision.transforms as T 10 | 11 | from .transforms import RandomHorizontalFlip 12 | 13 | 14 | def build_transforms(cfg, is_train=True): 15 | normalize = T.Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD) 16 | if is_train: 17 | def transform(img, target): 18 | img, target = RandomHorizontalFlip(cfg.INPUT.PROB)(img, target) 19 | img = T.ToTensor()(img) 20 | img = normalize(img) 21 | # label = image2label(target) 22 | label = np.array(target, dtype=np.int64) 23 | # remove boundary 24 | label[label == 255] = -1 25 | label = torch.from_numpy(label) 26 | return img, label 27 | 28 | return transform 29 | else: 30 | def transform(img, target): 31 | img = T.ToTensor()(img) 32 | img = normalize(img) 33 | # label = image2label(target) 34 | label = np.array(target, dtype=np.int64) 35 | # remove boundary 36 | label[label == 255] = -1 37 | label = torch.from_numpy(label) 38 | return img, label 39 | 40 | return transform 41 | 42 | 43 | def build_untransform(cfg): 44 | def untransform(img, target): 45 | img = img * torch.FloatTensor(cfg.INPUT.PIXEL_STD)[:, None, None] \ 46 | + torch.FloatTensor(cfg.INPUT.PIXEL_MEAN)[:, None, None] 47 | origin_img = torch.clamp(img, min=0, max=1) * 255 48 | origin_img = origin_img.permute(1, 2, 0).numpy() 49 | origin_img = origin_img.astype(np.uint8) 50 | 51 | label = target.numpy() 52 | label[label == -1] = 0 53 | return origin_img, label 54 | 55 | return untransform 56 | -------------------------------------------------------------------------------- /data/transforms/transforms.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import random 8 | 9 | import numpy as np 10 | import torchvision.transforms.functional as F 11 | 12 | 13 | class RandomHorizontalFlip(object): 14 | """Horizontally flip the given PIL Image randomly with a given probability. 15 | 16 | Args: 17 | p (float): probability of the image being flipped. Default value is 0.5 18 | """ 19 | 20 | def __init__(self, p=0.5): 21 | self.p = p 22 | 23 | def __call__(self, img, target): 24 | """ 25 | Args: 26 | img (PIL Image): Image to be flipped. 27 | 28 | Returns: 29 | PIL Image: Randomly flipped image. 30 | """ 31 | if random.random() < self.p: 32 | return F.hflip(img), F.hflip(target) 33 | return img, target 34 | 35 | def __repr__(self): 36 | return self.__class__.__name__ + '(p={})'.format(self.p) 37 | 38 | 39 | def image2label(img): 40 | cm2lbl = np.zeros(256 ** 3) 41 | for i, cm in enumerate(COLORMAP): 42 | cm2lbl[(cm[0] * 256 + cm[1]) * 256 + cm[2]] = i 43 | 44 | data = np.array(img, dtype=np.int32) 45 | idx = (data[:, :, 0] * 256 + data[:, :, 1] * 256 + data[:, :, 2]) 46 | return np.array(cm2lbl[idx], dtype=np.int64) 47 | 48 | 49 | CLASSES = ['background', 'aeroplane', 'bicycle', 'bird', 'boat', 50 | 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 51 | 'dog', 'horse', 'motorbike', 'person', 'potted plant', 52 | 'sheep', 'sofa', 'train', 'tv/monitor'] 53 | 54 | # RGB color for each class. 55 | COLORMAP = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128], 56 | [128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0], [192, 0, 0], 57 | [64, 128, 0], [192, 128, 0], [64, 0, 128], [192, 0, 128], 58 | [64, 128, 128], [192, 128, 128], [0, 64, 0], [128, 64, 0], 59 | [0, 192, 0], [128, 192, 0], [0, 64, 128]] 60 | -------------------------------------------------------------------------------- /engine/inference.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | import logging 7 | 8 | import matplotlib.pyplot as plt 9 | import numpy as np 10 | import torch 11 | from ignite.engine import Engine, Events 12 | from tensorboardX import SummaryWriter 13 | 14 | from data.transforms import build_untransform 15 | from data.transforms.transforms import COLORMAP 16 | from utils.metric import Label_Accuracy 17 | 18 | plt.switch_backend('agg') 19 | 20 | 21 | def create_evaluator(model, metrics={}, device=None): 22 | if device: 23 | model.to(device) 24 | 25 | def _inference(engine, batch): 26 | model.eval() 27 | with torch.no_grad(): 28 | x, y = batch 29 | x = x.to(device) 30 | y_pred = model(x) 31 | return y_pred, y 32 | 33 | engine = Engine(_inference) 34 | 35 | for name, metric in metrics.items(): 36 | metric.attach(engine, name) 37 | 38 | return engine 39 | 40 | 41 | def inference( 42 | cfg, 43 | model, 44 | val_loader 45 | ): 46 | cm = np.array(COLORMAP).astype(np.uint8) 47 | untransform = build_untransform(cfg) 48 | 49 | device = cfg.MODEL.DEVICE 50 | output_dir = cfg.OUTPUT_DIR 51 | 52 | logger = logging.getLogger("FCN_Model.inference") 53 | logger.info("Start inferencing") 54 | evaluator = create_evaluator(model, metrics={'mean_iu': Label_Accuracy(cfg.MODEL.NUM_CLASSES)}, device=device) 55 | 56 | writer = SummaryWriter(output_dir + '/board') 57 | 58 | # adding handlers using `evaluator.on` decorator API 59 | @evaluator.on(Events.EPOCH_COMPLETED) 60 | def print_validation_results(engine): 61 | metrics = evaluator.state.metrics 62 | mean_iu = metrics['mean_iu'] 63 | logger.info("Validation Results - Mean IU: {:.3f}".format(mean_iu)) 64 | 65 | @evaluator.on(Events.EPOCH_STARTED) 66 | def plot_output(engine): 67 | model.eval() 68 | for i, batch in enumerate(val_loader): 69 | if i > 9: 70 | break 71 | val_x, val_y = batch 72 | val_x = val_x.to(device) 73 | with torch.no_grad(): 74 | pred_y = model(val_x) 75 | 76 | orig_img, val_y = untransform(val_x.cpu().data[0], val_y[0]) 77 | pred_y = pred_y.max(1)[1].cpu().data[0].numpy() 78 | pred_val = cm[pred_y] 79 | seg_val = cm[val_y] 80 | 81 | # matplotlib 82 | fig = plt.figure(figsize=(9, 3)) 83 | plt.subplot(131) 84 | plt.imshow(orig_img) 85 | plt.axis("off") 86 | 87 | plt.subplot(132) 88 | plt.imshow(seg_val) 89 | plt.axis("off") 90 | 91 | plt.subplot(133) 92 | plt.imshow(pred_val) 93 | plt.axis("off") 94 | writer.add_figure('show_result', fig, i) 95 | 96 | evaluator.run(val_loader) 97 | writer.close() 98 | -------------------------------------------------------------------------------- /engine/trainer.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import logging 8 | 9 | import matplotlib.pyplot as plt 10 | import numpy as np 11 | import torch 12 | from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator 13 | from ignite.handlers import ModelCheckpoint, Timer 14 | from ignite.metrics import Loss, RunningAverage 15 | from tensorboardX import SummaryWriter 16 | 17 | from data.transforms import build_untransform 18 | from data.transforms.transforms import COLORMAP 19 | from utils.metric import Label_Accuracy 20 | 21 | plt.switch_backend('agg') 22 | 23 | 24 | def do_train( 25 | cfg, 26 | model, 27 | train_loader, 28 | val_loader, 29 | optimizer, 30 | loss_fn 31 | ): 32 | cm = np.array(COLORMAP).astype(np.uint8) 33 | untransform = build_untransform(cfg) 34 | 35 | log_period = cfg.SOLVER.LOG_PERIOD 36 | checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD 37 | epochs = cfg.SOLVER.MAX_EPOCHS 38 | device = cfg.MODEL.DEVICE 39 | output_dir = cfg.OUTPUT_DIR 40 | 41 | logger = logging.getLogger("FCN_Model.train") 42 | logger.info("Start training") 43 | trainer = create_supervised_trainer(model, optimizer, loss_fn, device=device) 44 | evaluator = create_supervised_evaluator(model, metrics={'mean_iu': Label_Accuracy(cfg.MODEL.NUM_CLASSES), 45 | 'loss': Loss(loss_fn)}, device=device) 46 | checkpointer = ModelCheckpoint(output_dir, 'fcn', checkpoint_period, n_saved=10, require_empty=False) 47 | timer = Timer(average=True) 48 | writer = SummaryWriter(output_dir + '/board') 49 | 50 | # automatically adding handlers via a special `attach` method of `RunningAverage` handler 51 | RunningAverage(output_transform=lambda x: x).attach(trainer, 'avg_loss') 52 | 53 | # automatically adding handlers via a special `attach` method of `Checkpointer` handler 54 | trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {'model': model.state_dict(), 55 | 'optimizer': optimizer.state_dict()}) 56 | 57 | # automatically adding handlers via a special `attach` method of `Timer` handler 58 | timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, 59 | pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) 60 | 61 | # adding handlers using `trainer.on` decorator API 62 | @trainer.on(Events.ITERATION_COMPLETED) 63 | def log_training_loss(engine): 64 | iter = (engine.state.iteration - 1) % len(train_loader) + 1 65 | 66 | if iter % log_period == 0: 67 | logger.info("Epoch[{}] Iteration[{}/{}] Loss: {:.3f}" 68 | .format(engine.state.epoch, iter, len(train_loader), engine.state.metrics['avg_loss'])) 69 | writer.add_scalars("loss", {'train': engine.state.metrics['avg_loss']}, engine.state.iteration) 70 | 71 | # adding handlers using `trainer.on` decorator API 72 | @trainer.on(Events.EPOCH_COMPLETED) 73 | def log_training_results(engine): 74 | evaluator.run(train_loader) 75 | metrics = evaluator.state.metrics 76 | mean_iu = metrics['mean_iu'] 77 | avg_loss = metrics['loss'] 78 | logger.info("Training Results - Epoch: {} Mean IU: {:.3f} Avg Loss: {:.3f}" 79 | .format(engine.state.epoch, mean_iu, avg_loss)) 80 | writer.add_scalars("mean_iu", {'train': mean_iu}, engine.state.epoch) 81 | 82 | if val_loader is not None: 83 | # adding handlers using `trainer.on` decorator API 84 | @trainer.on(Events.EPOCH_COMPLETED) 85 | def log_validation_results(engine): 86 | evaluator.run(val_loader) 87 | metrics = evaluator.state.metrics 88 | mean_iu = metrics['mean_iu'] 89 | avg_loss = metrics['loss'] 90 | logger.info("Validation Results - Epoch: {} Mean IU: {:.3f} Avg Loss: {:.3f}" 91 | .format(engine.state.epoch, mean_iu, avg_loss) 92 | ) 93 | writer.add_scalars("loss", {'validation': avg_loss}, engine.state.iteration) 94 | writer.add_scalars("mean_iu", {'validation': mean_iu}, engine.state.epoch) 95 | 96 | # adding handlers using `trainer.on` decorator API 97 | @trainer.on(Events.EPOCH_COMPLETED) 98 | def print_times(engine): 99 | logger.info('Epoch {} done. Time per batch: {:.3f}[s] Speed: {:.1f}[samples/s]' 100 | .format(engine.state.epoch, timer.value() * timer.step_count, 101 | train_loader.batch_size / timer.value())) 102 | timer.reset() 103 | 104 | @trainer.on(Events.EPOCH_COMPLETED) 105 | def plot_output(engine): 106 | model.eval() 107 | dataset = val_loader.dataset 108 | idx = np.random.choice(np.arange(len(dataset)), size=1).item() 109 | val_x, val_y = dataset[idx] 110 | val_x = val_x.to(device) 111 | with torch.no_grad(): 112 | pred_y = model(val_x.unsqueeze(0)) 113 | 114 | orig_img, val_y = untransform(val_x.cpu().data, val_y) 115 | pred_y = pred_y.max(1)[1].cpu().data[0].numpy() 116 | pred_val = cm[pred_y] 117 | seg_val = cm[val_y] 118 | 119 | # matplotlib 120 | fig = plt.figure(figsize=(9, 3)) 121 | plt.subplot(131) 122 | plt.imshow(orig_img) 123 | plt.axis("off") 124 | 125 | plt.subplot(132) 126 | plt.imshow(seg_val) 127 | plt.axis("off") 128 | 129 | plt.subplot(133) 130 | plt.imshow(pred_val) 131 | plt.axis("off") 132 | writer.add_figure('show_result', fig, engine.state.iteration) 133 | 134 | trainer.run(train_loader, max_epochs=epochs) 135 | writer.close() 136 | -------------------------------------------------------------------------------- /get_data.sh: -------------------------------------------------------------------------------- 1 | wget http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar 2 | 3 | if [ ! -e ./dataset ]; then 4 | mkdir ./dataset 5 | fi 6 | 7 | tar -xf VOCtrainval_11-May-2012.tar -C ./dataset 8 | rm VOCtrainval_11-May-2012.tar -------------------------------------------------------------------------------- /layers/bilinear_upsample.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import numpy as np 8 | import torch 9 | from torch import nn 10 | 11 | 12 | def get_upsampling_weight(in_channels, out_channels, kernel_size): 13 | """ 14 | Make a 2D bilinear kernel suitable for unsampling 15 | """ 16 | factor = (kernel_size + 1) // 2 17 | if kernel_size % 2 == 1: 18 | center = factor - 1 19 | else: 20 | center = factor - 0.5 21 | og = np.ogrid[:kernel_size, :kernel_size] 22 | bilinear_filter = (1 - abs(og[0] - center) / factor) * (1 - abs(og[1] - center) / factor) 23 | weight = np.zeros((in_channels, out_channels, kernel_size, kernel_size), dtype=np.float32) 24 | weight[range(in_channels), range(out_channels), :, :] = bilinear_filter 25 | return torch.from_numpy(weight).float() 26 | 27 | 28 | def bilinear_upsampling(in_channels, out_channels, kernel_size, stride, bias=False): 29 | initial_weight = get_upsampling_weight(in_channels, out_channels, kernel_size) 30 | layer = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, bias=bias) 31 | layer.weight.data.copy_(initial_weight) 32 | # weight is frozen because it's just a bilinear upsampling 33 | layer.weight.requires_grad = False 34 | return layer 35 | -------------------------------------------------------------------------------- /layers/conv_layer.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | from torch import nn 7 | 8 | 9 | def conv_layer(in_channels, out_channles, kernel_size, stride=1, padding=0, bias=True): 10 | layer = nn.Conv2d(in_channels, out_channles, kernel_size, stride, padding, bias=bias) 11 | layer.weight.data.zero_() 12 | if bias: 13 | layer.bias.data.zero_() 14 | return layer 15 | -------------------------------------------------------------------------------- /layers/cross_entropy2d.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | import torch.nn.functional as F 7 | 8 | 9 | def cross_entropy2d(input, target, weight=None, size_average=True): 10 | # input: (n, c, h, w), target: (n, h, w) 11 | n, c, h, w = input.size() 12 | # log_p: (n, c, h, w) 13 | log_p = F.log_softmax(input, dim=1) 14 | # log_p: (n*h*w, c) 15 | log_p = log_p.transpose(1, 2).transpose(2, 3).contiguous() 16 | log_p = log_p[target.view(n, h, w, 1).repeat(1, 1, 1, c) >= 0] 17 | log_p = log_p.view(-1, c) 18 | # target: (n*h*w,) 19 | mask = target >= 0 20 | target = target[mask] 21 | loss = F.nll_loss(log_p, target, weight=weight, reduction='sum') 22 | if size_average: 23 | loss /= mask.data.sum() 24 | return loss 25 | -------------------------------------------------------------------------------- /modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import torch 8 | 9 | from .backbone.vgg import pretrained_vgg 10 | from .fcn16s import FCN16s 11 | from .fcn32s import FCN32s 12 | from .fcn8s import FCN8s 13 | 14 | _FCN_META_ARCHITECTURE = {'fcn32s': FCN32s, 15 | 'fcn16s': FCN16s, 16 | 'fcn8s': FCN8s} 17 | 18 | 19 | def build_fcn_model(cfg): 20 | meta_arch = _FCN_META_ARCHITECTURE[cfg.MODEL.META_ARCHITECTURE] 21 | model = meta_arch(cfg) 22 | if cfg.MODEL.BACKBONE.PRETRAINED: 23 | vgg16 = pretrained_vgg(cfg) 24 | model.copy_params_from_vgg16(vgg16) 25 | if cfg.MODEL.REFINEMENT.NAME == 'fcn32s': 26 | fcn32s = FCN32s(cfg) 27 | fcn32s.load_state_dict(torch.load(cfg.MODEL.REFINEMENT.WEIGHT)) 28 | model.copy_params_from_fcn32s(fcn32s) 29 | elif cfg.MODEL.REFINEMENT.NAME == 'fcn16s': 30 | fcn16s = FCN16s(cfg) 31 | fcn16s.load_state_dict(torch.load(cfg.MODEL.REFINEMENT.WEIGHT)) 32 | model.copy_params_from_fcn16s(fcn16s) 33 | return model 34 | -------------------------------------------------------------------------------- /modeling/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | from .vgg import VGG16 7 | 8 | 9 | def build_backbone(cfg): 10 | if cfg.MODEL.BACKBONE.NAME == 'vgg16': 11 | backbone = VGG16() 12 | return backbone 13 | -------------------------------------------------------------------------------- /modeling/backbone/vgg.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | import torch 7 | import torchvision 8 | from torch import nn 9 | 10 | 11 | class VGG16(nn.Module): 12 | def __init__(self): 13 | super(VGG16, self).__init__() 14 | self.conv1_1 = nn.Conv2d(3, 64, 3, padding=100) 15 | self.relu1_1 = nn.ReLU(inplace=True) 16 | self.conv1_2 = nn.Conv2d(64, 64, 3, padding=1) 17 | self.relu1_2 = nn.ReLU(inplace=True) 18 | self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/2 19 | 20 | # conv2 21 | self.conv2_1 = nn.Conv2d(64, 128, 3, padding=1) 22 | self.relu2_1 = nn.ReLU(inplace=True) 23 | self.conv2_2 = nn.Conv2d(128, 128, 3, padding=1) 24 | self.relu2_2 = nn.ReLU(inplace=True) 25 | self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/4 26 | 27 | # conv3 28 | self.conv3_1 = nn.Conv2d(128, 256, 3, padding=1) 29 | self.relu3_1 = nn.ReLU(inplace=True) 30 | self.conv3_2 = nn.Conv2d(256, 256, 3, padding=1) 31 | self.relu3_2 = nn.ReLU(inplace=True) 32 | self.conv3_3 = nn.Conv2d(256, 256, 3, padding=1) 33 | self.relu3_3 = nn.ReLU(inplace=True) 34 | self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/8 35 | 36 | # conv4 37 | self.conv4_1 = nn.Conv2d(256, 512, 3, padding=1) 38 | self.relu4_1 = nn.ReLU(inplace=True) 39 | self.conv4_2 = nn.Conv2d(512, 512, 3, padding=1) 40 | self.relu4_2 = nn.ReLU(inplace=True) 41 | self.conv4_3 = nn.Conv2d(512, 512, 3, padding=1) 42 | self.relu4_3 = nn.ReLU(inplace=True) 43 | self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/16 44 | 45 | # conv5 46 | self.conv5_1 = nn.Conv2d(512, 512, 3, padding=1) 47 | self.relu5_1 = nn.ReLU(inplace=True) 48 | self.conv5_2 = nn.Conv2d(512, 512, 3, padding=1) 49 | self.relu5_2 = nn.ReLU(inplace=True) 50 | self.conv5_3 = nn.Conv2d(512, 512, 3, padding=1) 51 | self.relu5_3 = nn.ReLU(inplace=True) 52 | self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/32 53 | 54 | def forward(self, x): 55 | x = self.relu1_1(self.conv1_1(x)) 56 | x = self.relu1_2(self.conv1_2(x)) 57 | x = self.pool1(x) 58 | 59 | x = self.relu2_1(self.conv2_1(x)) 60 | x = self.relu2_2(self.conv2_2(x)) 61 | x = self.pool2(x) 62 | 63 | x = self.relu3_1(self.conv3_1(x)) 64 | x = self.relu3_2(self.conv3_2(x)) 65 | x = self.relu3_3(self.conv3_3(x)) 66 | x = self.pool3(x) 67 | 68 | x = self.relu4_1(self.conv4_1(x)) 69 | x = self.relu4_2(self.conv4_2(x)) 70 | x = self.relu4_3(self.conv4_3(x)) 71 | x = self.pool4(x) 72 | 73 | x = self.relu5_1(self.conv5_1(x)) 74 | x = self.relu5_2(self.conv5_2(x)) 75 | x = self.relu5_3(self.conv5_3(x)) 76 | x = self.pool5(x) 77 | return x 78 | 79 | 80 | def pretrained_vgg(cfg): 81 | model = torchvision.models.vgg16(pretrained=False) 82 | model.load_state_dict(torch.load(cfg.MODEL.BACKBONE.WEIGHT)) 83 | return model 84 | -------------------------------------------------------------------------------- /modeling/fcn16s.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from torch import nn 8 | 9 | from layers.bilinear_upsample import bilinear_upsampling 10 | from layers.conv_layer import conv_layer 11 | from .backbone import build_backbone 12 | 13 | 14 | class FCN16s(nn.Module): 15 | def __init__(self, cfg): 16 | super(FCN16s, self).__init__() 17 | self.backbone = build_backbone(cfg) 18 | num_classes = cfg.MODEL.NUM_CLASSES 19 | 20 | # fc1 21 | self.fc1 = conv_layer(512, 4096, 7) 22 | self.relu1 = nn.ReLU(inplace=True) 23 | self.drop1 = nn.Dropout2d() 24 | 25 | # fc2 26 | self.fc2 = conv_layer(4096, 4096, 1) 27 | self.relu2 = nn.ReLU(inplace=True) 28 | self.drop2 = nn.Dropout2d() 29 | 30 | self.score_fr = conv_layer(4096, num_classes, 1) 31 | self.score_pool4 = conv_layer(512, num_classes, 1) 32 | 33 | self.upscore2 = bilinear_upsampling(num_classes, num_classes, 4, stride=2, bias=False) 34 | self.upscore16 = bilinear_upsampling(num_classes, num_classes, 32, stride=16, bias=False) 35 | 36 | def forward(self, x): 37 | _, _, h, w = x.size() 38 | x = self.backbone.conv1_1(x) 39 | x = self.backbone.relu1_1(x) 40 | x = self.backbone.conv1_2(x) 41 | x = self.backbone.relu1_2(x) 42 | x = self.backbone.pool1(x) 43 | 44 | x = self.backbone.conv2_1(x) 45 | x = self.backbone.relu2_1(x) 46 | x = self.backbone.conv2_2(x) 47 | x = self.backbone.relu2_2(x) 48 | x = self.backbone.pool2(x) 49 | 50 | x = self.backbone.conv3_1(x) 51 | x = self.backbone.relu3_1(x) 52 | x = self.backbone.conv3_2(x) 53 | x = self.backbone.relu3_2(x) 54 | x = self.backbone.conv3_3(x) 55 | x = self.backbone.relu3_3(x) 56 | x = self.backbone.pool3(x) 57 | 58 | x = self.backbone.conv4_1(x) 59 | x = self.backbone.relu4_1(x) 60 | x = self.backbone.conv4_2(x) 61 | x = self.backbone.relu4_2(x) 62 | x = self.backbone.conv4_3(x) 63 | x = self.backbone.relu4_3(x) 64 | x = self.backbone.pool4(x) 65 | pool4 = x # 1/16 66 | 67 | x = self.backbone.conv5_1(x) 68 | x = self.backbone.relu5_1(x) 69 | x = self.backbone.conv5_2(x) 70 | x = self.backbone.relu5_2(x) 71 | x = self.backbone.conv5_3(x) 72 | x = self.backbone.relu5_3(x) 73 | x = self.backbone.pool5(x) 74 | 75 | x = self.relu1(self.fc1(x)) 76 | x = self.drop1(x) 77 | 78 | x = self.relu2(self.fc2(x)) 79 | x = self.drop2(x) 80 | 81 | x = self.score_fr(x) 82 | x = self.upscore2(x) 83 | upscore2 = x 84 | 85 | x = self.score_pool4(pool4) 86 | x = x[:, :, 5:5 + upscore2.size()[2], 5:5 + upscore2.size()[3]] 87 | score_pool4c = x # 1/16 88 | 89 | x = upscore2 + score_pool4c 90 | 91 | x = self.upscore16(x) 92 | x = x[:, :, 27:27 + h, 27:27 + w].contiguous() 93 | return x 94 | 95 | def copy_params_from_fcn32s(self, fcn32s): 96 | # load backbone 97 | self.backbone.load_state_dict(fcn32s.backbone.state_dict()) 98 | for name, l1 in fcn32s.named_children(): 99 | try: 100 | l2 = getattr(self, name) 101 | l2.weight # skip ReLU / Dropout 102 | except AttributeError: 103 | continue 104 | assert l1.weight.size() == l2.weight.size() 105 | l2.weight.data.copy_(l1.weight.data) 106 | if l1.bias is not None: 107 | assert l1.bias.size() == l2.bias.size() 108 | l2.bias.data.copy_(l1.bias.data) 109 | 110 | 111 | -------------------------------------------------------------------------------- /modeling/fcn32s.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from torch import nn 8 | 9 | from layers.bilinear_upsample import bilinear_upsampling 10 | from layers.conv_layer import conv_layer 11 | from .backbone import build_backbone 12 | 13 | 14 | class FCN32s(nn.Module): 15 | def __init__(self, cfg): 16 | super(FCN32s, self).__init__() 17 | self.backbone = build_backbone(cfg) 18 | num_classes = cfg.MODEL.NUM_CLASSES 19 | 20 | self.fc1 = conv_layer(512, 4096, 7) 21 | self.relu1 = nn.ReLU(inplace=True) 22 | self.drop1 = nn.Dropout2d() 23 | 24 | self.fc2 = conv_layer(4096, 4096, 1) 25 | self.relu2 = nn.ReLU(inplace=True) 26 | self.drop2 = nn.Dropout2d() 27 | 28 | self.score_fr = conv_layer(4096, num_classes, 1) 29 | self.upscore = bilinear_upsampling(num_classes, num_classes, 64, stride=32, 30 | bias=False) 31 | 32 | def forward(self, x): 33 | _, _, h, w = x.size() 34 | x = self.backbone(x) 35 | x = self.relu1(self.fc1(x)) 36 | x = self.drop1(x) 37 | 38 | x = self.relu2(self.fc2(x)) 39 | x = self.drop2(x) 40 | 41 | x = self.score_fr(x) 42 | x = self.upscore(x) 43 | x = x[:, :, 19:19 + h, 19:19 + w].contiguous() 44 | return x 45 | 46 | def copy_params_from_vgg16(self, vgg16): 47 | feat = self.backbone 48 | features = [ 49 | feat.conv1_1, feat.relu1_1, 50 | feat.conv1_2, feat.relu1_2, 51 | feat.pool1, 52 | feat.conv2_1, feat.relu2_1, 53 | feat.conv2_2, feat.relu2_2, 54 | feat.pool2, 55 | feat.conv3_1, feat.relu3_1, 56 | feat.conv3_2, feat.relu3_2, 57 | feat.conv3_3, feat.relu3_3, 58 | feat.pool3, 59 | feat.conv4_1, feat.relu4_1, 60 | feat.conv4_2, feat.relu4_2, 61 | feat.conv4_3, feat.relu4_3, 62 | feat.pool4, 63 | feat.conv5_1, feat.relu5_1, 64 | feat.conv5_2, feat.relu5_2, 65 | feat.conv5_3, feat.relu5_3, 66 | feat.pool5 67 | ] 68 | 69 | for l1, l2 in zip(vgg16.features, features): 70 | if isinstance(l1, nn.Conv2d) and isinstance(l2, nn.Conv2d): 71 | assert l1.weight.size() == l2.weight.size() 72 | assert l1.bias.size() == l2.bias.size() 73 | l2.weight.data.copy_(l1.weight.data) 74 | l2.bias.data.copy_(l1.bias.data) 75 | for i, name in zip([0, 3], ['fc1', 'fc2']): 76 | l1 = vgg16.classifier[i] 77 | l2 = getattr(self, name) 78 | l2.weight.data.copy_(l1.weight.data.view(l2.weight.size())) 79 | l2.bias.data.copy_(l1.bias.data.view(l2.bias.size())) 80 | -------------------------------------------------------------------------------- /modeling/fcn8s.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from torch import nn 8 | 9 | from layers.bilinear_upsample import bilinear_upsampling 10 | from layers.conv_layer import conv_layer 11 | from .backbone import build_backbone 12 | 13 | 14 | class FCN8s(nn.Module): 15 | def __init__(self, cfg): 16 | super(FCN8s, self).__init__() 17 | self.backbone = build_backbone(cfg) 18 | num_classes = cfg.MODEL.NUM_CLASSES 19 | 20 | # fc1 21 | self.fc1 = conv_layer(512, 4096, 7) 22 | self.relu1 = nn.ReLU(inplace=True) 23 | self.drop1 = nn.Dropout2d() 24 | 25 | # fc2 26 | self.fc2 = conv_layer(4096, 4096, 1) 27 | self.relu2 = nn.ReLU(inplace=True) 28 | self.drop2 = nn.Dropout2d() 29 | 30 | self.score_fr = conv_layer(4096, num_classes, 1) 31 | self.score_pool3 = conv_layer(256, num_classes, 1) 32 | self.score_pool4 = conv_layer(512, num_classes, 1) 33 | 34 | self.upscore2 = bilinear_upsampling(num_classes, num_classes, 4, stride=2, bias=False) 35 | self.upscore8 = bilinear_upsampling(num_classes, num_classes, 16, stride=8, bias=False) 36 | self.upscore_pool4 = bilinear_upsampling(num_classes, num_classes, 4, stride=2, bias=False) 37 | 38 | def forward(self, x): 39 | _, _, h, w = x.size() 40 | x = self.backbone.conv1_1(x) 41 | x = self.backbone.relu1_1(x) 42 | x = self.backbone.conv1_2(x) 43 | x = self.backbone.relu1_2(x) 44 | x = self.backbone.pool1(x) 45 | 46 | x = self.backbone.conv2_1(x) 47 | x = self.backbone.relu2_1(x) 48 | x = self.backbone.conv2_2(x) 49 | x = self.backbone.relu2_2(x) 50 | x = self.backbone.pool2(x) 51 | 52 | x = self.backbone.conv3_1(x) 53 | x = self.backbone.relu3_1(x) 54 | x = self.backbone.conv3_2(x) 55 | x = self.backbone.relu3_2(x) 56 | x = self.backbone.conv3_3(x) 57 | x = self.backbone.relu3_3(x) 58 | x = self.backbone.pool3(x) 59 | pool3 = x # 1/8 60 | 61 | x = self.backbone.conv4_1(x) 62 | x = self.backbone.relu4_1(x) 63 | x = self.backbone.conv4_2(x) 64 | x = self.backbone.relu4_2(x) 65 | x = self.backbone.conv4_3(x) 66 | x = self.backbone.relu4_3(x) 67 | x = self.backbone.pool4(x) 68 | pool4 = x # 1/16 69 | 70 | x = self.backbone.conv5_1(x) 71 | x = self.backbone.relu5_1(x) 72 | x = self.backbone.conv5_2(x) 73 | x = self.backbone.relu5_2(x) 74 | x = self.backbone.conv5_3(x) 75 | x = self.backbone.relu5_3(x) 76 | x = self.backbone.pool5(x) 77 | 78 | x = self.relu1(self.fc1(x)) 79 | x = self.drop1(x) 80 | 81 | x = self.relu2(self.fc2(x)) 82 | x = self.drop2(x) 83 | 84 | x = self.score_fr(x) 85 | x = self.upscore2(x) 86 | upscore2 = x # 1/16 87 | 88 | x = self.score_pool4(pool4) 89 | x = x[:, :, 5:5 + upscore2.size()[2], 5:5 + upscore2.size()[3]] 90 | score_pool4c = x # 1/16 91 | 92 | x = upscore2 + score_pool4c 93 | x = self.upscore_pool4(x) 94 | upscore_pool4 = x # 1/8 95 | 96 | x = self.score_pool3(pool3) 97 | x = x[:, :, 9:9 + upscore_pool4.size()[2], 9:9 + upscore_pool4.size()[3]].contiguous() 98 | score_pool3c = x # 1/8 99 | 100 | x = upscore_pool4 + score_pool3c # 1/8 101 | 102 | x = self.upscore8(x) 103 | x = x[:, :, 31:31 + h, 31:31 + w].contiguous() 104 | return x 105 | 106 | def copy_params_from_fcn16s(self, fcn16s): 107 | self.backbone.load_state_dict(fcn16s.backbone.state_dict()) 108 | for name, l1 in fcn16s.named_children(): 109 | try: 110 | l2 = getattr(self, name) 111 | l2.weight # skip ReLU / Dropout 112 | except AttributeError: 113 | continue 114 | assert l1.weight.size() == l2.weight.size() 115 | l2.weight.data.copy_(l1.weight.data) 116 | if l1.bias is not None: 117 | assert l1.bias.size() == l2.bias.size() 118 | l2.bias.data.copy_(l1.bias.data) 119 | 120 | def copy_params_from_vgg16(self, vgg16): 121 | feat = self.backbone 122 | features = [ 123 | feat.conv1_1, feat.relu1_1, 124 | feat.conv1_2, feat.relu1_2, 125 | feat.pool1, 126 | feat.conv2_1, feat.relu2_1, 127 | feat.conv2_2, feat.relu2_2, 128 | feat.pool2, 129 | feat.conv3_1, feat.relu3_1, 130 | feat.conv3_2, feat.relu3_2, 131 | feat.conv3_3, feat.relu3_3, 132 | feat.pool3, 133 | feat.conv4_1, feat.relu4_1, 134 | feat.conv4_2, feat.relu4_2, 135 | feat.conv4_3, feat.relu4_3, 136 | feat.pool4, 137 | feat.conv5_1, feat.relu5_1, 138 | feat.conv5_2, feat.relu5_2, 139 | feat.conv5_3, feat.relu5_3, 140 | feat.pool5 141 | ] 142 | 143 | for l1, l2 in zip(vgg16.features, features): 144 | if isinstance(l1, nn.Conv2d) and isinstance(l2, nn.Conv2d): 145 | assert l1.weight.size() == l2.weight.size() 146 | assert l1.bias.size() == l2.bias.size() 147 | l2.weight.data.copy_(l1.weight.data) 148 | l2.bias.data.copy_(l1.bias.data) 149 | for i, name in zip([0, 3], ['fc1', 'fc2']): 150 | l1 = vgg16.classifier[i] 151 | l2 = getattr(self, name) 152 | l2.weight.data.copy_(l1.weight.data.view(l2.weight.size())) 153 | l2.bias.data.copy_(l1.bias.data.view(l2.bias.size())) 154 | -------------------------------------------------------------------------------- /solver/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from .build import make_optimizer 8 | -------------------------------------------------------------------------------- /solver/build.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import torch 8 | 9 | 10 | def make_optimizer(cfg, model): 11 | params = [] 12 | for key, value in model.named_parameters(): 13 | if not value.requires_grad: 14 | continue 15 | lr = cfg.SOLVER.BASE_LR 16 | weight_decay = cfg.SOLVER.WEIGHT_DECAY 17 | if "bias" in key: 18 | lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR 19 | weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS 20 | params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}] 21 | if cfg.SOLVER.OPTIMIZER_NAME == 'SGD': 22 | optimizer = torch.optim.SGD(params, lr, momentum=cfg.SOLVER.MOMENTUM) 23 | return optimizer 24 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | -------------------------------------------------------------------------------- /tests/test_dataset.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import sys 8 | import unittest 9 | 10 | sys.path.append('.') 11 | from config import cfg 12 | from data.transforms import build_transforms 13 | from data.build import build_dataset 14 | from solver.build import make_optimizer 15 | 16 | 17 | class TestDataSet(unittest.TestCase): 18 | def test_dataset(self): 19 | train_transform = build_transforms(cfg, True) 20 | val_transform = build_transforms(cfg, False) 21 | train_set = build_dataset(cfg, train_transform, True) 22 | val_test = build_dataset(cfg, val_transform, False) 23 | from IPython import embed; 24 | embed() 25 | 26 | 27 | if __name__ == '__main__': 28 | unittest.main() 29 | -------------------------------------------------------------------------------- /tests/test_model.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import unittest 3 | 4 | sys.path.append('.') 5 | from modeling.backbone.vgg import VGG16 6 | from config import cfg 7 | from modeling import build_fcn_model 8 | from modeling.backbone import build_backbone 9 | import torch 10 | 11 | 12 | class MyTestCase(unittest.TestCase): 13 | def test_vgg(self): 14 | vgg = build_backbone(cfg) 15 | model = build_fcn_model(cfg) 16 | print(model.backbone.conv1_1.weight[0, 0, 0, 0]) 17 | # x = torch.randn(5, 3, 224, 224) 18 | # y = model(x) 19 | from IPython import embed; 20 | embed() 21 | 22 | 23 | if __name__ == '__main__': 24 | unittest.main() 25 | -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | -------------------------------------------------------------------------------- /tools/test_fcn.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import argparse 8 | import os 9 | import sys 10 | from os import mkdir 11 | 12 | import torch 13 | 14 | sys.path.append('.') 15 | from config import cfg 16 | from data import make_data_loader 17 | from engine.inference import inference 18 | from modeling import build_fcn_model 19 | from utils.logger import setup_logger 20 | 21 | 22 | def main(): 23 | parser = argparse.ArgumentParser(description="PyTorch FCN Inference") 24 | parser.add_argument( 25 | "--config_file", default="", help="path to config file", type=str 26 | ) 27 | parser.add_argument("opts", help="Modify config options using the command-line", default=None, 28 | nargs=argparse.REMAINDER) 29 | 30 | args = parser.parse_args() 31 | 32 | num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 33 | 34 | if args.config_file != "": 35 | cfg.merge_from_file(args.config_file) 36 | cfg.merge_from_list(args.opts) 37 | cfg.freeze() 38 | 39 | output_dir = cfg.OUTPUT_DIR 40 | if output_dir and not os.path.exists(output_dir): 41 | mkdir(output_dir) 42 | 43 | logger = setup_logger("FCN_Model", output_dir, 0) 44 | logger.info("Using {} GPUS".format(num_gpus)) 45 | logger.info(args) 46 | 47 | if args.config_file != "": 48 | logger.info("Loaded configuration file {}".format(args.config_file)) 49 | with open(args.config_file, 'r') as cf: 50 | config_str = "\n" + cf.read() 51 | logger.info(config_str) 52 | logger.info("Running with config:\n{}".format(cfg)) 53 | 54 | model = build_fcn_model(cfg) 55 | model.load_state_dict(torch.load(cfg.TEST.WEIGHT)) 56 | val_loader = make_data_loader(cfg, is_train=False) 57 | 58 | inference(cfg, model, val_loader) 59 | 60 | 61 | if __name__ == '__main__': 62 | main() 63 | -------------------------------------------------------------------------------- /tools/train_fcn.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import argparse 8 | import os 9 | import sys 10 | from os import mkdir 11 | 12 | sys.path.append('.') 13 | from config import cfg 14 | from data import make_data_loader 15 | from engine.trainer import do_train 16 | from modeling import build_fcn_model 17 | from solver import make_optimizer 18 | from utils.logger import setup_logger 19 | from layers.cross_entropy2d import cross_entropy2d 20 | 21 | 22 | def train(cfg): 23 | model = build_fcn_model(cfg) 24 | 25 | optimizer = make_optimizer(cfg, model) 26 | 27 | arguments = {} 28 | 29 | data_loader = make_data_loader(cfg, is_train=True) 30 | val_loader = make_data_loader(cfg, is_train=False) 31 | 32 | do_train( 33 | cfg, 34 | model, 35 | data_loader, 36 | val_loader, 37 | optimizer, 38 | cross_entropy2d, 39 | ) 40 | 41 | 42 | def main(): 43 | parser = argparse.ArgumentParser(description="PyTorch FCN Training") 44 | parser.add_argument( 45 | "--config_file", default="", help="path to config file", type=str 46 | ) 47 | parser.add_argument("opts", help="Modify config options using the command-line", default=None, 48 | nargs=argparse.REMAINDER) 49 | 50 | args = parser.parse_args() 51 | 52 | num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 53 | 54 | if args.config_file != "": 55 | cfg.merge_from_file(args.config_file) 56 | cfg.merge_from_list(args.opts) 57 | cfg.freeze() 58 | 59 | output_dir = cfg.OUTPUT_DIR 60 | if output_dir and not os.path.exists(output_dir): 61 | mkdir(output_dir) 62 | 63 | logger = setup_logger("FCN_Model", output_dir, 0) 64 | logger.info("Using {} GPUS".format(num_gpus)) 65 | logger.info(args) 66 | 67 | if args.config_file != "": 68 | logger.info("Loaded configuration file {}".format(args.config_file)) 69 | with open(args.config_file, 'r') as cf: 70 | config_str = "\n" + cf.read() 71 | logger.info(config_str) 72 | logger.info("Running with config:\n{}".format(cfg)) 73 | 74 | train(cfg) 75 | 76 | 77 | if __name__ == '__main__': 78 | main() 79 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import logging 8 | import os 9 | import sys 10 | 11 | 12 | def setup_logger(name, save_dir, distributed_rank): 13 | logger = logging.getLogger(name) 14 | logger.setLevel(logging.DEBUG) 15 | # don't log results for the non-master process 16 | if distributed_rank > 0: 17 | return logger 18 | ch = logging.StreamHandler(stream=sys.stdout) 19 | ch.setLevel(logging.DEBUG) 20 | formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s") 21 | ch.setFormatter(formatter) 22 | logger.addHandler(ch) 23 | 24 | if save_dir: 25 | fh = logging.FileHandler(os.path.join(save_dir, "log.txt"), mode='w') 26 | fh.setLevel(logging.DEBUG) 27 | fh.setFormatter(formatter) 28 | logger.addHandler(fh) 29 | 30 | return logger 31 | -------------------------------------------------------------------------------- /utils/metric.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | import numpy as np 7 | from ignite.metrics import Metric 8 | 9 | 10 | def _fast_hist(label_true, label_pred, n_class): 11 | mask = (label_true >= 0) & (label_true < n_class) 12 | hist = np.bincount( 13 | n_class * label_true[mask].astype(int) + 14 | label_pred[mask], minlength=n_class ** 2).reshape(n_class, n_class) 15 | return hist 16 | 17 | 18 | class Label_Accuracy(Metric): 19 | """ 20 | Calculates the accuracy. 21 | 22 | - `update` must receive output of the form `(y_pred, y)`. 23 | - `y_pred` must be in the following shape (batch_size, num_categories, ...) or (batch_size, ...) 24 | - `y` must be in the following shape (batch_size, ...) 25 | """ 26 | 27 | def __init__(self, n_class): 28 | super(Label_Accuracy, self).__init__() 29 | self.n_class = n_class 30 | 31 | def reset(self): 32 | self.step = 0 33 | self.mean_iu = 0 34 | 35 | def update(self, output): 36 | label_preds, label_trues = output 37 | label_preds = label_preds.max(dim=1)[1].data.cpu().numpy() 38 | label_preds = [i for i in label_preds] 39 | 40 | label_trues = label_trues.data.cpu().numpy() 41 | label_trues = [i for i in label_trues] 42 | 43 | hist = np.zeros((self.n_class, self.n_class)) 44 | for lt, lp in zip(label_trues, label_preds): 45 | hist += _fast_hist(lt.flatten(), lp.flatten(), self.n_class) 46 | with np.errstate(divide='ignore', invalid='ignore'): 47 | iu = np.diag(hist) / ( 48 | hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist) 49 | ) 50 | mean_iu = np.nanmean(iu) 51 | self.mean_iu += mean_iu 52 | self.step += 1 53 | 54 | def compute(self): 55 | return self.mean_iu / self.step 56 | --------------------------------------------------------------------------------