├── .gitignore ├── MODEL_ZOO.md ├── README.md ├── batch_inference.py ├── config ├── __init__.py ├── class_color.py ├── class_label.py └── config.py ├── datasets └── voc2012 │ ├── __init__.py │ ├── color2cls.py │ ├── dataset.py │ ├── make_data.py │ ├── make_data_aug.py │ └── process_data.py ├── hyparam ├── FCNs │ ├── fcn_mbv2_8s.yaml │ ├── fcn_mbv3_8s.yaml │ ├── fcn_resnet101_8s.yaml │ ├── fcn_resnet152_8s.yaml │ ├── fcn_resnet50_8s.yaml │ ├── fcn_vgg16_16s.yaml │ ├── fcn_vgg16_32s.yaml │ └── fcn_vgg16_8s.yaml ├── HRNet │ └── baseline_320.yaml ├── SegNet │ ├── seg_mbv2.yaml │ ├── seg_r50.yaml │ ├── seg_vgg16.yaml │ └── seg_vgg16_pool.yaml ├── U2Net │ ├── adjust_lr_bs.yaml │ ├── baseline.yaml │ ├── baseline_bce.yaml │ ├── baseline_bce_dice.yaml │ ├── baseline_bce_dice_colorjitter.yaml │ ├── baseline_bce_dice_l1c.yaml │ ├── baseline_bce_dice_l1c_pretrain.yaml │ ├── baseline_bce_dice_pretrain_320_data.yaml │ ├── baseline_bce_dice_pretrain_640.yaml │ ├── baseline_bce_dice_pretrain_768_colorjitter.yaml │ ├── baseline_bce_dice_tv.yaml │ ├── baseline_ce_dice_l1_no_pretrain_640.yaml │ ├── baseline_ce_dice_l1_pretrain.yaml │ └── baseline_ce_pretrain_480_data.yaml ├── UNet │ ├── unet_full.yaml │ ├── unet_full2.yaml │ ├── unet_full3.yaml │ ├── unet_full_adam.yaml │ └── unet_resnet50.yaml └── base.yaml ├── inference_api.py ├── losses ├── __init__.py ├── generatorLoss.py └── loss.py ├── main.py ├── models ├── DeepLab │ ├── __init__.py │ └── deeplab.py ├── FCN │ ├── __init__.py │ ├── fcn.py │ ├── fcn_mobilenetv2.py │ ├── fcn_mobilenetv3.py │ ├── fcn_resnet.py │ └── pretrained │ │ └── mbv3_large.pth.tar ├── HRNet │ ├── __init__.py │ ├── config │ │ ├── default.py │ │ └── seg_hrnet_w48.yaml │ └── hrnet_seg.py ├── SegNet │ ├── __init__.py │ ├── net.py │ ├── seg_mobilenetv2.py │ ├── seg_resnet.py │ └── seg_vgg16.py ├── U2Net │ ├── __init__.py │ ├── u2net.py │ └── u2net_bp.py ├── UNet │ ├── __init__.py │ ├── unet.py │ └── unet_resnet.py ├── __init__.py └── model_factory.py ├── post_process.py ├── script ├── FCNs │ ├── train.sh │ └── train2.sh ├── HRNet │ ├── train.sh │ └── train2.sh ├── U2Net │ ├── train.sh │ ├── train2.sh │ └── trian3.sh └── train_fcn.sh ├── train.py ├── train.sh └── utils ├── DataAugments.py ├── FuseAugments.py ├── Loss.py ├── LrSheduler.py ├── Metirc.py ├── Optim.py ├── Summary.py ├── __init__.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | **__pycache__** 3 | exp2 4 | 1_crop.jpg 5 | test_rotate.py 6 | process 7 | fcn_mbv2_8s 8 | tmp 9 | TNN 10 | *.jpg 11 | *.onnx 12 | *.png 13 | torch2onnx.py 14 | inference_cloth.py 15 | inference.py 16 | test_loader.py 17 | *.pth 18 | cloth 19 | cloth2 20 | *log 21 | datasets/voc2012/data/ 22 | makeData.py 23 | merge_data.py 24 | to_srcipt.py 25 | *.DS_Store* 26 | -------------------------------------------------------------------------------- /MODEL_ZOO.md: -------------------------------------------------------------------------------- 1 | # MODEL ZOO 2 | ### VOC2012 3 | |Models|BatchSize|GPUs|HyParameters|CropSize|GFLOPs|PA|MPA|MIOU|FWIOU|Weights| 4 | |:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:| 5 | |FCNS_vgg16_32s|8|1|[fcn_vgg16_32s](/data/jiangmingchao/data/code/SegmentationLight/hyparam/FCNs/fcn_vgg16_32s.yaml)|512|-|82.9%|50.6%|39.3%|72.2%|/data/jiangmingchao/data/AICKPT/Seg/FCNs/fcn32_vgg16_gpux1_512/checkpoints/best_ckpt_losses_0.5197840489365242_miou_0.39294091458955355.pth| 6 | |FCNS_vgg16_16s|8|1|[fcn_vgg16_16s](/data/jiangmingchao/data/code/SegmentationLight/hyparam/FCNs/fcn_vgg16_16s.yaml)|512|-|86%|57.8%|44.9%|77.1%|/data/jiangmingchao/data/AICKPT/Seg/FCNs/fcn16_vgg16_gpux1_512/checkpoints/best_ckpt_losses_0.4749274262851411_miou_0.449459438309746.pth| 7 | |FCNS_vgg16_8s|16|1|[fcn_vgg16_8s](hyparam/FCNs/fcn_vgg16_8s.yaml)|512|-|87.3%|59.1%|48.1%|78.6%|/data/jiangmingchao/data/AICKPT/Seg/FCNs/fcn8_vgg16_gpux1_512/checkpoints/best_ckpt_losses_0.43843183134283337_miou_0.48058054398118843.pth| 8 | |FCNS_r50_8s|8|1|[fcn_r50_8s](hyparam/FCNs/fcn_resnet50_8s.yaml)|512|-|90.1%|69.9%|58.7%|82.8%|/data/jiangmingchao/data/AICKPT/Seg/FCNs/fcn8_r50_gpux1_512/checkpoints/best_ckpt_losses_0.31139873512662375_miou_0.5870208409364805.pth| 9 | |FCNS_r101_8s|8|1|[fcn_r101_8s](hyparam/FCNs/fcn_resnet101_8s.yaml)|512|-|90.8%|71.7%|61.1%|83.9%|/data/jiangmingchao/data/AICKPT/Seg/FCNs/fcn8_r101_gpux1_512/checkpoints/best_ckpt_losses_0.28888792892570025_miou_0.6112685436040116.pth| 10 | |FCNS_r152_8s|8|1|[fcn_r152_8s](hyparam/FCNs/fcn_resnet152_8s.yaml)|512|-|91.5%|74.1%|63.6%|85.0%|/data/jiangmingchao/data/AICKPT/Seg/FCNs/fcn8_r152_gpux1_512/checkpoints/best_ckpt_losses_0.262765350823219_miou_0.6356795263547104.pth| 11 | |FCNS_mbv2_8s|8|1|[fcn_mbv2_8s](hyparam/FCNs/fcn_mbv2_8s.yaml)|512|-|87.5%|56.7%|47.6%|78.7%|/data/jiangmingchao/data/AICKPT/Seg/FCNs/fcn8_mbv2_gpux1_512/checkpoints/best_ckpt_losses_0.47803407727362035_miou_0.47626235483478524.pth| 12 | |FCNS_mbv3_8s|8|1|[fcn_mbv3_8s](hyparam/FCNs/fcn_mbv3_8s.yaml)|512|-|87.4%|55.4%|46.4%|78.5%|/data/jiangmingchao/data/AICKPT/Seg/FCNs/fcn8_mbv3_gpux1_512/checkpoints/best_ckpt_losses_0.4714071523029726_miou_0.4640446924410685.pth| 13 | |SegNet_vgg16_upsample|16|1|[segnet_vgg16_up](hyparam/SegNet/seg_vgg16.yaml)|512|-|87.7%|59.0%|48.6%|79.3%|/data/jiangmingchao/data/AICKPT/Seg/SegNet/segnet_vgg16_gpux1_512/checkpoints/best_ckpt_losses_0.443921719114859_miou_0.4860254266784411.pth| 14 | |SegNet_vgg16_pool|16|1|[segnet_vgg16_pool](hyparam/SegNet/seg_vgg16_pool.yaml)|512|-|87.0%|54.5%|44.4%|78.2%|/data/jiangmingchao/data/AICKPT/Seg/SegNet/segnet_vgg16_gpux1_512_pool/checkpoints/best_ckpt_losses_0.5337066383479716_miou_0.4440478477271453.pth| 15 | |SegNet_R50_up|8|1|[segnet_r50_up](hyparam/SegNet/seg_r50.yaml)|512|-|88.5%|60.8%|51.1%|80.2%|/data/jiangmingchao/data/AICKPT/Seg/SegNet/segnet_r50_gpux1_512/checkpoints/best_ckpt_losses_0.3944489204294079_miou_0.5108976440062539.pth| 16 | |SegNet_mobilenetv2_up|8|1|[segnet_mobilenetv2_up](hyparam/SegNet/seg_mbv2.yaml)|512|-|88.3%|59.5%|49.7%|79.9%|/data/jiangmingchao/data/AICKPT/Seg/SegNet/segnet_mbv2_gpux1_512/checkpoints/best_ckpt_losses_0.42041630568085137_miou_0.4966990811194377.pth| 17 | |UNet_resnet50|8|2|[unet_resnet50](hyparam/UNet/unet_resnet50.yaml)|512|-|89.6%|64.2%|53.8%|82.0%|/data/jiangmingchao/data/AICKPT/Seg/UNet/unet_resnet50/checkpoints/best_ckpt_losses_0.412014008632728_miou_0.5454743759021599.pth| 18 | 19 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SegmentationLight 2 | 3 | This repo is used for Semantic segmentation which code is easy to understand and effective. 4 | 5 | ### Generate Model Zoo 6 | - FCNs 7 | - SegNet 8 | - DeepLab 9 | - HRNet 10 | - Unet 11 | - U2net 12 | 13 | A simple benchmark is provided [here](MODEL_ZOO.md), others models(cnn & transformers) will be added if have some time. 14 | 15 | ### Code Structure 16 | ``` 17 | ├── config ------- use for make config class 18 | | ├── class_color.py 19 | │   ├── class_label.py 20 | │   ├── config.py 21 | │   ├── __init__.py 22 | ├── datasets ------- datasets class 23 | | ├── coco2017 24 | | └── voc2012 25 | | ├── color2cls.py 26 | | ├── data 27 | | ├── dataset.py 28 | | ├── __init__.py 29 | | ├── make_data_aug.py 30 | | ├── make_data.py 31 | | ├── process_data.py 32 | ├── hyparam -------- hyparam 33 | | ├── base.yaml 34 | │   ├── FCNs 35 | │   ├── HRNet 36 | │   ├── SegNet 37 | │   ├── U2Net 38 | │   └── UNet 39 | ├── inference_api.py -------- inference 40 | ├── losses -------- losses 41 | | ├── generatorLoss.py 42 | │   ├── __init__.py 43 | │   ├── loss.py 44 | ├── main.py -------- main file 45 | ├── models -------- models factory 46 | | ├── DeepLab 47 | │   ├── FCN 48 | │   ├── HRNet 49 | │   ├── __init__.py 50 | │   ├── model_factory.py 51 | │   ├── SegNet 52 | │   ├── U2Net 53 | │   └── UNet 54 | ├── post_process.py -------- post process 55 | ├── README.md -------- readme 56 | ├── script -------- bash 57 | | ├── FCNs 58 | │   ├── HRNet 59 | │   ├── SegNet 60 | │   └── U2Net 61 | | └── train.sh 62 | ├── train.py -------- train file 63 | ├── utils -------- utils 64 | │   ├── DataAugments.py 65 | │   ├── FuseAugments.py 66 | │   ├── __init__.py 67 | │   ├── Loss.py 68 | │   ├── LrSheduler.py 69 | │   ├── Metirc.py 70 | │   ├── Optim.py 71 | │   ├── Summary.py 72 | │   └── utils.py 73 | ``` 74 | 75 | ### Clothes Segmentation 76 | 77 | - Train 78 | ```bash 79 | #!/bin/bash 80 | OMP_NUM_THREADS=1 81 | MKL_NUM_THREADS=1 82 | export OMP_NUM_THREADS 83 | export MKL_NUM_THREADS 84 | cd /data/jiangmingchao/data/code/SegmentationLight; 85 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -W ignore -m torch.distributed.launch \ 86 | --master_port 2959 \ 87 | --nproc_per_node 8 main.py \ 88 | --hyp /data/jiangmingchao/data/code/SegmentationLight/hyparam/U2Net/baseline_bce_dice_pretrain_320_data.yaml 89 | ``` 90 | - Inference 91 | ```python 92 | python inference_api.py 93 | ``` 94 | - Batch Inference 95 | ```python 96 | python batch_inference.py 97 | ``` 98 | 99 | ### Dataset 100 | #### Custom Dataset 101 | 102 | 1. You need prepare the images & mask pair log, the format like this, each line is a json 103 | ``` 104 | {"image_path":"xxxxx/1.jpg", "label_path": "xxxxxx/1.png"} 105 | {"image_path":"xxxxx/2.jpg", "label_path": "xxxxxx/2.png"} 106 | {"image_path":"xxxxx/3.jpg", "label_path": "xxxxxx/3.png"} 107 | ``` 108 | 2. make the train.log and val.log, modify the config.yaml 109 | ``` 110 | CUSTOM_DATASET: 111 | TRAIN_FILE: "train.log" 112 | VAL_FILE: "val.log" 113 | NUM_CLASSES: 2 # used for celoss 114 | ``` 115 | 3. modify other hyparams in the config for you training. 116 | 117 | ### Resume 118 | if you want to resume the training, you only need add resume in the config yaml 119 | ``` 120 | RESUME: True 121 | RESUME_CHECKPOINTS: XXX.pth 122 | ``` 123 | -------------------------------------------------------------------------------- /batch_inference.py: -------------------------------------------------------------------------------- 1 | """Batch inference 2 | @author:Flyegle 3 | @datetime: 2022-06-02 4 | """ 5 | import os 6 | import cv2 7 | import json 8 | import torch 9 | import random 10 | import numpy as np 11 | import torch.nn.functional as F 12 | 13 | from models.model_factory import ModelFactory 14 | from utils.DataAugments import Normalize, Scale, ToTensor 15 | from torch.cuda.amp import autocast as autocast 16 | from torch.utils.data.dataset import Dataset 17 | from torch.utils.data.dataloader import DataLoader 18 | from PIL import Image 19 | 20 | 21 | def aug(images): 22 | images, _ = Scale((800, 800))(images, images) 23 | images, _ = Normalize(normalize=True)(images, images) 24 | images, _ = ToTensor()(images, images) 25 | return images 26 | 27 | 28 | def load_ckpt(net, model_ckpt): 29 | state_dict = torch.load(model_ckpt, map_location="cpu")['state_dict'] 30 | net.load_state_dict(state_dict) 31 | print(f"load the ckpt {model_ckpt}") 32 | return net 33 | 34 | 35 | class DataSet(Dataset): 36 | def __init__(self, file): 37 | super(DataSet, self).__init__() 38 | self.data_list = [] 39 | self._load_data(file) 40 | self.data_index = [x for x in range(len(self.data_list))] 41 | 42 | def _load_data(self, file): 43 | if os.path.isdir(file): 44 | self.data_list = [os.path.join(file, x) for x in os.listdir(file)] 45 | elif os.path.isfile(file): 46 | data_list = [x.strip() for x in open(file).readlines()] 47 | if "image_path" in data_list[0]: 48 | self.data_list = [json.loads(x)["image_path"] for x in data_list] 49 | else: 50 | self.data_list = data_list 51 | else: 52 | raise IOError(f"{file} must be images folder or meta log") 53 | 54 | def __getitem__(self, index): 55 | for _ in range(10): 56 | try: 57 | image = cv2.imread(self.data_list[index]) 58 | h, w, _ = image.shape 59 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 60 | image_tensor = aug(image) 61 | image_name = self.data_list[index].split('/')[-1].split('.')[0] 62 | return image_tensor, (w, h), image_name 63 | except Exception as e: 64 | print(f"{self.data_list[index]} have some error, need change another!!! {e}") 65 | index = random.choice(self.data_index) 66 | 67 | def __len__(self): 68 | return len(self.data_list) 69 | 70 | 71 | class SegNet: 72 | def __init__(self, model_name, num_classes, weights): 73 | self.model_name = model_name 74 | self.num_classes = num_classes 75 | self.weights = weights 76 | 77 | # build model 78 | model_factory = ModelFactory() 79 | self.net = model_factory.getattr(model_name)(num_classes=self.num_classes) 80 | load_ckpt(self.net, self.weights) 81 | 82 | # cuda & eval 83 | if torch.cuda.is_available(): 84 | self.net.cuda() 85 | self.net.eval() 86 | 87 | @torch.no_grad() 88 | def infer(self, images): 89 | """images: np.ndarray RGB 90 | Return: 91 | outputs: torch.Tensor 92 | """ 93 | with autocast(): 94 | images = images.cuda() 95 | outputs = self.net(images) 96 | return outputs 97 | 98 | 99 | def post_process(outputs, shape, name, out_folder): 100 | outputs = outputs[0] 101 | b, c, h, w = outputs.shape 102 | for i in range(b): 103 | output = F.interpolate(outputs[i,:,:,:].unsqueeze(0), size=[shape[1][i], shape[0][i]], mode="bilinear", align_corners=True) 104 | output = torch.sigmoid(output) 105 | output = output.permute(0,2,3,1).cpu().numpy() 106 | # b, c, h, w 107 | output[output >= 0.9] = 1 108 | output[output < 0.9] = 0 109 | 110 | output = output * 255 111 | mask = output.astype(np.uint8) 112 | output = np.concatenate((mask[0,:,:,:], mask[0,:,:,:],mask[0,:,:,:]), axis=-1) 113 | 114 | path = os.path.join(out_folder, name[i]+'.png') 115 | cv2.imwrite(path, output) 116 | 117 | 118 | 119 | if __name__ == '__main__': 120 | 121 | model = SegNet(model_name="u2net", num_classes=1, weights="/data/jiangmingchao/data/AICKPT/Seg/U2Net/400_epoch_800_crop_0.3_1.2_1E-4_circle_2_27k_copypaste_sgd/checkpoints/best_ckpt_epoch_133_losses_0.21228863086019242_miou_0.9878880708590272.pth") 122 | 123 | file = "/data/jiangmingchao/data/code/cluster/shein_2k_1w.log" 124 | out_file = "/data/jiangmingchao/data/dataset/shein_2k_1w_720_patch/mask_batch" 125 | 126 | if not os.path.exists(out_file): 127 | os.makedirs(out_file) 128 | 129 | dataset = DataSet(file) 130 | 131 | loader = DataLoader( 132 | dataset, 133 | batch_size=32, 134 | shuffle=False, 135 | sampler=None, 136 | num_workers=32 137 | ) 138 | 139 | length = len(loader) 140 | for idx, data in enumerate(loader): 141 | images, shape, name = data 142 | outputs = model.infer(images) 143 | 144 | post_process(outputs, shape, name, out_file) 145 | print(f"process {idx}/{length} mask!!!") 146 | 147 | 148 | 149 | -------------------------------------------------------------------------------- /config/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FlyEgle/segmentationlight/45c3e574f578bac046bd6027d2f3dbb7d106e015/config/__init__.py -------------------------------------------------------------------------------- /config/class_color.py: -------------------------------------------------------------------------------- 1 | # voc colors 2 | VOC_2012_CLASSES = { 3 | 0: (0, 0, 0), 4 | 1: (128, 0, 0), 5 | 2: (0, 128, 0), 6 | 3: (128, 128, 0), 7 | 4: (0, 0, 128), 8 | 5: (128, 0, 128), 9 | 6: (0, 128, 128), 10 | 7: (128, 128, 128), 11 | 8: (64, 0, 0), 12 | 9: (192, 0, 0), 13 | 10: (64, 128, 0), 14 | 11: (192, 128, 0), 15 | 12: (64, 0, 128), 16 | 13: (192, 0, 128), 17 | 14: (64, 128, 128), 18 | 15: (192, 128, 128), 19 | 16: (0, 64, 0), 20 | 17: (128, 64, 0), 21 | 18: (0, 192, 0), 22 | 19: (128, 192, 0), 23 | 20: (0,64,128) 24 | } -------------------------------------------------------------------------------- /config/class_label.py: -------------------------------------------------------------------------------- 1 | # voc labels 2 | VOC_2012_LABELS = { 3 | 0:'background', 4 | 1:'aeroplane', 5 | 2:'bicycle', 6 | 3:'bird', 7 | 4:'boat', 8 | 5:'bottle', 9 | 6:'bus', 10 | 7:'car', 11 | 8:'cat', 12 | 9:'chair', 13 | 10:'cow', 14 | 11:'diningtable', 15 | 12:'dog', 16 | 13:'horse', 17 | 14:'motorbike', 18 | 15:'person', 19 | 16:'pottedplant', 20 | 17:'sheep', 21 | 18:'sofa', 22 | 19:'train', 23 | 20:'tvmonitor', 24 | } -------------------------------------------------------------------------------- /config/config.py: -------------------------------------------------------------------------------- 1 | """make argparse 2 | @author: FlyEgle 3 | @datetime: 2022-01-20 4 | """ 5 | import yaml 6 | import argparse 7 | from dotmap import DotMap 8 | 9 | 10 | def build_argparse(): 11 | parser = argparse.ArgumentParser() 12 | # ------------------------------- 13 | parser.add_argument('--hyp', type=str, default="/data/jiangmingchao/data/code/SegmentationLight/hyparam/base.yaml") 14 | parser.add_argument('--ngpu', type=int, default=1) 15 | parser.add_argument('--rank', default=-1, type=int, help='node rank for distributed training') 16 | parser.add_argument('--dist-backend', default='nccl', type=str, help='distributed backend') 17 | parser.add_argument('--local_rank', default=-1, type=int) 18 | hyp = parser.parse_args() 19 | return hyp 20 | 21 | def parse_yaml(yam_path): 22 | with open(yam_path) as f: 23 | data = yaml.load(f, Loader=yaml.FullLoader) 24 | data = DotMap(data) 25 | return data 26 | 27 | 28 | if __name__ == '__main__': 29 | yaml_file = "/data/jiangmingchao/data/code/SegmentationLight/hyparam/base.yaml" 30 | with open(yaml_file) as file: 31 | data = yaml.load(file) 32 | config = DotMap(data) 33 | print(config) 34 | # print(data.keys()) -------------------------------------------------------------------------------- /datasets/voc2012/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FlyEgle/segmentationlight/45c3e574f578bac046bd6027d2f3dbb7d106e015/datasets/voc2012/__init__.py -------------------------------------------------------------------------------- /datasets/voc2012/color2cls.py: -------------------------------------------------------------------------------- 1 | """Translate the color map to class 2 | @author: FlyEgle 3 | @datetime: 2022-01-19 4 | """ 5 | import os 6 | import cv2 7 | import json 8 | import numpy as np 9 | from tqdm import tqdm 10 | 11 | # bg is not in class 12 | _CLASS_LABEL_DICT = { 13 | 0: (128, 0, 0), 14 | 1: (0, 128, 0), 15 | 2: (128, 128, 0), 16 | 3: (0, 0, 128), 17 | 4: (128, 0, 128), 18 | 5: (0, 128, 128), 19 | 6: (128, 128, 128), 20 | 7: (64, 0, 0), 21 | 8: (192, 0, 0), 22 | 9: (64, 128, 0), 23 | 10: (192, 128, 0), 24 | 11: (64, 0, 128), 25 | 12: (192, 0, 128), 26 | 13: (64, 128, 128), 27 | 14: (192, 128, 128), 28 | 15: (0, 64, 0), 29 | 16: (128, 64, 0), 30 | 17: (0, 192, 0), 31 | 18: (128, 192, 0), 32 | 19: (0,64,128) 33 | } 34 | 35 | border = (224, 224, 192) 36 | 37 | 38 | def find_contours(images, rgb_value): 39 | location = np.all((images == rgb_value), axis=2) 40 | position = np.where(location==1) 41 | 42 | return position 43 | 44 | 45 | def remove_border(images, border=border): 46 | location = np.all((images==border), axis=2) 47 | position = np.where(location==1) 48 | images[position] = 0 49 | return images 50 | 51 | 52 | if __name__ == '__main__': 53 | data_file = "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data_seg_label.log" 54 | data_list = [json.loads(x.strip()) for x in open(data_file).readlines()] 55 | 56 | OUTPUT_FOLDER = "/data/jiangmingchao/data/dataset/voc2012/VOCdevkit/VOC2012/SegmentationClassTargets" 57 | 58 | for data in tqdm(data_list): 59 | label_path = data["label_path"] 60 | label_name = label_path.split('/')[-1] 61 | images = cv2.imread(label_path) 62 | images = cv2.cvtColor(images, cv2.COLOR_BGR2RGB) 63 | 64 | targets = np.zeros((images.shape[0], images.shape[1])) 65 | 66 | if label_name == "2011_001621.png": 67 | for obj in data['obj']: 68 | cls = int(obj['class']) 69 | if cls == 19: 70 | color = _CLASS_LABEL_DICT[14] 71 | else: 72 | color = _CLASS_LABEL_DICT[cls] 73 | 74 | pts = find_contours(images, color) 75 | # bg is 0 76 | targets[pts] = cls+1 77 | 78 | cv2.imwrite(os.path.join(os.path.join(OUTPUT_FOLDER, label_name)), targets) 79 | else: 80 | for obj in data['obj']: 81 | cls = int(obj['class']) 82 | color = _CLASS_LABEL_DICT[cls] 83 | pts = find_contours(images, color) 84 | # bg is 0 85 | targets[pts] = cls+1 86 | # break 87 | 88 | cv2.imwrite(os.path.join(OUTPUT_FOLDER, label_name), targets) -------------------------------------------------------------------------------- /datasets/voc2012/dataset.py: -------------------------------------------------------------------------------- 1 | """Voc2017 Segmentation DataSet 2 | @author: FlyEgle 3 | @datetime: 2022-01-19 4 | """ 5 | # TODO: using LOGGER for print log 6 | import os 7 | import cv2 8 | import json 9 | import random 10 | import numpy as np 11 | import urllib.request as urt 12 | 13 | from torch.utils.data.dataset import Dataset 14 | from utils.DataAugments import RandomHorizionFlip, RandomRotate, RandomCropScale2, RanomCopyPastePruneBG 15 | from utils.DataAugments import Normalize, ToTensor, Compose, Scale, RandomGaussianBlur 16 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 17 | 18 | 19 | # TODO: need use the yaml config to control the hyparameters 20 | # train transformers 21 | def build_transformers(crop_size=(320, 320)): 22 | if isinstance(crop_size, int): 23 | crop_size = (crop_size, crop_size) 24 | 25 | data_aug = [ 26 | RandomCropScale2(scale_size=crop_size, scale=(0.3, 1.2), prob=0.5), 27 | RandomHorizionFlip(p=0.5), 28 | RanomCopyPastePruneBG(prob=0.1), # copy paste for prune bg 29 | RandomRotate(degree=15, mode=0), 30 | RandomGaussianBlur(p=0.2), 31 | ] 32 | 33 | to_tensor = [ 34 | Normalize(normalize=True, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), 35 | ToTensor(channel_first=True) 36 | ] 37 | 38 | final_aug = data_aug + to_tensor 39 | return Compose(final_aug) 40 | 41 | 42 | # val transformers 43 | def build_val_transformers(crop_size=(320, 320)): 44 | if isinstance(crop_size, int): 45 | crop_size = (crop_size, crop_size) 46 | 47 | data_aug = [ 48 | Scale(scale_size=crop_size) 49 | ] 50 | 51 | to_tensor = [ 52 | Normalize(normalize=True, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), 53 | ToTensor(channel_first=True) 54 | ] 55 | 56 | final_aug = data_aug + to_tensor 57 | return Compose(final_aug) 58 | 59 | 60 | class VocSemanticSegDataSet(Dataset): 61 | """Build the voc 2007 dataset for segmentation 62 | """ 63 | def __init__(self, data_file, transformers=None, train_phase=True, voc=False): 64 | super(VocSemanticSegDataSet, self).__init__() 65 | if not os.path.isfile(data_file): 66 | raise TypeError(f"{data_file} must be file type!!!") 67 | self.data_list = [json.loads(x.strip()) for x in open(data_file).readlines()] 68 | self.data_indices = [x for x in range(len(self.data_list))] 69 | self.train_phase = train_phase 70 | self.voc = voc 71 | if self.train_phase: 72 | random.shuffle(self.data_list) 73 | 74 | if transformers is not None: 75 | self.data_aug = transformers 76 | else: 77 | self.data_aug = None 78 | 79 | def _loadImages(self, line): 80 | img_path = line["image_path"] 81 | lbl_path = line["label_path"] 82 | 83 | if "http" not in img_path: 84 | image = cv2.imread(img_path) 85 | label = cv2.imread(lbl_path) 86 | # rm 255 border 87 | if self.voc: 88 | label = self._rm_border(label) 89 | else: 90 | label = self._make_lbl(label) 91 | 92 | # read oss data 93 | else: 94 | img_context = urt.urlopen(img_path).read() 95 | image = cv2.imdecode(np.asarray(bytearray(img_context), dtype='uint8'), cv2.IMREAD_COLOR) 96 | 97 | lbl_context = urt.urlopen(lbl_path).read() 98 | label = cv2.imdecode(np.asarray(bytearray(lbl_context), dtype='uint8'), cv2.IMREAD_COLOR) 99 | 100 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 101 | return image, label 102 | 103 | # remove the border 255 from label 104 | def _rm_border(self, seg): 105 | pe = np.where(seg==255) 106 | seg[pe] = 0 107 | return seg 108 | 109 | # make value to label 110 | def _make_lbl(self, seg): 111 | pe = np.where(seg==255) 112 | seg[pe] = 1 113 | return seg 114 | 115 | def __getitem__(self, index): 116 | for _ in range(10): 117 | try: 118 | line = self.data_list[index] 119 | img, lbl = self._loadImages(line) 120 | if self.data_aug is not None: 121 | img, lbl = self.data_aug(img, lbl) 122 | return img, lbl 123 | except Exception as e: 124 | print(f"{self.data_list[index]} have {e} exception!!!") 125 | index = random.choice(self.data_indices) 126 | 127 | def __len__(self): 128 | return len(self.data_list) 129 | 130 | 131 | 132 | -------------------------------------------------------------------------------- /datasets/voc2012/make_data_aug.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import json 4 | import numpy as np 5 | 6 | 7 | # seg_folder = "/data/jiangmingchao/data/dataset/voc2012/SegmentationClassAug" 8 | # seg_list = [os.path.join(seg_folder, x) for x in os.listdir(seg_folder)] 9 | 10 | # image_folder = "/data/jiangmingchao/data/dataset/voc2012/VOCdevkit/VOC2012/JPEGImages" 11 | # image_list = [os.path.join(image_folder, x) for x in os.listdir(image_folder)] 12 | 13 | # seg_dict = {x.split('/')[-1].split('.')[0]:x for x in seg_list} 14 | 15 | # # image_dict = {} 16 | # with open("/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/voc_aug/seg_train.log", "w") as file: 17 | # for image in image_list: 18 | # image_name = image.split('/')[-1].split('.')[0] 19 | # result = {} 20 | # if image_name in seg_dict: 21 | # result["image_path"] = image 22 | # result["label_path"] = seg_dict[image_name] 23 | 24 | # file.write(json.dumps(result) + '\n') 25 | 26 | # train_file = "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/voc_aug/seg_train.log" 27 | # val_file = "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/voc_aug/seg_val.log" 28 | 29 | # train_list = [x.strip() for x in open(train_file).readlines()] 30 | # val_list = [x.strip() for x in open(val_file).readlines()] 31 | 32 | 33 | # count = 0 34 | # for val in val_list: 35 | # if val in train_list: 36 | # count += 1 37 | 38 | # print(count) 39 | 40 | # image_file = "/data/jiangmingchao/data/dataset/voc2012/SegmentationClassAug/2007_000032.png" 41 | # image = cv2.imread(image_file) 42 | # image[np.where(image==255)] = 0 43 | # cv2.imwrite("/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/voc_aug/1.jpg", image) -------------------------------------------------------------------------------- /datasets/voc2012/process_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import json 4 | import torch 5 | import random 6 | from tqdm import tqdm 7 | # from torch.utils.data import Dataset, DataLoader 8 | # from dataset import build_val_transformers, VocSemanticSegDataSet 9 | # data_file = "/data/jiangmingchao/data/code/U-2-Net/makeData/taobao_seg_no_repeat_train.log" 10 | # data_list = [x.strip() for x in open(data_file).readlines()] 11 | 12 | 13 | # for data in tqdm(data_list): 14 | # data_json = json.loads(data) 15 | 16 | 17 | # data_file = "/data/jiangmingchao/data/code/U-2-Net/makeData/taobao_seg_no_repeat_val.log" 18 | # data_file = "/data/jiangmingchao/data/code/U-2-Net/makeData/taobao_green_shein/shin_10k.log" 19 | # data_list = open() 20 | # dataset = VocSemanticSegDataSet( 21 | # data_file, build_val_transformers(crop_size=(800, 800)), train_phase=False 22 | # ) 23 | 24 | # dataloader = DataLoader( 25 | # dataset, 26 | # batch_size=10, 27 | # shuffle=True, 28 | # num_workers=32, 29 | # pin_memory=True 30 | # ) 31 | 32 | # for idx, batch in enumerate(dataloader): 33 | # print(idx, batch.shape) 34 | # new_folder = "/data/jiangmingchao/data/dataset/shein/images_copy" 35 | # data_list = [json.loads(x.strip())["image_path"] for x in open(data_file).readlines()][10647:10648] 36 | # for data in tqdm(data_list): 37 | # print(data) 38 | # image_name = data.split('/')[-1] 39 | # image = cv2.imread(data) 40 | # cv2.imwrite(os.path.join(new_folder, image_name), image) 41 | 42 | 43 | data_file = "/data/jiangmingchao/data/code/U-2-Net/makeData/taobao_green_shein/shein_no_use_10k.log" 44 | data_list = [x.strip() for x in open(data_file).readlines()] 45 | 46 | 47 | sample_list = random.sample(data_list, 1000) 48 | 49 | with open("/data/jiangmingchao/data/code/U-2-Net/makeData/taobao_green_shein/shein_sample_no_use_1k.log", "w") as file: 50 | for data in sample_list: 51 | file.write(data + '\n') -------------------------------------------------------------------------------- /hyparam/FCNs/fcn_mbv2_8s.yaml: -------------------------------------------------------------------------------- 1 | # base hyparams for training 2 | # ------------DDP ---------------- 3 | DIST: 4 | DISTRIBUTED: 1 5 | NGPU: 1 6 | RANK: -1 7 | DIST_BACKEND: "NCCL" 8 | LOCAL_RANK: -1 9 | 10 | MODEL_NAME: "fcn_8s_mobilenetv2" 11 | DATASET_TYPE: "VOC" 12 | # ------------VOC DATASET --------------- 13 | VOC_DATASET: 14 | TRAIN_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC/seg_train.log" 15 | VAL_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC/seg_val.log" 16 | NUM_CLASSES: 21 17 | 18 | # ------------VOC_AUG DATASET --------------- 19 | VOC_AUG_DATASET: 20 | TRAIN_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC_AUG/seg_train.log" 21 | VAL_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC_AUG/seg_val.log" 22 | NUM_CLASSES: 21 23 | 24 | # ----------Transformes ---------------- 25 | DATA: 26 | CROP_SIZE: [512, 512] 27 | 28 | # TODO 29 | # ---------- Others DATASET-------------- 30 | # ---------- Custom DATASET ------------- 31 | 32 | # -----------SUMMARY---------------- 33 | SUMMARY: 34 | SAVE_PATH: "/data/jiangmingchao/data/AICKPT/Seg/FCNs/fcn8_mbv2_gpux1_512" 35 | CHECKPOINTS: "checkpoints" 36 | LOG: "log_dir" 37 | 38 | # -----------OPTIMIZER-------------- 39 | OPTIMIZER: 40 | OPTIM_NAME: "SGD" 41 | LEARNING_RATE: 0.1 42 | COSINE: 1 43 | FIX: 0 44 | WEIGHT_DECAY: 0.0001 45 | MOMENTUM: 0.9 46 | BATCH_SIZE: 16 47 | NUM_WORKERS: 32 48 | 49 | # ------------EPOCHS------------------ 50 | WARMUP_EPOCHS: 0 51 | MAX_EPOCHS: 150 52 | FREQENCE: 1 53 | 54 | SYNCBN: 0 55 | 56 | # -----------LOSSES ----------------- -------------------------------------------------------------------------------- /hyparam/FCNs/fcn_mbv3_8s.yaml: -------------------------------------------------------------------------------- 1 | # base hyparams for training 2 | # ------------DDP ---------------- 3 | DIST: 4 | DISTRIBUTED: 1 5 | NGPU: 1 6 | RANK: -1 7 | DIST_BACKEND: "NCCL" 8 | LOCAL_RANK: -1 9 | 10 | MODEL_NAME: "fcn_8s_mobilenetv3" 11 | DATASET_TYPE: "VOC" 12 | # ------------VOC DATASET --------------- 13 | VOC_DATASET: 14 | TRAIN_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC/seg_train.log" 15 | VAL_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC/seg_val.log" 16 | NUM_CLASSES: 21 17 | 18 | # ------------VOC_AUG DATASET --------------- 19 | VOC_AUG_DATASET: 20 | TRAIN_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC_AUG/seg_train.log" 21 | VAL_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC_AUG/seg_val.log" 22 | NUM_CLASSES: 21 23 | 24 | # ----------Transformes ---------------- 25 | DATA: 26 | CROP_SIZE: [512, 512] 27 | 28 | # TODO 29 | # ---------- Others DATASET-------------- 30 | # ---------- Custom DATASET ------------- 31 | 32 | # -----------SUMMARY---------------- 33 | SUMMARY: 34 | SAVE_PATH: "/data/jiangmingchao/data/AICKPT/Seg/FCNs/fcn8_mbv3_gpux1_512" 35 | CHECKPOINTS: "checkpoints" 36 | LOG: "log_dir" 37 | 38 | # -----------OPTIMIZER-------------- 39 | OPTIMIZER: 40 | OPTIM_NAME: "SGD" 41 | LEARNING_RATE: 0.1 42 | COSINE: 1 43 | FIX: 0 44 | WEIGHT_DECAY: 0.0001 45 | MOMENTUM: 0.9 46 | BATCH_SIZE: 16 47 | NUM_WORKERS: 32 48 | 49 | # ------------EPOCHS------------------ 50 | WARMUP_EPOCHS: 0 51 | MAX_EPOCHS: 150 52 | FREQENCE: 1 53 | 54 | SYNCBN: 0 55 | 56 | # -----------LOSSES ----------------- -------------------------------------------------------------------------------- /hyparam/FCNs/fcn_resnet101_8s.yaml: -------------------------------------------------------------------------------- 1 | # base hyparams for training 2 | # ------------DDP ---------------- 3 | DIST: 4 | DISTRIBUTED: 1 5 | NGPU: 1 6 | RANK: -1 7 | DIST_BACKEND: "NCCL" 8 | LOCAL_RANK: -1 9 | 10 | MODEL_NAME: "FCN_8S_resnet101" 11 | DATASET_TYPE: "VOC" 12 | # ------------VOC DATASET --------------- 13 | VOC_DATASET: 14 | TRAIN_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC/seg_train.log" 15 | VAL_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC/seg_val.log" 16 | NUM_CLASSES: 21 17 | 18 | # ------------VOC_AUG DATASET --------------- 19 | VOC_AUG_DATASET: 20 | TRAIN_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC_AUG/seg_train.log" 21 | VAL_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC_AUG/seg_val.log" 22 | NUM_CLASSES: 21 23 | 24 | # ----------Transformes ---------------- 25 | DATA: 26 | CROP_SIZE: [512, 512] 27 | 28 | # TODO 29 | # ---------- Others DATASET-------------- 30 | # ---------- Custom DATASET ------------- 31 | 32 | # -----------SUMMARY---------------- 33 | SUMMARY: 34 | SAVE_PATH: "/data/jiangmingchao/data/AICKPT/Seg/FCNs/fcn8_r101_gpux1_512" 35 | CHECKPOINTS: "checkpoints" 36 | LOG: "log_dir" 37 | 38 | # -----------OPTIMIZER-------------- 39 | OPTIMIZER: 40 | OPTIM_NAME: "SGD" 41 | LEARNING_RATE: 0.001 42 | COSINE: 1 43 | FIX: 0 44 | WEIGHT_DECAY: 0.0001 45 | MOMENTUM: 0.9 46 | BATCH_SIZE: 8 47 | NUM_WORKERS: 32 48 | 49 | # ------------EPOCHS------------------ 50 | WARMUP_EPOCHS: 0 51 | MAX_EPOCHS: 60 52 | FREQENCE: 1 53 | 54 | SYNCBN: 0 55 | 56 | # -----------LOSSES ----------------- -------------------------------------------------------------------------------- /hyparam/FCNs/fcn_resnet152_8s.yaml: -------------------------------------------------------------------------------- 1 | # base hyparams for training 2 | # ------------DDP ---------------- 3 | DIST: 4 | DISTRIBUTED: 1 5 | NGPU: 1 6 | RANK: -1 7 | DIST_BACKEND: "NCCL" 8 | LOCAL_RANK: -1 9 | 10 | MODEL_NAME: "FCN_8S_resnet152" 11 | DATASET_TYPE: "VOC" 12 | # ------------VOC DATASET --------------- 13 | VOC_DATASET: 14 | TRAIN_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC/seg_train.log" 15 | VAL_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC/seg_val.log" 16 | NUM_CLASSES: 21 17 | 18 | # ------------VOC_AUG DATASET --------------- 19 | VOC_AUG_DATASET: 20 | TRAIN_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC_AUG/seg_train.log" 21 | VAL_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC_AUG/seg_val.log" 22 | NUM_CLASSES: 21 23 | 24 | # ----------Transformes ---------------- 25 | DATA: 26 | CROP_SIZE: [512, 512] 27 | 28 | # TODO 29 | # ---------- Others DATASET-------------- 30 | # ---------- Custom DATASET ------------- 31 | 32 | # -----------SUMMARY---------------- 33 | SUMMARY: 34 | SAVE_PATH: "/data/jiangmingchao/data/AICKPT/Seg/FCNs/fcn8_r152_gpux1_512" 35 | CHECKPOINTS: "checkpoints" 36 | LOG: "log_dir" 37 | 38 | # -----------OPTIMIZER-------------- 39 | OPTIMIZER: 40 | OPTIM_NAME: "SGD" 41 | LEARNING_RATE: 0.001 42 | COSINE: 1 43 | FIX: 0 44 | WEIGHT_DECAY: 0.0001 45 | MOMENTUM: 0.9 46 | BATCH_SIZE: 8 47 | NUM_WORKERS: 32 48 | 49 | # ------------EPOCHS------------------ 50 | WARMUP_EPOCHS: 0 51 | MAX_EPOCHS: 60 52 | FREQENCE: 1 53 | 54 | SYNCBN: 0 55 | 56 | # -----------LOSSES ----------------- -------------------------------------------------------------------------------- /hyparam/FCNs/fcn_resnet50_8s.yaml: -------------------------------------------------------------------------------- 1 | # base hyparams for training 2 | # ------------DDP ---------------- 3 | DIST: 4 | DISTRIBUTED: 1 5 | NGPU: 1 6 | RANK: -1 7 | DIST_BACKEND: "NCCL" 8 | LOCAL_RANK: -1 9 | 10 | MODEL_NAME: "FCN_8S_resnet50" 11 | DATASET_TYPE: "VOC" 12 | # ------------VOC DATASET --------------- 13 | VOC_DATASET: 14 | TRAIN_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC/seg_train.log" 15 | VAL_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC/seg_val.log" 16 | NUM_CLASSES: 21 17 | 18 | # ------------VOC_AUG DATASET --------------- 19 | VOC_AUG_DATASET: 20 | TRAIN_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC_AUG/seg_train.log" 21 | VAL_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC_AUG/seg_val.log" 22 | NUM_CLASSES: 21 23 | 24 | # ----------Transformes ---------------- 25 | DATA: 26 | CROP_SIZE: [512, 512] 27 | 28 | # TODO 29 | # ---------- Others DATASET-------------- 30 | # ---------- Custom DATASET ------------- 31 | 32 | # -----------SUMMARY---------------- 33 | SUMMARY: 34 | SAVE_PATH: "/data/jiangmingchao/data/AICKPT/Seg/FCNs/fcn8_r50_gpux1_512" 35 | CHECKPOINTS: "checkpoints" 36 | LOG: "log_dir" 37 | 38 | # -----------OPTIMIZER-------------- 39 | OPTIMIZER: 40 | OPTIM_NAME: "SGD" 41 | LEARNING_RATE: 0.001 42 | COSINE: 1 43 | FIX: 0 44 | WEIGHT_DECAY: 0.0001 45 | MOMENTUM: 0.9 46 | BATCH_SIZE: 8 47 | NUM_WORKERS: 32 48 | 49 | # ------------EPOCHS------------------ 50 | WARMUP_EPOCHS: 0 51 | MAX_EPOCHS: 60 52 | FREQENCE: 1 53 | 54 | SYNCBN: 0 55 | 56 | # -----------LOSSES ----------------- -------------------------------------------------------------------------------- /hyparam/FCNs/fcn_vgg16_16s.yaml: -------------------------------------------------------------------------------- 1 | # base hyparams for training 2 | # ------------DDP ---------------- 3 | DIST: 4 | DISTRIBUTED: 1 5 | NGPU: 1 6 | RANK: -1 7 | DIST_BACKEND: "NCCL" 8 | LOCAL_RANK: -1 9 | 10 | MODEL_NAME: "FCN_16S" 11 | DATASET_TYPE: "VOC" 12 | # ------------VOC DATASET --------------- 13 | VOC_DATASET: 14 | TRAIN_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC/seg_train.log" 15 | VAL_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC/seg_val.log" 16 | NUM_CLASSES: 21 17 | 18 | # ------------VOC_AUG DATASET --------------- 19 | VOC_AUG_DATASET: 20 | TRAIN_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC_AUG/seg_train.log" 21 | VAL_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC_AUG/seg_val.log" 22 | NUM_CLASSES: 21 23 | 24 | # ----------Transformes ---------------- 25 | DATA: 26 | CROP_SIZE: [512, 512] 27 | 28 | # TODO 29 | # ---------- Others DATASET-------------- 30 | # ---------- Custom DATASET ------------- 31 | 32 | # -----------SUMMARY---------------- 33 | SUMMARY: 34 | SAVE_PATH: "/data/jiangmingchao/data/AICKPT/Seg/FCNs/fcn16_vgg16_gpux1_512" 35 | CHECKPOINTS: "checkpoints" 36 | LOG: "log_dir" 37 | 38 | # -----------OPTIMIZER-------------- 39 | OPTIMIZER: 40 | OPTIM_NAME: "SGD" 41 | LEARNING_RATE: 0.0001 42 | COSINE: 1 43 | FIX: 0 44 | WEIGHT_DECAY: 0.0001 45 | MOMENTUM: 0.9 46 | BATCH_SIZE: 8 47 | NUM_WORKERS: 32 48 | 49 | # ------------EPOCHS------------------ 50 | WARMUP_EPOCHS: 0 51 | MAX_EPOCHS: 60 52 | FREQENCE: 1 53 | 54 | SYNCBN: 0 55 | 56 | # -----------LOSSES ----------------- -------------------------------------------------------------------------------- /hyparam/FCNs/fcn_vgg16_32s.yaml: -------------------------------------------------------------------------------- 1 | # base hyparams for training 2 | # ------------DDP ---------------- 3 | DIST: 4 | DISTRIBUTED: 1 5 | NGPU: 1 6 | RANK: -1 7 | DIST_BACKEND: "NCCL" 8 | LOCAL_RANK: -1 9 | 10 | MODEL_NAME: "FCN_32S" 11 | DATASET_TYPE: "VOC" 12 | # ------------VOC DATASET --------------- 13 | VOC_DATASET: 14 | TRAIN_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC/seg_train.log" 15 | VAL_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC/seg_val.log" 16 | NUM_CLASSES: 21 17 | 18 | # ------------VOC_AUG DATASET --------------- 19 | VOC_AUG_DATASET: 20 | TRAIN_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC_AUG/seg_train.log" 21 | VAL_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC_AUG/seg_val.log" 22 | NUM_CLASSES: 21 23 | 24 | # ----------Transformes ---------------- 25 | DATA: 26 | CROP_SIZE: [512, 512] 27 | 28 | # TODO 29 | # ---------- Others DATASET-------------- 30 | # ---------- Custom DATASET ------------- 31 | 32 | # -----------SUMMARY---------------- 33 | SUMMARY: 34 | SAVE_PATH: "/data/jiangmingchao/data/AICKPT/Seg/FCNs/fcn32_vgg16_gpux1_512" 35 | CHECKPOINTS: "checkpoints" 36 | LOG: "log_dir" 37 | 38 | # -----------OPTIMIZER-------------- 39 | OPTIMIZER: 40 | OPTIM_NAME: "SGD" 41 | LEARNING_RATE: 0.001 42 | COSINE: 1 43 | FIX: 0 44 | WEIGHT_DECAY: 0.0001 45 | MOMENTUM: 0.9 46 | BATCH_SIZE: 8 47 | NUM_WORKERS: 32 48 | 49 | # ------------EPOCHS------------------ 50 | WARMUP_EPOCHS: 0 51 | MAX_EPOCHS: 60 52 | FREQENCE: 1 53 | 54 | SYNCBN: 0 55 | 56 | # -----------LOSSES ----------------- -------------------------------------------------------------------------------- /hyparam/FCNs/fcn_vgg16_8s.yaml: -------------------------------------------------------------------------------- 1 | # base hyparams for training 2 | # ------------DDP ---------------- 3 | DIST: 4 | DISTRIBUTED: 1 5 | NGPU: 1 6 | RANK: -1 7 | DIST_BACKEND: "NCCL" 8 | LOCAL_RANK: -1 9 | 10 | MODEL_NAME: "FCN_8S" 11 | DATASET_TYPE: "VOC" 12 | # ------------VOC DATASET --------------- 13 | VOC_DATASET: 14 | TRAIN_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC/seg_train.log" 15 | VAL_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC/seg_val.log" 16 | NUM_CLASSES: 21 17 | 18 | # ------------VOC_AUG DATASET --------------- 19 | VOC_AUG_DATASET: 20 | TRAIN_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC_AUG/seg_train.log" 21 | VAL_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC_AUG/seg_val.log" 22 | NUM_CLASSES: 21 23 | 24 | # ----------Transformes ---------------- 25 | DATA: 26 | CROP_SIZE: [512, 512] 27 | 28 | # TODO 29 | # ---------- Others DATASET-------------- 30 | # ---------- Custom DATASET ------------- 31 | 32 | # -----------SUMMARY---------------- 33 | SUMMARY: 34 | SAVE_PATH: "/data/jiangmingchao/data/AICKPT/Seg/FCNs/fcn8_vgg16_gpux1_512" 35 | CHECKPOINTS: "checkpoints" 36 | LOG: "log_dir" 37 | 38 | # -----------OPTIMIZER-------------- 39 | OPTIMIZER: 40 | OPTIM_NAME: "SGD" 41 | LEARNING_RATE: 0.0005 42 | COSINE: 1 43 | FIX: 0 44 | WEIGHT_DECAY: 0.0001 45 | MOMENTUM: 0.9 46 | BATCH_SIZE: 16 47 | NUM_WORKERS: 32 48 | 49 | # ------------EPOCHS------------------ 50 | WARMUP_EPOCHS: 0 51 | MAX_EPOCHS: 100 52 | FREQENCE: 1 53 | 54 | SYNCBN: 0 55 | 56 | # -----------LOSSES ----------------- -------------------------------------------------------------------------------- /hyparam/HRNet/baseline_320.yaml: -------------------------------------------------------------------------------- 1 | # u2net baseline hyparameters 2 | # ------------DDP ---------------- 3 | DIST: 4 | DISTRIBUTED: 1 5 | NGPU: 1 6 | RANK: -1 7 | DIST_BACKEND: "NCCL" 8 | LOCAL_RANK: -1 9 | 10 | MODEL_NAME: "hrnet" 11 | DATASET_TYPE: "CUSTOM" 12 | # ------------VOC DATASET --------------- 13 | VOC_DATASET: 14 | TRAIN_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC/seg_train.log" 15 | VAL_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC/seg_val.log" 16 | NUM_CLASSES: 21 17 | 18 | # ------------VOC_AUG DATASET --------------- 19 | VOC_AUG_DATASET: 20 | TRAIN_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC_AUG/seg_train.log" 21 | VAL_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC_AUG/seg_val.log" 22 | NUM_CLASSES: 21 23 | 24 | # ---------- CUSTOM DATASET ------------ 25 | CUSTOM_DATASET: 26 | # TRAIN_FILE: "/data/jiangmingchao/data/code/U-2-Net/makeData/taobao_seg_no_repeat_train.log" 27 | # TRAIN_FILE: "/data/jiangmingchao/data/code/U-2-Net/makeData/taobao_seg_20k.log" 28 | # TRAIN_FILE: "/data/jiangmingchao/data/code/U-2-Net/makeData/taobao_seg_12k_filter.log" 29 | # TRAIN_FILE: "/data/jiangmingchao/data/code/U-2-Net/makeData/taobao_seg_14k_rm_same_folder.log" # best 30 | TRAIN_FILE: "/data/jiangmingchao/data/code/U-2-Net/makeData/taobao_green/taobao_seg_14k_green_tryon_1k_shuf.txt" 31 | # TRAIN_FILE : "/data/jiangmingchao/data/code/U-2-Net/makeData/taobao_green/taobao_14k_green_tryon_1k_filter_3w_2k_shuf.log" # NEW 17K 32 | VAL_FILE: "/data/jiangmingchao/data/code/U-2-Net/makeData/taobao_seg_no_repeat_val.log" 33 | NUM_CLASSES: 1 # used for celoss 34 | # ----------Transformes ---------------- 35 | DATA: 36 | CROP_SIZE: [800, 800] 37 | 38 | # TODO 39 | # ---------- Others DATASET-------------- 40 | # ---------- Custom DATASET ------------- 41 | 42 | # -----------SUMMARY---------------- 43 | SUMMARY: 44 | SAVE_PATH: "/data/jiangmingchao/data/AICKPT/Seg/HRNet/baseline_bce_dice_800_150epch_FPN/" 45 | CHECKPOINTS: "checkpoints" 46 | LOG: "log_dir" 47 | 48 | # -----------OPTIMIZER-------------- 49 | OPTIMIZER: 50 | OPTIM_NAME: "ADAMW" # ADAMW 51 | LEARNING_RATE: 0.001 # 0.001 52 | COSINE: 1 53 | FIX: 0 54 | WEIGHT_DECAY: 0.01 # 0.01 55 | MOMENTUM: 0.9 56 | BATCH_SIZE: 8 57 | NUM_WORKERS: 32 58 | 59 | # ------------EPOCHS------------------ 60 | WARMUP_EPOCHS: 0 61 | MAX_EPOCHS: 150 62 | FREQENCE: 1 63 | 64 | SYNCBN: 1 65 | 66 | # -----------LOSSES ----------------- 67 | LOSS: "BCE+DICE" 68 | 69 | # -----------PRETRAIN -------------- 70 | PRETRAIN: True 71 | PRETRAIN_WEIGHTS: "/data/jiangmingchao/data/AICKPT/Seg/HRNet/baseline_bce_dice_320_100epch_FPN/checkpoints/best_ckpt_losses_0.09434028502021517_miou_0.9761377279926882.pth" 72 | 73 | 74 | # /data/jiangmingchao/data/AICKPT/Seg/U2Net/baseline_bce_dice_refine_gt_adamw_real_320x320_14k_data/checkpoints/best_ckpt_losses_1.1548381788390023_miou_0.9731329759628107.pth 75 | 76 | # best 77 | # "/data/jiangmingchao/data/AICKPT/Seg/U2Net/baseline_bce_dice_refine_gt_adamw_real_800x800_14k_finetune_20epoch/checkpoints/best_ckpt_losses_0.5155016886336463_miou_0.9867006791631263.pth" 78 | 79 | # ------------Batch Aug -------------- 80 | BATCH_AUG: 81 | MIXUP: False 82 | 83 | # /data/jiangmingchao/data/AICKPT/Seg/U2Net/baseline_bce_dice_refine_gt_adamw_real_768x768_14k_data_320_pretrain_long_epoch/checkpoints/best_ckpt_losses_0.5636463910341263_miou_0.9844545931472927.pth 84 | 85 | 86 | # /data/jiangmingchao/data/AICKPT/Seg/U2Net/baseline_bce_dice_refine_gt/ 87 | 88 | # /data/jiangmingchao/data/code/U-2-Net/makeData/taobao_seg_20k.log 89 | 90 | # /data/jiangmingchao/data/AICKPT/Seg/U2Net/baseline_bce_dice_refine_gt_adam_768x768/checkpoints/best_ckpt_losses_0.6692883372306824_miou_0.9821306581457698.pth 91 | # /data/jiangmingchao/data/AICKPT/Seg/U2Net/baseline_bce_dice_refine_gt_adam_768x768_20k_data/checkpoints/best_ckpt_losses_1.0762560623032706_miou_0.9748090077617877.pth -------------------------------------------------------------------------------- /hyparam/SegNet/seg_mbv2.yaml: -------------------------------------------------------------------------------- 1 | # base hyparams for training 2 | # ------------DDP ---------------- 3 | DIST: 4 | DISTRIBUTED: 1 5 | NGPU: 1 6 | RANK: -1 7 | DIST_BACKEND: "NCCL" 8 | LOCAL_RANK: -1 9 | 10 | MODEL_NAME: "segnet_mobilenetv2" 11 | DATASET_TYPE: "VOC" 12 | # ------------VOC DATASET --------------- 13 | VOC_DATASET: 14 | TRAIN_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC/seg_train.log" 15 | VAL_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC/seg_val.log" 16 | NUM_CLASSES: 21 17 | 18 | # ------------VOC_AUG DATASET --------------- 19 | VOC_AUG_DATASET: 20 | TRAIN_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC_AUG/seg_train.log" 21 | VAL_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC_AUG/seg_val.log" 22 | NUM_CLASSES: 21 23 | 24 | # ----------Transformes ---------------- 25 | DATA: 26 | CROP_SIZE: [512, 512] 27 | 28 | # TODO 29 | # ---------- Others DATASET-------------- 30 | # ---------- Custom DATASET ------------- 31 | 32 | # -----------SUMMARY---------------- 33 | SUMMARY: 34 | SAVE_PATH: "/data/jiangmingchao/data/AICKPT/Seg/SegNet/segnet_mbv2_gpux1_512" 35 | CHECKPOINTS: "checkpoints" 36 | LOG: "log_dir" 37 | 38 | # -----------OPTIMIZER-------------- 39 | OPTIMIZER: 40 | OPTIM_NAME: "SGD" 41 | LEARNING_RATE: 0.01 42 | COSINE: 1 43 | FIX: 0 44 | WEIGHT_DECAY: 0.0001 45 | MOMENTUM: 0.9 46 | BATCH_SIZE: 16 47 | NUM_WORKERS: 32 48 | 49 | # ------------EPOCHS------------------ 50 | WARMUP_EPOCHS: 0 51 | MAX_EPOCHS: 150 52 | FREQENCE: 1 53 | 54 | SYNCBN: 0 55 | 56 | # -----------LOSSES ----------------- -------------------------------------------------------------------------------- /hyparam/SegNet/seg_r50.yaml: -------------------------------------------------------------------------------- 1 | # base hyparams for training 2 | # ------------DDP ---------------- 3 | DIST: 4 | DISTRIBUTED: 1 5 | NGPU: 1 6 | RANK: -1 7 | DIST_BACKEND: "NCCL" 8 | LOCAL_RANK: -1 9 | 10 | MODEL_NAME: "segnet_resnet50" 11 | DATASET_TYPE: "VOC" 12 | # ------------VOC DATASET --------------- 13 | VOC_DATASET: 14 | TRAIN_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC/seg_train.log" 15 | VAL_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC/seg_val.log" 16 | NUM_CLASSES: 21 17 | 18 | # ------------VOC_AUG DATASET --------------- 19 | VOC_AUG_DATASET: 20 | TRAIN_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC_AUG/seg_train.log" 21 | VAL_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC_AUG/seg_val.log" 22 | NUM_CLASSES: 21 23 | 24 | # ----------Transformes ---------------- 25 | DATA: 26 | CROP_SIZE: [512, 512] 27 | 28 | # TODO 29 | # ---------- Others DATASET-------------- 30 | # ---------- Custom DATASET ------------- 31 | 32 | # -----------SUMMARY---------------- 33 | SUMMARY: 34 | SAVE_PATH: "/data/jiangmingchao/data/AICKPT/Seg/SegNet/segnet_r50_gpux1_512" 35 | CHECKPOINTS: "checkpoints" 36 | LOG: "log_dir" 37 | 38 | # -----------OPTIMIZER-------------- 39 | OPTIMIZER: 40 | OPTIM_NAME: "SGD" 41 | LEARNING_RATE: 0.01 42 | COSINE: 1 43 | FIX: 0 44 | WEIGHT_DECAY: 0.0001 45 | MOMENTUM: 0.9 46 | BATCH_SIZE: 8 47 | NUM_WORKERS: 32 48 | 49 | # ------------EPOCHS------------------ 50 | WARMUP_EPOCHS: 0 51 | MAX_EPOCHS: 60 52 | FREQENCE: 1 53 | 54 | SYNCBN: 0 55 | 56 | # -----------LOSSES ----------------- -------------------------------------------------------------------------------- /hyparam/SegNet/seg_vgg16.yaml: -------------------------------------------------------------------------------- 1 | # base hyparams for training 2 | # ------------DDP ---------------- 3 | DIST: 4 | DISTRIBUTED: 1 5 | NGPU: 1 6 | RANK: -1 7 | DIST_BACKEND: "NCCL" 8 | LOCAL_RANK: -1 9 | 10 | MODEL_NAME: "segnet_vgg16" 11 | DATASET_TYPE: "VOC" 12 | # ------------VOC DATASET --------------- 13 | VOC_DATASET: 14 | TRAIN_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC/seg_train.log" 15 | VAL_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC/seg_val.log" 16 | NUM_CLASSES: 21 17 | 18 | # ------------VOC_AUG DATASET --------------- 19 | VOC_AUG_DATASET: 20 | TRAIN_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC_AUG/seg_train.log" 21 | VAL_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC_AUG/seg_val.log" 22 | NUM_CLASSES: 21 23 | 24 | # ----------Transformes ---------------- 25 | DATA: 26 | CROP_SIZE: [512, 512] 27 | 28 | # TODO 29 | # ---------- Others DATASET-------------- 30 | # ---------- Custom DATASET ------------- 31 | 32 | # -----------SUMMARY---------------- 33 | SUMMARY: 34 | SAVE_PATH: "/data/jiangmingchao/data/AICKPT/Seg/SegNet/segnet_vgg16_gpux1_512" 35 | CHECKPOINTS: "checkpoints" 36 | LOG: "log_dir" 37 | 38 | # -----------OPTIMIZER-------------- 39 | OPTIMIZER: 40 | OPTIM_NAME: "SGD" 41 | LEARNING_RATE: 0.01 42 | COSINE: 1 43 | FIX: 0 44 | WEIGHT_DECAY: 0.0001 45 | MOMENTUM: 0.9 46 | BATCH_SIZE: 16 47 | NUM_WORKERS: 32 48 | 49 | # ------------EPOCHS------------------ 50 | WARMUP_EPOCHS: 0 51 | MAX_EPOCHS: 100 52 | FREQENCE: 1 53 | 54 | SYNCBN: 0 55 | 56 | # -----------LOSSES ----------------- -------------------------------------------------------------------------------- /hyparam/SegNet/seg_vgg16_pool.yaml: -------------------------------------------------------------------------------- 1 | # base hyparams for training 2 | # ------------DDP ---------------- 3 | DIST: 4 | DISTRIBUTED: 1 5 | NGPU: 1 6 | RANK: -1 7 | DIST_BACKEND: "NCCL" 8 | LOCAL_RANK: -1 9 | 10 | MODEL_NAME: "segnet_vgg16" 11 | DATASET_TYPE: "VOC" 12 | # ------------VOC DATASET --------------- 13 | VOC_DATASET: 14 | TRAIN_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC/seg_train.log" 15 | VAL_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC/seg_val.log" 16 | NUM_CLASSES: 21 17 | 18 | # ------------VOC_AUG DATASET --------------- 19 | VOC_AUG_DATASET: 20 | TRAIN_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC_AUG/seg_train.log" 21 | VAL_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC_AUG/seg_val.log" 22 | NUM_CLASSES: 21 23 | 24 | # ----------Transformes ---------------- 25 | DATA: 26 | CROP_SIZE: [512, 512] 27 | 28 | # TODO 29 | # ---------- Others DATASET-------------- 30 | # ---------- Custom DATASET ------------- 31 | 32 | # -----------SUMMARY---------------- 33 | SUMMARY: 34 | SAVE_PATH: "/data/jiangmingchao/data/AICKPT/Seg/SegNet/segnet_vgg16_gpux1_512_pool" 35 | CHECKPOINTS: "checkpoints" 36 | LOG: "log_dir" 37 | 38 | # -----------OPTIMIZER-------------- 39 | OPTIMIZER: 40 | OPTIM_NAME: "SGD" 41 | LEARNING_RATE: 0.01 42 | COSINE: 1 43 | FIX: 0 44 | WEIGHT_DECAY: 0.0001 45 | MOMENTUM: 0.9 46 | BATCH_SIZE: 16 47 | NUM_WORKERS: 32 48 | 49 | # ------------EPOCHS------------------ 50 | WARMUP_EPOCHS: 0 51 | MAX_EPOCHS: 300 52 | FREQENCE: 1 53 | 54 | SYNCBN: 0 55 | 56 | # -----------LOSSES ----------------- -------------------------------------------------------------------------------- /hyparam/U2Net/adjust_lr_bs.yaml: -------------------------------------------------------------------------------- 1 | # u2net baseline hyparameters 2 | # ------------DDP ---------------- 3 | DIST: 4 | DISTRIBUTED: 1 5 | NGPU: 1 6 | RANK: -1 7 | DIST_BACKEND: "NCCL" 8 | LOCAL_RANK: -1 9 | 10 | MODEL_NAME: "u2net" 11 | DATASET_TYPE: "CUSTOM" 12 | # ------------VOC DATASET --------------- 13 | VOC_DATASET: 14 | TRAIN_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC/seg_train.log" 15 | VAL_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC/seg_val.log" 16 | NUM_CLASSES: 21 17 | 18 | # ------------VOC_AUG DATASET --------------- 19 | VOC_AUG_DATASET: 20 | TRAIN_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC_AUG/seg_train.log" 21 | VAL_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC_AUG/seg_val.log" 22 | NUM_CLASSES: 21 23 | 24 | # ---------- CUSTOM DATASET ------------ 25 | CUSTOM_DATASET: 26 | TRAIN_FILE: "/data/jiangmingchao/data/code/U-2-Net/makeData/taobao_seg_no_repeat_train.log" 27 | VAL_FILE: "/data/jiangmingchao/data/code/U-2-Net/makeData/taobao_seg_no_repeat_val.log" 28 | NUM_CLASSES: 2 # used for celoss 29 | # ----------Transformes ---------------- 30 | DATA: 31 | CROP_SIZE: [320, 320] 32 | 33 | # TODO 34 | # ---------- Others DATASET-------------- 35 | # ---------- Custom DATASET ------------- 36 | 37 | # -----------SUMMARY---------------- 38 | SUMMARY: 39 | SAVE_PATH: "/data/jiangmingchao/data/AICKPT/Seg/U2Net/adjust_bs_lr" 40 | CHECKPOINTS: "checkpoints" 41 | LOG: "log_dir" 42 | 43 | # -----------OPTIMIZER-------------- 44 | OPTIMIZER: 45 | OPTIM_NAME: "SGD" 46 | LEARNING_RATE: 0.02 47 | COSINE: 1 48 | FIX: 0 49 | WEIGHT_DECAY: 0.0001 50 | MOMENTUM: 0.9 51 | BATCH_SIZE: 32 52 | NUM_WORKERS: 32 53 | 54 | # ------------EPOCHS------------------ 55 | WARMUP_EPOCHS: 0 56 | MAX_EPOCHS: 100 57 | FREQENCE: 1 58 | 59 | SYNCBN: 0 60 | 61 | # -----------LOSSES ----------------- -------------------------------------------------------------------------------- /hyparam/U2Net/baseline.yaml: -------------------------------------------------------------------------------- 1 | # u2net baseline hyparameters 2 | # ------------DDP ---------------- 3 | DIST: 4 | DISTRIBUTED: 1 5 | NGPU: 1 6 | RANK: -1 7 | DIST_BACKEND: "NCCL" 8 | LOCAL_RANK: -1 9 | 10 | MODEL_NAME: "u2net" 11 | DATASET_TYPE: "CUSTOM" 12 | # ------------VOC DATASET --------------- 13 | VOC_DATASET: 14 | TRAIN_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC/seg_train.log" 15 | VAL_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC/seg_val.log" 16 | NUM_CLASSES: 21 17 | 18 | # ------------VOC_AUG DATASET --------------- 19 | VOC_AUG_DATASET: 20 | TRAIN_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC_AUG/seg_train.log" 21 | VAL_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC_AUG/seg_val.log" 22 | NUM_CLASSES: 21 23 | 24 | # ---------- CUSTOM DATASET ------------ 25 | CUSTOM_DATASET: 26 | TRAIN_FILE: "/data/jiangmingchao/data/code/U-2-Net/makeData/taobao_seg_no_repeat_train.log" 27 | VAL_FILE: "/data/jiangmingchao/data/code/U-2-Net/makeData/taobao_seg_no_repeat_val.log" 28 | NUM_CLASSES: 2 # used for celoss 29 | # ----------Transformes ---------------- 30 | DATA: 31 | CROP_SIZE: [320, 320] 32 | 33 | # TODO 34 | # ---------- Others DATASET-------------- 35 | # ---------- Custom DATASET ------------- 36 | 37 | # -----------SUMMARY---------------- 38 | SUMMARY: 39 | SAVE_PATH: "/data/jiangmingchao/data/AICKPT/Seg/U2Net/baseline/" 40 | CHECKPOINTS: "checkpoints" 41 | LOG: "log_dir" 42 | 43 | # -----------OPTIMIZER-------------- 44 | OPTIMIZER: 45 | OPTIM_NAME: "SGD" 46 | LEARNING_RATE: 0.01 47 | COSINE: 1 48 | FIX: 0 49 | WEIGHT_DECAY: 0.0001 50 | MOMENTUM: 0.9 51 | BATCH_SIZE: 16 52 | NUM_WORKERS: 32 53 | 54 | # ------------EPOCHS------------------ 55 | WARMUP_EPOCHS: 0 56 | MAX_EPOCHS: 100 57 | FREQENCE: 1 58 | 59 | SYNCBN: 1 60 | 61 | # -----------LOSSES ----------------- -------------------------------------------------------------------------------- /hyparam/U2Net/baseline_bce.yaml: -------------------------------------------------------------------------------- 1 | # u2net baseline hyparameters 2 | # ------------DDP ---------------- 3 | DIST: 4 | DISTRIBUTED: 1 5 | NGPU: 1 6 | RANK: -1 7 | DIST_BACKEND: "NCCL" 8 | LOCAL_RANK: -1 9 | 10 | MODEL_NAME: "u2net" 11 | DATASET_TYPE: "CUSTOM" 12 | # ------------VOC DATASET --------------- 13 | VOC_DATASET: 14 | TRAIN_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC/seg_train.log" 15 | VAL_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC/seg_val.log" 16 | NUM_CLASSES: 21 17 | 18 | # ------------VOC_AUG DATASET --------------- 19 | VOC_AUG_DATASET: 20 | TRAIN_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC_AUG/seg_train.log" 21 | VAL_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC_AUG/seg_val.log" 22 | NUM_CLASSES: 21 23 | 24 | # ---------- CUSTOM DATASET ------------ 25 | CUSTOM_DATASET: 26 | TRAIN_FILE: "/data/jiangmingchao/data/code/U-2-Net/makeData/taobao_seg_no_repeat_train.log" 27 | VAL_FILE: "/data/jiangmingchao/data/code/U-2-Net/makeData/taobao_seg_no_repeat_val.log" 28 | NUM_CLASSES: 1 # used for celoss 29 | # ----------Transformes ---------------- 30 | DATA: 31 | CROP_SIZE: [320, 320] 32 | 33 | # TODO 34 | # ---------- Others DATASET-------------- 35 | # ---------- Custom DATASET ------------- 36 | 37 | # -----------SUMMARY---------------- 38 | SUMMARY: 39 | SAVE_PATH: "/data/jiangmingchao/data/AICKPT/Seg/U2Net/baseline_bce/" 40 | CHECKPOINTS: "checkpoints" 41 | LOG: "log_dir" 42 | 43 | # -----------OPTIMIZER-------------- 44 | OPTIMIZER: 45 | OPTIM_NAME: "SGD" 46 | LEARNING_RATE: 0.01 47 | COSINE: 1 48 | FIX: 0 49 | WEIGHT_DECAY: 0.0001 50 | MOMENTUM: 0.9 51 | BATCH_SIZE: 16 52 | NUM_WORKERS: 32 53 | 54 | # ------------EPOCHS------------------ 55 | WARMUP_EPOCHS: 0 56 | MAX_EPOCHS: 100 57 | FREQENCE: 1 58 | 59 | SYNCBN: 1 60 | 61 | # -----------LOSSES ----------------- 62 | LOSS: "BCE+DICE" -------------------------------------------------------------------------------- /hyparam/U2Net/baseline_bce_dice.yaml: -------------------------------------------------------------------------------- 1 | # u2net baseline hyparameters 2 | # ------------DDP ---------------- 3 | DIST: 4 | DISTRIBUTED: 1 5 | NGPU: 1 6 | RANK: -1 7 | DIST_BACKEND: "NCCL" 8 | LOCAL_RANK: -1 9 | 10 | MODEL_NAME: "u2net" 11 | DATASET_TYPE: "CUSTOM" 12 | # ------------VOC DATASET --------------- 13 | VOC_DATASET: 14 | TRAIN_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC/seg_train.log" 15 | VAL_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC/seg_val.log" 16 | NUM_CLASSES: 21 17 | 18 | # ------------VOC_AUG DATASET --------------- 19 | VOC_AUG_DATASET: 20 | TRAIN_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC_AUG/seg_train.log" 21 | VAL_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC_AUG/seg_val.log" 22 | NUM_CLASSES: 21 23 | 24 | # ---------- CUSTOM DATASET ------------ 25 | CUSTOM_DATASET: 26 | # TRAIN_FILE: "/data/jiangmingchao/data/code/U-2-Net/makeData/taobao_seg_no_repeat_train.log" 27 | TRAIN_FILE: "/data/jiangmingchao/data/code/U-2-Net/makeData/taobao_seg_20k.log" 28 | # TRAIN_FILE: "/data/jiangmingchao/data/code/U-2-Net/makeData/taobao_seg_12k_filter.log" 29 | VAL_FILE: "/data/jiangmingchao/data/code/U-2-Net/makeData/taobao_seg_no_repeat_val.log" 30 | NUM_CLASSES: 1 # used for celoss 31 | # ----------Transformes ---------------- 32 | DATA: 33 | CROP_SIZE: [320, 320] 34 | 35 | # TODO 36 | # ---------- Others DATASET-------------- 37 | # ---------- Custom DATASET ------------- 38 | 39 | # -----------SUMMARY---------------- 40 | SUMMARY: 41 | SAVE_PATH: "/data/jiangmingchao/data/AICKPT/Seg/U2Net/baseline_bce_dice_refine_gt_adam_320x320_20k_data_no_pretrain/" 42 | CHECKPOINTS: "checkpoints" 43 | LOG: "log_dir" 44 | 45 | # -----------OPTIMIZER-------------- 46 | OPTIMIZER: 47 | OPTIM_NAME: "ADAMW" # ADAMW 48 | LEARNING_RATE: 0.001 # 0.001 49 | COSINE: 1 50 | FIX: 0 51 | WEIGHT_DECAY: 0.01 # 0.01 52 | MOMENTUM: 0.9 53 | BATCH_SIZE: 16 54 | NUM_WORKERS: 32 55 | 56 | # ------------EPOCHS------------------ 57 | WARMUP_EPOCHS: 0 58 | MAX_EPOCHS: 100 59 | FREQENCE: 1 60 | 61 | SYNCBN: 1 62 | 63 | # -----------LOSSES ----------------- 64 | LOSS: "BCE+DICE" 65 | 66 | # -----------PRETRAIN -------------- 67 | PRETRAIN: False 68 | PRETRAIN_WEIGHTS: "/data/jiangmingchao/data/AICKPT/Seg/U2Net/baseline_bce_dice_refine_gt_adam_768x768_20k_data/checkpoints/best_ckpt_losses_1.0762560623032706_miou_0.9748090077617877.pth" 69 | 70 | 71 | # /data/jiangmingchao/data/AICKPT/Seg/U2Net/baseline_bce_dice_refine_gt/ 72 | 73 | # /data/jiangmingchao/data/code/U-2-Net/makeData/taobao_seg_20k.log -------------------------------------------------------------------------------- /hyparam/U2Net/baseline_bce_dice_colorjitter.yaml: -------------------------------------------------------------------------------- 1 | # u2net baseline hyparameters 2 | # ------------DDP ---------------- 3 | DIST: 4 | DISTRIBUTED: 1 5 | NGPU: 1 6 | RANK: -1 7 | DIST_BACKEND: "NCCL" 8 | LOCAL_RANK: -1 9 | 10 | MODEL_NAME: "u2net" 11 | DATASET_TYPE: "CUSTOM" 12 | # ------------VOC DATASET --------------- 13 | VOC_DATASET: 14 | TRAIN_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC/seg_train.log" 15 | VAL_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC/seg_val.log" 16 | NUM_CLASSES: 21 17 | 18 | # ------------VOC_AUG DATASET --------------- 19 | VOC_AUG_DATASET: 20 | TRAIN_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC_AUG/seg_train.log" 21 | VAL_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC_AUG/seg_val.log" 22 | NUM_CLASSES: 21 23 | 24 | # ---------- CUSTOM DATASET ------------ 25 | CUSTOM_DATASET: 26 | TRAIN_FILE: "/data/jiangmingchao/data/code/U-2-Net/makeData/taobao_seg_no_repeat_train.log" 27 | VAL_FILE: "/data/jiangmingchao/data/code/U-2-Net/makeData/taobao_seg_no_repeat_val.log" 28 | NUM_CLASSES: 1 # used for celoss 29 | # ----------Transformes ---------------- 30 | DATA: 31 | CROP_SIZE: [320, 320] 32 | 33 | # TODO 34 | # ---------- Others DATASET-------------- 35 | # ---------- Custom DATASET ------------- 36 | 37 | # -----------SUMMARY---------------- 38 | SUMMARY: 39 | SAVE_PATH: "/data/jiangmingchao/data/AICKPT/Seg/U2Net/baseline_bce_dice_refine_gt_colorjitter/" 40 | CHECKPOINTS: "checkpoints" 41 | LOG: "log_dir" 42 | 43 | # -----------OPTIMIZER-------------- 44 | OPTIMIZER: 45 | OPTIM_NAME: "SGD" 46 | LEARNING_RATE: 0.01 47 | COSINE: 1 48 | FIX: 0 49 | WEIGHT_DECAY: 0.0001 50 | MOMENTUM: 0.9 51 | BATCH_SIZE: 16 52 | NUM_WORKERS: 32 53 | 54 | # ------------EPOCHS------------------ 55 | WARMUP_EPOCHS: 0 56 | MAX_EPOCHS: 120 57 | FREQENCE: 1 58 | 59 | SYNCBN: 1 60 | 61 | # -----------LOSSES ----------------- 62 | LOSS: "BCE+DICE" 63 | 64 | # -----------DATAAUG----------------- 65 | -------------------------------------------------------------------------------- /hyparam/U2Net/baseline_bce_dice_l1c.yaml: -------------------------------------------------------------------------------- 1 | # u2net baseline hyparameters 2 | # ------------DDP ---------------- 3 | DIST: 4 | DISTRIBUTED: 1 5 | NGPU: 1 6 | RANK: -1 7 | DIST_BACKEND: "NCCL" 8 | LOCAL_RANK: -1 9 | 10 | MODEL_NAME: "u2net" 11 | DATASET_TYPE: "CUSTOM" 12 | # ------------VOC DATASET --------------- 13 | VOC_DATASET: 14 | TRAIN_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC/seg_train.log" 15 | VAL_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC/seg_val.log" 16 | NUM_CLASSES: 21 17 | 18 | # ------------VOC_AUG DATASET --------------- 19 | VOC_AUG_DATASET: 20 | TRAIN_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC_AUG/seg_train.log" 21 | VAL_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC_AUG/seg_val.log" 22 | NUM_CLASSES: 21 23 | 24 | # ---------- CUSTOM DATASET ------------ 25 | CUSTOM_DATASET: 26 | TRAIN_FILE: "/data/jiangmingchao/data/code/U-2-Net/makeData/taobao_seg_no_repeat_train.log" 27 | VAL_FILE: "/data/jiangmingchao/data/code/U-2-Net/makeData/taobao_seg_no_repeat_val.log" 28 | NUM_CLASSES: 1 # used for celoss 29 | # ----------Transformes ---------------- 30 | DATA: 31 | CROP_SIZE: [640, 640] 32 | 33 | # TODO 34 | # ---------- Others DATASET-------------- 35 | # ---------- Custom DATASET ------------- 36 | 37 | # -----------SUMMARY---------------- 38 | SUMMARY: 39 | SAVE_PATH: "/data/jiangmingchao/data/AICKPT/Seg/U2Net/baseline_bce_dice_cl1_640/" 40 | CHECKPOINTS: "checkpoints" 41 | LOG: "log_dir" 42 | 43 | # -----------OPTIMIZER-------------- 44 | OPTIMIZER: 45 | OPTIM_NAME: "SGD" 46 | LEARNING_RATE: 0.01 47 | COSINE: 1 48 | FIX: 0 49 | WEIGHT_DECAY: 0.0001 50 | MOMENTUM: 0.9 51 | BATCH_SIZE: 12 52 | NUM_WORKERS: 32 53 | 54 | # ------------EPOCHS------------------ 55 | WARMUP_EPOCHS: 0 56 | MAX_EPOCHS: 100 57 | FREQENCE: 1 58 | 59 | SYNCBN: 0 60 | 61 | # -----------LOSSES ----------------- 62 | LOSS: "BCE+DICE+L1" -------------------------------------------------------------------------------- /hyparam/U2Net/baseline_bce_dice_l1c_pretrain.yaml: -------------------------------------------------------------------------------- 1 | # u2net baseline hyparameters 2 | # ------------DDP ---------------- 3 | DIST: 4 | DISTRIBUTED: 1 5 | NGPU: 1 6 | RANK: -1 7 | DIST_BACKEND: "NCCL" 8 | LOCAL_RANK: -1 9 | 10 | MODEL_NAME: "u2net" 11 | DATASET_TYPE: "CUSTOM" 12 | # ------------VOC DATASET --------------- 13 | VOC_DATASET: 14 | TRAIN_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC/seg_train.log" 15 | VAL_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC/seg_val.log" 16 | NUM_CLASSES: 21 17 | 18 | # ------------VOC_AUG DATASET --------------- 19 | VOC_AUG_DATASET: 20 | TRAIN_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC_AUG/seg_train.log" 21 | VAL_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC_AUG/seg_val.log" 22 | NUM_CLASSES: 21 23 | 24 | # ---------- CUSTOM DATASET ------------ 25 | CUSTOM_DATASET: 26 | TRAIN_FILE: "/data/jiangmingchao/data/code/U-2-Net/makeData/taobao_seg_no_repeat_train.log" 27 | VAL_FILE: "/data/jiangmingchao/data/code/U-2-Net/makeData/taobao_seg_no_repeat_val.log" 28 | NUM_CLASSES: 1 # used for celoss 29 | # ----------Transformes ---------------- 30 | DATA: 31 | CROP_SIZE: [640, 640] 32 | 33 | # TODO 34 | # ---------- Others DATASET-------------- 35 | # ---------- Custom DATASET ------------- 36 | 37 | # -----------SUMMARY---------------- 38 | SUMMARY: 39 | SAVE_PATH: "/data/jiangmingchao/data/AICKPT/Seg/U2Net/baseline_bce_dice_cl1_640_pretrain/" 40 | CHECKPOINTS: "checkpoints" 41 | LOG: "log_dir" 42 | 43 | # -----------OPTIMIZER-------------- 44 | OPTIMIZER: 45 | OPTIM_NAME: "SGD" 46 | LEARNING_RATE: 0.01 47 | COSINE: 1 48 | FIX: 0 49 | WEIGHT_DECAY: 0.0001 50 | MOMENTUM: 0.9 51 | BATCH_SIZE: 12 52 | NUM_WORKERS: 32 53 | 54 | # ------------EPOCHS------------------ 55 | WARMUP_EPOCHS: 0 56 | MAX_EPOCHS: 150 57 | FREQENCE: 1 58 | 59 | SYNCBN: 0 60 | 61 | # -----------LOSSES ----------------- 62 | LOSS: "BCE+DICE+L1" 63 | 64 | # -----------PRETRAIN -------------- 65 | PRETRAIN: True 66 | PRETRAIN_WEIGHTS: "/data/jiangmingchao/data/AICKPT/Seg/U2Net/baseline_bce_dice_cl1/checkpoints/best_ckpt_losses_2.156236001423427_miou_0.9645676702550512.pth" -------------------------------------------------------------------------------- /hyparam/U2Net/baseline_bce_dice_pretrain_320_data.yaml: -------------------------------------------------------------------------------- 1 | # u2net baseline hyparameters 2 | # ------------DDP ---------------- 3 | DIST: 4 | DISTRIBUTED: 1 5 | NGPU: 1 6 | RANK: -1 7 | DIST_BACKEND: "NCCL" 8 | LOCAL_RANK: -1 9 | 10 | # MODEL_NAME: "u2net" 11 | MODEL_NAME: "u2net" 12 | DATASET_TYPE: "CUSTOM" 13 | # ------------VOC DATASET --------------- 14 | VOC_DATASET: 15 | TRAIN_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC/seg_train.log" 16 | VAL_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC/seg_val.log" 17 | NUM_CLASSES: 21 18 | 19 | # ------------VOC_AUG DATASET --------------- 20 | VOC_AUG_DATASET: 21 | TRAIN_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC_AUG/seg_train.log" 22 | VAL_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC_AUG/seg_val.log" 23 | NUM_CLASSES: 21 24 | 25 | # ---------- CUSTOM DATASET ------------ 26 | CUSTOM_DATASET: 27 | # TRAIN_FILE: "/data/jiangmingchao/data/code/U-2-Net/makeData/taobao_seg_no_repeat_train.log" 28 | # TRAIN_FILE: "/data/jiangmingchao/data/code/U-2-Net/makeData/taobao_seg_20k.log" 29 | # TRAIN_FILE: "/data/jiangmingchao/data/code/U-2-Net/makeData/taobao_seg_12k_filter.log" 30 | # TRAIN_FILE: "/data/jiangmingchao/data/code/U-2-Net/makeData/taobao_seg_14k_rm_same_folder.log" # best 31 | # TRAIN_FILE: "/data/jiangmingchao/data/code/U-2-Net/makeData/taobao_green/taobao_seg_14k_green_tryon_1k_shuf.txt" 32 | # TRAIN_FILE : "/data/jiangmingchao/data/code/U-2-Net/makeData/taobao_green/taobao_14k_green_tryon_1k_filter_3w_2k_shuf.log" # NEW 17K 33 | 34 | # TRAIN_FILE: "/data/jiangmingchao/data/code/U-2-Net/makeData/taobao_green_shein/taobao_green_shein_shuf_27k.log" # NEW 27K 35 | # VAL_FILE: "/data/jiangmingchao/data/code/U-2-Net/makeData/taobao_seg_no_repeat_val.log" 36 | TRAIN_FILE: "/data/jiangmingchao/data/code/U-2-Net/makeData/taobao_green_shein/train_seg_shuf_25k.log" 37 | VAL_FILE : "/data/jiangmingchao/data/code/U-2-Net/makeData/seg_val_rm_hole.log" 38 | NUM_CLASSES: 1 # used for celoss 39 | # ----------Transformes ---------------- 40 | DATA: 41 | CROP_SIZE: [480, 480] 42 | 43 | # TODO 44 | # ---------- Others DATASET-------------- 45 | # ---------- Custom DATASET ------------- 46 | 47 | # -----------SUMMARY---------------- 48 | SUMMARY: 49 | SAVE_PATH: "/data/jiangmingchao/data/AICKPT/Seg/U2Net/400_epoch_480_crop_0.3_1.2_5E-4_circle_2_25k_baseline" 50 | CHECKPOINTS: "checkpoints" 51 | LOG: "log_dir" 52 | 53 | # -----------OPTIMIZER-------------- 54 | OPTIMIZER: 55 | OPTIM_NAME: "ADAMW" # ADAMW 56 | LEARNING_RATE: 0.0005 # 0.001 57 | COSINE: 0 58 | FIX: 0 59 | CRICLE: 1 60 | CRICLE_STEPS: 2 61 | WEIGHT_DECAY: 0.01 # 0.01 62 | MOMENTUM: 0.9 63 | BATCH_SIZE: 16 64 | NUM_WORKERS: 48 65 | 66 | # ------------EPOCHS------------------ 67 | WARMUP_EPOCHS: 10 68 | MAX_EPOCHS: 400 69 | FREQENCE: 1 70 | 71 | # ------------SycnBN---------------- 72 | SYNCBN: 0 73 | 74 | # -----------LOSSES ----------------- 75 | LOSS: "BCE+DICE" 76 | 77 | # -----------PRETRAIN -------------- 78 | PRETRAIN: False 79 | PRETRAIN_WEIGHTS: /data/jiangmingchao/data/AICKPT/Seg/U2Net/300_epoch_800_crop_0.3_1.2_cricle_2_from_27kpretrain_5e-4_new27k/checkpoints/best_ckpt_losses_0.17932851133602007_miou_0.9900042523340764.pth 80 | # PRETRAIN_WEIGHTS: /data/jiangmingchao/data/AICKPT/Seg/U2Net/300_epoch_320_crop_0.3_1.2_cricle_2_new27k/checkpoints/best_ckpt_losses_0.37233564257621765_miou_0.9840403229107935.pth 81 | # PRETRAIN_WEIGHTS: /data/jiangmingchao/data/AICKPT/Seg/U2Net/baseline_300_epoch_320_crop_0.3_1.2_relu_1_bce_1_dice_cricle_2/checkpoints/best_ckpt_losses_0.36753059923648834_miou_0.9835136381246792.pth 82 | # PRETRAIN_WEIGHTS: "/data/jiangmingchao/data/AICKPT/Seg/U2Net/baseline_bce_dice_800_relu_pretrain_17k/checkpoints/best_ckpt_losses_0.4262606395142419_miou_0.9889711563277688.pth" 83 | 84 | # PRETRAIN_WEIGHTS: "/data/jiangmingchao/data/AICKPT/Seg/U2Net/baseline_bce_dice_refine_gt_320x320_relu/checkpoints/best_ckpt_losses_0.9227258222443717_miou_0.9785976140677624.pth" 85 | 86 | # PRETRAIN_WEIGHTS: "/data/jiangmingchao/data/AICKPT/Seg/U2Net/baseline_bce_dice_refine_gt_adamw_real_800x800_15k_20epoch_add_green_RandomCropScale_0.5/checkpoints/best_ckpt_losses_0.43711744035993305_miou_0.9885654761160251.pth" 87 | 88 | 89 | # PRETRAIN_WEIGHTS: "/data/jiangmingchao/data/AICKPT/Seg/U2Net/baseline_bce_dice_refine_gt_320x320_elu/checkpoints/best_ckpt_losses_0.9618568846157619_miou_0.9797444421563668.pth" 90 | 91 | 92 | # /data/jiangmingchao/data/AICKPT/Seg/U2Net/baseline_bce_dice_refine_gt_adamw_real_320x320_14k_data/checkpoints/best_ckpt_losses_1.1548381788390023_miou_0.9731329759628107.pth 93 | 94 | # best 95 | # "/data/jiangmingchao/data/AICKPT/Seg/U2Net/baseline_bce_dice_refine_gt_adamw_real_800x800_14k_finetune_20epoch/checkpoints/best_ckpt_losses_0.5155016886336463_miou_0.9867006791631263.pth" 96 | 97 | # ------------Batch Aug -------------- 98 | BATCH_AUG: 99 | MIXUP: False 100 | 101 | # /data/jiangmingchao/data/AICKPT/Seg/U2Net/baseline_bce_dice_refine_gt_adamw_real_768x768_14k_data_320_pretrain_long_epoch/checkpoints/best_ckpt_losses_0.5636463910341263_miou_0.9844545931472927.pth 102 | 103 | 104 | # /data/jiangmingchao/data/AICKPT/Seg/U2Net/baseline_bce_dice_refine_gt/ 105 | 106 | # /data/jiangmingchao/data/code/U-2-Net/makeData/taobao_seg_20k.log 107 | 108 | # /data/jiangmingchao/data/AICKPT/Seg/U2Net/baseline_bce_dice_refine_gt_adam_768x768/checkpoints/best_ckpt_losses_0.6692883372306824_miou_0.9821306581457698.pth 109 | # /data/jiangmingchao/data/AICKPT/Seg/U2Net/baseline_bce_dice_refine_gt_adam_768x768_20k_data/checkpoints/best_ckpt_losses_1.0762560623032706_miou_0.9748090077617877.pth 110 | 111 | 112 | # ACCUMULATE 113 | ACCUMULATE: False 114 | ACCUMULATE_STEPS: 2 -------------------------------------------------------------------------------- /hyparam/U2Net/baseline_bce_dice_pretrain_640.yaml: -------------------------------------------------------------------------------- 1 | # u2net baseline hyparameters 2 | # ------------DDP ---------------- 3 | DIST: 4 | DISTRIBUTED: 1 5 | NGPU: 1 6 | RANK: -1 7 | DIST_BACKEND: "NCCL" 8 | LOCAL_RANK: -1 9 | 10 | MODEL_NAME: "u2net" 11 | DATASET_TYPE: "CUSTOM" 12 | # ------------VOC DATASET --------------- 13 | VOC_DATASET: 14 | TRAIN_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC/seg_train.log" 15 | VAL_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC/seg_val.log" 16 | NUM_CLASSES: 21 17 | 18 | # ------------VOC_AUG DATASET --------------- 19 | VOC_AUG_DATASET: 20 | TRAIN_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC_AUG/seg_train.log" 21 | VAL_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC_AUG/seg_val.log" 22 | NUM_CLASSES: 21 23 | 24 | # ---------- CUSTOM DATASET ------------ 25 | CUSTOM_DATASET: 26 | TRAIN_FILE: "/data/jiangmingchao/data/code/U-2-Net/makeData/taobao_seg_no_repeat_train.log" 27 | VAL_FILE: "/data/jiangmingchao/data/code/U-2-Net/makeData/taobao_seg_no_repeat_val.log" 28 | NUM_CLASSES: 1 # used for celoss 29 | # ----------Transformes ---------------- 30 | DATA: 31 | CROP_SIZE: [768, 768] 32 | 33 | # TODO 34 | # ---------- Others DATASET-------------- 35 | # ---------- Custom DATASET ------------- 36 | 37 | # -----------SUMMARY---------------- 38 | SUMMARY: 39 | SAVE_PATH: "/data/jiangmingchao/data/AICKPT/Seg/U2Net/baseline_bce_dice_refine_gt_adam_768x768_two_node/" 40 | CHECKPOINTS: "checkpoints" 41 | LOG: "log_dir" 42 | 43 | # -----------OPTIMIZER-------------- 44 | OPTIMIZER: 45 | OPTIM_NAME: "ADAMW" # ADAMW 46 | LEARNING_RATE: 0.001 # 0.001 47 | COSINE: 1 48 | FIX: 0 49 | WEIGHT_DECAY: 0.01 # 0.01 50 | MOMENTUM: 0.9 51 | BATCH_SIZE: 8 52 | NUM_WORKERS: 32 53 | 54 | # ------------EPOCHS------------------ 55 | WARMUP_EPOCHS: 0 56 | MAX_EPOCHS: 60 57 | FREQENCE: 1 58 | 59 | SYNCBN: 1 60 | 61 | # -----------LOSSES ----------------- 62 | LOSS: "BCE+DICE" 63 | 64 | # -----------PRETRAIN -------------- 65 | PRETRAIN: True 66 | PRETRAIN_WEIGHTS: "/data/jiangmingchao/data/AICKPT/Seg/U2Net/baseline_bce_dice_refine_gt_adamw_1e-3/checkpoints/best_ckpt_losses_1.2692347168922424_miou_0.9704867602975709.pth" 67 | 68 | 69 | # /data/jiangmingchao/data/AICKPT/Seg/U2Net/baseline_bce_dice_refine_gt/ -------------------------------------------------------------------------------- /hyparam/U2Net/baseline_bce_dice_pretrain_768_colorjitter.yaml: -------------------------------------------------------------------------------- 1 | # u2net baseline hyparameters 2 | # ------------DDP ---------------- 3 | DIST: 4 | DISTRIBUTED: 1 5 | NGPU: 1 6 | RANK: -1 7 | DIST_BACKEND: "NCCL" 8 | LOCAL_RANK: -1 9 | 10 | MODEL_NAME: "u2net" 11 | DATASET_TYPE: "CUSTOM" 12 | # ------------VOC DATASET --------------- 13 | VOC_DATASET: 14 | TRAIN_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC/seg_train.log" 15 | VAL_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC/seg_val.log" 16 | NUM_CLASSES: 21 17 | 18 | # ------------VOC_AUG DATASET --------------- 19 | VOC_AUG_DATASET: 20 | TRAIN_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC_AUG/seg_train.log" 21 | VAL_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC_AUG/seg_val.log" 22 | NUM_CLASSES: 21 23 | 24 | # ---------- CUSTOM DATASET ------------ 25 | CUSTOM_DATASET: 26 | TRAIN_FILE: "/data/jiangmingchao/data/code/U-2-Net/makeData/taobao_seg_no_repeat_train.log" 27 | VAL_FILE: "/data/jiangmingchao/data/code/U-2-Net/makeData/taobao_seg_no_repeat_val.log" 28 | NUM_CLASSES: 1 # used for celoss 29 | # ----------Transformes ---------------- 30 | DATA: 31 | CROP_SIZE: [768, 512] 32 | 33 | # TODO 34 | # ---------- Others DATASET-------------- 35 | # ---------- Custom DATASET ------------- 36 | 37 | # -----------SUMMARY---------------- 38 | SUMMARY: 39 | SAVE_PATH: "/data/jiangmingchao/data/AICKPT/Seg/U2Net/baseline_bce_dice_refine_gt_sgd_1e-2_768x512_color/" 40 | CHECKPOINTS: "checkpoints" 41 | LOG: "log_dir" 42 | 43 | # -----------OPTIMIZER-------------- 44 | OPTIMIZER: 45 | OPTIM_NAME: "SGD" # ADAMW 46 | LEARNING_RATE: 0.01 # 0.001 47 | COSINE: 1 48 | FIX: 0 49 | WEIGHT_DECAY: 0.0001 # 0.01 50 | MOMENTUM: 0.9 51 | BATCH_SIZE: 12 52 | NUM_WORKERS: 32 53 | 54 | # ------------EPOCHS------------------ 55 | WARMUP_EPOCHS: 0 56 | MAX_EPOCHS: 80 57 | FREQENCE: 1 58 | 59 | SYNCBN: 1 60 | 61 | # -----------LOSSES ----------------- 62 | LOSS: "BCE+DICE" 63 | 64 | # -----------PRETRAIN -------------- 65 | PRETRAIN: True 66 | PRETRAIN_WEIGHTS: "/data/jiangmingchao/data/AICKPT/Seg/U2Net/baseline_bce_dice_refine_gt_adamw_1e-3/checkpoints/best_ckpt_losses_1.2692347168922424_miou_0.9704867602975709.pth" 67 | 68 | 69 | # /data/jiangmingchao/data/AICKPT/Seg/U2Net/baseline_bce_dice_refine_gt/ -------------------------------------------------------------------------------- /hyparam/U2Net/baseline_bce_dice_tv.yaml: -------------------------------------------------------------------------------- 1 | # u2net baseline hyparameters 2 | # ------------DDP ---------------- 3 | DIST: 4 | DISTRIBUTED: 1 5 | NGPU: 1 6 | RANK: -1 7 | DIST_BACKEND: "NCCL" 8 | LOCAL_RANK: -1 9 | 10 | MODEL_NAME: "u2net" 11 | DATASET_TYPE: "CUSTOM" 12 | # ------------VOC DATASET --------------- 13 | VOC_DATASET: 14 | TRAIN_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC/seg_train.log" 15 | VAL_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC/seg_val.log" 16 | NUM_CLASSES: 21 17 | 18 | # ------------VOC_AUG DATASET --------------- 19 | VOC_AUG_DATASET: 20 | TRAIN_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC_AUG/seg_train.log" 21 | VAL_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC_AUG/seg_val.log" 22 | NUM_CLASSES: 21 23 | 24 | # ---------- CUSTOM DATASET ------------ 25 | CUSTOM_DATASET: 26 | TRAIN_FILE: "/data/jiangmingchao/data/code/U-2-Net/makeData/taobao_seg_no_repeat_train.log" 27 | VAL_FILE: "/data/jiangmingchao/data/code/U-2-Net/makeData/taobao_seg_no_repeat_val.log" 28 | NUM_CLASSES: 1 # used for celoss 29 | # ----------Transformes ---------------- 30 | DATA: 31 | CROP_SIZE: [320, 320] 32 | 33 | # TODO 34 | # ---------- Others DATASET-------------- 35 | # ---------- Custom DATASET ------------- 36 | 37 | # -----------SUMMARY---------------- 38 | SUMMARY: 39 | SAVE_PATH: "/data/jiangmingchao/data/AICKPT/Seg/U2Net/baseline_bce_dice_tv_refine_gt_adamw_1e-3/" 40 | CHECKPOINTS: "checkpoints" 41 | LOG: "log_dir" 42 | 43 | # -----------OPTIMIZER-------------- 44 | OPTIMIZER: 45 | OPTIM_NAME: "ADAMW" 46 | LEARNING_RATE: 0.001 # 0.00015 47 | COSINE: 1 48 | FIX: 0 49 | WEIGHT_DECAY: 0.01 50 | MOMENTUM: 0.9 51 | BATCH_SIZE: 16 52 | NUM_WORKERS: 32 53 | 54 | # ------------EPOCHS------------------ 55 | WARMUP_EPOCHS: 5 56 | MAX_EPOCHS: 120 57 | FREQENCE: 1 58 | 59 | SYNCBN: 1 60 | 61 | # -----------LOSSES ----------------- 62 | LOSS: "BCE+DICE+TV" 63 | 64 | 65 | # /data/jiangmingchao/data/AICKPT/Seg/U2Net/baseline_bce_dice_refine_gt/ -------------------------------------------------------------------------------- /hyparam/U2Net/baseline_ce_dice_l1_no_pretrain_640.yaml: -------------------------------------------------------------------------------- 1 | # u2net baseline hyparameters 2 | # ------------DDP ---------------- 3 | DIST: 4 | DISTRIBUTED: 1 5 | NGPU: 1 6 | RANK: -1 7 | DIST_BACKEND: "NCCL" 8 | LOCAL_RANK: -1 9 | 10 | MODEL_NAME: "u2net" 11 | DATASET_TYPE: "CUSTOM" 12 | # ------------VOC DATASET --------------- 13 | VOC_DATASET: 14 | TRAIN_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC/seg_train.log" 15 | VAL_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC/seg_val.log" 16 | NUM_CLASSES: 21 17 | 18 | # ------------VOC_AUG DATASET --------------- 19 | VOC_AUG_DATASET: 20 | TRAIN_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC_AUG/seg_train.log" 21 | VAL_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC_AUG/seg_val.log" 22 | NUM_CLASSES: 21 23 | 24 | # ---------- CUSTOM DATASET ------------ 25 | CUSTOM_DATASET: 26 | TRAIN_FILE: "/data/jiangmingchao/data/code/U-2-Net/makeData/taobao_seg_no_repeat_train.log" 27 | VAL_FILE: "/data/jiangmingchao/data/code/U-2-Net/makeData/taobao_seg_no_repeat_val.log" 28 | NUM_CLASSES: 2 # used for celoss 29 | # ----------Transformes ---------------- 30 | DATA: 31 | CROP_SIZE: [640, 640] 32 | 33 | # TODO 34 | # ---------- Others DATASET-------------- 35 | # ---------- Custom DATASET ------------- 36 | 37 | # -----------SUMMARY---------------- 38 | SUMMARY: 39 | SAVE_PATH: "/data/jiangmingchao/data/AICKPT/Seg/U2Net/baseline_ce_dice_cl1_640_no_pretrain/" 40 | CHECKPOINTS: "checkpoints" 41 | LOG: "log_dir" 42 | 43 | # -----------OPTIMIZER-------------- 44 | OPTIMIZER: 45 | OPTIM_NAME: "SGD" 46 | LEARNING_RATE: 0.01 47 | COSINE: 1 48 | FIX: 0 49 | WEIGHT_DECAY: 0.0001 50 | MOMENTUM: 0.9 51 | BATCH_SIZE: 12 52 | NUM_WORKERS: 32 53 | 54 | # ------------EPOCHS------------------ 55 | WARMUP_EPOCHS: 0 56 | MAX_EPOCHS: 150 57 | FREQENCE: 1 58 | 59 | SYNCBN: 0 60 | 61 | # -----------LOSSES ----------------- 62 | LOSS: "CE+DICE+L1" 63 | 64 | # -----------PRETRAIN -------------- 65 | PRETRAIN: False 66 | PRETRAIN_WEIGHTS: "/data/jiangmingchao/data/AICKPT/Seg/U2Net/baseline_bce_dice_cl1/checkpoints/best_ckpt_losses_2.156236001423427_miou_0.9645676702550512.pth" -------------------------------------------------------------------------------- /hyparam/U2Net/baseline_ce_dice_l1_pretrain.yaml: -------------------------------------------------------------------------------- 1 | # u2net baseline hyparameters 2 | # ------------DDP ---------------- 3 | DIST: 4 | DISTRIBUTED: 1 5 | NGPU: 1 6 | RANK: -1 7 | DIST_BACKEND: "NCCL" 8 | LOCAL_RANK: -1 9 | 10 | MODEL_NAME: "u2net" 11 | DATASET_TYPE: "CUSTOM" 12 | # ------------VOC DATASET --------------- 13 | VOC_DATASET: 14 | TRAIN_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC/seg_train.log" 15 | VAL_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC/seg_val.log" 16 | NUM_CLASSES: 21 17 | 18 | # ------------VOC_AUG DATASET --------------- 19 | VOC_AUG_DATASET: 20 | TRAIN_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC_AUG/seg_train.log" 21 | VAL_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC_AUG/seg_val.log" 22 | NUM_CLASSES: 21 23 | 24 | # ---------- CUSTOM DATASET ------------ 25 | CUSTOM_DATASET: 26 | TRAIN_FILE: "/data/jiangmingchao/data/code/U-2-Net/makeData/taobao_seg_no_repeat_train.log" 27 | VAL_FILE: "/data/jiangmingchao/data/code/U-2-Net/makeData/taobao_seg_no_repeat_val.log" 28 | NUM_CLASSES: 2 # used for celoss 29 | # ----------Transformes ---------------- 30 | DATA: 31 | CROP_SIZE: [640, 640] 32 | 33 | # TODO 34 | # ---------- Others DATASET-------------- 35 | # ---------- Custom DATASET ------------- 36 | 37 | # -----------SUMMARY---------------- 38 | SUMMARY: 39 | SAVE_PATH: "/data/jiangmingchao/data/AICKPT/Seg/U2Net/baseline_ce_dice_cl1_640_pretrain/" 40 | CHECKPOINTS: "checkpoints" 41 | LOG: "log_dir" 42 | 43 | # -----------OPTIMIZER-------------- 44 | OPTIMIZER: 45 | OPTIM_NAME: "SGD" 46 | LEARNING_RATE: 0.01 47 | COSINE: 1 48 | FIX: 0 49 | WEIGHT_DECAY: 0.0001 50 | MOMENTUM: 0.9 51 | BATCH_SIZE: 12 52 | NUM_WORKERS: 32 53 | 54 | # ------------EPOCHS------------------ 55 | WARMUP_EPOCHS: 0 56 | MAX_EPOCHS: 150 57 | FREQENCE: 1 58 | 59 | SYNCBN: 0 60 | 61 | # -----------LOSSES ----------------- 62 | LOSS: "CE+DICE+L1" 63 | 64 | # -----------PRETRAIN -------------- 65 | PRETRAIN: True 66 | PRETRAIN_WEIGHTS: "/data/jiangmingchao/data/AICKPT/Seg/U2Net/baseline_bce_dice_cl1/checkpoints/best_ckpt_losses_2.156236001423427_miou_0.9645676702550512.pth" -------------------------------------------------------------------------------- /hyparam/U2Net/baseline_ce_pretrain_480_data.yaml: -------------------------------------------------------------------------------- 1 | # u2net baseline hyparameters 2 | # ------------DDP ---------------- 3 | DIST: 4 | DISTRIBUTED: 1 5 | NGPU: 1 6 | RANK: -1 7 | DIST_BACKEND: "NCCL" 8 | LOCAL_RANK: -1 9 | 10 | # MODEL_NAME: "u2net" 11 | MODEL_NAME: "u2net" 12 | DATASET_TYPE: "CUSTOM" 13 | # ------------VOC DATASET --------------- 14 | VOC_DATASET: 15 | TRAIN_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC/seg_train.log" 16 | VAL_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC/seg_val.log" 17 | NUM_CLASSES: 21 18 | 19 | # ------------VOC_AUG DATASET --------------- 20 | VOC_AUG_DATASET: 21 | TRAIN_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC_AUG/seg_train.log" 22 | VAL_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC_AUG/seg_val.log" 23 | NUM_CLASSES: 21 24 | 25 | # ---------- CUSTOM DATASET ------------ 26 | CUSTOM_DATASET: 27 | # TRAIN_FILE: "/data/jiangmingchao/data/code/U-2-Net/makeData/taobao_seg_no_repeat_train.log" 28 | # TRAIN_FILE: "/data/jiangmingchao/data/code/U-2-Net/makeData/taobao_seg_20k.log" 29 | # TRAIN_FILE: "/data/jiangmingchao/data/code/U-2-Net/makeData/taobao_seg_12k_filter.log" 30 | # TRAIN_FILE: "/data/jiangmingchao/data/code/U-2-Net/makeData/taobao_seg_14k_rm_same_folder.log" # best 31 | # TRAIN_FILE: "/data/jiangmingchao/data/code/U-2-Net/makeData/taobao_green/taobao_seg_14k_green_tryon_1k_shuf.txt" 32 | # TRAIN_FILE : "/data/jiangmingchao/data/code/U-2-Net/makeData/taobao_green/taobao_14k_green_tryon_1k_filter_3w_2k_shuf.log" # NEW 17K 33 | 34 | # TRAIN_FILE: "/data/jiangmingchao/data/code/U-2-Net/makeData/taobao_green_shein/taobao_green_shein_shuf_27k.log" # NEW 27K 35 | # VAL_FILE: "/data/jiangmingchao/data/code/U-2-Net/makeData/taobao_seg_no_repeat_val.log" 36 | 37 | # TRAIN_FILE: "/data/jiangmingchao/data/code/U-2-Net/makeData/taobao_green_shein/train_seg_shuf_25k.log" 38 | TRAIN_FILE : "/data/jiangmingchao/data/code/U-2-Net/makeData/taobao_green_shein/train_seg_27k_refine_hole_shuf.log" 39 | VAL_FILE : "/data/jiangmingchao/data/code/U-2-Net/makeData/seg_val_rm_hole.log" 40 | # NUM_CLASSES: 1 # used for bceloss 41 | NUM_CLASSES: 2 42 | # ----------Transformes ---------------- 43 | DATA: 44 | # CROP_SIZE: [480, 480] 45 | CROP_SIZE: [800, 800] 46 | 47 | # TODO 48 | # ---------- Others DATASET-------------- 49 | # ---------- Custom DATASET ------------- 50 | 51 | # -----------SUMMARY---------------- 52 | SUMMARY: 53 | SAVE_PATH: "/data/jiangmingchao/data/AICKPT/Seg/U2Net/400_epoch_800_crop_0.3_1.2_1E-4_circle_2_27k_copypaste_sgd_ce" 54 | CHECKPOINTS: "checkpoints" 55 | LOG: "log_dir" 56 | 57 | # -----------OPTIMIZER-------------- 58 | OPTIMIZER: 59 | OPTIM_NAME: "SGD" # ADAMW 60 | LEARNING_RATE: 0.01 # 0.001 61 | COSINE: 1 62 | FIX: 0 63 | CRICLE: 0 64 | CRICLE_STEPS: 2 65 | WEIGHT_DECAY: 0.0001 # 0.01 66 | MOMENTUM: 0.9 67 | BATCH_SIZE: 8 68 | NUM_WORKERS: 48 69 | 70 | # ------------EPOCHS------------------ 71 | WARMUP_EPOCHS: 10 72 | MAX_EPOCHS: 200 73 | FREQENCE: 1 74 | 75 | # ------------SycnBN---------------- 76 | SYNCBN: 0 77 | 78 | # -----------LOSSES ----------------- 79 | # LOSS: "BCE+DICE" 80 | LOSS: "CE" 81 | 82 | # -----------PRETRAIN -------------- 83 | PRETRAIN: False 84 | PRETRAIN_WEIGHTS: "/data/jiangmingchao/data/AICKPT/Seg/U2Net/400_epoch_800_crop_0.3_1.2_1E-4_circle_2_27k_copypaste/checkpoints/best_ckpt_losses_0.1756678047989096_miou_0.9899707500586136.pth" 85 | # PRETRAIN_WEIGHTS: "/data/jiangmingchao/data/AICKPT/Seg/U2Net/400_epoch_480_crop_0.3_1.2_5E-4_circle_2_25k_baseline_ce/checkpoints/best_ckpt_losses_0.08288952388933726_miou_0.9853206462283025.pth" 86 | # PRETRAIN_WEIGHTS: "/data/jiangmingchao/data/AICKPT/Seg/U2Net/400_epoch_480_crop_0.3_1.2_5E-4_circle_2_25k_baseline/checkpoints/best_ckpt_losses_0.255808310849326_miou_0.9866068313419727.pth" 87 | 88 | 89 | 90 | 91 | # PRETRAIN_WEIGHTS: /data/jiangmingchao/data/AICKPT/Seg/U2Net/300_epoch_800_crop_0.3_1.2_cricle_2_from_27kpretrain_5e-4_new27k/checkpoints/best_ckpt_losses_0.17932851133602007_miou_0.9900042523340764.pth 92 | # PRETRAIN_WEIGHTS: /data/jiangmingchao/data/AICKPT/Seg/U2Net/300_epoch_320_crop_0.3_1.2_cricle_2_new27k/checkpoints/best_ckpt_losses_0.37233564257621765_miou_0.9840403229107935.pth 93 | # PRETRAIN_WEIGHTS: /data/jiangmingchao/data/AICKPT/Seg/U2Net/baseline_300_epoch_320_crop_0.3_1.2_relu_1_bce_1_dice_cricle_2/checkpoints/best_ckpt_losses_0.36753059923648834_miou_0.9835136381246792.pth 94 | # PRETRAIN_WEIGHTS: "/data/jiangmingchao/data/AICKPT/Seg/U2Net/baseline_bce_dice_800_relu_pretrain_17k/checkpoints/best_ckpt_losses_0.4262606395142419_miou_0.9889711563277688.pth" 95 | 96 | # PRETRAIN_WEIGHTS: "/data/jiangmingchao/data/AICKPT/Seg/U2Net/baseline_bce_dice_refine_gt_320x320_relu/checkpoints/best_ckpt_losses_0.9227258222443717_miou_0.9785976140677624.pth" 97 | 98 | # PRETRAIN_WEIGHTS: "/data/jiangmingchao/data/AICKPT/Seg/U2Net/baseline_bce_dice_refine_gt_adamw_real_800x800_15k_20epoch_add_green_RandomCropScale_0.5/checkpoints/best_ckpt_losses_0.43711744035993305_miou_0.9885654761160251.pth" 99 | 100 | 101 | # PRETRAIN_WEIGHTS: "/data/jiangmingchao/data/AICKPT/Seg/U2Net/baseline_bce_dice_refine_gt_320x320_elu/checkpoints/best_ckpt_losses_0.9618568846157619_miou_0.9797444421563668.pth" 102 | 103 | 104 | # /data/jiangmingchao/data/AICKPT/Seg/U2Net/baseline_bce_dice_refine_gt_adamw_real_320x320_14k_data/checkpoints/best_ckpt_losses_1.1548381788390023_miou_0.9731329759628107.pth 105 | 106 | # best 107 | # "/data/jiangmingchao/data/AICKPT/Seg/U2Net/baseline_bce_dice_refine_gt_adamw_real_800x800_14k_finetune_20epoch/checkpoints/best_ckpt_losses_0.5155016886336463_miou_0.9867006791631263.pth" 108 | 109 | # ------------Batch Aug -------------- 110 | BATCH_AUG: 111 | MIXUP: False 112 | 113 | # /data/jiangmingchao/data/AICKPT/Seg/U2Net/baseline_bce_dice_refine_gt_adamw_real_768x768_14k_data_320_pretrain_long_epoch/checkpoints/best_ckpt_losses_0.5636463910341263_miou_0.9844545931472927.pth 114 | 115 | 116 | # /data/jiangmingchao/data/AICKPT/Seg/U2Net/baseline_bce_dice_refine_gt/ 117 | 118 | # /data/jiangmingchao/data/code/U-2-Net/makeData/taobao_seg_20k.log 119 | 120 | # /data/jiangmingchao/data/AICKPT/Seg/U2Net/baseline_bce_dice_refine_gt_adam_768x768/checkpoints/best_ckpt_losses_0.6692883372306824_miou_0.9821306581457698.pth 121 | # /data/jiangmingchao/data/AICKPT/Seg/U2Net/baseline_bce_dice_refine_gt_adam_768x768_20k_data/checkpoints/best_ckpt_losses_1.0762560623032706_miou_0.9748090077617877.pth 122 | 123 | 124 | # ACCUMULATE 125 | ACCUMULATE: False 126 | ACCUMULATE_STEPS: 2 -------------------------------------------------------------------------------- /hyparam/UNet/unet_full.yaml: -------------------------------------------------------------------------------- 1 | # base hyparams for training 2 | # ------------DDP ---------------- 3 | DIST: 4 | DISTRIBUTED: 1 5 | NGPU: 1 6 | RANK: -1 7 | DIST_BACKEND: "NCCL" 8 | LOCAL_RANK: -1 9 | 10 | MODEL_NAME: "unet_full" 11 | DATASET_TYPE: "VOC" 12 | # ------------VOC DATASET --------------- 13 | VOC_DATASET: 14 | TRAIN_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC/seg_train.log" 15 | VAL_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC/seg_val.log" 16 | NUM_CLASSES: 21 17 | 18 | # ------------VOC_AUG DATASET --------------- 19 | VOC_AUG_DATASET: 20 | TRAIN_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC_AUG/seg_train.log" 21 | VAL_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC_AUG/seg_val.log" 22 | NUM_CLASSES: 21 23 | 24 | # ----------Transformes ---------------- 25 | DATA: 26 | CROP_SIZE: [512, 512] 27 | 28 | # TODO 29 | # ---------- Others DATASET-------------- 30 | # ---------- Custom DATASET ------------- 31 | 32 | # -----------SUMMARY---------------- 33 | SUMMARY: 34 | SAVE_PATH: "/data/jiangmingchao/data/AICKPT/Seg/UNet/unet_full_gpux1_512" 35 | CHECKPOINTS: "checkpoints" 36 | LOG: "log_dir" 37 | 38 | # -----------OPTIMIZER-------------- 39 | OPTIMIZER: 40 | OPTIM_NAME: "SGD" 41 | LEARNING_RATE: 0.01 42 | COSINE: 1 43 | FIX: 0 44 | WEIGHT_DECAY: 0.0001 45 | MOMENTUM: 0.9 46 | BATCH_SIZE: 16 47 | NUM_WORKERS: 32 48 | 49 | # ------------EPOCHS------------------ 50 | WARMUP_EPOCHS: 0 51 | MAX_EPOCHS: 150 52 | FREQENCE: 1 53 | 54 | SYNCBN: 0 55 | 56 | # -----------LOSSES ----------------- -------------------------------------------------------------------------------- /hyparam/UNet/unet_full2.yaml: -------------------------------------------------------------------------------- 1 | # base hyparams for training 2 | # ------------DDP ---------------- 3 | DIST: 4 | DISTRIBUTED: 1 5 | NGPU: 1 6 | RANK: -1 7 | DIST_BACKEND: "NCCL" 8 | LOCAL_RANK: -1 9 | 10 | MODEL_NAME: "unet_full" 11 | DATASET_TYPE: "VOC" 12 | # ------------VOC DATASET --------------- 13 | VOC_DATASET: 14 | TRAIN_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC/seg_train.log" 15 | VAL_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC/seg_val.log" 16 | NUM_CLASSES: 21 17 | 18 | # ------------VOC_AUG DATASET --------------- 19 | VOC_AUG_DATASET: 20 | TRAIN_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC_AUG/seg_train.log" 21 | VAL_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC_AUG/seg_val.log" 22 | NUM_CLASSES: 21 23 | 24 | # ----------Transformes ---------------- 25 | DATA: 26 | CROP_SIZE: [512, 512] 27 | 28 | # TODO 29 | # ---------- Others DATASET-------------- 30 | # ---------- Custom DATASET ------------- 31 | 32 | # -----------SUMMARY---------------- 33 | SUMMARY: 34 | SAVE_PATH: "/data/jiangmingchao/data/AICKPT/Seg/UNet/unet_full_gpux1_512_sgd0.1" 35 | CHECKPOINTS: "checkpoints" 36 | LOG: "log_dir" 37 | 38 | # -----------OPTIMIZER-------------- 39 | OPTIMIZER: 40 | OPTIM_NAME: "SGD" 41 | LEARNING_RATE: 0.1 42 | COSINE: 1 43 | FIX: 0 44 | WEIGHT_DECAY: 0.0001 45 | MOMENTUM: 0.9 46 | BATCH_SIZE: 8 47 | NUM_WORKERS: 32 48 | 49 | # ------------EPOCHS------------------ 50 | WARMUP_EPOCHS: 0 51 | MAX_EPOCHS: 150 52 | FREQENCE: 1 53 | 54 | SYNCBN: 0 55 | 56 | # -----------LOSSES ----------------- -------------------------------------------------------------------------------- /hyparam/UNet/unet_full3.yaml: -------------------------------------------------------------------------------- 1 | # base hyparams for training 2 | # ------------DDP ---------------- 3 | DIST: 4 | DISTRIBUTED: 1 5 | NGPU: 1 6 | RANK: -1 7 | DIST_BACKEND: "NCCL" 8 | LOCAL_RANK: -1 9 | 10 | MODEL_NAME: "unet_full" 11 | DATASET_TYPE: "VOC" 12 | # ------------VOC DATASET --------------- 13 | VOC_DATASET: 14 | TRAIN_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC/seg_train.log" 15 | VAL_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC/seg_val.log" 16 | NUM_CLASSES: 21 17 | 18 | # ------------VOC_AUG DATASET --------------- 19 | VOC_AUG_DATASET: 20 | TRAIN_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC_AUG/seg_train.log" 21 | VAL_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC_AUG/seg_val.log" 22 | NUM_CLASSES: 21 23 | 24 | # ----------Transformes ---------------- 25 | DATA: 26 | CROP_SIZE: [512, 512] 27 | 28 | # TODO 29 | # ---------- Others DATASET-------------- 30 | # ---------- Custom DATASET ------------- 31 | 32 | # -----------SUMMARY---------------- 33 | SUMMARY: 34 | SAVE_PATH: "/data/jiangmingchao/data/AICKPT/Seg/UNet/unet_full_gpux1_512_sgd0.001" 35 | CHECKPOINTS: "checkpoints" 36 | LOG: "log_dir" 37 | 38 | # -----------OPTIMIZER-------------- 39 | OPTIMIZER: 40 | OPTIM_NAME: "SGD" 41 | LEARNING_RATE: 0.001 42 | COSINE: 1 43 | FIX: 0 44 | WEIGHT_DECAY: 0.0001 45 | MOMENTUM: 0.9 46 | BATCH_SIZE: 8 47 | NUM_WORKERS: 32 48 | 49 | # ------------EPOCHS------------------ 50 | WARMUP_EPOCHS: 0 51 | MAX_EPOCHS: 150 52 | FREQENCE: 1 53 | 54 | SYNCBN: 0 55 | 56 | # -----------LOSSES ----------------- -------------------------------------------------------------------------------- /hyparam/UNet/unet_full_adam.yaml: -------------------------------------------------------------------------------- 1 | # base hyparams for training 2 | # ------------DDP ---------------- 3 | DIST: 4 | DISTRIBUTED: 1 5 | NGPU: 1 6 | RANK: -1 7 | DIST_BACKEND: "NCCL" 8 | LOCAL_RANK: -1 9 | 10 | MODEL_NAME: "unet_full" 11 | DATASET_TYPE: "VOC" 12 | # ------------VOC DATASET --------------- 13 | VOC_DATASET: 14 | TRAIN_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC/seg_train.log" 15 | VAL_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC/seg_val.log" 16 | NUM_CLASSES: 21 17 | 18 | # ------------VOC_AUG DATASET --------------- 19 | VOC_AUG_DATASET: 20 | TRAIN_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC_AUG/seg_train.log" 21 | VAL_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC_AUG/seg_val.log" 22 | NUM_CLASSES: 21 23 | 24 | # ----------Transformes ---------------- 25 | DATA: 26 | CROP_SIZE: [512, 512] 27 | 28 | # TODO 29 | # ---------- Others DATASET-------------- 30 | # ---------- Custom DATASET ------------- 31 | 32 | # -----------SUMMARY---------------- 33 | SUMMARY: 34 | SAVE_PATH: "/data/jiangmingchao/data/AICKPT/Seg/UNet/unet_full_gpux1_512_adam" 35 | CHECKPOINTS: "checkpoints" 36 | LOG: "log_dir" 37 | 38 | # -----------OPTIMIZER-------------- 39 | OPTIMIZER: 40 | OPTIM_NAME: "ADAM" 41 | LEARNING_RATE: 0.001 42 | BETAS: [0.9, 0.999] 43 | EPS: 0.00000001 44 | COSINE: 0 45 | FIX: 1 46 | WEIGHT_DECAY: 0.0 47 | MOMENTUM: 0.9 48 | BATCH_SIZE: 8 49 | NUM_WORKERS: 32 50 | 51 | # ------------EPOCHS------------------ 52 | WARMUP_EPOCHS: 0 53 | MAX_EPOCHS: 150 54 | FREQENCE: 1 55 | 56 | SYNCBN: 0 57 | 58 | # -----------LOSSES ----------------- -------------------------------------------------------------------------------- /hyparam/UNet/unet_resnet50.yaml: -------------------------------------------------------------------------------- 1 | # base hyparams for training 2 | # ------------DDP ---------------- 3 | DIST: 4 | DISTRIBUTED: 1 5 | NGPU: 1 6 | RANK: -1 7 | DIST_BACKEND: "NCCL" 8 | LOCAL_RANK: -1 9 | 10 | MODEL_NAME: "unet_resnet50" 11 | DATASET_TYPE: "VOC" 12 | # ------------VOC DATASET --------------- 13 | VOC_DATASET: 14 | TRAIN_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC/seg_train.log" 15 | VAL_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC/seg_val.log" 16 | NUM_CLASSES: 21 17 | 18 | # ------------VOC_AUG DATASET --------------- 19 | VOC_AUG_DATASET: 20 | TRAIN_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC_AUG/seg_train.log" 21 | VAL_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC_AUG/seg_val.log" 22 | NUM_CLASSES: 21 23 | 24 | # ----------Transformes ---------------- 25 | DATA: 26 | CROP_SIZE: [512, 512] 27 | 28 | # TODO 29 | # ---------- Others DATASET-------------- 30 | # ---------- Custom DATASET ------------- 31 | 32 | # -----------SUMMARY---------------- 33 | SUMMARY: 34 | SAVE_PATH: "/data/jiangmingchao/data/AICKPT/Seg/UNet/unet_resnet50" 35 | CHECKPOINTS: "checkpoints" 36 | LOG: "log_dir" 37 | 38 | # -----------OPTIMIZER-------------- 39 | OPTIMIZER: 40 | OPTIM_NAME: "SGD" 41 | LEARNING_RATE: 0.01 42 | COSINE: 1 43 | FIX: 0 44 | WEIGHT_DECAY: 0.0001 45 | MOMENTUM: 0.9 46 | BATCH_SIZE: 8 47 | NUM_WORKERS: 32 48 | 49 | # ------------EPOCHS------------------ 50 | WARMUP_EPOCHS: 0 51 | MAX_EPOCHS: 150 52 | FREQENCE: 1 53 | 54 | SYNCBN: 0 55 | 56 | # -----------LOSSES ----------------- -------------------------------------------------------------------------------- /hyparam/base.yaml: -------------------------------------------------------------------------------- 1 | # base hyparams for training 2 | # ------------DDP ---------------- 3 | DIST: 4 | DISTRIBUTED: 1 5 | NGPU: 1 6 | RANK: -1 7 | DIST_BACKEND: "NCCL" 8 | LOCAL_RANK: -1 9 | 10 | MODEL_NAME: "FCN_8S" 11 | DATASET_TYPE: "VOC" 12 | # ------------VOC DATASET --------------- 13 | VOC_DATASET: 14 | TRAIN_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC/seg_train.log" 15 | VAL_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC/seg_val.log" 16 | NUM_CLASSES: 21 17 | 18 | # ------------VOC_AUG DATASET --------------- 19 | VOC_AUG_DATASET: 20 | TRAIN_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC_AUG/seg_train.log" 21 | VAL_FILE: "/data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/VOC_AUG/seg_val.log" 22 | NUM_CLASSES: 21 23 | 24 | # ----------Transformes ---------------- 25 | DATA: 26 | CROP_SIZE: [320, 320] 27 | 28 | 29 | # TODO 30 | # ---------- Others DATASET-------------- 31 | # ---------- Custom DATASET ------------- 32 | 33 | # -----------SUMMARY---------------- 34 | SUMMARY: 35 | SAVE_PATH: "/data/jiangmingchao/data/AICutDataset/Segmentation/" 36 | CHECKPOINTS: "checkpoints" 37 | LOG: "log_dir" 38 | 39 | # -----------OPTIMIZER-------------- 40 | OPTIMIZER: 41 | OPTIM_NAME: "SGD" 42 | LEARNING_RATE: 1e-1 43 | COSINE: 1 44 | FIX: 0 45 | WEIGHT_DECAY: 1e-4 46 | MOMENTUM: 0.9 47 | BATCH_SIZE: 64 48 | NUM_WORKERS: 32 49 | 50 | # ------------EPOCHS------------------ 51 | WARMUP_EPOCHS: 5 52 | MAX_EPOCHS: 90 53 | FREQENCE: 1 54 | 55 | SYNCBN: 0 -------------------------------------------------------------------------------- /inference_api.py: -------------------------------------------------------------------------------- 1 | """Inference 2 | @author: FlyEgle 3 | @datetime: 2022-01-26 4 | """ 5 | import os 6 | import cv2 7 | import json 8 | import torch 9 | import shutil 10 | import numpy as np 11 | import torch.nn.functional as F 12 | 13 | from models.model_factory import ModelFactory 14 | from utils.DataAugments import Normalize, Scale, ToTensor 15 | from torch.cuda.amp import autocast as autocast 16 | 17 | from tqdm import tqdm 18 | from PIL import Image 19 | 20 | 21 | def aug(images): 22 | images, _ = Scale((800, 800))(images, images) 23 | images, _ = Normalize(normalize=True)(images, images) 24 | images, _ = ToTensor()(images, images) 25 | return images 26 | 27 | 28 | def load_ckpt(net, model_ckpt): 29 | state_dict = torch.load(model_ckpt, map_location="cpu")['state_dict'] 30 | net.load_state_dict(state_dict) 31 | print(f"load the ckpt {model_ckpt}") 32 | return net 33 | 34 | 35 | class SegNet: 36 | def __init__(self, model_name, num_classes, weights): 37 | self.model_name = model_name 38 | self.num_classes = num_classes 39 | self.weights = weights 40 | 41 | # build model 42 | model_factory = ModelFactory() 43 | self.net = model_factory.getattr(model_name)(num_classes=self.num_classes) 44 | load_ckpt(self.net, self.weights) 45 | 46 | # cuda & eval 47 | if torch.cuda.is_available(): 48 | self.net.cuda() 49 | 50 | self.net.eval() 51 | 52 | @torch.no_grad() 53 | def infer(self, images): 54 | """images: np.ndarray RGB 55 | Return: 56 | mask : 0-1 uint8 map 57 | mask_map: crop RGB images 58 | """ 59 | src_h, src_w, c = images.shape 60 | img = aug(images) 61 | img.unsqueeze_(0) # chw->bchw 62 | 63 | with autocast(): 64 | img = img.cuda() 65 | outputs = self.net(img) # outputs is a list 66 | 67 | output = outputs[0] 68 | output = F.interpolate(output, size=[src_h, src_w], mode='bilinear', align_corners=True) 69 | # # matting 70 | output = torch.sigmoid(output) # b, 1, h, w 71 | output = output.cpu().numpy() 72 | 73 | # if use the ce loss trainig with softmax 74 | # output = torch.argmax(torch.softmax(output, dim=1), dim=1) 75 | # output = output.unsqueeze(1) 76 | # output = output.cpu().numpy() 77 | 78 | mask = output 79 | 80 | # binary threshold can control to make the context more 81 | output[output >= 0.9] = 1 82 | output[output < 0.9] = 0 83 | 84 | # make uint8 mask 85 | # mask = output 86 | mask = output.astype(np.uint8) 87 | 88 | # crop images 89 | mask_map = np.zeros((mask.shape[2], mask.shape[3], 3)) 90 | mask_map[:,:,0] = mask[0,0,:,:] * images[:,:,0] 91 | mask_map[:,:,1] = mask[0,0,:,:] * images[:,:,1] 92 | mask_map[:,:,2] = mask[0,0,:,:] * images[:,:,2] 93 | 94 | mask_map_white = mask_map.copy() 95 | mask_map_white[mask_map_white==0] = 255 96 | 97 | return mask[0,0,:,:]*255, mask_map, mask_map_white 98 | 99 | 100 | def load_data(path): 101 | if os.path.isdir(path): 102 | data_list = [os.path.join(path, x) for x in os.listdir(path)] 103 | else: 104 | data_list = [x.strip() for x in open(path).readlines()] 105 | if "image_path" in data_list[0]: 106 | data_list = [json.loads(x)["image_path"] for x in data_list] 107 | 108 | return data_list 109 | 110 | 111 | def concat_image(image1, image2): 112 | w, h = image1.size 113 | new = Image.new("RGB", (w*2, h), 255) 114 | new.paste(image1, (0, 0)) 115 | new.paste(image2, (w, 0)) 116 | new_img = new.resize((w // 4, h // 4)) 117 | return new_img 118 | 119 | 120 | 121 | if __name__ == "__main__": 122 | weights = "/data/jiangmingchao/data/AICKPT/Seg/U2Net/400_epoch_800_crop_0.3_1.2_1E-4_circle_2_27k_copypaste_sgd/checkpoints/best_ckpt_epoch_133_losses_0.21228863086019242_miou_0.9878880708590272.pth" 123 | model_name = "u2net" 124 | num_classes = 1 125 | model = SegNet(model_name, num_classes, weights) 126 | 127 | # -----------------------single image inference------------------------------------- 128 | # test_image_path = "/data/jiangmingchao/data/code/SegmentationLight/111.png" 129 | 130 | # image = cv2.imread(test_image_path) 131 | # img = Image.fromarray(image) 132 | # # print(img.mode) 133 | # # image = image[:,:,:] 134 | 135 | # if image is not None: 136 | # image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 137 | # mask, mask_map, mask_map_w = model.infer(image) 138 | # # merge = cv2.addWeighted(image, 0.5, mask, 0.5, 0) 139 | # color_mask = mask.copy() 140 | # color_mask = np.concatenate( 141 | # (np.expand_dims(color_mask, axis=-1), 142 | # np.expand_dims(color_mask, axis=-1), 143 | # np.expand_dims(color_mask, axis=-1), 144 | # ), -1) 145 | # # print(color_mask[color_mask==255]) 146 | # color_mask[:,:,0][color_mask[:,:,0]==255] = 128 147 | # color_mask[:,:,1][color_mask[:,:,1]==255] = 0 148 | # color_mask[:,:,2][color_mask[:,:,2]==255] = 128 149 | 150 | # merge = 0.5 * image[:,:,::-1] + 0.5 * color_mask 151 | # merge = merge.astype(np.uint8) 152 | 153 | # cv2.imwrite("/data/jiangmingchao/data/code/SegmentationLight/tmp/mask.png", mask, [cv2.IMWRITE_PNG_COMPRESSION, 0]) 154 | # cv2.imwrite("/data/jiangmingchao/data/code/SegmentationLight/tmp/mask_map.png", mask_map[:,:,::-1], [cv2.IMWRITE_PNG_COMPRESSION, 0]) 155 | # cv2.imwrite("/data/jiangmingchao/data/code/SegmentationLight/tmp/mask_map_w.png", mask_map_w[:,:,::-1], [cv2.IMWRITE_PNG_COMPRESSION, 0]) 156 | # cv2.imwrite("/data/jiangmingchao/data/code/SegmentationLight/tmp/merge.png", merge, [cv2.IMWRITE_PNG_COMPRESSION, 0]) 157 | # else: 158 | # raise IOError(f"{test_image_path} is not exists!!!") 159 | 160 | # ------------------------ folder inference ------------------------------------------- 161 | test_folder = "/data/jiangmingchao/data/code/cluster/shein_2k_1w.log" 162 | # test_image_list = [os.path.join(test_folder, x) for x in os.listdir(test_folder)] 163 | 164 | out_folder = "/data/jiangmingchao/data/dataset/shein_2k_1w_720_patch/mask" 165 | 166 | if not os.path.exists(out_folder): 167 | os.makedirs(out_folder) 168 | else: 169 | shutil.rmtree(out_folder) 170 | os.makedirs(out_folder) 171 | 172 | test_image_list = load_data(test_folder) 173 | for img_path in tqdm(test_image_list): 174 | img = cv2.imread(img_path) 175 | if img is not None : 176 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 177 | mask, mask_map, _ = model.infer(img) 178 | img_name = img_path.split('/')[-1].split('.')[0]+'.png' 179 | cv2.imwrite(os.path.join(out_folder, img_name), mask, [cv2.IMWRITE_PNG_COMPRESSION, 0]) 180 | else: 181 | raise IOError(f"{img_path} is not exists!!!") 182 | 183 | 184 | -------------------------------------------------------------------------------- /losses/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FlyEgle/segmentationlight/45c3e574f578bac046bd6027d2f3dbb7d106e015/losses/__init__.py -------------------------------------------------------------------------------- /losses/generatorLoss.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author : FlyEgle 3 | @describe : generate loss function for generate task 4 | @datetime : 2022-04-07 5 | """ 6 | import torch 7 | import warnings 8 | import numpy as np 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | 13 | def _fspecial_gauss_1d(size, sigma): 14 | r"""Create 1-D gauss kernel 15 | Args: 16 | size (int): the size of gauss kernel 17 | sigma (float): sigma of normal distribution 18 | Returns: 19 | torch.Tensor: 1D kernel (1 x 1 x size) 20 | """ 21 | coords = torch.arange(size).to(dtype=torch.float) 22 | coords -= size // 2 23 | 24 | g = torch.exp(-(coords ** 2) / (2 * sigma ** 2)) 25 | g /= g.sum() 26 | 27 | return g.unsqueeze(0).unsqueeze(0) 28 | 29 | 30 | def gaussian_filter(input, win): 31 | r""" Blur input with 1-D kernel 32 | Args: 33 | input (torch.Tensor): a batch of tensors to be blurred 34 | window (torch.Tensor): 1-D gauss kernel 35 | Returns: 36 | torch.Tensor: blurred tensors 37 | """ 38 | assert all([ws == 1 for ws in win.shape[1:-1]]), win.shape 39 | if len(input.shape) == 4: 40 | conv = F.conv2d 41 | elif len(input.shape) == 5: 42 | conv = F.conv3d 43 | else: 44 | raise NotImplementedError(input.shape) 45 | 46 | C = input.shape[1] 47 | out = input 48 | for i, s in enumerate(input.shape[2:]): 49 | if s >= win.shape[-1]: 50 | out = conv(out, weight=win.transpose(2 + i, -1), stride=1, padding=0, groups=C) 51 | else: 52 | warnings.warn( 53 | f"Skipping Gaussian Smoothing at dimension 2+{i} for input: {input.shape} and win size: {win.shape[-1]}" 54 | ) 55 | 56 | return out 57 | 58 | 59 | def _ssim(X, Y, data_range, win, size_average=True, K=(0.01, 0.03)): 60 | r""" Calculate ssim index for X and Y 61 | Args: 62 | X (torch.Tensor): images 63 | Y (torch.Tensor): images 64 | win (torch.Tensor): 1-D gauss kernel 65 | data_range (float or int, optional): value range of input images. (usually 1.0 or 255) 66 | size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar 67 | Returns: 68 | torch.Tensor: ssim results. 69 | """ 70 | K1, K2 = K 71 | # batch, channel, [depth,] height, width = X.shape 72 | compensation = 1.0 73 | 74 | C1 = (K1 * data_range) ** 2 75 | C2 = (K2 * data_range) ** 2 76 | 77 | win = win.to(X.device, dtype=X.dtype) 78 | 79 | mu1 = gaussian_filter(X, win) 80 | mu2 = gaussian_filter(Y, win) 81 | 82 | mu1_sq = mu1.pow(2) 83 | mu2_sq = mu2.pow(2) 84 | mu1_mu2 = mu1 * mu2 85 | 86 | sigma1_sq = compensation * (gaussian_filter(X * X, win) - mu1_sq) 87 | sigma2_sq = compensation * (gaussian_filter(Y * Y, win) - mu2_sq) 88 | sigma12 = compensation * (gaussian_filter(X * Y, win) - mu1_mu2) 89 | 90 | cs_map = (2 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2) # set alpha=beta=gamma=1 91 | ssim_map = ((2 * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1)) * cs_map 92 | 93 | ssim_per_channel = torch.flatten(ssim_map, 2).mean(-1) 94 | cs = torch.flatten(cs_map, 2).mean(-1) 95 | return ssim_per_channel, cs 96 | 97 | 98 | def ssim( 99 | X, 100 | Y, 101 | data_range=255, 102 | size_average=True, 103 | win_size=11, 104 | win_sigma=1.5, 105 | win=None, 106 | K=(0.01, 0.03), 107 | nonnegative_ssim=False, 108 | ): 109 | r""" interface of ssim 110 | Args: 111 | X (torch.Tensor): a batch of images, (N,C,H,W) 112 | Y (torch.Tensor): a batch of images, (N,C,H,W) 113 | data_range (float or int, optional): value range of input images. (usually 1.0 or 255) 114 | size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar 115 | win_size: (int, optional): the size of gauss kernel 116 | win_sigma: (float, optional): sigma of normal distribution 117 | win (torch.Tensor, optional): 1-D gauss kernel. if None, a new kernel will be created according to win_size and win_sigma 118 | K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results. 119 | nonnegative_ssim (bool, optional): force the ssim response to be nonnegative with relu 120 | Returns: 121 | torch.Tensor: ssim results 122 | """ 123 | if not X.shape == Y.shape: 124 | raise ValueError("Input images should have the same dimensions.") 125 | 126 | for d in range(len(X.shape) - 1, 1, -1): 127 | X = X.squeeze(dim=d) 128 | Y = Y.squeeze(dim=d) 129 | 130 | if len(X.shape) not in (4, 5): 131 | raise ValueError( 132 | f"Input images should be 4-d or 5-d tensors, but got {X.shape}") 133 | 134 | if not X.type() == Y.type(): 135 | raise ValueError("Input images should have the same dtype.") 136 | 137 | if win is not None: # set win_size 138 | win_size = win.shape[-1] 139 | 140 | if not (win_size % 2 == 1): 141 | raise ValueError("Window size should be odd.") 142 | 143 | if win is None: 144 | win = _fspecial_gauss_1d(win_size, win_sigma) 145 | win = win.repeat([X.shape[1]] + [1] * (len(X.shape) - 1)) 146 | 147 | ssim_per_channel, cs = _ssim( 148 | X, Y, data_range=data_range, win=win, size_average=False, K=K) 149 | if nonnegative_ssim: 150 | ssim_per_channel = torch.relu(ssim_per_channel) 151 | 152 | if size_average: 153 | return ssim_per_channel.mean() 154 | else: 155 | return ssim_per_channel.mean(1) 156 | 157 | # ------------------SSIM Loss---------------------- 158 | class SSIM(nn.Module): 159 | def __init__( 160 | self, 161 | data_range=255, 162 | size_average=True, 163 | win_size=11, 164 | win_sigma=1.5, 165 | channel=3, 166 | spatial_dims=2, 167 | K=(0.01, 0.03), 168 | nonnegative_ssim=False, 169 | ): 170 | r""" class for ssim 171 | Args: 172 | data_range (float or int, optional): value range of input images. (usually 1.0 or 255) 173 | size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar 174 | win_size: (int, optional): the size of gauss kernel 175 | win_sigma: (float, optional): sigma of normal distribution 176 | channel (int, optional): input channels (default: 3) 177 | K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results. 178 | nonnegative_ssim (bool, optional): force the ssim response to be nonnegative with relu. 179 | """ 180 | 181 | super(SSIM, self).__init__() 182 | self.win_size = win_size 183 | self.win = _fspecial_gauss_1d(win_size, win_sigma).repeat( 184 | [channel, 1] + [1] * spatial_dims) 185 | self.size_average = size_average 186 | self.data_range = data_range 187 | self.K = K 188 | self.nonnegative_ssim = nonnegative_ssim 189 | 190 | def forward(self, X, Y): 191 | return ssim( 192 | X, 193 | Y, 194 | data_range=self.data_range, 195 | size_average=self.size_average, 196 | win=self.win, 197 | K=self.K, 198 | nonnegative_ssim=self.nonnegative_ssim, 199 | ) 200 | 201 | 202 | # ------------------------- tvloss ------------------------ 203 | class TVLoss(nn.Module): 204 | def __init__(self, tv_loss_weight=1): 205 | super(TVLoss, self).__init__() 206 | self.tv_loss_weight = tv_loss_weight 207 | 208 | def forward(self, x): 209 | batch_size = x.size()[0] 210 | h_x = x.size()[2] 211 | w_x = x.size()[3] 212 | count_h = self.tensor_size(x[:, :, 1:, :]) 213 | count_w = self.tensor_size(x[:, :, :, 1:]) 214 | h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum() 215 | w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum() 216 | return self.tv_loss_weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_size 217 | 218 | @staticmethod 219 | def tensor_size(t): 220 | return t.size()[1] * t.size()[2] * t.size()[3] 221 | 222 | 223 | 224 | if __name__ == '__main__': 225 | inputs = torch.randn(1, 1, 224, 224) 226 | targets = torch.empty(1,1,224,224).random_(2) 227 | 228 | loss = TVLoss() 229 | print(loss(inputs)) -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """Traing model 2 | @author: FlyEgle 3 | @datetime: 2022-01-20 4 | """ 5 | import warnings 6 | warnings.filterwarnings("ignore") 7 | 8 | import os 9 | import math 10 | import torch 11 | import numpy as np 12 | import torch.nn as nn 13 | 14 | import torch.distributed as dist 15 | 16 | from torch.cuda.amp import autocast as autocast 17 | from torch.nn.parallel import DistributedDataParallel as DistParallel 18 | from torch.utils.data import DistributedSampler, DataLoader 19 | 20 | # model 21 | from config.config import build_argparse, parse_yaml 22 | # Optimizer 23 | from utils.Optim import BuildOptim 24 | # Metric 25 | from utils.Metirc import SegmentationMetric 26 | # loss 27 | from utils.Loss import LossBar 28 | # DataSet 29 | from datasets.voc2012.dataset import VocSemanticSegDataSet, build_transformers, build_val_transformers 30 | # Model 31 | from models.model_factory import ModelFactory 32 | # Complex Aug 33 | from utils.FuseAugments import MixUP, MixCriterion 34 | # Train 35 | from train import train_one_epoch, val_one_epoch 36 | # Summary 37 | from utils.Summary import LoggerRecord, LoggerInfo, CkptRecord 38 | # Utils function 39 | from utils.utils import Load_state_dict 40 | 41 | 42 | # random init seed 43 | def random_init(seed=100): 44 | torch.manual_seed(seed) 45 | torch.cuda.manual_seed_all(seed) 46 | np.random.seed(seed) 47 | torch.backends.cudnn.benchmark = True 48 | torch.backends.cudnn.deterministic = False 49 | 50 | 51 | def translate_state_dict(state_dict): 52 | new_state_dict = {} 53 | for key, value in state_dict.items(): 54 | if 'module' in key: 55 | new_state_dict[key[7:]] = value 56 | else: 57 | new_state_dict[key] = value 58 | return new_state_dict 59 | 60 | 61 | def main_worker(args, opt): 62 | total_rank = opt.DIST.NGPU * torch.cuda.device_count() 63 | print('rank: {} / {}'.format(args.local_rank, total_rank)) 64 | dist.init_process_group(backend=opt.DIST.DIST_BACKEND) 65 | torch.cuda.set_device(args.local_rank) 66 | 67 | ngpus_per_node = total_rank 68 | 69 | if opt.DATASET_TYPE.lower() == "voc": 70 | NUM_CLASSES = opt.VOC_DATASET.NUM_CLASSES 71 | TRAIN_DATA = opt.VOC_DATASET.TRAIN_FILE 72 | VAL_DATA = opt.VOC_DATASET.VAL_FILE 73 | elif opt.DATASET_TYPE.lower() == "voc_aug": 74 | NUM_CLASSES = opt.VOC_DATASET.NUM_CLASSES 75 | TRAIN_DATA = opt.VOC_DATASET.TRAIN_FILE 76 | VAL_DATA = opt.VOC_DATASET.VAL_FILE 77 | elif opt.DATASET_TYPE.lower() == "custom": 78 | NUM_CLASSES = opt.CUSTOM_DATASET.NUM_CLASSES 79 | TRAIN_DATA = opt.CUSTOM_DATASET.TRAIN_FILE 80 | VAL_DATA = opt.CUSTOM_DATASET.VAL_FILE 81 | 82 | # training metric 83 | train_metric = SegmentationMetric("Train", NUM_CLASSES) 84 | val_metric = SegmentationMetric("Val", NUM_CLASSES) 85 | 86 | # model 87 | model_factory = ModelFactory() 88 | net = model_factory.getattr(opt.MODEL_NAME)(num_classes=NUM_CLASSES) 89 | if opt.PRETRAIN: 90 | state = Load_state_dict(opt.PRETRAIN_WEIGHTS, net) 91 | net.load_state_dict(state) 92 | 93 | print("Load the pretrain from real domain dataset!!!") 94 | 95 | # resume 96 | if opt.RESUME: 97 | ckpt = torch.load(opt.RESUME_CHECKPOINTS, map_location="cpu") 98 | resume_start_epoch = ckpt["epoch"] 99 | optim_state_dict = ckpt["optimizer"] 100 | 101 | state = Load_state_dict(opt.RESUME_CHECKPOINTS, net) 102 | net.load_state_dict(state) 103 | print("Load the resume checkpoints for follow training!!!") 104 | 105 | 106 | if args.local_rank == 0: 107 | print(f"===============model arch ===============") 108 | print(net) 109 | 110 | if torch.cuda.is_available(): 111 | net.cuda(args.local_rank) 112 | 113 | # build loss function 114 | criterion = LossBar(opt.LOSS.lower())() 115 | if "mutil" in opt.MODEL_NAME.lower(): 116 | aux_criterion = nn.L1Loss() 117 | else: 118 | aux_criterion = None 119 | 120 | # mixup 121 | if opt.BATCH_AUG.MIXUP: 122 | mixup = MixUP(alpha=1.0, cuda=torch.cuda.is_available()) 123 | mix_criterion = MixCriterion(criterion) 124 | else: 125 | mixup = None 126 | mix_criterion = criterion 127 | 128 | # build Optim 129 | optim = BuildOptim( 130 | opt.OPTIMIZER.OPTIM_NAME, 131 | opt.OPTIMIZER.LEARNING_RATE, 132 | opt.OPTIMIZER.WEIGHT_DECAY, 133 | opt.OPTIMIZER.MOMENTUM 134 | )(net.parameters()) 135 | 136 | if opt.DIST.DISTRIBUTED and opt.SYNCBN: 137 | net = nn.SyncBatchNorm.convert_sync_batchnorm(net) 138 | 139 | if opt.DIST.DISTRIBUTED: 140 | net = DistParallel(net, device_ids=[args.local_rank], find_unused_parameters=False) 141 | # dataset & dataloader 142 | TrainDataset = VocSemanticSegDataSet( 143 | TRAIN_DATA, 144 | transformers=build_transformers(opt.DATA.CROP_SIZE), 145 | # transformers=build_val_transformers(opt.DATA.CROP_SIZE), # used for the fixres 146 | train_phase=True 147 | ) 148 | ValidationDataset = VocSemanticSegDataSet( 149 | VAL_DATA, 150 | transformers=build_val_transformers(opt.DATA.CROP_SIZE), 151 | train_phase=False 152 | ) 153 | 154 | if args.local_rank == 0: 155 | print("Training Dataset length: ", len(TrainDataset)) 156 | print("Validation Dataset length: ", len(ValidationDataset)) 157 | 158 | if opt.DIST.DISTRIBUTED: 159 | TrainSampler = DistributedSampler(TrainDataset) 160 | ValidationSampler = DistributedSampler(ValidationDataset) 161 | else: 162 | TrainSampler = None 163 | ValidationSampler = None 164 | 165 | # dataloader 166 | TrainLoader = DataLoader( 167 | dataset = TrainDataset, 168 | batch_size = opt.OPTIMIZER.BATCH_SIZE, 169 | shuffle = (TrainSampler is None), 170 | num_workers = opt.OPTIMIZER.NUM_WORKERS, 171 | pin_memory = True, 172 | sampler = TrainSampler, 173 | drop_last = True 174 | ) 175 | ValidationLoader = DataLoader( 176 | dataset = ValidationDataset, 177 | batch_size = opt.OPTIMIZER.BATCH_SIZE, 178 | shuffle = (ValidationSampler is None), 179 | num_workers = opt.OPTIMIZER.NUM_WORKERS, 180 | pin_memory = True, 181 | sampler = ValidationSampler, 182 | drop_last = False 183 | ) 184 | 185 | # log & ckpt 186 | if args.local_rank == 0: 187 | logger_writter = LoggerRecord(os.path.join(opt.SUMMARY.SAVE_PATH, opt.SUMMARY.LOG)) # log 188 | logger_info = LoggerInfo(os.path.join(opt.SUMMARY.SAVE_PATH, opt.SUMMARY.LOG)) # logger 189 | ckpt_saver = CkptRecord(os.path.join(opt.SUMMARY.SAVE_PATH, opt.SUMMARY.CHECKPOINTS)) # ckpt 190 | else: 191 | logger_writter = None 192 | logger_info = None 193 | ckpt_saver = None 194 | 195 | # train_batch = math.ceil(len(TrainLoader) / (args.batch_size * ngpus_per_node)) 196 | train_batch = len(TrainLoader) 197 | total_batch = train_batch * opt.MAX_EPOCHS 198 | print("train_batch: ", train_batch) 199 | 200 | # training params 201 | if opt.RESUME: 202 | start_epoch = resume_start_epoch 203 | optim.load_state_dict(optim_state_dict) 204 | batch_iter = train_batch * start_epoch 205 | else: 206 | start_epoch = 1 207 | batch_iter = 0 208 | 209 | 210 | val_batch = math.ceil(len(ValidationDataset) / (opt.OPTIMIZER.BATCH_SIZE * ngpus_per_node)) 211 | 212 | scaler = torch.cuda.amp.GradScaler() 213 | # training loop 214 | for epoch in range(start_epoch, opt.MAX_EPOCHS + 1): 215 | if opt.DIST.DISTRIBUTED: 216 | TrainSampler.set_epoch(epoch) 217 | # train 218 | batch_iter, scaler = train_one_epoch( 219 | args, opt, scaler, net, TrainLoader, 220 | mixup, mix_criterion, aux_criterion, 221 | optim, epoch, batch_iter, total_batch, 222 | train_batch, logger_writter, logger_info, 223 | train_metric 224 | ) 225 | # val 226 | if epoch % opt.FREQENCE == 0: 227 | val_losses, val_pa, val_mpa, val_miou, val_fwiou = val_one_epoch( 228 | args, opt, ValidationLoader, net, 229 | criterion, epoch, val_batch, 230 | logger_writter, logger_info, 231 | val_metric) 232 | # save ckpt 233 | if args.local_rank == 0: 234 | model_state = translate_state_dict(net.state_dict()) 235 | state_dict = { 236 | 'epoch': epoch, 237 | 'state_dict': model_state, 238 | 'optimizer': optim.state_dict() 239 | } 240 | ckpt_saver.SaveBestCkpt(state_dict, epoch, val_losses, val_miou) 241 | 242 | net.train() 243 | 244 | 245 | if __name__ == "__main__": 246 | args = build_argparse() 247 | 248 | opt = parse_yaml(args.hyp) 249 | print(opt) 250 | random_init() 251 | 252 | main_worker(args, opt) -------------------------------------------------------------------------------- /models/DeepLab/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FlyEgle/segmentationlight/45c3e574f578bac046bd6027d2f3dbb7d106e015/models/DeepLab/__init__.py -------------------------------------------------------------------------------- /models/FCN/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FlyEgle/segmentationlight/45c3e574f578bac046bd6027d2f3dbb7d106e015/models/FCN/__init__.py -------------------------------------------------------------------------------- /models/FCN/fcn_mobilenetv2.py: -------------------------------------------------------------------------------- 1 | """ 2 | Creates a MobileNetV2 Model as defined in: 3 | Mark Sandler, Andrew Howard, Menglong Zhu, Andrey Zhmoginov, Liang-Chieh Chen. (2018). 4 | MobileNetV2: Inverted Residuals and Linear Bottlenecks 5 | arXiv preprint arXiv:1801.04381. 6 | import from https://github.com/tonylins/pytorch-mobilenet-v2 7 | """ 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import math 12 | 13 | __all__ = ['mobilenetv2'] 14 | 15 | model_url = '/data/jiangmingchao/data/code/SegmentationLight/models/FCN/pretrained/mobilenetv2-c5e733a8.pth' 16 | 17 | def _make_divisible(v, divisor, min_value=None): 18 | """ 19 | This function is taken from the original tf repo. 20 | It ensures that all layers have a channel number that is divisible by 8 21 | It can be seen here: 22 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 23 | :param v: 24 | :param divisor: 25 | :param min_value: 26 | :return: 27 | """ 28 | if min_value is None: 29 | min_value = divisor 30 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 31 | # Make sure that round down does not go down by more than 10%. 32 | if new_v < 0.9 * v: 33 | new_v += divisor 34 | return new_v 35 | 36 | 37 | def conv_3x3_bn(inp, oup, stride): 38 | return nn.Sequential( 39 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 40 | nn.BatchNorm2d(oup), 41 | nn.ReLU6(inplace=True) 42 | ) 43 | 44 | 45 | def conv_1x1_bn(inp, oup): 46 | return nn.Sequential( 47 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 48 | nn.BatchNorm2d(oup), 49 | nn.ReLU6(inplace=True) 50 | ) 51 | 52 | 53 | class InvertedResidual(nn.Module): 54 | def __init__(self, inp, oup, stride, expand_ratio): 55 | super(InvertedResidual, self).__init__() 56 | assert stride in [1, 2] 57 | 58 | hidden_dim = round(inp * expand_ratio) 59 | self.identity = stride == 1 and inp == oup 60 | 61 | if expand_ratio == 1: 62 | self.conv = nn.Sequential( 63 | # dw 64 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 65 | nn.BatchNorm2d(hidden_dim), 66 | nn.ReLU6(inplace=True), 67 | # pw-linear 68 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 69 | nn.BatchNorm2d(oup), 70 | ) 71 | else: 72 | self.conv = nn.Sequential( 73 | # pw 74 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), 75 | nn.BatchNorm2d(hidden_dim), 76 | nn.ReLU6(inplace=True), 77 | # dw 78 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 79 | nn.BatchNorm2d(hidden_dim), 80 | nn.ReLU6(inplace=True), 81 | # pw-linear 82 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 83 | nn.BatchNorm2d(oup), 84 | ) 85 | 86 | def forward(self, x): 87 | if self.identity: 88 | return x + self.conv(x) 89 | else: 90 | return self.conv(x) 91 | 92 | 93 | class MobileNetV2(nn.Module): 94 | def __init__(self, num_classes=1000, width_mult=1., pretrained=True): 95 | super(MobileNetV2, self).__init__() 96 | # setting of inverted residual blocks 97 | self.cfgs = [ 98 | # t, c, n, s 99 | [1, 16, 1, 1], 100 | [6, 24, 2, 2], 101 | [6, 32, 3, 2], 102 | [6, 64, 4, 2], 103 | [6, 96, 3, 1], 104 | [6, 160, 3, 2], 105 | [6, 320, 1, 1], 106 | ] 107 | 108 | # building first layer 109 | input_channel = _make_divisible(32 * width_mult, 4 if width_mult == 0.1 else 8) 110 | layers = [conv_3x3_bn(3, input_channel, 2)] 111 | # building inverted residual blocks 112 | block = InvertedResidual 113 | for t, c, n, s in self.cfgs: 114 | output_channel = _make_divisible(c * width_mult, 4 if width_mult == 0.1 else 8) 115 | for i in range(n): 116 | layers.append(block(input_channel, output_channel, s if i == 0 else 1, t)) 117 | input_channel = output_channel 118 | self.features = nn.Sequential(*layers) 119 | # building last several layers 120 | output_channel = _make_divisible(1280 * width_mult, 4 if width_mult == 0.1 else 8) if width_mult > 1.0 else 1280 121 | self.conv = conv_1x1_bn(input_channel, output_channel) 122 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 123 | self.classifier = nn.Linear(output_channel, num_classes) 124 | 125 | self._initialize_weights() 126 | if pretrained: 127 | self._load_weights() 128 | 129 | def forward(self, x): 130 | features = [] 131 | for idx, feats in enumerate(self.features): 132 | x = feats(x) 133 | if idx == 6 or idx == 13: 134 | features.append(x) 135 | # x = self.features(x) 136 | x = self.conv(x) 137 | features.append(x) 138 | # x = self.avgpool(x) 139 | # x = x.view(x.size(0), -1) 140 | # x = self.classifier(x) 141 | return features[0], features[1], features[2] 142 | 143 | def _initialize_weights(self): 144 | for m in self.modules(): 145 | if isinstance(m, nn.Conv2d): 146 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 147 | m.weight.data.normal_(0, math.sqrt(2. / n)) 148 | if m.bias is not None: 149 | m.bias.data.zero_() 150 | elif isinstance(m, nn.BatchNorm2d): 151 | m.weight.data.fill_(1) 152 | m.bias.data.zero_() 153 | elif isinstance(m, nn.Linear): 154 | m.weight.data.normal_(0, 0.01) 155 | m.bias.data.zero_() 156 | 157 | def _load_weights(self): 158 | state_dict = torch.load(model_url, map_location="cpu") 159 | self.load_state_dict(state_dict) 160 | print("Load imagenet pretrain!!!") 161 | 162 | 163 | class FCNMobilenetv2_8S(nn.Module): 164 | def __init__(self, num_classes=21, pretrained=True): 165 | super(FCNMobilenetv2_8S, self).__init__() 166 | self.NUM_CLASSES = num_classes 167 | self.backbone = MobileNetV2(pretrained=pretrained) 168 | 169 | self.smooth_conv1 = nn.Conv2d(1280, 96, 3, 1, 1) 170 | self.relu1 = nn.ReLU(inplace=True) 171 | self.bn1 = nn.BatchNorm2d(96) 172 | 173 | self.smooth_conv2 = nn.Conv2d(96, 32, 3, 1, 1) 174 | self.relu2 = nn.ReLU(inplace=True) 175 | self.bn2 = nn.BatchNorm2d(32) 176 | 177 | self.smooth = nn.Conv2d(32, 32, 3, 1, 1) 178 | self.relu3 = nn.ReLU(inplace=True) 179 | self.bn3 = nn.BatchNorm2d(32) 180 | 181 | self.classification = nn.Conv2d(32, self.NUM_CLASSES, 1, 1, 0) 182 | 183 | def forward(self, x): 184 | b, c, h, w = x.shape 185 | p2, p3, p4 = self.backbone(x) 186 | 187 | p4 = self.relu1(self.bn1(self.smooth_conv1(p4))) 188 | out2 = F.interpolate(p4, size=(p3.shape[2], p3.shape[3]), mode='bilinear', align_corners=True) 189 | p3 = out2 + p3 190 | out3 = self.relu2(self.bn2(self.smooth_conv2(p3))) 191 | out4 = F.interpolate(out3, size=(p2.shape[2], p2.shape[3]), mode='bilinear', align_corners=True) 192 | p2 = out4 + p2 193 | 194 | p2 = self.relu3(self.bn3(self.smooth(p2))) 195 | out = F.interpolate(p2, size=(h, w), mode='bilinear', align_corners=True) 196 | out = self.classification(out) 197 | return out 198 | 199 | 200 | if __name__ == '__main__': 201 | net = FCNMobilenetv2_8S(21) 202 | inputs = torch.randn(1, 3, 512, 512) 203 | print(net) 204 | outputs = net(inputs) 205 | print(outputs.shape) 206 | -------------------------------------------------------------------------------- /models/FCN/fcn_mobilenetv3.py: -------------------------------------------------------------------------------- 1 | '''MobileNetV3 in PyTorch. 2 | See the paper "Inverted Residuals and Linear Bottlenecks: 3 | Mobile Networks for Classification, Detection and Segmentation" for more details. 4 | ''' 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.nn import init 9 | 10 | __all__ = ['MobileNetv3'] 11 | 12 | model_url = "/data/jiangmingchao/data/code/SegmentationLight/models/FCN/pretrained/mbv3_large.pth.tar" 13 | 14 | 15 | class hswish(nn.Module): 16 | def forward(self, x): 17 | out = x * F.relu6(x + 3, inplace=True) / 6 18 | return out 19 | 20 | 21 | class hsigmoid(nn.Module): 22 | def forward(self, x): 23 | out = F.relu6(x + 3, inplace=True) / 6 24 | return out 25 | 26 | 27 | class SeModule(nn.Module): 28 | def __init__(self, in_size, reduction=4): 29 | super(SeModule, self).__init__() 30 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 31 | self.se = nn.Sequential( 32 | nn.Conv2d(in_size, in_size // reduction, kernel_size=1, stride=1, padding=0, bias=False), 33 | nn.BatchNorm2d(in_size // reduction), 34 | nn.ReLU(inplace=True), 35 | nn.Conv2d(in_size // reduction, in_size, kernel_size=1, stride=1, padding=0, bias=False), 36 | nn.BatchNorm2d(in_size), 37 | hsigmoid() 38 | ) 39 | 40 | def forward(self, x): 41 | return x * self.se(x) 42 | 43 | 44 | class Block(nn.Module): 45 | '''expand + depthwise + pointwise''' 46 | def __init__(self, kernel_size, in_size, expand_size, out_size, nolinear, semodule, stride): 47 | super(Block, self).__init__() 48 | self.stride = stride 49 | self.se = semodule 50 | 51 | self.conv1 = nn.Conv2d(in_size, expand_size, kernel_size=1, stride=1, padding=0, bias=False) 52 | self.bn1 = nn.BatchNorm2d(expand_size) 53 | self.nolinear1 = nolinear 54 | self.conv2 = nn.Conv2d(expand_size, expand_size, kernel_size=kernel_size, stride=stride, padding=kernel_size//2, groups=expand_size, bias=False) 55 | self.bn2 = nn.BatchNorm2d(expand_size) 56 | self.nolinear2 = nolinear 57 | self.conv3 = nn.Conv2d(expand_size, out_size, kernel_size=1, stride=1, padding=0, bias=False) 58 | self.bn3 = nn.BatchNorm2d(out_size) 59 | 60 | self.shortcut = nn.Sequential() 61 | if stride == 1 and in_size != out_size: 62 | self.shortcut = nn.Sequential( 63 | nn.Conv2d(in_size, out_size, kernel_size=1, stride=1, padding=0, bias=False), 64 | nn.BatchNorm2d(out_size), 65 | ) 66 | 67 | def forward(self, x): 68 | out = self.nolinear1(self.bn1(self.conv1(x))) 69 | out = self.nolinear2(self.bn2(self.conv2(out))) 70 | out = self.bn3(self.conv3(out)) 71 | if self.se != None: 72 | out = self.se(out) 73 | out = out + self.shortcut(x) if self.stride==1 else out 74 | return out 75 | 76 | 77 | class MobileNetV3(nn.Module): 78 | def __init__(self, num_classes=1000, pretrained=True): 79 | super(MobileNetV3, self).__init__() 80 | self.pretrained = pretrained 81 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1, bias=False) 82 | self.bn1 = nn.BatchNorm2d(16) 83 | self.hs1 = hswish() 84 | 85 | self.bneck = nn.Sequential( 86 | Block(3, 16, 16, 16, nn.ReLU(inplace=True), None, 1), # 0 87 | Block(3, 16, 64, 24, nn.ReLU(inplace=True), None, 2), # 1 88 | Block(3, 24, 72, 24, nn.ReLU(inplace=True), None, 1), # 2 89 | Block(5, 24, 72, 40, nn.ReLU(inplace=True), SeModule(40), 2), # 3 90 | Block(5, 40, 120, 40, nn.ReLU(inplace=True), SeModule(40), 1), # 4 91 | Block(5, 40, 120, 40, nn.ReLU(inplace=True), SeModule(40), 1), # 5 92 | Block(3, 40, 240, 80, hswish(), None, 2), # 6 93 | Block(3, 80, 200, 80, hswish(), None, 1), # 7 94 | Block(3, 80, 184, 80, hswish(), None, 1), # 8 95 | Block(3, 80, 184, 80, hswish(), None, 1), # 9 96 | Block(3, 80, 480, 112, hswish(), SeModule(112), 1), # 10 97 | Block(3, 112, 672, 112, hswish(), SeModule(112), 1), # 11 98 | Block(5, 112, 672, 160, hswish(), SeModule(160), 1), # 12 99 | Block(5, 160, 672, 160, hswish(), SeModule(160), 2), # 13 100 | Block(5, 160, 960, 160, hswish(), SeModule(160), 1), # 14 101 | ) 102 | 103 | self.conv2 = nn.Conv2d(160, 960, kernel_size=1, stride=1, padding=0, bias=False) 104 | self.bn2 = nn.BatchNorm2d(960) 105 | self.hs2 = hswish() 106 | self.linear3 = nn.Linear(960, 1280) 107 | self.bn3 = nn.BatchNorm1d(1280) 108 | self.hs3 = hswish() 109 | self.linear4 = nn.Linear(1280, num_classes) 110 | self.init_params() 111 | 112 | if self.pretrained: 113 | self._load_weights() 114 | 115 | def init_params(self): 116 | for m in self.modules(): 117 | if isinstance(m, nn.Conv2d): 118 | init.kaiming_normal_(m.weight, mode='fan_out') 119 | if m.bias is not None: 120 | init.constant_(m.bias, 0) 121 | elif isinstance(m, nn.BatchNorm2d): 122 | init.constant_(m.weight, 1) 123 | init.constant_(m.bias, 0) 124 | elif isinstance(m, nn.Linear): 125 | init.normal_(m.weight, std=0.001) 126 | if m.bias is not None: 127 | init.constant_(m.bias, 0) 128 | 129 | def _load_weights(self): 130 | state_dict = {} 131 | state = torch.load(model_url, map_location="cpu")['state_dict'] 132 | for s in state: 133 | state_dict[s[7:]] = state[s] 134 | 135 | self.load_state_dict(state_dict) 136 | print("Load the imagenet Pretrain!!!") 137 | 138 | def forward(self, x): 139 | out = self.hs1(self.bn1(self.conv1(x))) 140 | features = [] 141 | for idx, block in enumerate(self.bneck): 142 | out = block(out) 143 | if idx == 1 or idx == 3 or idx == 6: 144 | features.append(out) 145 | 146 | out = self.hs2(self.bn2(self.conv2(out))) 147 | features.append(out) 148 | # out = F.adaptive_avg_pool2d(out, 1) 149 | # out = out.view(out.size(0), -1) 150 | # out = self.hs3(self.bn3(self.linear3(out))) 151 | # out = self.linear4(out) 152 | return features[0], features[1], features[2], features[3] 153 | 154 | 155 | class FCNMobileNetv3_8S(nn.Module): 156 | def __init__(self, num_classes=21, pretrained=True): 157 | super(FCNMobileNetv3_8S, self).__init__() 158 | self.backbone = MobileNetV3(pretrained=pretrained) 159 | 160 | self.NUM_CLASSES = num_classes 161 | 162 | self.smooth_conv1 = nn.Conv2d(960, 80, 3, 1, 1) 163 | self.relu1 = nn.ReLU(inplace=True) 164 | self.bn1 = nn.BatchNorm2d(80) 165 | 166 | self.smooth_conv2 = nn.Conv2d(80, 40, 3, 1, 1) 167 | self.relu2 = nn.ReLU(inplace=True) 168 | self.bn2 = nn.BatchNorm2d(40) 169 | 170 | self.smooth = nn.Conv2d(40, 40, 3, 1, 1) 171 | self.relu3 = nn.ReLU(inplace=True) 172 | self.bn3 = nn.BatchNorm2d(40) 173 | 174 | self.classification = nn.Conv2d(40, self.NUM_CLASSES, 1, 1, 0) 175 | 176 | 177 | def forward(self, x): 178 | b, c, h, w = x.shape 179 | p1, p2, p3, p4 = self.backbone(x) 180 | 181 | p4 = self.relu1(self.bn1(self.smooth_conv1(p4))) 182 | out2 = F.interpolate(p4, size=(p3.shape[2], p3.shape[3]), mode='bilinear', align_corners=True) 183 | p3 = out2 + p3 184 | out3 = self.relu2(self.bn2(self.smooth_conv2(p3))) 185 | out4 = F.interpolate(out3, size=(p2.shape[2], p2.shape[3]), mode='bilinear', align_corners=True) 186 | p2 = out4 + p2 187 | 188 | p2 = self.relu3(self.bn3(self.smooth(p2))) 189 | out = F.interpolate(p2, size=(h, w), mode='bilinear', align_corners=True) 190 | out = self.classification(out) 191 | return out 192 | 193 | 194 | if __name__ == '__main__': 195 | net = FCNMobileNetv3_8S() 196 | inputs = torch.randn(1, 3, 512, 512) 197 | outputs = net(inputs) 198 | print(outputs.shape) -------------------------------------------------------------------------------- /models/FCN/pretrained/mbv3_large.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FlyEgle/segmentationlight/45c3e574f578bac046bd6027d2f3dbb7d106e015/models/FCN/pretrained/mbv3_large.pth.tar -------------------------------------------------------------------------------- /models/HRNet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FlyEgle/segmentationlight/45c3e574f578bac046bd6027d2f3dbb7d106e015/models/HRNet/__init__.py -------------------------------------------------------------------------------- /models/HRNet/config/default.py: -------------------------------------------------------------------------------- 1 | 2 | # ------------------------------------------------------------------------------ 3 | # Copyright (c) Microsoft 4 | # Licensed under the MIT License. 5 | # Written by Ke Sun (sunk@mail.ustc.edu.cn) 6 | # ------------------------------------------------------------------------------ 7 | 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | 12 | import os 13 | 14 | from yacs.config import CfgNode as CN 15 | 16 | 17 | _C = CN() 18 | 19 | _C.OUTPUT_DIR = '' 20 | _C.LOG_DIR = '' 21 | _C.GPUS = (0,) 22 | _C.WORKERS = 4 23 | _C.PRINT_FREQ = 20 24 | _C.AUTO_RESUME = False 25 | _C.PIN_MEMORY = True 26 | _C.RANK = 0 27 | 28 | # Cudnn related params 29 | _C.CUDNN = CN() 30 | _C.CUDNN.BENCHMARK = True 31 | _C.CUDNN.DETERMINISTIC = False 32 | _C.CUDNN.ENABLED = True 33 | 34 | # common params for NETWORK 35 | _C.MODEL = CN() 36 | _C.MODEL.NAME = 'seg_hrnet' 37 | _C.MODEL.PRETRAINED = '' 38 | _C.MODEL.ALIGN_CORNERS = True 39 | _C.MODEL.NUM_OUTPUTS = 1 40 | _C.MODEL.EXTRA = CN(new_allowed=True) 41 | 42 | 43 | _C.MODEL.OCR = CN() 44 | _C.MODEL.OCR.MID_CHANNELS = 512 45 | _C.MODEL.OCR.KEY_CHANNELS = 256 46 | _C.MODEL.OCR.DROPOUT = 0.05 47 | _C.MODEL.OCR.SCALE = 1 48 | 49 | _C.LOSS = CN() 50 | _C.LOSS.USE_OHEM = False 51 | _C.LOSS.OHEMTHRES = 0.9 52 | _C.LOSS.OHEMKEEP = 100000 53 | _C.LOSS.CLASS_BALANCE = False 54 | _C.LOSS.BALANCE_WEIGHTS = [1] 55 | 56 | # DATASET related params 57 | _C.DATASET = CN() 58 | _C.DATASET.ROOT = '' 59 | _C.DATASET.DATASET = 'cityscapes' 60 | _C.DATASET.NUM_CLASSES = 19 61 | _C.DATASET.TRAIN_SET = 'list/cityscapes/train.lst' 62 | _C.DATASET.EXTRA_TRAIN_SET = '' 63 | _C.DATASET.TEST_SET = 'list/cityscapes/val.lst' 64 | 65 | # training 66 | _C.TRAIN = CN() 67 | 68 | _C.TRAIN.FREEZE_LAYERS = '' 69 | _C.TRAIN.FREEZE_EPOCHS = -1 70 | _C.TRAIN.NONBACKBONE_KEYWORDS = [] 71 | _C.TRAIN.NONBACKBONE_MULT = 10 72 | 73 | _C.TRAIN.IMAGE_SIZE = [1024, 512] # width * height 74 | _C.TRAIN.BASE_SIZE = 2048 75 | _C.TRAIN.DOWNSAMPLERATE = 1 76 | _C.TRAIN.FLIP = True 77 | _C.TRAIN.MULTI_SCALE = True 78 | _C.TRAIN.SCALE_FACTOR = 16 79 | 80 | _C.TRAIN.RANDOM_BRIGHTNESS = False 81 | _C.TRAIN.RANDOM_BRIGHTNESS_SHIFT_VALUE = 10 82 | 83 | _C.TRAIN.LR_FACTOR = 0.1 84 | _C.TRAIN.LR_STEP = [90, 110] 85 | _C.TRAIN.LR = 0.01 86 | _C.TRAIN.EXTRA_LR = 0.001 87 | 88 | _C.TRAIN.OPTIMIZER = 'sgd' 89 | _C.TRAIN.MOMENTUM = 0.9 90 | _C.TRAIN.WD = 0.0001 91 | _C.TRAIN.NESTEROV = False 92 | _C.TRAIN.IGNORE_LABEL = -1 93 | 94 | _C.TRAIN.BEGIN_EPOCH = 0 95 | _C.TRAIN.END_EPOCH = 484 96 | _C.TRAIN.EXTRA_EPOCH = 0 97 | 98 | _C.TRAIN.RESUME = False 99 | 100 | _C.TRAIN.BATCH_SIZE_PER_GPU = 32 101 | _C.TRAIN.SHUFFLE = True 102 | # only using some training samples 103 | _C.TRAIN.NUM_SAMPLES = 0 104 | 105 | # testing 106 | _C.TEST = CN() 107 | 108 | _C.TEST.IMAGE_SIZE = [2048, 1024] # width * height 109 | _C.TEST.BASE_SIZE = 2048 110 | 111 | _C.TEST.BATCH_SIZE_PER_GPU = 32 112 | # only testing some samples 113 | _C.TEST.NUM_SAMPLES = 0 114 | 115 | _C.TEST.MODEL_FILE = '' 116 | _C.TEST.FLIP_TEST = False 117 | _C.TEST.MULTI_SCALE = False 118 | _C.TEST.SCALE_LIST = [1] 119 | 120 | _C.TEST.OUTPUT_INDEX = -1 121 | 122 | # debug 123 | _C.DEBUG = CN() 124 | _C.DEBUG.DEBUG = False 125 | _C.DEBUG.SAVE_BATCH_IMAGES_GT = False 126 | _C.DEBUG.SAVE_BATCH_IMAGES_PRED = False 127 | _C.DEBUG.SAVE_HEATMAPS_GT = False 128 | _C.DEBUG.SAVE_HEATMAPS_PRED = False 129 | 130 | 131 | def update_config(cfg, args): 132 | cfg.defrost() 133 | 134 | cfg.merge_from_file(args) 135 | cfg.freeze() 136 | 137 | 138 | -------------------------------------------------------------------------------- /models/HRNet/config/seg_hrnet_w48.yaml: -------------------------------------------------------------------------------- 1 | CUDNN: 2 | BENCHMARK: true 3 | DETERMINISTIC: false 4 | ENABLED: true 5 | GPUS: (0,1,2,3,4,5,6,7) 6 | OUTPUT_DIR: 'output' 7 | LOG_DIR: 'log' 8 | WORKERS: 8 9 | PRINT_FREQ: 10 10 | 11 | DATASET: 12 | DATASET: ade20k 13 | ROOT: '../../../../dataset/ade20k/' 14 | TEST_SET: 'val.lst' 15 | TRAIN_SET: 'train.lst' 16 | NUM_CLASSES: 1 17 | MODEL: 18 | NAME: seg_hrnet 19 | NUM_OUTPUTS: 1 20 | PRETRAINED: '../../../../dataset/pretrained_models/HRNet_W48_C_ssld_pretrained.pth' 21 | EXTRA: 22 | FINAL_CONV_KERNEL: 1 23 | STAGE1: 24 | NUM_MODULES: 1 25 | NUM_RANCHES: 1 26 | BLOCK: BOTTLENECK 27 | NUM_BLOCKS: 28 | - 4 29 | NUM_CHANNELS: 30 | - 64 31 | FUSE_METHOD: SUM 32 | STAGE2: 33 | NUM_MODULES: 1 34 | NUM_BRANCHES: 2 35 | BLOCK: BASIC 36 | NUM_BLOCKS: 37 | - 4 38 | - 4 39 | NUM_CHANNELS: 40 | - 48 41 | - 96 42 | FUSE_METHOD: SUM 43 | STAGE3: 44 | NUM_MODULES: 4 45 | NUM_BRANCHES: 3 46 | BLOCK: BASIC 47 | NUM_BLOCKS: 48 | - 4 49 | - 4 50 | - 4 51 | NUM_CHANNELS: 52 | - 48 53 | - 96 54 | - 192 55 | FUSE_METHOD: SUM 56 | STAGE4: 57 | NUM_MODULES: 3 58 | NUM_BRANCHES: 4 59 | BLOCK: BASIC 60 | NUM_BLOCKS: 61 | - 4 62 | - 4 63 | - 4 64 | - 4 65 | NUM_CHANNELS: 66 | - 48 67 | - 96 68 | - 192 69 | - 384 70 | FUSE_METHOD: SUM 71 | LOSS: 72 | USE_OHEM: true 73 | OHEMTHRES: 0.9 74 | OHEMKEEP: 131072 75 | TRAIN: 76 | IMAGE_SIZE: 77 | - 520 78 | - 520 79 | BASE_SIZE: 520 80 | BATCH_SIZE_PER_GPU: 2 81 | SHUFFLE: true 82 | BEGIN_EPOCH: 0 83 | END_EPOCH: 120 84 | RESUME: true 85 | OPTIMIZER: sgd 86 | LR: 0.02 87 | WD: 0.0001 88 | MOMENTUM: 0.9 89 | NESTEROV: false 90 | FLIP: true 91 | MULTI_SCALE: true 92 | DOWNSAMPLERATE: 1 93 | IGNORE_LABEL: 255 94 | SCALE_FACTOR: 11 95 | TEST: 96 | IMAGE_SIZE: 97 | - 520 98 | - 520 99 | BASE_SIZE: 520 100 | BATCH_SIZE_PER_GPU: 1 101 | NUM_SAMPLES: 200 102 | FLIP_TEST: false 103 | MULTI_SCALE: false 104 | -------------------------------------------------------------------------------- /models/SegNet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FlyEgle/segmentationlight/45c3e574f578bac046bd6027d2f3dbb7d106e015/models/SegNet/__init__.py -------------------------------------------------------------------------------- /models/SegNet/net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class InvertedResidual(nn.Module): 7 | def __init__(self, inp, oup, stride, expand_ratio): 8 | super(InvertedResidual, self).__init__() 9 | assert stride in [1, 2] 10 | 11 | hidden_dim = round(inp * expand_ratio) 12 | self.identity = stride == 1 and inp == oup 13 | 14 | if expand_ratio == 1: 15 | self.conv = nn.Sequential( 16 | # dw 17 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 18 | nn.BatchNorm2d(hidden_dim), 19 | nn.ReLU6(inplace=True), 20 | # pw-linear 21 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 22 | nn.BatchNorm2d(oup), 23 | ) 24 | else: 25 | self.conv = nn.Sequential( 26 | # pw 27 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), 28 | nn.BatchNorm2d(hidden_dim), 29 | nn.ReLU6(inplace=True), 30 | # dw 31 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 32 | nn.BatchNorm2d(hidden_dim), 33 | nn.ReLU6(inplace=True), 34 | # pw-linear 35 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 36 | nn.BatchNorm2d(oup), 37 | ) 38 | 39 | def forward(self, x): 40 | if self.identity: 41 | return x + self.conv(x) 42 | else: 43 | return self.conv(x) 44 | 45 | 46 | if __name__ == '__main__': 47 | net = InvertedResidual(320, 320, 1, 1) 48 | print(net) 49 | -------------------------------------------------------------------------------- /models/SegNet/seg_mobilenetv2.py: -------------------------------------------------------------------------------- 1 | """ 2 | SegNet with mobilenetv2 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import math 8 | 9 | __all__ = ['mobilenetv2'] 10 | 11 | model_url = '/data/jiangmingchao/data/code/SegmentationLight/models/FCN/pretrained/mobilenetv2-c5e733a8.pth' 12 | 13 | def _make_divisible(v, divisor, min_value=None): 14 | """ 15 | This function is taken from the original tf repo. 16 | It ensures that all layers have a channel number that is divisible by 8 17 | It can be seen here: 18 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 19 | :param v: 20 | :param divisor: 21 | :param min_value: 22 | :return: 23 | """ 24 | if min_value is None: 25 | min_value = divisor 26 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 27 | # Make sure that round down does not go down by more than 10%. 28 | if new_v < 0.9 * v: 29 | new_v += divisor 30 | return new_v 31 | 32 | 33 | def conv_3x3_bn(inp, oup, stride): 34 | return nn.Sequential( 35 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 36 | nn.BatchNorm2d(oup), 37 | nn.ReLU6(inplace=True) 38 | ) 39 | 40 | 41 | def conv_1x1_bn(inp, oup): 42 | return nn.Sequential( 43 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 44 | nn.BatchNorm2d(oup), 45 | nn.ReLU6(inplace=True) 46 | ) 47 | 48 | 49 | class InvertedResidual(nn.Module): 50 | def __init__(self, inp, oup, stride, expand_ratio): 51 | super(InvertedResidual, self).__init__() 52 | assert stride in [1, 2] 53 | 54 | hidden_dim = round(inp * expand_ratio) 55 | self.identity = stride == 1 and inp == oup 56 | 57 | if expand_ratio == 1: 58 | self.conv = nn.Sequential( 59 | # dw 60 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 61 | nn.BatchNorm2d(hidden_dim), 62 | nn.ReLU6(inplace=True), 63 | # pw-linear 64 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 65 | nn.BatchNorm2d(oup), 66 | ) 67 | else: 68 | self.conv = nn.Sequential( 69 | # pw 70 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), 71 | nn.BatchNorm2d(hidden_dim), 72 | nn.ReLU6(inplace=True), 73 | # dw 74 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 75 | nn.BatchNorm2d(hidden_dim), 76 | nn.ReLU6(inplace=True), 77 | # pw-linear 78 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 79 | nn.BatchNorm2d(oup), 80 | ) 81 | 82 | def forward(self, x): 83 | if self.identity: 84 | return x + self.conv(x) 85 | else: 86 | return self.conv(x) 87 | 88 | 89 | class MobileNetV2(nn.Module): 90 | def __init__(self, num_classes=1000, width_mult=1., pretrained=True): 91 | super(MobileNetV2, self).__init__() 92 | # setting of inverted residual blocks 93 | self.cfgs = [ 94 | # t, c, n, s 95 | [1, 16, 1, 1], 96 | [6, 24, 2, 2], 97 | [6, 32, 3, 2], 98 | [6, 64, 4, 2], 99 | [6, 96, 3, 1], 100 | [6, 160, 3, 2], 101 | [6, 320, 1, 1], 102 | ] 103 | 104 | # building first layer 105 | input_channel = _make_divisible(32 * width_mult, 4 if width_mult == 0.1 else 8) 106 | layers = [conv_3x3_bn(3, input_channel, 2)] 107 | # building inverted residual blocks 108 | block = InvertedResidual 109 | for t, c, n, s in self.cfgs: 110 | output_channel = _make_divisible(c * width_mult, 4 if width_mult == 0.1 else 8) 111 | for i in range(n): 112 | layers.append(block(input_channel, output_channel, s if i == 0 else 1, t)) 113 | input_channel = output_channel 114 | self.features = nn.Sequential(*layers) 115 | # building last several layers 116 | output_channel = _make_divisible(1280 * width_mult, 4 if width_mult == 0.1 else 8) if width_mult > 1.0 else 1280 117 | self.conv = conv_1x1_bn(input_channel, output_channel) 118 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 119 | self.classifier = nn.Linear(output_channel, num_classes) 120 | 121 | self._initialize_weights() 122 | if pretrained: 123 | self._load_weights() 124 | 125 | def forward(self, x): 126 | features = [] 127 | for idx, feats in enumerate(self.features): 128 | x = feats(x) 129 | if idx==3 or idx == 6 or idx == 13 or idx == 17: 130 | features.append(x) 131 | # x = self.features(x) 132 | # x = self.conv(x) 133 | # features.append(x) 134 | return features[0], features[1], features[2], features[3] 135 | 136 | def _initialize_weights(self): 137 | for m in self.modules(): 138 | if isinstance(m, nn.Conv2d): 139 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 140 | m.weight.data.normal_(0, math.sqrt(2. / n)) 141 | if m.bias is not None: 142 | m.bias.data.zero_() 143 | elif isinstance(m, nn.BatchNorm2d): 144 | m.weight.data.fill_(1) 145 | m.bias.data.zero_() 146 | elif isinstance(m, nn.Linear): 147 | m.weight.data.normal_(0, 0.01) 148 | m.bias.data.zero_() 149 | 150 | def _load_weights(self): 151 | state_dict = torch.load(model_url, map_location="cpu") 152 | self.load_state_dict(state_dict) 153 | print("Load imagenet pretrain!!!") 154 | 155 | 156 | class SegNetMobilenetV2(nn.Module): 157 | def __init__(self, num_classes=21, pretrained=True): 158 | super(SegNetMobilenetV2, self).__init__() 159 | self.NUM_CLASSES = num_classes 160 | self.backbone = MobileNetV2(pretrained=pretrained) 161 | 162 | # d-block 320 -> 96 163 | self.d_block1 = nn.Sequential( 164 | InvertedResidual(320, 960, 1, 1), 165 | InvertedResidual(960, 320, 1, 1), 166 | InvertedResidual(320, 96, 1, 1) 167 | ) 168 | self.d_block2 = nn.Sequential( 169 | InvertedResidual(96, 320, 1, 1), 170 | InvertedResidual(320, 96, 1, 1), 171 | InvertedResidual(96, 64, 1, 1) 172 | ) 173 | self.d_block3 = nn.Sequential( 174 | InvertedResidual(64, 96, 1, 1), 175 | InvertedResidual(96, 64, 1, 1), 176 | InvertedResidual(64, 32, 1, 1) 177 | ) 178 | self.d_block4 = nn.Sequential( 179 | InvertedResidual(32, 64, 1, 1), 180 | InvertedResidual(64, 32, 1, 1), 181 | InvertedResidual(32, 24, 1, 1) 182 | ) 183 | self.d_block5 = nn.Sequential( 184 | InvertedResidual(24, 32, 1, 1), 185 | InvertedResidual(32, 24, 1, 1) 186 | ) 187 | self.classification = nn.Conv2d(24, self.NUM_CLASSES, 1, 1, 0) 188 | 189 | def forward(self, x): 190 | b, c, h, w = x.shape 191 | _, _, _, p4 = self.backbone(x) 192 | 193 | # up + block 194 | x = F.interpolate(p4, scale_factor=2.0, mode='bilinear', align_corners=True) 195 | x = self.d_block1(x) 196 | 197 | x = F.interpolate(x, scale_factor=2.0, mode='bilinear', align_corners=True) 198 | x = self.d_block2(x) 199 | 200 | x = F.interpolate(x, scale_factor=2.0, mode='bilinear', align_corners=True) 201 | x = self.d_block3(x) 202 | 203 | x = F.interpolate(x, scale_factor=2.0, mode='bilinear', align_corners=True) 204 | x = self.d_block4(x) 205 | 206 | x = F.interpolate(x, scale_factor=2.0, mode='bilinear', align_corners=True) 207 | x = self.d_block5(x) 208 | 209 | out = self.classification(x) 210 | return out 211 | 212 | 213 | if __name__ == '__main__': 214 | net = SegNetMobilenetV2(21) 215 | inputs = torch.randn(1, 3, 512, 512) 216 | # print(net) 217 | outputs = net(inputs) 218 | print(outputs.shape) 219 | -------------------------------------------------------------------------------- /models/SegNet/seg_vgg16.py: -------------------------------------------------------------------------------- 1 | """ 2 | SegNet: A Deep Convolutional Encoder-Decoder Architecture for Image Segmentation 3 | @author: FlyEgle 4 | @datetime: 2022-01-15 5 | """ 6 | import torch 7 | import torch.nn as nn 8 | import numpy as np 9 | import torch.nn.functional as F 10 | 11 | 12 | VGG_CKPT = "/data/jiangmingchao/data/AICKPT/r50_losses_1.0856279492378236.pth" 13 | 14 | 15 | class ConvBnRelu(nn.Module): 16 | def __init__(self, kernel_size, in_channels, out_channels, stride=1, padding=1, bn=True): 17 | super(ConvBnRelu, self).__init__() 18 | self.kernel_size = kernel_size 19 | self.in_channels = in_channels 20 | self.out_channels = out_channels 21 | self.padding = padding 22 | self.stride = stride 23 | self.BN = bn 24 | 25 | self.conv = nn.Conv2d( 26 | in_channels=self.in_channels, 27 | out_channels=self.out_channels, 28 | kernel_size=self.kernel_size, 29 | padding=self.padding, 30 | stride=self.stride, 31 | bias=False) 32 | self.relu = nn.ReLU(inplace=True) 33 | if self.BN: 34 | self.bn = nn.BatchNorm2d(self.out_channels) 35 | 36 | def forward(self, x): 37 | if self.BN: 38 | x = self.relu(self.bn(self.conv(x))) 39 | else: 40 | x = self.relu(self.conv(x)) 41 | 42 | return x 43 | 44 | class ConvBn(nn.Module): 45 | def __init__(self, kernel_size, in_channels, out_channels, stride=1, padding=1, bn=True): 46 | super(ConvBn, self).__init__() 47 | self.kernel_size = kernel_size 48 | self.in_channels = in_channels 49 | self.out_channels = out_channels 50 | self.stride = stride 51 | self.padding = padding 52 | self.BN = bn 53 | 54 | self.conv = nn.Conv2d( 55 | in_channels=self.in_channels, 56 | out_channels=self.out_channels, 57 | kernel_size=self.kernel_size, 58 | stride=self.stride, 59 | padding=self.padding, 60 | bias=False 61 | ) 62 | if self.BN: 63 | self.bn = nn.BatchNorm2d(self.out_channels) 64 | 65 | def forward(self, x): 66 | if self.BN: 67 | x = self.bn(self.conv(x)) 68 | else: 69 | x = self.conv(x) 70 | return x 71 | 72 | 73 | class Block(nn.Module): 74 | def __init__(self, layer_name, layer_num, in_channels, out_channels, block_id=1): 75 | super(Block, self).__init__() 76 | self.layer = layer_num 77 | self.in_channels = in_channels 78 | self.out_channels = out_channels 79 | self.layer_list = [] 80 | 81 | for _ in range(self.layer): 82 | if layer_name.lower() == "convbnrelu": 83 | self.layer_list.append(ConvBnRelu( 84 | 3, self.in_channels, self.out_channels 85 | )) 86 | elif layer_name.lower() == "convbn": 87 | self.layer_list.append(ConvBn( 88 | 3, self.in_channels, self.out_channels 89 | )) 90 | 91 | self.in_channels = self.out_channels 92 | 93 | self.block = nn.Sequential(*self.layer_list) 94 | 95 | def forward(self, x): 96 | return self.block(x) 97 | 98 | 99 | class SegNetVgg16(nn.Module): 100 | def __init__(self, num_classes): 101 | super(SegNetVgg16, self).__init__() 102 | 103 | self.num_classes = num_classes 104 | # Encoder Blocks 105 | self.block1 = Block("convbnrelu", 2, 3, 64) 106 | self.block2 = Block("convbnrelu", 2, 64, 128) 107 | self.block3 = Block("convbnrelu", 3, 128, 256) 108 | self.block4 = Block("convbnrelu", 3, 256, 512) 109 | self.block5 = Block("convbnrelu", 3, 512, 512) 110 | 111 | # Encoder Poolings 112 | self.pool1 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), return_indices=True) 113 | self.pool2 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), return_indices=True) 114 | self.pool3 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), return_indices=True) 115 | self.pool4 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), return_indices=True) 116 | self.pool5 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), return_indices=True) 117 | 118 | # smooth 119 | # self.smoothblock = Block("convbnrelu", 3, 512, 512) 120 | 121 | # Decoder 122 | self.d_block1 = Block("convbnrelu", 3, 512, 512) 123 | self.d_block2 = Block("convbnrelu", 3, 512, 256) 124 | self.d_block3 = Block("convbnrelu", 3, 256, 128) 125 | self.d_block4 = Block("convbnrelu", 2, 128, 64) 126 | self.d_block5 = nn.Sequential( 127 | ConvBnRelu(3, 64, 64, 1, 1), 128 | ConvBnRelu(3, 64, 64, 1, 1), 129 | ) 130 | 131 | self.classification = nn.Conv2d(64, self.num_classes, 1, 1) 132 | 133 | # Decoding Poolings 134 | self.d_pool1 = nn.MaxUnpool2d(kernel_size=(2, 2), stride=(2, 2)) 135 | self.d_pool2 = nn.MaxUnpool2d(kernel_size=(2, 2), stride=(2, 2)) 136 | self.d_pool3 = nn.MaxUnpool2d(kernel_size=(2, 2), stride=(2, 2)) 137 | self.d_pool4 = nn.MaxUnpool2d(kernel_size=(2, 2), stride=(2, 2)) 138 | self.d_pool5 = nn.MaxUnpool2d(kernel_size=(2, 2), stride=(2, 2)) 139 | 140 | # self.d_pool1 = nn.Upsample(scale_factor=2.0, mode='bilinear') 141 | # self.d_pool2 = nn.Upsample(scale_factor=2.0, mode='bilinear') 142 | # self.d_pool3 = nn.Upsample(scale_factor=2.0, mode='bilinear') 143 | # self.d_pool4 = nn.Upsample(scale_factor=2.0, mode='bilinear') 144 | # self.d_pool5 = nn.Upsample(scale_factor=2.0, mode='bilinear') 145 | 146 | self._initialize_weights() 147 | self._load_pretrain(VGG_CKPT) 148 | 149 | def forward(self, x): 150 | # encoder 151 | x = self.block1(x) 152 | x, p1_indices = self.pool1(x) 153 | x = self.block2(x) 154 | x, p2_indices = self.pool2(x) 155 | x = self.block3(x) 156 | x, p3_indices = self.pool3(x) 157 | x = self.block4(x) 158 | x, p4_indices = self.pool4(x) 159 | x = self.block5(x) 160 | x, p5_indices = self.pool5(x) 161 | 162 | # x = self.smoothblock(x) 163 | 164 | # decoder 165 | b, c, h, w = x.shape 166 | x = self.d_pool5(x, p5_indices, output_size=torch.Size([b, c, h*2, w*2])) 167 | # x = self.d_pool5(x) 168 | x = self.d_block1(x) 169 | 170 | b, c, h, w = x.shape 171 | x = self.d_pool4(x, p4_indices, output_size=torch.Size([b, c, h*2, w*2])) 172 | # x = self.d_pool4(x) 173 | x = self.d_block2(x) 174 | 175 | b, c, h, w = x.shape 176 | x = self.d_pool3(x, p3_indices, output_size=torch.Size([b, c, h*2, w*2])) 177 | # x = self.d_pool3(x) 178 | x = self.d_block3(x) 179 | 180 | b, c, h, w = x.shape 181 | x = self.d_pool2(x, p2_indices, output_size=torch.Size([b, c, h*2, w*2])) 182 | # x = self.d_pool2(x) 183 | x = self.d_block4(x) 184 | 185 | b, c, h, w= x.shape 186 | x = self.d_pool1(x, p1_indices, output_size=torch.Size([b, c, h*2, w*2])) 187 | # x = self.d_pool1(x) 188 | x = self.d_block5(x) 189 | 190 | outputs = self.classification(x) 191 | return outputs 192 | 193 | def _initialize_weights(self) -> None: 194 | for m in self.modules(): 195 | if isinstance(m, nn.Conv2d): 196 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 197 | if m.bias is not None: 198 | nn.init.constant_(m.bias, 0) 199 | elif isinstance(m, nn.BatchNorm2d): 200 | nn.init.constant_(m.weight, 1) 201 | nn.init.constant_(m.bias, 0) 202 | 203 | def _load_pretrain(self, VGG_CKPT): 204 | state_dict = torch.load(VGG_CKPT, map_location="cpu")['state_dict'] 205 | model_state_dict = self.state_dict() 206 | pretrain_state = {} 207 | for key, value in model_state_dict.items(): 208 | if key in state_dict and value.shape == state_dict[key].shape: 209 | pretrain_state[key] = state_dict[key] 210 | 211 | model_state_dict.update(pretrain_state) 212 | self.load_state_dict(model_state_dict) 213 | print("Load vgg imagenet pretrain!!!") 214 | 215 | 216 | if __name__ == "__main__": 217 | inputs = torch.randn(1, 3, 512, 512) 218 | model = SegNetVgg16(21) 219 | print(model) 220 | # for k, v in model.state_dict().items(): 221 | # print(k) 222 | # print(model.state_dict().keys()) 223 | # print(model) 224 | outputs = model(inputs) 225 | print(outputs.shape) 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | 235 | 236 | 237 | -------------------------------------------------------------------------------- /models/U2Net/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FlyEgle/segmentationlight/45c3e574f578bac046bd6027d2f3dbb7d106e015/models/U2Net/__init__.py -------------------------------------------------------------------------------- /models/U2Net/u2net_bp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import math 5 | 6 | __all__ = ['U2NET_full', 'U2NET_lite'] 7 | 8 | 9 | def _upsample_like(x, size): 10 | return nn.Upsample(size=size, mode='bilinear', align_corners=False)(x) 11 | 12 | 13 | def _size_map(x, height): 14 | # {height: size} for Upsample 15 | size = list(x.shape[-2:]) 16 | sizes = {} 17 | for h in range(1, height): 18 | sizes[h] = size 19 | size = [math.ceil(w / 2) for w in size] 20 | return sizes 21 | 22 | 23 | class REBNCONV(nn.Module): 24 | def __init__(self, in_ch=3, out_ch=3, dilate=1): 25 | super(REBNCONV, self).__init__() 26 | 27 | self.conv_s1 = nn.Conv2d(in_ch, out_ch, 3, padding=1 * dilate, dilation=1 * dilate) 28 | self.bn_s1 = nn.BatchNorm2d(out_ch) 29 | self.relu_s1 = nn.ReLU(inplace=True) 30 | 31 | def forward(self, x): 32 | return self.relu_s1(self.bn_s1(self.conv_s1(x))) 33 | 34 | 35 | class RSU(nn.Module): 36 | def __init__(self, name, height, in_ch, mid_ch, out_ch, dilated=False): 37 | super(RSU, self).__init__() 38 | self.name = name 39 | self.height = height 40 | self.dilated = dilated 41 | self._make_layers(height, in_ch, mid_ch, out_ch, dilated) 42 | 43 | def forward(self, x): 44 | sizes = _size_map(x, self.height) 45 | x = self.rebnconvin(x) 46 | 47 | # U-Net like symmetric encoder-decoder structure 48 | def unet(x, height=1): 49 | if height < self.height: 50 | x1 = getattr(self, f'rebnconv{height}')(x) 51 | if not self.dilated and height < self.height - 1: 52 | x2 = unet(getattr(self, 'downsample')(x1), height + 1) 53 | else: 54 | x2 = unet(x1, height + 1) 55 | 56 | x = getattr(self, f'rebnconv{height}d')(torch.cat((x2, x1), 1)) 57 | return _upsample_like(x, sizes[height - 1]) if not self.dilated and height > 1 else x 58 | else: 59 | return getattr(self, f'rebnconv{height}')(x) 60 | 61 | return x + unet(x) 62 | 63 | def _make_layers(self, height, in_ch, mid_ch, out_ch, dilated=False): 64 | self.add_module('rebnconvin', REBNCONV(in_ch, out_ch)) 65 | self.add_module('downsample', nn.MaxPool2d(2, stride=2, ceil_mode=True)) 66 | 67 | self.add_module(f'rebnconv1', REBNCONV(out_ch, mid_ch)) 68 | self.add_module(f'rebnconv1d', REBNCONV(mid_ch * 2, out_ch)) 69 | 70 | for i in range(2, height): 71 | dilate = 1 if not dilated else 2 ** (i - 1) 72 | self.add_module(f'rebnconv{i}', REBNCONV(mid_ch, mid_ch, dilate=dilate)) 73 | self.add_module(f'rebnconv{i}d', REBNCONV(mid_ch * 2, mid_ch, dilate=dilate)) 74 | 75 | dilate = 2 if not dilated else 2 ** (height - 1) 76 | self.add_module(f'rebnconv{height}', REBNCONV(mid_ch, mid_ch, dilate=dilate)) 77 | 78 | 79 | class U2NET(nn.Module): 80 | def __init__(self, cfgs, out_ch): 81 | super(U2NET, self).__init__() 82 | self.out_ch = out_ch 83 | self._make_layers(cfgs) 84 | 85 | def forward(self, x): 86 | sizes = _size_map(x, self.height) 87 | maps = [] # storage for maps 88 | 89 | # side saliency map 90 | def unet(x, height=1): 91 | if height < 6: 92 | x1 = getattr(self, f'stage{height}')(x) 93 | x2 = unet(getattr(self, 'downsample')(x1), height + 1) 94 | x = getattr(self, f'stage{height}d')(torch.cat((x2, x1), 1)) 95 | side(x, height) 96 | return _upsample_like(x, sizes[height - 1]) if height > 1 else x 97 | else: 98 | x = getattr(self, f'stage{height}')(x) 99 | side(x, height) 100 | return _upsample_like(x, sizes[height - 1]) 101 | 102 | def side(x, h): 103 | # side output saliency map (before sigmoid) 104 | x = getattr(self, f'side{h}')(x) 105 | x = _upsample_like(x, sizes[1]) 106 | maps.append(x) 107 | 108 | def fuse(): 109 | # fuse saliency probability maps 110 | maps.reverse() 111 | x = torch.cat(maps, 1) 112 | x = getattr(self, 'outconv')(x) 113 | maps.insert(0, x) 114 | return [torch.sigmoid(x) for x in maps] 115 | 116 | unet(x) 117 | maps = fuse() 118 | return maps 119 | 120 | def _make_layers(self, cfgs): 121 | self.height = int((len(cfgs) + 1) / 2) 122 | self.add_module('downsample', nn.MaxPool2d(2, stride=2, ceil_mode=True)) 123 | for k, v in cfgs.items(): 124 | # build rsu block 125 | self.add_module(k, RSU(v[0], *v[1])) 126 | if v[2] > 0: 127 | # build side layer 128 | self.add_module(f'side{v[0][-1]}', nn.Conv2d(v[2], self.out_ch, 3, padding=1)) 129 | # build fuse layer 130 | self.add_module('outconv', nn.Conv2d(int(self.height * self.out_ch), self.out_ch, 1)) 131 | 132 | 133 | def U2NET_full(): 134 | full = { 135 | # cfgs for building RSUs and sides 136 | # {stage : [name, (height(L), in_ch, mid_ch, out_ch, dilated), side]} 137 | 'stage1': ['En_1', (7, 3, 32, 64), -1], 138 | 'stage2': ['En_2', (6, 64, 32, 128), -1], 139 | 'stage3': ['En_3', (5, 128, 64, 256), -1], 140 | 'stage4': ['En_4', (4, 256, 128, 512), -1], 141 | 'stage5': ['En_5', (4, 512, 256, 512, True), -1], 142 | 'stage6': ['En_6', (4, 512, 256, 512, True), 512], 143 | 'stage5d': ['De_5', (4, 1024, 256, 512, True), 512], 144 | 'stage4d': ['De_4', (4, 1024, 128, 256), 256], 145 | 'stage3d': ['De_3', (5, 512, 64, 128), 128], 146 | 'stage2d': ['De_2', (6, 256, 32, 64), 64], 147 | 'stage1d': ['De_1', (7, 128, 16, 64), 64], 148 | } 149 | return U2NET(cfgs=full, out_ch=1) 150 | 151 | 152 | def U2NET_lite(): 153 | lite = { 154 | # cfgs for building RSUs and sides 155 | # {stage : [name, (height(L), in_ch, mid_ch, out_ch, dilated), side]} 156 | 'stage1': ['En_1', (7, 3, 16, 64), -1], 157 | 'stage2': ['En_2', (6, 64, 16, 64), -1], 158 | 'stage3': ['En_3', (5, 64, 16, 64), -1], 159 | 'stage4': ['En_4', (4, 64, 16, 64), -1], 160 | 'stage5': ['En_5', (4, 64, 16, 64, True), -1], 161 | 'stage6': ['En_6', (4, 64, 16, 64, True), 64], 162 | 'stage5d': ['De_5', (4, 128, 16, 64, True), 64], 163 | 'stage4d': ['De_4', (4, 128, 16, 64), 64], 164 | 'stage3d': ['De_3', (5, 128, 16, 64), 64], 165 | 'stage2d': ['De_2', (6, 128, 16, 64), 64], 166 | 'stage1d': ['De_1', (7, 128, 16, 64), 64], 167 | } 168 | return U2NET(cfgs=lite, out_ch=1) 169 | 170 | 171 | if __name__ == '__main__': 172 | model = U2NET_full() 173 | print(model) 174 | inputs = torch.randn(1,3,320,320) 175 | outputs = model(inputs) 176 | for o in outputs: 177 | print(o.shape) -------------------------------------------------------------------------------- /models/UNet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FlyEgle/segmentationlight/45c3e574f578bac046bd6027d2f3dbb7d106e015/models/UNet/__init__.py -------------------------------------------------------------------------------- /models/UNet/unet.py: -------------------------------------------------------------------------------- 1 | """ U-Net model """ 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class ConvBnReLU(nn.Module): 8 | def __init__(self, inp, oup, kernel, stride, padding): 9 | super(ConvBnReLU, self).__init__() 10 | self.conv = nn.Sequential( 11 | nn.Conv2d(inp, oup, kernel, stride, padding), 12 | nn.BatchNorm2d(oup), 13 | nn.ReLU(inplace=True) 14 | ) 15 | 16 | def forward(self, x): 17 | x = self.conv(x) 18 | return x 19 | 20 | 21 | class Conv3x3(nn.Module): 22 | def __init__(self, inp, oup): 23 | super(Conv3x3, self).__init__() 24 | self.conv = nn.Sequential( 25 | ConvBnReLU(inp, oup, 3, 1, 1), 26 | ConvBnReLU(oup, oup, 3, 1, 1), 27 | ) 28 | 29 | def forward(self, x): 30 | return self.conv(x) 31 | 32 | 33 | class UNet(nn.Module): 34 | def __init__(self, num_classes=21): 35 | super(UNet, self).__init__() 36 | self.num_classes = num_classes 37 | 38 | self.encoder_block1 = Conv3x3(3, 64) 39 | self.encoder_block2 = Conv3x3(64, 128) 40 | self.encoder_block3 = Conv3x3(128, 256) 41 | self.encoder_block4 = Conv3x3(256, 512) 42 | self.encoder_block5 = Conv3x3(512, 1024) 43 | 44 | self.pool1 = nn.MaxPool2d(2, 2) 45 | self.pool2 = nn.MaxPool2d(2, 2) 46 | self.pool3 = nn.MaxPool2d(2, 2) 47 | self.pool4 = nn.MaxPool2d(2, 2) 48 | 49 | self.smooth1 = ConvBnReLU(1024, 512, 1, 1, 0) 50 | self.smooth2 = ConvBnReLU(512, 256, 1, 1, 0) 51 | self.smooth3 = ConvBnReLU(256, 128, 1, 1, 0) 52 | self.smooth4 = ConvBnReLU(128, 64, 1, 1, 0) 53 | 54 | self.decoder_block1 = Conv3x3(1024, 512) 55 | self.decoder_block2 = Conv3x3(512, 256) 56 | self.decoder_block3 = Conv3x3(256, 128) 57 | self.decoder_block4 = Conv3x3(128, 64) 58 | 59 | self.classification = nn.Conv2d(64, self.num_classes, 1) 60 | 61 | def forward(self, x): 62 | e1 = self.encoder_block1(x) 63 | 64 | e2 = self.pool1(self.encoder_block2(e1)) # (bs, 64, 256, 256) 65 | e3 = self.pool2(self.encoder_block3(e2)) # (bs, 256, 128, 128) 66 | e4 = self.pool3(self.encoder_block4(e3)) # (bs, 512, 64, 64) 67 | x = self.pool4(self.encoder_block5(e4)) # (bs, 1024, 32, 32) 68 | 69 | x = self.smooth1(x) 70 | d5 = F.interpolate(x, scale_factor=2.0, mode='bilinear', align_corners=True) # (bs, 512, 64, 64) 71 | 72 | x = self.decoder_block1(torch.cat([e4, d5], dim=1)) 73 | x = self.smooth2(x) 74 | d4 = F.interpolate(x, scale_factor=2.0, mode='bilinear', align_corners=True) 75 | 76 | x = self.decoder_block2(torch.cat([e3, d4], dim=1)) 77 | x = self.smooth3(x) 78 | d3 = F.interpolate(x, scale_factor=2.0, mode='bilinear', align_corners=True) 79 | 80 | x = self.decoder_block3(torch.cat([e2, d3], dim=1)) 81 | x = self.smooth4(x) 82 | d2 = F.interpolate(x, scale_factor=2.0, mode='bilinear', align_corners=True) 83 | 84 | out = self.decoder_block4(torch.cat([e1, d2], dim=1)) 85 | 86 | output = self.classification(out) 87 | return output 88 | 89 | 90 | if __name__ == '__main__': 91 | inputs = torch.randn(1, 3, 512, 512) 92 | unet = UNet() 93 | outputs = unet(inputs) 94 | print(outputs.shape) 95 | 96 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FlyEgle/segmentationlight/45c3e574f578bac046bd6027d2f3dbb7d106e015/models/__init__.py -------------------------------------------------------------------------------- /models/model_factory.py: -------------------------------------------------------------------------------- 1 | """Model Factory 2 | @author: FlyEgle 3 | @datetime: 2022-02-09 4 | """ 5 | import torch.nn as nn 6 | 7 | from types import FunctionType 8 | from models.FCN.fcn import FCN8s, FCN16s, FCN32s 9 | from models.FCN.fcn_resnet import FCNResNet50_8S, FCNResNet101_8S, FCNResNet152_8S 10 | from models.FCN.fcn_mobilenetv3 import FCNMobileNetv3_8S 11 | from models.FCN.fcn_mobilenetv2 import FCNMobilenetv2_8S 12 | from models.SegNet.seg_vgg16 import SegNetVgg16 13 | from models.SegNet.seg_resnet import SegResNet50 14 | from models.SegNet.seg_mobilenetv2 import SegNetMobilenetV2 15 | from models.UNet.unet import UNet # this version is not modify with the imagenet standard cnn 16 | from models.UNet.unet_resnet import UNET 17 | from models.U2Net.u2net import U2NET, U2NET_L, U2NET_SCALE, U2NET_MUTIL_TASK # this version is used for 0-1 segmentation 18 | from models.DeepLab.deeplab import make_deeplab 19 | # seg-hrnet 20 | from models.HRNet.hrnet_seg import get_seg_model 21 | from models.HRNet.config.default import _C as config 22 | from models.HRNet.config.default import update_config 23 | 24 | # make hrnet 25 | def makeHRNet(num_classes=1): 26 | args="models/HRNet/config/seg_hrnet_w48.yaml" 27 | update_config(config, args) 28 | # print(config) 29 | # config.dataset.num_classes = num_classes 30 | model = get_seg_model(config, use_fpn=True) 31 | return model 32 | 33 | class ModelFactory: 34 | def __init__(self): 35 | # model class 36 | self.__MODEL_DICT__ = { 37 | 'fcn_8s': FCN8s, 38 | 'fcn_16s': FCN16s, 39 | 'fcn_32s': FCN32s, 40 | 'fcn_8s_resnet50': FCNResNet50_8S, 41 | 'fcn_8s_resnet101': FCNResNet101_8S, 42 | 'fcn_8s_resnet152': FCNResNet152_8S, 43 | 'fcn_8s_mobilenetv3': FCNMobileNetv3_8S, 44 | 'fcn_8s_mobilenetv2': FCNMobilenetv2_8S, 45 | 'segnet_vgg16': SegNetVgg16, 46 | 'segnet_resnet50': SegResNet50, 47 | 'segnet_mobilenetv2': SegNetMobilenetV2, 48 | 'unet_full': UNet, 49 | 'unet_resnet50': UNET, 50 | 'u2net': U2NET, 51 | 'u2netl': U2NET_L, 52 | 'u2net_mutil': U2NET_MUTIL_TASK, 53 | 'u2net_scale': U2NET_SCALE, 54 | 'hrnet': makeHRNet, 55 | } 56 | # TODO:modify each function to model class 57 | # model function 58 | self.__FUNCTION_DICT__ = { 59 | 'deeplab': make_deeplab, 60 | # 'u2net_full': U2NET, 61 | # 'u2net_lite': U2NETP 62 | } 63 | 64 | def setattr(self, name, value): 65 | if name in self.__MODEL_DICT__ or name in self.__FUNCTION_DICT__: 66 | print(f"{name} have been used in the model, please check or change a new name") 67 | else: 68 | # function 69 | if isinstance(value, FunctionType): 70 | self.__FUNCTION_DICT__[name] = value 71 | # class 72 | else: 73 | self.__MODEL_DICT__[name] = value 74 | 75 | def getattr(self, name): 76 | model_name = name.lower() 77 | if model_name in self.__MODEL_DICT__: 78 | return self.__MODEL_DICT__[model_name] 79 | elif model_name in self.__FUNCTION_DICT__: 80 | return self.__FUNCTION_DICT__[model_name] 81 | 82 | 83 | if __name__ == '__main__': 84 | 85 | factory = ModelFactory() 86 | model = factory.getattr("unet") 87 | print(model) 88 | 89 | 90 | 91 | 92 | 93 | -------------------------------------------------------------------------------- /post_process.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | from tqdm import tqdm 5 | 6 | 7 | def getShrink(mask, shrink=2): 8 | for _ in range(shrink): 9 | contours, hierachy = cv2.findContours( 10 | mask[:,:,0], 11 | cv2.RETR_EXTERNAL, 12 | cv2.CHAIN_APPROX_SIMPLE 13 | ) 14 | cv2.drawContours( 15 | mask, 16 | contours, 17 | -1, 18 | (125, 125, 125), 19 | 1 20 | ) 21 | 22 | mask[mask==np.array([125, 125, 125])] = 0 23 | 24 | return mask 25 | 26 | 27 | def edgePostProcess(mask, image): 28 | """Edge post Process 29 | Args: 30 | mask: a ndarray map, value is [0,255], shape is (h, w, 3) 31 | image: a ndarray map, value is 0-255, shape is(h, w, 3) 32 | Returns: 33 | outputs: edge blur image 34 | """ 35 | mask[mask==255] = 1 36 | mask = getShrink(mask) 37 | 38 | image = image * mask 39 | image[image==0] = 255 40 | blur_image = cv2.GaussianBlur(image, (5, 5), 0) 41 | new_mask = np.zeros(image.shape, np.uint8) 42 | contours, hierachy = cv2.findContours( 43 | mask[:,:,0], 44 | cv2.RETR_EXTERNAL, 45 | cv2.CHAIN_APPROX_SIMPLE 46 | ) 47 | cv2.drawContours(new_mask, contours, -1, (255, 255, 255), 5) 48 | output = np.where(new_mask==np.array([255, 255, 255]), blur_image, image) 49 | return output 50 | 51 | if True: 52 | # mask_folder = "/data/jiangmingchao/data/dataset/seg_exp/out_0.8_new_mask" 53 | # imge_folder = "/data/jiangmingchao/data/dataset/seg_exp/out_0.8_new_out" 54 | mask_folder = "/data/jiangmingchao/data/code/SegmentationLight/tmp/mask" 55 | imge_folder = "/data/jiangmingchao/data/code/SegmentationLight/tmp/out" 56 | 57 | mask_list = [os.path.join(mask_folder, x) for x in os.listdir(mask_folder)] 58 | imge_list = [os.path.join(imge_folder, x) for x in os.listdir(imge_folder)] 59 | 60 | print(len(mask_list)) 61 | print(len(imge_list)) 62 | 63 | output_folder = "/data/jiangmingchao/data/code/SegmentationLight/tmp/blur" 64 | if not os.path.exists(output_folder): 65 | os.mkdir(output_folder) 66 | 67 | 68 | for data in tqdm(zip(mask_list, imge_list)): 69 | mask_path, image_path = data[0], data[1] 70 | mask = cv2.imread(mask_path) 71 | cv2.imwrite("./pre.png", mask) 72 | # print(mask.shape) 73 | mask[mask==255] = 1 74 | mask = getShrink(mask) 75 | 76 | image = cv2.imread(image_path) 77 | image = image * mask 78 | image[image==0] = 255 79 | blur_image = cv2.GaussianBlur(image, (5, 5), 0) 80 | new_mask = np.zeros(image.shape, np.uint8) 81 | contours, hierachy = cv2.findContours( 82 | mask[:,:,0], 83 | cv2.RETR_EXTERNAL, 84 | cv2.CHAIN_APPROX_SIMPLE 85 | ) 86 | 87 | # for cnt in contours: 88 | # print(cnt.shape) 89 | 90 | cv2.drawContours(new_mask, contours, -1, (255, 255, 255), 5) 91 | output = np.where(new_mask==np.array([255, 255, 255]), blur_image, image) 92 | # print(output) 93 | # output[output==0]=255 94 | cv2.imwrite(os.path.join(output_folder, mask_path.split('/')[-1]), output) 95 | 96 | # cv2.imwrite("./out.png", mask*255) 97 | # break -------------------------------------------------------------------------------- /script/FCNs/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | OMP_NUM_THREADS=1 3 | MKL_NUM_THREADS=1 4 | export OMP_NUM_THREADS 5 | export MKL_NUM_THREADS 6 | cd /data/jiangmingchao/data/code/SegmentationLight; 7 | CUDA_VISIBLE_DEVICES=0 python -W ignore -m torch.distributed.launch --nproc_per_node 1 main.py \ 8 | --hyp /data/jiangmingchao/data/code/SegmentationLight/hyparam/U2Net/adjust_lr_bs.yaml 9 | -------------------------------------------------------------------------------- /script/FCNs/train2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | OMP_NUM_THREADS=1 3 | MKL_NUM_THREADS=1 4 | export OMP_NUM_THREADS 5 | export MKL_NUM_THREADS 6 | cd /data/jiangmingchao/data/code/SegmentationLight; 7 | CUDA_VISIBLE_DEVICES=1 python -W ignore -m torch.distributed.launch --master_port 29501 --nproc_per_node 1 main.py \ 8 | --hyp /data/jiangmingchao/data/code/SegmentationLight/hyparam/FCNs/fcn_mbv2_8s.yaml -------------------------------------------------------------------------------- /script/HRNet/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | OMP_NUM_THREADS=1 3 | MKL_NUM_THREADS=1 4 | export OMP_NUM_THREADS 5 | export MKL_NUM_THREADS 6 | cd /data/jiangmingchao/data/code/SegmentationLight; 7 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -W ignore -m torch.distributed.launch \ 8 | --master_port 2953 \ 9 | --nproc_per_node 8 main.py \ 10 | --hyp /data/jiangmingchao/data/code/SegmentationLight/hyparam/HRNet/baseline_320.yaml -------------------------------------------------------------------------------- /script/HRNet/train2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | OMP_NUM_THREADS=1 3 | MKL_NUM_THREADS=1 4 | export OMP_NUM_THREADS 5 | export MKL_NUM_THREADS 6 | cd /data/jiangmingchao/data/code/SegmentationLight; 7 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -W ignore -m torch.distributed.launch \ 8 | --master_port 2953 \ 9 | --nproc_per_node 8 main.py \ 10 | --hyp /data/jiangmingchao/data/code/SegmentationLight/hyparam/HRNet/baseline_320.yaml -------------------------------------------------------------------------------- /script/U2Net/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | OMP_NUM_THREADS=1 3 | MKL_NUM_THREADS=1 4 | export OMP_NUM_THREADS 5 | export MKL_NUM_THREADS 6 | cd /data/jiangmingchao/data/code/SegmentationLight; 7 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -W ignore -m torch.distributed.launch \ 8 | --master_port 2959 \ 9 | --nproc_per_node 8 main.py \ 10 | --hyp /data/jiangmingchao/data/code/SegmentationLight/hyparam/U2Net/baseline_ce_pretrain_480_data.yaml 11 | # --hyp /data/jiangmingchao/data/code/SegmentationLight/hyparam/U2Net/baseline_bce_dice_pretrain_320_data.yaml -------------------------------------------------------------------------------- /script/U2Net/train2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | OMP_NUM_THREADS=1 3 | MKL_NUM_THREADS=1 4 | export OMP_NUM_THREADS 5 | export MKL_NUM_THREADS 6 | cd /data/jiangmingchao/data/code/SegmentationLight; 7 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -W ignore -m torch.distributed.launch \ 8 | --master_port 2953 \ 9 | --nproc_per_node 8 main.py \ 10 | --hyp /data/jiangmingchao/data/code/SegmentationLight/hyparam/U2Net/baseline_bce_dice_pretrain_320_data.yaml -------------------------------------------------------------------------------- /script/U2Net/trian3.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | OMP_NUM_THREADS=1 3 | MKL_NUM_THREADS=1 4 | export OMP_NUM_THREADS 5 | export MKL_NUM_THREADS 6 | cd /data/jiangmingchao/data/code/SegmentationLight; 7 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -W ignore -m torch.distributed.launch \ 8 | --master_port 2953 \ 9 | --nproc_per_node 8 main.py \ 10 | --hyp /data/jiangmingchao/data/code/SegmentationLight/hyparam/U2Net/baseline_bce_dice_pretrain_320_data.yaml -------------------------------------------------------------------------------- /script/train_fcn.sh: -------------------------------------------------------------------------------- 1 | # #!/bin/bash 2 | # OMP_NUM_THREADS=1 3 | # MKL_NUM_THREADS=1 4 | # export OMP_NUM_THREADS 5 | # export MKL_NUM_THREADS 6 | # cd /data/jiangmingchao/data/code/SegmentationLight; 7 | # CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -W ignore -m torch.distributed.launch --nproc_per_node 8 main.py \ 8 | # --batch-size 32 \ 9 | # --num-workers 32 \ 10 | # --lr 1e-3 \ 11 | # --optim-name "sgd" \ 12 | # --cosine 1 \ 13 | # --fix 0 \ 14 | # --max-epochs 300 \ 15 | # --warmup-epochs 0 \ 16 | # --num-classes 21 \ 17 | # --crop-size 520 \ 18 | # --weight-decay 5e-4 \ 19 | # --train-file /data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/segclass/seg_train.log \ 20 | # --val-file /data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/segclass/seg_val.log \ 21 | # --ckpt-path /data/jiangmingchao/data/AICutDataset/Segmentation/FCN/fcn_baseline_lr_1e-3_cosine_aug_rotate_300epoch_flatten_ce_modify_fcn_520/checkpoints \ 22 | # --log-dir /data/jiangmingchao/data/AICutDataset/Segmentation/FCN/fcn_baseline_lr_1e-3_cosine_aug_rotate_300epoch_flatten_ce_modify_fcn_520/log_dir 23 | 24 | #!/bin/bash 25 | OMP_NUM_THREADS=1 26 | MKL_NUM_THREADS=1 27 | export OMP_NUM_THREADS 28 | export MKL_NUM_THREADS 29 | cd /data/jiangmingchao/data/code/SegmentationLight; 30 | CUDA_VISIBLE_DEVICES=0 python -W ignore -m torch.distributed.launch --nproc_per_node 1 main.py \ 31 | --batch-size 32 \ 32 | --num-workers 32 \ 33 | --lr 7e-3 \ 34 | --optim-name "sgd" \ 35 | --cosine 1 \ 36 | --fix 0 \ 37 | --max-epochs 50 \ 38 | --warmup-epochs 0 \ 39 | --num-classes 21 \ 40 | --crop-size 500 \ 41 | --weight-decay 5e-4 \ 42 | --train-file /data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/voc_aug/seg_train.log \ 43 | --val-file /data/jiangmingchao/data/code/SegmentationLight/datasets/voc2012/data/voc_aug/seg_val.log \ 44 | --ckpt-path /data/jiangmingchao/data/AICutDataset/Segmentation/deeplab/deeplab_ddp_8nodes/checkpoints \ 45 | --log-dir /data/jiangmingchao/data/AICutDataset/Segmentation/deeplab/deeplab_ddp_8nodes/log_dir -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | OMP_NUM_THREADS=1 3 | MKL_NUM_THREADS=1 4 | export OMP_NUM_THREADS 5 | export MKL_NUM_THREADS 6 | cd /data/jiangmingchao/data/code/SegmentationLight; 7 | CUDA_VISIBLE_DEVICES=0,1 python -W ignore -m torch.distributed.launch --master_port 2953 --nproc_per_node 2 main.py \ 8 | --hyp /data/jiangmingchao/data/code/SegmentationLight/hyparam/UNet/unet_resnet50.yaml -------------------------------------------------------------------------------- /utils/FuseAugments.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: FlyEgle 3 | @datetime: 2022-04-19 4 | @describe: Make Complex Augments like Mixup, Cutmix, Mosaic 5 | """ 6 | import cv2 7 | import torch 8 | import numpy as np 9 | import torch.nn as nn 10 | 11 | from typing import Any 12 | 13 | 14 | class MixUP: 15 | """Mixup for mixed two images 16 | Returns: 17 | mixed images, pairs of targets, lambda 18 | """ 19 | def __init__(self, alpha=1.0, cuda=True) -> None: 20 | self.alpha = alpha 21 | self.cuda = cuda 22 | 23 | def __call__(self, img, tgt): 24 | if self.alpha: 25 | lam = np.random.beta(self.alpha, self.alpha) 26 | else: 27 | lam = 1 28 | 29 | bs = img.shape[0] 30 | if self.cuda: 31 | index = torch.randperm(bs).cuda() 32 | else: 33 | index = torch.randperm(bs) 34 | 35 | mixed_img = lam * img + (1 - lam) * img[index, :] 36 | tgt_a, tgt_b = tgt, tgt[index] 37 | 38 | return mixed_img, tgt_a, tgt_b, lam 39 | 40 | 41 | class Shrink: 42 | """Shrink the mask small 1 or 2 pixel 43 | """ 44 | def __init__(self, shrink=1): 45 | self.shrink = shrink 46 | 47 | # TODO: Fast Implementation 48 | def __call__(self, img, tgt): 49 | tgt[tgt==255] = 1 50 | for _ in range(self.shrink): 51 | contours, hierachy = cv2.findContours( 52 | tgt, 53 | cv2.RETR_EXTERNAL, 54 | cv2.CHAIN_APPROX_SIMPLE 55 | ) 56 | cv2.drawContours( 57 | tgt, 58 | contours, 59 | -1, 60 | (125, 125, 125), 61 | 1 62 | ) 63 | tgt[tgt==np.array([125, 125, 125])] = 0 64 | 65 | return img, tgt 66 | 67 | 68 | class CutMix: 69 | def __init__(self) -> None: 70 | pass 71 | 72 | def __call__(self, img, tgt): 73 | pass 74 | 75 | 76 | class MixCriterion(nn.Module): 77 | def __init__(self, criterion): 78 | super(MixCriterion, self).__init__() 79 | self.criterion = criterion 80 | 81 | def forward(self, pred, tgts_a, tgts_b, lam): 82 | return lam * self.criterion(pred, tgts_a) + (1 - lam) * self.criterion(pred, tgts_b) 83 | 84 | 85 | class Mosaic: 86 | def __init__(self) -> None: 87 | pass 88 | 89 | def __call__(self, *args: Any, **kwds: Any) -> Any: 90 | pass -------------------------------------------------------------------------------- /utils/Loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: FlyEgle 3 | @datetime: 2022-04-07 4 | @describe: loss function manage 5 | """ 6 | import torch.nn as nn 7 | # general loss 8 | from losses.loss import CELoss, BCELoss, DiceWithBCELoss, Dice_Bce_L1, Dice_Ce_L1, IouWithDiceWithBCELoss, IOUWithBCELoss 9 | # generate loss 10 | from losses.generatorLoss import TVLoss 11 | 12 | 13 | # ------------------------ BCE + 3* DICE + 0.5 * TV ------------------------- 14 | class Bce_Dice_TvLoss(nn.Module): 15 | def __init__(self): 16 | super(Bce_Dice_TvLoss, self).__init__() 17 | self.weights = { 18 | 'bce': 1.0, 19 | 'dice': 3.0, 20 | 'tv': 0.1 21 | } 22 | self.dice_bce = DiceWithBCELoss(self.weights) 23 | self.tv = TVLoss() 24 | 25 | def forward(self, pred, tgts): 26 | if len(tgts.shape) == 3: 27 | tgts = tgts.unsqueeze(1).float() 28 | else: 29 | tgts = tgts.float() 30 | 31 | losses = self.dice_bce(pred, tgts) + self.weights['tv'] * self.tv(pred) 32 | return losses 33 | 34 | 35 | class LossBar: 36 | """Loss Bar 37 | """ 38 | def __init__(self, loss_name): 39 | self.loss_name = loss_name.lower() 40 | 41 | def __call__(self): 42 | weights = { 43 | 'bce': 1.0, 44 | 'dice': 1.0, 45 | 'iou': 1.0 46 | } 47 | if self.loss_name == "bce": 48 | criterion = BCELoss() 49 | elif self.loss_name == "ce": 50 | criterion = CELoss(flatten=False) 51 | elif self.loss_name == "bce+dice": 52 | weights = {'bce':1.0, 'dice': 1.0} 53 | criterion = DiceWithBCELoss(weights, mining=False) 54 | elif self.loss_name == "bce+dice+iou": 55 | weights = {'bce':1.0, 'dice':1.0, 'iou': 1.0} 56 | criterion = IouWithDiceWithBCELoss(weights, mining=False) 57 | elif self.loss_name == "bce+iou": 58 | criterion = IOUWithBCELoss(weights, mining=False) 59 | elif self.loss_name == "bce+dice+l1": 60 | weights = {'bce':1.0, 'dice': 3.0} 61 | criterion = Dice_Bce_L1(weights) 62 | elif self.loss_name == "ce+dice+l1": 63 | criterion = Dice_Ce_L1() 64 | elif self.loss_name == "bce+dice+tv": 65 | criterion = Bce_Dice_TvLoss() 66 | return criterion 67 | -------------------------------------------------------------------------------- /utils/LrSheduler.py: -------------------------------------------------------------------------------- 1 | """LR sheduler 2 | @author: FlyEgle 3 | @datetime: 2022-01-21 4 | """ 5 | import math 6 | 7 | 8 | def StepLR(opt, epoch, batch_iter, optimizer, train_batch): 9 | """Sets the learning rate 10 | # Adapted from PyTorch Imagenet example: 11 | # https://github.com/pytorch/examples/blob/master/imagenet/main.py 12 | """ 13 | total_epochs = opt.MAX_EPOCHS 14 | warm_epochs = opt.WARMUP_EPOCHS 15 | if epoch <= warm_epochs: 16 | lr_adj = (batch_iter + 1) / (warm_epochs * train_batch) 17 | elif epoch < int(0.3 * total_epochs): 18 | lr_adj = 1. 19 | elif epoch < int(0.6 * total_epochs): 20 | lr_adj = 1e-1 21 | elif epoch < int(0.8 * total_epochs): 22 | lr_adj = 1e-2 23 | else: 24 | lr_adj = 1e-3 25 | 26 | for param_group in optimizer.param_groups: 27 | param_group['lr'] = opt.OPTIMIZER.LEARNING_RATE * lr_adj 28 | return opt.OPTIMIZER.LEARNING_RATE * lr_adj 29 | 30 | 31 | def CosineLR(opt, epoch, batch_iter, optimizer, train_batch): 32 | """Cosine Learning rate 33 | """ 34 | total_epochs = opt.MAX_EPOCHS 35 | warm_epochs = opt.WARMUP_EPOCHS 36 | if epoch <= warm_epochs: 37 | lr_adj = (batch_iter + 1) / (warm_epochs * train_batch) + 1e-6 38 | else: 39 | lr_adj = 1/2 * (1 + math.cos(batch_iter * math.pi / ((total_epochs - warm_epochs) * train_batch))) + 1e-6 40 | 41 | for param_group in optimizer.param_groups: 42 | param_group['lr'] = opt.OPTIMIZER.LEARNING_RATE * lr_adj 43 | return opt.OPTIMIZER.LEARNING_RATE * lr_adj 44 | 45 | 46 | def FixLR(opt): 47 | """Fix Learning Rate 48 | """ 49 | lr_adj = 1 50 | return opt.OPTIMIZER.LEARNING_RATE * lr_adj 51 | 52 | 53 | def PolyLR(opt, epoch, batch_iter, optimizer, train_batch): 54 | """Poly LR follow the Rethinking Atrous Convolution for Semantic Image Segmentation 55 | """ 56 | total_batch = train_batch * opt.MAX_EPOCHS 57 | lr_adj = 1 - math.pow(batch_iter, total_batch) ** 0.9 58 | 59 | for param_group in optimizer.param_groups: 60 | param_group['lr'] = opt.OPTIMIZER.LEARNING_RATE * lr_adj 61 | 62 | return opt.OPTIMIZER.LEARNING_RATE * lr_adj 63 | 64 | 65 | # make the cricle need iter param 66 | def make_iter(opt, train_batch): 67 | total_epochs = opt.MAX_EPOCHS 68 | warm_epochs = opt.WARMUP_EPOCHS 69 | cricle = opt.OPTIMIZER.CRICLE_STEPS 70 | 71 | warm_iter = warm_epochs*train_batch 72 | epochs_list = [x+1 for x in range(total_epochs * train_batch)] 73 | # print(len(epochs_list)) 74 | cricle_epochs = int((total_epochs - warm_epochs) * train_batch / cricle) 75 | # print(cricle_epochs) 76 | cricle_epochs_list = [None for _ in range(cricle)] 77 | 78 | for i in range(cricle): 79 | if i == 0: 80 | cricle_epochs_list[i] = epochs_list[warm_iter: cricle_epochs + warm_iter] 81 | elif i == cricle - 1: 82 | cricle_epochs_list[i] = epochs_list[cricle_epochs*i + warm_iter: ] 83 | else: 84 | cricle_epochs_list[i] = epochs_list[cricle_epochs*i + warm_iter: cricle_epochs*(i+1)+ warm_iter] 85 | 86 | return cricle_epochs_list 87 | 88 | 89 | def CirCleLR(opt, epoch, batch_iter, optimizer, train_batch, cricle_epochs_list): 90 | """Circle Cosine LR + WarmUp 91 | """ 92 | warm_epochs = opt.WARMUP_EPOCHS 93 | 94 | warm_iter = warm_epochs * train_batch 95 | 96 | if epoch <= warm_epochs: 97 | lr_adj = (batch_iter + 1) / (warm_epochs * train_batch) + 1e-6 98 | else: 99 | for i in range(len(cricle_epochs_list)): 100 | # print(batch_iter) 101 | # print(cricle_epochs_list[0]) 102 | # restart the batchidx 103 | if i == 0: 104 | if (batch_iter+1) in cricle_epochs_list[i]: 105 | batch_idx = batch_iter + 1- warm_iter 106 | # print("batch_idx: ", batch_idx) 107 | lr_adj = 1/2 * (1 + math.cos(batch_idx * math.pi / (len(cricle_epochs_list[i])))) + 1e-6 108 | else: 109 | if (batch_iter+1) in cricle_epochs_list[i]: 110 | batch_idx = batch_iter + 1 - warm_iter - len(cricle_epochs_list[i]) * i 111 | # print("batch_idx: ", batch_idx) 112 | lr_adj = 1/2 * (1 + math.cos(batch_idx * math.pi / (len(cricle_epochs_list[i])))) + 1e-6 113 | 114 | for param_group in optimizer.param_groups: 115 | param_group['lr'] = opt.OPTIMIZER.LEARNING_RATE * lr_adj 116 | return opt.OPTIMIZER.LEARNING_RATE * lr_adj 117 | -------------------------------------------------------------------------------- /utils/Metirc.py: -------------------------------------------------------------------------------- 1 | """Metric for Segmentation 2 | - Pixel Accuracy(PA) 3 | - Mean Pixel Accuracy(MPA) 4 | - Mean Intersection over Union (MIoU) 5 | - todo Frequency Weighted Intersection over Union(FWIoU) 6 | 7 | @author: FlyEgle 8 | @datetime: 2022-01-20 9 | """ 10 | import cv2 11 | import torch 12 | import numpy as np 13 | 14 | 15 | def generate_mask(outputs, score=0.5): 16 | """build binary mask from outputs 17 | """ 18 | # N 1 H W 19 | outputs = torch.sigmoid(outputs) 20 | # N H W 21 | outputs[outputs > score] = 1 22 | outputs[outputs <= score] = 0 23 | return outputs.detach().squeeze(1).cpu().numpy().astype(np.int32) 24 | 25 | def confusion_matrix(y_true, y_pred): 26 | # this version is too slow to calculate the big image for segmentation 27 | classes = np.max(y_true) + 1 28 | matrix = np.zeros((classes, classes), dtype=np.int32) 29 | for i in range(len(y_true)): 30 | if y_true[i] == y_pred[i]: 31 | matrix[y_true[i], y_true[i]] += 1 32 | else: 33 | matrix[y_true[i], y_pred[i]] += 1 34 | 35 | return matrix 36 | 37 | # TODO: 存在如果当前批次没有类别的情况 38 | def fast_confusion_matrix(y_true, y_pred, num_classes): 39 | # this version is the faster than sklearn confusion matrix impelemation 40 | if not isinstance(y_true, np.ndarray): 41 | y_true = y_true.cpu().numpy() 42 | y_pred = y_pred.cpu().numpy() 43 | 44 | if len(y_true.shape) > 1: 45 | y_true = y_true.flatten() 46 | y_pred = y_pred.flatten() 47 | 48 | classes = num_classes 49 | numbers = classes * y_true + y_pred 50 | vector = np.bincount(numbers) 51 | matrix = vector.reshape((classes, classes)) 52 | return matrix 53 | 54 | 55 | def pixel_accuracy(confusion_matrix): 56 | """pixel accuracy, correct pixel / all pixel """ 57 | pa = np.diag(confusion_matrix).sum() / confusion_matrix.sum() 58 | return pa 59 | 60 | 61 | def mean_pixel_accuracy(confusion_matrix): 62 | """mean pixel accuracy""" 63 | Mpa = np.nanmean(np.diag(confusion_matrix) / np.sum(confusion_matrix, axis=1)) 64 | return Mpa 65 | 66 | 67 | def mean_intersection_over_union(confusion_matrix): 68 | """Mean Intersection over Union """ 69 | Union = np.sum(confusion_matrix, axis=0) + np.sum(confusion_matrix, axis=1) - np.diag(confusion_matrix) 70 | Inter = np.diag(confusion_matrix) 71 | IoU = Inter / Union 72 | MIoU = np.nanmean(IoU) 73 | return MIoU 74 | 75 | 76 | def frequency_weighted_intersection_over_union(confusion_matrix): 77 | """Intersection over Union with Classes weights 78 | """ 79 | Union = np.sum(confusion_matrix, axis=0) + np.sum(confusion_matrix, axis=1) - np.diag(confusion_matrix) 80 | Inter = np.diag(confusion_matrix) 81 | IoU = Inter / Union 82 | freq = np.sum(confusion_matrix, axis=1) / confusion_matrix.sum() 83 | FwIOU = (freq[freq>0] * IoU[freq>0]).sum() 84 | return FwIOU 85 | 86 | 87 | # translate the outputs & targets tensor to numpy 88 | def make_outputs(outputs): 89 | if isinstance(outputs, np.ndarray): 90 | return np.argmax(outputs, axis=1) 91 | else: 92 | return torch.argmax(outputs, axis=1).cpu().numpy() 93 | 94 | def make_targets(targets): 95 | if isinstance(targets, np.ndarray): 96 | return targets 97 | else: 98 | targets = targets.cpu().numpy() 99 | return targets 100 | 101 | 102 | # batch calculate 103 | def calc_semantic_segmentation_confusion(pred_labels, gt_labels, num_classes): 104 | """batch accumulate confusion matrix 105 | """ 106 | # debug the np.bincounts bug 107 | pred_labels[pred_labels<0] = 0 108 | 109 | pred_labels = iter(pred_labels) 110 | gt_labels = iter(gt_labels) 111 | 112 | n_class = num_classes 113 | confusion = np.zeros((n_class, n_class), dtype=np.int32) # (12, 12) 114 | 115 | for pred_label, gt_label in zip(pred_labels, gt_labels): 116 | 117 | if pred_label.ndim != 2 or gt_label.ndim != 2: 118 | raise ValueError('ndim of labels should be two.') 119 | 120 | if pred_label.shape != gt_label.shape: 121 | raise ValueError('Shape of ground truth and prediction should be same.') 122 | 123 | # Dynamically expand the confusion matrix if necessary. 124 | lb_max = np.max((pred_label, gt_label)) 125 | # print(lb_max) 126 | if lb_max >= n_class: #如果分类数大于预设的分类数目,则扩充一下。 127 | expanded_confusion = np.zeros((lb_max + 1, lb_max + 1), dtype=np.int32) 128 | expanded_confusion[0:n_class, 0:n_class] = confusion 129 | 130 | n_class = lb_max + 1 131 | confusion = expanded_confusion 132 | 133 | # Count statistics from valid pixels 134 | mask = gt_label >= 0 135 | confusion += np.bincount(n_class * gt_label[mask].astype(int) + pred_label[mask], minlength=n_class ** 2).reshape((n_class, n_class)) 136 | 137 | for iter_ in (pred_labels, gt_labels): 138 | # This code assumes any iterator does not contain None as its items. 139 | if next(iter_, None) is not None: 140 | raise ValueError('Length of input iterables need to be same') 141 | 142 | return confusion 143 | 144 | 145 | class SegmentationMetric: 146 | """Segmentation Metric, get the confusion for all the batch results, each batch record the binnums, only record current rank if use the ddp traininig 147 | """ 148 | def __init__(self, name, classes): 149 | self.name = name 150 | self.classes = classes 151 | self.batch_num = 0 152 | self.confusion_matrix = 0 153 | self.total_losses = 0.0 154 | 155 | def update(self, outputs, targets, losses, bce=False, score=0.5): 156 | if bce: 157 | outputs = generate_mask(outputs, score=score) 158 | else: 159 | outputs = make_outputs(outputs) 160 | 161 | targets = make_targets(targets) 162 | 163 | # print(outputs.shape) 164 | # print(targets.shape) 165 | # N,H,W 166 | assert outputs.shape == targets.shape, "outputs shape must be same with the targets" 167 | 168 | # losses 169 | if isinstance(losses, torch.Tensor): 170 | self.batch_losses = losses.data.item() 171 | else: 172 | self.batch_losses = losses 173 | 174 | self.total_losses += self.batch_losses 175 | 176 | # batch confusion matrix 177 | self.batch_confusion_matrix = calc_semantic_segmentation_confusion(outputs, targets, self.classes) 178 | 179 | # all batch accumulate confusion matrix 180 | self.confusion_matrix += self.batch_confusion_matrix 181 | # batch losses 182 | self.batch_num += 1 183 | 184 | def reset(self): 185 | self.batch_num = 0 186 | self.confusion_matrix = 0 187 | self.total_losses = 0.0 188 | 189 | @property 190 | def average(self): 191 | return self.total_losses / self.batch_num 192 | 193 | # batch 194 | @property 195 | def batch_metric(self): 196 | # print(self.batch_confusion_matrix) 197 | batch_pa, batch_mpa, batch_miou, batch_fwiou = self._cal_metirc(self.batch_confusion_matrix) 198 | return batch_pa, batch_mpa, batch_miou, batch_fwiou 199 | 200 | # epoch 201 | @property 202 | def epoch_metric(self): 203 | epoch_pa, epoch_mpa, epoch_miou, epoch_fwiou = self._cal_metirc(self.confusion_matrix) 204 | return epoch_pa, epoch_mpa, epoch_miou, epoch_fwiou 205 | 206 | # return pa, mpa, miou 207 | def _cal_metirc(self, confusion_matrix): 208 | pa = self._pa(confusion_matrix) 209 | mpa = self._mpa(confusion_matrix) 210 | miou = self._miou(confusion_matrix) 211 | fwiou = self._fwiou(confusion_matrix) 212 | return pa, mpa, miou, fwiou 213 | 214 | def _pa(self, confusion_matrix): 215 | pa = pixel_accuracy(confusion_matrix) 216 | return pa 217 | 218 | def _mpa(self, confusion_matrix): 219 | mpa = mean_pixel_accuracy(confusion_matrix) 220 | return mpa 221 | 222 | def _miou(self, confusion_matrix): 223 | miou = mean_intersection_over_union(confusion_matrix) 224 | return miou 225 | 226 | def _fwiou(self, confusion_matrix): 227 | fwiou = frequency_weighted_intersection_over_union(confusion_matrix) 228 | return fwiou 229 | 230 | 231 | if __name__ == "__main__": 232 | # batch_true = torch.empty(3, 4, 4, dtype=torch.long).random_(0, 21).numpy() 233 | # batch_pred = torch.empty(3, 4, 4, dtype=torch.long).random_(0, 21).numpy() 234 | 235 | batch_true = np.array([ 236 | [[0,1,1],[1,2,2],[1,1,2]], 237 | [[0,1,1],[1,2,2],[1,1,2]], 238 | [[0,1,1],[1,2,2],[1,1,2]], 239 | ]) 240 | 241 | 242 | batch_pred = np.array([ 243 | [[0,1,0],[1,2,1],[0,1,0]], 244 | [[0,1,0],[1,2,1],[0,1,0]], 245 | [[0,1,0],[1,2,1],[0,1,0]], 246 | ]) 247 | 248 | matrix = calc_semantic_segmentation_confusion(batch_pred, batch_true, num_classes=21) 249 | pa = pixel_accuracy(matrix) 250 | print(pa) 251 | mpa = mean_pixel_accuracy(matrix) 252 | print(mpa) 253 | miou = mean_intersection_over_union(matrix) 254 | print(miou) 255 | fwiou = frequency_weighted_intersection_over_union(matrix) 256 | print(fwiou) 257 | 258 | 259 | -------------------------------------------------------------------------------- /utils/Optim.py: -------------------------------------------------------------------------------- 1 | """Optimizer 2 | @author: FlyEgle 3 | @datetime: 2022-01-20 4 | """ 5 | from torch.optim import SGD, Adam, AdamW 6 | 7 | 8 | class BuildOptim: 9 | def __init__(self, optim_name, lr, weight_decay, momentum, betas=(0.9, 0.999), eps=1e-8): 10 | _names_ = ['sgd', 'adam', 'adamw'] 11 | if not optim_name.lower() in _names_: 12 | raise NotImplementedError(f"{optim_name} have not been implemented") 13 | self.optim_name = optim_name 14 | self.lr = lr 15 | self.betas = betas 16 | self.eps = eps 17 | self.weight_decay = weight_decay 18 | self.momentum = momentum 19 | 20 | def _sgd(self, params): 21 | return SGD( 22 | params, 23 | lr=self.lr, 24 | weight_decay=self.weight_decay, 25 | momentum=self.momentum 26 | ) 27 | 28 | def _adam(self, params): 29 | return Adam( 30 | params, 31 | lr=self.lr, 32 | betas=self.betas, 33 | eps=self.eps, 34 | weight_decay=self.weight_decay 35 | ) 36 | 37 | def _adamw(self, params): 38 | return AdamW( 39 | params, 40 | lr=self.lr, 41 | betas=self.betas, 42 | eps=self.eps, 43 | weight_decay=self.weight_decay 44 | ) 45 | 46 | def __call__(self, *args): 47 | if self.optim_name.lower() == "sgd": 48 | return self._sgd(*args) 49 | elif self.optim_name.lower() == "adam": 50 | return self._adam(*args) 51 | elif self.optim_name.lower() == "adamw": 52 | return self._adamw(*args) 53 | else: 54 | raise NotImplementedError(f"{self.optim_name} have not been implementation") -------------------------------------------------------------------------------- /utils/Summary.py: -------------------------------------------------------------------------------- 1 | """Record the model Training log or image generate 2 | @author: FlyEgle 3 | @datetime: 2022-01-21 4 | """ 5 | import os 6 | import json 7 | import torch 8 | import numpy as np 9 | from tensorboardX import SummaryWriter 10 | 11 | 12 | class LoggerRecord: 13 | def __init__(self, log_path) -> None: 14 | if os.path.exists(log_path): 15 | if not os.path.isdir(log_path): 16 | raise FileExistsError(f"{log_path} must be a folder!!!") 17 | else: 18 | os.makedirs(log_path) 19 | 20 | self.logger = SummaryWriter(log_path) 21 | 22 | def BatchMetricWritter(self, losses, pa, mpa, miou, fwiou, lr, batch_time, batch_iter, flag): 23 | self.logger.add_scalar(f"{flag}/batch/Loss", losses.data.item(), batch_iter) 24 | self.logger.add_scalar(f"{flag}/batch/PA", pa, batch_iter) 25 | self.logger.add_scalar(f"{flag}/batch/MPA", mpa, batch_iter) 26 | self.logger.add_scalar(f"{flag}/batch/MIoU", miou, batch_iter) 27 | self.logger.add_scalar(f"{flag}/batch/FwIoU", miou, batch_iter) 28 | self.logger.add_scalar(f"{flag}/batch/LearningRate", lr, batch_iter) 29 | self.logger.add_scalar(f"{flag}/batch/Times", batch_time, batch_iter) 30 | 31 | def EpochMetricWritter(self, losses, pa, mpa, miou, fwiou, epoch, flag): 32 | self.logger.add_scalar(f"{flag}/epoch/Loss", losses, epoch) 33 | self.logger.add_scalar(f"{flag}/epoch/PA", pa, epoch) 34 | self.logger.add_scalar(f"{flag}/epoch/MPA", mpa, epoch) 35 | self.logger.add_scalar(f"{flag}/epoch/MIoU", miou, epoch) 36 | self.logger.add_scalar(f"{flag}/epoch/FwIoU", miou, epoch) 37 | 38 | 39 | # todo 40 | def ImageCallback(self): 41 | raise NotImplementedError 42 | 43 | 44 | class CkptRecord: 45 | def __init__(self, ckpt_path): 46 | if os.path.exists(ckpt_path): 47 | if not os.path.isdir(ckpt_path): 48 | raise FileExistsError("f{ckpt_path} must be a folder!!!") 49 | else: 50 | os.makedirs(ckpt_path) 51 | 52 | self.best_losses = np.inf 53 | self.best_pa = 0.0 54 | self.best_mpa = 0.0 55 | self.best_miou = 0.0 56 | 57 | self.ckpt_queue = [] 58 | self.ckpt_path = ckpt_path 59 | 60 | def SaveBestMIOU(self, state_dict, losses, miou, maxNum=10): 61 | """save only final 5 best ckpt""" 62 | 63 | if miou > self.best_miou: 64 | self.best_miou = miou 65 | output_name = f"best_ckpt_losses_{losses}_miou_{miou}.pth" 66 | save_path = os.path.join(self.ckpt_path, output_name) 67 | torch.save(state_dict, save_path) 68 | 69 | if len(self.ckpt_queue) <= maxNum: 70 | if os.path.exists(save_path): 71 | self.ckpt_queue.append(save_path) 72 | else: 73 | print(f"{save_path} is not exists!!!") 74 | else: 75 | save_path = self.ckpt_queue.pop(0) 76 | os.remove(save_path) 77 | 78 | def SaveBestCkpt(self, state_dict, epoch, losses, miou): 79 | loss_postive = False 80 | miou_postive = False 81 | 82 | if losses < self.best_losses: 83 | self.best_losses = losses 84 | loss_postive = True 85 | 86 | if miou > self.best_miou: 87 | self.best_miou = miou 88 | miou_postive = True 89 | 90 | if loss_postive or miou_postive: 91 | output_name = f"best_ckpt_epoch_{epoch}_losses_{losses}_miou_{miou}.pth" 92 | torch.save(state_dict, os.path.join(self.ckpt_path, output_name)) 93 | 94 | 95 | def SaveCkpt(self, state_dict, losses, pa, mpa, miou): 96 | output_name = f"ckpt_losses_{losses}_pa_{pa}_mpa_{mpa}_miou_{miou}.pth" 97 | torch.save(state_dict, os.path.join(self.ckpt_path, output_name)) 98 | 99 | 100 | # TODO: draw confusion matrix 101 | def draw_confusion_matrix(confusion_matrix): 102 | pass 103 | 104 | 105 | class LoggerInfo: 106 | """save the training log 107 | """ 108 | def __init__(self, save_path): 109 | self.save_path = save_path 110 | if not os.path.exists(self.save_path): 111 | os.makedirs(self.save_path) 112 | self.logger_name = "train.log" 113 | self.logger_path = os.path.join(self.save_path, self.logger_name) 114 | 115 | def write(self, result): 116 | with open(self.logger_path, "a") as file: 117 | file.write(json.dumps(result) + '\n') 118 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FlyEgle/segmentationlight/45c3e574f578bac046bd6027d2f3dbb7d106e015/utils/__init__.py -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def Load_state_dict(ckpt, net): 6 | state_dict = torch.load(ckpt, map_location="cpu")['state_dict'] 7 | model_state_dict = net.state_dict() 8 | new_state_dict = {} 9 | 10 | count = 0 11 | total = 0 12 | for k, v in model_state_dict.items(): 13 | total += 1 14 | if k in state_dict and v.shape == state_dict[k].shape: 15 | new_state_dict[k] = state_dict[k] 16 | count += 1 17 | 18 | print(f"{count}/{total}kernel have load the pretrian!!!") 19 | 20 | return new_state_dict 21 | 22 | 23 | def check_tensor(vector, name=None): 24 | if isinstance(vector, torch.Tensor): 25 | if torch.any(torch.isnan(vector)) and torch.any(torch.isinf(vector)): 26 | return True 27 | elif isinstance(vector, np.ndarray): 28 | if np.any(np.isnan(vector)): 29 | return True 30 | elif isinstance(vector, list): 31 | vector = np.asarray(vector, dtype=np.float32) 32 | if np.any(np.isnan(vector)): 33 | return True 34 | 35 | 36 | if __name__ == '__main__': 37 | inputs = torch.zeros(1,1,4,4) 38 | inputs[:,:,0,0] = torch.inf 39 | 40 | inputs[:,:,0,1] = torch.nan 41 | print(inputs) 42 | print(check_tensor(inputs)) 43 | 44 | --------------------------------------------------------------------------------