├── segmentron ├── data │ ├── __init__.py │ ├── downloader │ │ ├── __init__.py │ │ ├── ade20k.py │ │ ├── cityscapes.py │ │ ├── sbu_shadow.py │ │ ├── mscoco.py │ │ └── pascal_voc.py │ └── dataloader │ │ ├── __init__.py │ │ ├── utils.py │ │ ├── sbu_shadow.py │ │ ├── lip_parsing.py │ │ ├── pascal_aug.py │ │ ├── pascal_voc.py │ │ ├── seg_data_base.py │ │ ├── mscoco.py │ │ └── cityscapes.py ├── solver │ ├── __init__.py │ └── optimizer.py ├── config │ ├── __init__.py │ └── config.py ├── __init__.py ├── modules │ ├── __init__.py │ ├── csrc │ │ ├── vision.cpp │ │ └── criss_cross_attention │ │ │ └── ca.h │ ├── cc_attention.py │ └── sync_bn │ │ └── syncbn.py ├── utils │ ├── __init__.py │ ├── env.py │ ├── logger.py │ ├── options.py │ ├── default_setup.py │ ├── filesystem.py │ ├── registry.py │ ├── download.py │ └── parallel.py └── models │ ├── backbones │ ├── __init__.py │ ├── build.py │ ├── mobilenet.py │ └── eespnet.py │ ├── __init__.py │ ├── fcn.py │ ├── model_zoo.py │ ├── pspnet.py │ ├── hrnet_seg.py │ ├── deeplabv3_plus.py │ ├── ccnet.py │ ├── danet.py │ ├── icnet.py │ ├── espnetv2.py │ ├── dfanet.py │ ├── dunet.py │ ├── unet.py │ ├── denseaspp.py │ ├── segbase.py │ ├── edanet.py │ ├── refinenet.py │ ├── dabnet.py │ └── lednet.py ├── tools ├── demo_vis.png ├── dist_test.sh ├── dist_train.sh ├── demo.py └── eval.py ├── docs ├── images │ └── demo.png └── DATA_PREPARE.md ├── configs ├── cityscapes_cgnet.yaml ├── cityscapes_lednet.yaml ├── cityscapes_hardnet.yaml ├── cityscapes_fcn.yaml ├── cityscapes_dunet.yaml ├── cityscapes_encnet.yaml ├── cityscapes_ocnet.yaml ├── cityscapes_bisenet.yaml ├── cityscapes_ccnet_resnet.yaml ├── cityscapes_dense_aspp.yaml ├── cityscapes_dfanet.yaml ├── coco_deeplabv3_plus.yaml ├── ade20k_deeplabv3_plus.yaml ├── cityscapes_enet.yaml ├── cityscapes_unet.yaml ├── cityscapes_fpenet.yaml ├── cityscapes_icnet_resnet.yaml ├── cityscapes_contextnet.yaml ├── cityscapes_refinenet.yaml ├── cityscapes_dabnet.yaml ├── cityscapes_espnetv2.yaml ├── cityscapes_deeplabv3_plus_resnet.yaml ├── cityscapes_deeplabv3_plus.yaml ├── cityscapes_danet_resnet.yaml ├── cityscapes_deeplabv3_plus_mobilenet.yaml ├── pascal_aug_deeplabv3_plus.yaml ├── cityscapes_pointrend_deeplabv3_plus.yaml ├── pascal_voc_deeplabv3_plus.yaml ├── cityscapes_pspnet_resnet.yaml ├── cityscapes_fast_scnn.yaml ├── coco_deeplabv3_plus_mobilenet.yaml ├── cityscapes_hrnet_w18_small_v1.yaml └── cityscapes_hrnet.yaml ├── .gitignore ├── setup.py └── README.md /segmentron/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /segmentron/solver/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /segmentron/data/downloader/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /segmentron/config/__init__.py: -------------------------------------------------------------------------------- 1 | from .settings import cfg -------------------------------------------------------------------------------- /segmentron/__init__.py: -------------------------------------------------------------------------------- 1 | from . import modules, models, utils, data -------------------------------------------------------------------------------- /tools/demo_vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LikeLy-Journey/SegmenTron/HEAD/tools/demo_vis.png -------------------------------------------------------------------------------- /docs/images/demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LikeLy-Journey/SegmenTron/HEAD/docs/images/demo.png -------------------------------------------------------------------------------- /segmentron/modules/__init__.py: -------------------------------------------------------------------------------- 1 | """Seg NN Modules""" 2 | 3 | from .basic import * 4 | from .module import * 5 | from .batch_norm import get_norm -------------------------------------------------------------------------------- /segmentron/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Utility functions.""" 2 | from __future__ import absolute_import 3 | 4 | from .download import download, check_sha1 5 | from .filesystem import makedirs 6 | -------------------------------------------------------------------------------- /segmentron/models/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import BACKBONE_REGISTRY, get_segmentation_backbone 2 | from .xception import * 3 | from .mobilenet import * 4 | from .resnet import * 5 | from .hrnet import * 6 | from .eespnet import * 7 | -------------------------------------------------------------------------------- /tools/dist_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | PYTHON=${PYTHON:-"python"} 4 | 5 | CONFIG=$1 6 | GPUS=$2 7 | 8 | $PYTHON -m torch.distributed.launch --nproc_per_node=$GPUS \ 9 | $(dirname "$0")/eval.py --config-file $CONFIG ${@:3} 10 | -------------------------------------------------------------------------------- /tools/dist_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | PYTHON=${PYTHON:-"python"} 4 | 5 | CONFIG=$1 6 | GPUS=$2 7 | 8 | $PYTHON -m torch.distributed.launch --nproc_per_node=$GPUS \ 9 | $(dirname "$0")/train.py --config-file $CONFIG ${@:3} 10 | -------------------------------------------------------------------------------- /configs/cityscapes_cgnet.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "cityscape" 3 | MEAN: [0.485, 0.456, 0.406] 4 | STD: [0.229, 0.224, 0.225] 5 | TRAIN: 6 | EPOCHS: 240 7 | BATCH_SIZE: 4 8 | CROP_SIZE: 768 9 | TEST: 10 | BATCH_SIZE: 4 11 | 12 | SOLVER: 13 | LR: 0.003 14 | 15 | MODEL: 16 | MODEL_NAME: "CGNet" 17 | 18 | -------------------------------------------------------------------------------- /configs/cityscapes_lednet.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "cityscape" 3 | MEAN: [0.485, 0.456, 0.406] 4 | STD: [0.229, 0.224, 0.225] 5 | TRAIN: 6 | EPOCHS: 240 7 | BATCH_SIZE: 4 8 | CROP_SIZE: 768 9 | TEST: 10 | BATCH_SIZE: 4 11 | 12 | SOLVER: 13 | LR: 0.003 14 | 15 | MODEL: 16 | MODEL_NAME: "LEDNet" 17 | 18 | -------------------------------------------------------------------------------- /configs/cityscapes_hardnet.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "cityscape" 3 | MEAN: [0.485, 0.456, 0.406] 4 | STD: [0.229, 0.224, 0.225] 5 | TRAIN: 6 | EPOCHS: 500 7 | BATCH_SIZE: 16 8 | CROP_SIZE: 1024 9 | TEST: 10 | BATCH_SIZE: 8 11 | 12 | SOLVER: 13 | LR: 0.02 14 | WEIGHT_DECAY: 5e-4 15 | 16 | MODEL: 17 | MODEL_NAME: "HardNet" 18 | 19 | -------------------------------------------------------------------------------- /configs/cityscapes_fcn.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "cityscape" 3 | MEAN: [0.485, 0.456, 0.406] 4 | STD: [0.229, 0.224, 0.225] 5 | TRAIN: 6 | EPOCHS: 400 7 | BATCH_SIZE: 4 8 | CROP_SIZE: 769 9 | TEST: 10 | BATCH_SIZE: 4 11 | CROP_SIZE: (1025, 2049) 12 | 13 | SOLVER: 14 | LR: 0.02 15 | 16 | MODEL: 17 | MODEL_NAME: "FCN" 18 | BACKBONE: "resnet101" 19 | -------------------------------------------------------------------------------- /configs/cityscapes_dunet.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "cityscape" 3 | MEAN: [0.485, 0.456, 0.406] 4 | STD: [0.229, 0.224, 0.225] 5 | TRAIN: 6 | EPOCHS: 240 7 | BATCH_SIZE: 4 8 | CROP_SIZE: 768 9 | TEST: 10 | BATCH_SIZE: 4 11 | 12 | SOLVER: 13 | LR: 0.003 14 | 15 | MODEL: 16 | MODEL_NAME: "DUNet" 17 | BACKBONE: "resnet50" 18 | OUTPUT_STRIDE: 8 19 | 20 | -------------------------------------------------------------------------------- /configs/cityscapes_encnet.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "cityscape" 3 | MEAN: [0.485, 0.456, 0.406] 4 | STD: [0.229, 0.224, 0.225] 5 | TRAIN: 6 | EPOCHS: 240 7 | BATCH_SIZE: 4 8 | CROP_SIZE: 768 9 | TEST: 10 | BATCH_SIZE: 4 11 | 12 | SOLVER: 13 | LR: 0.003 14 | 15 | MODEL: 16 | MODEL_NAME: "EncNet" 17 | BACKBONE: "resnet50" 18 | OUTPUT_STRIDE: 16 19 | 20 | -------------------------------------------------------------------------------- /configs/cityscapes_ocnet.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "cityscape" 3 | MEAN: [0.485, 0.456, 0.406] 4 | STD: [0.229, 0.224, 0.225] 5 | TRAIN: 6 | EPOCHS: 240 7 | BATCH_SIZE: 4 8 | CROP_SIZE: 768 9 | TEST: 10 | BATCH_SIZE: 4 11 | 12 | SOLVER: 13 | LR: 0.003 14 | 15 | MODEL: 16 | MODEL_NAME: "OCNet" 17 | BACKBONE: "resnet50" 18 | OUTPUT_STRIDE: 16 19 | 20 | -------------------------------------------------------------------------------- /configs/cityscapes_bisenet.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "cityscape" 3 | MEAN: [0.485, 0.456, 0.406] 4 | STD: [0.229, 0.224, 0.225] 5 | TRAIN: 6 | EPOCHS: 240 7 | BATCH_SIZE: 8 8 | CROP_SIZE: 768 9 | TEST: 10 | BATCH_SIZE: 8 11 | 12 | SOLVER: 13 | LR: 0.003 14 | 15 | MODEL: 16 | MODEL_NAME: "BiSeNet" 17 | BACKBONE: "resnet18" 18 | OUTPUT_STRIDE: 16 19 | 20 | -------------------------------------------------------------------------------- /configs/cityscapes_ccnet_resnet.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "cityscape" 3 | MEAN: [0.485, 0.456, 0.406] 4 | STD: [0.229, 0.224, 0.225] 5 | TRAIN: 6 | EPOCHS: 240 7 | BATCH_SIZE: 4 8 | CROP_SIZE: 768 9 | TEST: 10 | BATCH_SIZE: 2 11 | 12 | SOLVER: 13 | LR: 0.003 14 | 15 | MODEL: 16 | MODEL_NAME: "CCNet" 17 | BACKBONE: "resnet101" 18 | OUTPUT_STRIDE: 16 19 | 20 | -------------------------------------------------------------------------------- /configs/cityscapes_dense_aspp.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "cityscape" 3 | MEAN: [0.485, 0.456, 0.406] 4 | STD: [0.229, 0.224, 0.225] 5 | TRAIN: 6 | EPOCHS: 240 7 | BATCH_SIZE: 4 8 | CROP_SIZE: 768 9 | TEST: 10 | BATCH_SIZE: 4 11 | 12 | SOLVER: 13 | LR: 0.003 14 | 15 | MODEL: 16 | MODEL_NAME: "DenseASPP" 17 | BACKBONE: "resnet101" 18 | OUTPUT_STRIDE: 16 19 | 20 | -------------------------------------------------------------------------------- /configs/cityscapes_dfanet.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "cityscape" 3 | MEAN: [0.485, 0.456, 0.406] 4 | STD: [0.229, 0.224, 0.225] 5 | TRAIN: 6 | EPOCHS: 400 7 | BATCH_SIZE: 4 8 | CROP_SIZE: 768 9 | TEST: 10 | BATCH_SIZE: 4 11 | CROP_SIZE: (1024, 2048) 12 | 13 | SOLVER: 14 | LR: 0.02 15 | 16 | MODEL: 17 | MODEL_NAME: "DFANet" 18 | BACKBONE: "xception_a" 19 | 20 | -------------------------------------------------------------------------------- /configs/coco_deeplabv3_plus.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "coco" 3 | MEAN: [0.5, 0.5, 0.5] 4 | STD: [0.5, 0.5, 0.5] 5 | MODE: 'val' 6 | TRAIN: 7 | EPOCHS: 30 8 | BATCH_SIZE: 8 9 | CROP_SIZE: 480 10 | BASE_SIZE: 520 11 | TEST: 12 | BATCH_SIZE: 8 13 | 14 | SOLVER: 15 | LR: 0.01 16 | 17 | MODEL: 18 | MODEL_NAME: "DeepLabV3_Plus" 19 | BACKBONE: "xception65" 20 | BN_EPS_FOR_ENCODER: 1e-3 21 | -------------------------------------------------------------------------------- /configs/ade20k_deeplabv3_plus.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "ade20k" 3 | MEAN: [0.5, 0.5, 0.5] 4 | STD: [0.5, 0.5, 0.5] 5 | MODE: 'val' 6 | TRAIN: 7 | EPOCHS: 120 8 | BATCH_SIZE: 8 9 | CROP_SIZE: 480 10 | BASE_SIZE: 520 11 | TEST: 12 | BATCH_SIZE: 8 13 | 14 | SOLVER: 15 | LR: 0.01 16 | 17 | MODEL: 18 | MODEL_NAME: "DeepLabV3_Plus" 19 | BACKBONE: "xception65" 20 | BN_EPS_FOR_ENCODER: 1e-3 21 | -------------------------------------------------------------------------------- /configs/cityscapes_enet.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "cityscape" 3 | MEAN: [0.485, 0.456, 0.406] 4 | STD: [0.229, 0.224, 0.225] 5 | TRAIN: 6 | EPOCHS: 400 7 | BATCH_SIZE: 4 8 | CROP_SIZE: 768 9 | TEST: 10 | BATCH_SIZE: 4 11 | CROP_SIZE: (1024, 2048) 12 | # TEST_MODEL_PATH: trained_models/deeplabv3_plus_resnet101_segmentron.pth 13 | 14 | SOLVER: 15 | LR: 0.001 16 | 17 | MODEL: 18 | MODEL_NAME: "ENet" 19 | 20 | -------------------------------------------------------------------------------- /configs/cityscapes_unet.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "cityscape" 3 | MEAN: [0.485, 0.456, 0.406] 4 | STD: [0.229, 0.224, 0.225] 5 | TRAIN: 6 | EPOCHS: 400 7 | BATCH_SIZE: 2 8 | CROP_SIZE: 512 9 | TEST: 10 | BATCH_SIZE: 2 11 | CROP_SIZE: (1024, 2048) 12 | # TEST_MODEL_PATH: trained_models/deeplabv3_plus_resnet101_segmentron.pth 13 | 14 | SOLVER: 15 | LR: 0.001 16 | 17 | MODEL: 18 | MODEL_NAME: "UNet" 19 | 20 | -------------------------------------------------------------------------------- /configs/cityscapes_fpenet.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "cityscape" 3 | MEAN: [0.485, 0.456, 0.406] 4 | STD: [0.229, 0.224, 0.225] 5 | TRAIN: 6 | EPOCHS: 400 7 | BATCH_SIZE: 4 8 | CROP_SIZE: 768 9 | TEST: 10 | BATCH_SIZE: 4 11 | CROP_SIZE: (1024, 2048) 12 | # TEST_MODEL_PATH: trained_models/deeplabv3_plus_resnet101_segmentron.pth 13 | 14 | SOLVER: 15 | LR: 0.001 16 | 17 | MODEL: 18 | MODEL_NAME: "FPENet" 19 | 20 | -------------------------------------------------------------------------------- /configs/cityscapes_icnet_resnet.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "cityscape" 3 | MEAN: [0.485, 0.456, 0.406] 4 | STD: [0.229, 0.224, 0.225] 5 | TRAIN: 6 | EPOCHS: 240 7 | BATCH_SIZE: 4 8 | CROP_SIZE: 768 9 | BACKBONE_PRETRAINED: False 10 | TEST: 11 | BATCH_SIZE: 4 12 | 13 | SOLVER: 14 | LR: 0.01 15 | 16 | MODEL: 17 | MODEL_NAME: "ICNet" 18 | BACKBONE: "resnet50" 19 | OUTPUT_STRIDE: 8 20 | BACKBONE_SCALE: 0.5 21 | -------------------------------------------------------------------------------- /configs/cityscapes_contextnet.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "cityscape" 3 | MEAN: [0.485, 0.456, 0.406] 4 | STD: [0.229, 0.224, 0.225] 5 | TRAIN: 6 | EPOCHS: 400 7 | BATCH_SIZE: 4 8 | CROP_SIZE: 768 9 | TEST: 10 | BATCH_SIZE: 4 11 | CROP_SIZE: (1024, 2048) 12 | # TEST_MODEL_PATH: trained_models/deeplabv3_plus_resnet101_segmentron.pth 13 | 14 | SOLVER: 15 | LR: 0.045 16 | 17 | MODEL: 18 | MODEL_NAME: "ContextNet" 19 | 20 | -------------------------------------------------------------------------------- /segmentron/modules/csrc/vision.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include "criss_cross_attention/ca.h" 3 | 4 | namespace segmentron { 5 | 6 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 7 | m.def("ca_forward", &ca_forward, "ca_forward"); 8 | m.def("ca_backward", &ca_backward, "ca_backward"); 9 | m.def("ca_map_forward", &ca_map_forward, "ca_map_forward"); 10 | m.def("ca_map_backward", &ca_map_backward, "ca_map_backward"); 11 | } 12 | 13 | } // namespace segmentron 14 | -------------------------------------------------------------------------------- /configs/cityscapes_refinenet.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "cityscape" 3 | MEAN: [0.485, 0.456, 0.406] 4 | STD: [0.229, 0.224, 0.225] 5 | TRAIN: 6 | EPOCHS: 400 7 | BATCH_SIZE: 4 8 | CROP_SIZE: 769 9 | TEST: 10 | BATCH_SIZE: 4 11 | CROP_SIZE: (1025, 2049) 12 | # TEST_MODEL_PATH: trained_models/deeplabv3_plus_resnet101_segmentron.pth 13 | 14 | SOLVER: 15 | LR: 0.001 16 | 17 | MODEL: 18 | MODEL_NAME: "RefineNet" 19 | BACKBONE: "resnet101" 20 | 21 | -------------------------------------------------------------------------------- /configs/cityscapes_dabnet.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "cityscape" 3 | MEAN: [0.485, 0.456, 0.406] 4 | STD: [0.229, 0.224, 0.225] 5 | TRAIN: 6 | EPOCHS: 400 7 | BATCH_SIZE: 4 8 | CROP_SIZE: (512, 1024) 9 | TEST: 10 | BATCH_SIZE: 4 11 | CROP_SIZE: (1024, 2048) 12 | # TEST_MODEL_PATH: trained_models/deeplabv3_plus_resnet101_segmentron.pth 13 | 14 | SOLVER: 15 | LR: 0.001 16 | 17 | MODEL: 18 | MODEL_NAME: "DABNet" 19 | BACKBONE: "resnet101" 20 | 21 | -------------------------------------------------------------------------------- /configs/cityscapes_espnetv2.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "cityscape" 3 | MEAN: [0.485, 0.456, 0.406] 4 | STD: [0.229, 0.224, 0.225] 5 | TRAIN: 6 | EPOCHS: 400 7 | BATCH_SIZE: 4 8 | CROP_SIZE: (512, 1024) 9 | TEST: 10 | BATCH_SIZE: 4 11 | CROP_SIZE: (1024, 2048) 12 | # TEST_MODEL_PATH: trained_models/deeplabv3_plus_resnet101_segmentron.pth 13 | 14 | SOLVER: 15 | LR: 0.001 16 | 17 | MODEL: 18 | MODEL_NAME: "ESPNetV2" 19 | BACKBONE: "eespnet" 20 | 21 | -------------------------------------------------------------------------------- /configs/cityscapes_deeplabv3_plus_resnet.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "cityscape" 3 | MEAN: [0.485, 0.456, 0.406] 4 | STD: [0.229, 0.224, 0.225] 5 | TRAIN: 6 | EPOCHS: 400 7 | BATCH_SIZE: 4 8 | CROP_SIZE: 769 9 | TEST: 10 | BATCH_SIZE: 4 11 | CROP_SIZE: (1025, 2049) 12 | # TEST_MODEL_PATH: trained_models/deeplabv3_plus_resnet101_segmentron.pth 13 | 14 | SOLVER: 15 | LR: 0.02 16 | 17 | MODEL: 18 | MODEL_NAME: "DeepLabV3_Plus" 19 | BACKBONE: "resnet101" 20 | 21 | -------------------------------------------------------------------------------- /configs/cityscapes_deeplabv3_plus.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "cityscape" 3 | MEAN: [0.5, 0.5, 0.5] 4 | STD: [0.5, 0.5, 0.5] 5 | TRAIN: 6 | EPOCHS: 400 7 | BATCH_SIZE: 4 8 | CROP_SIZE: 769 9 | TEST: 10 | BATCH_SIZE: 4 11 | CROP_SIZE: (1025, 2049) 12 | # TEST_MODEL_PATH: trained_models/deeplabv3_plus_xception_segmentron.pth 13 | 14 | SOLVER: 15 | LR: 0.02 16 | 17 | MODEL: 18 | MODEL_NAME: "DeepLabV3_Plus" 19 | BACKBONE: "xception65" 20 | BN_EPS_FOR_ENCODER: 1e-3 21 | -------------------------------------------------------------------------------- /configs/cityscapes_danet_resnet.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "cityscape" 3 | MEAN: [0.485, 0.456, 0.406] 4 | STD: [0.229, 0.224, 0.225] 5 | TRAIN: 6 | EPOCHS: 240 7 | BATCH_SIZE: 2 8 | CROP_SIZE: 768 9 | TEST: 10 | BATCH_SIZE: 1 11 | # TEST_MODEL_PATH: trained_models/danet101_segmentron.pth 12 | 13 | SOLVER: 14 | LR: 0.003 15 | 16 | MODEL: 17 | MODEL_NAME: "DANet" 18 | BACKBONE: "resnet101" 19 | OUTPUT_STRIDE: 8 20 | DANET: 21 | MULTI_GRID: True 22 | MULTI_DILATION: [4, 8, 16] 23 | -------------------------------------------------------------------------------- /configs/cityscapes_deeplabv3_plus_mobilenet.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "cityscape" 3 | MEAN: [0.485, 0.456, 0.406] 4 | STD: [0.229, 0.224, 0.225] 5 | TRAIN: 6 | EPOCHS: 400 7 | BATCH_SIZE: 8 8 | CROP_SIZE: 769 9 | 10 | TEST: 11 | BATCH_SIZE: 8 12 | CROP_SIZE: (1025, 2049) 13 | 14 | SOLVER: 15 | LR: 0.02 16 | 17 | MODEL: 18 | MODEL_NAME: "DeepLabV3_Plus" 19 | BACKBONE: "mobilenet_v2" 20 | OUTPUT_STRIDE: 16 21 | DEEPLABV3_PLUS: 22 | USE_ASPP: False 23 | ENABLE_DECODER: False 24 | 25 | -------------------------------------------------------------------------------- /configs/pascal_aug_deeplabv3_plus.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "pascal_aug" 3 | MEAN: [0.5, 0.5, 0.5] 4 | STD: [0.5, 0.5, 0.5] 5 | MODE: 'val' 6 | TRAIN: 7 | EPOCHS: 50 8 | BATCH_SIZE: 8 9 | CROP_SIZE: 480 10 | BASE_SIZE: 520 11 | # PRETRAINED_MODEL_PATH: ./runs/checkpoints/DeepLabV3_Plus_xception65_coco_2019-11-25-13-09/best_model.pth 12 | TEST: 13 | BATCH_SIZE: 8 14 | 15 | SOLVER: 16 | LR: 0.001 17 | 18 | MODEL: 19 | MODEL_NAME: "DeepLabV3_Plus" 20 | BACKBONE: "xception65" 21 | BN_EPS_FOR_ENCODER: 1e-3 22 | -------------------------------------------------------------------------------- /configs/cityscapes_pointrend_deeplabv3_plus.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "cityscape" 3 | MEAN: [0.5, 0.5, 0.5] 4 | STD: [0.5, 0.5, 0.5] 5 | TRAIN: 6 | EPOCHS: 200 7 | BATCH_SIZE: 2 8 | CROP_SIZE: 769 9 | TEST: 10 | BATCH_SIZE: 2 11 | CROP_SIZE: (1025, 2049) 12 | # TEST_MODEL_PATH: trained_models/deeplabv3_plus_xception_segmentron.pth 13 | 14 | SOLVER: 15 | LR: 0.02 16 | 17 | MODEL: 18 | MODEL_NAME: "PointRend" 19 | BACKBONE: "xception65" 20 | BN_EPS_FOR_ENCODER: 1e-3 21 | DEEPLABV3_PLUS: 22 | ENABLE_DECODER: False 23 | -------------------------------------------------------------------------------- /configs/pascal_voc_deeplabv3_plus.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "pascal_voc" 3 | MEAN: [0.5, 0.5, 0.5] 4 | STD: [0.5, 0.5, 0.5] 5 | MODE: 'val' 6 | TRAIN: 7 | EPOCHS: 50 8 | BATCH_SIZE: 8 9 | CROP_SIZE: 480 10 | BASE_SIZE: 520 11 | # PRETRAINED_MODEL_PATH: "./runs/checkpoints/DeepLabV3_Plus_xception65_pascal_aug_2019-11-28-03-07/best_model.pth" 12 | 13 | TEST: 14 | BATCH_SIZE: 8 15 | 16 | SOLVER: 17 | LR: 0.0001 18 | 19 | MODEL: 20 | MODEL_NAME: "DeepLabV3_Plus" 21 | BACKBONE: "xception65" 22 | BN_EPS_FOR_ENCODER: 1e-3 23 | -------------------------------------------------------------------------------- /configs/cityscapes_pspnet_resnet.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "cityscape" 3 | MEAN: [0.485, 0.456, 0.406] 4 | STD: [0.229, 0.224, 0.225] 5 | TRAIN: 6 | EPOCHS: 200 7 | BATCH_SIZE: 4 8 | CROP_SIZE: 713 9 | TEST: 10 | BATCH_SIZE: 4 11 | CROP_SIZE: (1025, 2049) 12 | # TEST_MODEL_PATH: trained_models/pspnet_resnet101_segmentron.pth 13 | 14 | SOLVER: 15 | LR: 0.01 16 | AUX: True 17 | AUX_WEIGHT: 0.4 18 | 19 | AUG: 20 | BLUR_PROB: 0.5 21 | 22 | MODEL: 23 | MODEL_NAME: "PSPNet" 24 | BACKBONE: "resnet101" 25 | OUTPUT_STRIDE: 8 26 | -------------------------------------------------------------------------------- /configs/cityscapes_fast_scnn.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "cityscape" 3 | MEAN: [0.485, 0.456, 0.406] 4 | STD: [0.229, 0.224, 0.225] 5 | TRAIN: 6 | EPOCHS: 1000 7 | BATCH_SIZE: 12 8 | CROP_SIZE: (512, 1024) 9 | TEST: 10 | BATCH_SIZE: 4 11 | TEST_MODEL_PATH: 'runs/checkpoints/fast_scnn__cityscape_2019-11-19-02-02/best_model.pth' 12 | 13 | SOLVER: 14 | LR: 0.045 15 | DECODER_LR_FACTOR: 1.0 16 | WEIGHT_DECAY: 4e-5 17 | AUX: True 18 | AUX_WEIGHT: 0.4 19 | 20 | AUG: 21 | COLOR_JITTER: 0.4 22 | 23 | MODEL: 24 | MODEL_NAME: "FastSCNN" 25 | BN_MOMENTUM: 0.01 26 | -------------------------------------------------------------------------------- /configs/coco_deeplabv3_plus_mobilenet.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "coco" 3 | MEAN: [0.485, 0.456, 0.406] 4 | STD: [0.229, 0.224, 0.225] 5 | MODE: 'val' 6 | TRAIN: 7 | EPOCHS: 30 8 | BATCH_SIZE: 8 9 | CROP_SIZE: 480 10 | BASE_SIZE: 520 11 | BACKBONE_PRETRAINED_PATH: "/workspace/pretrained_models/mobilenet-convert-from-torchvision.pth" 12 | TEST: 13 | BATCH_SIZE: 8 14 | 15 | SOLVER: 16 | LR: 0.01 17 | 18 | MODEL: 19 | MODEL_NAME: "DeepLabV3_Plus" 20 | BACKBONE: "mobilenet_v2" 21 | OUTPUT_STRIDE: 16 22 | DEEPLABV3_PLUS: 23 | USE_ASPP: False 24 | ENABLE_DECODER: False 25 | -------------------------------------------------------------------------------- /segmentron/data/dataloader/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module provides data loaders and transformers for popular vision datasets. 3 | """ 4 | from .mscoco import COCOSegmentation 5 | from .cityscapes import CitySegmentation 6 | from .ade import ADE20KSegmentation 7 | from .pascal_voc import VOCSegmentation 8 | from .pascal_aug import VOCAugSegmentation 9 | from .sbu_shadow import SBUSegmentation 10 | 11 | datasets = { 12 | 'ade20k': ADE20KSegmentation, 13 | 'pascal_voc': VOCSegmentation, 14 | 'pascal_aug': VOCAugSegmentation, 15 | 'coco': COCOSegmentation, 16 | 'cityscape': CitySegmentation, 17 | 'sbu': SBUSegmentation, 18 | } 19 | 20 | 21 | def get_segmentation_dataset(name, **kwargs): 22 | """Segmentation Datasets""" 23 | return datasets[name.lower()](**kwargs) 24 | -------------------------------------------------------------------------------- /segmentron/utils/env.py: -------------------------------------------------------------------------------- 1 | # this code heavily based on detectron2 2 | 3 | import logging 4 | import numpy as np 5 | import os 6 | import random 7 | from datetime import datetime 8 | import torch 9 | 10 | __all__ = ["seed_all_rng"] 11 | 12 | 13 | def seed_all_rng(seed=None): 14 | """ 15 | Set the random seed for the RNG in torch, numpy and python. 16 | 17 | Args: 18 | seed (int): if None, will use a strong random seed. 19 | """ 20 | if seed is None: 21 | seed = ( 22 | os.getpid() 23 | + int(datetime.now().strftime("%S%f")) 24 | + int.from_bytes(os.urandom(2), "big") 25 | ) 26 | logger = logging.getLogger(__name__) 27 | logger.info("Using a generated random seed {}".format(seed)) 28 | np.random.seed(seed) 29 | torch.set_rng_state(torch.manual_seed(seed).get_state()) 30 | random.seed(seed) 31 | -------------------------------------------------------------------------------- /segmentron/models/__init__.py: -------------------------------------------------------------------------------- 1 | """Model Zoo""" 2 | from .model_zoo import MODEL_REGISTRY 3 | from .fast_scnn import FastSCNN 4 | from .deeplabv3_plus import DeepLabV3Plus 5 | from .hrnet_seg import HighResolutionNet 6 | from .fcn import FCN 7 | from .dfanet import DFANet 8 | from .pspnet import PSPNet 9 | from .icnet import ICNet 10 | from .danet import DANet 11 | # from .ccnet import CCNet 12 | from .bisenet import BiSeNet 13 | from .cgnet import CGNet 14 | from .denseaspp import DenseASPP 15 | from .dunet import DUNet 16 | from .encnet import EncNet 17 | from .lednet import LEDNet 18 | from .ocnet import OCNet 19 | from .hardnet import HardNet 20 | from .refinenet import RefineNet 21 | from .dabnet import DABNet 22 | from .unet import UNet 23 | from .fpenet import FPENet 24 | from .contextnet import ContextNet 25 | from .espnetv2 import ESPNetV2 26 | from .enet import ENet 27 | from .edanet import EDANet 28 | from .pointrend import PointRend 29 | -------------------------------------------------------------------------------- /segmentron/utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | 5 | __all__ = ['setup_logger'] 6 | 7 | 8 | def setup_logger(name, save_dir, distributed_rank, filename="log.txt", mode='w'): 9 | if distributed_rank > 0: 10 | return 11 | 12 | logging.root.name = name 13 | logging.root.setLevel(logging.INFO) 14 | # don't log results for the non-master process 15 | ch = logging.StreamHandler(stream=sys.stdout) 16 | ch.setLevel(logging.DEBUG) 17 | formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s") 18 | ch.setFormatter(formatter) 19 | logging.root.addHandler(ch) 20 | 21 | if save_dir: 22 | if not os.path.exists(save_dir): 23 | os.makedirs(save_dir) 24 | fh = logging.FileHandler(os.path.join(save_dir, filename), mode=mode) # 'a+' for add, 'w' for overwrite 25 | fh.setLevel(logging.DEBUG) 26 | fh.setFormatter(formatter) 27 | logging.root.addHandler(fh) 28 | -------------------------------------------------------------------------------- /segmentron/models/fcn.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import torch.nn.functional as F 4 | 5 | from .segbase import SegBaseModel 6 | from .model_zoo import MODEL_REGISTRY 7 | from ..modules import _FCNHead 8 | 9 | __all__ = ['FCN'] 10 | 11 | 12 | @MODEL_REGISTRY.register() 13 | class FCN(SegBaseModel): 14 | def __init__(self): 15 | super(FCN, self).__init__() 16 | self.head = _FCNHead(2048, self.nclass) 17 | if self.aux: 18 | self.auxlayer = _FCNHead(1024, self.nclass) 19 | 20 | self.__setattr__('decoder', ['head', 'auxlayer'] if self.aux else ['head']) 21 | 22 | def forward(self, x): 23 | size = x.size()[2:] 24 | _, _, c3, c4 = self.base_forward(x) 25 | 26 | outputs = [] 27 | x = self.head(c4) 28 | x = F.interpolate(x, size, mode='bilinear', align_corners=True) 29 | outputs.append(x) 30 | if self.aux: 31 | auxout = self.auxlayer(c3) 32 | auxout = F.interpolate(auxout, size, mode='bilinear', align_corners=True) 33 | outputs.append(auxout) 34 | return tuple(outputs) 35 | -------------------------------------------------------------------------------- /segmentron/utils/options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def parse_args(): 4 | parser = argparse.ArgumentParser(description='Segmentron') 5 | parser.add_argument('--config-file', metavar="FILE", 6 | help='config file path') 7 | # cuda setting 8 | parser.add_argument('--no-cuda', action='store_true', default=False, 9 | help='disables CUDA training') 10 | parser.add_argument('--local_rank', type=int, default=0) 11 | # checkpoint and log 12 | parser.add_argument('--resume', type=str, default=None, 13 | help='put the path to resuming file if needed') 14 | parser.add_argument('--log-iter', type=int, default=10, 15 | help='print log every log-iter') 16 | # for evaluation 17 | parser.add_argument('--val-epoch', type=int, default=1, 18 | help='run validation every val-epoch') 19 | parser.add_argument('--skip-val', action='store_true', default=False, 20 | help='skip validation during training') 21 | # for visual 22 | parser.add_argument('--input-img', type=str, default='tools/demo_vis.png', 23 | help='path to the input image or a directory of images') 24 | # config options 25 | parser.add_argument('opts', help='See config for all options', 26 | default=None, nargs=argparse.REMAINDER) 27 | args = parser.parse_args() 28 | 29 | return args -------------------------------------------------------------------------------- /segmentron/utils/default_setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import json 4 | import torch 5 | 6 | from .distributed import get_rank, synchronize 7 | from .logger import setup_logger 8 | from .env import seed_all_rng 9 | from ..config import cfg 10 | 11 | def default_setup(args): 12 | num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 13 | args.num_gpus = num_gpus 14 | args.distributed = num_gpus > 1 15 | 16 | if not args.no_cuda and torch.cuda.is_available(): 17 | # cudnn.deterministic = True 18 | torch.backends.cudnn.benchmark = True 19 | args.device = "cuda" 20 | else: 21 | args.distributed = False 22 | args.device = "cpu" 23 | if args.distributed: 24 | torch.cuda.set_device(args.local_rank) 25 | torch.distributed.init_process_group(backend="nccl", init_method="env://") 26 | synchronize() 27 | 28 | # TODO 29 | # if args.save_pred: 30 | # outdir = '../runs/pred_pic/{}_{}_{}'.format(args.model, args.backbone, args.dataset) 31 | # if not os.path.exists(outdir): 32 | # os.makedirs(outdir) 33 | 34 | save_dir = cfg.TRAIN.LOG_SAVE_DIR if cfg.PHASE == 'train' else None 35 | setup_logger("Segmentron", save_dir, get_rank(), filename='{}_{}_{}_{}_log.txt'.format( 36 | cfg.MODEL.MODEL_NAME, cfg.MODEL.BACKBONE, cfg.DATASET.NAME, cfg.TIME_STAMP)) 37 | 38 | logging.info("Using {} GPUs".format(num_gpus)) 39 | logging.info(args) 40 | logging.info(json.dumps(cfg, indent=8)) 41 | 42 | seed_all_rng(None if cfg.SEED < 0 else cfg.SEED + get_rank()) -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # mac os 107 | __MACOSX/ 108 | -------------------------------------------------------------------------------- /segmentron/models/model_zoo.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | 4 | from collections import OrderedDict 5 | from segmentron.utils.registry import Registry 6 | from ..config import cfg 7 | 8 | MODEL_REGISTRY = Registry("MODEL") 9 | MODEL_REGISTRY.__doc__ = """ 10 | Registry for segment model, i.e. the whole model. 11 | 12 | The registered object will be called with `obj()` 13 | and expected to return a `nn.Module` object. 14 | """ 15 | 16 | 17 | def get_segmentation_model(): 18 | """ 19 | Built the whole model, defined by `cfg.MODEL.META_ARCHITECTURE`. 20 | """ 21 | model_name = cfg.MODEL.MODEL_NAME 22 | model = MODEL_REGISTRY.get(model_name)() 23 | load_model_pretrain(model) 24 | return model 25 | 26 | 27 | def load_model_pretrain(model): 28 | if cfg.PHASE == 'train': 29 | if cfg.TRAIN.PRETRAINED_MODEL_PATH: 30 | logging.info('load pretrained model from {}'.format(cfg.TRAIN.PRETRAINED_MODEL_PATH)) 31 | state_dict_to_load = torch.load(cfg.TRAIN.PRETRAINED_MODEL_PATH) 32 | keys_wrong_shape = [] 33 | state_dict_suitable = OrderedDict() 34 | state_dict = model.state_dict() 35 | for k, v in state_dict_to_load.items(): 36 | if v.shape == state_dict[k].shape: 37 | state_dict_suitable[k] = v 38 | else: 39 | keys_wrong_shape.append(k) 40 | logging.info('Shape unmatched weights: {}'.format(keys_wrong_shape)) 41 | msg = model.load_state_dict(state_dict_suitable, strict=False) 42 | logging.info(msg) 43 | else: 44 | if cfg.TEST.TEST_MODEL_PATH: 45 | logging.info('load test model from {}'.format(cfg.TEST.TEST_MODEL_PATH)) 46 | msg = model.load_state_dict(torch.load(cfg.TEST.TEST_MODEL_PATH), strict=False) 47 | logging.info(msg) -------------------------------------------------------------------------------- /configs/cityscapes_hrnet_w18_small_v1.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "cityscape" 3 | MEAN: [0.485, 0.456, 0.406] 4 | STD: [0.229, 0.224, 0.225] 5 | 6 | TRAIN: 7 | EPOCHS: 484 8 | CROP_SIZE: (512, 1024) 9 | BATCH_SIZE: 3 10 | 11 | TEST: 12 | BATCH_SIZE: 4 13 | # TEST_MODEL_PATH: trained_models/hrnet_w18_small_v1_segmentron.pth 14 | 15 | SOLVER: 16 | LR: 0.01 17 | WEIGHT_DECAY: 5e-4 18 | DECODER_LR_FACTOR: 1.0 19 | 20 | MODEL: 21 | MODEL_NAME: "HRNet" 22 | BACKBONE: "hrnet_w18_small_v1" 23 | BN_TYPE: 'BN' 24 | BN_MOMENTUM: 0.01 25 | HRNET: 26 | FINAL_CONV_KERNEL: 1 27 | STAGE1: 28 | NUM_MODULES: 1 29 | NUM_BRANCHES: 1 30 | BLOCK: BOTTLENECK 31 | NUM_BLOCKS: 32 | - 1 33 | NUM_CHANNELS: 34 | - 32 35 | FUSE_METHOD: SUM 36 | STAGE2: 37 | NUM_MODULES: 1 38 | NUM_BRANCHES: 2 39 | BLOCK: BASIC 40 | NUM_BLOCKS: 41 | - 2 42 | - 2 43 | NUM_CHANNELS: 44 | - 16 45 | - 32 46 | FUSE_METHOD: SUM 47 | STAGE3: 48 | NUM_MODULES: 1 49 | NUM_BRANCHES: 3 50 | BLOCK: BASIC 51 | NUM_BLOCKS: 52 | - 2 53 | - 2 54 | - 2 55 | NUM_CHANNELS: 56 | - 16 57 | - 32 58 | - 64 59 | FUSE_METHOD: SUM 60 | STAGE4: 61 | NUM_MODULES: 1 62 | NUM_BRANCHES: 4 63 | BLOCK: BASIC 64 | NUM_BLOCKS: 65 | - 2 66 | - 2 67 | - 2 68 | - 2 69 | NUM_CHANNELS: 70 | - 16 71 | - 32 72 | - 64 73 | - 128 74 | FUSE_METHOD: SUM 75 | 76 | -------------------------------------------------------------------------------- /configs/cityscapes_hrnet.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "cityscape" 3 | MEAN: [0.485, 0.456, 0.406] 4 | STD: [0.229, 0.224, 0.225] 5 | 6 | TRAIN: 7 | EPOCHS: 400 8 | CROP_SIZE: (512, 1024) 9 | BATCH_SIZE: 4 10 | BACKBONE_PRETRAINED_PATH: /workspace/pretrained_models/hrnet_w18_small_model_v1.pth 11 | 12 | TEST: 13 | BATCH_SIZE: 4 14 | CROP_SIZE: (1024, 2048) 15 | TEST_MODEL_PATH: ../trained_models/hrnet_w18_small_v1_cityscapes_cls19_1024x2048_trainset_segmentron.pth 16 | 17 | SOLVER: 18 | LR: 0.01 19 | 20 | MODEL: 21 | MODEL_NAME: "hrnet" 22 | BACKBONE: "hrnet" 23 | BN_TYPE: 'SyncBN' 24 | HRNET: 25 | FINAL_CONV_KERNEL: 1 26 | STAGE1: 27 | NUM_MODULES: 1 28 | NUM_BRANCHES: 1 29 | BLOCK: BOTTLENECK 30 | NUM_BLOCKS: 31 | - 1 32 | NUM_CHANNELS: 33 | - 32 34 | FUSE_METHOD: SUM 35 | STAGE2: 36 | NUM_MODULES: 1 37 | NUM_BRANCHES: 2 38 | BLOCK: BASIC 39 | NUM_BLOCKS: 40 | - 2 41 | - 2 42 | NUM_CHANNELS: 43 | - 16 44 | - 32 45 | FUSE_METHOD: SUM 46 | STAGE3: 47 | NUM_MODULES: 1 48 | NUM_BRANCHES: 3 49 | BLOCK: BASIC 50 | NUM_BLOCKS: 51 | - 2 52 | - 2 53 | - 2 54 | NUM_CHANNELS: 55 | - 16 56 | - 32 57 | - 64 58 | FUSE_METHOD: SUM 59 | STAGE4: 60 | NUM_MODULES: 1 61 | NUM_BRANCHES: 4 62 | BLOCK: BASIC 63 | NUM_BLOCKS: 64 | - 2 65 | - 2 66 | - 2 67 | - 2 68 | NUM_CHANNELS: 69 | - 16 70 | - 32 71 | - 64 72 | - 128 73 | FUSE_METHOD: SUM 74 | 75 | -------------------------------------------------------------------------------- /docs/DATA_PREPARE.md: -------------------------------------------------------------------------------- 1 | ## data prepare 2 | 3 | It is recommended to symlink the dataset root to `$SEGMENTRON/datasets`. 4 | 5 | 6 | ``` 7 | SegmenTron 8 | |-- configs 9 | |-- datasets 10 | | |-- ade 11 | | | |-- ADEChallengeData2016 12 | | | | |-- annotations 13 | | | | `-- images 14 | | | |-- downloads 15 | | | `-- release_test 16 | | | `-- testing 17 | | |-- cityscapes 18 | | | |-- gtFine 19 | | | | |-- test 20 | | | | |-- train 21 | | | | `-- val 22 | | | `-- leftImg8bit 23 | | | |-- test 24 | | | |-- train 25 | | | `-- val 26 | | |-- coco 27 | | | |-- annotations 28 | | | |-- train2017 29 | | | `-- val2017 30 | | `-- voc 31 | | |-- VOC2007 32 | | | |-- Annotations 33 | | | |-- ImageSets 34 | | | |-- JPEGImages 35 | | | |-- SegmentationClass 36 | | | `-- SegmentationObject 37 | | |-- VOC2012 38 | | | |-- Annotations 39 | | | |-- ImageSets 40 | | | |-- JPEGImages 41 | | | |-- SegmentationClass 42 | | | `-- SegmentationObject 43 | | `-- VOCaug 44 | | |-- benchmark_code_RELEASE 45 | | `-- dataset 46 | |-- docs 47 | |-- segmentron 48 | |-- tools 49 | 50 | ``` 51 | 52 | ### cityscape 53 | Goto [Cityscape](https://www.cityscapes-dataset.com) register a account and download datasets. 54 | 55 | ### coco 56 | 57 | run following command, and it will automatically symlink ```your-download-dir``` to ```datasets/coco``` 58 | ``` 59 | python segmentron/data/downloader/mscoco.py --download-dir your-download-dir 60 | ``` 61 | 62 | ### pascal aug & voc 63 | run following command, and it will automatically symlink ```your-download-dir``` to ```datasets/voc``` 64 | ``` 65 | python segmentron/data/downloader/pascal_voc.py --download-dir your-download-dir 66 | ``` 67 | 68 | ### ade20k 69 | run following command, and it will automatically symlink ```your-download-dir``` to ```datasets/ade``` 70 | ``` 71 | python segmentron/data/downloader/ade20k.py --download-dir your-download-dir 72 | ``` -------------------------------------------------------------------------------- /segmentron/models/pspnet.py: -------------------------------------------------------------------------------- 1 | """Pyramid Scene Parsing Network""" 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from .segbase import SegBaseModel 7 | from .model_zoo import MODEL_REGISTRY 8 | from ..modules import _FCNHead, PyramidPooling 9 | 10 | __all__ = ['PSPNet'] 11 | 12 | 13 | @MODEL_REGISTRY.register() 14 | class PSPNet(SegBaseModel): 15 | r"""Pyramid Scene Parsing Network 16 | Reference: 17 | Zhao, Hengshuang, Jianping Shi, Xiaojuan Qi, Xiaogang Wang, and Jiaya Jia. 18 | "Pyramid scene parsing network." *CVPR*, 2017 19 | """ 20 | 21 | def __init__(self): 22 | super(PSPNet, self).__init__() 23 | self.head = _PSPHead(self.nclass) 24 | if self.aux: 25 | self.auxlayer = _FCNHead(1024, self.nclass) 26 | 27 | self.__setattr__('decoder', ['head', 'auxlayer'] if self.aux else ['head']) 28 | 29 | def forward(self, x): 30 | size = x.size()[2:] 31 | _, _, c3, c4 = self.encoder(x) 32 | outputs = [] 33 | x = self.head(c4) 34 | x = F.interpolate(x, size, mode='bilinear', align_corners=True) 35 | outputs.append(x) 36 | 37 | if self.aux: 38 | auxout = self.auxlayer(c3) 39 | auxout = F.interpolate(auxout, size, mode='bilinear', align_corners=True) 40 | outputs.append(auxout) 41 | return tuple(outputs) 42 | 43 | 44 | class _PSPHead(nn.Module): 45 | def __init__(self, nclass, norm_layer=nn.BatchNorm2d, norm_kwargs=None, **kwargs): 46 | super(_PSPHead, self).__init__() 47 | self.psp = PyramidPooling(2048, norm_layer=norm_layer, norm_kwargs=norm_kwargs) 48 | self.block = nn.Sequential( 49 | nn.Conv2d(4096, 512, 3, padding=1, bias=False), 50 | norm_layer(512, **({} if norm_kwargs is None else norm_kwargs)), 51 | nn.ReLU(True), 52 | nn.Dropout(0.1), 53 | nn.Conv2d(512, nclass, 1) 54 | ) 55 | 56 | def forward(self, x): 57 | x = self.psp(x) 58 | return self.block(x) 59 | 60 | -------------------------------------------------------------------------------- /tools/demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | 5 | cur_path = os.path.abspath(os.path.dirname(__file__)) 6 | root_path = os.path.split(cur_path)[0] 7 | sys.path.append(root_path) 8 | 9 | from torchvision import transforms 10 | from PIL import Image 11 | from segmentron.utils.visualize import get_color_pallete 12 | from segmentron.models.model_zoo import get_segmentation_model 13 | from segmentron.utils.options import parse_args 14 | from segmentron.utils.default_setup import default_setup 15 | from segmentron.config import cfg 16 | 17 | 18 | def demo(): 19 | args = parse_args() 20 | cfg.update_from_file(args.config_file) 21 | cfg.PHASE = 'test' 22 | cfg.ROOT_PATH = root_path 23 | cfg.check_and_freeze() 24 | default_setup(args) 25 | 26 | # output folder 27 | output_dir = os.path.join(cfg.VISUAL.OUTPUT_DIR, 'vis_result_{}_{}_{}_{}'.format( 28 | cfg.MODEL.MODEL_NAME, cfg.MODEL.BACKBONE, cfg.DATASET.NAME, cfg.TIME_STAMP)) 29 | if not os.path.exists(output_dir): 30 | os.makedirs(output_dir) 31 | 32 | # image transform 33 | transform = transforms.Compose([ 34 | transforms.ToTensor(), 35 | transforms.Normalize(cfg.DATASET.MEAN, cfg.DATASET.STD), 36 | ]) 37 | 38 | model = get_segmentation_model().to(args.device) 39 | model.eval() 40 | 41 | if os.path.isdir(args.input_img): 42 | img_paths = [os.path.join(args.input_img, x) for x in os.listdir(args.input_img)] 43 | else: 44 | img_paths = [args.input_img] 45 | for img_path in img_paths: 46 | image = Image.open(img_path).convert('RGB') 47 | images = transform(image).unsqueeze(0).to(args.device) 48 | with torch.no_grad(): 49 | output = model(images) 50 | 51 | pred = torch.argmax(output[0], 1).squeeze(0).cpu().data.numpy() 52 | mask = get_color_pallete(pred, cfg.DATASET.NAME) 53 | outname = os.path.splitext(os.path.split(img_path)[-1])[0] + '.png' 54 | mask.save(os.path.join(output_dir, outname)) 55 | 56 | 57 | if __name__ == '__main__': 58 | demo() 59 | -------------------------------------------------------------------------------- /segmentron/modules/cc_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from torch.autograd.function import once_differentiable 6 | from segmentron import _C 7 | 8 | __all__ = ['CrissCrossAttention', 'ca_weight', 'ca_map'] 9 | 10 | 11 | class _CAWeight(torch.autograd.Function): 12 | @staticmethod 13 | def forward(ctx, t, f): 14 | weight = _C.ca_forward(t, f) 15 | 16 | ctx.save_for_backward(t, f) 17 | 18 | return weight 19 | 20 | @staticmethod 21 | @once_differentiable 22 | def backward(ctx, dw): 23 | t, f = ctx.saved_tensors 24 | 25 | dt, df = _C.ca_backward(dw, t, f) 26 | return dt, df 27 | 28 | 29 | class _CAMap(torch.autograd.Function): 30 | @staticmethod 31 | def forward(ctx, weight, g): 32 | out = _C.ca_map_forward(weight, g) 33 | 34 | ctx.save_for_backward(weight, g) 35 | 36 | return out 37 | 38 | @staticmethod 39 | @once_differentiable 40 | def backward(ctx, dout): 41 | weight, g = ctx.saved_tensors 42 | 43 | dw, dg = _C.ca_map_backward(dout, weight, g) 44 | 45 | return dw, dg 46 | 47 | 48 | ca_weight = _CAWeight.apply 49 | ca_map = _CAMap.apply 50 | 51 | 52 | class CrissCrossAttention(nn.Module): 53 | """Criss-Cross Attention Module""" 54 | 55 | def __init__(self, in_channels): 56 | super(CrissCrossAttention, self).__init__() 57 | self.query_conv = nn.Conv2d(in_channels, in_channels // 8, 1) 58 | self.key_conv = nn.Conv2d(in_channels, in_channels // 8, 1) 59 | self.value_conv = nn.Conv2d(in_channels, in_channels, 1) 60 | self.gamma = nn.Parameter(torch.zeros(1)) 61 | 62 | def forward(self, x): 63 | proj_query = self.query_conv(x) 64 | proj_key = self.key_conv(x) 65 | proj_value = self.value_conv(x) 66 | 67 | energy = ca_weight(proj_query, proj_key) 68 | attention = F.softmax(energy, 1) 69 | out = ca_map(attention, proj_value) 70 | out = self.gamma * out + x 71 | 72 | return out 73 | -------------------------------------------------------------------------------- /segmentron/data/downloader/ade20k.py: -------------------------------------------------------------------------------- 1 | """Prepare ADE20K dataset""" 2 | import os 3 | import sys 4 | 5 | cur_path = os.path.abspath(os.path.dirname(__file__)) 6 | root_path = os.path.split(os.path.split(os.path.split(cur_path)[0])[0])[0] 7 | sys.path.append(root_path) 8 | 9 | import argparse 10 | import zipfile 11 | from segmentron.utils import download, makedirs 12 | 13 | 14 | def parse_args(): 15 | parser = argparse.ArgumentParser( 16 | description='Initialize ADE20K dataset.', 17 | epilog='Example: python setup_ade20k.py', 18 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 19 | parser.add_argument('--download-dir', default=None, help='dataset directory on disk') 20 | args = parser.parse_args() 21 | return args 22 | 23 | def download_ade(path, overwrite=False): 24 | _AUG_DOWNLOAD_URLS = [ 25 | ('http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip', 26 | '219e1696abb36c8ba3a3afe7fb2f4b4606a897c7'), 27 | ('http://data.csail.mit.edu/places/ADEchallenge/release_test.zip', 28 | 'e05747892219d10e9243933371a497e905a4860c'), 29 | ] 30 | download_dir = os.path.join(path, 'downloads') 31 | makedirs(download_dir) 32 | for url, checksum in _AUG_DOWNLOAD_URLS: 33 | filename = download(url, path=download_dir, overwrite=overwrite, sha1_hash=checksum) 34 | # extract 35 | with zipfile.ZipFile(filename,"r") as zip_ref: 36 | zip_ref.extractall(path=path) 37 | 38 | 39 | if __name__ == '__main__': 40 | args = parse_args() 41 | default_dir = os.path.join(root_path, 'datasets/ade') 42 | if args.download_dir is not None: 43 | _TARGET_DIR = args.download_dir 44 | else: 45 | _TARGET_DIR = default_dir 46 | makedirs(_TARGET_DIR) 47 | 48 | if os.path.exists(default_dir): 49 | print('{} is already exist!'.format(default_dir)) 50 | else: 51 | try: 52 | os.symlink(_TARGET_DIR, default_dir) 53 | except Exception as e: 54 | print(e) 55 | download_ade(_TARGET_DIR, overwrite=False) 56 | -------------------------------------------------------------------------------- /segmentron/data/downloader/cityscapes.py: -------------------------------------------------------------------------------- 1 | """Prepare Cityscapes dataset""" 2 | import os 3 | import sys 4 | import argparse 5 | import zipfile 6 | 7 | cur_path = os.path.abspath(os.path.dirname(__file__)) 8 | root_path = os.path.split(os.path.split(os.path.split(cur_path)[0])[0])[0] 9 | sys.path.append(root_path) 10 | 11 | from segmentron.utils import makedirs, check_sha1 12 | 13 | _TARGET_DIR = os.path.expanduser('~/.torch/datasets/citys') 14 | 15 | 16 | def parse_args(): 17 | parser = argparse.ArgumentParser( 18 | description='Initialize ADE20K dataset.', 19 | epilog='Example: python prepare_cityscapes.py', 20 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 21 | parser.add_argument('--download-dir', default=None, help='dataset directory on disk') 22 | args = parser.parse_args() 23 | return args 24 | 25 | 26 | def download_city(path, overwrite=False): 27 | _CITY_DOWNLOAD_URLS = [ 28 | ('gtFine_trainvaltest.zip', '99f532cb1af174f5fcc4c5bc8feea8c66246ddbc'), 29 | ('leftImg8bit_trainvaltest.zip', '2c0b77ce9933cc635adda307fbba5566f5d9d404')] 30 | download_dir = os.path.join(path, 'downloads') 31 | makedirs(download_dir) 32 | for filename, checksum in _CITY_DOWNLOAD_URLS: 33 | if not check_sha1(filename, checksum): 34 | raise UserWarning('File {} is downloaded but the content hash does not match. ' \ 35 | 'The repo may be outdated or download may be incomplete. ' \ 36 | 'If the "repo_url" is overridden, consider switching to ' \ 37 | 'the default repo.'.format(filename)) 38 | # extract 39 | with zipfile.ZipFile(filename, "r") as zip_ref: 40 | zip_ref.extractall(path=path) 41 | print("Extracted", filename) 42 | 43 | 44 | if __name__ == '__main__': 45 | args = parse_args() 46 | makedirs(os.path.expanduser('~/.torch/datasets')) 47 | if args.download_dir is not None: 48 | if os.path.isdir(_TARGET_DIR): 49 | os.remove(_TARGET_DIR) 50 | # make symlink 51 | os.symlink(args.download_dir, _TARGET_DIR) 52 | else: 53 | download_city(_TARGET_DIR, overwrite=False) 54 | -------------------------------------------------------------------------------- /segmentron/utils/filesystem.py: -------------------------------------------------------------------------------- 1 | """Filesystem utility functions.""" 2 | from __future__ import absolute_import 3 | import os 4 | import errno 5 | import torch 6 | import logging 7 | 8 | from ..config import cfg 9 | 10 | def save_checkpoint(model, epoch, optimizer=None, lr_scheduler=None, is_best=False): 11 | """Save Checkpoint""" 12 | directory = os.path.expanduser(cfg.TRAIN.MODEL_SAVE_DIR) 13 | directory = os.path.join(directory, '{}_{}_{}_{}'.format(cfg.MODEL.MODEL_NAME, cfg.MODEL.BACKBONE, 14 | cfg.DATASET.NAME, cfg.TIME_STAMP)) 15 | if not os.path.exists(directory): 16 | os.makedirs(directory) 17 | filename = '{}.pth'.format(str(epoch)) 18 | filename = os.path.join(directory, filename) 19 | model_state_dict = model.module.state_dict() if hasattr(model, 'module') else model.state_dict() 20 | if is_best: 21 | best_filename = 'best_model.pth' 22 | best_filename = os.path.join(directory, best_filename) 23 | torch.save(model_state_dict, best_filename) 24 | else: 25 | save_state = { 26 | 'epoch': epoch, 27 | 'state_dict': model_state_dict, 28 | 'optimizer': optimizer.state_dict(), 29 | 'lr_scheduler': lr_scheduler.state_dict() 30 | } 31 | if not os.path.exists(filename): 32 | torch.save(save_state, filename) 33 | logging.info('Epoch {} model saved in: {}'.format(epoch, filename)) 34 | 35 | # remove last epoch 36 | pre_filename = '{}.pth'.format(str(epoch - 1)) 37 | pre_filename = os.path.join(directory, pre_filename) 38 | try: 39 | if os.path.exists(pre_filename): 40 | os.remove(pre_filename) 41 | except OSError as e: 42 | logging.info(e) 43 | 44 | def makedirs(path): 45 | """Create directory recursively if not exists. 46 | Similar to `makedir -p`, you can skip checking existence before this function. 47 | Parameters 48 | ---------- 49 | path : str 50 | Path of the desired dir 51 | """ 52 | try: 53 | os.makedirs(path) 54 | except OSError as exc: 55 | if exc.errno != errno.EEXIST: 56 | raise 57 | 58 | -------------------------------------------------------------------------------- /segmentron/utils/registry.py: -------------------------------------------------------------------------------- 1 | # this code heavily based on detectron2 2 | 3 | import logging 4 | import torch 5 | 6 | from ..config import cfg 7 | 8 | class Registry(object): 9 | """ 10 | The registry that provides name -> object mapping, to support third-party users' custom modules. 11 | 12 | To create a registry (inside segmentron): 13 | 14 | .. code-block:: python 15 | 16 | BACKBONE_REGISTRY = Registry('BACKBONE') 17 | 18 | To register an object: 19 | 20 | .. code-block:: python 21 | 22 | @BACKBONE_REGISTRY.register() 23 | class MyBackbone(): 24 | ... 25 | 26 | Or: 27 | 28 | .. code-block:: python 29 | 30 | BACKBONE_REGISTRY.register(MyBackbone) 31 | """ 32 | 33 | def __init__(self, name): 34 | """ 35 | Args: 36 | name (str): the name of this registry 37 | """ 38 | self._name = name 39 | 40 | self._obj_map = {} 41 | 42 | def _do_register(self, name, obj): 43 | assert ( 44 | name not in self._obj_map 45 | ), "An object named '{}' was already registered in '{}' registry!".format(name, self._name) 46 | self._obj_map[name] = obj 47 | 48 | def register(self, obj=None, name=None): 49 | """ 50 | Register the given object under the the name `obj.__name__`. 51 | Can be used as either a decorator or not. See docstring of this class for usage. 52 | """ 53 | if obj is None: 54 | # used as a decorator 55 | def deco(func_or_class, name=name): 56 | if name is None: 57 | name = func_or_class.__name__ 58 | self._do_register(name, func_or_class) 59 | return func_or_class 60 | 61 | return deco 62 | 63 | # used as a function call 64 | if name is None: 65 | name = obj.__name__ 66 | self._do_register(name, obj) 67 | 68 | 69 | 70 | def get(self, name): 71 | ret = self._obj_map.get(name) 72 | if ret is None: 73 | raise KeyError("No object named '{}' found in '{}' registry!".format(name, self._name)) 74 | 75 | return ret 76 | 77 | def get_list(self): 78 | return list(self._obj_map.keys()) 79 | -------------------------------------------------------------------------------- /segmentron/models/hrnet_seg.py: -------------------------------------------------------------------------------- 1 | # this code is heavily based on https://github.com/HRNet 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch._utils 10 | import torch.nn.functional as F 11 | 12 | from .segbase import SegBaseModel 13 | from .model_zoo import MODEL_REGISTRY 14 | from ..config import cfg 15 | 16 | 17 | @MODEL_REGISTRY.register(name='HRNet') 18 | class HighResolutionNet(SegBaseModel): 19 | def __init__(self): 20 | super(HighResolutionNet, self).__init__() 21 | self.hrnet_head = _HRNetHead(self.nclass, self.encoder.last_inp_channels) 22 | self.__setattr__('decoder', ['hrnet_head']) 23 | 24 | def forward(self, x): 25 | shape = x.shape[2:] 26 | x = self.encoder(x) 27 | x = self.hrnet_head(x) 28 | x = F.interpolate(x, size=shape, mode='bilinear', align_corners=False) 29 | return [x] 30 | 31 | 32 | class _HRNetHead(nn.Module): 33 | def __init__(self, nclass, last_inp_channels, norm_layer=nn.BatchNorm2d): 34 | super(_HRNetHead, self).__init__() 35 | 36 | self.last_layer = nn.Sequential( 37 | nn.Conv2d( 38 | in_channels=last_inp_channels, 39 | out_channels=last_inp_channels, 40 | kernel_size=1, 41 | stride=1, 42 | padding=0), 43 | 44 | norm_layer(last_inp_channels), 45 | nn.ReLU(inplace=False), 46 | nn.Conv2d( 47 | in_channels=last_inp_channels, 48 | out_channels=nclass, 49 | kernel_size=cfg.MODEL.HRNET.FINAL_CONV_KERNEL, 50 | stride=1, 51 | padding=1 if cfg.MODEL.HRNET.FINAL_CONV_KERNEL == 3 else 0) 52 | ) 53 | 54 | def forward(self, x): 55 | # Upsampling 56 | x0_h, x0_w = x[0].size(2), x[0].size(3) 57 | x1 = F.interpolate(x[1], size=(x0_h, x0_w), mode='bilinear', align_corners=False) 58 | x2 = F.interpolate(x[2], size=(x0_h, x0_w), mode='bilinear', align_corners=False) 59 | x3 = F.interpolate(x[3], size=(x0_h, x0_w), mode='bilinear', align_corners=False) 60 | 61 | x = torch.cat([x[0], x1, x2, x3], 1) 62 | x = self.last_layer(x) 63 | return x 64 | -------------------------------------------------------------------------------- /segmentron/data/downloader/sbu_shadow.py: -------------------------------------------------------------------------------- 1 | """Prepare SBU Shadow datasets""" 2 | import os 3 | import sys 4 | import argparse 5 | import zipfile 6 | 7 | cur_path = os.path.abspath(os.path.dirname(__file__)) 8 | root_path = os.path.split(os.path.split(os.path.split(cur_path)[0])[0])[0] 9 | sys.path.append(root_path) 10 | 11 | from segmentron.utils import download, makedirs 12 | 13 | 14 | def parse_args(): 15 | parser = argparse.ArgumentParser( 16 | description='Initialize SBU Shadow dataset.', 17 | epilog='Example: python sbu_shadow.py --download-dir ~/SBU-shadow', 18 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 19 | parser.add_argument('--download-dir', type=str, default=None, help='dataset directory on disk') 20 | parser.add_argument('--no-download', action='store_true', help='disable automatic download if set') 21 | parser.add_argument('--overwrite', action='store_true', 22 | help='overwrite downloaded files if set, in case they are corrupted') 23 | args = parser.parse_args() 24 | return args 25 | 26 | 27 | ##################################################################################### 28 | # Download and extract SBU shadow datasets into ``path`` 29 | 30 | def download_sbu(path, overwrite=False): 31 | _DOWNLOAD_URLS = [ 32 | ('http://www3.cs.stonybrook.edu/~cvl/content/datasets/shadow_db/SBU-shadow.zip'), 33 | ] 34 | download_dir = os.path.join(path, 'downloads') 35 | makedirs(download_dir) 36 | for url in _DOWNLOAD_URLS: 37 | filename = download(url, path=path, overwrite=overwrite) 38 | # extract 39 | with zipfile.ZipFile(filename, "r") as zf: 40 | zf.extractall(path=path) 41 | print("Extracted", filename) 42 | 43 | 44 | if __name__ == '__main__': 45 | args = parse_args() 46 | default_dir = os.path.join(root_path, 'datasets/sbu') 47 | if args.download_dir is not None: 48 | _TARGET_DIR = args.download_dir 49 | else: 50 | _TARGET_DIR = default_dir 51 | makedirs(_TARGET_DIR) 52 | if os.path.exists(default_dir): 53 | print('{} is already exist!'.format(default_dir)) 54 | else: 55 | try: 56 | os.symlink(_TARGET_DIR, default_dir) 57 | except Exception as e: 58 | print(e) 59 | download_sbu(_TARGET_DIR, overwrite=False) 60 | -------------------------------------------------------------------------------- /segmentron/modules/csrc/criss_cross_attention/ca.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | 5 | namespace segmentron { 6 | at::Tensor ca_forward_cuda( 7 | const at::Tensor& t, 8 | const at::Tensor& f); 9 | 10 | std::tuple ca_backward_cuda( 11 | const at::Tensor& dw, 12 | const at::Tensor& t, 13 | const at::Tensor& f); 14 | 15 | at::Tensor ca_map_forward_cuda( 16 | const at::Tensor& weight, 17 | const at::Tensor& g); 18 | 19 | std::tuple ca_map_backward_cuda( 20 | const at::Tensor& dout, 21 | const at::Tensor& weight, 22 | const at::Tensor& g); 23 | 24 | 25 | at::Tensor ca_forward(const at::Tensor& t, 26 | const at::Tensor& f) { 27 | if (t.type().is_cuda()) { 28 | #ifdef WITH_CUDA 29 | return ca_forward_cuda(t, f); 30 | #else 31 | AT_ERROR("Not compiled with GPU support"); 32 | #endif 33 | } 34 | AT_ERROR("Not implemented on the CPU"); 35 | } 36 | 37 | std::tuple ca_backward(const at::Tensor& dw, 38 | const at::Tensor& t, 39 | const at::Tensor& f) { 40 | if (dw.type().is_cuda()) { 41 | #ifdef WITH_CUDA 42 | return ca_backward_cuda(dw, t, f); 43 | #else 44 | AT_ERROR("Not compiled with GPU support"); 45 | #endif 46 | } 47 | AT_ERROR("Not implemented on the CPU"); 48 | } 49 | 50 | at::Tensor ca_map_forward(const at::Tensor& weight, 51 | const at::Tensor& g) { 52 | if (weight.type().is_cuda()) { 53 | #ifdef WITH_CUDA 54 | return ca_map_forward_cuda(weight, g); 55 | #else 56 | AT_ERROR("Not compiled with GPU support"); 57 | #endif 58 | } 59 | AT_ERROR("Not implemented on the CPU"); 60 | } 61 | 62 | std::tuple ca_map_backward(const at::Tensor& dout, 63 | const at::Tensor& weight, 64 | const at::Tensor& g) { 65 | if (dout.type().is_cuda()) { 66 | #ifdef WITH_CUDA 67 | return ca_map_backward_cuda(dout, weight, g); 68 | #else 69 | AT_ERROR("Not compiled with GPU support"); 70 | #endif 71 | } 72 | AT_ERROR("Not implemented on the CPU"); 73 | } 74 | 75 | } // namespace segmentron 76 | -------------------------------------------------------------------------------- /segmentron/data/dataloader/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import hashlib 3 | import errno 4 | import tarfile 5 | from six.moves import urllib 6 | from torch.utils.model_zoo import tqdm 7 | 8 | def gen_bar_updater(): 9 | pbar = tqdm(total=None) 10 | 11 | def bar_update(count, block_size, total_size): 12 | if pbar.total is None and total_size: 13 | pbar.total = total_size 14 | progress_bytes = count * block_size 15 | pbar.update(progress_bytes - pbar.n) 16 | 17 | return bar_update 18 | 19 | def check_integrity(fpath, md5=None): 20 | if md5 is None: 21 | return True 22 | if not os.path.isfile(fpath): 23 | return False 24 | md5o = hashlib.md5() 25 | with open(fpath, 'rb') as f: 26 | # read in 1MB chunks 27 | for chunk in iter(lambda: f.read(1024 * 1024), b''): 28 | md5o.update(chunk) 29 | md5c = md5o.hexdigest() 30 | if md5c != md5: 31 | return False 32 | return True 33 | 34 | def makedir_exist_ok(dirpath): 35 | try: 36 | os.makedirs(dirpath) 37 | except OSError as e: 38 | if e.errno == errno.EEXIST: 39 | pass 40 | else: 41 | pass 42 | 43 | def download_url(url, root, filename=None, md5=None): 44 | """Download a file from a url and place it in root.""" 45 | root = os.path.expanduser(root) 46 | if not filename: 47 | filename = os.path.basename(url) 48 | fpath = os.path.join(root, filename) 49 | 50 | makedir_exist_ok(root) 51 | 52 | # downloads file 53 | if os.path.isfile(fpath) and check_integrity(fpath, md5): 54 | print('Using downloaded and verified file: ' + fpath) 55 | else: 56 | try: 57 | print('Downloading ' + url + ' to ' + fpath) 58 | urllib.request.urlretrieve(url, fpath, reporthook=gen_bar_updater()) 59 | except OSError: 60 | if url[:5] == 'https': 61 | url = url.replace('https:', 'http:') 62 | print('Failed download. Trying https -> http instead.' 63 | ' Downloading ' + url + ' to ' + fpath) 64 | urllib.request.urlretrieve(url, fpath, reporthook=gen_bar_updater()) 65 | 66 | def download_extract(url, root, filename, md5): 67 | download_url(url, root, filename, md5) 68 | with tarfile.open(os.path.join(root, filename), "r") as tar: 69 | tar.extractall(path=root) -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import glob 4 | import os 5 | from setuptools import find_packages, setup 6 | import torch 7 | from torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension 8 | 9 | torch_ver = [int(x) for x in torch.__version__.split(".")[:2]] 10 | assert torch_ver >= [1, 1], "Requires PyTorch >= 1.1" 11 | 12 | 13 | def get_extensions(): 14 | this_dir = os.path.dirname(os.path.abspath(__file__)) 15 | extensions_dir = os.path.join(this_dir, "segmentron", "modules", "csrc") 16 | 17 | main_source = os.path.join(extensions_dir, "vision.cpp") 18 | sources = glob.glob(os.path.join(extensions_dir, "**", "*.cpp")) 19 | source_cuda = glob.glob(os.path.join(extensions_dir, "**", "*.cu")) + glob.glob( 20 | os.path.join(extensions_dir, "*.cu") 21 | ) 22 | 23 | sources = [main_source] + sources 24 | 25 | extension = CppExtension 26 | 27 | extra_compile_args = {"cxx": []} 28 | define_macros = [] 29 | 30 | if (torch.cuda.is_available() and CUDA_HOME is not None) or os.getenv("FORCE_CUDA", "0") == "1": 31 | extension = CUDAExtension 32 | sources += source_cuda 33 | define_macros += [("WITH_CUDA", None)] 34 | extra_compile_args["nvcc"] = [ 35 | "-DCUDA_HAS_FP16=1", 36 | "-D__CUDA_NO_HALF_OPERATORS__", 37 | "-D__CUDA_NO_HALF_CONVERSIONS__", 38 | "-D__CUDA_NO_HALF2_OPERATORS__", 39 | ] 40 | 41 | # It's better if pytorch can do this by default .. 42 | CC = os.environ.get("CC", None) 43 | if CC is not None: 44 | extra_compile_args["nvcc"].append("-ccbin={}".format(CC)) 45 | 46 | sources = [os.path.join(extensions_dir, s) for s in sources] 47 | 48 | include_dirs = [extensions_dir] 49 | 50 | ext_modules = [ 51 | extension( 52 | "segmentron._C", 53 | sources, 54 | include_dirs=include_dirs, 55 | define_macros=define_macros, 56 | extra_compile_args=extra_compile_args, 57 | ) 58 | ] 59 | 60 | return ext_modules 61 | 62 | 63 | setup( 64 | name="segmentron", 65 | version="0.1", 66 | author="LikeLy-Journey", 67 | url="https://github.com/LikeLy-Journey/SegmenTron", 68 | description="platform for semantic segmentation base on pytorch.", 69 | # packages=find_packages(exclude=("configs", "tests")), 70 | # python_requires=">=3.6", 71 | # install_requires=[ 72 | # "termcolor>=1.1", 73 | # "Pillow", 74 | # "yacs>=0.1.6", 75 | # "tabulate", 76 | # "cloudpickle", 77 | # "matplotlib", 78 | # "tqdm>4.29.0", 79 | # "tensorboard", 80 | # ], 81 | # extras_require={"all": ["shapely", "psutil"]}, 82 | ext_modules=get_extensions(), 83 | cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, 84 | ) 85 | -------------------------------------------------------------------------------- /segmentron/models/deeplabv3_plus.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .segbase import SegBaseModel 6 | from .model_zoo import MODEL_REGISTRY 7 | from ..modules import _ConvBNReLU, SeparableConv2d, _ASPP, _FCNHead 8 | from ..config import cfg 9 | 10 | __all__ = ['DeepLabV3Plus'] 11 | 12 | 13 | @MODEL_REGISTRY.register(name='DeepLabV3_Plus') 14 | class DeepLabV3Plus(SegBaseModel): 15 | r"""DeepLabV3Plus 16 | Reference: 17 | Chen, Liang-Chieh, et al. "Encoder-Decoder with Atrous Separable Convolution for Semantic 18 | Image Segmentation." 19 | """ 20 | def __init__(self): 21 | super(DeepLabV3Plus, self).__init__() 22 | if self.backbone.startswith('mobilenet'): 23 | c1_channels = 24 24 | c4_channels = 320 25 | else: 26 | c1_channels = 256 27 | c4_channels = 2048 28 | self.head = _DeepLabHead(self.nclass, c1_channels=c1_channels, c4_channels=c4_channels) 29 | if self.aux: 30 | self.auxlayer = _FCNHead(728, self.nclass) 31 | self.__setattr__('decoder', ['head', 'auxlayer'] if self.aux else ['head']) 32 | 33 | def forward(self, x): 34 | size = x.size()[2:] 35 | c1, _, c3, c4 = self.encoder(x) 36 | 37 | outputs = list() 38 | x = self.head(c4, c1) 39 | x = F.interpolate(x, size, mode='bilinear', align_corners=True) 40 | 41 | outputs.append(x) 42 | if self.aux: 43 | auxout = self.auxlayer(c3) 44 | auxout = F.interpolate(auxout, size, mode='bilinear', align_corners=True) 45 | outputs.append(auxout) 46 | return tuple(outputs) 47 | 48 | 49 | class _DeepLabHead(nn.Module): 50 | def __init__(self, nclass, c1_channels=256, c4_channels=2048, norm_layer=nn.BatchNorm2d): 51 | super(_DeepLabHead, self).__init__() 52 | self.use_aspp = cfg.MODEL.DEEPLABV3_PLUS.USE_ASPP 53 | self.use_decoder = cfg.MODEL.DEEPLABV3_PLUS.ENABLE_DECODER 54 | last_channels = c4_channels 55 | if self.use_aspp: 56 | self.aspp = _ASPP(c4_channels, 256) 57 | last_channels = 256 58 | if self.use_decoder: 59 | self.c1_block = _ConvBNReLU(c1_channels, 48, 1, norm_layer=norm_layer) 60 | last_channels += 48 61 | self.block = nn.Sequential( 62 | SeparableConv2d(last_channels, 256, 3, norm_layer=norm_layer, relu_first=False), 63 | SeparableConv2d(256, 256, 3, norm_layer=norm_layer, relu_first=False), 64 | nn.Conv2d(256, nclass, 1)) 65 | 66 | def forward(self, x, c1): 67 | size = c1.size()[2:] 68 | if self.use_aspp: 69 | x = self.aspp(x) 70 | if self.use_decoder: 71 | x = F.interpolate(x, size, mode='bilinear', align_corners=True) 72 | c1 = self.c1_block(c1) 73 | return self.block(torch.cat([x, c1], dim=1)) 74 | 75 | return self.block(x) 76 | -------------------------------------------------------------------------------- /segmentron/models/ccnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .segbase import SegBaseModel 6 | from .model_zoo import MODEL_REGISTRY 7 | from ..modules import _FCNHead 8 | from ..modules.cc_attention import CrissCrossAttention 9 | from ..config import cfg 10 | 11 | @MODEL_REGISTRY.register() 12 | class CCNet(SegBaseModel): 13 | r"""CCNet 14 | Reference: 15 | Zilong Huang, et al. "CCNet: Criss-Cross Attention for Semantic Segmentation." 16 | arXiv preprint arXiv:1811.11721 (2018). 17 | """ 18 | 19 | def __init__(self): 20 | super(CCNet, self).__init__() 21 | self.head = _CCHead(self.nclass, norm_layer=self.norm_layer) 22 | if self.aux: 23 | self.auxlayer = _FCNHead(1024, self.nclass, norm_layer=self.norm_layer) 24 | 25 | self.__setattr__('decoder', ['head', 'auxlayer'] if self.aux else ['head']) 26 | 27 | def forward(self, x): 28 | size = x.size()[2:] 29 | _, _, c3, c4 = self.base_forward(x) 30 | outputs = list() 31 | x = self.head(c4) 32 | x = F.interpolate(x, size, mode='bilinear', align_corners=True) 33 | outputs.append(x) 34 | 35 | if self.aux: 36 | auxout = self.auxlayer(c3) 37 | auxout = F.interpolate(auxout, size, mode='bilinear', align_corners=True) 38 | outputs.append(auxout) 39 | return tuple(outputs) 40 | 41 | 42 | class _CCHead(nn.Module): 43 | def __init__(self, nclass, norm_layer=nn.BatchNorm2d): 44 | super(_CCHead, self).__init__() 45 | self.rcca = _RCCAModule(2048, 512, norm_layer) 46 | self.out = nn.Conv2d(512, nclass, 1) 47 | 48 | def forward(self, x): 49 | x = self.rcca(x) 50 | x = self.out(x) 51 | return x 52 | 53 | 54 | class _RCCAModule(nn.Module): 55 | def __init__(self, in_channels, out_channels, norm_layer): 56 | super(_RCCAModule, self).__init__() 57 | self.recurrence = cfg.MODEL.CCNET.RECURRENCE 58 | inter_channels = in_channels // 4 59 | self.conva = nn.Sequential( 60 | nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False), 61 | norm_layer(inter_channels), 62 | nn.ReLU(True)) 63 | self.cca = CrissCrossAttention(inter_channels) 64 | self.convb = nn.Sequential( 65 | nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False), 66 | norm_layer(inter_channels), 67 | nn.ReLU(True)) 68 | 69 | self.bottleneck = nn.Sequential( 70 | nn.Conv2d(in_channels + inter_channels, out_channels, 3, padding=1, bias=False), 71 | norm_layer(out_channels), 72 | nn.Dropout2d(0.1)) 73 | 74 | def forward(self, x): 75 | out = self.conva(x) 76 | for i in range(self.recurrence): 77 | out = self.cca(out) 78 | out = self.convb(out) 79 | out = torch.cat([x, out], dim=1) 80 | out = self.bottleneck(out) 81 | 82 | return out 83 | -------------------------------------------------------------------------------- /segmentron/data/downloader/mscoco.py: -------------------------------------------------------------------------------- 1 | """Prepare MS COCO datasets""" 2 | import os 3 | import sys 4 | import argparse 5 | import zipfile 6 | 7 | # TODO: optim code 8 | cur_path = os.path.abspath(os.path.dirname(__file__)) 9 | root_path = os.path.split(os.path.split(os.path.split(cur_path)[0])[0])[0] 10 | sys.path.append(root_path) 11 | 12 | from segmentron.utils import download, makedirs 13 | 14 | _TARGET_DIR = os.path.expanduser('~/.torch/datasets/coco') 15 | 16 | 17 | def parse_args(): 18 | parser = argparse.ArgumentParser( 19 | description='Initialize MS COCO dataset.', 20 | epilog='Example: python mscoco.py --download-dir ~/mscoco', 21 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 22 | parser.add_argument('--download-dir', type=str, default=None, help='dataset directory on disk') 23 | parser.add_argument('--no-download', action='store_true', help='disable automatic download if set') 24 | parser.add_argument('--overwrite', action='store_true', 25 | help='overwrite downloaded files if set, in case they are corrupted') 26 | args = parser.parse_args() 27 | return args 28 | 29 | 30 | def download_coco(path, overwrite=False): 31 | _DOWNLOAD_URLS = [ 32 | ('http://images.cocodataset.org/zips/train2017.zip', 33 | '10ad623668ab00c62c096f0ed636d6aff41faca5'), 34 | ('http://images.cocodataset.org/annotations/annotations_trainval2017.zip', 35 | '8551ee4bb5860311e79dace7e79cb91e432e78b3'), 36 | ('http://images.cocodataset.org/zips/val2017.zip', 37 | '4950dc9d00dbe1c933ee0170f5797584351d2a41'), 38 | # ('http://images.cocodataset.org/annotations/stuff_annotations_trainval2017.zip', 39 | # '46cdcf715b6b4f67e980b529534e79c2edffe084'), 40 | # test2017.zip, for those who want to attend the competition. 41 | # ('http://images.cocodataset.org/zips/test2017.zip', 42 | # '4e443f8a2eca6b1dac8a6c57641b67dd40621a49'), 43 | ] 44 | makedirs(path) 45 | for url, checksum in _DOWNLOAD_URLS: 46 | filename = download(url, path=path, overwrite=overwrite, sha1_hash=checksum) 47 | # extract 48 | with zipfile.ZipFile(filename) as zf: 49 | zf.extractall(path=path) 50 | 51 | 52 | if __name__ == '__main__': 53 | args = parse_args() 54 | default_dir = os.path.join(root_path, 'datasets/coco') 55 | if args.download_dir is not None: 56 | path = args.download_dir 57 | else: 58 | path = default_dir 59 | if not os.path.isdir(path) or not os.path.isdir(os.path.join(path, 'train2017')) \ 60 | or not os.path.isdir(os.path.join(path, 'val2017')) \ 61 | or not os.path.isdir(os.path.join(path, 'annotations')): 62 | if args.no_download: 63 | raise ValueError(('{} is not a valid directory, make sure it is present.' 64 | ' Or you should not disable "--no-download" to grab it'.format(path))) 65 | else: 66 | download_coco(path, overwrite=args.overwrite) 67 | 68 | # make symlink 69 | try: 70 | os.symlink(path, default_dir) 71 | except Exception as e: 72 | print(e) 73 | -------------------------------------------------------------------------------- /segmentron/models/backbones/build.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import logging 4 | import torch.utils.model_zoo as model_zoo 5 | 6 | from ...utils.download import download 7 | from ...utils.registry import Registry 8 | from ...config import cfg 9 | 10 | BACKBONE_REGISTRY = Registry("BACKBONE") 11 | BACKBONE_REGISTRY.__doc__ = """ 12 | Registry for backbone, i.e. resnet. 13 | 14 | The registered object will be called with `obj()` 15 | and expected to return a `nn.Module` object. 16 | """ 17 | 18 | model_urls = { 19 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 20 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 21 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 22 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 23 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 24 | 'resnet50c': 'https://github.com/LikeLy-Journey/SegmenTron/releases/download/v0.1.0/resnet50-25c4b509.pth', 25 | 'resnet101c': 'https://github.com/LikeLy-Journey/SegmenTron/releases/download/v0.1.0/resnet101-2a57e44d.pth', 26 | 'resnet152c': 'https://github.com/LikeLy-Journey/SegmenTron/releases/download/v0.1.0/resnet152-0d43d698.pth', 27 | 'xception65': 'https://github.com/LikeLy-Journey/SegmenTron/releases/download/v0.1.0/tf-xception65-270e81cf.pth', 28 | 'hrnet_w18_small_v1': 'https://github.com/LikeLy-Journey/SegmenTron/releases/download/v0.1.0/hrnet-w18-small-v1-08f8ae64.pth', 29 | 'mobilenet_v2': 'https://github.com/LikeLy-Journey/SegmenTron/releases/download/v0.1.0/mobilenetV2-15498621.pth', 30 | } 31 | 32 | 33 | def load_backbone_pretrained(model, backbone): 34 | if cfg.PHASE == 'train' and cfg.TRAIN.BACKBONE_PRETRAINED and (not cfg.TRAIN.PRETRAINED_MODEL_PATH): 35 | if os.path.isfile(cfg.TRAIN.BACKBONE_PRETRAINED_PATH): 36 | logging.info('Load backbone pretrained model from {}'.format( 37 | cfg.TRAIN.BACKBONE_PRETRAINED_PATH 38 | )) 39 | msg = model.load_state_dict(torch.load(cfg.TRAIN.BACKBONE_PRETRAINED_PATH), strict=False) 40 | logging.info(msg) 41 | elif backbone not in model_urls: 42 | logging.info('{} has no pretrained model'.format(backbone)) 43 | return 44 | else: 45 | logging.info('load backbone pretrained model from url..') 46 | try: 47 | msg = model.load_state_dict(model_zoo.load_url(model_urls[backbone]), strict=False) 48 | except Exception as e: 49 | logging.warning(e) 50 | logging.info('Use torch download failed, try custom method!') 51 | 52 | msg = model.load_state_dict(torch.load(download(model_urls[backbone], 53 | path=os.path.join(torch.hub._get_torch_home(), 'checkpoints'))), strict=False) 54 | logging.info(msg) 55 | 56 | 57 | def get_segmentation_backbone(backbone, norm_layer=torch.nn.BatchNorm2d): 58 | """ 59 | Built the backbone model, defined by `cfg.MODEL.BACKBONE`. 60 | """ 61 | model = BACKBONE_REGISTRY.get(backbone)(norm_layer) 62 | load_backbone_pretrained(model, backbone) 63 | return model 64 | 65 | -------------------------------------------------------------------------------- /segmentron/solver/optimizer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch.nn as nn 3 | 4 | from torch import optim 5 | from segmentron.config import cfg 6 | 7 | 8 | def _set_batch_norm_attr(named_modules, attr, value): 9 | for m in named_modules: 10 | if isinstance(m[1], (nn.BatchNorm2d, nn.SyncBatchNorm)): 11 | setattr(m[1], attr, value) 12 | 13 | 14 | def _get_paramters(model): 15 | params_list = list() 16 | if hasattr(model, 'encoder') and model.encoder is not None and hasattr(model, 'decoder'): 17 | params_list.append({'params': model.encoder.parameters(), 'lr': cfg.SOLVER.LR}) 18 | if cfg.MODEL.BN_EPS_FOR_ENCODER: 19 | logging.info('Set bn custom eps for bn in encoder: {}'.format(cfg.MODEL.BN_EPS_FOR_ENCODER)) 20 | _set_batch_norm_attr(model.encoder.named_modules(), 'eps', cfg.MODEL.BN_EPS_FOR_ENCODER) 21 | 22 | for module in model.decoder: 23 | params_list.append({'params': getattr(model, module).parameters(), 24 | 'lr': cfg.SOLVER.LR * cfg.SOLVER.DECODER_LR_FACTOR}) 25 | 26 | if cfg.MODEL.BN_EPS_FOR_DECODER: 27 | logging.info('Set bn custom eps for bn in decoder: {}'.format(cfg.MODEL.BN_EPS_FOR_DECODER)) 28 | for module in model.decoder: 29 | _set_batch_norm_attr(getattr(model, module).named_modules(), 'eps', 30 | cfg.MODEL.BN_EPS_FOR_DECODER) 31 | else: 32 | logging.info('Model do not have encoder or decoder, params list was from model.parameters(), ' 33 | 'and arguments BN_EPS_FOR_ENCODER, BN_EPS_FOR_DECODER, DECODER_LR_FACTOR not used!') 34 | params_list = model.parameters() 35 | 36 | if cfg.MODEL.BN_MOMENTUM and cfg.MODEL.BN_TYPE in ['BN']: 37 | logging.info('Set bn custom momentum: {}'.format(cfg.MODEL.BN_MOMENTUM)) 38 | _set_batch_norm_attr(model.named_modules(), 'momentum', cfg.MODEL.BN_MOMENTUM) 39 | elif cfg.MODEL.BN_MOMENTUM and cfg.MODEL.BN_TYPE not in ['BN']: 40 | logging.info('Batch norm type is {}, custom bn momentum is not effective!'.format(cfg.MODEL.BN_TYPE)) 41 | 42 | return params_list 43 | 44 | 45 | def get_optimizer(model): 46 | parameters = _get_paramters(model) 47 | opt_lower = cfg.SOLVER.OPTIMIZER.lower() 48 | 49 | if opt_lower == 'sgd': 50 | optimizer = optim.SGD( 51 | parameters, lr=cfg.SOLVER.LR, momentum=cfg.SOLVER.MOMENTUM, weight_decay=cfg.SOLVER.WEIGHT_DECAY) 52 | elif opt_lower == 'adam': 53 | optimizer = optim.Adam( 54 | parameters, lr=cfg.SOLVER.LR, eps=cfg.SOLVER.EPSILON, weight_decay=cfg.SOLVER.WEIGHT_DECAY) 55 | elif opt_lower == 'adadelta': 56 | optimizer = optim.Adadelta( 57 | parameters, lr=cfg.SOLVER.LR, eps=cfg.SOLVER.EPSILON, weight_decay=cfg.SOLVER.WEIGHT_DECAY) 58 | elif opt_lower == 'rmsprop': 59 | optimizer = optim.RMSprop( 60 | parameters, lr=cfg.SOLVER.LR, alpha=0.9, eps=cfg.SOLVER.EPSILON, 61 | momentum=cfg.SOLVER.MOMENTUM, weight_decay=cfg.SOLVER.WEIGHT_DECAY) 62 | else: 63 | raise ValueError("Expected optimizer method in [sgd, adam, adadelta, rmsprop], but received " 64 | "{}".format(opt_lower)) 65 | 66 | return optimizer 67 | -------------------------------------------------------------------------------- /segmentron/utils/download.py: -------------------------------------------------------------------------------- 1 | import os 2 | import hashlib 3 | import requests 4 | from tqdm import tqdm 5 | 6 | def check_sha1(filename, sha1_hash): 7 | """Check whether the sha1 hash of the file content matches the expected hash. 8 | Parameters 9 | ---------- 10 | filename : str 11 | Path to the file. 12 | sha1_hash : str 13 | Expected sha1 hash in hexadecimal digits. 14 | Returns 15 | ------- 16 | bool 17 | Whether the file content matches the expected hash. 18 | """ 19 | sha1 = hashlib.sha1() 20 | with open(filename, 'rb') as f: 21 | while True: 22 | data = f.read(1048576) 23 | if not data: 24 | break 25 | sha1.update(data) 26 | 27 | sha1_file = sha1.hexdigest() 28 | l = min(len(sha1_file), len(sha1_hash)) 29 | return sha1.hexdigest()[0:l] == sha1_hash[0:l] 30 | 31 | def download(url, path=None, overwrite=False, sha1_hash=None): 32 | """Download an given URL 33 | Parameters 34 | ---------- 35 | url : str 36 | URL to download 37 | path : str, optional 38 | Destination path to store downloaded file. By default stores to the 39 | current directory with same name as in url. 40 | overwrite : bool, optional 41 | Whether to overwrite destination file if already exists. 42 | sha1_hash : str, optional 43 | Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified 44 | but doesn't match. 45 | Returns 46 | ------- 47 | str 48 | The file path of the downloaded file. 49 | """ 50 | if path is None: 51 | fname = url.split('/')[-1] 52 | else: 53 | path = os.path.expanduser(path) 54 | if os.path.isdir(path): 55 | fname = os.path.join(path, url.split('/')[-1]) 56 | else: 57 | fname = path 58 | 59 | if overwrite or not os.path.exists(fname) or (sha1_hash and not check_sha1(fname, sha1_hash)): 60 | dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname))) 61 | if not os.path.exists(dirname): 62 | os.makedirs(dirname) 63 | 64 | print('Downloading %s from %s...'%(fname, url)) 65 | r = requests.get(url, stream=True) 66 | if r.status_code != 200: 67 | raise RuntimeError("Failed downloading url %s"%url) 68 | total_length = r.headers.get('content-length') 69 | with open(fname, 'wb') as f: 70 | if total_length is None: # no content length header 71 | for chunk in r.iter_content(chunk_size=1024): 72 | if chunk: # filter out keep-alive new chunks 73 | f.write(chunk) 74 | else: 75 | total_length = int(total_length) 76 | for chunk in tqdm(r.iter_content(chunk_size=1024), 77 | total=int(total_length / 1024. + 0.5), 78 | unit='KB', unit_scale=False, dynamic_ncols=True): 79 | f.write(chunk) 80 | 81 | if sha1_hash and not check_sha1(fname, sha1_hash): 82 | raise UserWarning('File {} is downloaded but the content hash does not match. ' \ 83 | 'The repo may be outdated or download may be incomplete. ' \ 84 | 'If the "repo_url" is overridden, consider switching to ' \ 85 | 'the default repo.'.format(fname)) 86 | 87 | return fname -------------------------------------------------------------------------------- /segmentron/models/danet.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from .segbase import SegBaseModel 8 | from .model_zoo import MODEL_REGISTRY 9 | from ..modules import _FCNHead, PAM_Module, CAM_Module 10 | 11 | __all__ = ['DANet'] 12 | 13 | 14 | @MODEL_REGISTRY.register() 15 | class DANet(SegBaseModel): 16 | r"""DANet model from the paper `"Dual Attention Network for Scene Segmentation" 17 | ` 18 | """ 19 | def __init__(self): 20 | super(DANet, self).__init__() 21 | self.head = DANetHead(2048, self.nclass) 22 | if self.aux: 23 | self.auxlayer = _FCNHead(728, self.nclass) 24 | self.__setattr__('decoder', ['head', 'auxlayer'] if self.aux else ['head']) 25 | 26 | def forward(self, x): 27 | imsize = x.size()[2:] 28 | _, _, c3, c4 = self.encoder(x) 29 | 30 | x = self.head(c4) 31 | x = list(x) 32 | x[0] = F.interpolate(x[0], imsize, mode='bilinear', align_corners=True) 33 | x[1] = F.interpolate(x[1], imsize, mode='bilinear', align_corners=True) 34 | x[2] = F.interpolate(x[2], imsize, mode='bilinear', align_corners=True) 35 | 36 | outputs = list() 37 | outputs.append(x[0]) 38 | outputs.append(x[1]) 39 | outputs.append(x[2]) 40 | 41 | return tuple(outputs) 42 | 43 | 44 | class DANetHead(nn.Module): 45 | def __init__(self, in_channels, out_channels, norm_layer=nn.BatchNorm2d): 46 | super(DANetHead, self).__init__() 47 | inter_channels = in_channels // 4 48 | self.conv5a = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False), 49 | norm_layer(inter_channels), 50 | nn.ReLU()) 51 | 52 | self.conv5c = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False), 53 | norm_layer(inter_channels), 54 | nn.ReLU()) 55 | 56 | self.sa = PAM_Module(inter_channels) 57 | self.sc = CAM_Module(inter_channels) 58 | self.conv51 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False), 59 | norm_layer(inter_channels), 60 | nn.ReLU()) 61 | self.conv52 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False), 62 | norm_layer(inter_channels), 63 | nn.ReLU()) 64 | 65 | self.conv6 = nn.Sequential(nn.Dropout2d(0.1, False), nn.Conv2d(512, out_channels, 1)) 66 | self.conv7 = nn.Sequential(nn.Dropout2d(0.1, False), nn.Conv2d(512, out_channels, 1)) 67 | 68 | self.conv8 = nn.Sequential(nn.Dropout2d(0.1, False), nn.Conv2d(512, out_channels, 1)) 69 | 70 | def forward(self, x): 71 | feat1 = self.conv5a(x) 72 | sa_feat = self.sa(feat1) 73 | sa_conv = self.conv51(sa_feat) 74 | sa_output = self.conv6(sa_conv) 75 | 76 | feat2 = self.conv5c(x) 77 | sc_feat = self.sc(feat2) 78 | sc_conv = self.conv52(sc_feat) 79 | sc_output = self.conv7(sc_conv) 80 | 81 | feat_sum = sa_conv+sc_conv 82 | 83 | sasc_output = self.conv8(feat_sum) 84 | 85 | output = [sasc_output] 86 | output.append(sa_output) 87 | output.append(sc_output) 88 | return tuple(output) 89 | 90 | -------------------------------------------------------------------------------- /segmentron/models/icnet.py: -------------------------------------------------------------------------------- 1 | """Image Cascade Network""" 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from .segbase import SegBaseModel 7 | from .model_zoo import MODEL_REGISTRY 8 | from ..modules.basic import _ConvBNReLU 9 | from ..config import cfg 10 | 11 | __all__ = ['ICNet'] 12 | 13 | 14 | @MODEL_REGISTRY.register() 15 | class ICNet(SegBaseModel): 16 | """Image Cascade Network""" 17 | 18 | def __init__(self): 19 | super(ICNet, self).__init__() 20 | self.conv_sub1 = nn.Sequential( 21 | _ConvBNReLU(3, 32, 3, 2), 22 | _ConvBNReLU(32, 32, 3, 2), 23 | _ConvBNReLU(32, 64, 3, 2) 24 | ) 25 | 26 | self.head = _ICHead(self.nclass) 27 | self.__setattr__('decoder', ['conv_sub1', 'head']) 28 | 29 | def forward(self, x): 30 | size = x.size()[2:] 31 | # sub 1 32 | x_sub1 = self.conv_sub1(x) 33 | 34 | # sub 2 35 | x_sub2 = F.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=True) 36 | _, x_sub2, _, _ = self.encoder(x_sub2) 37 | 38 | # sub 4 39 | x_sub4 = F.interpolate(x, scale_factor=0.25, mode='bilinear', align_corners=True) 40 | _, _, _, x_sub4 = self.encoder(x_sub4) 41 | 42 | outputs = self.head(x_sub1, x_sub2, x_sub4, size) 43 | 44 | return tuple(outputs) 45 | 46 | 47 | class _ICHead(nn.Module): 48 | def __init__(self, nclass, norm_layer=nn.BatchNorm2d): 49 | super(_ICHead, self).__init__() 50 | scale = cfg.MODEL.BACKBONE_SCALE 51 | self.cff_12 = CascadeFeatureFusion(int(512 * scale), 64, 128, nclass, norm_layer) 52 | self.cff_24 = CascadeFeatureFusion(int(2048 * scale), int(512 * scale), 128, nclass, norm_layer) 53 | self.conv_cls = nn.Conv2d(128, nclass, 1, bias=False) 54 | 55 | def forward(self, x_sub1, x_sub2, x_sub4, size): 56 | outputs = list() 57 | x_cff_24, x_24_cls = self.cff_24(x_sub4, x_sub2) 58 | outputs.append(x_24_cls) 59 | x_cff_12, x_12_cls = self.cff_12(x_sub2, x_sub1) 60 | outputs.append(x_12_cls) 61 | 62 | up_x2 = F.interpolate(x_cff_12, scale_factor=2, mode='bilinear', align_corners=True) 63 | up_x2 = self.conv_cls(up_x2) 64 | outputs.append(up_x2) 65 | 66 | up_x8 = F.interpolate(up_x2, size, mode='bilinear', align_corners=True) 67 | outputs.append(up_x8) 68 | # 1 -> 1/4 -> 1/8 -> 1/16 69 | outputs.reverse() 70 | 71 | return outputs 72 | 73 | 74 | class CascadeFeatureFusion(nn.Module): 75 | """CFF Unit""" 76 | 77 | def __init__(self, low_channels, high_channels, out_channels, nclass, norm_layer=nn.BatchNorm2d): 78 | super(CascadeFeatureFusion, self).__init__() 79 | self.conv_low = nn.Sequential( 80 | nn.Conv2d(low_channels, out_channels, 3, padding=2, dilation=2, bias=False), 81 | norm_layer(out_channels) 82 | ) 83 | self.conv_high = nn.Sequential( 84 | nn.Conv2d(high_channels, out_channels, 1, bias=False), 85 | norm_layer(out_channels) 86 | ) 87 | self.conv_low_cls = nn.Conv2d(out_channels, nclass, 1, bias=False) 88 | 89 | def forward(self, x_low, x_high): 90 | x_low = F.interpolate(x_low, size=x_high.size()[2:], mode='bilinear', align_corners=True) 91 | x_low = self.conv_low(x_low) 92 | x_high = self.conv_high(x_high) 93 | x = x_low + x_high 94 | x = F.relu(x, inplace=True) 95 | x_low_cls = self.conv_low_cls(x_low) 96 | 97 | return x, x_low_cls 98 | -------------------------------------------------------------------------------- /segmentron/models/espnetv2.py: -------------------------------------------------------------------------------- 1 | "ESPNetv2: A Light-weight, Power Efficient, and General Purpose for Semantic Segmentation" 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from .segbase import SegBaseModel 7 | from .model_zoo import MODEL_REGISTRY 8 | from ..modules import _ConvBNPReLU, EESP, _BNPReLU, _FCNHead 9 | from ..config import cfg 10 | 11 | 12 | @MODEL_REGISTRY.register() 13 | class ESPNetV2(SegBaseModel): 14 | r"""ESPNetV2 15 | Reference: 16 | Sachin Mehta, et al. "ESPNetv2: A Light-weight, Power Efficient, and General Purpose Convolutional Neural Network." 17 | arXiv preprint arXiv:1811.11431 (2018). 18 | """ 19 | 20 | def __init__(self, **kwargs): 21 | super(ESPNetV2, self).__init__() 22 | self.proj_L4_C = _ConvBNPReLU(256, 128, 1, **kwargs) 23 | self.pspMod = nn.Sequential( 24 | EESP(256, 128, stride=1, k=4, r_lim=7, **kwargs), 25 | _PSPModule(128, 128, **kwargs)) 26 | self.project_l3 = nn.Sequential( 27 | nn.Dropout2d(0.1), 28 | nn.Conv2d(128, self.nclass, 1, bias=False)) 29 | self.act_l3 = _BNPReLU(self.nclass, **kwargs) 30 | self.project_l2 = _ConvBNPReLU(64 + self.nclass, self.nclass, 1, **kwargs) 31 | self.project_l1 = nn.Sequential( 32 | nn.Dropout2d(0.1), 33 | nn.Conv2d(32 + self.nclass, self.nclass, 1, bias=False)) 34 | 35 | self.__setattr__('exclusive', ['proj_L4_C', 'pspMod', 'project_l3', 'act_l3', 'project_l2', 'project_l1']) 36 | 37 | def forward(self, x): 38 | size = x.size()[2:] 39 | out_l1, out_l2, out_l3, out_l4 = self.encoder(x, seg=True) 40 | out_l4_proj = self.proj_L4_C(out_l4) 41 | up_l4_to_l3 = F.interpolate(out_l4_proj, scale_factor=2, mode='bilinear', align_corners=True) 42 | merged_l3_upl4 = self.pspMod(torch.cat([out_l3, up_l4_to_l3], 1)) 43 | proj_merge_l3_bef_act = self.project_l3(merged_l3_upl4) 44 | proj_merge_l3 = self.act_l3(proj_merge_l3_bef_act) 45 | out_up_l3 = F.interpolate(proj_merge_l3, scale_factor=2, mode='bilinear', align_corners=True) 46 | merge_l2 = self.project_l2(torch.cat([out_l2, out_up_l3], 1)) 47 | out_up_l2 = F.interpolate(merge_l2, scale_factor=2, mode='bilinear', align_corners=True) 48 | merge_l1 = self.project_l1(torch.cat([out_l1, out_up_l2], 1)) 49 | 50 | outputs = list() 51 | merge1_l1 = F.interpolate(merge_l1, scale_factor=2, mode='bilinear', align_corners=True) 52 | outputs.append(merge1_l1) 53 | if self.aux: 54 | # different from paper 55 | auxout = F.interpolate(proj_merge_l3_bef_act, size, mode='bilinear', align_corners=True) 56 | outputs.append(auxout) 57 | 58 | return tuple(outputs) 59 | 60 | 61 | # different from PSPNet 62 | class _PSPModule(nn.Module): 63 | def __init__(self, in_channels, out_channels=1024, sizes=(1, 2, 4, 8), **kwargs): 64 | super(_PSPModule, self).__init__() 65 | self.stages = nn.ModuleList( 66 | [nn.Conv2d(in_channels, in_channels, 3, 1, 1, groups=in_channels, bias=False) for _ in sizes]) 67 | self.project = _ConvBNPReLU(in_channels * (len(sizes) + 1), out_channels, 1, 1, **kwargs) 68 | 69 | def forward(self, x): 70 | size = x.size()[2:] 71 | feats = [x] 72 | for stage in self.stages: 73 | x = F.avg_pool2d(x, kernel_size=3, stride=2, padding=1) 74 | upsampled = F.interpolate(stage(x), size, mode='bilinear', align_corners=True) 75 | feats.append(upsampled) 76 | return self.project(torch.cat(feats, dim=1)) 77 | -------------------------------------------------------------------------------- /segmentron/data/dataloader/sbu_shadow.py: -------------------------------------------------------------------------------- 1 | """SBU Shadow Segmentation Dataset.""" 2 | import os 3 | import torch 4 | import numpy as np 5 | 6 | from PIL import Image 7 | from .seg_data_base import SegmentationDataset 8 | 9 | 10 | class SBUSegmentation(SegmentationDataset): 11 | """SBU Shadow Segmentation Dataset 12 | """ 13 | NUM_CLASS = 2 14 | 15 | def __init__(self, root='datasets/sbu', split='train', mode=None, transform=None, **kwargs): 16 | super(SBUSegmentation, self).__init__(root, split, mode, transform, **kwargs) 17 | assert os.path.exists(self.root) 18 | self.images, self.masks = _get_sbu_pairs(self.root, self.split) 19 | assert (len(self.images) == len(self.masks)) 20 | if len(self.images) == 0: 21 | raise RuntimeError("Found 0 images in subfolders of:" + root + "\n") 22 | 23 | def __getitem__(self, index): 24 | img = Image.open(self.images[index]).convert('RGB') 25 | if self.mode == 'test': 26 | if self.transform is not None: 27 | img = self.transform(img) 28 | return img, os.path.basename(self.images[index]) 29 | mask = Image.open(self.masks[index]) 30 | # synchrosized transform 31 | if self.mode == 'train': 32 | img, mask = self._sync_transform(img, mask) 33 | elif self.mode == 'val': 34 | img, mask = self._val_sync_transform(img, mask) 35 | else: 36 | assert self.mode == 'testval' 37 | img, mask = self._img_transform(img), self._mask_transform(mask) 38 | # general resize, normalize and toTensor 39 | if self.transform is not None: 40 | img = self.transform(img) 41 | return img, mask, os.path.basename(self.images[index]) 42 | 43 | def _mask_transform(self, mask): 44 | target = np.array(mask).astype('int32') 45 | target[target > 0] = 1 46 | return torch.from_numpy(target).long() 47 | 48 | def __len__(self): 49 | return len(self.images) 50 | 51 | @property 52 | def pred_offset(self): 53 | return 0 54 | 55 | 56 | def _get_sbu_pairs(folder, split='train'): 57 | def get_path_pairs(img_folder, mask_folder): 58 | img_paths = [] 59 | mask_paths = [] 60 | for root, _, files in os.walk(img_folder): 61 | print(root) 62 | for filename in files: 63 | if filename.endswith('.jpg'): 64 | imgpath = os.path.join(root, filename) 65 | maskname = filename.replace('.jpg', '.png') 66 | maskpath = os.path.join(mask_folder, maskname) 67 | if os.path.isfile(imgpath) and os.path.isfile(maskpath): 68 | img_paths.append(imgpath) 69 | mask_paths.append(maskpath) 70 | else: 71 | print('cannot find the mask or image:', imgpath, maskpath) 72 | print('Found {} images in the folder {}'.format(len(img_paths), img_folder)) 73 | return img_paths, mask_paths 74 | 75 | if split == 'train': 76 | img_folder = os.path.join(folder, 'SBUTrain4KRecoveredSmall/ShadowImages') 77 | mask_folder = os.path.join(folder, 'SBUTrain4KRecoveredSmall/ShadowMasks') 78 | img_paths, mask_paths = get_path_pairs(img_folder, mask_folder) 79 | else: 80 | assert split in ('val', 'test') 81 | img_folder = os.path.join(folder, 'SBU-Test/ShadowImages') 82 | mask_folder = os.path.join(folder, 'SBU-Test/ShadowMasks') 83 | img_paths, mask_paths = get_path_pairs(img_folder, mask_folder) 84 | return img_paths, mask_paths 85 | 86 | 87 | if __name__ == '__main__': 88 | dataset = SBUSegmentation(base_size=280, crop_size=256) -------------------------------------------------------------------------------- /segmentron/data/dataloader/lip_parsing.py: -------------------------------------------------------------------------------- 1 | """Look into Person Dataset""" 2 | import os 3 | import torch 4 | import numpy as np 5 | 6 | from PIL import Image 7 | from segmentron.data.dataloader.seg_data_base import SegmentationDataset 8 | 9 | 10 | class LIPSegmentation(SegmentationDataset): 11 | """Look into person parsing dataset """ 12 | 13 | BASE_DIR = 'LIP' 14 | NUM_CLASS = 20 15 | 16 | def __init__(self, root='datasets/LIP', split='train', mode=None, transform=None, **kwargs): 17 | super(LIPSegmentation, self).__init__(root, split, mode, transform, **kwargs) 18 | _trainval_image_dir = os.path.join(root, 'TrainVal_images') 19 | _testing_image_dir = os.path.join(root, 'Testing_images') 20 | _trainval_mask_dir = os.path.join(root, 'TrainVal_parsing_annotations') 21 | if split == 'train': 22 | _image_dir = os.path.join(_trainval_image_dir, 'train_images') 23 | _mask_dir = os.path.join(_trainval_mask_dir, 'train_segmentations') 24 | _split_f = os.path.join(_trainval_image_dir, 'train_id.txt') 25 | elif split == 'val': 26 | _image_dir = os.path.join(_trainval_image_dir, 'val_images') 27 | _mask_dir = os.path.join(_trainval_mask_dir, 'val_segmentations') 28 | _split_f = os.path.join(_trainval_image_dir, 'val_id.txt') 29 | elif split == 'test': 30 | _image_dir = os.path.join(_testing_image_dir, 'testing_images') 31 | _split_f = os.path.join(_testing_image_dir, 'test_id.txt') 32 | else: 33 | raise RuntimeError('Unknown dataset split.') 34 | 35 | self.images = [] 36 | self.masks = [] 37 | with open(os.path.join(_split_f), 'r') as lines: 38 | for line in lines: 39 | _image = os.path.join(_image_dir, line.rstrip('\n') + '.jpg') 40 | assert os.path.isfile(_image) 41 | self.images.append(_image) 42 | if split != 'test': 43 | _mask = os.path.join(_mask_dir, line.rstrip('\n') + '.png') 44 | assert os.path.isfile(_mask) 45 | self.masks.append(_mask) 46 | 47 | if split != 'test': 48 | assert (len(self.images) == len(self.masks)) 49 | print('Found {} {} images in the folder {}'.format(len(self.images), split, root)) 50 | 51 | def __getitem__(self, index): 52 | img = Image.open(self.images[index]).convert('RGB') 53 | if self.mode == 'test': 54 | img = self._img_transform(img) 55 | if self.transform is not None: 56 | img = self.transform(img) 57 | return img, os.path.basename(self.images[index]) 58 | mask = Image.open(self.masks[index]) 59 | # synchronized transform 60 | if self.mode == 'train': 61 | img, mask = self._sync_transform(img, mask) 62 | elif self.mode == 'val': 63 | img, mask = self._val_sync_transform(img, mask) 64 | else: 65 | assert self.mode == 'testval' 66 | img, mask = self._img_transform(img), self._mask_transform(mask) 67 | # general resize, normalize and toTensor 68 | if self.transform is not None: 69 | img = self.transform(img) 70 | 71 | return img, mask, os.path.basename(self.images[index]) 72 | 73 | def __len__(self): 74 | return len(self.images) 75 | 76 | def _mask_transform(self, mask): 77 | target = np.array(mask).astype('int32') 78 | return torch.from_numpy(target).long() 79 | 80 | @property 81 | def classes(self): 82 | """Category name.""" 83 | return ('background', 'hat', 'hair', 'glove', 'sunglasses', 'upperclothes', 84 | 'dress', 'coat', 'socks', 'pants', 'jumpsuits', 'scarf', 'skirt', 85 | 'face', 'leftArm', 'rightArm', 'leftLeg', 'rightLeg', 'leftShoe', 86 | 'rightShoe') 87 | 88 | 89 | if __name__ == '__main__': 90 | dataset = LIPSegmentation(base_size=280, crop_size=256) -------------------------------------------------------------------------------- /segmentron/models/dfanet.py: -------------------------------------------------------------------------------- 1 | """ Deep Feature Aggregation""" 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from .segbase import SegBaseModel 7 | from .backbones.xception import Enc, FCAttention 8 | from .model_zoo import MODEL_REGISTRY 9 | from ..modules import _ConvBNReLU 10 | 11 | __all__ = ['DFANet'] 12 | 13 | 14 | @MODEL_REGISTRY.register() 15 | class DFANet(SegBaseModel): 16 | def __init__(self, **kwargs): 17 | super(DFANet, self).__init__() 18 | 19 | self.enc2_2 = Enc(240, 48, 4, **kwargs) 20 | self.enc3_2 = Enc(144, 96, 6, **kwargs) 21 | self.enc4_2 = Enc(288, 192, 4, **kwargs) 22 | self.fca_2 = FCAttention(192, **kwargs) 23 | 24 | self.enc2_3 = Enc(240, 48, 4, **kwargs) 25 | self.enc3_3 = Enc(144, 96, 6, **kwargs) 26 | self.enc3_4 = Enc(288, 192, 4, **kwargs) 27 | self.fca_3 = FCAttention(192, **kwargs) 28 | 29 | self.enc2_1_reduce = _ConvBNReLU(48, 32, 1, **kwargs) 30 | self.enc2_2_reduce = _ConvBNReLU(48, 32, 1, **kwargs) 31 | self.enc2_3_reduce = _ConvBNReLU(48, 32, 1, **kwargs) 32 | self.conv_fusion = _ConvBNReLU(32, 32, 1, **kwargs) 33 | 34 | self.fca_1_reduce = _ConvBNReLU(192, 32, 1, **kwargs) 35 | self.fca_2_reduce = _ConvBNReLU(192, 32, 1, **kwargs) 36 | self.fca_3_reduce = _ConvBNReLU(192, 32, 1, **kwargs) 37 | self.conv_out = nn.Conv2d(32, self.nclass, 1) 38 | 39 | self.__setattr__('decoder', ['enc2_2', 'enc3_2', 'enc4_2', 'fca_2', 'enc2_3', 'enc3_3', 'enc3_4', 40 | 'fca_3', 'enc2_1_reduce', 'enc2_2_reduce', 'enc2_3_reduce', 'conv_fusion', 41 | 'fca_1_reduce', 'fca_2_reduce', 'fca_3_reduce', 'conv_out']) 42 | 43 | def forward(self, x): 44 | # backbone 45 | stage1_conv1 = self.encoder.conv1(x) 46 | stage1_enc2 = self.encoder.enc2(stage1_conv1) 47 | stage1_enc3 = self.encoder.enc3(stage1_enc2) 48 | stage1_enc4 = self.encoder.enc4(stage1_enc3) 49 | stage1_fca = self.encoder.fca(stage1_enc4) 50 | stage1_out = F.interpolate(stage1_fca, scale_factor=4, mode='bilinear', align_corners=True) 51 | 52 | # stage2 53 | stage2_enc2 = self.enc2_2(torch.cat([stage1_enc2, stage1_out], dim=1)) 54 | stage2_enc3 = self.enc3_2(torch.cat([stage1_enc3, stage2_enc2], dim=1)) 55 | stage2_enc4 = self.enc4_2(torch.cat([stage1_enc4, stage2_enc3], dim=1)) 56 | stage2_fca = self.fca_2(stage2_enc4) 57 | stage2_out = F.interpolate(stage2_fca, scale_factor=4, mode='bilinear', align_corners=True) 58 | 59 | # stage3 60 | stage3_enc2 = self.enc2_3(torch.cat([stage2_enc2, stage2_out], dim=1)) 61 | stage3_enc3 = self.enc3_3(torch.cat([stage2_enc3, stage3_enc2], dim=1)) 62 | stage3_enc4 = self.enc3_4(torch.cat([stage2_enc4, stage3_enc3], dim=1)) 63 | stage3_fca = self.fca_3(stage3_enc4) 64 | 65 | stage1_enc2_decoder = self.enc2_1_reduce(stage1_enc2) 66 | stage2_enc2_docoder = F.interpolate(self.enc2_2_reduce(stage2_enc2), scale_factor=2, 67 | mode='bilinear', align_corners=True) 68 | stage3_enc2_decoder = F.interpolate(self.enc2_3_reduce(stage3_enc2), scale_factor=4, 69 | mode='bilinear', align_corners=True) 70 | fusion = stage1_enc2_decoder + stage2_enc2_docoder + stage3_enc2_decoder 71 | fusion = self.conv_fusion(fusion) 72 | 73 | stage1_fca_decoder = F.interpolate(self.fca_1_reduce(stage1_fca), scale_factor=4, 74 | mode='bilinear', align_corners=True) 75 | stage2_fca_decoder = F.interpolate(self.fca_2_reduce(stage2_fca), scale_factor=8, 76 | mode='bilinear', align_corners=True) 77 | stage3_fca_decoder = F.interpolate(self.fca_3_reduce(stage3_fca), scale_factor=16, 78 | mode='bilinear', align_corners=True) 79 | fusion = fusion + stage1_fca_decoder + stage2_fca_decoder + stage3_fca_decoder 80 | 81 | outputs = list() 82 | out = self.conv_out(fusion) 83 | out = F.interpolate(out, scale_factor=4, mode='bilinear', align_corners=True) 84 | outputs.append(out) 85 | 86 | return tuple(outputs) 87 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch for Semantic Segmentation 2 | ## Introduce 3 | This repository contains some models for semantic segmentation and the pipeline of training and testing models, 4 | implemented in PyTorch. 5 | 6 | ![](docs/images/demo.png) 7 | 8 | ## Model zoo 9 | 10 | |Model|Backbone|Datasets|eval size|Mean IoU(paper)|Mean IoU(this repo)| 11 | |:-:|:-:|:-:|:-:|:-:|:-:| 12 | |DeepLabv3_plus|xception65|cityscape(val)|(1025,2049)|78.8|[78.93](https://github.com/LikeLy-Journey/SegmenTron/releases/download/v0.1.0/deeplabv3_plus_xception_segmentron.pth)| 13 | |DeepLabv3_plus|xception65|coco(val)|480/520|-|[70.50](https://github.com/LikeLy-Journey/SegmenTron/releases/download/v0.1.0/deeplabv3_plus_xception_coco_segmentron.pth)| 14 | |DeepLabv3_plus|xception65|pascal_aug(val)|480/520|-|[89.56](https://github.com/LikeLy-Journey/SegmenTron/releases/download/v0.1.0/deeplabv3_plus_xception_pascal_aug_segmentron.pth)| 15 | |DeepLabv3_plus|xception65|pascal_voc(val)|480/520|-|[88.39](https://github.com/LikeLy-Journey/SegmenTron/releases/download/v0.1.0/deeplabv3_plus_xception_pascal_voc_segmentron.pth)| 16 | |DeepLabv3_plus|resnet101|cityscape(val)|(1025,2049)|-|[78.27](https://github.com/LikeLy-Journey/SegmenTron/releases/download/v0.1.0/deeplabv3_plus_resnet101_segmentron.pth)| 17 | |Danet|resnet101|cityscape(val)|(1024,2048)|79.9|[79.34](https://github.com/LikeLy-Journey/SegmenTron/releases/download/v0.1.0/danet101_segmentron.pth)| 18 | |Pspnet|resnet101|cityscape(val)|(1025,2049)|78.63|[77.00](https://github.com/LikeLy-Journey/SegmenTron/releases/download/v0.1.0/pspnet_resnet101_segmentron.pth)| 19 | 20 | ### real-time models 21 | Model|Backbone|Datasets|eval size|Mean IoU(paper)|Mean IoU(this repo)|FPS| 22 | |:-:|:-:|:-:|:-:|:-:|:-:|:-:| 23 | |ICnet|resnet50(0.5)|cityscape(val)|(1024,2048)|67.8|-|41.39| 24 | |DeepLabv3_plus|mobilenetV2|cityscape(val)|(1024,2048)|70.7|[70.3](https://github.com/LikeLy-Journey/SegmenTron/releases/download/v0.1.0/deeplabv3_plus_mobilenetv2_segmentron.pth)|46.64| 25 | |BiSeNet|resnet18|cityscape(val)|(1024,2048)|-|-|39.90| 26 | |LEDNet|-|cityscape(val)|(1024,2048)|-|-|31.78| 27 | |CGNet|-|cityscape(val)|(1024,2048)|-|-|46.11| 28 | |HardNet|-|cityscape(val)|(1024,2048)|75.9|-|69.06| 29 | |DFANet|xceptionA|cityscape(val)|(1024,2048)|70.3|-|21.46| 30 | |HRNet|w18_small_v1|cityscape(val)|(1024,2048)|70.3|[70.5](https://github.com/LikeLy-Journey/SegmenTron/releases/download/v0.1.0/hrnet_w18_small_v1_segmentron.pth)|66.01| 31 | |Fast_SCNN|-|cityscape(val)|(1024,2048)|68.3|[68.9](https://github.com/LikeLy-Journey/SegmenTron/releases/download/v0.1.0/fast_scnn_segmentron.pth)|145.77| 32 | 33 | FPS was tested on V100. 34 | 35 | ## Environments 36 | 37 | - python 3 38 | - torch >= 1.1.0 39 | - torchvision 40 | - pyyaml 41 | - Pillow 42 | - numpy 43 | 44 | ## INSTALL 45 | 46 | ``` 47 | python setup.py develop 48 | ``` 49 | 50 | if you do not want to run CCNet, you do not need to install, just comment following line in ```segmentron/models/__init__.py``` 51 | ``` 52 | from .ccnet import CCNet 53 | ``` 54 | ## Dataset prepare 55 | Support cityscape, coco, voc, ade20k now. 56 | 57 | Please refer to [DATA_PREPARE.md](docs/DATA_PREPARE.md) for dataset preparation. 58 | 59 | ## Pretrained backbone models 60 | 61 | pretrained backbone models will be download automatically in pytorch default directory(```~/.cache/torch/checkpoints/```). 62 | 63 | ## Code structure 64 | ``` 65 | ├── configs # yaml config file 66 | ├── segmentron # core code 67 | ├── tools # train eval code 68 | └── datasets # put datasets here 69 | ``` 70 | 71 | ## Train 72 | ### Train with a single GPU 73 | ``` 74 | CUDA_VISIBLE_DEVICES=0 python -u tools/train.py --config-file configs/cityscapes_deeplabv3_plus.yaml 75 | ``` 76 | ### Train with multiple GPUs 77 | ``` 78 | CUDA_VISIBLE_DEVICES=0,1,2,3 ./tools/dist_train.sh ${CONFIG_FILE} ${GPU_NUM} [optional arguments] 79 | ``` 80 | 81 | ## Eval 82 | ### Eval with a single GPU 83 | You can download trained model from model zoo table above, or train by yourself. 84 | ``` 85 | CUDA_VISIBLE_DEVICES=0 python -u ./tools/eval.py --config-file configs/cityscapes_deeplabv3_plus.yaml \ 86 | TEST.TEST_MODEL_PATH your_test_model_path 87 | 88 | ``` 89 | ### Eval with a multiple GPUs 90 | ``` 91 | CUDA_VISIBLE_DEVICES=0,1,2,3 ./tools/dist_test.sh ${CONFIG_FILE} ${GPU_NUM} \ 92 | TEST.TEST_MODEL_PATH your_test_model_path 93 | ``` 94 | 95 | ## References 96 | - [PyTorch-Encoding](https://github.com/zhanghang1989/PyTorch-Encoding) 97 | - [detectron2](https://github.com/facebookresearch/detectron2) 98 | - [gloun-cv](https://github.com/dmlc/gluon-cv) 99 | -------------------------------------------------------------------------------- /segmentron/data/dataloader/pascal_aug.py: -------------------------------------------------------------------------------- 1 | """Pascal Augmented VOC Semantic Segmentation Dataset.""" 2 | import os 3 | import torch 4 | import scipy.io as sio 5 | import numpy as np 6 | 7 | from PIL import Image 8 | from .seg_data_base import SegmentationDataset 9 | 10 | 11 | class VOCAugSegmentation(SegmentationDataset): 12 | """Pascal VOC Augmented Semantic Segmentation Dataset. 13 | 14 | Parameters 15 | ---------- 16 | root : string 17 | Path to VOCdevkit folder. Default is './datasets/voc' 18 | split: string 19 | 'train', 'val' or 'test' 20 | transform : callable, optional 21 | A function that transforms the image 22 | Examples 23 | -------- 24 | >>> from torchvision import transforms 25 | >>> import torch.utils.data as data 26 | >>> # Transforms for Normalization 27 | >>> input_transform = transforms.Compose([ 28 | >>> transforms.ToTensor(), 29 | >>> transforms.Normalize([.485, .456, .406], [.229, .224, .225]), 30 | >>> ]) 31 | >>> # Create Dataset 32 | >>> trainset = VOCAugSegmentation(split='train', transform=input_transform) 33 | >>> # Create Training Loader 34 | >>> train_data = data.DataLoader( 35 | >>> trainset, 4, shuffle=True, 36 | >>> num_workers=4) 37 | """ 38 | BASE_DIR = 'VOCaug/dataset/' 39 | NUM_CLASS = 21 40 | 41 | def __init__(self, root='datasets/voc', split='train', mode=None, transform=None, **kwargs): 42 | super(VOCAugSegmentation, self).__init__(root, split, mode, transform, **kwargs) 43 | # train/val/test splits are pre-cut 44 | _voc_root = os.path.join(root, self.BASE_DIR) 45 | _mask_dir = os.path.join(_voc_root, 'cls') 46 | _image_dir = os.path.join(_voc_root, 'img') 47 | if split == 'train': 48 | _split_f = os.path.join(_voc_root, 'trainval.txt') 49 | elif split == 'val': 50 | _split_f = os.path.join(_voc_root, 'val.txt') 51 | else: 52 | raise RuntimeError('Unknown dataset split: {}'.format(split)) 53 | 54 | self.images = [] 55 | self.masks = [] 56 | with open(os.path.join(_split_f), "r") as lines: 57 | for line in lines: 58 | _image = os.path.join(_image_dir, line.rstrip('\n') + ".jpg") 59 | assert os.path.isfile(_image) 60 | self.images.append(_image) 61 | _mask = os.path.join(_mask_dir, line.rstrip('\n') + ".mat") 62 | assert os.path.isfile(_mask) 63 | self.masks.append(_mask) 64 | 65 | assert (len(self.images) == len(self.masks)) 66 | print('Found {} images in the folder {}'.format(len(self.images), _voc_root)) 67 | 68 | def __getitem__(self, index): 69 | img = Image.open(self.images[index]).convert('RGB') 70 | target = self._load_mat(self.masks[index]) 71 | # synchrosized transform 72 | if self.mode == 'train': 73 | img, target = self._sync_transform(img, target) 74 | elif self.mode == 'val': 75 | img, target = self._val_sync_transform(img, target) 76 | elif self.mode == 'testval': 77 | logging.warn("Use mode of testval, you should set batch size=1") 78 | img, target = self._img_transform(img), self._mask_transform(target) 79 | else: 80 | raise RuntimeError('unknown mode for dataloader: {}'.format(self.mode)) 81 | # general resize, normalize and toTensor 82 | if self.transform is not None: 83 | img = self.transform(img) 84 | return img, target, os.path.basename(self.images[index]) 85 | 86 | def _mask_transform(self, mask): 87 | return torch.LongTensor(np.array(mask).astype('int32')) 88 | 89 | def _load_mat(self, filename): 90 | mat = sio.loadmat(filename, mat_dtype=True, squeeze_me=True, struct_as_record=False) 91 | mask = mat['GTcls'].Segmentation 92 | return Image.fromarray(mask) 93 | 94 | def __len__(self): 95 | return len(self.images) 96 | 97 | @property 98 | def classes(self): 99 | """Category names.""" 100 | return ('background', 'airplane', 'bicycle', 'bird', 'boat', 'bottle', 101 | 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 102 | 'motorcycle', 'person', 'potted-plant', 'sheep', 'sofa', 'train', 103 | 'tv') 104 | 105 | 106 | if __name__ == '__main__': 107 | dataset = VOCAugSegmentation() -------------------------------------------------------------------------------- /segmentron/models/dunet.py: -------------------------------------------------------------------------------- 1 | """Decoders Matter for Semantic Segmentation""" 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from .segbase import SegBaseModel 7 | from .model_zoo import MODEL_REGISTRY 8 | from .fcn import _FCNHead 9 | 10 | __all__ = ['DUNet'] 11 | 12 | 13 | @MODEL_REGISTRY.register() 14 | class DUNet(SegBaseModel): 15 | """Decoders Matter for Semantic Segmentation 16 | Reference: 17 | Zhi Tian, Tong He, Chunhua Shen, and Youliang Yan. 18 | "Decoders Matter for Semantic Segmentation: 19 | Data-Dependent Decoding Enables Flexible Feature Aggregation." CVPR, 2019 20 | """ 21 | 22 | def __init__(self): 23 | super(DUNet, self).__init__() 24 | self.head = _DUHead(2144, norm_layer=self.norm_layer) 25 | self.dupsample = DUpsampling(256, self.nclass, scale_factor=8) 26 | if self.aux: 27 | self.auxlayer = _FCNHead(1024, 256, norm_layer=self.norm_layer) 28 | self.aux_dupsample = DUpsampling(256, self.nclass, scale_factor=8) 29 | 30 | self.__setattr__('decoder', ['dupsample', 'head', 'auxlayer', 'aux_dupsample'] if self.aux else 31 | ['dupsample', 'head']) 32 | 33 | def forward(self, x): 34 | _, c2, c3, c4 = self.encoder(x) 35 | outputs = [] 36 | x = self.head(c2, c3, c4) 37 | x = self.dupsample(x) 38 | outputs.append(x) 39 | 40 | if self.aux: 41 | auxout = self.auxlayer(c3) 42 | auxout = self.aux_dupsample(auxout) 43 | outputs.append(auxout) 44 | return tuple(outputs) 45 | 46 | 47 | class FeatureFused(nn.Module): 48 | """Module for fused features""" 49 | 50 | def __init__(self, inter_channels=48, norm_layer=nn.BatchNorm2d): 51 | super(FeatureFused, self).__init__() 52 | self.conv2 = nn.Sequential( 53 | nn.Conv2d(512, inter_channels, 1, bias=False), 54 | norm_layer(inter_channels), 55 | nn.ReLU(True) 56 | ) 57 | self.conv3 = nn.Sequential( 58 | nn.Conv2d(1024, inter_channels, 1, bias=False), 59 | norm_layer(inter_channels), 60 | nn.ReLU(True) 61 | ) 62 | 63 | def forward(self, c2, c3, c4): 64 | size = c4.size()[2:] 65 | c2 = self.conv2(F.interpolate(c2, size, mode='bilinear', align_corners=True)) 66 | c3 = self.conv3(F.interpolate(c3, size, mode='bilinear', align_corners=True)) 67 | fused_feature = torch.cat([c4, c3, c2], dim=1) 68 | return fused_feature 69 | 70 | 71 | class _DUHead(nn.Module): 72 | def __init__(self, in_channels, norm_layer=nn.BatchNorm2d): 73 | super(_DUHead, self).__init__() 74 | self.fuse = FeatureFused(norm_layer=norm_layer) 75 | self.block = nn.Sequential( 76 | nn.Conv2d(in_channels, 256, 3, padding=1, bias=False), 77 | norm_layer(256), 78 | nn.ReLU(True), 79 | nn.Conv2d(256, 256, 3, padding=1, bias=False), 80 | norm_layer(256), 81 | nn.ReLU(True) 82 | ) 83 | 84 | def forward(self, c2, c3, c4): 85 | fused_feature = self.fuse(c2, c3, c4) 86 | out = self.block(fused_feature) 87 | return out 88 | 89 | 90 | class DUpsampling(nn.Module): 91 | """DUsampling module""" 92 | 93 | def __init__(self, in_channels, out_channels, scale_factor=2): 94 | super(DUpsampling, self).__init__() 95 | self.scale_factor = scale_factor 96 | self.conv_w = nn.Conv2d(in_channels, out_channels * scale_factor * scale_factor, 1, bias=False) 97 | 98 | def forward(self, x): 99 | x = self.conv_w(x) 100 | n, c, h, w = x.size() 101 | 102 | # N, C, H, W --> N, W, H, C 103 | x = x.permute(0, 3, 2, 1).contiguous() 104 | 105 | # N, W, H, C --> N, W, H * scale, C // scale 106 | x = x.view(n, w, h * self.scale_factor, c // self.scale_factor) 107 | 108 | # N, W, H * scale, C // scale --> N, H * scale, W, C // scale 109 | x = x.permute(0, 2, 1, 3).contiguous() 110 | 111 | # N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2) 112 | x = x.view(n, h * self.scale_factor, w * self.scale_factor, c // (self.scale_factor * self.scale_factor)) 113 | 114 | # N, H * scale, W * scale, C // (scale ** 2) -- > N, C // (scale ** 2), H * scale, W * scale 115 | x = x.permute(0, 3, 1, 2) 116 | 117 | return x 118 | -------------------------------------------------------------------------------- /segmentron/models/unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .segbase import SegBaseModel 6 | from .model_zoo import MODEL_REGISTRY 7 | from ..modules import _FCNHead 8 | from ..config import cfg 9 | 10 | __all__ = ['UNet'] 11 | 12 | 13 | @MODEL_REGISTRY.register() 14 | class UNet(SegBaseModel): 15 | 16 | def __init__(self): 17 | super(UNet, self).__init__(need_backbone=False) 18 | self.inc = DoubleConv(3, 64) 19 | self.down1 = Down(64, 128) 20 | self.down2 = Down(128, 256) 21 | self.down3 = Down(256, 512) 22 | self.down4 = Down(512, 512) 23 | self.head = _UNetHead(self.nclass) 24 | 25 | self.__setattr__('decoder', ['head', 'auxlayer'] if self.aux else ['head']) 26 | 27 | def forward(self, x): 28 | size = x.size()[2:] 29 | x1 = self.inc(x) 30 | x2 = self.down1(x1) 31 | x3 = self.down2(x2) 32 | x4 = self.down3(x3) 33 | x5 = self.down4(x4) 34 | 35 | outputs = list() 36 | x = self.head(x1, x2, x3, x4, x5) 37 | x = F.interpolate(x, size, mode='bilinear', align_corners=True) 38 | 39 | outputs.append(x) 40 | return tuple(outputs) 41 | 42 | 43 | class _UNetHead(nn.Module): 44 | def __init__(self, nclass, norm_layer=nn.BatchNorm2d): 45 | super(_UNetHead, self).__init__() 46 | bilinear = True 47 | self.up1 = Up(1024, 256, bilinear) 48 | self.up2 = Up(512, 128, bilinear) 49 | self.up3 = Up(256, 64, bilinear) 50 | self.up4 = Up(128, 64, bilinear) 51 | self.outc = OutConv(64, nclass) 52 | 53 | def forward(self, x1, x2, x3, x4, x5): 54 | x = self.up1(x5, x4) 55 | x = self.up2(x, x3) 56 | x = self.up3(x, x2) 57 | x = self.up4(x, x1) 58 | 59 | logits = self.outc(x) 60 | return logits 61 | 62 | 63 | class DoubleConv(nn.Module): 64 | """(convolution => [BN] => ReLU) * 2""" 65 | 66 | def __init__(self, in_channels, out_channels): 67 | super().__init__() 68 | self.double_conv = nn.Sequential( 69 | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), 70 | nn.BatchNorm2d(out_channels), 71 | nn.ReLU(inplace=True), 72 | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), 73 | nn.BatchNorm2d(out_channels), 74 | nn.ReLU(inplace=True) 75 | ) 76 | 77 | def forward(self, x): 78 | return self.double_conv(x) 79 | 80 | 81 | class Down(nn.Module): 82 | """Downscaling with maxpool then double conv""" 83 | 84 | def __init__(self, in_channels, out_channels): 85 | super().__init__() 86 | self.maxpool_conv = nn.Sequential( 87 | nn.MaxPool2d(2), 88 | DoubleConv(in_channels, out_channels) 89 | ) 90 | 91 | def forward(self, x): 92 | return self.maxpool_conv(x) 93 | 94 | 95 | class Up(nn.Module): 96 | """Upscaling then double conv""" 97 | 98 | def __init__(self, in_channels, out_channels, bilinear=True): 99 | super().__init__() 100 | 101 | # if bilinear, use the normal convolutions to reduce the number of channels 102 | if bilinear: 103 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 104 | else: 105 | self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2) 106 | 107 | self.conv = DoubleConv(in_channels, out_channels) 108 | 109 | def forward(self, x1, x2): 110 | x1 = self.up(x1) 111 | # input is CHW 112 | diffY = torch.tensor([x2.size()[2] - x1.size()[2]]) 113 | diffX = torch.tensor([x2.size()[3] - x1.size()[3]]) 114 | 115 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, 116 | diffY // 2, diffY - diffY // 2]) 117 | # if you have padding issues, see 118 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a 119 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd 120 | x = torch.cat([x2, x1], dim=1) 121 | return self.conv(x) 122 | 123 | 124 | class OutConv(nn.Module): 125 | def __init__(self, in_channels, out_channels): 126 | super(OutConv, self).__init__() 127 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) 128 | 129 | def forward(self, x): 130 | return self.conv(x) 131 | -------------------------------------------------------------------------------- /segmentron/data/dataloader/pascal_voc.py: -------------------------------------------------------------------------------- 1 | """Pascal VOC Semantic Segmentation Dataset.""" 2 | import os 3 | import torch 4 | import numpy as np 5 | 6 | from PIL import Image 7 | from .seg_data_base import SegmentationDataset 8 | 9 | 10 | class VOCSegmentation(SegmentationDataset): 11 | """Pascal VOC Semantic Segmentation Dataset. 12 | 13 | Parameters 14 | ---------- 15 | root : string 16 | Path to VOCdevkit folder. Default is './datasets/VOCdevkit' 17 | split: string 18 | 'train', 'val' or 'test' 19 | transform : callable, optional 20 | A function that transforms the image 21 | Examples 22 | -------- 23 | >>> from torchvision import transforms 24 | >>> import torch.utils.data as data 25 | >>> # Transforms for Normalization 26 | >>> input_transform = transforms.Compose([ 27 | >>> transforms.ToTensor(), 28 | >>> transforms.Normalize([.485, .456, .406], [.229, .224, .225]), 29 | >>> ]) 30 | >>> # Create Dataset 31 | >>> trainset = VOCSegmentation(split='train', transform=input_transform) 32 | >>> # Create Training Loader 33 | >>> train_data = data.DataLoader( 34 | >>> trainset, 4, shuffle=True, 35 | >>> num_workers=4) 36 | """ 37 | BASE_DIR = 'VOC2012' 38 | NUM_CLASS = 21 39 | 40 | def __init__(self, root='datasets/voc', split='train', mode=None, transform=None, **kwargs): 41 | super(VOCSegmentation, self).__init__(root, split, mode, transform, **kwargs) 42 | _voc_root = os.path.join(root, self.BASE_DIR) 43 | _mask_dir = os.path.join(_voc_root, 'SegmentationClass') 44 | _image_dir = os.path.join(_voc_root, 'JPEGImages') 45 | # train/val/test splits are pre-cut 46 | _splits_dir = os.path.join(_voc_root, 'ImageSets/Segmentation') 47 | if split == 'train': 48 | _split_f = os.path.join(_splits_dir, 'train.txt') 49 | elif split == 'val': 50 | _split_f = os.path.join(_splits_dir, 'val.txt') 51 | elif split == 'test': 52 | _split_f = os.path.join(_splits_dir, 'test.txt') 53 | else: 54 | raise RuntimeError('Unknown dataset split.') 55 | 56 | self.images = [] 57 | self.masks = [] 58 | with open(os.path.join(_split_f), "r") as lines: 59 | for line in lines: 60 | _image = os.path.join(_image_dir, line.rstrip('\n') + ".jpg") 61 | assert os.path.isfile(_image) 62 | self.images.append(_image) 63 | if split != 'test': 64 | _mask = os.path.join(_mask_dir, line.rstrip('\n') + ".png") 65 | assert os.path.isfile(_mask) 66 | self.masks.append(_mask) 67 | 68 | if split != 'test': 69 | assert (len(self.images) == len(self.masks)) 70 | print('Found {} images in the folder {}'.format(len(self.images), _voc_root)) 71 | 72 | def __getitem__(self, index): 73 | img = Image.open(self.images[index]).convert('RGB') 74 | if self.mode == 'test': 75 | img = self._img_transform(img) 76 | if self.transform is not None: 77 | img = self.transform(img) 78 | return img, os.path.basename(self.images[index]) 79 | mask = Image.open(self.masks[index]) 80 | # synchronized transform 81 | if self.mode == 'train': 82 | img, mask = self._sync_transform(img, mask) 83 | elif self.mode == 'val': 84 | img, mask = self._val_sync_transform(img, mask) 85 | else: 86 | assert self.mode == 'testval' 87 | img, mask = self._img_transform(img), self._mask_transform(mask) 88 | # general resize, normalize and toTensor 89 | if self.transform is not None: 90 | img = self.transform(img) 91 | 92 | return img, mask, os.path.basename(self.images[index]) 93 | 94 | def __len__(self): 95 | return len(self.images) 96 | 97 | def _mask_transform(self, mask): 98 | target = np.array(mask).astype('int32') 99 | target[target == 255] = -1 100 | return torch.from_numpy(target).long() 101 | 102 | @property 103 | def classes(self): 104 | """Category names.""" 105 | return ('background', 'airplane', 'bicycle', 'bird', 'boat', 'bottle', 106 | 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 107 | 'motorcycle', 'person', 'potted-plant', 'sheep', 'sofa', 'train', 108 | 'tv') 109 | 110 | 111 | if __name__ == '__main__': 112 | dataset = VOCSegmentation() -------------------------------------------------------------------------------- /segmentron/models/denseaspp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .segbase import SegBaseModel 6 | from .model_zoo import MODEL_REGISTRY 7 | from .fcn import _FCNHead 8 | 9 | __all__ = ['DenseASPP'] 10 | 11 | 12 | @MODEL_REGISTRY.register() 13 | class DenseASPP(SegBaseModel): 14 | def __init__(self): 15 | super(DenseASPP, self).__init__() 16 | 17 | in_channels = self.encoder.last_inp_channels 18 | 19 | self.head = _DenseASPPHead(in_channels, self.nclass, norm_layer=self.norm_layer) 20 | 21 | if self.aux: 22 | self.auxlayer = _FCNHead(in_channels, self.nclass) 23 | 24 | self.__setattr__('decoder', ['head', 'auxlayer'] if self.aux else ['head']) 25 | 26 | def forward(self, x): 27 | size = x.size()[2:] 28 | _, _, c3, c4 = self.encoder(x) 29 | # TODO add densenet as backbone 30 | # if self.dilate_scale > 8: 31 | # features = F.interpolate(c4, scale_factor=2, mode='bilinear', align_corners=True) 32 | outputs = [] 33 | x = self.head(c4) 34 | x = F.interpolate(x, size, mode='bilinear', align_corners=True) 35 | outputs.append(x) 36 | 37 | if self.aux: 38 | auxout = self.auxlayer(c3) 39 | auxout = F.interpolate(auxout, size, mode='bilinear', align_corners=True) 40 | outputs.append(auxout) 41 | return tuple(outputs) 42 | 43 | 44 | class _DenseASPPHead(nn.Module): 45 | def __init__(self, in_channels, nclass, norm_layer=nn.BatchNorm2d, norm_kwargs=None): 46 | super(_DenseASPPHead, self).__init__() 47 | self.dense_aspp_block = _DenseASPPBlock(in_channels, 256, 64, norm_layer, norm_kwargs) 48 | self.block = nn.Sequential( 49 | nn.Dropout(0.1), 50 | nn.Conv2d(in_channels + 5 * 64, nclass, 1) 51 | ) 52 | 53 | def forward(self, x): 54 | x = self.dense_aspp_block(x) 55 | return self.block(x) 56 | 57 | 58 | class _DenseASPPConv(nn.Sequential): 59 | def __init__(self, in_channels, inter_channels, out_channels, atrous_rate, 60 | drop_rate=0.1, norm_layer=nn.BatchNorm2d, norm_kwargs=None): 61 | super(_DenseASPPConv, self).__init__() 62 | self.add_module('conv1', nn.Conv2d(in_channels, inter_channels, 1)), 63 | self.add_module('bn1', norm_layer(inter_channels, **({} if norm_kwargs is None else norm_kwargs))), 64 | self.add_module('relu1', nn.ReLU(True)), 65 | self.add_module('conv2', nn.Conv2d(inter_channels, out_channels, 3, dilation=atrous_rate, padding=atrous_rate)), 66 | self.add_module('bn2', norm_layer(out_channels, **({} if norm_kwargs is None else norm_kwargs))), 67 | self.add_module('relu2', nn.ReLU(True)), 68 | self.drop_rate = drop_rate 69 | 70 | def forward(self, x): 71 | features = super(_DenseASPPConv, self).forward(x) 72 | if self.drop_rate > 0: 73 | features = F.dropout(features, p=self.drop_rate, training=self.training) 74 | return features 75 | 76 | 77 | class _DenseASPPBlock(nn.Module): 78 | def __init__(self, in_channels, inter_channels1, inter_channels2, 79 | norm_layer=nn.BatchNorm2d, norm_kwargs=None): 80 | super(_DenseASPPBlock, self).__init__() 81 | self.aspp_3 = _DenseASPPConv(in_channels, inter_channels1, inter_channels2, 3, 0.1, 82 | norm_layer, norm_kwargs) 83 | self.aspp_6 = _DenseASPPConv(in_channels + inter_channels2 * 1, inter_channels1, inter_channels2, 6, 0.1, 84 | norm_layer, norm_kwargs) 85 | self.aspp_12 = _DenseASPPConv(in_channels + inter_channels2 * 2, inter_channels1, inter_channels2, 12, 0.1, 86 | norm_layer, norm_kwargs) 87 | self.aspp_18 = _DenseASPPConv(in_channels + inter_channels2 * 3, inter_channels1, inter_channels2, 18, 0.1, 88 | norm_layer, norm_kwargs) 89 | self.aspp_24 = _DenseASPPConv(in_channels + inter_channels2 * 4, inter_channels1, inter_channels2, 24, 0.1, 90 | norm_layer, norm_kwargs) 91 | 92 | def forward(self, x): 93 | aspp3 = self.aspp_3(x) 94 | x = torch.cat([aspp3, x], dim=1) 95 | 96 | aspp6 = self.aspp_6(x) 97 | x = torch.cat([aspp6, x], dim=1) 98 | 99 | aspp12 = self.aspp_12(x) 100 | x = torch.cat([aspp12, x], dim=1) 101 | 102 | aspp18 = self.aspp_18(x) 103 | x = torch.cat([aspp18, x], dim=1) 104 | 105 | aspp24 = self.aspp_24(x) 106 | x = torch.cat([aspp24, x], dim=1) 107 | 108 | return x 109 | -------------------------------------------------------------------------------- /segmentron/models/segbase.py: -------------------------------------------------------------------------------- 1 | """Base Model for Semantic Segmentation""" 2 | import math 3 | import numbers 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from .backbones import get_segmentation_backbone 10 | from ..data.dataloader import datasets 11 | from ..modules import get_norm 12 | from ..config import cfg 13 | __all__ = ['SegBaseModel'] 14 | 15 | 16 | class SegBaseModel(nn.Module): 17 | r"""Base Model for Semantic Segmentation 18 | """ 19 | def __init__(self, need_backbone=True): 20 | super(SegBaseModel, self).__init__() 21 | self.nclass = datasets[cfg.DATASET.NAME].NUM_CLASS 22 | self.aux = cfg.SOLVER.AUX 23 | self.norm_layer = get_norm(cfg.MODEL.BN_TYPE) 24 | self.backbone = None 25 | self.encoder = None 26 | if need_backbone: 27 | self.get_backbone() 28 | 29 | def get_backbone(self): 30 | self.backbone = cfg.MODEL.BACKBONE.lower() 31 | self.encoder = get_segmentation_backbone(self.backbone, self.norm_layer) 32 | 33 | def base_forward(self, x): 34 | """forwarding backbone network""" 35 | c1, c2, c3, c4 = self.encoder(x) 36 | return c1, c2, c3, c4 37 | 38 | def demo(self, x): 39 | pred = self.forward(x) 40 | if self.aux: 41 | pred = pred[0] 42 | return pred 43 | 44 | def evaluate(self, image): 45 | """evaluating network with inputs and targets""" 46 | scales = cfg.TEST.SCALES 47 | flip = cfg.TEST.FLIP 48 | crop_size = _to_tuple(cfg.TEST.CROP_SIZE) if cfg.TEST.CROP_SIZE else None 49 | batch, _, h, w = image.shape 50 | base_size = max(h, w) 51 | # scores = torch.zeros((batch, self.nclass, h, w)).to(image.device) 52 | scores = None 53 | for scale in scales: 54 | long_size = int(math.ceil(base_size * scale)) 55 | if h > w: 56 | height = long_size 57 | width = int(1.0 * w * long_size / h + 0.5) 58 | else: 59 | width = long_size 60 | height = int(1.0 * h * long_size / w + 0.5) 61 | 62 | # resize image to current size 63 | cur_img = _resize_image(image, height, width) 64 | if crop_size is not None: 65 | assert crop_size[0] >= h and crop_size[1] >= w 66 | crop_size_scaled = (int(math.ceil(crop_size[0] * scale)), 67 | int(math.ceil(crop_size[1] * scale))) 68 | cur_img = _pad_image(cur_img, crop_size_scaled) 69 | outputs = self.forward(cur_img)[0][..., :height, :width] 70 | if flip: 71 | outputs += _flip_image(self.forward(_flip_image(cur_img))[0])[..., :height, :width] 72 | 73 | score = _resize_image(outputs, h, w) 74 | 75 | if scores is None: 76 | scores = score 77 | else: 78 | scores += score 79 | return scores 80 | 81 | 82 | def _resize_image(img, h, w): 83 | return F.interpolate(img, size=[h, w], mode='bilinear', align_corners=True) 84 | 85 | 86 | def _pad_image(img, crop_size): 87 | b, c, h, w = img.shape 88 | assert(c == 3) 89 | padh = crop_size[0] - h if h < crop_size[0] else 0 90 | padw = crop_size[1] - w if w < crop_size[1] else 0 91 | if padh == 0 and padw == 0: 92 | return img 93 | img_pad = F.pad(img, (0, padh, 0, padw)) 94 | 95 | # TODO clean this code 96 | # mean = cfg.DATASET.MEAN 97 | # std = cfg.DATASET.STD 98 | # pad_values = -np.array(mean) / np.array(std) 99 | # img_pad = torch.zeros((b, c, h + padh, w + padw)).to(img.device) 100 | # for i in range(c): 101 | # # print(img[:, i, :, :].unsqueeze(1).shape) 102 | # img_pad[:, i, :, :] = torch.squeeze( 103 | # F.pad(img[:, i, :, :].unsqueeze(1), (0, padh, 0, padw), 104 | # 'constant', value=pad_values[i]), 1) 105 | # assert(img_pad.shape[2] >= crop_size[0] and img_pad.shape[3] >= crop_size[1]) 106 | 107 | return img_pad 108 | 109 | 110 | def _crop_image(img, h0, h1, w0, w1): 111 | return img[:, :, h0:h1, w0:w1] 112 | 113 | 114 | def _flip_image(img): 115 | assert(img.ndim == 4) 116 | return img.flip((3)) 117 | 118 | 119 | def _to_tuple(size): 120 | if isinstance(size, (list, tuple)): 121 | assert len(size), 'Expect eval crop size contains two element, ' \ 122 | 'but received {}'.format(len(size)) 123 | return tuple(size) 124 | elif isinstance(size, numbers.Number): 125 | return tuple((size, size)) 126 | else: 127 | raise ValueError('Unsupport datatype: {}'.format(type(size))) 128 | -------------------------------------------------------------------------------- /segmentron/data/downloader/pascal_voc.py: -------------------------------------------------------------------------------- 1 | """Prepare PASCAL VOC datasets""" 2 | import os 3 | import sys 4 | 5 | cur_path = os.path.abspath(os.path.dirname(__file__)) 6 | root_path = os.path.split(os.path.split(os.path.split(cur_path)[0])[0])[0] 7 | sys.path.append(root_path) 8 | 9 | import argparse 10 | import shutil 11 | import tarfile 12 | from segmentron.utils import download, makedirs 13 | 14 | 15 | def parse_args(): 16 | parser = argparse.ArgumentParser( 17 | description='Initialize PASCAL VOC dataset.', 18 | epilog='Example: python pascal_voc.py --download-dir ~/VOCdevkit', 19 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 20 | parser.add_argument('--download-dir', type=str, default=None, help='dataset directory on disk') 21 | parser.add_argument('--no-download', action='store_true', help='disable automatic download if set') 22 | parser.add_argument('--overwrite', action='store_true', 23 | help='overwrite downloaded files if set, in case they are corrupted') 24 | args = parser.parse_args() 25 | return args 26 | 27 | 28 | ##################################################################################### 29 | # Download and extract VOC datasets into ``path`` 30 | 31 | def download_voc(path, overwrite=False): 32 | _DOWNLOAD_URLS = [ 33 | ('http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar', 34 | '34ed68851bce2a36e2a223fa52c661d592c66b3c'), 35 | ('http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar', 36 | '41a8d6e12baa5ab18ee7f8f8029b9e11805b4ef1'), 37 | ('http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar', 38 | '4e443f8a2eca6b1dac8a6c57641b67dd40621a49')] 39 | makedirs(path) 40 | for url, checksum in _DOWNLOAD_URLS: 41 | filename = download(url, path=path, overwrite=overwrite, sha1_hash=checksum) 42 | # extract 43 | with tarfile.open(filename) as tar: 44 | tar.extractall(path=path) 45 | 46 | 47 | ##################################################################################### 48 | # Download and extract the VOC augmented segmentation dataset into ``path`` 49 | 50 | def download_aug(path, overwrite=False): 51 | _AUG_DOWNLOAD_URLS = [ 52 | ('http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz', 53 | '7129e0a480c2d6afb02b517bb18ac54283bfaa35')] 54 | makedirs(path) 55 | for url, checksum in _AUG_DOWNLOAD_URLS: 56 | filename = download(url, path=path, overwrite=overwrite, sha1_hash=checksum) 57 | # extract 58 | with tarfile.open(filename) as tar: 59 | tar.extractall(path=path) 60 | shutil.move(os.path.join(path, 'benchmark_RELEASE'), 61 | os.path.join(path, 'VOCaug')) 62 | filenames = ['VOCaug/dataset/train.txt', 'VOCaug/dataset/val.txt'] 63 | # generate trainval.txt 64 | with open(os.path.join(path, 'VOCaug/dataset/trainval.txt'), 'w') as outfile: 65 | for fname in filenames: 66 | fname = os.path.join(path, fname) 67 | with open(fname) as infile: 68 | for line in infile: 69 | outfile.write(line) 70 | 71 | 72 | if __name__ == '__main__': 73 | args = parse_args() 74 | 75 | default_dir = os.path.join(root_path, 'datasets/voc') 76 | if args.download_dir is not None: 77 | path = args.download_dir 78 | else: 79 | path = default_dir 80 | if not os.path.isfile(path) or not os.path.isdir(os.path.join(path, 'VOC2007')) \ 81 | or not os.path.isdir(os.path.join(path, 'VOC2012')): 82 | if args.no_download: 83 | raise ValueError(('{} is not a valid directory, make sure it is present.' 84 | ' Or you should not disable "--no-download" to grab it'.format(path))) 85 | else: 86 | download_voc(path, overwrite=args.overwrite) 87 | shutil.move(os.path.join(path, 'VOCdevkit', 'VOC2007'), os.path.join(path, 'VOC2007')) 88 | shutil.move(os.path.join(path, 'VOCdevkit', 'VOC2012'), os.path.join(path, 'VOC2012')) 89 | shutil.rmtree(os.path.join(path, 'VOCdevkit')) 90 | 91 | if not os.path.isdir(os.path.join(path, 'VOCaug')): 92 | if args.no_download: 93 | raise ValueError(('{} is not a valid directory, make sure it is present.' 94 | ' Or you should not disable "--no-download" to grab it'.format(path))) 95 | else: 96 | download_aug(path, overwrite=args.overwrite) 97 | 98 | try: 99 | os.symlink(path, default_dir) 100 | except Exception as e: 101 | print(e) 102 | -------------------------------------------------------------------------------- /segmentron/models/edanet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .segbase import SegBaseModel 6 | from .model_zoo import MODEL_REGISTRY 7 | from ..modules import _FCNHead 8 | from ..config import cfg 9 | 10 | __all__ = ["EDANet"] 11 | 12 | 13 | @MODEL_REGISTRY.register() 14 | class EDANet(SegBaseModel): 15 | def __init__(self): 16 | super(EDANet, self).__init__(need_backbone=False) 17 | 18 | self.layers = nn.ModuleList() 19 | self.dilation1 = [1, 1, 1, 2, 2] 20 | self.dilation2 = [2, 2, 4, 4, 8, 8, 16, 16] 21 | 22 | # DownsamplerBlock1 23 | self.layers.append(DownsamplerBlock(3, 15)) 24 | 25 | # DownsamplerBlock2 26 | self.layers.append(DownsamplerBlock(15, 60)) 27 | 28 | # EDA module 1-1 ~ 1-5 29 | for i in range(5): 30 | self.layers.append(EDABlock(60 + 40 * i, self.dilation1[i])) 31 | 32 | # DownsamplerBlock3 33 | self.layers.append(DownsamplerBlock(260, 130)) 34 | 35 | # EDA module 2-1 ~ 2-8 36 | for j in range(8): 37 | self.layers.append(EDABlock(130 + 40 * j, self.dilation2[j])) 38 | 39 | # Projection layer 40 | self.project_layer = nn.Conv2d(450, self.nclass, kernel_size=1) 41 | 42 | self.weights_init() 43 | 44 | def weights_init(self): 45 | for idx, m in enumerate(self.modules()): 46 | classname = m.__class__.__name__ 47 | if classname.find('Conv') != -1: 48 | m.weight.data.normal_(0.0, 0.02) 49 | elif classname.find('BatchNorm') != -1: 50 | m.weight.data.normal_(1.0, 0.02) 51 | m.bias.data.fill_(0) 52 | 53 | def forward(self, x): 54 | 55 | output = x 56 | 57 | for layer in self.layers: 58 | output = layer(output) 59 | 60 | output = self.project_layer(output) 61 | 62 | output = F.interpolate(output, size=x.size()[2:], mode='bilinear', align_corners=True) 63 | # Bilinear interpolation x8 64 | # output = F.interpolate(output, scale_factor=8, mode='bilinear', align_corners=True) 65 | # 66 | # # Bilinear interpolation x2 (inference only) 67 | # if not self.training: 68 | # output = F.interpolate(output, scale_factor=2, mode='bilinear', align_corners=True) 69 | 70 | return tuple([output]) 71 | 72 | 73 | class DownsamplerBlock(nn.Module): 74 | def __init__(self, ninput, noutput): 75 | super(DownsamplerBlock, self).__init__() 76 | 77 | self.ninput = ninput 78 | self.noutput = noutput 79 | 80 | if self.ninput < self.noutput: 81 | # Wout > Win 82 | self.conv = nn.Conv2d(ninput, noutput - ninput, kernel_size=3, stride=2, padding=1) 83 | self.pool = nn.MaxPool2d(2, stride=2) 84 | else: 85 | # Wout < Win 86 | self.conv = nn.Conv2d(ninput, noutput, kernel_size=3, stride=2, padding=1) 87 | 88 | self.bn = nn.BatchNorm2d(noutput) 89 | 90 | def forward(self, x): 91 | if self.ninput < self.noutput: 92 | output = torch.cat([self.conv(x), self.pool(x)], 1) 93 | else: 94 | output = self.conv(x) 95 | 96 | output = self.bn(output) 97 | return F.relu(output) 98 | 99 | 100 | class EDABlock(nn.Module): 101 | def __init__(self, ninput, dilated, k=40, dropprob=0.02): 102 | super(EDABlock, self).__init__() 103 | 104 | # k: growthrate 105 | # dropprob:a dropout layer between the last ReLU and the concatenation of each module 106 | 107 | self.conv1x1 = nn.Conv2d(ninput, k, kernel_size=1) 108 | self.bn0 = nn.BatchNorm2d(k) 109 | 110 | self.conv3x1_1 = nn.Conv2d(k, k, kernel_size=(3, 1), padding=(1, 0)) 111 | self.conv1x3_1 = nn.Conv2d(k, k, kernel_size=(1, 3), padding=(0, 1)) 112 | self.bn1 = nn.BatchNorm2d(k) 113 | 114 | self.conv3x1_2 = nn.Conv2d(k, k, (3, 1), stride=1, padding=(dilated, 0), dilation=dilated) 115 | self.conv1x3_2 = nn.Conv2d(k, k, (1, 3), stride=1, padding=(0, dilated), dilation=dilated) 116 | self.bn2 = nn.BatchNorm2d(k) 117 | 118 | self.dropout = nn.Dropout2d(dropprob) 119 | 120 | def forward(self, x): 121 | input = x 122 | 123 | output = self.conv1x1(x) 124 | output = self.bn0(output) 125 | output = F.relu(output) 126 | 127 | output = self.conv3x1_1(output) 128 | output = self.conv1x3_1(output) 129 | output = self.bn1(output) 130 | output = F.relu(output) 131 | 132 | output = self.conv3x1_2(output) 133 | output = self.conv1x3_2(output) 134 | output = self.bn2(output) 135 | output = F.relu(output) 136 | 137 | if (self.dropout.p != 0): 138 | output = self.dropout(output) 139 | 140 | output = torch.cat([output, input], 1) 141 | # print output.size() #check the output 142 | return output -------------------------------------------------------------------------------- /tools/eval.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | import sys 5 | 6 | cur_path = os.path.abspath(os.path.dirname(__file__)) 7 | root_path = os.path.split(cur_path)[0] 8 | sys.path.append(root_path) 9 | 10 | import logging 11 | import torch 12 | import torch.nn as nn 13 | import torch.utils.data as data 14 | import torch.nn.functional as F 15 | 16 | from tabulate import tabulate 17 | from torchvision import transforms 18 | from segmentron.data.dataloader import get_segmentation_dataset 19 | from segmentron.models.model_zoo import get_segmentation_model 20 | from segmentron.utils.score import SegmentationMetric 21 | from segmentron.utils.distributed import synchronize, make_data_sampler, make_batch_data_sampler 22 | from segmentron.config import cfg 23 | from segmentron.utils.options import parse_args 24 | from segmentron.utils.default_setup import default_setup 25 | 26 | 27 | class Evaluator(object): 28 | def __init__(self, args): 29 | self.args = args 30 | self.device = torch.device(args.device) 31 | 32 | # image transform 33 | input_transform = transforms.Compose([ 34 | transforms.ToTensor(), 35 | transforms.Normalize(cfg.DATASET.MEAN, cfg.DATASET.STD), 36 | ]) 37 | 38 | # dataset and dataloader 39 | val_dataset = get_segmentation_dataset(cfg.DATASET.NAME, split='val', mode='testval', transform=input_transform) 40 | val_sampler = make_data_sampler(val_dataset, False, args.distributed) 41 | val_batch_sampler = make_batch_data_sampler(val_sampler, images_per_batch=cfg.TEST.BATCH_SIZE, drop_last=False) 42 | self.val_loader = data.DataLoader(dataset=val_dataset, 43 | batch_sampler=val_batch_sampler, 44 | num_workers=cfg.DATASET.WORKERS, 45 | pin_memory=True) 46 | self.classes = val_dataset.classes 47 | # create network 48 | self.model = get_segmentation_model().to(self.device) 49 | 50 | if hasattr(self.model, 'encoder') and hasattr(self.model.encoder, 'named_modules') and \ 51 | cfg.MODEL.BN_EPS_FOR_ENCODER: 52 | logging.info('set bn custom eps for bn in encoder: {}'.format(cfg.MODEL.BN_EPS_FOR_ENCODER)) 53 | self.set_batch_norm_attr(self.model.encoder.named_modules(), 'eps', cfg.MODEL.BN_EPS_FOR_ENCODER) 54 | 55 | if args.distributed: 56 | self.model = nn.parallel.DistributedDataParallel(self.model, 57 | device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True) 58 | self.model.to(self.device) 59 | 60 | self.metric = SegmentationMetric(val_dataset.num_class, args.distributed) 61 | 62 | def set_batch_norm_attr(self, named_modules, attr, value): 63 | for m in named_modules: 64 | if isinstance(m[1], nn.BatchNorm2d) or isinstance(m[1], nn.SyncBatchNorm): 65 | setattr(m[1], attr, value) 66 | 67 | def eval(self): 68 | self.metric.reset() 69 | self.model.eval() 70 | if self.args.distributed: 71 | model = self.model.module 72 | else: 73 | model = self.model 74 | 75 | logging.info("Start validation, Total sample: {:d}".format(len(self.val_loader))) 76 | import time 77 | time_start = time.time() 78 | for i, (image, target, filename) in enumerate(self.val_loader): 79 | image = image.to(self.device) 80 | target = target.to(self.device) 81 | 82 | with torch.no_grad(): 83 | output = model.evaluate(image) 84 | 85 | self.metric.update(output, target) 86 | pixAcc, mIoU = self.metric.get() 87 | logging.info("Sample: {:d}, validation pixAcc: {:.3f}, mIoU: {:.3f}".format( 88 | i + 1, pixAcc * 100, mIoU * 100)) 89 | 90 | synchronize() 91 | pixAcc, mIoU, category_iou = self.metric.get(return_category_iou=True) 92 | logging.info('Eval use time: {:.3f} second'.format(time.time() - time_start)) 93 | logging.info('End validation pixAcc: {:.3f}, mIoU: {:.3f}'.format( 94 | pixAcc * 100, mIoU * 100)) 95 | 96 | headers = ['class id', 'class name', 'iou'] 97 | table = [] 98 | for i, cls_name in enumerate(self.classes): 99 | table.append([cls_name, category_iou[i]]) 100 | logging.info('Category iou: \n {}'.format(tabulate(table, headers, tablefmt='grid', showindex="always", 101 | numalign='center', stralign='center'))) 102 | 103 | 104 | if __name__ == '__main__': 105 | args = parse_args() 106 | cfg.update_from_file(args.config_file) 107 | cfg.update_from_list(args.opts) 108 | cfg.PHASE = 'test' 109 | cfg.ROOT_PATH = root_path 110 | cfg.check_and_freeze() 111 | 112 | default_setup(args) 113 | 114 | evaluator = Evaluator(args) 115 | evaluator.eval() 116 | -------------------------------------------------------------------------------- /segmentron/config/config.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import unicode_literals 5 | 6 | import codecs 7 | import yaml 8 | import six 9 | import time 10 | 11 | from ast import literal_eval 12 | 13 | class SegmentronConfig(dict): 14 | def __init__(self, *args, **kwargs): 15 | super(SegmentronConfig, self).__init__(*args, **kwargs) 16 | self.immutable = False 17 | 18 | def __setattr__(self, key, value, create_if_not_exist=True): 19 | if key in ["immutable"]: 20 | self.__dict__[key] = value 21 | return 22 | 23 | t = self 24 | keylist = key.split(".") 25 | for k in keylist[:-1]: 26 | t = t.__getattr__(k, create_if_not_exist) 27 | 28 | t.__getattr__(keylist[-1], create_if_not_exist) 29 | t[keylist[-1]] = value 30 | 31 | def __getattr__(self, key, create_if_not_exist=True): 32 | if key in ["immutable"]: 33 | if key not in self.__dict__: 34 | self.__dict__[key] = False 35 | return self.__dict__[key] 36 | 37 | if not key in self: 38 | if not create_if_not_exist: 39 | raise KeyError 40 | self[key] = SegmentronConfig() 41 | return self[key] 42 | 43 | def __setitem__(self, key, value): 44 | # 45 | if self.immutable: 46 | raise AttributeError( 47 | 'Attempted to set "{}" to "{}", but SegConfig is immutable'. 48 | format(key, value)) 49 | # 50 | if isinstance(value, six.string_types): 51 | try: 52 | value = literal_eval(value) 53 | except ValueError: 54 | pass 55 | except SyntaxError: 56 | pass 57 | super(SegmentronConfig, self).__setitem__(key, value) 58 | 59 | def update_from_other_cfg(self, other): 60 | if isinstance(other, dict): 61 | other = SegmentronConfig(other) 62 | assert isinstance(other, SegmentronConfig) 63 | cfg_list = [("", other)] 64 | while len(cfg_list): 65 | prefix, tdic = cfg_list[0] 66 | cfg_list = cfg_list[1:] 67 | for key, value in tdic.items(): 68 | key = "{}.{}".format(prefix, key) if prefix else key 69 | if isinstance(value, dict): 70 | cfg_list.append((key, value)) 71 | continue 72 | try: 73 | self.__setattr__(key, value, create_if_not_exist=False) 74 | except KeyError: 75 | raise KeyError('Non-existent config key: {}'.format(key)) 76 | 77 | def remove_irrelevant_cfg(self): 78 | model_name = self.MODEL.MODEL_NAME 79 | 80 | from ..models.model_zoo import MODEL_REGISTRY 81 | model_list = MODEL_REGISTRY.get_list() 82 | model_list_lower = [x.lower() for x in model_list] 83 | 84 | assert model_name.lower() in model_list_lower, "Expected model name in {}, but received {}"\ 85 | .format(model_list, model_name) 86 | pop_keys = [] 87 | for key in self.MODEL.keys(): 88 | if key.lower() in model_list_lower: 89 | if model_name.lower() == 'pointrend' and \ 90 | key.lower() == self.MODEL.POINTREND.BASEMODEL.lower(): 91 | continue 92 | if key.lower() in model_list_lower and key.lower() != model_name.lower(): 93 | pop_keys.append(key) 94 | for key in pop_keys: 95 | self.MODEL.pop(key) 96 | 97 | 98 | 99 | def check_and_freeze(self): 100 | self.TIME_STAMP = time.strftime('%Y-%m-%d-%H-%M', time.localtime()) 101 | # TODO: remove irrelevant config and then freeze 102 | self.remove_irrelevant_cfg() 103 | self.immutable = True 104 | 105 | def update_from_list(self, config_list): 106 | if len(config_list) % 2 != 0: 107 | raise ValueError( 108 | "Command line options config format error! Please check it: {}". 109 | format(config_list)) 110 | for key, value in zip(config_list[0::2], config_list[1::2]): 111 | try: 112 | self.__setattr__(key, value, create_if_not_exist=False) 113 | except KeyError: 114 | raise KeyError('Non-existent config key: {}'.format(key)) 115 | 116 | def update_from_file(self, config_file): 117 | with codecs.open(config_file, 'r', 'utf-8') as file: 118 | loaded_cfg = yaml.load(file, Loader=yaml.FullLoader) 119 | self.update_from_other_cfg(loaded_cfg) 120 | 121 | def set_immutable(self, immutable): 122 | self.immutable = immutable 123 | for value in self.values(): 124 | if isinstance(value, SegmentronConfig): 125 | value.set_immutable(immutable) 126 | 127 | def is_immutable(self): 128 | return self.immutable -------------------------------------------------------------------------------- /segmentron/data/dataloader/seg_data_base.py: -------------------------------------------------------------------------------- 1 | """Base segmentation dataset""" 2 | import os 3 | import random 4 | import numpy as np 5 | import torchvision 6 | 7 | from PIL import Image, ImageOps, ImageFilter 8 | from ...config import cfg 9 | 10 | __all__ = ['SegmentationDataset'] 11 | 12 | 13 | class SegmentationDataset(object): 14 | """Segmentation Base Dataset""" 15 | 16 | def __init__(self, root, split, mode, transform, base_size=520, crop_size=480): 17 | super(SegmentationDataset, self).__init__() 18 | self.root = os.path.join(cfg.ROOT_PATH, root) 19 | self.transform = transform 20 | self.split = split 21 | self.mode = mode if mode is not None else split 22 | self.base_size = base_size 23 | self.crop_size = self.to_tuple(crop_size) 24 | self.color_jitter = self._get_color_jitter() 25 | 26 | def to_tuple(self, size): 27 | if isinstance(size, (list, tuple)): 28 | return tuple(size) 29 | elif isinstance(size, (int, float)): 30 | return tuple((size, size)) 31 | else: 32 | raise ValueError('Unsupport datatype: {}'.format(type(size))) 33 | 34 | def _get_color_jitter(self): 35 | color_jitter = cfg.AUG.COLOR_JITTER 36 | if color_jitter is None: 37 | return None 38 | if isinstance(color_jitter, (list, tuple)): 39 | # color jitter should be a 3-tuple/list if spec brightness/contrast/saturation 40 | # or 4 if also augmenting hue 41 | assert len(color_jitter) in (3, 4) 42 | else: 43 | # if it's a scalar, duplicate for brightness, contrast, and saturation, no hue 44 | color_jitter = (float(color_jitter),) * 3 45 | return torchvision.transforms.ColorJitter(*color_jitter) 46 | 47 | def _val_sync_transform(self, img, mask): 48 | outsize = self.crop_size 49 | short_size = min(outsize) 50 | w, h = img.size 51 | if w > h: 52 | oh = short_size 53 | ow = int(1.0 * w * oh / h) 54 | else: 55 | ow = short_size 56 | oh = int(1.0 * h * ow / w) 57 | img = img.resize((ow, oh), Image.BILINEAR) 58 | mask = mask.resize((ow, oh), Image.NEAREST) 59 | # center crop 60 | w, h = img.size 61 | x1 = int(round((w - outsize[1]) / 2.)) 62 | y1 = int(round((h - outsize[0]) / 2.)) 63 | img = img.crop((x1, y1, x1 + outsize[1], y1 + outsize[0])) 64 | mask = mask.crop((x1, y1, x1 + outsize[1], y1 + outsize[0])) 65 | 66 | # final transform 67 | img, mask = self._img_transform(img), self._mask_transform(mask) 68 | return img, mask 69 | 70 | def _sync_transform(self, img, mask): 71 | # random mirror 72 | if cfg.AUG.MIRROR and random.random() < 0.5: 73 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 74 | mask = mask.transpose(Image.FLIP_LEFT_RIGHT) 75 | crop_size = self.crop_size 76 | # random scale (short edge) 77 | short_size = random.randint(int(self.base_size * 0.5), int(self.base_size * 2.0)) 78 | w, h = img.size 79 | if h > w: 80 | ow = short_size 81 | oh = int(1.0 * h * ow / w) 82 | else: 83 | oh = short_size 84 | ow = int(1.0 * w * oh / h) 85 | img = img.resize((ow, oh), Image.BILINEAR) 86 | mask = mask.resize((ow, oh), Image.NEAREST) 87 | # pad crop 88 | if short_size < min(crop_size): 89 | padh = crop_size[0] - oh if oh < crop_size[0] else 0 90 | padw = crop_size[1] - ow if ow < crop_size[1] else 0 91 | img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0) 92 | mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=-1) 93 | # random crop crop_size 94 | w, h = img.size 95 | x1 = random.randint(0, w - crop_size[1]) 96 | y1 = random.randint(0, h - crop_size[0]) 97 | img = img.crop((x1, y1, x1 + crop_size[1], y1 + crop_size[0])) 98 | mask = mask.crop((x1, y1, x1 + crop_size[1], y1 + crop_size[0])) 99 | # gaussian blur as in PSP 100 | if cfg.AUG.BLUR_PROB > 0 and random.random() < cfg.AUG.BLUR_PROB: 101 | radius = cfg.AUG.BLUR_RADIUS if cfg.AUG.BLUR_RADIUS > 0 else random.random() 102 | img = img.filter(ImageFilter.GaussianBlur(radius=radius)) 103 | # color jitter 104 | if self.color_jitter: 105 | img = self.color_jitter(img) 106 | # final transform 107 | img, mask = self._img_transform(img), self._mask_transform(mask) 108 | return img, mask 109 | 110 | def _img_transform(self, img): 111 | return np.array(img) 112 | 113 | def _mask_transform(self, mask): 114 | return np.array(mask).astype('int32') 115 | 116 | @property 117 | def num_class(self): 118 | """Number of categories.""" 119 | return self.NUM_CLASS 120 | 121 | @property 122 | def pred_offset(self): 123 | return 0 124 | -------------------------------------------------------------------------------- /segmentron/models/refinenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .segbase import SegBaseModel 6 | from .model_zoo import MODEL_REGISTRY 7 | from ..modules import _FCNHead 8 | 9 | __all__ = ['RefineNet'] 10 | 11 | 12 | @MODEL_REGISTRY.register() 13 | class RefineNet(SegBaseModel): 14 | 15 | def __init__(self): 16 | super(RefineNet, self).__init__() 17 | self.head = _RefineHead(self.nclass, norm_layer=self.norm_layer) 18 | if self.aux: 19 | self.auxlayer = _FCNHead(728, self.nclass) 20 | self.__setattr__('decoder', ['head', 'auxlayer'] if self.aux else ['head']) 21 | 22 | def forward(self, x): 23 | size = x.size()[2:] 24 | c1, c2, c3, c4 = self.encoder(x) 25 | 26 | outputs = list() 27 | x = self.head(c1, c2, c3, c4) 28 | x = F.interpolate(x, size, mode='bilinear', align_corners=True) 29 | 30 | outputs.append(x) 31 | if self.aux: 32 | auxout = self.auxlayer(c3) 33 | auxout = F.interpolate(auxout, size, mode='bilinear', align_corners=True) 34 | outputs.append(auxout) 35 | return tuple(outputs) 36 | 37 | 38 | class _RefineHead(nn.Module): 39 | def __init__(self, nclass, norm_layer=nn.BatchNorm2d): 40 | super(_RefineHead, self).__init__() 41 | self.do = nn.Dropout(p=0.5) 42 | self.relu = nn.ReLU(inplace=True) 43 | self.p_ims1d2_outl1_dimred = nn.Conv2d(2048, 512, 1, bias=False) 44 | self.mflow_conv_g1_pool = self._make_crp(512, 512, 4) 45 | self.mflow_conv_g1_b3_joint_varout_dimred = nn.Conv2d(512, 256, 1, bias=False) 46 | self.p_ims1d2_outl2_dimred = nn.Conv2d(1024, 256, 1, bias=False) 47 | self.adapt_stage2_b2_joint_varout_dimred = nn.Conv2d(256, 256, 1, bias=False) 48 | self.mflow_conv_g2_pool = self._make_crp(256, 256, 4) 49 | self.mflow_conv_g2_b3_joint_varout_dimred = nn.Conv2d(256, 256, 1, bias=False) 50 | 51 | self.p_ims1d2_outl3_dimred = nn.Conv2d(512, 256, 1, bias=False) 52 | self.adapt_stage3_b2_joint_varout_dimred = nn.Conv2d(256, 256, 1, bias=False) 53 | self.mflow_conv_g3_pool = self._make_crp(256, 256, 4) 54 | self.mflow_conv_g3_b3_joint_varout_dimred = nn.Conv2d(256, 256, 1, bias=False) 55 | 56 | self.p_ims1d2_outl4_dimred = nn.Conv2d(256, 256, 1, bias=False) 57 | self.adapt_stage4_b2_joint_varout_dimred = nn.Conv2d(256, 256, 1, bias=False) 58 | self.mflow_conv_g4_pool = self._make_crp(256, 256, 4) 59 | 60 | self.clf_conv = nn.Conv2d(256, nclass, kernel_size=3, stride=1, 61 | padding=1, bias=True) 62 | 63 | def _make_crp(self, in_planes, out_planes, stages): 64 | layers = [CRPBlock(in_planes, out_planes, stages)] 65 | return nn.Sequential(*layers) 66 | 67 | def forward(self, l1, l2, l3, l4): 68 | l4 = self.do(l4) 69 | l3 = self.do(l3) 70 | 71 | x4 = self.p_ims1d2_outl1_dimred(l4) 72 | x4 = self.relu(x4) 73 | x4 = self.mflow_conv_g1_pool(x4) 74 | x4 = self.mflow_conv_g1_b3_joint_varout_dimred(x4) 75 | x4 = F.interpolate(x4, size=l3.size()[2:], mode='bilinear', align_corners=True) 76 | 77 | x3 = self.p_ims1d2_outl2_dimred(l3) 78 | x3 = self.adapt_stage2_b2_joint_varout_dimred(x3) 79 | x3 = x3 + x4 80 | x3 = F.relu(x3) 81 | x3 = self.mflow_conv_g2_pool(x3) 82 | x3 = self.mflow_conv_g2_b3_joint_varout_dimred(x3) 83 | x3 = F.interpolate(x3, size=l2.size()[2:], mode='bilinear', align_corners=True) 84 | 85 | x2 = self.p_ims1d2_outl3_dimred(l2) 86 | x2 = self.adapt_stage3_b2_joint_varout_dimred(x2) 87 | x2 = x2 + x3 88 | x2 = F.relu(x2) 89 | x2 = self.mflow_conv_g3_pool(x2) 90 | x2 = self.mflow_conv_g3_b3_joint_varout_dimred(x2) 91 | x2 = F.interpolate(x2, size=l1.size()[2:], mode='bilinear', align_corners=True) 92 | 93 | x1 = self.p_ims1d2_outl4_dimred(l1) 94 | x1 = self.adapt_stage4_b2_joint_varout_dimred(x1) 95 | x1 = x1 + x2 96 | x1 = F.relu(x1) 97 | x1 = self.mflow_conv_g4_pool(x1) 98 | 99 | out = self.clf_conv(x1) 100 | return out 101 | 102 | 103 | class CRPBlock(nn.Module): 104 | 105 | def __init__(self, in_planes, out_planes, n_stages): 106 | super(CRPBlock, self).__init__() 107 | for i in range(n_stages): 108 | setattr(self, '{}_{}'.format(i + 1, 'outvar_dimred'), 109 | nn.Conv2d(in_planes if (i == 0) else out_planes, 110 | out_planes, 1, stride=1, 111 | bias=False)) 112 | self.stride = 1 113 | self.n_stages = n_stages 114 | self.maxpool = nn.MaxPool2d(kernel_size=5, stride=1, padding=2) 115 | 116 | def forward(self, x): 117 | top = x 118 | for i in range(self.n_stages): 119 | top = self.maxpool(top) 120 | top = getattr(self, '{}_{}'.format(i + 1, 'outvar_dimred'))(top) 121 | x = top + x 122 | return x -------------------------------------------------------------------------------- /segmentron/data/dataloader/mscoco.py: -------------------------------------------------------------------------------- 1 | """MSCOCO Semantic Segmentation pretraining for VOC.""" 2 | import os 3 | import pickle 4 | import torch 5 | import numpy as np 6 | 7 | from tqdm import trange 8 | from PIL import Image 9 | from .seg_data_base import SegmentationDataset 10 | 11 | 12 | class COCOSegmentation(SegmentationDataset): 13 | """COCO Semantic Segmentation Dataset for VOC Pre-training. 14 | 15 | Parameters 16 | ---------- 17 | root : string 18 | Path to ADE20K folder. Default is './datasets/coco' 19 | split: string 20 | 'train', 'val' or 'test' 21 | transform : callable, optional 22 | A function that transforms the image 23 | Examples 24 | -------- 25 | >>> from torchvision import transforms 26 | >>> import torch.utils.data as data 27 | >>> # Transforms for Normalization 28 | >>> input_transform = transforms.Compose([ 29 | >>> transforms.ToTensor(), 30 | >>> transforms.Normalize((.485, .456, .406), (.229, .224, .225)), 31 | >>> ]) 32 | >>> # Create Dataset 33 | >>> trainset = COCOSegmentation(split='train', transform=input_transform) 34 | >>> # Create Training Loader 35 | >>> train_data = data.DataLoader( 36 | >>> trainset, 4, shuffle=True, 37 | >>> num_workers=4) 38 | """ 39 | CAT_LIST = [0, 5, 2, 16, 9, 44, 6, 3, 17, 62, 21, 67, 18, 19, 4, 40 | 1, 64, 20, 63, 7, 72] 41 | NUM_CLASS = 21 42 | 43 | def __init__(self, root='datasets/coco', split='train', mode=None, transform=None, **kwargs): 44 | super(COCOSegmentation, self).__init__(root, split, mode, transform, **kwargs) 45 | # lazy import pycocotools 46 | from pycocotools.coco import COCO 47 | from pycocotools import mask 48 | if split == 'train': 49 | print('train set') 50 | ann_file = os.path.join(root, 'annotations/instances_train2017.json') 51 | ids_file = os.path.join(root, 'annotations/train_ids.pkl') 52 | self.root = os.path.join(root, 'train2017') 53 | else: 54 | print('val set') 55 | ann_file = os.path.join(root, 'annotations/instances_val2017.json') 56 | ids_file = os.path.join(root, 'annotations/val_ids.pkl') 57 | self.root = os.path.join(root, 'val2017') 58 | self.coco = COCO(ann_file) 59 | self.coco_mask = mask 60 | if os.path.exists(ids_file): 61 | with open(ids_file, 'rb') as f: 62 | self.ids = pickle.load(f) 63 | else: 64 | ids = list(self.coco.imgs.keys()) 65 | self.ids = self._preprocess(ids, ids_file) 66 | self.transform = transform 67 | 68 | def __getitem__(self, index): 69 | coco = self.coco 70 | img_id = self.ids[index] 71 | img_metadata = coco.loadImgs(img_id)[0] 72 | path = img_metadata['file_name'] 73 | img = Image.open(os.path.join(self.root, path)).convert('RGB') 74 | cocotarget = coco.loadAnns(coco.getAnnIds(imgIds=img_id)) 75 | mask = Image.fromarray(self._gen_seg_mask( 76 | cocotarget, img_metadata['height'], img_metadata['width'])) 77 | # synchrosized transform 78 | if self.mode == 'train': 79 | img, mask = self._sync_transform(img, mask) 80 | elif self.mode == 'val': 81 | img, mask = self._val_sync_transform(img, mask) 82 | else: 83 | assert self.mode == 'testval' 84 | img, mask = self._img_transform(img), self._mask_transform(mask) 85 | # general resize, normalize and toTensor 86 | if self.transform is not None: 87 | img = self.transform(img) 88 | return img, mask, os.path.basename(path) 89 | 90 | def __len__(self): 91 | return len(self.ids) 92 | 93 | def _mask_transform(self, mask): 94 | return torch.LongTensor(np.array(mask).astype('int32')) 95 | 96 | def _gen_seg_mask(self, target, h, w): 97 | mask = np.zeros((h, w), dtype=np.uint8) 98 | coco_mask = self.coco_mask 99 | for instance in target: 100 | rle = coco_mask.frPyObjects(instance['segmentation'], h, w) 101 | m = coco_mask.decode(rle) 102 | cat = instance['category_id'] 103 | if cat in self.CAT_LIST: 104 | c = self.CAT_LIST.index(cat) 105 | else: 106 | continue 107 | if len(m.shape) < 3: 108 | mask[:, :] += (mask == 0) * (m * c) 109 | else: 110 | mask[:, :] += (mask == 0) * (((np.sum(m, axis=2)) > 0) * c).astype(np.uint8) 111 | return mask 112 | 113 | def _preprocess(self, ids, ids_file): 114 | print("Preprocessing mask, this will take a while." + \ 115 | "But don't worry, it only run once for each split.") 116 | tbar = trange(len(ids)) 117 | new_ids = [] 118 | for i in tbar: 119 | img_id = ids[i] 120 | cocotarget = self.coco.loadAnns(self.coco.getAnnIds(imgIds=img_id)) 121 | img_metadata = self.coco.loadImgs(img_id)[0] 122 | mask = self._gen_seg_mask(cocotarget, img_metadata['height'], img_metadata['width']) 123 | # more than 1k pixels 124 | if (mask > 0).sum() > 1000: 125 | new_ids.append(img_id) 126 | tbar.set_description('Doing: {}/{}, got {} qualified images'. \ 127 | format(i, len(ids), len(new_ids))) 128 | print('Found number of qualified images: ', len(new_ids)) 129 | with open(ids_file, 'wb') as f: 130 | pickle.dump(new_ids, f) 131 | return new_ids 132 | 133 | @property 134 | def classes(self): 135 | """Category names.""" 136 | return ('background', 'airplane', 'bicycle', 'bird', 'boat', 'bottle', 137 | 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 138 | 'motorcycle', 'person', 'potted-plant', 'sheep', 'sofa', 'train', 139 | 'tv') 140 | -------------------------------------------------------------------------------- /segmentron/modules/sync_bn/syncbn.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Hang Zhang 3 | ## ECE Department, Rutgers University 4 | ## Email: zhang.hang@rutgers.edu 5 | ## Copyright (c) 2017 6 | ## 7 | ## This source code is licensed under the MIT-style license found in the 8 | ## LICENSE file in the root directory of this source tree 9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 10 | 11 | """Synchronized Cross-GPU Batch Normalization Module""" 12 | import warnings 13 | import torch 14 | 15 | from torch.nn.modules.batchnorm import _BatchNorm 16 | from queue import Queue 17 | from .functions import * 18 | 19 | __all__ = ['SyncBatchNorm', 'BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d'] 20 | 21 | 22 | # Adopt from https://github.com/zhanghang1989/PyTorch-Encoding/blob/master/encoding/nn/syncbn.py 23 | class SyncBatchNorm(_BatchNorm): 24 | """Cross-GPU Synchronized Batch normalization (SyncBN) 25 | 26 | Parameters: 27 | num_features: num_features from an expected input of 28 | size batch_size x num_features x height x width 29 | eps: a value added to the denominator for numerical stability. 30 | Default: 1e-5 31 | momentum: the value used for the running_mean and running_var 32 | computation. Default: 0.1 33 | sync: a boolean value that when set to ``True``, synchronize across 34 | different gpus. Default: ``True`` 35 | activation : str 36 | Name of the activation functions, one of: `leaky_relu` or `none`. 37 | slope : float 38 | Negative slope for the `leaky_relu` activation. 39 | 40 | Shape: 41 | - Input: :math:`(N, C, H, W)` 42 | - Output: :math:`(N, C, H, W)` (same shape as input) 43 | Reference: 44 | .. [1] Ioffe, Sergey, and Christian Szegedy. "Batch normalization: Accelerating deep network training by reducing internal covariate shift." *ICML 2015* 45 | .. [2] Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi, and Amit Agrawal. "Context Encoding for Semantic Segmentation." *CVPR 2018* 46 | Examples: 47 | >>> m = SyncBatchNorm(100) 48 | >>> net = torch.nn.DataParallel(m) 49 | >>> output = net(input) 50 | """ 51 | 52 | def __init__(self, num_features, eps=1e-5, momentum=0.1, sync=True, activation='none', slope=0.01, inplace=True): 53 | super(SyncBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=True) 54 | self.activation = activation 55 | self.inplace = False if activation == 'none' else inplace 56 | self.slope = slope 57 | self.devices = list(range(torch.cuda.device_count())) 58 | self.sync = sync if len(self.devices) > 1 else False 59 | # Initialize queues 60 | self.worker_ids = self.devices[1:] 61 | self.master_queue = Queue(len(self.worker_ids)) 62 | self.worker_queues = [Queue(1) for _ in self.worker_ids] 63 | 64 | def forward(self, x): 65 | # resize the input to (B, C, -1) 66 | input_shape = x.size() 67 | x = x.view(input_shape[0], self.num_features, -1) 68 | if x.get_device() == self.devices[0]: 69 | # Master mode 70 | extra = { 71 | "is_master": True, 72 | "master_queue": self.master_queue, 73 | "worker_queues": self.worker_queues, 74 | "worker_ids": self.worker_ids 75 | } 76 | else: 77 | # Worker mode 78 | extra = { 79 | "is_master": False, 80 | "master_queue": self.master_queue, 81 | "worker_queue": self.worker_queues[self.worker_ids.index(x.get_device())] 82 | } 83 | if self.inplace: 84 | return inp_syncbatchnorm(x, self.weight, self.bias, self.running_mean, self.running_var, 85 | extra, self.sync, self.training, self.momentum, self.eps, 86 | self.activation, self.slope).view(input_shape) 87 | else: 88 | return syncbatchnorm(x, self.weight, self.bias, self.running_mean, self.running_var, 89 | extra, self.sync, self.training, self.momentum, self.eps, 90 | self.activation, self.slope).view(input_shape) 91 | 92 | def extra_repr(self): 93 | if self.activation == 'none': 94 | return 'sync={}'.format(self.sync) 95 | else: 96 | return 'sync={}, act={}, slope={}, inplace={}'.format( 97 | self.sync, self.activation, self.slope, self.inplace) 98 | 99 | 100 | class BatchNorm1d(SyncBatchNorm): 101 | """BatchNorm1d is deprecated in favor of :class:`core.nn.sync_bn.SyncBatchNorm`.""" 102 | 103 | def __init__(self, *args, **kwargs): 104 | warnings.warn("core.nn.sync_bn.{} is now deprecated in favor of core.nn.sync_bn.{}." 105 | .format('BatchNorm1d', SyncBatchNorm.__name__), DeprecationWarning) 106 | super(BatchNorm1d, self).__init__(*args, **kwargs) 107 | 108 | 109 | class BatchNorm2d(SyncBatchNorm): 110 | """BatchNorm1d is deprecated in favor of :class:`core.nn.sync_bn.SyncBatchNorm`.""" 111 | 112 | def __init__(self, *args, **kwargs): 113 | warnings.warn("core.nn.sync_bn.{} is now deprecated in favor of core.nn.sync_bn.{}." 114 | .format('BatchNorm2d', SyncBatchNorm.__name__), DeprecationWarning) 115 | super(BatchNorm2d, self).__init__(*args, **kwargs) 116 | 117 | 118 | class BatchNorm3d(SyncBatchNorm): 119 | """BatchNorm1d is deprecated in favor of :class:`core.nn.sync_bn.SyncBatchNorm`.""" 120 | 121 | def __init__(self, *args, **kwargs): 122 | warnings.warn("core.nn.sync_bn.{} is now deprecated in favor of core.nn.sync_bn.{}." 123 | .format('BatchNorm3d', SyncBatchNorm.__name__), DeprecationWarning) 124 | super(BatchNorm3d, self).__init__(*args, **kwargs) 125 | -------------------------------------------------------------------------------- /segmentron/utils/parallel.py: -------------------------------------------------------------------------------- 1 | """Utils for Semantic Segmentation""" 2 | import threading 3 | import torch 4 | import torch.cuda.comm as comm 5 | from torch.nn.parallel.data_parallel import DataParallel 6 | from torch.nn.parallel._functions import Broadcast 7 | from torch.autograd import Function 8 | 9 | __all__ = ['DataParallelModel', 'DataParallelCriterion'] 10 | 11 | 12 | class Reduce(Function): 13 | @staticmethod 14 | def forward(ctx, *inputs): 15 | ctx.target_gpus = [inputs[i].get_device() for i in range(len(inputs))] 16 | inputs = sorted(inputs, key=lambda i: i.get_device()) 17 | return comm.reduce_add(inputs) 18 | 19 | @staticmethod 20 | def backward(ctx, gradOutputs): 21 | return Broadcast.apply(ctx.target_gpus, gradOutputs) 22 | 23 | 24 | class DataParallelModel(DataParallel): 25 | """Data parallelism 26 | 27 | Hide the difference of single/multiple GPUs to the user. 28 | In the forward pass, the module is replicated on each device, 29 | and each replica handles a portion of the input. During the backwards 30 | pass, gradients from each replica are summed into the original module. 31 | 32 | The batch size should be larger than the number of GPUs used. 33 | 34 | Parameters 35 | ---------- 36 | module : object 37 | Network to be parallelized. 38 | sync : bool 39 | enable synchronization (default: False). 40 | Inputs: 41 | - **inputs**: list of input 42 | Outputs: 43 | - **outputs**: list of output 44 | Example:: 45 | >>> net = DataParallelModel(model, device_ids=[0, 1, 2]) 46 | >>> output = net(input_var) # input_var can be on any device, including CPU 47 | """ 48 | 49 | def gather(self, outputs, output_device): 50 | return outputs 51 | 52 | def replicate(self, module, device_ids): 53 | modules = super(DataParallelModel, self).replicate(module, device_ids) 54 | return modules 55 | 56 | 57 | # Reference: https://github.com/zhanghang1989/PyTorch-Encoding/blob/master/encoding/parallel.py 58 | class DataParallelCriterion(DataParallel): 59 | """ 60 | Calculate loss in multiple-GPUs, which balance the memory usage for 61 | Semantic Segmentation. 62 | 63 | The targets are splitted across the specified devices by chunking in 64 | the batch dimension. Please use together with :class:`encoding.parallel.DataParallelModel`. 65 | 66 | Example:: 67 | >>> net = DataParallelModel(model, device_ids=[0, 1, 2]) 68 | >>> criterion = DataParallelCriterion(criterion, device_ids=[0, 1, 2]) 69 | >>> y = net(x) 70 | >>> loss = criterion(y, target) 71 | """ 72 | 73 | def forward(self, inputs, *targets, **kwargs): 74 | # the inputs should be the outputs of DataParallelModel 75 | if not self.device_ids: 76 | return self.module(inputs, *targets, **kwargs) 77 | targets, kwargs = self.scatter(targets, kwargs, self.device_ids) 78 | if len(self.device_ids) == 1: 79 | return self.module(inputs, *targets[0], **kwargs[0]) 80 | replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) 81 | outputs = criterion_parallel_apply(replicas, inputs, targets, kwargs) 82 | return Reduce.apply(*outputs) / len(outputs) 83 | 84 | 85 | def get_a_var(obj): 86 | if isinstance(obj, torch.Tensor): 87 | return obj 88 | 89 | if isinstance(obj, list) or isinstance(obj, tuple): 90 | for result in map(get_a_var, obj): 91 | if isinstance(result, torch.Tensor): 92 | return result 93 | 94 | if isinstance(obj, dict): 95 | for result in map(get_a_var, obj.items()): 96 | if isinstance(result, torch.Tensor): 97 | return result 98 | return None 99 | 100 | 101 | def criterion_parallel_apply(modules, inputs, targets, kwargs_tup=None, devices=None): 102 | r"""Applies each `module` in :attr:`modules` in parallel on arguments 103 | contained in :attr:`inputs` (positional), attr:'targets' (positional) and :attr:`kwargs_tup` (keyword) 104 | on each of :attr:`devices`. 105 | 106 | Args: 107 | modules (Module): modules to be parallelized 108 | inputs (tensor): inputs to the modules 109 | targets (tensor): targets to the modules 110 | devices (list of int or torch.device): CUDA devices 111 | :attr:`modules`, :attr:`inputs`, :attr:'targets' :attr:`kwargs_tup` (if given), and 112 | :attr:`devices` (if given) should all have same length. Moreover, each 113 | element of :attr:`inputs` can either be a single object as the only argument 114 | to a module, or a collection of positional arguments. 115 | """ 116 | assert len(modules) == len(inputs) 117 | assert len(targets) == len(inputs) 118 | if kwargs_tup is not None: 119 | assert len(modules) == len(kwargs_tup) 120 | else: 121 | kwargs_tup = ({},) * len(modules) 122 | if devices is not None: 123 | assert len(modules) == len(devices) 124 | else: 125 | devices = [None] * len(modules) 126 | lock = threading.Lock() 127 | results = {} 128 | grad_enabled = torch.is_grad_enabled() 129 | 130 | def _worker(i, module, input, target, kwargs, device=None): 131 | torch.set_grad_enabled(grad_enabled) 132 | if device is None: 133 | device = get_a_var(input).get_device() 134 | try: 135 | with torch.cuda.device(device): 136 | output = module(*(list(input) + target), **kwargs) 137 | with lock: 138 | results[i] = output 139 | except Exception as e: 140 | with lock: 141 | results[i] = e 142 | 143 | if len(modules) > 1: 144 | threads = [threading.Thread(target=_worker, 145 | args=(i, module, input, target, kwargs, device)) 146 | for i, (module, input, target, kwargs, device) in 147 | enumerate(zip(modules, inputs, targets, kwargs_tup, devices))] 148 | 149 | for thread in threads: 150 | thread.start() 151 | for thread in threads: 152 | thread.join() 153 | else: 154 | _worker(0, modules[0], inputs[0], targets[0], kwargs_tup[0], devices[0]) 155 | 156 | outputs = [] 157 | for i in range(len(inputs)): 158 | output = results[i] 159 | if isinstance(output, Exception): 160 | raise output 161 | outputs.append(output) 162 | return outputs 163 | -------------------------------------------------------------------------------- /segmentron/models/dabnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .segbase import SegBaseModel 6 | from .model_zoo import MODEL_REGISTRY 7 | from ..modules import _FCNHead 8 | from ..config import cfg 9 | 10 | __all__ = ["DABNet"] 11 | 12 | 13 | @MODEL_REGISTRY.register() 14 | class DABNet(SegBaseModel): 15 | def __init__(self, block_1=3, block_2=6): 16 | super().__init__(need_backbone=False) 17 | self.init_conv = nn.Sequential( 18 | Conv(3, 32, 3, 2, padding=1, bn_acti=True), 19 | Conv(32, 32, 3, 1, padding=1, bn_acti=True), 20 | Conv(32, 32, 3, 1, padding=1, bn_acti=True), 21 | ) 22 | 23 | self.down_1 = InputInjection(1) # down-sample the image 1 times 24 | self.down_2 = InputInjection(2) # down-sample the image 2 times 25 | self.down_3 = InputInjection(3) # down-sample the image 3 times 26 | 27 | self.bn_prelu_1 = BNPReLU(32 + 3) 28 | 29 | # DAB Block 1 30 | self.downsample_1 = DownSamplingBlock(32 + 3, 64) 31 | self.DAB_Block_1 = nn.Sequential() 32 | for i in range(0, block_1): 33 | self.DAB_Block_1.add_module("DAB_Module_1_" + str(i), DABModule(64, d=2)) 34 | self.bn_prelu_2 = BNPReLU(128 + 3) 35 | 36 | # DAB Block 2 37 | dilation_block_2 = [4, 4, 8, 8, 16, 16] 38 | self.downsample_2 = DownSamplingBlock(128 + 3, 128) 39 | self.DAB_Block_2 = nn.Sequential() 40 | for i in range(0, block_2): 41 | self.DAB_Block_2.add_module("DAB_Module_2_" + str(i), 42 | DABModule(128, d=dilation_block_2[i])) 43 | self.bn_prelu_3 = BNPReLU(256 + 3) 44 | 45 | self.classifier = nn.Sequential(Conv(259, self.nclass, 1, 1, padding=0)) 46 | 47 | def forward(self, input): 48 | 49 | output0 = self.init_conv(input) 50 | 51 | down_1 = self.down_1(input) 52 | down_2 = self.down_2(input) 53 | down_3 = self.down_3(input) 54 | 55 | output0_cat = self.bn_prelu_1(torch.cat([output0, down_1], 1)) 56 | 57 | # DAB Block 1 58 | output1_0 = self.downsample_1(output0_cat) 59 | output1 = self.DAB_Block_1(output1_0) 60 | output1_cat = self.bn_prelu_2(torch.cat([output1, output1_0, down_2], 1)) 61 | 62 | # DAB Block 2 63 | output2_0 = self.downsample_2(output1_cat) 64 | output2 = self.DAB_Block_2(output2_0) 65 | output2_cat = self.bn_prelu_3(torch.cat([output2, output2_0, down_3], 1)) 66 | 67 | out = self.classifier(output2_cat) 68 | out = F.interpolate(out, input.size()[2:], mode='bilinear', align_corners=False) 69 | 70 | outputs = list() 71 | outputs.append(out) 72 | return outputs 73 | 74 | 75 | class Conv(nn.Module): 76 | def __init__(self, nIn, nOut, kSize, stride, padding, dilation=(1, 1), groups=1, bn_acti=False, bias=False): 77 | super().__init__() 78 | 79 | self.bn_acti = bn_acti 80 | 81 | self.conv = nn.Conv2d(nIn, nOut, kernel_size=kSize, 82 | stride=stride, padding=padding, 83 | dilation=dilation, groups=groups, bias=bias) 84 | 85 | if self.bn_acti: 86 | self.bn_prelu = BNPReLU(nOut) 87 | 88 | def forward(self, input): 89 | output = self.conv(input) 90 | 91 | if self.bn_acti: 92 | output = self.bn_prelu(output) 93 | 94 | return output 95 | 96 | 97 | class BNPReLU(nn.Module): 98 | def __init__(self, nIn): 99 | super().__init__() 100 | self.bn = nn.BatchNorm2d(nIn, eps=1e-3) 101 | self.acti = nn.PReLU(nIn) 102 | 103 | def forward(self, input): 104 | output = self.bn(input) 105 | output = self.acti(output) 106 | 107 | return output 108 | 109 | 110 | class DABModule(nn.Module): 111 | def __init__(self, nIn, d=1, kSize=3, dkSize=3): 112 | super().__init__() 113 | 114 | self.bn_relu_1 = BNPReLU(nIn) 115 | self.conv3x3 = Conv(nIn, nIn // 2, kSize, 1, padding=1, bn_acti=True) 116 | 117 | self.dconv3x1 = Conv(nIn // 2, nIn // 2, (dkSize, 1), 1, 118 | padding=(1, 0), groups=nIn // 2, bn_acti=True) 119 | self.dconv1x3 = Conv(nIn // 2, nIn // 2, (1, dkSize), 1, 120 | padding=(0, 1), groups=nIn // 2, bn_acti=True) 121 | self.ddconv3x1 = Conv(nIn // 2, nIn // 2, (dkSize, 1), 1, 122 | padding=(1 * d, 0), dilation=(d, 1), groups=nIn // 2, bn_acti=True) 123 | self.ddconv1x3 = Conv(nIn // 2, nIn // 2, (1, dkSize), 1, 124 | padding=(0, 1 * d), dilation=(1, d), groups=nIn // 2, bn_acti=True) 125 | 126 | self.bn_relu_2 = BNPReLU(nIn // 2) 127 | self.conv1x1 = Conv(nIn // 2, nIn, 1, 1, padding=0, bn_acti=False) 128 | 129 | def forward(self, input): 130 | output = self.bn_relu_1(input) 131 | output = self.conv3x3(output) 132 | 133 | br1 = self.dconv3x1(output) 134 | br1 = self.dconv1x3(br1) 135 | br2 = self.ddconv3x1(output) 136 | br2 = self.ddconv1x3(br2) 137 | 138 | output = br1 + br2 139 | output = self.bn_relu_2(output) 140 | output = self.conv1x1(output) 141 | 142 | return output + input 143 | 144 | 145 | class DownSamplingBlock(nn.Module): 146 | def __init__(self, nIn, nOut): 147 | super().__init__() 148 | self.nIn = nIn 149 | self.nOut = nOut 150 | 151 | if self.nIn < self.nOut: 152 | nConv = nOut - nIn 153 | else: 154 | nConv = nOut 155 | 156 | self.conv3x3 = Conv(nIn, nConv, kSize=3, stride=2, padding=1) 157 | self.max_pool = nn.MaxPool2d(2, stride=2) 158 | self.bn_prelu = BNPReLU(nOut) 159 | 160 | def forward(self, input): 161 | output = self.conv3x3(input) 162 | 163 | if self.nIn < self.nOut: 164 | max_pool = self.max_pool(input) 165 | output = torch.cat([output, max_pool], 1) 166 | 167 | output = self.bn_prelu(output) 168 | 169 | return output 170 | 171 | 172 | class InputInjection(nn.Module): 173 | def __init__(self, ratio): 174 | super().__init__() 175 | self.pool = nn.ModuleList() 176 | for i in range(0, ratio): 177 | self.pool.append(nn.AvgPool2d(3, stride=2, padding=1)) 178 | 179 | def forward(self, input): 180 | for pool in self.pool: 181 | input = pool(input) 182 | 183 | return input -------------------------------------------------------------------------------- /segmentron/models/lednet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .segbase import SegBaseModel 6 | from .model_zoo import MODEL_REGISTRY 7 | from ..modules import _ConvBNReLU 8 | 9 | __all__ = ['LEDNet'] 10 | 11 | 12 | @MODEL_REGISTRY.register() 13 | class LEDNet(SegBaseModel): 14 | r"""LEDNet 15 | Reference: 16 | Yu Wang, et al. "LEDNet: A Lightweight Encoder-Decoder Network for Real-Time Semantic Segmentation." 17 | arXiv preprint arXiv:1905.02423 (2019). 18 | """ 19 | 20 | def __init__(self): 21 | super(LEDNet, self).__init__(need_backbone=False) 22 | self.encoder = nn.Sequential( 23 | Downsampling(3, 32), 24 | SSnbt(32, norm_layer=self.norm_layer), 25 | SSnbt(32, norm_layer=self.norm_layer), 26 | SSnbt(32, norm_layer=self.norm_layer), 27 | Downsampling(32, 64), 28 | SSnbt(64, norm_layer=self.norm_layer), 29 | SSnbt(64, norm_layer=self.norm_layer), 30 | Downsampling(64, 128), 31 | SSnbt(128, norm_layer=self.norm_layer), 32 | SSnbt(128, 2, norm_layer=self.norm_layer), 33 | SSnbt(128, 5, norm_layer=self.norm_layer), 34 | SSnbt(128, 9, norm_layer=self.norm_layer), 35 | SSnbt(128, 2, norm_layer=self.norm_layer), 36 | SSnbt(128, 5, norm_layer=self.norm_layer), 37 | SSnbt(128, 9, norm_layer=self.norm_layer), 38 | SSnbt(128, 17, norm_layer=self.norm_layer), 39 | ) 40 | self.head = APNModule(128, self.nclass, norm_layer=self.norm_layer) 41 | 42 | self.__setattr__('decoder', ['head']) 43 | 44 | def forward(self, x): 45 | size = x.size()[2:] 46 | x = self.encoder(x) 47 | x = self.head(x) 48 | outputs = list() 49 | x = F.interpolate(x, size, mode='bilinear', align_corners=True) 50 | outputs.append(x) 51 | 52 | return tuple(outputs) 53 | 54 | 55 | class Downsampling(nn.Module): 56 | def __init__(self, in_channels, out_channels): 57 | super(Downsampling, self).__init__() 58 | self.conv1 = nn.Conv2d(in_channels, out_channels // 2, 3, 2, 2, bias=False) 59 | self.conv2 = nn.Conv2d(in_channels, out_channels // 2, 3, 2, 2, bias=False) 60 | self.pool = nn.MaxPool2d(kernel_size=2, stride=1) 61 | 62 | def forward(self, x): 63 | x1 = self.conv1(x) 64 | x1 = self.pool(x1) 65 | 66 | x2 = self.conv2(x) 67 | x2 = self.pool(x2) 68 | 69 | return torch.cat([x1, x2], dim=1) 70 | 71 | 72 | class SSnbt(nn.Module): 73 | def __init__(self, in_channels, dilation=1, norm_layer=nn.BatchNorm2d): 74 | super(SSnbt, self).__init__() 75 | inter_channels = in_channels // 2 76 | self.branch1 = nn.Sequential( 77 | nn.Conv2d(inter_channels, inter_channels, (3, 1), padding=(1, 0), bias=False), 78 | nn.ReLU(True), 79 | nn.Conv2d(inter_channels, inter_channels, (1, 3), padding=(0, 1), bias=False), 80 | norm_layer(inter_channels), 81 | nn.ReLU(True), 82 | nn.Conv2d(inter_channels, inter_channels, (3, 1), padding=(dilation, 0), dilation=(dilation, 1), 83 | bias=False), 84 | nn.ReLU(True), 85 | nn.Conv2d(inter_channels, inter_channels, (1, 3), padding=(0, dilation), dilation=(1, dilation), 86 | bias=False), 87 | norm_layer(inter_channels), 88 | nn.ReLU(True)) 89 | 90 | self.branch2 = nn.Sequential( 91 | nn.Conv2d(inter_channels, inter_channels, (1, 3), padding=(0, 1), bias=False), 92 | nn.ReLU(True), 93 | nn.Conv2d(inter_channels, inter_channels, (3, 1), padding=(1, 0), bias=False), 94 | norm_layer(inter_channels), 95 | nn.ReLU(True), 96 | nn.Conv2d(inter_channels, inter_channels, (1, 3), padding=(0, dilation), dilation=(1, dilation), 97 | bias=False), 98 | nn.ReLU(True), 99 | nn.Conv2d(inter_channels, inter_channels, (3, 1), padding=(dilation, 0), dilation=(dilation, 1), 100 | bias=False), 101 | norm_layer(inter_channels), 102 | nn.ReLU(True)) 103 | 104 | self.relu = nn.ReLU(True) 105 | 106 | @staticmethod 107 | def channel_shuffle(x, groups): 108 | n, c, h, w = x.size() 109 | 110 | channels_per_group = c // groups 111 | x = x.view(n, groups, channels_per_group, h, w) 112 | x = torch.transpose(x, 1, 2).contiguous() 113 | x = x.view(n, -1, h, w) 114 | 115 | return x 116 | 117 | def forward(self, x): 118 | # channels split 119 | x1, x2 = x.split(x.size(1) // 2, 1) 120 | 121 | x1 = self.branch1(x1) 122 | x2 = self.branch2(x2) 123 | 124 | out = torch.cat([x1, x2], dim=1) 125 | out = self.relu(out + x) 126 | out = self.channel_shuffle(out, groups=2) 127 | 128 | return out 129 | 130 | 131 | class APNModule(nn.Module): 132 | def __init__(self, in_channels, nclass, norm_layer=nn.BatchNorm2d): 133 | super(APNModule, self).__init__() 134 | self.conv1 = _ConvBNReLU(in_channels, in_channels, 3, 2, 1, norm_layer=norm_layer) 135 | self.conv2 = _ConvBNReLU(in_channels, in_channels, 5, 2, 2, norm_layer=norm_layer) 136 | self.conv3 = _ConvBNReLU(in_channels, in_channels, 7, 2, 3, norm_layer=norm_layer) 137 | self.level1 = _ConvBNReLU(in_channels, nclass, 1, norm_layer=norm_layer) 138 | self.level2 = _ConvBNReLU(in_channels, nclass, 1, norm_layer=norm_layer) 139 | self.level3 = _ConvBNReLU(in_channels, nclass, 1, norm_layer=norm_layer) 140 | self.level4 = _ConvBNReLU(in_channels, nclass, 1, norm_layer=norm_layer) 141 | self.level5 = nn.Sequential( 142 | nn.AdaptiveAvgPool2d(1), 143 | _ConvBNReLU(in_channels, nclass, 1)) 144 | 145 | def forward(self, x): 146 | w, h = x.size()[2:] 147 | branch3 = self.conv1(x) 148 | branch2 = self.conv2(branch3) 149 | branch1 = self.conv3(branch2) 150 | 151 | out = self.level1(branch1) 152 | out = F.interpolate(out, ((w + 3) // 4, (h + 3) // 4), mode='bilinear', align_corners=True) 153 | out = self.level2(branch2) + out 154 | out = F.interpolate(out, ((w + 1) // 2, (h + 1) // 2), mode='bilinear', align_corners=True) 155 | out = self.level3(branch3) + out 156 | out = F.interpolate(out, (w, h), mode='bilinear', align_corners=True) 157 | out = self.level4(x) * out 158 | out = self.level5(x) + out 159 | return out 160 | -------------------------------------------------------------------------------- /segmentron/data/dataloader/cityscapes.py: -------------------------------------------------------------------------------- 1 | """Prepare Cityscapes dataset""" 2 | import os 3 | import torch 4 | import numpy as np 5 | import logging 6 | 7 | from PIL import Image 8 | from .seg_data_base import SegmentationDataset 9 | 10 | 11 | class CitySegmentation(SegmentationDataset): 12 | """Cityscapes Semantic Segmentation Dataset. 13 | 14 | Parameters 15 | ---------- 16 | root : string 17 | Path to Cityscapes folder. Default is './datasets/cityscapes' 18 | split: string 19 | 'train', 'val' or 'test' 20 | transform : callable, optional 21 | A function that transforms the image 22 | Examples 23 | -------- 24 | >>> from torchvision import transforms 25 | >>> import torch.utils.data as data 26 | >>> # Transforms for Normalization 27 | >>> input_transform = transforms.Compose([ 28 | >>> transforms.ToTensor(), 29 | >>> transforms.Normalize((.485, .456, .406), (.229, .224, .225)), 30 | >>> ]) 31 | >>> # Create Dataset 32 | >>> trainset = CitySegmentation(split='train', transform=input_transform) 33 | >>> # Create Training Loader 34 | >>> train_data = data.DataLoader( 35 | >>> trainset, 4, shuffle=True, 36 | >>> num_workers=4) 37 | """ 38 | BASE_DIR = 'cityscapes' 39 | NUM_CLASS = 19 40 | 41 | def __init__(self, root='datasets/cityscapes', split='train', mode=None, transform=None, **kwargs): 42 | super(CitySegmentation, self).__init__(root, split, mode, transform, **kwargs) 43 | # self.root = os.path.join(root, self.BASE_DIR) 44 | assert os.path.exists(self.root), "Please put dataset in {SEG_ROOT}/datasets/cityscapes" 45 | self.images, self.mask_paths = _get_city_pairs(self.root, self.split) 46 | assert (len(self.images) == len(self.mask_paths)) 47 | if len(self.images) == 0: 48 | raise RuntimeError("Found 0 images in subfolders of:" + root + "\n") 49 | self.valid_classes = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 50 | 23, 24, 25, 26, 27, 28, 31, 32, 33] 51 | self._key = np.array([-1, -1, -1, -1, -1, -1, 52 | -1, -1, 0, 1, -1, -1, 53 | 2, 3, 4, -1, -1, -1, 54 | 5, -1, 6, 7, 8, 9, 55 | 10, 11, 12, 13, 14, 15, 56 | -1, -1, 16, 17, 18]) 57 | self._mapping = np.array(range(-1, len(self._key) - 1)).astype('int32') 58 | 59 | def _class_to_index(self, mask): 60 | # assert the value 61 | values = np.unique(mask) 62 | for value in values: 63 | assert (value in self._mapping) 64 | index = np.digitize(mask.ravel(), self._mapping, right=True) 65 | return self._key[index].reshape(mask.shape) 66 | 67 | def __getitem__(self, index): 68 | img = Image.open(self.images[index]).convert('RGB') 69 | if self.mode == 'test': 70 | if self.transform is not None: 71 | img = self.transform(img) 72 | return img, os.path.basename(self.images[index]) 73 | mask = Image.open(self.mask_paths[index]) 74 | # synchrosized transform 75 | if self.mode == 'train': 76 | img, mask = self._sync_transform(img, mask) 77 | elif self.mode == 'val': 78 | img, mask = self._val_sync_transform(img, mask) 79 | else: 80 | assert self.mode == 'testval' 81 | img, mask = self._img_transform(img), self._mask_transform(mask) 82 | # general resize, normalize and toTensor 83 | if self.transform is not None: 84 | img = self.transform(img) 85 | return img, mask, os.path.basename(self.images[index]) 86 | 87 | def _mask_transform(self, mask): 88 | target = self._class_to_index(np.array(mask).astype('int32')) 89 | return torch.LongTensor(np.array(target).astype('int32')) 90 | 91 | def __len__(self): 92 | return len(self.images) 93 | 94 | @property 95 | def pred_offset(self): 96 | return 0 97 | 98 | @property 99 | def classes(self): 100 | """Category names.""" 101 | return ('road', 'sidewalk', 'building', 'wall', 'fence', 'pole', 'traffic light', 102 | 'traffic sign', 'vegetation', 'terrain', 'sky', 'person', 'rider', 'car', 103 | 'truck', 'bus', 'train', 'motorcycle', 'bicycle') 104 | 105 | 106 | def _get_city_pairs(folder, split='train'): 107 | def get_path_pairs(img_folder, mask_folder): 108 | img_paths = [] 109 | mask_paths = [] 110 | for root, _, files in os.walk(img_folder): 111 | for filename in files: 112 | if filename.startswith('._'): 113 | continue 114 | if filename.endswith('.png'): 115 | imgpath = os.path.join(root, filename) 116 | foldername = os.path.basename(os.path.dirname(imgpath)) 117 | maskname = filename.replace('leftImg8bit', 'gtFine_labelIds') 118 | maskpath = os.path.join(mask_folder, foldername, maskname) 119 | if os.path.isfile(imgpath) and os.path.isfile(maskpath): 120 | img_paths.append(imgpath) 121 | mask_paths.append(maskpath) 122 | else: 123 | logging.info('cannot find the mask or image:', imgpath, maskpath) 124 | logging.info('Found {} images in the folder {}'.format(len(img_paths), img_folder)) 125 | return img_paths, mask_paths 126 | 127 | if split in ('train', 'val'): 128 | img_folder = os.path.join(folder, 'leftImg8bit/' + split) 129 | mask_folder = os.path.join(folder, 'gtFine/' + split) 130 | img_paths, mask_paths = get_path_pairs(img_folder, mask_folder) 131 | return img_paths, mask_paths 132 | else: 133 | assert split == 'trainval' 134 | logging.info('trainval set') 135 | train_img_folder = os.path.join(folder, 'leftImg8bit/train') 136 | train_mask_folder = os.path.join(folder, 'gtFine/train') 137 | val_img_folder = os.path.join(folder, 'leftImg8bit/val') 138 | val_mask_folder = os.path.join(folder, 'gtFine/val') 139 | train_img_paths, train_mask_paths = get_path_pairs(train_img_folder, train_mask_folder) 140 | val_img_paths, val_mask_paths = get_path_pairs(val_img_folder, val_mask_folder) 141 | img_paths = train_img_paths + val_img_paths 142 | mask_paths = train_mask_paths + val_mask_paths 143 | return img_paths, mask_paths 144 | 145 | 146 | if __name__ == '__main__': 147 | dataset = CitySegmentation() 148 | -------------------------------------------------------------------------------- /segmentron/models/backbones/mobilenet.py: -------------------------------------------------------------------------------- 1 | """MobileNet and MobileNetV2.""" 2 | import torch.nn as nn 3 | 4 | from .build import BACKBONE_REGISTRY 5 | from ...modules import _ConvBNReLU, _DepthwiseConv, InvertedResidual 6 | from ...config import cfg 7 | 8 | __all__ = ['MobileNet', 'MobileNetV2'] 9 | 10 | 11 | class MobileNet(nn.Module): 12 | def __init__(self, num_classes=1000, norm_layer=nn.BatchNorm2d): 13 | super(MobileNet, self).__init__() 14 | multiplier = cfg.MODEL.BACKBONE_SCALE 15 | conv_dw_setting = [ 16 | [64, 1, 1], 17 | [128, 2, 2], 18 | [256, 2, 2], 19 | [512, 6, 2], 20 | [1024, 2, 2]] 21 | input_channels = int(32 * multiplier) if multiplier > 1.0 else 32 22 | features = [_ConvBNReLU(3, input_channels, 3, 2, 1, norm_layer=norm_layer)] 23 | 24 | for c, n, s in conv_dw_setting: 25 | out_channels = int(c * multiplier) 26 | for i in range(n): 27 | stride = s if i == 0 else 1 28 | features.append(_DepthwiseConv(input_channels, out_channels, stride, norm_layer)) 29 | input_channels = out_channels 30 | self.last_inp_channels = int(1024 * multiplier) 31 | features.append(nn.AdaptiveAvgPool2d(1)) 32 | self.features = nn.Sequential(*features) 33 | 34 | self.classifier = nn.Linear(int(1024 * multiplier), num_classes) 35 | 36 | # weight initialization 37 | for m in self.modules(): 38 | if isinstance(m, nn.Conv2d): 39 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 40 | if m.bias is not None: 41 | nn.init.zeros_(m.bias) 42 | elif isinstance(m, nn.BatchNorm2d): 43 | nn.init.ones_(m.weight) 44 | nn.init.zeros_(m.bias) 45 | elif isinstance(m, nn.Linear): 46 | nn.init.normal_(m.weight, 0, 0.01) 47 | nn.init.zeros_(m.bias) 48 | 49 | def forward(self, x): 50 | x = self.features(x) 51 | x = self.classifier(x.view(x.size(0), x.size(1))) 52 | return x 53 | 54 | 55 | class MobileNetV2(nn.Module): 56 | def __init__(self, num_classes=1000, norm_layer=nn.BatchNorm2d): 57 | super(MobileNetV2, self).__init__() 58 | output_stride = cfg.MODEL.OUTPUT_STRIDE 59 | self.multiplier = cfg.MODEL.BACKBONE_SCALE 60 | if output_stride == 32: 61 | dilations = [1, 1] 62 | elif output_stride == 16: 63 | dilations = [1, 2] 64 | elif output_stride == 8: 65 | dilations = [2, 4] 66 | else: 67 | raise NotImplementedError 68 | inverted_residual_setting = [ 69 | # t, c, n, s 70 | [1, 16, 1, 1], 71 | [6, 24, 2, 2], 72 | [6, 32, 3, 2], 73 | [6, 64, 4, 2], 74 | [6, 96, 3, 1], 75 | [6, 160, 3, 2], 76 | [6, 320, 1, 1]] 77 | # building first layer 78 | input_channels = int(32 * self.multiplier) if self.multiplier > 1.0 else 32 79 | # last_channels = int(1280 * multiplier) if multiplier > 1.0 else 1280 80 | self.conv1 = _ConvBNReLU(3, input_channels, 3, 2, 1, relu6=True, norm_layer=norm_layer) 81 | 82 | # building inverted residual blocks 83 | self.planes = input_channels 84 | self.block1 = self._make_layer(InvertedResidual, self.planes, inverted_residual_setting[0:1], 85 | norm_layer=norm_layer) 86 | self.block2 = self._make_layer(InvertedResidual, self.planes, inverted_residual_setting[1:2], 87 | norm_layer=norm_layer) 88 | self.block3 = self._make_layer(InvertedResidual, self.planes, inverted_residual_setting[2:3], 89 | norm_layer=norm_layer) 90 | self.block4 = self._make_layer(InvertedResidual, self.planes, inverted_residual_setting[3:5], 91 | dilations[0], norm_layer=norm_layer) 92 | self.block5 = self._make_layer(InvertedResidual, self.planes, inverted_residual_setting[5:], 93 | dilations[1], norm_layer=norm_layer) 94 | self.last_inp_channels = self.planes 95 | 96 | # building last several layers 97 | # features = list() 98 | # features.append(_ConvBNReLU(input_channels, last_channels, 1, relu6=True, norm_layer=norm_layer)) 99 | # features.append(nn.AdaptiveAvgPool2d(1)) 100 | # self.features = nn.Sequential(*features) 101 | # 102 | # self.classifier = nn.Sequential( 103 | # nn.Dropout2d(0.2), 104 | # nn.Linear(last_channels, num_classes)) 105 | 106 | # weight initialization 107 | for m in self.modules(): 108 | if isinstance(m, nn.Conv2d): 109 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 110 | if m.bias is not None: 111 | nn.init.zeros_(m.bias) 112 | elif isinstance(m, nn.BatchNorm2d): 113 | nn.init.ones_(m.weight) 114 | nn.init.zeros_(m.bias) 115 | elif isinstance(m, nn.Linear): 116 | nn.init.normal_(m.weight, 0, 0.01) 117 | if m.bias is not None: 118 | nn.init.zeros_(m.bias) 119 | 120 | def _make_layer(self, block, planes, inverted_residual_setting, dilation=1, norm_layer=nn.BatchNorm2d): 121 | features = list() 122 | for t, c, n, s in inverted_residual_setting: 123 | out_channels = int(c * self.multiplier) 124 | stride = s if dilation == 1 else 1 125 | features.append(block(planes, out_channels, stride, t, dilation, norm_layer)) 126 | planes = out_channels 127 | for i in range(n - 1): 128 | features.append(block(planes, out_channels, 1, t, norm_layer=norm_layer)) 129 | planes = out_channels 130 | self.planes = planes 131 | return nn.Sequential(*features) 132 | 133 | def forward(self, x): 134 | x = self.conv1(x) 135 | x = self.block1(x) 136 | c1 = self.block2(x) 137 | c2 = self.block3(c1) 138 | c3 = self.block4(c2) 139 | c4 = self.block5(c3) 140 | 141 | # x = self.features(x) 142 | # x = self.classifier(x.view(x.size(0), x.size(1))) 143 | return c1, c2, c3, c4 144 | 145 | 146 | @BACKBONE_REGISTRY.register() 147 | def mobilenet_v1(norm_layer=nn.BatchNorm2d): 148 | return MobileNet(norm_layer=norm_layer) 149 | 150 | 151 | @BACKBONE_REGISTRY.register() 152 | def mobilenet_v2(norm_layer=nn.BatchNorm2d): 153 | return MobileNetV2(norm_layer=norm_layer) 154 | 155 | -------------------------------------------------------------------------------- /segmentron/models/backbones/eespnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from ...modules import _ConvBNPReLU, _ConvBN, _BNPReLU, EESP 7 | from .build import BACKBONE_REGISTRY 8 | from ...config import cfg 9 | 10 | __all__ = ['EESPNet', 'eespnet'] 11 | 12 | 13 | class DownSampler(nn.Module): 14 | 15 | def __init__(self, in_channels, out_channels, k=4, r_lim=9, reinf=True, inp_reinf=3, norm_layer=None): 16 | super(DownSampler, self).__init__() 17 | channels_diff = out_channels - in_channels 18 | self.eesp = EESP(in_channels, channels_diff, stride=2, k=k, 19 | r_lim=r_lim, down_method='avg', norm_layer=norm_layer) 20 | self.avg = nn.AvgPool2d(kernel_size=3, padding=1, stride=2) 21 | if reinf: 22 | self.inp_reinf = nn.Sequential( 23 | _ConvBNPReLU(inp_reinf, inp_reinf, 3, 1, 1), 24 | _ConvBN(inp_reinf, out_channels, 1, 1)) 25 | self.act = nn.PReLU(out_channels) 26 | 27 | def forward(self, x, x2=None): 28 | avg_out = self.avg(x) 29 | eesp_out = self.eesp(x) 30 | output = torch.cat([avg_out, eesp_out], 1) 31 | if x2 is not None: 32 | w1 = avg_out.size(2) 33 | while True: 34 | x2 = F.avg_pool2d(x2, kernel_size=3, padding=1, stride=2) 35 | w2 = x2.size(2) 36 | if w2 == w1: 37 | break 38 | output = output + self.inp_reinf(x2) 39 | 40 | return self.act(output) 41 | 42 | 43 | class EESPNet(nn.Module): 44 | def __init__(self, num_classes=1000, scale=1, reinf=True, norm_layer=nn.BatchNorm2d): 45 | super(EESPNet, self).__init__() 46 | inp_reinf = 3 if reinf else None 47 | reps = [0, 3, 7, 3] 48 | r_lim = [13, 11, 9, 7, 5] 49 | K = [4] * len(r_lim) 50 | 51 | # set out_channels 52 | base, levels, base_s = 32, 5, 0 53 | out_channels = [base] * levels 54 | for i in range(levels): 55 | if i == 0: 56 | base_s = int(base * scale) 57 | base_s = math.ceil(base_s / K[0]) * K[0] 58 | out_channels[i] = base if base_s > base else base_s 59 | else: 60 | out_channels[i] = base_s * pow(2, i) 61 | if scale <= 1.5: 62 | out_channels.append(1024) 63 | elif scale in [1.5, 2]: 64 | out_channels.append(1280) 65 | else: 66 | raise ValueError("Unknown scale value.") 67 | 68 | self.level1 = _ConvBNPReLU(3, out_channels[0], 3, 2, 1, norm_layer=norm_layer) 69 | 70 | self.level2_0 = DownSampler(out_channels[0], out_channels[1], k=K[0], r_lim=r_lim[0], 71 | reinf=reinf, inp_reinf=inp_reinf, norm_layer=norm_layer) 72 | 73 | self.level3_0 = DownSampler(out_channels[1], out_channels[2], k=K[1], r_lim=r_lim[1], 74 | reinf=reinf, inp_reinf=inp_reinf, norm_layer=norm_layer) 75 | self.level3 = nn.ModuleList() 76 | for i in range(reps[1]): 77 | self.level3.append(EESP(out_channels[2], out_channels[2], k=K[2], r_lim=r_lim[2], 78 | norm_layer=norm_layer)) 79 | 80 | self.level4_0 = DownSampler(out_channels[2], out_channels[3], k=K[2], r_lim=r_lim[2], 81 | reinf=reinf, inp_reinf=inp_reinf, norm_layer=norm_layer) 82 | self.level4 = nn.ModuleList() 83 | for i in range(reps[2]): 84 | self.level4.append(EESP(out_channels[3], out_channels[3], k=K[3], r_lim=r_lim[3], 85 | norm_layer=norm_layer)) 86 | 87 | self.level5_0 = DownSampler(out_channels[3], out_channels[4], k=K[3], r_lim=r_lim[3], 88 | reinf=reinf, inp_reinf=inp_reinf, norm_layer=norm_layer) 89 | self.level5 = nn.ModuleList() 90 | for i in range(reps[2]): 91 | self.level5.append(EESP(out_channels[4], out_channels[4], k=K[4], r_lim=r_lim[4], 92 | norm_layer=norm_layer)) 93 | 94 | self.level5.append(_ConvBNPReLU(out_channels[4], out_channels[4], 3, 1, 1, 95 | groups=out_channels[4], norm_layer=norm_layer)) 96 | self.level5.append(_ConvBNPReLU(out_channels[4], out_channels[5], 1, 1, 0, 97 | groups=K[4], norm_layer=norm_layer)) 98 | 99 | self.fc = nn.Linear(out_channels[5], num_classes) 100 | 101 | for m in self.modules(): 102 | if isinstance(m, nn.Conv2d): 103 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 104 | if m.bias is not None: 105 | nn.init.constant_(m.bias, 0) 106 | elif isinstance(m, nn.BatchNorm2d): 107 | nn.init.constant_(m.weight, 1) 108 | nn.init.constant_(m.bias, 0) 109 | elif isinstance(m, nn.Linear): 110 | nn.init.normal_(m.weight, std=0.001) 111 | if m.bias is not None: 112 | nn.init.constant_(m.bias, 0) 113 | 114 | def forward(self, x, seg=True): 115 | out_l1 = self.level1(x) 116 | 117 | out_l2 = self.level2_0(out_l1, x) 118 | 119 | out_l3_0 = self.level3_0(out_l2, x) 120 | for i, layer in enumerate(self.level3): 121 | if i == 0: 122 | out_l3 = layer(out_l3_0) 123 | else: 124 | out_l3 = layer(out_l3) 125 | 126 | out_l4_0 = self.level4_0(out_l3, x) 127 | for i, layer in enumerate(self.level4): 128 | if i == 0: 129 | out_l4 = layer(out_l4_0) 130 | else: 131 | out_l4 = layer(out_l4) 132 | 133 | if not seg: 134 | out_l5_0 = self.level5_0(out_l4) # down-sampled 135 | for i, layer in enumerate(self.level5): 136 | if i == 0: 137 | out_l5 = layer(out_l5_0) 138 | else: 139 | out_l5 = layer(out_l5) 140 | 141 | output_g = F.adaptive_avg_pool2d(out_l5, output_size=1) 142 | output_g = F.dropout(output_g, p=0.2, training=self.training) 143 | output_1x1 = output_g.view(output_g.size(0), -1) 144 | 145 | return self.fc(output_1x1) 146 | return out_l1, out_l2, out_l3, out_l4 147 | 148 | 149 | @BACKBONE_REGISTRY.register() 150 | def eespnet(norm_layer=nn.BatchNorm2d): 151 | return EESPNet(norm_layer=norm_layer) 152 | 153 | # def eespnet(pretrained=False, **kwargs): 154 | # model = EESPNet(**kwargs) 155 | # if pretrained: 156 | # raise ValueError("Don't support pretrained") 157 | # return model 158 | 159 | 160 | if __name__ == '__main__': 161 | img = torch.randn(1, 3, 224, 224) 162 | model = eespnet() 163 | out = model(img) 164 | --------------------------------------------------------------------------------