├── .gitignore
├── README.md
├── assets
├── 1.png
├── 2.png
├── 3.png
├── 4.png
├── 5.png
├── 6.png
└── 7.png
├── config
├── __init__.py
└── defaults.py
├── configs
├── test_fcn16s.yml
├── test_fcn32s.yml
├── test_fcn8s.yml
├── test_fcn8s_atonce.yml
├── train_fcn16s.yml
├── train_fcn32s.yml
├── train_fcn8s.yml
└── train_fcn8s_atonce.yml
├── data
├── __init__.py
├── build.py
├── datasets
│ ├── __init__.py
│ └── voc.py
└── transforms
│ ├── __init__.py
│ ├── build.py
│ └── transforms.py
├── engine
├── inference.py
└── trainer.py
├── get_data.sh
├── layers
├── bilinear_upsample.py
├── conv_layer.py
└── cross_entropy2d.py
├── modeling
├── __init__.py
├── backbone
│ ├── __init__.py
│ └── vgg.py
├── fcn16s.py
├── fcn32s.py
└── fcn8s.py
├── solver
├── __init__.py
└── build.py
├── tests
├── __init__.py
├── test_dataset.py
└── test_model.py
├── tools
├── __init__.py
├── test_fcn.py
└── train_fcn.py
└── utils
├── __init__.py
├── logger.py
└── metric.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea
2 | __pycache__
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # fcn.pytorch
2 |
3 | PyTorch implementation of [Fully Convolutional Networks](https://github.com/shelhamer/fcn.berkeleyvision.org), main code modified from [pytorch-fcn](https://github.com/wkentaro/pytorch-fcn).
4 |
5 | ### Requirements
6 | - pytorch
7 | - torchvision
8 | - [ignite](https://github.com/pytorch/ignite)
9 | - [yacs](https://github.com/rbgirshick/yacs)
10 | - [tensorboardX](https://github.com/lanpa/tensorboardX)
11 | - tensorflow (for tensorboard)
12 |
13 | ### Get Started
14 | The designed architecture follows this guide [PyTorch-Project-Template](https://github.com/L1aoXingyu/PyTorch-Project-Template), you can check each folder's purpose by yourself.
15 |
16 | #### Prepare Dataset
17 | You can open the terminal and run the bash command to get VOC2012 dataset
18 |
19 | ```bash
20 | bash get_data.sh
21 | ```
22 |
23 | or you can just copy this url download by yourself
24 |
25 | ```bash
26 | http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar
27 | ```
28 |
29 | ### Training
30 | Most of the configuration files that we provide are in folder `configs`. You just need to modify `dataset root`, `vgg model weight` and `output directory`. There are a few possibilities:
31 |
32 | #### 1. Modify configuration file and run
33 | You can modify `train_fcn32s.yml` first and run following code
34 |
35 | ```bash
36 | python3 tools/train_fcn.py --config_file='configs/train_fcn32s.yml'
37 | ```
38 |
39 | #### 2. Modify the cfg parameters
40 | You can change configuration parameter such as learning rate or max epochs in command line.
41 |
42 | ```bash
43 | python3 tools/train_fcn.py --config_file='configs/train_fcn32s.yml' SOLVER.BASE_LR 0.0025 SOLVER.MAX_EPOCHS 8
44 | ```
45 |
46 | ### Results
47 | We are training these models on VOC2012 train.txt and testing on val.txt, and we also use torchvision pretrained vgg16 rather than caffe pretrained. So the results maybe are different from the origin paper.
48 |
49 | |Model| Epoch | Mean IU |
50 | |-|-|-|
51 | | FCN32s| 13 | 55.1|
52 | | FCN16s| 8 | 54.8|
53 | | FCN8s | 7 | 55.7 |
54 | | FCN8sAtOnce | 11 | 53.6 |
55 |
56 |
57 |

58 |

59 |

60 |

61 |

62 |

63 |

64 |
65 |
--------------------------------------------------------------------------------
/assets/1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/L1aoXingyu/fcn.pytorch/7f592407f41325375baff5b3514567fb4c5f9a62/assets/1.png
--------------------------------------------------------------------------------
/assets/2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/L1aoXingyu/fcn.pytorch/7f592407f41325375baff5b3514567fb4c5f9a62/assets/2.png
--------------------------------------------------------------------------------
/assets/3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/L1aoXingyu/fcn.pytorch/7f592407f41325375baff5b3514567fb4c5f9a62/assets/3.png
--------------------------------------------------------------------------------
/assets/4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/L1aoXingyu/fcn.pytorch/7f592407f41325375baff5b3514567fb4c5f9a62/assets/4.png
--------------------------------------------------------------------------------
/assets/5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/L1aoXingyu/fcn.pytorch/7f592407f41325375baff5b3514567fb4c5f9a62/assets/5.png
--------------------------------------------------------------------------------
/assets/6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/L1aoXingyu/fcn.pytorch/7f592407f41325375baff5b3514567fb4c5f9a62/assets/6.png
--------------------------------------------------------------------------------
/assets/7.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/L1aoXingyu/fcn.pytorch/7f592407f41325375baff5b3514567fb4c5f9a62/assets/7.png
--------------------------------------------------------------------------------
/config/__init__.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | from .defaults import _C as cfg
8 |
--------------------------------------------------------------------------------
/config/defaults.py:
--------------------------------------------------------------------------------
1 | from yacs.config import CfgNode as CN
2 |
3 | # -----------------------------------------------------------------------------
4 | # Convention about Training / Test specific parameters
5 | # -----------------------------------------------------------------------------
6 | # Whenever an argument can be either used for training or for testing, the
7 | # corresponding name will be post-fixed by a _TRAIN for a training parameter,
8 | # or _TEST for a test-specific parameter.
9 | # For example, the number of images during training will be
10 | # IMAGES_PER_BATCH_TRAIN, while the number of images for testing will be
11 | # IMAGES_PER_BATCH_TEST
12 |
13 | # -----------------------------------------------------------------------------
14 | # Config definition
15 | # -----------------------------------------------------------------------------
16 |
17 | _C = CN()
18 |
19 | _C.MODEL = CN()
20 | _C.MODEL.DEVICE = "cuda"
21 | _C.MODEL.NUM_CLASSES = 21
22 |
23 | _C.MODEL.META_ARCHITECTURE = "fcn32s"
24 |
25 | _C.MODEL.BACKBONE = CN()
26 | _C.MODEL.BACKBONE.NAME = "vgg16"
27 | _C.MODEL.BACKBONE.PRETRAINED = False
28 | _C.MODEL.BACKBONE.WEIGHT = ""
29 |
30 | _C.MODEL.REFINEMENT = CN()
31 | _C.MODEL.REFINEMENT.NAME = ''
32 | _C.MODEL.REFINEMENT.WEIGHT = ''
33 |
34 | # -----------------------------------------------------------------------------
35 | # INPUT
36 | # -----------------------------------------------------------------------------
37 | _C.INPUT = CN()
38 | # Random probability for image horizontal flip
39 | _C.INPUT.PROB = 0.5
40 | # Values to be used for image normalization
41 | # _C.INPUT.PIXEL_MEAN = [104.00698793, 116.66876762, 122.67891434]
42 | _C.INPUT.PIXEL_MEAN = [0.485, 0.456, 0.406]
43 | # Values to be used for image normalization
44 | _C.INPUT.PIXEL_STD = [0.229, 0.224, 0.225]
45 |
46 | # -----------------------------------------------------------------------------
47 | # Dataset
48 | # -----------------------------------------------------------------------------
49 | _C.DATASETS = CN()
50 | # Dataset root path
51 | _C.DATASETS.ROOT = ''
52 | # -----------------------------------------------------------------------------
53 | # DataLoader
54 | # -----------------------------------------------------------------------------
55 | _C.DATALOADER = CN()
56 | # Number of data loading threads
57 | _C.DATALOADER.NUM_WORKERS = 8
58 |
59 | # ---------------------------------------------------------------------------- #
60 | # Solver
61 | # ---------------------------------------------------------------------------- #
62 | _C.SOLVER = CN()
63 | _C.SOLVER.OPTIMIZER_NAME = "SGD"
64 |
65 | _C.SOLVER.MAX_EPOCHS = 11
66 |
67 | _C.SOLVER.BASE_LR = 1.0e-4
68 | _C.SOLVER.BIAS_LR_FACTOR = 2
69 |
70 | _C.SOLVER.MOMENTUM = 0.99
71 |
72 | _C.SOLVER.WEIGHT_DECAY = 0.0005
73 | _C.SOLVER.WEIGHT_DECAY_BIAS = 0
74 |
75 | _C.SOLVER.CHECKPOINT_PERIOD = 10
76 | _C.SOLVER.LOG_PERIOD = 400
77 |
78 | # Number of images per batch
79 | # This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will
80 | # see 2 images per batch
81 | _C.SOLVER.IMS_PER_BATCH = 1
82 |
83 | # This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will
84 | # see 2 images per batch
85 | _C.TEST = CN()
86 | _C.TEST.IMS_PER_BATCH = 1
87 | _C.TEST.WEIGHT = ""
88 |
89 | # ---------------------------------------------------------------------------- #
90 | # Misc options
91 | # ---------------------------------------------------------------------------- #
92 | _C.OUTPUT_DIR = ""
93 |
--------------------------------------------------------------------------------
/configs/test_fcn16s.yml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | META_ARCHITECTURE: "fcn16s"
3 |
4 | BACKBONE:
5 | PRETRAINED: False
6 |
7 | REFINEMENT:
8 | NAME: ''
9 |
10 | DATASETS:
11 | ROOT: '/mnt/truenas/scratch/xingyu.liao/DATA/VOCdevkit/VOC2012'
12 |
13 |
14 | TEST:
15 | WEIGHT: '/mnt/truenas/scratch/xingyu.liao/checkpoints/train_fcn16s/fcn_model_8.pth'
16 |
17 | OUTPUT_DIR: "/mnt/truenas/scratch/xingyu.liao/checkpoints/inference_fcn16s"
18 |
19 |
--------------------------------------------------------------------------------
/configs/test_fcn32s.yml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | META_ARCHITECTURE: "fcn32s"
3 |
4 | BACKBONE:
5 | PRETRAINED: False
6 |
7 | REFINEMENT:
8 | NAME: ''
9 |
10 | DATASETS:
11 | ROOT: '/mnt/truenas/scratch/xingyu.liao/DATA/VOCdevkit/VOC2012'
12 |
13 |
14 | TEST:
15 | WEIGHT: '/mnt/truenas/scratch/xingyu.liao/checkpoints/train_fcn32s/fcn_model_13.pth'
16 |
17 | OUTPUT_DIR: "/mnt/truenas/scratch/xingyu.liao/checkpoints/inference_fcn32s"
18 |
19 |
--------------------------------------------------------------------------------
/configs/test_fcn8s.yml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | META_ARCHITECTURE: "fcn8s"
3 |
4 | BACKBONE:
5 | PRETRAINED: False
6 |
7 | REFINEMENT:
8 | NAME: ''
9 |
10 | DATASETS:
11 | ROOT: '/mnt/truenas/scratch/xingyu.liao/DATA/VOCdevkit/VOC2012'
12 |
13 |
14 | TEST:
15 | WEIGHT: '/mnt/truenas/scratch/xingyu.liao/checkpoints/train_fcn8s/fcn_model_7.pth'
16 |
17 | OUTPUT_DIR: "/mnt/truenas/scratch/xingyu.liao/checkpoints/inference_fcn8s"
18 |
19 |
--------------------------------------------------------------------------------
/configs/test_fcn8s_atonce.yml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | META_ARCHITECTURE: "fcn8s"
3 |
4 | BACKBONE:
5 | PRETRAINED: False
6 |
7 | REFINEMENT:
8 | NAME: ''
9 |
10 | DATASETS:
11 | ROOT: '/mnt/truenas/scratch/xingyu.liao/DATA/VOCdevkit/VOC2012'
12 |
13 |
14 | TEST:
15 | WEIGHT: '/mnt/truenas/scratch/xingyu.liao/checkpoints/train_fcn8s_atonce/fcn_model_13.pth'
16 |
17 | OUTPUT_DIR: "/mnt/truenas/scratch/xingyu.liao/checkpoints/inference_fcn8s_atonce"
18 |
19 |
--------------------------------------------------------------------------------
/configs/train_fcn16s.yml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | META_ARCHITECTURE: "fcn16s"
3 |
4 | BACKBONE:
5 | PRETRAINED: False
6 |
7 | REFINEMENT:
8 | NAME: 'fcn32s'
9 | WEIGHT: "/mnt/truenas/scratch/xingyu.liao/checkpoints/train_fcn32s/fcn_model_13.pth"
10 |
11 | DATASETS:
12 | ROOT: '/mnt/truenas/scratch/xingyu.liao/DATA/VOCdevkit/VOC2012'
13 |
14 | SOLVER:
15 | MAX_EPOCHS: 8
16 | CHECKPOINT_PERIOD: 8
17 |
18 | OUTPUT_DIR: "/mnt/truenas/scratch/xingyu.liao/checkpoints/train_fcn16s"
19 |
--------------------------------------------------------------------------------
/configs/train_fcn32s.yml:
--------------------------------------------------------------------------------
1 |
2 | MODEL:
3 | META_ARCHITECTURE: "fcn32s"
4 |
5 | BACKBONE:
6 | PRETRAINED: True
7 | WEIGHT: '/mnt/truenas/scratch/xingyu.liao/model_zoo/vgg16-397923af.pth'
8 |
9 | REFINEMENT:
10 | NAME: ''
11 |
12 | DATASETS:
13 | ROOT: '/mnt/truenas/scratch/xingyu.liao/DATA/VOCdevkit/VOC2012'
14 |
15 | SOLVER:
16 | MAX_EPOCHS: 13
17 | CHECKPOINT_PERIOD: 13
18 |
19 | OUTPUT_DIR: "/mnt/truenas/scratch/xingyu.liao/checkpoints/train_fcn32s"
20 |
--------------------------------------------------------------------------------
/configs/train_fcn8s.yml:
--------------------------------------------------------------------------------
1 |
2 | MODEL:
3 | META_ARCHITECTURE: "fcn8s"
4 |
5 | BACKBONE:
6 | PRETRAINED: False
7 |
8 | REFINEMENT:
9 | NAME: 'fcn16s'
10 | WEIGHT: "/mnt/truenas/scratch/xingyu.liao/checkpoints/train_fcn16s/fcn_model_8.pth"
11 |
12 | DATASETS:
13 | ROOT: '/mnt/truenas/scratch/xingyu.liao/DATA/VOCdevkit/VOC2012'
14 |
15 | SOLVER:
16 | MAX_EPOCHS: 7
17 | CHECKPOINT_PERIOD: 7
18 |
19 | OUTPUT_DIR: "/mnt/truenas/scratch/xingyu.liao/checkpoints/train_fcn8s"
20 |
--------------------------------------------------------------------------------
/configs/train_fcn8s_atonce.yml:
--------------------------------------------------------------------------------
1 |
2 | MODEL:
3 | META_ARCHITECTURE: "fcn8s"
4 |
5 | BACKBONE:
6 | PRETRAINED: True
7 | WEIGHT: '/mnt/truenas/scratch/xingyu.liao/model_zoo/vgg16-397923af.pth'
8 |
9 | REFINEMENT:
10 | NAME: ''
11 |
12 | DATASETS:
13 | ROOT: '/mnt/truenas/scratch/xingyu.liao/DATA/VOCdevkit/VOC2012'
14 |
15 | SOLVER:
16 | MAX_EPOCHS: 13
17 | CHECKPOINT_PERIOD: 13
18 |
19 | OUTPUT_DIR: "/mnt/truenas/scratch/xingyu.liao/checkpoints/train_fcn8s_atonce"
20 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | from .build import make_data_loader
8 |
--------------------------------------------------------------------------------
/data/build.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | from torch.utils import data
8 |
9 | from .datasets.voc import VocSegDataset
10 | from .transforms import build_transforms
11 |
12 |
13 | def build_dataset(cfg, transforms, is_train=True):
14 | datasets = VocSegDataset(cfg, is_train, transforms)
15 | return datasets
16 |
17 |
18 | def make_data_loader(cfg, is_train=True):
19 | if is_train:
20 | batch_size = cfg.SOLVER.IMS_PER_BATCH
21 | shuffle = True
22 | else:
23 | batch_size = cfg.TEST.IMS_PER_BATCH
24 | shuffle = False
25 |
26 | transforms = build_transforms(cfg, is_train)
27 | datasets = build_dataset(cfg, transforms, is_train)
28 |
29 | num_workers = cfg.DATALOADER.NUM_WORKERS
30 | data_loader = data.DataLoader(
31 | datasets, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True
32 | )
33 |
34 | return data_loader
35 |
--------------------------------------------------------------------------------
/data/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 |
--------------------------------------------------------------------------------
/data/datasets/voc.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | import os
8 |
9 | import numpy as np
10 | from PIL import Image
11 | from torch.utils import data
12 |
13 |
14 | def read_images(root, train):
15 | txt_fname = os.path.join(root, 'ImageSets/Segmentation/') + ('train.txt' if train else 'val.txt')
16 | with open(txt_fname, 'r') as f:
17 | images = f.read().split()
18 | data = [os.path.join(root, 'JPEGImages', i + '.jpg') for i in images]
19 | label = [os.path.join(root, 'SegmentationClass', i + '.png') for i in images]
20 | return data, label
21 |
22 |
23 | class VocSegDataset(data.Dataset):
24 |
25 | def __init__(self, cfg, train, transforms=None):
26 | self.cfg = cfg
27 | self.train = train
28 | self.transforms = transforms
29 | self.data_list, self.label_list = read_images(self.cfg.DATASETS.ROOT, train)
30 |
31 | def __getitem__(self, item):
32 | img = self.data_list[item]
33 | label = self.label_list[item]
34 | img = Image.open(img)
35 | # load label
36 | label = Image.open(label)
37 | img, label = self.transforms(img, label)
38 | return img, label
39 |
40 | def __len__(self):
41 | return len(self.data_list)
42 |
--------------------------------------------------------------------------------
/data/transforms/__init__.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | from .build import build_transforms, build_untransform
8 |
--------------------------------------------------------------------------------
/data/transforms/build.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | import numpy as np
8 | import torch
9 | import torchvision.transforms as T
10 |
11 | from .transforms import RandomHorizontalFlip
12 |
13 |
14 | def build_transforms(cfg, is_train=True):
15 | normalize = T.Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD)
16 | if is_train:
17 | def transform(img, target):
18 | img, target = RandomHorizontalFlip(cfg.INPUT.PROB)(img, target)
19 | img = T.ToTensor()(img)
20 | img = normalize(img)
21 | # label = image2label(target)
22 | label = np.array(target, dtype=np.int64)
23 | # remove boundary
24 | label[label == 255] = -1
25 | label = torch.from_numpy(label)
26 | return img, label
27 |
28 | return transform
29 | else:
30 | def transform(img, target):
31 | img = T.ToTensor()(img)
32 | img = normalize(img)
33 | # label = image2label(target)
34 | label = np.array(target, dtype=np.int64)
35 | # remove boundary
36 | label[label == 255] = -1
37 | label = torch.from_numpy(label)
38 | return img, label
39 |
40 | return transform
41 |
42 |
43 | def build_untransform(cfg):
44 | def untransform(img, target):
45 | img = img * torch.FloatTensor(cfg.INPUT.PIXEL_STD)[:, None, None] \
46 | + torch.FloatTensor(cfg.INPUT.PIXEL_MEAN)[:, None, None]
47 | origin_img = torch.clamp(img, min=0, max=1) * 255
48 | origin_img = origin_img.permute(1, 2, 0).numpy()
49 | origin_img = origin_img.astype(np.uint8)
50 |
51 | label = target.numpy()
52 | label[label == -1] = 0
53 | return origin_img, label
54 |
55 | return untransform
56 |
--------------------------------------------------------------------------------
/data/transforms/transforms.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | import random
8 |
9 | import numpy as np
10 | import torchvision.transforms.functional as F
11 |
12 |
13 | class RandomHorizontalFlip(object):
14 | """Horizontally flip the given PIL Image randomly with a given probability.
15 |
16 | Args:
17 | p (float): probability of the image being flipped. Default value is 0.5
18 | """
19 |
20 | def __init__(self, p=0.5):
21 | self.p = p
22 |
23 | def __call__(self, img, target):
24 | """
25 | Args:
26 | img (PIL Image): Image to be flipped.
27 |
28 | Returns:
29 | PIL Image: Randomly flipped image.
30 | """
31 | if random.random() < self.p:
32 | return F.hflip(img), F.hflip(target)
33 | return img, target
34 |
35 | def __repr__(self):
36 | return self.__class__.__name__ + '(p={})'.format(self.p)
37 |
38 |
39 | def image2label(img):
40 | cm2lbl = np.zeros(256 ** 3)
41 | for i, cm in enumerate(COLORMAP):
42 | cm2lbl[(cm[0] * 256 + cm[1]) * 256 + cm[2]] = i
43 |
44 | data = np.array(img, dtype=np.int32)
45 | idx = (data[:, :, 0] * 256 + data[:, :, 1] * 256 + data[:, :, 2])
46 | return np.array(cm2lbl[idx], dtype=np.int64)
47 |
48 |
49 | CLASSES = ['background', 'aeroplane', 'bicycle', 'bird', 'boat',
50 | 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable',
51 | 'dog', 'horse', 'motorbike', 'person', 'potted plant',
52 | 'sheep', 'sofa', 'train', 'tv/monitor']
53 |
54 | # RGB color for each class.
55 | COLORMAP = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128],
56 | [128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0], [192, 0, 0],
57 | [64, 128, 0], [192, 128, 0], [64, 0, 128], [192, 0, 128],
58 | [64, 128, 128], [192, 128, 128], [0, 64, 0], [128, 64, 0],
59 | [0, 192, 0], [128, 192, 0], [0, 64, 128]]
60 |
--------------------------------------------------------------------------------
/engine/inference.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 | import logging
7 |
8 | import matplotlib.pyplot as plt
9 | import numpy as np
10 | import torch
11 | from ignite.engine import Engine, Events
12 | from tensorboardX import SummaryWriter
13 |
14 | from data.transforms import build_untransform
15 | from data.transforms.transforms import COLORMAP
16 | from utils.metric import Label_Accuracy
17 |
18 | plt.switch_backend('agg')
19 |
20 |
21 | def create_evaluator(model, metrics={}, device=None):
22 | if device:
23 | model.to(device)
24 |
25 | def _inference(engine, batch):
26 | model.eval()
27 | with torch.no_grad():
28 | x, y = batch
29 | x = x.to(device)
30 | y_pred = model(x)
31 | return y_pred, y
32 |
33 | engine = Engine(_inference)
34 |
35 | for name, metric in metrics.items():
36 | metric.attach(engine, name)
37 |
38 | return engine
39 |
40 |
41 | def inference(
42 | cfg,
43 | model,
44 | val_loader
45 | ):
46 | cm = np.array(COLORMAP).astype(np.uint8)
47 | untransform = build_untransform(cfg)
48 |
49 | device = cfg.MODEL.DEVICE
50 | output_dir = cfg.OUTPUT_DIR
51 |
52 | logger = logging.getLogger("FCN_Model.inference")
53 | logger.info("Start inferencing")
54 | evaluator = create_evaluator(model, metrics={'mean_iu': Label_Accuracy(cfg.MODEL.NUM_CLASSES)}, device=device)
55 |
56 | writer = SummaryWriter(output_dir + '/board')
57 |
58 | # adding handlers using `evaluator.on` decorator API
59 | @evaluator.on(Events.EPOCH_COMPLETED)
60 | def print_validation_results(engine):
61 | metrics = evaluator.state.metrics
62 | mean_iu = metrics['mean_iu']
63 | logger.info("Validation Results - Mean IU: {:.3f}".format(mean_iu))
64 |
65 | @evaluator.on(Events.EPOCH_STARTED)
66 | def plot_output(engine):
67 | model.eval()
68 | for i, batch in enumerate(val_loader):
69 | if i > 9:
70 | break
71 | val_x, val_y = batch
72 | val_x = val_x.to(device)
73 | with torch.no_grad():
74 | pred_y = model(val_x)
75 |
76 | orig_img, val_y = untransform(val_x.cpu().data[0], val_y[0])
77 | pred_y = pred_y.max(1)[1].cpu().data[0].numpy()
78 | pred_val = cm[pred_y]
79 | seg_val = cm[val_y]
80 |
81 | # matplotlib
82 | fig = plt.figure(figsize=(9, 3))
83 | plt.subplot(131)
84 | plt.imshow(orig_img)
85 | plt.axis("off")
86 |
87 | plt.subplot(132)
88 | plt.imshow(seg_val)
89 | plt.axis("off")
90 |
91 | plt.subplot(133)
92 | plt.imshow(pred_val)
93 | plt.axis("off")
94 | writer.add_figure('show_result', fig, i)
95 |
96 | evaluator.run(val_loader)
97 | writer.close()
98 |
--------------------------------------------------------------------------------
/engine/trainer.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | import logging
8 |
9 | import matplotlib.pyplot as plt
10 | import numpy as np
11 | import torch
12 | from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
13 | from ignite.handlers import ModelCheckpoint, Timer
14 | from ignite.metrics import Loss, RunningAverage
15 | from tensorboardX import SummaryWriter
16 |
17 | from data.transforms import build_untransform
18 | from data.transforms.transforms import COLORMAP
19 | from utils.metric import Label_Accuracy
20 |
21 | plt.switch_backend('agg')
22 |
23 |
24 | def do_train(
25 | cfg,
26 | model,
27 | train_loader,
28 | val_loader,
29 | optimizer,
30 | loss_fn
31 | ):
32 | cm = np.array(COLORMAP).astype(np.uint8)
33 | untransform = build_untransform(cfg)
34 |
35 | log_period = cfg.SOLVER.LOG_PERIOD
36 | checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
37 | epochs = cfg.SOLVER.MAX_EPOCHS
38 | device = cfg.MODEL.DEVICE
39 | output_dir = cfg.OUTPUT_DIR
40 |
41 | logger = logging.getLogger("FCN_Model.train")
42 | logger.info("Start training")
43 | trainer = create_supervised_trainer(model, optimizer, loss_fn, device=device)
44 | evaluator = create_supervised_evaluator(model, metrics={'mean_iu': Label_Accuracy(cfg.MODEL.NUM_CLASSES),
45 | 'loss': Loss(loss_fn)}, device=device)
46 | checkpointer = ModelCheckpoint(output_dir, 'fcn', checkpoint_period, n_saved=10, require_empty=False)
47 | timer = Timer(average=True)
48 | writer = SummaryWriter(output_dir + '/board')
49 |
50 | # automatically adding handlers via a special `attach` method of `RunningAverage` handler
51 | RunningAverage(output_transform=lambda x: x).attach(trainer, 'avg_loss')
52 |
53 | # automatically adding handlers via a special `attach` method of `Checkpointer` handler
54 | trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {'model': model.state_dict(),
55 | 'optimizer': optimizer.state_dict()})
56 |
57 | # automatically adding handlers via a special `attach` method of `Timer` handler
58 | timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED,
59 | pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED)
60 |
61 | # adding handlers using `trainer.on` decorator API
62 | @trainer.on(Events.ITERATION_COMPLETED)
63 | def log_training_loss(engine):
64 | iter = (engine.state.iteration - 1) % len(train_loader) + 1
65 |
66 | if iter % log_period == 0:
67 | logger.info("Epoch[{}] Iteration[{}/{}] Loss: {:.3f}"
68 | .format(engine.state.epoch, iter, len(train_loader), engine.state.metrics['avg_loss']))
69 | writer.add_scalars("loss", {'train': engine.state.metrics['avg_loss']}, engine.state.iteration)
70 |
71 | # adding handlers using `trainer.on` decorator API
72 | @trainer.on(Events.EPOCH_COMPLETED)
73 | def log_training_results(engine):
74 | evaluator.run(train_loader)
75 | metrics = evaluator.state.metrics
76 | mean_iu = metrics['mean_iu']
77 | avg_loss = metrics['loss']
78 | logger.info("Training Results - Epoch: {} Mean IU: {:.3f} Avg Loss: {:.3f}"
79 | .format(engine.state.epoch, mean_iu, avg_loss))
80 | writer.add_scalars("mean_iu", {'train': mean_iu}, engine.state.epoch)
81 |
82 | if val_loader is not None:
83 | # adding handlers using `trainer.on` decorator API
84 | @trainer.on(Events.EPOCH_COMPLETED)
85 | def log_validation_results(engine):
86 | evaluator.run(val_loader)
87 | metrics = evaluator.state.metrics
88 | mean_iu = metrics['mean_iu']
89 | avg_loss = metrics['loss']
90 | logger.info("Validation Results - Epoch: {} Mean IU: {:.3f} Avg Loss: {:.3f}"
91 | .format(engine.state.epoch, mean_iu, avg_loss)
92 | )
93 | writer.add_scalars("loss", {'validation': avg_loss}, engine.state.iteration)
94 | writer.add_scalars("mean_iu", {'validation': mean_iu}, engine.state.epoch)
95 |
96 | # adding handlers using `trainer.on` decorator API
97 | @trainer.on(Events.EPOCH_COMPLETED)
98 | def print_times(engine):
99 | logger.info('Epoch {} done. Time per batch: {:.3f}[s] Speed: {:.1f}[samples/s]'
100 | .format(engine.state.epoch, timer.value() * timer.step_count,
101 | train_loader.batch_size / timer.value()))
102 | timer.reset()
103 |
104 | @trainer.on(Events.EPOCH_COMPLETED)
105 | def plot_output(engine):
106 | model.eval()
107 | dataset = val_loader.dataset
108 | idx = np.random.choice(np.arange(len(dataset)), size=1).item()
109 | val_x, val_y = dataset[idx]
110 | val_x = val_x.to(device)
111 | with torch.no_grad():
112 | pred_y = model(val_x.unsqueeze(0))
113 |
114 | orig_img, val_y = untransform(val_x.cpu().data, val_y)
115 | pred_y = pred_y.max(1)[1].cpu().data[0].numpy()
116 | pred_val = cm[pred_y]
117 | seg_val = cm[val_y]
118 |
119 | # matplotlib
120 | fig = plt.figure(figsize=(9, 3))
121 | plt.subplot(131)
122 | plt.imshow(orig_img)
123 | plt.axis("off")
124 |
125 | plt.subplot(132)
126 | plt.imshow(seg_val)
127 | plt.axis("off")
128 |
129 | plt.subplot(133)
130 | plt.imshow(pred_val)
131 | plt.axis("off")
132 | writer.add_figure('show_result', fig, engine.state.iteration)
133 |
134 | trainer.run(train_loader, max_epochs=epochs)
135 | writer.close()
136 |
--------------------------------------------------------------------------------
/get_data.sh:
--------------------------------------------------------------------------------
1 | wget http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar
2 |
3 | if [ ! -e ./dataset ]; then
4 | mkdir ./dataset
5 | fi
6 |
7 | tar -xf VOCtrainval_11-May-2012.tar -C ./dataset
8 | rm VOCtrainval_11-May-2012.tar
--------------------------------------------------------------------------------
/layers/bilinear_upsample.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | import numpy as np
8 | import torch
9 | from torch import nn
10 |
11 |
12 | def get_upsampling_weight(in_channels, out_channels, kernel_size):
13 | """
14 | Make a 2D bilinear kernel suitable for unsampling
15 | """
16 | factor = (kernel_size + 1) // 2
17 | if kernel_size % 2 == 1:
18 | center = factor - 1
19 | else:
20 | center = factor - 0.5
21 | og = np.ogrid[:kernel_size, :kernel_size]
22 | bilinear_filter = (1 - abs(og[0] - center) / factor) * (1 - abs(og[1] - center) / factor)
23 | weight = np.zeros((in_channels, out_channels, kernel_size, kernel_size), dtype=np.float32)
24 | weight[range(in_channels), range(out_channels), :, :] = bilinear_filter
25 | return torch.from_numpy(weight).float()
26 |
27 |
28 | def bilinear_upsampling(in_channels, out_channels, kernel_size, stride, bias=False):
29 | initial_weight = get_upsampling_weight(in_channels, out_channels, kernel_size)
30 | layer = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, bias=bias)
31 | layer.weight.data.copy_(initial_weight)
32 | # weight is frozen because it's just a bilinear upsampling
33 | layer.weight.requires_grad = False
34 | return layer
35 |
--------------------------------------------------------------------------------
/layers/conv_layer.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 | from torch import nn
7 |
8 |
9 | def conv_layer(in_channels, out_channles, kernel_size, stride=1, padding=0, bias=True):
10 | layer = nn.Conv2d(in_channels, out_channles, kernel_size, stride, padding, bias=bias)
11 | layer.weight.data.zero_()
12 | if bias:
13 | layer.bias.data.zero_()
14 | return layer
15 |
--------------------------------------------------------------------------------
/layers/cross_entropy2d.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 | import torch.nn.functional as F
7 |
8 |
9 | def cross_entropy2d(input, target, weight=None, size_average=True):
10 | # input: (n, c, h, w), target: (n, h, w)
11 | n, c, h, w = input.size()
12 | # log_p: (n, c, h, w)
13 | log_p = F.log_softmax(input, dim=1)
14 | # log_p: (n*h*w, c)
15 | log_p = log_p.transpose(1, 2).transpose(2, 3).contiguous()
16 | log_p = log_p[target.view(n, h, w, 1).repeat(1, 1, 1, c) >= 0]
17 | log_p = log_p.view(-1, c)
18 | # target: (n*h*w,)
19 | mask = target >= 0
20 | target = target[mask]
21 | loss = F.nll_loss(log_p, target, weight=weight, reduction='sum')
22 | if size_average:
23 | loss /= mask.data.sum()
24 | return loss
25 |
--------------------------------------------------------------------------------
/modeling/__init__.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | import torch
8 |
9 | from .backbone.vgg import pretrained_vgg
10 | from .fcn16s import FCN16s
11 | from .fcn32s import FCN32s
12 | from .fcn8s import FCN8s
13 |
14 | _FCN_META_ARCHITECTURE = {'fcn32s': FCN32s,
15 | 'fcn16s': FCN16s,
16 | 'fcn8s': FCN8s}
17 |
18 |
19 | def build_fcn_model(cfg):
20 | meta_arch = _FCN_META_ARCHITECTURE[cfg.MODEL.META_ARCHITECTURE]
21 | model = meta_arch(cfg)
22 | if cfg.MODEL.BACKBONE.PRETRAINED:
23 | vgg16 = pretrained_vgg(cfg)
24 | model.copy_params_from_vgg16(vgg16)
25 | if cfg.MODEL.REFINEMENT.NAME == 'fcn32s':
26 | fcn32s = FCN32s(cfg)
27 | fcn32s.load_state_dict(torch.load(cfg.MODEL.REFINEMENT.WEIGHT))
28 | model.copy_params_from_fcn32s(fcn32s)
29 | elif cfg.MODEL.REFINEMENT.NAME == 'fcn16s':
30 | fcn16s = FCN16s(cfg)
31 | fcn16s.load_state_dict(torch.load(cfg.MODEL.REFINEMENT.WEIGHT))
32 | model.copy_params_from_fcn16s(fcn16s)
33 | return model
34 |
--------------------------------------------------------------------------------
/modeling/backbone/__init__.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 | from .vgg import VGG16
7 |
8 |
9 | def build_backbone(cfg):
10 | if cfg.MODEL.BACKBONE.NAME == 'vgg16':
11 | backbone = VGG16()
12 | return backbone
13 |
--------------------------------------------------------------------------------
/modeling/backbone/vgg.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 | import torch
7 | import torchvision
8 | from torch import nn
9 |
10 |
11 | class VGG16(nn.Module):
12 | def __init__(self):
13 | super(VGG16, self).__init__()
14 | self.conv1_1 = nn.Conv2d(3, 64, 3, padding=100)
15 | self.relu1_1 = nn.ReLU(inplace=True)
16 | self.conv1_2 = nn.Conv2d(64, 64, 3, padding=1)
17 | self.relu1_2 = nn.ReLU(inplace=True)
18 | self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/2
19 |
20 | # conv2
21 | self.conv2_1 = nn.Conv2d(64, 128, 3, padding=1)
22 | self.relu2_1 = nn.ReLU(inplace=True)
23 | self.conv2_2 = nn.Conv2d(128, 128, 3, padding=1)
24 | self.relu2_2 = nn.ReLU(inplace=True)
25 | self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/4
26 |
27 | # conv3
28 | self.conv3_1 = nn.Conv2d(128, 256, 3, padding=1)
29 | self.relu3_1 = nn.ReLU(inplace=True)
30 | self.conv3_2 = nn.Conv2d(256, 256, 3, padding=1)
31 | self.relu3_2 = nn.ReLU(inplace=True)
32 | self.conv3_3 = nn.Conv2d(256, 256, 3, padding=1)
33 | self.relu3_3 = nn.ReLU(inplace=True)
34 | self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/8
35 |
36 | # conv4
37 | self.conv4_1 = nn.Conv2d(256, 512, 3, padding=1)
38 | self.relu4_1 = nn.ReLU(inplace=True)
39 | self.conv4_2 = nn.Conv2d(512, 512, 3, padding=1)
40 | self.relu4_2 = nn.ReLU(inplace=True)
41 | self.conv4_3 = nn.Conv2d(512, 512, 3, padding=1)
42 | self.relu4_3 = nn.ReLU(inplace=True)
43 | self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/16
44 |
45 | # conv5
46 | self.conv5_1 = nn.Conv2d(512, 512, 3, padding=1)
47 | self.relu5_1 = nn.ReLU(inplace=True)
48 | self.conv5_2 = nn.Conv2d(512, 512, 3, padding=1)
49 | self.relu5_2 = nn.ReLU(inplace=True)
50 | self.conv5_3 = nn.Conv2d(512, 512, 3, padding=1)
51 | self.relu5_3 = nn.ReLU(inplace=True)
52 | self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/32
53 |
54 | def forward(self, x):
55 | x = self.relu1_1(self.conv1_1(x))
56 | x = self.relu1_2(self.conv1_2(x))
57 | x = self.pool1(x)
58 |
59 | x = self.relu2_1(self.conv2_1(x))
60 | x = self.relu2_2(self.conv2_2(x))
61 | x = self.pool2(x)
62 |
63 | x = self.relu3_1(self.conv3_1(x))
64 | x = self.relu3_2(self.conv3_2(x))
65 | x = self.relu3_3(self.conv3_3(x))
66 | x = self.pool3(x)
67 |
68 | x = self.relu4_1(self.conv4_1(x))
69 | x = self.relu4_2(self.conv4_2(x))
70 | x = self.relu4_3(self.conv4_3(x))
71 | x = self.pool4(x)
72 |
73 | x = self.relu5_1(self.conv5_1(x))
74 | x = self.relu5_2(self.conv5_2(x))
75 | x = self.relu5_3(self.conv5_3(x))
76 | x = self.pool5(x)
77 | return x
78 |
79 |
80 | def pretrained_vgg(cfg):
81 | model = torchvision.models.vgg16(pretrained=False)
82 | model.load_state_dict(torch.load(cfg.MODEL.BACKBONE.WEIGHT))
83 | return model
84 |
--------------------------------------------------------------------------------
/modeling/fcn16s.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | from torch import nn
8 |
9 | from layers.bilinear_upsample import bilinear_upsampling
10 | from layers.conv_layer import conv_layer
11 | from .backbone import build_backbone
12 |
13 |
14 | class FCN16s(nn.Module):
15 | def __init__(self, cfg):
16 | super(FCN16s, self).__init__()
17 | self.backbone = build_backbone(cfg)
18 | num_classes = cfg.MODEL.NUM_CLASSES
19 |
20 | # fc1
21 | self.fc1 = conv_layer(512, 4096, 7)
22 | self.relu1 = nn.ReLU(inplace=True)
23 | self.drop1 = nn.Dropout2d()
24 |
25 | # fc2
26 | self.fc2 = conv_layer(4096, 4096, 1)
27 | self.relu2 = nn.ReLU(inplace=True)
28 | self.drop2 = nn.Dropout2d()
29 |
30 | self.score_fr = conv_layer(4096, num_classes, 1)
31 | self.score_pool4 = conv_layer(512, num_classes, 1)
32 |
33 | self.upscore2 = bilinear_upsampling(num_classes, num_classes, 4, stride=2, bias=False)
34 | self.upscore16 = bilinear_upsampling(num_classes, num_classes, 32, stride=16, bias=False)
35 |
36 | def forward(self, x):
37 | _, _, h, w = x.size()
38 | x = self.backbone.conv1_1(x)
39 | x = self.backbone.relu1_1(x)
40 | x = self.backbone.conv1_2(x)
41 | x = self.backbone.relu1_2(x)
42 | x = self.backbone.pool1(x)
43 |
44 | x = self.backbone.conv2_1(x)
45 | x = self.backbone.relu2_1(x)
46 | x = self.backbone.conv2_2(x)
47 | x = self.backbone.relu2_2(x)
48 | x = self.backbone.pool2(x)
49 |
50 | x = self.backbone.conv3_1(x)
51 | x = self.backbone.relu3_1(x)
52 | x = self.backbone.conv3_2(x)
53 | x = self.backbone.relu3_2(x)
54 | x = self.backbone.conv3_3(x)
55 | x = self.backbone.relu3_3(x)
56 | x = self.backbone.pool3(x)
57 |
58 | x = self.backbone.conv4_1(x)
59 | x = self.backbone.relu4_1(x)
60 | x = self.backbone.conv4_2(x)
61 | x = self.backbone.relu4_2(x)
62 | x = self.backbone.conv4_3(x)
63 | x = self.backbone.relu4_3(x)
64 | x = self.backbone.pool4(x)
65 | pool4 = x # 1/16
66 |
67 | x = self.backbone.conv5_1(x)
68 | x = self.backbone.relu5_1(x)
69 | x = self.backbone.conv5_2(x)
70 | x = self.backbone.relu5_2(x)
71 | x = self.backbone.conv5_3(x)
72 | x = self.backbone.relu5_3(x)
73 | x = self.backbone.pool5(x)
74 |
75 | x = self.relu1(self.fc1(x))
76 | x = self.drop1(x)
77 |
78 | x = self.relu2(self.fc2(x))
79 | x = self.drop2(x)
80 |
81 | x = self.score_fr(x)
82 | x = self.upscore2(x)
83 | upscore2 = x
84 |
85 | x = self.score_pool4(pool4)
86 | x = x[:, :, 5:5 + upscore2.size()[2], 5:5 + upscore2.size()[3]]
87 | score_pool4c = x # 1/16
88 |
89 | x = upscore2 + score_pool4c
90 |
91 | x = self.upscore16(x)
92 | x = x[:, :, 27:27 + h, 27:27 + w].contiguous()
93 | return x
94 |
95 | def copy_params_from_fcn32s(self, fcn32s):
96 | # load backbone
97 | self.backbone.load_state_dict(fcn32s.backbone.state_dict())
98 | for name, l1 in fcn32s.named_children():
99 | try:
100 | l2 = getattr(self, name)
101 | l2.weight # skip ReLU / Dropout
102 | except AttributeError:
103 | continue
104 | assert l1.weight.size() == l2.weight.size()
105 | l2.weight.data.copy_(l1.weight.data)
106 | if l1.bias is not None:
107 | assert l1.bias.size() == l2.bias.size()
108 | l2.bias.data.copy_(l1.bias.data)
109 |
110 |
111 |
--------------------------------------------------------------------------------
/modeling/fcn32s.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | from torch import nn
8 |
9 | from layers.bilinear_upsample import bilinear_upsampling
10 | from layers.conv_layer import conv_layer
11 | from .backbone import build_backbone
12 |
13 |
14 | class FCN32s(nn.Module):
15 | def __init__(self, cfg):
16 | super(FCN32s, self).__init__()
17 | self.backbone = build_backbone(cfg)
18 | num_classes = cfg.MODEL.NUM_CLASSES
19 |
20 | self.fc1 = conv_layer(512, 4096, 7)
21 | self.relu1 = nn.ReLU(inplace=True)
22 | self.drop1 = nn.Dropout2d()
23 |
24 | self.fc2 = conv_layer(4096, 4096, 1)
25 | self.relu2 = nn.ReLU(inplace=True)
26 | self.drop2 = nn.Dropout2d()
27 |
28 | self.score_fr = conv_layer(4096, num_classes, 1)
29 | self.upscore = bilinear_upsampling(num_classes, num_classes, 64, stride=32,
30 | bias=False)
31 |
32 | def forward(self, x):
33 | _, _, h, w = x.size()
34 | x = self.backbone(x)
35 | x = self.relu1(self.fc1(x))
36 | x = self.drop1(x)
37 |
38 | x = self.relu2(self.fc2(x))
39 | x = self.drop2(x)
40 |
41 | x = self.score_fr(x)
42 | x = self.upscore(x)
43 | x = x[:, :, 19:19 + h, 19:19 + w].contiguous()
44 | return x
45 |
46 | def copy_params_from_vgg16(self, vgg16):
47 | feat = self.backbone
48 | features = [
49 | feat.conv1_1, feat.relu1_1,
50 | feat.conv1_2, feat.relu1_2,
51 | feat.pool1,
52 | feat.conv2_1, feat.relu2_1,
53 | feat.conv2_2, feat.relu2_2,
54 | feat.pool2,
55 | feat.conv3_1, feat.relu3_1,
56 | feat.conv3_2, feat.relu3_2,
57 | feat.conv3_3, feat.relu3_3,
58 | feat.pool3,
59 | feat.conv4_1, feat.relu4_1,
60 | feat.conv4_2, feat.relu4_2,
61 | feat.conv4_3, feat.relu4_3,
62 | feat.pool4,
63 | feat.conv5_1, feat.relu5_1,
64 | feat.conv5_2, feat.relu5_2,
65 | feat.conv5_3, feat.relu5_3,
66 | feat.pool5
67 | ]
68 |
69 | for l1, l2 in zip(vgg16.features, features):
70 | if isinstance(l1, nn.Conv2d) and isinstance(l2, nn.Conv2d):
71 | assert l1.weight.size() == l2.weight.size()
72 | assert l1.bias.size() == l2.bias.size()
73 | l2.weight.data.copy_(l1.weight.data)
74 | l2.bias.data.copy_(l1.bias.data)
75 | for i, name in zip([0, 3], ['fc1', 'fc2']):
76 | l1 = vgg16.classifier[i]
77 | l2 = getattr(self, name)
78 | l2.weight.data.copy_(l1.weight.data.view(l2.weight.size()))
79 | l2.bias.data.copy_(l1.bias.data.view(l2.bias.size()))
80 |
--------------------------------------------------------------------------------
/modeling/fcn8s.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | from torch import nn
8 |
9 | from layers.bilinear_upsample import bilinear_upsampling
10 | from layers.conv_layer import conv_layer
11 | from .backbone import build_backbone
12 |
13 |
14 | class FCN8s(nn.Module):
15 | def __init__(self, cfg):
16 | super(FCN8s, self).__init__()
17 | self.backbone = build_backbone(cfg)
18 | num_classes = cfg.MODEL.NUM_CLASSES
19 |
20 | # fc1
21 | self.fc1 = conv_layer(512, 4096, 7)
22 | self.relu1 = nn.ReLU(inplace=True)
23 | self.drop1 = nn.Dropout2d()
24 |
25 | # fc2
26 | self.fc2 = conv_layer(4096, 4096, 1)
27 | self.relu2 = nn.ReLU(inplace=True)
28 | self.drop2 = nn.Dropout2d()
29 |
30 | self.score_fr = conv_layer(4096, num_classes, 1)
31 | self.score_pool3 = conv_layer(256, num_classes, 1)
32 | self.score_pool4 = conv_layer(512, num_classes, 1)
33 |
34 | self.upscore2 = bilinear_upsampling(num_classes, num_classes, 4, stride=2, bias=False)
35 | self.upscore8 = bilinear_upsampling(num_classes, num_classes, 16, stride=8, bias=False)
36 | self.upscore_pool4 = bilinear_upsampling(num_classes, num_classes, 4, stride=2, bias=False)
37 |
38 | def forward(self, x):
39 | _, _, h, w = x.size()
40 | x = self.backbone.conv1_1(x)
41 | x = self.backbone.relu1_1(x)
42 | x = self.backbone.conv1_2(x)
43 | x = self.backbone.relu1_2(x)
44 | x = self.backbone.pool1(x)
45 |
46 | x = self.backbone.conv2_1(x)
47 | x = self.backbone.relu2_1(x)
48 | x = self.backbone.conv2_2(x)
49 | x = self.backbone.relu2_2(x)
50 | x = self.backbone.pool2(x)
51 |
52 | x = self.backbone.conv3_1(x)
53 | x = self.backbone.relu3_1(x)
54 | x = self.backbone.conv3_2(x)
55 | x = self.backbone.relu3_2(x)
56 | x = self.backbone.conv3_3(x)
57 | x = self.backbone.relu3_3(x)
58 | x = self.backbone.pool3(x)
59 | pool3 = x # 1/8
60 |
61 | x = self.backbone.conv4_1(x)
62 | x = self.backbone.relu4_1(x)
63 | x = self.backbone.conv4_2(x)
64 | x = self.backbone.relu4_2(x)
65 | x = self.backbone.conv4_3(x)
66 | x = self.backbone.relu4_3(x)
67 | x = self.backbone.pool4(x)
68 | pool4 = x # 1/16
69 |
70 | x = self.backbone.conv5_1(x)
71 | x = self.backbone.relu5_1(x)
72 | x = self.backbone.conv5_2(x)
73 | x = self.backbone.relu5_2(x)
74 | x = self.backbone.conv5_3(x)
75 | x = self.backbone.relu5_3(x)
76 | x = self.backbone.pool5(x)
77 |
78 | x = self.relu1(self.fc1(x))
79 | x = self.drop1(x)
80 |
81 | x = self.relu2(self.fc2(x))
82 | x = self.drop2(x)
83 |
84 | x = self.score_fr(x)
85 | x = self.upscore2(x)
86 | upscore2 = x # 1/16
87 |
88 | x = self.score_pool4(pool4)
89 | x = x[:, :, 5:5 + upscore2.size()[2], 5:5 + upscore2.size()[3]]
90 | score_pool4c = x # 1/16
91 |
92 | x = upscore2 + score_pool4c
93 | x = self.upscore_pool4(x)
94 | upscore_pool4 = x # 1/8
95 |
96 | x = self.score_pool3(pool3)
97 | x = x[:, :, 9:9 + upscore_pool4.size()[2], 9:9 + upscore_pool4.size()[3]].contiguous()
98 | score_pool3c = x # 1/8
99 |
100 | x = upscore_pool4 + score_pool3c # 1/8
101 |
102 | x = self.upscore8(x)
103 | x = x[:, :, 31:31 + h, 31:31 + w].contiguous()
104 | return x
105 |
106 | def copy_params_from_fcn16s(self, fcn16s):
107 | self.backbone.load_state_dict(fcn16s.backbone.state_dict())
108 | for name, l1 in fcn16s.named_children():
109 | try:
110 | l2 = getattr(self, name)
111 | l2.weight # skip ReLU / Dropout
112 | except AttributeError:
113 | continue
114 | assert l1.weight.size() == l2.weight.size()
115 | l2.weight.data.copy_(l1.weight.data)
116 | if l1.bias is not None:
117 | assert l1.bias.size() == l2.bias.size()
118 | l2.bias.data.copy_(l1.bias.data)
119 |
120 | def copy_params_from_vgg16(self, vgg16):
121 | feat = self.backbone
122 | features = [
123 | feat.conv1_1, feat.relu1_1,
124 | feat.conv1_2, feat.relu1_2,
125 | feat.pool1,
126 | feat.conv2_1, feat.relu2_1,
127 | feat.conv2_2, feat.relu2_2,
128 | feat.pool2,
129 | feat.conv3_1, feat.relu3_1,
130 | feat.conv3_2, feat.relu3_2,
131 | feat.conv3_3, feat.relu3_3,
132 | feat.pool3,
133 | feat.conv4_1, feat.relu4_1,
134 | feat.conv4_2, feat.relu4_2,
135 | feat.conv4_3, feat.relu4_3,
136 | feat.pool4,
137 | feat.conv5_1, feat.relu5_1,
138 | feat.conv5_2, feat.relu5_2,
139 | feat.conv5_3, feat.relu5_3,
140 | feat.pool5
141 | ]
142 |
143 | for l1, l2 in zip(vgg16.features, features):
144 | if isinstance(l1, nn.Conv2d) and isinstance(l2, nn.Conv2d):
145 | assert l1.weight.size() == l2.weight.size()
146 | assert l1.bias.size() == l2.bias.size()
147 | l2.weight.data.copy_(l1.weight.data)
148 | l2.bias.data.copy_(l1.bias.data)
149 | for i, name in zip([0, 3], ['fc1', 'fc2']):
150 | l1 = vgg16.classifier[i]
151 | l2 = getattr(self, name)
152 | l2.weight.data.copy_(l1.weight.data.view(l2.weight.size()))
153 | l2.bias.data.copy_(l1.bias.data.view(l2.bias.size()))
154 |
--------------------------------------------------------------------------------
/solver/__init__.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | from .build import make_optimizer
8 |
--------------------------------------------------------------------------------
/solver/build.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | import torch
8 |
9 |
10 | def make_optimizer(cfg, model):
11 | params = []
12 | for key, value in model.named_parameters():
13 | if not value.requires_grad:
14 | continue
15 | lr = cfg.SOLVER.BASE_LR
16 | weight_decay = cfg.SOLVER.WEIGHT_DECAY
17 | if "bias" in key:
18 | lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR
19 | weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS
20 | params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}]
21 | if cfg.SOLVER.OPTIMIZER_NAME == 'SGD':
22 | optimizer = torch.optim.SGD(params, lr, momentum=cfg.SOLVER.MOMENTUM)
23 | return optimizer
24 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
--------------------------------------------------------------------------------
/tests/test_dataset.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | import sys
8 | import unittest
9 |
10 | sys.path.append('.')
11 | from config import cfg
12 | from data.transforms import build_transforms
13 | from data.build import build_dataset
14 | from solver.build import make_optimizer
15 |
16 |
17 | class TestDataSet(unittest.TestCase):
18 | def test_dataset(self):
19 | train_transform = build_transforms(cfg, True)
20 | val_transform = build_transforms(cfg, False)
21 | train_set = build_dataset(cfg, train_transform, True)
22 | val_test = build_dataset(cfg, val_transform, False)
23 | from IPython import embed;
24 | embed()
25 |
26 |
27 | if __name__ == '__main__':
28 | unittest.main()
29 |
--------------------------------------------------------------------------------
/tests/test_model.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import unittest
3 |
4 | sys.path.append('.')
5 | from modeling.backbone.vgg import VGG16
6 | from config import cfg
7 | from modeling import build_fcn_model
8 | from modeling.backbone import build_backbone
9 | import torch
10 |
11 |
12 | class MyTestCase(unittest.TestCase):
13 | def test_vgg(self):
14 | vgg = build_backbone(cfg)
15 | model = build_fcn_model(cfg)
16 | print(model.backbone.conv1_1.weight[0, 0, 0, 0])
17 | # x = torch.randn(5, 3, 224, 224)
18 | # y = model(x)
19 | from IPython import embed;
20 | embed()
21 |
22 |
23 | if __name__ == '__main__':
24 | unittest.main()
25 |
--------------------------------------------------------------------------------
/tools/__init__.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
--------------------------------------------------------------------------------
/tools/test_fcn.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | import argparse
8 | import os
9 | import sys
10 | from os import mkdir
11 |
12 | import torch
13 |
14 | sys.path.append('.')
15 | from config import cfg
16 | from data import make_data_loader
17 | from engine.inference import inference
18 | from modeling import build_fcn_model
19 | from utils.logger import setup_logger
20 |
21 |
22 | def main():
23 | parser = argparse.ArgumentParser(description="PyTorch FCN Inference")
24 | parser.add_argument(
25 | "--config_file", default="", help="path to config file", type=str
26 | )
27 | parser.add_argument("opts", help="Modify config options using the command-line", default=None,
28 | nargs=argparse.REMAINDER)
29 |
30 | args = parser.parse_args()
31 |
32 | num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
33 |
34 | if args.config_file != "":
35 | cfg.merge_from_file(args.config_file)
36 | cfg.merge_from_list(args.opts)
37 | cfg.freeze()
38 |
39 | output_dir = cfg.OUTPUT_DIR
40 | if output_dir and not os.path.exists(output_dir):
41 | mkdir(output_dir)
42 |
43 | logger = setup_logger("FCN_Model", output_dir, 0)
44 | logger.info("Using {} GPUS".format(num_gpus))
45 | logger.info(args)
46 |
47 | if args.config_file != "":
48 | logger.info("Loaded configuration file {}".format(args.config_file))
49 | with open(args.config_file, 'r') as cf:
50 | config_str = "\n" + cf.read()
51 | logger.info(config_str)
52 | logger.info("Running with config:\n{}".format(cfg))
53 |
54 | model = build_fcn_model(cfg)
55 | model.load_state_dict(torch.load(cfg.TEST.WEIGHT))
56 | val_loader = make_data_loader(cfg, is_train=False)
57 |
58 | inference(cfg, model, val_loader)
59 |
60 |
61 | if __name__ == '__main__':
62 | main()
63 |
--------------------------------------------------------------------------------
/tools/train_fcn.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | import argparse
8 | import os
9 | import sys
10 | from os import mkdir
11 |
12 | sys.path.append('.')
13 | from config import cfg
14 | from data import make_data_loader
15 | from engine.trainer import do_train
16 | from modeling import build_fcn_model
17 | from solver import make_optimizer
18 | from utils.logger import setup_logger
19 | from layers.cross_entropy2d import cross_entropy2d
20 |
21 |
22 | def train(cfg):
23 | model = build_fcn_model(cfg)
24 |
25 | optimizer = make_optimizer(cfg, model)
26 |
27 | arguments = {}
28 |
29 | data_loader = make_data_loader(cfg, is_train=True)
30 | val_loader = make_data_loader(cfg, is_train=False)
31 |
32 | do_train(
33 | cfg,
34 | model,
35 | data_loader,
36 | val_loader,
37 | optimizer,
38 | cross_entropy2d,
39 | )
40 |
41 |
42 | def main():
43 | parser = argparse.ArgumentParser(description="PyTorch FCN Training")
44 | parser.add_argument(
45 | "--config_file", default="", help="path to config file", type=str
46 | )
47 | parser.add_argument("opts", help="Modify config options using the command-line", default=None,
48 | nargs=argparse.REMAINDER)
49 |
50 | args = parser.parse_args()
51 |
52 | num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
53 |
54 | if args.config_file != "":
55 | cfg.merge_from_file(args.config_file)
56 | cfg.merge_from_list(args.opts)
57 | cfg.freeze()
58 |
59 | output_dir = cfg.OUTPUT_DIR
60 | if output_dir and not os.path.exists(output_dir):
61 | mkdir(output_dir)
62 |
63 | logger = setup_logger("FCN_Model", output_dir, 0)
64 | logger.info("Using {} GPUS".format(num_gpus))
65 | logger.info(args)
66 |
67 | if args.config_file != "":
68 | logger.info("Loaded configuration file {}".format(args.config_file))
69 | with open(args.config_file, 'r') as cf:
70 | config_str = "\n" + cf.read()
71 | logger.info(config_str)
72 | logger.info("Running with config:\n{}".format(cfg))
73 |
74 | train(cfg)
75 |
76 |
77 | if __name__ == '__main__':
78 | main()
79 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 |
--------------------------------------------------------------------------------
/utils/logger.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | import logging
8 | import os
9 | import sys
10 |
11 |
12 | def setup_logger(name, save_dir, distributed_rank):
13 | logger = logging.getLogger(name)
14 | logger.setLevel(logging.DEBUG)
15 | # don't log results for the non-master process
16 | if distributed_rank > 0:
17 | return logger
18 | ch = logging.StreamHandler(stream=sys.stdout)
19 | ch.setLevel(logging.DEBUG)
20 | formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s")
21 | ch.setFormatter(formatter)
22 | logger.addHandler(ch)
23 |
24 | if save_dir:
25 | fh = logging.FileHandler(os.path.join(save_dir, "log.txt"), mode='w')
26 | fh.setLevel(logging.DEBUG)
27 | fh.setFormatter(formatter)
28 | logger.addHandler(fh)
29 |
30 | return logger
31 |
--------------------------------------------------------------------------------
/utils/metric.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 | import numpy as np
7 | from ignite.metrics import Metric
8 |
9 |
10 | def _fast_hist(label_true, label_pred, n_class):
11 | mask = (label_true >= 0) & (label_true < n_class)
12 | hist = np.bincount(
13 | n_class * label_true[mask].astype(int) +
14 | label_pred[mask], minlength=n_class ** 2).reshape(n_class, n_class)
15 | return hist
16 |
17 |
18 | class Label_Accuracy(Metric):
19 | """
20 | Calculates the accuracy.
21 |
22 | - `update` must receive output of the form `(y_pred, y)`.
23 | - `y_pred` must be in the following shape (batch_size, num_categories, ...) or (batch_size, ...)
24 | - `y` must be in the following shape (batch_size, ...)
25 | """
26 |
27 | def __init__(self, n_class):
28 | super(Label_Accuracy, self).__init__()
29 | self.n_class = n_class
30 |
31 | def reset(self):
32 | self.step = 0
33 | self.mean_iu = 0
34 |
35 | def update(self, output):
36 | label_preds, label_trues = output
37 | label_preds = label_preds.max(dim=1)[1].data.cpu().numpy()
38 | label_preds = [i for i in label_preds]
39 |
40 | label_trues = label_trues.data.cpu().numpy()
41 | label_trues = [i for i in label_trues]
42 |
43 | hist = np.zeros((self.n_class, self.n_class))
44 | for lt, lp in zip(label_trues, label_preds):
45 | hist += _fast_hist(lt.flatten(), lp.flatten(), self.n_class)
46 | with np.errstate(divide='ignore', invalid='ignore'):
47 | iu = np.diag(hist) / (
48 | hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist)
49 | )
50 | mean_iu = np.nanmean(iu)
51 | self.mean_iu += mean_iu
52 | self.step += 1
53 |
54 | def compute(self):
55 | return self.mean_iu / self.step
56 |
--------------------------------------------------------------------------------