├── README.md
├── augseg
├── __init__.py
├── dataset
│ ├── __init__.py
│ ├── augs_ALIA.py
│ ├── augs_TIBA.py
│ ├── base.py
│ ├── builder.py
│ ├── cityscapes.py
│ └── pascal_voc.py
├── models
│ ├── __init__.py
│ ├── decoder.py
│ ├── model_helper.py
│ └── resnet.py
└── utils
│ ├── __init__.py
│ ├── dist_helper.py
│ ├── loss_helper.py
│ ├── lr_helper.py
│ └── utils.py
├── data
└── splitsall.tar.gz
├── docs
├── Augseg-diagram.png
├── augs-cutmix.png
└── augs-intensity.png
├── exps
├── zrun_citys
│ ├── citys_semi744
│ │ └── config_semi.yaml
│ └── r50_citys_semi744
│ │ └── config_semi.yaml
├── zrun_vocs
│ ├── r50_voc_semi662
│ │ └── config_semi.yaml
│ └── voc_semi_fine92_r50
│ │ └── config_semi.yaml
└── zrun_vocs_u2pl
│ ├── r50_voc_semi662
│ └── config_semi.yaml
│ └── voc_semi662
│ └── config_semi.yaml
├── requirements.txt
├── scripts
├── bashtorch
├── zsing_run_citys.sh
└── zsing_run_voc.sh
├── single_run.sh
├── train_semi.py
└── training-logs
├── 1 Sup-Voc
├── blender-u2pl
│ ├── R101-1323
│ │ └── seg_2022-09-24_10_58_06.log
│ ├── R101-2646
│ │ └── seg_2022-09-24_21_10_14.log
│ ├── R101-662
│ │ └── seg_2022-09-24_18_50_02.log
│ ├── R50-1323
│ │ └── seg_2022-09-24_10_51_42.log
│ ├── R50-2646
│ │ └── seg_2022-09-24_13_13_23.log
│ └── R50-662
│ │ └── seg_2022-09-23_22_55_28.log
├── blender
│ ├── r101-1323
│ │ └── seg_2022-09-24_21_07_36.log
│ ├── r101-2646
│ │ └── seg_2022-09-24_23_51_05.log
│ ├── r101-662
│ │ └── seg_2022-09-24_10_56_07.log
│ ├── r50-1323
│ │ └── seg_2022-09-24_12_20_19.log
│ ├── r50-2646
│ │ └── seg_2022-09-24_14_22_47.log
│ └── r50-662
│ │ └── seg_2022-09-24_10_45_00.log
└── fine
│ ├── R101-1464
│ └── seg_2022-09-25_16_59_09.log
│ ├── R101-183
│ └── seg_2022-09-25_17_24_49.log
│ ├── R101-366
│ └── seg_2022-09-25_18_46_02.log
│ ├── R101-732
│ └── seg_2022-09-25_19_56_21.log
│ ├── R101-92
│ └── seg_2022-09-25_15_53_41.log
│ ├── R50-1464
│ └── seg_2022-09-25_00_02_47.log
│ ├── R50-183
│ └── seg_2022-09-25_10_35_51.log
│ ├── R50-366
│ └── seg_2022-09-25_12_25_53.log
│ ├── R50-732
│ └── seg_2022-09-25_13_38_56.log
│ └── R50-92
│ └── seg_2022-09-25_09_42_12.log
├── 2 Sup-Citys
├── r101-1488
│ └── seg_2022-09-26_12_47_19.log
├── r101-186
│ └── seg_2022-09-24_19_05_36.log
├── r101-372
│ └── seg_2022-09-25_09_50_10.log
├── r101-744
│ └── seg_2022-09-26_19_25_39.log
├── r50-1488
│ └── seg_2022-09-25_12_18_19.log
├── r50-186
│ └── seg_2022-09-23_22_52_31.log
├── r50-372
│ └── seg_2022-09-24_11_06_55.log
└── r50-744
│ └── seg_2022-09-24_22_04_08.log
├── 3 Semi-Voc-fine
├── 92
│ ├── config_semi.yaml
│ └── seg_2022-11-05_18:19:47.log
├── 183
│ ├── config_semi.yaml
│ └── seg_2022-11-05_18:17:48.log
├── 366
│ ├── config_semi.yaml
│ └── seg_2022-11-05_17_56_47.log
├── 732
│ ├── config_semi.yaml
│ └── seg_2022-11-05_18:16:33.log
└── 1464
│ ├── config_semi.yaml
│ └── seg_2022-11-05_18:14:50.log
├── 4 Semi-Voc-blender
├── r101-1323
│ ├── config_semi.yaml
│ └── seg_2022-11-04_02_44_22.log
├── r101-2646
│ ├── config_semi.yaml
│ └── seg_2022-11-04_23_43_51.log
├── r101-662
│ ├── config_semi.yaml
│ └── seg_2022-11-04_23_45_28.log
├── r50-1323
│ ├── config_semi.yaml
│ └── seg_2022-11-03_12_42_19.log
├── r50-2646
│ ├── config_semi.yaml
│ └── seg_2022-11-03_13_58_37.log
└── r50-662
│ ├── config_semi.yaml
│ └── seg_2022-11-02_20_35_40.log
├── 5 Semi-Voc-blender-split-u2pl
├── r101-1323
│ ├── config_semi.yaml
│ └── seg_2022-11-04_02_19_35.log
├── r101-2646
│ ├── config_semi.yaml
│ └── seg_2022-11-04_15_38_27.log
├── r101-662
│ ├── config_semi.yaml
│ └── seg_2022-11-04_02_23_01.log
├── r50-1323
│ ├── config_semi.yaml
│ └── seg_2022-11-04_02_15_18.log
├── r50-2646
│ ├── config_semi.yaml
│ └── seg_2022-11-04_02_15_17.log
└── r50-662
│ ├── config_semi.yaml
│ └── seg_2022-11-04_02_23_48.log
└── 6 Semi-citys
├── R101-1488
├── config_semi.yaml
└── seg_2022-11-08_23_12_47.log
├── R101-186
├── config_semi.yaml
└── seg_2022-11-06_15_37_41.log
├── R101-372
├── config_semi.yaml
└── seg_2022-11-06_22_15_34.log
├── R101-744
├── config_semi.yaml
└── seg_2022-11-06_22:34:45.log
├── R50-1488
├── config_semi.yaml
└── seg_2022-11-03_14_17_21.log
├── R50-186
├── config_semi.yaml
└── seg_2022-11-03_14_38_38.log
├── R50-372
├── config_semi.yaml
└── seg_2022-11-06_13_28_34.log
└── R50-744
├── config_semi.yaml
└── seg_2022-11-05_18_28_53.log
/README.md:
--------------------------------------------------------------------------------
1 | # AugSeg
2 |
3 | > "Augmentation Matters: A Simple-yet-Effective Approach to Semi-supervised Semantic Segmentation".
4 |
5 |
6 |
7 | ### Introduction
8 |
9 |
10 | - Recent studies on semi-supervised semantic segmentation (SSS) have seen fast progress. Despite their promising performance, current state-of-the-art methods tend to increasingly complex designs at the cost of introducing more network components and additional training procedures.
11 |
12 | - Differently, in this work, we follow a standard teacher-student framework and propose AugSeg, a simple and clean approach that focuses mainly on data perturbations to boost the SSS performance. We argue that various data augmentations should be adjusted to better adapt to the semi-supervised scenarios instead of directly applying these techniques from supervised learning. Specifically, we adopt a simplified intensity-based augmentation that selects a random number of data transformations with uniformly sampling distortion strengths from a continuous space. Based on the estimated confidence of the model on different unlabeled samples, we also randomly inject labelled information to augment the unlabeled samples in an adaptive manner.
13 |
14 | - Without bells and whistles, our simple AugSeg can readily achieve new state-of-the-art performance on SSS benchmarks under different partition protocols.
15 |
16 |
17 |
18 |
19 | ### Diagram
20 |
21 | 
22 |
23 | > Without any complicated designs, AugSeg readily obtains new SOTA performance on popular SSS benchmarks under different partition protocols. We hope our AugSeg can inspir future studies, and serve as a strong baseline for SSS.
24 |
25 |
26 |
27 | ### Performance
28 |
29 | Labeled images are sampled from the **original high-quality** training set. Results are obtained by DeepLabv3+ based on ResNet-101 with training size 512.
30 |
31 | | Method | 1/115 (92)| 1/57 (183)| 1/28 (366)| 1/14 (732)| 1/7 (1464) |
32 | | :-------------------------: | :-------: | :-------: | :-------: | :-------: | :---------: |
33 | | SupOnly | 43.92 | 59.10| 65.88 | 70.87 | 74.97 |
34 | | ST++ | 65.2 | 71.0 | 74.6 | 77.3 | 79.1 |
35 | | PS-MT | 65.80 | 69.58 | 76.57 | 78.42 | 80.01 |
36 | | U2PL | 67.98 | 69.15 | 73.66 | 76.16 | 79.49 |
37 | | **AugSeg** | **71.09** | **75.45** | **78.80** | **80.33** | **81.36** |
38 |
39 |
40 | Results are obtained by DeepLabv3+ based on ResNet-50/101. We reproduce U2PL results on ResNet-50.
41 |
42 | | R50 | 1/16 | 1/8 | 1/4 | 1/2 | R101 | 1/16 | 1/8 | 1/4 | 1/2 |
43 | | :-------------------------: | :-------: | :-------: | :-------: | :-------: | :---------: | :---------: | :---------: | --------------------------- | --------------------------- |
44 | | SupOnly | 63.34 | 68.73 | 74.14 | 76.62 | SupOnly | 64.77 | 71.64 | 75.24 | 78.03 |
45 | | U2PL | 69.03 | 73.02 | 76.31 | 78.64 | U2PL | 70.30 | 74.37 | 76.47 | 79.05 |
46 | | PS-MT | - | 75.76 | 76.92 | 77.64 | PS-MT | - | 76.89 | 77.60 | 79.09 |
47 | | **AugSeg** | **73.73** | **76.49** | **78.76** | **79.33** | **AugSeg** | **75.22** | **77.82** | **79.56** | **80.43** |
48 |
49 | > All the training logs of AugSeg and our reproduced SupOnly baselines are included under the directory of [training-logs](./training-logs)
50 |
51 |
52 |
53 | ## Running AugSeg
54 |
55 | ### Prepare datasets
56 |
57 | Please download the Pascal and Cityscapes, and set up the path to them properly in the configuration files.
58 |
59 | - Pascal: [JPEGImages](http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar) | [SegmentationClass](https://drive.google.com/file/d/1ikrDlsai5QSf2GiSUR3f8PZUzyTubcuF/view?usp=sharing)
60 | - Cityscapes: [leftImg8bit](https://www.cityscapes-dataset.com/file-handling/?packageID=3) | [gtFine](https://drive.google.com/file/d/1E_27g9tuHm6baBqcA7jct_jqcGA89QPm/view?usp=sharing)
61 |
62 | - Splitall: included.
63 |
64 | Here is our adopted way,
65 | ```
66 | ├── ./data
67 | ├── splitsall
68 | ├── cityscapes
69 | ├── pascal
70 | └── pascal_u2pl
71 | ├── VOC2012
72 | ├── JPEGImages
73 | ├── SegmentationClass
74 | └── SegmentationClassAug
75 | └── cityscapes
76 | ├── gtFine
77 | └── leftImg8bit
78 | ```
79 |
80 |
81 |
82 | ### Prepare pre-trained encoder
83 |
84 | Please download the pretrained models, and set up the path to these models properly in the file of `config_xxx.yaml` .
85 |
86 | ~~[ResNet-50](https://drive.google.com/file/d/1AuyE_rCUSwDpjMJHMPklXeKdZpdH1-6F/view?usp=sharing) | [ResNet-101](https://drive.google.com/file/d/13jNMOEYkqBC3CimlSSw-sWRHVZEeROmK/view?usp=sharing)~~
87 |
88 | [ResNet-50](https://drive.google.com/file/d/1mqUrqFvTQ0k5QEotk4oiOFyP6B9dVZXS/view?usp=sharing) | [ResNet-101](https://drive.google.com/file/d/1Rx0legsMolCWENpfvE2jUScT3ogalMO8/view?usp=sharing)
89 |
90 | Here is our adopted way,
91 |
92 | ```
93 | ├── ./pretrained
94 | ├── resnet50.pth
95 | └── resnet101.pth
96 | ```
97 |
98 |
99 |
100 | ### Prepare running Envs
101 |
102 | Nothing special
103 | - python: 3.7.13
104 | - pytorch: 1.7.1
105 | - cuda11.0.221_cudnn8.0.5_0
106 | - torchvision: 0.8.2
107 |
108 |
109 |
110 | ### Ready to Run
111 |
112 | Basically, you are recommanded to config the experimental runnings in a ".yaml" file firstly.
113 | We include various configuration files under the directory of "exps".
114 |
115 |
116 | ```bash
117 | # 1) configure your yaml file in a running script
118 | vim ./scripts/run_abls_citys.sh
119 |
120 | # 2) run directly
121 | sh ./scripts/run_abls_citys.sh
122 |
123 | ```
124 |
125 | ## Citation
126 |
127 | If you find these projects useful, please consider citing:
128 |
129 | ```bibtex
130 | @inproceedings{zhao2023augmentation,
131 | title={Augmentation Matters: A Simple-yet-Effective Approach to Semi-supervised Semantic Segmentation},
132 | author={Zhao, Zhen and Yang, Lihe and Long, Sifan and Pi, Jimin and Zhou, Luping and Wang, Jingdong},
133 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
134 | pages={11350--11359},
135 | year={2023}
136 | }
137 | ```
138 |
139 | We have other relevant semi-supervised semantic segmentation projects:
140 | - [ST++](https://github.com/LiheYoung/ST-PlusPlus)
141 | - [iMas](https://github.com/ZhenZHAO/iMAS)
142 | - [Unimatch](https://github.com/LiheYoung/UniMatch)
143 |
144 |
145 | ## Acknowledgement
146 |
147 | We thank [ST++](https://github.com/LiheYoung/ST-PlusPlus), [CPS](https://github.com/charlesCXK/TorchSemiSeg), and [U2PL](https://github.com/Haochen-Wang409/U2PL), for part of their codes, processed datasets, data partitions, and pretrained models.
148 |
--------------------------------------------------------------------------------
/augseg/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZhenZHAO/AugSeg/7ead8705cf6ce6c9234f52c414ec2237ed6743cd/augseg/__init__.py
--------------------------------------------------------------------------------
/augseg/dataset/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZhenZHAO/AugSeg/7ead8705cf6ce6c9234f52c414ec2237ed6743cd/augseg/dataset/__init__.py
--------------------------------------------------------------------------------
/augseg/dataset/augs_ALIA.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import random
3 | import torch
4 | import scipy.stats as stats
5 |
6 |
7 | # # # # # # # # # # # # # # # # # # # # #
8 | # # 0 random box
9 | # # # # # # # # # # # # # # # # # # # # #
10 | def rand_bbox(size, lam=None):
11 | # past implementation
12 | if len(size) == 4:
13 | W = size[2]
14 | H = size[3]
15 | elif len(size) == 3:
16 | W = size[1]
17 | H = size[2]
18 | else:
19 | raise Exception
20 | B = size[0]
21 |
22 | cut_rat = np.sqrt(1. - lam)
23 | cut_w = int(W * cut_rat)
24 | cut_h = int(H * cut_rat)
25 |
26 | cx = np.random.randint(size=[B, ], low=int(W/8), high=W)
27 | cy = np.random.randint(size=[B, ], low=int(H/8), high=H)
28 |
29 | bbx1 = np.clip(cx - cut_w // 2, 0, W)
30 | bby1 = np.clip(cy - cut_h // 2, 0, H)
31 |
32 | bbx2 = np.clip(cx + cut_w // 2, 0, W)
33 | bby2 = np.clip(cy + cut_h // 2, 0, H)
34 |
35 |
36 | return bbx1, bby1, bbx2, bby2
37 |
38 |
39 | # # # # # # # # # # # # # # # # # # # # #
40 | # # 1 cutmix label-adaptive
41 | # # # # # # # # # # # # # # # # # # # # #
42 | def cut_mix_label_adaptive(unlabeled_image, unlabeled_mask, unlabeled_logits,
43 | labeled_image, labeled_mask, lst_confidences):
44 | assert len(lst_confidences) == len(unlabeled_image), "Ensure the confidence is properly obtained"
45 | assert labeled_image.shape == unlabeled_image.shape, "Ensure shape match between lb and unlb"
46 | mix_unlabeled_image = unlabeled_image.clone()
47 | mix_unlabeled_target = unlabeled_mask.clone()
48 | mix_unlabeled_logits = unlabeled_logits.clone()
49 | labeled_logits = torch.ones_like(labeled_mask)
50 |
51 | # 1) get the random mixing objects
52 | u_rand_index = torch.randperm(unlabeled_image.size()[0])[:unlabeled_image.size()[0]]
53 |
54 | # 2) get box
55 | l_bbx1, l_bby1, l_bbx2, l_bby2 = rand_bbox(unlabeled_image.size(), lam=np.random.beta(8, 2))
56 | u_bbx1, u_bby1, u_bbx2, u_bby2 = rand_bbox(unlabeled_image.size(), lam=np.random.beta(4, 4))
57 |
58 | # 3) labeled adaptive
59 | for i in range(0, mix_unlabeled_image.shape[0]):
60 | if np.random.random() > lst_confidences[i]:
61 | mix_unlabeled_image[i, :, l_bbx1[i]:l_bbx2[i], l_bby1[i]:l_bby2[i]] = \
62 | labeled_image[u_rand_index[i], :, l_bbx1[i]:l_bbx2[i], l_bby1[i]:l_bby2[i]]
63 |
64 | mix_unlabeled_target[i, l_bbx1[i]:l_bbx2[i], l_bby1[i]:l_bby2[i]] = \
65 | labeled_mask[u_rand_index[i], l_bbx1[i]:l_bbx2[i], l_bby1[i]:l_bby2[i]]
66 |
67 | mix_unlabeled_logits[i, l_bbx1[i]:l_bbx2[i], l_bby1[i]:l_bby2[i]] = \
68 | labeled_logits[u_rand_index[i], l_bbx1[i]:l_bbx2[i], l_bby1[i]:l_bby2[i]]
69 |
70 | # 4) copy and paste
71 | for i in range(0, unlabeled_image.shape[0]):
72 | unlabeled_image[i, :, u_bbx1[i]:u_bbx2[i], u_bby1[i]:u_bby2[i]] = \
73 | mix_unlabeled_image[u_rand_index[i], :, u_bbx1[i]:u_bbx2[i], u_bby1[i]:u_bby2[i]]
74 |
75 | unlabeled_mask[i, u_bbx1[i]:u_bbx2[i], u_bby1[i]:u_bby2[i]] = \
76 | mix_unlabeled_target[u_rand_index[i], u_bbx1[i]:u_bbx2[i], u_bby1[i]:u_bby2[i]]
77 |
78 | unlabeled_logits[i, u_bbx1[i]:u_bbx2[i], u_bby1[i]:u_bby2[i]] = \
79 | mix_unlabeled_logits[u_rand_index[i], u_bbx1[i]:u_bbx2[i], u_bby1[i]:u_bby2[i]]
80 |
81 | del mix_unlabeled_image, mix_unlabeled_target, mix_unlabeled_logits, labeled_logits
82 |
83 | return unlabeled_image, unlabeled_mask, unlabeled_logits
84 |
--------------------------------------------------------------------------------
/augseg/dataset/augs_TIBA.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import scipy.stats as stats
3 | from PIL import Image, ImageOps, ImageFilter, ImageEnhance
4 | import random
5 | import collections
6 | import cv2
7 | import torch
8 | from torchvision import transforms
9 |
10 |
11 | # # # # # # # # # # # # # # # # # # # # # # # #
12 | # # # 1. Augmentation for image and labels
13 | # # # # # # # # # # # # # # # # # # # # # # # #
14 | class Compose(object):
15 | def __init__(self, segtransforms):
16 | self.segtransforms = segtransforms
17 |
18 | def __call__(self, image, label):
19 | for idx, t in enumerate(self.segtransforms):
20 | if isinstance(t, strong_img_aug):
21 | image = t(image)
22 | else:
23 | image, label = t(image, label)
24 | return image, label
25 |
26 |
27 | class ToTensorAndNormalize(object):
28 | def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
29 | assert len(mean) == len(std)
30 | assert len(mean) == 3
31 | self.normalize = transforms.Normalize(mean, std)
32 | self.to_tensor = transforms.ToTensor()
33 |
34 | def __call__(self, in_image, in_label):
35 | in_image = Image.fromarray(np.uint8(in_image))
36 | image = self.normalize(self.to_tensor(in_image))
37 | label = torch.from_numpy(np.array(in_label, dtype=np.int32)).long()
38 |
39 | return image, label
40 |
41 |
42 | class Resize(object):
43 | def __init__(self, base_size, ratio_range, scale=True, bigger_side_to_base_size=True):
44 | # assert isinstance(ratio_range, collections.Iterable) and len(ratio_range) == 2
45 | assert isinstance(ratio_range, collections.abc.Iterable) and len(ratio_range) == 2 # for recent python version
46 | self.base_size = base_size
47 | self.ratio_range = ratio_range
48 | self.scale = scale
49 | self.bigger_side_to_base_size = bigger_side_to_base_size
50 |
51 | def __call__(self, in_image, in_label):
52 | w, h = in_image.size
53 |
54 | if isinstance(self.base_size, int):
55 | # obtain long_side
56 | if self.scale:
57 | long_side = random.randint(int(self.base_size * self.ratio_range[0]),
58 | int(self.base_size * self.ratio_range[1]))
59 | else:
60 | long_side = self.base_size
61 |
62 | # obtain new oh, ow
63 | if self.bigger_side_to_base_size:
64 | if h > w:
65 | oh = long_side
66 | ow = int(1.0 * long_side * w / h + 0.5)
67 | else:
68 | oh = int(1.0 * long_side * h / w + 0.5)
69 | ow = long_side
70 | else:
71 | oh, ow = (long_side, int(1.0 * long_side * w / h + 0.5)) if h < w else (
72 | int(1.0 * long_side * h / w + 0.5), long_side)
73 |
74 | image = in_image.resize((ow, oh), Image.BILINEAR)
75 | label = in_label.resize((ow, oh), Image.NEAREST)
76 | return image, label
77 | elif (isinstance(self.base_size, list) or isinstance(self.base_size, tuple)) and len(self.base_size) == 2:
78 | if self.scale:
79 | # scale = random.random() * 1.5 + 0.5 # Scaling between [0.5, 2]
80 | scale = self.ratio_range[0] + random.random() * (self.ratio_range[1] - self.ratio_range[0])
81 | # print("="*100, h, self.base_size[0])
82 | # print("="*100, w, self.base_size[1])
83 | oh, ow = int(self.base_size[0] * scale), int(self.base_size[1] * scale)
84 | else:
85 | oh, ow = self.base_size
86 | image = in_image.resize((ow, oh), Image.BILINEAR)
87 | label = in_label.resize((ow, oh), Image.NEAREST)
88 | # print("="*100, in_image.size, image.size)
89 | return image, label
90 |
91 | else:
92 | raise ValueError
93 |
94 |
95 | class Crop(object):
96 | def __init__(self, crop_size, crop_type="rand", mean=[0.485, 0.456, 0.406], ignore_value=255):
97 | if (isinstance(crop_size, list) or isinstance(crop_size, tuple)) and len(crop_size) == 2:
98 | self.crop_h, self.crop_w = crop_size
99 | elif isinstance(crop_size, int):
100 | self.crop_h, self.crop_w = crop_size, crop_size
101 | else:
102 | raise ValueError
103 |
104 | self.crop_type = crop_type
105 | self.image_padding = (np.array(mean) * 255.).tolist()
106 | self.ignore_value = ignore_value
107 |
108 | def __call__(self, in_image, in_label):
109 | # Padding to return the correct crop size
110 | w, h = in_image.size
111 | pad_h = max(self.crop_h - h, 0)
112 | pad_w = max(self.crop_w - w, 0)
113 | pad_kwargs = {
114 | "top": 0,
115 | "bottom": pad_h,
116 | "left": 0,
117 | "right": pad_w,
118 | "borderType": cv2.BORDER_CONSTANT,
119 | }
120 | if pad_h > 0 or pad_w > 0:
121 | image = cv2.copyMakeBorder(np.asarray(in_image, dtype=np.float32),
122 | value=self.image_padding, **pad_kwargs)
123 | label = cv2.copyMakeBorder(np.asarray(in_label, dtype=np.int32),
124 | value=self.ignore_value, **pad_kwargs)
125 | image = Image.fromarray(np.uint8(image))
126 | label = Image.fromarray(np.uint8(label))
127 | else:
128 | image = in_image
129 | label = in_label
130 |
131 | # cropping
132 | w, h = image.size
133 | if self.crop_type == "rand":
134 | x = random.randint(0, w - self.crop_w)
135 | y = random.randint(0, h - self.crop_h)
136 | else:
137 | x = (w - self.crop_w) // 2
138 | y = (h - self.crop_h) // 2
139 | image = image.crop((x, y, x + self.crop_w, y + self.crop_h))
140 | label = label.crop((x, y, x + self.crop_w, y + self.crop_h))
141 | return image, label
142 |
143 |
144 | class RandomFlip(object):
145 | def __init__(self, prob=0.5, flag_hflip=True,):
146 | self.prob = prob
147 | if flag_hflip:
148 | self.type_flip = Image.FLIP_LEFT_RIGHT
149 | else:
150 | self.type_flip = Image.FLIP_TOP_BOTTOM
151 |
152 | def __call__(self, in_image, in_label):
153 | if random.random() < self.prob:
154 | in_image = in_image.transpose(self.type_flip)
155 | in_label = in_label.transpose(self.type_flip)
156 | return in_image, in_label
157 |
158 |
159 | # # # # # # # # # # # # # # # # # # # # # # # #
160 | # # # 2. Strong Augmentation for image only
161 | # # # # # # # # # # # # # # # # # # # # # # # #
162 |
163 | def img_aug_identity(img, scale=None):
164 | return img
165 |
166 |
167 | def img_aug_autocontrast(img, scale=None):
168 | return ImageOps.autocontrast(img)
169 |
170 |
171 | def img_aug_equalize(img, scale=None):
172 | return ImageOps.equalize(img)
173 |
174 |
175 | def img_aug_invert(img, scale=None):
176 | return ImageOps.invert(img)
177 |
178 |
179 | def img_aug_blur(img, scale=[0.1, 2.0]):
180 | assert scale[0] < scale[1]
181 | sigma = np.random.uniform(scale[0], scale[1])
182 | # print(f"sigma:{sigma}")
183 | return img.filter(ImageFilter.GaussianBlur(radius=sigma))
184 |
185 |
186 | def img_aug_contrast(img, scale=[0.05, 0.95]):
187 | min_v, max_v = min(scale), max(scale)
188 | v = float(max_v - min_v)*random.random()
189 | v = max_v - v
190 | # # print(f"final:{v}")
191 | # v = np.random.uniform(scale[0], scale[1])
192 | return ImageEnhance.Contrast(img).enhance(v)
193 |
194 |
195 | def img_aug_brightness(img, scale=[0.05, 0.95]):
196 | min_v, max_v = min(scale), max(scale)
197 | v = float(max_v - min_v)*random.random()
198 | v = max_v - v
199 | # print(f"final:{v}")
200 | return ImageEnhance.Brightness(img).enhance(v)
201 |
202 |
203 | def img_aug_color(img, scale=[0.05, 0.95]):
204 | min_v, max_v = min(scale), max(scale)
205 | v = float(max_v - min_v)*random.random()
206 | v = max_v - v
207 | # print(f"final:{v}")
208 | return ImageEnhance.Color(img).enhance(v)
209 |
210 |
211 | def img_aug_sharpness(img, scale=[0.05, 0.95]):
212 | min_v, max_v = min(scale), max(scale)
213 | v = float(max_v - min_v)*random.random()
214 | v = max_v - v
215 | # print(f"final:{v}")
216 | return ImageEnhance.Sharpness(img).enhance(v)
217 |
218 |
219 | def img_aug_hue(img, scale=[0, 0.5]):
220 | min_v, max_v = min(scale), max(scale)
221 | v = float(max_v - min_v)*random.random()
222 | v += min_v
223 | if np.random.random() < 0.5:
224 | hue_factor = -v
225 | else:
226 | hue_factor = v
227 | # print(f"Final-V:{hue_factor}")
228 | input_mode = img.mode
229 | if input_mode in {"L", "1", "I", "F"}:
230 | return img
231 | h, s, v = img.convert("HSV").split()
232 | np_h = np.array(h, dtype=np.uint8)
233 | # uint8 addition take cares of rotation across boundaries
234 | with np.errstate(over="ignore"):
235 | np_h += np.uint8(hue_factor * 255)
236 | h = Image.fromarray(np_h, "L")
237 | img = Image.merge("HSV", (h, s, v)).convert(input_mode)
238 | return img
239 |
240 |
241 | def img_aug_posterize(img, scale=[4, 8]):
242 | min_v, max_v = min(scale), max(scale)
243 | v = float(max_v - min_v)*random.random()
244 | # print(min_v, max_v, v)
245 | v = int(np.ceil(v))
246 | v = max(1, v)
247 | v = max_v - v
248 | # print(f"final:{v}")
249 | return ImageOps.posterize(img, v)
250 |
251 |
252 | def img_aug_solarize(img, scale=[1, 256]):
253 | min_v, max_v = min(scale), max(scale)
254 | v = float(max_v - min_v)*random.random()
255 | # print(min_v, max_v, v)
256 | v = int(np.ceil(v))
257 | v = max(1, v)
258 | v = max_v - v
259 | # print(f"final:{v}")
260 | return ImageOps.solarize(img, v)
261 |
262 | def get_augment_list(flag_using_wide=False):
263 | if flag_using_wide:
264 | l = [
265 | (img_aug_identity, None),
266 | (img_aug_autocontrast, None),
267 | (img_aug_equalize, None),
268 | (img_aug_blur, [0.1, 2.0]),
269 | (img_aug_contrast, [0.1, 1.8]),
270 | (img_aug_brightness, [0.1, 1.8]),
271 | (img_aug_color, [0.1, 1.8]),
272 | (img_aug_sharpness, [0.1, 1.8]),
273 | (img_aug_posterize, [2, 8]),
274 | (img_aug_solarize, [1, 256]),
275 | (img_aug_hue, [0, 0.5])
276 | ]
277 | else:
278 | l = [
279 | (img_aug_identity, None),
280 | (img_aug_autocontrast, None),
281 | (img_aug_equalize, None),
282 | (img_aug_blur, [0.1, 2.0]),
283 | (img_aug_contrast, [0.05, 0.95]),
284 | (img_aug_brightness, [0.05, 0.95]),
285 | (img_aug_color, [0.05, 0.95]),
286 | (img_aug_sharpness, [0.05, 0.95]),
287 | (img_aug_posterize, [4, 8]),
288 | (img_aug_solarize, [1, 256]),
289 | (img_aug_hue, [0, 0.5])
290 | ]
291 | return l
292 |
293 |
294 | class strong_img_aug:
295 | def __init__(self, num_augs, flag_using_random_num=False):
296 | assert 1<= num_augs <= 11
297 | self.n = num_augs
298 | self.augment_list = get_augment_list(flag_using_wide=False)
299 | self.flag_using_random_num = flag_using_random_num
300 |
301 | def __call__(self, img):
302 | if self.flag_using_random_num:
303 | max_num = np.random.randint(1, high=self.n + 1)
304 | else:
305 | max_num =self.n
306 | ops = random.choices(self.augment_list, k=max_num)
307 | for op, scales in ops:
308 | # print("="*20, str(op))
309 | img = op(img, scales)
310 | return img
311 |
--------------------------------------------------------------------------------
/augseg/dataset/base.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | from PIL import Image
4 | from torch.utils.data import Dataset
5 |
6 |
7 | class BaseDataset(Dataset):
8 | def __init__(self, d_list, **kwargs):
9 | # parse the input list
10 | self.parse_input_list(d_list, **kwargs)
11 |
12 | def parse_input_list(self, d_list, max_sample=-1, start_idx=-1, end_idx=-1):
13 | logger = logging.getLogger("global")
14 | assert isinstance(d_list, str)
15 | if "cityscapes" in d_list:
16 | self.list_sample = [
17 | [
18 | line.strip(),
19 | "gtFine/" + line.strip()[12:-15] + "gtFine_labelTrainIds.png",
20 | ]
21 | for line in open(d_list, "r")
22 | ]
23 | elif "pascal" in d_list or "VOC" in d_list:
24 | self.list_sample = [
25 | [
26 | "JPEGImages/{}.jpg".format(line.strip()),
27 | "SegmentationClassAug/{}.png".format(line.strip()),
28 | ]
29 | for line in open(d_list, "r")
30 | ]
31 | else:
32 | raise "unknown dataset!"
33 |
34 | if max_sample > 0:
35 | self.list_sample = self.list_sample[0:max_sample]
36 | if start_idx >= 0 and end_idx >= 0:
37 | self.list_sample = self.list_sample[start_idx:end_idx]
38 |
39 | self.num_sample = len(self.list_sample)
40 | assert self.num_sample > 0
41 | logger.info("# samples: {}".format(self.num_sample))
42 |
43 | def img_loader(self, path, mode):
44 | with open(path, "rb") as f:
45 | img = Image.open(f)
46 | return img.convert(mode)
47 |
48 | def __len__(self):
49 | return self.num_sample
50 |
--------------------------------------------------------------------------------
/augseg/dataset/builder.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | from .cityscapes import build_city_semi_loader, build_cityloader
4 | from .pascal_voc import build_voc_semi_loader, build_vocloader
5 |
6 | logger = logging.getLogger("global")
7 |
8 |
9 | def get_loader(cfg, seed=0):
10 | cfg_dataset = cfg["dataset"]
11 |
12 | if cfg_dataset["type"] == "cityscapes_semi":
13 | train_loader_sup, train_loader_unsup = build_city_semi_loader(
14 | "train", cfg, seed=seed
15 | )
16 | val_loader = build_cityloader("val", cfg)
17 | logger.info("Get loader Done...")
18 | return train_loader_sup, train_loader_unsup, val_loader
19 |
20 | elif cfg_dataset["type"] == "cityscapes":
21 | train_loader_sup = build_cityloader("train", cfg, seed=seed)
22 | val_loader = build_cityloader("val", cfg)
23 | logger.info("Get loader Done...")
24 | return train_loader_sup, val_loader
25 |
26 | elif cfg_dataset["type"] == "pascal_semi":
27 | train_loader_sup, train_loader_unsup = build_voc_semi_loader(
28 | "train", cfg, seed=seed
29 | )
30 | val_loader = build_vocloader("val", cfg)
31 | logger.info("Get loader Done...")
32 | return train_loader_sup, train_loader_unsup, val_loader
33 |
34 | elif cfg_dataset["type"] == "pascal":
35 | train_loader_sup = build_vocloader("train", cfg, seed=seed)
36 | val_loader = build_vocloader("val", cfg)
37 | logger.info("Get loader Done...")
38 | return train_loader_sup, val_loader
39 |
40 | else:
41 | raise NotImplementedError(
42 | "dataset type {} is not supported".format(cfg_dataset)
43 | )
44 |
--------------------------------------------------------------------------------
/augseg/dataset/cityscapes.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import math
3 | import os
4 | import os.path
5 | import random
6 |
7 | import numpy as np
8 | import torch
9 | import torch.distributed as dist
10 | from torch.utils.data import DataLoader
11 | from torch.utils.data.distributed import DistributedSampler
12 | from torchvision import transforms
13 |
14 | # from . import augmentations as img_trsform
15 | from . import augs_TIBA as img_trsform
16 | from .base import BaseDataset
17 |
18 | # https://pytorch.org/docs/stable/notes/randomness.html
19 | def seed_worker(worker_id):
20 | cur_seed = np.random.get_state()[1][0]
21 | cur_seed += worker_id
22 | np.random.seed(cur_seed)
23 | random.seed(cur_seed)
24 |
25 |
26 | class city_dset(BaseDataset):
27 | def __init__(self, data_root, data_list, trs_form, trs_form_strong=None,
28 | seed=0, n_sup=2975, split="val", flag_semi=False,
29 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
30 | ):
31 | super(city_dset, self).__init__(data_list)
32 | self.data_root = data_root
33 | self.transform_weak = trs_form
34 | self.transform_strong = trs_form_strong
35 | self.flag_semi = flag_semi
36 | self.split = split
37 | # random.seed(seed)
38 |
39 | self.trf_normalize = self._get_to_tensor_and_normalize(mean, std)
40 |
41 | # oversamplying labeled data for semi-supervised training
42 | if len(self.list_sample) >= n_sup and split == "train":
43 | self.list_sample_new = random.sample(self.list_sample, n_sup)
44 | elif len(self.list_sample) < n_sup and split == "train":
45 | num_repeat = math.ceil(n_sup / len(self.list_sample))
46 | self.list_sample = self.list_sample * num_repeat
47 |
48 | self.list_sample_new = random.sample(self.list_sample, n_sup)
49 | else:
50 | self.list_sample_new = self.list_sample
51 |
52 | @staticmethod
53 | def _get_to_tensor_and_normalize(mean, std):
54 | return img_trsform.ToTensorAndNormalize(mean, std)
55 |
56 | def __getitem__(self, index):
57 | # load image and its label
58 | image_path = os.path.join(self.data_root, self.list_sample_new[index][0])
59 | label_path = os.path.join(self.data_root, self.list_sample_new[index][1])
60 | image = self.img_loader(image_path, "RGB")
61 | label = self.img_loader(label_path, "L")
62 |
63 | if self.transform_strong is None:
64 | image, label = self.transform_weak(image, label)
65 | # print(image.shape, label.shape)
66 | image, label = self.trf_normalize(image, label)
67 | if not self.flag_semi:
68 | return index, image, label
69 | else:
70 | return index, image, image.clone(), label
71 | else:
72 | # apply augmentation
73 | image_weak, label = self.transform_weak(image, label)
74 | image_strong = self.transform_strong(image_weak)
75 | # print("="*100)
76 | # print(index, image_weak.size, image_strong.size, label.size)
77 | # print("="*100)
78 |
79 | image_weak, label = self.trf_normalize(image_weak, label)
80 | image_strong, _ = self.trf_normalize(image_strong, label)
81 | # print(index, image_weak.shape, image_strong.shape,label.shape)
82 |
83 | return index, image_weak, image_strong, label
84 |
85 | # image, label = self.transform(image, label)
86 | # return image[0], label[0, 0].long()
87 |
88 | def __len__(self):
89 | return len(self.list_sample_new)
90 |
91 |
92 | def build_additional_strong_transform(cfg):
93 | assert cfg.get("strong_aug", False) != False
94 | strong_aug_nums = cfg["strong_aug"].get("num_augs", 2)
95 | flag_use_rand_num = cfg["strong_aug"].get("flag_use_random_num_sampling", True)
96 | strong_img_aug = img_trsform.strong_img_aug(strong_aug_nums,
97 | flag_using_random_num=flag_use_rand_num)
98 | return strong_img_aug
99 |
100 |
101 | def build_basic_transfrom(cfg, split="val", mean=[0.485, 0.456, 0.406]):
102 | ignore_label = cfg["ignore_label"]
103 | trs_form = []
104 | if split != "val":
105 | if cfg.get("rand_resize", False):
106 | trs_form.append(img_trsform.Resize(cfg.get("resize_base_size", [1024, 2048]), cfg["rand_resize"]))
107 |
108 | if cfg.get("flip", False):
109 | trs_form.append(img_trsform.RandomFlip(prob=0.5, flag_hflip=True))
110 |
111 | # crop also sometime for validating
112 | if cfg.get("crop", False):
113 | crop_size, crop_type = cfg["crop"]["size"], cfg["crop"]["type"]
114 | trs_form.append(img_trsform.Crop(crop_size, crop_type=crop_type, mean=mean, ignore_value=ignore_label))
115 |
116 | return img_trsform.Compose(trs_form)
117 |
118 |
119 | def build_cityloader(split, all_cfg, seed=0):
120 | # extract augs config from "train"/"val" into the higher level.
121 | cfg_dset = all_cfg["dataset"]
122 | cfg = copy.deepcopy(cfg_dset)
123 | cfg.update(cfg.get(split, {}))
124 |
125 | # set up workers and batchsize
126 | workers = cfg.get("workers", 2)
127 | batch_size = cfg.get("batch_size", 1)
128 | n_sup = cfg.get("n_sup", 2975)
129 |
130 | # build transform
131 | mean, std = cfg["mean"], cfg["std"]
132 | trs_form = build_basic_transfrom(cfg, split=split, mean=mean)
133 |
134 | # create dataset
135 | dset = city_dset(cfg["data_root"], cfg["data_list"], trs_form, None,
136 | seed, n_sup, mean=mean, std=std)
137 |
138 | # build sampler
139 | sample = DistributedSampler(dset)
140 | loader = DataLoader(
141 | dset,
142 | batch_size=batch_size,
143 | num_workers=workers,
144 | sampler=sample,
145 | shuffle=False,
146 | pin_memory=False,
147 | worker_init_fn=seed_worker,
148 | )
149 | return loader
150 |
151 |
152 | def build_city_semi_loader(split, all_cfg, seed=0):
153 | split = "train"
154 | # extract augs config from "train" into the higher level.
155 | cfg_dset = all_cfg["dataset"]
156 | cfg = copy.deepcopy(cfg_dset)
157 | cfg.update(cfg.get(split, {}))
158 |
159 | # set up workers and batchsize
160 | workers = cfg.get("workers", 2)
161 | batch_size = cfg.get("batch_size", 2)
162 | n_sup = 2975 - cfg.get("n_sup", 2975)
163 |
164 | # build transform
165 | mean, std = cfg["mean"], cfg["std"]
166 | trs_form_weak = build_basic_transfrom(cfg, split=split, mean=mean)
167 | if cfg.get("strong_aug", False):
168 | trs_form_strong = build_additional_strong_transform(cfg)
169 | else:
170 | trs_form_strong = None
171 |
172 | dset = city_dset(cfg["data_root"], cfg["data_list"], trs_form_weak, None,
173 | seed, n_sup, split=split, mean=mean, std=std)
174 | sample_sup = DistributedSampler(dset)
175 |
176 | data_list_unsup = cfg["data_list"].replace("labeled.txt", "unlabeled.txt")
177 | dset_unsup = city_dset(cfg["data_root"], data_list_unsup, trs_form_weak, trs_form_strong,
178 | seed, n_sup, split,
179 | flag_semi=True,
180 | mean=mean, std=std)
181 | sample_unsup = DistributedSampler(dset_unsup)
182 |
183 | # create dataloader
184 | loader_sup = DataLoader(
185 | dset,
186 | batch_size=batch_size,
187 | num_workers=workers,
188 | sampler=sample_sup,
189 | shuffle=False,
190 | pin_memory=True,
191 | drop_last=True,
192 | worker_init_fn=seed_worker,
193 | )
194 | loader_unsup = DataLoader(
195 | dset_unsup,
196 | batch_size=batch_size,
197 | num_workers=workers,
198 | sampler=sample_unsup,
199 | shuffle=False,
200 | pin_memory=True,
201 | drop_last=True,
202 | worker_init_fn=seed_worker,
203 | )
204 | return loader_sup, loader_unsup
205 |
--------------------------------------------------------------------------------
/augseg/dataset/pascal_voc.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import math
3 | import os
4 | import os.path
5 | import random
6 |
7 | import numpy as np
8 | import torch
9 | from torch.utils.data import DataLoader, Dataset
10 | from torch.utils.data.distributed import DistributedSampler
11 | from torchvision import transforms
12 |
13 | from . import augs_TIBA as img_trsform
14 | from .base import BaseDataset
15 |
16 | # https://pytorch.org/docs/stable/notes/randomness.html
17 | def seed_worker(worker_id):
18 | cur_seed = np.random.get_state()[1][0]
19 | cur_seed += worker_id
20 | np.random.seed(cur_seed)
21 | random.seed(cur_seed)
22 |
23 |
24 | class voc_dset(BaseDataset):
25 | def __init__(
26 | self, data_root, data_list, trs_form, trs_form_strong=None,
27 | seed=0, n_sup=10582, split="val", flag_semi=False,
28 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
29 | ):
30 | super(voc_dset, self).__init__(data_list)
31 | self.data_root = data_root
32 | self.transform_weak = trs_form
33 | self.transform_strong = trs_form_strong
34 | self.flag_semi = flag_semi
35 | self.split = split
36 | # random.seed(seed) # set on the top level
37 |
38 | self.trf_normalize = self._get_to_tensor_and_normalize(mean, std)
39 |
40 | # oversamplying labeled data for semi-supervised training
41 | if len(self.list_sample) >= n_sup and split == "train":
42 | self.list_sample_new = random.sample(self.list_sample, n_sup)
43 | elif len(self.list_sample) < n_sup and split == "train":
44 | num_repeat = math.ceil(n_sup / len(self.list_sample))
45 | self.list_sample = self.list_sample * num_repeat
46 | self.list_sample_new = random.sample(self.list_sample, n_sup)
47 | else:
48 | self.list_sample_new = self.list_sample
49 |
50 | # # ADD: shuffle the image list ===> not neccesary, random.sample also shuffle it.
51 | # if split == "train":
52 | # np.random.shuffle(self.list_sample)
53 |
54 | @staticmethod
55 | def _get_to_tensor_and_normalize(mean, std):
56 | return img_trsform.ToTensorAndNormalize(mean, std)
57 |
58 | def __getitem__(self, index):
59 | # load image and its label
60 | image_path = os.path.join(self.data_root, self.list_sample_new[index][0])
61 | label_path = os.path.join(self.data_root, self.list_sample_new[index][1])
62 | image = self.img_loader(image_path, "RGB")
63 | label = self.img_loader(label_path, "L")
64 |
65 | if self.transform_strong is None:
66 | image, label = self.transform_weak(image, label)
67 | # print(image.shape, label.shape)
68 | image, label = self.trf_normalize(image, label)
69 | if not self.flag_semi:
70 | return index, image, label
71 | else:
72 | return index, image, image.clone(), label
73 | else:
74 | # apply augmentation
75 | image_weak, label = self.transform_weak(image, label)
76 | image_strong = self.transform_strong(image_weak)
77 | # print("="*100)
78 | # print(index, image_weak.size, image_strong.size, label.size)
79 | # print("="*100)
80 |
81 | image_weak, label = self.trf_normalize(image_weak, label)
82 | image_strong, _ = self.trf_normalize(image_strong, label)
83 | # print(index, image_weak.shape, image_strong.shape,label.shape)
84 |
85 | return index, image_weak, image_strong, label
86 |
87 | def __len__(self):
88 | return len(self.list_sample_new)
89 |
90 |
91 | def build_additional_strong_transform(cfg):
92 | assert cfg.get("strong_aug", False) != False
93 | strong_aug_nums = cfg["strong_aug"].get("num_augs", 2)
94 | flag_use_rand_num = cfg["strong_aug"].get("flag_use_random_num_sampling", True)
95 | strong_img_aug = img_trsform.strong_img_aug(strong_aug_nums,
96 | flag_using_random_num=flag_use_rand_num)
97 | return strong_img_aug
98 |
99 |
100 | def build_basic_transfrom(cfg, split="val", mean=[0.485, 0.456, 0.406]):
101 | ignore_label = cfg["ignore_label"]
102 | trs_form = []
103 | if split != "val":
104 | if cfg.get("rand_resize", False):
105 | trs_form.append(img_trsform.Resize(cfg.get("resize_base_size", 600), cfg["rand_resize"]))
106 |
107 | if cfg.get("flip", False):
108 | trs_form.append(img_trsform.RandomFlip(prob=0.5, flag_hflip=True))
109 |
110 | # crop also sometime for validating
111 | if cfg.get("crop", False):
112 | crop_size, crop_type = cfg["crop"]["size"], cfg["crop"]["type"]
113 | trs_form.append(img_trsform.Crop(crop_size, crop_type=crop_type, mean=mean, ignore_value=ignore_label))
114 |
115 | return img_trsform.Compose(trs_form)
116 |
117 |
118 | def build_vocloader(split, all_cfg, seed=0):
119 | # extract augs config from "train"/"val" into the higher level.
120 | cfg_dset = all_cfg["dataset"]
121 | cfg = copy.deepcopy(cfg_dset)
122 | cfg.update(cfg.get(split, {}))
123 |
124 | # set up workers and batchsize
125 | workers = cfg.get("workers", 2)
126 | batch_size = cfg.get("batch_size", 1)
127 | n_sup = cfg.get("n_sup", 10582)
128 |
129 | # build transform
130 | mean, std = cfg["mean"], cfg["std"]
131 | trs_form = build_basic_transfrom(cfg, split=split, mean=mean)
132 |
133 | # create dataset
134 | dset = voc_dset(cfg["data_root"], cfg["data_list"], trs_form, None,
135 | seed, n_sup, mean=mean, std=std)
136 |
137 | # build sampler
138 | sample = DistributedSampler(dset)
139 | loader = DataLoader(
140 | dset,
141 | batch_size=batch_size,
142 | num_workers=workers,
143 | sampler=sample,
144 | shuffle=False,
145 | pin_memory=False,
146 | worker_init_fn=seed_worker,
147 | )
148 | return loader
149 |
150 |
151 | def build_voc_semi_loader(split, all_cfg, seed=0):
152 | split = "train"
153 | # extract augs config from "train" into the higher level.
154 | cfg_dset = all_cfg["dataset"]
155 | cfg = copy.deepcopy(cfg_dset)
156 | cfg.update(cfg.get(split, {}))
157 |
158 | # set up workers and batchsize
159 | workers = cfg.get("workers", 2)
160 | batch_size = cfg.get("batch_size", 2)
161 | n_sup = 10582 - cfg.get("n_sup", 10582) # oversample labeled data to the amount of unlabeled data
162 |
163 | # build transform
164 | mean, std = cfg["mean"], cfg["std"]
165 | trs_form_weak = build_basic_transfrom(cfg, split=split, mean=mean)
166 | if cfg.get("strong_aug", False):
167 | trs_form_strong = build_additional_strong_transform(cfg)
168 | else:
169 | trs_form_strong = None
170 |
171 | dset = voc_dset(cfg["data_root"], cfg["data_list"], trs_form_weak, None,
172 | seed, n_sup, split=split, mean=mean, std=std)
173 | sample_sup = DistributedSampler(dset)
174 |
175 | data_list_unsup = cfg["data_list"].replace("labeled.txt", "unlabeled.txt")
176 | dset_unsup = voc_dset(cfg["data_root"], data_list_unsup, trs_form_weak, trs_form_strong,
177 | seed, n_sup, split,
178 | flag_semi=True,
179 | mean=mean, std=std)
180 | sample_unsup = DistributedSampler(dset_unsup)
181 |
182 | # create dataloader
183 | loader_sup = DataLoader(
184 | dset,
185 | batch_size=batch_size,
186 | num_workers=workers,
187 | sampler=sample_sup,
188 | shuffle=False,
189 | pin_memory=True,
190 | drop_last=True,
191 | worker_init_fn=seed_worker,
192 | )
193 | loader_unsup = DataLoader(
194 | dset_unsup,
195 | batch_size=batch_size,
196 | num_workers=workers,
197 | sampler=sample_unsup,
198 | shuffle=False,
199 | pin_memory=True,
200 | drop_last=True,
201 | worker_init_fn=seed_worker,
202 | )
203 | return loader_sup, loader_unsup
204 |
--------------------------------------------------------------------------------
/augseg/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZhenZHAO/AugSeg/7ead8705cf6ce6c9234f52c414ec2237ed6743cd/augseg/models/__init__.py
--------------------------------------------------------------------------------
/augseg/models/decoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import functional as F
4 |
5 | def get_syncbn():
6 | # return nn.BatchNorm2d
7 | return nn.SyncBatchNorm
8 |
9 |
10 | class dec_deeplabv3_plus(nn.Module):
11 | def __init__(
12 | self,
13 | in_planes,
14 | num_classes=19,
15 | inner_planes=256,
16 | sync_bn=False,
17 | dilations=(12, 24, 36),
18 | low_conv_planes=48,
19 | ):
20 | super(dec_deeplabv3_plus, self).__init__()
21 |
22 | norm_layer = get_syncbn() if sync_bn else nn.BatchNorm2d
23 |
24 | self.low_conv = nn.Sequential(
25 | nn.Conv2d(256, low_conv_planes, kernel_size=1),
26 | norm_layer(low_conv_planes),
27 | nn.ReLU(inplace=True)
28 | )
29 |
30 | self.aspp = ASPP(
31 | in_planes, inner_planes=inner_planes, sync_bn=sync_bn, dilations=dilations
32 | )
33 |
34 | self.head = nn.Sequential(
35 | nn.Conv2d(self.aspp.get_outplanes(), 256, 1, bias=False),
36 | norm_layer(256),
37 | nn.ReLU(inplace=True),
38 | )
39 |
40 | self.classifier = nn.Sequential(
41 | nn.Conv2d(256+int(low_conv_planes), 256, kernel_size=3, stride=1, padding=1, bias=False),
42 | norm_layer(256),
43 | nn.ReLU(inplace=True),
44 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False),
45 | norm_layer(256),
46 | nn.ReLU(inplace=True),
47 | nn.Conv2d(256, num_classes, kernel_size=1, stride=1, padding=0, bias=True),
48 | )
49 |
50 |
51 | def forward(self, x):
52 | x1, x2, x3, x4 = x
53 | low_feat = self.low_conv(x1)
54 | h, w = low_feat.size()[-2:]
55 |
56 | aspp_out = self.aspp(x4)
57 | aspp_out = self.head(aspp_out)
58 | aspp_out = F.interpolate(
59 | aspp_out, size=(h, w), mode="bilinear", align_corners=True
60 | )
61 |
62 | aspp_out = torch.cat((low_feat, aspp_out), dim=1)
63 |
64 | return self.classifier(aspp_out)
65 |
66 |
67 | class Aux_Module(nn.Module):
68 | def __init__(self, in_planes, num_classes=19, sync_bn=False):
69 | super(Aux_Module, self).__init__()
70 |
71 | norm_layer = get_syncbn() if sync_bn else nn.BatchNorm2d
72 | self.aux = nn.Sequential(
73 | nn.Conv2d(in_planes, 256, kernel_size=3, stride=1, padding=1, bias=False),
74 | norm_layer(256),
75 | nn.ReLU(inplace=True),
76 | nn.Conv2d(256, num_classes, kernel_size=1, stride=1, padding=0, bias=True),
77 | )
78 |
79 | def forward(self, x):
80 | res = self.aux(x)
81 | return res
82 |
83 |
84 | class ASPP(nn.Module):
85 | """
86 | Reference:
87 | Chen, Liang-Chieh, et al. *"Rethinking Atrous Convolution for Semantic Image Segmentation."*
88 | """
89 |
90 | def __init__(
91 | self, in_planes, inner_planes=256, sync_bn=False, dilations=(12, 24, 36)
92 | ):
93 | super(ASPP, self).__init__()
94 |
95 | norm_layer = get_syncbn() if sync_bn else nn.BatchNorm2d
96 | self.conv1 = nn.Sequential(
97 | nn.AdaptiveAvgPool2d((1, 1)),
98 | nn.Conv2d(
99 | in_planes,
100 | inner_planes,
101 | kernel_size=1,
102 | padding=0,
103 | dilation=1,
104 | bias=False,
105 | ),
106 | norm_layer(inner_planes),
107 | nn.ReLU(inplace=True),
108 | )
109 | self.conv2 = nn.Sequential(
110 | nn.Conv2d(
111 | in_planes,
112 | inner_planes,
113 | kernel_size=1,
114 | padding=0,
115 | dilation=1,
116 | bias=False,
117 | ),
118 | norm_layer(inner_planes),
119 | nn.ReLU(inplace=True),
120 | )
121 | self.conv3 = nn.Sequential(
122 | nn.Conv2d(
123 | in_planes,
124 | inner_planes,
125 | kernel_size=3,
126 | padding=dilations[0],
127 | dilation=dilations[0],
128 | bias=False,
129 | ),
130 | norm_layer(inner_planes),
131 | nn.ReLU(inplace=True),
132 | )
133 | self.conv4 = nn.Sequential(
134 | nn.Conv2d(
135 | in_planes,
136 | inner_planes,
137 | kernel_size=3,
138 | padding=dilations[1],
139 | dilation=dilations[1],
140 | bias=False,
141 | ),
142 | norm_layer(inner_planes),
143 | nn.ReLU(inplace=True),
144 | )
145 | self.conv5 = nn.Sequential(
146 | nn.Conv2d(
147 | in_planes,
148 | inner_planes,
149 | kernel_size=3,
150 | padding=dilations[2],
151 | dilation=dilations[2],
152 | bias=False,
153 | ),
154 | norm_layer(inner_planes),
155 | nn.ReLU(inplace=True),
156 | )
157 |
158 | self.out_planes = (len(dilations) + 2) * inner_planes
159 |
160 | def get_outplanes(self):
161 | return self.out_planes
162 |
163 | def forward(self, x):
164 | _, _, h, w = x.size()
165 | feat1 = F.interpolate(
166 | self.conv1(x), size=(h, w), mode="bilinear", align_corners=True
167 | )
168 | feat2 = self.conv2(x)
169 | feat3 = self.conv3(x)
170 | feat4 = self.conv4(x)
171 | feat5 = self.conv5(x)
172 | aspp_out = torch.cat((feat1, feat2, feat3, feat4, feat5), 1)
173 | return aspp_out
174 |
--------------------------------------------------------------------------------
/augseg/models/model_helper.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import torch.nn as nn
3 | from torch.nn import functional as F
4 | from .decoder import Aux_Module
5 |
6 |
7 | class ModelBuilder(nn.Module):
8 | def __init__(self, net_cfg):
9 | super(ModelBuilder, self).__init__()
10 | self._sync_bn = net_cfg["sync_bn"]
11 | self._num_classes = net_cfg["num_classes"]
12 |
13 | self.encoder = self._build_encoder(net_cfg["encoder"])
14 | self.decoder = self._build_decoder(net_cfg["decoder"])
15 |
16 | self._use_auxloss = True if net_cfg.get("aux_loss", False) else False
17 | if self._use_auxloss:
18 | cfg_aux = net_cfg["aux_loss"]
19 | self.loss_weight = cfg_aux["loss_weight"]
20 | self.auxor = Aux_Module(
21 | cfg_aux["aux_plane"], self._num_classes, self._sync_bn
22 | )
23 |
24 | def _build_encoder(self, enc_cfg):
25 | enc_cfg["kwargs"].update({"sync_bn": self._sync_bn})
26 | pretrained_model_url = enc_cfg["pretrain"]
27 | encoder = self._build_module(enc_cfg["type"], enc_cfg["kwargs"], pretrain_model_url=pretrained_model_url)
28 | return encoder
29 |
30 | def _build_decoder(self, dec_cfg):
31 | dec_cfg["kwargs"].update(
32 | {
33 | "in_planes": self.encoder.get_outplanes(),
34 | "sync_bn": self._sync_bn,
35 | "num_classes": self._num_classes,
36 | }
37 | )
38 | decoder = self._build_module(dec_cfg["type"], dec_cfg["kwargs"])
39 | return decoder
40 |
41 | def _build_module(self, mtype, kwargs, pretrain_model_url=None):
42 | module_name, class_name = mtype.rsplit(".", 1)
43 | module = importlib.import_module(module_name)
44 | cls = getattr(module, class_name)
45 | if pretrain_model_url is None:
46 | return cls(**kwargs)
47 | else:
48 | return cls(pretrain_model_url=pretrain_model_url, **kwargs)
49 |
50 | def forward(self, x, flag_use_fdrop=False):
51 | h, w = x.shape[-2:]
52 | if self._use_auxloss:
53 | f1, f2, feat1, feat2 = self.encoder(x)
54 | outs = self.decoder([f1, f2, feat1, feat2])
55 | pred_aux = self.auxor(feat1)
56 |
57 | # upsampling
58 | outs = F.interpolate(outs, (h, w), mode="bilinear", align_corners=True)
59 | pred_aux = F.interpolate(pred_aux, (h, w), mode="bilinear", align_corners=True)
60 |
61 | return outs, pred_aux
62 | else:
63 | if flag_use_fdrop:
64 | f1, f2, feat1, feat2 = self.encoder(x)
65 | f1 = nn.Dropout2d(0.5)(f1)
66 | feat2 = nn.Dropout2d(0.5)(feat2)
67 | outs = self.decoder([f1, f2, feat1, feat2])
68 | else:
69 | feat = self.encoder(x)
70 | outs = self.decoder(feat)
71 |
72 | outs = F.interpolate(outs, (h, w), mode="bilinear", align_corners=True)
73 |
74 | return outs, None
75 |
--------------------------------------------------------------------------------
/augseg/models/resnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | import os, sys
5 |
6 | __all__ = [
7 | "ResNet",
8 | "resnet18",
9 | "resnet34",
10 | "resnet50",
11 | "resnet101",
12 | "resnet152",
13 | ]
14 |
15 | # HOME_NAME = "SSS"
16 | # cur_path_file = os.getcwd()
17 | # lst_dir = cur_path_file.split(os.path.sep)
18 | # CODE_DIR = os.path.sep.join(lst_dir[:lst_dir.index(HOME_NAME)+1])
19 |
20 | model_urls = {
21 | "resnet18": "/path/to/resnet18.pth",
22 | "resnet34": "/path/to/resnet34.pth",
23 | "resnet50": "/path/to/resnet50.pth",
24 | "resnet101": "/path/to/resnet101.pth",
25 | # "resnet50": os.path.join(CODE_DIR, "pretrained", "resnet50.pth"),
26 | # "resnet101": os.path.join(CODE_DIR, "pretrained", "resnet101.pth"),
27 | "resnet152": "/path/to/resnet152.pth",
28 | }
29 |
30 |
31 | def get_syncbn():
32 | # return nn.BatchNorm2d
33 | return nn.SyncBatchNorm
34 |
35 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
36 | """3x3 convolution with padding"""
37 | return nn.Conv2d(
38 | in_planes,
39 | out_planes,
40 | kernel_size=3,
41 | stride=stride,
42 | padding=dilation,
43 | groups=groups,
44 | bias=False,
45 | dilation=dilation,
46 | )
47 |
48 |
49 | def conv1x1(in_planes, out_planes, stride=1):
50 | """1x1 convolution"""
51 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
52 |
53 |
54 | class BasicBlock(nn.Module):
55 | expansion = 1
56 |
57 | def __init__(
58 | self,
59 | inplanes,
60 | planes,
61 | stride=1,
62 | downsample=None,
63 | groups=1,
64 | base_width=64,
65 | dilation=1,
66 | norm_layer=None,
67 | ):
68 | super(BasicBlock, self).__init__()
69 | if norm_layer is None:
70 | norm_layer = nn.BatchNorm2d
71 | if groups != 1 or base_width != 64:
72 | raise ValueError("BasicBlock only supports groups=1 and base_width=64")
73 | if dilation > 1:
74 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
75 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
76 | self.conv1 = conv3x3(inplanes, planes, stride)
77 | self.bn1 = norm_layer(planes)
78 | self.relu = nn.ReLU(inplace=True)
79 | self.conv2 = conv3x3(planes, planes)
80 | self.bn2 = norm_layer(planes)
81 | self.downsample = downsample
82 | self.stride = stride
83 |
84 | def forward(self, x):
85 | identity = x
86 |
87 | out = self.conv1(x)
88 | out = self.bn1(out)
89 | out = self.relu(out)
90 |
91 | out = self.conv2(out)
92 | out = self.bn2(out)
93 |
94 | if self.downsample is not None:
95 | identity = self.downsample(x)
96 |
97 | out += identity
98 | out = self.relu(out)
99 |
100 | return out
101 |
102 |
103 | class Bottleneck(nn.Module):
104 | expansion = 4
105 |
106 | def __init__(
107 | self,
108 | inplanes,
109 | planes,
110 | stride=1,
111 | downsample=None,
112 | groups=1,
113 | base_width=64,
114 | dilation=1,
115 | norm_layer=nn.BatchNorm2d,
116 | ):
117 | super(Bottleneck, self).__init__()
118 | width = int(planes * (base_width / 64.0)) * groups
119 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
120 | self.conv1 = conv1x1(inplanes, width)
121 | self.bn1 = norm_layer(width)
122 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
123 | self.bn2 = norm_layer(width)
124 | self.conv3 = conv1x1(width, planes * self.expansion)
125 | self.bn3 = norm_layer(planes * self.expansion)
126 | self.relu = nn.ReLU(inplace=True)
127 | self.downsample = downsample
128 | self.stride = stride
129 |
130 | def forward(self, x):
131 | identity = x
132 |
133 | out = self.conv1(x)
134 | out = self.bn1(out)
135 | out = self.relu(out)
136 |
137 | out = self.conv2(out)
138 | out = self.bn2(out)
139 | out = self.relu(out)
140 |
141 | out = self.conv3(out)
142 | out = self.bn3(out)
143 |
144 | if self.downsample is not None:
145 | identity = self.downsample(x)
146 |
147 | out += identity
148 | out = self.relu(out)
149 |
150 | return out
151 |
152 |
153 | class ResNet(nn.Module):
154 | def __init__(
155 | self,
156 | block,
157 | layers,
158 | zero_init_residual=False,
159 | groups=1,
160 | width_per_group=64,
161 | replace_stride_with_dilation=[False, False, False],
162 | sync_bn=False,
163 | multi_grid=False,
164 | ):
165 | super(ResNet, self).__init__()
166 |
167 | norm_layer = get_syncbn() if sync_bn else nn.BatchNorm2d
168 | self._norm_layer = norm_layer
169 |
170 | self.inplanes = 128
171 | self.dilation = 1
172 |
173 | if replace_stride_with_dilation is None:
174 | # each element in the tuple indicates if we should replace
175 | # the 2x2 stride with a dilated convolution instead
176 | replace_stride_with_dilation = [False, False, False]
177 |
178 | if len(replace_stride_with_dilation) != 3:
179 | raise ValueError(
180 | "replace_stride_with_dilation should be None "
181 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)
182 | )
183 |
184 | self.groups = groups
185 | self.base_width = width_per_group
186 | self.conv1 = nn.Sequential(
187 | conv3x3(3, 64, stride=2),
188 | norm_layer(64),
189 | nn.ReLU(inplace=True),
190 | conv3x3(64, 64),
191 | norm_layer(64),
192 | nn.ReLU(inplace=True),
193 | conv3x3(64, self.inplanes),
194 | )
195 | self.bn1 = norm_layer(self.inplanes)
196 | self.relu = nn.ReLU(inplace=True)
197 | self.maxpool = nn.MaxPool2d(
198 | kernel_size=3, stride=2, padding=1, ceil_mode=True
199 | ) # change
200 |
201 | self.layer1 = self._make_layer(block, 64, layers[0])
202 | self.layer2 = self._make_layer(
203 | block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]
204 | )
205 | self.layer3 = self._make_layer(
206 | block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]
207 | )
208 | self.layer4 = self._make_layer(
209 | block,
210 | 512,
211 | layers[3],
212 | stride=2,
213 | dilate=replace_stride_with_dilation[2],
214 | multi_grid=multi_grid,
215 | )
216 |
217 | for m in self.modules():
218 | if isinstance(m, nn.Conv2d):
219 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
220 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm, nn.SyncBatchNorm)):
221 | nn.init.constant_(m.weight, 1)
222 | nn.init.constant_(m.bias, 0)
223 |
224 | # Zero-initialize the last BN in each residual branch,
225 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
226 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
227 | if zero_init_residual:
228 | for m in self.modules():
229 | if isinstance(m, Bottleneck):
230 | nn.init.constant_(m.bn3.weight, 0)
231 | elif isinstance(m, BasicBlock):
232 | nn.init.constant_(m.bn2.weight, 0)
233 |
234 | def get_outplanes(self):
235 | return self.inplanes
236 |
237 | def get_auxplanes(self):
238 | return self.inplanes // 2
239 |
240 | def _make_layer(
241 | self, block, planes, blocks, stride=1, dilate=False, multi_grid=False
242 | ):
243 | norm_layer = self._norm_layer
244 | downsample = None
245 | previous_dilation = self.dilation
246 | if dilate:
247 | self.dilation *= stride
248 | stride = 1
249 | if stride != 1 or self.inplanes != planes * block.expansion:
250 | downsample = nn.Sequential(
251 | conv1x1(self.inplanes, planes * block.expansion, stride),
252 | norm_layer(planes * block.expansion),
253 | )
254 |
255 | grids = [1] * blocks
256 | if multi_grid:
257 | grids = [2, 2, 4]
258 |
259 | layers = []
260 | layers.append(
261 | block(
262 | self.inplanes,
263 | planes,
264 | stride,
265 | downsample,
266 | self.groups,
267 | self.base_width,
268 | previous_dilation * grids[0],
269 | norm_layer,
270 | )
271 | )
272 | self.inplanes = planes * block.expansion
273 | for i in range(1, blocks):
274 | layers.append(
275 | block(
276 | self.inplanes,
277 | planes,
278 | groups=self.groups,
279 | base_width=self.base_width,
280 | dilation=self.dilation * grids[i],
281 | norm_layer=norm_layer,
282 | )
283 | )
284 |
285 | return nn.Sequential(*layers)
286 |
287 | def forward(self, x):
288 | x = self.relu(self.bn1(self.conv1(x)))
289 | x = self.maxpool(x)
290 |
291 | x1 = self.layer1(x)
292 | x2 = self.layer2(x1)
293 | x3 = self.layer3(x2)
294 | x4 = self.layer4(x3)
295 | return [x1, x2, x3, x4]
296 |
297 |
298 | def resnet18(pretrained=True, pretrain_model_url=None, **kwargs):
299 | """Constructs a ResNet-18 model.
300 |
301 | Args:
302 | pretrained (bool): If True, returns a model pre-trained on ImageNet
303 | """
304 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
305 | if pretrained:
306 | if pretrain_model_url is None:
307 | model_url = model_urls["resnet18"]
308 | else:
309 | model_url = pretrain_model_url
310 | state_dict = torch.load(model_url)
311 |
312 | missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
313 | print(
314 | f"[Info] Load ImageNet pretrain from '{model_url}'",
315 | "\nmissing_keys: ",
316 | missing_keys,
317 | "\nunexpected_keys: ",
318 | unexpected_keys,
319 | )
320 | return model
321 |
322 |
323 | def resnet34(pretrained=True, pretrain_model_url=None, **kwargs):
324 | """Constructs a ResNet-34 model.
325 |
326 | Args:
327 | pretrained (bool): If True, returns a model pre-trained on ImageNet
328 | """
329 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
330 | if pretrained:
331 | if pretrain_model_url is None:
332 | model_url = model_urls["resnet34"]
333 | else:
334 | model_url = pretrain_model_url
335 |
336 | state_dict = torch.load(model_url)
337 |
338 | missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
339 | print(
340 | f"[Info] Load ImageNet pretrain from '{model_url}'",
341 | "\nmissing_keys: ",
342 | missing_keys,
343 | "\nunexpected_keys: ",
344 | unexpected_keys,
345 | )
346 | return model
347 |
348 |
349 | def resnet50(pretrained=True, pretrain_model_url=None, **kwargs):
350 | """Constructs a ResNet-50 model.
351 |
352 | Args:
353 | pretrained (bool): If True, returns a model pre-trained on ImageNet
354 | """
355 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
356 | if pretrained:
357 | if pretrain_model_url is None:
358 | model_url = model_urls["resnet50"]
359 | else:
360 | model_url = pretrain_model_url
361 |
362 | state_dict = torch.load(model_url)
363 |
364 | missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
365 | print(
366 | f"[Info] Load ImageNet pretrain from '{model_url}'",
367 | "\nmissing_keys: ",
368 | missing_keys,
369 | "\nunexpected_keys: ",
370 | unexpected_keys,
371 | )
372 | return model
373 |
374 |
375 | def resnet101(pretrained=True, pretrain_model_url=None, **kwargs):
376 | """Constructs a ResNet-101 model.
377 |
378 | Args:
379 | pretrained (bool): If True, returns a model pre-trained on ImageNet
380 | """
381 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
382 | if pretrained:
383 | if pretrain_model_url is None:
384 | model_url = model_urls["resnet101"]
385 | else:
386 | model_url = pretrain_model_url
387 | state_dict = torch.load(model_url)
388 |
389 | missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
390 | print(
391 | f"[Info] Load ImageNet pretrain from '{model_url}'",
392 | "\nmissing_keys: ",
393 | missing_keys,
394 | "\nunexpected_keys: ",
395 | unexpected_keys,
396 | )
397 | return model
398 |
399 |
400 | def resnet152(pretrained=True, pretrain_model_url=None, **kwargs):
401 | """Constructs a ResNet-152 model.
402 |
403 | Args:
404 | pretrained (bool): If True, returns a model pre-trained on ImageNet
405 | """
406 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
407 | if pretrained:
408 | if pretrain_model_url is None:
409 | model_url = model_urls["resnet152"]
410 | else:
411 | model_url = pretrain_model_url
412 | state_dict = torch.load(model_url)
413 |
414 | missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
415 | print(
416 | f"[Info] Load ImageNet pretrain from '{model_url}'",
417 | "\nmissing_keys: ",
418 | missing_keys,
419 | "\nunexpected_keys: ",
420 | unexpected_keys,
421 | )
422 | return model
423 |
--------------------------------------------------------------------------------
/augseg/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZhenZHAO/AugSeg/7ead8705cf6ce6c9234f52c414ec2237ed6743cd/augseg/utils/__init__.py
--------------------------------------------------------------------------------
/augseg/utils/dist_helper.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import math
3 | import os
4 | import random
5 | import subprocess
6 |
7 | import numpy as np
8 | import torch
9 | import torch.distributed as dist
10 | from torch.utils.data.sampler import Sampler
11 |
12 |
13 | def setup_distributed(backend="nccl", port=None):
14 | """AdaHessian Optimizer
15 | Lifted from https://github.com/BIGBALLON/distribuuuu/blob/master/distribuuuu/utils.py
16 | Originally licensed MIT, Copyright (c) 2020 Wei Li
17 | """
18 | num_gpus = torch.cuda.device_count()
19 |
20 | if "SLURM_JOB_ID" in os.environ:
21 | rank = int(os.environ["SLURM_PROCID"])
22 | world_size = int(os.environ["SLURM_NTASKS"])
23 | node_list = os.environ["SLURM_NODELIST"]
24 | addr = subprocess.getoutput(f"scontrol show hostname {node_list} | head -n1")
25 | # specify master port
26 | if port is not None:
27 | os.environ["MASTER_PORT"] = str(port)
28 | elif "MASTER_PORT" not in os.environ:
29 | os.environ["MASTER_PORT"] = "10685"
30 | if "MASTER_ADDR" not in os.environ:
31 | os.environ["MASTER_ADDR"] = addr
32 | os.environ["WORLD_SIZE"] = str(world_size)
33 | os.environ["LOCAL_RANK"] = str(rank % num_gpus)
34 | os.environ["RANK"] = str(rank)
35 | else:
36 | rank = int(os.environ["RANK"])
37 | world_size = int(os.environ["WORLD_SIZE"])
38 |
39 | torch.cuda.set_device(rank % num_gpus)
40 |
41 | dist.init_process_group(
42 | backend=backend,
43 | world_size=world_size,
44 | rank=rank,
45 | )
46 | return rank, world_size
47 |
48 |
49 | def gather_together(data):
50 | world_size = dist.get_world_size()
51 | gather_data = [torch.zeros_like(data).cuda() for _ in range(world_size)]
52 | dist.all_gather(gather_data, data)
53 | return gather_data
54 |
55 |
56 | class DistributedGivenIterationSampler(Sampler):
57 | def __init__(
58 | self, dataset, total_iter, batch_size, world_size=None, rank=None, last_iter=-1
59 | ):
60 | if world_size is None:
61 | world_size = dist.get_world_size()
62 | if rank is None:
63 | rank = dist.get_rank()
64 | assert rank < world_size
65 | self.dataset = dataset
66 | self.total_iter = total_iter
67 | self.batch_size = batch_size
68 | self.world_size = world_size
69 | self.rank = rank
70 | self.last_iter = last_iter
71 |
72 | self.total_size = self.total_iter * self.batch_size
73 |
74 | self.indices = self.gen_new_list()
75 | self.call = 0
76 |
77 | def __iter__(self):
78 | if self.call == 0:
79 | self.call = 1
80 | return iter(self.indices[(self.last_iter + 1) * self.batch_size :])
81 | else:
82 | raise RuntimeError(
83 | "this sampler is not designed to be called more than once!!"
84 | )
85 |
86 | def gen_new_list(self):
87 | # each process shuffle all list with same seed, and pick one piece according to rank
88 | np.random.seed(0)
89 |
90 | all_size = self.total_size * self.world_size
91 | indices = np.arange(len(self.dataset))
92 | indices = indices[:all_size]
93 | num_repeat = (all_size - 1) // indices.shape[0] + 1
94 | indices = np.tile(indices, num_repeat)
95 | indices = indices[:all_size]
96 |
97 | np.random.shuffle(indices)
98 | beg = self.total_size * self.rank
99 | indices = indices[beg : beg + self.total_size]
100 |
101 | assert len(indices) == self.total_size
102 |
103 | return indices
104 |
105 | def __len__(self):
106 | # note here we do not take last iter into consideration, since __len__
107 | # should only be used for displaying, the correct remaining size is
108 | # handled by dataloader
109 | # return self.total_size - (self.last_iter+1)*self.batch_size
110 | return self.total_size
111 |
--------------------------------------------------------------------------------
/augseg/utils/loss_helper.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import scipy.ndimage as nd
3 | import torch
4 | import torch.nn as nn
5 | from torch.nn import functional as F
6 |
7 |
8 | # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
9 | # # # # # 1. get training criterion
10 | # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
11 | def get_criterion(cfg):
12 | cfg_criterion = cfg["criterion"]
13 | aux_weight = (
14 | cfg["net"]["aux_loss"]["loss_weight"]
15 | if cfg["net"].get("aux_loss", False)
16 | else 0
17 | )
18 | ignore_index = cfg["dataset"]["ignore_label"]
19 | if cfg_criterion["type"] == "ohem":
20 | criterion = CriterionOhem(
21 | aux_weight, ignore_index=ignore_index, **cfg_criterion["kwargs"]
22 | )
23 | else:
24 | criterion = Criterion(
25 | aux_weight, ignore_index=ignore_index, **cfg_criterion["kwargs"]
26 | )
27 |
28 | return criterion
29 |
30 |
31 | class Criterion(nn.Module):
32 | def __init__(self, aux_weight, ignore_index=255, use_weight=False):
33 | super(Criterion, self).__init__()
34 | self._aux_weight = aux_weight
35 | self._ignore_index = ignore_index
36 | self.use_weight = use_weight
37 | if not use_weight:
38 | self._criterion = nn.CrossEntropyLoss(ignore_index=ignore_index)
39 | else:
40 | weights = torch.FloatTensor(
41 | [
42 | 0.0,
43 | 0.0,
44 | 0.0,
45 | 1.0,
46 | 1.0,
47 | 1.0,
48 | 1.0,
49 | 0.0,
50 | 0.0,
51 | 1.0,
52 | 0.0,
53 | 0.0,
54 | 1.0,
55 | 0.0,
56 | 1.0,
57 | 0.0,
58 | 1.0,
59 | 1.0,
60 | 1.0,
61 | ]
62 | ).cuda()
63 | self._criterion = nn.CrossEntropyLoss(ignore_index=ignore_index)
64 | self._criterion1 = nn.CrossEntropyLoss(
65 | ignore_index=ignore_index, weight=weights
66 | )
67 |
68 | def forward(self, preds, target):
69 | h, w = target.size(1), target.size(2)
70 | if self._aux_weight > 0: # require aux loss
71 | main_pred, aux_pred = preds
72 | main_h, main_w = main_pred.size(2), main_pred.size(3)
73 | aux_h, aux_w = aux_pred.size(2), aux_pred.size(3)
74 | assert (
75 | len(preds) == 2
76 | and main_h == aux_h
77 | and main_w == aux_w
78 | and main_h == h
79 | and main_w == w
80 | )
81 | if self.use_weight:
82 | loss1 = self._criterion(main_pred, target) + self._criterion1(
83 | main_pred, target
84 | )
85 | else:
86 | loss1 = self._criterion(main_pred, target)
87 | loss2 = self._criterion(aux_pred, target)
88 | loss = loss1 + self._aux_weight * loss2
89 | else:
90 | pred_h, pred_w = preds.size(2), preds.size(3)
91 | assert pred_h == h and pred_w == w
92 | loss = self._criterion(preds, target)
93 | return loss
94 |
95 |
96 | class CriterionOhem(nn.Module):
97 | def __init__(
98 | self,
99 | aux_weight,
100 | thresh=0.7,
101 | min_kept=100000,
102 | ignore_index=255,
103 | use_weight=False,
104 | ):
105 | super(CriterionOhem, self).__init__()
106 | self._aux_weight = aux_weight
107 | self._criterion1 = OhemCrossEntropy2dTensor(
108 | ignore_index, thresh, min_kept, use_weight
109 | )
110 | self._criterion2 = OhemCrossEntropy2dTensor(ignore_index, thresh, min_kept)
111 |
112 | def forward(self, preds, target):
113 | h, w = target.size(1), target.size(2)
114 | if self._aux_weight > 0: # require aux loss
115 | main_pred, aux_pred = preds
116 | main_h, main_w = main_pred.size(2), main_pred.size(3)
117 | aux_h, aux_w = aux_pred.size(2), aux_pred.size(3)
118 | assert (
119 | len(preds) == 2
120 | and main_h == aux_h
121 | and main_w == aux_w
122 | and main_h == h
123 | and main_w == w
124 | )
125 |
126 | loss1 = self._criterion1(main_pred, target)
127 | loss2 = self._criterion2(aux_pred, target)
128 | loss = loss1 + self._aux_weight * loss2
129 | else:
130 | pred_h, pred_w = preds.size(2), preds.size(3)
131 | assert pred_h == h and pred_w == w
132 | loss = self._criterion1(preds, target)
133 | return loss
134 |
135 |
136 | class OhemCrossEntropy2d(nn.Module):
137 | def __init__(self, ignore_label=255, thresh=0.7, min_kept=100000, factor=8):
138 | super(OhemCrossEntropy2d, self).__init__()
139 | self.ignore_label = ignore_label
140 | self.thresh = float(thresh)
141 | self.min_kept = int(min_kept)
142 | self.factor = factor
143 | self.criterion = torch.nn.CrossEntropyLoss(ignore_index=ignore_label)
144 |
145 | def find_threshold(self, np_predict, np_target):
146 | # downsample 1/8
147 | factor = self.factor
148 | predict = nd.zoom(np_predict, (1.0, 1.0, 1.0 / factor, 1.0 / factor), order=1)
149 | target = nd.zoom(np_target, (1.0, 1.0 / factor, 1.0 / factor), order=0)
150 |
151 | n, c, h, w = predict.shape
152 | min_kept = self.min_kept // (
153 | factor * factor
154 | ) # int(self.min_kept_ratio * n * h * w)
155 |
156 | input_label = target.ravel().astype(np.int32)
157 | input_prob = np.rollaxis(predict, 1).reshape((c, -1))
158 |
159 | valid_flag = input_label != self.ignore_label
160 | valid_inds = np.where(valid_flag)[0]
161 | label = input_label[valid_flag]
162 | num_valid = valid_flag.sum()
163 | if min_kept >= num_valid:
164 | threshold = 1.0
165 | elif num_valid > 0:
166 | prob = input_prob[:, valid_flag]
167 | pred = prob[label, np.arange(len(label), dtype=np.int32)]
168 | threshold = self.thresh
169 | if min_kept > 0:
170 | k_th = min(len(pred), min_kept) - 1
171 | new_array = np.partition(pred, k_th)
172 | new_threshold = new_array[k_th]
173 | if new_threshold > self.thresh:
174 | threshold = new_threshold
175 | return threshold
176 |
177 | def generate_new_target(self, predict, target):
178 | np_predict = predict.data.cpu().numpy()
179 | np_target = target.data.cpu().numpy()
180 | n, c, h, w = np_predict.shape
181 |
182 | threshold = self.find_threshold(np_predict, np_target)
183 |
184 | input_label = np_target.ravel().astype(np.int32)
185 | input_prob = np.rollaxis(np_predict, 1).reshape((c, -1))
186 |
187 | valid_flag = input_label != self.ignore_label
188 | valid_inds = np.where(valid_flag)[0]
189 | label = input_label[valid_flag]
190 | num_valid = valid_flag.sum()
191 |
192 | if num_valid > 0:
193 | prob = input_prob[:, valid_flag]
194 | pred = prob[label, np.arange(len(label), dtype=np.int32)]
195 | kept_flag = pred <= threshold
196 | valid_inds = valid_inds[kept_flag]
197 |
198 | label = input_label[valid_inds].copy()
199 | input_label.fill(self.ignore_label)
200 | input_label[valid_inds] = label
201 | new_target = (
202 | torch.from_numpy(input_label.reshape(target.size()))
203 | .long()
204 | .cuda(target.get_device())
205 | )
206 |
207 | return new_target
208 |
209 | def forward(self, predict, target, weight=None):
210 | """
211 | Args:
212 | predict:(n, c, h, w)
213 | target:(n, h, w)
214 | weight (Tensor, optional): a manual rescaling weight given to each class.
215 | If given, has to be a Tensor of size "nclasses"
216 | """
217 | assert not target.requires_grad
218 |
219 | input_prob = F.softmax(predict, 1)
220 | target = self.generate_new_target(input_prob, target)
221 | return self.criterion(predict, target)
222 |
223 |
224 | class OhemCrossEntropy2dTensor(nn.Module):
225 | """
226 | Ohem Cross Entropy Tensor Version
227 | """
228 |
229 | def __init__(
230 | self, ignore_index=255, thresh=0.7, min_kept=256, use_weight=False, reduce=False
231 | ):
232 | super(OhemCrossEntropy2dTensor, self).__init__()
233 | self.ignore_index = ignore_index
234 | self.thresh = float(thresh)
235 | self.min_kept = int(min_kept)
236 | if use_weight:
237 | weight = torch.FloatTensor(
238 | [
239 | 0.8373,
240 | 0.918,
241 | 0.866,
242 | 1.0345,
243 | 1.0166,
244 | 0.9969,
245 | 0.9754,
246 | 1.0489,
247 | 0.8786,
248 | 1.0023,
249 | 0.9539,
250 | 0.9843,
251 | 1.1116,
252 | 0.9037,
253 | 1.0865,
254 | 1.0955,
255 | 1.0865,
256 | 1.1529,
257 | 1.0507,
258 | ]
259 | ).cuda()
260 | # weight = torch.FloatTensor(
261 | # [0.4762, 0.5, 0.4762, 1.4286, 1.1111, 0.4762, 0.8333, 0.5, 0.5, 0.8333, 0.5263, 0.5882,
262 | # 1.4286, 0.5, 3.3333,5.0, 10.0, 2.5, 0.8333]).cuda()
263 | self.criterion = torch.nn.CrossEntropyLoss(
264 | reduction="mean", weight=weight, ignore_index=ignore_index
265 | )
266 | elif reduce:
267 | self.criterion = torch.nn.CrossEntropyLoss(
268 | reduction="none", ignore_index=ignore_index
269 | )
270 | else:
271 | self.criterion = torch.nn.CrossEntropyLoss(
272 | reduction="mean", ignore_index=ignore_index
273 | )
274 |
275 | def forward(self, pred, target):
276 | b, c, h, w = pred.size()
277 | target = target.view(-1)
278 | valid_mask = target.ne(self.ignore_index)
279 | target = target * valid_mask.long()
280 | num_valid = valid_mask.sum()
281 |
282 | prob = F.softmax(pred, dim=1)
283 | prob = (prob.transpose(0, 1)).reshape(c, -1)
284 |
285 | if self.min_kept > num_valid:
286 | pass
287 | # print('Labels: {}'.format(num_valid))
288 | elif num_valid > 0:
289 | prob = prob.masked_fill_(~valid_mask, 1)
290 | mask_prob = prob[target, torch.arange(len(target), dtype=torch.long)]
291 | threshold = self.thresh
292 | if self.min_kept > 0:
293 | _, index = mask_prob.sort()
294 | threshold_index = index[min(len(index), self.min_kept) - 1]
295 | if mask_prob[threshold_index] > self.thresh:
296 | threshold = mask_prob[threshold_index]
297 | kept_mask = mask_prob.le(threshold)
298 | target = target * kept_mask.long()
299 | valid_mask = valid_mask * kept_mask
300 |
301 | target = target.masked_fill_(~valid_mask, self.ignore_index)
302 | target = target.view(b, h, w)
303 |
304 | return self.criterion(pred, target)
305 |
306 |
307 | # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
308 | # # # # # 2. calculate unsupervised loss
309 | # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
310 | def compute_unsupervised_loss_by_threshold(predict, target, logits, thresh=0.95):
311 | batch_size, num_class, h, w = predict.shape
312 | thresh_mask = logits.ge(thresh).bool() * (target != 255).bool()
313 | target[~thresh_mask] = 255
314 | loss = F.cross_entropy(predict, target, ignore_index=255, reduction="none")
315 | return loss.mean(), thresh_mask.float().mean()
316 |
317 |
318 | def compute_unsupervised_loss_by_threshold_hardness(predict, target, logits, thresh=0.95, hardness_tensor=None):
319 | batch_size, num_class, h, w = predict.shape
320 | thresh_mask = logits.ge(thresh).bool() * (target != 255).bool()
321 | target[~thresh_mask] = 255
322 | loss = F.cross_entropy(predict, target, ignore_index=255, reduction="none")
323 | if hardness_tensor is None:
324 | return loss.mean()
325 | loss = loss.mean(dim=[1,2])
326 | assert loss.shape == hardness_tensor.shape, "wrong hardness calculation!"
327 | loss *= hardness_tensor
328 | return loss.mean()
329 |
--------------------------------------------------------------------------------
/augseg/utils/lr_helper.py:
--------------------------------------------------------------------------------
1 | """Learning Rate Schedulers"""
2 | from __future__ import division
3 |
4 | import copy
5 | import logging
6 | import warnings
7 | from math import cos, pi
8 |
9 | import torch.optim as optim
10 |
11 |
12 | def get_optimizer(parms, cfg_optim):
13 | """
14 | Get the optimizer
15 | """
16 | optim_type = cfg_optim["type"]
17 | optim_kwargs = cfg_optim["kwargs"]
18 | if optim_type == "SGD":
19 | optimizer = optim.SGD(parms, **optim_kwargs)
20 | elif optim_type == "adam":
21 | optimizer = optim.Adam(parms, **optim_kwargs)
22 | else:
23 | optimizer = None
24 |
25 | assert optimizer is not None, "optimizer type is not supported by LightSeg"
26 |
27 | return optimizer
28 |
29 |
30 | def get_scheduler(cfg_trainer, len_data, optimizer, start_epoch=0, use_iteration=False):
31 | epochs = (
32 | cfg_trainer["epochs"] if not use_iteration else 1
33 | ) # if use_iteration = True, only one epoch be use
34 | lr_mode = cfg_trainer["lr_scheduler"]["mode"]
35 | lr_args = cfg_trainer["lr_scheduler"]["kwargs"]
36 | lr_scheduler = LRScheduler(
37 | lr_mode, lr_args, len_data, optimizer, epochs, start_epoch
38 | )
39 | return lr_scheduler
40 |
41 |
42 | class LRScheduler(object):
43 | def __init__(self, mode, lr_args, data_size, optimizer, num_epochs, start_epochs):
44 | super(LRScheduler, self).__init__()
45 | logger = logging.getLogger("global")
46 |
47 | assert mode in ["multistep", "poly", "cosine"]
48 | self.mode = mode
49 | self.optimizer = optimizer
50 | self.data_size = data_size
51 |
52 | self.cur_iter = start_epochs * data_size
53 | self.max_iter = num_epochs * data_size
54 |
55 | # set learning rate
56 | self.base_lr = [
57 | param_group["lr"] for param_group in self.optimizer.param_groups
58 | ]
59 | self.cur_lr = [lr for lr in self.base_lr]
60 |
61 | # poly kwargs
62 | # TODO
63 | if mode == "poly":
64 | self.power = lr_args["power"] if lr_args.get("power", False) else 0.9
65 | # logger.info("The kwargs for lr scheduler: {}".format(self.power))
66 | if mode == "milestones":
67 | default_mist = list(range(0, num_epochs, num_epochs // 3))[1:]
68 | self.milestones = (
69 | lr_args["milestones"]
70 | if lr_args.get("milestones", False)
71 | else default_mist
72 | )
73 | # logger.info("The kwargs for lr scheduler: {}".format(self.milestones))
74 | if mode == "cosine":
75 | self.targetlr = lr_args["targetlr"]
76 | # logger.info("The kwargs for lr scheduler: {}".format(self.targetlr))
77 |
78 | def step(self):
79 | self._step()
80 | self.update_lr()
81 | self.cur_iter += 1
82 |
83 | def _step(self):
84 | if self.mode == "step":
85 | epoch = self.cur_iter // self.data_size
86 | power = sum([1 for s in self.milestones if s <= epoch])
87 | for i, lr in enumerate(self.base_lr):
88 | adj_lr = lr * pow(0.1, power)
89 | self.cur_lr[i] = adj_lr
90 | elif self.mode == "poly":
91 | for i, lr in enumerate(self.base_lr):
92 | adj_lr = lr * (
93 | (1 - float(self.cur_iter) / self.max_iter) ** (self.power)
94 | )
95 | self.cur_lr[i] = adj_lr
96 | elif self.mode == "cosine":
97 | for i, lr in enumerate(self.base_lr):
98 | adj_lr = (
99 | self.targetlr
100 | + (lr - self.targetlr)
101 | * (1 + cos(pi * self.cur_iter / self.max_iter))
102 | / 2
103 | )
104 | self.cur_lr[i] = adj_lr
105 | else:
106 | raise NotImplementedError
107 |
108 | def get_lr(self):
109 | return self.cur_lr
110 |
111 | def update_lr(self):
112 | for param_group, lr in zip(self.optimizer.param_groups, self.cur_lr):
113 | param_group["lr"] = lr
114 |
--------------------------------------------------------------------------------
/data/splitsall.tar.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZhenZHAO/AugSeg/7ead8705cf6ce6c9234f52c414ec2237ed6743cd/data/splitsall.tar.gz
--------------------------------------------------------------------------------
/docs/Augseg-diagram.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZhenZHAO/AugSeg/7ead8705cf6ce6c9234f52c414ec2237ed6743cd/docs/Augseg-diagram.png
--------------------------------------------------------------------------------
/docs/augs-cutmix.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZhenZHAO/AugSeg/7ead8705cf6ce6c9234f52c414ec2237ed6743cd/docs/augs-cutmix.png
--------------------------------------------------------------------------------
/docs/augs-intensity.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZhenZHAO/AugSeg/7ead8705cf6ce6c9234f52c414ec2237ed6743cd/docs/augs-intensity.png
--------------------------------------------------------------------------------
/exps/zrun_citys/citys_semi744/config_semi.yaml:
--------------------------------------------------------------------------------
1 | # # # # # # # # # # # # # #
2 | # 1. datasets
3 | # # # # # # # # # # # # # #
4 | dataset: # Required.
5 | type: cityscapes_semi
6 | train:
7 | data_root: ./data/cityscapes
8 | data_list: ./data/splitsall/cityscapes/744/labeled.txt
9 | batch_size: 4
10 | flip: True
11 | rand_resize: [0.5, 2.0]
12 | resize_base_size: [1024, 2048]
13 | crop:
14 | type: rand
15 | size: [800, 800]
16 | #---- 1) strong data augs ----#
17 | strong_aug:
18 | num_augs: 3
19 | flag_use_random_num_sampling: True
20 | val:
21 | data_root: ./data/cityscapes
22 | data_list: ./data/splitsall/cityscapes/val.txt
23 | batch_size: 4
24 | crop:
25 | size: [800, 800]
26 | n_sup: 744
27 | workers: 4
28 | mean: [0.485, 0.456, 0.406]
29 | std: [0.229, 0.224, 0.225]
30 | ignore_label: 255
31 |
32 | # # # # # # # # # # # # # #
33 | # 2. training params
34 | # # # # # # # # # # # # # #
35 | trainer: # Required.
36 | epochs: 240
37 | sup_only_epoch: 0 # 0, 1
38 | evaluate_student: False
39 | optimizer:
40 | type: SGD
41 | kwargs:
42 | lr: 0.01 # 4GPUs
43 | momentum: 0.9
44 | weight_decay: 0.0005
45 | lr_scheduler:
46 | mode: poly
47 | kwargs:
48 | power: 0.9
49 | unsupervised:
50 | flag_extra_weak: False
51 | threshold: -0.7 # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ params1: 0.7, 0.75, 0.8, 0.85, 0.9, 0.95
52 | loss_weight: 1.0 # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ params2: 0.7, 1.0, 1.5, 2.0, 3.0, 4.0
53 | #------ 2)cutmix augs ------#
54 | use_cutmix: True
55 | use_cutmix_adaptive: True
56 | use_cutmix_trigger_prob: 1.0 # wide range, but trigger by 100%
57 |
58 | # # # # # # # # # # # # # #
59 | # 3. output files, and loss
60 | # # # # # # # # # # # # # #
61 | saver:
62 | snapshot_dir: checkpoints
63 | pretrain: ''
64 | use_tb: False
65 |
66 | criterion:
67 | type: ohem
68 | kwargs:
69 | thresh: 0.7
70 | min_kept: 100000
71 |
72 | # # # # # # # # # # # # # #
73 | # 4. models
74 | # # # # # # # # # # # # # #
75 | net: # Required.
76 | num_classes: 19
77 | sync_bn: True
78 | ema_decay: 0.996
79 | # aux_loss:
80 | # aux_plane: 1024
81 | # loss_weight: 0.4
82 | encoder:
83 | type: augseg.models.resnet.resnet101
84 | pretrain: ./pretrained/resnet101.pth
85 | # type: augseg.models.resnet.resnet50
86 | # pretrain: ./pretrained/resnet50.pth
87 | kwargs:
88 | zero_init_residual: True
89 | multi_grid: True
90 | replace_stride_with_dilation: [False, False, True] # [output_stride = 16]
91 | # replace_stride_with_dilation: [False, True, True] # [output_stride = 8]
92 | decoder:
93 | type: augseg.models.decoder.dec_deeplabv3_plus
94 | kwargs:
95 | inner_planes: 256
96 | low_conv_planes: 48 # 256
97 | dilations: [6, 12, 18]
98 | # dilations: [12, 24, 36] # [output_stride = 8]
99 |
--------------------------------------------------------------------------------
/exps/zrun_citys/r50_citys_semi744/config_semi.yaml:
--------------------------------------------------------------------------------
1 | # # # # # # # # # # # # # #
2 | # 1. datasets
3 | # # # # # # # # # # # # # #
4 | dataset: # Required.
5 | type: cityscapes_semi
6 | train:
7 | data_root: ./data/cityscapes
8 | data_list: ./data/splitsall/cityscapes/744/labeled.txt
9 | batch_size: 4
10 | flip: True
11 | rand_resize: [0.5, 2.0]
12 | resize_base_size: [1024, 2048]
13 | crop:
14 | type: rand
15 | size: [800, 800]
16 | #---- 1) strong data augs ----#
17 | strong_aug:
18 | num_augs: 3
19 | flag_use_random_num_sampling: True
20 | val:
21 | data_root: ./data/cityscapes
22 | data_list: ./data/splitsall/cityscapes/val.txt
23 | batch_size: 4
24 | crop:
25 | size: [800, 800]
26 | n_sup: 744
27 | workers: 4
28 | mean: [0.485, 0.456, 0.406]
29 | std: [0.229, 0.224, 0.225]
30 | ignore_label: 255
31 |
32 | # # # # # # # # # # # # # #
33 | # 2. training params
34 | # # # # # # # # # # # # # #
35 | trainer: # Required.
36 | epochs: 240
37 | sup_only_epoch: 0 # 0, 1
38 | evaluate_student: False
39 | optimizer:
40 | type: SGD
41 | kwargs:
42 | lr: 0.01 # 4GPUs
43 | momentum: 0.9
44 | weight_decay: 0.0005
45 | lr_scheduler:
46 | mode: poly
47 | kwargs:
48 | power: 0.9
49 | unsupervised:
50 | flag_extra_weak: False
51 | threshold: -0.7 # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ params1: 0.7, 0.75, 0.8, 0.85, 0.9, 0.95
52 | loss_weight: 2.0 # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ params2: 0.7, 1.0, 1.5, 2.0, 3.0, 4.0
53 | #------ 2)cutmix augs ------#
54 | use_cutmix: True
55 | use_cutmix_adaptive: True
56 | use_cutmix_trigger_prob: 1.0 # wide range, but trigger by 100%
57 |
58 | # # # # # # # # # # # # # #
59 | # 3. output files, and loss
60 | # # # # # # # # # # # # # #
61 | saver:
62 | snapshot_dir: checkpoints
63 | pretrain: ''
64 | use_tb: False
65 |
66 | criterion:
67 | type: ohem
68 | kwargs:
69 | thresh: 0.7
70 | min_kept: 100000
71 |
72 | # # # # # # # # # # # # # #
73 | # 4. models
74 | # # # # # # # # # # # # # #
75 | net: # Required.
76 | num_classes: 19
77 | sync_bn: True
78 | ema_decay: 0.996
79 | # aux_loss:
80 | # aux_plane: 1024
81 | # loss_weight: 0.4
82 | encoder:
83 | # type: augseg.models.resnet.resnet101
84 | # pretrain: ./pretrained/resnet101.pth
85 | type: augseg.models.resnet.resnet50
86 | pretrain: ./pretrained/resnet50.pth
87 | kwargs:
88 | zero_init_residual: True
89 | multi_grid: True
90 | replace_stride_with_dilation: [False, False, True] # [output_stride = 16]
91 | # replace_stride_with_dilation: [False, True, True] # [output_stride = 8]
92 | decoder:
93 | type: augseg.models.decoder.dec_deeplabv3_plus
94 | kwargs:
95 | inner_planes: 256
96 | low_conv_planes: 48 # 256
97 | dilations: [6, 12, 18]
98 | # dilations: [12, 24, 36] # [output_stride = 8]
99 |
--------------------------------------------------------------------------------
/exps/zrun_vocs/r50_voc_semi662/config_semi.yaml:
--------------------------------------------------------------------------------
1 | # # # # # # # # # # # # # #
2 | # 1. datasets
3 | # # # # # # # # # # # # # #
4 | dataset: # Required.
5 | type: pascal_semi
6 | train:
7 | data_root: ./data/VOC2012
8 | data_list: ./data/splitsall/pascal/662/labeled.txt
9 | batch_size: 8
10 | flip: True
11 | rand_resize: [0.5, 2.0]
12 | resize_base_size: 500
13 | crop:
14 | type: rand
15 | size: [513, 513]
16 | #---- 1) strong data augs ----#
17 | strong_aug:
18 | num_augs: 3 # !!!!!!!!!!!!!!!!!!!! key param1 (4): 1, 2, 3, 4 !!!!!!!!!!!!!!!!!!!!
19 | flag_use_random_num_sampling: True # !!!!!!!!!!!!!!!!!!!! key param1 (2): True, False !!!!!!!!!!!!!!!!!!!!
20 | val:
21 | data_root: ./data/VOC2012
22 | data_list: ./data/splitsall/pascal/val.txt
23 | batch_size: 1
24 | n_sup: 662
25 | workers: 4
26 | mean: [0.485, 0.456, 0.406]
27 | std: [0.229, 0.224, 0.225]
28 | ignore_label: 255
29 |
30 | # # # # # # # # # # # # # #
31 | # 2. training params
32 | # # # # # # # # # # # # # #
33 | trainer: # Required.
34 | epochs: 80
35 | sup_only_epoch: 0 # 0, 1
36 | evaluate_student: True
37 | optimizer:
38 | type: SGD
39 | kwargs:
40 | lr: 0.001 # 0.001:8*2gpus, 4*4gpus; 0.002:8*4gpus.
41 | momentum: 0.9
42 | weight_decay: 0.0001
43 | lr_scheduler:
44 | mode: poly
45 | kwargs:
46 | power: 0.9
47 | # # # # # # # # # # # # # #
48 | # unsupervised loss
49 | # # # # # # # # # # # # # #
50 | unsupervised:
51 | flag_extra_weak: False
52 | threshold: 0.95 # ============================================================================= <>: [0.7, 0.75, 0.8, 0.85, 0.9, 0.95]
53 | loss_weight: 1.0 # ============================================================================= << abl2-weight >> : [1.0, 1.5, 2.0, 2.5, 3.0, 4.0]
54 | #------ 2)cutmix augs ------#
55 | use_cutmix: True
56 | use_cutmix_adaptive: True
57 | use_cutmix_trigger_prob: 1.0
58 |
59 | # # # # # # # # # # # # # #
60 | # 3. output files, and loss
61 | # # # # # # # # # # # # # #
62 | saver:
63 | snapshot_dir: checkpoints
64 | pretrain: ''
65 | use_tb: False
66 |
67 | criterion:
68 | type: CELoss
69 | kwargs:
70 | use_weight: False
71 |
72 | # # # # # # # # # # # # # #
73 | # 4. models
74 | # # # # # # # # # # # # # #
75 | net: # Required.
76 | num_classes: 21
77 | sync_bn: True
78 | ema_decay: 0.999
79 | encoder:
80 | type: augseg.models.resnet.resnet50
81 | pretrain: ./pretrained/resnet50.pth
82 | # type: augseg.models.resnet.resnet101
83 | # pretrain: ./pretrained/resnet101.pth
84 | kwargs:
85 | zero_init_residual: True
86 | multi_grid: True
87 | replace_stride_with_dilation: [False, False, True] # [output_stride = 16]
88 | # replace_stride_with_dilation: [False, True, True] # [output_stride = 8]
89 | decoder:
90 | type: augseg.models.decoder.dec_deeplabv3_plus
91 | kwargs:
92 | inner_planes: 256
93 | low_conv_planes: 48
94 | dilations: [6, 12, 18] # [output_stride = 16]
95 | # dilations: [12, 24, 36] # [output_stride = 8]
96 |
--------------------------------------------------------------------------------
/exps/zrun_vocs/voc_semi_fine92_r50/config_semi.yaml:
--------------------------------------------------------------------------------
1 | # # # # # # # # # # # # # #
2 | # 1. datasets
3 | # # # # # # # # # # # # # #
4 | dataset: # Required.
5 | type: pascal_semi
6 | train:
7 | data_root: ./data/VOC2012
8 | data_list: ./data/splitsall/pascal/92/labeled.txt
9 | batch_size: 8
10 | flip: True
11 | rand_resize: [0.5, 2.0]
12 | resize_base_size: 500
13 | crop:
14 | type: rand
15 | size: [513, 513]
16 | #---- 1) strong data augs ----#
17 | strong_aug:
18 | num_augs: 3 # !!!!!!!!!!!!!!!!!!!! key param1 (4): 1, 2, 3, 4 !!!!!!!!!!!!!!!!!!!!
19 | flag_use_random_num_sampling: True # !!!!!!!!!!!!!!!!!!!! key param1 (2): True, False !!!!!!!!!!!!!!!!!!!!
20 | val:
21 | data_root: ./data/VOC2012
22 | data_list: ./data/splitsall/pascal/val.txt
23 | batch_size: 1
24 | n_sup: 92
25 | workers: 4
26 | mean: [0.485, 0.456, 0.406]
27 | std: [0.229, 0.224, 0.225]
28 | ignore_label: 255
29 |
30 | # # # # # # # # # # # # # #
31 | # 2. training params
32 | # # # # # # # # # # # # # #
33 | trainer: # Required.
34 | epochs: 80
35 | sup_only_epoch: 0 # 0, 1
36 | evaluate_student: True
37 | optimizer:
38 | type: SGD
39 | kwargs:
40 | lr: 0.001 # 0.001:8*2gpus, 4*4gpus; 0.002:8*4gpus.
41 | momentum: 0.9
42 | weight_decay: 0.0001
43 | lr_scheduler:
44 | mode: poly
45 | kwargs:
46 | power: 0.9
47 | # # # # # # # # # # # # # #
48 | # unsupervised loss
49 | # # # # # # # # # # # # # #
50 | unsupervised:
51 | flag_extra_weak: True
52 | threshold: 0.95 # ============================================================================= <>: [0.7, 0.75, 0.8, 0.85, 0.9, 0.95]
53 | loss_weight: 1.0 # ============================================================================= << abl2-weight >> : [1.0, 1.5, 2.0, 2.5, 3.0, 4.0]
54 | #------ 2)cutmix augs ------#
55 | use_cutmix: True
56 | use_cutmix_adaptive: True
57 | use_cutmix_trigger_prob: 1.0
58 |
59 | # # # # # # # # # # # # # #
60 | # 3. output files, and loss
61 | # # # # # # # # # # # # # #
62 | saver:
63 | snapshot_dir: checkpoints
64 | pretrain: ''
65 | use_tb: False
66 |
67 | criterion:
68 | type: CELoss
69 | kwargs:
70 | use_weight: False
71 |
72 | # # # # # # # # # # # # # #
73 | # 4. models
74 | # # # # # # # # # # # # # #
75 | net: # Required.
76 | num_classes: 21
77 | sync_bn: True
78 | ema_decay: 0.999
79 | encoder:
80 | type: augseg.models.resnet.resnet50
81 | pretrain: ./pretrained/resnet50.pth
82 | # type: augseg.models.resnet.resnet101
83 | # pretrain: ./pretrained/resnet101.pth
84 | kwargs:
85 | zero_init_residual: True
86 | multi_grid: True
87 | replace_stride_with_dilation: [False, False, True] # [output_stride = 16]
88 | # replace_stride_with_dilation: [False, True, True] # [output_stride = 8]
89 | decoder:
90 | type: augseg.models.decoder.dec_deeplabv3_plus
91 | kwargs:
92 | inner_planes: 256
93 | low_conv_planes: 48
94 | dilations: [6, 12, 18] # [output_stride = 16]
95 | # dilations: [12, 24, 36] # [output_stride = 8]
96 |
--------------------------------------------------------------------------------
/exps/zrun_vocs_u2pl/r50_voc_semi662/config_semi.yaml:
--------------------------------------------------------------------------------
1 | # # # # # # # # # # # # # #
2 | # 1. datasets
3 | # # # # # # # # # # # # # #
4 | dataset: # Required.
5 | type: pascal_semi
6 | train:
7 | data_root: ./data/VOC2012
8 | data_list: ./data/splitsall/pascal_u2pl/662/labeled.txt
9 | batch_size: 8
10 | flip: True
11 | rand_resize: [0.5, 2.0]
12 | resize_base_size: 500
13 | crop:
14 | type: rand
15 | size: [513, 513]
16 | #---- 1) strong data augs ----#
17 | strong_aug:
18 | num_augs: 3 # !!!!!!!!!!!!!!!!!!!! key param1 (4): 1, 2, 3, 4 !!!!!!!!!!!!!!!!!!!!
19 | flag_use_random_num_sampling: True # !!!!!!!!!!!!!!!!!!!! key param1 (2): True, False !!!!!!!!!!!!!!!!!!!!
20 | val:
21 | data_root: ./data/VOC2012
22 | data_list: ./data/splitsall/pascal_u2pl/val.txt
23 | batch_size: 1
24 | n_sup: 662
25 | workers: 4
26 | mean: [0.485, 0.456, 0.406]
27 | std: [0.229, 0.224, 0.225]
28 | ignore_label: 255
29 |
30 | # # # # # # # # # # # # # #
31 | # 2. training params
32 | # # # # # # # # # # # # # #
33 | trainer: # Required.
34 | epochs: 80
35 | sup_only_epoch: 0 # 0, 1
36 | evaluate_student: True
37 | optimizer:
38 | type: SGD
39 | kwargs:
40 | lr: 0.001 # 0.001:8*2gpus, 4*4gpus; 0.002:8*4gpus.
41 | momentum: 0.9
42 | weight_decay: 0.0001
43 | lr_scheduler:
44 | mode: poly
45 | kwargs:
46 | power: 0.9
47 | # # # # # # # # # # # # # #
48 | # unsupervised loss
49 | # # # # # # # # # # # # # #
50 | unsupervised:
51 | flag_extra_weak: False
52 | threshold: 0.95 # ============================================================================= <>: [0.7, 0.75, 0.8, 0.85, 0.9, 0.95]
53 | loss_weight: 1.0 # ============================================================================= << abl2-weight >> : [1.0, 1.5, 2.0, 2.5, 3.0, 4.0]
54 | #------ 2)cutmix augs ------#
55 | use_cutmix: True
56 | use_cutmix_adaptive: True
57 | use_cutmix_trigger_prob: 1.0
58 |
59 | # # # # # # # # # # # # # #
60 | # 3. output files, and loss
61 | # # # # # # # # # # # # # #
62 | saver:
63 | snapshot_dir: checkpoints
64 | pretrain: ''
65 | use_tb: False
66 |
67 | criterion:
68 | type: CELoss
69 | kwargs:
70 | use_weight: False
71 |
72 | # # # # # # # # # # # # # #
73 | # 4. models
74 | # # # # # # # # # # # # # #
75 | net: # Required.
76 | num_classes: 21
77 | sync_bn: True
78 | ema_decay: 0.999
79 | encoder:
80 | type: augseg.models.resnet.resnet50
81 | pretrain: ./pretrained/resnet50.pth
82 | # type: augseg.models.resnet.resnet101
83 | # pretrain: ./pretrained/resnet101.pth
84 | kwargs:
85 | zero_init_residual: True
86 | multi_grid: True
87 | replace_stride_with_dilation: [False, False, True] # [output_stride = 16]
88 | # replace_stride_with_dilation: [False, True, True] # [output_stride = 8]
89 | decoder:
90 | type: augseg.models.decoder.dec_deeplabv3_plus
91 | kwargs:
92 | inner_planes: 256
93 | low_conv_planes: 48
94 | dilations: [6, 12, 18] # [output_stride = 16]
95 | # dilations: [12, 24, 36] # [output_stride = 8]
96 |
--------------------------------------------------------------------------------
/exps/zrun_vocs_u2pl/voc_semi662/config_semi.yaml:
--------------------------------------------------------------------------------
1 | # # # # # # # # # # # # # #
2 | # 1. datasets
3 | # # # # # # # # # # # # # #
4 | dataset: # Required.
5 | type: pascal_semi
6 | train:
7 | data_root: ./data/VOC2012
8 | data_list: ./data/splitsall/pascal_u2pl/662/labeled.txt
9 | batch_size: 8
10 | flip: True
11 | rand_resize: [0.5, 2.0]
12 | resize_base_size: 500
13 | crop:
14 | type: rand
15 | size: [513, 513]
16 | #---- 1) strong data augs ----#
17 | strong_aug:
18 | num_augs: 3 # !!!!!!!!!!!!!!!!!!!! key param1 (4): 1, 2, 3, 4 !!!!!!!!!!!!!!!!!!!!
19 | flag_use_random_num_sampling: True # !!!!!!!!!!!!!!!!!!!! key param1 (2): True, False !!!!!!!!!!!!!!!!!!!!
20 | val:
21 | data_root: ./data/VOC2012
22 | data_list: ./data/splitsall/pascal_u2pl/val.txt
23 | batch_size: 1
24 | n_sup: 662
25 | workers: 4
26 | mean: [0.485, 0.456, 0.406]
27 | std: [0.229, 0.224, 0.225]
28 | ignore_label: 255
29 |
30 | # # # # # # # # # # # # # #
31 | # 2. training params
32 | # # # # # # # # # # # # # #
33 | trainer: # Required.
34 | epochs: 80
35 | sup_only_epoch: 0 # 0, 1
36 | evaluate_student: True
37 | optimizer:
38 | type: SGD
39 | kwargs:
40 | lr: 0.001 # 0.001:8*2gpus, 4*4gpus; 0.002:8*4gpus.
41 | momentum: 0.9
42 | weight_decay: 0.0001
43 | lr_scheduler:
44 | mode: poly
45 | kwargs:
46 | power: 0.9
47 | # # # # # # # # # # # # # #
48 | # unsupervised loss
49 | # # # # # # # # # # # # # #
50 | unsupervised:
51 | flag_extra_weak: False
52 | threshold: 0.95 # ============================================================================= <>: [0.7, 0.75, 0.8, 0.85, 0.9, 0.95]
53 | loss_weight: 1.0 # ============================================================================= << abl2-weight >> : [1.0, 1.5, 2.0, 2.5, 3.0, 4.0]
54 | #------ 2)cutmix augs ------#
55 | use_cutmix: True
56 | use_cutmix_adaptive: True
57 | use_cutmix_trigger_prob: 1.0
58 |
59 | # # # # # # # # # # # # # #
60 | # 3. output files, and loss
61 | # # # # # # # # # # # # # #
62 | saver:
63 | snapshot_dir: checkpoints
64 | pretrain: ''
65 | use_tb: False
66 |
67 | criterion:
68 | type: CELoss
69 | kwargs:
70 | use_weight: False
71 |
72 | # # # # # # # # # # # # # #
73 | # 4. models
74 | # # # # # # # # # # # # # #
75 | net: # Required.
76 | num_classes: 21
77 | sync_bn: True
78 | ema_decay: 0.999
79 | encoder:
80 | # type: augseg.models.resnet.resnet50
81 | # pretrain: ./pretrained/resnet50.pth
82 | type: augseg.models.resnet.resnet101
83 | pretrain: ./pretrained/resnet101.pth
84 | kwargs:
85 | zero_init_residual: True
86 | multi_grid: True
87 | replace_stride_with_dilation: [False, False, True] # [output_stride = 16]
88 | # replace_stride_with_dilation: [False, True, True] # [output_stride = 8]
89 | decoder:
90 | type: augseg.models.decoder.dec_deeplabv3_plus
91 | kwargs:
92 | inner_planes: 256
93 | low_conv_planes: 48
94 | dilations: [6, 12, 18] # [output_stride = 16]
95 | # dilations: [12, 24, 36] # [output_stride = 8]
96 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | pyyaml>=5.4.0
2 | easydict
3 | tensorboard
4 | opencv-python
5 | scipy
6 | scikit-learn
7 | scikit-image
8 | einops
9 | h5py
10 | pandas
11 | tqdm
12 | simpleitk
13 | medpy
14 | nibabel
15 |
--------------------------------------------------------------------------------
/scripts/bashtorch:
--------------------------------------------------------------------------------
1 | # bashtorch file
2 | # export TORCH_HOME=./envs/torch
3 | export TORCH_HOME=./envs/a100torch
4 | export PATH=${TORCH_HOME}/bin:$PATH
5 | export LD_LIBRARY_PATH=${TORCH_HOME}/lib/:${LD_LIBRARY_PATH}
6 |
--------------------------------------------------------------------------------
/scripts/zsing_run_citys.sh:
--------------------------------------------------------------------------------
1 | tport=52009
2 | ngpu=4
3 | ROOT=.
4 |
5 | # CUDA_VISIBLE_DEVICES=4,5,6,7 \
6 | python -m torch.distributed.launch \
7 | --nproc_per_node=${ngpu} \
8 | --node_rank=0 \
9 | --master_port=${tport} \
10 | $ROOT/train_semi.py \
11 | --config=$ROOT/exps/zrun_citys/citys_semi744/config_semi.yaml --seed 2 --port ${tport}
12 |
13 | # --- --- ---
14 | # --config=$ROOT/exps/zrun_citys/citys_semi186/config_semi.yaml --seed 2 --port ${tport}
15 | # --config=$ROOT/exps/zrun_citys/citys_semi372/config_semi.yaml --seed 2 --port ${tport}
16 | # --config=$ROOT/exps/zrun_citys/citys_semi744/config_semi.yaml --seed 2 --port ${tport}
17 | # --config=$ROOT/exps/zrun_citys/citys_semi1488/config_semi.yaml --seed 2 --port ${tport}
18 |
19 | # --config=$ROOT/exps/zrun_citys/r50_citys_semi186/config_semi.yaml --seed 2 --port ${tport}
20 | # --config=$ROOT/exps/zrun_citys/r50_citys_semi372/config_semi.yaml --seed 2 --port ${tport}
21 | # --config=$ROOT/exps/zrun_citys/r50_citys_semi744/config_semi.yaml --seed 2 --port ${tport}
22 | # --config=$ROOT/exps/zrun_citys/r50_citys_semi1488/config_semi.yaml --seed 2 --port ${tport}
23 |
--------------------------------------------------------------------------------
/scripts/zsing_run_voc.sh:
--------------------------------------------------------------------------------
1 | tport=53907
2 | ngpu=2
3 | ROOT=.
4 |
5 | # CUDA_VISIBLE_DEVICES=4,5,
6 | python -m torch.distributed.launch \
7 | --nproc_per_node=${ngpu} \
8 | --node_rank=0 \
9 | --master_port=${tport} \
10 | $ROOT/train_semi.py \
11 | --config=$ROOT/exps/zrun_vocs_u2pl/voc_semi662/config_semi.yaml --seed 2 --port ${tport}
12 |
13 | # ---- -----
14 | # --config=$ROOT/exps/zrun_vocs/voc_semi_fine92/config_semi.yaml --seed 2 --port ${tport}
15 | # --config=$ROOT/exps/zrun_vocs/voc_semi_fine183/config_semi.yaml --seed 2 --port ${tport}
16 | # --config=$ROOT/exps/zrun_vocs/voc_semi_fine366/config_semi.yaml --seed 2 --port ${tport}
17 | # --config=$ROOT/exps/zrun_vocs/voc_semi_fine732/config_semi.yaml --seed 2 --port ${tport}
18 | # --config=$ROOT/exps/zrun_vocs/voc_semi_fine1464/config_semi.yaml --seed 2 --port ${tport}
19 |
20 | # --config=$ROOT/exps/zrun_vocs/voc_semi662/config_semi.yaml --seed 2 --port ${tport}
21 | # --config=$ROOT/exps/zrun_vocs/voc_semi1323/config_semi.yaml --seed 2 --port ${tport}
22 | # --config=$ROOT/exps/zrun_vocs/voc_semi2646/config_semi.yaml --seed 2 --port ${tport}
23 | # --config=$ROOT/exps/zrun_vocs/voc_semi5291/config_semi.yaml --seed 2 --port ${tport}
24 |
25 | # --config=$ROOT/exps/zrun_vocs/r50_voc_semi662/config_semi.yaml --seed 2 --port ${tport}
26 | # --config=$ROOT/exps/zrun_vocs/r50_voc_semi1323/config_semi.yaml --seed 2 --port ${tport}
27 | # --config=$ROOT/exps/zrun_vocs/r50_voc_semi2646/config_semi.yaml --seed 2 --port ${tport}
28 | # --config=$ROOT/exps/zrun_vocs/r50_voc_semi5291/config_semi.yaml --seed 2 --port ${tport}
29 |
30 | # ---- ---- u2pl
31 | # --config=$ROOT/exps/zrun_vocs_u2pl/voc_semi662/config_semi.yaml --seed 2 --port ${tport}
32 | # --config=$ROOT/exps/zrun_vocs_u2pl/voc_semi1323/config_semi.yaml --seed 2 --port ${tport}
33 | # --config=$ROOT/exps/zrun_vocs_u2pl/voc_semi2646/config_semi.yaml --seed 2 --port ${tport}
34 | # --config=$ROOT/exps/zrun_vocs_u2pl/voc_semi5291/config_semi.yaml --seed 2 --port ${tport}
35 |
36 | # --config=$ROOT/exps/zrun_vocs_u2pl/r50_voc_semi662/config_semi.yaml --seed 2 --port ${tport}
37 | # --config=$ROOT/exps/zrun_vocs_u2pl/r50_voc_semi1323/config_semi.yaml --seed 2 --port ${tport}
38 | # --config=$ROOT/exps/zrun_vocs_u2pl/r50_voc_semi2646/config_semi.yaml --seed 2 --port ${tport}
39 | # --config=$ROOT/exps/zrun_vocs_u2pl/r50_voc_semi5291/config_semi.yaml --seed 2 --port ${tport}
--------------------------------------------------------------------------------
/single_run.sh:
--------------------------------------------------------------------------------
1 | tport=53907
2 | ngpu=2
3 | ROOT=.
4 |
5 | CUDA_VISIBLE_DEVICES=0,1 \
6 | python -m torch.distributed.launch \
7 | --nproc_per_node=${ngpu} \
8 | --node_rank=0 \
9 | --master_port=${tport} \
10 | $ROOT/train_semi.py \
11 | --config=$ROOT/exps/zrun_vocs_u2pl/voc_semi662/config_semi.yaml --seed 2 --port ${tport}
12 |
--------------------------------------------------------------------------------
/train_semi.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import yaml
3 | import os, sys
4 | import os.path as osp
5 | import pprint
6 | import time
7 | import pickle
8 |
9 | import random
10 | import numpy as np
11 | import pandas as pd
12 | import torch
13 | import torch.backends.cudnn as cudnn
14 | import torch.distributed as dist
15 | import torch.nn.functional as F
16 | from torch.utils.tensorboard import SummaryWriter
17 |
18 | from augseg.dataset.augs_ALIA import cut_mix_label_adaptive
19 | from augseg.dataset.builder import get_loader
20 | from augseg.models.model_helper import ModelBuilder
21 | from augseg.utils.dist_helper import setup_distributed
22 | from augseg.utils.loss_helper import get_criterion, compute_unsupervised_loss_by_threshold
23 | from augseg.utils.lr_helper import get_optimizer, get_scheduler
24 | from augseg.utils.utils import AverageMeter, intersectionAndUnion, load_state
25 | from augseg.utils.utils import init_log, get_rank, get_world_size, set_random_seed, setup_default_logging
26 |
27 | import warnings
28 | warnings.filterwarnings('ignore')
29 |
30 |
31 | def main(in_args):
32 | args = in_args
33 | if args.seed is not None:
34 | # print("set random seed to", args.seed)
35 | set_random_seed(args.seed, deterministic=True)
36 | # set_random_seed(args.seed)
37 | cfg = yaml.load(open(args.config, "r"), Loader=yaml.Loader)
38 | rank, word_size = setup_distributed(port=args.port)
39 |
40 | ###########################
41 | # 1. output settings
42 | ###########################
43 | cfg["exp_path"] = osp.dirname(args.config)
44 | cfg["save_path"] = osp.join(cfg["exp_path"], cfg["saver"]["snapshot_dir"])
45 | cfg["log_path"] = osp.join(cfg["exp_path"], "log")
46 | flag_use_tb = cfg["saver"]["use_tb"]
47 |
48 | if not os.path.exists(cfg["log_path"]) and rank == 0:
49 | os.makedirs(cfg["log_path"])
50 | if not osp.exists(cfg["save_path"]) and rank == 0:
51 | os.makedirs(cfg["save_path"])
52 | # my favorate: logs
53 | if rank == 0:
54 | logger, curr_timestr = setup_default_logging("global", cfg["log_path"])
55 | csv_path = os.path.join(cfg["log_path"], "seg_{}_stat.csv".format(curr_timestr))
56 | else:
57 | logger, curr_timestr = None, ""
58 | csv_path = None
59 | # tensorboard
60 | if rank == 0:
61 | logger.info("{}".format(pprint.pformat(cfg)))
62 | if flag_use_tb:
63 | tb_logger = SummaryWriter(
64 | osp.join(cfg["log_path"], "events_seg",curr_timestr)
65 | )
66 | else:
67 | tb_logger = None
68 | else:
69 | tb_logger = None
70 | # make sure all folders and csv handler are correctly created on rank ==0.
71 | dist.barrier()
72 |
73 | ###########################
74 | # 2. prepare model 1
75 | ###########################
76 | model = ModelBuilder(cfg["net"])
77 | modules_back = [model.encoder]
78 | modules_head = [model.decoder]
79 | if cfg["net"].get("aux_loss", False):
80 | modules_head.append(model.auxor)
81 | if cfg["net"].get("sync_bn", True):
82 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
83 | model.cuda()
84 |
85 | ###########################
86 | # 3. data
87 | ###########################
88 | sup_loss_fn = get_criterion(cfg)
89 | train_loader_sup, train_loader_unsup, val_loader = get_loader(cfg, seed=args.seed)
90 |
91 | ##############################
92 | # 4. optimizer & scheduler
93 | ##############################
94 | cfg_trainer = cfg["trainer"]
95 | cfg_optim = cfg_trainer["optimizer"]
96 | times = 10 if "pascal" in cfg["dataset"]["type"] else 1
97 |
98 | params_list = []
99 | for module in modules_back:
100 | params_list.append(
101 | dict(params=module.parameters(), lr=cfg_optim["kwargs"]["lr"])
102 | )
103 | for module in modules_head:
104 | params_list.append(
105 | dict(params=module.parameters(), lr=cfg_optim["kwargs"]["lr"] * times)
106 | )
107 | optimizer = get_optimizer(params_list, cfg_optim)
108 |
109 | ###########################
110 | # 5. prepare model more
111 | ###########################
112 | local_rank = int(os.environ["LOCAL_RANK"])
113 | model = torch.nn.parallel.DistributedDataParallel(
114 | model,
115 | device_ids=[local_rank],
116 | output_device=local_rank,
117 | find_unused_parameters=False,
118 | )
119 |
120 | # Teacher model -- freeze training
121 | model_teacher = ModelBuilder(cfg["net"])
122 | model_teacher.cuda()
123 | model_teacher = torch.nn.parallel.DistributedDataParallel(
124 | model_teacher,
125 | device_ids=[local_rank],
126 | output_device=local_rank,
127 | find_unused_parameters=False,
128 | )
129 | for p in model_teacher.parameters():
130 | p.requires_grad = False
131 |
132 | # initialize teacher model -- not neccesary if using warmup
133 | with torch.no_grad():
134 | for t_params, s_params in zip(model_teacher.parameters(), model.parameters()):
135 | t_params.data = s_params.data
136 |
137 | ######################################
138 | # 6. resume
139 | ######################################
140 | last_epoch = 0
141 | best_prec = 0
142 | best_epoch = -1
143 | best_prec_stu = 0
144 | best_epoch_stu = -1
145 | # auto_resume > pretrain
146 | if cfg["saver"].get("auto_resume", False):
147 | lastest_model = os.path.join(cfg["save_path"], "ckpt.pth")
148 | if not os.path.exists(lastest_model):
149 | "No checkpoint found in '{}'".format(lastest_model)
150 | else:
151 | print(f"Resume model from: '{lastest_model}'")
152 | best_prec, last_epoch = load_state(
153 | lastest_model, model, optimizer=optimizer, key="model_state"
154 | )
155 | _, _ = load_state(
156 | lastest_model, model_teacher, optimizer=optimizer, key="teacher_state"
157 | )
158 |
159 | optimizer_start = get_optimizer(params_list, cfg_optim)
160 | lr_scheduler = get_scheduler(
161 | cfg_trainer, len(train_loader_sup), optimizer_start, start_epoch=last_epoch
162 | )
163 |
164 | ######################################
165 | # 7. training loop
166 | ######################################
167 | if rank == 0:
168 | logger.info('-------------------------- start training --------------------------')
169 | # Start to train model
170 | for epoch in range(last_epoch, cfg_trainer["epochs"]):
171 | # Training
172 | res_loss_sup, res_loss_unsup = train(
173 | model,
174 | model_teacher,
175 | optimizer,
176 | lr_scheduler,
177 | sup_loss_fn,
178 | train_loader_sup,
179 | train_loader_unsup,
180 | epoch,
181 | tb_logger,
182 | logger,
183 | cfg
184 | )
185 |
186 | # Validation and store checkpoint
187 | if "cityscapes" in cfg["dataset"].get("type", "pascal"):
188 | if epoch % 10 == 0 or epoch > (cfg_trainer["epochs"]-50):
189 | if cfg_trainer.get("evaluate_student", True):
190 | prec_stu = validate_citys(model, val_loader, epoch, logger, cfg)
191 | else:
192 | prec_stu =-1000.0
193 | prec_tea = validate_citys(model_teacher, val_loader, epoch, logger, cfg)
194 | prec = prec_tea
195 | else:
196 | prec_stu = -1000.0
197 | prec_tea = -1000.0
198 | prec = prec_tea
199 | else:
200 | if cfg_trainer.get("evaluate_student", True):
201 | prec_stu = validate(model, val_loader, epoch, logger, cfg)
202 | else:
203 | prec_stu = -1000.0
204 | prec_tea = validate(model_teacher, val_loader, epoch, logger, cfg)
205 | prec = prec_tea
206 |
207 | if rank == 0:
208 | state = {
209 | "epoch": epoch + 1,
210 | "model_state": model.state_dict(),
211 | "optimizer_state": optimizer.state_dict(),
212 | "teacher_state": model_teacher.state_dict(),
213 | "best_miou": best_prec,
214 | }
215 | if prec_stu > best_prec_stu:
216 | best_prec_stu = prec_stu
217 | best_epoch_stu = epoch
218 |
219 | if prec > best_prec:
220 | best_prec = prec
221 | best_epoch = epoch
222 | state["best_miou"] = prec
223 | torch.save(state, osp.join(cfg["save_path"], "ckpt_best.pth"))
224 |
225 | torch.save(state, osp.join(cfg["save_path"], "ckpt.pth"))
226 | # save statistics
227 | tmp_results = {
228 | 'loss_lb': res_loss_sup,
229 | 'loss_ub': res_loss_unsup,
230 | 'miou_stu': prec_stu,
231 | 'miou_tea': prec_tea,
232 | "best": best_prec,
233 | "best-stu":best_prec_stu}
234 | data_frame = pd.DataFrame(data=tmp_results, index=range(epoch, epoch+1))
235 | if epoch > 0 and osp.exists(csv_path):
236 | data_frame.to_csv(csv_path, mode='a', header=None, index_label='epoch')
237 | else:
238 | data_frame.to_csv(csv_path, index_label='epoch')
239 |
240 | logger.info(" <> - Epoch: {}. MIoU: {:.2f}/{:.2f}. \033[34mBest-STU:{:.2f}/{} \033[31mBest-EMA: {:.2f}/{}\033[0m".format(epoch,
241 | prec_stu * 100, prec_tea * 100, best_prec_stu * 100, best_epoch_stu, best_prec * 100, best_epoch))
242 | if tb_logger is not None:
243 | tb_logger.add_scalar("mIoU val", prec, epoch)
244 |
245 |
246 |
247 |
248 |
249 | def train(
250 | model,
251 | model_teacher,
252 | optimizer,
253 | lr_scheduler,
254 | sup_loss_fn,
255 | loader_l,
256 | loader_u,
257 | epoch,
258 | tb_logger,
259 | logger,
260 | cfg,
261 | ):
262 |
263 | ema_decay_origin = cfg["net"]["ema_decay"]
264 | rank, world_size = dist.get_rank(), dist.get_world_size()
265 | flag_extra_weak = cfg["trainer"]["unsupervised"].get("flag_extra_weak", False)
266 | model.train()
267 |
268 | # data loader
269 | loader_l.sampler.set_epoch(epoch)
270 | loader_u.sampler.set_epoch(epoch)
271 | loader_l_iter = iter(loader_l)
272 | loader_u_iter = iter(loader_u)
273 | assert len(loader_l) == len(loader_u), f"labeled data {len(loader_l)} unlabeled data {len(loader_u)}, mixmatch!"
274 |
275 | # metric indicators
276 | sup_losses = AverageMeter(20)
277 | uns_losses = AverageMeter(20)
278 | batch_times = AverageMeter(20)
279 | learning_rates = AverageMeter(20)
280 | meter_high_pseudo_ratio = AverageMeter(20)
281 |
282 | # print freq 8 times for a epoch
283 | print_freq = len(loader_u) // 8 # 8 for semi 4 for sup
284 | print_freq_lst = [i * print_freq for i in range(1,8)]
285 | print_freq_lst.append(len(loader_u) -1)
286 |
287 | # start iterations
288 | model.train()
289 | model_teacher.eval()
290 | for step in range(len(loader_l)):
291 | batch_start = time.time()
292 |
293 | i_iter = epoch * len(loader_l) + step # total iters till now
294 | lr = lr_scheduler.get_lr()
295 | learning_rates.update(lr[0])
296 | lr_scheduler.step() # lr is updated at the iteration level
297 |
298 | # obtain labeled and unlabeled data
299 | _, image_l, label_l = loader_l_iter.next()
300 | image_l, label_l = image_l.cuda(), label_l.cuda()
301 | _, image_u_weak, image_u_aug, _ = loader_u_iter.next()
302 | image_u_weak, image_u_aug = image_u_weak.cuda(), image_u_aug.cuda()
303 |
304 | # start the training
305 | if epoch < cfg["trainer"].get("sup_only_epoch", 0):
306 | # forward
307 | pred, aux = model(image_l)
308 | # supervised loss
309 | if "aux_loss" in cfg["net"].keys():
310 | sup_loss = sup_loss_fn([pred, aux], label_l)
311 | del aux
312 | else:
313 | sup_loss = sup_loss_fn(pred, label_l)
314 | del pred
315 |
316 | # no unlabeled data during the warmup period
317 | unsup_loss = torch.tensor(0.0).cuda()
318 | pseduo_high_ratio = torch.tensor(0.0).cuda()
319 |
320 | else:
321 | # 1. generate pseudo labels
322 | p_threshold = cfg["trainer"]["unsupervised"].get("threshold", 0.95)
323 | with torch.no_grad():
324 | model_teacher.eval()
325 | pred_u, _ = model_teacher(image_u_weak.detach())
326 | pred_u = F.softmax(pred_u, dim=1)
327 | # obtain pseudos
328 | logits_u_aug, label_u_aug = torch.max(pred_u, dim=1)
329 |
330 | # obtain confidence
331 | entropy = -torch.sum(pred_u * torch.log(pred_u + 1e-10), dim=1)
332 | entropy /= np.log(cfg["net"]["num_classes"])
333 | confidence = 1.0 - entropy
334 | confidence = confidence * logits_u_aug
335 | confidence = confidence.mean(dim=[1,2]) # 1*C
336 | confidence = confidence.cpu().numpy().tolist()
337 | # confidence = logits_u_aug.ge(p_threshold).float().mean(dim=[1,2]).cpu().numpy().tolist()
338 | del pred_u
339 | model.train()
340 |
341 | # 2. apply cutmix
342 | trigger_prob = cfg["trainer"]["unsupervised"].get("use_cutmix_trigger_prob", 1.0)
343 | if np.random.uniform(0, 1) < trigger_prob and cfg["trainer"]["unsupervised"].get("use_cutmix", False):
344 | if cfg["trainer"]["unsupervised"].get("use_cutmix_adaptive", False):
345 | image_u_aug, label_u_aug, logits_u_aug = cut_mix_label_adaptive(
346 | image_u_aug,
347 | label_u_aug,
348 | logits_u_aug,
349 | image_l,
350 | label_l,
351 | confidence
352 | )
353 |
354 | # 3. forward concate labeled + unlabeld into student networks
355 | num_labeled = len(image_l)
356 | if flag_extra_weak:
357 | pred_all, aux_all = model(torch.cat((image_l, image_u_weak, image_u_aug), dim=0))
358 | del image_l, image_u_weak, image_u_aug
359 | pred_l= pred_all[:num_labeled]
360 | _, pred_u_strong = pred_all[num_labeled:].chunk(2)
361 | del pred_all
362 | else:
363 | pred_all, aux_all = model(torch.cat((image_l, image_u_aug), dim=0))
364 | del image_l, image_u_weak, image_u_aug
365 | pred_l= pred_all[:num_labeled]
366 | pred_u_strong = pred_all[num_labeled:]
367 | del pred_all
368 |
369 | # 4. supervised loss
370 | if "aux_loss" in cfg["net"].keys():
371 | aux = aux_all[:num_labeled]
372 | sup_loss = sup_loss_fn([pred_l, aux], label_l)
373 | del aux_all, aux
374 | else:
375 | sup_loss = sup_loss_fn(pred_l, label_l)
376 |
377 | # 5. unsupervised loss
378 | unsup_loss, pseduo_high_ratio = compute_unsupervised_loss_by_threshold(
379 | pred_u_strong, label_u_aug.detach(),
380 | logits_u_aug.detach(), thresh=p_threshold)
381 | unsup_loss *= cfg["trainer"]["unsupervised"].get("loss_weight", 1.0)
382 | del pred_l, pred_u_strong, label_u_aug, logits_u_aug
383 |
384 | loss = sup_loss + unsup_loss
385 |
386 | # update student model
387 | optimizer.zero_grad()
388 | loss.backward()
389 | optimizer.step()
390 |
391 | # update teacher model with EMA
392 | with torch.no_grad():
393 | if epoch > cfg["trainer"].get("sup_only_epoch", 0):
394 | ema_decay = min(
395 | 1
396 | - 1
397 | / (
398 | i_iter
399 | - len(loader_l) * cfg["trainer"].get("sup_only_epoch", 0)
400 | + 1
401 | ),
402 | ema_decay_origin,
403 | )
404 | else:
405 | ema_decay = 0.0
406 | # update weight
407 | for param_train, param_eval in zip(model.parameters(), model_teacher.parameters()):
408 | param_eval.data = param_eval.data * ema_decay + param_train.data * (1 - ema_decay)
409 | # update bn
410 | for buffer_train, buffer_eval in zip(model.buffers(), model_teacher.buffers()):
411 | buffer_eval.data = buffer_eval.data * ema_decay + buffer_train.data * (1 - ema_decay)
412 | # buffer_eval.data = buffer_train.data
413 |
414 | # gather all loss from different gpus
415 | reduced_sup_loss = sup_loss.clone().detach()
416 | dist.all_reduce(reduced_sup_loss)
417 | sup_losses.update(reduced_sup_loss.item() / world_size)
418 |
419 | reduced_uns_loss = unsup_loss.clone().detach()
420 | dist.all_reduce(reduced_uns_loss)
421 | uns_losses.update(reduced_uns_loss.item() / world_size)
422 |
423 | reduced_pseudo_high_ratio = pseduo_high_ratio.clone().detach()
424 | dist.all_reduce(reduced_pseudo_high_ratio)
425 | meter_high_pseudo_ratio.update(reduced_pseudo_high_ratio.item() / world_size)
426 |
427 | # 12. print log information
428 | batch_end = time.time()
429 | batch_times.update(batch_end - batch_start)
430 | # if i_iter % 10 == 0 and rank == 0:
431 | if step in print_freq_lst and rank == 0:
432 | logger.info(
433 | "Epoch/Iter [{}:{:3}/{:3}]. "
434 | "Sup:{sup_loss.val:.3f}({sup_loss.avg:.3f}) "
435 | "Uns:{uns_loss.val:.3f}({uns_loss.avg:.3f}) "
436 | "Pseudo:{high_ratio.val:.3f}({high_ratio.avg:.3f}) "
437 | "Time:{batch_time.avg:.2f} "
438 | "LR:{lr.val:.5f}".format(
439 | cfg["trainer"]["epochs"], epoch, step,
440 | sup_loss=sup_losses,
441 | uns_loss=uns_losses,
442 | high_ratio=meter_high_pseudo_ratio,
443 | batch_time=batch_times,
444 | lr=learning_rates,
445 | )
446 | )
447 | if tb_logger is not None:
448 | tb_logger.add_scalar("lr", learning_rates.avg, i_iter)
449 | tb_logger.add_scalar("Sup Loss", sup_losses.avg, i_iter)
450 | tb_logger.add_scalar("Uns Loss", uns_losses.avg, i_iter)
451 | tb_logger.add_scalar("High ratio", meter_high_pseudo_ratio.avg, i_iter)
452 |
453 | return sup_losses.avg, uns_losses.avg
454 |
455 |
456 | def validate(
457 | model,
458 | data_loader,
459 | epoch,
460 | logger,
461 | cfg
462 | ):
463 | model.eval()
464 | data_loader.sampler.set_epoch(epoch)
465 |
466 | num_classes, ignore_label = (
467 | cfg["net"]["num_classes"],
468 | cfg["dataset"]["ignore_label"],
469 | )
470 | rank, world_size = dist.get_rank(), dist.get_world_size()
471 |
472 | intersection_meter = AverageMeter()
473 | union_meter = AverageMeter()
474 |
475 | for step, batch in enumerate(data_loader):
476 | _, images, labels = batch
477 | images = images.cuda()
478 | labels = labels.long().cuda()
479 |
480 | with torch.no_grad():
481 | output, _ = model(images)
482 |
483 | # get the output produced by model_teacher
484 | output = output.data.max(1)[1].cpu().numpy()
485 | target_origin = labels.cpu().numpy()
486 |
487 | # start to calculate miou
488 | intersection, union, target = intersectionAndUnion(
489 | output, target_origin, num_classes, ignore_label
490 | )
491 |
492 | # gather all validation information
493 | reduced_intersection = torch.from_numpy(intersection).cuda()
494 | reduced_union = torch.from_numpy(union).cuda()
495 | reduced_target = torch.from_numpy(target).cuda()
496 |
497 | dist.all_reduce(reduced_intersection)
498 | dist.all_reduce(reduced_union)
499 | dist.all_reduce(reduced_target)
500 |
501 | intersection_meter.update(reduced_intersection.cpu().numpy())
502 | union_meter.update(reduced_union.cpu().numpy())
503 |
504 | iou_class = intersection_meter.sum / (union_meter.sum + 1e-10)
505 | mIoU = np.mean(iou_class)
506 |
507 | if rank == 0:
508 | for i, iou in enumerate(iou_class):
509 | logger.info(" [Test] - class [{}] IoU {:.2f}".format(i, iou * 100))
510 |
511 | return mIoU
512 |
513 |
514 | def validate_citys(
515 | model,
516 | data_loader,
517 | epoch,
518 | logger,
519 | cfg
520 | ):
521 | model.eval()
522 | data_loader.sampler.set_epoch(epoch)
523 | rank, world_size = dist.get_rank(), dist.get_world_size()
524 |
525 | num_classes = cfg["net"]["num_classes"]
526 | ignore_label = cfg["dataset"]["ignore_label"]
527 | if cfg["dataset"]["val"].get("crop", False):
528 | crop_size, _ = cfg["dataset"]["val"]["crop"].get("size", [800, 800])
529 | else:
530 | crop_size = 800
531 |
532 | intersection_meter = AverageMeter()
533 | union_meter = AverageMeter()
534 |
535 | for step, batch in enumerate(data_loader):
536 | _, images, labels = batch
537 | images = images.cuda()
538 | labels = labels.long()
539 | batch_size, h, w = labels.shape
540 |
541 | with torch.no_grad():
542 | final = torch.zeros(batch_size, num_classes, h, w).cuda()
543 | row = 0
544 | while row < h:
545 | col = 0
546 | while col < w:
547 | pred, _ = model(images[:, :, row: min(h, row + crop_size), col: min(w, col + crop_size)])
548 | final[:, :, row: min(h, row + crop_size), col: min(w, col + crop_size)] += pred.softmax(dim=1)
549 | col += int(crop_size * 2 / 3)
550 | row += int(crop_size * 2 / 3)
551 | # get the output
552 | output = final.argmax(dim=1).cpu().numpy()
553 | target_origin = labels.numpy()
554 | # print("="*50, output.shape, output.dtype, target_origin.shape, target_origin.dtype)
555 |
556 | # start to calculate miou
557 | intersection, union, target = intersectionAndUnion(
558 | output, target_origin, num_classes, ignore_label
559 | )
560 | # # return ndarray, b*clas
561 | # print("="*20, type(intersection), type(union), type(target), intersection, union, target)
562 |
563 | # gather all validation information
564 | reduced_intersection = torch.from_numpy(intersection).cuda()
565 | reduced_union = torch.from_numpy(union).cuda()
566 | reduced_target = torch.from_numpy(target).cuda()
567 |
568 | dist.all_reduce(reduced_intersection)
569 | dist.all_reduce(reduced_union)
570 | dist.all_reduce(reduced_target)
571 |
572 | intersection_meter.update(reduced_intersection.cpu().numpy())
573 | union_meter.update(reduced_union.cpu().numpy())
574 |
575 | iou_class = intersection_meter.sum / (union_meter.sum + 1e-10)
576 | mIoU = np.mean(iou_class)
577 |
578 | if rank == 0:
579 | for i, iou in enumerate(iou_class):
580 | logger.info(" [Test] - class [{}] IoU {:.2f}".format(i, iou * 100))
581 | return mIoU
582 |
583 |
584 | if __name__ == "__main__":
585 | parser = argparse.ArgumentParser(description="Semi-Supervised Semantic Segmentation")
586 | parser.add_argument("--config", type=str, default="config.yaml")
587 | parser.add_argument("--local_rank", type=int, default=0)
588 | parser.add_argument("--seed", type=int, default=0)
589 | parser.add_argument("--port", default=None, type=int)
590 | args = parser.parse_args()
591 | main(args)
592 |
--------------------------------------------------------------------------------
/training-logs/3 Semi-Voc-fine/1464/config_semi.yaml:
--------------------------------------------------------------------------------
1 | # # # # # # # # # # # # # #
2 | # 1. datasets
3 | # # # # # # # # # # # # # #
4 | dataset: # Required.
5 | type: pascal_semi
6 | train:
7 | data_root: ./data/VOC2012
8 | data_list: ./data/splitsall/pascal/1464/labeled.txt
9 | batch_size: 8
10 | flip: True
11 | rand_resize: [0.5, 2.0]
12 | resize_base_size: 500
13 | crop:
14 | type: rand
15 | size: [513, 513]
16 | #---- 1) strong data augs ----#
17 | strong_aug:
18 | num_augs: 3 # !!!!!!!!!!!!!!!!!!!! key param1 (4): 1, 2, 3, 4 !!!!!!!!!!!!!!!!!!!!
19 | flag_use_random_num_sampling: True # !!!!!!!!!!!!!!!!!!!! key param1 (2): True, False !!!!!!!!!!!!!!!!!!!!
20 | val:
21 | data_root: ./data/VOC2012
22 | data_list: ./data/splitsall/pascal/val.txt
23 | batch_size: 1
24 | n_sup: 1464
25 | workers: 4
26 | mean: [0.485, 0.456, 0.406]
27 | std: [0.229, 0.224, 0.225]
28 | ignore_label: 255
29 |
30 | # # # # # # # # # # # # # #
31 | # 2. training params
32 | # # # # # # # # # # # # # #
33 | trainer: # Required.
34 | epochs: 80
35 | sup_only_epoch: 0 # 0, 1
36 | evaluate_student: True
37 | optimizer:
38 | type: SGD
39 | kwargs:
40 | lr: 0.001 # 0.001:8*2gpus, 4*4gpus; 0.002:8*4gpus.
41 | momentum: 0.9
42 | weight_decay: 0.0001
43 | lr_scheduler:
44 | mode: poly
45 | kwargs:
46 | power: 0.9
47 | # # # # # # # # # # # # # #
48 | # unsupervised loss
49 | # # # # # # # # # # # # # #
50 | unsupervised:
51 | flag_extra_weak: False
52 | threshold: 0.95 # ============================================================================= <>: [0.7, 0.75, 0.8, 0.85, 0.9, 0.95]
53 | loss_weight: 1.0 # ============================================================================= << abl2-weight >> : [1.0, 1.5, 2.0, 2.5, 3.0, 4.0]
54 | #------ 2)cutmix augs ------#
55 | use_cutmix: True
56 | use_cutmix_adaptive: True
57 | use_cutmix_trigger_prob: 1.0
58 |
59 | # # # # # # # # # # # # # #
60 | # 3. output files, and loss
61 | # # # # # # # # # # # # # #
62 | saver:
63 | snapshot_dir: checkpoints
64 | pretrain: ''
65 | use_tb: False
66 |
67 | criterion:
68 | type: CELoss
69 | kwargs:
70 | use_weight: False
71 |
72 | # # # # # # # # # # # # # #
73 | # 4. models
74 | # # # # # # # # # # # # # #
75 | net: # Required.
76 | num_classes: 21
77 | sync_bn: True
78 | ema_decay: 0.999
79 | encoder:
80 | # type: augseg.models.resnet.resnet50
81 | # pretrain: ./pretrained/resnet50.pth
82 | type: augseg.models.resnet.resnet101
83 | pretrain: ./pretrained/resnet101.pth
84 | kwargs:
85 | zero_init_residual: True
86 | multi_grid: True
87 | replace_stride_with_dilation: [False, False, True] # [output_stride = 16]
88 | # replace_stride_with_dilation: [False, True, True] # [output_stride = 8]
89 | decoder:
90 | type: augseg.models.decoder.dec_deeplabv3_plus
91 | kwargs:
92 | inner_planes: 256
93 | low_conv_planes: 48
94 | dilations: [6, 12, 18] # [output_stride = 16]
95 | # dilations: [12, 24, 36] # [output_stride = 8]
96 |
--------------------------------------------------------------------------------
/training-logs/3 Semi-Voc-fine/183/config_semi.yaml:
--------------------------------------------------------------------------------
1 | # # # # # # # # # # # # # #
2 | # 1. datasets
3 | # # # # # # # # # # # # # #
4 | dataset: # Required.
5 | type: pascal_semi
6 | train:
7 | data_root: ./data/VOC2012
8 | data_list: ./data/splitsall/pascal/183/labeled.txt
9 | batch_size: 8
10 | flip: True
11 | rand_resize: [0.5, 2.0]
12 | resize_base_size: 500
13 | crop:
14 | type: rand
15 | size: [513, 513]
16 | #---- 1) strong data augs ----#
17 | strong_aug:
18 | num_augs: 3 # !!!!!!!!!!!!!!!!!!!! key param1 (4): 1, 2, 3, 4 !!!!!!!!!!!!!!!!!!!!
19 | flag_use_random_num_sampling: True # !!!!!!!!!!!!!!!!!!!! key param1 (2): True, False !!!!!!!!!!!!!!!!!!!!
20 | val:
21 | data_root: ./data/VOC2012
22 | data_list: ./data/splitsall/pascal/val.txt
23 | batch_size: 1
24 | n_sup: 183
25 | workers: 4
26 | mean: [0.485, 0.456, 0.406]
27 | std: [0.229, 0.224, 0.225]
28 | ignore_label: 255
29 |
30 | # # # # # # # # # # # # # #
31 | # 2. training params
32 | # # # # # # # # # # # # # #
33 | trainer: # Required.
34 | epochs: 80
35 | sup_only_epoch: 0 # 0, 1
36 | evaluate_student: True
37 | optimizer:
38 | type: SGD
39 | kwargs:
40 | lr: 0.001 # 0.001:8*2gpus, 4*4gpus; 0.002:8*4gpus.
41 | momentum: 0.9
42 | weight_decay: 0.0001
43 | lr_scheduler:
44 | mode: poly
45 | kwargs:
46 | power: 0.9
47 | # # # # # # # # # # # # # #
48 | # unsupervised loss
49 | # # # # # # # # # # # # # #
50 | unsupervised:
51 | flag_extra_weak: False
52 | threshold: 0.95 # ============================================================================= <>: [0.7, 0.75, 0.8, 0.85, 0.9, 0.95]
53 | loss_weight: 1.0 # ============================================================================= << abl2-weight >> : [1.0, 1.5, 2.0, 2.5, 3.0, 4.0]
54 | #------ 2)cutmix augs ------#
55 | use_cutmix: True
56 | use_cutmix_adaptive: True
57 | use_cutmix_trigger_prob: 1.0
58 |
59 | # # # # # # # # # # # # # #
60 | # 3. output files, and loss
61 | # # # # # # # # # # # # # #
62 | saver:
63 | snapshot_dir: checkpoints
64 | pretrain: ''
65 | use_tb: False
66 |
67 | criterion:
68 | type: CELoss
69 | kwargs:
70 | use_weight: False
71 |
72 | # # # # # # # # # # # # # #
73 | # 4. models
74 | # # # # # # # # # # # # # #
75 | net: # Required.
76 | num_classes: 21
77 | sync_bn: True
78 | ema_decay: 0.999
79 | encoder:
80 | # type: augseg.models.resnet.resnet50
81 | # pretrain: ./pretrained/resnet50.pth
82 | type: augseg.models.resnet.resnet101
83 | pretrain: ./pretrained/resnet101.pth
84 | kwargs:
85 | zero_init_residual: True
86 | multi_grid: True
87 | replace_stride_with_dilation: [False, False, True] # [output_stride = 16]
88 | # replace_stride_with_dilation: [False, True, True] # [output_stride = 8]
89 | decoder:
90 | type: augseg.models.decoder.dec_deeplabv3_plus
91 | kwargs:
92 | inner_planes: 256
93 | low_conv_planes: 48
94 | dilations: [6, 12, 18] # [output_stride = 16]
95 | # dilations: [12, 24, 36] # [output_stride = 8]
96 |
--------------------------------------------------------------------------------
/training-logs/3 Semi-Voc-fine/366/config_semi.yaml:
--------------------------------------------------------------------------------
1 | # # # # # # # # # # # # # #
2 | # 1. datasets
3 | # # # # # # # # # # # # # #
4 | dataset: # Required.
5 | type: pascal_semi
6 | train:
7 | data_root: ./data/VOC2012
8 | data_list: ./data/splitsall/pascal/366/labeled.txt
9 | batch_size: 8
10 | flip: True
11 | rand_resize: [0.5, 2.0]
12 | resize_base_size: 500
13 | crop:
14 | type: rand
15 | size: [513, 513]
16 | #---- 1) strong data augs ----#
17 | strong_aug:
18 | num_augs: 3 # !!!!!!!!!!!!!!!!!!!! key param1 (4): 1, 2, 3, 4 !!!!!!!!!!!!!!!!!!!!
19 | flag_use_random_num_sampling: True # !!!!!!!!!!!!!!!!!!!! key param1 (2): True, False !!!!!!!!!!!!!!!!!!!!
20 | val:
21 | data_root: ./data/VOC2012
22 | data_list: ./data/splitsall/pascal/val.txt
23 | batch_size: 1
24 | n_sup: 366
25 | workers: 4
26 | mean: [0.485, 0.456, 0.406]
27 | std: [0.229, 0.224, 0.225]
28 | ignore_label: 255
29 |
30 | # # # # # # # # # # # # # #
31 | # 2. training params
32 | # # # # # # # # # # # # # #
33 | trainer: # Required.
34 | epochs: 80
35 | sup_only_epoch: 0 # 0, 1
36 | evaluate_student: True
37 | optimizer:
38 | type: SGD
39 | kwargs:
40 | lr: 0.001 # 0.001:8*2gpus, 4*4gpus; 0.002:8*4gpus.
41 | momentum: 0.9
42 | weight_decay: 0.0001
43 | lr_scheduler:
44 | mode: poly
45 | kwargs:
46 | power: 0.9
47 | # # # # # # # # # # # # # #
48 | # unsupervised loss
49 | # # # # # # # # # # # # # #
50 | unsupervised:
51 | flag_extra_weak: False
52 | threshold: 0.95 # ============================================================================= <>: [0.7, 0.75, 0.8, 0.85, 0.9, 0.95]
53 | loss_weight: 1.0 # ============================================================================= << abl2-weight >> : [1.0, 1.5, 2.0, 2.5, 3.0, 4.0]
54 | #------ 2)cutmix augs ------#
55 | use_cutmix: True
56 | use_cutmix_adaptive: True
57 | use_cutmix_trigger_prob: 1.0
58 |
59 | # # # # # # # # # # # # # #
60 | # 3. output files, and loss
61 | # # # # # # # # # # # # # #
62 | saver:
63 | snapshot_dir: checkpoints
64 | pretrain: ''
65 | use_tb: False
66 |
67 | criterion:
68 | type: CELoss
69 | kwargs:
70 | use_weight: False
71 |
72 | # # # # # # # # # # # # # #
73 | # 4. models
74 | # # # # # # # # # # # # # #
75 | net: # Required.
76 | num_classes: 21
77 | sync_bn: True
78 | ema_decay: 0.999
79 | encoder:
80 | # type: augseg.models.resnet.resnet50
81 | # pretrain: ./pretrained/resnet50.pth
82 | type: augseg.models.resnet.resnet101
83 | pretrain: ./pretrained/resnet101.pth
84 | kwargs:
85 | zero_init_residual: True
86 | multi_grid: True
87 | replace_stride_with_dilation: [False, False, True] # [output_stride = 16]
88 | # replace_stride_with_dilation: [False, True, True] # [output_stride = 8]
89 | decoder:
90 | type: augseg.models.decoder.dec_deeplabv3_plus
91 | kwargs:
92 | inner_planes: 256
93 | low_conv_planes: 48
94 | dilations: [6, 12, 18] # [output_stride = 16]
95 | # dilations: [12, 24, 36] # [output_stride = 8]
96 |
--------------------------------------------------------------------------------
/training-logs/3 Semi-Voc-fine/732/config_semi.yaml:
--------------------------------------------------------------------------------
1 | # # # # # # # # # # # # # #
2 | # 1. datasets
3 | # # # # # # # # # # # # # #
4 | dataset: # Required.
5 | type: pascal_semi
6 | train:
7 | data_root: ./data/VOC2012
8 | data_list: ./data/splitsall/pascal/732/labeled.txt
9 | batch_size: 8
10 | flip: True
11 | rand_resize: [0.5, 2.0]
12 | resize_base_size: 500
13 | crop:
14 | type: rand
15 | size: [513, 513]
16 | #---- 1) strong data augs ----#
17 | strong_aug:
18 | num_augs: 3 # !!!!!!!!!!!!!!!!!!!! key param1 (4): 1, 2, 3, 4 !!!!!!!!!!!!!!!!!!!!
19 | flag_use_random_num_sampling: True # !!!!!!!!!!!!!!!!!!!! key param1 (2): True, False !!!!!!!!!!!!!!!!!!!!
20 | val:
21 | data_root: ./data/VOC2012
22 | data_list: ./data/splitsall/pascal/val.txt
23 | batch_size: 1
24 | n_sup: 732
25 | workers: 4
26 | mean: [0.485, 0.456, 0.406]
27 | std: [0.229, 0.224, 0.225]
28 | ignore_label: 255
29 |
30 | # # # # # # # # # # # # # #
31 | # 2. training params
32 | # # # # # # # # # # # # # #
33 | trainer: # Required.
34 | epochs: 80
35 | sup_only_epoch: 0 # 0, 1
36 | evaluate_student: True
37 | optimizer:
38 | type: SGD
39 | kwargs:
40 | lr: 0.001 # 0.001:8*2gpus, 4*4gpus; 0.002:8*4gpus.
41 | momentum: 0.9
42 | weight_decay: 0.0001
43 | lr_scheduler:
44 | mode: poly
45 | kwargs:
46 | power: 0.9
47 | # # # # # # # # # # # # # #
48 | # unsupervised loss
49 | # # # # # # # # # # # # # #
50 | unsupervised:
51 | flag_extra_weak: False
52 | threshold: 0.95 # ============================================================================= <>: [0.7, 0.75, 0.8, 0.85, 0.9, 0.95]
53 | loss_weight: 1.0 # ============================================================================= << abl2-weight >> : [1.0, 1.5, 2.0, 2.5, 3.0, 4.0]
54 | #------ 2)cutmix augs ------#
55 | use_cutmix: True
56 | use_cutmix_adaptive: True
57 | use_cutmix_trigger_prob: 1.0
58 |
59 | # # # # # # # # # # # # # #
60 | # 3. output files, and loss
61 | # # # # # # # # # # # # # #
62 | saver:
63 | snapshot_dir: checkpoints
64 | pretrain: ''
65 | use_tb: False
66 |
67 | criterion:
68 | type: CELoss
69 | kwargs:
70 | use_weight: False
71 |
72 | # # # # # # # # # # # # # #
73 | # 4. models
74 | # # # # # # # # # # # # # #
75 | net: # Required.
76 | num_classes: 21
77 | sync_bn: True
78 | ema_decay: 0.999
79 | encoder:
80 | # type: augseg.models.resnet.resnet50
81 | # pretrain: ./pretrained/resnet50.pth
82 | type: augseg.models.resnet.resnet101
83 | pretrain: ./pretrained/resnet101.pth
84 | kwargs:
85 | zero_init_residual: True
86 | multi_grid: True
87 | replace_stride_with_dilation: [False, False, True] # [output_stride = 16]
88 | # replace_stride_with_dilation: [False, True, True] # [output_stride = 8]
89 | decoder:
90 | type: augseg.models.decoder.dec_deeplabv3_plus
91 | kwargs:
92 | inner_planes: 256
93 | low_conv_planes: 48
94 | dilations: [6, 12, 18] # [output_stride = 16]
95 | # dilations: [12, 24, 36] # [output_stride = 8]
96 |
--------------------------------------------------------------------------------
/training-logs/3 Semi-Voc-fine/92/config_semi.yaml:
--------------------------------------------------------------------------------
1 | # # # # # # # # # # # # # #
2 | # 1. datasets
3 | # # # # # # # # # # # # # #
4 | dataset: # Required.
5 | type: pascal_semi
6 | train:
7 | data_root: ./data/VOC2012
8 | data_list: ./data/splitsall/pascal/92/labeled.txt
9 | batch_size: 8
10 | flip: True
11 | rand_resize: [0.5, 2.0]
12 | resize_base_size: 500
13 | crop:
14 | type: rand
15 | size: [513, 513]
16 | #---- 1) strong data augs ----#
17 | strong_aug:
18 | num_augs: 3 # !!!!!!!!!!!!!!!!!!!! key param1 (4): 1, 2, 3, 4 !!!!!!!!!!!!!!!!!!!!
19 | flag_use_random_num_sampling: True # !!!!!!!!!!!!!!!!!!!! key param1 (2): True, False !!!!!!!!!!!!!!!!!!!!
20 | val:
21 | data_root: ./data/VOC2012
22 | data_list: ./data/splitsall/pascal/val.txt
23 | batch_size: 1
24 | n_sup: 92
25 | workers: 4
26 | mean: [0.485, 0.456, 0.406]
27 | std: [0.229, 0.224, 0.225]
28 | ignore_label: 255
29 |
30 | # # # # # # # # # # # # # #
31 | # 2. training params
32 | # # # # # # # # # # # # # #
33 | trainer: # Required.
34 | epochs: 80
35 | sup_only_epoch: 0 # 0, 1
36 | evaluate_student: True
37 | optimizer:
38 | type: SGD
39 | kwargs:
40 | lr: 0.001 # 0.001:8*2gpus, 4*4gpus; 0.002:8*4gpus.
41 | momentum: 0.9
42 | weight_decay: 0.0001
43 | lr_scheduler:
44 | mode: poly
45 | kwargs:
46 | power: 0.9
47 | # # # # # # # # # # # # # #
48 | # unsupervised loss
49 | # # # # # # # # # # # # # #
50 | unsupervised:
51 | flag_extra_weak: False
52 | threshold: 0.95 # ============================================================================= <>: [0.7, 0.75, 0.8, 0.85, 0.9, 0.95]
53 | loss_weight: 1.0 # ============================================================================= << abl2-weight >> : [1.0, 1.5, 2.0, 2.5, 3.0, 4.0]
54 | #------ 2)cutmix augs ------#
55 | use_cutmix: True
56 | use_cutmix_adaptive: True
57 | use_cutmix_trigger_prob: 1.0
58 |
59 | # # # # # # # # # # # # # #
60 | # 3. output files, and loss
61 | # # # # # # # # # # # # # #
62 | saver:
63 | snapshot_dir: checkpoints
64 | pretrain: ''
65 | use_tb: False
66 |
67 | criterion:
68 | type: CELoss
69 | kwargs:
70 | use_weight: False
71 |
72 | # # # # # # # # # # # # # #
73 | # 4. models
74 | # # # # # # # # # # # # # #
75 | net: # Required.
76 | num_classes: 21
77 | sync_bn: True
78 | ema_decay: 0.999
79 | encoder:
80 | # type: augseg.models.resnet.resnet50
81 | # pretrain: ./pretrained/resnet50.pth
82 | type: augseg.models.resnet.resnet101
83 | pretrain: ./pretrained/resnet101.pth
84 | kwargs:
85 | zero_init_residual: True
86 | multi_grid: True
87 | replace_stride_with_dilation: [False, False, True] # [output_stride = 16]
88 | # replace_stride_with_dilation: [False, True, True] # [output_stride = 8]
89 | decoder:
90 | type: augseg.models.decoder.dec_deeplabv3_plus
91 | kwargs:
92 | inner_planes: 256
93 | low_conv_planes: 48
94 | dilations: [6, 12, 18] # [output_stride = 16]
95 | # dilations: [12, 24, 36] # [output_stride = 8]
96 |
--------------------------------------------------------------------------------
/training-logs/4 Semi-Voc-blender/r101-1323/config_semi.yaml:
--------------------------------------------------------------------------------
1 | # # # # # # # # # # # # # #
2 | # 1. datasets
3 | # # # # # # # # # # # # # #
4 | dataset: # Required.
5 | type: pascal_semi
6 | train:
7 | data_root: ./data/VOC2012
8 | data_list: ./data/splitsall/pascal/1323/labeled.txt
9 | batch_size: 8
10 | flip: True
11 | rand_resize: [0.5, 2.0]
12 | resize_base_size: 500
13 | crop:
14 | type: rand
15 | size: [513, 513]
16 | #---- 1) strong data augs ----#
17 | strong_aug:
18 | num_augs: 3 # !!!!!!!!!!!!!!!!!!!! key param1 (4): 1, 2, 3, 4 !!!!!!!!!!!!!!!!!!!!
19 | flag_use_random_num_sampling: True # !!!!!!!!!!!!!!!!!!!! key param1 (2): True, False !!!!!!!!!!!!!!!!!!!!
20 | val:
21 | data_root: ./data/VOC2012
22 | data_list: ./data/splitsall/pascal/val.txt
23 | batch_size: 1
24 | n_sup: 1323
25 | workers: 4
26 | mean: [0.485, 0.456, 0.406]
27 | std: [0.229, 0.224, 0.225]
28 | ignore_label: 255
29 |
30 | # # # # # # # # # # # # # #
31 | # 2. training params
32 | # # # # # # # # # # # # # #
33 | trainer: # Required.
34 | epochs: 80
35 | sup_only_epoch: 0 # 0, 1
36 | evaluate_student: True
37 | optimizer:
38 | type: SGD
39 | kwargs:
40 | lr: 0.001 # 0.001:8*2gpus, 4*4gpus; 0.002:8*4gpus.
41 | momentum: 0.9
42 | weight_decay: 0.0001
43 | lr_scheduler:
44 | mode: poly
45 | kwargs:
46 | power: 0.9
47 | # # # # # # # # # # # # # #
48 | # unsupervised loss
49 | # # # # # # # # # # # # # #
50 | unsupervised:
51 | flag_extra_weak: False
52 | threshold: 0.95 # ============================================================================= <>: [0.7, 0.75, 0.8, 0.85, 0.9, 0.95]
53 | loss_weight: 1.0 # ============================================================================= << abl2-weight >> : [1.0, 1.5, 2.0, 2.5, 3.0, 4.0]
54 | #------ 2)cutmix augs ------#
55 | use_cutmix: True
56 | use_cutmix_adaptive: True
57 | use_cutmix_trigger_prob: 1.0
58 |
59 | # # # # # # # # # # # # # #
60 | # 3. output files, and loss
61 | # # # # # # # # # # # # # #
62 | saver:
63 | snapshot_dir: checkpoints
64 | pretrain: ''
65 | use_tb: False
66 |
67 | criterion:
68 | type: CELoss
69 | kwargs:
70 | use_weight: False
71 |
72 | # # # # # # # # # # # # # #
73 | # 4. models
74 | # # # # # # # # # # # # # #
75 | net: # Required.
76 | num_classes: 21
77 | sync_bn: True
78 | ema_decay: 0.999
79 | encoder:
80 | # type: augseg.models.resnet.resnet50
81 | # pretrain: ./pretrained/resnet50.pth
82 | type: augseg.models.resnet.resnet101
83 | pretrain: ./pretrained/resnet101.pth
84 | kwargs:
85 | zero_init_residual: True
86 | multi_grid: True
87 | replace_stride_with_dilation: [False, False, True] # [output_stride = 16]
88 | # replace_stride_with_dilation: [False, True, True] # [output_stride = 8]
89 | decoder:
90 | type: augseg.models.decoder.dec_deeplabv3_plus
91 | kwargs:
92 | inner_planes: 256
93 | low_conv_planes: 48
94 | dilations: [6, 12, 18] # [output_stride = 16]
95 | # dilations: [12, 24, 36] # [output_stride = 8]
96 |
--------------------------------------------------------------------------------
/training-logs/4 Semi-Voc-blender/r101-2646/config_semi.yaml:
--------------------------------------------------------------------------------
1 | # # # # # # # # # # # # # #
2 | # 1. datasets
3 | # # # # # # # # # # # # # #
4 | dataset: # Required.
5 | type: pascal_semi
6 | train:
7 | data_root: ./data/VOC2012
8 | data_list: ./data/splitsall/pascal/2646/labeled.txt
9 | batch_size: 8
10 | flip: True
11 | rand_resize: [0.5, 2.0]
12 | resize_base_size: 500
13 | crop:
14 | type: rand
15 | size: [513, 513]
16 | #---- 1) strong data augs ----#
17 | strong_aug:
18 | num_augs: 3 # !!!!!!!!!!!!!!!!!!!! key param1 (4): 1, 2, 3, 4 !!!!!!!!!!!!!!!!!!!!
19 | flag_use_random_num_sampling: True # !!!!!!!!!!!!!!!!!!!! key param1 (2): True, False !!!!!!!!!!!!!!!!!!!!
20 | val:
21 | data_root: ./data/VOC2012
22 | data_list: ./data/splitsall/pascal/val.txt
23 | batch_size: 1
24 | n_sup: 2646
25 | workers: 4
26 | mean: [0.485, 0.456, 0.406]
27 | std: [0.229, 0.224, 0.225]
28 | ignore_label: 255
29 |
30 | # # # # # # # # # # # # # #
31 | # 2. training params
32 | # # # # # # # # # # # # # #
33 | trainer: # Required.
34 | epochs: 80
35 | sup_only_epoch: 0 # 0, 1
36 | evaluate_student: True
37 | optimizer:
38 | type: SGD
39 | kwargs:
40 | lr: 0.001 # 0.001:8*2gpus, 4*4gpus; 0.002:8*4gpus.
41 | momentum: 0.9
42 | weight_decay: 0.0001
43 | lr_scheduler:
44 | mode: poly
45 | kwargs:
46 | power: 0.9
47 | # # # # # # # # # # # # # #
48 | # unsupervised loss
49 | # # # # # # # # # # # # # #
50 | unsupervised:
51 | flag_extra_weak: False
52 | threshold: 0.95 # ============================================================================= <>: [0.7, 0.75, 0.8, 0.85, 0.9, 0.95]
53 | loss_weight: 1.0 # ============================================================================= << abl2-weight >> : [1.0, 1.5, 2.0, 2.5, 3.0, 4.0]
54 | #------ 2)cutmix augs ------#
55 | use_cutmix: True
56 | use_cutmix_adaptive: True
57 | use_cutmix_trigger_prob: 1.0
58 |
59 | # # # # # # # # # # # # # #
60 | # 3. output files, and loss
61 | # # # # # # # # # # # # # #
62 | saver:
63 | snapshot_dir: checkpoints
64 | pretrain: ''
65 | use_tb: False
66 |
67 | criterion:
68 | type: CELoss
69 | kwargs:
70 | use_weight: False
71 |
72 | # # # # # # # # # # # # # #
73 | # 4. models
74 | # # # # # # # # # # # # # #
75 | net: # Required.
76 | num_classes: 21
77 | sync_bn: True
78 | ema_decay: 0.999
79 | encoder:
80 | # type: augseg.models.resnet.resnet50
81 | # pretrain: ./pretrained/resnet50.pth
82 | type: augseg.models.resnet.resnet101
83 | pretrain: ./pretrained/resnet101.pth
84 | kwargs:
85 | zero_init_residual: True
86 | multi_grid: True
87 | replace_stride_with_dilation: [False, False, True] # [output_stride = 16]
88 | # replace_stride_with_dilation: [False, True, True] # [output_stride = 8]
89 | decoder:
90 | type: augseg.models.decoder.dec_deeplabv3_plus
91 | kwargs:
92 | inner_planes: 256
93 | low_conv_planes: 48
94 | dilations: [6, 12, 18] # [output_stride = 16]
95 | # dilations: [12, 24, 36] # [output_stride = 8]
96 |
--------------------------------------------------------------------------------
/training-logs/4 Semi-Voc-blender/r101-662/config_semi.yaml:
--------------------------------------------------------------------------------
1 | # # # # # # # # # # # # # #
2 | # 1. datasets
3 | # # # # # # # # # # # # # #
4 | dataset: # Required.
5 | type: pascal_semi
6 | train:
7 | data_root: ./data/VOC2012
8 | data_list: ./data/splitsall/pascal/662/labeled.txt
9 | batch_size: 8
10 | flip: True
11 | rand_resize: [0.5, 2.0]
12 | resize_base_size: 500
13 | crop:
14 | type: rand
15 | size: [513, 513]
16 | #---- 1) strong data augs ----#
17 | strong_aug:
18 | num_augs: 3 # !!!!!!!!!!!!!!!!!!!! key param1 (4): 1, 2, 3, 4 !!!!!!!!!!!!!!!!!!!!
19 | flag_use_random_num_sampling: True # !!!!!!!!!!!!!!!!!!!! key param1 (2): True, False !!!!!!!!!!!!!!!!!!!!
20 | val:
21 | data_root: ./data/VOC2012
22 | data_list: ./data/splitsall/pascal/val.txt
23 | batch_size: 1
24 | n_sup: 662
25 | workers: 4
26 | mean: [0.485, 0.456, 0.406]
27 | std: [0.229, 0.224, 0.225]
28 | ignore_label: 255
29 |
30 | # # # # # # # # # # # # # #
31 | # 2. training params
32 | # # # # # # # # # # # # # #
33 | trainer: # Required.
34 | epochs: 80
35 | sup_only_epoch: 0 # 0, 1
36 | evaluate_student: True
37 | optimizer:
38 | type: SGD
39 | kwargs:
40 | lr: 0.001 # 0.001:8*2gpus, 4*4gpus; 0.002:8*4gpus.
41 | momentum: 0.9
42 | weight_decay: 0.0001
43 | lr_scheduler:
44 | mode: poly
45 | kwargs:
46 | power: 0.9
47 | # # # # # # # # # # # # # #
48 | # unsupervised loss
49 | # # # # # # # # # # # # # #
50 | unsupervised:
51 | flag_extra_weak: False
52 | threshold: 0.95 # ============================================================================= <>: [0.7, 0.75, 0.8, 0.85, 0.9, 0.95]
53 | loss_weight: 1.0 # ============================================================================= << abl2-weight >> : [1.0, 1.5, 2.0, 2.5, 3.0, 4.0]
54 | #------ 2)cutmix augs ------#
55 | use_cutmix: True
56 | use_cutmix_adaptive: True
57 | use_cutmix_trigger_prob: 1.0
58 |
59 | # # # # # # # # # # # # # #
60 | # 3. output files, and loss
61 | # # # # # # # # # # # # # #
62 | saver:
63 | snapshot_dir: checkpoints
64 | pretrain: ''
65 | use_tb: False
66 |
67 | criterion:
68 | type: CELoss
69 | kwargs:
70 | use_weight: False
71 |
72 | # # # # # # # # # # # # # #
73 | # 4. models
74 | # # # # # # # # # # # # # #
75 | net: # Required.
76 | num_classes: 21
77 | sync_bn: True
78 | ema_decay: 0.999
79 | encoder:
80 | # type: augseg.models.resnet.resnet50
81 | # pretrain: ./pretrained/resnet50.pth
82 | type: augseg.models.resnet.resnet101
83 | pretrain: ./pretrained/resnet101.pth
84 | kwargs:
85 | zero_init_residual: True
86 | multi_grid: True
87 | replace_stride_with_dilation: [False, False, True] # [output_stride = 16]
88 | # replace_stride_with_dilation: [False, True, True] # [output_stride = 8]
89 | decoder:
90 | type: augseg.models.decoder.dec_deeplabv3_plus
91 | kwargs:
92 | inner_planes: 256
93 | low_conv_planes: 48
94 | dilations: [6, 12, 18] # [output_stride = 16]
95 | # dilations: [12, 24, 36] # [output_stride = 8]
96 |
--------------------------------------------------------------------------------
/training-logs/4 Semi-Voc-blender/r50-1323/config_semi.yaml:
--------------------------------------------------------------------------------
1 | # # # # # # # # # # # # # #
2 | # 1. datasets
3 | # # # # # # # # # # # # # #
4 | dataset: # Required.
5 | type: pascal_semi
6 | train:
7 | data_root: ./data/VOC2012
8 | data_list: ./data/splitsall/pascal/1323/labeled.txt
9 | batch_size: 8
10 | flip: True
11 | rand_resize: [0.5, 2.0]
12 | resize_base_size: 500
13 | crop:
14 | type: rand
15 | size: [513, 513]
16 | #---- 1) strong data augs ----#
17 | strong_aug:
18 | num_augs: 3 # !!!!!!!!!!!!!!!!!!!! key param1 (4): 1, 2, 3, 4 !!!!!!!!!!!!!!!!!!!!
19 | flag_use_random_num_sampling: True # !!!!!!!!!!!!!!!!!!!! key param1 (2): True, False !!!!!!!!!!!!!!!!!!!!
20 | val:
21 | data_root: ./data/VOC2012
22 | data_list: ./data/splitsall/pascal/val.txt
23 | batch_size: 1
24 | n_sup: 1323
25 | workers: 4
26 | mean: [0.485, 0.456, 0.406]
27 | std: [0.229, 0.224, 0.225]
28 | ignore_label: 255
29 |
30 | # # # # # # # # # # # # # #
31 | # 2. training params
32 | # # # # # # # # # # # # # #
33 | trainer: # Required.
34 | epochs: 80
35 | sup_only_epoch: 0 # 0, 1
36 | evaluate_student: True
37 | optimizer:
38 | type: SGD
39 | kwargs:
40 | lr: 0.001 # 0.001:8*2gpus, 4*4gpus; 0.002:8*4gpus.
41 | momentum: 0.9
42 | weight_decay: 0.0001
43 | lr_scheduler:
44 | mode: poly
45 | kwargs:
46 | power: 0.9
47 | # # # # # # # # # # # # # #
48 | # unsupervised loss
49 | # # # # # # # # # # # # # #
50 | unsupervised:
51 | flag_extra_weak: False
52 | threshold: 0.95 # ============================================================================= <>: [0.7, 0.75, 0.8, 0.85, 0.9, 0.95]
53 | loss_weight: 1.0 # ============================================================================= << abl2-weight >> : [1.0, 1.5, 2.0, 2.5, 3.0, 4.0]
54 | #------ 2)cutmix augs ------#
55 | use_cutmix: True
56 | use_cutmix_adaptive: True
57 | use_cutmix_trigger_prob: 1.0
58 |
59 | # # # # # # # # # # # # # #
60 | # 3. output files, and loss
61 | # # # # # # # # # # # # # #
62 | saver:
63 | snapshot_dir: checkpoints
64 | pretrain: ''
65 | use_tb: False
66 |
67 | criterion:
68 | type: CELoss
69 | kwargs:
70 | use_weight: False
71 |
72 | # # # # # # # # # # # # # #
73 | # 4. models
74 | # # # # # # # # # # # # # #
75 | net: # Required.
76 | num_classes: 21
77 | sync_bn: True
78 | ema_decay: 0.999
79 | encoder:
80 | type: augseg.models.resnet.resnet50
81 | pretrain: ./pretrained/resnet50.pth
82 | # type: augseg.models.resnet.resnet101
83 | # pretrain: ./pretrained/resnet101.pth
84 | kwargs:
85 | zero_init_residual: True
86 | multi_grid: True
87 | replace_stride_with_dilation: [False, False, True] # [output_stride = 16]
88 | # replace_stride_with_dilation: [False, True, True] # [output_stride = 8]
89 | decoder:
90 | type: augseg.models.decoder.dec_deeplabv3_plus
91 | kwargs:
92 | inner_planes: 256
93 | low_conv_planes: 48
94 | dilations: [6, 12, 18] # [output_stride = 16]
95 | # dilations: [12, 24, 36] # [output_stride = 8]
96 |
--------------------------------------------------------------------------------
/training-logs/4 Semi-Voc-blender/r50-2646/config_semi.yaml:
--------------------------------------------------------------------------------
1 | # # # # # # # # # # # # # #
2 | # 1. datasets
3 | # # # # # # # # # # # # # #
4 | dataset: # Required.
5 | type: pascal_semi
6 | train:
7 | data_root: ./data/VOC2012
8 | data_list: ./data/splitsall/pascal/2646/labeled.txt
9 | batch_size: 8
10 | flip: True
11 | rand_resize: [0.5, 2.0]
12 | resize_base_size: 500
13 | crop:
14 | type: rand
15 | size: [513, 513]
16 | #---- 1) strong data augs ----#
17 | strong_aug:
18 | num_augs: 3 # !!!!!!!!!!!!!!!!!!!! key param1 (4): 1, 2, 3, 4 !!!!!!!!!!!!!!!!!!!!
19 | flag_use_random_num_sampling: True # !!!!!!!!!!!!!!!!!!!! key param1 (2): True, False !!!!!!!!!!!!!!!!!!!!
20 | val:
21 | data_root: ./data/VOC2012
22 | data_list: ./data/splitsall/pascal/val.txt
23 | batch_size: 1
24 | n_sup: 2646
25 | workers: 4
26 | mean: [0.485, 0.456, 0.406]
27 | std: [0.229, 0.224, 0.225]
28 | ignore_label: 255
29 |
30 | # # # # # # # # # # # # # #
31 | # 2. training params
32 | # # # # # # # # # # # # # #
33 | trainer: # Required.
34 | epochs: 80
35 | sup_only_epoch: 0 # 0, 1
36 | evaluate_student: True
37 | optimizer:
38 | type: SGD
39 | kwargs:
40 | lr: 0.001 # 0.001:8*2gpus, 4*4gpus; 0.002:8*4gpus.
41 | momentum: 0.9
42 | weight_decay: 0.0001
43 | lr_scheduler:
44 | mode: poly
45 | kwargs:
46 | power: 0.9
47 | # # # # # # # # # # # # # #
48 | # unsupervised loss
49 | # # # # # # # # # # # # # #
50 | unsupervised:
51 | flag_extra_weak: False
52 | threshold: 0.95 # ============================================================================= <>: [0.7, 0.75, 0.8, 0.85, 0.9, 0.95]
53 | loss_weight: 1.0 # ============================================================================= << abl2-weight >> : [1.0, 1.5, 2.0, 2.5, 3.0, 4.0]
54 | #------ 2)cutmix augs ------#
55 | use_cutmix: True
56 | use_cutmix_adaptive: True
57 | use_cutmix_trigger_prob: 1.0
58 |
59 | # # # # # # # # # # # # # #
60 | # 3. output files, and loss
61 | # # # # # # # # # # # # # #
62 | saver:
63 | snapshot_dir: checkpoints
64 | pretrain: ''
65 | use_tb: False
66 |
67 | criterion:
68 | type: CELoss
69 | kwargs:
70 | use_weight: False
71 |
72 | # # # # # # # # # # # # # #
73 | # 4. models
74 | # # # # # # # # # # # # # #
75 | net: # Required.
76 | num_classes: 21
77 | sync_bn: True
78 | ema_decay: 0.999
79 | encoder:
80 | type: augseg.models.resnet.resnet50
81 | pretrain: ./pretrained/resnet50.pth
82 | # type: augseg.models.resnet.resnet101
83 | # pretrain: ./pretrained/resnet101.pth
84 | kwargs:
85 | zero_init_residual: True
86 | multi_grid: True
87 | replace_stride_with_dilation: [False, False, True] # [output_stride = 16]
88 | # replace_stride_with_dilation: [False, True, True] # [output_stride = 8]
89 | decoder:
90 | type: augseg.models.decoder.dec_deeplabv3_plus
91 | kwargs:
92 | inner_planes: 256
93 | low_conv_planes: 48
94 | dilations: [6, 12, 18] # [output_stride = 16]
95 | # dilations: [12, 24, 36] # [output_stride = 8]
96 |
--------------------------------------------------------------------------------
/training-logs/4 Semi-Voc-blender/r50-662/config_semi.yaml:
--------------------------------------------------------------------------------
1 | # # # # # # # # # # # # # #
2 | # 1. datasets
3 | # # # # # # # # # # # # # #
4 | dataset: # Required.
5 | type: pascal_semi
6 | train:
7 | data_root: ./data/VOC2012
8 | data_list: ./data/splitsall/pascal/662/labeled.txt
9 | batch_size: 8
10 | flip: True
11 | rand_resize: [0.5, 2.0]
12 | resize_base_size: 500
13 | crop:
14 | type: rand
15 | size: [513, 513]
16 | #---- 1) strong data augs ----#
17 | strong_aug:
18 | num_augs: 3 # !!!!!!!!!!!!!!!!!!!! key param1 (4): 1, 2, 3, 4 !!!!!!!!!!!!!!!!!!!!
19 | flag_use_random_num_sampling: True # !!!!!!!!!!!!!!!!!!!! key param1 (2): True, False !!!!!!!!!!!!!!!!!!!!
20 | val:
21 | data_root: ./data/VOC2012
22 | data_list: ./data/splitsall/pascal/val.txt
23 | batch_size: 1
24 | n_sup: 662
25 | workers: 4
26 | mean: [0.485, 0.456, 0.406]
27 | std: [0.229, 0.224, 0.225]
28 | ignore_label: 255
29 |
30 | # # # # # # # # # # # # # #
31 | # 2. training params
32 | # # # # # # # # # # # # # #
33 | trainer: # Required.
34 | epochs: 80
35 | sup_only_epoch: 0 # 0, 1
36 | evaluate_student: True
37 | optimizer:
38 | type: SGD
39 | kwargs:
40 | lr: 0.001 # 0.001:8*2gpus, 4*4gpus; 0.002:8*4gpus.
41 | momentum: 0.9
42 | weight_decay: 0.0001
43 | lr_scheduler:
44 | mode: poly
45 | kwargs:
46 | power: 0.9
47 | # # # # # # # # # # # # # #
48 | # unsupervised loss
49 | # # # # # # # # # # # # # #
50 | unsupervised:
51 | flag_extra_weak: False
52 | threshold: 0.95 # ============================================================================= <>: [0.7, 0.75, 0.8, 0.85, 0.9, 0.95]
53 | loss_weight: 1.0 # ============================================================================= << abl2-weight >> : [1.0, 1.5, 2.0, 2.5, 3.0, 4.0]
54 | #------ 2)cutmix augs ------#
55 | use_cutmix: True
56 | use_cutmix_adaptive: True
57 | use_cutmix_trigger_prob: 1.0
58 |
59 | # # # # # # # # # # # # # #
60 | # 3. output files, and loss
61 | # # # # # # # # # # # # # #
62 | saver:
63 | snapshot_dir: checkpoints
64 | pretrain: ''
65 | use_tb: False
66 |
67 | criterion:
68 | type: CELoss
69 | kwargs:
70 | use_weight: False
71 |
72 | # # # # # # # # # # # # # #
73 | # 4. models
74 | # # # # # # # # # # # # # #
75 | net: # Required.
76 | num_classes: 21
77 | sync_bn: True
78 | ema_decay: 0.999
79 | encoder:
80 | type: augseg.models.resnet.resnet50
81 | pretrain: ./pretrained/resnet50.pth
82 | # type: augseg.models.resnet.resnet101
83 | # pretrain: ./pretrained/resnet101.pth
84 | kwargs:
85 | zero_init_residual: True
86 | multi_grid: True
87 | replace_stride_with_dilation: [False, False, True] # [output_stride = 16]
88 | # replace_stride_with_dilation: [False, True, True] # [output_stride = 8]
89 | decoder:
90 | type: augseg.models.decoder.dec_deeplabv3_plus
91 | kwargs:
92 | inner_planes: 256
93 | low_conv_planes: 48
94 | dilations: [6, 12, 18] # [output_stride = 16]
95 | # dilations: [12, 24, 36] # [output_stride = 8]
96 |
--------------------------------------------------------------------------------
/training-logs/5 Semi-Voc-blender-split-u2pl/r101-1323/config_semi.yaml:
--------------------------------------------------------------------------------
1 | # # # # # # # # # # # # # #
2 | # 1. datasets
3 | # # # # # # # # # # # # # #
4 | dataset: # Required.
5 | type: pascal_semi
6 | train:
7 | data_root: ./data/VOC2012
8 | data_list: ./data/splitsall/pascal_u2pl/1323/labeled.txt
9 | batch_size: 8
10 | flip: True
11 | rand_resize: [0.5, 2.0]
12 | resize_base_size: 500
13 | crop:
14 | type: rand
15 | size: [513, 513]
16 | #---- 1) strong data augs ----#
17 | strong_aug:
18 | num_augs: 3 # !!!!!!!!!!!!!!!!!!!! key param1 (4): 1, 2, 3, 4 !!!!!!!!!!!!!!!!!!!!
19 | flag_use_random_num_sampling: True # !!!!!!!!!!!!!!!!!!!! key param1 (2): True, False !!!!!!!!!!!!!!!!!!!!
20 | val:
21 | data_root: ./data/VOC2012
22 | data_list: ./data/splitsall/pascal_u2pl/val.txt
23 | batch_size: 1
24 | n_sup: 1323
25 | workers: 4
26 | mean: [0.485, 0.456, 0.406]
27 | std: [0.229, 0.224, 0.225]
28 | ignore_label: 255
29 |
30 | # # # # # # # # # # # # # #
31 | # 2. training params
32 | # # # # # # # # # # # # # #
33 | trainer: # Required.
34 | epochs: 80
35 | sup_only_epoch: 0 # 0, 1
36 | evaluate_student: True
37 | optimizer:
38 | type: SGD
39 | kwargs:
40 | lr: 0.001 # 0.001:8*2gpus, 4*4gpus; 0.002:8*4gpus.
41 | momentum: 0.9
42 | weight_decay: 0.0001
43 | lr_scheduler:
44 | mode: poly
45 | kwargs:
46 | power: 0.9
47 | # # # # # # # # # # # # # #
48 | # unsupervised loss
49 | # # # # # # # # # # # # # #
50 | unsupervised:
51 | flag_extra_weak: False
52 | threshold: 0.95 # ============================================================================= <>: [0.7, 0.75, 0.8, 0.85, 0.9, 0.95]
53 | loss_weight: 1.0 # ============================================================================= << abl2-weight >> : [1.0, 1.5, 2.0, 2.5, 3.0, 4.0]
54 | #------ 2)cutmix augs ------#
55 | use_cutmix: True
56 | use_cutmix_adaptive: True
57 | use_cutmix_trigger_prob: 1.0
58 |
59 | # # # # # # # # # # # # # #
60 | # 3. output files, and loss
61 | # # # # # # # # # # # # # #
62 | saver:
63 | snapshot_dir: checkpoints
64 | pretrain: ''
65 | use_tb: False
66 |
67 | criterion:
68 | type: CELoss
69 | kwargs:
70 | use_weight: False
71 |
72 | # # # # # # # # # # # # # #
73 | # 4. models
74 | # # # # # # # # # # # # # #
75 | net: # Required.
76 | num_classes: 21
77 | sync_bn: True
78 | ema_decay: 0.999
79 | encoder:
80 | # type: augseg.models.resnet.resnet50
81 | # pretrain: ./pretrained/resnet50.pth
82 | type: augseg.models.resnet.resnet101
83 | pretrain: ./pretrained/resnet101.pth
84 | kwargs:
85 | zero_init_residual: True
86 | multi_grid: True
87 | replace_stride_with_dilation: [False, False, True] # [output_stride = 16]
88 | # replace_stride_with_dilation: [False, True, True] # [output_stride = 8]
89 | decoder:
90 | type: augseg.models.decoder.dec_deeplabv3_plus
91 | kwargs:
92 | inner_planes: 256
93 | low_conv_planes: 48
94 | dilations: [6, 12, 18] # [output_stride = 16]
95 | # dilations: [12, 24, 36] # [output_stride = 8]
96 |
--------------------------------------------------------------------------------
/training-logs/5 Semi-Voc-blender-split-u2pl/r101-2646/config_semi.yaml:
--------------------------------------------------------------------------------
1 | # # # # # # # # # # # # # #
2 | # 1. datasets
3 | # # # # # # # # # # # # # #
4 | dataset: # Required.
5 | type: pascal_semi
6 | train:
7 | data_root: ./data/VOC2012
8 | data_list: ./data/splitsall/pascal_u2pl/2646/labeled.txt
9 | batch_size: 8
10 | flip: True
11 | rand_resize: [0.5, 2.0]
12 | resize_base_size: 500
13 | crop:
14 | type: rand
15 | size: [513, 513]
16 | #---- 1) strong data augs ----#
17 | strong_aug:
18 | num_augs: 3 # !!!!!!!!!!!!!!!!!!!! key param1 (4): 1, 2, 3, 4 !!!!!!!!!!!!!!!!!!!!
19 | flag_use_random_num_sampling: True # !!!!!!!!!!!!!!!!!!!! key param1 (2): True, False !!!!!!!!!!!!!!!!!!!!
20 | val:
21 | data_root: ./data/VOC2012
22 | data_list: ./data/splitsall/pascal_u2pl/val.txt
23 | batch_size: 1
24 | n_sup: 2646
25 | workers: 4
26 | mean: [0.485, 0.456, 0.406]
27 | std: [0.229, 0.224, 0.225]
28 | ignore_label: 255
29 |
30 | # # # # # # # # # # # # # #
31 | # 2. training params
32 | # # # # # # # # # # # # # #
33 | trainer: # Required.
34 | epochs: 80
35 | sup_only_epoch: 0 # 0, 1
36 | evaluate_student: True
37 | optimizer:
38 | type: SGD
39 | kwargs:
40 | lr: 0.001 # 0.001:8*2gpus, 4*4gpus; 0.002:8*4gpus.
41 | momentum: 0.9
42 | weight_decay: 0.0001
43 | lr_scheduler:
44 | mode: poly
45 | kwargs:
46 | power: 0.9
47 | # # # # # # # # # # # # # #
48 | # unsupervised loss
49 | # # # # # # # # # # # # # #
50 | unsupervised:
51 | flag_extra_weak: False
52 | threshold: 0.95 # ============================================================================= <>: [0.7, 0.75, 0.8, 0.85, 0.9, 0.95]
53 | loss_weight: 1.0 # ============================================================================= << abl2-weight >> : [1.0, 1.5, 2.0, 2.5, 3.0, 4.0]
54 | #------ 2)cutmix augs ------#
55 | use_cutmix: True
56 | use_cutmix_adaptive: True
57 | use_cutmix_trigger_prob: 1.0
58 |
59 | # # # # # # # # # # # # # #
60 | # 3. output files, and loss
61 | # # # # # # # # # # # # # #
62 | saver:
63 | snapshot_dir: checkpoints
64 | pretrain: ''
65 | use_tb: False
66 |
67 | criterion:
68 | type: CELoss
69 | kwargs:
70 | use_weight: False
71 |
72 | # # # # # # # # # # # # # #
73 | # 4. models
74 | # # # # # # # # # # # # # #
75 | net: # Required.
76 | num_classes: 21
77 | sync_bn: True
78 | ema_decay: 0.999
79 | encoder:
80 | # type: augseg.models.resnet.resnet50
81 | # pretrain: ./pretrained/resnet50.pth
82 | type: augseg.models.resnet.resnet101
83 | pretrain: ./pretrained/resnet101.pth
84 | kwargs:
85 | zero_init_residual: True
86 | multi_grid: True
87 | replace_stride_with_dilation: [False, False, True] # [output_stride = 16]
88 | # replace_stride_with_dilation: [False, True, True] # [output_stride = 8]
89 | decoder:
90 | type: augseg.models.decoder.dec_deeplabv3_plus
91 | kwargs:
92 | inner_planes: 256
93 | low_conv_planes: 48
94 | dilations: [6, 12, 18] # [output_stride = 16]
95 | # dilations: [12, 24, 36] # [output_stride = 8]
96 |
--------------------------------------------------------------------------------
/training-logs/5 Semi-Voc-blender-split-u2pl/r101-662/config_semi.yaml:
--------------------------------------------------------------------------------
1 | # # # # # # # # # # # # # #
2 | # 1. datasets
3 | # # # # # # # # # # # # # #
4 | dataset: # Required.
5 | type: pascal_semi
6 | train:
7 | data_root: ./data/VOC2012
8 | data_list: ./data/splitsall/pascal_u2pl/662/labeled.txt
9 | batch_size: 8
10 | flip: True
11 | rand_resize: [0.5, 2.0]
12 | resize_base_size: 500
13 | crop:
14 | type: rand
15 | size: [513, 513]
16 | #---- 1) strong data augs ----#
17 | strong_aug:
18 | num_augs: 3 # !!!!!!!!!!!!!!!!!!!! key param1 (4): 1, 2, 3, 4 !!!!!!!!!!!!!!!!!!!!
19 | flag_use_random_num_sampling: True # !!!!!!!!!!!!!!!!!!!! key param1 (2): True, False !!!!!!!!!!!!!!!!!!!!
20 | val:
21 | data_root: ./data/VOC2012
22 | data_list: ./data/splitsall/pascal_u2pl/val.txt
23 | batch_size: 1
24 | n_sup: 662
25 | workers: 4
26 | mean: [0.485, 0.456, 0.406]
27 | std: [0.229, 0.224, 0.225]
28 | ignore_label: 255
29 |
30 | # # # # # # # # # # # # # #
31 | # 2. training params
32 | # # # # # # # # # # # # # #
33 | trainer: # Required.
34 | epochs: 80
35 | sup_only_epoch: 0 # 0, 1
36 | evaluate_student: True
37 | optimizer:
38 | type: SGD
39 | kwargs:
40 | lr: 0.001 # 0.001:8*2gpus, 4*4gpus; 0.002:8*4gpus.
41 | momentum: 0.9
42 | weight_decay: 0.0001
43 | lr_scheduler:
44 | mode: poly
45 | kwargs:
46 | power: 0.9
47 | # # # # # # # # # # # # # #
48 | # unsupervised loss
49 | # # # # # # # # # # # # # #
50 | unsupervised:
51 | flag_extra_weak: False
52 | threshold: 0.95 # ============================================================================= <>: [0.7, 0.75, 0.8, 0.85, 0.9, 0.95]
53 | loss_weight: 1.0 # ============================================================================= << abl2-weight >> : [1.0, 1.5, 2.0, 2.5, 3.0, 4.0]
54 | #------ 2)cutmix augs ------#
55 | use_cutmix: True
56 | use_cutmix_adaptive: True
57 | use_cutmix_trigger_prob: 1.0
58 |
59 | # # # # # # # # # # # # # #
60 | # 3. output files, and loss
61 | # # # # # # # # # # # # # #
62 | saver:
63 | snapshot_dir: checkpoints
64 | pretrain: ''
65 | use_tb: False
66 |
67 | criterion:
68 | type: CELoss
69 | kwargs:
70 | use_weight: False
71 |
72 | # # # # # # # # # # # # # #
73 | # 4. models
74 | # # # # # # # # # # # # # #
75 | net: # Required.
76 | num_classes: 21
77 | sync_bn: True
78 | ema_decay: 0.999
79 | encoder:
80 | # type: augseg.models.resnet.resnet50
81 | # pretrain: ./pretrained/resnet50.pth
82 | type: augseg.models.resnet.resnet101
83 | pretrain: ./pretrained/resnet101.pth
84 | kwargs:
85 | zero_init_residual: True
86 | multi_grid: True
87 | replace_stride_with_dilation: [False, False, True] # [output_stride = 16]
88 | # replace_stride_with_dilation: [False, True, True] # [output_stride = 8]
89 | decoder:
90 | type: augseg.models.decoder.dec_deeplabv3_plus
91 | kwargs:
92 | inner_planes: 256
93 | low_conv_planes: 48
94 | dilations: [6, 12, 18] # [output_stride = 16]
95 | # dilations: [12, 24, 36] # [output_stride = 8]
96 |
--------------------------------------------------------------------------------
/training-logs/5 Semi-Voc-blender-split-u2pl/r50-1323/config_semi.yaml:
--------------------------------------------------------------------------------
1 | # # # # # # # # # # # # # #
2 | # 1. datasets
3 | # # # # # # # # # # # # # #
4 | dataset: # Required.
5 | type: pascal_semi
6 | train:
7 | data_root: ./data/VOC2012
8 | data_list: ./data/splitsall/pascal_u2pl/1323/labeled.txt
9 | batch_size: 8
10 | flip: True
11 | rand_resize: [0.5, 2.0]
12 | resize_base_size: 500
13 | crop:
14 | type: rand
15 | size: [513, 513]
16 | #---- 1) strong data augs ----#
17 | strong_aug:
18 | num_augs: 3 # !!!!!!!!!!!!!!!!!!!! key param1 (4): 1, 2, 3, 4 !!!!!!!!!!!!!!!!!!!!
19 | flag_use_random_num_sampling: True # !!!!!!!!!!!!!!!!!!!! key param1 (2): True, False !!!!!!!!!!!!!!!!!!!!
20 | val:
21 | data_root: ./data/VOC2012
22 | data_list: ./data/splitsall/pascal_u2pl/val.txt
23 | batch_size: 1
24 | n_sup: 1323
25 | workers: 4
26 | mean: [0.485, 0.456, 0.406]
27 | std: [0.229, 0.224, 0.225]
28 | ignore_label: 255
29 |
30 | # # # # # # # # # # # # # #
31 | # 2. training params
32 | # # # # # # # # # # # # # #
33 | trainer: # Required.
34 | epochs: 80
35 | sup_only_epoch: 0 # 0, 1
36 | evaluate_student: True
37 | optimizer:
38 | type: SGD
39 | kwargs:
40 | lr: 0.001 # 0.001:8*2gpus, 4*4gpus; 0.002:8*4gpus.
41 | momentum: 0.9
42 | weight_decay: 0.0001
43 | lr_scheduler:
44 | mode: poly
45 | kwargs:
46 | power: 0.9
47 | # # # # # # # # # # # # # #
48 | # unsupervised loss
49 | # # # # # # # # # # # # # #
50 | unsupervised:
51 | flag_extra_weak: False
52 | threshold: 0.95 # ============================================================================= <>: [0.7, 0.75, 0.8, 0.85, 0.9, 0.95]
53 | loss_weight: 1.0 # ============================================================================= << abl2-weight >> : [1.0, 1.5, 2.0, 2.5, 3.0, 4.0]
54 | #------ 2)cutmix augs ------#
55 | use_cutmix: True
56 | use_cutmix_adaptive: True
57 | use_cutmix_trigger_prob: 1.0
58 |
59 | # # # # # # # # # # # # # #
60 | # 3. output files, and loss
61 | # # # # # # # # # # # # # #
62 | saver:
63 | snapshot_dir: checkpoints
64 | pretrain: ''
65 | use_tb: False
66 |
67 | criterion:
68 | type: CELoss
69 | kwargs:
70 | use_weight: False
71 |
72 | # # # # # # # # # # # # # #
73 | # 4. models
74 | # # # # # # # # # # # # # #
75 | net: # Required.
76 | num_classes: 21
77 | sync_bn: True
78 | ema_decay: 0.999
79 | encoder:
80 | type: augseg.models.resnet.resnet50
81 | pretrain: ./pretrained/resnet50.pth
82 | # type: augseg.models.resnet.resnet101
83 | # pretrain: ./pretrained/resnet101.pth
84 | kwargs:
85 | zero_init_residual: True
86 | multi_grid: True
87 | replace_stride_with_dilation: [False, False, True] # [output_stride = 16]
88 | # replace_stride_with_dilation: [False, True, True] # [output_stride = 8]
89 | decoder:
90 | type: augseg.models.decoder.dec_deeplabv3_plus
91 | kwargs:
92 | inner_planes: 256
93 | low_conv_planes: 48
94 | dilations: [6, 12, 18] # [output_stride = 16]
95 | # dilations: [12, 24, 36] # [output_stride = 8]
96 |
--------------------------------------------------------------------------------
/training-logs/5 Semi-Voc-blender-split-u2pl/r50-2646/config_semi.yaml:
--------------------------------------------------------------------------------
1 | # # # # # # # # # # # # # #
2 | # 1. datasets
3 | # # # # # # # # # # # # # #
4 | dataset: # Required.
5 | type: pascal_semi
6 | train:
7 | data_root: ./data/VOC2012
8 | data_list: ./data/splitsall/pascal_u2pl/2646/labeled.txt
9 | batch_size: 8
10 | flip: True
11 | rand_resize: [0.5, 2.0]
12 | resize_base_size: 500
13 | crop:
14 | type: rand
15 | size: [513, 513]
16 | #---- 1) strong data augs ----#
17 | strong_aug:
18 | num_augs: 3 # !!!!!!!!!!!!!!!!!!!! key param1 (4): 1, 2, 3, 4 !!!!!!!!!!!!!!!!!!!!
19 | flag_use_random_num_sampling: True # !!!!!!!!!!!!!!!!!!!! key param1 (2): True, False !!!!!!!!!!!!!!!!!!!!
20 | val:
21 | data_root: ./data/VOC2012
22 | data_list: ./data/splitsall/pascal_u2pl/val.txt
23 | batch_size: 1
24 | n_sup: 2646
25 | workers: 4
26 | mean: [0.485, 0.456, 0.406]
27 | std: [0.229, 0.224, 0.225]
28 | ignore_label: 255
29 |
30 | # # # # # # # # # # # # # #
31 | # 2. training params
32 | # # # # # # # # # # # # # #
33 | trainer: # Required.
34 | epochs: 80
35 | sup_only_epoch: 0 # 0, 1
36 | evaluate_student: True
37 | optimizer:
38 | type: SGD
39 | kwargs:
40 | lr: 0.001 # 0.001:8*2gpus, 4*4gpus; 0.002:8*4gpus.
41 | momentum: 0.9
42 | weight_decay: 0.0001
43 | lr_scheduler:
44 | mode: poly
45 | kwargs:
46 | power: 0.9
47 | # # # # # # # # # # # # # #
48 | # unsupervised loss
49 | # # # # # # # # # # # # # #
50 | unsupervised:
51 | flag_extra_weak: False
52 | threshold: 0.95 # ============================================================================= <>: [0.7, 0.75, 0.8, 0.85, 0.9, 0.95]
53 | loss_weight: 1.0 # ============================================================================= << abl2-weight >> : [1.0, 1.5, 2.0, 2.5, 3.0, 4.0]
54 | #------ 2)cutmix augs ------#
55 | use_cutmix: True
56 | use_cutmix_adaptive: True
57 | use_cutmix_trigger_prob: 1.0
58 |
59 | # # # # # # # # # # # # # #
60 | # 3. output files, and loss
61 | # # # # # # # # # # # # # #
62 | saver:
63 | snapshot_dir: checkpoints
64 | pretrain: ''
65 | use_tb: False
66 |
67 | criterion:
68 | type: CELoss
69 | kwargs:
70 | use_weight: False
71 |
72 | # # # # # # # # # # # # # #
73 | # 4. models
74 | # # # # # # # # # # # # # #
75 | net: # Required.
76 | num_classes: 21
77 | sync_bn: True
78 | ema_decay: 0.999
79 | encoder:
80 | type: augseg.models.resnet.resnet50
81 | pretrain: ./pretrained/resnet50.pth
82 | # type: augseg.models.resnet.resnet101
83 | # pretrain: ./pretrained/resnet101.pth
84 | kwargs:
85 | zero_init_residual: True
86 | multi_grid: True
87 | replace_stride_with_dilation: [False, False, True] # [output_stride = 16]
88 | # replace_stride_with_dilation: [False, True, True] # [output_stride = 8]
89 | decoder:
90 | type: augseg.models.decoder.dec_deeplabv3_plus
91 | kwargs:
92 | inner_planes: 256
93 | low_conv_planes: 48
94 | dilations: [6, 12, 18] # [output_stride = 16]
95 | # dilations: [12, 24, 36] # [output_stride = 8]
96 |
--------------------------------------------------------------------------------
/training-logs/5 Semi-Voc-blender-split-u2pl/r50-662/config_semi.yaml:
--------------------------------------------------------------------------------
1 | # # # # # # # # # # # # # #
2 | # 1. datasets
3 | # # # # # # # # # # # # # #
4 | dataset: # Required.
5 | type: pascal_semi
6 | train:
7 | data_root: ./data/VOC2012
8 | data_list: ./data/splitsall/pascal_u2pl/662/labeled.txt
9 | batch_size: 8
10 | flip: True
11 | rand_resize: [0.5, 2.0]
12 | resize_base_size: 500
13 | crop:
14 | type: rand
15 | size: [513, 513]
16 | #---- 1) strong data augs ----#
17 | strong_aug:
18 | num_augs: 3 # !!!!!!!!!!!!!!!!!!!! key param1 (4): 1, 2, 3, 4 !!!!!!!!!!!!!!!!!!!!
19 | flag_use_random_num_sampling: True # !!!!!!!!!!!!!!!!!!!! key param1 (2): True, False !!!!!!!!!!!!!!!!!!!!
20 | val:
21 | data_root: ./data/VOC2012
22 | data_list: ./data/splitsall/pascal_u2pl/val.txt
23 | batch_size: 1
24 | n_sup: 662
25 | workers: 4
26 | mean: [0.485, 0.456, 0.406]
27 | std: [0.229, 0.224, 0.225]
28 | ignore_label: 255
29 |
30 | # # # # # # # # # # # # # #
31 | # 2. training params
32 | # # # # # # # # # # # # # #
33 | trainer: # Required.
34 | epochs: 80
35 | sup_only_epoch: 0 # 0, 1
36 | evaluate_student: True
37 | optimizer:
38 | type: SGD
39 | kwargs:
40 | lr: 0.001 # 0.001:8*2gpus, 4*4gpus; 0.002:8*4gpus.
41 | momentum: 0.9
42 | weight_decay: 0.0001
43 | lr_scheduler:
44 | mode: poly
45 | kwargs:
46 | power: 0.9
47 | # # # # # # # # # # # # # #
48 | # unsupervised loss
49 | # # # # # # # # # # # # # #
50 | unsupervised:
51 | flag_extra_weak: False
52 | threshold: 0.95 # ============================================================================= <>: [0.7, 0.75, 0.8, 0.85, 0.9, 0.95]
53 | loss_weight: 1.0 # ============================================================================= << abl2-weight >> : [1.0, 1.5, 2.0, 2.5, 3.0, 4.0]
54 | #------ 2)cutmix augs ------#
55 | use_cutmix: True
56 | use_cutmix_adaptive: True
57 | use_cutmix_trigger_prob: 1.0
58 |
59 | # # # # # # # # # # # # # #
60 | # 3. output files, and loss
61 | # # # # # # # # # # # # # #
62 | saver:
63 | snapshot_dir: checkpoints
64 | pretrain: ''
65 | use_tb: False
66 |
67 | criterion:
68 | type: CELoss
69 | kwargs:
70 | use_weight: False
71 |
72 | # # # # # # # # # # # # # #
73 | # 4. models
74 | # # # # # # # # # # # # # #
75 | net: # Required.
76 | num_classes: 21
77 | sync_bn: True
78 | ema_decay: 0.999
79 | encoder:
80 | type: augseg.models.resnet.resnet50
81 | pretrain: ./pretrained/resnet50.pth
82 | # type: augseg.models.resnet.resnet101
83 | # pretrain: ./pretrained/resnet101.pth
84 | kwargs:
85 | zero_init_residual: True
86 | multi_grid: True
87 | replace_stride_with_dilation: [False, False, True] # [output_stride = 16]
88 | # replace_stride_with_dilation: [False, True, True] # [output_stride = 8]
89 | decoder:
90 | type: augseg.models.decoder.dec_deeplabv3_plus
91 | kwargs:
92 | inner_planes: 256
93 | low_conv_planes: 48
94 | dilations: [6, 12, 18] # [output_stride = 16]
95 | # dilations: [12, 24, 36] # [output_stride = 8]
96 |
--------------------------------------------------------------------------------
/training-logs/6 Semi-citys/R101-1488/config_semi.yaml:
--------------------------------------------------------------------------------
1 | # # # # # # # # # # # # # #
2 | # 1. datasets
3 | # # # # # # # # # # # # # #
4 | dataset: # Required.
5 | type: cityscapes_semi
6 | train:
7 | data_root: ./data/cityscapes
8 | data_list: ./data/splitsall/cityscapes/1488/labeled.txt
9 | batch_size: 4
10 | flip: True
11 | rand_resize: [0.5, 2.0]
12 | resize_base_size: [1024, 2048]
13 | crop:
14 | type: rand
15 | size: [800, 800]
16 | #---- 1) strong data augs ----#
17 | strong_aug:
18 | num_augs: 3
19 | flag_use_random_num_sampling: True # flag_use_wide_range
20 |
21 | val:
22 | data_root: ./data/cityscapes
23 | data_list: ./data/splitsall/cityscapes/val.txt
24 | batch_size: 4
25 | crop:
26 | size: [800, 800]
27 | n_sup: 1488
28 | workers: 4
29 | mean: [0.485, 0.456, 0.406]
30 | std: [0.229, 0.224, 0.225]
31 | ignore_label: 255
32 |
33 | # # # # # # # # # # # # # #
34 | # 2. training params
35 | # # # # # # # # # # # # # #
36 | trainer: # Required.
37 | epochs: 240
38 | sup_only_epoch: 0 # 0, 1
39 | evaluate_student: True
40 | optimizer:
41 | type: SGD
42 | kwargs:
43 | lr: 0.01 # 4GPUs
44 | momentum: 0.9
45 | weight_decay: 0.0005
46 | lr_scheduler:
47 | mode: poly
48 | kwargs:
49 | power: 0.9
50 | unsupervised:
51 | flag_extra_weak: False
52 | threshold: -0.1 # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ params1: 0.7, 0.75, 0.8, 0.85, 0.9, 0.95
53 | loss_weight: 1.0 # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ params2: 0.7, 1.0, 1.5, 2.0, 3.0, 4.0
54 | #------ 2)cutmix augs ------#
55 | use_cutmix: True
56 | use_cutmix_adaptive: True
57 | use_cutmix_trigger_prob: 1.0 # wide range, but trigger by 100%
58 |
59 | # # # # # # # # # # # # # #
60 | # 3. output files, and loss
61 | # # # # # # # # # # # # # #
62 | saver:
63 | snapshot_dir: checkpoints
64 | pretrain: ''
65 | use_tb: False
66 |
67 | criterion:
68 | type: ohem
69 | kwargs:
70 | thresh: 0.7
71 | min_kept: 100000
72 |
73 | # # # # # # # # # # # # # #
74 | # 4. models
75 | # # # # # # # # # # # # # #
76 | net: # Required.
77 | num_classes: 19
78 | sync_bn: True
79 | ema_decay: 0.996
80 | # aux_loss:
81 | # aux_plane: 1024
82 | # loss_weight: 0.4
83 | encoder:
84 | type: augseg.models.resnet.resnet101
85 | pretrain: ./pretrained/resnet101.pth
86 | # type: augseg.models.resnet.resnet50
87 | # pretrain: ./pretrained/resnet50.pth
88 | kwargs:
89 | zero_init_residual: True
90 | multi_grid: True
91 | replace_stride_with_dilation: [False, False, True] # [output_stride = 16]
92 | # replace_stride_with_dilation: [False, True, True] # [output_stride = 8]
93 | decoder:
94 | type: augseg.models.decoder.dec_deeplabv3_plus
95 | kwargs:
96 | inner_planes: 256
97 | low_conv_planes: 48 # 256
98 | dilations: [6, 12, 18]
99 | # dilations: [12, 24, 36] # [output_stride = 8]
100 |
--------------------------------------------------------------------------------
/training-logs/6 Semi-citys/R101-186/config_semi.yaml:
--------------------------------------------------------------------------------
1 | # # # # # # # # # # # # # #
2 | # 1. datasets
3 | # # # # # # # # # # # # # #
4 | dataset: # Required.
5 | type: cityscapes_semi
6 | train:
7 | data_root: ./data/cityscapes
8 | data_list: ./data/splitsall/cityscapes/186/labeled.txt
9 | batch_size: 4
10 | flip: True
11 | rand_resize: [0.5, 2.0]
12 | resize_base_size: [1024, 2048]
13 | crop:
14 | type: rand
15 | size: [800, 800]
16 | #---- 1) strong data augs ----#
17 | strong_aug:
18 | num_augs: 3
19 | flag_use_random_num_sampling: True # flag_use_wide_range
20 |
21 | val:
22 | data_root: ./data/cityscapes
23 | data_list: ./data/splitsall/cityscapes/val.txt
24 | batch_size: 4
25 | crop:
26 | size: [800, 800]
27 | n_sup: 186
28 | workers: 4
29 | mean: [0.485, 0.456, 0.406]
30 | std: [0.229, 0.224, 0.225]
31 | ignore_label: 255
32 |
33 | # # # # # # # # # # # # # #
34 | # 2. training params
35 | # # # # # # # # # # # # # #
36 | trainer: # Required.
37 | epochs: 240
38 | sup_only_epoch: 0 # 0, 1
39 | evaluate_student: False
40 | optimizer:
41 | type: SGD
42 | kwargs:
43 | lr: 0.01 # 4GPUs
44 | momentum: 0.9
45 | weight_decay: 0.0005
46 | lr_scheduler:
47 | mode: poly
48 | kwargs:
49 | power: 0.9
50 | unsupervised:
51 | flag_extra_weak: False
52 | threshold: -0.1 # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ params1: 0.7, 0.75, 0.8, 0.85, 0.9, 0.95
53 | loss_weight: 1.0 # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ params2: 0.7, 1.0, 1.5, 2.0, 3.0, 4.0
54 | #------ 2)cutmix augs ------#
55 | use_cutmix: True
56 | use_cutmix_adaptive: True
57 | use_cutmix_trigger_prob: 1.0 # wide range, but trigger by 100%
58 |
59 | # # # # # # # # # # # # # #
60 | # 3. output files, and loss
61 | # # # # # # # # # # # # # #
62 | saver:
63 | snapshot_dir: checkpoints
64 | pretrain: ''
65 | use_tb: False
66 |
67 | criterion:
68 | type: ohem
69 | kwargs:
70 | thresh: 0.7
71 | min_kept: 100000
72 |
73 | # # # # # # # # # # # # # #
74 | # 4. models
75 | # # # # # # # # # # # # # #
76 | net: # Required.
77 | num_classes: 19
78 | sync_bn: True
79 | ema_decay: 0.996
80 | # aux_loss:
81 | # aux_plane: 1024
82 | # loss_weight: 0.4
83 | encoder:
84 | type: augseg.models.resnet.resnet101
85 | pretrain: ./pretrained/resnet101.pth
86 | # type: augseg.models.resnet.resnet50
87 | # pretrain: ./pretrained/resnet50.pth
88 | kwargs:
89 | zero_init_residual: True
90 | multi_grid: True
91 | replace_stride_with_dilation: [False, False, True] # [output_stride = 16]
92 | # replace_stride_with_dilation: [False, True, True] # [output_stride = 8]
93 | decoder:
94 | type: augseg.models.decoder.dec_deeplabv3_plus
95 | kwargs:
96 | inner_planes: 256
97 | low_conv_planes: 48 # 256
98 | dilations: [6, 12, 18]
99 | # dilations: [12, 24, 36] # [output_stride = 8]
100 |
--------------------------------------------------------------------------------
/training-logs/6 Semi-citys/R101-372/config_semi.yaml:
--------------------------------------------------------------------------------
1 | # # # # # # # # # # # # # #
2 | # 1. datasets
3 | # # # # # # # # # # # # # #
4 | dataset: # Required.
5 | type: cityscapes_semi
6 | train:
7 | data_root: ./data/cityscapes
8 | data_list: ./data/splitsall/cityscapes/372/labeled.txt
9 | batch_size: 4
10 | flip: True
11 | rand_resize: [0.5, 2.0]
12 | resize_base_size: [1024, 2048]
13 | crop:
14 | type: rand
15 | size: [800, 800]
16 | #---- 1) strong data augs ----#
17 | strong_aug:
18 | num_augs: 3
19 | flag_use_random_num_sampling: True
20 | val:
21 | data_root: ./data/cityscapes
22 | data_list: ./data/splitsall/cityscapes/val.txt
23 | batch_size: 4
24 | crop:
25 | size: [800, 800]
26 | n_sup: 372
27 | workers: 4
28 | mean: [0.485, 0.456, 0.406]
29 | std: [0.229, 0.224, 0.225]
30 | ignore_label: 255
31 |
32 | # # # # # # # # # # # # # #
33 | # 2. training params
34 | # # # # # # # # # # # # # #
35 | trainer: # Required.
36 | epochs: 240
37 | sup_only_epoch: 0 # 0, 1
38 | evaluate_student: False
39 | optimizer:
40 | type: SGD
41 | kwargs:
42 | lr: 0.01 # 4GPUs
43 | momentum: 0.9
44 | weight_decay: 0.0005
45 | lr_scheduler:
46 | mode: poly
47 | kwargs:
48 | power: 0.9
49 | unsupervised:
50 | flag_extra_weak: False
51 | threshold: -0.7 # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ params1: 0.7, 0.75, 0.8, 0.85, 0.9, 0.95
52 | loss_weight: 1.0 # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ params2: 0.7, 1.0, 1.5, 2.0, 3.0, 4.0
53 | #------ 2)cutmix augs ------#
54 | use_cutmix: True
55 | use_cutmix_adaptive: True
56 | use_cutmix_trigger_prob: 1.0 # wide range, but trigger by 100%
57 |
58 | # # # # # # # # # # # # # #
59 | # 3. output files, and loss
60 | # # # # # # # # # # # # # #
61 | saver:
62 | snapshot_dir: checkpoints
63 | pretrain: ''
64 | use_tb: False
65 |
66 | criterion:
67 | type: ohem
68 | kwargs:
69 | thresh: 0.7
70 | min_kept: 100000
71 |
72 | # # # # # # # # # # # # # #
73 | # 4. models
74 | # # # # # # # # # # # # # #
75 | net: # Required.
76 | num_classes: 19
77 | sync_bn: True
78 | ema_decay: 0.996
79 | # aux_loss:
80 | # aux_plane: 1024
81 | # loss_weight: 0.4
82 | encoder:
83 | type: augseg.models.resnet.resnet101
84 | pretrain: ./pretrained/resnet101.pth
85 | # type: augseg.models.resnet.resnet50
86 | # pretrain: ./pretrained/resnet50.pth
87 | kwargs:
88 | zero_init_residual: True
89 | multi_grid: True
90 | replace_stride_with_dilation: [False, False, True] # [output_stride = 16]
91 | # replace_stride_with_dilation: [False, True, True] # [output_stride = 8]
92 | decoder:
93 | type: augseg.models.decoder.dec_deeplabv3_plus
94 | kwargs:
95 | inner_planes: 256
96 | low_conv_planes: 48 # 256
97 | dilations: [6, 12, 18]
98 | # dilations: [12, 24, 36] # [output_stride = 8]
99 |
--------------------------------------------------------------------------------
/training-logs/6 Semi-citys/R101-744/config_semi.yaml:
--------------------------------------------------------------------------------
1 | # # # # # # # # # # # # # #
2 | # 1. datasets
3 | # # # # # # # # # # # # # #
4 | dataset: # Required.
5 | type: cityscapes_semi
6 | train:
7 | data_root: ./data/cityscapes
8 | data_list: ./data/splitsall/cityscapes/744/labeled.txt
9 | batch_size: 4
10 | flip: True
11 | rand_resize: [0.5, 2.0]
12 | resize_base_size: [1024, 2048]
13 | crop:
14 | type: rand
15 | size: [800, 800]
16 | #---- 1) strong data augs ----#
17 | strong_aug:
18 | num_augs: 3
19 | flag_use_random_num_sampling: True # flag_use_wide_range
20 |
21 | val:
22 | data_root: ./data/cityscapes
23 | data_list: ./data/splitsall/cityscapes/val.txt
24 | batch_size: 4
25 | crop:
26 | size: [800, 800]
27 | n_sup: 744
28 | workers: 4
29 | mean: [0.485, 0.456, 0.406]
30 | std: [0.229, 0.224, 0.225]
31 | ignore_label: 255
32 |
33 | # # # # # # # # # # # # # #
34 | # 2. training params
35 | # # # # # # # # # # # # # #
36 | trainer: # Required.
37 | epochs: 240
38 | sup_only_epoch: 0 # 0, 1
39 | evaluate_student: True
40 | optimizer:
41 | type: SGD
42 | kwargs:
43 | lr: 0.01 # 4GPUs
44 | momentum: 0.9
45 | weight_decay: 0.0005
46 | lr_scheduler:
47 | mode: poly
48 | kwargs:
49 | power: 0.9
50 | unsupervised:
51 | flag_extra_weak: False
52 | threshold: -0.1 # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ params1: 0.7, 0.75, 0.8, 0.85, 0.9, 0.95
53 | loss_weight: 1.0 # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ params2: 0.7, 1.0, 1.5, 2.0, 3.0, 4.0
54 | #------ 2)cutmix augs ------#
55 | use_cutmix: True
56 | use_cutmix_adaptive: True
57 | use_cutmix_trigger_prob: 1.0 # wide range, but trigger by 100%
58 |
59 | # # # # # # # # # # # # # #
60 | # 3. output files, and loss
61 | # # # # # # # # # # # # # #
62 | saver:
63 | snapshot_dir: checkpoints
64 | pretrain: ''
65 | use_tb: False
66 |
67 | criterion:
68 | type: ohem
69 | kwargs:
70 | thresh: 0.7
71 | min_kept: 100000
72 |
73 | # # # # # # # # # # # # # #
74 | # 4. models
75 | # # # # # # # # # # # # # #
76 | net: # Required.
77 | num_classes: 19
78 | sync_bn: True
79 | ema_decay: 0.996
80 | # aux_loss:
81 | # aux_plane: 1024
82 | # loss_weight: 0.4
83 | encoder:
84 | type: augseg.models.resnet.resnet101
85 | pretrain: ./pretrained/resnet101.pth
86 | # type: augseg.models.resnet.resnet50
87 | # pretrain: ./pretrained/resnet50.pth
88 | kwargs:
89 | zero_init_residual: True
90 | multi_grid: True
91 | replace_stride_with_dilation: [False, False, True] # [output_stride = 16]
92 | # replace_stride_with_dilation: [False, True, True] # [output_stride = 8]
93 | decoder:
94 | type: augseg.models.decoder.dec_deeplabv3_plus
95 | kwargs:
96 | inner_planes: 256
97 | low_conv_planes: 48 # 256
98 | dilations: [6, 12, 18]
99 | # dilations: [12, 24, 36] # [output_stride = 8]
100 |
--------------------------------------------------------------------------------
/training-logs/6 Semi-citys/R50-1488/config_semi.yaml:
--------------------------------------------------------------------------------
1 | # # # # # # # # # # # # # #
2 | # 1. datasets
3 | # # # # # # # # # # # # # #
4 | dataset: # Required.
5 | type: cityscapes_semi
6 | train:
7 | data_root: ./data/cityscapes
8 | data_list: ./data/splitsall/cityscapes/1488/labeled.txt
9 | batch_size: 4
10 | flip: True
11 | rand_resize: [0.5, 2.0]
12 | resize_base_size: [1024, 2048]
13 | crop:
14 | type: rand
15 | size: [800, 800]
16 | #---- 1) strong data augs ----#
17 | strong_aug:
18 | num_augs: 3
19 | flag_use_random_num_sampling: True
20 | val:
21 | data_root: ./data/cityscapes
22 | data_list: ./data/splitsall/cityscapes/val.txt
23 | batch_size: 4
24 | crop:
25 | size: [800, 800]
26 | n_sup: 1488
27 | workers: 4
28 | mean: [0.485, 0.456, 0.406]
29 | std: [0.229, 0.224, 0.225]
30 | ignore_label: 255
31 |
32 | # # # # # # # # # # # # # #
33 | # 2. training params
34 | # # # # # # # # # # # # # #
35 | trainer: # Required.
36 | epochs: 240
37 | sup_only_epoch: 0 # 0, 1
38 | evaluate_student: True
39 | optimizer:
40 | type: SGD
41 | kwargs:
42 | lr: 0.01 # 4GPUs
43 | momentum: 0.9
44 | weight_decay: 0.0005
45 | lr_scheduler:
46 | mode: poly
47 | kwargs:
48 | power: 0.9
49 | unsupervised:
50 | flag_extra_weak: False
51 | threshold: -0.7 # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ params1: 0.7, 0.75, 0.8, 0.85, 0.9, 0.95
52 | loss_weight: 1.0 # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ params2: 0.7, 1.0, 1.5, 2.0, 3.0, 4.0
53 | #------ 2)cutmix augs ------#
54 | use_cutmix: True
55 | use_cutmix_adaptive: True
56 | use_cutmix_trigger_prob: 1.0 # wide range, but trigger by 100%
57 |
58 | # # # # # # # # # # # # # #
59 | # 3. output files, and loss
60 | # # # # # # # # # # # # # #
61 | saver:
62 | snapshot_dir: checkpoints
63 | pretrain: ''
64 | use_tb: False
65 |
66 | criterion:
67 | type: ohem
68 | kwargs:
69 | thresh: 0.7
70 | min_kept: 100000
71 |
72 | # # # # # # # # # # # # # #
73 | # 4. models
74 | # # # # # # # # # # # # # #
75 | net: # Required.
76 | num_classes: 19
77 | sync_bn: True
78 | ema_decay: 0.996
79 | # aux_loss:
80 | # aux_plane: 1024
81 | # loss_weight: 0.4
82 | encoder:
83 | # type: augseg.models.resnet.resnet101
84 | # pretrain: ./pretrained/resnet101.pth
85 | type: augseg.models.resnet.resnet50
86 | pretrain: ./pretrained/resnet50.pth
87 | kwargs:
88 | zero_init_residual: True
89 | multi_grid: True
90 | replace_stride_with_dilation: [False, False, True] # [output_stride = 16]
91 | # replace_stride_with_dilation: [False, True, True] # [output_stride = 8]
92 | decoder:
93 | type: augseg.models.decoder.dec_deeplabv3_plus
94 | kwargs:
95 | inner_planes: 256
96 | low_conv_planes: 48 # 256
97 | dilations: [6, 12, 18]
98 | # dilations: [12, 24, 36] # [output_stride = 8]
99 |
--------------------------------------------------------------------------------
/training-logs/6 Semi-citys/R50-186/config_semi.yaml:
--------------------------------------------------------------------------------
1 | # # # # # # # # # # # # # #
2 | # 1. datasets
3 | # # # # # # # # # # # # # #
4 | dataset: # Required.
5 | type: cityscapes_semi
6 | train:
7 | data_root: ./data/cityscapes
8 | data_list: ./data/splitsall/cityscapes/186/labeled.txt
9 | batch_size: 4
10 | flip: True
11 | rand_resize: [0.5, 2.0]
12 | resize_base_size: [1024, 2048]
13 | crop:
14 | type: rand
15 | size: [800, 800]
16 | #---- 1) strong data augs ----#
17 | strong_aug:
18 | num_augs: 3
19 | flag_use_random_num_sampling: True # flag_use_wide_range
20 |
21 | val:
22 | data_root: ./data/cityscapes
23 | data_list: ./data/splitsall/cityscapes/val.txt
24 | batch_size: 4
25 | crop:
26 | size: [800, 800]
27 | n_sup: 186
28 | workers: 4
29 | mean: [0.485, 0.456, 0.406]
30 | std: [0.229, 0.224, 0.225]
31 | ignore_label: 255
32 |
33 | # # # # # # # # # # # # # #
34 | # 2. training params
35 | # # # # # # # # # # # # # #
36 | trainer: # Required.
37 | epochs: 240
38 | sup_only_epoch: 0 # 0, 1
39 | evaluate_student: False
40 | optimizer:
41 | type: SGD
42 | kwargs:
43 | lr: 0.01 # 4GPUs
44 | momentum: 0.9
45 | weight_decay: 0.0005
46 | lr_scheduler:
47 | mode: poly
48 | kwargs:
49 | power: 0.9
50 | unsupervised:
51 | flag_extra_weak: False
52 | threshold: -0.1 # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ params1: 0.7, 0.75, 0.8, 0.85, 0.9, 0.95
53 | loss_weight: 1.0 # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ params2: 0.7, 1.0, 1.5, 2.0, 3.0, 4.0
54 | #------ 2)cutmix augs ------#
55 | use_cutmix: True
56 | use_cutmix_adaptive: True
57 | use_cutmix_trigger_prob: 1.0 # wide range, but trigger by 100%
58 |
59 | # # # # # # # # # # # # # #
60 | # 3. output files, and loss
61 | # # # # # # # # # # # # # #
62 | saver:
63 | snapshot_dir: checkpoints
64 | pretrain: ''
65 | use_tb: False
66 |
67 | criterion:
68 | type: ohem
69 | kwargs:
70 | thresh: 0.7
71 | min_kept: 100000
72 |
73 | # # # # # # # # # # # # # #
74 | # 4. models
75 | # # # # # # # # # # # # # #
76 | net: # Required.
77 | num_classes: 19
78 | sync_bn: True
79 | ema_decay: 0.996
80 | # aux_loss:
81 | # aux_plane: 1024
82 | # loss_weight: 0.4
83 | encoder:
84 | # type: augseg.models.resnet.resnet101
85 | # pretrain: ./pretrained/resnet101.pth
86 | type: augseg.models.resnet.resnet50
87 | pretrain: ./pretrained/resnet50.pth
88 | kwargs:
89 | zero_init_residual: True
90 | multi_grid: True
91 | replace_stride_with_dilation: [False, False, True] # [output_stride = 16]
92 | # replace_stride_with_dilation: [False, True, True] # [output_stride = 8]
93 | decoder:
94 | type: augseg.models.decoder.dec_deeplabv3_plus
95 | kwargs:
96 | inner_planes: 256
97 | low_conv_planes: 48 # 256
98 | dilations: [6, 12, 18]
99 | # dilations: [12, 24, 36] # [output_stride = 8]
100 |
--------------------------------------------------------------------------------
/training-logs/6 Semi-citys/R50-372/config_semi.yaml:
--------------------------------------------------------------------------------
1 | # # # # # # # # # # # # # #
2 | # 1. datasets
3 | # # # # # # # # # # # # # #
4 | dataset: # Required.
5 | type: cityscapes_semi
6 | train:
7 | data_root: ./data/cityscapes
8 | data_list: ./data/splitsall/cityscapes/372/labeled.txt
9 | batch_size: 4
10 | flip: True
11 | rand_resize: [0.5, 2.0]
12 | resize_base_size: [1024, 2048]
13 | crop:
14 | type: rand
15 | size: [800, 800]
16 | #---- 1) strong data augs ----#
17 | strong_aug:
18 | num_augs: 3
19 | flag_use_random_num_sampling: True
20 | val:
21 | data_root: ./data/cityscapes
22 | data_list: ./data/splitsall/cityscapes/val.txt
23 | batch_size: 4
24 | crop:
25 | size: [800, 800]
26 | n_sup: 372
27 | workers: 4
28 | mean: [0.485, 0.456, 0.406]
29 | std: [0.229, 0.224, 0.225]
30 | ignore_label: 255
31 |
32 | # # # # # # # # # # # # # #
33 | # 2. training params
34 | # # # # # # # # # # # # # #
35 | trainer: # Required.
36 | epochs: 240
37 | sup_only_epoch: 0 # 0, 1
38 | evaluate_student: True
39 | optimizer:
40 | type: SGD
41 | kwargs:
42 | lr: 0.01 # 4GPUs
43 | momentum: 0.9
44 | weight_decay: 0.0005
45 | lr_scheduler:
46 | mode: poly
47 | kwargs:
48 | power: 0.9
49 | unsupervised:
50 | flag_extra_weak: False
51 | threshold: -0.1 # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ params1: 0.7, 0.75, 0.8, 0.85, 0.9, 0.95
52 | loss_weight: 1.0 # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ params2: 0.7, 1.0, 1.5, 2.0, 3.0, 4.0
53 | #------ 2)cutmix augs ------#
54 | use_cutmix: True
55 | use_cutmix_adaptive: True
56 | use_cutmix_trigger_prob: 1.0 # wide range, but trigger by 100%
57 |
58 | # # # # # # # # # # # # # #
59 | # 3. output files, and loss
60 | # # # # # # # # # # # # # #
61 | saver:
62 | snapshot_dir: checkpoints
63 | pretrain: ''
64 | use_tb: False
65 |
66 | criterion:
67 | type: ohem
68 | kwargs:
69 | thresh: 0.7
70 | min_kept: 100000
71 |
72 | # # # # # # # # # # # # # #
73 | # 4. models
74 | # # # # # # # # # # # # # #
75 | net: # Required.
76 | num_classes: 19
77 | sync_bn: True
78 | ema_decay: 0.996
79 | # aux_loss:
80 | # aux_plane: 1024
81 | # loss_weight: 0.4
82 | encoder:
83 | # type: augseg.models.resnet.resnet101
84 | # pretrain: ./pretrained/resnet101.pth
85 | type: augseg.models.resnet.resnet50
86 | pretrain: ./pretrained/resnet50.pth
87 | kwargs:
88 | zero_init_residual: True
89 | multi_grid: True
90 | replace_stride_with_dilation: [False, False, True] # [output_stride = 16]
91 | # replace_stride_with_dilation: [False, True, True] # [output_stride = 8]
92 | decoder:
93 | type: augseg.models.decoder.dec_deeplabv3_plus
94 | kwargs:
95 | inner_planes: 256
96 | low_conv_planes: 48 # 256
97 | dilations: [6, 12, 18]
98 | # dilations: [12, 24, 36] # [output_stride = 8]
99 |
--------------------------------------------------------------------------------
/training-logs/6 Semi-citys/R50-744/config_semi.yaml:
--------------------------------------------------------------------------------
1 | # # # # # # # # # # # # # #
2 | # 1. datasets
3 | # # # # # # # # # # # # # #
4 | dataset: # Required.
5 | type: cityscapes_semi
6 | train:
7 | data_root: ./data/cityscapes
8 | data_list: ./data/splitsall/cityscapes/744/labeled.txt
9 | batch_size: 4
10 | flip: True
11 | rand_resize: [0.5, 2.0]
12 | resize_base_size: [1024, 2048]
13 | crop:
14 | type: rand
15 | size: [800, 800]
16 | #---- 1) strong data augs ----#
17 | strong_aug:
18 | num_augs: 3
19 | flag_use_random_num_sampling: True
20 | val:
21 | data_root: ./data/cityscapes
22 | data_list: ./data/splitsall/cityscapes/val.txt
23 | batch_size: 4
24 | crop:
25 | size: [800, 800]
26 | n_sup: 744
27 | workers: 4
28 | mean: [0.485, 0.456, 0.406]
29 | std: [0.229, 0.224, 0.225]
30 | ignore_label: 255
31 |
32 | # # # # # # # # # # # # # #
33 | # 2. training params
34 | # # # # # # # # # # # # # #
35 | trainer: # Required.
36 | epochs: 240
37 | sup_only_epoch: 0 # 0, 1
38 | evaluate_student: False
39 | optimizer:
40 | type: SGD
41 | kwargs:
42 | lr: 0.01 # 4GPUs
43 | momentum: 0.9
44 | weight_decay: 0.0005
45 | lr_scheduler:
46 | mode: poly
47 | kwargs:
48 | power: 0.9
49 | unsupervised:
50 | flag_extra_weak: False
51 | threshold: -0.7 # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ params1: 0.7, 0.75, 0.8, 0.85, 0.9, 0.95
52 | loss_weight: 1.0 # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ params2: 0.7, 1.0, 1.5, 2.0, 3.0, 4.0
53 | #------ 2)cutmix augs ------#
54 | use_cutmix: True
55 | use_cutmix_adaptive: True
56 | use_cutmix_trigger_prob: 1.0 # wide range, but trigger by 100%
57 |
58 | # # # # # # # # # # # # # #
59 | # 3. output files, and loss
60 | # # # # # # # # # # # # # #
61 | saver:
62 | snapshot_dir: checkpoints
63 | pretrain: ''
64 | use_tb: False
65 |
66 | criterion:
67 | type: ohem
68 | kwargs:
69 | thresh: 0.7
70 | min_kept: 100000
71 |
72 | # # # # # # # # # # # # # #
73 | # 4. models
74 | # # # # # # # # # # # # # #
75 | net: # Required.
76 | num_classes: 19
77 | sync_bn: True
78 | ema_decay: 0.996
79 | # aux_loss:
80 | # aux_plane: 1024
81 | # loss_weight: 0.4
82 | encoder:
83 | # type: augseg.models.resnet.resnet101
84 | # pretrain: ./pretrained/resnet101.pth
85 | type: augseg.models.resnet.resnet50
86 | pretrain: ./pretrained/resnet50.pth
87 | kwargs:
88 | zero_init_residual: True
89 | multi_grid: True
90 | replace_stride_with_dilation: [False, False, True] # [output_stride = 16]
91 | # replace_stride_with_dilation: [False, True, True] # [output_stride = 8]
92 | decoder:
93 | type: augseg.models.decoder.dec_deeplabv3_plus
94 | kwargs:
95 | inner_planes: 256
96 | low_conv_planes: 48 # 256
97 | dilations: [6, 12, 18]
98 | # dilations: [12, 24, 36] # [output_stride = 8]
99 |
--------------------------------------------------------------------------------