├── pretrained └── .gitkeep ├── segmentron ├── data │ ├── __init__.py │ └── dataloader │ │ ├── __init__.py │ │ ├── utils.py │ │ ├── transparent11.py │ │ ├── acdc.py │ │ ├── seg_data_base.py │ │ ├── stanford2d3d.py │ │ ├── cityscapes.py │ │ └── densepass.py ├── solver │ ├── __init__.py │ ├── optimizer.py │ └── lr_scheduler.py ├── config │ ├── __init__.py │ ├── config.py │ └── settings.py ├── __init__.py ├── modules │ ├── __init__.py │ ├── csrc │ │ ├── vision.cpp │ │ └── criss_cross_attention │ │ │ └── ca.h │ ├── norm.py │ ├── sync_bn │ │ └── syncbn.py │ ├── basic.py │ ├── drop.py │ └── batch_norm.py ├── utils │ ├── __init__.py │ ├── env.py │ ├── logger.py │ ├── default_setup.py │ ├── options.py │ ├── filesystem.py │ ├── registry.py │ ├── download.py │ ├── parallel.py │ └── score.py └── models │ ├── __init__.py │ ├── backbones │ ├── __init__.py │ ├── build.py │ ├── mobilenet.py │ └── eespnet.py │ ├── model_zoo.py │ ├── segbase.py │ ├── pvt_fpt.py │ ├── pvt_fpt_joint.py │ ├── pvtv2_mit_fpt.py │ └── pvt2_mit_fpt_joint.py ├── tools ├── data │ ├── random_split.py │ ├── check_label.py │ └── resize_datasets.py ├── demo_vis.png ├── dist_test.sh ├── dist_train.sh ├── demo.py └── eval.py ├── workdirs ├── cocostuff │ └── .gitkeep └── trans10kv2 │ └── .gitkeep ├── colors.npy ├── sounds ├── both_1.wav ├── left_1_2x.wav └── right_1_2x.wav ├── trans4trans_fig_1.jpg ├── name2label.json ├── demo.sh ├── configs ├── acdc │ ├── pvt_tiny_FPT.yaml │ ├── pvt_small_FPT.yaml │ ├── pvt_medium_FPT.yaml │ ├── pvt_v2_mit_medium_FPT.yaml │ ├── pvt_v2_mit_small_FPT.yaml │ └── pvt_v2_mit_tiny_FPT.yaml ├── trans10kv2 │ ├── pvt_tiny_FPT.yaml │ ├── pvt_small_FPT.yaml │ └── pvt_medium_FPT.yaml ├── cityscapes │ ├── pvt_tiny_FPT.yaml │ ├── pvt_small_FPT.yaml │ ├── pvt_medium_FPT.yaml │ ├── pvt_v2_mit_tiny_FPT.yaml │ ├── pvt_v2_mit_medium_FPT.yaml │ └── pvt_v2_mit_small_FPT.yaml ├── joint_stan_trans │ ├── pvt_tiny_FPT.yaml │ ├── pvt_medium_FPT64.yaml │ └── pvt_small_FPT64.yaml ├── stanford2d3d │ ├── pvt_tiny_FPT.yaml │ ├── pvt_small_FPT.yaml │ └── pvt_medium_FPT.yaml └── joint_cs_acdc │ ├── pvt_tiny_FPT.yaml │ ├── pvt_small_FPT.yaml │ ├── pvt_v2_mit_small_FPT.yaml │ ├── pvt_v2_mit_tiny_FPT.yaml │ ├── pvt_medium_FPT.yaml │ └── pvt_v2_mit_medium_FPT.yaml ├── setup.py ├── coco_model.txt ├── .gitignore └── README.md /pretrained/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /segmentron/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tools/data/random_split.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /workdirs/cocostuff/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /workdirs/trans10kv2/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /segmentron/solver/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tools/data/check_label.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | -------------------------------------------------------------------------------- /segmentron/config/__init__.py: -------------------------------------------------------------------------------- 1 | from .settings import cfg -------------------------------------------------------------------------------- /segmentron/__init__.py: -------------------------------------------------------------------------------- 1 | from . import modules, models, utils, data -------------------------------------------------------------------------------- /colors.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InSAI-Lab/Trans4Trans/HEAD/colors.npy -------------------------------------------------------------------------------- /sounds/both_1.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InSAI-Lab/Trans4Trans/HEAD/sounds/both_1.wav -------------------------------------------------------------------------------- /tools/demo_vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InSAI-Lab/Trans4Trans/HEAD/tools/demo_vis.png -------------------------------------------------------------------------------- /sounds/left_1_2x.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InSAI-Lab/Trans4Trans/HEAD/sounds/left_1_2x.wav -------------------------------------------------------------------------------- /sounds/right_1_2x.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InSAI-Lab/Trans4Trans/HEAD/sounds/right_1_2x.wav -------------------------------------------------------------------------------- /trans4trans_fig_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InSAI-Lab/Trans4Trans/HEAD/trans4trans_fig_1.jpg -------------------------------------------------------------------------------- /segmentron/modules/__init__.py: -------------------------------------------------------------------------------- 1 | """Seg NN Modules""" 2 | 3 | from .basic import * 4 | from .module import * 5 | from .batch_norm import get_norm 6 | -------------------------------------------------------------------------------- /name2label.json: -------------------------------------------------------------------------------- 1 | {"": 0, "beam": 1, "board": 2, "bookcase": 3, "ceiling": 4, "chair": 5, "clutter": 6, "column": 7, "door": 8, "floor": 9, "sofa": 10, "table": 11, "wall": 12, "window": 13} -------------------------------------------------------------------------------- /segmentron/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Utility functions.""" 2 | from __future__ import absolute_import 3 | 4 | from .download import download, check_sha1 5 | from .filesystem import makedirs 6 | -------------------------------------------------------------------------------- /demo.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd trans4trans 4 | 5 | eval "$(conda shell.bash hook)" 6 | conda activate trans4trans 7 | 8 | python demo_r200.py --config-file configs/trans10kv2/pvt_tiny_FPT.yaml 9 | -------------------------------------------------------------------------------- /tools/dist_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -e 3 | set -x 4 | 5 | CONFIG=$1 6 | 7 | python -m torch.distributed.launch --nproc_per_node=8 \ 8 | $(dirname "$0")/eval.py --config-file $CONFIG ${@:2} -------------------------------------------------------------------------------- /tools/dist_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -e 3 | set -x 4 | 5 | CONFIG=$1 6 | GPUS=${GPUS:-4} 7 | 8 | python -m torch.distributed.launch --nproc_per_node=$GPUS \ 9 | $(dirname "$0")/train.py --config-file $CONFIG ${@:2} -------------------------------------------------------------------------------- /segmentron/models/__init__.py: -------------------------------------------------------------------------------- 1 | """Model Zoo""" 2 | from .model_zoo import MODEL_REGISTRY 3 | from .pvt_fpt import PVT_FPT 4 | from .pvt_fpt_joint import PVT_FPT_JOINT 5 | from .pvtv2_mit_fpt import PVTV2_MIT_FPT 6 | from .pvt2_mit_fpt_joint import PVTV2_FPT_JOINT -------------------------------------------------------------------------------- /segmentron/models/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import BACKBONE_REGISTRY, get_segmentation_backbone 2 | from .xception import * 3 | from .mobilenet import * 4 | from .resnet import * 5 | from .hrnet import * 6 | from .eespnet import * 7 | from .pvt import * 8 | from .pvtv2_mix_transformer import * -------------------------------------------------------------------------------- /segmentron/modules/csrc/vision.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include "criss_cross_attention/ca.h" 3 | 4 | namespace segmentron { 5 | 6 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 7 | m.def("ca_forward", &ca_forward, "ca_forward"); 8 | m.def("ca_backward", &ca_backward, "ca_backward"); 9 | m.def("ca_map_forward", &ca_map_forward, "ca_map_forward"); 10 | m.def("ca_map_backward", &ca_map_backward, "ca_map_backward"); 11 | } 12 | 13 | } // namespace segmentron 14 | -------------------------------------------------------------------------------- /configs/acdc/pvt_tiny_FPT.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "acdc" 3 | MEAN: [0.485, 0.456, 0.406] 4 | STD: [0.229, 0.224, 0.225] 5 | WORKERS: 8 6 | TRAIN: 7 | EPOCHS: 200 8 | BATCH_SIZE: 4 9 | CROP_SIZE: (512, 512) 10 | BASE_SIZE: 512 11 | BACKBONE_PRETRAINED_PATH: "pretrained/pvt_tiny.pth" 12 | MODEL_SAVE_DIR: 'workdirs/acdc/pvt_tiny_FPT' 13 | APEX: True 14 | TEST: 15 | BATCH_SIZE: 4 16 | CROP_SIZE: (512, 512) 17 | TEST_MODEL_PATH: 'workdirs/acdc/pvt_small_FPT/best_model.pth' 18 | 19 | SOLVER: 20 | OPTIMIZER: "adamw" 21 | LR: 0.0001 22 | MODEL: 23 | MODEL_NAME: "PVT_FPT" 24 | BACKBONE: "pvt_tiny" 25 | EMB_CHANNELS: 64 26 | 27 | AUG: 28 | CROP: True 29 | -------------------------------------------------------------------------------- /configs/acdc/pvt_small_FPT.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "acdc" 3 | MEAN: [0.485, 0.456, 0.406] 4 | STD: [0.229, 0.224, 0.225] 5 | WORKERS: 8 6 | TRAIN: 7 | EPOCHS: 200 8 | BATCH_SIZE: 4 9 | CROP_SIZE: (512, 512) 10 | BASE_SIZE: 512 11 | BACKBONE_PRETRAINED_PATH: "pretrained/pvt_small.pth" 12 | MODEL_SAVE_DIR: 'workdirs/acdc/pvt_small_FPT' 13 | APEX: True 14 | TEST: 15 | BATCH_SIZE: 4 16 | CROP_SIZE: (512, 512) 17 | TEST_MODEL_PATH: 'workdirs/acdc/pvt_small_FPT/best_model.pth' 18 | 19 | SOLVER: 20 | OPTIMIZER: "adamw" 21 | LR: 0.0001 22 | MODEL: 23 | MODEL_NAME: "PVT_FPT" 24 | BACKBONE: "pvt_small" 25 | EMB_CHANNELS: 128 26 | 27 | AUG: 28 | CROP: True 29 | -------------------------------------------------------------------------------- /configs/trans10kv2/pvt_tiny_FPT.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "transparent11" 3 | MEAN: [0.485, 0.456, 0.406] 4 | STD: [0.229, 0.224, 0.225] 5 | WORKERS: 8 6 | TRAIN: 7 | EPOCHS: 100 8 | BATCH_SIZE: 4 9 | CROP_SIZE: (512, 512) 10 | BASE_SIZE: 512 11 | BACKBONE_PRETRAINED_PATH: "pretrained/pvt_tiny.pth" 12 | MODEL_SAVE_DIR: 'workdirs/trans10kv2/pvt_tiny_FPT' 13 | APEX: True 14 | TEST: 15 | BATCH_SIZE: 4 16 | CROP_SIZE: (512, 512) 17 | TEST_MODEL_PATH: "workdirs/trans10kv2/pvt_tiny_FPT/best_model.pth" 18 | 19 | SOLVER: 20 | OPTIMIZER: "adamw" 21 | LR: 0.0001 22 | MODEL: 23 | MODEL_NAME: "PVT_FPT" 24 | BACKBONE: "pvt_tiny" 25 | 26 | 27 | AUG: 28 | CROP: False 29 | -------------------------------------------------------------------------------- /configs/acdc/pvt_medium_FPT.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "acdc" 3 | MEAN: [0.485, 0.456, 0.406] 4 | STD: [0.229, 0.224, 0.225] 5 | WORKERS: 8 6 | TRAIN: 7 | EPOCHS: 200 8 | BATCH_SIZE: 4 9 | CROP_SIZE: (512, 512) 10 | BASE_SIZE: 512 11 | BACKBONE_PRETRAINED_PATH: "pretrained/pvt_medium.pth" 12 | MODEL_SAVE_DIR: 'workdirs/acdc/pvt_medium_FPT' 13 | APEX: True 14 | TEST: 15 | BATCH_SIZE: 4 16 | CROP_SIZE: (512, 512) 17 | TEST_MODEL_PATH: 'workdirs/acdc/pvt_medium_FPT/best_model.pth' 18 | 19 | SOLVER: 20 | OPTIMIZER: "adamw" 21 | LR: 0.0001 22 | MODEL: 23 | MODEL_NAME: "PVT_FPT" 24 | BACKBONE: "pvt_medium" 25 | EMB_CHANNELS: 256 26 | 27 | AUG: 28 | CROP: True 29 | -------------------------------------------------------------------------------- /configs/trans10kv2/pvt_small_FPT.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "transparent11" 3 | MEAN: [0.485, 0.456, 0.406] 4 | STD: [0.229, 0.224, 0.225] 5 | WORKERS: 8 6 | TRAIN: 7 | EPOCHS: 100 8 | BATCH_SIZE: 4 9 | CROP_SIZE: (512, 512) 10 | BASE_SIZE: 512 11 | BACKBONE_PRETRAINED_PATH: "pretrained/pvt_small.pth" 12 | MODEL_SAVE_DIR: 'workdirs/trans10kv2/pvt_small_FPT' 13 | APEX: True 14 | TEST: 15 | BATCH_SIZE: 4 16 | CROP_SIZE: (512, 512) 17 | TEST_MODEL_PATH: "workdirs/trans10kv2/pvt_small_FPT/best_model.pth" 18 | 19 | SOLVER: 20 | OPTIMIZER: "adamw" 21 | LR: 0.0001 22 | MODEL: 23 | MODEL_NAME: "PVT_FPT" 24 | BACKBONE: "pvt_small" 25 | 26 | 27 | AUG: 28 | CROP: False 29 | -------------------------------------------------------------------------------- /configs/acdc/pvt_v2_mit_medium_FPT.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "acdc" 3 | MEAN: [0.485, 0.456, 0.406] 4 | STD: [0.229, 0.224, 0.225] 5 | WORKERS: 8 6 | TRAIN: 7 | EPOCHS: 200 8 | BATCH_SIZE: 4 9 | CROP_SIZE: (512,512) 10 | BASE_SIZE: 512 11 | BACKBONE_PRETRAINED_PATH: "pretrained/mit_b3.pth" 12 | MODEL_SAVE_DIR: 'workdirs/acdc/pvtv2_mit_b3_FPT' 13 | APEX: True 14 | TEST: 15 | BATCH_SIZE: 4 16 | CROP_SIZE: (512,512) 17 | TEST_MODEL_PATH: 'workdirs/acdc/pvtv2_mit_b3_FPT/best_model.pth' 18 | 19 | SOLVER: 20 | OPTIMIZER: "adamw" 21 | LR: 0.00005 22 | MODEL: 23 | MODEL_NAME: "PVTV2_MIT_FPT" 24 | BACKBONE: "mit_b3" 25 | EMB_CHANNELS: 256 26 | 27 | AUG: 28 | CROP: True 29 | -------------------------------------------------------------------------------- /configs/acdc/pvt_v2_mit_small_FPT.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "acdc" 3 | MEAN: [0.485, 0.456, 0.406] 4 | STD: [0.229, 0.224, 0.225] 5 | WORKERS: 8 6 | TRAIN: 7 | EPOCHS: 200 8 | BATCH_SIZE: 8 9 | CROP_SIZE: (512,512) 10 | BASE_SIZE: 512 11 | BACKBONE_PRETRAINED_PATH: "pretrained/mit_b2.pth" 12 | MODEL_SAVE_DIR: 'workdirs/acdc/pvtv2_mit_b2_FPT' 13 | APEX: True 14 | TEST: 15 | BATCH_SIZE: 4 16 | CROP_SIZE: (512,512) 17 | TEST_MODEL_PATH: 'workdirs/acdc/pvtv2_mit_b2_FPT/best_model.pth' 18 | 19 | SOLVER: 20 | OPTIMIZER: "adamw" 21 | LR: 0.00005 22 | MODEL: 23 | MODEL_NAME: "PVTV2_MIT_FPT" 24 | BACKBONE: "mit_b2" 25 | EMB_CHANNELS: 128 26 | 27 | AUG: 28 | CROP: True 29 | -------------------------------------------------------------------------------- /configs/acdc/pvt_v2_mit_tiny_FPT.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "acdc" 3 | MEAN: [0.485, 0.456, 0.406] 4 | STD: [0.229, 0.224, 0.225] 5 | WORKERS: 8 6 | TRAIN: 7 | EPOCHS: 200 8 | BATCH_SIZE: 12 9 | CROP_SIZE: (512,512) 10 | BASE_SIZE: 512 11 | BACKBONE_PRETRAINED_PATH: "pretrained/mit_b1.pth" 12 | MODEL_SAVE_DIR: 'workdirs/acdc/pvtv2_mit_b1_FPT' 13 | APEX: True 14 | TEST: 15 | BATCH_SIZE: 4 16 | CROP_SIZE: (512,512) 17 | TEST_MODEL_PATH: 'workdirs/acdc/pvtv2_mit_b1_FPT/best_model.pth' 18 | 19 | SOLVER: 20 | OPTIMIZER: "adamw" 21 | LR: 0.00005 22 | MODEL: 23 | MODEL_NAME: "PVTV2_MIT_FPT" 24 | BACKBONE: "mit_b1" 25 | EMB_CHANNELS: 64 26 | 27 | AUG: 28 | CROP: True 29 | -------------------------------------------------------------------------------- /configs/trans10kv2/pvt_medium_FPT.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "transparent11" 3 | MEAN: [0.485, 0.456, 0.406] 4 | STD: [0.229, 0.224, 0.225] 5 | WORKERS: 8 6 | TRAIN: 7 | EPOCHS: 100 8 | BATCH_SIZE: 4 9 | CROP_SIZE: (512, 512) 10 | BASE_SIZE: 512 11 | BACKBONE_PRETRAINED_PATH: "pretrained/pvt_medium.pth" 12 | MODEL_SAVE_DIR: 'workdirs/trans10kv2/pvt_medium_FPT' 13 | APEX: True 14 | TEST: 15 | BATCH_SIZE: 4 16 | CROP_SIZE: (512, 512) 17 | TEST_MODEL_PATH: "workdirs/trans10kv2/pvt_medium_FPT/best_model.pth" 18 | 19 | SOLVER: 20 | OPTIMIZER: "adamw" 21 | LR: 0.0001 22 | MODEL: 23 | MODEL_NAME: "PVT_FPT" 24 | BACKBONE: "pvt_medium" 25 | 26 | 27 | AUG: 28 | CROP: False 29 | -------------------------------------------------------------------------------- /configs/cityscapes/pvt_tiny_FPT.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "cityscape" 3 | MEAN: [0.485, 0.456, 0.406] 4 | STD: [0.229, 0.224, 0.225] 5 | WORKERS: 8 6 | TRAIN: 7 | EPOCHS: 200 8 | BATCH_SIZE: 4 9 | CROP_SIZE: (512, 512) 10 | BASE_SIZE: 512 11 | BACKBONE_PRETRAINED_PATH: "pretrained/pvt_tiny.pth" 12 | MODEL_SAVE_DIR: 'workdirs/cityscapes/pvt_tiny_FPT' 13 | APEX: True 14 | TEST: 15 | BATCH_SIZE: 4 16 | CROP_SIZE: (512, 512) 17 | TEST_MODEL_PATH: 'workdirs/cityscapes/pvt_small_FPT/best_model.pth' 18 | 19 | SOLVER: 20 | OPTIMIZER: "adamw" 21 | LR: 0.00005 22 | MODEL: 23 | MODEL_NAME: "PVT_FPT" 24 | BACKBONE: "pvt_tiny" 25 | EMB_CHANNELS: 64 26 | 27 | AUG: 28 | CROP: True 29 | -------------------------------------------------------------------------------- /configs/joint_stan_trans/pvt_tiny_FPT.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "transparent11" 3 | MEAN: [0.485, 0.456, 0.406] 4 | STD: [0.229, 0.224, 0.225] 5 | WORKERS: 8 6 | 7 | DATASET2: 8 | NAME: "stanford2d3d" 9 | 10 | TRAIN: 11 | EPOCHS: 100 12 | ITERS: 20000 13 | BATCH_SIZE: 4 14 | CROP_SIZE: (512, 512) 15 | BASE_SIZE: 512 16 | BACKBONE_PRETRAINED_PATH: "pretrained/pvt_tiny.pth" 17 | MODEL_SAVE_DIR: 'workdirs/joint_stan_trans/pvt_tiny_FPT' 18 | APEX: True 19 | TEST: 20 | BATCH_SIZE: 4 21 | CROP_SIZE: (512, 512) 22 | 23 | SOLVER: 24 | OPTIMIZER: "adamw" 25 | LR: 0.0001 26 | MODEL: 27 | MODEL_NAME: "PVT_FPT_JOINT" 28 | BACKBONE: "pvt_tiny" 29 | 30 | AUG: 31 | CROP: False 32 | -------------------------------------------------------------------------------- /configs/cityscapes/pvt_small_FPT.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "cityscape" 3 | MEAN: [0.485, 0.456, 0.406] 4 | STD: [0.229, 0.224, 0.225] 5 | WORKERS: 8 6 | TRAIN: 7 | EPOCHS: 200 8 | BATCH_SIZE: 4 9 | CROP_SIZE: (512, 512) 10 | BASE_SIZE: 512 11 | BACKBONE_PRETRAINED_PATH: "pretrained/pvt_small.pth" 12 | MODEL_SAVE_DIR: 'workdirs/cityscapes/pvt_small_FPT' 13 | APEX: True 14 | TEST: 15 | BATCH_SIZE: 4 16 | CROP_SIZE: (512, 512) 17 | TEST_MODEL_PATH: 'workdirs/cityscapes/pvt_small_FPT/best_model.pth' 18 | 19 | SOLVER: 20 | OPTIMIZER: "adamw" 21 | LR: 0.00005 22 | MODEL: 23 | MODEL_NAME: "PVT_FPT" 24 | BACKBONE: "pvt_small" 25 | EMB_CHANNELS: 128 26 | 27 | AUG: 28 | CROP: True 29 | -------------------------------------------------------------------------------- /configs/cityscapes/pvt_medium_FPT.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "cityscape" 3 | MEAN: [0.485, 0.456, 0.406] 4 | STD: [0.229, 0.224, 0.225] 5 | WORKERS: 8 6 | TRAIN: 7 | EPOCHS: 200 8 | BATCH_SIZE: 4 9 | CROP_SIZE: (512, 512) 10 | BASE_SIZE: 512 11 | BACKBONE_PRETRAINED_PATH: "pretrained/pvt_medium.pth" 12 | MODEL_SAVE_DIR: 'workdirs/cityscapes/pvt_medium_FPT' 13 | APEX: True 14 | TEST: 15 | BATCH_SIZE: 4 16 | CROP_SIZE: (512, 512) 17 | TEST_MODEL_PATH: 'workdirs/cityscapes/pvt_medium_FPT/best_model.pth' 18 | 19 | SOLVER: 20 | OPTIMIZER: "adamw" 21 | LR: 0.00005 22 | MODEL: 23 | MODEL_NAME: "PVT_FPT" 24 | BACKBONE: "pvt_medium" 25 | EMB_CHANNELS: 256 26 | 27 | AUG: 28 | CROP: True 29 | -------------------------------------------------------------------------------- /configs/joint_stan_trans/pvt_medium_FPT64.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "transparent11" 3 | MEAN: [0.485, 0.456, 0.406] 4 | STD: [0.229, 0.224, 0.225] 5 | WORKERS: 8 6 | 7 | DATASET2: 8 | NAME: "stanford2d3d" 9 | 10 | TRAIN: 11 | EPOCHS: 100 12 | ITERS: 20000 13 | BATCH_SIZE: 4 14 | CROP_SIZE: (512, 512) 15 | BASE_SIZE: 512 16 | BACKBONE_PRETRAINED_PATH: "pretrained/pvt_medium.pth" 17 | MODEL_SAVE_DIR: 'workdirs/joint_stan_trans/pvt_medium_FPT64' 18 | APEX: True 19 | TEST: 20 | BATCH_SIZE: 4 21 | CROP_SIZE: (512, 512) 22 | 23 | SOLVER: 24 | OPTIMIZER: "adamw" 25 | LR: 0.0001 26 | MODEL: 27 | MODEL_NAME: "PVT_FPT_JOINT" 28 | BACKBONE: "pvt_medium" 29 | 30 | AUG: 31 | CROP: False 32 | -------------------------------------------------------------------------------- /configs/joint_stan_trans/pvt_small_FPT64.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "transparent11" 3 | MEAN: [0.485, 0.456, 0.406] 4 | STD: [0.229, 0.224, 0.225] 5 | WORKERS: 8 6 | 7 | DATASET2: 8 | NAME: "stanford2d3d" 9 | 10 | TRAIN: 11 | EPOCHS: 100 12 | ITERS: 20000 13 | BATCH_SIZE: 4 14 | CROP_SIZE: (512, 512) 15 | BASE_SIZE: 512 16 | BACKBONE_PRETRAINED_PATH: "pretrained/pvt_small.pth" 17 | MODEL_SAVE_DIR: 'workdirs/joint_stan_trans/pvt_small_FPT64' 18 | APEX: True 19 | TEST: 20 | BATCH_SIZE: 4 21 | CROP_SIZE: (512, 512) 22 | 23 | SOLVER: 24 | OPTIMIZER: "adamw" 25 | LR: 0.0001 26 | MODEL: 27 | MODEL_NAME: "PVT_FPT_JOINT" 28 | BACKBONE: "pvt_small" 29 | 30 | AUG: 31 | CROP: False 32 | -------------------------------------------------------------------------------- /configs/cityscapes/pvt_v2_mit_tiny_FPT.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "cityscape" 3 | MEAN: [0.485, 0.456, 0.406] 4 | STD: [0.229, 0.224, 0.225] 5 | WORKERS: 0 6 | TRAIN: 7 | EPOCHS: 200 8 | BATCH_SIZE: 4 9 | CROP_SIZE: (512,512) 10 | BASE_SIZE: 512 11 | BACKBONE_PRETRAINED_PATH: "pretrained/mit_b1.pth" 12 | MODEL_SAVE_DIR: 'workdirs/cityscapes/pvtv2_mit_b1_FPT' 13 | APEX: True 14 | TEST: 15 | BATCH_SIZE: 4 16 | CROP_SIZE: (512,512) 17 | TEST_MODEL_PATH: 'workdirs/cityscapes/pvtv2_mit_b1_FPT/best_model.pth' 18 | 19 | SOLVER: 20 | OPTIMIZER: "adamw" 21 | LR: 0.00005 22 | MODEL: 23 | MODEL_NAME: "PVTV2_MIT_FPT" 24 | BACKBONE: "mit_b1" 25 | EMB_CHANNELS: 64 26 | 27 | AUG: 28 | CROP: True 29 | -------------------------------------------------------------------------------- /configs/stanford2d3d/pvt_tiny_FPT.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "stanford2d3d" 3 | MEAN: [0.485, 0.456, 0.406] 4 | STD: [0.229, 0.224, 0.225] 5 | WORKERS: 8 6 | IGNORE_INDEX: 0 7 | TRAIN: 8 | EPOCHS: 10 9 | BATCH_SIZE: 4 10 | CROP_SIZE: (512, 512) 11 | BASE_SIZE: 512 12 | BACKBONE_PRETRAINED_PATH: "pretrained/pvt_tiny.pth" 13 | MODEL_SAVE_DIR: 'workdirs/stanford2d3d/pvt_tiny_FPT' 14 | APEX: True 15 | TEST: 16 | BATCH_SIZE: 1 17 | CROP_SIZE: (512, 512) 18 | TEST_MODEL_PATH: 'workdirs/stanford2d3d/pvt_tiny_FPT/best_model.pth' 19 | 20 | SOLVER: 21 | OPTIMIZER: "adamw" 22 | LR: 0.0001 23 | MODEL: 24 | MODEL_NAME: "PVT_FPT" 25 | BACKBONE: "pvt_tiny" 26 | 27 | 28 | AUG: 29 | CROP: False 30 | -------------------------------------------------------------------------------- /configs/cityscapes/pvt_v2_mit_medium_FPT.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "cityscape" 3 | MEAN: [0.485, 0.456, 0.406] 4 | STD: [0.229, 0.224, 0.225] 5 | WORKERS: 8 6 | TRAIN: 7 | EPOCHS: 200 8 | BATCH_SIZE: 4 9 | CROP_SIZE: (512,512) 10 | BASE_SIZE: 512 11 | BACKBONE_PRETRAINED_PATH: "pretrained/mit_b3.pth" 12 | MODEL_SAVE_DIR: 'workdirs/cityscapes/pvtv2_mit_b3_FPT' 13 | APEX: True 14 | TEST: 15 | BATCH_SIZE: 4 16 | CROP_SIZE: (512,512) 17 | TEST_MODEL_PATH: 'workdirs/cityscapes/pvtv2_mit_b3_FPT/best_model.pth' 18 | 19 | SOLVER: 20 | OPTIMIZER: "adamw" 21 | LR: 0.00005 22 | MODEL: 23 | MODEL_NAME: "PVTV2_MIT_FPT" 24 | BACKBONE: "mit_b3" 25 | EMB_CHANNELS: 256 26 | 27 | AUG: 28 | CROP: True 29 | -------------------------------------------------------------------------------- /configs/cityscapes/pvt_v2_mit_small_FPT.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "cityscape" 3 | MEAN: [0.485, 0.456, 0.406] 4 | STD: [0.229, 0.224, 0.225] 5 | WORKERS: 8 6 | TRAIN: 7 | EPOCHS: 200 8 | BATCH_SIZE: 4 9 | CROP_SIZE: (512,512) 10 | BASE_SIZE: 512 11 | BACKBONE_PRETRAINED_PATH: "pretrained/mit_b2.pth" 12 | MODEL_SAVE_DIR: 'workdirs/cityscapes/pvtv2_mit_b2_FPT' 13 | APEX: True 14 | TEST: 15 | BATCH_SIZE: 4 16 | CROP_SIZE: (512,512) 17 | TEST_MODEL_PATH: 'workdirs/cityscapes/pvtv2_mit_b2_FPT/best_model.pth' 18 | 19 | SOLVER: 20 | OPTIMIZER: "adamw" 21 | LR: 0.00005 22 | MODEL: 23 | MODEL_NAME: "PVTV2_MIT_FPT" 24 | BACKBONE: "mit_b2" 25 | EMB_CHANNELS: 128 26 | 27 | AUG: 28 | CROP: True 29 | -------------------------------------------------------------------------------- /configs/stanford2d3d/pvt_small_FPT.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "stanford2d3d" 3 | MEAN: [0.485, 0.456, 0.406] 4 | STD: [0.229, 0.224, 0.225] 5 | WORKERS: 8 6 | IGNORE_INDEX: 0 7 | TRAIN: 8 | EPOCHS: 10 9 | BATCH_SIZE: 4 10 | CROP_SIZE: (512, 512) 11 | BASE_SIZE: 512 12 | BACKBONE_PRETRAINED_PATH: "pretrained/pvt_small.pth" 13 | MODEL_SAVE_DIR: 'workdirs/stanford2d3d/pvt_small_FPT' 14 | APEX: True 15 | TEST: 16 | BATCH_SIZE: 1 17 | CROP_SIZE: (512, 512) 18 | TEST_MODEL_PATH: 'workdirs/stanford2d3d/pvt_small_FPT/best_model.pth' 19 | 20 | SOLVER: 21 | OPTIMIZER: "adamw" 22 | LR: 0.0001 23 | MODEL: 24 | MODEL_NAME: "PVT_FPT" 25 | BACKBONE: "pvt_small" 26 | 27 | 28 | AUG: 29 | CROP: False 30 | -------------------------------------------------------------------------------- /configs/stanford2d3d/pvt_medium_FPT.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "stanford2d3d" 3 | MEAN: [0.485, 0.456, 0.406] 4 | STD: [0.229, 0.224, 0.225] 5 | WORKERS: 8 6 | IGNORE_INDEX: 0 7 | TRAIN: 8 | EPOCHS: 10 9 | BATCH_SIZE: 4 10 | CROP_SIZE: (512, 512) 11 | BASE_SIZE: 512 12 | BACKBONE_PRETRAINED_PATH: "pretrained/pvt_medium.pth" 13 | MODEL_SAVE_DIR: 'workdirs/stanford2d3d/pvt_medium_FPT' 14 | APEX: True 15 | TEST: 16 | BATCH_SIZE: 1 17 | CROP_SIZE: (512, 512) 18 | TEST_MODEL_PATH: 'workdirs/stanford2d3d/pvt_medium_FPT/best_model.pth' 19 | 20 | SOLVER: 21 | OPTIMIZER: "adamw" 22 | LR: 0.0001 23 | MODEL: 24 | MODEL_NAME: "PVT_FPT" 25 | BACKBONE: "pvt_medium" 26 | 27 | 28 | AUG: 29 | CROP: False 30 | -------------------------------------------------------------------------------- /configs/joint_cs_acdc/pvt_tiny_FPT.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "acdc" 3 | MEAN: [0.485, 0.456, 0.406] 4 | STD: [0.229, 0.224, 0.225] 5 | WORKERS: 8 6 | 7 | DATASET2: 8 | NAME: "cityscape" 9 | 10 | TRAIN: 11 | EPOCHS: 200 12 | BATCH_SIZE: 2 13 | CROP_SIZE: (512, 512) 14 | BASE_SIZE: 512 15 | BACKBONE_PRETRAINED_PATH: "pretrained/pvt_tiny.pth" 16 | MODEL_SAVE_DIR: 'workdirs/joint_cs_acdc/pvt_tiny_FPT' 17 | APEX: True 18 | TEST: 19 | BATCH_SIZE: 2 20 | CROP_SIZE: (512, 512) 21 | TEST_MODEL_PATH: 'workdirs/joint_cs_acdc/pvt_tiny_FPT/best_model.pth' 22 | 23 | SOLVER: 24 | OPTIMIZER: "adamw" 25 | LR: 0.0001 26 | MODEL: 27 | MODEL_NAME: "PVT_FPT_JOINT" 28 | BACKBONE: "pvt_tiny" 29 | EMB_CHANNELS: 64 30 | 31 | AUG: 32 | CROP: True 33 | -------------------------------------------------------------------------------- /configs/joint_cs_acdc/pvt_small_FPT.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "acdc" 3 | MEAN: [0.485, 0.456, 0.406] 4 | STD: [0.229, 0.224, 0.225] 5 | WORKERS: 8 6 | 7 | DATASET2: 8 | NAME: "cityscape" 9 | 10 | TRAIN: 11 | EPOCHS: 200 12 | BATCH_SIZE: 2 13 | CROP_SIZE: (512, 512) 14 | BASE_SIZE: 512 15 | BACKBONE_PRETRAINED_PATH: "pretrained/pvt_small.pth" 16 | MODEL_SAVE_DIR: 'workdirs/joint_cs_acdc/pvt_small_FPT' 17 | APEX: True 18 | TEST: 19 | BATCH_SIZE: 2 20 | CROP_SIZE: (512, 512) 21 | TEST_MODEL_PATH: 'workdirs/joint_cs_acdc/pvt_small_FPT/best_model.pth' 22 | 23 | SOLVER: 24 | OPTIMIZER: "adamw" 25 | LR: 0.0001 26 | MODEL: 27 | MODEL_NAME: "PVT_FPT_JOINT" 28 | BACKBONE: "pvt_small" 29 | EMB_CHANNELS: 128 30 | 31 | AUG: 32 | CROP: True 33 | -------------------------------------------------------------------------------- /configs/joint_cs_acdc/pvt_v2_mit_small_FPT.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "acdc" 3 | MEAN: [0.485, 0.456, 0.406] 4 | STD: [0.229, 0.224, 0.225] 5 | WORKERS: 8 6 | 7 | DATASET2: 8 | NAME: "cityscape" 9 | TRAIN: 10 | EPOCHS: 200 11 | BATCH_SIZE: 2 12 | CROP_SIZE: (512,512) 13 | BASE_SIZE: 512 14 | BACKBONE_PRETRAINED_PATH: "pretrained/mit_b2.pth" 15 | MODEL_SAVE_DIR: 'workdirs/joint_cs_acdc/pvtv2_mit_b2_FPT' 16 | APEX: True 17 | TEST: 18 | BATCH_SIZE: 2 19 | CROP_SIZE: (512,512) 20 | TEST_MODEL_PATH: 'workdirs/joint_cs_acdc/pvtv2_mit_b2_FPT/best_model.pth' 21 | 22 | SOLVER: 23 | OPTIMIZER: "adamw" 24 | LR: 0.00005 25 | MODEL: 26 | MODEL_NAME: "PVTV2_FPT_JOINT" 27 | BACKBONE: "mit_b2" 28 | EMB_CHANNELS: 128 29 | 30 | AUG: 31 | CROP: True 32 | -------------------------------------------------------------------------------- /configs/joint_cs_acdc/pvt_v2_mit_tiny_FPT.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "acdc" 3 | MEAN: [0.485, 0.456, 0.406] 4 | STD: [0.229, 0.224, 0.225] 5 | WORKERS: 8 6 | 7 | DATASET2: 8 | NAME: "cityscape" 9 | TRAIN: 10 | EPOCHS: 200 11 | BATCH_SIZE: 2 12 | CROP_SIZE: (512,512) 13 | BASE_SIZE: 512 14 | BACKBONE_PRETRAINED_PATH: "pretrained/mit_b1.pth" 15 | MODEL_SAVE_DIR: 'workdirs/joint_cs_acdc/pvtv2_mit_b1_FPT' 16 | APEX: True 17 | TEST: 18 | BATCH_SIZE: 2 19 | CROP_SIZE: (512,512) 20 | TEST_MODEL_PATH: 'workdirs/joint_cs_acdc/pvtv2_mit_b1_FPT/best_model.pth' 21 | 22 | SOLVER: 23 | OPTIMIZER: "adamw" 24 | LR: 0.00005 25 | MODEL: 26 | MODEL_NAME: "PVTV2_FPT_JOINT" 27 | BACKBONE: "mit_b1" 28 | EMB_CHANNELS: 64 29 | 30 | AUG: 31 | CROP: True 32 | -------------------------------------------------------------------------------- /configs/joint_cs_acdc/pvt_medium_FPT.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "acdc" 3 | MEAN: [0.485, 0.456, 0.406] 4 | STD: [0.229, 0.224, 0.225] 5 | WORKERS: 8 6 | 7 | DATASET2: 8 | NAME: "cityscape" 9 | 10 | TRAIN: 11 | EPOCHS: 200 12 | BATCH_SIZE: 2 13 | CROP_SIZE: (512, 512) 14 | BASE_SIZE: 512 15 | BACKBONE_PRETRAINED_PATH: "pretrained/pvt_medium.pth" 16 | MODEL_SAVE_DIR: 'workdirs/joint_cs_acdc/pvt_medium_FPT' 17 | APEX: True 18 | TEST: 19 | BATCH_SIZE: 2 20 | CROP_SIZE: (512, 512) 21 | TEST_MODEL_PATH: 'workdirs/joint_cs_acdc/pvt_medium_FPT/best_model.pth' 22 | 23 | SOLVER: 24 | OPTIMIZER: "adamw" 25 | LR: 0.0001 26 | MODEL: 27 | MODEL_NAME: "PVT_FPT_JOINT" 28 | BACKBONE: "pvt_medium" 29 | EMB_CHANNELS: 256 30 | 31 | AUG: 32 | CROP: True 33 | -------------------------------------------------------------------------------- /configs/joint_cs_acdc/pvt_v2_mit_medium_FPT.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "acdc" 3 | MEAN: [0.485, 0.456, 0.406] 4 | STD: [0.229, 0.224, 0.225] 5 | WORKERS: 8 6 | 7 | DATASET2: 8 | NAME: "cityscape" 9 | TRAIN: 10 | EPOCHS: 200 11 | BATCH_SIZE: 2 12 | CROP_SIZE: (512,512) 13 | BASE_SIZE: 512 14 | BACKBONE_PRETRAINED_PATH: "pretrained/mit_b3.pth" 15 | MODEL_SAVE_DIR: 'workdirs/joint_cs_acdc/pvtv2_mit_b3_FPT' 16 | APEX: True 17 | TEST: 18 | BATCH_SIZE: 2 19 | CROP_SIZE: (512,512) 20 | TEST_MODEL_PATH: 'workdirs/joint_cs_acdc/pvtv2_mit_b3_FPT/best_model.pth' 21 | 22 | SOLVER: 23 | OPTIMIZER: "adamw" 24 | LR: 0.00005 25 | MODEL: 26 | MODEL_NAME: "PVTV2_FPT_JOINT" 27 | BACKBONE: "mit_b3" 28 | EMB_CHANNELS: 256 29 | 30 | AUG: 31 | CROP: True 32 | -------------------------------------------------------------------------------- /segmentron/data/dataloader/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module provides data loaders and transformers for popular vision datasets. 3 | """ 4 | from .cityscapes import CitySegmentation 5 | from .transparent11 import TransparentSegmentation 6 | from .stanford2d3d import Stanford2d3dSegmentation 7 | from .cocostuff import COCOStuffSegmentation 8 | from .acdc import ACDCSegmentation 9 | from .densepass import DensePASSSegmentation 10 | 11 | datasets = { 12 | 'cityscape': CitySegmentation, 13 | 'transparent11': TransparentSegmentation, 14 | 'stanford2d3d': Stanford2d3dSegmentation, 15 | 'cocostuff': COCOStuffSegmentation, 16 | 'acdc': ACDCSegmentation, 17 | 'densepass': DensePASSSegmentation, 18 | } 19 | 20 | 21 | def get_segmentation_dataset(name, **kwargs): 22 | """Segmentation Datasets""" 23 | return datasets[name.lower()](**kwargs) 24 | -------------------------------------------------------------------------------- /segmentron/utils/env.py: -------------------------------------------------------------------------------- 1 | # this code heavily based on detectron2 2 | 3 | import logging 4 | import numpy as np 5 | import os 6 | import random 7 | from datetime import datetime 8 | import torch 9 | 10 | __all__ = ["seed_all_rng"] 11 | 12 | 13 | def seed_all_rng(seed=None): 14 | """ 15 | Set the random seed for the RNG in torch, numpy and python. 16 | 17 | Args: 18 | seed (int): if None, will use a strong random seed. 19 | """ 20 | if seed is None: 21 | seed = ( 22 | os.getpid() 23 | + int(datetime.now().strftime("%S%f")) 24 | + int.from_bytes(os.urandom(2), "big") 25 | ) 26 | logger = logging.getLogger(__name__) 27 | logger.info("Using a generated random seed {}".format(seed)) 28 | np.random.seed(seed) 29 | torch.set_rng_state(torch.manual_seed(seed).get_state()) 30 | random.seed(seed) 31 | -------------------------------------------------------------------------------- /segmentron/utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | 5 | __all__ = ['setup_logger'] 6 | 7 | 8 | def setup_logger(name, save_dir, distributed_rank, filename="log.txt", mode='w'): 9 | if distributed_rank > 0: 10 | return 11 | 12 | logging.root.name = name 13 | logging.root.setLevel(logging.INFO) 14 | ch = logging.StreamHandler(stream=sys.stdout) 15 | ch.setLevel(logging.DEBUG) 16 | formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s") 17 | ch.setFormatter(formatter) 18 | logging.root.addHandler(ch) 19 | 20 | if save_dir: 21 | if not os.path.exists(save_dir): 22 | os.makedirs(save_dir) 23 | fh = logging.FileHandler(os.path.join(save_dir, filename), mode=mode) 24 | fh.setLevel(logging.DEBUG) 25 | fh.setFormatter(formatter) 26 | logging.root.addHandler(fh) 27 | -------------------------------------------------------------------------------- /segmentron/utils/default_setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import json 4 | import torch 5 | 6 | from .distributed import get_rank, synchronize 7 | from .logger import setup_logger 8 | from .env import seed_all_rng 9 | from ..config import cfg 10 | 11 | def default_setup(args): 12 | num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 13 | args.num_gpus = num_gpus 14 | args.distributed = num_gpus > 1 15 | 16 | if not args.no_cuda and torch.cuda.is_available(): 17 | # cudnn.deterministic = True 18 | torch.backends.cudnn.benchmark = True 19 | args.device = "cuda" 20 | else: 21 | args.distributed = False 22 | args.device = "cpu" 23 | if args.distributed: 24 | torch.cuda.set_device(args.local_rank) 25 | torch.distributed.init_process_group(backend="nccl", init_method="env://") 26 | synchronize() 27 | 28 | save_dir = cfg.TRAIN.MODEL_SAVE_DIR if cfg.PHASE == 'train' else None 29 | setup_logger("Segmentron", save_dir, get_rank(), filename='{}_{}_{}_{}_log.txt'.format( 30 | cfg.MODEL.MODEL_NAME, cfg.MODEL.BACKBONE, cfg.DATASET.NAME, cfg.TIME_STAMP)) 31 | 32 | logging.info("Using {} GPUs".format(num_gpus)) 33 | logging.info(args) 34 | logging.info(json.dumps(cfg, indent=8)) 35 | 36 | seed_all_rng(None if cfg.SEED < 0 else cfg.SEED + get_rank()) -------------------------------------------------------------------------------- /segmentron/utils/options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def parse_args(): 4 | parser = argparse.ArgumentParser(description='Segmentron') 5 | parser.add_argument('--config-file' ,metavar="FILE", 6 | help='config file path') 7 | # cuda setting 8 | parser.add_argument('--no-cuda', action='store_true', default=False, 9 | help='disables CUDA training') 10 | parser.add_argument('--local_rank', type=int, default=0) 11 | # checkpoint and log 12 | parser.add_argument('--resume', type=str, default=None, 13 | help='put the path to resuming file if needed') 14 | parser.add_argument('--log-iter', type=int, default=10, 15 | help='print log every log-iter') 16 | # for evaluation 17 | parser.add_argument('--val-epoch', type=int, default=1, 18 | help='run validation every val-epoch') 19 | parser.add_argument('--skip-val', action='store_true', default=False, 20 | help='skip validation during training') 21 | parser.add_argument('--test', action='store_true', default=False, 22 | help='test model') 23 | parser.add_argument('--vis', action='store_true', default=False, 24 | help='visualize images') 25 | # for visual 26 | parser.add_argument('--input-img', type=str, default='tools/demo_vis.png', 27 | help='path to the input image or a directory of images') 28 | # config options 29 | parser.add_argument('opts', help='See config for all options', 30 | default=None, nargs=argparse.REMAINDER) 31 | args = parser.parse_args() 32 | 33 | return args -------------------------------------------------------------------------------- /tools/data/resize_datasets.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import os 3 | import os.path as osp 4 | import random 5 | 6 | 7 | dataset_root = "./datasets/transparent/Trans10K" 8 | dataset_out_root = "/mnt/lustre/share_data/lixiangtai/datasets/transparent/Trans10K_cls12" 9 | 10 | 11 | def check_dir(dir): 12 | if not os.path.isdir(dir): 13 | os.makedirs(dir) 14 | 15 | 16 | def random_size(shot_size): 17 | if shot_size > 1200: 18 | return random.randint(800, 1200) 19 | else: 20 | return shot_size 21 | 22 | 23 | def resize_dataset(i_root, o_root, mode="train"): 24 | 25 | out_images_path = osp.join(o_root, mode, "images") 26 | out_masks_path = osp.join(o_root, mode, "masks_12") 27 | 28 | check_dir(o_root) 29 | check_dir(out_images_path) 30 | check_dir(out_masks_path) 31 | 32 | image_path = osp.join(i_root, mode, "images") 33 | mask_path = osp.join(i_root, mode, "masks_12") 34 | for i in os.listdir(image_path): 35 | img = Image.open(osp.join(image_path, i)) 36 | basename, _ = os.path.splitext(i) 37 | mask = Image.open(osp.join(mask_path, basename+"_mask.png")) 38 | assert img.size == mask.size 39 | w, h = img.size 40 | 41 | if w > h: 42 | oh =random_size(h) 43 | ow = int(1.0 * w * oh / h) 44 | else: 45 | ow = random_size(w) 46 | 47 | oh = int(1.0 * h * ow / w) 48 | 49 | new_img = img.resize((ow, oh), Image.BILINEAR) 50 | new_mask = mask.resize((ow, oh), Image.NEAREST) 51 | new_img.save(osp.join(out_images_path, i)) 52 | new_mask.save(osp.join(out_masks_path, basename+"_mask.png")) 53 | print("process image", i) 54 | 55 | 56 | resize_dataset(dataset_root, dataset_out_root) 57 | resize_dataset(dataset_root, dataset_out_root, mode='test') 58 | resize_dataset(dataset_root, dataset_out_root, mode='validation') -------------------------------------------------------------------------------- /segmentron/utils/filesystem.py: -------------------------------------------------------------------------------- 1 | """Filesystem utility functions.""" 2 | from __future__ import absolute_import 3 | import os 4 | import errno 5 | import torch 6 | import logging 7 | 8 | from ..config import cfg 9 | 10 | def save_checkpoint(model, epoch, optimizer=None, lr_scheduler=None, is_best=False): 11 | """Save Checkpoint""" 12 | directory = os.path.expanduser(cfg.TRAIN.MODEL_SAVE_DIR) 13 | if not os.path.exists(directory): 14 | os.makedirs(directory) 15 | filename = '{}.pth'.format(str(epoch)) 16 | filename = os.path.join(directory, filename) 17 | model_state_dict = model.module.state_dict() if hasattr(model, 'module') else model.state_dict() 18 | if is_best: 19 | best_filename = 'best_model.pth' 20 | best_filename = os.path.join(directory, best_filename) 21 | torch.save(model_state_dict, best_filename) 22 | else: 23 | save_state = { 24 | 'epoch': epoch, 25 | 'state_dict': model_state_dict, 26 | 'optimizer': optimizer.state_dict(), 27 | 'lr_scheduler': lr_scheduler.state_dict() 28 | } 29 | if not os.path.exists(filename): 30 | torch.save(save_state, filename) 31 | logging.info('Epoch {} model saved in: {}'.format(epoch, filename)) 32 | 33 | pre_filename = '{}.pth'.format(str(epoch - 1)) 34 | pre_filename = os.path.join(directory, pre_filename) 35 | try: 36 | if os.path.exists(pre_filename): 37 | os.remove(pre_filename) 38 | except OSError as e: 39 | logging.info(e) 40 | 41 | def makedirs(path): 42 | """Create directory recursively if not exists. 43 | Similar to `makedir -p`, you can skip checking existence before this function. 44 | Parameters 45 | ---------- 46 | path : str 47 | Path of the desired dir 48 | """ 49 | try: 50 | os.makedirs(path) 51 | except OSError as exc: 52 | if exc.errno != errno.EEXIST: 53 | raise 54 | 55 | -------------------------------------------------------------------------------- /segmentron/models/model_zoo.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | 4 | from collections import OrderedDict 5 | from segmentron.utils.registry import Registry 6 | from ..config import cfg 7 | 8 | MODEL_REGISTRY = Registry("MODEL") 9 | MODEL_REGISTRY.__doc__ = """ 10 | Registry for segment model, i.e. the whole model. 11 | 12 | The registered object will be called with `obj()` 13 | and expected to return a `nn.Module` object. 14 | """ 15 | 16 | 17 | def get_segmentation_model(): 18 | """ 19 | Built the whole model, defined by `cfg.MODEL.META_ARCHITECTURE`. 20 | """ 21 | model_name = cfg.MODEL.MODEL_NAME 22 | model = MODEL_REGISTRY.get(model_name)() 23 | load_model_pretrain(model) 24 | return model 25 | 26 | 27 | def load_model_pretrain(model): 28 | if cfg.PHASE == 'train': 29 | if cfg.TRAIN.PRETRAINED_MODEL_PATH: 30 | logging.info('load pretrained model from {}'.format(cfg.TRAIN.PRETRAINED_MODEL_PATH)) 31 | state_dict_to_load = torch.load(cfg.TRAIN.PRETRAINED_MODEL_PATH) 32 | keys_wrong_shape = [] 33 | state_dict_suitable = OrderedDict() 34 | state_dict = model.state_dict() 35 | for k, v in state_dict_to_load.items(): 36 | if v.shape == state_dict[k].shape: 37 | state_dict_suitable[k] = v 38 | else: 39 | keys_wrong_shape.append(k) 40 | logging.info('Shape unmatched weights: {}'.format(keys_wrong_shape)) 41 | msg = model.load_state_dict(state_dict_suitable, strict=False) 42 | logging.info(msg) 43 | else: 44 | if cfg.TEST.TEST_MODEL_PATH: 45 | logging.info('load test model from {}'.format(cfg.TEST.TEST_MODEL_PATH)) 46 | model_dic = torch.load(cfg.TEST.TEST_MODEL_PATH, map_location='cuda:0') 47 | if 'state_dict' in model_dic.keys(): 48 | # load the last checkpoint 49 | model_dic = model_dic['state_dict'] 50 | msg = model.load_state_dict(model_dic, strict=False) 51 | logging.info(msg) -------------------------------------------------------------------------------- /tools/demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | 5 | cur_path = os.path.abspath(os.path.dirname(__file__)) 6 | root_path = os.path.split(cur_path)[0] 7 | sys.path.append(root_path) 8 | 9 | from torchvision import transforms 10 | from PIL import Image 11 | from segmentron.utils.visualize import get_color_pallete 12 | from segmentron.models.model_zoo import get_segmentation_model 13 | from segmentron.utils.options import parse_args 14 | from segmentron.utils.default_setup import default_setup 15 | from segmentron.config import cfg 16 | 17 | 18 | def demo(): 19 | args = parse_args() 20 | cfg.update_from_file(args.config_file) 21 | cfg.PHASE = 'test' 22 | cfg.ROOT_PATH = root_path 23 | cfg.check_and_freeze() 24 | default_setup(args) 25 | 26 | # output folder 27 | output_dir = os.path.join(cfg.VISUAL.OUTPUT_DIR, 'vis_result_{}_{}_{}_{}'.format( 28 | cfg.MODEL.MODEL_NAME, cfg.MODEL.BACKBONE, cfg.DATASET.NAME, cfg.TIME_STAMP)) 29 | if not os.path.exists(output_dir): 30 | os.makedirs(output_dir) 31 | 32 | # image transform 33 | transform = transforms.Compose([ 34 | transforms.ToTensor(), 35 | transforms.Normalize(cfg.DATASET.MEAN, cfg.DATASET.STD), 36 | ]) 37 | 38 | model = get_segmentation_model().to(args.device) 39 | model.eval() 40 | 41 | if os.path.isdir(args.input_img): 42 | img_paths = [os.path.join(args.input_img, x) for x in os.listdir(args.input_img)] 43 | else: 44 | img_paths = [args.input_img] 45 | for img_path in img_paths: 46 | image = Image.open(img_path).convert('RGB') 47 | size = image.size 48 | image = image.resize((512, 512)) 49 | images = transform(image).unsqueeze(0).to(args.device) 50 | with torch.no_grad(): 51 | output = model(images) 52 | 53 | pred = torch.argmax(output[0], 1).squeeze(0).cpu().data.numpy() 54 | mask = get_color_pallete(pred, cfg.DATASET.NAME).resize(size) 55 | outname = os.path.splitext(os.path.split(img_path)[-1])[0] + '.png' 56 | mask.save(os.path.join(output_dir, outname)) 57 | 58 | 59 | if __name__ == '__main__': 60 | demo() 61 | -------------------------------------------------------------------------------- /segmentron/utils/registry.py: -------------------------------------------------------------------------------- 1 | # this code heavily based on detectron2 2 | 3 | import logging 4 | import torch 5 | 6 | from ..config import cfg 7 | 8 | class Registry(object): 9 | """ 10 | The registry that provides name -> object mapping, to support third-party users' custom modules. 11 | 12 | To create a registry (inside segmentron): 13 | 14 | .. code-block:: python 15 | 16 | BACKBONE_REGISTRY = Registry('BACKBONE') 17 | 18 | To register an object: 19 | 20 | .. code-block:: python 21 | 22 | @BACKBONE_REGISTRY.register() 23 | class MyBackbone(): 24 | ... 25 | 26 | Or: 27 | 28 | .. code-block:: python 29 | 30 | BACKBONE_REGISTRY.register(MyBackbone) 31 | """ 32 | 33 | def __init__(self, name): 34 | """ 35 | Args: 36 | name (str): the name of this registry 37 | """ 38 | self._name = name 39 | 40 | self._obj_map = {} 41 | 42 | def _do_register(self, name, obj): 43 | assert ( 44 | name not in self._obj_map 45 | ), "An object named '{}' was already registered in '{}' registry!".format(name, self._name) 46 | self._obj_map[name] = obj 47 | 48 | def register(self, obj=None, name=None): 49 | """ 50 | Register the given object under the the name `obj.__name__`. 51 | Can be used as either a decorator or not. See docstring of this class for usage. 52 | """ 53 | if obj is None: 54 | # used as a decorator 55 | def deco(func_or_class, name=name): 56 | if name is None: 57 | name = func_or_class.__name__ 58 | self._do_register(name, func_or_class) 59 | return func_or_class 60 | 61 | return deco 62 | 63 | # used as a function call 64 | if name is None: 65 | name = obj.__name__ 66 | self._do_register(name, obj) 67 | 68 | 69 | 70 | def get(self, name): 71 | ret = self._obj_map.get(name) 72 | if ret is None: 73 | raise KeyError("No object named '{}' found in '{}' registry!".format(name, self._name)) 74 | 75 | return ret 76 | 77 | def get_list(self): 78 | return list(self._obj_map.keys()) 79 | -------------------------------------------------------------------------------- /segmentron/modules/csrc/criss_cross_attention/ca.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | 5 | namespace segmentron { 6 | at::Tensor ca_forward_cuda( 7 | const at::Tensor& t, 8 | const at::Tensor& f); 9 | 10 | std::tuple ca_backward_cuda( 11 | const at::Tensor& dw, 12 | const at::Tensor& t, 13 | const at::Tensor& f); 14 | 15 | at::Tensor ca_map_forward_cuda( 16 | const at::Tensor& weight, 17 | const at::Tensor& g); 18 | 19 | std::tuple ca_map_backward_cuda( 20 | const at::Tensor& dout, 21 | const at::Tensor& weight, 22 | const at::Tensor& g); 23 | 24 | 25 | at::Tensor ca_forward(const at::Tensor& t, 26 | const at::Tensor& f) { 27 | if (t.type().is_cuda()) { 28 | #ifdef WITH_CUDA 29 | return ca_forward_cuda(t, f); 30 | #else 31 | AT_ERROR("Not compiled with GPU support"); 32 | #endif 33 | } 34 | AT_ERROR("Not implemented on the CPU"); 35 | } 36 | 37 | std::tuple ca_backward(const at::Tensor& dw, 38 | const at::Tensor& t, 39 | const at::Tensor& f) { 40 | if (dw.type().is_cuda()) { 41 | #ifdef WITH_CUDA 42 | return ca_backward_cuda(dw, t, f); 43 | #else 44 | AT_ERROR("Not compiled with GPU support"); 45 | #endif 46 | } 47 | AT_ERROR("Not implemented on the CPU"); 48 | } 49 | 50 | at::Tensor ca_map_forward(const at::Tensor& weight, 51 | const at::Tensor& g) { 52 | if (weight.type().is_cuda()) { 53 | #ifdef WITH_CUDA 54 | return ca_map_forward_cuda(weight, g); 55 | #else 56 | AT_ERROR("Not compiled with GPU support"); 57 | #endif 58 | } 59 | AT_ERROR("Not implemented on the CPU"); 60 | } 61 | 62 | std::tuple ca_map_backward(const at::Tensor& dout, 63 | const at::Tensor& weight, 64 | const at::Tensor& g) { 65 | if (dout.type().is_cuda()) { 66 | #ifdef WITH_CUDA 67 | return ca_map_backward_cuda(dout, weight, g); 68 | #else 69 | AT_ERROR("Not compiled with GPU support"); 70 | #endif 71 | } 72 | AT_ERROR("Not implemented on the CPU"); 73 | } 74 | 75 | } // namespace segmentron 76 | -------------------------------------------------------------------------------- /segmentron/data/dataloader/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import hashlib 3 | import errno 4 | import tarfile 5 | from six.moves import urllib 6 | from torch.utils.model_zoo import tqdm 7 | 8 | def gen_bar_updater(): 9 | pbar = tqdm(total=None) 10 | 11 | def bar_update(count, block_size, total_size): 12 | if pbar.total is None and total_size: 13 | pbar.total = total_size 14 | progress_bytes = count * block_size 15 | pbar.update(progress_bytes - pbar.n) 16 | 17 | return bar_update 18 | 19 | def check_integrity(fpath, md5=None): 20 | if md5 is None: 21 | return True 22 | if not os.path.isfile(fpath): 23 | return False 24 | md5o = hashlib.md5() 25 | with open(fpath, 'rb') as f: 26 | # read in 1MB chunks 27 | for chunk in iter(lambda: f.read(1024 * 1024), b''): 28 | md5o.update(chunk) 29 | md5c = md5o.hexdigest() 30 | if md5c != md5: 31 | return False 32 | return True 33 | 34 | def makedir_exist_ok(dirpath): 35 | try: 36 | os.makedirs(dirpath) 37 | except OSError as e: 38 | if e.errno == errno.EEXIST: 39 | pass 40 | else: 41 | pass 42 | 43 | def download_url(url, root, filename=None, md5=None): 44 | """Download a file from a url and place it in root.""" 45 | root = os.path.expanduser(root) 46 | if not filename: 47 | filename = os.path.basename(url) 48 | fpath = os.path.join(root, filename) 49 | 50 | makedir_exist_ok(root) 51 | 52 | # downloads file 53 | if os.path.isfile(fpath) and check_integrity(fpath, md5): 54 | print('Using downloaded and verified file: ' + fpath) 55 | else: 56 | try: 57 | print('Downloading ' + url + ' to ' + fpath) 58 | urllib.request.urlretrieve(url, fpath, reporthook=gen_bar_updater()) 59 | except OSError: 60 | if url[:5] == 'https': 61 | url = url.replace('https:', 'http:') 62 | print('Failed download. Trying https -> http instead.' 63 | ' Downloading ' + url + ' to ' + fpath) 64 | urllib.request.urlretrieve(url, fpath, reporthook=gen_bar_updater()) 65 | 66 | def download_extract(url, root, filename, md5): 67 | download_url(url, root, filename, md5) 68 | with tarfile.open(os.path.join(root, filename), "r") as tar: 69 | tar.extractall(path=root) -------------------------------------------------------------------------------- /segmentron/modules/norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import warnings 4 | 5 | 6 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 7 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 8 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 9 | def norm_cdf(x): 10 | # Computes standard normal cumulative distribution function 11 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 12 | 13 | if (mean < a - 2 * std) or (mean > b + 2 * std): 14 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 15 | "The distribution of values may be incorrect.", 16 | stacklevel=2) 17 | 18 | with torch.no_grad(): 19 | # Values are generated by using a truncated uniform distribution and 20 | # then using the inverse CDF for the normal distribution. 21 | # Get upper and lower cdf values 22 | l = norm_cdf((a - mean) / std) 23 | u = norm_cdf((b - mean) / std) 24 | 25 | # Uniformly fill tensor with values from [l, u], then translate to 26 | # [2l-1, 2u-1]. 27 | tensor.uniform_(2 * l - 1, 2 * u - 1) 28 | 29 | # Use inverse cdf transform for normal distribution to get truncated 30 | # standard normal 31 | tensor.erfinv_() 32 | 33 | # Transform to proper mean, std 34 | tensor.mul_(std * math.sqrt(2.)) 35 | tensor.add_(mean) 36 | 37 | # Clamp to ensure it's in the proper range 38 | tensor.clamp_(min=a, max=b) 39 | return tensor 40 | 41 | 42 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 43 | # type: (Tensor, float, float, float, float) -> Tensor 44 | r"""Fills the input Tensor with values drawn from a truncated 45 | normal distribution. The values are effectively drawn from the 46 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` 47 | with values outside :math:`[a, b]` redrawn until they are within 48 | the bounds. The method used for generating the random values works 49 | best when :math:`a \leq \text{mean} \leq b`. 50 | Args: 51 | tensor: an n-dimensional `torch.Tensor` 52 | mean: the mean of the normal distribution 53 | std: the standard deviation of the normal distribution 54 | a: the minimum cutoff value 55 | b: the maximum cutoff value 56 | Examples: 57 | >>> w = torch.empty(3, 5) 58 | >>> nn.init.trunc_normal_(w) 59 | """ 60 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import glob 4 | import os 5 | from setuptools import find_packages, setup 6 | import torch 7 | from torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension 8 | 9 | torch_ver = [int(x) for x in torch.__version__.split(".")[:2]] 10 | assert torch_ver >= [1, 1], "Requires PyTorch >= 1.1" 11 | 12 | 13 | def get_extensions(): 14 | this_dir = os.path.dirname(os.path.abspath(__file__)) 15 | extensions_dir = os.path.join(this_dir, "segmentron", "modules", "csrc") 16 | 17 | main_source = os.path.join(extensions_dir, "vision.cpp") 18 | sources = glob.glob(os.path.join(extensions_dir, "**", "*.cpp")) 19 | source_cuda = glob.glob(os.path.join(extensions_dir, "**", "*.cu")) + glob.glob( 20 | os.path.join(extensions_dir, "*.cu") 21 | ) 22 | 23 | sources = [main_source] + sources 24 | 25 | extension = CppExtension 26 | 27 | extra_compile_args = {"cxx": []} 28 | define_macros = [] 29 | 30 | if (torch.cuda.is_available() and CUDA_HOME is not None) or os.getenv("FORCE_CUDA", "0") == "1": 31 | extension = CUDAExtension 32 | sources += source_cuda 33 | define_macros += [("WITH_CUDA", None)] 34 | extra_compile_args["nvcc"] = [ 35 | "-DCUDA_HAS_FP16=1", 36 | "-D__CUDA_NO_HALF_OPERATORS__", 37 | "-D__CUDA_NO_HALF_CONVERSIONS__", 38 | "-D__CUDA_NO_HALF2_OPERATORS__", 39 | ] 40 | 41 | # It's better if pytorch can do this by default .. 42 | CC = os.environ.get("CC", None) 43 | if CC is not None: 44 | extra_compile_args["nvcc"].append("-ccbin={}".format(CC)) 45 | 46 | sources = [os.path.join(extensions_dir, s) for s in sources] 47 | 48 | include_dirs = [extensions_dir] 49 | 50 | ext_modules = [ 51 | extension( 52 | "segmentron._C", 53 | sources, 54 | include_dirs=include_dirs, 55 | define_macros=define_macros, 56 | extra_compile_args=extra_compile_args, 57 | ) 58 | ] 59 | 60 | return ext_modules 61 | 62 | 63 | setup( 64 | name="segmentron", 65 | version="0.1", 66 | author="LikeLy-Journey", 67 | url="https://github.com/LikeLy-Journey/SegmenTron", 68 | description="platform for semantic segmentation base on pytorch.", 69 | # packages=find_packages(exclude=("configs", "tests")), 70 | # python_requires=">=3.6", 71 | # install_requires=[ 72 | # "termcolor>=1.1", 73 | # "Pillow", 74 | # "yacs>=0.1.6", 75 | # "tabulate", 76 | # "cloudpickle", 77 | # "matplotlib", 78 | # "tqdm>4.29.0", 79 | # "tensorboard", 80 | # ], 81 | # extras_require={"all": ["shapely", "psutil"]}, 82 | ext_modules=get_extensions(), 83 | cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, 84 | ) 85 | -------------------------------------------------------------------------------- /segmentron/models/backbones/build.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import logging 4 | import torch.utils.model_zoo as model_zoo 5 | 6 | from ...utils.download import download 7 | from ...utils.registry import Registry 8 | from ...config import cfg 9 | 10 | BACKBONE_REGISTRY = Registry("BACKBONE") 11 | BACKBONE_REGISTRY.__doc__ = """ 12 | Registry for backbone, i.e. resnet. 13 | 14 | The registered object will be called with `obj()` 15 | and expected to return a `nn.Module` object. 16 | """ 17 | 18 | model_urls = { 19 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 20 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 21 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 22 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 23 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 24 | 'resnet50c': 'https://github.com/LikeLy-Journey/SegmenTron/releases/download/v0.1.0/resnet50-25c4b509.pth', 25 | 'resnet101c': 'https://github.com/LikeLy-Journey/SegmenTron/releases/download/v0.1.0/resnet101-2a57e44d.pth', 26 | 'resnet152c': 'https://github.com/LikeLy-Journey/SegmenTron/releases/download/v0.1.0/resnet152-0d43d698.pth', 27 | 'xception65': 'https://github.com/LikeLy-Journey/SegmenTron/releases/download/v0.1.0/tf-xception65-270e81cf.pth', 28 | 'hrnet_w18_small_v1': 'https://github.com/LikeLy-Journey/SegmenTron/releases/download/v0.1.0/hrnet-w18-small-v1-08f8ae64.pth', 29 | 'mobilenet_v2': 'https://github.com/LikeLy-Journey/SegmenTron/releases/download/v0.1.0/mobilenetV2-15498621.pth', 30 | } 31 | 32 | 33 | def load_backbone_pretrained(model, backbone): 34 | if cfg.PHASE == 'train' and cfg.TRAIN.BACKBONE_PRETRAINED and (not cfg.TRAIN.PRETRAINED_MODEL_PATH): 35 | if os.path.isfile(cfg.TRAIN.BACKBONE_PRETRAINED_PATH): 36 | logging.info('Load backbone pretrained model from {}'.format( 37 | cfg.TRAIN.BACKBONE_PRETRAINED_PATH 38 | )) 39 | msg = model.load_state_dict(torch.load(cfg.TRAIN.BACKBONE_PRETRAINED_PATH), strict=False) 40 | logging.info(msg) 41 | elif backbone not in model_urls: 42 | logging.info('{} has no pretrained model'.format(backbone)) 43 | return 44 | else: 45 | logging.info('load backbone pretrained model from url..') 46 | try: 47 | msg = model.load_state_dict(model_zoo.load_url(model_urls[backbone]), strict=False) 48 | except Exception as e: 49 | logging.warning(e) 50 | logging.info('Use torch download failed, try custom method!') 51 | 52 | msg = model.load_state_dict(torch.load(download(model_urls[backbone], 53 | path=os.path.join(torch.hub._get_torch_home(), 'checkpoints'))), strict=False) 54 | logging.info(msg) 55 | 56 | 57 | 58 | def get_segmentation_backbone(backbone, norm_layer=torch.nn.BatchNorm2d): 59 | """ 60 | Built the backbone model, defined by `cfg.MODEL.BACKBONE`. 61 | """ 62 | model = BACKBONE_REGISTRY.get(backbone)(norm_layer) 63 | load_backbone_pretrained(model, backbone) 64 | return model -------------------------------------------------------------------------------- /coco_model.txt: -------------------------------------------------------------------------------- 1 | [167, 200, 7, 2 | 127, 228, 215, 3 | 26, 135, 248, 4 | 238, 73, 166, 5 | 91, 210, 215, 6 | 122, 20, 236, 7 | 234, 173, 35, 8 | 34, 98, 46, 9 | 115, 11, 206, 10 | 52, 251, 238, 11 | 209, 156, 236, 12 | 239, 10, 0, 13 | 26, 122, 36, 14 | 162, 181, 66, 15 | 26, 64, 22, 16 | 46, 226, 200, 17 | 89, 176, 6, 18 | 103, 36, 32, 19 | 74, 89, 159, 20 | 250, 215, 25, 21 | 57, 246, 82, 22 | 51, 156, 111, 23 | 139, 114, 219, 24 | 65, 208, 253, 25 | 33, 184, 119, 26 | 230, 239, 58, 27 | 176, 141, 158, 28 | 21, 29, 31, 29 | 135, 133, 163, 30 | 152, 241, 248, 31 | 253, 54, 7, 32 | 231, 86, 229, 33 | 179, 220, 46, 34 | 155, 217, 185, 35 | 58, 251, 190, 36 | 40, 201, 63, 37 | 236, 52, 220, 38 | 71, 203, 170, 39 | 96, 56, 41, 40 | 252, 231, 125, 41 | 255, 60, 100, 42 | 11, 172, 184, 43 | 127, 46, 248, 44 | 1, 105, 163, 45 | 191, 218, 95, 46 | 87, 160, 119, 47 | 149, 223, 79, 48 | 216, 180, 245, 49 | 58, 226, 163, 50 | 11, 43, 118, 51 | 20, 23, 100, 52 | 71, 222, 109, 53 | 124, 197, 150, 54 | 38, 106, 43, 55 | 115, 73, 156, 56 | 113, 110, 50, 57 | 94, 2, 184, 58 | 163, 168, 155, 59 | 83, 39, 145, 60 | 150, 169, 81, 61 | 134, 25, 2, 62 | 145, 49, 138, 63 | 46, 27, 209, 64 | 145, 187, 117, 65 | 197, 9, 211, 66 | 179, 12, 118, 67 | 107, 241, 133, 68 | 255, 176, 224, 69 | 49, 56, 217, 70 | 10, 227, 177, 71 | 152, 117, 25, 72 | 139, 76, 23, 73 | 53, 191, 10, 74 | 14, 244, 90, 75 | 247, 94, 189, 76 | 202, 160, 149, 77 | 24, 31, 150, 78 | 164, 236, 24, 79 | 47, 10, 204, 80 | 84, 187, 44, 81 | 17, 153, 55, 82 | 9, 191, 39, 83 | 216, 53, 216, 84 | 54, 13, 26, 85 | 241, 13, 196, 86 | 157, 90, 225, 87 | 99, 195, 27, 88 | 20, 186, 253, 89 | 175, 192, 0, 90 | 81, 11, 238, 91 | 137, 83, 196, 92 | 53, 186, 24, 93 | 231, 20, 101, 94 | 246, 223, 173, 95 | 75, 202, 249, 96 | 9, 188, 201, 97 | 216, 83, 7, 98 | 152, 92, 54, 99 | 137, 192, 79, 100 | 242, 169, 49, 101 | 99, 65, 207, 102 | 178, 112, 1, 103 | 120, 135, 40, 104 | 71, 220, 82, 105 | 180, 83, 172, 106 | 68, 137, 75, 107 | 46, 58, 15, 108 | 0, 80, 68, 109 | 175, 86, 173, 110 | 19, 208, 152, 111 | 215, 235, 142, 112 | 95, 30, 166, 113 | 246, 193, 8, 114 | 222, 19, 72, 115 | 177, 29, 183, 116 | 238, 61, 178, 117 | 246, 136, 87, 118 | 199, 207, 174, 119 | 218, 149, 231, 120 | 98, 179, 168, 121 | 23, 10, 10, 122 | 223, 9, 253, 123 | 206, 114, 95, 124 | 177, 242, 152, 125 | 115, 189, 142, 126 | 254, 105, 107, 127 | 59, 175, 153, 128 | 42, 114, 178, 129 | 50, 121, 91, 130 | 78, 238, 175, 131 | 232, 201, 123, 132 | 61, 39, 248, 133 | 76, 43, 218, 134 | 121, 191, 38, 135 | 13, 164, 242, 136 | 83, 70, 160, 137 | 109, 2, 64, 138 | 252, 81, 105, 139 | 151, 107, 83, 140 | 31, 95, 170, 141 | 7, 238, 218, 142 | 227, 49, 19, 143 | 56, 102, 49, 144 | 152, 241, 48, 145 | 110, 35, 108, 146 | 59, 198, 242, 147 | 186, 189, 39, 148 | 26, 157, 41, 149 | 183, 16, 169, 150 | 114, 26, 104, 151 | 131, 142, 127, 152 | 118, 85, 219, 153 | 203, 84, 210, 154 | 245, 16, 127, 155 | 57, 238, 110, 156 | 223, 225, 154, 157 | 143, 21, 231, 158 | 12, 215, 113, 159 | 117, 58, 3, 160 | 170, 201, 252, 161 | 60, 190, 197, 162 | 38, 22, 24, 163 | 37, 155, 237, 164 | 175, 41, 211, 165 | 188, 151, 129, 166 | 231, 92, 102, 167 | 229, 112, 245, 168 | 157, 182, 40, 169 | 1, 60, 204, 170 | 57, 58, 19, 171 | 156, 199, 180, 172 | 211, 47, 8] -------------------------------------------------------------------------------- /segmentron/solver/optimizer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch.nn as nn 3 | 4 | from torch import optim 5 | from segmentron.config import cfg 6 | 7 | 8 | def _set_batch_norm_attr(named_modules, attr, value): 9 | for m in named_modules: 10 | if isinstance(m[1], (nn.BatchNorm2d, nn.SyncBatchNorm)): 11 | setattr(m[1], attr, value) 12 | 13 | 14 | def _get_paramters(model): 15 | params_list = list() 16 | if hasattr(model, 'encoder') and model.encoder is not None and hasattr(model, 'decoder'): 17 | params_list.append({'params': model.encoder.parameters(), 'lr': cfg.SOLVER.LR}) 18 | if cfg.MODEL.BN_EPS_FOR_ENCODER: 19 | logging.info('Set bn custom eps for bn in encoder: {}'.format(cfg.MODEL.BN_EPS_FOR_ENCODER)) 20 | _set_batch_norm_attr(model.encoder.named_modules(), 'eps', cfg.MODEL.BN_EPS_FOR_ENCODER) 21 | 22 | for module in model.decoder: 23 | params_list.append({'params': getattr(model, module).parameters(), 24 | 'lr': cfg.SOLVER.LR * cfg.SOLVER.DECODER_LR_FACTOR}) 25 | 26 | if cfg.MODEL.BN_EPS_FOR_DECODER: 27 | logging.info('Set bn custom eps for bn in decoder: {}'.format(cfg.MODEL.BN_EPS_FOR_DECODER)) 28 | for module in model.decoder: 29 | _set_batch_norm_attr(getattr(model, module).named_modules(), 'eps', 30 | cfg.MODEL.BN_EPS_FOR_DECODER) 31 | else: 32 | logging.info('Model do not have encoder or decoder, params list was from model.parameters(), ' 33 | 'and arguments BN_EPS_FOR_ENCODER, BN_EPS_FOR_DECODER, DECODER_LR_FACTOR not used!') 34 | params_list = model.parameters() 35 | 36 | if cfg.MODEL.BN_MOMENTUM and cfg.MODEL.BN_TYPE in ['BN']: 37 | logging.info('Set bn custom momentum: {}'.format(cfg.MODEL.BN_MOMENTUM)) 38 | _set_batch_norm_attr(model.named_modules(), 'momentum', cfg.MODEL.BN_MOMENTUM) 39 | elif cfg.MODEL.BN_MOMENTUM and cfg.MODEL.BN_TYPE not in ['BN']: 40 | logging.info('Batch norm type is {}, custom bn momentum is not effective!'.format(cfg.MODEL.BN_TYPE)) 41 | 42 | return params_list 43 | 44 | 45 | def get_optimizer(model): 46 | parameters = _get_paramters(model) 47 | opt_lower = cfg.SOLVER.OPTIMIZER.lower() 48 | 49 | if opt_lower == 'sgd': 50 | optimizer = optim.SGD( 51 | parameters, lr=cfg.SOLVER.LR, momentum=cfg.SOLVER.MOMENTUM, weight_decay=cfg.SOLVER.WEIGHT_DECAY) 52 | elif opt_lower == 'adam': 53 | optimizer = optim.Adam( 54 | parameters, lr=cfg.SOLVER.LR, eps=cfg.SOLVER.EPSILON, weight_decay=cfg.SOLVER.WEIGHT_DECAY) 55 | elif opt_lower == 'adamw': 56 | optimizer = optim.AdamW( 57 | parameters, lr=cfg.SOLVER.LR, eps=cfg.SOLVER.EPSILON, weight_decay=cfg.SOLVER.WEIGHT_DECAY) 58 | elif opt_lower == 'adadelta': 59 | optimizer = optim.Adadelta( 60 | parameters, lr=cfg.SOLVER.LR, eps=cfg.SOLVER.EPSILON, weight_decay=cfg.SOLVER.WEIGHT_DECAY) 61 | elif opt_lower == 'rmsprop': 62 | optimizer = optim.RMSprop( 63 | parameters, lr=cfg.SOLVER.LR, alpha=0.9, eps=cfg.SOLVER.EPSILON, 64 | momentum=cfg.SOLVER.MOMENTUM, weight_decay=cfg.SOLVER.WEIGHT_DECAY) 65 | else: 66 | raise ValueError("Expected optimizer method in [sgd, adam, adadelta, rmsprop], but received " 67 | "{}".format(opt_lower)) 68 | 69 | return optimizer 70 | -------------------------------------------------------------------------------- /segmentron/utils/download.py: -------------------------------------------------------------------------------- 1 | import os 2 | import hashlib 3 | import requests 4 | from tqdm import tqdm 5 | 6 | def check_sha1(filename, sha1_hash): 7 | """Check whether the sha1 hash of the file content matches the expected hash. 8 | Parameters 9 | ---------- 10 | filename : str 11 | Path to the file. 12 | sha1_hash : str 13 | Expected sha1 hash in hexadecimal digits. 14 | Returns 15 | ------- 16 | bool 17 | Whether the file content matches the expected hash. 18 | """ 19 | sha1 = hashlib.sha1() 20 | with open(filename, 'rb') as f: 21 | while True: 22 | data = f.read(1048576) 23 | if not data: 24 | break 25 | sha1.update(data) 26 | 27 | sha1_file = sha1.hexdigest() 28 | l = min(len(sha1_file), len(sha1_hash)) 29 | return sha1.hexdigest()[0:l] == sha1_hash[0:l] 30 | 31 | def download(url, path=None, overwrite=False, sha1_hash=None): 32 | """Download an given URL 33 | Parameters 34 | ---------- 35 | url : str 36 | URL to download 37 | path : str, optional 38 | Destination path to store downloaded file. By default stores to the 39 | current directory with same name as in url. 40 | overwrite : bool, optional 41 | Whether to overwrite destination file if already exists. 42 | sha1_hash : str, optional 43 | Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified 44 | but doesn't match. 45 | Returns 46 | ------- 47 | str 48 | The file path of the downloaded file. 49 | """ 50 | if path is None: 51 | fname = url.split('/')[-1] 52 | else: 53 | path = os.path.expanduser(path) 54 | if os.path.isdir(path): 55 | fname = os.path.join(path, url.split('/')[-1]) 56 | else: 57 | fname = path 58 | 59 | if overwrite or not os.path.exists(fname) or (sha1_hash and not check_sha1(fname, sha1_hash)): 60 | dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname))) 61 | if not os.path.exists(dirname): 62 | os.makedirs(dirname) 63 | 64 | print('Downloading %s from %s...'%(fname, url)) 65 | r = requests.get(url, stream=True) 66 | if r.status_code != 200: 67 | raise RuntimeError("Failed downloading url %s"%url) 68 | total_length = r.headers.get('content-length') 69 | with open(fname, 'wb') as f: 70 | if total_length is None: # no content length header 71 | for chunk in r.iter_content(chunk_size=1024): 72 | if chunk: # filter out keep-alive new chunks 73 | f.write(chunk) 74 | else: 75 | total_length = int(total_length) 76 | for chunk in tqdm(r.iter_content(chunk_size=1024), 77 | total=int(total_length / 1024. + 0.5), 78 | unit='KB', unit_scale=False, dynamic_ncols=True): 79 | f.write(chunk) 80 | 81 | if sha1_hash and not check_sha1(fname, sha1_hash): 82 | raise UserWarning('File {} is downloaded but the content hash does not match. ' \ 83 | 'The repo may be outdated or download may be incomplete. ' \ 84 | 'If the "repo_url" is overridden, consider switching to ' \ 85 | 'the default repo.'.format(fname)) 86 | 87 | return fname -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ -------------------------------------------------------------------------------- /segmentron/data/dataloader/transparent11.py: -------------------------------------------------------------------------------- 1 | """Transparent Semantic Segmentation Dataset.""" 2 | import os 3 | import logging 4 | import torch 5 | import numpy as np 6 | 7 | from PIL import Image 8 | from .seg_data_base import SegmentationDataset 9 | 10 | 11 | class TransparentSegmentation(SegmentationDataset): 12 | BASE_DIR = 'Trans10K_cls12' 13 | NUM_CLASS = 12 14 | 15 | def __init__(self, root='datasets/transparent', split='test', mode=None, transform=None, **kwargs): 16 | super(TransparentSegmentation, self).__init__(root, split, mode, transform, **kwargs) 17 | root = os.path.join(self.root, self.BASE_DIR) 18 | assert os.path.exists(root), "Please put the data in {SEG_ROOT}/datasets/transparent" 19 | self.images, self.masks = _get_trans10k_pairs(root, split) 20 | assert (len(self.images) == len(self.masks)) 21 | if len(self.images) == 0: 22 | raise RuntimeError("Found 0 images in subfolders of:" + root + "\n") 23 | logging.info('Found {} images in the folder {}'.format(len(self.images), root)) 24 | 25 | def _mask_transform(self, mask): 26 | return torch.LongTensor(np.array(mask).astype('int32')) 27 | 28 | def _val_sync_transform_resize(self, img, mask): 29 | short_size = self.crop_size 30 | img = img.resize(short_size, Image.BILINEAR) 31 | mask = mask.resize(short_size, Image.NEAREST) 32 | 33 | # final transform 34 | img, mask = self._img_transform(img), self._mask_transform(mask) 35 | return img, mask 36 | 37 | def __getitem__(self, index): 38 | img = Image.open(self.images[index]).convert('RGB') 39 | if self.mode == 'test': 40 | img = self._img_transform(img) 41 | if self.transform is not None: 42 | img = self.transform(img) 43 | return img, os.path.basename(self.images[index]) 44 | mask = Image.open(self.masks[index]).convert("P") 45 | # synchrosized transform 46 | if self.mode == 'train': 47 | img, mask = self._sync_transform(img, mask, resize=True) 48 | elif self.mode == 'val': 49 | img, mask = self._val_sync_transform_resize(img, mask) 50 | else: 51 | assert self.mode == 'testval' 52 | img, mask = self._val_sync_transform_resize(img, mask) 53 | # general resize, normalize and to Tensor 54 | if self.transform is not None: 55 | img = self.transform(img) 56 | return img, mask, os.path.basename(self.images[index]) 57 | 58 | def __len__(self): 59 | return len(self.images) 60 | 61 | @property 62 | def pred_offset(self): 63 | return 1 64 | 65 | @property 66 | def classes(self): 67 | """Category names.""" 68 | return ('Background', 'Shelf', 'Jar or Tank', 'Freezer', 'Window', 69 | 'Glass Door', 'Eyeglass', 'Cup', 'Floor Glass', 'Glass Bow', 70 | 'Water Bottle', 'Storage Box') 71 | 72 | 73 | def _get_trans10k_pairs(folder, mode='train'): 74 | img_paths = [] 75 | mask_paths = [] 76 | if mode == 'train': 77 | img_folder = os.path.join(folder, 'train/images') 78 | mask_folder = os.path.join(folder, 'train/masks_12') 79 | elif mode == "val": 80 | img_folder = os.path.join(folder, 'validation/images') 81 | mask_folder = os.path.join(folder, 'validation/masks_12') 82 | else: 83 | assert mode == "test" 84 | img_folder = os.path.join(folder, 'test/images') 85 | mask_folder = os.path.join(folder, 'test/masks_12') 86 | 87 | for filename in os.listdir(img_folder): 88 | basename, _ = os.path.splitext(filename) 89 | if filename.endswith(".jpg"): 90 | imgpath = os.path.join(img_folder, filename) 91 | maskname = basename + '_mask.png' 92 | maskpath = os.path.join(mask_folder, maskname) 93 | if os.path.isfile(maskpath): 94 | img_paths.append(imgpath) 95 | mask_paths.append(maskpath) 96 | else: 97 | logging.info('cannot find the mask:', maskpath) 98 | 99 | return img_paths, mask_paths 100 | 101 | 102 | if __name__ == '__main__': 103 | train_dataset = TransparentSegmentation() 104 | -------------------------------------------------------------------------------- /segmentron/data/dataloader/acdc.py: -------------------------------------------------------------------------------- 1 | """ACDC dataset""" 2 | import os 3 | import torch 4 | import numpy as np 5 | import logging 6 | import glob 7 | 8 | from PIL import Image 9 | from .seg_data_base import SegmentationDataset 10 | import random 11 | 12 | 13 | class ACDCSegmentation(SegmentationDataset): 14 | BASE_DIR = 'acdc' 15 | NUM_CLASS = 19 16 | 17 | def __init__(self, root='datasets/acdc', split='train', mode=None, transform=None, **kwargs): 18 | super(ACDCSegmentation, self).__init__(root, split, mode, transform, **kwargs) 19 | assert os.path.exists(self.root), "Please put dataset in {SEG_ROOT}/datasets/acdc" 20 | self.images, self.mask_paths = _get_acdc_pairs(self.root, self.split) 21 | assert (len(self.images) == len(self.mask_paths)) 22 | if len(self.images) == 0: 23 | raise RuntimeError("Found 0 images in subfolders of:" + root + "\n") 24 | self.valid_classes = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 25 | 23, 24, 25, 26, 27, 28, 31, 32, 33] 26 | self._key = np.array([-1, -1, -1, -1, -1, -1, 27 | -1, -1, 0, 1, -1, -1, 28 | 2, 3, 4, -1, -1, -1, 29 | 5, -1, 6, 7, 8, 9, 30 | 10, 11, 12, 13, 14, 15, 31 | -1, -1, 16, 17, 18]) 32 | self._mapping = np.array(range(-1, len(self._key) - 1)).astype('int32') 33 | 34 | def _class_to_index(self, mask): 35 | values = np.unique(mask) 36 | for value in values: 37 | assert (value in self._mapping) 38 | index = np.digitize(mask.ravel(), self._mapping, right=True) 39 | return self._key[index].reshape(mask.shape) 40 | def _val_sync_transform_resize(self, img, mask): 41 | w, h = img.size 42 | x1 = random.randint(0, w - self.crop_size[1]) 43 | y1 = random.randint(0, h - self.crop_size[0]) 44 | img = img.crop((x1, y1, x1 + self.crop_size[1], y1 + self.crop_size[0])) 45 | mask = mask.crop((x1, y1, x1 + self.crop_size[1], y1 + self.crop_size[0])) 46 | 47 | img, mask = self._img_transform(img), self._mask_transform(mask) 48 | return img, mask 49 | 50 | def __getitem__(self, index): 51 | img = Image.open(self.images[index]).convert('RGB') 52 | if self.mode == 'test': 53 | if self.transform is not None: 54 | img = self.transform(img) 55 | return img, os.path.basename(self.images[index]) 56 | mask = Image.open(self.mask_paths[index]) 57 | if self.mode == 'train': 58 | img, mask = self._sync_transform(img, mask, resize=True) 59 | elif self.mode == 'val': 60 | img, mask = self._val_sync_transform_resize(img, mask) 61 | else: 62 | assert self.mode == 'testval' 63 | img, mask = self._val_sync_transform_resize(img, mask) 64 | if self.transform is not None: 65 | img = self.transform(img) 66 | return img, mask, os.path.basename(self.images[index]) 67 | 68 | def _mask_transform(self, mask): 69 | target = self._class_to_index(np.array(mask).astype('int32')) 70 | return torch.LongTensor(np.array(target).astype('int32')) 71 | 72 | def __len__(self): 73 | return len(self.images) 74 | 75 | @property 76 | def pred_offset(self): 77 | return 0 78 | 79 | @property 80 | def classes(self): 81 | """Category names.""" 82 | return ('road', 'sidewalk', 'building', 'wall', 'fence', 'pole', 'traffic light', 83 | 'traffic sign', 'vegetation', 'terrain', 'sky', 'person', 'rider', 'car', 84 | 'truck', 'bus', 'train', 'motorcycle', 'bicycle') 85 | 86 | 87 | def _get_acdc_pairs(folder, split='train'): 88 | img_paths = [] 89 | mask_paths = [] 90 | if split == 'test': 91 | split = 'val' 92 | img_paths_temp = glob.glob(os.path.join(folder, 'rgb_anon/*/{}/*/*_rgb_anon.png'.format(split))) 93 | for imgpath in img_paths_temp: 94 | maskpath = imgpath.replace('/rgb_anon/', '/gt/').replace('rgb_anon.png', 'gt_labelIds.png') 95 | if os.path.isfile(imgpath) and os.path.isfile(maskpath): 96 | img_paths.append(imgpath) 97 | mask_paths.append(maskpath) 98 | else: 99 | logging.info('cannot find the mask or image:', imgpath, maskpath) 100 | logging.info('Found {} images in the folder {}'.format(len(img_paths), folder)) 101 | return img_paths, mask_paths 102 | 103 | 104 | if __name__ == '__main__': 105 | dataset = ACDCSegmentation() 106 | -------------------------------------------------------------------------------- /segmentron/models/segbase.py: -------------------------------------------------------------------------------- 1 | """Base Model for Semantic Segmentation""" 2 | import math 3 | import numbers 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from .backbones import get_segmentation_backbone 10 | from ..data.dataloader import datasets 11 | from ..modules import get_norm 12 | from ..config import cfg 13 | __all__ = ['SegBaseModel'] 14 | 15 | 16 | class SegBaseModel(nn.Module): 17 | r"""Base Model for Semantic Segmentation 18 | """ 19 | def __init__(self, need_backbone=True): 20 | super(SegBaseModel, self).__init__() 21 | self.nclass = datasets[cfg.DATASET.NAME].NUM_CLASS 22 | self.aux = cfg.SOLVER.AUX 23 | self.norm_layer = get_norm(cfg.MODEL.BN_TYPE) 24 | self.backbone = None 25 | self.encoder = None 26 | if need_backbone: 27 | self.get_backbone() 28 | 29 | def get_backbone(self): 30 | self.backbone = cfg.MODEL.BACKBONE.lower() 31 | self.encoder = get_segmentation_backbone(self.backbone, self.norm_layer) 32 | 33 | def base_forward(self, x): 34 | """forwarding backbone network""" 35 | c1, c2, c3, c4 = self.encoder(x) 36 | return c1, c2, c3, c4 37 | 38 | def demo(self, x): 39 | pred = self.forward(x) 40 | if self.aux: 41 | pred = pred[0] 42 | return pred 43 | 44 | def evaluate(self, image): 45 | """evaluating network with inputs and targets""" 46 | scales = cfg.TEST.SCALES 47 | flip = cfg.TEST.FLIP 48 | crop_size = _to_tuple(cfg.TEST.CROP_SIZE) if cfg.TEST.CROP_SIZE else None 49 | batch, _, h, w = image.shape 50 | base_size = max(h, w) 51 | # scores = torch.zeros((batch, self.nclass, h, w)).to(image.device) 52 | scores = None 53 | for scale in scales: 54 | long_size = int(math.ceil(base_size * scale)) 55 | if h > w: 56 | height = long_size 57 | width = int(1.0 * w * long_size / h + 0.5) 58 | else: 59 | width = long_size 60 | height = int(1.0 * h * long_size / w + 0.5) 61 | 62 | # resize image to current size 63 | cur_img = _resize_image(image, height, width) 64 | if crop_size is not None: 65 | assert crop_size[0] >= h and crop_size[1] >= w 66 | crop_size_scaled = (int(math.ceil(crop_size[0] * scale)), 67 | int(math.ceil(crop_size[1] * scale))) 68 | cur_img = _pad_image(cur_img, crop_size_scaled) 69 | outputs = self.forward(cur_img)[0][..., :height, :width] 70 | if flip: 71 | outputs += _flip_image(self.forward(_flip_image(cur_img))[0])[..., :height, :width] 72 | 73 | score = _resize_image(outputs, h, w) 74 | 75 | if scores is None: 76 | scores = score 77 | else: 78 | scores += score 79 | return scores 80 | 81 | 82 | def _resize_image(img, h, w): 83 | return F.interpolate(img, size=[h, w], mode='bilinear', align_corners=True) 84 | 85 | 86 | def _pad_image(img, crop_size): 87 | b, c, h, w = img.shape 88 | assert(c == 3) 89 | padh = crop_size[0] - h if h < crop_size[0] else 0 90 | padw = crop_size[1] - w if w < crop_size[1] else 0 91 | if padh == 0 and padw == 0: 92 | return img 93 | img_pad = F.pad(img, (0, padh, 0, padw)) 94 | 95 | # TODO clean this code 96 | # mean = cfg.DATASET.MEAN 97 | # std = cfg.DATASET.STD 98 | # pad_values = -np.array(mean) / np.array(std) 99 | # img_pad = torch.zeros((b, c, h + padh, w + padw)).to(img.device) 100 | # for i in range(c): 101 | # # print(img[:, i, :, :].unsqueeze(1).shape) 102 | # img_pad[:, i, :, :] = torch.squeeze( 103 | # F.pad(img[:, i, :, :].unsqueeze(1), (0, padh, 0, padw), 104 | # 'constant', value=pad_values[i]), 1) 105 | # assert(img_pad.shape[2] >= crop_size[0] and img_pad.shape[3] >= crop_size[1]) 106 | 107 | return img_pad 108 | 109 | 110 | def _crop_image(img, h0, h1, w0, w1): 111 | return img[:, :, h0:h1, w0:w1] 112 | 113 | 114 | def _flip_image(img): 115 | assert(img.ndim == 4) 116 | return img.flip((3)) 117 | 118 | 119 | def _to_tuple(size): 120 | if isinstance(size, (list, tuple)): 121 | assert len(size), 'Expect eval crop size contains two element, ' \ 122 | 'but received {}'.format(len(size)) 123 | return tuple(size) 124 | elif isinstance(size, numbers.Number): 125 | return tuple((size, size)) 126 | else: 127 | raise ValueError('Unsupport datatype: {}'.format(type(size))) 128 | -------------------------------------------------------------------------------- /tools/eval.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | import sys 5 | 6 | cur_path = os.path.abspath(os.path.dirname(__file__)) 7 | root_path = os.path.split(cur_path)[0] 8 | sys.path.append(root_path) 9 | 10 | import logging 11 | import torch 12 | import torch.nn as nn 13 | import torch.utils.data as data 14 | import torch.nn.functional as F 15 | 16 | from tabulate import tabulate 17 | from torchvision import transforms 18 | from segmentron.data.dataloader import get_segmentation_dataset 19 | from segmentron.models.model_zoo import get_segmentation_model 20 | from segmentron.utils.score import SegmentationMetric 21 | from segmentron.utils.distributed import synchronize, make_data_sampler, make_batch_data_sampler 22 | from segmentron.config import cfg 23 | from segmentron.utils.options import parse_args 24 | from segmentron.utils.default_setup import default_setup 25 | 26 | 27 | class Evaluator(object): 28 | def __init__(self, args): 29 | self.args = args 30 | self.device = torch.device(args.device) 31 | 32 | # image transform 33 | input_transform = transforms.Compose([ 34 | transforms.ToTensor(), 35 | transforms.Normalize(cfg.DATASET.MEAN, cfg.DATASET.STD), 36 | ]) 37 | 38 | # dataset and dataloader 39 | val_dataset = get_segmentation_dataset(cfg.DATASET.NAME, split='test', mode='testval', transform=input_transform) 40 | val_sampler = make_data_sampler(val_dataset, False, args.distributed) 41 | val_batch_sampler = make_batch_data_sampler(val_sampler, images_per_batch=cfg.TEST.BATCH_SIZE, drop_last=False) 42 | self.val_loader = data.DataLoader(dataset=val_dataset, 43 | batch_sampler=val_batch_sampler, 44 | num_workers=cfg.DATASET.WORKERS, 45 | pin_memory=True) 46 | self.classes = val_dataset.classes 47 | # create network 48 | self.model = get_segmentation_model().to(self.device) 49 | 50 | if hasattr(self.model, 'encoder') and hasattr(self.model.encoder, 'named_modules') and \ 51 | cfg.MODEL.BN_EPS_FOR_ENCODER: 52 | logging.info('set bn custom eps for bn in encoder: {}'.format(cfg.MODEL.BN_EPS_FOR_ENCODER)) 53 | self.set_batch_norm_attr(self.model.encoder.named_modules(), 'eps', cfg.MODEL.BN_EPS_FOR_ENCODER) 54 | 55 | if args.distributed: 56 | self.model = nn.parallel.DistributedDataParallel(self.model, 57 | device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True) 58 | self.model.to(self.device) 59 | 60 | self.metric = SegmentationMetric(val_dataset.num_class, args.distributed) 61 | 62 | def set_batch_norm_attr(self, named_modules, attr, value): 63 | for m in named_modules: 64 | if isinstance(m[1], nn.BatchNorm2d) or isinstance(m[1], nn.SyncBatchNorm): 65 | setattr(m[1], attr, value) 66 | 67 | def eval(self): 68 | self.metric.reset() 69 | self.model.eval() 70 | if self.args.distributed: 71 | model = self.model.module 72 | else: 73 | model = self.model 74 | 75 | logging.info("Start validation, Total sample: {:d}".format(len(self.val_loader))) 76 | import time 77 | time_start = time.time() 78 | for i, (image, target, filename) in enumerate(self.val_loader): 79 | image = image.to(self.device) 80 | target = target.to(self.device) 81 | 82 | with torch.no_grad(): 83 | output = model.evaluate(image) 84 | 85 | self.metric.update(output, target) 86 | pixAcc, mIoU = self.metric.get() 87 | logging.info("Sample: {:d}, validation pixAcc: {:.3f}, mIoU: {:.3f}".format( 88 | i + 1, pixAcc * 100, mIoU * 100)) 89 | 90 | synchronize() 91 | pixAcc, mIoU, category_iou = self.metric.get(return_category_iou=True) 92 | logging.info('Eval use time: {:.3f} second'.format(time.time() - time_start)) 93 | logging.info('End validation pixAcc: {:.3f}, mIoU: {:.3f}'.format( 94 | pixAcc * 100, mIoU * 100)) 95 | 96 | headers = ['class id', 'class name', 'iou'] 97 | table = [] 98 | for i, cls_name in enumerate(self.classes): 99 | table.append([cls_name, category_iou[i]]) 100 | logging.info('Category iou: \n {}'.format(tabulate(table, headers, tablefmt='grid', showindex="always", 101 | numalign='center', stralign='center'))) 102 | 103 | 104 | if __name__ == '__main__': 105 | args = parse_args() 106 | cfg.update_from_file(args.config_file) 107 | cfg.update_from_list(args.opts) 108 | cfg.PHASE = 'test' 109 | cfg.ROOT_PATH = root_path 110 | cfg.check_and_freeze() 111 | 112 | default_setup(args) 113 | 114 | evaluator = Evaluator(args) 115 | evaluator.eval() 116 | -------------------------------------------------------------------------------- /segmentron/config/config.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import unicode_literals 5 | 6 | import codecs 7 | import yaml 8 | import six 9 | import time 10 | 11 | from ast import literal_eval 12 | 13 | class SegmentronConfig(dict): 14 | def __init__(self, *args, **kwargs): 15 | super(SegmentronConfig, self).__init__(*args, **kwargs) 16 | self.immutable = False 17 | 18 | def __setattr__(self, key, value, create_if_not_exist=True): 19 | if key in ["immutable"]: 20 | self.__dict__[key] = value 21 | return 22 | 23 | t = self 24 | keylist = key.split(".") 25 | for k in keylist[:-1]: 26 | t = t.__getattr__(k, create_if_not_exist) 27 | 28 | t.__getattr__(keylist[-1], create_if_not_exist) 29 | t[keylist[-1]] = value 30 | 31 | def __getattr__(self, key, create_if_not_exist=True): 32 | if key in ["immutable"]: 33 | if key not in self.__dict__: 34 | self.__dict__[key] = False 35 | return self.__dict__[key] 36 | 37 | if not key in self: 38 | if not create_if_not_exist: 39 | raise KeyError 40 | self[key] = SegmentronConfig() 41 | return self[key] 42 | 43 | def __setitem__(self, key, value): 44 | # 45 | if self.immutable: 46 | raise AttributeError( 47 | 'Attempted to set "{}" to "{}", but SegConfig is immutable'. 48 | format(key, value)) 49 | # 50 | if isinstance(value, six.string_types): 51 | try: 52 | value = literal_eval(value) 53 | except ValueError: 54 | pass 55 | except SyntaxError: 56 | pass 57 | super(SegmentronConfig, self).__setitem__(key, value) 58 | 59 | def update_from_other_cfg(self, other): 60 | if isinstance(other, dict): 61 | other = SegmentronConfig(other) 62 | assert isinstance(other, SegmentronConfig) 63 | cfg_list = [("", other)] 64 | while len(cfg_list): 65 | prefix, tdic = cfg_list[0] 66 | cfg_list = cfg_list[1:] 67 | for key, value in tdic.items(): 68 | key = "{}.{}".format(prefix, key) if prefix else key 69 | if isinstance(value, dict): 70 | cfg_list.append((key, value)) 71 | continue 72 | try: 73 | self.__setattr__(key, value, create_if_not_exist=False) 74 | except KeyError: 75 | raise KeyError('Non-existent config key: {}'.format(key)) 76 | 77 | def remove_irrelevant_cfg(self): 78 | model_name = self.MODEL.MODEL_NAME 79 | 80 | from ..models.model_zoo import MODEL_REGISTRY 81 | model_list = MODEL_REGISTRY.get_list() 82 | model_list_lower = [x.lower() for x in model_list] 83 | 84 | assert model_name.lower() in model_list_lower, "Expected model name in {}, but received {}"\ 85 | .format(model_list, model_name) 86 | pop_keys = [] 87 | for key in self.MODEL.keys(): 88 | if key.lower() in model_list_lower: 89 | if model_name.lower() == 'pointrend' and \ 90 | key.lower() == self.MODEL.POINTREND.BASEMODEL.lower(): 91 | continue 92 | if key.lower() in model_list_lower and key.lower() != model_name.lower(): 93 | if model_name.lower() in ['pvt_trans2seg', 'pvt_fpt']: continue 94 | pop_keys.append(key) 95 | for key in pop_keys: 96 | self.MODEL.pop(key) 97 | 98 | 99 | 100 | def check_and_freeze(self): 101 | self.TIME_STAMP = time.strftime('%Y-%m-%d-%H-%M', time.localtime()) 102 | # TODO: remove irrelevant config and then freeze 103 | self.remove_irrelevant_cfg() 104 | self.immutable = True 105 | 106 | def update_from_list(self, config_list): 107 | if len(config_list) % 2 != 0: 108 | raise ValueError( 109 | "Command line options config format error! Please check it: {}". 110 | format(config_list)) 111 | for key, value in zip(config_list[0::2], config_list[1::2]): 112 | try: 113 | self.__setattr__(key, value, create_if_not_exist=False) 114 | except KeyError: 115 | raise KeyError('Non-existent config key: {}'.format(key)) 116 | 117 | def update_from_file(self, config_file): 118 | with codecs.open(config_file, 'r', 'utf-8') as file: 119 | loaded_cfg = yaml.load(file, Loader=yaml.FullLoader) 120 | self.update_from_other_cfg(loaded_cfg) 121 | 122 | def set_immutable(self, immutable): 123 | self.immutable = immutable 124 | for value in self.values(): 125 | if isinstance(value, SegmentronConfig): 126 | value.set_immutable(immutable) 127 | 128 | def is_immutable(self): 129 | return self.immutable -------------------------------------------------------------------------------- /segmentron/models/pvt_fpt.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from functools import partial 4 | 5 | from .segbase import SegBaseModel 6 | from .model_zoo import MODEL_REGISTRY 7 | from ..modules import _FCNHead 8 | from ..config import cfg 9 | from .backbones.pvt import Attention, Mlp 10 | 11 | 12 | __all__ = ['PVT_FPT'] 13 | 14 | 15 | @MODEL_REGISTRY.register(name='PVT_FPT') 16 | class PVT_FPT(SegBaseModel): 17 | 18 | def __init__(self, ncls=None): 19 | super().__init__() 20 | if self.backbone.startswith('mobilenet'): 21 | c1_channels = 24 22 | c4_channels = 320 23 | elif self.backbone.startswith('resnet18'): 24 | c1_channels = 64 25 | c4_channels = 512 26 | elif self.backbone.startswith('pvt'): 27 | c1_channels = 64 28 | c4_channels = 512 29 | elif self.backbone.startswith('resnet34'): 30 | c1_channels = 64 31 | c4_channels = 512 32 | elif self.backbone.startswith('hrnet_w18_small_v1'): 33 | c1_channels = 16 34 | c4_channels = 128 35 | else: 36 | c1_channels = 256 37 | c4_channels = 2048 38 | 39 | vit_params = cfg.MODEL.TRANS4TRANS 40 | hid_dim = cfg.MODEL.TRANS4TRANS.hid_dim 41 | 42 | assert cfg.AUG.CROP == False and cfg.TRAIN.CROP_SIZE[0] == cfg.TRAIN.CROP_SIZE[1]\ 43 | == cfg.TRAIN.BASE_SIZE == cfg.TEST.CROP_SIZE[0] == cfg.TEST.CROP_SIZE[1] 44 | c4_HxW = (cfg.TRAIN.BASE_SIZE // 32) ** 2 45 | 46 | vit_params['decoder_feat_HxW'] = c4_HxW 47 | vit_params['nclass'] = self.nclass if ncls is None else ncls 48 | vit_params['emb_chans'] = cfg.MODEL.EMB_CHANNELS 49 | 50 | self.fpt_head = FPTHead(vit_params) 51 | if self.aux: 52 | self.auxlayer = _FCNHead(728, self.nclass) 53 | self.__setattr__('decoder', ['fpt_head', 'auxlayer'] if self.aux else ['fpt_head']) 54 | 55 | 56 | def forward(self, x): 57 | size = x.size()[2:] 58 | c1, c2, c3, c4 = self.encoder(x) 59 | 60 | outputs = list() 61 | x = self.fpt_head(c1, c2, c3, c4) 62 | x = F.interpolate(x, size, mode='bilinear', align_corners=True) 63 | 64 | outputs.append(x) 65 | if self.aux: 66 | auxout = self.auxlayer(c3) 67 | auxout = F.interpolate(auxout, size, mode='bilinear', align_corners=True) 68 | outputs.append(auxout) 69 | return tuple(outputs) 70 | 71 | class ProjEmbed(nn.Module): 72 | """ feature map to Projected Embedding 73 | """ 74 | def __init__(self, in_chans=512, emb_chans=128): 75 | super().__init__() 76 | self.proj = nn.Linear(in_chans, emb_chans) 77 | self.norm = nn.LayerNorm(emb_chans) 78 | 79 | def forward(self, x): 80 | x = self.proj(x.flatten(2).transpose(1, 2)) 81 | x = self.norm(x) 82 | return x 83 | 84 | class HeadBlock(nn.Module): 85 | def __init__(self, in_chans=512, emb_chans=64, num_heads=2, sr_ratio=4): 86 | super().__init__() 87 | self.proj = ProjEmbed(in_chans=in_chans, emb_chans=emb_chans) 88 | self.norm1 = partial(nn.LayerNorm, eps=1e-6)(emb_chans) 89 | self.attn = Attention(emb_chans, num_heads=num_heads, sr_ratio=sr_ratio) 90 | self.drop_path = nn.Identity() 91 | self.norm2 = partial(nn.LayerNorm, eps=1e-6)(emb_chans) 92 | mlp_ratio = 2 93 | mlp_hidden_dim = int(emb_chans * mlp_ratio) 94 | self.mlp = Mlp(in_features=emb_chans, hidden_features=mlp_hidden_dim, act_layer=nn.Hardswish) 95 | 96 | def forward(self, x): 97 | B, C, H, W = x.shape 98 | x = self.proj(x) 99 | 100 | x = x + self.drop_path(self.attn(self.norm1(x), H, W)) 101 | x = x + self.drop_path(self.mlp(self.norm2(x))) 102 | 103 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 104 | return x 105 | 106 | 107 | 108 | class FPTHead(nn.Module): 109 | def __init__(self, vit_params): 110 | super().__init__() 111 | sr_ratio = [4, 4, 4, 1] 112 | emb_chans = vit_params['emb_chans'] 113 | self.head1 = HeadBlock(in_chans=64, emb_chans=emb_chans, sr_ratio=sr_ratio[0]) 114 | self.head2 = HeadBlock(in_chans=128, emb_chans=emb_chans, sr_ratio=sr_ratio[1]) 115 | self.head3 = HeadBlock(in_chans=320, emb_chans=emb_chans, sr_ratio=sr_ratio[2]) 116 | self.head4 = HeadBlock(in_chans=512, emb_chans=emb_chans, sr_ratio=sr_ratio[3]) 117 | 118 | self.pred = nn.Conv2d(emb_chans, vit_params['nclass'], 1) 119 | 120 | 121 | def forward(self, c1, c2, c3, c4): 122 | size = c1.size()[2:] 123 | 124 | c4 = self.head4(c4) 125 | out = F.interpolate(c4, size, mode='bilinear', align_corners=True) 126 | 127 | c3 = self.head3(c3) 128 | out += F.interpolate(c3, size, mode='bilinear', align_corners=True) 129 | 130 | c2 = self.head2(c2) 131 | out += F.interpolate(c2, size, mode='bilinear', align_corners=True) 132 | 133 | out += self.head1(c1) 134 | out = self.pred(out) 135 | return out 136 | 137 | -------------------------------------------------------------------------------- /segmentron/data/dataloader/seg_data_base.py: -------------------------------------------------------------------------------- 1 | """Base segmentation dataset""" 2 | import os 3 | import random 4 | import numpy as np 5 | import torchvision 6 | 7 | from PIL import Image, ImageOps, ImageFilter 8 | from ...config import cfg 9 | from IPython import embed 10 | 11 | __all__ = ['SegmentationDataset'] 12 | 13 | 14 | class SegmentationDataset(object): 15 | """Segmentation Base Dataset""" 16 | 17 | def __init__(self, root, split, mode, transform, base_size=520, crop_size=480): 18 | super(SegmentationDataset, self).__init__() 19 | self.root = os.path.join(cfg.ROOT_PATH, root) 20 | self.transform = transform 21 | self.split = split 22 | self.mode = mode if mode is not None else split 23 | self.base_size = base_size 24 | self.crop_size = self.to_tuple(crop_size) 25 | self.color_jitter = self._get_color_jitter() 26 | 27 | def to_tuple(self, size): 28 | if isinstance(size, (list, tuple)): 29 | return tuple(size) 30 | elif isinstance(size, (int, float)): 31 | return tuple((size, size)) 32 | else: 33 | raise ValueError('Unsupport datatype: {}'.format(type(size))) 34 | 35 | def _get_color_jitter(self): 36 | color_jitter = cfg.AUG.COLOR_JITTER 37 | if color_jitter is None: 38 | return None 39 | if isinstance(color_jitter, (list, tuple)): 40 | # color jitter should be a 3-tuple/list if spec brightness/contrast/saturation 41 | # or 4 if also augmenting hue 42 | assert len(color_jitter) in (3, 4) 43 | else: 44 | # if it's a scalar, duplicate for brightness, contrast, and saturation, no hue 45 | color_jitter = (float(color_jitter),) * 3 46 | return torchvision.transforms.ColorJitter(*color_jitter) 47 | 48 | def _val_sync_transform(self, img, mask): 49 | outsize = self.crop_size 50 | short_size = min(outsize) 51 | w, h = img.size 52 | if w > h: 53 | oh = short_size 54 | ow = int(1.0 * w * oh / h) 55 | else: 56 | ow = short_size 57 | oh = int(1.0 * h * ow / w) 58 | img = img.resize((ow, oh), Image.BILINEAR) 59 | mask = mask.resize((ow, oh), Image.NEAREST) 60 | # center crop 61 | w, h = img.size 62 | x1 = int(round((w - outsize[1]) / 2.)) 63 | y1 = int(round((h - outsize[0]) / 2.)) 64 | img = img.crop((x1, y1, x1 + outsize[1], y1 + outsize[0])) 65 | mask = mask.crop((x1, y1, x1 + outsize[1], y1 + outsize[0])) 66 | 67 | # final transform 68 | img, mask = self._img_transform(img), self._mask_transform(mask) 69 | return img, mask 70 | 71 | def _sync_transform(self, img, mask, resize=False): 72 | # first resize image to fix size 73 | if resize: 74 | img = img.resize(self.crop_size, Image.BILINEAR) 75 | mask = mask.resize(self.crop_size, Image.NEAREST) 76 | 77 | # random mirror 78 | if cfg.AUG.MIRROR and random.random() < 0.5: 79 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 80 | mask = mask.transpose(Image.FLIP_LEFT_RIGHT) 81 | 82 | # random crop 83 | if cfg.AUG.CROP: 84 | crop_size = self.crop_size 85 | # random scale (short edge) 86 | short_size = random.randint(int(self.base_size * 0.5), int(self.base_size * 2.0)) 87 | w, h = img.size 88 | if h > w: 89 | ow = short_size 90 | oh = int(1.0 * h * ow / w) 91 | else: 92 | oh = short_size 93 | ow = int(1.0 * w * oh / h) 94 | img = img.resize((ow, oh), Image.BILINEAR) 95 | mask = mask.resize((ow, oh), Image.NEAREST) 96 | # pad crop 97 | if short_size < min(crop_size): 98 | padh = crop_size[0] - oh if oh < crop_size[0] else 0 99 | padw = crop_size[1] - ow if ow < crop_size[1] else 0 100 | img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0) 101 | mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=-1) 102 | # random crop crop_size 103 | w, h = img.size 104 | x1 = random.randint(0, w - crop_size[1]) 105 | y1 = random.randint(0, h - crop_size[0]) 106 | img = img.crop((x1, y1, x1 + crop_size[1], y1 + crop_size[0])) 107 | mask = mask.crop((x1, y1, x1 + crop_size[1], y1 + crop_size[0])) 108 | 109 | # gaussian blur as in PSP 110 | if cfg.AUG.BLUR_PROB > 0 and random.random() < cfg.AUG.BLUR_PROB: 111 | radius = cfg.AUG.BLUR_RADIUS if cfg.AUG.BLUR_RADIUS > 0 else random.random() 112 | img = img.filter(ImageFilter.GaussianBlur(radius=radius)) 113 | 114 | # color jitter 115 | if self.color_jitter: 116 | img = self.color_jitter(img) 117 | 118 | # final transform 119 | img, mask = self._img_transform(img), self._mask_transform(mask) 120 | 121 | return img, mask 122 | 123 | def _img_transform(self, img): 124 | return np.array(img) 125 | 126 | def _mask_transform(self, mask): 127 | return np.array(mask).astype('int32') 128 | 129 | @property 130 | def num_class(self): 131 | """Number of categories.""" 132 | return self.NUM_CLASS 133 | 134 | @property 135 | def pred_offset(self): 136 | return 0 137 | -------------------------------------------------------------------------------- /segmentron/data/dataloader/stanford2d3d.py: -------------------------------------------------------------------------------- 1 | """Stanford2d3d Dataset.""" 2 | import os 3 | import logging 4 | import torch 5 | import numpy as np 6 | import glob 7 | import json 8 | from PIL import Image 9 | from segmentron.data.dataloader.seg_data_base import SegmentationDataset 10 | import torchvision 11 | 12 | __FOLD__ = { 13 | '1_train': ['area_1', 'area_2', 'area_3', 'area_4', 'area_6'], 14 | '1_val': ['area_5a', 'area_5b'], 15 | '2_train': ['area_1', 'area_3', 'area_5a', 'area_5b', 'area_6'], 16 | '2_val': ['area_2', 'area_4'], 17 | '3_train': ['area_2', 'area_4', 'area_5a', 'area_5b'], 18 | '3_val': ['area_1', 'area_3', 'area_6'] 19 | } 20 | 21 | class Stanford2d3dSegmentation(SegmentationDataset): 22 | BASE_DIR = '' 23 | NUM_CLASS = 13 24 | fold = 1 25 | 26 | def __init__(self, root='datasets/stanford2d3d', split='test', mode=None, transform=None, **kwargs): 27 | super(Stanford2d3dSegmentation, self).__init__(root, split, mode, transform, **kwargs) 28 | root = os.path.join(self.root, self.BASE_DIR) 29 | assert os.path.exists(root), "Please put the data in {SEG_ROOT}/datasets/" 30 | self.images, self.masks = _get_stanford2d3d_pairs(root, self.fold, split) 31 | assert (len(self.images) == len(self.masks)) 32 | if len(self.images) == 0: 33 | raise RuntimeError("Found 0 images in {}".format(os.path.join(root, split))) 34 | logging.info('Found {} images in the folder {}'.format(len(self.images), os.path.join(root, split))) 35 | with open('semantic_labels.json') as f: 36 | id2name = [name.split('_')[0] for name in json.load(f)] + [''] 37 | with open('name2label.json') as f: 38 | name2id = json.load(f) 39 | self.colors = np.load('colors.npy') 40 | self.id2label = np.array([name2id[name] for name in id2name], np.uint8) 41 | 42 | def _mask_transform(self, mask): 43 | return torch.LongTensor(np.array(mask).astype('int32')) 44 | 45 | def _val_sync_transform_resize(self, img, mask): 46 | short_size = self.crop_size 47 | img = img.resize(short_size, Image.BICUBIC) 48 | mask = mask.resize(short_size, Image.NEAREST) 49 | 50 | # final transform 51 | img, mask = self._img_transform(img), self._mask_transform(mask) 52 | return img, mask 53 | 54 | def __getitem__(self, index): 55 | img = Image.open(self.images[index]).convert('RGB') 56 | if self.mode == 'test': 57 | img = self._img_transform(img) 58 | if self.transform is not None: 59 | img = self.transform(img) 60 | return img, os.path.basename(self.images[index]) 61 | mask = Image.open(self.masks[index]) 62 | mask = _color2id(mask, img, self.id2label) 63 | if self.mode == 'train': 64 | img, mask = self._sync_transform(img, mask, resize=True) 65 | elif self.mode == 'val': 66 | img, mask = self._val_sync_transform_resize(img, mask) 67 | else: 68 | assert self.mode == 'testval' 69 | img, mask = self._val_sync_transform_resize(img, mask) 70 | if self.transform is not None: 71 | img = self.transform(img) 72 | 73 | mask[mask == 255] = -1 74 | return img, mask, os.path.basename(self.images[index]) 75 | 76 | def __len__(self): 77 | return len(self.images) 78 | 79 | @property 80 | def pred_offset(self): 81 | return 1 82 | 83 | @property 84 | def classes(self): 85 | """Category names.""" 86 | return ('beam', 'board', 'bookcase', 'ceiling', 'chair', 87 | 'clutter', 'column', 'door', 'floor', 'sofa', 88 | 'table', 'wall', 'window') 89 | 90 | def _get_stanford2d3d_pairs(folder, fold, mode='train'): 91 | '''image is jpg, label is png''' 92 | img_paths = [] 93 | if mode == 'train': 94 | area_ids = __FOLD__['{}_{}'.format(fold, mode)] 95 | elif mode == 'val': 96 | area_ids = __FOLD__['{}_{}'.format(fold, mode)] 97 | else: 98 | raise NotImplementedError 99 | for a in area_ids: 100 | img_paths += glob.glob(os.path.join(folder, '{}/data/rgb/*_rgb.png'.format(a))) 101 | img_paths = sorted(img_paths) 102 | mask_paths = [imgpath.replace('rgb', 'semantic') for imgpath in img_paths] 103 | return img_paths, mask_paths 104 | 105 | def _color2id(mask, img, id2label): 106 | mask = np.array(mask, np.int32) 107 | unk = (mask[..., 0] != 0) 108 | mask = id2label[mask[..., 1] * 256 + mask[..., 2]] 109 | mask[unk] = 0 110 | mask[np.array(img, np.int8).sum(2) == 0] = 0 111 | mask -= 1 112 | return Image.fromarray(mask) 113 | 114 | 115 | 116 | 117 | if __name__ == '__main__': 118 | from torchvision import transforms 119 | import torch.utils.data as data 120 | input_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((.485, .456, .406), (.229, .224, .225)),]) 121 | trainset = Stanford2d3dSegmentation(split='train', transform=input_transform) 122 | train_data = data.DataLoader(trainset, 2, shuffle=True, num_workers=0) 123 | for i, data in enumerate(train_data): 124 | imgs, targets, _ = data 125 | print(imgs.shape) 126 | if i == 0: 127 | img = torchvision.utils.make_grid(imgs).numpy() 128 | img = np.transpose(img, (1, 2, 0)) 129 | img = img[:, :, ::-1] 130 | plt.imshow(img) 131 | plt.show() 132 | -------------------------------------------------------------------------------- /segmentron/models/pvt_fpt_joint.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from functools import partial 4 | 5 | from .segbase import SegBaseModel 6 | from .model_zoo import MODEL_REGISTRY 7 | from ..modules import _FCNHead 8 | from ..config import cfg 9 | from ..data.dataloader import datasets 10 | from .backbones.pvt import Attention, Mlp 11 | 12 | __all__ = ['PVT_FPT_JOINT'] 13 | 14 | 15 | 16 | @MODEL_REGISTRY.register(name='PVT_FPT_JOINT') 17 | class PVT_FPT_JOINT(SegBaseModel): 18 | """ PVT with 2 joint FPT decoder. """ 19 | 20 | def __init__(self): 21 | super().__init__() 22 | if self.backbone.startswith('mobilenet'): 23 | c1_channels = 24 24 | c4_channels = 320 25 | elif self.backbone.startswith('resnet18'): 26 | c1_channels = 64 27 | c4_channels = 512 28 | elif self.backbone.startswith('pvt'): 29 | c1_channels = 64 30 | c4_channels = 512 31 | elif self.backbone.startswith('resnet34'): 32 | c1_channels = 64 33 | c4_channels = 512 34 | elif self.backbone.startswith('hrnet_w18_small_v1'): 35 | c1_channels = 16 36 | c4_channels = 128 37 | else: 38 | c1_channels = 256 39 | c4_channels = 2048 40 | 41 | vit_params = cfg.MODEL.TRANS4TRANS 42 | hid_dim = cfg.MODEL.TRANS4TRANS.hid_dim 43 | 44 | assert cfg.AUG.CROP == False and cfg.TRAIN.CROP_SIZE[0] == cfg.TRAIN.CROP_SIZE[1]\ 45 | == cfg.TRAIN.BASE_SIZE == cfg.TEST.CROP_SIZE[0] == cfg.TEST.CROP_SIZE[1] 46 | 47 | c4_HxW = (cfg.TRAIN.BASE_SIZE // 32) ** 2 48 | 49 | vit_params['decoder_feat_HxW'] = c4_HxW 50 | vit_params['nclass'] = self.nclass 51 | 52 | self.fpt_head_1 = FPTHead(vit_params, c1_channels=c1_channels, c4_channels=c4_channels, hid_dim=hid_dim) 53 | self.fpt_head_2 = FPTHead(vit_params, c1_channels=c1_channels, c4_channels=c4_channels, hid_dim=hid_dim) 54 | decoders = ['fpt_head_1', 'fpt_head_2'] 55 | if self.aux: 56 | self.auxlayer = _FCNHead(728, self.nclass) 57 | decoders.append('auxlayer') 58 | self.__setattr__('decoder', decoders) 59 | 60 | 61 | def forward(self, x): 62 | size = x.size()[2:] 63 | c1, c2, c3, c4 = self.encoder(x) 64 | 65 | outputs = list() 66 | x_1 = self.fpt_head_1(c1, c2, c3, c4) 67 | x_1 = F.interpolate(x_1, size, mode='bilinear', align_corners=True) 68 | outputs.append(x_1) 69 | 70 | x_2 = self.fpt_head_2(c1, c2, c3, c4) 71 | x_2 = F.interpolate(x_2, size, mode='bilinear', align_corners=True) 72 | outputs.append(x_2) 73 | 74 | if self.aux: 75 | auxout = self.auxlayer(c3) 76 | auxout = F.interpolate(auxout, size, mode='bilinear', align_corners=True) 77 | outputs.append(auxout) 78 | return tuple(outputs) 79 | 80 | class ProjEmbed(nn.Module): 81 | """ feature map to Projected Embedding 82 | """ 83 | def __init__(self, in_chans=512, emb_chans=128): 84 | super().__init__() 85 | self.proj = nn.Linear(in_chans, emb_chans) 86 | self.norm = nn.LayerNorm(emb_chans) 87 | 88 | def forward(self, x): 89 | x = self.proj(x.flatten(2).transpose(1, 2)) 90 | x = self.norm(x) 91 | return x 92 | 93 | class HeadBlock(nn.Module): 94 | def __init__(self, in_chans=512, emb_chans=64, num_heads=2, sr_ratio=4): 95 | super().__init__() 96 | self.proj = ProjEmbed(in_chans=in_chans, emb_chans=emb_chans) 97 | self.norm1 = partial(nn.LayerNorm, eps=1e-6)(emb_chans) 98 | self.attn = Attention(emb_chans, num_heads=num_heads, sr_ratio=sr_ratio) 99 | self.drop_path = nn.Identity() 100 | self.norm2 = partial(nn.LayerNorm, eps=1e-6)(emb_chans) 101 | mlp_ratio = 2 102 | mlp_hidden_dim = int(emb_chans * mlp_ratio) 103 | self.mlp = Mlp(in_features=emb_chans, hidden_features=mlp_hidden_dim, act_layer=nn.Hardswish) 104 | 105 | def forward(self, x): 106 | B, C, H, W = x.shape 107 | x = self.proj(x) 108 | 109 | x = x + self.drop_path(self.attn(self.norm1(x), H, W)) 110 | x = x + self.drop_path(self.mlp(self.norm2(x))) 111 | 112 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 113 | return x 114 | 115 | 116 | 117 | class FPTHead(nn.Module): 118 | def __init__(self, vit_params, c1_channels=256, c4_channels=2048, hid_dim=64, norm_layer=nn.BatchNorm2d): 119 | super().__init__() 120 | sr_ratio = [4, 4, 4, 1] 121 | emb_chans = 64 122 | self.head1 = HeadBlock(in_chans=64, emb_chans=emb_chans, sr_ratio=sr_ratio[0]) 123 | self.head2 = HeadBlock(in_chans=128, emb_chans=emb_chans, sr_ratio=sr_ratio[1]) 124 | self.head3 = HeadBlock(in_chans=320, emb_chans=emb_chans, sr_ratio=sr_ratio[2]) 125 | self.head4 = HeadBlock(in_chans=512, emb_chans=emb_chans, sr_ratio=sr_ratio[3]) 126 | 127 | 128 | def forward(self, c1, c2, c3, c4): 129 | size = c1.size()[2:] 130 | 131 | c4 = self.head4(c4) 132 | out = F.interpolate(c4, size, mode='bilinear', align_corners=True) 133 | 134 | c3 = self.head3(c3) 135 | out += F.interpolate(c3, size, mode='bilinear', align_corners=True) 136 | 137 | c2 = self.head2(c2) 138 | out += F.interpolate(c2, size, mode='bilinear', align_corners=True) 139 | 140 | out += self.head1(c1) 141 | return out 142 | 143 | -------------------------------------------------------------------------------- /segmentron/data/dataloader/cityscapes.py: -------------------------------------------------------------------------------- 1 | """Prepare Cityscapes dataset""" 2 | import os 3 | import torch 4 | import numpy as np 5 | import logging 6 | 7 | from PIL import Image 8 | from .seg_data_base import SegmentationDataset 9 | 10 | 11 | class CitySegmentation(SegmentationDataset): 12 | BASE_DIR = 'cityscapes' 13 | NUM_CLASS = 19 14 | 15 | def __init__(self, root='datasets/cityscapes', split='train', mode=None, transform=None, **kwargs): 16 | super(CitySegmentation, self).__init__(root, split, mode, transform, **kwargs) 17 | assert os.path.exists(self.root), "Please put dataset in {SEG_ROOT}/datasets/cityscapes" 18 | self.images, self.mask_paths = _get_city_pairs(self.root, self.split) 19 | assert (len(self.images) == len(self.mask_paths)) 20 | if len(self.images) == 0: 21 | raise RuntimeError("Found 0 images in subfolders of:" + root + "\n") 22 | self.valid_classes = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23 | 23, 24, 25, 26, 27, 28, 31, 32, 33] 24 | self._key = np.array([-1, -1, -1, -1, -1, -1, 25 | -1, -1, 0, 1, -1, -1, 26 | 2, 3, 4, -1, -1, -1, 27 | 5, -1, 6, 7, 8, 9, 28 | 10, 11, 12, 13, 14, 15, 29 | -1, -1, 16, 17, 18]) 30 | self._mapping = np.array(range(-1, len(self._key) - 1)).astype('int32') 31 | 32 | def _class_to_index(self, mask): 33 | # assert the value 34 | values = np.unique(mask) 35 | for value in values: 36 | assert (value in self._mapping) 37 | index = np.digitize(mask.ravel(), self._mapping, right=True) 38 | return self._key[index].reshape(mask.shape) 39 | 40 | def __getitem__(self, index): 41 | img = Image.open(self.images[index]).convert('RGB') 42 | if self.mode == 'test': 43 | if self.transform is not None: 44 | img = self.transform(img) 45 | return img, os.path.basename(self.images[index]) 46 | mask = Image.open(self.mask_paths[index]) 47 | if self.mode == 'train': 48 | img, mask = self._sync_transform(img, mask) 49 | elif self.mode == 'val': 50 | img, mask = self._val_sync_transform(img, mask) 51 | else: 52 | assert self.mode == 'testval' 53 | img, mask = self._img_transform(img), self._mask_transform(mask) 54 | if self.transform is not None: 55 | img = self.transform(img) 56 | return img, mask, os.path.basename(self.images[index]) 57 | 58 | def _mask_transform(self, mask): 59 | target = self._class_to_index(np.array(mask).astype('int32')) 60 | return torch.LongTensor(np.array(target).astype('int32')) 61 | 62 | def __len__(self): 63 | return len(self.images) 64 | 65 | @property 66 | def pred_offset(self): 67 | return 0 68 | 69 | @property 70 | def classes(self): 71 | """Category names.""" 72 | return ('road', 'sidewalk', 'building', 'wall', 'fence', 'pole', 'traffic light', 73 | 'traffic sign', 'vegetation', 'terrain', 'sky', 'person', 'rider', 'car', 74 | 'truck', 'bus', 'train', 'motorcycle', 'bicycle') 75 | 76 | 77 | def _get_city_pairs(folder, split='train'): 78 | def get_path_pairs(img_folder, mask_folder): 79 | img_paths = [] 80 | mask_paths = [] 81 | for root, _, files in os.walk(img_folder): 82 | for filename in files: 83 | if filename.startswith('._'): 84 | continue 85 | if filename.endswith('.png'): 86 | imgpath = os.path.join(root, filename) 87 | foldername = os.path.basename(os.path.dirname(imgpath)) 88 | maskname = filename.replace('leftImg8bit', 'gtFine_labelIds') 89 | maskpath = os.path.join(mask_folder, foldername, maskname) 90 | if os.path.isfile(imgpath) and os.path.isfile(maskpath): 91 | img_paths.append(imgpath) 92 | mask_paths.append(maskpath) 93 | else: 94 | logging.info('cannot find the mask or image:', imgpath, maskpath) 95 | logging.info('Found {} images in the folder {}'.format(len(img_paths), img_folder)) 96 | return img_paths, mask_paths 97 | 98 | if split in ('train', 'val'): 99 | img_folder = os.path.join(folder, 'leftImg8bit/' + split) 100 | mask_folder = os.path.join(folder, 'gtFine/' + split) 101 | img_paths, mask_paths = get_path_pairs(img_folder, mask_folder) 102 | return img_paths, mask_paths 103 | else: 104 | assert split == 'trainval' 105 | logging.info('trainval set') 106 | train_img_folder = os.path.join(folder, 'leftImg8bit/train') 107 | train_mask_folder = os.path.join(folder, 'gtFine/train') 108 | val_img_folder = os.path.join(folder, 'leftImg8bit/val') 109 | val_mask_folder = os.path.join(folder, 'gtFine/val') 110 | train_img_paths, train_mask_paths = get_path_pairs(train_img_folder, train_mask_folder) 111 | val_img_paths, val_mask_paths = get_path_pairs(val_img_folder, val_mask_folder) 112 | img_paths = train_img_paths + val_img_paths 113 | mask_paths = train_mask_paths + val_mask_paths 114 | return img_paths, mask_paths 115 | 116 | 117 | if __name__ == '__main__': 118 | dataset = CitySegmentation() 119 | -------------------------------------------------------------------------------- /segmentron/models/pvtv2_mit_fpt.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from functools import partial 4 | from timm.models.layers import trunc_normal_ 5 | import math 6 | from .segbase import SegBaseModel 7 | from .model_zoo import MODEL_REGISTRY 8 | from ..modules import _FCNHead 9 | from ..config import cfg 10 | from .backbones.pvtv2_mix_transformer import Attention, Mlp 11 | 12 | 13 | __all__ = ['PVTV2_MIT_FPT'] 14 | 15 | 16 | @MODEL_REGISTRY.register(name='PVTV2_MIT_FPT') 17 | class PVTV2_MIT_FPT(SegBaseModel): 18 | 19 | def __init__(self): 20 | super().__init__() 21 | if self.backbone.startswith('mobilenet'): 22 | c1_channels = 24 23 | c4_channels = 320 24 | elif self.backbone.startswith('resnet18'): 25 | c1_channels = 64 26 | c4_channels = 512 27 | elif self.backbone.startswith('mit'): 28 | c1_channels = 64 29 | c4_channels = 512 30 | elif self.backbone.startswith('resnet34'): 31 | c1_channels = 64 32 | c4_channels = 512 33 | elif self.backbone.startswith('hrnet_w18_small_v1'): 34 | c1_channels = 16 35 | c4_channels = 128 36 | else: 37 | c1_channels = 256 38 | c4_channels = 2048 39 | 40 | vit_params = cfg.MODEL.TRANS4TRANS 41 | c4_HxW = (cfg.TRAIN.BASE_SIZE // 32) ** 2 42 | 43 | vit_params['decoder_feat_HxW'] = c4_HxW 44 | vit_params['nclass'] = self.nclass 45 | vit_params['emb_chans'] = cfg.MODEL.EMB_CHANNELS 46 | 47 | self.fpt_head = FPTHead(vit_params) 48 | if self.aux: 49 | self.auxlayer = _FCNHead(728, self.nclass) 50 | self.__setattr__('decoder', ['fpt_head', 'auxlayer'] if self.aux else ['fpt_head']) 51 | 52 | 53 | def forward(self, x): 54 | size = x.size()[2:] 55 | c1, c2, c3, c4 = self.encoder(x) 56 | 57 | outputs = list() 58 | x = self.fpt_head(c1, c2, c3, c4) 59 | x = F.interpolate(x, size, mode='bilinear', align_corners=True) 60 | 61 | outputs.append(x) 62 | if self.aux: 63 | auxout = self.auxlayer(c3) 64 | auxout = F.interpolate(auxout, size, mode='bilinear', align_corners=True) 65 | outputs.append(auxout) 66 | return tuple(outputs) 67 | 68 | class ProjEmbed(nn.Module): 69 | """ feature map to Projected Embedding 70 | """ 71 | def __init__(self, in_chans=512, emb_chans=128): 72 | super().__init__() 73 | self.proj = nn.Linear(in_chans, emb_chans) 74 | self.norm = nn.LayerNorm(emb_chans) 75 | self.apply(self._init_weights) 76 | 77 | 78 | def _init_weights(self, m): 79 | if isinstance(m, nn.Linear): 80 | trunc_normal_(m.weight, std=.02) 81 | if isinstance(m, nn.Linear) and m.bias is not None: 82 | nn.init.constant_(m.bias, 0) 83 | elif isinstance(m, nn.LayerNorm): 84 | nn.init.constant_(m.bias, 0) 85 | nn.init.constant_(m.weight, 1.0) 86 | elif isinstance(m, nn.Conv2d): 87 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 88 | fan_out //= m.groups 89 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 90 | if m.bias is not None: 91 | m.bias.data.zero_() 92 | def forward(self, x): 93 | x = self.proj(x.flatten(2).transpose(1, 2)) 94 | x = self.norm(x) 95 | return x 96 | 97 | class HeadBlock(nn.Module): 98 | def __init__(self, in_chans=512, emb_chans=64, num_heads=2, sr_ratio=4): 99 | super().__init__() 100 | self.proj = ProjEmbed(in_chans=in_chans, emb_chans=emb_chans) 101 | self.norm1 = partial(nn.LayerNorm, eps=1e-6)(emb_chans) 102 | self.attn = Attention(emb_chans, num_heads=num_heads, sr_ratio=sr_ratio) 103 | self.drop_path = nn.Identity() 104 | self.norm2 = partial(nn.LayerNorm, eps=1e-6)(emb_chans) 105 | mlp_ratio = 2 106 | mlp_hidden_dim = int(emb_chans * mlp_ratio) 107 | self.mlp = Mlp(in_features=emb_chans, hidden_features=mlp_hidden_dim, act_layer=nn.Hardswish) 108 | 109 | def forward(self, x): 110 | B, C, H, W = x.shape 111 | x = self.proj(x) 112 | 113 | x = x + self.drop_path(self.attn(self.norm1(x), H, W)) 114 | x = x + self.drop_path(self.mlp(self.norm2(x), H, W)) 115 | 116 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 117 | return x 118 | 119 | 120 | 121 | class FPTHead(nn.Module): 122 | def __init__(self, vit_params): 123 | super().__init__() 124 | sr_ratio = [4, 4, 4, 1] 125 | emb_chans = vit_params['emb_chans'] 126 | self.head1 = HeadBlock(in_chans=64, emb_chans=emb_chans, sr_ratio=sr_ratio[0]) 127 | self.head2 = HeadBlock(in_chans=128, emb_chans=emb_chans, sr_ratio=sr_ratio[1]) 128 | self.head3 = HeadBlock(in_chans=320, emb_chans=emb_chans, sr_ratio=sr_ratio[2]) 129 | self.head4 = HeadBlock(in_chans=512, emb_chans=emb_chans, sr_ratio=sr_ratio[3]) 130 | 131 | self.pred = nn.Conv2d(emb_chans, vit_params['nclass'], 1) 132 | 133 | 134 | def forward(self, c1, c2, c3, c4): 135 | size = c1.size()[2:] 136 | 137 | c4 = self.head4(c4) 138 | out = F.interpolate(c4, size, mode='bilinear', align_corners=True) 139 | 140 | c3 = self.head3(c3) 141 | out += F.interpolate(c3, size, mode='bilinear', align_corners=True) 142 | 143 | c2 = self.head2(c2) 144 | out += F.interpolate(c2, size, mode='bilinear', align_corners=True) 145 | 146 | out += self.head1(c1) 147 | out = self.pred(out) 148 | return out 149 | 150 | -------------------------------------------------------------------------------- /segmentron/data/dataloader/densepass.py: -------------------------------------------------------------------------------- 1 | """Prepare DensePASS dataset""" 2 | import os 3 | import torch 4 | import numpy as np 5 | import logging 6 | 7 | import torchvision 8 | from PIL import Image 9 | from segmentron.data.dataloader.seg_data_base import SegmentationDataset 10 | import random 11 | from torch.utils import data 12 | 13 | 14 | class DensePASSSegmentation(SegmentationDataset): 15 | BASE_DIR = 'DensePASS' 16 | NUM_CLASS = 19 17 | 18 | def __init__(self, root='datasets/DensePASS', split='val', mode=None, transform=None, **kwargs): 19 | super(DensePASSSegmentation, self).__init__(root, split, mode, transform, **kwargs) 20 | assert os.path.exists(self.root), "Please put dataset in {SEG_ROOT}/datasets/DensePASS_train_pseudo_val" 21 | self.images, self.mask_paths = _get_city_pairs(self.root, self.split) 22 | self.crop_size = [400, 2048] 23 | assert (len(self.images) == len(self.mask_paths)) 24 | if len(self.images) == 0: 25 | raise RuntimeError("Found 0 images in subfolders of:" + root + "\n") 26 | self.valid_classes = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 27 | 23, 24, 25, 26, 27, 28, 31, 32, 33] 28 | self._key = np.array([-1, -1, -1, -1, -1, -1, 29 | -1, -1, 0, 1, -1, -1, 30 | 2, 3, 4, -1, -1, -1, 31 | 5, -1, 6, 7, 8, 9, 32 | 10, 11, 12, 13, 14, 15, 33 | -1, -1, 16, 17, 18]) 34 | self._mapping = np.array(range(-1, len(self._key) - 1)).astype('int32') 35 | 36 | def _class_to_index(self, mask): 37 | values = np.unique(mask) 38 | for value in values: 39 | assert (value in self._mapping) 40 | index = np.digitize(mask.ravel(), self._mapping, right=True) 41 | return self._key[index].reshape(mask.shape) 42 | def _val_sync_transform_resize(self, img, mask): 43 | w, h = img.size 44 | x1 = random.randint(0, w - self.crop_size[1]) 45 | y1 = random.randint(0, h - self.crop_size[0]) 46 | img = img.crop((x1, y1, x1 + self.crop_size[1], y1 + self.crop_size[0])) 47 | mask = mask.crop((x1, y1, x1 + self.crop_size[1], y1 + self.crop_size[0])) 48 | 49 | img, mask = self._img_transform(img), self._mask_transform(mask) 50 | return img, mask 51 | 52 | def __getitem__(self, index): 53 | img = Image.open(self.images[index]).convert('RGB') 54 | 55 | if self.mode == 'test': 56 | if self.transform is not None: 57 | img = self.transform(img) 58 | return img, os.path.basename(self.images[index]) 59 | mask = Image.open(self.mask_paths[index]) 60 | if self.mode == 'train': 61 | img, mask = self._sync_transform(img, mask, resize=True) 62 | elif self.mode == 'val': 63 | img, mask = self._val_sync_transform_resize(img, mask) 64 | else: 65 | assert self.mode == 'testval' 66 | img, mask = self._val_sync_transform_resize(img, mask) 67 | if self.transform is not None: 68 | img = self.transform(img) 69 | return img, mask, os.path.basename(self.images[index]) 70 | 71 | def _mask_transform(self, mask): 72 | return torch.LongTensor(np.array(mask).astype('int32')) 73 | 74 | def __len__(self): 75 | return len(self.images) 76 | 77 | @property 78 | def pred_offset(self): 79 | return 0 80 | 81 | @property 82 | def classes(self): 83 | """Category names.""" 84 | return ('road', 'sidewalk', 'building', 'wall', 'fence', 'pole', 'traffic light', 85 | 'traffic sign', 'vegetation', 'terrain', 'sky', 'person', 'rider', 'car', 86 | 'truck', 'bus', 'train', 'motorcycle', 'bicycle') 87 | 88 | 89 | def _get_city_pairs(folder, split='train'): 90 | def get_path_pairs(img_folder, mask_folder): 91 | img_paths = [] 92 | mask_paths = [] 93 | for root, _, files in os.walk(img_folder): 94 | for filename in files: 95 | if filename.startswith('._'): 96 | continue 97 | if filename.endswith('.png'): 98 | imgpath = os.path.join(root, filename) 99 | foldername = os.path.basename(os.path.dirname(imgpath)) 100 | maskname = filename.replace('_.png', '_labelTrainIds.png') 101 | maskpath = os.path.join(mask_folder, foldername, maskname) 102 | if os.path.isfile(imgpath) and os.path.isfile(maskpath): 103 | img_paths.append(imgpath) 104 | mask_paths.append(maskpath) 105 | else: 106 | logging.info('cannot find the mask or image:', imgpath, maskpath) 107 | logging.info('Found {} images in the folder {}'.format(len(img_paths), img_folder)) 108 | return img_paths, mask_paths 109 | 110 | if split in ('train', 'val'): 111 | img_folder = os.path.join(folder, 'leftImg8bit/' + split) 112 | mask_folder = os.path.join(folder, 'gtFine/' + split) 113 | img_paths, mask_paths = get_path_pairs(img_folder, mask_folder) 114 | return img_paths, mask_paths 115 | else: 116 | assert split == 'test' 117 | logging.info('test set, but only val set') 118 | val_img_folder = os.path.join(folder, 'leftImg8bit/val') 119 | val_mask_folder = os.path.join(folder, 'gtFine/val') 120 | img_paths, mask_paths = get_path_pairs(val_img_folder, val_mask_folder) 121 | 122 | return img_paths, mask_paths 123 | 124 | 125 | if __name__ == '__main__': 126 | dst = DensePASSSegmentation(split='train', mode='train') 127 | trainloader = data.DataLoader(dst, batch_size=1) 128 | for i, data in enumerate(trainloader): 129 | imgs, labels, *args = data 130 | break -------------------------------------------------------------------------------- /segmentron/modules/sync_bn/syncbn.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Hang Zhang 3 | ## ECE Department, Rutgers University 4 | ## Email: zhang.hang@rutgers.edu 5 | ## Copyright (c) 2017 6 | ## 7 | ## This source code is licensed under the MIT-style license found in the 8 | ## LICENSE file in the root directory of this source tree 9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 10 | 11 | """Synchronized Cross-GPU Batch Normalization Module""" 12 | import warnings 13 | import torch 14 | 15 | from torch.nn.modules.batchnorm import _BatchNorm 16 | from queue import Queue 17 | from .functions import * 18 | 19 | __all__ = ['SyncBatchNorm', 'BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d'] 20 | 21 | 22 | # Adopt from https://github.com/zhanghang1989/PyTorch-Encoding/blob/master/encoding/nn/syncbn.py 23 | class SyncBatchNorm(_BatchNorm): 24 | """Cross-GPU Synchronized Batch normalization (SyncBN) 25 | 26 | Parameters: 27 | num_features: num_features from an expected input of 28 | size batch_size x num_features x height x width 29 | eps: a value added to the denominator for numerical stability. 30 | Default: 1e-5 31 | momentum: the value used for the running_mean and running_var 32 | computation. Default: 0.1 33 | sync: a boolean value that when set to ``True``, synchronize across 34 | different gpus. Default: ``True`` 35 | activation : str 36 | Name of the activation functions, one of: `leaky_relu` or `none`. 37 | slope : float 38 | Negative slope for the `leaky_relu` activation. 39 | 40 | Shape: 41 | - Input: :math:`(N, C, H, W)` 42 | - Output: :math:`(N, C, H, W)` (same shape as input) 43 | Reference: 44 | .. [1] Ioffe, Sergey, and Christian Szegedy. "Batch normalization: Accelerating deep network training by reducing internal covariate shift." *ICML 2015* 45 | .. [2] Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi, and Amit Agrawal. "Context Encoding for Semantic Segmentation." *CVPR 2018* 46 | Examples: 47 | >>> m = SyncBatchNorm(100) 48 | >>> net = torch.nn.DataParallel(m) 49 | >>> output = net(input) 50 | """ 51 | 52 | def __init__(self, num_features, eps=1e-5, momentum=0.1, sync=True, activation='none', slope=0.01, inplace=True): 53 | super(SyncBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=True) 54 | self.activation = activation 55 | self.inplace = False if activation == 'none' else inplace 56 | self.slope = slope 57 | self.devices = list(range(torch.cuda.device_count())) 58 | self.sync = sync if len(self.devices) > 1 else False 59 | # Initialize queues 60 | self.worker_ids = self.devices[1:] 61 | self.master_queue = Queue(len(self.worker_ids)) 62 | self.worker_queues = [Queue(1) for _ in self.worker_ids] 63 | 64 | def forward(self, x): 65 | # resize the input to (B, C, -1) 66 | input_shape = x.size() 67 | x = x.view(input_shape[0], self.num_features, -1) 68 | if x.get_device() == self.devices[0]: 69 | # Master mode 70 | extra = { 71 | "is_master": True, 72 | "master_queue": self.master_queue, 73 | "worker_queues": self.worker_queues, 74 | "worker_ids": self.worker_ids 75 | } 76 | else: 77 | # Worker mode 78 | extra = { 79 | "is_master": False, 80 | "master_queue": self.master_queue, 81 | "worker_queue": self.worker_queues[self.worker_ids.index(x.get_device())] 82 | } 83 | if self.inplace: 84 | return inp_syncbatchnorm(x, self.weight, self.bias, self.running_mean, self.running_var, 85 | extra, self.sync, self.training, self.momentum, self.eps, 86 | self.activation, self.slope).view(input_shape) 87 | else: 88 | return syncbatchnorm(x, self.weight, self.bias, self.running_mean, self.running_var, 89 | extra, self.sync, self.training, self.momentum, self.eps, 90 | self.activation, self.slope).view(input_shape) 91 | 92 | def extra_repr(self): 93 | if self.activation == 'none': 94 | return 'sync={}'.format(self.sync) 95 | else: 96 | return 'sync={}, act={}, slope={}, inplace={}'.format( 97 | self.sync, self.activation, self.slope, self.inplace) 98 | 99 | 100 | class BatchNorm1d(SyncBatchNorm): 101 | """BatchNorm1d is deprecated in favor of :class:`core.nn.sync_bn.SyncBatchNorm`.""" 102 | 103 | def __init__(self, *args, **kwargs): 104 | warnings.warn("core.nn.sync_bn.{} is now deprecated in favor of core.nn.sync_bn.{}." 105 | .format('BatchNorm1d', SyncBatchNorm.__name__), DeprecationWarning) 106 | super(BatchNorm1d, self).__init__(*args, **kwargs) 107 | 108 | 109 | class BatchNorm2d(SyncBatchNorm): 110 | """BatchNorm1d is deprecated in favor of :class:`core.nn.sync_bn.SyncBatchNorm`.""" 111 | 112 | def __init__(self, *args, **kwargs): 113 | warnings.warn("core.nn.sync_bn.{} is now deprecated in favor of core.nn.sync_bn.{}." 114 | .format('BatchNorm2d', SyncBatchNorm.__name__), DeprecationWarning) 115 | super(BatchNorm2d, self).__init__(*args, **kwargs) 116 | 117 | 118 | class BatchNorm3d(SyncBatchNorm): 119 | """BatchNorm1d is deprecated in favor of :class:`core.nn.sync_bn.SyncBatchNorm`.""" 120 | 121 | def __init__(self, *args, **kwargs): 122 | warnings.warn("core.nn.sync_bn.{} is now deprecated in favor of core.nn.sync_bn.{}." 123 | .format('BatchNorm3d', SyncBatchNorm.__name__), DeprecationWarning) 124 | super(BatchNorm3d, self).__init__(*args, **kwargs) 125 | -------------------------------------------------------------------------------- /segmentron/utils/parallel.py: -------------------------------------------------------------------------------- 1 | """Utils for Semantic Segmentation""" 2 | import threading 3 | import torch 4 | import torch.cuda.comm as comm 5 | from torch.nn.parallel.data_parallel import DataParallel 6 | from torch.nn.parallel._functions import Broadcast 7 | from torch.autograd import Function 8 | 9 | __all__ = ['DataParallelModel', 'DataParallelCriterion'] 10 | 11 | 12 | class Reduce(Function): 13 | @staticmethod 14 | def forward(ctx, *inputs): 15 | ctx.target_gpus = [inputs[i].get_device() for i in range(len(inputs))] 16 | inputs = sorted(inputs, key=lambda i: i.get_device()) 17 | return comm.reduce_add(inputs) 18 | 19 | @staticmethod 20 | def backward(ctx, gradOutputs): 21 | return Broadcast.apply(ctx.target_gpus, gradOutputs) 22 | 23 | 24 | class DataParallelModel(DataParallel): 25 | """Data parallelism 26 | 27 | Hide the difference of single/multiple GPUs to the user. 28 | In the forward pass, the module is replicated on each device, 29 | and each replica handles a portion of the input. During the backwards 30 | pass, gradients from each replica are summed into the original module. 31 | 32 | The batch size should be larger than the number of GPUs used. 33 | 34 | Parameters 35 | ---------- 36 | module : object 37 | Network to be parallelized. 38 | sync : bool 39 | enable synchronization (default: False). 40 | Inputs: 41 | - **inputs**: list of input 42 | Outputs: 43 | - **outputs**: list of output 44 | Example:: 45 | >>> net = DataParallelModel(model, device_ids=[0, 1, 2]) 46 | >>> output = net(input_var) # input_var can be on any device, including CPU 47 | """ 48 | 49 | def gather(self, outputs, output_device): 50 | return outputs 51 | 52 | def replicate(self, module, device_ids): 53 | modules = super(DataParallelModel, self).replicate(module, device_ids) 54 | return modules 55 | 56 | 57 | # Reference: https://github.com/zhanghang1989/PyTorch-Encoding/blob/master/encoding/parallel.py 58 | class DataParallelCriterion(DataParallel): 59 | """ 60 | Calculate loss in multiple-GPUs, which balance the memory usage for 61 | Semantic Segmentation. 62 | 63 | The targets are splitted across the specified devices by chunking in 64 | the batch dimension. Please use together with :class:`encoding.parallel.DataParallelModel`. 65 | 66 | Example:: 67 | >>> net = DataParallelModel(model, device_ids=[0, 1, 2]) 68 | >>> criterion = DataParallelCriterion(criterion, device_ids=[0, 1, 2]) 69 | >>> y = net(x) 70 | >>> loss = criterion(y, target) 71 | """ 72 | 73 | def forward(self, inputs, *targets, **kwargs): 74 | # the inputs should be the outputs of DataParallelModel 75 | if not self.device_ids: 76 | return self.module(inputs, *targets, **kwargs) 77 | targets, kwargs = self.scatter(targets, kwargs, self.device_ids) 78 | if len(self.device_ids) == 1: 79 | return self.module(inputs, *targets[0], **kwargs[0]) 80 | replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) 81 | outputs = criterion_parallel_apply(replicas, inputs, targets, kwargs) 82 | return Reduce.apply(*outputs) / len(outputs) 83 | 84 | 85 | def get_a_var(obj): 86 | if isinstance(obj, torch.Tensor): 87 | return obj 88 | 89 | if isinstance(obj, list) or isinstance(obj, tuple): 90 | for result in map(get_a_var, obj): 91 | if isinstance(result, torch.Tensor): 92 | return result 93 | 94 | if isinstance(obj, dict): 95 | for result in map(get_a_var, obj.items()): 96 | if isinstance(result, torch.Tensor): 97 | return result 98 | return None 99 | 100 | 101 | def criterion_parallel_apply(modules, inputs, targets, kwargs_tup=None, devices=None): 102 | r"""Applies each `module` in :attr:`modules` in parallel on arguments 103 | contained in :attr:`inputs` (positional), attr:'targets' (positional) and :attr:`kwargs_tup` (keyword) 104 | on each of :attr:`devices`. 105 | 106 | Args: 107 | modules (Module): modules to be parallelized 108 | inputs (tensor): inputs to the modules 109 | targets (tensor): targets to the modules 110 | devices (list of int or torch.device): CUDA devices 111 | :attr:`modules`, :attr:`inputs`, :attr:'targets' :attr:`kwargs_tup` (if given), and 112 | :attr:`devices` (if given) should all have same length. Moreover, each 113 | element of :attr:`inputs` can either be a single object as the only argument 114 | to a module, or a collection of positional arguments. 115 | """ 116 | assert len(modules) == len(inputs) 117 | assert len(targets) == len(inputs) 118 | if kwargs_tup is not None: 119 | assert len(modules) == len(kwargs_tup) 120 | else: 121 | kwargs_tup = ({},) * len(modules) 122 | if devices is not None: 123 | assert len(modules) == len(devices) 124 | else: 125 | devices = [None] * len(modules) 126 | lock = threading.Lock() 127 | results = {} 128 | grad_enabled = torch.is_grad_enabled() 129 | 130 | def _worker(i, module, input, target, kwargs, device=None): 131 | torch.set_grad_enabled(grad_enabled) 132 | if device is None: 133 | device = get_a_var(input).get_device() 134 | try: 135 | with torch.cuda.device(device): 136 | output = module(*(list(input) + target), **kwargs) 137 | with lock: 138 | results[i] = output 139 | except Exception as e: 140 | with lock: 141 | results[i] = e 142 | 143 | if len(modules) > 1: 144 | threads = [threading.Thread(target=_worker, 145 | args=(i, module, input, target, kwargs, device)) 146 | for i, (module, input, target, kwargs, device) in 147 | enumerate(zip(modules, inputs, targets, kwargs_tup, devices))] 148 | 149 | for thread in threads: 150 | thread.start() 151 | for thread in threads: 152 | thread.join() 153 | else: 154 | _worker(0, modules[0], inputs[0], targets[0], kwargs_tup[0], devices[0]) 155 | 156 | outputs = [] 157 | for i in range(len(inputs)): 158 | output = results[i] 159 | if isinstance(output, Exception): 160 | raise output 161 | outputs.append(output) 162 | return outputs 163 | -------------------------------------------------------------------------------- /segmentron/models/pvt2_mit_fpt_joint.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from functools import partial 4 | from timm.models.layers import trunc_normal_ 5 | import math 6 | from .segbase import SegBaseModel 7 | from .model_zoo import MODEL_REGISTRY 8 | from ..modules import _FCNHead 9 | from ..config import cfg 10 | from ..data.dataloader import datasets 11 | from .backbones.pvtv2_mix_transformer import Attention, Mlp 12 | 13 | 14 | __all__ = ['PVTV2_FPT_JOINT'] 15 | 16 | 17 | 18 | @MODEL_REGISTRY.register(name='PVTV2_FPT_JOINT') 19 | class PVTV2_FPT_JOINT(SegBaseModel): 20 | """ PVTv2 with 2 joint FPT decoder. """ 21 | 22 | def __init__(self): 23 | super().__init__() 24 | if self.backbone.startswith('mobilenet'): 25 | c1_channels = 24 26 | c4_channels = 320 27 | elif self.backbone.startswith('resnet18'): 28 | c1_channels = 64 29 | c4_channels = 512 30 | elif self.backbone.startswith('pvt'): 31 | c1_channels = 64 32 | c4_channels = 512 33 | elif self.backbone.startswith('resnet34'): 34 | c1_channels = 64 35 | c4_channels = 512 36 | elif self.backbone.startswith('hrnet_w18_small_v1'): 37 | c1_channels = 16 38 | c4_channels = 128 39 | else: 40 | c1_channels = 256 41 | c4_channels = 2048 42 | 43 | vit_params = cfg.MODEL.TRANS4TRANS 44 | hid_dim = cfg.MODEL.TRANS4TRANS.hid_dim 45 | 46 | assert cfg.TRAIN.CROP_SIZE[0] == cfg.TRAIN.CROP_SIZE[1]\ 47 | == cfg.TRAIN.BASE_SIZE == cfg.TEST.CROP_SIZE[0] == cfg.TEST.CROP_SIZE[1] 48 | c4_HxW = (cfg.TRAIN.BASE_SIZE // 32) ** 2 49 | 50 | vit_params['decoder_feat_HxW'] = c4_HxW 51 | vit_params['emb_chans'] = cfg.MODEL.EMB_CHANNELS 52 | 53 | vit_params['nclass'] = self.nclass 54 | self.fpt_head_1 = FPTHead(vit_params, c1_channels=c1_channels, c4_channels=c4_channels, hid_dim=hid_dim) 55 | vit_params['nclass'] = datasets[cfg.DATASET2.NAME].NUM_CLASS 56 | self.fpt_head_2 = FPTHead(vit_params, c1_channels=c1_channels, c4_channels=c4_channels, hid_dim=hid_dim) 57 | decoders = ['fpt_head_1', 'fpt_head_2'] 58 | if self.aux: 59 | self.auxlayer = _FCNHead(728, self.nclass) 60 | decoders.append('auxlayer') 61 | self.__setattr__('decoder', decoders) 62 | 63 | 64 | def forward(self, x): 65 | size = x.size()[2:] 66 | c1, c2, c3, c4 = self.encoder(x) 67 | 68 | outputs = list() 69 | x_1 = self.fpt_head_1(c1, c2, c3, c4) 70 | x_1 = F.interpolate(x_1, size, mode='bilinear', align_corners=True) 71 | outputs.append(x_1) 72 | 73 | x_2 = self.fpt_head_2(c1, c2, c3, c4) 74 | x_2 = F.interpolate(x_2, size, mode='bilinear', align_corners=True) 75 | outputs.append(x_2) 76 | 77 | if self.aux: 78 | auxout = self.auxlayer(c3) 79 | auxout = F.interpolate(auxout, size, mode='bilinear', align_corners=True) 80 | outputs.append(auxout) 81 | return tuple(outputs) 82 | 83 | class ProjEmbed(nn.Module): 84 | """ feature map to Projected Embedding 85 | """ 86 | def __init__(self, in_chans=512, emb_chans=128): 87 | super().__init__() 88 | self.proj = nn.Linear(in_chans, emb_chans) 89 | self.norm = nn.LayerNorm(emb_chans) 90 | def _init_weights(self, m): 91 | if isinstance(m, nn.Linear): 92 | trunc_normal_(m.weight, std=.02) 93 | if isinstance(m, nn.Linear) and m.bias is not None: 94 | nn.init.constant_(m.bias, 0) 95 | elif isinstance(m, nn.LayerNorm): 96 | nn.init.constant_(m.bias, 0) 97 | nn.init.constant_(m.weight, 1.0) 98 | elif isinstance(m, nn.Conv2d): 99 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 100 | fan_out //= m.groups 101 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 102 | if m.bias is not None: 103 | m.bias.data.zero_() 104 | def forward(self, x): 105 | x = self.proj(x.flatten(2).transpose(1, 2)) 106 | # x = self.act1(self.bn1(self.fc1(x))).flatten(2).transpose(1, 2) 107 | x = self.norm(x) 108 | return x 109 | 110 | class HeadBlock(nn.Module): 111 | def __init__(self, in_chans=512, emb_chans=64, num_heads=2, sr_ratio=4): 112 | super().__init__() 113 | self.proj = ProjEmbed(in_chans=in_chans, emb_chans=emb_chans) 114 | self.norm1 = partial(nn.LayerNorm, eps=1e-6)(emb_chans) 115 | self.attn = Attention(emb_chans, num_heads=num_heads, sr_ratio=sr_ratio) 116 | self.drop_path = nn.Identity() 117 | self.norm2 = partial(nn.LayerNorm, eps=1e-6)(emb_chans) 118 | mlp_ratio = 2 119 | mlp_hidden_dim = int(emb_chans * mlp_ratio) 120 | self.mlp = Mlp(in_features=emb_chans, hidden_features=mlp_hidden_dim, act_layer=nn.Hardswish) 121 | 122 | def forward(self, x): 123 | B, C, H, W = x.shape 124 | x = self.proj(x) 125 | 126 | x = x + self.drop_path(self.attn(self.norm1(x), H, W)) 127 | x = x + self.drop_path(self.mlp(self.norm2(x), H, W)) 128 | 129 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 130 | return x 131 | 132 | 133 | 134 | class FPTHead(nn.Module): 135 | def __init__(self, vit_params, c1_channels=256, c4_channels=2048, hid_dim=64, norm_layer=nn.BatchNorm2d): 136 | super().__init__() 137 | sr_ratio = [4, 4, 4, 1] 138 | emb_chans = vit_params['emb_chans'] 139 | self.head1 = HeadBlock(in_chans=64, emb_chans=emb_chans, sr_ratio=sr_ratio[0]) 140 | self.head2 = HeadBlock(in_chans=128, emb_chans=emb_chans, sr_ratio=sr_ratio[1]) 141 | self.head3 = HeadBlock(in_chans=320, emb_chans=emb_chans, sr_ratio=sr_ratio[2]) 142 | self.head4 = HeadBlock(in_chans=512, emb_chans=emb_chans, sr_ratio=sr_ratio[3]) 143 | 144 | self.pred = nn.Conv2d(emb_chans, vit_params['nclass'], 1) 145 | 146 | 147 | def forward(self, c1, c2, c3, c4): 148 | size = c1.size()[2:] 149 | 150 | c4 = self.head4(c4) 151 | out = F.interpolate(c4, size, mode='bilinear', align_corners=True) 152 | 153 | c3 = self.head3(c3) 154 | out += F.interpolate(c3, size, mode='bilinear', align_corners=True) 155 | 156 | c2 = self.head2(c2) 157 | out += F.interpolate(c2, size, mode='bilinear', align_corners=True) 158 | 159 | out += self.head1(c1) 160 | out = self.pred(out) 161 | return out 162 | 163 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Trans4Trans 2 | 3 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/trans4trans-efficient-transformer-for/semantic-segmentation-on-trans10k)](https://paperswithcode.com/sota/semantic-segmentation-on-trans10k?p=trans4trans-efficient-transformer-for) 4 | 5 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/trans4trans-efficient-transformer-for-1/semantic-segmentation-on-dada-seg)](https://paperswithcode.com/sota/semantic-segmentation-on-dada-seg?p=trans4trans-efficient-transformer-for-1) 6 | 7 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/trans4trans-efficient-transformer-for/semantic-segmentation-on-eventscape)](https://paperswithcode.com/sota/semantic-segmentation-on-eventscape?p=trans4trans-efficient-transformer-for) 8 | 9 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/trans4trans-efficient-transformer-for-1/semantic-segmentation-on-cityscapes-val)](https://paperswithcode.com/sota/semantic-segmentation-on-cityscapes-val?p=trans4trans-efficient-transformer-for-1) 10 | 11 | ## Introduction 12 | ![trans4trans](trans4trans_fig_1.jpg) 13 | 14 | We build upon a portable system based on the proposed Trans4Trans model, aiming to assist the people with visual impairment to correctly interact with general and transparent objects in daily living. 15 | 16 | Please refer to our conference paper. 17 | 18 | **Trans4Trans: Efficient Transformer for Transparent Object Segmentation to Help Visually Impaired People Navigate in the Real World**, ICCVW 2021, [[paper](https://arxiv.org/pdf/2107.03172.pdf)]. 19 | 20 | For more details and the driving scene segmentation on benchmarks including Cityscapes, ACDC, and DADAseg, please refer to the journal version. 21 | 22 | **Trans4Trans: Efficient Transformer for Transparent Object and Semantic Scene Segmentation in Real-World Navigation Assistance**, T-ITS 2021, [[paper](https://arxiv.org/pdf/2108.09174.pdf)]. 23 | 24 | 25 | ## Installation 26 | 27 | Create environment: 28 | 29 | ```bash 30 | conda create -n trans4trans python=3.7 31 | conda activate trans4trans 32 | conda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cudatoolkit=11.1 -c pytorch -c conda-forge 33 | conda install pyyaml pillow requests tqdm ipython scipy opencv-python thop tabulate 34 | ``` 35 | 36 | And install: 37 | 38 | ```bash 39 | python setup.py develop --user 40 | ``` 41 | 42 | ## Datasets 43 | Create `datasets` directory and prepare [ACDC](https://acdc.vision.ee.ethz.ch/), [Cityscapes](https://www.cityscapes-dataset.com/), [DensePASS](https://github.com/chma1024/DensePASS#dataset), [Stanford2D3D](http://buildingparser.stanford.edu/dataset.html), and [Trans10K](https://github.com/xieenze/Trans2Seg#data-preparation) datasets as the structure below: 44 | 45 | ```text 46 | ./datasets 47 | ├── acdc 48 | │ ├── gt 49 | │ └── rgb_anon 50 | ├── cityscapes 51 | │ ├── gtFine 52 | │ └── leftImg8bit 53 | ├── DensePASS 54 | │ ├── gtFine 55 | │ └── leftImg8bit 56 | ├── stanford2d3d 57 | │ ├── area_1 58 | │ ├── area_2 59 | │ ├── area_3 60 | │ ├── area_4 61 | │ ├── area_5a 62 | │ ├── area_5b 63 | │ └── area_6 64 | └── transparent/Trans10K_cls12 65 | │ ├── test 66 | │ ├── train 67 | │ └── validation 68 | ``` 69 | 70 | Create `pretrained` direcotry and prepare [pretrained models](https://github.com/whai362/PVT#image-classification) as: 71 | 72 | ```text 73 | . 74 | ├── pvt_medium.pth 75 | ├── pvt_small.pth 76 | ├── pvt_tiny.pth 77 | └── v2 78 | ├── pvt_medium.pth 79 | ├── pvt_small.pth 80 | └── pvt_tiny.pth 81 | ``` 82 | 83 | ## Training 84 | 85 | Before training, please modify the [config](./configs) file to match your own paths. 86 | 87 | We train our models on 4 1080Ti GPUs, for example: 88 | 89 | ```bash 90 | python -m torch.distributed.launch --nproc_per_node=4 tools/train.py --config-file configs/trans10kv2/pvt_tiny_FPT.yaml 91 | ``` 92 | 93 | We recommend to use the [mmsegmentation](https://github.com/open-mmlab/mmsegmentation) framework to train models with higher resolutions (e.g. 768x768). 94 | 95 | ## Evaluation 96 | 97 | Before testing the model, please change the `TEST_MODEL_PATH` of the config file. 98 | 99 | ```bash 100 | python -m torch.distributed.launch --nproc_per_node=4 tools/eval.py --config-file configs/trans10kv2/pvt_tiny_FPT.yaml 101 | ``` 102 | The weights can be downloaded from [BaiduDrive](https://pan.baidu.com/s/1N2VbNRwrqsQELMw7A6_xew?pwd=mq2v). 103 | 104 | 105 | 106 | ## Assistive system demo 107 | 108 | We use Intel RealSense R200 camera on our assistive system. The librealsense are required. 109 | Please install the [**librealsense legacy**](https://github.com/IntelRealSense/librealsense/tree/legacy) version. 110 | And install the dependencies for the demo: 111 | 112 | ```bash 113 | pip install pyrealsense==2.2 114 | pip install pyttsx3 115 | pip install pydub 116 | ``` 117 | 118 | Download the pretrained weight from [GoogleDrive](https://drive.google.com/drive/folders/1_b1oSheDtniegqirWNPj7xb9VcaU-25r?usp=sharing) and save as `workdirs/cocostuff/model_cocostuff.pth` and `workdirs/trans10kv2/model_trans.pth`. 119 | 120 | (Optional) Please check [`demo.py`](./demo.py) to customize the configurations, for example, the speech volumn and frequency. 121 | 122 | After installation and weights downloaded, run `bash demo.sh`. 123 | 124 | 125 | ## References 126 | * [Segmentron](https://github.com/LikeLy-Journey/SegmenTron) 127 | * [Trans2Seg](https://github.com/xieenze/Trans2Seg) 128 | * [mmsegmentation](https://github.com/open-mmlab/mmsegmentation) 129 | 130 | ## License 131 | 132 | This repository is under the Apache-2.0 license. For commercial use, please contact with the authors. 133 | 134 | 135 | ## Citations 136 | 137 | If you are interested in this work, please cite the following work: 138 | 139 | ```text 140 | @inproceedings{zhang2021trans4trans, 141 | title={Trans4Trans: Efficient transformer for transparent object segmentation to help visually impaired people navigate in the real world}, 142 | author={Zhang, Jiaming and Yang, Kailun and Constantinescu, Angela and Peng, Kunyu and M{\"u}ller, Karin and Stiefelhagen, Rainer}, 143 | booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision}, 144 | pages={1760--1770}, 145 | year={2021} 146 | } 147 | ``` 148 | 149 | ```text 150 | @article{zhang2022trans4trans, 151 | title={Trans4Trans: Efficient transformer for transparent object and semantic scene segmentation in real-world navigation assistance}, 152 | author={Zhang, Jiaming and Yang, Kailun and Constantinescu, Angela and Peng, Kunyu and M{\"u}ller, Karin and Stiefelhagen, Rainer}, 153 | journal={IEEE Transactions on Intelligent Transportation Systems}, 154 | year={2022}, 155 | publisher={IEEE} 156 | } 157 | ``` -------------------------------------------------------------------------------- /segmentron/models/backbones/mobilenet.py: -------------------------------------------------------------------------------- 1 | """MobileNet and MobileNetV2.""" 2 | import torch.nn as nn 3 | 4 | from .build import BACKBONE_REGISTRY 5 | from ...modules import _ConvBNReLU, _DepthwiseConv, InvertedResidual 6 | from ...config import cfg 7 | 8 | __all__ = ['MobileNet', 'MobileNetV2'] 9 | 10 | 11 | class MobileNet(nn.Module): 12 | def __init__(self, num_classes=1000, norm_layer=nn.BatchNorm2d): 13 | super(MobileNet, self).__init__() 14 | multiplier = cfg.MODEL.BACKBONE_SCALE 15 | conv_dw_setting = [ 16 | [64, 1, 1], 17 | [128, 2, 2], 18 | [256, 2, 2], 19 | [512, 6, 2], 20 | [1024, 2, 2]] 21 | input_channels = int(32 * multiplier) if multiplier > 1.0 else 32 22 | features = [_ConvBNReLU(3, input_channels, 3, 2, 1, norm_layer=norm_layer)] 23 | 24 | for c, n, s in conv_dw_setting: 25 | out_channels = int(c * multiplier) 26 | for i in range(n): 27 | stride = s if i == 0 else 1 28 | features.append(_DepthwiseConv(input_channels, out_channels, stride, norm_layer)) 29 | input_channels = out_channels 30 | self.last_inp_channels = int(1024 * multiplier) 31 | features.append(nn.AdaptiveAvgPool2d(1)) 32 | self.features = nn.Sequential(*features) 33 | 34 | self.classifier = nn.Linear(int(1024 * multiplier), num_classes) 35 | 36 | # weight initialization 37 | for m in self.modules(): 38 | if isinstance(m, nn.Conv2d): 39 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 40 | if m.bias is not None: 41 | nn.init.zeros_(m.bias) 42 | elif isinstance(m, nn.BatchNorm2d): 43 | nn.init.ones_(m.weight) 44 | nn.init.zeros_(m.bias) 45 | elif isinstance(m, nn.Linear): 46 | nn.init.normal_(m.weight, 0, 0.01) 47 | nn.init.zeros_(m.bias) 48 | 49 | def forward(self, x): 50 | x = self.features(x) 51 | x = self.classifier(x.view(x.size(0), x.size(1))) 52 | return x 53 | 54 | 55 | class MobileNetV2(nn.Module): 56 | def __init__(self, num_classes=1000, norm_layer=nn.BatchNorm2d): 57 | super(MobileNetV2, self).__init__() 58 | output_stride = cfg.MODEL.OUTPUT_STRIDE 59 | self.multiplier = cfg.MODEL.BACKBONE_SCALE 60 | if output_stride == 32: 61 | dilations = [1, 1] 62 | elif output_stride == 16: 63 | dilations = [1, 2] 64 | elif output_stride == 8: 65 | dilations = [2, 4] 66 | else: 67 | raise NotImplementedError 68 | inverted_residual_setting = [ 69 | # t, c, n, s 70 | [1, 16, 1, 1], 71 | [6, 24, 2, 2], 72 | [6, 32, 3, 2], 73 | [6, 64, 4, 2], 74 | [6, 96, 3, 1], 75 | [6, 160, 3, 2], 76 | [6, 320, 1, 1]] 77 | # building first layer 78 | input_channels = int(32 * self.multiplier) if self.multiplier > 1.0 else 32 79 | # last_channels = int(1280 * multiplier) if multiplier > 1.0 else 1280 80 | self.conv1 = _ConvBNReLU(3, input_channels, 3, 2, 1, relu6=True, norm_layer=norm_layer) 81 | 82 | # building inverted residual blocks 83 | self.planes = input_channels 84 | self.block1 = self._make_layer(InvertedResidual, self.planes, inverted_residual_setting[0:1], 85 | norm_layer=norm_layer) 86 | self.block2 = self._make_layer(InvertedResidual, self.planes, inverted_residual_setting[1:2], 87 | norm_layer=norm_layer) 88 | self.block3 = self._make_layer(InvertedResidual, self.planes, inverted_residual_setting[2:3], 89 | norm_layer=norm_layer) 90 | self.block4 = self._make_layer(InvertedResidual, self.planes, inverted_residual_setting[3:5], 91 | dilations[0], norm_layer=norm_layer) 92 | self.block5 = self._make_layer(InvertedResidual, self.planes, inverted_residual_setting[5:], 93 | dilations[1], norm_layer=norm_layer) 94 | self.last_inp_channels = self.planes 95 | 96 | # building last several layers 97 | # features = list() 98 | # features.append(_ConvBNReLU(input_channels, last_channels, 1, relu6=True, norm_layer=norm_layer)) 99 | # features.append(nn.AdaptiveAvgPool2d(1)) 100 | # self.features = nn.Sequential(*features) 101 | # 102 | # self.classifier = nn.Sequential( 103 | # nn.Dropout2d(0.2), 104 | # nn.Linear(last_channels, num_classes)) 105 | 106 | # weight initialization 107 | for m in self.modules(): 108 | if isinstance(m, nn.Conv2d): 109 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 110 | if m.bias is not None: 111 | nn.init.zeros_(m.bias) 112 | elif isinstance(m, nn.BatchNorm2d): 113 | nn.init.ones_(m.weight) 114 | nn.init.zeros_(m.bias) 115 | elif isinstance(m, nn.Linear): 116 | nn.init.normal_(m.weight, 0, 0.01) 117 | if m.bias is not None: 118 | nn.init.zeros_(m.bias) 119 | 120 | def _make_layer(self, block, planes, inverted_residual_setting, dilation=1, norm_layer=nn.BatchNorm2d): 121 | features = list() 122 | for t, c, n, s in inverted_residual_setting: 123 | out_channels = int(c * self.multiplier) 124 | stride = s if dilation == 1 else 1 125 | features.append(block(planes, out_channels, stride, t, dilation, norm_layer)) 126 | planes = out_channels 127 | for i in range(n - 1): 128 | features.append(block(planes, out_channels, 1, t, norm_layer=norm_layer)) 129 | planes = out_channels 130 | self.planes = planes 131 | return nn.Sequential(*features) 132 | 133 | def forward(self, x): 134 | x = self.conv1(x) 135 | x = self.block1(x) 136 | c1 = self.block2(x) 137 | c2 = self.block3(c1) 138 | c3 = self.block4(c2) 139 | c4 = self.block5(c3) 140 | 141 | # x = self.features(x) 142 | # x = self.classifier(x.view(x.size(0), x.size(1))) 143 | return c1, c2, c3, c4 144 | 145 | 146 | @BACKBONE_REGISTRY.register() 147 | def mobilenet_v1(norm_layer=nn.BatchNorm2d): 148 | return MobileNet(norm_layer=norm_layer) 149 | 150 | 151 | @BACKBONE_REGISTRY.register() 152 | def mobilenet_v2(norm_layer=nn.BatchNorm2d): 153 | return MobileNetV2(norm_layer=norm_layer) 154 | 155 | -------------------------------------------------------------------------------- /segmentron/models/backbones/eespnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from ...modules import _ConvBNPReLU, _ConvBN, _BNPReLU, EESP 7 | from .build import BACKBONE_REGISTRY 8 | from ...config import cfg 9 | 10 | __all__ = ['EESPNet', 'eespnet'] 11 | 12 | 13 | class DownSampler(nn.Module): 14 | 15 | def __init__(self, in_channels, out_channels, k=4, r_lim=9, reinf=True, inp_reinf=3, norm_layer=None): 16 | super(DownSampler, self).__init__() 17 | channels_diff = out_channels - in_channels 18 | self.eesp = EESP(in_channels, channels_diff, stride=2, k=k, 19 | r_lim=r_lim, down_method='avg', norm_layer=norm_layer) 20 | self.avg = nn.AvgPool2d(kernel_size=3, padding=1, stride=2) 21 | if reinf: 22 | self.inp_reinf = nn.Sequential( 23 | _ConvBNPReLU(inp_reinf, inp_reinf, 3, 1, 1), 24 | _ConvBN(inp_reinf, out_channels, 1, 1)) 25 | self.act = nn.PReLU(out_channels) 26 | 27 | def forward(self, x, x2=None): 28 | avg_out = self.avg(x) 29 | eesp_out = self.eesp(x) 30 | output = torch.cat([avg_out, eesp_out], 1) 31 | if x2 is not None: 32 | w1 = avg_out.size(2) 33 | while True: 34 | x2 = F.avg_pool2d(x2, kernel_size=3, padding=1, stride=2) 35 | w2 = x2.size(2) 36 | if w2 == w1: 37 | break 38 | output = output + self.inp_reinf(x2) 39 | 40 | return self.act(output) 41 | 42 | 43 | class EESPNet(nn.Module): 44 | def __init__(self, num_classes=1000, scale=1, reinf=True, norm_layer=nn.BatchNorm2d): 45 | super(EESPNet, self).__init__() 46 | inp_reinf = 3 if reinf else None 47 | reps = [0, 3, 7, 3] 48 | r_lim = [13, 11, 9, 7, 5] 49 | K = [4] * len(r_lim) 50 | 51 | # set out_channels 52 | base, levels, base_s = 32, 5, 0 53 | out_channels = [base] * levels 54 | for i in range(levels): 55 | if i == 0: 56 | base_s = int(base * scale) 57 | base_s = math.ceil(base_s / K[0]) * K[0] 58 | out_channels[i] = base if base_s > base else base_s 59 | else: 60 | out_channels[i] = base_s * pow(2, i) 61 | if scale <= 1.5: 62 | out_channels.append(1024) 63 | elif scale in [1.5, 2]: 64 | out_channels.append(1280) 65 | else: 66 | raise ValueError("Unknown scale value.") 67 | 68 | self.level1 = _ConvBNPReLU(3, out_channels[0], 3, 2, 1, norm_layer=norm_layer) 69 | 70 | self.level2_0 = DownSampler(out_channels[0], out_channels[1], k=K[0], r_lim=r_lim[0], 71 | reinf=reinf, inp_reinf=inp_reinf, norm_layer=norm_layer) 72 | 73 | self.level3_0 = DownSampler(out_channels[1], out_channels[2], k=K[1], r_lim=r_lim[1], 74 | reinf=reinf, inp_reinf=inp_reinf, norm_layer=norm_layer) 75 | self.level3 = nn.ModuleList() 76 | for i in range(reps[1]): 77 | self.level3.append(EESP(out_channels[2], out_channels[2], k=K[2], r_lim=r_lim[2], 78 | norm_layer=norm_layer)) 79 | 80 | self.level4_0 = DownSampler(out_channels[2], out_channels[3], k=K[2], r_lim=r_lim[2], 81 | reinf=reinf, inp_reinf=inp_reinf, norm_layer=norm_layer) 82 | self.level4 = nn.ModuleList() 83 | for i in range(reps[2]): 84 | self.level4.append(EESP(out_channels[3], out_channels[3], k=K[3], r_lim=r_lim[3], 85 | norm_layer=norm_layer)) 86 | 87 | self.level5_0 = DownSampler(out_channels[3], out_channels[4], k=K[3], r_lim=r_lim[3], 88 | reinf=reinf, inp_reinf=inp_reinf, norm_layer=norm_layer) 89 | self.level5 = nn.ModuleList() 90 | for i in range(reps[2]): 91 | self.level5.append(EESP(out_channels[4], out_channels[4], k=K[4], r_lim=r_lim[4], 92 | norm_layer=norm_layer)) 93 | 94 | self.level5.append(_ConvBNPReLU(out_channels[4], out_channels[4], 3, 1, 1, 95 | groups=out_channels[4], norm_layer=norm_layer)) 96 | self.level5.append(_ConvBNPReLU(out_channels[4], out_channels[5], 1, 1, 0, 97 | groups=K[4], norm_layer=norm_layer)) 98 | 99 | self.fc = nn.Linear(out_channels[5], num_classes) 100 | 101 | for m in self.modules(): 102 | if isinstance(m, nn.Conv2d): 103 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 104 | if m.bias is not None: 105 | nn.init.constant_(m.bias, 0) 106 | elif isinstance(m, nn.BatchNorm2d): 107 | nn.init.constant_(m.weight, 1) 108 | nn.init.constant_(m.bias, 0) 109 | elif isinstance(m, nn.Linear): 110 | nn.init.normal_(m.weight, std=0.001) 111 | if m.bias is not None: 112 | nn.init.constant_(m.bias, 0) 113 | 114 | def forward(self, x, seg=True): 115 | out_l1 = self.level1(x) 116 | 117 | out_l2 = self.level2_0(out_l1, x) 118 | 119 | out_l3_0 = self.level3_0(out_l2, x) 120 | for i, layer in enumerate(self.level3): 121 | if i == 0: 122 | out_l3 = layer(out_l3_0) 123 | else: 124 | out_l3 = layer(out_l3) 125 | 126 | out_l4_0 = self.level4_0(out_l3, x) 127 | for i, layer in enumerate(self.level4): 128 | if i == 0: 129 | out_l4 = layer(out_l4_0) 130 | else: 131 | out_l4 = layer(out_l4) 132 | 133 | if not seg: 134 | out_l5_0 = self.level5_0(out_l4) # down-sampled 135 | for i, layer in enumerate(self.level5): 136 | if i == 0: 137 | out_l5 = layer(out_l5_0) 138 | else: 139 | out_l5 = layer(out_l5) 140 | 141 | output_g = F.adaptive_avg_pool2d(out_l5, output_size=1) 142 | output_g = F.dropout(output_g, p=0.2, training=self.training) 143 | output_1x1 = output_g.view(output_g.size(0), -1) 144 | 145 | return self.fc(output_1x1) 146 | return out_l1, out_l2, out_l3, out_l4 147 | 148 | 149 | @BACKBONE_REGISTRY.register() 150 | def eespnet(norm_layer=nn.BatchNorm2d): 151 | return EESPNet(norm_layer=norm_layer) 152 | 153 | # def eespnet(pretrained=False, **kwargs): 154 | # model = EESPNet(**kwargs) 155 | # if pretrained: 156 | # raise ValueError("Don't support pretrained") 157 | # return model 158 | 159 | 160 | if __name__ == '__main__': 161 | img = torch.randn(1, 3, 224, 224) 162 | model = eespnet() 163 | out = model(img) 164 | -------------------------------------------------------------------------------- /segmentron/solver/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # this code heavily reference: detectron2 2 | from __future__ import division 3 | import math 4 | import torch 5 | 6 | from typing import List 7 | from bisect import bisect_right 8 | from segmentron.config import cfg 9 | 10 | __all__ = ['get_scheduler'] 11 | 12 | 13 | class WarmupPolyLR(torch.optim.lr_scheduler._LRScheduler): 14 | def __init__(self, optimizer, target_lr=0, max_iters=0, power=0.9, warmup_factor=1.0 / 3, 15 | warmup_iters=500, warmup_method='linear', last_epoch=-1): 16 | if warmup_method not in ("constant", "linear"): 17 | raise ValueError( 18 | "Only 'constant' or 'linear' warmup_method accepted " 19 | "got {}".format(warmup_method)) 20 | 21 | self.target_lr = target_lr 22 | self.max_iters = max_iters 23 | self.power = power 24 | self.warmup_factor = warmup_factor 25 | self.warmup_iters = warmup_iters 26 | self.warmup_method = warmup_method 27 | 28 | super(WarmupPolyLR, self).__init__(optimizer, last_epoch) 29 | 30 | def get_lr(self): 31 | N = self.max_iters - self.warmup_iters 32 | T = self.last_epoch - self.warmup_iters 33 | if self.last_epoch < self.warmup_iters: 34 | if self.warmup_method == 'constant': 35 | warmup_factor = self.warmup_factor 36 | elif self.warmup_method == 'linear': 37 | alpha = float(self.last_epoch) / self.warmup_iters 38 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 39 | else: 40 | raise ValueError("Unknown warmup type.") 41 | return [self.target_lr + (base_lr - self.target_lr) * warmup_factor for base_lr in self.base_lrs] 42 | factor = pow(1 - T / N, self.power) 43 | return [self.target_lr + (base_lr - self.target_lr) * factor for base_lr in self.base_lrs] 44 | 45 | 46 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): 47 | def __init__( 48 | self, 49 | optimizer: torch.optim.Optimizer, 50 | milestones: List[int], 51 | gamma: float = 0.1, 52 | warmup_factor: float = 0.001, 53 | warmup_iters: int = 1000, 54 | warmup_method: str = "linear", 55 | last_epoch: int = -1, 56 | ): 57 | if not list(milestones) == sorted(milestones): 58 | raise ValueError( 59 | "Milestones should be a list of" " increasing integers. Got {}", milestones 60 | ) 61 | self.milestones = milestones 62 | self.gamma = gamma 63 | self.warmup_factor = warmup_factor 64 | self.warmup_iters = warmup_iters 65 | self.warmup_method = warmup_method 66 | super().__init__(optimizer, last_epoch) 67 | 68 | def get_lr(self) -> List[float]: 69 | warmup_factor = _get_warmup_factor_at_iter( 70 | self.warmup_method, self.last_epoch, self.warmup_iters, self.warmup_factor 71 | ) 72 | return [ 73 | base_lr * warmup_factor * self.gamma ** bisect_right(self.milestones, self.last_epoch) 74 | for base_lr in self.base_lrs 75 | ] 76 | 77 | def _compute_values(self) -> List[float]: 78 | # The new interface 79 | return self.get_lr() 80 | 81 | 82 | class WarmupCosineLR(torch.optim.lr_scheduler._LRScheduler): 83 | def __init__( 84 | self, 85 | optimizer: torch.optim.Optimizer, 86 | max_iters: int, 87 | warmup_factor: float = 0.001, 88 | warmup_iters: int = 1000, 89 | warmup_method: str = "linear", 90 | last_epoch: int = -1, 91 | ): 92 | self.max_iters = max_iters 93 | self.warmup_factor = warmup_factor 94 | self.warmup_iters = warmup_iters 95 | self.warmup_method = warmup_method 96 | super().__init__(optimizer, last_epoch) 97 | 98 | def get_lr(self) -> List[float]: 99 | warmup_factor = _get_warmup_factor_at_iter( 100 | self.warmup_method, self.last_epoch, self.warmup_iters, self.warmup_factor 101 | ) 102 | # Different definitions of half-cosine with warmup are possible. For 103 | # simplicity we multiply the standard half-cosine schedule by the warmup 104 | # factor. An alternative is to start the period of the cosine at warmup_iters 105 | # instead of at 0. In the case that warmup_iters << max_iters the two are 106 | # very close to each other. 107 | return [ 108 | base_lr 109 | * warmup_factor 110 | * 0.5 111 | * (1.0 + math.cos(math.pi * self.last_epoch / self.max_iters)) 112 | for base_lr in self.base_lrs 113 | ] 114 | 115 | def _compute_values(self) -> List[float]: 116 | # The new interface 117 | return self.get_lr() 118 | 119 | 120 | def _get_warmup_factor_at_iter( 121 | method: str, iter: int, warmup_iters: int, warmup_factor: float 122 | ) -> float: 123 | """ 124 | Return the learning rate warmup factor at a specific iteration. 125 | See https://arxiv.org/abs/1706.02677 for more details. 126 | 127 | Args: 128 | method (str): warmup method; either "constant" or "linear". 129 | iter (int): iteration at which to calculate the warmup factor. 130 | warmup_iters (int): the number of warmup iterations. 131 | warmup_factor (float): the base warmup factor (the meaning changes according 132 | to the method used). 133 | 134 | Returns: 135 | float: the effective warmup factor at the given iteration. 136 | """ 137 | if iter >= warmup_iters: 138 | return 1.0 139 | 140 | if method == "constant": 141 | return warmup_factor 142 | elif method == "linear": 143 | alpha = iter / warmup_iters 144 | return warmup_factor * (1 - alpha) + alpha 145 | else: 146 | raise ValueError("Unknown warmup method: {}".format(method)) 147 | 148 | 149 | def get_scheduler(optimizer, max_iters, iters_per_epoch): 150 | mode = cfg.SOLVER.LR_SCHEDULER.lower() 151 | warm_up_iters = iters_per_epoch * cfg.SOLVER.WARMUP.EPOCHS 152 | if mode == 'poly': 153 | return WarmupPolyLR(optimizer, max_iters=max_iters, power=cfg.SOLVER.POLY.POWER, 154 | warmup_factor=cfg.SOLVER.WARMUP.FACTOR, warmup_iters=warm_up_iters, 155 | warmup_method=cfg.SOLVER.WARMUP.METHOD) 156 | elif mode == 'cosine': 157 | return WarmupCosineLR(optimizer, max_iters=max_iters, warmup_factor=cfg.SOLVER.WARMUP.FACTOR, 158 | warmup_iters=warm_up_iters, warmup_method=cfg.SOLVER.WARMUP.METHOD) 159 | elif mode == 'step': 160 | milestones = [x * iters_per_epoch for x in cfg.SOLVER.STEP.DECAY_EPOCH] 161 | return WarmupMultiStepLR(optimizer, milestones=milestones, gamma=cfg.SOLVER.STEP.GAMMA, 162 | warmup_factor=cfg.SOLVER.WARMUP.FACTOR, warmup_iters=warm_up_iters, 163 | warmup_method=cfg.SOLVER.WARMUP.METHOD) 164 | else: 165 | raise ValueError("not support lr scheduler method!") 166 | 167 | -------------------------------------------------------------------------------- /segmentron/modules/basic.py: -------------------------------------------------------------------------------- 1 | """Basic Module for Semantic Segmentation""" 2 | import torch 3 | import torch.nn as nn 4 | 5 | from collections import OrderedDict 6 | 7 | __all__ = ['_ConvBNPReLU', '_ConvBN', '_BNPReLU', '_ConvBNReLU', '_DepthwiseConv', 'InvertedResidual', 8 | 'SeparableConv2d'] 9 | 10 | _USE_FIXED_PAD = False 11 | 12 | 13 | def _pytorch_padding(kernel_size, stride=1, dilation=1, **_): 14 | if _USE_FIXED_PAD: 15 | return 0 # FIXME remove once verified 16 | else: 17 | padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 18 | 19 | # FIXME remove once verified 20 | fp = _fixed_padding(kernel_size, dilation) 21 | assert all(padding == p for p in fp) 22 | 23 | return padding 24 | 25 | 26 | def _fixed_padding(kernel_size, dilation): 27 | kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1) 28 | pad_total = kernel_size_effective - 1 29 | pad_beg = pad_total // 2 30 | pad_end = pad_total - pad_beg 31 | return [pad_beg, pad_end, pad_beg, pad_end] 32 | 33 | 34 | class SeparableConv2d(nn.Module): 35 | def __init__(self, inplanes, planes, kernel_size=3, stride=1, dilation=1, relu_first=True, 36 | bias=False, norm_layer=nn.BatchNorm2d): 37 | super().__init__() 38 | depthwise = nn.Conv2d(inplanes, inplanes, kernel_size, 39 | stride=stride, padding=dilation, 40 | dilation=dilation, groups=inplanes, bias=bias) 41 | bn_depth = norm_layer(inplanes) 42 | pointwise = nn.Conv2d(inplanes, planes, 1, bias=bias) 43 | bn_point = norm_layer(planes) 44 | 45 | if relu_first: 46 | self.block = nn.Sequential(OrderedDict([('relu', nn.ReLU()), 47 | ('depthwise', depthwise), 48 | ('bn_depth', bn_depth), 49 | ('pointwise', pointwise), 50 | ('bn_point', bn_point) 51 | ])) 52 | else: 53 | self.block = nn.Sequential(OrderedDict([('depthwise', depthwise), 54 | ('bn_depth', bn_depth), 55 | ('relu1', nn.ReLU(inplace=True)), 56 | ('pointwise', pointwise), 57 | ('bn_point', bn_point), 58 | ('relu2', nn.ReLU(inplace=True)) 59 | ])) 60 | 61 | def forward(self, x): 62 | return self.block(x) 63 | 64 | 65 | class _ConvBNReLU(nn.Module): 66 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, 67 | dilation=1, groups=1, relu6=False, norm_layer=nn.BatchNorm2d): 68 | super(_ConvBNReLU, self).__init__() 69 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=False) 70 | self.bn = norm_layer(out_channels) 71 | self.relu = nn.ReLU6(True) if relu6 else nn.ReLU(True) 72 | 73 | def forward(self, x): 74 | x = self.conv(x) 75 | x = self.bn(x) 76 | x = self.relu(x) 77 | return x 78 | 79 | 80 | class _ConvBNPReLU(nn.Module): 81 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, 82 | dilation=1, groups=1, norm_layer=nn.BatchNorm2d): 83 | super(_ConvBNPReLU, self).__init__() 84 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=False) 85 | self.bn = norm_layer(out_channels) 86 | self.prelu = nn.PReLU(out_channels) 87 | 88 | def forward(self, x): 89 | x = self.conv(x) 90 | x = self.bn(x) 91 | x = self.prelu(x) 92 | return x 93 | 94 | 95 | class _ConvBN(nn.Module): 96 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, 97 | dilation=1, groups=1, norm_layer=nn.BatchNorm2d, **kwargs): 98 | super(_ConvBN, self).__init__() 99 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=False) 100 | self.bn = norm_layer(out_channels) 101 | 102 | def forward(self, x): 103 | x = self.conv(x) 104 | x = self.bn(x) 105 | return x 106 | 107 | 108 | class _BNPReLU(nn.Module): 109 | def __init__(self, out_channels, norm_layer=nn.BatchNorm2d): 110 | super(_BNPReLU, self).__init__() 111 | self.bn = norm_layer(out_channels) 112 | self.prelu = nn.PReLU(out_channels) 113 | 114 | def forward(self, x): 115 | x = self.bn(x) 116 | x = self.prelu(x) 117 | return x 118 | 119 | 120 | # ----------------------------------------------------------------- 121 | # For MobileNet 122 | # ----------------------------------------------------------------- 123 | class _DepthwiseConv(nn.Module): 124 | """conv_dw in MobileNet""" 125 | 126 | def __init__(self, in_channels, out_channels, stride, norm_layer=nn.BatchNorm2d, **kwargs): 127 | super(_DepthwiseConv, self).__init__() 128 | self.conv = nn.Sequential( 129 | _ConvBNReLU(in_channels, in_channels, 3, stride, 1, groups=in_channels, norm_layer=norm_layer), 130 | _ConvBNReLU(in_channels, out_channels, 1, norm_layer=norm_layer)) 131 | 132 | def forward(self, x): 133 | return self.conv(x) 134 | 135 | 136 | # ----------------------------------------------------------------- 137 | # For MobileNetV2 138 | # ----------------------------------------------------------------- 139 | class InvertedResidual(nn.Module): 140 | def __init__(self, in_channels, out_channels, stride, expand_ratio, dilation=1, norm_layer=nn.BatchNorm2d): 141 | super(InvertedResidual, self).__init__() 142 | assert stride in [1, 2] 143 | self.use_res_connect = stride == 1 and in_channels == out_channels 144 | 145 | layers = list() 146 | inter_channels = int(round(in_channels * expand_ratio)) 147 | if expand_ratio != 1: 148 | # pw 149 | layers.append(_ConvBNReLU(in_channels, inter_channels, 1, relu6=True, norm_layer=norm_layer)) 150 | layers.extend([ 151 | # dw 152 | _ConvBNReLU(inter_channels, inter_channels, 3, stride, dilation, dilation, 153 | groups=inter_channels, relu6=True, norm_layer=norm_layer), 154 | # pw-linear 155 | nn.Conv2d(inter_channels, out_channels, 1, bias=False), 156 | norm_layer(out_channels)]) 157 | self.conv = nn.Sequential(*layers) 158 | 159 | def forward(self, x): 160 | if self.use_res_connect: 161 | return x + self.conv(x) 162 | else: 163 | return self.conv(x) 164 | 165 | 166 | if __name__ == '__main__': 167 | x = torch.randn(1, 32, 64, 64) 168 | model = InvertedResidual(32, 64, 2, 1) 169 | out = model(x) 170 | -------------------------------------------------------------------------------- /segmentron/modules/drop.py: -------------------------------------------------------------------------------- 1 | """ DropBlock, DropPath 2 | PyTorch implementations of DropBlock and DropPath (Stochastic Depth) regularization layers. 3 | Papers: 4 | DropBlock: A regularization method for convolutional networks (https://arxiv.org/abs/1810.12890) 5 | Deep Networks with Stochastic Depth (https://arxiv.org/abs/1603.09382) 6 | Code: 7 | DropBlock impl inspired by two Tensorflow impl that I liked: 8 | - https://github.com/tensorflow/tpu/blob/master/models/official/resnet/resnet_model.py#L74 9 | - https://github.com/clovaai/assembled-cnn/blob/master/nets/blocks.py 10 | Hacked together by / Copyright 2020 Ross Wightman 11 | """ 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | 16 | 17 | def drop_block_2d( 18 | x, drop_prob: float = 0.1, block_size: int = 7, gamma_scale: float = 1.0, 19 | with_noise: bool = False, inplace: bool = False, batchwise: bool = False): 20 | """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf 21 | DropBlock with an experimental gaussian noise option. This layer has been tested on a few training 22 | runs with success, but needs further validation and possibly optimization for lower runtime impact. 23 | """ 24 | B, C, H, W = x.shape 25 | total_size = W * H 26 | clipped_block_size = min(block_size, min(W, H)) 27 | # seed_drop_rate, the gamma parameter 28 | gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / ( 29 | (W - block_size + 1) * (H - block_size + 1)) 30 | 31 | # Forces the block to be inside the feature map. 32 | w_i, h_i = torch.meshgrid(torch.arange(W).to(x.device), torch.arange(H).to(x.device)) 33 | valid_block = ((w_i >= clipped_block_size // 2) & (w_i < W - (clipped_block_size - 1) // 2)) & \ 34 | ((h_i >= clipped_block_size // 2) & (h_i < H - (clipped_block_size - 1) // 2)) 35 | valid_block = torch.reshape(valid_block, (1, 1, H, W)).to(dtype=x.dtype) 36 | 37 | if batchwise: 38 | # one mask for whole batch, quite a bit faster 39 | uniform_noise = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device) 40 | else: 41 | uniform_noise = torch.rand_like(x) 42 | block_mask = ((2 - gamma - valid_block + uniform_noise) >= 1).to(dtype=x.dtype) 43 | block_mask = -F.max_pool2d( 44 | -block_mask, 45 | kernel_size=clipped_block_size, # block_size, 46 | stride=1, 47 | padding=clipped_block_size // 2) 48 | 49 | if with_noise: 50 | normal_noise = torch.randn((1, C, H, W), dtype=x.dtype, device=x.device) if batchwise else torch.randn_like(x) 51 | if inplace: 52 | x.mul_(block_mask).add_(normal_noise * (1 - block_mask)) 53 | else: 54 | x = x * block_mask + normal_noise * (1 - block_mask) 55 | else: 56 | normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)).to(x.dtype) 57 | if inplace: 58 | x.mul_(block_mask * normalize_scale) 59 | else: 60 | x = x * block_mask * normalize_scale 61 | return x 62 | 63 | 64 | def drop_block_fast_2d( 65 | x: torch.Tensor, drop_prob: float = 0.1, block_size: int = 7, 66 | gamma_scale: float = 1.0, with_noise: bool = False, inplace: bool = False, batchwise: bool = False): 67 | """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf 68 | DropBlock with an experimental gaussian noise option. Simplied from above without concern for valid 69 | block mask at edges. 70 | """ 71 | B, C, H, W = x.shape 72 | total_size = W * H 73 | clipped_block_size = min(block_size, min(W, H)) 74 | gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / ( 75 | (W - block_size + 1) * (H - block_size + 1)) 76 | 77 | if batchwise: 78 | # one mask for whole batch, quite a bit faster 79 | block_mask = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device) < gamma 80 | else: 81 | # mask per batch element 82 | block_mask = torch.rand_like(x) < gamma 83 | block_mask = F.max_pool2d( 84 | block_mask.to(x.dtype), kernel_size=clipped_block_size, stride=1, padding=clipped_block_size // 2) 85 | 86 | if with_noise: 87 | normal_noise = torch.randn((1, C, H, W), dtype=x.dtype, device=x.device) if batchwise else torch.randn_like(x) 88 | if inplace: 89 | x.mul_(1. - block_mask).add_(normal_noise * block_mask) 90 | else: 91 | x = x * (1. - block_mask) + normal_noise * block_mask 92 | else: 93 | block_mask = 1 - block_mask 94 | normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)).to(dtype=x.dtype) 95 | if inplace: 96 | x.mul_(block_mask * normalize_scale) 97 | else: 98 | x = x * block_mask * normalize_scale 99 | return x 100 | 101 | 102 | class DropBlock2d(nn.Module): 103 | """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf 104 | """ 105 | def __init__(self, 106 | drop_prob=0.1, 107 | block_size=7, 108 | gamma_scale=1.0, 109 | with_noise=False, 110 | inplace=False, 111 | batchwise=False, 112 | fast=True): 113 | super(DropBlock2d, self).__init__() 114 | self.drop_prob = drop_prob 115 | self.gamma_scale = gamma_scale 116 | self.block_size = block_size 117 | self.with_noise = with_noise 118 | self.inplace = inplace 119 | self.batchwise = batchwise 120 | self.fast = fast # FIXME finish comparisons of fast vs not 121 | 122 | def forward(self, x): 123 | if not self.training or not self.drop_prob: 124 | return x 125 | if self.fast: 126 | return drop_block_fast_2d( 127 | x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise) 128 | else: 129 | return drop_block_2d( 130 | x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise) 131 | 132 | 133 | def drop_path(x, drop_prob: float = 0., training: bool = False): 134 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 135 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, 136 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... 137 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for 138 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 139 | 'survival rate' as the argument. 140 | """ 141 | if drop_prob == 0. or not training: 142 | return x 143 | keep_prob = 1 - drop_prob 144 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 145 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) 146 | random_tensor.floor_() # binarize 147 | output = x.div(keep_prob) * random_tensor 148 | return output 149 | 150 | 151 | class DropPath(nn.Module): 152 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 153 | """ 154 | def __init__(self, drop_prob=None): 155 | super(DropPath, self).__init__() 156 | self.drop_prob = drop_prob 157 | 158 | def forward(self, x): 159 | return drop_path(x, self.drop_prob, self.training) -------------------------------------------------------------------------------- /segmentron/utils/score.py: -------------------------------------------------------------------------------- 1 | """Evaluation Metrics for Semantic Segmentation""" 2 | import torch 3 | import numpy as np 4 | from torch import distributed as dist 5 | import copy 6 | 7 | __all__ = ['SegmentationMetric', 'batch_pix_accuracy', 'batch_intersection_union', 8 | 'pixelAccuracy', 'intersectionAndUnion', 'hist_info', 'compute_score'] 9 | 10 | 11 | class SegmentationMetric(object): 12 | """Computes pixAcc and mIoU metric scores 13 | """ 14 | 15 | def __init__(self, nclass, distributed): 16 | super(SegmentationMetric, self).__init__() 17 | self.nclass = nclass 18 | self.distributed = distributed 19 | self.reset() 20 | 21 | def update(self, preds, labels): 22 | """Updates the internal evaluation result. 23 | 24 | Parameters 25 | ---------- 26 | labels : 'NumpyArray' or list of `NumpyArray` 27 | The labels of the data. 28 | preds : 'NumpyArray' or list of `NumpyArray` 29 | Predicted values. 30 | """ 31 | 32 | def reduce_tensor(tensor): 33 | rt = tensor.clone() 34 | dist.all_reduce(rt, op=dist.ReduceOp.SUM) 35 | return rt 36 | 37 | def evaluate_worker(self, pred, label): 38 | correct, labeled = batch_pix_accuracy(pred, label) 39 | inter, union = batch_intersection_union(pred, label, self.nclass) 40 | if self.distributed: 41 | correct = reduce_tensor(correct) 42 | labeled = reduce_tensor(labeled) 43 | inter = reduce_tensor(inter.cuda()) 44 | union = reduce_tensor(union.cuda()) 45 | torch.cuda.synchronize() 46 | self.total_correct += correct.item() 47 | self.total_label += labeled.item() 48 | if self.total_inter.device != inter.device: 49 | self.total_inter = self.total_inter.to(inter.device) 50 | self.total_union = self.total_union.to(union.device) 51 | self.total_inter += inter 52 | self.total_union += union 53 | 54 | if isinstance(preds, torch.Tensor): 55 | evaluate_worker(self, preds, labels) 56 | elif isinstance(preds, (list, tuple)): 57 | for (pred, label) in zip(preds, labels): 58 | evaluate_worker(self, pred, label) 59 | 60 | def get(self, return_category_iou=False): 61 | """Gets the current evaluation result. 62 | 63 | Returns 64 | ------- 65 | metrics : tuple of float 66 | pixAcc and mIoU 67 | """ 68 | pixAcc = 1.0 * self.total_correct / (2.220446049250313e-16 + self.total_label) # remove np.spacing(1) 69 | IoU = 1.0 * self.total_inter / (2.220446049250313e-16 + self.total_union) 70 | mIoU = IoU.mean().item() 71 | if return_category_iou: 72 | return pixAcc, mIoU, IoU.cpu().numpy() 73 | return pixAcc, mIoU 74 | 75 | def reset(self): 76 | """Resets the internal evaluation result to initial state.""" 77 | self.total_inter = torch.zeros(self.nclass) 78 | self.total_union = torch.zeros(self.nclass) 79 | self.total_correct = 0 80 | self.total_label = 0 81 | 82 | 83 | def batch_pix_accuracy(output, target): 84 | """PixAcc""" 85 | # inputs are numpy array, output 4D, target 3D 86 | predict = torch.argmax(output.long(), 1) + 1 87 | target = target.long() + 1 88 | 89 | pixel_labeled = torch.sum(target > 0)#.item() 90 | pixel_correct = torch.sum((predict == target) * (target > 0))#.item() 91 | assert pixel_correct <= pixel_labeled, "Correct area should be smaller than Labeled" 92 | return pixel_correct, pixel_labeled 93 | 94 | 95 | def batch_intersection_union(output, target, nclass): 96 | """mIoU""" 97 | # inputs are numpy array, output 4D, target 3D 98 | mini = 1 99 | maxi = nclass 100 | nbins = nclass 101 | predict = torch.argmax(output, 1) + 1 102 | target = target.float() + 1 103 | 104 | predict = predict.float() * (target > 0).float() 105 | intersection = predict * (predict == target).float() 106 | # areas of intersection and union 107 | # element 0 in intersection occur the main difference from np.bincount. set boundary to -1 is necessary. 108 | area_inter = torch.histc(intersection.cpu(), bins=nbins, min=mini, max=maxi) 109 | area_pred = torch.histc(predict.cpu(), bins=nbins, min=mini, max=maxi) 110 | area_lab = torch.histc(target.cpu(), bins=nbins, min=mini, max=maxi) 111 | area_union = area_pred + area_lab - area_inter 112 | assert torch.sum(area_inter > area_union).item() == 0, "Intersection area should be smaller than Union area" 113 | return area_inter.float(), area_union.float() 114 | 115 | 116 | def pixelAccuracy(imPred, imLab): 117 | """ 118 | This function takes the prediction and label of a single image, returns pixel-wise accuracy 119 | To compute over many images do: 120 | for i = range(Nimages): 121 | (pixel_accuracy[i], pixel_correct[i], pixel_labeled[i]) = \ 122 | pixelAccuracy(imPred[i], imLab[i]) 123 | mean_pixel_accuracy = 1.0 * np.sum(pixel_correct) / (np.spacing(1) + np.sum(pixel_labeled)) 124 | """ 125 | # Remove classes from unlabeled pixels in gt image. 126 | # We should not penalize detections in unlabeled portions of the image. 127 | pixel_labeled = np.sum(imLab >= 0) 128 | pixel_correct = np.sum((imPred == imLab) * (imLab >= 0)) 129 | pixel_accuracy = 1.0 * pixel_correct / pixel_labeled 130 | return (pixel_accuracy, pixel_correct, pixel_labeled) 131 | 132 | 133 | def intersectionAndUnion(imPred, imLab, numClass): 134 | """ 135 | This function takes the prediction and label of a single image, 136 | returns intersection and union areas for each class 137 | To compute over many images do: 138 | for i in range(Nimages): 139 | (area_intersection[:,i], area_union[:,i]) = intersectionAndUnion(imPred[i], imLab[i]) 140 | IoU = 1.0 * np.sum(area_intersection, axis=1) / np.sum(np.spacing(1)+area_union, axis=1) 141 | """ 142 | # Remove classes from unlabeled pixels in gt image. 143 | # We should not penalize detections in unlabeled portions of the image. 144 | imPred = imPred * (imLab >= 0) 145 | 146 | # Compute area intersection: 147 | intersection = imPred * (imPred == imLab) 148 | (area_intersection, _) = np.histogram(intersection, bins=numClass, range=(1, numClass)) 149 | 150 | # Compute area union: 151 | (area_pred, _) = np.histogram(imPred, bins=numClass, range=(1, numClass)) 152 | (area_lab, _) = np.histogram(imLab, bins=numClass, range=(1, numClass)) 153 | area_union = area_pred + area_lab - area_intersection 154 | return (area_intersection, area_union) 155 | 156 | 157 | def hist_info(pred, label, num_cls): 158 | assert pred.shape == label.shape 159 | k = (label >= 0) & (label < num_cls) 160 | labeled = np.sum(k) 161 | correct = np.sum((pred[k] == label[k])) 162 | 163 | return np.bincount(num_cls * label[k].astype(int) + pred[k], minlength=num_cls ** 2).reshape(num_cls, 164 | num_cls), labeled, correct 165 | 166 | 167 | def compute_score(hist, correct, labeled): 168 | iu = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist)) 169 | mean_IU = np.nanmean(iu) 170 | mean_IU_no_back = np.nanmean(iu[1:]) 171 | freq = hist.sum(1) / hist.sum() 172 | # freq_IU = (iu[freq > 0] * freq[freq > 0]).sum() 173 | mean_pixel_acc = correct / labeled 174 | 175 | return iu, mean_IU, mean_IU_no_back, mean_pixel_acc 176 | -------------------------------------------------------------------------------- /segmentron/modules/batch_norm.py: -------------------------------------------------------------------------------- 1 | # this code heavily based on detectron2 2 | import logging 3 | import torch 4 | import torch.distributed as dist 5 | from torch import nn 6 | from torch.autograd.function import Function 7 | from ..utils.distributed import get_world_size 8 | 9 | 10 | class FrozenBatchNorm2d(nn.Module): 11 | """ 12 | BatchNorm2d where the batch statistics and the affine parameters are fixed. 13 | 14 | It contains non-trainable buffers called 15 | "weight" and "bias", "running_mean", "running_var", 16 | initialized to perform identity transformation. 17 | 18 | The pre-trained backbone models from Caffe2 only contain "weight" and "bias", 19 | which are computed from the original four parameters of BN. 20 | The affine transform `x * weight + bias` will perform the equivalent 21 | computation of `(x - running_mean) / sqrt(running_var) * weight + bias`. 22 | When loading a backbone model from Caffe2, "running_mean" and "running_var" 23 | will be left unchanged as identity transformation. 24 | 25 | Other pre-trained backbone models may contain all 4 parameters. 26 | 27 | The forward is implemented by `F.batch_norm(..., training=False)`. 28 | """ 29 | 30 | _version = 3 31 | 32 | def __init__(self, num_features, eps=1e-5): 33 | super().__init__() 34 | self.num_features = num_features 35 | self.eps = eps 36 | self.register_buffer("weight", torch.ones(num_features)) 37 | self.register_buffer("bias", torch.zeros(num_features)) 38 | self.register_buffer("running_mean", torch.zeros(num_features)) 39 | self.register_buffer("running_var", torch.ones(num_features) - eps) 40 | 41 | def forward(self, x): 42 | scale = self.weight * (self.running_var + self.eps).rsqrt() 43 | bias = self.bias - self.running_mean * scale 44 | scale = scale.reshape(1, -1, 1, 1) 45 | bias = bias.reshape(1, -1, 1, 1) 46 | return x * scale + bias 47 | 48 | def _load_from_state_dict( 49 | self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs 50 | ): 51 | version = local_metadata.get("version", None) 52 | 53 | if version is None or version < 2: 54 | # No running_mean/var in early versions 55 | # This will silent the warnings 56 | if prefix + "running_mean" not in state_dict: 57 | state_dict[prefix + "running_mean"] = torch.zeros_like(self.running_mean) 58 | if prefix + "running_var" not in state_dict: 59 | state_dict[prefix + "running_var"] = torch.ones_like(self.running_var) 60 | 61 | if version is not None and version < 3: 62 | # logger = logging.getLogger(__name__) 63 | logging.info("FrozenBatchNorm {} is upgraded to version 3.".format(prefix.rstrip("."))) 64 | # In version < 3, running_var are used without +eps. 65 | state_dict[prefix + "running_var"] -= self.eps 66 | 67 | super()._load_from_state_dict( 68 | state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs 69 | ) 70 | 71 | def __repr__(self): 72 | return "FrozenBatchNorm2d(num_features={}, eps={})".format(self.num_features, self.eps) 73 | 74 | @classmethod 75 | def convert_frozen_batchnorm(cls, module): 76 | """ 77 | Convert BatchNorm/SyncBatchNorm in module into FrozenBatchNorm. 78 | 79 | Args: 80 | module (torch.nn.Module): 81 | 82 | Returns: 83 | If module is BatchNorm/SyncBatchNorm, returns a new module. 84 | Otherwise, in-place convert module and return it. 85 | 86 | Similar to convert_sync_batchnorm in 87 | https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py 88 | """ 89 | bn_module = nn.modules.batchnorm 90 | bn_module = (bn_module.BatchNorm2d, bn_module.SyncBatchNorm) 91 | res = module 92 | if isinstance(module, bn_module): 93 | res = cls(module.num_features) 94 | if module.affine: 95 | res.weight.data = module.weight.data.clone().detach() 96 | res.bias.data = module.bias.data.clone().detach() 97 | res.running_mean.data = module.running_mean.data 98 | res.running_var.data = module.running_var.data + module.eps 99 | else: 100 | for name, child in module.named_children(): 101 | new_child = cls.convert_frozen_batchnorm(child) 102 | if new_child is not child: 103 | res.add_module(name, new_child) 104 | return res 105 | 106 | 107 | def groupNorm(num_channels, eps=1e-5, momentum=0.1, affine=True): 108 | return nn.GroupNorm(min(32, num_channels), num_channels, eps=eps, affine=affine) 109 | 110 | 111 | def get_norm(norm): 112 | """ 113 | Args: 114 | norm (str or callable): 115 | 116 | Returns: 117 | nn.Module or None: the normalization layer 118 | """ 119 | support_norm_type = ['BN', 'SyncBN', 'FrozenBN', 'GN', 'nnSyncBN'] 120 | assert norm in support_norm_type, 'Unknown norm type {}, support norm types are {}'.format( 121 | norm, support_norm_type) 122 | if isinstance(norm, str): 123 | if len(norm) == 0: 124 | return None 125 | norm = { 126 | "BN": nn.BatchNorm2d, 127 | "SyncBN": NaiveSyncBatchNorm, 128 | "FrozenBN": FrozenBatchNorm2d, 129 | "GN": groupNorm, 130 | "nnSyncBN": nn.SyncBatchNorm, # keep for debugging 131 | }[norm] 132 | return norm 133 | 134 | 135 | class AllReduce(Function): 136 | @staticmethod 137 | def forward(ctx, input): 138 | input_list = [torch.zeros_like(input) for k in range(dist.get_world_size())] 139 | # Use allgather instead of allreduce since I don't trust in-place operations .. 140 | dist.all_gather(input_list, input, async_op=False) 141 | inputs = torch.stack(input_list, dim=0) 142 | return torch.sum(inputs, dim=0) 143 | 144 | @staticmethod 145 | def backward(ctx, grad_output): 146 | dist.all_reduce(grad_output, async_op=False) 147 | return grad_output 148 | 149 | 150 | class NaiveSyncBatchNorm(nn.BatchNorm2d): 151 | """ 152 | `torch.nn.SyncBatchNorm` has known unknown bugs. 153 | It produces significantly worse AP (and sometimes goes NaN) 154 | when the batch size on each worker is quite different 155 | (e.g., when scale augmentation is used, or when it is applied to mask head). 156 | 157 | Use this implementation before `nn.SyncBatchNorm` is fixed. 158 | It is slower than `nn.SyncBatchNorm`. 159 | """ 160 | 161 | def forward(self, input): 162 | if get_world_size() == 1 or not self.training: 163 | return super().forward(input) 164 | 165 | assert input.shape[0] > 0, "SyncBatchNorm does not support empty inputs" 166 | C = input.shape[1] 167 | mean = torch.mean(input, dim=[0, 2, 3]) 168 | meansqr = torch.mean(input * input, dim=[0, 2, 3]) 169 | 170 | vec = torch.cat([mean, meansqr], dim=0) 171 | vec = AllReduce.apply(vec) * (1.0 / dist.get_world_size()) 172 | 173 | mean, meansqr = torch.split(vec, C) 174 | var = meansqr - mean * mean 175 | self.running_mean += self.momentum * (mean.detach() - self.running_mean) 176 | self.running_var += self.momentum * (var.detach() - self.running_var) 177 | 178 | invstd = torch.rsqrt(var + self.eps) 179 | scale = self.weight * invstd 180 | bias = self.bias - mean * scale 181 | scale = scale.reshape(1, -1, 1, 1) 182 | bias = bias.reshape(1, -1, 1, 1) 183 | return input * scale + bias 184 | -------------------------------------------------------------------------------- /segmentron/config/settings.py: -------------------------------------------------------------------------------- 1 | from .config import SegmentronConfig 2 | 3 | cfg = SegmentronConfig() 4 | 5 | ########################## basic set ########################################### 6 | # random seed 7 | cfg.SEED = 1024 8 | # train time stamp, auto generate, do not need to set 9 | cfg.TIME_STAMP = '' 10 | # root path 11 | cfg.ROOT_PATH = '' 12 | # model phase ['train', 'test'] 13 | cfg.PHASE = 'train' 14 | 15 | ########################## dataset config ######################################### 16 | # dataset name 17 | cfg.DATASET.NAME = '' 18 | # pixel mean 19 | cfg.DATASET.MEAN = [0.5, 0.5, 0.5] 20 | # pixel std 21 | cfg.DATASET.STD = [0.5, 0.5, 0.5] 22 | # dataset ignore index 23 | cfg.DATASET.IGNORE_INDEX = -1 24 | # workers 25 | cfg.DATASET.WORKERS = 8 26 | # val dataset mode 27 | cfg.DATASET.MODE = 'testval' 28 | ########################### data augment ###################################### 29 | # data augment image mirror 30 | cfg.AUG.MIRROR = True 31 | # blur probability 32 | cfg.AUG.BLUR_PROB = 0.0 33 | # blur radius 34 | cfg.AUG.BLUR_RADIUS = 0.0 35 | # color jitter, float or tuple: (0.1, 0.2, 0.3, 0.4) 36 | cfg.AUG.COLOR_JITTER = None 37 | cfg.AUG.CROP = True 38 | ########################### train config ########################################## 39 | # epochs 40 | cfg.TRAIN.EPOCHS = 30 41 | # iterations 42 | cfg.TRAIN.ITERS = 40000 43 | # batch size 44 | cfg.TRAIN.BATCH_SIZE = 1 45 | # train crop size 46 | cfg.TRAIN.CROP_SIZE = 769 47 | # train base size 48 | cfg.TRAIN.BASE_SIZE = 1024 49 | # model output dir 50 | cfg.TRAIN.MODEL_SAVE_DIR = 'workdirs/' 51 | # log dir 52 | cfg.TRAIN.LOG_SAVE_DIR = cfg.TRAIN.MODEL_SAVE_DIR 53 | # pretrained model for eval or finetune 54 | cfg.TRAIN.PRETRAINED_MODEL_PATH = '' 55 | # use pretrained backbone model over imagenet 56 | cfg.TRAIN.BACKBONE_PRETRAINED = True 57 | # backbone pretrained model path, if not specific, will load from url when backbone pretrained enabled 58 | cfg.TRAIN.BACKBONE_PRETRAINED_PATH = '' 59 | # resume model path 60 | cfg.TRAIN.RESUME_MODEL_PATH = '' 61 | # whether to use synchronize bn 62 | cfg.TRAIN.SYNC_BATCH_NORM = True 63 | # save model every checkpoint-epoch 64 | cfg.TRAIN.SNAPSHOT_EPOCH = 1 65 | # apex training? 66 | cfg.TRAIN.APEX = False 67 | ########################### optimizer config ################################## 68 | # base learning rate 69 | cfg.SOLVER.LR = 1e-4 70 | # optimizer method 71 | cfg.SOLVER.OPTIMIZER = "sgd" 72 | # optimizer epsilon 73 | cfg.SOLVER.EPSILON = 1e-8 74 | # optimizer momentum 75 | cfg.SOLVER.MOMENTUM = 0.9 76 | # weight decay 77 | cfg.SOLVER.WEIGHT_DECAY = 1e-4 #0.00004 78 | # decoder lr x10 79 | cfg.SOLVER.DECODER_LR_FACTOR = 10.0 80 | # lr scheduler mode 81 | cfg.SOLVER.LR_SCHEDULER = "poly" 82 | # poly power 83 | cfg.SOLVER.POLY.POWER = 0.9 84 | # step gamma 85 | cfg.SOLVER.STEP.GAMMA = 0.1 86 | # milestone of step lr scheduler 87 | cfg.SOLVER.STEP.DECAY_EPOCH = [10, 20] 88 | # warm up epochs can be float 89 | cfg.SOLVER.WARMUP.EPOCHS = 0. 90 | # warm up factor 91 | cfg.SOLVER.WARMUP.FACTOR = 1.0 / 3 92 | # warm up method 93 | cfg.SOLVER.WARMUP.METHOD = 'linear' 94 | # whether to use ohem 95 | cfg.SOLVER.OHEM = False 96 | # whether to use aux loss 97 | cfg.SOLVER.AUX = False 98 | # aux loss weight 99 | cfg.SOLVER.AUX_WEIGHT = 0.4 100 | # loss name 101 | cfg.SOLVER.LOSS_NAME = '' 102 | ########################## test config ########################################### 103 | # val/test model path 104 | cfg.TEST.TEST_MODEL_PATH = '' 105 | # test batch size 106 | cfg.TEST.BATCH_SIZE = 1 107 | # eval crop size 108 | cfg.TEST.CROP_SIZE = None 109 | # multiscale eval 110 | cfg.TEST.SCALES = [1.0] 111 | # flip 112 | cfg.TEST.FLIP = False 113 | 114 | ########################## visual config ########################################### 115 | # visual result output dir 116 | cfg.VISUAL.OUTPUT_DIR = '../runs/visual/' 117 | 118 | ########################## model ####################################### 119 | # model name 120 | cfg.MODEL.MODEL_NAME = '' 121 | # model backbone 122 | cfg.MODEL.BACKBONE = '' 123 | # model backbone channel scale 124 | cfg.MODEL.BACKBONE_SCALE = 1.0 125 | # support resnet b, c. b is standard resnet in pytorch official repo 126 | # cfg.MODEL.RESNET_VARIANT = 'b' 127 | # multi branch loss weight 128 | cfg.MODEL.MULTI_LOSS_WEIGHT = [1.0] 129 | # gn groups 130 | cfg.MODEL.DEFAULT_GROUP_NUMBER = 32 131 | # whole model default epsilon 132 | cfg.MODEL.DEFAULT_EPSILON = 1e-5 133 | # batch norm, support ['BN', 'SyncBN', 'FrozenBN', 'GN', 'nnSyncBN'] 134 | cfg.MODEL.BN_TYPE = 'BN' 135 | # batch norm epsilon for encoder, if set None will use api default value. 136 | cfg.MODEL.BN_EPS_FOR_ENCODER = None 137 | # batch norm epsilon for encoder, if set None will use api default value. 138 | cfg.MODEL.BN_EPS_FOR_DECODER = None 139 | # backbone output stride 140 | cfg.MODEL.OUTPUT_STRIDE = 16 141 | # BatchNorm momentum, if set None will use api default value. 142 | cfg.MODEL.BN_MOMENTUM = None 143 | 144 | cfg.MODEL.EMB_CHANNELS = 64 145 | 146 | ########################## DANet config #################################### 147 | # danet param 148 | cfg.MODEL.DANET.MULTI_DILATION = None 149 | # danet param 150 | cfg.MODEL.DANET.MULTI_GRID = False 151 | 152 | ########################## DeepLab config #################################### 153 | # whether to use aspp 154 | cfg.MODEL.DEEPLABV3_PLUS.USE_ASPP = True 155 | # whether to use decoder 156 | cfg.MODEL.DEEPLABV3_PLUS.ENABLE_DECODER = True 157 | # whether aspp use sep conv 158 | cfg.MODEL.DEEPLABV3_PLUS.ASPP_WITH_SEP_CONV = True 159 | # whether decoder use sep conv 160 | cfg.MODEL.DEEPLABV3_PLUS.DECODER_USE_SEP_CONV = True 161 | 162 | ########################## UNET config ####################################### 163 | # upsample mode 164 | # cfg.MODEL.UNET.UPSAMPLE_MODE = 'bilinear' 165 | 166 | ########################## OCNet config ###################################### 167 | # ['base', 'pyramid', 'asp'] 168 | cfg.MODEL.OCNet.OC_ARCH = 'base' 169 | 170 | ########################## EncNet config ###################################### 171 | cfg.MODEL.ENCNET.SE_LOSS = True 172 | cfg.MODEL.ENCNET.SE_WEIGHT = 0.2 173 | cfg.MODEL.ENCNET.LATERAL = True 174 | 175 | 176 | ########################## CCNET config ###################################### 177 | cfg.MODEL.CCNET.RECURRENCE = 2 178 | 179 | ########################## CGNET config ###################################### 180 | cfg.MODEL.CGNET.STAGE2_BLOCK_NUM = 3 181 | cfg.MODEL.CGNET.STAGE3_BLOCK_NUM = 21 182 | 183 | ########################## PointRend config ################################## 184 | cfg.MODEL.POINTREND.BASEMODEL = 'DeepLabV3_Plus' 185 | 186 | ########################## hrnet config ###################################### 187 | cfg.MODEL.HRNET.PRETRAINED_LAYERS = ['*'] 188 | cfg.MODEL.HRNET.STEM_INPLANES = 64 189 | cfg.MODEL.HRNET.FINAL_CONV_KERNEL = 1 190 | cfg.MODEL.HRNET.WITH_HEAD = True 191 | # stage 1 192 | cfg.MODEL.HRNET.STAGE1.NUM_MODULES = 1 193 | cfg.MODEL.HRNET.STAGE1.NUM_BRANCHES = 1 194 | cfg.MODEL.HRNET.STAGE1.NUM_BLOCKS = [1] 195 | cfg.MODEL.HRNET.STAGE1.NUM_CHANNELS = [32] 196 | cfg.MODEL.HRNET.STAGE1.BLOCK = 'BOTTLENECK' 197 | cfg.MODEL.HRNET.STAGE1.FUSE_METHOD = 'SUM' 198 | # stage 2 199 | cfg.MODEL.HRNET.STAGE2.NUM_MODULES = 1 200 | cfg.MODEL.HRNET.STAGE2.NUM_BRANCHES = 2 201 | cfg.MODEL.HRNET.STAGE2.NUM_BLOCKS = [4, 4] 202 | cfg.MODEL.HRNET.STAGE2.NUM_CHANNELS = [32, 64] 203 | cfg.MODEL.HRNET.STAGE2.BLOCK = 'BASIC' 204 | cfg.MODEL.HRNET.STAGE2.FUSE_METHOD = 'SUM' 205 | # stage 3 206 | cfg.MODEL.HRNET.STAGE3.NUM_MODULES = 1 207 | cfg.MODEL.HRNET.STAGE3.NUM_BRANCHES = 3 208 | cfg.MODEL.HRNET.STAGE3.NUM_BLOCKS = [4, 4, 4] 209 | cfg.MODEL.HRNET.STAGE3.NUM_CHANNELS = [32, 64, 128] 210 | cfg.MODEL.HRNET.STAGE3.BLOCK = 'BASIC' 211 | cfg.MODEL.HRNET.STAGE3.FUSE_METHOD = 'SUM' 212 | # stage 4 213 | cfg.MODEL.HRNET.STAGE4.NUM_MODULES = 1 214 | cfg.MODEL.HRNET.STAGE4.NUM_BRANCHES = 4 215 | cfg.MODEL.HRNET.STAGE4.NUM_BLOCKS = [4, 4, 4, 4] 216 | cfg.MODEL.HRNET.STAGE4.NUM_CHANNELS = [32, 64, 128, 256] 217 | cfg.MODEL.HRNET.STAGE4.BLOCK = 'BASIC' 218 | cfg.MODEL.HRNET.STAGE4.FUSE_METHOD = 'SUM' 219 | 220 | 221 | ########################## translab config ###################################### 222 | cfg.MODEL.TRANSLAB.BOUNDARY_WEIGHT = 5 223 | 224 | ########################## trans4trans config ##################################### 225 | cfg.MODEL.TRANS4TRANS.embed_dim = 256 226 | cfg.MODEL.TRANS4TRANS.depth = 4 227 | cfg.MODEL.TRANS4TRANS.num_heads = 8 228 | cfg.MODEL.TRANS4TRANS.mlp_ratio = 3. 229 | cfg.MODEL.TRANS4TRANS.hid_dim = 64 230 | 231 | 232 | 233 | --------------------------------------------------------------------------------