├── README.md
├── builders
├── dataset_builder.py
├── loss_builder.py
├── model_builder.py
└── validation_builder.py
├── dataset
├── README.md
├── calculate_class_weight.py
├── cityscapes
│ ├── Percent.png
│ ├── cityscape_scripts
│ │ ├── __init__.py
│ │ ├── download_cityscapes.sh
│ │ └── process_cityscapes.py
│ ├── cityscapes.py
│ ├── cityscapes_test_list.txt
│ ├── cityscapes_train_list.txt
│ ├── cityscapes_trainval_list.txt
│ ├── cityscapes_val_list.txt
│ └── class_map.csv
├── generate_txt.py
└── inform
│ └── cityscapes_inform.pkl
├── example
├── aachen_000000_000019_gtFine_color.png
├── aachen_000000_000019_gtFine_labelTrainIds.png
├── aachen_000000_000019_leftImg8bit.png
├── average_results.png
├── class_results1.png
├── class_results2.png
├── class_results3.png
├── lindau_000000_000019_leftImg8bit.png
├── lindau_000000_000019_leftImg8bit_color.png
└── lindau_000000_000019_leftImg8bit_gt.png
├── model
├── BiSeNet.py
├── BiSeNetV2.py
├── DDRNet.py
├── DeeplabV3Plus.py
├── FCN8s.py
├── FCN_ResNet.py
├── HRNet.py
├── PSPNet
│ ├── psanet.py
│ ├── pspnet.py
│ └── resnet.py
├── SegNet.py
├── UNet.py
├── base_model
│ ├── __init__.py
│ ├── resnet.py
│ └── xception.py
└── sync_batchnorm
│ ├── __init__.py
│ ├── batchnorm.py
│ ├── batchnorm_reimpl.py
│ ├── comm.py
│ ├── replicate.py
│ └── unittest.py
├── predict.py
├── predict.sh
├── requirements.txt
├── train.py
├── train.sh
└── utils
├── colorize_mask.py
├── distributed.py
├── earlyStopping.py
├── flops_counter
├── CHANGELOG.md
├── Flops_test.py
├── LICENSE
├── README.md
├── __init__.py
├── ptflops
│ ├── __init__.py
│ └── flops_counter.py
└── setup.py
├── fps_test
└── eval_forward_time.py
├── image_transform.py
├── losses
├── __init__.py
├── loss.py
└── lovasz_losses.py
├── metric
├── SegmentationMetric.py
└── __init__.py
├── optim
├── AdamW.py
├── Lookahead.py
├── RAdam.py
├── Ranger.py
└── __init__.py
├── plot_log.py
├── record_log.py
├── scheduler
├── __init__.py
└── lr_scheduler.py
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 | ## :rocket: If it helps you, click a star! :star: ##
2 | ## Update log
3 | - 2020.12.10 Project structure adjustment, the previous code has been deleted, the adjustment will be re-uploaded code
4 | - 2021.04.09 Re-upload the code, "V1 Commit"
5 | - 2021.04.22 update torch distributed training
6 | - Ongoing update .....
7 |
8 | # 1. Display (Cityscapes)
9 | - Using model DDRNet 1525 test sets, official MIOU =78.4069%
10 |
11 |
12 |
13 | Average results
14 | Class results1
15 | Class results2
16 | Class results3
17 |
18 |
19 |
20 | - Comparison of the original and predicted images
21 |
22 |
23 | origin
24 | label
25 | predict
26 |
27 |
28 |
29 | # 2. Install
30 | ```pip install -r requirements.txt```
31 | Experimental environment:
32 | - Ubuntu 16.04 Nvidia-Cards >= 1
33 | - python==3.6.5
34 | - See Dependency Installation Package for details in requirement.txt
35 |
36 | # 3. Model
37 | All the modeling is done in `builders/model_builder.py`
38 | - [x] FCN
39 | - [x] FCN_ResNet
40 | - [x] SegNet
41 | - [x] UNet
42 | - [x] BiSeNet
43 | - [x] BiSeNetV2
44 | - [x] PSPNet
45 | - [x] DeepLabv3_plus
46 | - [x] HRNet
47 | - [x] DDRNet
48 |
49 | | Model| Backbone| Val mIoU | Test mIoU | Imagenet Pretrain| Pretrained Model |
50 | | :--- | :---: |:---: |:---:|:---:|:---:|
51 | | PSPNet | ResNet 50 | 76.54% | - | √ | [PSPNet](https://drive.google.com/file/d/10T321s62xDZQJUR3k0H-l64smYW0QAxN/view?usp=sharing) |
52 | | DeeplabV3+ | ResNet 50 | 77.78% | - | √ | [DeeplabV3+](https://drive.google.com/file/d/1xP7HQwFcXAPuoL_BCYdghOBnEJIxNE-T/view?usp=sharing) |
53 | | DDRNet23_slim | - | | | [DDRNet23_slim_imagenet](https://drive.google.com/file/d/1mg5tMX7TJ9ZVcAiGSB4PEihPtrJyalB4/view) | |
54 | | DDRNet23 | - | | | [DDRNet23_imagenet](https://drive.google.com/file/d/1VoUsERBeuCaiuQJufu8PqpKKtGvCTdug/view) | |
55 | | DDRNet39 | - | 79.63% | - | [DDRNet39_imagenet](https://drive.google.com/file/d/122CMx6DZBaRRf-dOHYwuDY9vG0_UQ10i/view) | [DDRNet39](https://drive.google.com/file/d/1-poQsQzXqGl2d2ILXRhWgQH452MUTX5y/view?usp=sharing) |
56 | Updating more model.......
57 |
58 | # 4. Data preprocessing
59 | This project enables you to expose data sets: `Cityscapes`、`ISPRS`
60 | The data set is uploaded later .....
61 | **Cityscapes data set preparation is shown here:**
62 |
63 | ## 4.1 Download the dataset
64 | Download the dataset from the link on the website, You can get `*leftImg8bit.png` suffix of original image under folder `leftImg8bit`,
65 | `a) *color.png`、`b) *labelIds.png`、`c) *instanceIds.png` suffix of fine labeled image under folder `gtFine`.
66 | ```
67 | *leftImg8bit.png : the origin picture
68 | a) *color.png : the class is encoded by its color
69 | b) *labelIds.png : the class is encoded by its ID
70 | c) *instanceIds.png : the class and the instance are encoded by an instance ID
71 | ```
72 | ## 4.2 Onehot encoding of label image
73 | The real label gray scale image Onehot encoding used by the semantic segmentation task is 0-18, so the label needs to be encoded.
74 | Using scripts `dataset/cityscapes/cityscapes_scripts/process_cityscapes.py`
75 | to process the image and get the result `*labelTrainIds.png`.
76 | `process_cityscapes.py` usage: Modify 486 lines `Cityscapes_path'is the path to store your own data.
77 |
78 | - Comparison of original image, color label image and gray label image (0-18)
79 |
80 |
81 | ***_leftImg8bit
82 | ***_gtFine_color
83 | ***_gtFine_labelTrainIds
84 |
85 |
86 |
87 | - Local storage path display `/data/open_data/cityscapes/`:
88 | ```
89 | data
90 | |--open_data
91 | |--cityscapes
92 | |--leftImg8bit
93 | |--train
94 | |--cologne
95 | |--*******
96 | |--val
97 | |--*******
98 | |--test
99 | |--*******
100 | |--gtFine
101 | |--train
102 | |--cologne
103 | |--*******
104 | |--val
105 | |--*******
106 | |--test
107 | |--*******
108 | ```
109 |
110 | ## 4.3 Generate image path
111 | - Generate a txt containing the image path
112 | Use script `dataset/generate_txt.py` to generate the path `txt` file containing the original image and labels.
113 | A total of 3 `txt` files will be generated: `cityscapes_train_list.txt`、`cityscapes_val_list.txt`、
114 | `cityscapes_test_list.txt`, and copy the three files to the dataset root directory.
115 | ```
116 | data
117 | |--open_data
118 | |--cityscapes
119 | |--cityscapes_train_list.txt
120 | |--cityscapes_val_list.txt
121 | |--cityscapes_test_list.txt
122 | |--leftImg8bit
123 | |--train
124 | |--cologne
125 | |--*******
126 | |--val
127 | |--*******
128 | |--test
129 | |--*******
130 | |--gtFine
131 | |--train
132 | |--cologne
133 | |--*******
134 | |--val
135 | |--*******
136 | |--test
137 | |--*******
138 | ```
139 |
140 | - The contents of the `txt` are shown as follows:
141 | ```
142 | leftImg8bit/train/cologne/cologne_000000_000019_leftImg8bit.png gtFine/train/cologne/cologne_000000_000019_gtFine_labelTrainIds.png
143 | leftImg8bit/train/cologne/cologne_000001_000019_leftImg8bit.png gtFine/train/cologne/cologne_000001_000019_gtFine_labelTrainIds.png
144 | ..............
145 | ```
146 |
147 | - The format of the `txt` are shown as follows:
148 | ```
149 | origin image path + the separator '\t' + label path + the separator '\n'
150 | ```
151 |
152 |
153 | # TODO.....
154 | # 5. How to train
155 | ```
156 | sh train.sh
157 | ```
158 | ## 5.1 Parameters
159 | ```
160 | python -m torch.distributed.launch --nproc_per_node=2 \
161 | train.py --model PSPNet_res50 --out_stride 8 \
162 | --max_epochs 200 --val_epochs 20 --batch_size 4 --lr 0.01 --optim sgd --loss ProbOhemCrossEntropy2d \
163 | --base_size 768 --crop_size 768 --tile_hw_size 768,768 \
164 | --root '/data/open_data' --dataset cityscapes --gpus_id 1,2
165 | ```
166 | # 6. How to validate
167 | ```
168 | sh predict.sh
169 | ```
--------------------------------------------------------------------------------
/builders/dataset_builder.py:
--------------------------------------------------------------------------------
1 | # _*_ coding: utf-8 _*_
2 | """
3 | Time: 2020/11/30 17:02
4 | Author: Ding Cheng(Deeachain)
5 | File: dataset_builder.py
6 | Describe: Write during my study in Nanjing University of Information and Secience Technology
7 | Github: https://github.com/Deeachain
8 | """
9 | import os
10 | import pickle
11 | import pandas as pd
12 | from dataset.cityscapes.cityscapes import CityscapesTrainDataSet, CityscapesTrainInform, CityscapesValDataSet, \
13 | CityscapesTestDataSet
14 |
15 | def build_dataset_train(root, dataset, base_size, crop_size):
16 | data_dir = os.path.join(root, dataset)
17 | train_data_list = os.path.join(data_dir, dataset + '_' + 'train_list.txt')
18 | inform_data_file = os.path.join('./dataset/inform/', dataset + '_inform.pkl')
19 |
20 | # inform_data_file collect the information of mean, std and weigth_class
21 | if not os.path.isfile(inform_data_file):
22 | print("%s is not found" % (inform_data_file))
23 | if dataset == "cityscapes":
24 | dataCollect = CityscapesTrainInform(data_dir, 19, train_set_file=train_data_list,
25 | inform_data_file=inform_data_file)
26 | else:
27 | raise NotImplementedError(
28 | "This repository now supports two datasets: cityscapes and camvid, %s is not included" % dataset)
29 |
30 | datas = dataCollect.collectDataAndSave()
31 | if datas is None:
32 | print("error while pickling data. Please check.")
33 | exit(-1)
34 | else:
35 | datas = pickle.load(open(inform_data_file, "rb"))
36 |
37 | if dataset == "cityscapes":
38 | TrainDataSet = CityscapesTrainDataSet(data_dir, train_data_list, base_size=base_size, crop_size=crop_size,
39 | mean=datas['mean'], std=datas['std'], ignore_label=255)
40 | return datas, TrainDataSet
41 |
42 |
43 | def build_dataset_test(root, dataset, crop_size, mode='whole', gt=False):
44 | data_dir = os.path.join(root, dataset)
45 | inform_data_file = os.path.join('./dataset/inform/', dataset + '_inform.pkl')
46 | train_data_list = os.path.join(data_dir, dataset + '_train_list.txt')
47 | if mode == 'whole':
48 | test_data_list = os.path.join(data_dir, dataset + '_test' + '_list.txt')
49 | else:
50 | test_data_list = os.path.join(data_dir, dataset + '_test_sliding' + '_list.txt')
51 |
52 | # inform_data_file collect the information of mean, std and weigth_class
53 | if not os.path.isfile(inform_data_file):
54 | print("%s is not found" % (inform_data_file))
55 | if dataset == "cityscapes":
56 | dataCollect = CityscapesTrainInform(data_dir, 19, train_set_file=train_data_list,
57 | inform_data_file=inform_data_file)
58 | else:
59 | raise NotImplementedError(
60 | "This repository now supports two datasets: cityscapes and camvid, %s is not included" % dataset)
61 |
62 | datas = dataCollect.collectDataAndSave()
63 | if datas is None:
64 | print("error while pickling data. Please check.")
65 | exit(-1)
66 | else:
67 | datas = pickle.load(open(inform_data_file, "rb"))
68 |
69 | class_dict_df = pd.read_csv(os.path.join('./dataset', dataset, 'class_map.csv'))
70 | if dataset == "cityscapes":
71 | # for cityscapes, if test on validation set, set none_gt to False
72 | # if test on the test set, set none_gt to True
73 | if gt:
74 | test_data_list = os.path.join(data_dir, dataset + '_val' + '_list.txt')
75 | testdataset = CityscapesValDataSet(data_dir, test_data_list, crop_size=crop_size, mean=datas['mean'],
76 | std=datas['std'], ignore_label=255)
77 | else:
78 | test_data_list = os.path.join(data_dir, dataset + '_test' + '_list.txt')
79 | testdataset = CityscapesTestDataSet(data_dir, test_data_list, crop_size=crop_size, mean=datas['mean'],
80 | std=datas['std'], ignore_label=255)
81 | return testdataset, class_dict_df
82 |
--------------------------------------------------------------------------------
/builders/loss_builder.py:
--------------------------------------------------------------------------------
1 | # _*_ coding: utf-8 _*_
2 | """
3 | Time: 2020/11/30 17:02
4 | Author: Ding Cheng(Deeachain)
5 | File: loss_builder.py
6 | Describe: Write during my study in Nanjing University of Information and Secience Technology
7 | Github: https://github.com/Deeachain
8 | """
9 | import torch
10 | from utils.losses.loss import LovaszSoftmax, CrossEntropyLoss2d, CrossEntropyLoss2dLabelSmooth, \
11 | ProbOhemCrossEntropy2d, FocalLoss2d, LabelSmoothing
12 |
13 |
14 | def build_loss(args, datas, ignore_label):
15 | if args.dataset == 'cityscapes':
16 | weight = torch.FloatTensor([0.8373, 0.918, 0.866, 1.0345, 1.0166, 0.9969, 0.9754, 1.0489,
17 | 0.8786, 1.0023, 0.9539, 0.9843, 1.1116, 0.9037, 1.0865, 1.0955,
18 | 1.0865, 1.1529, 1.0507])
19 | elif datas != None:
20 | weight = torch.from_numpy(datas['classWeights'])
21 | else:
22 | weight = None
23 |
24 | # Default uses cross quotient loss function
25 | criteria = CrossEntropyLoss2d(weight=weight, ignore_label=ignore_label)
26 | if args.loss == 'ProbOhemCrossEntropy2d':
27 | h, w = args.base_size, args.base_size
28 | min_kept = int(args.batch_size // len(args.gpus_id) * h * w // 16)
29 | criteria = ProbOhemCrossEntropy2d(weight=weight, ignore_label=ignore_label, thresh=0.7, min_kept=min_kept)
30 | elif args.loss == 'CrossEntropyLoss2dLabelSmooth':
31 | criteria = CrossEntropyLoss2dLabelSmooth(weight=weight, ignore_label=ignore_label)
32 | # criteria = LabelSmoothing()
33 | elif args.loss == 'LovaszSoftmax':
34 | criteria = LovaszSoftmax(ignore_index=ignore_label)
35 | elif args.loss == 'FocalLoss2d':
36 | criteria = FocalLoss2d(weight=weight, ignore_index=ignore_label)
37 |
38 | return criteria
39 |
--------------------------------------------------------------------------------
/builders/model_builder.py:
--------------------------------------------------------------------------------
1 | # _*_ coding: utf-8 _*_
2 | """
3 | Time: 2020/11/30 17:02
4 | Author: Ding Cheng(Deeachain)
5 | File: model_builder.py
6 | Describe: Write during my study in Nanjing University of Information and Secience Technology
7 | Github: https://github.com/Deeachain
8 | """
9 | import torch.nn as nn
10 | import torch.utils.model_zoo as model_zoo
11 | from collections import OrderedDict
12 | from model.UNet import UNet
13 | from model.SegNet import SegNet
14 | from model.FCN8s import FCN
15 | from model.BiSeNet import BiSeNet
16 | from model.BiSeNetV2 import BiSeNetV2
17 | from model.PSPNet.pspnet import PSPNet
18 | from model.DeeplabV3Plus import Deeplabv3plus_res50
19 | from model.FCN_ResNet import FCN_ResNet
20 | from model.DDRNet import DDRNet
21 | from model.HRNet import HighResolutionNet
22 |
23 | model_urls = {
24 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
25 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
26 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
27 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
28 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
29 | }
30 |
31 |
32 | def build_model(model_name, num_classes, backbone='resnet18', pretrained=False, out_stride=32, mult_grid=False):
33 | if model_name == 'FCN':
34 | model = FCN(num_classes=num_classes)
35 | elif model_name == 'FCN_ResNet':
36 | model = FCN_ResNet(num_classes=num_classes, backbone=backbone, out_stride=out_stride, mult_grid=mult_grid)
37 | elif model_name == 'SegNet':
38 | model = SegNet(classes=num_classes)
39 | elif model_name == 'UNet':
40 | model = UNet(num_classes=num_classes)
41 | elif model_name == 'BiSeNet':
42 | model = BiSeNet(num_classes=num_classes, backbone=backbone)
43 | elif model_name == 'BiSeNetV2':
44 | model = BiSeNetV2(num_classes=num_classes)
45 | elif model_name == 'HRNet':
46 | model = HighResolutionNet(num_classes=num_classes)
47 | elif model_name == 'Deeplabv3plus_res50':
48 | model = Deeplabv3plus_res50(num_classes=num_classes, os=out_stride, pretrained=True)
49 | elif model_name == "DDRNet":
50 | model = DDRNet(pretrained=True, num_classes=num_classes)
51 | elif model_name == 'PSPNet_res50':
52 | model = PSPNet(layers=50, bins=(1, 2, 3, 6), dropout=0.1, num_classes=num_classes, zoom_factor=8, use_ppm=True,
53 | pretrained=True)
54 | elif model_name == 'PSPNet_res101':
55 | model = PSPNet(layers=101, bins=(1, 2, 3, 6), dropout=0.1, num_classes=num_classes, zoom_factor=8, use_ppm=True,
56 | pretrained=True)
57 | # elif model_name == 'PSANet50':
58 | # return PSANet(layers=50, dropout=0.1, classes=num_classes, zoom_factor=8, use_psa=True, psa_type=2, compact=compact,
59 | # shrink_factor=shrink_factor, mask_h=mask_h, mask_w=mask_w, psa_softmax=True, pretrained=True)
60 |
61 | if pretrained:
62 | checkpoint = model_zoo.load_url(model_urls[backbone])
63 | model_dict = model.state_dict()
64 | # print(model_dict)
65 | # Screen out layers that are not loaded
66 | pretrained_dict = {'backbone.' + k: v for k, v in checkpoint.items() if 'backbone.' + k in model_dict}
67 | # Update the structure dictionary for the current network
68 | model_dict.update(pretrained_dict)
69 | model.load_state_dict(model_dict)
70 |
71 | return model
72 |
73 |
74 |
--------------------------------------------------------------------------------
/dataset/README.md:
--------------------------------------------------------------------------------
1 | # Supported datasets
2 |
3 | - CityScapes
4 |
5 | Note: When referring to the number of classes, the void/unlabeled class is excluded.
6 |
7 |
8 | ## Cityscapes
9 |
10 | Cityscapes is a set of stereo video sequences recorded in streets from 50 different cities with 34 different classes. There are 5000 images with fine annotations and 20000 images coarsely annotated.
11 |
12 | The version supported here is the finely annotated one with 19 classes.
13 |
14 | For more detailed information see the official [website](https://www.cityscapes-dataset.com/) and [repository](https://github.com/mcordts/cityscapesScripts).
15 |
16 | The dataset can be downloaded from https://www.cityscapes-dataset.com/downloads/. At this time, a registration is required to download the data.
--------------------------------------------------------------------------------
/dataset/calculate_class_weight.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | from tqdm import tqdm
4 | from glob import glob
5 |
6 |
7 | class DatasetsInformation:
8 | """ To get statistical information about the train set, such as mean, std, class distribution.
9 | The class is employed for tackle class imbalance.
10 | """
11 |
12 | def __init__(self, classes=32, normVal=1.10):
13 | """
14 | Args:
15 | data_dir: directory where the dataset is kept
16 | classes: number of classes in the dataset
17 | inform_data_file: location where cached file has to be stored
18 | normVal: normalization value, as defined in ERFNet paper
19 | """
20 | self.classes = classes
21 | self.classWeights = np.ones(self.classes, dtype=np.float32)
22 | self.normVal = normVal
23 | self.mean = np.zeros(3, dtype=np.float32)
24 | self.std = np.zeros(3, dtype=np.float32)
25 |
26 | def compute_class_weights(self, histogram):
27 | """to compute the class weights
28 | Args:
29 | histogram: distribution of class samples
30 | """
31 | normHist = histogram / np.sum(histogram)
32 |
33 | for i in range(self.classes):
34 | self.classWeights[i] = 0.1 / (np.log(self.normVal + normHist[i]))
35 | return histogram, normHist * 100, self.classWeights
36 |
37 | def compute_class_weight_with_media_freq(self, n_classes, labels):
38 | count = np.zeros(n_classes)
39 | image_count = np.zeros(n_classes)
40 | example = cv2.imread(labels[0])
41 | h, w, c = example.shape
42 |
43 | for label in tqdm(labels, desc="Claculating Class Weight"):
44 | data = cv2.imdecode(np.fromfile(label, dtype=np.uint8), -1)
45 |
46 | for c in range(n_classes):
47 | c_sum = np.sum(data == c) # 统计c类像素的个数
48 | count[c] += c_sum
49 | if np.sum(data == c) != 0: # 判断该图片中是否存在第c类像素,如果存在则第c类图片个数+1
50 | image_count[c] += 1
51 |
52 | frequency = count / (image_count * h * w)
53 | median = np.median(frequency)
54 | weight = median / frequency
55 |
56 | return frequency, weight
57 |
58 | def readWholeTrainSet(self, image, label, train_flag=True):
59 | """to read the whole train set of current dataset.
60 | Args:
61 | fileName: train set file that stores the image locations
62 | trainStg: if processing training or validation data
63 |
64 | return: 0 if successful
65 | """
66 | global_hist = np.zeros(self.classes, dtype=np.float32)
67 |
68 | no_files = 0
69 | min_val_al = 0
70 | max_val_al = 0
71 |
72 | for index, label_file in tqdm(iterable=enumerate(label), desc='Calculate_ClassWeight', total=len(label)):
73 | img_file = image[index]
74 |
75 | label_img = cv2.imread(label_file, 0)
76 | unique_values = np.unique(label_img)
77 | max_val = max(unique_values)
78 | min_val = min(unique_values)
79 |
80 | max_val_al = max(max_val, max_val_al)
81 | min_val_al = min(min_val, min_val_al)
82 |
83 | if train_flag == True:
84 | hist = np.histogram(label_img, self.classes, range=(0, self.classes - 1))
85 | global_hist += hist[0]
86 |
87 | rgb_img = cv2.imread(img_file)
88 | self.mean[0] += np.mean(rgb_img[:, :, 0])
89 | self.mean[1] += np.mean(rgb_img[:, :, 1])
90 | self.mean[2] += np.mean(rgb_img[:, :, 2])
91 |
92 | self.std[0] += np.std(rgb_img[:, :, 0])
93 | self.std[1] += np.std(rgb_img[:, :, 1])
94 | self.std[2] += np.std(rgb_img[:, :, 2])
95 |
96 | else:
97 | print("we can only collect statistical information of train set, please check")
98 |
99 | if max_val > (self.classes - 1) or min_val < 0:
100 | print('Labels can take value between 0 and number of classes.')
101 | print('Some problem with labels. Please check. label_set:', unique_values)
102 | print('Label Image ID: ' + label_file)
103 | no_files += 1
104 |
105 | # divide the mean and std values by the sample space size
106 | self.mean /= no_files
107 | self.std /= no_files
108 |
109 | # compute the class imbalance information
110 | self.compute_class_weights(global_hist)
111 | self.compute_class_weight_with_media_freq(self.classes, label)
112 | return self.mean, self.std, self.compute_class_weights(global_hist), self.compute_class_weight_with_media_freq(
113 | self.classes, label)
114 |
115 |
116 | if __name__ == '__main__':
117 | image = glob('/media/ding/Data/datasets/paris/512_image_625/crop_files/512_image/*.png')
118 | label = glob('/media/ding/Data/datasets/paris/512_image_625/crop_files/512_label/*.png')
119 | # image = glob('/media/ding/Data/datasets/paris/paris_origin/*_image.png')
120 | # label = glob('/media/ding/Data/datasets/paris/paris_origin/*labels_gray.png')
121 | class_num = 3
122 |
123 | info = DatasetsInformation(classes=class_num)
124 | out = info.readWholeTrainSet(image=image, label=label, train_flag=True)
125 |
126 | np.set_printoptions(suppress=True)
127 | print('Std is \n', out[0]) # 计算数据集图片的均值
128 | print('\nMean is \n', out[1]) # 计算数据集图片的均值
129 | print('\nPerClass Count is \n', np.array(out[2][0], dtype=int)) # 计算每个类别的个数,返回值是一个列表,列表的长度是类和个数
130 | print('\nPercentPerClass is (%)\n', out[2][1]) # 计算每个类别所占比例百分比
131 | print('\nCalcClassWeight is \n', out[2][2]) # 计算类别的权重
132 | print('\nMedia class Freq and weight', out[3][0], out[3][1])
133 |
134 | import matplotlib.pyplot as plt
135 |
136 | plt.figure(figsize=(10, 8))
137 | # plt.rcParams['font.sans-serif'] = ['SimSun']
138 | # plt.rcParams['axes.unicode_minus'] = False
139 | X = np.arange(0, class_num)
140 | Y = out[2][1]
141 | print(Y)
142 | plt.bar(x=X, height=Y, color="c", width=0.8)
143 | plt.xticks(X, fontsize=20) # 标注横坐标的类别名
144 | plt.yticks(fontsize=20) # 标注纵坐标
145 | for x, y in zip(X, Y):
146 | plt.text(x, y, '%.2f' % y, ha='center', va='bottom', fontsize=10) # 使用matplotlib画柱状图并标记数字
147 | plt.xlabel("Class", fontsize=12)
148 | plt.ylabel("Percent(%)", fontsize=12)
149 | plt.title("Category scale distribution", fontsize=16)
150 | plt.savefig('Paris Percent.png')
151 | plt.show()
152 |
--------------------------------------------------------------------------------
/dataset/cityscapes/Percent.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Deeachain/Segmentation-Pytorch/acc6998863dfef884bc5fe954c2b8de1c28576a7/dataset/cityscapes/Percent.png
--------------------------------------------------------------------------------
/dataset/cityscapes/cityscape_scripts/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Deeachain/Segmentation-Pytorch/acc6998863dfef884bc5fe954c2b8de1c28576a7/dataset/cityscapes/cityscape_scripts/__init__.py
--------------------------------------------------------------------------------
/dataset/cityscapes/cityscape_scripts/download_cityscapes.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | global_path='../../../vision_datasets'
3 | data_dir=$global_path'/cityscapes'
4 |
5 | mkdir -p $data_dir
6 | cd $data_dir
7 |
8 | # enter user details
9 | uname='' #
10 | pass=''
11 |
12 | wget --keep-session-cookies --save-cookies=cookies.txt --post-data 'username='$uname'&password='$pass'&submit=Login' https://www.cityscapes-dataset.com/login/
13 | wget --load-cookies cookies.txt --content-disposition https://www.cityscapes-dataset.com/file-handling/?packageID=1
14 | wget --load-cookies cookies.txt --content-disposition https://www.cityscapes-dataset.com/file-handling/?packageID=3
15 | # Uncomment if you want to download coarse
16 | #wget --load-cookies cookies.txt --content-disposition https://www.cityscapes-dataset.com/file-handling/?packageID=4
17 | #wget --load-cookies cookies.txt --content-disposition https://www.cityscapes-dataset.com/file-handling/?packageID=2
18 |
19 |
20 | #unzip -q -o gtCoarse.zip
21 | unzip -q -o gtFine_trainvaltest.zip
22 | #unzip -q -o leftImg8bit_trainextra.zip
23 | unzip -q -o leftImg8bit_trainvaltest.zip
24 |
25 | #rm -rf gtCoarse.zip
26 | rm -rf gtFine_trainvaltest.zip
27 | #rm -rf leftImg8bit_trainextra.zip
28 | rm -rf leftImg8bit_trainvaltest.zip
--------------------------------------------------------------------------------
/dataset/cityscapes/cityscapes.py:
--------------------------------------------------------------------------------
1 | # _*_ coding: utf-8 _*_
2 | """
3 | Time: 2021/3/12 10:28
4 | Author: Ding Cheng(Deeachain)
5 | File: cityscapes.py
6 | Github: https://github.com/Deeachain
7 | """
8 | import os.path as osp
9 | import numpy as np
10 | import cv2
11 | from torch.utils import data
12 | import pickle
13 | from PIL import Image
14 | from torchvision import transforms
15 | from utils import image_transform as tr
16 |
17 |
18 | class CityscapesTrainDataSet(data.Dataset):
19 | """
20 | CityscapesTrainDataSet is employed to load train set
21 | Args:
22 | root: the Cityscapes dataset path,
23 | cityscapes
24 | ├── gtFine
25 | ├── leftImg8bit
26 | list_path: cityscapes_train_list.txt, include partial path
27 | mean: bgr_mean (73.15835921, 82.90891754, 72.39239876)
28 |
29 | """
30 |
31 | def __init__(self, root='', list_path='', max_iters=None, base_size=513, crop_size=513, mean=(128, 128, 128),
32 | std=(128, 128, 128), ignore_label=255):
33 | self.root = root
34 | self.list_path = list_path
35 | self.base_size = base_size
36 | self.crop_size = crop_size
37 | self.mean = mean
38 | self.std = std
39 | self.ignore_label = ignore_label
40 | self.img_ids = [i_id.strip() for i_id in open(list_path)]
41 | if not max_iters == None:
42 | self.img_ids = self.img_ids * int(np.ceil(float(max_iters) / len(self.img_ids)))
43 | self.files = []
44 |
45 | for name in self.img_ids:
46 | img_file = osp.join(self.root, name.split()[0])
47 | label_file = osp.join(self.root, name.split()[1])
48 | name = name.strip().split()[0].strip().split('/', 3)[3].split('.')[0]
49 | self.files.append({"img": img_file, "label": label_file, "name": name})
50 |
51 | print("length of train dataset: ", len(self.files))
52 |
53 | def __len__(self):
54 | return len(self.files)
55 |
56 | def __getitem__(self, index):
57 | datafiles = self.files[index]
58 | image = Image.open(datafiles["img"]).convert('RGB')
59 | label = Image.open(datafiles["label"])
60 | size = np.asarray(image).shape
61 | name = datafiles["name"]
62 |
63 | composed_transforms = transforms.Compose([
64 | tr.RandomHorizontalFlip(),
65 | # tr.RandomRotate(180),
66 | tr.RandomScaleCrop(base_size=self.base_size, crop_size=self.crop_size, fill=255),
67 | tr.RandomGaussianBlur(),
68 | tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
69 | tr.ToTensor()])
70 | sample = {'image': image, 'label': label}
71 | sampled = composed_transforms(sample)
72 | image, label = sampled['image'], sampled['label']
73 | return image, label, np.array(size), name
74 |
75 |
76 | class CityscapesValDataSet(data.Dataset):
77 | """
78 | CityscapesDataSet is employed to load val set
79 | Args:
80 | root: the Cityscapes dataset path,
81 | cityscapes
82 | ├── gtFine
83 | ├── leftImg8bit
84 | list_path: cityscapes_val_list.txt, include partial path
85 |
86 | """
87 |
88 | def __init__(self, root='', list_path='', crop_size=513, mean=(128, 128, 128), std=(128, 128, 128),
89 | ignore_label=255):
90 | self.root = root
91 | self.list_path = list_path
92 | self.crop_size = crop_size
93 | self.mean = mean
94 | self.std = std
95 | self.ignore_label = ignore_label
96 | self.img_ids = [i_id.strip() for i_id in open(list_path)]
97 | self.files = []
98 | for name in self.img_ids:
99 | img_file = osp.join(self.root, name.split()[0])
100 | label_file = osp.join(self.root, name.split()[1])
101 | name = name.strip().split()[0].strip().split('/', 3)[3].split('.')[0]
102 | self.files.append({"img": img_file, "label": label_file, "name": name})
103 |
104 | print("length of validation dataset: ", len(self.files))
105 |
106 | def __len__(self):
107 | return len(self.files)
108 |
109 | def __getitem__(self, index):
110 | datafiles = self.files[index]
111 | image = Image.open(datafiles["img"]).convert('RGB')
112 | label = Image.open(datafiles["label"])
113 | size = np.asarray(image).shape
114 | name = datafiles["name"]
115 | composed_transforms = transforms.Compose([
116 | tr.FixScaleCrop(crop_size=self.crop_size),
117 | # tr.FixedResize(size=(1024,512)),
118 | tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
119 | tr.ToTensor()])
120 | sample = {'image': image, 'label': label}
121 | sampled = composed_transforms(sample)
122 | image, label = sampled['image'], sampled['label']
123 |
124 | return image, label, np.array(size), name
125 |
126 |
127 | class CityscapesTestDataSet(data.Dataset):
128 | """
129 | CityscapesDataSet is employed to load test set
130 | Args:
131 | root: the Cityscapes dataset path,
132 | list_path: cityscapes_test_list.txt, include partial path
133 |
134 | """
135 |
136 | def __init__(self, root='', list_path='', crop_size=513, mean=(128, 128, 128), std=(128, 128, 128),
137 | ignore_label=255):
138 | self.root = root
139 | self.list_path = list_path
140 | self.crop_size = crop_size
141 | self.mean = mean
142 | self.std = std
143 | self.ignore_label = ignore_label
144 | self.img_ids = [i_id.strip() for i_id in open(list_path)]
145 | self.files = []
146 | for name in self.img_ids:
147 | img_file = osp.join(self.root, name.split()[0])
148 | name = name.strip().split()[0].strip().split('/', 3)[3].split('.')[0]
149 | self.files.append({"img": img_file, "name": name})
150 |
151 | print("length of validation dataset: ", len(self.files))
152 |
153 | def __len__(self):
154 | return len(self.files)
155 |
156 | def __getitem__(self, index):
157 | datafiles = self.files[index]
158 | image = Image.open(datafiles["img"]).convert('RGB')
159 | size = np.asarray(image).shape
160 | name = datafiles["name"]
161 | composed_transforms = transforms.Compose([
162 | # tr.FixedResize(size=(1024,512)),
163 | tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
164 | tr.ToTensor()])
165 | sample = {'image': image}
166 | sampled = composed_transforms(sample)
167 | image = sampled['image']
168 |
169 | return image, np.array(size), name
170 |
171 |
172 | class CityscapesTrainInform:
173 | """ To get statistical information about the train set, such as mean, std, class distribution.
174 | The class is employed for tackle class imbalance.
175 | """
176 |
177 | def __init__(self, data_dir='', classes=19,
178 | train_set_file="", inform_data_file="", normVal=1.10):
179 | """
180 | Args:
181 | data_dir: directory where the dataset is kept
182 | classes: number of classes in the dataset
183 | inform_data_file: location where cached file has to be stored
184 | normVal: normalization value, as defined in ERFNet paper
185 | """
186 | self.data_dir = data_dir
187 | self.classes = classes
188 | self.classWeights = np.ones(self.classes, dtype=np.float32)
189 | self.normVal = normVal
190 | self.mean = np.zeros(3, dtype=np.float32)
191 | self.std = np.zeros(3, dtype=np.float32)
192 | self.train_set_file = train_set_file
193 | self.inform_data_file = inform_data_file
194 |
195 | def compute_class_weights(self, histogram):
196 | """to compute the class weights
197 | Args:
198 | histogram: distribution of class samples
199 | """
200 | normHist = histogram / np.sum(histogram)
201 | for i in range(self.classes):
202 | self.classWeights[i] = 1 / (np.log(self.normVal + normHist[i]))
203 |
204 | def readWholeTrainSet(self, fileName, train_flag=True):
205 | """to read the whole train set of current dataset.
206 | Args:
207 | fileName: train set file that stores the image locations
208 | trainStg: if processing training or validation data
209 |
210 | return: 0 if successful
211 | """
212 | global_hist = np.zeros(self.classes, dtype=np.float32)
213 |
214 | no_files = 0
215 | min_val_al = 0
216 | max_val_al = 0
217 | with open(self.data_dir + '/' + fileName, 'r') as textFile:
218 | for line in textFile:
219 | # we expect the text file to contain the data in following format
220 | #
221 | line_arr = line.split()
222 | img_file = ((self.data_dir).strip() + '/' + line_arr[0].strip()).strip()
223 | label_file = ((self.data_dir).strip() + '/' + line_arr[1].strip()).strip()
224 |
225 | label_img = cv2.imread(label_file, 0)
226 | unique_values = np.unique(label_img)
227 | max_val = max(unique_values)
228 | min_val = min(unique_values)
229 |
230 | max_val_al = max(max_val, max_val_al)
231 | min_val_al = min(min_val, min_val_al)
232 |
233 | if train_flag == True:
234 | hist = np.histogram(label_img, self.classes, range=(0, 18))
235 | global_hist += hist[0]
236 |
237 | rgb_img = cv2.imread(img_file)
238 | self.mean[0] += np.mean(rgb_img[:, :, 0])
239 | self.mean[1] += np.mean(rgb_img[:, :, 1])
240 | self.mean[2] += np.mean(rgb_img[:, :, 2])
241 |
242 | self.std[0] += np.std(rgb_img[:, :, 0])
243 | self.std[1] += np.std(rgb_img[:, :, 1])
244 | self.std[2] += np.std(rgb_img[:, :, 2])
245 |
246 | else:
247 | print("we can only collect statistical information of train set, please check")
248 |
249 | if max_val > (self.classes - 1) or min_val < 0:
250 | print('Labels can take value between 0 and number of classes.')
251 | print('Some problem with labels. Please check. label_set:', unique_values)
252 | print('Label Image ID: ' + label_file)
253 | no_files += 1
254 |
255 | # divide the mean and std values by the sample space size
256 | self.mean /= no_files * 255
257 | self.std /= no_files
258 |
259 | # compute the class imbalance information
260 | self.compute_class_weights(global_hist)
261 | return 0
262 |
263 | def collectDataAndSave(self):
264 | """ To collect statistical information of train set and then save it.
265 | The file train.txt should be inside the data directory.
266 | """
267 | print('Processing training data')
268 | return_val = self.readWholeTrainSet(fileName=self.train_set_file)
269 |
270 | print('Pickling data')
271 | if return_val == 0:
272 | data_dict = dict()
273 | data_dict['mean'] = self.mean
274 | data_dict['std'] = self.std
275 | data_dict['classWeights'] = self.classWeights
276 | pickle.dump(data_dict, open(self.inform_data_file, "wb"))
277 | return data_dict
278 | return None
279 |
--------------------------------------------------------------------------------
/dataset/cityscapes/class_map.csv:
--------------------------------------------------------------------------------
1 | label_index,class_name,rgb,gray
2 | 0,road,128 64 128,0
3 | 1,sidewalk,244 35 232,0
4 | 2,building,70 70 70,0
5 | 3,wall,102 102 156,0
6 | 4,fence,190 153 153,0
7 | 5,pole,153 153 153,0
8 | 6,traffic light,250 170 30,0
9 | 7,traffic sign,220 220 0,0
10 | 8,vegetation,107 142 35,0
11 | 9,terrain,152 251 152,0
12 | 10,sky,70 130 180,0
13 | 11,person,220 20 60,0
14 | 12,rider,255 0 0,0
15 | 13,car,0 0 142,0
16 | 14,truck,0 0 70,0
17 | 15,bus,0 60 100,0
18 | 16,train,0 80 100,0
19 | 17,motorcycle,0 0 230,0
20 | 18,bicycle,119 11 32,0
--------------------------------------------------------------------------------
/dataset/generate_txt.py:
--------------------------------------------------------------------------------
1 | # _*_ coding: utf-8 _*_
2 | """
3 | Time: 2020/11/30 17:02
4 | Author: Ding Cheng(Deeachain)
5 | File: generate_txt.py
6 | Describe: Write during my study in Nanjing University of Information and Secience Technology
7 | Github: https://github.com/Deeachain
8 | """
9 | import glob
10 | import random
11 | import os
12 |
13 |
14 | def generate_txt(mode='train'):
15 | root_dir = '/data/open_data/cityscapes/leftImg8bit/'
16 | dir_list = os.listdir(os.path.join(root_dir, mode))
17 | filename_list = []
18 | for dir in dir_list:
19 | filename = glob.glob(os.path.join(root_dir, mode, dir) + '/'+'*.png')
20 | filename.sort()
21 |
22 | random.shuffle(filename)
23 | filename_list.extend(filename)
24 | with open('cityscapes_{}_list.txt'.format(mode), 'w+') as f:
25 | for filename in filename_list[:]:
26 | filename_gt = filename.replace('leftImg8bit', 'gtFine').replace('.png', '_labelTrainIds.png')
27 | print(filename, filename_gt)
28 | f.write('{}/{}/{}/{}\t{}/{}/{}/{}\n'.format(filename.split('/')[-4], filename.split('/')[-3],
29 | filename.split('/')[-2], filename.split('/')[-1],
30 | filename_gt.split('/')[-4], filename_gt.split('/')[-3],
31 | filename_gt.split('/')[-2], filename_gt.split('/')[-1]))
32 |
33 |
34 | if __name__ == '__main__':
35 | generate_txt('train')
36 | print('Finsh!')
37 |
--------------------------------------------------------------------------------
/dataset/inform/cityscapes_inform.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Deeachain/Segmentation-Pytorch/acc6998863dfef884bc5fe954c2b8de1c28576a7/dataset/inform/cityscapes_inform.pkl
--------------------------------------------------------------------------------
/example/aachen_000000_000019_gtFine_color.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Deeachain/Segmentation-Pytorch/acc6998863dfef884bc5fe954c2b8de1c28576a7/example/aachen_000000_000019_gtFine_color.png
--------------------------------------------------------------------------------
/example/aachen_000000_000019_gtFine_labelTrainIds.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Deeachain/Segmentation-Pytorch/acc6998863dfef884bc5fe954c2b8de1c28576a7/example/aachen_000000_000019_gtFine_labelTrainIds.png
--------------------------------------------------------------------------------
/example/aachen_000000_000019_leftImg8bit.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Deeachain/Segmentation-Pytorch/acc6998863dfef884bc5fe954c2b8de1c28576a7/example/aachen_000000_000019_leftImg8bit.png
--------------------------------------------------------------------------------
/example/average_results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Deeachain/Segmentation-Pytorch/acc6998863dfef884bc5fe954c2b8de1c28576a7/example/average_results.png
--------------------------------------------------------------------------------
/example/class_results1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Deeachain/Segmentation-Pytorch/acc6998863dfef884bc5fe954c2b8de1c28576a7/example/class_results1.png
--------------------------------------------------------------------------------
/example/class_results2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Deeachain/Segmentation-Pytorch/acc6998863dfef884bc5fe954c2b8de1c28576a7/example/class_results2.png
--------------------------------------------------------------------------------
/example/class_results3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Deeachain/Segmentation-Pytorch/acc6998863dfef884bc5fe954c2b8de1c28576a7/example/class_results3.png
--------------------------------------------------------------------------------
/example/lindau_000000_000019_leftImg8bit.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Deeachain/Segmentation-Pytorch/acc6998863dfef884bc5fe954c2b8de1c28576a7/example/lindau_000000_000019_leftImg8bit.png
--------------------------------------------------------------------------------
/example/lindau_000000_000019_leftImg8bit_color.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Deeachain/Segmentation-Pytorch/acc6998863dfef884bc5fe954c2b8de1c28576a7/example/lindau_000000_000019_leftImg8bit_color.png
--------------------------------------------------------------------------------
/example/lindau_000000_000019_leftImg8bit_gt.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Deeachain/Segmentation-Pytorch/acc6998863dfef884bc5fe954c2b8de1c28576a7/example/lindau_000000_000019_leftImg8bit_gt.png
--------------------------------------------------------------------------------
/model/BiSeNet.py:
--------------------------------------------------------------------------------
1 | # _*_ coding: utf-8 _*_
2 | """
3 | Time: 2020/11/30 19:27
4 | Author: Ding Cheng(Deeachain)
5 | File: BiSeNet.py
6 | Describe: Write during my study in Nanjing University of Information and Secience Technology
7 | Github: https://github.com/Deeachain
8 | """
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | from torchsummary import summary
13 | from model.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
14 | from model.base_model import build_backbone
15 |
16 |
17 | class ConvBnRelu(nn.Module):
18 | def __init__(self, in_planes, out_planes, ksize, stride, pad, dilation=1,
19 | groups=1, has_bn=True, norm_layer=nn.BatchNorm2d, bn_eps=1e-5,
20 | has_relu=True, inplace=True, has_bias=False):
21 | super(ConvBnRelu, self).__init__()
22 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=ksize,
23 | stride=stride, padding=pad,
24 | dilation=dilation, groups=groups, bias=has_bias)
25 | self.has_bn = has_bn
26 | if self.has_bn:
27 | self.bn = norm_layer(out_planes, eps=bn_eps)
28 | self.has_relu = has_relu
29 | if self.has_relu:
30 | self.relu = nn.ReLU(inplace=inplace)
31 |
32 | def forward(self, x):
33 | x = self.conv(x)
34 | if self.has_bn:
35 | x = self.bn(x)
36 | if self.has_relu:
37 | x = self.relu(x)
38 |
39 | return x
40 |
41 |
42 | class AttentionRefinement(nn.Module):
43 | def __init__(self, in_planes, out_planes,
44 | norm_layer=nn.BatchNorm2d):
45 | super(AttentionRefinement, self).__init__()
46 | self.conv_3x3 = ConvBnRelu(in_planes, out_planes, 3, 1, 1,
47 | has_bn=True, norm_layer=norm_layer,
48 | has_relu=True, has_bias=False)
49 | self.channel_attention = nn.Sequential(
50 | nn.AdaptiveAvgPool2d(1),
51 | ConvBnRelu(out_planes, out_planes, 1, 1, 0,
52 | has_bn=True, norm_layer=norm_layer,
53 | has_relu=False, has_bias=False),
54 | nn.Sigmoid()
55 | )
56 |
57 | def forward(self, x):
58 | fm = self.conv_3x3(x)
59 | fm_se = self.channel_attention(fm)
60 | fm = fm * fm_se
61 |
62 | return fm
63 |
64 |
65 | class FeatureFusion(nn.Module):
66 | def __init__(self, in_planes, out_planes,
67 | reduction=1, norm_layer=nn.BatchNorm2d):
68 | super(FeatureFusion, self).__init__()
69 | self.conv_1x1 = ConvBnRelu(in_planes, out_planes, 1, 1, 0,
70 | has_bn=True, norm_layer=norm_layer,
71 | has_relu=True, has_bias=False)
72 | self.channel_attention = nn.Sequential(
73 | nn.AdaptiveAvgPool2d(1),
74 | ConvBnRelu(out_planes, out_planes // reduction, 1, 1, 0,
75 | has_bn=False, norm_layer=norm_layer,
76 | has_relu=True, has_bias=False),
77 | ConvBnRelu(out_planes // reduction, out_planes, 1, 1, 0,
78 | has_bn=False, norm_layer=norm_layer,
79 | has_relu=False, has_bias=False),
80 | nn.Sigmoid()
81 | )
82 |
83 | def forward(self, x1, x2):
84 | fm = torch.cat([x1, x2], dim=1)
85 | fm = self.conv_1x1(fm)
86 | fm_se = self.channel_attention(fm)
87 | output = fm + fm * fm_se
88 | return output
89 |
90 |
91 | class BiSeNetHead(nn.Module):
92 | def __init__(self, in_planes, out_planes, is_aux=False, norm_layer=nn.BatchNorm2d):
93 | super(BiSeNetHead, self).__init__()
94 | if is_aux:
95 | self.conv_3x3 = ConvBnRelu(in_planes, 256, 3, 1, 1, has_bn=True, norm_layer=norm_layer, has_relu=True,
96 | has_bias=False)
97 | else:
98 | self.conv_3x3 = ConvBnRelu(in_planes, 64, 3, 1, 1, has_bn=True, norm_layer=norm_layer, has_relu=True,
99 | has_bias=False)
100 | if is_aux:
101 | self.conv_1x1 = nn.Conv2d(256, out_planes, kernel_size=1, stride=1, padding=0)
102 | else:
103 | self.conv_1x1 = nn.Conv2d(64, out_planes, kernel_size=1, stride=1, padding=0)
104 |
105 | def forward(self, x):
106 | fm = self.conv_3x3(x)
107 | output = self.conv_1x1(fm)
108 | return output
109 |
110 |
111 | class SpatialPath(nn.Module):
112 | def __init__(self, in_planes, out_planes, norm_layer=nn.BatchNorm2d):
113 | super(SpatialPath, self).__init__()
114 | inner_channel = 64
115 | self.conv_7x7 = ConvBnRelu(in_planes, inner_channel, 7, 2, 3,
116 | has_bn=True, norm_layer=norm_layer,
117 | has_relu=True, has_bias=False)
118 | self.conv_3x3_1 = ConvBnRelu(inner_channel, inner_channel, 3, 2, 1,
119 | has_bn=True, norm_layer=norm_layer,
120 | has_relu=True, has_bias=False)
121 | self.conv_3x3_2 = ConvBnRelu(inner_channel, inner_channel, 3, 2, 1,
122 | has_bn=True, norm_layer=norm_layer,
123 | has_relu=True, has_bias=False)
124 | self.conv_1x1 = ConvBnRelu(inner_channel, out_planes, 1, 1, 0,
125 | has_bn=True, norm_layer=norm_layer,
126 | has_relu=True, has_bias=False)
127 |
128 | def forward(self, x):
129 | x = self.conv_7x7(x)
130 | x = self.conv_3x3_1(x)
131 | x = self.conv_3x3_2(x)
132 | output = self.conv_1x1(x)
133 |
134 | return output
135 |
136 |
137 | class BiSeNet(nn.Module):
138 | def __init__(self, num_classes, norm_layer=nn.BatchNorm2d, backbone='resnet18'):
139 | super(BiSeNet, self).__init__()
140 | conv_channel = 128
141 | self.spatial_path = SpatialPath(3, conv_channel, norm_layer)
142 |
143 | if backbone == 'resnet18' or backbone == 'resnet34':
144 | expansion = 1
145 | elif backbone == 'resnet50' or backbone == 'resnet101':
146 | expansion = 4
147 |
148 | self.backbone = build_backbone(backbone)
149 |
150 | # resnet layers < 50 stage = [512, 256, 128, 64]; resnet layers > 50 stage = [2048, 1024, 512, 256]
151 | self.global_context = nn.Sequential(nn.AdaptiveAvgPool2d(1), ConvBnRelu(512 * expansion, conv_channel, 1, 1, 0,
152 | has_bn=True, has_relu=True,
153 | has_bias=False, norm_layer=norm_layer))
154 |
155 | self.arms1 = AttentionRefinement(512 * expansion, conv_channel, norm_layer)
156 | self.arms2 = AttentionRefinement(256 * expansion, conv_channel, norm_layer)
157 | self.refines = ConvBnRelu(conv_channel, conv_channel, 3, 1, 1,
158 | has_bn=True, norm_layer=norm_layer,
159 | has_relu=True, has_bias=False)
160 |
161 | self.heads1 = BiSeNetHead(conv_channel * 2, num_classes, False, norm_layer)
162 | self.heads2 = BiSeNetHead(conv_channel, num_classes, True, norm_layer)
163 | self.heads3 = BiSeNetHead(conv_channel, num_classes, True, norm_layer)
164 |
165 | self.ffm = FeatureFusion(conv_channel * 2, conv_channel * 2, 1, norm_layer)
166 |
167 | self._init_weight()
168 |
169 | def forward(self, x):
170 | size = x.shape
171 | spatial_out = self.spatial_path(x)
172 |
173 | context_blocks = self.backbone(x)
174 |
175 | global_context = self.global_context(context_blocks[-1]) # change channel
176 | global_context = F.interpolate(global_context, size=context_blocks[-1].size()[2:], mode='bilinear',
177 | align_corners=True)
178 | last_fm = global_context
179 |
180 | arm1 = self.arms1(context_blocks[-1])
181 | arm1 += last_fm
182 | arm1 = F.interpolate(arm1, size=(context_blocks[-2].size()[2:]), mode='bilinear', align_corners=True)
183 | arm1 = self.refines(arm1)
184 | last_fm = arm1
185 |
186 | arm2 = self.arms2(context_blocks[-2])
187 | arm2 += last_fm
188 | arm2 = F.interpolate(arm2, size=(context_blocks[-3].size()[2:]), mode='bilinear', align_corners=True)
189 | arm2 = self.refines(arm2)
190 | context_out = arm2
191 |
192 | concate_fm = self.ffm(spatial_out, context_out)
193 |
194 | main = self.heads1(concate_fm)
195 | aux_0 = self.heads2(arm2)
196 | aux_1 = self.heads3(arm1)
197 | main = F.interpolate(main, size=size[2:], mode='bilinear', align_corners=True)
198 | aux_0 = F.interpolate(aux_0, size=size[2:], mode='bilinear', align_corners=True)
199 | aux_1 = F.interpolate(aux_1, size=size[2:], mode='bilinear', align_corners=True)
200 |
201 | return main, aux_0, aux_1
202 |
203 | def _init_weight(self):
204 | for m in self.modules():
205 | if isinstance(m, nn.Conv2d):
206 | torch.nn.init.kaiming_normal_(m.weight)
207 | elif isinstance(m, SynchronizedBatchNorm2d):
208 | m.weight.data.fill_(1)
209 | m.bias.data.zero_()
210 | elif isinstance(m, nn.BatchNorm2d):
211 | m.weight.data.fill_(1)
212 | m.bias.data.zero_()
213 |
214 | def freeze_bn(self):
215 | for m in self.modules():
216 | if isinstance(m, SynchronizedBatchNorm2d):
217 | m.eval()
218 | elif isinstance(m, nn.BatchNorm2d):
219 | m.eval()
220 |
221 | def get_1x_lr_params(self):
222 | modules = [self.backbone]
223 | for i in range(len(modules)):
224 | for m in modules[i].named_modules():
225 | if self.freeze_bn:
226 | if isinstance(m[1], nn.Conv2d):
227 | for p in m[1].parameters():
228 | if p.requires_grad:
229 | yield p
230 | else:
231 | if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \
232 | or isinstance(m[1], nn.BatchNorm2d):
233 | for p in m[1].parameters():
234 | if p.requires_grad:
235 | yield p
236 |
237 | def get_10x_lr_params(self):
238 | modules = [self.global_context, self.arms1, self.arms2, self.refines, self.heads1, self.heads2, self.heads3, self.ffm]
239 | for i in range(len(modules)):
240 | for m in modules[i].named_modules():
241 | if self.freeze_bn:
242 | if isinstance(m[1], nn.Conv2d):
243 | for p in m[1].parameters():
244 | if p.requires_grad:
245 | yield p
246 | else:
247 | if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \
248 | or isinstance(m[1], nn.BatchNorm2d):
249 | for p in m[1].parameters():
250 | if p.requires_grad:
251 | yield p
252 |
253 |
254 | """print layers and params of network"""
255 | if __name__ == '__main__':
256 | model = BiSeNet(num_classes=3, backbone='resnet18')
257 | print(model)
258 | # summary(model, (3, 512, 512), device="cpu")
259 |
--------------------------------------------------------------------------------
/model/BiSeNetV2.py:
--------------------------------------------------------------------------------
1 | # _*_ coding: utf-8 _*_
2 | """
3 | Reference from: https://github.com/CoinCheung/BiSeNet
4 | """
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | from torchsummary import summary
9 |
10 |
11 | class ConvBNReLU(nn.Module):
12 |
13 | def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1,
14 | dilation=1, groups=1, bias=False):
15 | super(ConvBNReLU, self).__init__()
16 | self.conv = nn.Conv2d(
17 | in_chan, out_chan, kernel_size=ks, stride=stride,
18 | padding=padding, dilation=dilation,
19 | groups=groups, bias=bias)
20 | self.bn = nn.BatchNorm2d(out_chan)
21 | self.relu = nn.ReLU(inplace=True)
22 |
23 | def forward(self, x):
24 | feat = self.conv(x)
25 | feat = self.bn(feat)
26 | feat = self.relu(feat)
27 | return feat
28 |
29 |
30 | class DetailBranch(nn.Module):
31 |
32 | def __init__(self):
33 | super(DetailBranch, self).__init__()
34 | self.S1 = nn.Sequential(
35 | ConvBNReLU(3, 64, 3, stride=2),
36 | ConvBNReLU(64, 64, 3, stride=1),
37 | )
38 | self.S2 = nn.Sequential(
39 | ConvBNReLU(64, 64, 3, stride=2),
40 | ConvBNReLU(64, 64, 3, stride=1),
41 | ConvBNReLU(64, 64, 3, stride=1),
42 | )
43 | self.S3 = nn.Sequential(
44 | ConvBNReLU(64, 128, 3, stride=2),
45 | ConvBNReLU(128, 128, 3, stride=1),
46 | ConvBNReLU(128, 128, 3, stride=1),
47 | )
48 |
49 | def forward(self, x):
50 | feat = self.S1(x)
51 | feat = self.S2(feat)
52 | feat = self.S3(feat)
53 | return feat
54 |
55 |
56 | class StemBlock(nn.Module):
57 |
58 | def __init__(self):
59 | super(StemBlock, self).__init__()
60 | self.conv = ConvBNReLU(3, 16, 3, stride=2)
61 | self.left = nn.Sequential(
62 | ConvBNReLU(16, 8, 1, stride=1, padding=0),
63 | ConvBNReLU(8, 16, 3, stride=2),
64 | )
65 | self.right = nn.MaxPool2d(
66 | kernel_size=3, stride=2, padding=1, ceil_mode=False)
67 | self.fuse = ConvBNReLU(32, 16, 3, stride=1)
68 |
69 | def forward(self, x):
70 | feat = self.conv(x)
71 | feat_left = self.left(feat)
72 | feat_right = self.right(feat)
73 | feat = torch.cat([feat_left, feat_right], dim=1)
74 | feat = self.fuse(feat)
75 | return feat
76 |
77 |
78 | class CEBlock(nn.Module):
79 |
80 | def __init__(self):
81 | super(CEBlock, self).__init__()
82 | self.bn = nn.BatchNorm2d(128)
83 | self.conv_gap = ConvBNReLU(128, 128, 1, stride=1, padding=0)
84 | #TODO: in paper here is naive conv2d, no bn-relu
85 | self.conv_last = ConvBNReLU(128, 128, 3, stride=1)
86 |
87 | def forward(self, x):
88 | feat = torch.mean(x, dim=(2, 3), keepdim=True)
89 | feat = self.bn(feat)
90 | feat = self.conv_gap(feat)
91 | feat = feat + x
92 | feat = self.conv_last(feat)
93 | return feat
94 |
95 |
96 | class GELayerS1(nn.Module):
97 |
98 | def __init__(self, in_chan, out_chan, exp_ratio=6):
99 | super(GELayerS1, self).__init__()
100 | mid_chan = in_chan * exp_ratio
101 | self.conv1 = ConvBNReLU(in_chan, in_chan, 3, stride=1)
102 | self.dwconv = nn.Sequential(
103 | nn.Conv2d(
104 | in_chan, mid_chan, kernel_size=3, stride=1,
105 | padding=1, groups=in_chan, bias=False),
106 | nn.BatchNorm2d(mid_chan),
107 | nn.ReLU(inplace=True), # not shown in paper
108 | )
109 | self.conv2 = nn.Sequential(
110 | nn.Conv2d(
111 | mid_chan, out_chan, kernel_size=1, stride=1,
112 | padding=0, bias=False),
113 | nn.BatchNorm2d(out_chan),
114 | )
115 | self.conv2[1].last_bn = True
116 | self.relu = nn.ReLU(inplace=True)
117 |
118 | def forward(self, x):
119 | feat = self.conv1(x)
120 | feat = self.dwconv(feat)
121 | feat = self.conv2(feat)
122 | feat = feat + x
123 | feat = self.relu(feat)
124 | return feat
125 |
126 |
127 | class GELayerS2(nn.Module):
128 |
129 | def __init__(self, in_chan, out_chan, exp_ratio=6):
130 | super(GELayerS2, self).__init__()
131 | mid_chan = in_chan * exp_ratio
132 | self.conv1 = ConvBNReLU(in_chan, in_chan, 3, stride=1)
133 | self.dwconv1 = nn.Sequential(
134 | nn.Conv2d(
135 | in_chan, mid_chan, kernel_size=3, stride=2,
136 | padding=1, groups=in_chan, bias=False),
137 | nn.BatchNorm2d(mid_chan),
138 | )
139 | self.dwconv2 = nn.Sequential(
140 | nn.Conv2d(
141 | mid_chan, mid_chan, kernel_size=3, stride=1,
142 | padding=1, groups=mid_chan, bias=False),
143 | nn.BatchNorm2d(mid_chan),
144 | nn.ReLU(inplace=True), # not shown in paper
145 | )
146 | self.conv2 = nn.Sequential(
147 | nn.Conv2d(
148 | mid_chan, out_chan, kernel_size=1, stride=1,
149 | padding=0, bias=False),
150 | nn.BatchNorm2d(out_chan),
151 | )
152 | self.conv2[1].last_bn = True
153 | self.shortcut = nn.Sequential(
154 | nn.Conv2d(
155 | in_chan, in_chan, kernel_size=3, stride=2,
156 | padding=1, groups=in_chan, bias=False),
157 | nn.BatchNorm2d(in_chan),
158 | nn.Conv2d(
159 | in_chan, out_chan, kernel_size=1, stride=1,
160 | padding=0, bias=False),
161 | nn.BatchNorm2d(out_chan),
162 | )
163 | self.relu = nn.ReLU(inplace=True)
164 |
165 | def forward(self, x):
166 | feat = self.conv1(x)
167 | feat = self.dwconv1(feat)
168 | feat = self.dwconv2(feat)
169 | feat = self.conv2(feat)
170 | shortcut = self.shortcut(x)
171 | feat = feat + shortcut
172 | feat = self.relu(feat)
173 | return feat
174 |
175 |
176 | class SegmentBranch(nn.Module):
177 |
178 | def __init__(self):
179 | super(SegmentBranch, self).__init__()
180 | self.S1S2 = StemBlock()
181 | self.S3 = nn.Sequential(
182 | GELayerS2(16, 32),
183 | GELayerS1(32, 32),
184 | )
185 | self.S4 = nn.Sequential(
186 | GELayerS2(32, 64),
187 | GELayerS1(64, 64),
188 | )
189 | self.S5_4 = nn.Sequential(
190 | GELayerS2(64, 128),
191 | GELayerS1(128, 128),
192 | GELayerS1(128, 128),
193 | GELayerS1(128, 128),
194 | )
195 | self.S5_5 = CEBlock()
196 |
197 | def forward(self, x):
198 | feat2 = self.S1S2(x)
199 | feat3 = self.S3(feat2)
200 | feat4 = self.S4(feat3)
201 | feat5_4 = self.S5_4(feat4)
202 | feat5_5 = self.S5_5(feat5_4)
203 | return feat2, feat3, feat4, feat5_4, feat5_5
204 |
205 |
206 | class BGALayer(nn.Module):
207 |
208 | def __init__(self):
209 | super(BGALayer, self).__init__()
210 | self.left1 = nn.Sequential(
211 | nn.Conv2d(
212 | 128, 128, kernel_size=3, stride=1,
213 | padding=1, groups=128, bias=False),
214 | nn.BatchNorm2d(128),
215 | nn.Conv2d(
216 | 128, 128, kernel_size=1, stride=1,
217 | padding=0, bias=False),
218 | )
219 | self.left2 = nn.Sequential(
220 | nn.Conv2d(
221 | 128, 128, kernel_size=3, stride=2,
222 | padding=1, bias=False),
223 | nn.BatchNorm2d(128),
224 | nn.AvgPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=False)
225 | )
226 | self.right1 = nn.Sequential(
227 | nn.Conv2d(
228 | 128, 128, kernel_size=3, stride=1,
229 | padding=1, bias=False),
230 | nn.BatchNorm2d(128),
231 | )
232 | self.right2 = nn.Sequential(
233 | nn.Conv2d(
234 | 128, 128, kernel_size=3, stride=1,
235 | padding=1, groups=128, bias=False),
236 | nn.BatchNorm2d(128),
237 | nn.Conv2d(
238 | 128, 128, kernel_size=1, stride=1,
239 | padding=0, bias=False),
240 | )
241 | ##TODO: does this really has no relu?
242 | self.conv = nn.Sequential(
243 | nn.Conv2d(
244 | 128, 128, kernel_size=3, stride=1,
245 | padding=1, bias=False),
246 | nn.BatchNorm2d(128),
247 | nn.ReLU(inplace=True), # not shown in paper
248 | )
249 |
250 | def forward(self, x_d, x_s):
251 | dsize = x_d.size()[2:]
252 | left1 = self.left1(x_d)
253 | left2 = self.left2(x_d)
254 | right1 = self.right1(x_s)
255 | right2 = self.right2(x_s)
256 | right1 = F.interpolate(
257 | right1, size=dsize, mode='bilinear', align_corners=True)
258 | left = left1 * torch.sigmoid(right1)
259 | right = left2 * torch.sigmoid(right2)
260 | right = F.interpolate(
261 | right, size=dsize, mode='bilinear', align_corners=True)
262 | out = self.conv(left + right)
263 | return out
264 |
265 |
266 |
267 | class SegmentHead(nn.Module):
268 |
269 | def __init__(self, in_chan, mid_chan, num_classes):
270 | super(SegmentHead, self).__init__()
271 | self.conv = ConvBNReLU(in_chan, mid_chan, 3, stride=1)
272 | self.drop = nn.Dropout(0.1)
273 | self.conv_out = nn.Conv2d(
274 | mid_chan, num_classes, kernel_size=1, stride=1,
275 | padding=0, bias=True)
276 |
277 | def forward(self, x, size=None):
278 | feat = self.conv(x)
279 | feat = self.drop(feat)
280 | feat = self.conv_out(feat)
281 | if not size is None:
282 | feat = F.interpolate(feat, size=size,
283 | mode='bilinear', align_corners=True)
284 | return feat
285 |
286 |
287 | class BiSeNetV2(nn.Module):
288 |
289 | def __init__(self, num_classes):
290 | super(BiSeNetV2, self).__init__()
291 | self.detail = DetailBranch()
292 | self.segment = SegmentBranch()
293 | self.bga = BGALayer()
294 |
295 | ## TODO: what is the number of mid chan ?
296 | self.head = SegmentHead(128, 1024, num_classes)
297 | self.aux2 = SegmentHead(16, 128, num_classes)
298 | self.aux3 = SegmentHead(32, 128, num_classes)
299 | self.aux4 = SegmentHead(64, 128, num_classes)
300 | self.aux5_4 = SegmentHead(128, 128, num_classes)
301 |
302 | # self.init_weights()
303 |
304 | def forward(self, x):
305 | size = x.size()[2:]
306 | feat_d = self.detail(x)
307 | feat2, feat3, feat4, feat5_4, feat_s = self.segment(x)
308 | feat_head = self.bga(feat_d, feat_s)
309 |
310 | logits = self.head(feat_head, size)
311 | logits_aux2 = self.aux2(feat2, size)
312 | logits_aux3 = self.aux3(feat3, size)
313 | logits_aux4 = self.aux4(feat4, size)
314 | logits_aux5_4 = self.aux5_4(feat5_4, size)
315 | return logits, logits_aux2, logits_aux3, logits_aux4, logits_aux5_4
316 |
317 | # def init_weights(self):
318 | # for name, module in self.named_modules():
319 | # if isinstance(module, (nn.Conv2d, nn.Linear)):
320 | # nn.init.kaiming_normal_(module.weight, mode='fan_out')
321 | # if not module.bias is None: nn.init.constant_(module.bias, 0)
322 | # elif isinstance(module, nn.modules.batchnorm._BatchNorm):
323 | # if hasattr(module, 'last_bn') and module.last_bn:
324 | # nn.init.zeros_(module.weight)
325 | # else:
326 | # nn.init.ones_(module.weight)
327 | # nn.init.zeros_(module.bias)
328 |
329 |
330 | """print layers and params of network"""
331 | if __name__ == '__main__':
332 | model = BiSeNetV2(num_classes=3)
333 | summary(model, (3, 512, 512), device="cpu")
--------------------------------------------------------------------------------
/model/FCN8s.py:
--------------------------------------------------------------------------------
1 | # _*_ coding: utf-8 _*_
2 | """
3 | Time: 2020/11/22 19:06
4 | Author: Cheng Ding(Deeachain)
5 | Version: V 0.1
6 | File: FCN8s.py
7 | Describe: Write during my study in Nanjing University of Information and Secience Technology
8 | Github: https://github.com/Deeachain
9 | """
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 | from torchsummary import summary
14 |
15 | ######################################################################################
16 | # FCN: Fully Convolutional Networks for Semantic Segmentation
17 | # Paper-Link: https://arxiv.org/abs/1411.4038
18 | ######################################################################################
19 |
20 | __all__ = ["FCN"]
21 |
22 |
23 | def conv1x1(in_planes, out_planes, stride=1):
24 | """1x1 convolution"""
25 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
26 |
27 |
28 | class conv3x3_block_x1(nn.Module):
29 | '''(conv => BN => ReLU) * 1'''
30 |
31 | def __init__(self, in_ch, out_ch):
32 | super(conv3x3_block_x1, self).__init__()
33 | self.conv = nn.Sequential(
34 | nn.Conv2d(in_ch, out_ch, 3, padding=1),
35 | nn.BatchNorm2d(out_ch),
36 | nn.ReLU(inplace=True)
37 | )
38 |
39 | def forward(self, x):
40 | x = self.conv(x)
41 | return x
42 |
43 |
44 | class conv3x3_block_x2(nn.Module):
45 | '''(conv => BN => ReLU) * 2'''
46 |
47 | def __init__(self, in_ch, out_ch):
48 | super(conv3x3_block_x2, self).__init__()
49 | self.conv = nn.Sequential(
50 | nn.Conv2d(in_ch, out_ch, 3, padding=1),
51 | nn.BatchNorm2d(out_ch),
52 | nn.ReLU(inplace=True),
53 | nn.Conv2d(out_ch, out_ch, 3, padding=1),
54 | nn.BatchNorm2d(out_ch),
55 | nn.ReLU(inplace=True),
56 | nn.MaxPool2d(2)
57 | )
58 |
59 | def forward(self, x):
60 | x = self.conv(x)
61 | return x
62 |
63 |
64 | class conv3x3_block_x3(nn.Module):
65 | '''(conv => BN => ReLU) * 3'''
66 |
67 | def __init__(self, in_ch, out_ch):
68 | super(conv3x3_block_x3, self).__init__()
69 | self.conv = nn.Sequential(
70 | nn.Conv2d(in_ch, out_ch, 3, padding=1),
71 | nn.BatchNorm2d(out_ch),
72 | nn.ReLU(inplace=True),
73 | nn.Conv2d(out_ch, out_ch, 3, padding=1),
74 | nn.BatchNorm2d(out_ch),
75 | nn.ReLU(inplace=True),
76 | nn.Conv2d(out_ch, out_ch, 3, padding=1),
77 | nn.BatchNorm2d(out_ch),
78 | nn.ReLU(inplace=True),
79 | nn.MaxPool2d(2)
80 | )
81 |
82 | def forward(self, x):
83 | x = self.conv(x)
84 | return x
85 |
86 |
87 | class upsample(nn.Module):
88 | def __init__(self, in_ch, out_ch, scale_factor=2):
89 | super(upsample, self).__init__()
90 | self.conv1x1 = conv1x1(in_ch, out_ch)
91 | self.scale_factor = scale_factor
92 |
93 | def forward(self, H):
94 | """
95 | H: High level feature map, upsample
96 | """
97 | H = self.conv1x1(H)
98 | H = F.interpolate(H, scale_factor=self.scale_factor, mode='bilinear', align_corners=False)
99 | return H
100 |
101 |
102 | class FCN(nn.Module):
103 | def __init__(self, num_classes):
104 | super(FCN, self).__init__()
105 | self.maxpool = nn.MaxPool2d(2)
106 | self.block1 = conv3x3_block_x2(3, 64)
107 | self.block2 = conv3x3_block_x2(64, 128)
108 | self.block3 = conv3x3_block_x3(128, 256)
109 | self.block4 = conv3x3_block_x3(256, 512)
110 | self.block5 = conv3x3_block_x3(512, 512)
111 | self.upsample1 = upsample(512, 512, 2)
112 | self.upsample2 = upsample(512, 256, 2)
113 | self.upsample3 = upsample(256, num_classes, 8)
114 |
115 | def forward(self, x):
116 | block1_x = self.block1(x)
117 | block2_x = self.block2(block1_x)
118 | block3_x = self.block3(block2_x)
119 | block4_x = self.block4(block3_x)
120 | block5_x = self.block5(block4_x)
121 | upsample1 = self.upsample1(block5_x)
122 | x = torch.add(upsample1, block4_x)
123 | upsample2 = self.upsample2(x)
124 | x = torch.add(upsample2, block3_x)
125 | x = self.upsample3(x)
126 |
127 | return x
128 |
129 |
130 | if __name__ == '__main__':
131 | model = FCN(num_classes=3)
132 | summary(model, (3, 512, 512), device="cpu")
133 |
--------------------------------------------------------------------------------
/model/FCN_ResNet.py:
--------------------------------------------------------------------------------
1 | # _*_ coding: utf-8 _*_
2 | """
3 | Time: 2020/12/1 18:23
4 | Author: Ding Cheng(Deeachain)
5 | File: FCN_ResNet.py
6 | Describe: Write during my study in Nanjing University of Information and Secience Technology
7 | Github: https://github.com/Deeachain
8 | # """
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | from torchsummary import summary
13 | from model.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
14 | from model.base_model import build_backbone
15 |
16 |
17 | class FCN_ResNet(nn.Module):
18 | def __init__(self, num_classes, backbone='resnet18', out_stride=32, mult_grid=False):
19 | super(FCN_ResNet, self).__init__()
20 |
21 | if backbone == 'resnet18' or backbone == 'resnet34':
22 | expansion = 1
23 | elif backbone == 'resnet50' or backbone == 'resnet101':
24 | expansion = 4
25 | self.backbone = build_backbone(backbone, out_stride, mult_grid)
26 |
27 | self.conv_1 = nn.Conv2d(in_channels=512 * expansion, out_channels=num_classes, kernel_size=1)
28 | self.conv_2 = nn.Conv2d(in_channels=256 * expansion, out_channels=num_classes, kernel_size=1)
29 | self.conv_3 = nn.Conv2d(in_channels=128 * expansion, out_channels=num_classes, kernel_size=1)
30 | self.conv_4 = nn.Conv2d(in_channels=64 * expansion, out_channels=num_classes, kernel_size=1)
31 |
32 | self._init_weight()
33 |
34 | def forward(self, x):
35 | layers = self.backbone(x) # resnet 4 layers
36 |
37 | layers3 = self.conv_1(layers[3])
38 | layers3 = F.interpolate(layers3, layers[2].size()[2:], mode="bilinear", align_corners=True)
39 | layers2 = self.conv_2(layers[2])
40 |
41 | output = layers2 + layers3
42 | output = F.interpolate(output, layers[1].size()[2:], mode="bilinear", align_corners=True)
43 | layers1 = self.conv_3(layers[1])
44 |
45 | output = output + layers1
46 | output = F.interpolate(output, layers[0].size()[2:], mode="bilinear", align_corners=True)
47 | layers0 = self.conv_4(layers[0])
48 |
49 | output = output + layers0
50 | output = F.interpolate(output, x.size()[2:], mode="bilinear", align_corners=True)
51 | aux1 = F.interpolate(layers2, x.size()[2:], mode="bilinear", align_corners=True)
52 |
53 | return output
54 |
55 | def _init_weight(self):
56 | for m in self.modules():
57 | if isinstance(m, nn.Conv2d):
58 | torch.nn.init.kaiming_normal_(m.weight)
59 | elif isinstance(m, SynchronizedBatchNorm2d):
60 | m.weight.data.fill_(1)
61 | m.bias.data.zero_()
62 | elif isinstance(m, nn.BatchNorm2d):
63 | m.weight.data.fill_(1)
64 | m.bias.data.zero_()
65 |
66 | def freeze_bn(self):
67 | for m in self.modules():
68 | if isinstance(m, SynchronizedBatchNorm2d):
69 | m.eval()
70 | elif isinstance(m, nn.BatchNorm2d):
71 | m.eval()
72 |
73 | def get_1x_lr_params(self):
74 | modules = [self.backbone]
75 | for i in range(len(modules)):
76 | for m in modules[i].named_modules():
77 | if self.freeze_bn:
78 | if isinstance(m[1], nn.Conv2d):
79 | for p in m[1].parameters():
80 | if p.requires_grad:
81 | yield p
82 | else:
83 | if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \
84 | or isinstance(m[1], nn.BatchNorm2d):
85 | for p in m[1].parameters():
86 | if p.requires_grad:
87 | yield p
88 |
89 | def get_10x_lr_params(self):
90 | modules = [self.conv_1, self.conv_2, self.conv_3, self.conv_4]
91 | for i in range(len(modules)):
92 | for m in modules[i].named_modules():
93 | if self.freeze_bn:
94 | if isinstance(m[1], nn.Conv2d):
95 | for p in m[1].parameters():
96 | if p.requires_grad:
97 | yield p
98 | else:
99 | if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \
100 | or isinstance(m[1], nn.BatchNorm2d):
101 | for p in m[1].parameters():
102 | if p.requires_grad:
103 | yield p
104 |
105 |
106 | """print layers and params of network"""
107 | if __name__ == '__main__':
108 | model = FCN_ResNet(num_classes=3, backbone='resnet18')
109 | summary(model, (3, 512, 512), device="cpu")
110 |
--------------------------------------------------------------------------------
/model/PSPNet/psanet.py:
--------------------------------------------------------------------------------
1 | """
2 | Reference from source code by author: https://github.com/hszhao/semseg
3 | """
4 | import torch
5 | from torch import nn
6 | import torch.nn.functional as F
7 | from .lib.psa import functional as PF
8 | from model.PSPNet import resnet as models
9 | # import .lib.psa.functional as PF
10 |
11 |
12 | class PSA(nn.Module):
13 | def __init__(self, in_channels=2048, mid_channels=512, psa_type=2, compact=False, shrink_factor=2, mask_h=59,
14 | mask_w=59, normalization_factor=1.0, psa_softmax=True):
15 | super(PSA, self).__init__()
16 | assert psa_type in [0, 1, 2]
17 | self.psa_type = psa_type
18 | self.compact = compact
19 | self.shrink_factor = shrink_factor
20 | self.mask_h = mask_h
21 | self.mask_w = mask_w
22 | self.psa_softmax = psa_softmax
23 | if normalization_factor is None:
24 | normalization_factor = mask_h * mask_w
25 | self.normalization_factor = normalization_factor
26 |
27 | self.reduce = nn.Sequential(
28 | nn.Conv2d(in_channels, mid_channels, kernel_size=1, bias=False),
29 | nn.BatchNorm2d(mid_channels),
30 | nn.ReLU(inplace=True)
31 | )
32 | self.attention = nn.Sequential(
33 | nn.Conv2d(mid_channels, mid_channels, kernel_size=1, bias=False),
34 | nn.BatchNorm2d(mid_channels),
35 | nn.ReLU(inplace=True),
36 | nn.Conv2d(mid_channels, mask_h*mask_w, kernel_size=1, bias=False),
37 | )
38 | if psa_type == 2:
39 | self.reduce_p = nn.Sequential(
40 | nn.Conv2d(in_channels, mid_channels, kernel_size=1, bias=False),
41 | nn.BatchNorm2d(mid_channels),
42 | nn.ReLU(inplace=True)
43 | )
44 | self.attention_p = nn.Sequential(
45 | nn.Conv2d(mid_channels, mid_channels, kernel_size=1, bias=False),
46 | nn.BatchNorm2d(mid_channels),
47 | nn.ReLU(inplace=True),
48 | nn.Conv2d(mid_channels, mask_h*mask_w, kernel_size=1, bias=False),
49 | )
50 | self.proj = nn.Sequential(
51 | nn.Conv2d(mid_channels * (2 if psa_type == 2 else 1), in_channels, kernel_size=1, bias=False),
52 | nn.BatchNorm2d(in_channels),
53 | nn.ReLU(inplace=True)
54 | )
55 |
56 | def forward(self, x):
57 | out = x
58 | if self.psa_type in [0, 1]:
59 | x = self.reduce(x)
60 | n, c, h, w = x.size()
61 | if self.shrink_factor != 1:
62 | h = (h - 1) // self.shrink_factor + 1
63 | w = (w - 1) // self.shrink_factor + 1
64 | x = F.interpolate(x, size=(h, w), mode='bilinear', align_corners=True)
65 | y = self.attention(x)
66 | if self.compact:
67 | if self.psa_type == 1:
68 | y = y.view(n, h * w, h * w).transpose(1, 2).view(n, h * w, h, w)
69 | else:
70 | y = PF.psa_mask(y, self.psa_type, self.mask_h, self.mask_w)
71 | if self.psa_softmax:
72 | y = F.softmax(y, dim=1)
73 | x = torch.bmm(x.view(n, c, h * w), y.view(n, h * w, h * w)).view(n, c, h, w) * (1.0 / self.normalization_factor)
74 | elif self.psa_type == 2:
75 | x_col = self.reduce(x)
76 | x_dis = self.reduce_p(x)
77 | n, c, h, w = x_col.size()
78 | if self.shrink_factor != 1:
79 | h = (h - 1) // self.shrink_factor + 1
80 | w = (w - 1) // self.shrink_factor + 1
81 | x_col = F.interpolate(x_col, size=(h, w), mode='bilinear', align_corners=True)
82 | x_dis = F.interpolate(x_dis, size=(h, w), mode='bilinear', align_corners=True)
83 | y_col = self.attention(x_col)
84 | y_dis = self.attention_p(x_dis)
85 | if self.compact:
86 | y_dis = y_dis.view(n, h * w, h * w).transpose(1, 2).view(n, h * w, h, w)
87 | else:
88 | y_col = PF.psa_mask(y_col, 0, self.mask_h, self.mask_w)
89 | y_dis = PF.psa_mask(y_dis, 1, self.mask_h, self.mask_w)
90 | if self.psa_softmax:
91 | y_col = F.softmax(y_col, dim=1)
92 | y_dis = F.softmax(y_dis, dim=1)
93 | x_col = torch.bmm(x_col.view(n, c, h * w), y_col.view(n, h * w, h * w)).view(n, c, h, w) * (1.0 / self.normalization_factor)
94 | x_dis = torch.bmm(x_dis.view(n, c, h * w), y_dis.view(n, h * w, h * w)).view(n, c, h, w) * (1.0 / self.normalization_factor)
95 | x = torch.cat([x_col, x_dis], 1)
96 | x = self.proj(x)
97 | if self.shrink_factor != 1:
98 | h = (h - 1) * self.shrink_factor + 1
99 | w = (w - 1) * self.shrink_factor + 1
100 | x = F.interpolate(x, size=(h, w), mode='bilinear', align_corners=True)
101 | return torch.cat((out, x), 1)
102 |
103 |
104 | class PSANet(nn.Module):
105 | def __init__(self, layers=50, dropout=0.1, classes=2, zoom_factor=8, use_psa=True, psa_type=2, compact=False,
106 | shrink_factor=2, mask_h=59, mask_w=59, normalization_factor=1.0, psa_softmax=True,
107 | criterion=nn.CrossEntropyLoss(ignore_index=255), pretrained=True):
108 | super(PSANet, self).__init__()
109 | assert layers in [50, 101, 152]
110 | assert classes > 1
111 | assert zoom_factor in [1, 2, 4, 8]
112 | assert psa_type in [0, 1, 2]
113 | self.zoom_factor = zoom_factor
114 | self.use_psa = use_psa
115 | self.criterion = criterion
116 |
117 | if layers == 50:
118 | resnet = models.resnet50(pretrained=pretrained)
119 | elif layers == 101:
120 | resnet = models.resnet101(pretrained=pretrained)
121 | else:
122 | resnet = models.resnet152(pretrained=pretrained)
123 | self.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.conv2, resnet.bn2, resnet.relu, resnet.conv3, resnet.bn3, resnet.relu, resnet.maxpool)
124 | self.layer1, self.layer2, self.layer3, self.layer4 = resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4
125 |
126 | for n, m in self.layer3.named_modules():
127 | if 'conv2' in n:
128 | m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1)
129 | elif 'downsample.0' in n:
130 | m.stride = (1, 1)
131 | for n, m in self.layer4.named_modules():
132 | if 'conv2' in n:
133 | m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1)
134 | elif 'downsample.0' in n:
135 | m.stride = (1, 1)
136 |
137 | fea_dim = 2048
138 | if use_psa:
139 | self.psa = PSA(fea_dim, 512, psa_type, compact, shrink_factor, mask_h, mask_w, normalization_factor, psa_softmax)
140 | fea_dim *= 2
141 | self.cls = nn.Sequential(
142 | nn.Conv2d(fea_dim, 512, kernel_size=3, padding=1, bias=False),
143 | nn.BatchNorm2d(512),
144 | nn.ReLU(inplace=True),
145 | nn.Dropout2d(p=dropout),
146 | nn.Conv2d(512, classes, kernel_size=1)
147 | )
148 | if self.training:
149 | self.aux = nn.Sequential(
150 | nn.Conv2d(1024, 256, kernel_size=3, padding=1, bias=False),
151 | nn.BatchNorm2d(256),
152 | nn.ReLU(inplace=True),
153 | nn.Dropout2d(p=dropout),
154 | nn.Conv2d(256, classes, kernel_size=1)
155 | )
156 |
157 | def forward(self, x, y=None):
158 | x_size = x.size()
159 | assert (x_size[2] - 1) % 8 == 0 and (x_size[3] - 1) % 8 == 0
160 | h = int((x_size[2] - 1) / 8 * self.zoom_factor + 1)
161 | w = int((x_size[3] - 1) / 8 * self.zoom_factor + 1)
162 |
163 | x = self.layer0(x)
164 | x = self.layer1(x)
165 | x = self.layer2(x)
166 | x_tmp = self.layer3(x)
167 | x = self.layer4(x_tmp)
168 | if self.use_psa:
169 | x = self.psa(x)
170 | x = self.cls(x)
171 | if self.zoom_factor != 1:
172 | x = F.interpolate(x, size=(h, w), mode='bilinear', align_corners=True)
173 |
174 | if self.training:
175 | aux = self.aux(x_tmp)
176 | if self.zoom_factor != 1:
177 | aux = F.interpolate(aux, size=(h, w), mode='bilinear', align_corners=True)
178 | main_loss = self.criterion(x, y)
179 | aux_loss = self.criterion(aux, y)
180 | return x.max(1)[1], main_loss, aux_loss
181 | else:
182 | return x
183 |
184 |
185 | if __name__ == '__main__':
186 | import os
187 | os.environ["CUDA_VISIBLE_DEVICES"] = '0, 1'
188 | crop_h = crop_w = 465
189 | input = torch.rand(4, 3, crop_h, crop_w).cuda()
190 | compact = False
191 | mask_h, mask_w = None, None
192 | shrink_factor = 2
193 | if compact:
194 | mask_h = (crop_h - 1) // (8 * shrink_factor) + 1
195 | mask_w = (crop_w - 1) // (8 * shrink_factor) + 1
196 | else:
197 | assert (mask_h is None and mask_w is None) or (mask_h is not None and mask_w is not None)
198 | if mask_h is None and mask_w is None:
199 | mask_h = 2 * ((crop_h - 1) // (8 * shrink_factor) + 1) - 1
200 | mask_w = 2 * ((crop_w - 1) // (8 * shrink_factor) + 1) - 1
201 | else:
202 | assert (mask_h % 2 == 1) and (mask_h >= 3) and (mask_h <= 2 * ((crop_h - 1) // (8 * shrink_factor) + 1) - 1)
203 | assert (mask_w % 2 == 1) and (mask_w >= 3) and (mask_w <= 2 * ((crop_h - 1) // (8 * shrink_factor) + 1) - 1)
204 |
205 | model = PSANet(layers=50, dropout=0.1, classes=21, zoom_factor=8, use_psa=True, psa_type=2, compact=compact,
206 | shrink_factor=shrink_factor, mask_h=mask_h, mask_w=mask_w, psa_softmax=True, pretrained=True).cuda()
207 | print(model)
208 | model.eval()
209 | output = model(input)
210 | print('PSANet', output.size())
211 |
--------------------------------------------------------------------------------
/model/PSPNet/pspnet.py:
--------------------------------------------------------------------------------
1 | """
2 | Reference from source code by author: https://github.com/hszhao/semseg
3 | """
4 | import torch
5 | from torch import nn
6 | import torch.nn.functional as F
7 | from model.PSPNet import resnet as models
8 | from model.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
9 |
10 |
11 | class PPM(nn.Module):
12 | def __init__(self, in_dim, reduction_dim, bins):
13 | super(PPM, self).__init__()
14 | self.features = []
15 | for bin in bins:
16 | self.features.append(nn.Sequential(
17 | nn.AdaptiveAvgPool2d(bin),
18 | nn.Conv2d(in_dim, reduction_dim, kernel_size=1, bias=False),
19 | nn.BatchNorm2d(reduction_dim),
20 | nn.ReLU(inplace=True)
21 | ))
22 | self.features = nn.ModuleList(self.features)
23 |
24 | def forward(self, x):
25 | x_size = x.size()
26 | out = [x]
27 | for f in self.features:
28 | out.append(F.interpolate(f(x), x_size[2:], mode='bilinear', align_corners=True))
29 | return torch.cat(out, 1)
30 |
31 |
32 | class PSPNet(nn.Module):
33 | def __init__(self, layers=50, bins=(1, 2, 3, 6), dropout=0.1, num_classes=2, zoom_factor=8, use_ppm=True,
34 | criterion=nn.CrossEntropyLoss(ignore_index=255), pretrained=True):
35 | super(PSPNet, self).__init__()
36 | assert layers in [50, 101, 152]
37 | assert 2048 % len(bins) == 0
38 | assert num_classes > 1
39 | assert zoom_factor in [1, 2, 4, 8]
40 | self.zoom_factor = zoom_factor
41 | self.use_ppm = use_ppm
42 | self.criterion = criterion
43 |
44 | if layers == 50:
45 | resnet = models.resnet50(pretrained=pretrained)
46 | elif layers == 101:
47 | resnet = models.resnet101(pretrained=pretrained)
48 | else:
49 | resnet = models.resnet152(pretrained=pretrained)
50 | # self.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.conv2, resnet.bn2, resnet.relu,
51 | # resnet.conv3, resnet.bn3, resnet.relu, resnet.maxpool)
52 | self.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool)
53 | self.layer1, self.layer2, self.layer3, self.layer4 = resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4
54 |
55 | for n, m in self.layer3.named_modules():
56 | if 'conv2' in n:
57 | m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1)
58 | elif 'downsample.0' in n:
59 | m.stride = (1, 1)
60 | for n, m in self.layer4.named_modules():
61 | if 'conv2' in n:
62 | m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1)
63 | elif 'downsample.0' in n:
64 | m.stride = (1, 1)
65 |
66 | fea_dim = 2048
67 | if use_ppm:
68 | self.ppm = PPM(fea_dim, int(fea_dim / len(bins)), bins)
69 | fea_dim *= 2
70 | self.cls = nn.Sequential(
71 | nn.Conv2d(fea_dim, 512, kernel_size=3, padding=1, bias=False),
72 | nn.BatchNorm2d(512),
73 | nn.ReLU(inplace=True),
74 | nn.Dropout2d(p=dropout),
75 | nn.Conv2d(512, num_classes, kernel_size=1)
76 | )
77 | if self.training:
78 | self.aux = nn.Sequential(
79 | nn.Conv2d(1024, 256, kernel_size=3, padding=1, bias=False),
80 | nn.BatchNorm2d(256),
81 | nn.ReLU(inplace=True),
82 | nn.Dropout2d(p=dropout),
83 | nn.Conv2d(256, num_classes, kernel_size=1)
84 | )
85 |
86 | def forward(self, x, y=None):
87 | x_size = x.size()
88 | # assert (x_size[2]) % 8 == 0 and (x_size[3]) % 8 == 0
89 | # h = int((x_size[2] - 1) / 8 * self.zoom_factor + 1)
90 | # w = int((x_size[3] - 1) / 8 * self.zoom_factor + 1)
91 | h, w = x_size[2], x_size[3]
92 |
93 | x = self.layer0(x)
94 | x = self.layer1(x)
95 | x = self.layer2(x)
96 | x_tmp = self.layer3(x)
97 | x = self.layer4(x_tmp)
98 | if self.use_ppm:
99 | x = self.ppm(x)
100 | x = self.cls(x)
101 |
102 | if self.zoom_factor != 1:
103 | x = F.interpolate(x, size=(h, w), mode='bilinear', align_corners=True)
104 | if self.training:
105 | aux = self.aux(x_tmp)
106 | if self.zoom_factor != 1:
107 | aux = F.interpolate(aux, size=(h, w), mode='bilinear', align_corners=True)
108 | return x, aux
109 | else:
110 | return x
111 |
112 |
113 | if __name__ == '__main__':
114 | import os
115 |
116 | os.environ["CUDA_VISIBLE_DEVICES"] = '0'
117 | input = torch.rand(4, 3, 473, 473).cuda()
118 | model = PSPNet(layers=50, bins=(1, 2, 3, 6), dropout=0.1, num_classes=21, zoom_factor=1, use_ppm=True,
119 | pretrained=True).cuda()
120 | model.eval()
121 | output = model(input)
122 | print('PSPNet', output.size())
123 |
--------------------------------------------------------------------------------
/model/PSPNet/resnet.py:
--------------------------------------------------------------------------------
1 | """
2 | Reference from source code by author: https://github.com/hszhao/semseg
3 | """
4 | import torch
5 | import torch.nn as nn
6 | import math
7 | import torch.utils.model_zoo as model_zoo
8 |
9 |
10 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
11 | 'resnet152']
12 |
13 |
14 | model_urls = {
15 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
16 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
17 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
18 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
19 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
20 | }
21 |
22 |
23 | def conv3x3(in_planes, out_planes, stride=1):
24 | """3x3 convolution with padding"""
25 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
26 | padding=1, bias=False)
27 |
28 |
29 | class BasicBlock(nn.Module):
30 | expansion = 1
31 |
32 | def __init__(self, inplanes, planes, stride=1, downsample=None):
33 | super(BasicBlock, self).__init__()
34 | self.conv1 = conv3x3(inplanes, planes, stride)
35 | self.bn1 = nn.BatchNorm2d(planes)
36 | self.relu = nn.ReLU(inplace=True)
37 | self.conv2 = conv3x3(planes, planes)
38 | self.bn2 = nn.BatchNorm2d(planes)
39 | self.downsample = downsample
40 | self.stride = stride
41 |
42 | def forward(self, x):
43 | residual = x
44 |
45 | out = self.conv1(x)
46 | out = self.bn1(out)
47 | out = self.relu(out)
48 |
49 | out = self.conv2(out)
50 | out = self.bn2(out)
51 |
52 | if self.downsample is not None:
53 | residual = self.downsample(x)
54 |
55 | out += residual
56 | out = self.relu(out)
57 |
58 | return out
59 |
60 |
61 | class Bottleneck(nn.Module):
62 | expansion = 4
63 |
64 | def __init__(self, inplanes, planes, stride=1, downsample=None):
65 | super(Bottleneck, self).__init__()
66 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
67 | self.bn1 = nn.BatchNorm2d(planes)
68 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
69 | padding=1, bias=False)
70 | self.bn2 = nn.BatchNorm2d(planes)
71 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
72 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
73 | self.relu = nn.ReLU(inplace=True)
74 | self.downsample = downsample
75 | self.stride = stride
76 |
77 | def forward(self, x):
78 | residual = x
79 |
80 | out = self.conv1(x)
81 | out = self.bn1(out)
82 | out = self.relu(out)
83 |
84 | out = self.conv2(out)
85 | out = self.bn2(out)
86 | out = self.relu(out)
87 |
88 | out = self.conv3(out)
89 | out = self.bn3(out)
90 |
91 | if self.downsample is not None:
92 | residual = self.downsample(x)
93 |
94 | out += residual
95 | out = self.relu(out)
96 |
97 | return out
98 |
99 |
100 | class ResNet(nn.Module):
101 |
102 | def __init__(self, block, layers, num_classes=1000, deep_base=True):
103 | super(ResNet, self).__init__()
104 | self.deep_base = deep_base
105 | if not self.deep_base:
106 | self.inplanes = 64
107 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
108 | self.bn1 = nn.BatchNorm2d(64)
109 | else:
110 | self.inplanes = 128
111 | self.conv1 = conv3x3(3, 64, stride=2)
112 | self.bn1 = nn.BatchNorm2d(64)
113 | self.conv2 = conv3x3(64, 64)
114 | self.bn2 = nn.BatchNorm2d(64)
115 | self.conv3 = conv3x3(64, 128)
116 | self.bn3 = nn.BatchNorm2d(128)
117 | self.relu = nn.ReLU(inplace=True)
118 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
119 | self.layer1 = self._make_layer(block, 64, layers[0])
120 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
121 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
122 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
123 | self.avgpool = nn.AvgPool2d(7, stride=1)
124 | self.fc = nn.Linear(512 * block.expansion, num_classes)
125 |
126 | for m in self.modules():
127 | if isinstance(m, nn.Conv2d):
128 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
129 | elif isinstance(m, nn.BatchNorm2d):
130 | nn.init.constant_(m.weight, 1)
131 | nn.init.constant_(m.bias, 0)
132 |
133 | def _make_layer(self, block, planes, blocks, stride=1):
134 | downsample = None
135 | if stride != 1 or self.inplanes != planes * block.expansion:
136 | downsample = nn.Sequential(
137 | nn.Conv2d(self.inplanes, planes * block.expansion,
138 | kernel_size=1, stride=stride, bias=False),
139 | nn.BatchNorm2d(planes * block.expansion),
140 | )
141 |
142 | layers = []
143 | layers.append(block(self.inplanes, planes, stride, downsample))
144 | self.inplanes = planes * block.expansion
145 | for i in range(1, blocks):
146 | layers.append(block(self.inplanes, planes))
147 |
148 | return nn.Sequential(*layers)
149 |
150 | def forward(self, x):
151 | x = self.relu(self.bn1(self.conv1(x)))
152 | if self.deep_base:
153 | x = self.relu(self.bn2(self.conv2(x)))
154 | x = self.relu(self.bn3(self.conv3(x)))
155 | x = self.maxpool(x)
156 |
157 | x = self.layer1(x)
158 | x = self.layer2(x)
159 | x = self.layer3(x)
160 | x = self.layer4(x)
161 |
162 | x = self.avgpool(x)
163 | x = x.view(x.size(0), -1)
164 | x = self.fc(x)
165 |
166 | return x
167 |
168 |
169 | def resnet18(pretrained=False, **kwargs):
170 | """Constructs a ResNet-18 model.
171 |
172 | Args:
173 | pretrained (bool): If True, returns a model pre-trained on ImageNet
174 | """
175 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
176 | if pretrained:
177 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
178 | return model
179 |
180 |
181 | def resnet34(pretrained=False, **kwargs):
182 | """Constructs a ResNet-34 model.
183 |
184 | Args:
185 | pretrained (bool): If True, returns a model pre-trained on ImageNet
186 | """
187 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
188 | if pretrained:
189 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
190 | return model
191 |
192 |
193 | def resnet50(pretrained=False, **kwargs):
194 | """Constructs a ResNet-50 model.
195 |
196 | Args:
197 | pretrained (bool): If True, returns a model pre-trained on ImageNet
198 | """
199 | model = ResNet(Bottleneck, [3, 4, 6, 3], deep_base=False, **kwargs)
200 | if pretrained:
201 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50']), strict=False)
202 | # model_path = './initmodel/resnet50_v2.pth'
203 | # model.load_state_dict(torch.load(model_path), strict=False)
204 | return model
205 |
206 |
207 | def resnet101(pretrained=False, **kwargs):
208 | """Constructs a ResNet-101 model.
209 |
210 | Args:
211 | pretrained (bool): If True, returns a model pre-trained on ImageNet
212 | """
213 | model = ResNet(Bottleneck, [3, 4, 23, 3], deep_base=False, **kwargs)
214 | if pretrained:
215 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
216 | # model_path = './initmodel/resnet101_v2.pth'
217 | # model.load_state_dict(torch.load(model_path), strict=False)
218 | return model
219 |
220 |
221 | def resnet152(pretrained=False, **kwargs):
222 | """Constructs a ResNet-152 model.
223 |
224 | Args:
225 | pretrained (bool): If True, returns a model pre-trained on ImageNet
226 | """
227 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
228 | if pretrained:
229 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
230 | # model_path = './initmodel/resnet152_v2.pth'
231 | # model.load_state_dict(torch.load(model_path), strict=False)
232 | return model
233 |
--------------------------------------------------------------------------------
/model/SegNet.py:
--------------------------------------------------------------------------------
1 | # _*_ coding: utf-8 _*_
2 | """
3 | Time: 2020/11/30 19:27
4 | Author: Ding Cheng(Deeachain)
5 | File: SegNet.py
6 | Describe: Write during my study in Nanjing University of Information and Secience Technology
7 | Github: https://github.com/Deeachain
8 | """
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | from torchsummary import summary
13 |
14 |
15 | __all__ = ["SegNet"]
16 |
17 | class SegNet(nn.Module):
18 | def __init__(self, classes=19):
19 | super(SegNet, self).__init__()
20 |
21 | batchNorm_momentum = 0.1
22 |
23 | self.conv11 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
24 | self.bn11 = nn.BatchNorm2d(64, momentum=batchNorm_momentum)
25 | self.conv12 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
26 | self.bn12 = nn.BatchNorm2d(64, momentum=batchNorm_momentum)
27 |
28 | self.conv21 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
29 | self.bn21 = nn.BatchNorm2d(128, momentum=batchNorm_momentum)
30 | self.conv22 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
31 | self.bn22 = nn.BatchNorm2d(128, momentum=batchNorm_momentum)
32 |
33 | self.conv31 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
34 | self.bn31 = nn.BatchNorm2d(256, momentum=batchNorm_momentum)
35 | self.conv32 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
36 | self.bn32 = nn.BatchNorm2d(256, momentum=batchNorm_momentum)
37 | self.conv33 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
38 | self.bn33 = nn.BatchNorm2d(256, momentum=batchNorm_momentum)
39 |
40 | self.conv41 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
41 | self.bn41 = nn.BatchNorm2d(512, momentum=batchNorm_momentum)
42 | self.conv42 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
43 | self.bn42 = nn.BatchNorm2d(512, momentum=batchNorm_momentum)
44 | self.conv43 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
45 | self.bn43 = nn.BatchNorm2d(512, momentum=batchNorm_momentum)
46 |
47 | self.conv51 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
48 | self.bn51 = nn.BatchNorm2d(512, momentum=batchNorm_momentum)
49 | self.conv52 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
50 | self.bn52 = nn.BatchNorm2d(512, momentum=batchNorm_momentum)
51 | self.conv53 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
52 | self.bn53 = nn.BatchNorm2d(512, momentum=batchNorm_momentum)
53 |
54 | self.conv53d = nn.Conv2d(512, 512, kernel_size=3, padding=1)
55 | self.bn53d = nn.BatchNorm2d(512, momentum=batchNorm_momentum)
56 | self.conv52d = nn.Conv2d(512, 512, kernel_size=3, padding=1)
57 | self.bn52d = nn.BatchNorm2d(512, momentum=batchNorm_momentum)
58 | self.conv51d = nn.Conv2d(512, 512, kernel_size=3, padding=1)
59 | self.bn51d = nn.BatchNorm2d(512, momentum=batchNorm_momentum)
60 |
61 | self.conv43d = nn.Conv2d(512, 512, kernel_size=3, padding=1)
62 | self.bn43d = nn.BatchNorm2d(512, momentum=batchNorm_momentum)
63 | self.conv42d = nn.Conv2d(512, 512, kernel_size=3, padding=1)
64 | self.bn42d = nn.BatchNorm2d(512, momentum=batchNorm_momentum)
65 | self.conv41d = nn.Conv2d(512, 256, kernel_size=3, padding=1)
66 | self.bn41d = nn.BatchNorm2d(256, momentum=batchNorm_momentum)
67 |
68 | self.conv33d = nn.Conv2d(256, 256, kernel_size=3, padding=1)
69 | self.bn33d = nn.BatchNorm2d(256, momentum=batchNorm_momentum)
70 | self.conv32d = nn.Conv2d(256, 256, kernel_size=3, padding=1)
71 | self.bn32d = nn.BatchNorm2d(256, momentum=batchNorm_momentum)
72 | self.conv31d = nn.Conv2d(256, 128, kernel_size=3, padding=1)
73 | self.bn31d = nn.BatchNorm2d(128, momentum=batchNorm_momentum)
74 |
75 | self.conv22d = nn.Conv2d(128, 128, kernel_size=3, padding=1)
76 | self.bn22d = nn.BatchNorm2d(128, momentum=batchNorm_momentum)
77 | self.conv21d = nn.Conv2d(128, 64, kernel_size=3, padding=1)
78 | self.bn21d = nn.BatchNorm2d(64, momentum=batchNorm_momentum)
79 |
80 | self.conv12d = nn.Conv2d(64, 64, kernel_size=3, padding=1)
81 | self.bn12d = nn.BatchNorm2d(64, momentum=batchNorm_momentum)
82 | self.conv11d = nn.Conv2d(64, classes, kernel_size=3, padding=1)
83 |
84 | def forward(self, x):
85 | # Stage 1
86 | x11 = F.relu(self.bn11(self.conv11(x)))
87 | x12 = F.relu(self.bn12(self.conv12(x11)))
88 | x1_size = x12.size()
89 | x1p, id1 = F.max_pool2d(x12, kernel_size=2, stride=2, return_indices=True)
90 |
91 | # Stage 2
92 | x21 = F.relu(self.bn21(self.conv21(x1p)))
93 | x22 = F.relu(self.bn22(self.conv22(x21)))
94 | x2_size = x22.size()
95 | x2p, id2 = F.max_pool2d(x22, kernel_size=2, stride=2, return_indices=True)
96 |
97 | # Stage 3
98 | x31 = F.relu(self.bn31(self.conv31(x2p)))
99 | x32 = F.relu(self.bn32(self.conv32(x31)))
100 | x33 = F.relu(self.bn33(self.conv33(x32)))
101 | x3_size = x33.size()
102 | x3p, id3 = F.max_pool2d(x33, kernel_size=2, stride=2, return_indices=True)
103 |
104 | # Stage 4
105 | x41 = F.relu(self.bn41(self.conv41(x3p)))
106 | x42 = F.relu(self.bn42(self.conv42(x41)))
107 | x43 = F.relu(self.bn43(self.conv43(x42)))
108 | x4_size = x43.size()
109 | x4p, id4 = F.max_pool2d(x43, kernel_size=2, stride=2, return_indices=True)
110 |
111 | # Stage 5
112 | x51 = F.relu(self.bn51(self.conv51(x4p)))
113 | x52 = F.relu(self.bn52(self.conv52(x51)))
114 | x53 = F.relu(self.bn53(self.conv53(x52)))
115 | x5_size = x53.size()
116 | x5p, id5 = F.max_pool2d(x53, kernel_size=2, stride=2, return_indices=True)
117 |
118 | # Stage 5d
119 | x5d = F.max_unpool2d(x5p, id5, kernel_size=2, stride=2, output_size=x5_size)
120 | x53d = F.relu(self.bn53d(self.conv53d(x5d)))
121 | x52d = F.relu(self.bn52d(self.conv52d(x53d)))
122 | x51d = F.relu(self.bn51d(self.conv51d(x52d)))
123 |
124 | # Stage 4d
125 | x4d = F.max_unpool2d(x51d, id4, kernel_size=2, stride=2, output_size=x4_size)
126 | x43d = F.relu(self.bn43d(self.conv43d(x4d)))
127 | x42d = F.relu(self.bn42d(self.conv42d(x43d)))
128 | x41d = F.relu(self.bn41d(self.conv41d(x42d)))
129 |
130 | # Stage 3d
131 | x3d = F.max_unpool2d(x41d, id3, kernel_size=2, stride=2, output_size=x3_size)
132 | x33d = F.relu(self.bn33d(self.conv33d(x3d)))
133 | x32d = F.relu(self.bn32d(self.conv32d(x33d)))
134 | x31d = F.relu(self.bn31d(self.conv31d(x32d)))
135 |
136 | # Stage 2d
137 | x2d = F.max_unpool2d(x31d, id2, kernel_size=2, stride=2, output_size=x2_size)
138 | x22d = F.relu(self.bn22d(self.conv22d(x2d)))
139 | x21d = F.relu(self.bn21d(self.conv21d(x22d)))
140 |
141 | # Stage 1d
142 | x1d = F.max_unpool2d(x21d, id1, kernel_size=2, stride=2, output_size=x1_size)
143 | x12d = F.relu(self.bn12d(self.conv12d(x1d)))
144 | x11d = self.conv11d(x12d)
145 |
146 | return x11d
147 |
148 |
149 | """print layers and params of network"""
150 | if __name__ == '__main__':
151 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
152 | model = SegNet(classes=19).to(device)
153 | summary(model, (3, 512, 1024))
154 |
--------------------------------------------------------------------------------
/model/UNet.py:
--------------------------------------------------------------------------------
1 | # _*_ coding: utf-8 _*_
2 | """
3 | Time: 2020/11/22 下午3:25
4 | Author: Cheng Ding(Deeachain)
5 | File: UNet.py
6 | Describe: Write during my study in Nanjing University of Information and Secience Technology
7 | Github: https://github.com/Deeachain
8 | """
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | from torchsummary import summary
13 | from model.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
14 |
15 |
16 | __all__ = ["UNet"]
17 |
18 | def conv1x1(in_planes, out_planes, stride=1):
19 | """1x1 convolution"""
20 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
21 |
22 |
23 | class conv3x3_block_x1(nn.Module):
24 | '''(conv => BN => ReLU) * 1'''
25 |
26 | def __init__(self, in_ch, out_ch):
27 | super(conv3x3_block_x1, self).__init__()
28 | self.conv = nn.Sequential(
29 | nn.Conv2d(in_ch, out_ch, 3, padding=1),
30 | nn.BatchNorm2d(out_ch),
31 | nn.ReLU(inplace=True)
32 | )
33 |
34 | def forward(self, x):
35 | x = self.conv(x)
36 | return x
37 |
38 |
39 | class conv3x3_block_x2(nn.Module):
40 | '''(conv => BN => ReLU) * 2'''
41 |
42 | def __init__(self, in_ch, out_ch):
43 | super(conv3x3_block_x2, self).__init__()
44 | self.conv = nn.Sequential(
45 | nn.Conv2d(in_ch, out_ch, 3, padding=1),
46 | nn.BatchNorm2d(out_ch),
47 | nn.ReLU(inplace=True),
48 | nn.Conv2d(out_ch, out_ch, 3, padding=1),
49 | nn.BatchNorm2d(out_ch),
50 | nn.ReLU(inplace=True)
51 | )
52 |
53 | def forward(self, x):
54 | x = self.conv(x)
55 | return x
56 |
57 |
58 | class upsample(nn.Module):
59 | def __init__(self, in_ch, out_ch):
60 | super(upsample, self).__init__()
61 | self.conv1x1 = conv1x1(in_ch, out_ch)
62 | self.conv = conv3x3_block_x2(in_ch, out_ch)
63 |
64 | def forward(self, H, L):
65 | """
66 | H: High level feature map, upsample
67 | L: Low level feature map, block output
68 | """
69 | H = F.interpolate(H, scale_factor=2, mode='bilinear', align_corners=False)
70 | H = self.conv1x1(H)
71 | x = torch.cat([H, L], dim=1)
72 | x = self.conv(x)
73 | return x
74 |
75 |
76 | class UNet(nn.Module):
77 | def __init__(self, num_classes):
78 | super(UNet, self).__init__()
79 | self.maxpool = nn.MaxPool2d(2)
80 | self.block1 = conv3x3_block_x2(3, 64)
81 | self.block2 = conv3x3_block_x2(64, 128)
82 | self.block3 = conv3x3_block_x2(128, 256)
83 | self.block4 = conv3x3_block_x2(256, 512)
84 | self.block_out = conv3x3_block_x1(512, 1024)
85 | self.upsample1 = upsample(1024, 512)
86 | self.upsample2 = upsample(512, 256)
87 | self.upsample3 = upsample(256, 128)
88 | self.upsample4 = upsample(128, 64)
89 | self.upsample_out = conv3x3_block_x2(64, num_classes)
90 |
91 | self._init_weight()
92 |
93 | def forward(self, x):
94 | block1_x = self.block1(x)
95 | x = self.maxpool(block1_x)
96 | block2_x = self.block2(x)
97 | x = self.maxpool(block2_x)
98 | block3_x = self.block3(x)
99 | x = self.maxpool(block3_x)
100 | block4_x = self.block4(x)
101 | x = self.maxpool(block4_x)
102 | x = self.block_out(x)
103 | x = self.upsample1(x, block4_x)
104 | x = self.upsample2(x, block3_x)
105 | x = self.upsample3(x, block2_x)
106 | x = self.upsample4(x, block1_x)
107 | x = self.upsample_out(x)
108 |
109 | return x
110 |
111 | def _init_weight(self):
112 | for m in self.modules():
113 | if isinstance(m, nn.Conv2d):
114 | torch.nn.init.kaiming_normal_(m.weight)
115 | elif isinstance(m, SynchronizedBatchNorm2d):
116 | m.weight.data.fill_(1)
117 | m.bias.data.zero_()
118 | elif isinstance(m, nn.BatchNorm2d):
119 | m.weight.data.fill_(1)
120 | m.bias.data.zero_()
121 |
122 |
123 | """print layers and params of network"""
124 | if __name__ == '__main__':
125 | model = UNet(num_classes=3)
126 | print(model.modules())
127 | summary(model, (3, 512, 512), device="cpu")
128 |
--------------------------------------------------------------------------------
/model/base_model/__init__.py:
--------------------------------------------------------------------------------
1 | # _*_ coding: utf-8 _*_
2 | """
3 | Time: 2020/11/27 10:23
4 | Author: Cheng Ding(Deeachain)
5 | File: __init__.py
6 | Describe: Write during my study in Nanjing University of Information and Secience Technology
7 | Github: https://github.com/Deeachain
8 | """
9 | from .resnet import ResNet, resnet18, resnet34, resnet50, resnet101, resnet152
10 | from .xception import Xception, xception39
11 |
12 |
13 | def build_backbone(backbone, out_stride=32, mult_grid=False):
14 | if backbone == 'resnet18':
15 | return resnet18(out_stride, mult_grid)
16 | elif backbone == 'resnet34':
17 | return resnet34(out_stride, mult_grid)
18 | elif backbone == 'resnet50':
19 | return resnet50(out_stride, mult_grid)
20 | elif backbone == 'resnet101':
21 | return resnet101(out_stride, mult_grid)
22 | else:
23 | raise NotImplementedError
24 |
--------------------------------------------------------------------------------
/model/base_model/resnet.py:
--------------------------------------------------------------------------------
1 | # _*_ coding: utf-8 _*_
2 | """
3 | Time: 2020/11/27 10:23
4 | Author: Cheng Ding(Deeachain)
5 | File: resnet.py
6 | Describe: Write during my study in Nanjing University of Information and Secience Technology
7 | Github: https://github.com/Deeachain
8 | """
9 | import torch.nn as nn
10 | import math
11 | import torch.utils.model_zoo as model_zoo
12 | from torchsummary import summary
13 |
14 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
15 | 'resnet152']
16 |
17 | class BasicBlock(nn.Module):
18 | expansion = 1
19 |
20 | def __init__(self, inplanes, planes, dilation=1, stride=1, downsample=None):
21 | super(BasicBlock, self).__init__()
22 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride,
23 | padding=dilation, dilation=dilation, bias=False)
24 | self.bn1 = nn.BatchNorm2d(planes)
25 | self.relu = nn.ReLU(inplace=True)
26 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1,
27 | padding=dilation, dilation=dilation, bias=False)
28 | self.bn2 = nn.BatchNorm2d(planes)
29 | self.downsample = downsample
30 | self.stride = stride
31 |
32 | def forward(self, x):
33 | residual = x
34 |
35 | out = self.conv1(x)
36 | out = self.bn1(out)
37 | out = self.relu(out)
38 | out = self.conv2(out)
39 | out = self.bn2(out)
40 |
41 | if self.downsample is not None:
42 | residual = self.downsample(x)
43 |
44 | out += residual
45 | out = self.relu(out)
46 |
47 | return out
48 |
49 |
50 | class Bottleneck(nn.Module):
51 | expansion = 4
52 |
53 | def __init__(self, inplanes, planes, dilation=1, stride=1, downsample=None):
54 | super(Bottleneck, self).__init__()
55 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
56 | self.bn1 = nn.BatchNorm2d(planes)
57 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
58 | padding=dilation, dilation=dilation, bias=False)
59 | self.bn2 = nn.BatchNorm2d(planes)
60 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
61 | self.bn3 = nn.BatchNorm2d(planes * 4)
62 | self.relu = nn.ReLU(inplace=True)
63 | self.downsample = downsample
64 | self.stride = stride
65 |
66 | def forward(self, x):
67 | residual = x
68 |
69 | out = self.conv1(x)
70 | out = self.bn1(out)
71 | out = self.relu(out)
72 |
73 | out = self.conv2(out)
74 | out = self.bn2(out)
75 | out = self.relu(out)
76 |
77 | out = self.conv3(out)
78 | out = self.bn3(out)
79 |
80 | if self.downsample is not None:
81 | residual = self.downsample(x)
82 |
83 | out += residual
84 | out = self.relu(out)
85 |
86 | return out
87 |
88 |
89 | class ResNet(nn.Module):
90 |
91 | def __init__(self, block, layers, out_stride, mult_grid):
92 | self.inplanes = 64
93 | super(ResNet, self).__init__()
94 | if out_stride == 8:
95 | stride = [2, 1, 1]
96 | elif out_stride == 16:
97 | stride = [2, 2, 1]
98 | elif out_stride == 32:
99 | stride = [2, 2, 2]
100 | # setting resnet last layer with dilation
101 | if mult_grid:
102 | if layers[-1] == 3: # layers >= 50
103 | mult_grid = [2, 4, 6]
104 | mult_grid = [4, 8, 16]
105 | else:
106 | mult_grid = [2, 4]
107 | mult_grid = [4, 8]
108 | else:
109 | mult_grid = []
110 |
111 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
112 | self.bn1 = nn.BatchNorm2d(64)
113 | self.relu = nn.ReLU(inplace=True)
114 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
115 | self.layer1 = self._make_layer(block, 64, layers[0])
116 | self.layer2 = self._make_layer(block, 128, layers[1], stride=stride[0])
117 | self.layer3 = self._make_layer(block, 256, layers[2], stride=stride[1])
118 | self.layer4 = self._make_layer(block, 512, layers[3], stride=stride[2], dilation=mult_grid)
119 |
120 | for m in self.modules():
121 | if isinstance(m, nn.Conv2d):
122 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
123 | m.weight.data.normal_(0, math.sqrt(2. / n))
124 | elif isinstance(m, nn.BatchNorm2d):
125 | m.weight.data.fill_(1)
126 | m.bias.data.zero_()
127 |
128 | def _make_layer(self, block, planes, blocks, stride=1, dilation=[]):
129 | downsample = None
130 | if stride != 1 or self.inplanes != planes * block.expansion:
131 | downsample = nn.Sequential(
132 | nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
133 | nn.BatchNorm2d(planes * block.expansion),
134 | )
135 |
136 | layers = []
137 | if dilation != []:
138 | layers.append(block(self.inplanes, planes, dilation[0], stride, downsample))
139 | else:
140 | layers.append(block(self.inplanes, planes, 1, stride, downsample))
141 | self.inplanes = planes * block.expansion
142 | for i in range(1, blocks):
143 | if dilation != []:
144 | layers.append(block(self.inplanes, planes, dilation[i]))
145 | else:
146 | layers.append(block(self.inplanes, planes))
147 | return nn.Sequential(*layers)
148 |
149 |
150 | def forward(self, x):
151 | blocks = []
152 |
153 | x = self.conv1(x)
154 | x = self.bn1(x)
155 | x = self.relu(x)
156 | x = self.maxpool(x)
157 | x = self.layer1(x)
158 | blocks.append(x)
159 | x = self.layer2(x)
160 | blocks.append(x)
161 | x = self.layer3(x)
162 | blocks.append(x)
163 | x = self.layer4(x)
164 | blocks.append(x)
165 |
166 | return blocks
167 |
168 |
169 | def resnet18(out_stride=32, mult_grid=False):
170 | """Constructs a ResNet-18 model."""
171 | model = ResNet(BasicBlock, [2, 2, 2, 2], out_stride, mult_grid)
172 |
173 | return model
174 |
175 |
176 | def resnet34(out_stride=32, mult_grid=False):
177 | """Constructs a ResNet-34 model."""
178 | model = ResNet(BasicBlock, [3, 4, 6, 3], out_stride, mult_grid)
179 |
180 | return model
181 |
182 |
183 | def resnet50(out_stride=32, mult_grid=False):
184 | """Constructs a ResNet-50 model."""
185 | model = ResNet(Bottleneck, [3, 4, 6, 3], out_stride, mult_grid)
186 |
187 | return model
188 |
189 |
190 | def resnet101(out_stride=32, mult_grid=False):
191 | """Constructs a ResNet-101 model."""
192 | model = ResNet(Bottleneck, [3, 4, 23, 3], out_stride, mult_grid)
193 | return model
194 |
195 |
196 | def resnet152(out_stride=32, mult_grid=False):
197 | """Constructs a ResNet-152 model."""
198 | model = ResNet(Bottleneck, [3, 8, 36, 3], out_stride, mult_grid)
199 | return model
200 |
201 |
202 | """print layers and params of network"""
203 | if __name__ == '__main__':
204 | model = resnet18(pretrained=True)
205 | model_dict = model.state_dict()
206 | # for k, v in model_dict.items():
207 | # print(k, v)
208 | print(model)
209 | # summary(model, (3, 512, 512), device="cpu")
210 |
--------------------------------------------------------------------------------
/model/base_model/xception.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, division, absolute_import
2 | import torch
3 | import torch.nn as nn
4 | from collections import OrderedDict, defaultdict
5 |
6 | __all__ = ['Xception', 'xception39']
7 |
8 |
9 |
10 | def load_model(model, model_file, is_restore=False):
11 | if isinstance(model_file, str):
12 | state_dict = torch.load(model_file, map_location=torch.device('cpu'))
13 | if 'model' in state_dict.keys():
14 | state_dict = state_dict['model']
15 | else:
16 | state_dict = model_file
17 |
18 | if is_restore:
19 | new_state_dict = OrderedDict()
20 | for k, v in state_dict.items():
21 | name = 'module.' + k
22 | new_state_dict[name] = v
23 | state_dict = new_state_dict
24 |
25 | model.load_state_dict(state_dict, strict=False)
26 | ckpt_keys = set(state_dict.keys())
27 | own_keys = set(model.state_dict().keys())
28 | missing_keys = own_keys - ckpt_keys
29 | unexpected_keys = ckpt_keys - own_keys
30 |
31 | # if len(missing_keys) > 0:
32 | # logger.warning('Missing key(s) in state_dict: {}'.format(
33 | # ', '.join('{}'.format(k) for k in missing_keys)))
34 | #
35 | # if len(unexpected_keys) > 0:
36 | # logger.warning('Unexpected key(s) in state_dict: {}'.format(
37 | # ', '.join('{}'.format(k) for k in unexpected_keys)))
38 |
39 | del state_dict
40 |
41 | return model
42 |
43 |
44 | class ConvBnRelu(nn.Module):
45 | def __init__(self, in_planes, out_planes, ksize, stride, pad, dilation=1,
46 | groups=1, has_bn=True, norm_layer=nn.BatchNorm2d, bn_eps=1e-5,
47 | has_relu=True, inplace=True, has_bias=False):
48 | super(ConvBnRelu, self).__init__()
49 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=ksize,
50 | stride=stride, padding=pad,
51 | dilation=dilation, groups=groups, bias=has_bias)
52 | self.has_bn = has_bn
53 | if self.has_bn:
54 | self.bn = norm_layer(out_planes, eps=bn_eps)
55 | self.has_relu = has_relu
56 | if self.has_relu:
57 | self.relu = nn.ReLU(inplace=inplace)
58 |
59 | def forward(self, x):
60 | x = self.conv(x)
61 | if self.has_bn:
62 | x = self.bn(x)
63 | if self.has_relu:
64 | x = self.relu(x)
65 |
66 | return x
67 |
68 |
69 | class SeparableConvBnRelu(nn.Module):
70 | def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
71 | padding=0, dilation=1,
72 | has_relu=True, norm_layer=nn.BatchNorm2d):
73 | super(SeparableConvBnRelu, self).__init__()
74 |
75 | self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size, stride,
76 | padding, dilation, groups=in_channels,
77 | bias=False)
78 | self.point_wise_cbr = ConvBnRelu(in_channels, out_channels, 1, 1, 0,
79 | has_bn=True, norm_layer=norm_layer,
80 | has_relu=has_relu, has_bias=False)
81 |
82 | def forward(self, x):
83 | x = self.conv1(x)
84 | x = self.point_wise_cbr(x)
85 | return x
86 |
87 |
88 | class Block(nn.Module):
89 | expansion = 4
90 |
91 | def __init__(self, in_channels, mid_out_channels, has_proj, stride,
92 | dilation=1, norm_layer=nn.BatchNorm2d):
93 | super(Block, self).__init__()
94 | self.has_proj = has_proj
95 |
96 | if has_proj:
97 | self.proj = SeparableConvBnRelu(in_channels,
98 | mid_out_channels * self.expansion,
99 | 3, stride, 1,
100 | has_relu=False,
101 | norm_layer=norm_layer)
102 |
103 | self.residual_branch = nn.Sequential(
104 | SeparableConvBnRelu(in_channels, mid_out_channels,
105 | 3, stride, dilation, dilation,
106 | has_relu=True, norm_layer=norm_layer),
107 | SeparableConvBnRelu(mid_out_channels, mid_out_channels, 3, 1, 1,
108 | has_relu=True, norm_layer=norm_layer),
109 | SeparableConvBnRelu(mid_out_channels,
110 | mid_out_channels * self.expansion, 3, 1, 1,
111 | has_relu=False, norm_layer=norm_layer))
112 | self.relu = nn.ReLU(inplace=True)
113 |
114 | def forward(self, x):
115 | shortcut = x
116 | if self.has_proj:
117 | shortcut = self.proj(x)
118 |
119 | residual = self.residual_branch(x)
120 | output = self.relu(shortcut + residual)
121 |
122 | return output
123 |
124 |
125 | class Xception(nn.Module):
126 | def __init__(self, block, layers, channels, norm_layer=nn.BatchNorm2d):
127 | super(Xception, self).__init__()
128 |
129 | self.in_channels = 8
130 | self.conv1 = ConvBnRelu(3, self.in_channels, 3, 2, 1,
131 | has_bn=True, norm_layer=norm_layer,
132 | has_relu=True, has_bias=False)
133 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
134 |
135 | self.layer1 = self._make_layer(block, norm_layer,
136 | layers[0], channels[0], stride=2)
137 | self.layer2 = self._make_layer(block, norm_layer,
138 | layers[1], channels[1], stride=2)
139 | self.layer3 = self._make_layer(block, norm_layer,
140 | layers[2], channels[2], stride=2)
141 |
142 | def _make_layer(self, block, norm_layer, blocks,
143 | mid_out_channels, stride=1):
144 | layers = []
145 | has_proj = True if stride > 1 else False
146 | layers.append(block(self.in_channels, mid_out_channels, has_proj,
147 | stride=stride, norm_layer=norm_layer))
148 | self.in_channels = mid_out_channels * block.expansion
149 | for i in range(1, blocks):
150 | layers.append(block(self.in_channels, mid_out_channels,
151 | has_proj=False, stride=1,
152 | norm_layer=norm_layer))
153 |
154 | return nn.Sequential(*layers)
155 |
156 | def forward(self, x):
157 | x = self.conv1(x)
158 | x = self.maxpool(x)
159 |
160 | blocks = []
161 | x = self.layer1(x);
162 | blocks.append(x)
163 | x = self.layer2(x);
164 | blocks.append(x)
165 | x = self.layer3(x);
166 | blocks.append(x)
167 |
168 | return blocks
169 |
170 |
171 | def xception39(pretrained_model=None, **kwargs):
172 | model = Xception(Block, [4, 8, 4], [16, 32, 64], **kwargs)
173 |
174 | if pretrained_model is not None:
175 | model = load_model(model, pretrained_model)
176 | return model
177 |
--------------------------------------------------------------------------------
/model/sync_batchnorm/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : __init__.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
12 | from .replicate import DataParallelWithCallback, patch_replication_callback
13 |
--------------------------------------------------------------------------------
/model/sync_batchnorm/batchnorm_reimpl.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # File : batchnorm_reimpl.py
4 | # Author : acgtyrant
5 | # Date : 11/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.init as init
14 |
15 | __all__ = ['BatchNormReimpl']
16 |
17 |
18 | class BatchNorm2dReimpl(nn.Module):
19 | """
20 | A re-implementation of batch normalization, used for testing the numerical
21 | stability.
22 |
23 | Author: acgtyrant
24 | See also:
25 | https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14
26 | """
27 | def __init__(self, num_features, eps=1e-5, momentum=0.1):
28 | super().__init__()
29 |
30 | self.num_features = num_features
31 | self.eps = eps
32 | self.momentum = momentum
33 | self.weight = nn.Parameter(torch.empty(num_features))
34 | self.bias = nn.Parameter(torch.empty(num_features))
35 | self.register_buffer('running_mean', torch.zeros(num_features))
36 | self.register_buffer('running_var', torch.ones(num_features))
37 | self.reset_parameters()
38 |
39 | def reset_running_stats(self):
40 | self.running_mean.zero_()
41 | self.running_var.fill_(1)
42 |
43 | def reset_parameters(self):
44 | self.reset_running_stats()
45 | init.uniform_(self.weight)
46 | init.zeros_(self.bias)
47 |
48 | def forward(self, input_):
49 | batchsize, channels, height, width = input_.size()
50 | numel = batchsize * height * width
51 | input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel)
52 | sum_ = input_.sum(1)
53 | sum_of_square = input_.pow(2).sum(1)
54 | mean = sum_ / numel
55 | sumvar = sum_of_square - sum_ * mean
56 |
57 | self.running_mean = (
58 | (1 - self.momentum) * self.running_mean
59 | + self.momentum * mean.detach()
60 | )
61 | unbias_var = sumvar / (numel - 1)
62 | self.running_var = (
63 | (1 - self.momentum) * self.running_var
64 | + self.momentum * unbias_var.detach()
65 | )
66 |
67 | bias_var = sumvar / numel
68 | inv_std = 1 / (bias_var + self.eps).pow(0.5)
69 | output = (
70 | (input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) *
71 | self.weight.unsqueeze(1) + self.bias.unsqueeze(1))
72 |
73 | return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous()
74 |
75 |
--------------------------------------------------------------------------------
/model/sync_batchnorm/comm.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : comm.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import queue
12 | import collections
13 | import threading
14 |
15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
16 |
17 |
18 | class FutureResult(object):
19 | """A thread-safe future implementation. Used only as one-to-one pipe."""
20 |
21 | def __init__(self):
22 | self._result = None
23 | self._lock = threading.Lock()
24 | self._cond = threading.Condition(self._lock)
25 |
26 | def put(self, result):
27 | with self._lock:
28 | assert self._result is None, 'Previous result has\'t been fetched.'
29 | self._result = result
30 | self._cond.notify()
31 |
32 | def get(self):
33 | with self._lock:
34 | if self._result is None:
35 | self._cond.wait()
36 |
37 | res = self._result
38 | self._result = None
39 | return res
40 |
41 |
42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
44 |
45 |
46 | class SlavePipe(_SlavePipeBase):
47 | """Pipe for master-slave communication."""
48 |
49 | def run_slave(self, msg):
50 | self.queue.put((self.identifier, msg))
51 | ret = self.result.get()
52 | self.queue.put(True)
53 | return ret
54 |
55 |
56 | class SyncMaster(object):
57 | """An abstract `SyncMaster` object.
58 |
59 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should
60 | call `register(id)` and obtain an `SlavePipe` to communicate with the master.
61 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
62 | and passed to a registered callback.
63 | - After receiving the messages, the master device should gather the information and determine to message passed
64 | back to each slave devices.
65 | """
66 |
67 | def __init__(self, master_callback):
68 | """
69 |
70 | Args:
71 | master_callback: a callback to be invoked after having collected messages from slave devices.
72 | """
73 | self._master_callback = master_callback
74 | self._queue = queue.Queue()
75 | self._registry = collections.OrderedDict()
76 | self._activated = False
77 |
78 | def __getstate__(self):
79 | return {'master_callback': self._master_callback}
80 |
81 | def __setstate__(self, state):
82 | self.__init__(state['master_callback'])
83 |
84 | def register_slave(self, identifier):
85 | """
86 | Register an slave device.
87 |
88 | Args:
89 | identifier: an identifier, usually is the device id.
90 |
91 | Returns: a `SlavePipe` object which can be used to communicate with the master device.
92 |
93 | """
94 | if self._activated:
95 | assert self._queue.empty(), 'Queue is not clean before next initialization.'
96 | self._activated = False
97 | self._registry.clear()
98 | future = FutureResult()
99 | self._registry[identifier] = _MasterRegistry(future)
100 | return SlavePipe(identifier, self._queue, future)
101 |
102 | def run_master(self, master_msg):
103 | """
104 | Main entry for the master device in each forward pass.
105 | The messages were first collected from each devices (including the master device), and then
106 | an callback will be invoked to compute the message to be sent back to each devices
107 | (including the master device).
108 |
109 | Args:
110 | master_msg: the message that the master want to send to itself. This will be placed as the first
111 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
112 |
113 | Returns: the message to be sent back to the master device.
114 |
115 | """
116 | self._activated = True
117 |
118 | intermediates = [(0, master_msg)]
119 | for i in range(self.nr_slaves):
120 | intermediates.append(self._queue.get())
121 |
122 | results = self._master_callback(intermediates)
123 | assert results[0][0] == 0, 'The first result should belongs to the master.'
124 |
125 | for i, res in results:
126 | if i == 0:
127 | continue
128 | self._registry[i].result.put(res)
129 |
130 | for i in range(self.nr_slaves):
131 | assert self._queue.get() is True
132 |
133 | return results[0][1]
134 |
135 | @property
136 | def nr_slaves(self):
137 | return len(self._registry)
138 |
--------------------------------------------------------------------------------
/model/sync_batchnorm/replicate.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : replicate.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import functools
12 |
13 | from torch.nn.parallel.data_parallel import DataParallel
14 |
15 | __all__ = [
16 | 'CallbackContext',
17 | 'execute_replication_callbacks',
18 | 'DataParallelWithCallback',
19 | 'patch_replication_callback'
20 | ]
21 |
22 |
23 | class CallbackContext(object):
24 | pass
25 |
26 |
27 | def execute_replication_callbacks(modules):
28 | """
29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
30 |
31 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
32 |
33 | Note that, as all modules are isomorphism, we assign each sub-module with a context
34 | (shared among multiple copies of this module on different devices).
35 | Through this context, different copies can share some information.
36 |
37 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
38 | of any slave copies.
39 | """
40 | master_copy = modules[0]
41 | nr_modules = len(list(master_copy.modules()))
42 | ctxs = [CallbackContext() for _ in range(nr_modules)]
43 |
44 | for i, module in enumerate(modules):
45 | for j, m in enumerate(module.modules()):
46 | if hasattr(m, '__data_parallel_replicate__'):
47 | m.__data_parallel_replicate__(ctxs[j], i)
48 |
49 |
50 | class DataParallelWithCallback(DataParallel):
51 | """
52 | Data Parallel with a replication callback.
53 |
54 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
55 | original `replicate` function.
56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
57 |
58 | Examples:
59 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
60 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
61 | # sync_bn.__data_parallel_replicate__ will be invoked.
62 | """
63 |
64 | def replicate(self, module, device_ids):
65 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
66 | execute_replication_callbacks(modules)
67 | return modules
68 |
69 |
70 | def patch_replication_callback(data_parallel):
71 | """
72 | Monkey-patch an existing `DataParallel` object. Add the replication callback.
73 | Useful when you have customized `DataParallel` implementation.
74 |
75 | Examples:
76 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
77 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
78 | > patch_replication_callback(sync_bn)
79 | # this is equivalent to
80 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
81 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
82 | """
83 |
84 | assert isinstance(data_parallel, DataParallel)
85 |
86 | old_replicate = data_parallel.replicate
87 |
88 | @functools.wraps(old_replicate)
89 | def new_replicate(module, device_ids):
90 | modules = old_replicate(module, device_ids)
91 | execute_replication_callbacks(modules)
92 | return modules
93 |
94 | data_parallel.replicate = new_replicate
95 |
--------------------------------------------------------------------------------
/model/sync_batchnorm/unittest.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : unittest.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import unittest
12 | import torch
13 |
14 |
15 | class TorchTestCase(unittest.TestCase):
16 | def assertTensorClose(self, x, y):
17 | adiff = float((x - y).abs().max())
18 | if (y == 0).all():
19 | rdiff = 'NaN'
20 | else:
21 | rdiff = float((adiff / y).abs().max())
22 |
23 | message = (
24 | 'Tensor close check failed\n'
25 | 'adiff={}\n'
26 | 'rdiff={}\n'
27 | ).format(adiff, rdiff)
28 | self.assertTrue(torch.allclose(x, y), message)
29 |
30 |
--------------------------------------------------------------------------------
/predict.py:
--------------------------------------------------------------------------------
1 | # _*_ coding: utf-8 _*_
2 | """
3 | Time: 2020/11/30 17:02
4 | Author: Ding Cheng(Deeachain)
5 | File: predict.py
6 | Describe: Write during my study in Nanjing University of Information and Secience Technology
7 | Github: https://github.com/Deeachain
8 | """
9 | import os
10 | import torch
11 | import torch.backends.cudnn as cudnn
12 | from torch.utils import data
13 | from argparse import ArgumentParser
14 | from prettytable import PrettyTable
15 | from builders.model_builder import build_model
16 | from builders.dataset_builder import build_dataset_test
17 | from builders.loss_builder import build_loss
18 | from builders.validation_builder import predict_multiscale_sliding
19 |
20 |
21 | def main(args):
22 | """
23 | main function for testing
24 | param args: global arguments
25 | return: None
26 | """
27 | t = PrettyTable(['args_name', 'args_value'])
28 | for k in list(vars(args).keys()):
29 | t.add_row([k, vars(args)[k]])
30 | print(t.get_string(title="Predict Arguments"))
31 |
32 | # build the model
33 | model = build_model(args.model, args.classes, args.backbone, args.pretrained, args.out_stride, args.mult_grid)
34 |
35 | # load the test set
36 | if args.predict_type == 'validation':
37 | testdataset, class_dict_df = build_dataset_test(args.root, args.dataset, args.crop_size,
38 | mode=args.predict_mode, gt=True)
39 | else:
40 | testdataset, class_dict_df = build_dataset_test(args.root, args.dataset, args.crop_size,
41 | mode=args.predict_mode, gt=False)
42 | DataLoader = data.DataLoader(testdataset, batch_size=args.batch_size,
43 | shuffle=False, num_workers=args.batch_size, pin_memory=True, drop_last=False)
44 |
45 | if args.cuda:
46 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
47 | model = model.cuda()
48 | cudnn.benchmark = True
49 | if not torch.cuda.is_available():
50 | raise Exception("no GPU found or wrong gpu id, please run without --cuda")
51 |
52 | if not os.path.exists(args.save_seg_dir):
53 | os.makedirs(args.save_seg_dir)
54 |
55 | if args.checkpoint:
56 | if os.path.isfile(args.checkpoint):
57 | checkpoint = torch.load(args.checkpoint)['model']
58 | check_list = [i for i in checkpoint.items()]
59 | # Read weights with multiple cards, and continue training with a single card this time
60 | if 'module.' in check_list[0][0]: # 读取使用多卡训练权重,并且此次使用单卡预测
61 | new_stat_dict = {}
62 | for k, v in checkpoint.items():
63 | new_stat_dict[k[7:]] = v
64 | model.load_state_dict(new_stat_dict, strict=True)
65 | # Read the training weight of a single card, and continue training with a single card this time
66 | else:
67 | model.load_state_dict(checkpoint)
68 | else:
69 | print("no checkpoint found at '{}'".format(args.checkpoint))
70 | raise FileNotFoundError("no checkpoint found at '{}'".format(args.checkpoint))
71 |
72 | # define loss function
73 | criterion = build_loss(args, None, 255)
74 | print(">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n"
75 | ">>>>>>>>>>> beginning testing >>>>>>>>>>>>\n"
76 | ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>")
77 | predict_multiscale_sliding(args=args, model=model, testLoader=DataLoader, class_dict_df=class_dict_df,
78 | scales=args.scales, overlap=args.overlap, criterion=criterion,
79 | mode=args.predict_type, save_result=True)
80 |
81 |
82 | if __name__ == '__main__':
83 | parser = ArgumentParser()
84 | parser.add_argument('--model', default="UNet", help="model name")
85 | parser.add_argument('--backbone', type=str, default="resnet18", help="backbone name")
86 | parser.add_argument('--pretrained', action='store_true',
87 | help="whether choice backbone pretrained on imagenet")
88 | parser.add_argument('--out_stride', type=int, default=32, help="output stride of backbone")
89 | parser.add_argument('--mult_grid', action='store_true',
90 | help="whether choice mult_grid in backbone last layer")
91 | parser.add_argument('--root', type=str, default="", help="path of datasets")
92 | parser.add_argument('--predict_mode', default="sliding", choices=["sliding", "whole"],
93 | help="Defalut use whole predict mode")
94 | parser.add_argument('--predict_type', default="validation", choices=["validation", "predict"],
95 | help="Defalut use validation type")
96 | parser.add_argument('--flip_merge', action='store_true', help="Defalut use predict without flip_merge")
97 | parser.add_argument('--scales', type=float, nargs='+', default=[1.0], help="predict with multi_scales")
98 | parser.add_argument('--overlap', type=float, default=0.0, help="sliding predict overlap rate")
99 | parser.add_argument('--dataset', default="paris", help="dataset: cityscapes")
100 | parser.add_argument('--num_workers', type=int, default=4, help="the number of parallel threads")
101 | parser.add_argument('--batch_size', type=int, default=1,
102 | help=" the batch_size is set to 1 when evaluating or testing NOTES:image size should fixed!")
103 | parser.add_argument('--tile_hw_size', type=str, default='512, 512',
104 | help=" the tile_size is when evaluating or testing")
105 | parser.add_argument('--crop_size', type=int, default=769, help="crop size of image")
106 | parser.add_argument('--input_size', type=str, default=(769, 769),
107 | help=" the input_size is for build ProbOhemCrossEntropy2d loss")
108 | parser.add_argument('--checkpoint', type=str, default='',
109 | help="use the file to load the checkpoint for evaluating or testing ")
110 | parser.add_argument('--save_seg_dir', type=str, default="./outputs/",
111 | help="saving path of prediction result")
112 | parser.add_argument('--loss', type=str, default="CrossEntropyLoss2d",
113 | choices=['CrossEntropyLoss2d', 'ProbOhemCrossEntropy2d', 'CrossEntropyLoss2dLabelSmooth',
114 | 'LovaszSoftmax', 'FocalLoss2d'], help="choice loss for train or val in list")
115 | parser.add_argument('--cuda', default=True, help="run on CPU or GPU")
116 | parser.add_argument("--gpus", default="0", type=str, help="gpu ids (default: 0)")
117 | parser.add_argument('--local_rank', type=int, default=0)
118 | args = parser.parse_args()
119 |
120 | save_dirname = args.checkpoint.split('/')[-2] + '_' + args.checkpoint.split('/')[-1].split('.')[0]
121 |
122 | args.save_seg_dir = os.path.join(args.save_seg_dir, args.dataset, args.predict_mode, save_dirname)
123 |
124 | if args.dataset == 'cityscapes':
125 | args.classes = 19
126 | else:
127 | raise NotImplementedError(
128 | "This repository now supports datasets %s is not included" % args.dataset)
129 |
130 | main(args)
131 |
--------------------------------------------------------------------------------
/predict.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | python predict.py --model PSPNet_res50 \
3 | --checkpoint /data/dingcheng/segment/checkpoint/cityscapes/PSPNet_res50/best_model.pth \
4 | --out_stride 8 \
5 | --root '/data/open_data' \
6 | --dataset cityscapes \
7 | --predict_type validation \
8 | --predict_mode whole \
9 | --crop_size 768 \
10 | --tile_hw_size '768,768' \
11 | --batch_size 2 \
12 | --gpus 3 \
13 | --overlap 0.3 \
14 | --scales 0.5 0.75 1.0 1.25
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | cycler==0.10.0
2 | kiwisolver==1.1.0
3 | matplotlib==3.1.1
4 | Pillow>=6.2.2
5 | pyparsing==2.4.2
6 | python-dateutil==2.8.1
7 | pytz==2018.4
8 | six==1.12.0
9 | torch==1.2.0
10 | #torchvision==0.2.0
11 | torchsummary==1.5.1
12 | opencv-python==4.1.0.25
13 | tqdm==4.52.0
14 | numpy==1.15.1
15 | pandas==0.25.0
16 | scipy==1.5.1
17 | prettytable
18 |
--------------------------------------------------------------------------------
/train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Deeplabv3plus_res101 PSPNet_res101 DualSeg_res101 BiSeNet BiSeNetV2 DDRNet
4 | # FCN_ResNet SegTrans
5 |
6 | python -m torch.distributed.launch --nproc_per_node=2 \
7 | train.py --model PSPNet_res50 --out_stride 8 \
8 | --max_epochs 200 --val_epochs 20 --batch_size 4 --lr 0.01 --optim sgd --loss ProbOhemCrossEntropy2d \
9 | --base_size 768 --crop_size 768 --tile_hw_size 768,768 \
10 | --root '/data/open_data' --dataset cityscapes --gpus_id 1,2
11 |
--------------------------------------------------------------------------------
/utils/colorize_mask.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import torch
3 | import numpy as np
4 |
5 | cityscapes_palette = [128, 64, 128, 244, 35, 232, 70, 70, 70, 102, 102, 156, 190, 153, 153, 153, 153, 153, 250, 170, 30,
6 | 220, 220, 0, 107, 142, 35, 152, 251, 152, 70, 130, 180, 220, 20, 60, 255, 0, 0, 0, 0, 142, 0, 0,
7 | 70, 0, 60, 100, 0, 80, 100, 0, 0, 230, 119, 11, 32]
8 |
9 | paris_palette = [255, 255, 255, 0, 0, 255, 255, 0, 0]
10 |
11 | road_palette = [0, 0, 0, 255, 255, 255]
12 |
13 | austin_palette = [0, 0, 0, 255, 255, 255]
14 |
15 | isprs_palette = [255, 255, 255, 0, 0, 255, 0, 255, 255, 0, 255, 0, 255, 255, 0,255, 0, 0]
16 | # isprs_palette = [255, 255, 255, 255, 0, 0, 255, 255, 0, 0, 255, 0, 0, 255, 255, 0, 0, 255]
17 | zero_pad = 256 * 3 - len(cityscapes_palette)
18 | for i in range(zero_pad):
19 | cityscapes_palette.append(0)
20 |
21 |
22 | # zero_pad = 256 * 3 - len(camvid_palette)
23 | # for i in range(zero_pad):
24 | # camvid_palette.append(0)
25 |
26 | def cityscapes_colorize_mask(mask):
27 | # mask: numpy array of the mask
28 | new_mask = Image.fromarray(mask.astype(np.uint8)).convert('P')
29 | new_mask.putpalette(cityscapes_palette)
30 |
31 | return new_mask
32 |
33 | def paris_colorize_mask(mask):
34 | # mask: numpy array of the mask
35 | new_mask = Image.fromarray(mask.astype(np.uint8)).convert('P')
36 | new_mask.putpalette(paris_palette)
37 |
38 | return new_mask
39 |
40 | def road_colorize_mask(mask):
41 | # mask: numpy array of the mask
42 | new_mask = Image.fromarray(mask.astype(np.uint8)).convert('P')
43 | new_mask.putpalette(road_palette)
44 |
45 | return new_mask
46 |
47 | def isprs_colorize_mask(mask):
48 | # mask: numpy array of the mask
49 | new_mask = Image.fromarray(mask.astype(np.uint8)).convert('P')
50 | new_mask.putpalette(isprs_palette)
51 |
52 | return new_mask
53 |
54 |
55 | def austin_colorize_mask(mask):
56 | # mask: numpy array of the mask
57 | new_mask = Image.fromarray(mask.astype(np.uint8)).convert('P')
58 | new_mask.putpalette(austin_palette)
59 |
60 | return new_mask
61 |
62 |
63 | class VOCColorize(object):
64 | def __init__(self, n=22):
65 | self.cmap = voc_color_map(22)
66 | self.cmap = torch.from_numpy(self.cmap[:n])
67 |
68 | def __call__(self, gray_image):
69 | size = gray_image.shape
70 | color_image = np.zeros((3, size[0], size[1]), dtype=np.uint8)
71 |
72 | for label in range(0, len(self.cmap)):
73 | mask = (label == gray_image)
74 | color_image[0][mask] = self.cmap[label][0]
75 | color_image[1][mask] = self.cmap[label][1]
76 | color_image[2][mask] = self.cmap[label][2]
77 |
78 | # handle void
79 | mask = (255 == gray_image)
80 | color_image[0][mask] = color_image[1][mask] = color_image[2][mask] = 255
81 |
82 | return color_image
83 |
84 |
85 | def voc_color_map(N=256, normalized=False):
86 | def bitget(byteval, idx):
87 | return ((byteval & (1 << idx)) != 0)
88 |
89 | dtype = 'float32' if normalized else 'uint8'
90 | cmap = np.zeros((N, 3), dtype=dtype)
91 | for i in range(N):
92 | r = g = b = 0
93 | c = i
94 | for j in range(8):
95 | r = r | (bitget(c, 0) << 7 - j)
96 | g = g | (bitget(c, 1) << 7 - j)
97 | b = b | (bitget(c, 2) << 7 - j)
98 | c = c >> 3
99 |
100 | cmap[i] = np.array([r, g, b])
101 |
102 | cmap = cmap / 255 if normalized else cmap
103 | return cmap
104 |
--------------------------------------------------------------------------------
/utils/distributed.py:
--------------------------------------------------------------------------------
1 | # _*_ coding: utf-8 _*_
2 | """
3 | Time: 2021/4/22 14:47
4 | Author: Cheng Ding(Deeachain)
5 | Version: V 0.1
6 | File: distributed.py
7 | """
8 | import os
9 | import torch
10 | import torch.distributed as dist
11 | import torch.nn as nn
12 | from torch.utils import data
13 |
14 |
15 | def Distribute(args, traindataset, model, criterion, device, gpus):
16 | # process_group = torch.distributed.new_group(list(range(gpus)))
17 | # model = nn.SyncBatchNorm.convert_sync_batchnorm(model, process_group)
18 |
19 | train_sampler = data.distributed.DistributedSampler(traindataset)
20 | DataLoader = data.DataLoader(traindataset, batch_size=args.batch_size//gpus, sampler=train_sampler,
21 | shuffle=False, num_workers=args.batch_size, pin_memory=True, drop_last=True)
22 |
23 | model.to(device)
24 | criterion = criterion.to(device)
25 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
26 | output_device=args.local_rank,
27 | find_unused_parameters=True)
28 | return DataLoader, model, criterion
--------------------------------------------------------------------------------
/utils/earlyStopping.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | class EarlyStopping:
4 | """Early stops the training if validation score doesn't improve after a given patience."""
5 | def __init__(self, patience=10, delta=0):
6 | """
7 | Args:
8 | patience (int): How long to wait after last time validation loss improved.
9 | Default: 10
10 | delta (float): Minimum change in the monitored quantity to qualify as an improvement.
11 | Default: 0
12 | """
13 | self.patience = patience
14 | self.counter = 0
15 | self.best_score = None
16 | self.early_stop = False
17 | self.delta = delta
18 |
19 |
20 | def monitor(self, monitor):
21 |
22 | score = monitor
23 |
24 | if self.best_score is None:
25 | self.best_score = score
26 | elif score <= self.best_score + self.delta:
27 | self.counter += 1
28 | if self.counter >= self.patience:
29 | self.early_stop = True
30 | else:
31 | self.best_score = score
32 | self.counter = 0
33 |
34 |
35 |
36 |
37 |
--------------------------------------------------------------------------------
/utils/flops_counter/CHANGELOG.md:
--------------------------------------------------------------------------------
1 | # ptflops versions log
2 |
3 | ## v 0.3
4 | - Add 1d operators: batch norm, poolings, convolution.
5 | - Add ability to output extended report to any output stream.
6 |
7 | ## v 0.2
8 | - Add new operations: Conv3d, BatchNorm3d, MaxPool3d, AvgPool3d, ConvTranspose2d.
9 | - Add some results on widespread models to the README.
10 | - Minor bugfixes.
11 |
12 | ## v 0.1
13 | - Initial release with basic functionality
14 |
--------------------------------------------------------------------------------
/utils/flops_counter/Flops_test.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import argparse
3 | import torch
4 | from utils.flops_counter.ptflops import get_model_complexity_info
5 |
6 | from model.UNet import UNet
7 | from model.FCN8s import FCN
8 | from model.ENet import ENet
9 | from model.SegNet import SegNet
10 | from model.ERFNet import ERFNet
11 | from model.ESPNet import ESPNet
12 | from model.ESPNet_v2.SegmentationModel import EESPNet_Seg
13 | from model.DABNet import DABNet
14 | from model.BiSeNet import BiSeNet
15 | from model.BiSeNetV2 import BiSeNetV2
16 | from model.PSPNet.pspnet import PSPNet
17 | # from model.PSPNet.psanet import PSANet
18 | from model.DeeplabV3Plus import Deeplabv3plus_res50
19 | from model.DualGCNNet import DualSeg_res50
20 | from model.MyNet import MyNet
21 | from model.MyNet_trans import MyNet_trans
22 | from model.NFSNet import NFSNet
23 |
24 | models = {
25 | 'ENet': ENet,
26 | 'FCN': FCN,
27 | 'UNet': UNet,
28 | 'BiSeNet': BiSeNet,
29 | 'BiSeNetV2': BiSeNetV2,
30 | 'PSPNet': PSPNet,
31 | 'DeeplabV3Plus': Deeplabv3plus_res50,
32 | 'DualGCNNet': DualSeg_res50,
33 | 'MyNet': MyNet,
34 | 'MyNet_trans': MyNet_trans,
35 | 'NFSNet': NFSNet
36 | }
37 |
38 | if __name__ == '__main__':
39 | parser = argparse.ArgumentParser(description='ptflops sample script')
40 | parser.add_argument('--device', type=int, default=0, help='Device to store the model.')
41 | parser.add_argument('--model', choices=list(models.keys()), type=str, default='NFSNet')
42 | parser.add_argument('--result', type=str, default=None)
43 | args = parser.parse_args()
44 |
45 | if args.result is None:
46 | ost = sys.stdout
47 | else:
48 | ost = open(args.result, 'w')
49 |
50 | net = models[args.model](num_classes=3)
51 |
52 | flops, params = get_model_complexity_info(net, (3, 512, 512), as_strings=True, print_per_layer_stat=True, ost=ost)
53 |
54 | print('Flops: ' + flops)
55 | print('Params: ' + params)
56 |
--------------------------------------------------------------------------------
/utils/flops_counter/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Vladislav Sovrasov
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/utils/flops_counter/README.md:
--------------------------------------------------------------------------------
1 | # Flops counter for convolutional networks in pytorch framework
2 | [](https://pypi.org/project/ptflops/)
3 |
4 | This script is designed to compute the theoretical amount of multiply-add operations
5 | in convolutional neural networks. It also can compute the number of parameters and
6 | print per-layer computational cost of a given network.
7 |
8 | Supported layers:
9 | - Conv1d/2d/3d (including grouping)
10 | - ConvTranspose2d (including grouping)
11 | - BatchNorm1d/2d/3d
12 | - Activations (ReLU, PReLU, ELU, ReLU6, LeakyReLU)
13 | - Linear
14 | - Upsample
15 | - Poolings (AvgPool1d/2d/3d, MaxPool1d/2d/3d and adaptive ones)
16 |
17 | Requirements: Pytorch >= 0.4.1, torchvision >= 0.2.1
18 |
19 | Thanks to @warmspringwinds for the initial version of script.
20 |
21 | ## Usage tips
22 |
23 | - This script doesn't take into account `torch.nn.functional.*` operations. For an instance, if one have a semantic segmentation model and use `torch.nn.functional.interpolate` to upscale features, these operations won't contribute to overall amount of flops. To avoid that one can use `torch.nn.Upsample` instead of `torch.nn.functional.interpolate`.
24 | - `ptflops` launches a given model on a random tensor and estimates amount of computations during inference. Complicated models can have several inputs, some of them could be optional. To construct non-trivial input one can use the `input_constructor` argument of the `get_model_complexity_info`. `input_constructor` is a function that takes the input spatial resolution as a tuple and returns a dict with named input arguments of the model. Next this dict would be passed to the model as keyworded arguments.
25 |
26 | ## Install the latest version
27 | ```bash
28 | pip install --upgrade git+https://github.com/sovrasov/flops-counter.pytorch.git
29 | ```
30 |
31 | ## Example
32 | ```python
33 | import torchvision.models as models
34 | import torch
35 | from ptflops import get_model_complexity_info
36 |
37 | with torch.cuda.device(0):
38 | net = models.densenet161()
39 | flops, params = get_model_complexity_info(net, (3, 224, 224), as_strings=True, print_per_layer_stat=True)
40 | print('Flops: ' + flops)
41 | print('Params: ' + params)
42 | ```
43 |
44 | ## Benchmark
45 |
46 | ### [torchvision](https://pytorch.org/docs/1.0.0/torchvision/models.html)
47 |
48 | Model | Input Resolution | Params(M) | MACs(G) | Top-1 error | Top-5 error
49 | --- |--- |--- |--- |--- |---
50 | alexnet |224x224 | 61.1 | 0.72 | 43.45 | 20.91
51 | vgg11 |224x224 | 132.86 | 7.63 | 30.98 | 11.37
52 | vgg13 |224x224 | 133.05 | 11.34 | 30.07 | 10.75
53 | vgg16 |224x224 | 138.36 | 15.5 | 28.41 | 9.62
54 | vgg19 |224x224 | 143.67 | 19.67 | 27.62 | 9.12
55 | vgg11_bn |224x224 | 132.87 | 7.64 | 29.62 | 10.19
56 | vgg13_bn |224x224 | 133.05 | 11.36 | 28.45 | 9.63
57 | vgg16_bn |224x224 | 138.37 | 15.53 | 26.63 | 8.50
58 | vgg19_bn |224x224 | 143.68 | 19.7 | 25.76 | 8.15
59 | resnet18 |224x224 | 11.69 | 1.82 | 30.24 | 10.92
60 | resnet34 |224x224 | 21.8 | 3.68 | 26.70 | 8.58
61 | resnet50 |224x224 | 25.56 | 4.12 | 23.85 | 7.13
62 | resnet101 |224x224 | 44.55 | 7.85 | 22.63 | 6.44
63 | resnet152 |224x224 | 60.19 | 11.58 | 21.69 | 5.94
64 | squeezenet1_0 |224x224 | 1.25 | 0.83 | 41.90 | 19.58
65 | squeezenet1_1 |224x224 | 1.24 | 0.36 | 41.81 | 19.38
66 | densenet121 |224x224 | 7.98 | 2.88 | 25.35 | 7.83
67 | densenet169 |224x224 | 14.15 | 3.42 | 24.00 | 7.00
68 | densenet201 |224x224 | 20.01 | 4.37 | 22.80 | 6.43
69 | densenet161 |224x224 | 28.68 | 7.82 | 22.35 | 6.20
70 | inception_v3 |224x224 | 27.16 | 2.85 | 22.55 | 6.44
71 |
72 | * Top-1 error - ImageNet single-crop top-1 error (224x224)
73 | * Top-5 error - ImageNet single-crop top-5 error (224x224)
74 |
75 | ### [Cadene/pretrained-models.pytorch](https://github.com/Cadene/pretrained-models.pytorch)
76 |
77 | Model | Input Resolution | Params(M) | MACs(G) | Acc@1 | Acc@5
78 | --- |--- |--- |--- |--- |---
79 | alexnet | 224x224 | 61.1 | 0.72 | 56.432 | 79.194
80 | bninception | 224x224 | 11.3 | 2.05 | 73.524 | 91.562
81 | cafferesnet101 | 224x224 | 44.55 | 7.62 | 76.2 | 92.766
82 | densenet121 | 224x224 | 7.98 | 2.88 | 74.646 | 92.136
83 | densenet161 | 224x224 | 28.68 | 7.82 | 77.56 | 93.798
84 | densenet169 | 224x224 | 14.15 | 3.42 | 76.026 | 92.992
85 | densenet201 | 224x224 | 20.01 | 4.37 | 77.152 | 93.548
86 | dpn107 | 224x224 | 86.92 | 18.42 | 79.746 | 94.684
87 | dpn131 | 224x224 | 79.25 | 16.13 | 79.432 | 94.574
88 | dpn68 | 224x224 | 12.61 | 2.36 | 75.868 | 92.774
89 | dpn68b | 224x224 | 12.61 | 2.36 | 77.034 | 93.59
90 | dpn92 | 224x224 | 37.67 | 6.56 | 79.4 | 94.62
91 | dpn98 | 224x224 | 61.57 | 11.76 | 79.224 | 94.488
92 | fbresnet152 | 224x224 | 60.27 | 11.6 | 77.386 | 93.594
93 | inceptionresnetv2 | 299x299 | 55.84 | 13.22 | 80.17 | 95.234
94 | inceptionv3 | 299x299 | 27.16 | 5.73 | 77.294 | 93.454
95 | inceptionv4 | 299x299 | 42.68 | 12.31 | 80.062 | 94.926
96 | nasnetalarge | 331x331 | 88.75 | 24.04 | 82.566 | 96.086
97 | nasnetamobile | 224x224 | 5.29 | 0.59 | 74.08 | 91.74
98 | pnasnet5large | 331x331 | 86.06 | 25.21 | 82.736 | 95.992
99 | polynet | 331x331 | 95.37 | 34.9 | 81.002 | 95.624
100 | resnet101 | 224x224 | 44.55 | 7.85 | 77.438 | 93.672
101 | resnet152 | 224x224 | 60.19 | 11.58 | 78.428 | 94.11
102 | resnet18 | 224x224 | 11.69 | 1.82 | 70.142 | 89.274
103 | resnet34 | 224x224 | 21.8 | 3.68 | 73.554 | 91.456
104 | resnet50 | 224x224 | 25.56 | 4.12 | 76.002 | 92.98
105 | resnext101_32x4d | 224x224 | 44.18 | 8.03 | 78.188 | 93.886
106 | resnext101_64x4d | 224x224 | 83.46 | 15.55 | 78.956 | 94.252
107 | se_resnet101 | 224x224 | 49.33 | 7.63 | 78.396 | 94.258
108 | se_resnet152 | 224x224 | 66.82 | 11.37 | 78.658 | 94.374
109 | se_resnet50 | 224x224 | 28.09 | 3.9 | 77.636 | 93.752
110 | se_resnext101_32x4d | 224x224 | 48.96 | 8.05 | 80.236 | 95.028
111 | se_resnext50_32x4d | 224x224 | 27.56 | 4.28 | 79.076 | 94.434
112 | senet154 | 224x224 | 115.09 | 20.82 | 81.304 | 95.498
113 | squeezenet1_0 | 224x224 | 1.25 | 0.83 | 58.108 | 80.428
114 | squeezenet1_1 | 224x224 | 1.24 | 0.36 | 58.25 | 80.8
115 | vgg11 | 224x224 | 132.86 | 7.63 | 68.97 | 88.746
116 | vgg11_bn | 224x224 | 132.87 | 7.64 | 70.452 | 89.818
117 | vgg13 | 224x224 | 133.05 | 11.34 | 69.662 | 89.264
118 | vgg13_bn | 224x224 | 133.05 | 11.36 | 71.508 | 90.494
119 | vgg16 | 224x224 | 138.36 | 15.5 | 71.636 | 90.354
120 | vgg16_bn | 224x224 | 138.37 | 15.53 | 73.518 | 91.608
121 | vgg19 | 224x224 | 143.67 | 19.67 | 72.08 | 90.822
122 | vgg19_bn | 224x224 | 143.68 | 19.7 | 74.266 | 92.066
123 | xception | 299x299 | 22.86 | 8.42 | 78.888 | 94.292
124 |
125 | * Acc@1 - ImageNet single-crop top-1 accuracy on validation images of the same size used during the training process.
126 | * Acc@5 - ImageNet single-crop top-5 accuracy on validation images of the same size used during the training process.
127 |
--------------------------------------------------------------------------------
/utils/flops_counter/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Deeachain/Segmentation-Pytorch/acc6998863dfef884bc5fe954c2b8de1c28576a7/utils/flops_counter/__init__.py
--------------------------------------------------------------------------------
/utils/flops_counter/ptflops/__init__.py:
--------------------------------------------------------------------------------
1 | from .flops_counter import get_model_complexity_info
2 |
--------------------------------------------------------------------------------
/utils/flops_counter/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import sys
4 | from setuptools import setup, find_packages
5 |
6 | readme = open('README.md').read()
7 |
8 | VERSION = '0.3'
9 |
10 | requirements = [
11 | 'torch',
12 | ]
13 |
14 | setup(
15 | # Metadata
16 | name='ptflops',
17 | version=VERSION,
18 | author='Vladislav Sovrasov',
19 | author_email='sovrasov.vlad@gmail.com',
20 | url='https://github.com/sovrasov/flops-counter.pytorch',
21 | description='Flops counter for convolutional networks in pytorch framework',
22 | long_description=readme,
23 | long_description_content_type='text/markdown',
24 | license='MIT',
25 |
26 | # Package info
27 | packages=find_packages(exclude=('*test*',)),
28 |
29 | #
30 | zip_safe=True,
31 | install_requires=requirements,
32 |
33 | # Classifiers
34 | classifiers=[
35 | 'Programming Language :: Python :: 3',
36 | ],
37 | )
38 |
--------------------------------------------------------------------------------
/utils/fps_test/eval_forward_time.py:
--------------------------------------------------------------------------------
1 | import time
2 | import torch
3 | import torch.backends.cudnn as cudnn
4 |
5 | from argparse import ArgumentParser
6 | from builders.model_builder import build_model
7 |
8 |
9 | def compute_speed(model, input_size, device, iteration=100):
10 | torch.cuda.set_device(device)
11 | cudnn.benchmark = True
12 |
13 | model.eval()
14 | model = model.cuda()
15 |
16 | input = torch.randn(*input_size, device=device)
17 |
18 | for _ in range(50):
19 | model(input)
20 |
21 | print('=========Eval Forward Time=========')
22 | torch.cuda.synchronize()
23 | t_start = time.time()
24 | for _ in range(iteration):
25 | model(input)
26 | torch.cuda.synchronize()
27 | elapsed_time = time.time() - t_start
28 |
29 | speed_time = elapsed_time / iteration * 1000
30 | fps = iteration / elapsed_time
31 |
32 | print('Elapsed Time: [%.2f s / %d iter]' % (elapsed_time, iteration))
33 | print('Speed Time: %.2f ms / iter FPS: %.2f' % (speed_time, fps))
34 | return speed_time, fps
35 |
36 |
37 | if __name__ == '__main__':
38 | parser = ArgumentParser()
39 |
40 | parser.add_argument("--size", type=str, default="512,512", help="input size of model")
41 | parser.add_argument('--num-channels', type=int, default=3)
42 | parser.add_argument('--batch-size', type=int, default=1)
43 | parser.add_argument('--classes', type=int, default=3)
44 | parser.add_argument('--iter', type=int, default=100)
45 | parser.add_argument('--model', type=str, default='NFSNet')
46 | parser.add_argument("--gpus", type=str, default="0", help="gpu ids (default: 0)")
47 | args = parser.parse_args()
48 |
49 | h, w = map(int, args.size.split(','))
50 | model = build_model(args.model, num_classes=args.classes)
51 | compute_speed(model, (args.batch_size, args.num_channels, h, w), int(args.gpus), iteration=args.iter)
52 |
--------------------------------------------------------------------------------
/utils/image_transform.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import random
3 | import numpy as np
4 | from PIL import Image, ImageOps, ImageFilter
5 |
6 | class Normalize(object):
7 | """Normalize a tensor image with mean and standard deviation.
8 | Args:
9 | mean (tuple): means for each channel.
10 | std (tuple): standard deviations for each channel.
11 | """
12 | def __init__(self, mean=(0., 0., 0.), std=(1., 1., 1.)):
13 | self.mean = mean
14 | self.std = std
15 |
16 | def __call__(self, sample):
17 | if len(sample) != 1:
18 | img = sample['image']
19 | mask = sample['label']
20 | img = np.array(img).astype(np.float32)
21 | mask = np.array(mask).astype(np.float32)
22 | img /= 255.0
23 | img -= self.mean
24 | img /= self.std
25 | return {'image': img, 'label': mask}
26 | else:
27 | img = sample['image']
28 | img = np.array(img).astype(np.float32)
29 | img /= 255.0
30 | img -= self.mean
31 | img /= self.std
32 | return {'image': img}
33 |
34 |
35 | class ToTensor(object):
36 | """Convert ndarrays in sample to Tensors."""
37 |
38 | def __call__(self, sample):
39 | # swap color axis because
40 | # numpy image: H x W x C
41 | # torch image: C X H X W
42 | if len(sample) != 1:
43 | img = sample['image']
44 | mask = sample['label']
45 | img = np.array(img).astype(np.float32).transpose((2, 0, 1))
46 | mask = np.array(mask).astype(np.float32)
47 | img = torch.from_numpy(img).float()
48 | mask = torch.from_numpy(mask).float()
49 | return {'image': img, 'label': mask}
50 | else:
51 | img = sample['image']
52 | img = np.array(img).astype(np.float32).transpose((2, 0, 1))
53 | img = torch.from_numpy(img).float()
54 | return {'image': img}
55 |
56 |
57 | class RandomHorizontalFlip(object):
58 | def __call__(self, sample):
59 | if len(sample) != 1:
60 | img = sample['image']
61 | mask = sample['label']
62 | if random.random() < 0.5:
63 | img = img.transpose(Image.FLIP_LEFT_RIGHT)
64 | mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
65 | return {'image': img, 'label': mask}
66 | else:
67 | img = sample['image']
68 | if random.random() < 0.5:
69 | img = img.transpose(Image.FLIP_LEFT_RIGHT)
70 | return {'image': img}
71 |
72 |
73 | class RandomRotate(object):
74 | def __init__(self, degree):
75 | self.degree = degree
76 |
77 | def __call__(self, sample):
78 | if len(sample) != 1:
79 | img = sample['image']
80 | mask = sample['label']
81 | rotate_degree = random.uniform(-1*self.degree, self.degree)
82 | img = img.rotate(rotate_degree, Image.BILINEAR)
83 | mask = mask.rotate(rotate_degree, Image.NEAREST)
84 | return {'image': img, 'label': mask}
85 | else:
86 | img = sample['image']
87 | rotate_degree = random.uniform(-1*self.degree, self.degree)
88 | img = img.rotate(rotate_degree, Image.BILINEAR)
89 | return {'image': img}
90 |
91 |
92 | class RandomGaussianBlur(object):
93 | def __call__(self, sample):
94 | if len(sample) != 1:
95 | img = sample['image']
96 | mask = sample['label']
97 | if random.random() < 0.5:
98 | img = img.filter(ImageFilter.GaussianBlur(radius=random.random()))
99 | return {'image': img, 'label': mask}
100 | else:
101 | img = sample['image']
102 | if random.random() < 0.5:
103 | img = img.filter(ImageFilter.GaussianBlur(radius=random.random()))
104 | return {'image': img}
105 |
106 |
107 | class RandomScaleCrop(object):
108 | def __init__(self, base_size, crop_size, fill=0):
109 | self.base_size = base_size
110 | self.crop_size = crop_size
111 | self.fill = fill
112 |
113 | def __call__(self, sample):
114 | if len(sample) != 1:
115 | img = sample['image']
116 | mask = sample['label']
117 | # random scale (short edge)
118 | short_size = random.randint(int(self.base_size * 0.5), int(self.base_size * 2.0))
119 | w, h = img.size
120 | if h > w:
121 | ow = short_size
122 | oh = int(1.0 * h * ow / w)
123 | else:
124 | oh = short_size
125 | ow = int(1.0 * w * oh / h)
126 | img = img.resize((ow, oh), Image.BILINEAR)
127 | mask = mask.resize((ow, oh), Image.NEAREST)
128 | # pad crop
129 | if short_size < self.crop_size:
130 | padh = self.crop_size - oh if oh < self.crop_size else 0
131 | padw = self.crop_size - ow if ow < self.crop_size else 0
132 | img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0)
133 | mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=self.fill)
134 | # random crop crop_size
135 | w, h = img.size
136 | x1 = random.randint(0, w - self.crop_size)
137 | y1 = random.randint(0, h - self.crop_size)
138 | img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))
139 | mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))
140 | return {'image': img, 'label': mask}
141 | else:
142 | img = sample['image']
143 | # random scale (short edge)
144 | short_size = random.randint(int(self.base_size * 0.5), int(self.base_size * 2.0))
145 | w, h = img.size
146 | if h > w:
147 | ow = short_size
148 | oh = int(1.0 * h * ow / w)
149 | else:
150 | oh = short_size
151 | ow = int(1.0 * w * oh / h)
152 | img = img.resize((ow, oh), Image.BILINEAR)
153 | # pad crop
154 | if short_size < self.crop_size:
155 | padh = self.crop_size - oh if oh < self.crop_size else 0
156 | padw = self.crop_size - ow if ow < self.crop_size else 0
157 | img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0)
158 | # random crop crop_size
159 | w, h = img.size
160 | x1 = random.randint(0, w - self.crop_size)
161 | y1 = random.randint(0, h - self.crop_size)
162 | img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))
163 | return {'image': img}
164 |
165 | class FixScaleCrop(object):
166 | def __init__(self, crop_size):
167 | self.crop_size = crop_size
168 |
169 | def __call__(self, sample):
170 | if len(sample) != 1:
171 | img = sample['image']
172 | mask = sample['label']
173 | w, h = img.size
174 | if w > h:
175 | oh = self.crop_size
176 | ow = int(1.0 * w * oh / h)
177 | else:
178 | ow = self.crop_size
179 | oh = int(1.0 * h * ow / w)
180 | img = img.resize((ow, oh), Image.BILINEAR)
181 | mask = mask.resize((ow, oh), Image.NEAREST)
182 | # center crop
183 | w, h = img.size
184 | x1 = int(round((w - self.crop_size) / 2.))
185 | y1 = int(round((h - self.crop_size) / 2.))
186 | img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))
187 | mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))
188 | return {'image': img, 'label': mask}
189 | else:
190 | img = sample['image']
191 | w, h = img.size
192 | if w > h:
193 | oh = self.crop_size
194 | ow = int(1.0 * w * oh / h)
195 | else:
196 | ow = self.crop_size
197 | oh = int(1.0 * h * ow / w)
198 | img = img.resize((ow, oh), Image.BILINEAR)
199 | # center crop
200 | w, h = img.size
201 | x1 = int(round((w - self.crop_size) / 2.))
202 | y1 = int(round((h - self.crop_size) / 2.))
203 | img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))
204 | return {'image': img}
205 |
206 |
207 | class FixedResize(object):
208 | def __init__(self, size):
209 | self.size = size # size: (h, w)
210 |
211 | def __call__(self, sample):
212 | if len(sample) != 1:
213 | img = sample['image']
214 | mask = sample['label']
215 | assert img.size == mask.size
216 | img = img.resize(self.size, Image.BILINEAR)
217 | mask = mask.resize(self.size, Image.NEAREST)
218 | return {'image': img,
219 | 'label': mask}
220 | else:
221 | img = sample['image']
222 | img = img.resize(self.size, Image.BILINEAR)
223 | return {'image': img}
--------------------------------------------------------------------------------
/utils/losses/__init__.py:
--------------------------------------------------------------------------------
1 | from .loss import *
--------------------------------------------------------------------------------
/utils/losses/lovasz_losses.py:
--------------------------------------------------------------------------------
1 | """
2 | Lovasz-Softmax and Jaccard hinge loss in PyTorch
3 | Maxim Berman 2018 ESAT-PSI KU Leuven (MIT License)
4 | https://github.com/bermanmaxim/LovaszSoftmax/blob/master/pytorch/lovasz_losses.py
5 | """
6 |
7 | from __future__ import print_function, division
8 |
9 | import torch
10 | from torch.autograd import Variable
11 | import torch.nn.functional as F
12 | import numpy as np
13 | try:
14 | from itertools import ifilterfalse
15 | except ImportError: # py3k
16 | from itertools import filterfalse as ifilterfalse
17 |
18 |
19 | def lovasz_grad(gt_sorted):
20 | """
21 | Computes gradient of the Lovasz extension w.r.t sorted errors
22 | See Alg. 1 in paper
23 | """
24 | p = len(gt_sorted)
25 | gts = gt_sorted.sum()
26 | intersection = gts - gt_sorted.float().cumsum(0)
27 | union = gts + (1 - gt_sorted).float().cumsum(0)
28 | jaccard = 1. - intersection / union
29 | if p > 1: # cover 1-pixel case
30 | jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]
31 | return jaccard
32 |
33 |
34 | def iou_binary(preds, labels, EMPTY=1., ignore=None, per_image=True):
35 | """
36 | IoU for foreground class
37 | binary: 1 foreground, 0 background
38 | """
39 | if not per_image:
40 | preds, labels = (preds,), (labels,)
41 | ious = []
42 | for pred, label in zip(preds, labels):
43 | intersection = ((label == 1) & (pred == 1)).sum()
44 | union = ((label == 1) | ((pred == 1) & (label != ignore))).sum()
45 | if not union:
46 | iou = EMPTY
47 | else:
48 | iou = float(intersection) / float(union)
49 | ious.append(iou)
50 | iou = mean(ious) # mean accross images if per_image
51 | return 100 * iou
52 |
53 |
54 | def iou(preds, labels, C, EMPTY=1., ignore=None, per_image=False):
55 | """
56 | Array of IoU for each (non ignored) class
57 | """
58 | if not per_image:
59 | preds, labels = (preds,), (labels,)
60 | ious = []
61 | for pred, label in zip(preds, labels):
62 | iou = []
63 | for i in range(C):
64 | if i != ignore: # The ignored label is sometimes among predicted classes (ENet - CityScapes)
65 | intersection = ((label == i) & (pred == i)).sum()
66 | union = ((label == i) | ((pred == i) & (label != ignore))).sum()
67 | if not union:
68 | iou.append(EMPTY)
69 | else:
70 | iou.append(float(intersection) / float(union))
71 | ious.append(iou)
72 | ious = [mean(iou) for iou in zip(*ious)] # mean accross images if per_image
73 | return 100 * np.array(ious)
74 |
75 |
76 | # --------------------------- BINARY LOSSES ---------------------------
77 |
78 | def lovasz_hinge(logits, labels, per_image=True, ignore=None):
79 | """
80 | Binary Lovasz hinge loss
81 | logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty)
82 | labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
83 | per_image: compute the loss per image instead of per batch
84 | ignore: void class id
85 | """
86 | if per_image:
87 | loss = mean(lovasz_hinge_flat(*flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore))
88 | for log, lab in zip(logits, labels))
89 | else:
90 | loss = lovasz_hinge_flat(*flatten_binary_scores(logits, labels, ignore))
91 | return loss
92 |
93 |
94 | def lovasz_hinge_flat(logits, labels):
95 | """
96 | Binary Lovasz hinge loss
97 | logits: [P] Variable, logits at each prediction (between -\infty and +\infty)
98 | labels: [P] Tensor, binary ground truth labels (0 or 1)
99 | ignore: label to ignore
100 | """
101 | if len(labels) == 0:
102 | # only void pixels, the gradients should be 0
103 | return logits.sum() * 0.
104 | signs = 2. * labels.float() - 1.
105 | errors = (1. - logits * Variable(signs))
106 | errors_sorted, perm = torch.sort(errors, dim=0, descending=True)
107 | perm = perm.data
108 | gt_sorted = labels[perm]
109 | grad = lovasz_grad(gt_sorted)
110 | loss = torch.dot(F.relu(errors_sorted), Variable(grad))
111 | return loss
112 |
113 |
114 | def flatten_binary_scores(scores, labels, ignore=None):
115 | """
116 | Flattens predictions in the batch (binary case)
117 | Remove labels equal to 'ignore'
118 | """
119 | scores = scores.view(-1)
120 | labels = labels.view(-1)
121 | if ignore is None:
122 | return scores, labels
123 | valid = (labels != ignore)
124 | vscores = scores[valid]
125 | vlabels = labels[valid]
126 | return vscores, vlabels
127 |
128 |
129 | class StableBCELoss(torch.nn.modules.Module):
130 | def __init__(self):
131 | super(StableBCELoss, self).__init__()
132 | def forward(self, input, target):
133 | neg_abs = - input.abs()
134 | loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log()
135 | return loss.mean()
136 |
137 |
138 | def binary_xloss(logits, labels, ignore=None):
139 | """
140 | Binary Cross entropy loss
141 | logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty)
142 | labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
143 | ignore: void class id
144 | """
145 | logits, labels = flatten_binary_scores(logits, labels, ignore)
146 | loss = StableBCELoss()(logits, Variable(labels.float()))
147 | return loss
148 |
149 |
150 | # --------------------------- MULTICLASS LOSSES ---------------------------
151 |
152 |
153 | def lovasz_softmax(probas, labels, classes='present', per_image=False, ignore=None):
154 | """
155 | Multi-class Lovasz-Softmax loss
156 | probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1).
157 | Interpreted as binary (sigmoid) output with outputs of size [B, H, W].
158 | labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1)
159 | classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
160 | per_image: compute the loss per image instead of per batch
161 | ignore: void class labels
162 | """
163 | if per_image:
164 | loss = mean(lovasz_softmax_flat(*flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore), classes=classes)
165 | for prob, lab in zip(probas, labels))
166 | else:
167 | loss = lovasz_softmax_flat(*flatten_probas(probas, labels, ignore), classes=classes)
168 | return loss
169 |
170 |
171 | def lovasz_softmax_flat(probas, labels, classes='present'):
172 | """
173 | Multi-class Lovasz-Softmax loss
174 | probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1)
175 | labels: [P] Tensor, ground truth labels (between 0 and C - 1)
176 | classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
177 | """
178 | if probas.numel() == 0:
179 | # only void pixels, the gradients should be 0
180 | return probas * 0.
181 | C = probas.size(1)
182 | losses = []
183 | class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes
184 | for c in class_to_sum:
185 | fg = (labels == c).float() # foreground for class c
186 | if (classes is 'present' and fg.sum() == 0):
187 | continue
188 | if C == 1:
189 | if len(classes) > 1:
190 | raise ValueError('Sigmoid output possible only with 1 class')
191 | class_pred = probas[:, 0]
192 | else:
193 | class_pred = probas[:, c]
194 | errors = (Variable(fg) - class_pred).abs()
195 | errors_sorted, perm = torch.sort(errors, 0, descending=True)
196 | perm = perm.data
197 | fg_sorted = fg[perm]
198 | losses.append(torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted))))
199 | return mean(losses)
200 |
201 |
202 | def flatten_probas(probas, labels, ignore=None):
203 | """
204 | Flattens predictions in the batch
205 | """
206 | if probas.dim() == 3:
207 | # assumes output of a sigmoid layer
208 | B, H, W = probas.size()
209 | probas = probas.view(B, 1, H, W)
210 | B, C, H, W = probas.size()
211 | probas = probas.permute(0, 2, 3, 1).contiguous().view(-1, C) # B * H * W, C = P, C
212 | labels = labels.view(-1)
213 | if ignore is None:
214 | return probas, labels
215 | valid = (labels != ignore)
216 | vprobas = probas[valid.nonzero().squeeze()]
217 | vlabels = labels[valid]
218 | return vprobas, vlabels
219 |
220 | def xloss(logits, labels, ignore=None):
221 | """
222 | Cross entropy loss
223 | """
224 | return F.cross_entropy(logits, Variable(labels), ignore_index=255)
225 |
226 |
227 | # --------------------------- HELPER FUNCTIONS ---------------------------
228 | def isnan(x):
229 | return x != x
230 |
231 |
232 | def mean(l, ignore_nan=False, empty=0):
233 | """
234 | nanmean compatible with generators.
235 | """
236 | l = iter(l)
237 | if ignore_nan:
238 | l = ifilterfalse(isnan, l)
239 | try:
240 | n = 1
241 | acc = next(l)
242 | except StopIteration:
243 | if empty == 'raise':
244 | raise ValueError('Empty mean')
245 | return empty
246 | for n, v in enumerate(l, 2):
247 | acc += v
248 | if n == 1:
249 | return acc
250 | return acc / n
--------------------------------------------------------------------------------
/utils/metric/SegmentationMetric.py:
--------------------------------------------------------------------------------
1 | """
2 | Reference to https://github.com/jfzhang95/pytorch-deeplab-xception/blob/master/utils/metrics.py
3 | Add metrics: Precision、Recall、F1-Score
4 | """
5 | import numpy as np
6 | np.seterr(divide='ignore', invalid='ignore')
7 | __all__ = ['SegmentationMetric']
8 |
9 | """
10 | confusionMetric # 注意:此处横着代表预测值,竖着代表真实值,与之前介绍的相反
11 | P\L P N
12 | P TP FP
13 | N FN TN
14 | sum(axis=0) TP+FN
15 | sum(axis=1) TP+FP
16 | np.diag().sum() TP+TN
17 | """
18 | class SegmentationMetric(object):
19 | def __init__(self, numClass):
20 | self.numClass = numClass
21 | self.confusionMatrix = np.zeros((self.numClass,)*2)
22 |
23 | def pixelAccuracy(self):
24 | # return all class overall pixel accuracy
25 | # PA = acc = (TP + TN) / (TP + TN + FP + FN)
26 | acc = np.diag(self.confusionMatrix).sum() / self.confusionMatrix.sum()
27 | return acc
28 |
29 | def meanPixelAccuracy(self):
30 | # return each category pixel accuracy(A more accurate way to call it precision)
31 | # acc = TP / (TP + FP)
32 | Cpa = np.diag(self.confusionMatrix) / self.confusionMatrix.sum(axis=1)
33 | Mpa = np.nanmean(Cpa) # 求各类别Cpa的平均
34 | return Mpa, Cpa # 返回的是一个列表值,如:[0.90, 0.80, 0.96],表示类别1 2 3各类别的预测准确率
35 |
36 |
37 | def meanIntersectionOverUnion(self):
38 | # Intersection = TP ;Union = TP + FP + FN
39 | # Ciou = TP / (TP + FP + FN)
40 | intersection = np.diag(self.confusionMatrix) # 取对角元素的值,返回列表
41 | union = np.sum(self.confusionMatrix, axis=1) + np.sum(self.confusionMatrix, axis=0) - np.diag(self.confusionMatrix) # axis = 1表示混淆矩阵行的值,返回列表; axis = 0表示取混淆矩阵列的值,返回列表
42 |
43 | Ciou = (intersection / np.maximum(1.0, union)) # 返回列表,其值为各个类别的Ciou
44 | mIoU = np.nanmean(Ciou) # 求各类别Ciou的平均
45 | return mIoU, Ciou
46 |
47 | def Frequency_Weighted_Intersection_over_Union(self):
48 | # FWIOU = [(TP+FN)/(TP+FP+TN+FN)] *[TP / (TP + FP + FN)]
49 | freq = np.sum(self.confusionMatrix, axis=1) / np.sum(self.confusionMatrix)
50 | iu = np.diag(self.confusionMatrix) / (
51 | np.sum(self.confusionMatrix, axis=1) + np.sum(self.confusionMatrix, axis=0) -
52 | np.diag(self.confusionMatrix))
53 | FWIoU = (freq[freq > 0] * iu[freq > 0]).sum()
54 | return FWIoU
55 |
56 | def precision(self):
57 | # precision = TP / (TP+FP)
58 | precision = np.diag(self.confusionMatrix) / np.sum(self.confusionMatrix, axis=1)
59 | return precision
60 |
61 | def recall(self):
62 | # recall = TP / (TP+FN)
63 | recall = np.diag(self.confusionMatrix) / np.sum(self.confusionMatrix, axis=0)
64 | return recall
65 |
66 | def genConfusionMatrix(self, imgPredict, imgLabel): # 同FCN中score.py的fast_hist()函数
67 | # remove classes from unlabeled pixels in gt image and predict
68 | mask = (imgLabel >= 0) & (imgLabel < self.numClass)
69 | label = self.numClass * imgLabel[mask].astype('int') + imgPredict[mask]
70 | count = np.bincount(label, minlength=self.numClass**2)
71 | confusionMatrix = count.reshape(self.numClass, self.numClass)
72 | return confusionMatrix
73 |
74 | def addBatch(self, imgPredict, imgLabel):
75 | assert imgPredict.shape == imgLabel.shape
76 | self.confusionMatrix += self.genConfusionMatrix(imgPredict, imgLabel)
77 |
78 | def reset(self):
79 | self.confusionMatrix = np.zeros((self.numClass, self.numClass))
80 |
81 | if __name__ == '__main__':
82 | imgPredict = np.array([0, 0, 1, 1, 2, 2]) # 可直接换成预测图片
83 | imgLabel = np.array([0, 0, 1, 1, 1, 2]) # 可直接换成标注图片
84 | metric = SegmentationMetric(3) # 3表示有3个分类,有几个分类就填几
85 | metric.addBatch(imgPredict, imgLabel)
86 | pa = metric.pixelAccuracy()
87 | cpa = metric.classPixelAccuracy()
88 | mpa = metric.meanPixelAccuracy()
89 | mIoU, per = metric.meanIntersectionOverUnion()
90 | print('pa is : %f' % pa)
91 | print('cpa is :') # 列表
92 | print('mpa is : %f' % mpa)
93 | print('mIoU is : %f' % mIoU, per)
94 |
--------------------------------------------------------------------------------
/utils/metric/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Deeachain/Segmentation-Pytorch/acc6998863dfef884bc5fe954c2b8de1c28576a7/utils/metric/__init__.py
--------------------------------------------------------------------------------
/utils/optim/AdamW.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch.optim.optimizer import Optimizer
4 |
5 | class AdamW(Optimizer):
6 | """Implements Adam algorithm.
7 | It has been proposed in `Adam: A Method for Stochastic Optimization`_.
8 | Arguments:
9 | params (iterable): iterable of parameters to optimize or dicts defining
10 | parameter groups
11 | lr (float, optional): learning rate (default: 1e-3)
12 | betas (Tuple[float, float], optional): coefficients used for computing
13 | running averages of gradient and its square (default: (0.9, 0.999))
14 | eps (float, optional): term added to the denominator to improve
15 | numerical stability (default: 1e-8)
16 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
17 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this
18 | algorithm from the paper `On the Convergence of Adam and Beyond`_
19 | .. _Adam\: A Method for Stochastic Optimization:
20 | https://arxiv.org/abs/1412.6980
21 | .. _On the Convergence of Adam and Beyond:
22 | https://openreview.net/forum?id=ryQu7f-RZ
23 | """
24 |
25 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
26 | weight_decay=0, amsgrad=False):
27 | if not 0.0 <= lr:
28 | raise ValueError("Invalid learning rate: {}".format(lr))
29 | if not 0.0 <= eps:
30 | raise ValueError("Invalid epsilon value: {}".format(eps))
31 | if not 0.0 <= betas[0] < 1.0:
32 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
33 | if not 0.0 <= betas[1] < 1.0:
34 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
35 | defaults = dict(lr=lr, betas=betas, eps=eps,
36 | weight_decay=weight_decay, amsgrad=amsgrad)
37 | super(AdamW, self).__init__(params, defaults)
38 |
39 | def __setstate__(self, state):
40 | super(AdamW, self).__setstate__(state)
41 | for group in self.param_groups:
42 | group.setdefault('amsgrad', False)
43 |
44 | def step(self, closure=None):
45 | """Performs a single optimization step.
46 | Arguments:
47 | closure (callable, optional): A closure that reevaluates the model
48 | and returns the loss.
49 | """
50 | loss = None
51 | if closure is not None:
52 | loss = closure()
53 |
54 | for group in self.param_groups:
55 | for p in group['params']:
56 | if p.grad is None:
57 | continue
58 | grad = p.grad.data
59 | if grad.is_sparse:
60 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
61 | amsgrad = group['amsgrad']
62 |
63 | state = self.state[p]
64 |
65 | # State initialization
66 | if len(state) == 0:
67 | state['step'] = 0
68 | # Exponential moving average of gradient values
69 | state['exp_avg'] = torch.zeros_like(p.data)
70 | # Exponential moving average of squared gradient values
71 | state['exp_avg_sq'] = torch.zeros_like(p.data)
72 | if amsgrad:
73 | # Maintains max of all exp. moving avg. of sq. grad. values
74 | state['max_exp_avg_sq'] = torch.zeros_like(p.data)
75 |
76 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
77 | if amsgrad:
78 | max_exp_avg_sq = state['max_exp_avg_sq']
79 | beta1, beta2 = group['betas']
80 |
81 | state['step'] += 1
82 |
83 | # if group['weight_decay'] != 0:
84 | # grad = grad.add(group['weight_decay'], p.data)
85 |
86 | # Decay the first and second moment running average coefficient
87 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
88 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
89 | if amsgrad:
90 | # Maintains the maximum of all 2nd moment running avg. till now
91 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
92 | # Use the max. for normalizing running avg. of gradient
93 | denom = max_exp_avg_sq.sqrt().add_(group['eps'])
94 | else:
95 | denom = exp_avg_sq.sqrt().add_(group['eps'])
96 |
97 | bias_correction1 = 1 - beta1 ** state['step']
98 | bias_correction2 = 1 - beta2 ** state['step']
99 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
100 |
101 | # p.data.addcdiv_(-step_size, exp_avg, denom)
102 | p.data.add_(-step_size, torch.mul(p.data, group['weight_decay']).addcdiv_(1, exp_avg, denom) )
103 |
104 | return loss
--------------------------------------------------------------------------------
/utils/optim/Lookahead.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | from itertools import chain
3 | from torch.optim import Optimizer
4 | import torch
5 | import warnings
6 |
7 | class Lookahead(Optimizer):
8 | def __init__(self, optimizer, k=5, alpha=0.5):
9 | self.optimizer = optimizer
10 | self.k = k
11 | self.alpha = alpha
12 | self.param_groups = self.optimizer.param_groups
13 | self.state = defaultdict(dict)
14 | self.fast_state = self.optimizer.state
15 | for group in self.param_groups:
16 | group["counter"] = 0
17 |
18 | def update(self, group):
19 | for fast in group["params"]:
20 | param_state = self.state[fast]
21 | if "slow_param" not in param_state:
22 | param_state["slow_param"] = torch.zeros_like(fast.data)
23 | param_state["slow_param"].copy_(fast.data)
24 | slow = param_state["slow_param"]
25 | slow += (fast.data - slow) * self.alpha
26 | fast.data.copy_(slow)
27 |
28 | def update_lookahead(self):
29 | for group in self.param_groups:
30 | self.update(group)
31 |
32 | def step(self, closure=None):
33 | loss = self.optimizer.step(closure)
34 | for group in self.param_groups:
35 | if group["counter"] == 0:
36 | self.update(group)
37 | group["counter"] += 1
38 | if group["counter"] >= self.k:
39 | group["counter"] = 0
40 | return loss
41 |
42 | def state_dict(self):
43 | fast_state_dict = self.optimizer.state_dict()
44 | slow_state = {
45 | (id(k) if isinstance(k, torch.Tensor) else k): v
46 | for k, v in self.state.items()
47 | }
48 | fast_state = fast_state_dict["state"]
49 | param_groups = fast_state_dict["param_groups"]
50 | return {
51 | "fast_state": fast_state,
52 | "slow_state": slow_state,
53 | "param_groups": param_groups,
54 | }
55 |
56 | def load_state_dict(self, state_dict):
57 | slow_state_dict = {
58 | "state": state_dict["slow_state"],
59 | "param_groups": state_dict["param_groups"],
60 | }
61 | fast_state_dict = {
62 | "state": state_dict["fast_state"],
63 | "param_groups": state_dict["param_groups"],
64 | }
65 | super(Lookahead, self).load_state_dict(slow_state_dict)
66 | self.optimizer.load_state_dict(fast_state_dict)
67 | self.fast_state = self.optimizer.state
68 |
69 | def add_param_group(self, param_group):
70 | param_group["counter"] = 0
71 | self.optimizer.add_param_group(param_group)
--------------------------------------------------------------------------------
/utils/optim/RAdam.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch.optim.optimizer import Optimizer
4 |
5 | class RAdam(Optimizer):
6 |
7 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
8 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
9 | self.buffer = [[None, None, None] for ind in range(10)]
10 | super(RAdam, self).__init__(params, defaults)
11 |
12 | def __setstate__(self, state):
13 | super(RAdam, self).__setstate__(state)
14 |
15 | def step(self, closure=None):
16 |
17 | loss = None
18 | if closure is not None:
19 | loss = closure()
20 |
21 | for group in self.param_groups:
22 |
23 | for p in group['params']:
24 | if p.grad is None:
25 | continue
26 | grad = p.grad.data.float()
27 | if grad.is_sparse:
28 | raise RuntimeError('RAdam does not support sparse gradients')
29 |
30 | p_data_fp32 = p.data.float()
31 |
32 | state = self.state[p]
33 |
34 | if len(state) == 0:
35 | state['step'] = 0
36 | state['exp_avg'] = torch.zeros_like(p_data_fp32)
37 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
38 | else:
39 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
40 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
41 |
42 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
43 | beta1, beta2 = group['betas']
44 |
45 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
46 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
47 |
48 | state['step'] += 1
49 | buffered = self.buffer[int(state['step'] % 10)]
50 | if state['step'] == buffered[0]:
51 | N_sma, step_size = buffered[1], buffered[2]
52 | else:
53 | buffered[0] = state['step']
54 | beta2_t = beta2 ** state['step']
55 | N_sma_max = 2 / (1 - beta2) - 1
56 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
57 | buffered[1] = N_sma
58 |
59 | # more conservative since it's an approximated value
60 | if N_sma >= 5:
61 | step_size = group['lr'] * math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step'])
62 | else:
63 | step_size = group['lr'] / (1 - beta1 ** state['step'])
64 | buffered[2] = step_size
65 |
66 | if group['weight_decay'] != 0:
67 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
68 |
69 | # more conservative since it's an approximated value
70 | if N_sma >= 5:
71 | denom = exp_avg_sq.sqrt().add_(group['eps'])
72 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
73 | else:
74 | p_data_fp32.add_(-step_size, exp_avg)
75 |
76 | p.data.copy_(p_data_fp32)
77 |
78 | return loss
--------------------------------------------------------------------------------
/utils/optim/Ranger.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch.optim.optimizer import Optimizer
4 | import itertools as it
5 |
6 |
7 |
8 | class Ranger(Optimizer):
9 |
10 | def __init__(self, params, lr=1e-3, alpha=0.5, k=6, N_sma_threshhold=5, betas=(.95,0.999), eps=1e-5, weight_decay=0):
11 | #parameter checks
12 | if not 0.0 <= alpha <= 1.0:
13 | raise ValueError(f'Invalid slow update rate: {alpha}')
14 | if not 1 <= k:
15 | raise ValueError(f'Invalid lookahead steps: {k}')
16 | if not lr > 0:
17 | raise ValueError(f'Invalid Learning Rate: {lr}')
18 | if not eps > 0:
19 | raise ValueError(f'Invalid eps: {eps}')
20 |
21 | #parameter comments:
22 | # beta1 (momentum) of .95 seems to work better than .90...
23 | #N_sma_threshold of 5 seems better in testing than 4.
24 | #In both cases, worth testing on your dataset (.90 vs .95, 4 vs 5) to make sure which works best for you.
25 |
26 | #prep defaults and init torch.optim base
27 | defaults = dict(lr=lr, alpha=alpha, k=k, step_counter=0, betas=betas, N_sma_threshhold=N_sma_threshhold, eps=eps, weight_decay=weight_decay)
28 | super().__init__(params,defaults)
29 |
30 | #adjustable threshold
31 | self.N_sma_threshhold = N_sma_threshhold
32 |
33 | #now we can get to work...
34 | #removed as we now use step from RAdam...no need for duplicate step counting
35 | #for group in self.param_groups:
36 | # group["step_counter"] = 0
37 | #print("group step counter init")
38 |
39 | #look ahead params
40 | self.alpha = alpha
41 | self.k = k
42 |
43 | #radam buffer for state
44 | self.radam_buffer = [[None,None,None] for ind in range(10)]
45 |
46 | #self.first_run_check=0
47 |
48 | #lookahead weights
49 | #9/2/19 - lookahead param tensors have been moved to state storage.
50 | #This should resolve issues with load/save where weights were left in GPU memory from first load, slowing down future runs.
51 |
52 | #self.slow_weights = [[p.clone().detach() for p in group['params']]
53 | # for group in self.param_groups]
54 |
55 | #don't use grad for lookahead weights
56 | #for w in it.chain(*self.slow_weights):
57 | # w.requires_grad = False
58 |
59 | def __setstate__(self, state):
60 | print("set state called")
61 | super(Ranger, self).__setstate__(state)
62 |
63 |
64 | def step(self, closure=None):
65 | loss = None
66 | #note - below is commented out b/c I have other work that passes back the loss as a float, and thus not a callable closure.
67 | #Uncomment if you need to use the actual closure...
68 |
69 | #if closure is not None:
70 | #loss = closure()
71 |
72 | #Evaluate averages and grad, update param tensors
73 | for group in self.param_groups:
74 |
75 | for p in group['params']:
76 | if p.grad is None:
77 | continue
78 | grad = p.grad.data.float()
79 | if grad.is_sparse:
80 | raise RuntimeError('Ranger optimizer does not support sparse gradients')
81 |
82 | p_data_fp32 = p.data.float()
83 |
84 | state = self.state[p] #get state dict for this param
85 |
86 | if len(state) == 0: #if first time to run...init dictionary with our desired entries
87 | #if self.first_run_check==0:
88 | #self.first_run_check=1
89 | #print("Initializing slow buffer...should not see this at load from saved model!")
90 | state['step'] = 0
91 | state['exp_avg'] = torch.zeros_like(p_data_fp32)
92 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
93 |
94 | #look ahead weight storage now in state dict
95 | state['slow_buffer'] = torch.empty_like(p.data)
96 | state['slow_buffer'].copy_(p.data)
97 |
98 | else:
99 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
100 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
101 |
102 | #begin computations
103 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
104 | beta1, beta2 = group['betas']
105 |
106 | #compute variance mov avg
107 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
108 | #compute mean moving avg
109 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
110 |
111 | state['step'] += 1
112 |
113 |
114 | buffered = self.radam_buffer[int(state['step'] % 10)]
115 | if state['step'] == buffered[0]:
116 | N_sma, step_size = buffered[1], buffered[2]
117 | else:
118 | buffered[0] = state['step']
119 | beta2_t = beta2 ** state['step']
120 | N_sma_max = 2 / (1 - beta2) - 1
121 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
122 | buffered[1] = N_sma
123 | if N_sma > self.N_sma_threshhold:
124 | step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step'])
125 | else:
126 | step_size = 1.0 / (1 - beta1 ** state['step'])
127 | buffered[2] = step_size
128 |
129 | if group['weight_decay'] != 0:
130 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
131 |
132 | if N_sma > self.N_sma_threshhold:
133 | denom = exp_avg_sq.sqrt().add_(group['eps'])
134 | p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom)
135 | else:
136 | p_data_fp32.add_(-step_size * group['lr'], exp_avg)
137 |
138 | p.data.copy_(p_data_fp32)
139 |
140 | #integrated look ahead...
141 | #we do it at the param level instead of group level
142 | if state['step'] % group['k'] == 0:
143 | slow_p = state['slow_buffer'] #get access to slow param tensor
144 | slow_p.add_(self.alpha, p.data - slow_p) #(fast weights - slow weights) * alpha
145 | p.data.copy_(slow_p) #copy interpolated weights to RAdam param tensor
146 |
147 | return loss
--------------------------------------------------------------------------------
/utils/optim/__init__.py:
--------------------------------------------------------------------------------
1 | from .RAdam import *
2 | from .AdamW import *
3 | from .Lookahead import *
4 | from .Ranger import *
--------------------------------------------------------------------------------
/utils/plot_log.py:
--------------------------------------------------------------------------------
1 | import matplotlib
2 | matplotlib.use('Agg')
3 | from matplotlib import pyplot as plt
4 |
5 |
6 | def draw_log(args, epoch, epoch_list, lossTr_list, mIOU_val_list, lossVal_list):
7 | f = open(args.savedir + 'log.txt', 'r')
8 | next(f)
9 | if args.val_epochs == 1:
10 | try:
11 | assert len(range(1, epoch + 1)) == len(lossTr_list)
12 | assert len(range(1, epoch + 1)) == len(lossVal_list)
13 | except:
14 | print('plot dimension is wrong! Please check log.txt! \n')
15 | else:
16 | # plt loss
17 | fig1, ax1 = plt.subplots(figsize=(11, 8))
18 | ax1.plot(range(1, epoch + 1), lossTr_list, label='Train_loss')
19 | ax1.plot(range(1, epoch + 1), lossVal_list, label='Val_loss')
20 | ax1.set_title("Average training loss vs epochs")
21 | ax1.set_xlabel("Epochs")
22 | ax1.set_ylabel("Current loss")
23 | ax1.legend()
24 | plt.savefig(args.savedir + "loss.png")
25 | plt.close('all')
26 | plt.clf()
27 | # plt Miou
28 | fig2, ax2 = plt.subplots(figsize=(11, 8))
29 | ax2.plot(range(1, epoch + 1), mIOU_val_list, label="Val IoU")
30 | ax2.set_title("Average IoU vs epochs")
31 | ax2.set_xlabel("Epochs")
32 | ax2.set_ylabel("Current IoU")
33 | ax2.legend()
34 | plt.savefig(args.savedir + "mIou.png")
35 | plt.close('all')
36 | else:
37 | # plt loss
38 | fig1, ax1 = plt.subplots(figsize=(11, 8))
39 | try:
40 | assert len(epoch_list) == len(lossVal_list)
41 | assert len(range(1, epoch + 1)) == len(lossTr_list)
42 | except:
43 | print('plot dimension is wrong! Please check log.txt! \n')
44 | else:
45 | ax1.plot(range(1, epoch + 1), lossTr_list, label='Train_loss')
46 | ax1.plot(epoch_list, lossVal_list, label='Val_loss')
47 | ax1.set_title("Average loss vs epochs")
48 | ax1.set_xlabel("Epochs")
49 | ax1.set_ylabel("Current loss")
50 | ax1.legend()
51 | plt.savefig(args.savedir + "loss.png")
52 | plt.clf()
53 | # plt Miou
54 | fig2, ax2 = plt.subplots(figsize=(11, 8))
55 | ax2.plot(epoch_list, mIOU_val_list, label="Val IoU")
56 | ax2.set_title("Average IoU vs epochs")
57 | ax2.set_xlabel("Epochs")
58 | ax2.set_ylabel("Current IoU")
59 | ax2.legend()
60 | plt.savefig(args.savedir + "mIou.png")
61 | plt.close('all')
--------------------------------------------------------------------------------
/utils/record_log.py:
--------------------------------------------------------------------------------
1 | from prettytable import PrettyTable
2 |
3 |
4 | class record_log():
5 | def __init__(self, args):
6 | self.args = args
7 |
8 | def record_args(self, datas, total_paramters, GLOBAL_SEED):
9 | with open(self.args.savedir + 'args.txt', 'w') as f:
10 | t = PrettyTable(['args_name', 'args_value'])
11 | for k in list(vars(self.args).keys()):
12 | t.add_row([k, vars(self.args)[k]])
13 | t.add_row(['seed', GLOBAL_SEED])
14 | t.add_row(['parameters', total_paramters])
15 | t.add_row(['mean', datas['mean']])
16 | t.add_row(['std', datas['std']])
17 | t.add_row(['classWeights', datas['classWeights']])
18 | print(t.get_string(title="Train Arguments"))
19 | f.write(str(t))
20 |
21 | def record_best_epoch(self, epoch, Best_Miou, Pa):
22 | with open(self.args.savedir + 'args.txt', 'a+') as f:
23 | f.write('\nBest Validation Epoch {} Best_Miou is {} OA is {}'.format(epoch, Best_Miou, Pa))
24 |
25 | def initial_logfile(self):
26 | logFileLoc = self.args.savedir + self.args.logFile
27 | logger = open(logFileLoc, 'w')
28 | logger.write(("{}\t{}\t\t{}\t{}\t{}\t{}\t{}\t{}\t\t{}\t\t{}\n".format(
29 | 'Epoch', ' lr', 'Loss(Tr)', 'Loss(Val)', 'FWIOU(Val)', 'mIOU(Val)', 'Pa(Val)', ' Mpa(Val)',
30 | 'PerMiou_set(Val)',' Cpa_set(Val)')))
31 | return logger
32 |
33 | def resume_logfile(self):
34 | logFileLoc = self.args.savedir + self.args.logFile
35 | logger_recored = open(logFileLoc, 'r')
36 | next(logger_recored)
37 | lines = logger_recored.readlines()
38 | logger_recored.close()
39 | logger = open(logFileLoc, 'a+')
40 | return logger, lines
41 |
42 | def record_trainVal_log(self, logger, epoch, lr, lossTr, val_loss, FWIoU, mIOU_val, MIoU, PerMiou_set, Pa_Val, Mpa_Val,
43 | Cpa_set, MF, F_set, F1_avg, class_dict_df):
44 | logger.write("{}\t{:.6f}\t{:.4f}\t\t{:.4f}\t\t{:.4f}\t\t{:.4f}\t\t{:.4f}\t{:.4f}\t {:.4f} {} {}\n".format(
45 | epoch, lr, lossTr, val_loss, FWIoU, mIOU_val, MIoU, Pa_Val, Mpa_Val, PerMiou_set, Cpa_set))
46 | logger.flush()
47 | print("Epoch {} lr={:.6f} Train Loss={:.4f} Val Loss={:.4f}".format(epoch, lr, lossTr, val_loss))
48 |
49 |
50 | def record_train_log(self, logger, epoch, lr, lossTr):
51 | logger.write("{}\t{:.6f}\t{:.4f}\n".format(epoch, lr, lossTr))
52 | logger.flush()
53 | print("Epoch {} lr={:.6f} Train Loss={:.4f}".format(epoch, lr, lossTr))
54 |
--------------------------------------------------------------------------------
/utils/scheduler/__init__.py:
--------------------------------------------------------------------------------
1 | from .lr_scheduler import *
2 |
--------------------------------------------------------------------------------
/utils/scheduler/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | import math
2 | from torch.optim.lr_scheduler import MultiStepLR, _LRScheduler
3 |
4 |
5 | class WarmupMultiStepLR(MultiStepLR):
6 | def __init__(self, optimizer, milestones, gamma=0.1, warmup_factor=1.0 / 3,
7 | warmup_iters=500, last_epoch=-1):
8 | self.warmup_factor = warmup_factor
9 | self.warmup_iters = warmup_iters
10 | super().__init__(optimizer, milestones, gamma, last_epoch)
11 |
12 | def get_lr(self):
13 | if self.last_epoch <= self.warmup_iters:
14 | alpha = self.last_epoch / self.warmup_iters
15 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha
16 | # print(self.base_lrs[0]*warmup_factor)
17 | return [lr * warmup_factor for lr in self.base_lrs]
18 | else:
19 | lr = super().get_lr()
20 | return lr
21 |
22 |
23 | class WarmupCosineLR(_LRScheduler):
24 | def __init__(self, optimizer, T_max, warmup_factor=1.0 / 3, warmup_iters=500,
25 | eta_min=0, last_epoch=-1):
26 | self.warmup_factor = warmup_factor
27 | self.warmup_iters = warmup_iters
28 | self.T_max, self.eta_min = T_max, eta_min
29 | super().__init__(optimizer, last_epoch)
30 |
31 | def get_lr(self):
32 | if self.last_epoch <= self.warmup_iters:
33 | alpha = self.last_epoch / self.warmup_iters
34 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha
35 | # print(self.base_lrs[0]*warmup_factor)
36 | return [lr * warmup_factor for lr in self.base_lrs]
37 | else:
38 | return [self.eta_min + (base_lr - self.eta_min) *
39 | (1 + math.cos(
40 | math.pi * (self.last_epoch - self.warmup_iters) / (self.T_max - self.warmup_iters))) / 2
41 | for base_lr in self.base_lrs]
42 |
43 |
44 |
45 | class WarmupPolyLR(_LRScheduler):
46 | def __init__(self, optimizer, T_max, cur_iter, warmup_factor=1.0 / 3, warmup_iters=500,
47 | eta_min=0, power=0.9):
48 | self.warmup_factor = warmup_factor
49 | self.warmup_iters = warmup_iters
50 | self.power = power
51 | self.T_max, self.eta_min = T_max, eta_min
52 | self.cur_iter = cur_iter
53 | super().__init__(optimizer)
54 |
55 | def get_lr(self):
56 | if self.cur_iter <= self.warmup_iters:
57 | alpha = self.cur_iter / self.warmup_iters
58 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha
59 | # print(self.base_lrs[0]*warmup_factor)
60 | return [lr * warmup_factor for lr in self.base_lrs]
61 | else:
62 | return [self.eta_min + (base_lr - self.eta_min) *
63 | math.pow(1 - (self.cur_iter - self.warmup_iters) / (self.T_max - self.warmup_iters),
64 | self.power) for base_lr in self.base_lrs]
65 |
66 |
67 | class PolyLR(_LRScheduler):
68 | def __init__(self, optimizer, max_iter, cur_iter, power=0.9):
69 | self.power = power
70 | self.max_iter = max_iter
71 | self.cur_iter = cur_iter
72 | super().__init__(optimizer)
73 |
74 | def get_lr(self):
75 |
76 | return [base_lr * math.pow(1 - (self.cur_iter / self.max_iter), self.power) for base_lr in self.base_lrs]
77 |
78 |
79 |
80 | class GradualWarmupScheduler(_LRScheduler):
81 | """ Gradually warm-up(increasing) learning rate in optimizer.
82 | Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.
83 | Args:
84 | optimizer (Optimizer): Wrapped optimizer.
85 | min_lr_mul: target learning rate = base lr * min_lr_mul
86 | total_epoch: target learning rate is reached at total_epoch, gradually
87 | after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
88 | """
89 |
90 | def __init__(self, optimizer, total_epoch, min_lr_mul=0.1, after_scheduler=None):
91 | self.min_lr_mul = min_lr_mul
92 | if self.min_lr_mul > 1. or self.min_lr_mul < 0.:
93 | raise ValueError('min_lr_mul should be [0., 1.]')
94 | self.total_epoch = total_epoch
95 | self.after_scheduler = after_scheduler
96 | self.finished = False
97 | super(GradualWarmupScheduler, self).__init__(optimizer)
98 |
99 | def get_lr(self):
100 | if self.last_epoch > self.total_epoch:
101 | if self.after_scheduler:
102 | if not self.finished:
103 | self.after_scheduler.base_lrs = self.base_lrs
104 | self.finished = True
105 | return self.after_scheduler.get_lr()
106 | else:
107 | return self.base_lrs
108 | else:
109 | return [base_lr * (self.min_lr_mul + (1. - self.min_lr_mul) * (self.last_epoch / float(self.total_epoch))) for base_lr in self.base_lrs]
110 |
111 | def step(self, epoch=None):
112 | if self.finished and self.after_scheduler:
113 | return self.after_scheduler.step(epoch - self.total_epoch)
114 | else:
115 | return super(GradualWarmupScheduler, self).step(epoch)
116 |
117 |
118 |
119 |
120 | if __name__ == '__main__':
121 | optim = WarmupPolyLR()
122 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Time: 2020/11/30 下午5:02
3 | Author: Ding Cheng(Deeachain)
4 | File: utils.py
5 | Describe: Write during my study in Nanjing University of Information and Secience Technology
6 | Github: https://github.com/Deeachain
7 | """
8 | import os
9 | import random
10 | import numpy as np
11 | from PIL import Image
12 | import torch
13 | import torch.nn as nn
14 | from utils.colorize_mask import cityscapes_colorize_mask, paris_colorize_mask, road_colorize_mask, \
15 | austin_colorize_mask, isprs_colorize_mask
16 |
17 |
18 | def __init_weight(feature, conv_init, norm_layer, bn_eps, bn_momentum,
19 | **kwargs):
20 | for name, m in feature.named_modules():
21 | if isinstance(m, (nn.Conv2d, nn.Conv3d)):
22 | conv_init(m.weight, **kwargs)
23 | elif isinstance(m, norm_layer):
24 | m.eps = bn_eps
25 | m.momentum = bn_momentum
26 | nn.init.constant_(m.weight, 1)
27 | nn.init.constant_(m.bias, 0)
28 |
29 |
30 | def init_weight(module_list, conv_init, norm_layer, bn_eps, bn_momentum,
31 | **kwargs):
32 | if isinstance(module_list, list):
33 | for feature in module_list:
34 | __init_weight(feature, conv_init, norm_layer, bn_eps, bn_momentum,
35 | **kwargs)
36 | else:
37 | __init_weight(module_list, conv_init, norm_layer, bn_eps, bn_momentum,
38 | **kwargs)
39 |
40 |
41 | def setup_seed(seed):
42 | torch.manual_seed(seed)
43 | torch.cuda.manual_seed_all(seed)
44 | np.random.seed(seed)
45 | random.seed(seed)
46 |
47 |
48 | def save_predict(output, gt, img_name, dataset, save_path, output_grey=False, output_color=True, gt_color=False):
49 |
50 | if output_grey:
51 | if dataset == 'cityscapes':
52 | output[np.where(output==18)] = 33
53 | output[np.where(output==17)] = 32
54 | output[np.where(output==16)] = 31
55 | output[np.where(output==15)] = 28
56 | output[np.where(output==14)] = 27
57 | output[np.where(output==13)] = 26
58 | output[np.where(output==12)] = 25
59 | output[np.where(output==11)] = 24
60 | output[np.where(output==10)] = 23
61 | output[np.where(output==9)] = 22
62 | output[np.where(output==8)] = 21
63 | output[np.where(output==7)] = 20
64 | output[np.where(output==6)] = 19
65 | output[np.where(output==5)] = 17
66 | output[np.where(output==4)] = 13
67 | output[np.where(output==3)] = 12
68 | output[np.where(output==2)] = 11
69 | output[np.where(output==1)] = 8
70 | output[np.where(output==0)] = 7
71 | output_grey = Image.fromarray(output)
72 | output_grey.save(os.path.join(save_path, img_name + '.png'))
73 |
74 | if output_color:
75 | if dataset == 'cityscapes':
76 | output_color = cityscapes_colorize_mask(output)
77 | elif dataset == 'paris':
78 | output_color = paris_colorize_mask(output)
79 | elif dataset == 'road':
80 | output_color = road_colorize_mask(output)
81 | elif dataset == 'austin':
82 | output_color = austin_colorize_mask(output)
83 | elif dataset == 'postdam' or dataset == 'vaihingen':
84 | output_color = isprs_colorize_mask(output)
85 | output_color.save(os.path.join(save_path, img_name + '_color.png'))
86 |
87 | if gt_color:
88 | if dataset == 'cityscapes':
89 | gt_color = cityscapes_colorize_mask(gt)
90 | elif dataset == 'paris':
91 | gt_color = paris_colorize_mask(gt)
92 | elif dataset == 'road':
93 | gt_color = road_colorize_mask(gt)
94 | elif dataset == 'austin':
95 | gt_color = austin_colorize_mask(gt)
96 | elif dataset == 'postdam' or dataset == 'vaihingen':
97 | gt_color = isprs_colorize_mask(gt)
98 | gt_color.save(os.path.join(save_path, img_name + '_gt.png'))
99 |
100 |
101 | def netParams(model):
102 | """
103 | computing total network parameters
104 | args:
105 | model: model
106 | return: the number of parameters
107 | """
108 | total_paramters = 0
109 | for parameter in model.parameters():
110 | i = len(parameter.size())
111 | p = 1
112 | for j in range(i):
113 | p *= parameter.size(j)
114 | total_paramters += p
115 |
116 | return total_paramters
117 |
--------------------------------------------------------------------------------