├── INSTALL.md ├── README.md ├── assets ├── demo.png ├── snake_city.png ├── snake_tensorboard.png └── vis_city.png ├── config.yaml ├── data └── record │ ├── events.out.tfevents.1705824855.medicalimaginglab │ ├── events.out.tfevents.1705824921.medicalimaginglab │ ├── events.out.tfevents.1705824973.medicalimaginglab │ ├── events.out.tfevents.1705825034.medicalimaginglab │ ├── events.out.tfevents.1705825091.medicalimaginglab │ ├── events.out.tfevents.1705825155.medicalimaginglab │ ├── events.out.tfevents.1705825364.medicalimaginglab │ ├── events.out.tfevents.1705825410.medicalimaginglab │ ├── events.out.tfevents.1705825464.medicalimaginglab │ ├── events.out.tfevents.1705825543.medicalimaginglab │ ├── events.out.tfevents.1705825597.medicalimaginglab │ ├── events.out.tfevents.1705825689.medicalimaginglab │ ├── events.out.tfevents.1705825770.medicalimaginglab │ ├── events.out.tfevents.1705825879.medicalimaginglab │ ├── events.out.tfevents.1705826211.medicalimaginglab │ ├── events.out.tfevents.1705826246.medicalimaginglab │ ├── events.out.tfevents.1705826419.medicalimaginglab │ ├── events.out.tfevents.1705826453.medicalimaginglab │ ├── events.out.tfevents.1705826494.medicalimaginglab │ ├── events.out.tfevents.1705826520.medicalimaginglab │ ├── events.out.tfevents.1705826629.medicalimaginglab │ └── events.out.tfevents.1705826773.medicalimaginglab ├── lib ├── __init__.py ├── __pycache__ │ └── __init__.cpython-37.pyc ├── config │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── config.cpython-37.pyc │ │ ├── config.cpython-38.pyc │ │ ├── config.cpython-39.pyc │ │ ├── yacs.cpython-37.pyc │ │ ├── yacs.cpython-38.pyc │ │ └── yacs.cpython-39.pyc │ ├── config.py │ └── yacs.py ├── csrc │ ├── DCNv1 │ │ ├── .gitignore │ │ ├── LICENSE │ │ ├── README.md │ │ ├── __init__.py │ │ ├── dcn_v2.py │ │ ├── dcn_v2_onnx.py │ │ ├── make.sh │ │ ├── setup.py │ │ ├── src │ │ │ ├── cpu │ │ │ │ ├── dcn_v2_cpu.cpp │ │ │ │ ├── dcn_v2_im2col_cpu.cpp │ │ │ │ ├── dcn_v2_im2col_cpu.h │ │ │ │ ├── dcn_v2_psroi_pooling_cpu.cpp │ │ │ │ └── vision.h │ │ │ ├── cuda │ │ │ │ ├── dcn_v2_cuda.cu │ │ │ │ ├── dcn_v2_im2col_cuda.cu │ │ │ │ ├── dcn_v2_im2col_cuda.h │ │ │ │ ├── dcn_v2_psroi_pooling_cuda.cu │ │ │ │ └── vision.h │ │ │ ├── dcn_v2.h │ │ │ └── vision.cpp │ │ ├── testcpu.py │ │ └── testcuda.py │ ├── extreme_utils │ │ ├── _ext.cpython-37m-x86_64-linux-gnu.so │ │ ├── build │ │ │ ├── lib.linux-x86_64-cpython-37 │ │ │ │ └── _ext.cpython-37m-x86_64-linux-gnu.so │ │ │ └── temp.linux-x86_64-cpython-37 │ │ │ │ ├── .ninja_deps │ │ │ │ ├── .ninja_log │ │ │ │ ├── build.ninja │ │ │ │ └── data │ │ │ │ └── tzx │ │ │ │ └── snake │ │ │ │ └── lib │ │ │ │ └── csrc │ │ │ │ └── extreme_utils │ │ │ │ ├── src │ │ │ │ ├── nms.o │ │ │ │ └── utils.o │ │ │ │ └── utils.o │ │ ├── setup.py │ │ ├── src │ │ │ ├── cuda_common.h │ │ │ ├── nms.cu │ │ │ ├── nms.h │ │ │ └── utils.cu │ │ ├── utils.cpp │ │ └── utils.h │ └── roi_align_layer │ │ ├── ROIAlign.h │ │ ├── _roi_align.cpython-37m-x86_64-linux-gnu.so │ │ ├── build │ │ ├── lib.linux-x86_64-cpython-37 │ │ │ └── _roi_align.cpython-37m-x86_64-linux-gnu.so │ │ └── temp.linux-x86_64-cpython-37 │ │ │ ├── .ninja_deps │ │ │ ├── .ninja_log │ │ │ ├── build.ninja │ │ │ └── data │ │ │ └── tzx │ │ │ └── snake │ │ │ └── lib │ │ │ └── csrc │ │ │ └── roi_align_layer │ │ │ ├── cpu │ │ │ └── ROIAlign_cpu.o │ │ │ ├── cuda │ │ │ └── ROIAlign_cuda.o │ │ │ └── vision.o │ │ ├── cpu │ │ ├── ROIAlign_cpu.cpp │ │ └── vision.h │ │ ├── cuda │ │ ├── ROIAlign_cuda.cu │ │ └── vision.h │ │ ├── roi_align.py │ │ ├── setup.py │ │ └── vision.cpp ├── datasets │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── collate_batch.cpython-37.pyc │ │ ├── dataset_catalog.cpython-37.pyc │ │ ├── make_dataset.cpython-37.pyc │ │ ├── samplers.cpython-37.pyc │ │ └── transforms.cpython-37.pyc │ ├── collate_batch.py │ ├── dataset_catalog.py │ ├── make_dataset.py │ ├── medical │ │ ├── __pycache__ │ │ │ └── snake.cpython-37.pyc │ │ └── snake.py │ ├── samplers.py │ └── transforms.py ├── networks │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── dcn_v2.cpython-37.pyc │ │ ├── make_network.cpython-37.pyc │ │ └── make_network.cpython-38.pyc │ ├── ct_rcnn │ │ ├── __init__.py │ │ ├── cp_head.py │ │ ├── ct_rcnn.py │ │ └── dla.py │ ├── dcn_v2.py │ ├── make_network.py │ ├── rcnn_snake │ │ ├── __init__.py │ │ ├── cp_head.py │ │ ├── ct_rcnn_snake.py │ │ ├── dla.py │ │ ├── evolve.py │ │ └── snake.py │ └── snake │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── ct_snake.cpython-37.pyc │ │ ├── dla.cpython-37.pyc │ │ ├── evolve.cpython-37.pyc │ │ ├── snake.cpython-37.pyc │ │ └── unethead.cpython-37.pyc │ │ ├── ct_snake.py │ │ ├── ct_snake_.py │ │ ├── ct_snake_crossatt.py │ │ ├── ct_snake单独训练unet.py │ │ ├── ct_snake用预训练unet.py │ │ ├── dla.py │ │ ├── evolve.py │ │ ├── snake.py │ │ ├── unet.py │ │ └── unethead.py ├── train │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── optimizer.cpython-37.pyc │ │ ├── optimizer.cpython-38.pyc │ │ ├── recorder.cpython-37.pyc │ │ ├── recorder.cpython-38.pyc │ │ ├── scheduler.cpython-37.pyc │ │ └── scheduler.cpython-38.pyc │ ├── optimizer.py │ ├── recorder.py │ ├── scheduler.py │ └── trainers │ │ ├── PolyProcess.py │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── CCQLoss.cpython-37.pyc │ │ ├── PolyProcess.cpython-37.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── make_trainer.cpython-37.pyc │ │ ├── make_trainer.cpython-38.pyc │ │ ├── snake.cpython-37.pyc │ │ ├── snakerec.cpython-37.pyc │ │ ├── trainer.cpython-37.pyc │ │ └── trainer.cpython-38.pyc │ │ ├── ct_rcnn.py │ │ ├── make_trainer.py │ │ ├── rcnn_snake.py │ │ ├── snake.py │ │ ├── snake2.py │ │ ├── snakerec.py │ │ └── trainer.py └── utils │ ├── __pycache__ │ ├── data_utils.cpython-37.pyc │ ├── data_utils.cpython-38.pyc │ ├── getedge.cpython-37.pyc │ ├── img_utils.cpython-37.pyc │ └── net_utils.cpython-37.pyc │ ├── base_utils.py │ ├── data_utils.py │ ├── getedge.py │ ├── img_utils.py │ ├── net_utils.py │ ├── optimizer │ ├── __pycache__ │ │ ├── lr_scheduler.cpython-37.pyc │ │ ├── lr_scheduler.cpython-38.pyc │ │ ├── radam.cpython-37.pyc │ │ └── radam.cpython-38.pyc │ ├── lr_scheduler.py │ └── radam.py │ ├── rcnn_snake │ ├── rcnn_snake_config.py │ └── rcnn_snake_utils.py │ └── snake │ ├── __pycache__ │ ├── active_spline.cpython-37.pyc │ ├── snake_cityscapes_utils.cpython-37.pyc │ ├── snake_config.cpython-37.pyc │ ├── snake_decode.cpython-37.pyc │ ├── snake_gcn_utils.cpython-37.pyc │ ├── snake_voc_utils.cpython-37.pyc │ └── visualize_utils.cpython-37.pyc │ ├── active_spline.py │ ├── snake_cityscapes_coco_utils.py │ ├── snake_cityscapes_utils.py │ ├── snake_coco_utils.py │ ├── snake_config.py │ ├── snake_decode.py │ ├── snake_eval_utils.py │ ├── snake_gcn_utils.py │ ├── snake_kins_utils.py │ ├── snake_poly_utils.py │ ├── snake_voc_utils.py │ └── visualize_utils.py ├── poly2mask.py ├── process ├── __pycache__ │ └── metrics.cpython-37.pyc └── metrics.py ├── requirements.txt ├── run.py ├── run_visual.py ├── test.py ├── tools ├── __pycache__ │ ├── demo.cpython-37.pyc │ ├── demo.cpython-38.pyc │ └── visualization.cpython-37.pyc ├── demo.py └── visualization.py └── train_net.py /INSTALL.md: -------------------------------------------------------------------------------- 1 | ### Set up the python environment 2 | 3 | ``` 4 | conda create -n snake python=3.7 5 | conda activate snake 6 | 7 | # make sure that the pytorch cuda is consistent with the system cuda 8 | # e.g., if your system cuda is 9.0, install torch 1.1 built from cuda 9.0 9 | pip install torch==1.1.0 -f https://download.pytorch.org/whl/cu90/stable 10 | 11 | pip install Cython==0.28.2 12 | pip install -r requirements.txt 13 | 14 | # install apex 15 | cd 16 | git clone https://github.com/NVIDIA/apex.git 17 | cd apex 18 | git checkout 39e153a3159724432257a8fc118807b359f4d1c8 19 | export CUDA_HOME="/usr/local/cuda-9.0" 20 | python setup.py install --cuda_ext --cpp_ext 21 | ``` 22 | 23 | ### Compile cuda extensions under `lib/csrc` 24 | 25 | ``` 26 | ROOT=/path/to/snake 27 | cd $ROOT/lib/csrc 28 | export CUDA_HOME="/usr/local/cuda-9.0" 29 | cd dcn_v2 30 | python setup.py build_ext --inplace 31 | cd ../extreme_utils 32 | python setup.py build_ext --inplace 33 | cd ../roi_align_layer 34 | python setup.py build_ext --inplace 35 | ``` 36 | 37 | 38 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Progressive Deep Snake for Instance Boundary Extraction in Medical Images 2 | 3 | 4 | ## Installation 5 | 6 | Please see [INSTALL.md](INSTALL.md). 7 | 8 | ## Data prepare 9 | We provide preprocessed MRSpineSeg and Verse20 datasete. Please use the following links to download them: 10 | 11 | - [MRSpineSeg](https://pan.baidu.com/s/1N-0_Odxe0MI6aJbxipExgQ?pwd=1234) (code: 1234) 12 | - [Verse20](https://pan.baidu.com/s/1TyMgLM_5zwMg6QIs4ORavw?pwd=1234) (code: 1234) 13 | 14 | Then unzip them to ./data/dataset . 15 | 16 | ```bash 17 | unzip MRSpineSeg.zip -d ./data/dataset/MRSpineSeg 18 | unzip Verse20.zip -d ./data/dataset/Verse20 19 | ``` 20 | ## Training 21 | 22 | ### Training on MRSpineSeg 23 | 1. Change the 'model_dir' in ./config.yaml to './data/model/MRSpineSeg'; 24 | 2. Change 'data_path' to './data/MRSpineSeg' 25 | 3. Run the code: 26 | ```bash 27 | python train_net.py 28 | ``` 29 | ### Training on Verse 30 | 1. Change the 'model_dir' in ./config.yaml to './data/model/Verse20'; 31 | 2. Change 'data_path' to './data/dataset/Verse20' 32 | 3. Run the code: 33 | ```bash 34 | python train_net.py 35 | ``` 36 | ## Testing 37 | ```bash 38 | python test.py 39 | ``` 40 | 41 | 42 | ## Acknowledgment 43 | 44 | -------------------------------------------------------------------------------- /assets/demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/assets/demo.png -------------------------------------------------------------------------------- /assets/snake_city.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/assets/snake_city.png -------------------------------------------------------------------------------- /assets/snake_tensorboard.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/assets/snake_tensorboard.png -------------------------------------------------------------------------------- /assets/vis_city.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/assets/vis_city.png -------------------------------------------------------------------------------- /config.yaml: -------------------------------------------------------------------------------- 1 | model: 'sbd' 2 | network: 'ro_34' 3 | evolution_num: 3 4 | task: 'snake' 5 | resume: true 6 | ff_num: 0 7 | circle_rate: 1.0 8 | layer1rate: [0.0, 0.25] 9 | ifmultistage: True 10 | isrec: False 11 | ct_weight: 1.0 12 | wh_weight: 0.1 13 | replace_box: True 14 | 15 | 16 | model_dir: "path/to/save/model" 17 | data_path: "path/to/dataset" 18 | 19 | gpus: (0,1) 20 | 21 | train: 22 | optim: 'adam' 23 | lr: 1e-4 24 | milestones: (80, 120) 25 | gamma: 0.2 26 | batch_size: 2 27 | dataset: 'MedicalTrain' 28 | num_workers: 0 29 | epoch: 150 30 | test: 31 | dataset: 'MedicalTest' 32 | batch_size: 8 33 | 34 | heads: {'ct_hm': 1, 'wh': 2} 35 | segm_or_bbox: 'segm' 36 | -------------------------------------------------------------------------------- /data/record/events.out.tfevents.1705824855.medicalimaginglab: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/data/record/events.out.tfevents.1705824855.medicalimaginglab -------------------------------------------------------------------------------- /data/record/events.out.tfevents.1705824921.medicalimaginglab: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/data/record/events.out.tfevents.1705824921.medicalimaginglab -------------------------------------------------------------------------------- /data/record/events.out.tfevents.1705824973.medicalimaginglab: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/data/record/events.out.tfevents.1705824973.medicalimaginglab -------------------------------------------------------------------------------- /data/record/events.out.tfevents.1705825034.medicalimaginglab: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/data/record/events.out.tfevents.1705825034.medicalimaginglab -------------------------------------------------------------------------------- /data/record/events.out.tfevents.1705825091.medicalimaginglab: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/data/record/events.out.tfevents.1705825091.medicalimaginglab -------------------------------------------------------------------------------- /data/record/events.out.tfevents.1705825155.medicalimaginglab: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/data/record/events.out.tfevents.1705825155.medicalimaginglab -------------------------------------------------------------------------------- /data/record/events.out.tfevents.1705825364.medicalimaginglab: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/data/record/events.out.tfevents.1705825364.medicalimaginglab -------------------------------------------------------------------------------- /data/record/events.out.tfevents.1705825410.medicalimaginglab: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/data/record/events.out.tfevents.1705825410.medicalimaginglab -------------------------------------------------------------------------------- /data/record/events.out.tfevents.1705825464.medicalimaginglab: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/data/record/events.out.tfevents.1705825464.medicalimaginglab -------------------------------------------------------------------------------- /data/record/events.out.tfevents.1705825543.medicalimaginglab: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/data/record/events.out.tfevents.1705825543.medicalimaginglab -------------------------------------------------------------------------------- /data/record/events.out.tfevents.1705825597.medicalimaginglab: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/data/record/events.out.tfevents.1705825597.medicalimaginglab -------------------------------------------------------------------------------- /data/record/events.out.tfevents.1705825689.medicalimaginglab: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/data/record/events.out.tfevents.1705825689.medicalimaginglab -------------------------------------------------------------------------------- /data/record/events.out.tfevents.1705825770.medicalimaginglab: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/data/record/events.out.tfevents.1705825770.medicalimaginglab -------------------------------------------------------------------------------- /data/record/events.out.tfevents.1705825879.medicalimaginglab: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/data/record/events.out.tfevents.1705825879.medicalimaginglab -------------------------------------------------------------------------------- /data/record/events.out.tfevents.1705826211.medicalimaginglab: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/data/record/events.out.tfevents.1705826211.medicalimaginglab -------------------------------------------------------------------------------- /data/record/events.out.tfevents.1705826246.medicalimaginglab: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/data/record/events.out.tfevents.1705826246.medicalimaginglab -------------------------------------------------------------------------------- /data/record/events.out.tfevents.1705826419.medicalimaginglab: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/data/record/events.out.tfevents.1705826419.medicalimaginglab -------------------------------------------------------------------------------- /data/record/events.out.tfevents.1705826453.medicalimaginglab: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/data/record/events.out.tfevents.1705826453.medicalimaginglab -------------------------------------------------------------------------------- /data/record/events.out.tfevents.1705826494.medicalimaginglab: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/data/record/events.out.tfevents.1705826494.medicalimaginglab -------------------------------------------------------------------------------- /data/record/events.out.tfevents.1705826520.medicalimaginglab: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/data/record/events.out.tfevents.1705826520.medicalimaginglab -------------------------------------------------------------------------------- /data/record/events.out.tfevents.1705826629.medicalimaginglab: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/data/record/events.out.tfevents.1705826629.medicalimaginglab -------------------------------------------------------------------------------- /data/record/events.out.tfevents.1705826773.medicalimaginglab: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/data/record/events.out.tfevents.1705826773.medicalimaginglab -------------------------------------------------------------------------------- /lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/lib/__init__.py -------------------------------------------------------------------------------- /lib/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/lib/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /lib/config/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import cfg, args 2 | -------------------------------------------------------------------------------- /lib/config/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/lib/config/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /lib/config/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/lib/config/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /lib/config/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/lib/config/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /lib/config/__pycache__/config.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/lib/config/__pycache__/config.cpython-37.pyc -------------------------------------------------------------------------------- /lib/config/__pycache__/config.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/lib/config/__pycache__/config.cpython-38.pyc -------------------------------------------------------------------------------- /lib/config/__pycache__/config.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/lib/config/__pycache__/config.cpython-39.pyc -------------------------------------------------------------------------------- /lib/config/__pycache__/yacs.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/lib/config/__pycache__/yacs.cpython-37.pyc -------------------------------------------------------------------------------- /lib/config/__pycache__/yacs.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/lib/config/__pycache__/yacs.cpython-38.pyc -------------------------------------------------------------------------------- /lib/config/__pycache__/yacs.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/lib/config/__pycache__/yacs.cpython-39.pyc -------------------------------------------------------------------------------- /lib/config/config.py: -------------------------------------------------------------------------------- 1 | from .yacs import CfgNode as CN 2 | import argparse 3 | import os 4 | 5 | cfg = CN() 6 | 7 | # model 8 | cfg.model = 'hello' 9 | cfg.model_dir = 'data/model' 10 | cfg.ff_num = 5 11 | cfg.layer1rate=[0.0,0.1] 12 | cfg.ifmultistage=True 13 | cfg.isrec=True 14 | cfg.ct_weight = 0.1 15 | cfg.wh_weight = 0.2 16 | cfg.replace_box = True 17 | 18 | # network 19 | cfg.network = 'dla_34' 20 | cfg.evolution_num =30 21 | 22 | # network heads 23 | cfg.heads = CN() 24 | 25 | # task 26 | cfg.task = '' 27 | 28 | # gpus 29 | cfg.gpus = [0] 30 | 31 | # if load the pretrained network 32 | cfg.resume = True 33 | 34 | 35 | # ----------------------------------------------------------------------------- 36 | # train 37 | # ----------------------------------------------------------------------------- 38 | cfg.train = CN() 39 | 40 | cfg.train.dataset = 'CocoTrain' 41 | cfg.train.epoch = 140 42 | cfg.train.num_workers = 8 43 | cfg.data_path = '/data/tzx/data/images' 44 | 45 | # use adam as default 46 | cfg.train.optim = 'adam' 47 | cfg.train.lr = 1e-4 48 | cfg.train.weight_decay = 5e-4 49 | 50 | cfg.train.warmup = False 51 | cfg.train.scheduler = '' 52 | cfg.train.milestones = [80, 120, 200, 240] 53 | cfg.train.gamma = 0.5 54 | 55 | cfg.train.batch_size = 4 56 | 57 | # test 58 | cfg.test = CN() 59 | cfg.test.dataset = 'CocoVal' 60 | cfg.test.batch_size = 1 61 | cfg.test.epoch = -1 62 | 63 | # recorder 64 | cfg.record_dir = 'data/record' 65 | 66 | # result 67 | cfg.result_dir = 'data/result' 68 | 69 | # evaluation 70 | cfg.skip_eval = False 71 | 72 | cfg.save_ep = 30 73 | cfg.eval_ep = 5 74 | 75 | cfg.use_gt_det = False 76 | 77 | # ----------------------------------------------------------------------------- 78 | # snake 79 | # ----------------------------------------------------------------------------- 80 | cfg.ct_score = 0.05 81 | cfg.demo_path = '' 82 | cfg.ff_num=0 83 | 84 | 85 | def parse_cfg(cfg, args): 86 | if len(cfg.task) == 0: 87 | raise ValueError('task must be specified') 88 | 89 | # assign the gpus 90 | os.environ['CUDA_VISIBLE_DEVICES'] = ', '.join([str(gpu) for gpu in cfg.gpus]) 91 | 92 | cfg.det_dir = os.path.join(cfg.model_dir, cfg.task, args.det) 93 | 94 | # assign the network head conv 95 | cfg.head_conv = 64 if 'res' in cfg.network else 256 96 | 97 | #cfg.model_dir = os.path.join(cfg.model_dir, cfg.workname, cfg.model) 98 | #cfg.record_dir = os.path.join(cfg.record_dir, cfg.workname, cfg.model) 99 | #cfg.result_dir = os.path.join(cfg.result_dir, cfg.workname, cfg.model) 100 | 101 | 102 | def make_cfg(args): 103 | cfg.merge_from_file(args.cfg_file) 104 | cfg.merge_from_list(args.opts) 105 | parse_cfg(cfg, args) 106 | return cfg 107 | 108 | 109 | parser = argparse.ArgumentParser() 110 | parser.add_argument("--cfg_file", default="config.yaml", type=str) 111 | parser.add_argument('--test', action='store_true', dest='test', default=False) 112 | parser.add_argument("--type", type=str, default="") 113 | parser.add_argument('--det', type=str, default='') 114 | parser.add_argument('-f', type=str, default='') 115 | parser.add_argument("opts", default=None, nargs=argparse.REMAINDER) 116 | args = parser.parse_args() 117 | if len(args.type) > 0: 118 | cfg.task = "run" 119 | cfg = make_cfg(args) 120 | -------------------------------------------------------------------------------- /lib/csrc/DCNv1/.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | .idea 3 | *.so 4 | *.o 5 | *pyc 6 | _ext 7 | build 8 | DCNv2.egg-info 9 | dist 10 | vendor/ 11 | 12 | -------------------------------------------------------------------------------- /lib/csrc/DCNv1/LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2019, Charles Shang 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /lib/csrc/DCNv1/README.md: -------------------------------------------------------------------------------- 1 | # DCNv2 latest 2 | 3 | - Add support for pytorch1.11 (may be not backward-compatible). 4 | - Test on ubuntu20.04, python3.8(conda), cuda_11.4 5 | 6 | It was confirmed that pytorch1.11 worked, but not compatible with previous pytorch version. If you want pytorch1.10 or earlier, please using pytorch1.6 branch, or using last git commit. 7 | 8 | It's suggested using latest stable pytorch 1.11 to start your project. 9 | 10 | 11 | ## Install 12 | 13 | ```bash 14 | $ python3 setup.py build develop 15 | ``` 16 | 17 | ## Updates 18 | 19 | - **2021.03.24**: It was confirmed PyTorch 1.8 is OK with master branch, feel free to use it. 20 | - **2021.02.18**: Happy new year! PyTorch 1.7 finally supported on master branch! **for lower version theoretically also works, if not, pls fire an issue to me!**. 21 | - **2020.09.23**: Now master branch works for pytorch 1.6 by default, for older version you gonna need separated one. 22 | - **2020.08.25**: Check out pytorch1.6 branch for pytorch 1.6 support, you will meet an error like `THCudaBlas_Sgemv undefined` if you using pytorch 1.6 build master branch. master branch now work for pytorch 1.5; 23 | 24 | ## Contact 25 | 26 | If you have any question, please using this platform post questions: http://t.manaai.cn 27 | -------------------------------------------------------------------------------- /lib/csrc/DCNv1/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/lib/csrc/DCNv1/__init__.py -------------------------------------------------------------------------------- /lib/csrc/DCNv1/make.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | sudo rm *.so 3 | sudo rm -r build/ 4 | sudo python3 setup.py build develop 5 | -------------------------------------------------------------------------------- /lib/csrc/DCNv1/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import glob 4 | import os 5 | 6 | import torch 7 | from setuptools import find_packages, setup 8 | from torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension 9 | 10 | requirements = ["torch", "torchvision"] 11 | 12 | 13 | def get_extensions(): 14 | this_dir = os.path.dirname(os.path.abspath(__file__)) 15 | extensions_dir = os.path.join(this_dir, "src") 16 | 17 | main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) 18 | source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) 19 | source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu")) 20 | os.environ["CC"] = "g++" 21 | sources = main_file + source_cpu 22 | extension = CppExtension 23 | extra_compile_args = {"cxx": []} 24 | define_macros = [] 25 | 26 | 27 | if torch.cuda.is_available() and CUDA_HOME is not None: 28 | extension = CUDAExtension 29 | sources += source_cuda 30 | define_macros += [("WITH_CUDA", None)] 31 | extra_compile_args["nvcc"] = [ 32 | "-DCUDA_HAS_FP16=1", 33 | "-D__CUDA_NO_HALF_OPERATORS__", 34 | "-D__CUDA_NO_HALF_CONVERSIONS__", 35 | "-D__CUDA_NO_HALF2_OPERATORS__", 36 | ] 37 | else: 38 | # raise NotImplementedError('Cuda is not available') 39 | pass 40 | 41 | sources = [os.path.join(extensions_dir, s) for s in sources] 42 | include_dirs = [extensions_dir] 43 | ext_modules = [ 44 | extension( 45 | "_ext", 46 | sources, 47 | include_dirs=include_dirs, 48 | define_macros=define_macros, 49 | extra_compile_args=extra_compile_args, 50 | ) 51 | ] 52 | return ext_modules 53 | 54 | 55 | setup( 56 | name="DCNv2", 57 | version="0.1", 58 | author="charlesshang", 59 | url="https://github.com/charlesshang/DCNv2", 60 | description="deformable convolutional networks", 61 | packages=find_packages(exclude=("configs", "tests")), 62 | # install_requires=requirements, 63 | ext_modules=get_extensions(), 64 | cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, 65 | ) 66 | -------------------------------------------------------------------------------- /lib/csrc/DCNv1/src/cpu/dcn_v2_im2col_cpu.h: -------------------------------------------------------------------------------- 1 | 2 | /*! 3 | ******************* BEGIN Caffe Copyright Notice and Disclaimer **************** 4 | * 5 | * COPYRIGHT 6 | * 7 | * All contributions by the University of California: 8 | * Copyright (c) 2014-2017 The Regents of the University of California (Regents) 9 | * All rights reserved. 10 | * 11 | * All other contributions: 12 | * Copyright (c) 2014-2017, the respective contributors 13 | * All rights reserved. 14 | * 15 | * Caffe uses a shared copyright model: each contributor holds copyright over 16 | * their contributions to Caffe. The project versioning records all such 17 | * contribution and copyright details. If a contributor wants to further mark 18 | * their specific copyright on a particular contribution, they should indicate 19 | * their copyright solely in the commit message of the change when it is 20 | * committed. 21 | * 22 | * LICENSE 23 | * 24 | * Redistribution and use in source and binary forms, with or without 25 | * modification, are permitted provided that the following conditions are met: 26 | * 27 | * 1. Redistributions of source code must retain the above copyright notice, this 28 | * list of conditions and the following disclaimer. 29 | * 2. Redistributions in binary form must reproduce the above copyright notice, 30 | * this list of conditions and the following disclaimer in the documentation 31 | * and/or other materials provided with the distribution. 32 | * 33 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 34 | * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 35 | * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 36 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 37 | * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 38 | * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 39 | * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 40 | * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 41 | * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 42 | * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 43 | * 44 | * CONTRIBUTION AGREEMENT 45 | * 46 | * By contributing to the BVLC/caffe repository through pull-request, comment, 47 | * or otherwise, the contributor releases their content to the 48 | * license and copyright terms herein. 49 | * 50 | ***************** END Caffe Copyright Notice and Disclaimer ******************** 51 | * 52 | * Copyright (c) 2018 Microsoft 53 | * Licensed under The MIT License [see LICENSE for details] 54 | * \file modulated_deformable_im2col.h 55 | * \brief Function definitions of converting an image to 56 | * column matrix based on kernel, padding, dilation, and offset. 57 | * These functions are mainly used in deformable convolution operators. 58 | * \ref: https://arxiv.org/abs/1811.11168 59 | * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu 60 | */ 61 | 62 | /***************** Adapted by Charles Shang *********************/ 63 | // modified from the CUDA version for CPU use by Daniel K. Suhendro 64 | 65 | #ifndef DCN_V2_IM2COL_CPU 66 | #define DCN_V2_IM2COL_CPU 67 | 68 | #ifdef __cplusplus 69 | extern "C" 70 | { 71 | #endif 72 | 73 | void modulated_deformable_im2col_cpu(const float *data_im, const float *data_offset, const float *data_mask, 74 | const int batch_size, const int channels, const int height_im, const int width_im, 75 | const int height_col, const int width_col, const int kernel_h, const int kenerl_w, 76 | const int pad_h, const int pad_w, const int stride_h, const int stride_w, 77 | const int dilation_h, const int dilation_w, 78 | const int deformable_group, float *data_col); 79 | 80 | void modulated_deformable_col2im_cpu(const float *data_col, const float *data_offset, const float *data_mask, 81 | const int batch_size, const int channels, const int height_im, const int width_im, 82 | const int height_col, const int width_col, const int kernel_h, const int kenerl_w, 83 | const int pad_h, const int pad_w, const int stride_h, const int stride_w, 84 | const int dilation_h, const int dilation_w, 85 | const int deformable_group, float *grad_im); 86 | 87 | void modulated_deformable_col2im_coord_cpu(const float *data_col, const float *data_im, const float *data_offset, const float *data_mask, 88 | const int batch_size, const int channels, const int height_im, const int width_im, 89 | const int height_col, const int width_col, const int kernel_h, const int kenerl_w, 90 | const int pad_h, const int pad_w, const int stride_h, const int stride_w, 91 | const int dilation_h, const int dilation_w, 92 | const int deformable_group, 93 | float *grad_offset, float *grad_mask); 94 | 95 | #ifdef __cplusplus 96 | } 97 | #endif 98 | 99 | #endif -------------------------------------------------------------------------------- /lib/csrc/DCNv1/src/cpu/vision.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | at::Tensor 5 | dcn_v2_cpu_forward(const at::Tensor &input, 6 | const at::Tensor &weight, 7 | const at::Tensor &bias, 8 | const at::Tensor &offset, 9 | const at::Tensor &mask, 10 | const int kernel_h, 11 | const int kernel_w, 12 | const int stride_h, 13 | const int stride_w, 14 | const int pad_h, 15 | const int pad_w, 16 | const int dilation_h, 17 | const int dilation_w, 18 | const int deformable_group); 19 | 20 | std::vector 21 | dcn_v2_cpu_backward(const at::Tensor &input, 22 | const at::Tensor &weight, 23 | const at::Tensor &bias, 24 | const at::Tensor &offset, 25 | const at::Tensor &mask, 26 | const at::Tensor &grad_output, 27 | int kernel_h, int kernel_w, 28 | int stride_h, int stride_w, 29 | int pad_h, int pad_w, 30 | int dilation_h, int dilation_w, 31 | int deformable_group); 32 | 33 | 34 | std::tuple 35 | dcn_v2_psroi_pooling_cpu_forward(const at::Tensor &input, 36 | const at::Tensor &bbox, 37 | const at::Tensor &trans, 38 | const int no_trans, 39 | const float spatial_scale, 40 | const int output_dim, 41 | const int group_size, 42 | const int pooled_size, 43 | const int part_size, 44 | const int sample_per_part, 45 | const float trans_std); 46 | 47 | std::tuple 48 | dcn_v2_psroi_pooling_cpu_backward(const at::Tensor &out_grad, 49 | const at::Tensor &input, 50 | const at::Tensor &bbox, 51 | const at::Tensor &trans, 52 | const at::Tensor &top_count, 53 | const int no_trans, 54 | const float spatial_scale, 55 | const int output_dim, 56 | const int group_size, 57 | const int pooled_size, 58 | const int part_size, 59 | const int sample_per_part, 60 | const float trans_std); -------------------------------------------------------------------------------- /lib/csrc/DCNv1/src/cuda/dcn_v2_im2col_cuda.h: -------------------------------------------------------------------------------- 1 | 2 | /*! 3 | ******************* BEGIN Caffe Copyright Notice and Disclaimer **************** 4 | * 5 | * COPYRIGHT 6 | * 7 | * All contributions by the University of California: 8 | * Copyright (c) 2014-2017 The Regents of the University of California (Regents) 9 | * All rights reserved. 10 | * 11 | * All other contributions: 12 | * Copyright (c) 2014-2017, the respective contributors 13 | * All rights reserved. 14 | * 15 | * Caffe uses a shared copyright model: each contributor holds copyright over 16 | * their contributions to Caffe. The project versioning records all such 17 | * contribution and copyright details. If a contributor wants to further mark 18 | * their specific copyright on a particular contribution, they should indicate 19 | * their copyright solely in the commit message of the change when it is 20 | * committed. 21 | * 22 | * LICENSE 23 | * 24 | * Redistribution and use in source and binary forms, with or without 25 | * modification, are permitted provided that the following conditions are met: 26 | * 27 | * 1. Redistributions of source code must retain the above copyright notice, this 28 | * list of conditions and the following disclaimer. 29 | * 2. Redistributions in binary form must reproduce the above copyright notice, 30 | * this list of conditions and the following disclaimer in the documentation 31 | * and/or other materials provided with the distribution. 32 | * 33 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 34 | * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 35 | * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 36 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 37 | * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 38 | * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 39 | * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 40 | * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 41 | * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 42 | * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 43 | * 44 | * CONTRIBUTION AGREEMENT 45 | * 46 | * By contributing to the BVLC/caffe repository through pull-request, comment, 47 | * or otherwise, the contributor releases their content to the 48 | * license and copyright terms herein. 49 | * 50 | ***************** END Caffe Copyright Notice and Disclaimer ******************** 51 | * 52 | * Copyright (c) 2018 Microsoft 53 | * Licensed under The MIT License [see LICENSE for details] 54 | * \file modulated_deformable_im2col.h 55 | * \brief Function definitions of converting an image to 56 | * column matrix based on kernel, padding, dilation, and offset. 57 | * These functions are mainly used in deformable convolution operators. 58 | * \ref: https://arxiv.org/abs/1811.11168 59 | * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu 60 | */ 61 | 62 | /***************** Adapted by Charles Shang *********************/ 63 | 64 | #ifndef DCN_V2_IM2COL_CUDA 65 | #define DCN_V2_IM2COL_CUDA 66 | 67 | #ifdef __cplusplus 68 | extern "C" 69 | { 70 | #endif 71 | 72 | void modulated_deformable_im2col_cuda(cudaStream_t stream, 73 | const float *data_im, const float *data_offset, const float *data_mask, 74 | const int batch_size, const int channels, const int height_im, const int width_im, 75 | const int height_col, const int width_col, const int kernel_h, const int kenerl_w, 76 | const int pad_h, const int pad_w, const int stride_h, const int stride_w, 77 | const int dilation_h, const int dilation_w, 78 | const int deformable_group, float *data_col); 79 | 80 | void modulated_deformable_col2im_cuda(cudaStream_t stream, 81 | const float *data_col, const float *data_offset, const float *data_mask, 82 | const int batch_size, const int channels, const int height_im, const int width_im, 83 | const int height_col, const int width_col, const int kernel_h, const int kenerl_w, 84 | const int pad_h, const int pad_w, const int stride_h, const int stride_w, 85 | const int dilation_h, const int dilation_w, 86 | const int deformable_group, float *grad_im); 87 | 88 | void modulated_deformable_col2im_coord_cuda(cudaStream_t stream, 89 | const float *data_col, const float *data_im, const float *data_offset, const float *data_mask, 90 | const int batch_size, const int channels, const int height_im, const int width_im, 91 | const int height_col, const int width_col, const int kernel_h, const int kenerl_w, 92 | const int pad_h, const int pad_w, const int stride_h, const int stride_w, 93 | const int dilation_h, const int dilation_w, 94 | const int deformable_group, 95 | float *grad_offset, float *grad_mask); 96 | 97 | #ifdef __cplusplus 98 | } 99 | #endif 100 | 101 | #endif -------------------------------------------------------------------------------- /lib/csrc/DCNv1/src/cuda/vision.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | at::Tensor 5 | dcn_v2_cuda_forward(const at::Tensor &input, 6 | const at::Tensor &weight, 7 | const at::Tensor &bias, 8 | const at::Tensor &offset, 9 | const at::Tensor &mask, 10 | const int kernel_h, 11 | const int kernel_w, 12 | const int stride_h, 13 | const int stride_w, 14 | const int pad_h, 15 | const int pad_w, 16 | const int dilation_h, 17 | const int dilation_w, 18 | const int deformable_group); 19 | 20 | std::vector 21 | dcn_v2_cuda_backward(const at::Tensor &input, 22 | const at::Tensor &weight, 23 | const at::Tensor &bias, 24 | const at::Tensor &offset, 25 | const at::Tensor &mask, 26 | const at::Tensor &grad_output, 27 | int kernel_h, int kernel_w, 28 | int stride_h, int stride_w, 29 | int pad_h, int pad_w, 30 | int dilation_h, int dilation_w, 31 | int deformable_group); 32 | 33 | 34 | std::tuple 35 | dcn_v2_psroi_pooling_cuda_forward(const at::Tensor &input, 36 | const at::Tensor &bbox, 37 | const at::Tensor &trans, 38 | const int no_trans, 39 | const float spatial_scale, 40 | const int output_dim, 41 | const int group_size, 42 | const int pooled_size, 43 | const int part_size, 44 | const int sample_per_part, 45 | const float trans_std); 46 | 47 | std::tuple 48 | dcn_v2_psroi_pooling_cuda_backward(const at::Tensor &out_grad, 49 | const at::Tensor &input, 50 | const at::Tensor &bbox, 51 | const at::Tensor &trans, 52 | const at::Tensor &top_count, 53 | const int no_trans, 54 | const float spatial_scale, 55 | const int output_dim, 56 | const int group_size, 57 | const int pooled_size, 58 | const int part_size, 59 | const int sample_per_part, 60 | const float trans_std); -------------------------------------------------------------------------------- /lib/csrc/DCNv1/src/vision.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include "dcn_v2.h" 3 | 4 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 5 | m.def("dcn_v2_forward", &dcn_v2_forward, "dcn_v2_forward"); 6 | m.def("dcn_v2_backward", &dcn_v2_backward, "dcn_v2_backward"); 7 | m.def("dcn_v2_psroi_pooling_forward", &dcn_v2_psroi_pooling_forward, "dcn_v2_psroi_pooling_forward"); 8 | m.def("dcn_v2_psroi_pooling_backward", &dcn_v2_psroi_pooling_backward, "dcn_v2_psroi_pooling_backward"); 9 | } 10 | -------------------------------------------------------------------------------- /lib/csrc/extreme_utils/_ext.cpython-37m-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/lib/csrc/extreme_utils/_ext.cpython-37m-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /lib/csrc/extreme_utils/build/lib.linux-x86_64-cpython-37/_ext.cpython-37m-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/lib/csrc/extreme_utils/build/lib.linux-x86_64-cpython-37/_ext.cpython-37m-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /lib/csrc/extreme_utils/build/temp.linux-x86_64-cpython-37/.ninja_deps: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/lib/csrc/extreme_utils/build/temp.linux-x86_64-cpython-37/.ninja_deps -------------------------------------------------------------------------------- /lib/csrc/extreme_utils/build/temp.linux-x86_64-cpython-37/.ninja_log: -------------------------------------------------------------------------------- 1 | # ninja log v5 2 | 0 12031 1675996368658665154 /data/tzx/snake/lib/csrc/extreme_utils/build/temp.linux-x86_64-cpython-37/data/tzx/snake/lib/csrc/extreme_utils/src/nms.o 96ece70217737dde 3 | 1 12356 1675996368986673014 /data/tzx/snake/lib/csrc/extreme_utils/build/temp.linux-x86_64-cpython-37/data/tzx/snake/lib/csrc/extreme_utils/src/utils.o 4881b352aa1f0705 4 | 1 26575 1675996383199003813 /data/tzx/snake/lib/csrc/extreme_utils/build/temp.linux-x86_64-cpython-37/data/tzx/snake/lib/csrc/extreme_utils/utils.o 5d3239440a336996 5 | -------------------------------------------------------------------------------- /lib/csrc/extreme_utils/build/temp.linux-x86_64-cpython-37/build.ninja: -------------------------------------------------------------------------------- 1 | ninja_required_version = 1.3 2 | cxx = c++ 3 | nvcc = /usr/local/cuda-11.1/bin/nvcc 4 | 5 | cflags = -pthread -B /data/public/miniconda3/envs/snake3/compiler_compat -Wl,--sysroot=/ -Wsign-compare -DNDEBUG -g -fwrapv -O3 -Wall -Wstrict-prototypes -fPIC -I/data/tzx/snake/lib/csrc/extreme_utils -I/data/public/miniconda3/envs/snake3/lib/python3.7/site-packages/torch/include -I/data/public/miniconda3/envs/snake3/lib/python3.7/site-packages/torch/include/torch/csrc/api/include -I/data/public/miniconda3/envs/snake3/lib/python3.7/site-packages/torch/include/TH -I/data/public/miniconda3/envs/snake3/lib/python3.7/site-packages/torch/include/THC -I/usr/local/cuda-11.1/include -I/data/public/miniconda3/envs/snake3/include/python3.7m -c 6 | post_cflags = -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=_ext -D_GLIBCXX_USE_CXX11_ABI=0 -std=c++14 7 | cuda_cflags = -I/data/tzx/snake/lib/csrc/extreme_utils -I/data/public/miniconda3/envs/snake3/lib/python3.7/site-packages/torch/include -I/data/public/miniconda3/envs/snake3/lib/python3.7/site-packages/torch/include/torch/csrc/api/include -I/data/public/miniconda3/envs/snake3/lib/python3.7/site-packages/torch/include/TH -I/data/public/miniconda3/envs/snake3/lib/python3.7/site-packages/torch/include/THC -I/usr/local/cuda-11.1/include -I/data/public/miniconda3/envs/snake3/include/python3.7m -c 8 | cuda_post_cflags = -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=_ext -D_GLIBCXX_USE_CXX11_ABI=0 -gencode=arch=compute_61,code=compute_61 -gencode=arch=compute_61,code=sm_61 -std=c++14 9 | ldflags = 10 | 11 | rule compile 12 | command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags 13 | depfile = $out.d 14 | deps = gcc 15 | 16 | rule cuda_compile 17 | depfile = $out.d 18 | deps = gcc 19 | command = $nvcc $cuda_cflags -c $in -o $out $cuda_post_cflags 20 | 21 | 22 | 23 | build /data/tzx/snake/lib/csrc/extreme_utils/build/temp.linux-x86_64-cpython-37/data/tzx/snake/lib/csrc/extreme_utils/src/nms.o: cuda_compile /data/tzx/snake/lib/csrc/extreme_utils/src/nms.cu 24 | build /data/tzx/snake/lib/csrc/extreme_utils/build/temp.linux-x86_64-cpython-37/data/tzx/snake/lib/csrc/extreme_utils/src/utils.o: cuda_compile /data/tzx/snake/lib/csrc/extreme_utils/src/utils.cu 25 | build /data/tzx/snake/lib/csrc/extreme_utils/build/temp.linux-x86_64-cpython-37/data/tzx/snake/lib/csrc/extreme_utils/utils.o: compile /data/tzx/snake/lib/csrc/extreme_utils/utils.cpp 26 | 27 | 28 | 29 | 30 | 31 | -------------------------------------------------------------------------------- /lib/csrc/extreme_utils/build/temp.linux-x86_64-cpython-37/data/tzx/snake/lib/csrc/extreme_utils/src/nms.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/lib/csrc/extreme_utils/build/temp.linux-x86_64-cpython-37/data/tzx/snake/lib/csrc/extreme_utils/src/nms.o -------------------------------------------------------------------------------- /lib/csrc/extreme_utils/build/temp.linux-x86_64-cpython-37/data/tzx/snake/lib/csrc/extreme_utils/src/utils.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/lib/csrc/extreme_utils/build/temp.linux-x86_64-cpython-37/data/tzx/snake/lib/csrc/extreme_utils/src/utils.o -------------------------------------------------------------------------------- /lib/csrc/extreme_utils/build/temp.linux-x86_64-cpython-37/data/tzx/snake/lib/csrc/extreme_utils/utils.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/lib/csrc/extreme_utils/build/temp.linux-x86_64-cpython-37/data/tzx/snake/lib/csrc/extreme_utils/utils.o -------------------------------------------------------------------------------- /lib/csrc/extreme_utils/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import CUDAExtension, BuildExtension 3 | import os 4 | import glob 5 | 6 | 7 | def get_extensions(): 8 | this_dir = os.path.dirname(os.path.abspath(__file__)) 9 | main_file = glob.glob(os.path.join(this_dir, '*.cpp')) 10 | source_cuda = glob.glob(os.path.join(this_dir, 'src', '*.cu')) 11 | sources = main_file + source_cuda 12 | include_dirs = [this_dir] 13 | ext_modules = [ 14 | CUDAExtension( 15 | name='_ext', 16 | sources=sources, 17 | include_dirs=include_dirs 18 | ) 19 | ] 20 | return ext_modules 21 | 22 | 23 | setup( 24 | ext_modules=get_extensions(), 25 | cmdclass={'build_ext': BuildExtension} 26 | ) 27 | -------------------------------------------------------------------------------- /lib/csrc/extreme_utils/src/cuda_common.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #ifndef CUDA_COMMON_H_ 9 | #define CUDA_COMMON_H_ 10 | 11 | #define DIST(x1,y1,z1,x2,y2,z2) (((x1)-(x2))*((x1)-(x2))+((y1)-(y2))*((y1)-(y2))+((z1)-(z2))*((z1)-(z2))) 12 | #define DIST2D(x1,y1,x2,y2) (((x1)-(x2))*((x1)-(x2))+((y1)-(y2))*((y1)-(y2))) 13 | 14 | #define gpuErrchk(ans) { gpuAssert((ans), __FILE__, __LINE__); } 15 | 16 | void gpuAssert(cudaError_t code, const char *file, int line, bool abort=true) 17 | { 18 | if (code != cudaSuccess) 19 | { 20 | fprintf(stderr,"GPUassert: %s %s %d\n", cudaGetErrorString(code), file, line); 21 | if (abort) exit(code); 22 | } 23 | } 24 | 25 | int infTwoExp(int val) 26 | { 27 | int inf=1; 28 | while(val>inf) inf<<=1; 29 | return inf; 30 | } 31 | 32 | void getGPULayout( 33 | int dim0,int dim1,int dim2, 34 | int* bdim0,int* bdim1,int* bdim2, 35 | int* tdim0,int* tdim1,int* tdim2 36 | ) 37 | { 38 | (*tdim2)=64; 39 | if(dim2<(*tdim2)) (*tdim2)=infTwoExp(dim2); 40 | (*bdim2)=dim2/(*tdim2); 41 | if(dim2%(*tdim2)>0) (*bdim2)++; 42 | 43 | (*tdim1)=1024/(*tdim2); 44 | if(dim1<(*tdim1)) (*tdim1)=infTwoExp(dim1); 45 | (*bdim1)=dim1/(*tdim1); 46 | if(dim1%(*tdim1)>0) (*bdim1)++; 47 | 48 | (*tdim0)=1024/((*tdim1)*(*tdim2)); 49 | if(dim0<(*tdim0)) (*tdim0)=infTwoExp(dim0); 50 | (*bdim0)=dim0/(*tdim0); 51 | if(dim0%(*tdim0)>0) (*bdim0)++; 52 | } 53 | #endif 54 | -------------------------------------------------------------------------------- /lib/csrc/extreme_utils/src/nms.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | #include 3 | #include 4 | 5 | #include 6 | #include 7 | 8 | #include 9 | #include 10 | 11 | int const threadsPerBlock = sizeof(unsigned long long) * 8; 12 | 13 | __device__ inline float devIoU(float const * const a, float const * const b) { 14 | float left = max(a[0], b[0]), right = min(a[2], b[2]); 15 | float top = max(a[1], b[1]), bottom = min(a[3], b[3]); 16 | float width = max(right - left + 1, 0.f), height = max(bottom - top + 1, 0.f); 17 | float interS = width * height; 18 | float Sa = (a[2] - a[0] + 1) * (a[3] - a[1] + 1); 19 | float Sb = (b[2] - b[0] + 1) * (b[3] - b[1] + 1); 20 | return interS / (Sa + Sb - interS); 21 | } 22 | 23 | __global__ void nms_kernel(const int n_boxes, const float nms_overlap_thresh, 24 | const float *dev_boxes, unsigned long long *dev_mask) { 25 | const int row_start = blockIdx.y; 26 | const int col_start = blockIdx.x; 27 | 28 | // if (row_start > col_start) return; 29 | 30 | const int row_size = 31 | min(n_boxes - row_start * threadsPerBlock, threadsPerBlock); 32 | const int col_size = 33 | min(n_boxes - col_start * threadsPerBlock, threadsPerBlock); 34 | 35 | __shared__ float block_boxes[threadsPerBlock * 5]; 36 | if (threadIdx.x < col_size) { 37 | block_boxes[threadIdx.x * 5 + 0] = 38 | dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 0]; 39 | block_boxes[threadIdx.x * 5 + 1] = 40 | dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 1]; 41 | block_boxes[threadIdx.x * 5 + 2] = 42 | dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 2]; 43 | block_boxes[threadIdx.x * 5 + 3] = 44 | dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 3]; 45 | block_boxes[threadIdx.x * 5 + 4] = 46 | dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 4]; 47 | } 48 | __syncthreads(); 49 | 50 | if (threadIdx.x < row_size) { 51 | const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x; 52 | const float *cur_box = dev_boxes + cur_box_idx * 5; 53 | int i = 0; 54 | unsigned long long t = 0; 55 | int start = 0; 56 | if (row_start == col_start) { 57 | start = threadIdx.x + 1; 58 | } 59 | for (i = start; i < col_size; i++) { 60 | if (devIoU(cur_box, block_boxes + i * 5) > nms_overlap_thresh) { 61 | t |= 1ULL << i; 62 | } 63 | } 64 | const int col_blocks = THCCeilDiv(n_boxes, threadsPerBlock); 65 | dev_mask[cur_box_idx * col_blocks + col_start] = t; 66 | } 67 | } 68 | 69 | // boxes is a N x 5 tensor 70 | at::Tensor nms_cuda(const at::Tensor boxes, float nms_overlap_thresh) { 71 | using scalar_t = float; 72 | AT_ASSERTM(boxes.type().is_cuda(), "boxes must be a CUDA tensor"); 73 | auto scores = boxes.select(1, 4); 74 | auto order_t = std::get<1>(scores.sort(0, /* descending=*/true)); 75 | auto boxes_sorted = boxes.index_select(0, order_t); 76 | 77 | int boxes_num = boxes.size(0); 78 | 79 | const int col_blocks = THCCeilDiv(boxes_num, threadsPerBlock); 80 | 81 | scalar_t* boxes_dev = boxes_sorted.data(); 82 | 83 | THCState *state = at::globalContext().lazyInitCUDA(); // TODO replace with getTHCState 84 | 85 | unsigned long long* mask_dev = NULL; 86 | //THCudaCheck(THCudaMalloc(state, (void**) &mask_dev, 87 | // boxes_num * col_blocks * sizeof(unsigned long long))); 88 | 89 | mask_dev = (unsigned long long*) THCudaMalloc(state, boxes_num * col_blocks * sizeof(unsigned long long)); 90 | 91 | dim3 blocks(THCCeilDiv(boxes_num, threadsPerBlock), 92 | THCCeilDiv(boxes_num, threadsPerBlock)); 93 | dim3 threads(threadsPerBlock); 94 | nms_kernel<<>>(boxes_num, 95 | nms_overlap_thresh, 96 | boxes_dev, 97 | mask_dev); 98 | 99 | std::vector mask_host(boxes_num * col_blocks); 100 | THCudaCheck(cudaMemcpy(&mask_host[0], 101 | mask_dev, 102 | sizeof(unsigned long long) * boxes_num * col_blocks, 103 | cudaMemcpyDeviceToHost)); 104 | 105 | std::vector remv(col_blocks); 106 | memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks); 107 | 108 | at::Tensor keep = at::empty({boxes_num}, boxes.options().dtype(at::kLong).device(at::kCPU)); 109 | int64_t* keep_out = keep.data(); 110 | 111 | int num_to_keep = 0; 112 | for (int i = 0; i < boxes_num; i++) { 113 | int nblock = i / threadsPerBlock; 114 | int inblock = i % threadsPerBlock; 115 | 116 | if (!(remv[nblock] & (1ULL << inblock))) { 117 | keep_out[num_to_keep++] = i; 118 | unsigned long long *p = &mask_host[0] + i * col_blocks; 119 | for (int j = nblock; j < col_blocks; j++) { 120 | remv[j] |= p[j]; 121 | } 122 | } 123 | } 124 | 125 | THCudaFree(state, mask_dev); 126 | // TODO improve this part 127 | return std::get<0>(order_t.index({keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep)}).sort(0, false)); 128 | } 129 | 130 | -------------------------------------------------------------------------------- /lib/csrc/extreme_utils/src/nms.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | at::Tensor nms_cuda(const at::Tensor boxes, float nms_overlap_thresh); 5 | 6 | -------------------------------------------------------------------------------- /lib/csrc/extreme_utils/utils.cpp: -------------------------------------------------------------------------------- 1 | #include "utils.h" 2 | 3 | 4 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 5 | m.def("collect_extreme_point", &collect_extreme_point, "collect_extreme_point"); 6 | m.def("calculate_edge_num", &calculate_edge_num, "calculate_edge_num"); 7 | m.def("calculate_wnp", &calculate_wnp, "calculate_wnp"); 8 | m.def("roll_array", &roll_array, "roll_array"); 9 | m.def("nms", &nms, "non-maximum suppression"); 10 | } 11 | -------------------------------------------------------------------------------- /lib/csrc/extreme_utils/utils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include "src/nms.h" 4 | 5 | 6 | at::Tensor collect_extreme_point( 7 | const at::Tensor& ext_hm, 8 | const at::Tensor& bbox, 9 | const at::Tensor& radius, 10 | const at::Tensor& vote, 11 | const at::Tensor& ct 12 | ); 13 | 14 | 15 | void calculate_edge_num( 16 | at::Tensor& edge_num, 17 | const at::Tensor& edge_num_sum, 18 | const at::Tensor& edge_idx_sort, 19 | const int p_num 20 | ); 21 | 22 | 23 | std::tuple calculate_wnp( 24 | const at::Tensor& edge_num, 25 | const at::Tensor& edge_start_idx, 26 | const int p_num 27 | ); 28 | 29 | 30 | at::Tensor roll_array( 31 | const at::Tensor& array, 32 | const at::Tensor& step 33 | ); 34 | 35 | 36 | at::Tensor nms(const at::Tensor& dets, 37 | const at::Tensor& scores, 38 | const float threshold) { 39 | if (dets.numel() == 0) 40 | return at::empty({0}, dets.options().dtype(at::kLong).device(at::kCPU)); 41 | auto b = at::cat({dets, scores.unsqueeze(1)}, 1); 42 | return nms_cuda(b, threshold); 43 | } 44 | 45 | -------------------------------------------------------------------------------- /lib/csrc/roi_align_layer/ROIAlign.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | #pragma once 3 | 4 | #include "cpu/vision.h" 5 | 6 | #ifdef WITH_CUDA 7 | #include "cuda/vision.h" 8 | #endif 9 | 10 | // Interface for Python 11 | at::Tensor ROIAlign_forward(const at::Tensor& input, 12 | const at::Tensor& rois, 13 | const float spatial_scale, 14 | const int pooled_height, 15 | const int pooled_width, 16 | const int sampling_ratio) { 17 | if (input.type().is_cuda()) { 18 | #ifdef WITH_CUDA 19 | return ROIAlign_forward_cuda(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio); 20 | #else 21 | AT_ERROR("Not compiled with GPU support"); 22 | #endif 23 | } 24 | return ROIAlign_forward_cpu(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio); 25 | } 26 | 27 | at::Tensor ROIAlign_backward(const at::Tensor& grad, 28 | const at::Tensor& rois, 29 | const float spatial_scale, 30 | const int pooled_height, 31 | const int pooled_width, 32 | const int batch_size, 33 | const int channels, 34 | const int height, 35 | const int width, 36 | const int sampling_ratio) { 37 | if (grad.type().is_cuda()) { 38 | #ifdef WITH_CUDA 39 | return ROIAlign_backward_cuda(grad, rois, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width, sampling_ratio); 40 | #else 41 | AT_ERROR("Not compiled with GPU support"); 42 | #endif 43 | } 44 | AT_ERROR("Not implemented on the CPU"); 45 | } 46 | 47 | -------------------------------------------------------------------------------- /lib/csrc/roi_align_layer/_roi_align.cpython-37m-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/lib/csrc/roi_align_layer/_roi_align.cpython-37m-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /lib/csrc/roi_align_layer/build/lib.linux-x86_64-cpython-37/_roi_align.cpython-37m-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/lib/csrc/roi_align_layer/build/lib.linux-x86_64-cpython-37/_roi_align.cpython-37m-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /lib/csrc/roi_align_layer/build/temp.linux-x86_64-cpython-37/.ninja_deps: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/lib/csrc/roi_align_layer/build/temp.linux-x86_64-cpython-37/.ninja_deps -------------------------------------------------------------------------------- /lib/csrc/roi_align_layer/build/temp.linux-x86_64-cpython-37/.ninja_log: -------------------------------------------------------------------------------- 1 | # ninja log v5 2 | 1 11896 1675997119675944592 /data/tzx/snake/lib/csrc/roi_align_layer/build/temp.linux-x86_64-cpython-37/data/tzx/snake/lib/csrc/roi_align_layer/cuda/ROIAlign_cuda.o 8c56fca9874496b7 3 | 0 20252 1675997128027843332 /data/tzx/snake/lib/csrc/roi_align_layer/build/temp.linux-x86_64-cpython-37/data/tzx/snake/lib/csrc/roi_align_layer/cpu/ROIAlign_cpu.o 7fd4bcdc0948f548 4 | 1 24264 1675997132031794273 /data/tzx/snake/lib/csrc/roi_align_layer/build/temp.linux-x86_64-cpython-37/data/tzx/snake/lib/csrc/roi_align_layer/vision.o 9ff3cec9bf1e015f 5 | -------------------------------------------------------------------------------- /lib/csrc/roi_align_layer/build/temp.linux-x86_64-cpython-37/build.ninja: -------------------------------------------------------------------------------- 1 | ninja_required_version = 1.3 2 | cxx = c++ 3 | nvcc = /usr/local/cuda-11.1/bin/nvcc 4 | 5 | cflags = -pthread -B /data/public/miniconda3/envs/snake3/compiler_compat -Wl,--sysroot=/ -Wsign-compare -DNDEBUG -g -fwrapv -O3 -Wall -Wstrict-prototypes -fPIC -DWITH_CUDA -I/data/tzx/snake/lib/csrc/roi_align_layer -I/data/public/miniconda3/envs/snake3/lib/python3.7/site-packages/torch/include -I/data/public/miniconda3/envs/snake3/lib/python3.7/site-packages/torch/include/torch/csrc/api/include -I/data/public/miniconda3/envs/snake3/lib/python3.7/site-packages/torch/include/TH -I/data/public/miniconda3/envs/snake3/lib/python3.7/site-packages/torch/include/THC -I/usr/local/cuda-11.1/include -I/data/public/miniconda3/envs/snake3/include/python3.7m -c 6 | post_cflags = -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=_roi_align -D_GLIBCXX_USE_CXX11_ABI=0 -std=c++14 7 | cuda_cflags = -DWITH_CUDA -I/data/tzx/snake/lib/csrc/roi_align_layer -I/data/public/miniconda3/envs/snake3/lib/python3.7/site-packages/torch/include -I/data/public/miniconda3/envs/snake3/lib/python3.7/site-packages/torch/include/torch/csrc/api/include -I/data/public/miniconda3/envs/snake3/lib/python3.7/site-packages/torch/include/TH -I/data/public/miniconda3/envs/snake3/lib/python3.7/site-packages/torch/include/THC -I/usr/local/cuda-11.1/include -I/data/public/miniconda3/envs/snake3/include/python3.7m -c 8 | cuda_post_cflags = -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -DCUDA_HAS_FP16=1 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=_roi_align -D_GLIBCXX_USE_CXX11_ABI=0 -gencode=arch=compute_61,code=compute_61 -gencode=arch=compute_61,code=sm_61 -std=c++14 9 | ldflags = 10 | 11 | rule compile 12 | command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags 13 | depfile = $out.d 14 | deps = gcc 15 | 16 | rule cuda_compile 17 | depfile = $out.d 18 | deps = gcc 19 | command = $nvcc $cuda_cflags -c $in -o $out $cuda_post_cflags 20 | 21 | 22 | 23 | build /data/tzx/snake/lib/csrc/roi_align_layer/build/temp.linux-x86_64-cpython-37/data/tzx/snake/lib/csrc/roi_align_layer/cpu/ROIAlign_cpu.o: compile /data/tzx/snake/lib/csrc/roi_align_layer/cpu/ROIAlign_cpu.cpp 24 | build /data/tzx/snake/lib/csrc/roi_align_layer/build/temp.linux-x86_64-cpython-37/data/tzx/snake/lib/csrc/roi_align_layer/cuda/ROIAlign_cuda.o: cuda_compile /data/tzx/snake/lib/csrc/roi_align_layer/cuda/ROIAlign_cuda.cu 25 | build /data/tzx/snake/lib/csrc/roi_align_layer/build/temp.linux-x86_64-cpython-37/data/tzx/snake/lib/csrc/roi_align_layer/vision.o: compile /data/tzx/snake/lib/csrc/roi_align_layer/vision.cpp 26 | 27 | 28 | 29 | 30 | 31 | -------------------------------------------------------------------------------- /lib/csrc/roi_align_layer/build/temp.linux-x86_64-cpython-37/data/tzx/snake/lib/csrc/roi_align_layer/cpu/ROIAlign_cpu.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/lib/csrc/roi_align_layer/build/temp.linux-x86_64-cpython-37/data/tzx/snake/lib/csrc/roi_align_layer/cpu/ROIAlign_cpu.o -------------------------------------------------------------------------------- /lib/csrc/roi_align_layer/build/temp.linux-x86_64-cpython-37/data/tzx/snake/lib/csrc/roi_align_layer/cuda/ROIAlign_cuda.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/lib/csrc/roi_align_layer/build/temp.linux-x86_64-cpython-37/data/tzx/snake/lib/csrc/roi_align_layer/cuda/ROIAlign_cuda.o -------------------------------------------------------------------------------- /lib/csrc/roi_align_layer/build/temp.linux-x86_64-cpython-37/data/tzx/snake/lib/csrc/roi_align_layer/vision.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/lib/csrc/roi_align_layer/build/temp.linux-x86_64-cpython-37/data/tzx/snake/lib/csrc/roi_align_layer/vision.o -------------------------------------------------------------------------------- /lib/csrc/roi_align_layer/cpu/vision.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | #pragma once 3 | #include 4 | 5 | 6 | at::Tensor ROIAlign_forward_cpu(const at::Tensor& input, 7 | const at::Tensor& rois, 8 | const float spatial_scale, 9 | const int pooled_height, 10 | const int pooled_width, 11 | const int sampling_ratio); 12 | 13 | 14 | at::Tensor nms_cpu(const at::Tensor& dets, 15 | const at::Tensor& scores, 16 | const float threshold); 17 | -------------------------------------------------------------------------------- /lib/csrc/roi_align_layer/roi_align.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import torch 3 | from torch import nn 4 | from torch.autograd import Function 5 | from torch.autograd.function import once_differentiable 6 | from torch.nn.modules.utils import _pair 7 | 8 | import lib.csrc.roi_align_layer._roi_align as _roi_align 9 | 10 | from apex import amp 11 | 12 | class _ROIAlign(Function): 13 | @staticmethod 14 | def forward(ctx, input, roi, output_size, spatial_scale, sampling_ratio): 15 | ctx.save_for_backward(roi) 16 | ctx.output_size = _pair(output_size) 17 | ctx.spatial_scale = spatial_scale 18 | ctx.sampling_ratio = sampling_ratio 19 | ctx.input_shape = input.size() 20 | output = _roi_align.roi_align_forward( 21 | input, roi, spatial_scale, output_size[0], output_size[1], sampling_ratio 22 | ) 23 | return output 24 | 25 | @staticmethod 26 | @once_differentiable 27 | def backward(ctx, grad_output): 28 | rois, = ctx.saved_tensors 29 | output_size = ctx.output_size 30 | spatial_scale = ctx.spatial_scale 31 | sampling_ratio = ctx.sampling_ratio 32 | bs, ch, h, w = ctx.input_shape 33 | grad_input = _roi_align.roi_align_backward( 34 | grad_output, 35 | rois, 36 | spatial_scale, 37 | output_size[0], 38 | output_size[1], 39 | bs, 40 | ch, 41 | h, 42 | w, 43 | sampling_ratio, 44 | ) 45 | return grad_input, None, None, None, None 46 | 47 | 48 | roi_align_func = _ROIAlign.apply 49 | 50 | class ROIAlign(nn.Module): 51 | def __init__(self, output_size, spatial_scale=1., sampling_ratio=0): 52 | super(ROIAlign, self).__init__() 53 | self.output_size = output_size 54 | self.spatial_scale = spatial_scale 55 | self.sampling_ratio = sampling_ratio 56 | 57 | @amp.float_function 58 | def forward(self, input, rois): 59 | return roi_align_func( 60 | input, rois, self.output_size, self.spatial_scale, self.sampling_ratio 61 | ) 62 | 63 | def __repr__(self): 64 | tmpstr = self.__class__.__name__ + "(" 65 | tmpstr += "output_size=" + str(self.output_size) 66 | tmpstr += ", spatial_scale=" + str(self.spatial_scale) 67 | tmpstr += ", sampling_ratio=" + str(self.sampling_ratio) 68 | tmpstr += ")" 69 | return tmpstr 70 | -------------------------------------------------------------------------------- /lib/csrc/roi_align_layer/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | #!/usr/bin/env python 3 | 4 | import glob 5 | import os 6 | 7 | import torch 8 | from setuptools import find_packages 9 | from setuptools import setup 10 | from torch.utils.cpp_extension import CUDA_HOME 11 | from torch.utils.cpp_extension import CppExtension 12 | from torch.utils.cpp_extension import CUDAExtension 13 | 14 | requirements = ["torch", "torchvision"] 15 | 16 | 17 | def get_extensions(): 18 | this_dir = os.path.dirname(os.path.abspath(__file__)) 19 | extensions_dir = this_dir 20 | 21 | main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) 22 | source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) 23 | source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu")) 24 | 25 | sources = main_file + source_cpu 26 | extension = CppExtension 27 | 28 | extra_compile_args = {"cxx": []} 29 | define_macros = [] 30 | 31 | if (torch.cuda.is_available() and CUDA_HOME is not None) or os.getenv("FORCE_CUDA", "0") == "1": 32 | extension = CUDAExtension 33 | sources += source_cuda 34 | define_macros += [("WITH_CUDA", None)] 35 | extra_compile_args["nvcc"] = [ 36 | "-DCUDA_HAS_FP16=1", 37 | "-D__CUDA_NO_HALF_OPERATORS__", 38 | "-D__CUDA_NO_HALF_CONVERSIONS__", 39 | "-D__CUDA_NO_HALF2_OPERATORS__", 40 | ] 41 | 42 | sources = [os.path.join(extensions_dir, s) for s in sources] 43 | 44 | include_dirs = [extensions_dir] 45 | 46 | ext_modules = [ 47 | extension( 48 | "_roi_align", 49 | sources, 50 | include_dirs=include_dirs, 51 | define_macros=define_macros, 52 | extra_compile_args=extra_compile_args, 53 | ) 54 | ] 55 | 56 | return ext_modules 57 | 58 | 59 | setup( 60 | ext_modules=get_extensions(), 61 | cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, 62 | ) 63 | -------------------------------------------------------------------------------- /lib/csrc/roi_align_layer/vision.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | #include "ROIAlign.h" 3 | 4 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 5 | m.def("roi_align_forward", &ROIAlign_forward, "ROIAlign_forward"); 6 | m.def("roi_align_backward", &ROIAlign_backward, "ROIAlign_backward"); 7 | } 8 | -------------------------------------------------------------------------------- /lib/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .make_dataset import make_data_loader 2 | -------------------------------------------------------------------------------- /lib/datasets/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/lib/datasets/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /lib/datasets/__pycache__/collate_batch.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/lib/datasets/__pycache__/collate_batch.cpython-37.pyc -------------------------------------------------------------------------------- /lib/datasets/__pycache__/dataset_catalog.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/lib/datasets/__pycache__/dataset_catalog.cpython-37.pyc -------------------------------------------------------------------------------- /lib/datasets/__pycache__/make_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/lib/datasets/__pycache__/make_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /lib/datasets/__pycache__/samplers.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/lib/datasets/__pycache__/samplers.cpython-37.pyc -------------------------------------------------------------------------------- /lib/datasets/__pycache__/transforms.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/lib/datasets/__pycache__/transforms.cpython-37.pyc -------------------------------------------------------------------------------- /lib/datasets/dataset_catalog.py: -------------------------------------------------------------------------------- 1 | #from lib.config import cfg 2 | 3 | 4 | class DatasetCatalog(object): 5 | dataset_attrs = { 6 | 'CocoTrain': { 7 | 'id': 'coco', 8 | 'data_root': 'data/coco/train2017', 9 | 'ann_file': 'data/coco/annotations/instances_train2017.json', 10 | 'split': 'train' 11 | }, 12 | 'CocoVal': { 13 | 'id': 'coco', 14 | 'data_root': 'data/coco/val2017', 15 | 'ann_file': 'data/coco/annotations/instances_val2017.json', 16 | 'split': 'test' 17 | }, 18 | 'CocoMini': { 19 | 'id': 'coco', 20 | 'data_root': 'data/coco/val2017', 21 | 'ann_file': 'data/coco/annotations/instances_val2017.json', 22 | 'split': 'mini' 23 | }, 24 | 'CocoTest': { 25 | 'id': 'coco_test', 26 | 'data_root': 'data/coco/test2017', 27 | 'ann_file': 'data/coco/annotations/image_info_test-dev2017.json', 28 | 'split': 'test' 29 | }, 30 | 'CityscapesTrain': { 31 | 'id': 'cityscapes', 32 | 'data_root': 'data/cityscapes/leftImg8bit', 33 | 'ann_file': ('data/cityscapes/annotations/train', 'data/cityscapes/annotations/train_val'), 34 | 'split': 'train' 35 | }, 36 | 'CityscapesVal': { 37 | 'id': 'cityscapes', 38 | 'data_root': 'data/cityscapes/leftImg8bit', 39 | 'ann_file': 'data/cityscapes/annotations/val', 40 | 'split': 'val' 41 | }, 42 | 'CityscapesCocoVal': { 43 | 'id': 'cityscapes_coco', 44 | 'data_root': 'data/cityscapes/leftImg8bit/val', 45 | 'ann_file': 'data/cityscapes/coco_ann/instance_val.json', 46 | 'split': 'val' 47 | }, 48 | 'CityCocoBox': { 49 | 'id': 'cityscapes_coco', 50 | 'data_root': 'data/cityscapes/leftImg8bit/val', 51 | 'ann_file': 'data/cityscapes/coco_ann/instance_box_val.json', 52 | 'split': 'val' 53 | }, 54 | 'CityscapesMini': { 55 | 'id': 'cityscapes', 56 | 'data_root': 'data/cityscapes/leftImg8bit', 57 | 'ann_file': 'data/cityscapes/annotations/val', 58 | 'split': 'mini' 59 | }, 60 | 'CityscapesTest': { 61 | 'id': 'cityscapes_test', 62 | 'data_root': 'data/cityscapes/leftImg8bit/test' 63 | }, 64 | 'SbdTrain': { 65 | 'id': 'sbd', 66 | 'data_root': 'data/sbd/img', 67 | 'ann_file': 'data/sbd/annotations/sbd_train_instance.json', 68 | 'split': 'train' 69 | 70 | }, 71 | 'SbdVal': { 72 | 'id': 'sbd', 73 | 'data_root': 'data/sbd/img', 74 | 'ann_file': 'data/sbd/annotations/sbd_trainval_instance.json', 75 | 'split': 'val' 76 | }, 77 | 'SbdMini': { 78 | 'id': 'sbd', 79 | 'data_root': 'data/sbd/img', 80 | 'ann_file': 'data/sbd/annotations/sbd_trainval_instance.json', 81 | 'split': 'mini' 82 | }, 83 | 'MedicalTest': { 84 | 'id': 'medical', 85 | 'data_root': '/home/amax/Titan_Five/TZX/deep_sanke/images', 86 | 'ann_file': 'data/sbd/annotations/sbd_trainval_instance.json', 87 | 'split': 'mini', 88 | }, 89 | 'MedicalTrain': { 90 | 'id': 'medical', 91 | 'data_root': '/home/amax/Titan_Five/TZX/deep_sanke/images', 92 | 'ann_file': 'data/sbd/annotations/sbd_trainval_instance.json', 93 | 'split': 'train', 94 | 95 | }, 96 | 'VocVal': { 97 | 'id': 'voc', 98 | 'data_root': 'data/voc/JPEGImages', 99 | 'ann_file': 'data/voc/annotations/voc_val_instance.json', 100 | 'split': 'val' 101 | }, 102 | 'KinsTrain': { 103 | 'id': 'kins', 104 | 'data_root': 'data/kitti/training/image_2', 105 | 'ann_file': 'data/kitti/training/instances_train.json', 106 | 'split': 'train' 107 | }, 108 | 'KinsVal': { 109 | 'id': 'kins', 110 | 'data_root': 'data/kitti/testing/image_2', 111 | 'ann_file': 'data/kitti/testing/instances_val.json', 112 | 'split': 'val' 113 | }, 114 | 'KinsMini': { 115 | 'id': 'kins', 116 | 'data_root': 'data/kitti/testing/image_2', 117 | 'ann_file': 'data/kitti/testing/instances_val.json', 118 | 'split': 'mini' 119 | } 120 | 121 | } 122 | 123 | @staticmethod 124 | def get(name): 125 | attrs = DatasetCatalog.dataset_attrs[name] 126 | return attrs.copy() 127 | 128 | if __name__ == '__main__': 129 | print(DatasetCatalog.get('KinsMini')) 130 | 131 | -------------------------------------------------------------------------------- /lib/datasets/make_dataset.py: -------------------------------------------------------------------------------- 1 | from .transforms import make_transforms 2 | from . import samplers 3 | from .dataset_catalog import DatasetCatalog 4 | import torch 5 | import torch.utils.data 6 | import imp 7 | import os 8 | from .collate_batch import make_collator 9 | import random 10 | import numpy as np 11 | 12 | 13 | torch.multiprocessing.set_sharing_strategy('file_system') 14 | 15 | 16 | def _dataset_factory(data_source, task): 17 | module = '.'.join(['lib.datasets', data_source, task]) 18 | path = os.path.join('lib/datasets', data_source, task+'.py') 19 | print("dataloader",module, path) 20 | dataset = imp.load_source(module, path).Dataset 21 | return dataset 22 | 23 | 24 | def make_dataset(cfg, dataset_name, transforms, is_train=True): 25 | args = DatasetCatalog.get(dataset_name) 26 | data_source = args['id'] 27 | dataset = _dataset_factory(data_source, cfg.task) 28 | del args['id'] 29 | # args['cfg'] = cfg 30 | # args['transforms'] = transforms 31 | # args['is_train'] = is_train 32 | 33 | dataset = dataset(**args) 34 | return dataset 35 | 36 | 37 | def make_data_sampler(dataset, shuffle): 38 | if shuffle: 39 | sampler = torch.utils.data.sampler.RandomSampler(dataset) 40 | else: 41 | sampler = torch.utils.data.sampler.SequentialSampler(dataset) 42 | return sampler 43 | 44 | 45 | def make_batch_data_sampler(cfg, sampler, batch_size, drop_last, max_iter): 46 | batch_sampler = torch.utils.data.sampler.BatchSampler(sampler, batch_size, drop_last) 47 | if max_iter != -1: 48 | batch_sampler = samplers.IterationBasedBatchSampler(batch_sampler, max_iter) 49 | return batch_sampler 50 | 51 | GLOBAL_SEED = 1 52 | 53 | def set_seed(seed): 54 | random.seed(seed) 55 | np.random.seed(seed) 56 | torch.manual_seed(seed) 57 | torch.cuda.manual_seed(seed) 58 | torch.cuda.manual_seed_all(seed) 59 | 60 | GLOBAL_WORKER_ID = None 61 | def worker_init_fn(worker_id): 62 | global GLOBAL_WORKER_ID 63 | GLOBAL_WORKER_ID = worker_id 64 | set_seed(GLOBAL_SEED + worker_id) 65 | 66 | def make_data_loader(cfg, is_train=True, is_distributed=False, max_iter=-1): 67 | if is_train: 68 | batch_size = cfg.train.batch_size 69 | shuffle = True 70 | drop_last = False 71 | else: 72 | batch_size = cfg.test.batch_size 73 | shuffle = True if is_distributed else False 74 | drop_last = False 75 | 76 | dataset_name = cfg.train.dataset if is_train else cfg.test.dataset 77 | 78 | transforms = make_transforms(cfg, is_train) 79 | dataset = make_dataset(cfg, dataset_name, transforms, is_train) 80 | sampler = make_data_sampler(dataset, shuffle) 81 | batch_sampler = make_batch_data_sampler(cfg, sampler, batch_size, drop_last, max_iter) 82 | num_workers = cfg.train.num_workers 83 | collator = make_collator(cfg) 84 | data_loader = torch.utils.data.DataLoader( 85 | dataset, 86 | batch_sampler=batch_sampler, 87 | num_workers=num_workers, 88 | collate_fn=collator, 89 | worker_init_fn=worker_init_fn 90 | ) 91 | 92 | return data_loader 93 | -------------------------------------------------------------------------------- /lib/datasets/medical/__pycache__/snake.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/lib/datasets/medical/__pycache__/snake.cpython-37.pyc -------------------------------------------------------------------------------- /lib/datasets/samplers.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data.sampler import Sampler 2 | from torch.utils.data.sampler import BatchSampler 3 | import numpy as np 4 | import torch 5 | import math 6 | import torch.distributed as dist 7 | 8 | 9 | class ImageSizeBatchSampler(Sampler): 10 | def __init__(self, sampler, batch_size, drop_last, min_size=600, max_size=800, size_int=8): 11 | self.sampler = sampler 12 | self.batch_size = batch_size 13 | self.drop_last = drop_last 14 | self.hmin = min_size 15 | self.hmax = max_size 16 | self.wmin = min_size 17 | self.wmax = max_size 18 | self.size_int = size_int 19 | self.hint = (self.hmax-self.hmin)//self.size_int+1 20 | self.wint = (self.wmax-self.wmin)//self.size_int+1 21 | 22 | def generate_height_width(self): 23 | hi, wi = np.random.randint(0, self.hint), np.random.randint(0, self.wint) 24 | h, w = self.hmin + hi * self.size_int, self.wmin + wi * self.size_int 25 | return h, w 26 | 27 | def __iter__(self): 28 | batch = [] 29 | h, w = self.generate_height_width() 30 | for idx in self.sampler: 31 | batch.append((idx, h, w)) 32 | if len(batch) == self.batch_size: 33 | h, w = self.generate_height_width() 34 | yield batch 35 | batch = [] 36 | if len(batch) > 0 and not self.drop_last: 37 | yield batch 38 | 39 | def __len__(self): 40 | if self.drop_last: 41 | return len(self.sampler) // self.batch_size 42 | else: 43 | return (len(self.sampler) + self.batch_size - 1) // self.batch_size 44 | 45 | 46 | class IterationBasedBatchSampler(BatchSampler): 47 | """ 48 | Wraps a BatchSampler, resampling from it until 49 | a specified number of iterations have been sampled 50 | """ 51 | def __init__(self, batch_sampler, num_iterations, start_iter=0): 52 | self.batch_sampler = batch_sampler 53 | self.num_iterations = num_iterations 54 | self.start_iter = start_iter 55 | 56 | def __iter__(self): 57 | iteration = self.start_iter 58 | while iteration <= self.num_iterations: 59 | for batch in self.batch_sampler: 60 | iteration += 1 61 | if iteration > self.num_iterations: 62 | break 63 | yield batch 64 | 65 | def __len__(self): 66 | return self.num_iterations 67 | -------------------------------------------------------------------------------- /lib/datasets/transforms.py: -------------------------------------------------------------------------------- 1 | class Compose(object): 2 | def __init__(self, transforms): 3 | self.transforms = transforms 4 | 5 | def __call__(self, img, kpts=None): 6 | for t in self.transforms: 7 | img, kpts = t(img, kpts) 8 | if kpts is None: 9 | return img 10 | else: 11 | return img, kpts 12 | 13 | def __repr__(self): 14 | format_string = self.__class__.__name__ + "(" 15 | for t in self.transforms: 16 | format_string += "\n" 17 | format_string += " {0}".format(t) 18 | format_string += "\n)" 19 | return format_string 20 | 21 | 22 | class ToTensor(object): 23 | def __call__(self, img, kpts): 24 | return img / 255., kpts 25 | 26 | 27 | class Normalize(object): 28 | def __init__(self, mean, std): 29 | self.mean = mean 30 | self.std = std 31 | 32 | def __call__(self, img, kpts): 33 | img -= self.mean 34 | img /= self.std 35 | return img, kpts 36 | 37 | 38 | def make_transforms(cfg, is_train): 39 | if is_train is True: 40 | transform = Compose( 41 | [ 42 | ToTensor(), 43 | Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 44 | ] 45 | ) 46 | else: 47 | transform = Compose( 48 | [ 49 | ToTensor(), 50 | Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 51 | ] 52 | ) 53 | 54 | return transform 55 | -------------------------------------------------------------------------------- /lib/networks/__init__.py: -------------------------------------------------------------------------------- 1 | from .make_network import make_network, get_network 2 | -------------------------------------------------------------------------------- /lib/networks/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/lib/networks/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /lib/networks/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/lib/networks/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /lib/networks/__pycache__/dcn_v2.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/lib/networks/__pycache__/dcn_v2.cpython-37.pyc -------------------------------------------------------------------------------- /lib/networks/__pycache__/make_network.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/lib/networks/__pycache__/make_network.cpython-37.pyc -------------------------------------------------------------------------------- /lib/networks/__pycache__/make_network.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/lib/networks/__pycache__/make_network.cpython-38.pyc -------------------------------------------------------------------------------- /lib/networks/ct_rcnn/__init__.py: -------------------------------------------------------------------------------- 1 | from .ct_rcnn import get_network as get_rcnn 2 | from lib.utils.snake import snake_config 3 | 4 | 5 | _network_factory = { 6 | 'rcnn': get_rcnn 7 | } 8 | 9 | 10 | def get_network(cfg): 11 | arch = cfg.network 12 | heads = cfg.heads 13 | head_conv = cfg.head_conv 14 | num_layers = int(arch[arch.find('_') + 1:]) if '_' in arch else 0 15 | arch = arch[:arch.find('_')] if '_' in arch else arch 16 | get_model = _network_factory[arch] 17 | network = get_model(num_layers, heads, head_conv, snake_config.down_ratio, cfg.det_dir) 18 | return network 19 | 20 | -------------------------------------------------------------------------------- /lib/networks/ct_rcnn/cp_head.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from lib.csrc.roi_align_layer.roi_align import ROIAlign 3 | from lib.utils.rcnn_snake import rcnn_snake_config, rcnn_snake_utils 4 | import torch 5 | from lib.csrc.extreme_utils import _ext 6 | 7 | 8 | def fill_fc_weights(layers): 9 | for m in layers.modules(): 10 | if isinstance(m, nn.Conv2d): 11 | if m.bias is not None: 12 | nn.init.constant_(m.bias, 0) 13 | 14 | 15 | class ComponentDetection(nn.Module): 16 | def __init__(self): 17 | super(ComponentDetection, self).__init__() 18 | 19 | self.pooler = ROIAlign((rcnn_snake_config.roi_h, rcnn_snake_config.roi_w)) 20 | 21 | self.fusion = nn.Sequential( 22 | nn.Conv2d(64, 256, kernel_size=3, padding=1, bias=True), 23 | nn.ReLU(inplace=True), 24 | nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=True), 25 | nn.ReLU(inplace=True), 26 | nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=True), 27 | nn.ReLU(inplace=True), 28 | nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=True), 29 | nn.ReLU(inplace=True) 30 | ) 31 | 32 | self.heads = {'cp_hm': 1, 'cp_wh': 2} 33 | for head in self.heads: 34 | classes = self.heads[head] 35 | fc = nn.Sequential( 36 | nn.ConvTranspose2d(256, 256, kernel_size=2, stride=2), 37 | nn.Conv2d(256, classes, kernel_size=1, stride=1) 38 | ) 39 | if 'hm' in head: 40 | fc[-1].bias.data.fill_(-2.19) 41 | else: 42 | fill_fc_weights(fc) 43 | self.__setattr__(head, fc) 44 | 45 | def prepare_training(self, cnn_feature, output, batch): 46 | w = cnn_feature.size(3) 47 | xs = (batch['act_ind'] % w).float()[..., None] 48 | ys = (batch['act_ind'] // w).float()[..., None] 49 | wh = batch['awh'] 50 | bboxes = torch.cat([xs - wh[..., 0:1] / 2, 51 | ys - wh[..., 1:2] / 2, 52 | xs + wh[..., 0:1] / 2, 53 | ys + wh[..., 1:2] / 2], dim=2) 54 | rois = rcnn_snake_utils.box_to_roi(bboxes, batch['act_01'].byte()) 55 | roi = self.pooler(cnn_feature, rois) 56 | return roi 57 | 58 | def nms_class_box(self, box, score, cls, cls_num): 59 | box_score_cls = [] 60 | 61 | for j in range(cls_num): 62 | ind = (cls == j).nonzero().view(-1) 63 | if len(ind) == 0: 64 | continue 65 | 66 | box_ = box[ind] 67 | score_ = score[ind] 68 | ind = _ext.nms(box_, score_, rcnn_snake_config.max_ct_overlap) 69 | 70 | box_ = box_[ind] 71 | score_ = score_[ind] 72 | 73 | ind = score_ > rcnn_snake_config.ct_score 74 | box_ = box_[ind] 75 | score_ = score_[ind] 76 | label_ = torch.full([len(box_)], j).to(box_.device).float() 77 | 78 | box_score_cls.append([box_, score_, label_]) 79 | 80 | return box_score_cls 81 | 82 | def nms_abox(self, output): 83 | box = output['detection'][..., :4] 84 | score = output['detection'][..., 4] 85 | cls = output['detection'][..., 5] 86 | 87 | batch_size = box.size(0) 88 | cls_num = output['act_hm'].size(1) 89 | 90 | box_score_cls = [] 91 | for i in range(batch_size): 92 | box_score_cls_ = self.nms_class_box(box[i], score[i], cls[i], cls_num) 93 | box_score_cls_ = [torch.cat(d, dim=0) for d in list(zip(*box_score_cls_))] 94 | box_score_cls.append(box_score_cls_) 95 | 96 | box, score, cls = list(zip(*box_score_cls)) 97 | ind = torch.cat([torch.full([len(box[i])], i) for i in range(len(box))], dim=0) 98 | box = torch.cat(box, dim=0) 99 | score = torch.stack(score, dim=1) 100 | cls = torch.stack(cls, dim=1) 101 | 102 | detection = torch.cat([box, score, cls], dim=1) 103 | 104 | return detection, ind 105 | 106 | def prepare_testing(self, cnn_feature, output): 107 | if rcnn_snake_config.nms_ct: 108 | detection, ind = self.nms_abox(output) 109 | else: 110 | ind = output['detection'][..., 4] > rcnn_snake_config.ct_score 111 | detection = output['detection'][ind] 112 | ind = torch.cat([torch.full([ind[i].sum()], i) for i in range(len(ind))], dim=0) 113 | 114 | ind = ind.to(cnn_feature.device) 115 | abox = detection[:, :4] 116 | roi = torch.cat([ind[:, None], abox], dim=1) 117 | 118 | roi = self.pooler(cnn_feature, roi) 119 | output.update({'detection': detection, 'roi_ind': ind}) 120 | 121 | return roi 122 | 123 | def decode_cp_detection(self, cp_hm, cp_wh, output): 124 | abox = output['detection'][..., :4] 125 | adet = output['detection'] 126 | ind = output['roi_ind'] 127 | box, cp_ind = rcnn_snake_utils.decode_cp_detection(torch.sigmoid(cp_hm), cp_wh, abox, adet) 128 | output.update({'cp_box': box, 'cp_ind': cp_ind}) 129 | 130 | def forward(self, output, cnn_feature, batch=None): 131 | z = {} 132 | 133 | if batch is not None and 'test' not in batch['meta']: 134 | roi = self.prepare_training(cnn_feature, output, batch) 135 | roi = self.fusion(roi) 136 | for head in self.heads: 137 | z[head] = self.__getattr__(head)(roi) 138 | 139 | if not self.training: 140 | with torch.no_grad(): 141 | roi = self.prepare_testing(cnn_feature, output) 142 | roi = self.fusion(roi) 143 | cp_hm = self.cp_hm(roi) 144 | cp_wh = self.cp_wh(roi) 145 | self.decode_cp_detection(cp_hm, cp_wh, output) 146 | 147 | output.update(z) 148 | 149 | return output 150 | 151 | -------------------------------------------------------------------------------- /lib/networks/ct_rcnn/ct_rcnn.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .dla import DLASeg 3 | from lib.utils import net_utils 4 | from .cp_head import ComponentDetection 5 | import torch 6 | from lib.utils.snake import snake_decode 7 | from lib.utils import data_utils 8 | 9 | 10 | class Network(nn.Module): 11 | def __init__(self, num_layers, heads, head_conv=256, down_ratio=4, det_dir=''): 12 | super(Network, self).__init__() 13 | 14 | self.dla = DLASeg('dla{}'.format(num_layers), heads, 15 | pretrained=True, 16 | down_ratio=down_ratio, 17 | final_kernel=1, 18 | last_level=5, 19 | head_conv=head_conv) 20 | self.cp = ComponentDetection() 21 | 22 | def decode_detection(self, output, h, w): 23 | ct_hm = output['act_hm'] 24 | wh = output['awh'] 25 | ct, detection = snake_decode.decode_ct_hm(torch.sigmoid(ct_hm), wh) 26 | detection[..., :4] = data_utils.clip_to_image(detection[..., :4], h, w) 27 | output.update({'ct': ct, 'detection': detection}) 28 | return ct, detection 29 | 30 | def forward(self, x, batch=None): 31 | output, cnn_feature = self.dla(x) 32 | with torch.no_grad(): 33 | self.decode_detection(output, cnn_feature.size(2), cnn_feature.size(3)) 34 | output = self.cp(output, cnn_feature, batch) 35 | return output 36 | 37 | 38 | def get_network(num_layers, heads, head_conv=256, down_ratio=4, det_dir=''): 39 | network = Network(num_layers, heads, head_conv, down_ratio, det_dir) 40 | return network 41 | 42 | -------------------------------------------------------------------------------- /lib/networks/make_network.py: -------------------------------------------------------------------------------- 1 | import os 2 | import imp 3 | import sys 4 | 5 | 6 | _network_factory = { 7 | } 8 | 9 | 10 | def get_network(cfg): 11 | arch = cfg.network 12 | heads = cfg.heads 13 | head_conv = cfg.head_conv 14 | num_layers = int(arch[arch.find('_') + 1:]) if '_' in arch else 0 15 | arch = arch[:arch.find('_')] if '_' in arch else arch 16 | get_model = _network_factory[arch] 17 | network = get_model(num_layers, heads, head_conv) 18 | return network 19 | 20 | 21 | def make_network(cfg): 22 | 23 | module = '.'.join(['lib.networks', cfg.task]) 24 | path = os.path.join('lib/networks', cfg.task, '__init__.py') 25 | 26 | return imp.load_source(module, path).get_network(cfg) 27 | -------------------------------------------------------------------------------- /lib/networks/rcnn_snake/__init__.py: -------------------------------------------------------------------------------- 1 | from .ct_rcnn_snake import get_network as get_rcnn 2 | from lib.utils.snake import snake_config 3 | 4 | 5 | _network_factory = { 6 | 'rcnn': get_rcnn 7 | } 8 | 9 | 10 | def get_network(cfg): 11 | arch = cfg.network 12 | heads = cfg.heads 13 | head_conv = cfg.head_conv 14 | num_layers = int(arch[arch.find('_') + 1:]) if '_' in arch else 0 15 | arch = arch[:arch.find('_')] if '_' in arch else arch 16 | get_model = _network_factory[arch] 17 | network = get_model(num_layers, heads, head_conv, snake_config.down_ratio, cfg.det_dir) 18 | return network 19 | 20 | -------------------------------------------------------------------------------- /lib/networks/rcnn_snake/cp_head.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from lib.csrc.roi_align_layer.roi_align import ROIAlign 3 | from lib.utils.rcnn_snake import rcnn_snake_config, rcnn_snake_utils 4 | import torch 5 | from lib.csrc.extreme_utils import _ext 6 | 7 | 8 | def fill_fc_weights(layers): 9 | for m in layers.modules(): 10 | if isinstance(m, nn.Conv2d): 11 | if m.bias is not None: 12 | nn.init.constant_(m.bias, 0) 13 | 14 | 15 | class ComponentDetection(nn.Module): 16 | def __init__(self): 17 | super(ComponentDetection, self).__init__() 18 | 19 | self.pooler = ROIAlign((rcnn_snake_config.roi_h, rcnn_snake_config.roi_w)) 20 | 21 | self.fusion = nn.Sequential( 22 | nn.Conv2d(64, 256, kernel_size=3, padding=1, bias=True), 23 | nn.ReLU(inplace=True), 24 | nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=True), 25 | nn.ReLU(inplace=True), 26 | nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=True), 27 | nn.ReLU(inplace=True), 28 | nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=True), 29 | nn.ReLU(inplace=True) 30 | ) 31 | 32 | self.heads = {'cp_hm': 1, 'cp_wh': 2} 33 | for head in self.heads: 34 | classes = self.heads[head] 35 | fc = nn.Sequential( 36 | nn.ConvTranspose2d(256, 256, kernel_size=2, stride=2), 37 | nn.Conv2d(256, classes, kernel_size=1, stride=1) 38 | ) 39 | if 'hm' in head: 40 | fc[-1].bias.data.fill_(-2.19) 41 | else: 42 | fill_fc_weights(fc) 43 | self.__setattr__(head, fc) 44 | 45 | def prepare_training(self, cnn_feature, output, batch): 46 | w = cnn_feature.size(3) 47 | xs = (batch['act_ind'] % w).float()[..., None] 48 | ys = (batch['act_ind'] // w).float()[..., None] 49 | wh = batch['awh'] 50 | bboxes = torch.cat([xs - wh[..., 0:1] / 2, 51 | ys - wh[..., 1:2] / 2, 52 | xs + wh[..., 0:1] / 2, 53 | ys + wh[..., 1:2] / 2], dim=2) 54 | rois = rcnn_snake_utils.box_to_roi(bboxes, batch['act_01'].byte()) 55 | roi = self.pooler(cnn_feature, rois) 56 | return roi 57 | 58 | def nms_class_box(self, box, score, cls, cls_num): 59 | box_score_cls = [] 60 | 61 | for j in range(cls_num): 62 | ind = (cls == j).nonzero().view(-1) 63 | if len(ind) == 0: 64 | continue 65 | 66 | box_ = box[ind] 67 | score_ = score[ind] 68 | ind = _ext.nms(box_, score_, rcnn_snake_config.max_ct_overlap) 69 | 70 | box_ = box_[ind] 71 | score_ = score_[ind] 72 | 73 | ind = score_ > rcnn_snake_config.ct_score 74 | box_ = box_[ind] 75 | score_ = score_[ind] 76 | label_ = torch.full([len(box_)], j).to(box_.device).float() 77 | 78 | box_score_cls.append([box_, score_, label_]) 79 | 80 | return box_score_cls 81 | 82 | def nms_abox(self, output): 83 | box = output['detection'][..., :4] 84 | score = output['detection'][..., 4] 85 | cls = output['detection'][..., 5] 86 | 87 | batch_size = box.size(0) 88 | cls_num = output['act_hm'].size(1) 89 | 90 | box_score_cls = [] 91 | for i in range(batch_size): 92 | box_score_cls_ = self.nms_class_box(box[i], score[i], cls[i], cls_num) 93 | box_score_cls_ = [torch.cat(d, dim=0) for d in list(zip(*box_score_cls_))] 94 | box_score_cls.append(box_score_cls_) 95 | 96 | box, score, cls = list(zip(*box_score_cls)) 97 | ind = torch.cat([torch.full([len(box[i])], i) for i in range(len(box))], dim=0) 98 | box = torch.cat(box, dim=0) 99 | score = torch.stack(score, dim=1) 100 | cls = torch.stack(cls, dim=1) 101 | 102 | detection = torch.cat([box, score, cls], dim=1) 103 | 104 | return detection, ind 105 | 106 | def prepare_testing(self, cnn_feature, output): 107 | if rcnn_snake_config.nms_ct: 108 | detection, ind = self.nms_abox(output) 109 | else: 110 | ind = output['detection'][..., 4] > rcnn_snake_config.ct_score 111 | detection = output['detection'][ind] 112 | ind = torch.cat([torch.full([ind[i].sum()], i) for i in range(len(ind))], dim=0) 113 | 114 | ind = ind.to(cnn_feature.device) 115 | abox = detection[:, :4] 116 | roi = torch.cat([ind[:, None], abox], dim=1) 117 | 118 | roi = self.pooler(cnn_feature, roi) 119 | output.update({'detection': detection, 'roi_ind': ind}) 120 | 121 | return roi 122 | 123 | def decode_cp_detection(self, cp_hm, cp_wh, output): 124 | abox = output['detection'][..., :4] 125 | adet = output['detection'] 126 | ind = output['roi_ind'] 127 | box, cp_ind = rcnn_snake_utils.decode_cp_detection(torch.sigmoid(cp_hm), cp_wh, abox, adet) 128 | output.update({'cp_box': box, 'cp_ind': cp_ind}) 129 | 130 | def forward(self, output, cnn_feature, batch=None): 131 | z = {} 132 | 133 | if batch is not None and 'test' not in batch['meta']: 134 | roi = self.prepare_training(cnn_feature, output, batch) 135 | roi = self.fusion(roi) 136 | for head in self.heads: 137 | z[head] = self.__getattr__(head)(roi) 138 | 139 | if not self.training: 140 | with torch.no_grad(): 141 | roi = self.prepare_testing(cnn_feature, output) 142 | roi = self.fusion(roi) 143 | cp_hm = self.cp_hm(roi) 144 | cp_wh = self.cp_wh(roi) 145 | self.decode_cp_detection(cp_hm, cp_wh, output) 146 | 147 | output.update(z) 148 | 149 | return output 150 | 151 | -------------------------------------------------------------------------------- /lib/networks/rcnn_snake/ct_rcnn_snake.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .dla import DLASeg 3 | from lib.utils import net_utils 4 | from .cp_head import ComponentDetection 5 | import torch 6 | from lib.utils.snake import snake_decode 7 | from lib.utils import data_utils 8 | from .evolve import Evolution 9 | from lib.config import cfg 10 | import os 11 | 12 | 13 | class Network(nn.Module): 14 | def __init__(self, num_layers, heads, head_conv=256, down_ratio=4, det_dir=''): 15 | super(Network, self).__init__() 16 | 17 | self.dla = DLASeg('dla{}'.format(num_layers), heads, 18 | pretrained=True, 19 | down_ratio=down_ratio, 20 | final_kernel=1, 21 | last_level=5, 22 | head_conv=head_conv) 23 | self.cp = ComponentDetection() 24 | self.gcn = Evolution() 25 | 26 | det_dir = os.path.join(os.path.dirname(cfg.model_dir), cfg.det_model) 27 | net_utils.load_network(self, det_dir, strict=False) 28 | 29 | def decode_detection(self, output, h, w): 30 | ct_hm = output['act_hm'] 31 | wh = output['awh'] 32 | ct, detection = snake_decode.decode_ct_hm(torch.sigmoid(ct_hm), wh) 33 | detection[..., :4] = data_utils.clip_to_image(detection[..., :4], h, w) 34 | output.update({'ct': ct, 'detection': detection}) 35 | return ct, detection 36 | 37 | def forward(self, x, batch=None): 38 | output, cnn_feature = self.dla(x) 39 | with torch.no_grad(): 40 | self.decode_detection(output, cnn_feature.size(2), cnn_feature.size(3)) 41 | output = self.cp(output, cnn_feature, batch) 42 | output = self.gcn(output, cnn_feature, batch) 43 | return output 44 | 45 | 46 | def get_network(num_layers, heads, head_conv=256, down_ratio=4, det_dir=''): 47 | network = Network(num_layers, heads, head_conv, down_ratio, det_dir) 48 | return network 49 | 50 | -------------------------------------------------------------------------------- /lib/networks/rcnn_snake/snake.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | 5 | class CircConv(nn.Module): 6 | def __init__(self, state_dim, out_state_dim=None, n_adj=4): 7 | super(CircConv, self).__init__() 8 | 9 | self.n_adj = n_adj 10 | out_state_dim = state_dim if out_state_dim is None else out_state_dim 11 | self.fc = nn.Conv1d(state_dim, out_state_dim, kernel_size=self.n_adj*2+1) 12 | 13 | def forward(self, input, adj): 14 | input = torch.cat([input[..., -self.n_adj:], input, input[..., :self.n_adj]], dim=2) 15 | return self.fc(input) 16 | 17 | 18 | class DilatedCircConv(nn.Module): 19 | def __init__(self, state_dim, out_state_dim=None, n_adj=4, dilation=1): 20 | super(DilatedCircConv, self).__init__() 21 | 22 | self.n_adj = n_adj 23 | self.dilation = dilation 24 | out_state_dim = state_dim if out_state_dim is None else out_state_dim 25 | self.fc = nn.Conv1d(state_dim, out_state_dim, kernel_size=self.n_adj*2+1, dilation=self.dilation) 26 | 27 | def forward(self, input, adj): 28 | if self.n_adj != 0: 29 | input = torch.cat([input[..., -self.n_adj*self.dilation:], input, input[..., :self.n_adj*self.dilation]], dim=2) 30 | return self.fc(input) 31 | 32 | 33 | _conv_factory = { 34 | 'grid': CircConv, 35 | 'dgrid': DilatedCircConv 36 | } 37 | 38 | 39 | class BasicBlock(nn.Module): 40 | def __init__(self, state_dim, out_state_dim, conv_type, n_adj=4, dilation=1): 41 | super(BasicBlock, self).__init__() 42 | 43 | self.conv = _conv_factory[conv_type](state_dim, out_state_dim, n_adj, dilation) 44 | self.relu = nn.ReLU(inplace=True) 45 | self.norm = nn.BatchNorm1d(out_state_dim) 46 | 47 | def forward(self, x, adj=None): 48 | x = self.conv(x, adj) 49 | x = self.relu(x) 50 | x = self.norm(x) 51 | return x 52 | 53 | 54 | class Snake(nn.Module): 55 | def __init__(self, state_dim, feature_dim, conv_type='dgrid'): 56 | super(Snake, self).__init__() 57 | 58 | self.head = BasicBlock(feature_dim, state_dim, conv_type) 59 | 60 | self.res_layer_num = 7 61 | dilation = [1, 1, 1, 2, 2, 4, 4] 62 | for i in range(self.res_layer_num): 63 | conv = BasicBlock(state_dim, state_dim, conv_type, n_adj=4, dilation=dilation[i]) 64 | self.__setattr__('res'+str(i), conv) 65 | 66 | fusion_state_dim = 256 67 | self.fusion = nn.Conv1d(state_dim * (self.res_layer_num + 1), fusion_state_dim, 1) 68 | self.prediction = nn.Sequential( 69 | nn.Conv1d(state_dim * (self.res_layer_num + 1) + fusion_state_dim, 256, 1), 70 | nn.ReLU(inplace=True), 71 | nn.Conv1d(256, 64, 1), 72 | nn.ReLU(inplace=True), 73 | nn.Conv1d(64, 2, 1) 74 | ) 75 | 76 | def forward(self, x, adj): 77 | states = [] 78 | 79 | x = self.head(x, adj) 80 | states.append(x) 81 | for i in range(self.res_layer_num): 82 | x = self.__getattr__('res'+str(i))(x, adj) + x 83 | states.append(x) 84 | 85 | state = torch.cat(states, dim=1) 86 | global_state = torch.max(self.fusion(state), dim=2, keepdim=True)[0] 87 | global_state = global_state.expand(global_state.size(0), global_state.size(1), state.size(2)) 88 | state = torch.cat([global_state, state], dim=1) 89 | x = self.prediction(state) 90 | 91 | return x 92 | -------------------------------------------------------------------------------- /lib/networks/snake/__init__.py: -------------------------------------------------------------------------------- 1 | from lib.utils.snake import snake_config 2 | from .ct_snake import get_network as get_ro 3 | 4 | 5 | _network_factory = { 6 | 'ro': get_ro 7 | } 8 | 9 | 10 | def get_network(cfg): 11 | arch = cfg.network 12 | heads = cfg.heads 13 | head_conv = cfg.head_conv 14 | num_layers = int(arch[arch.find('_') + 1:]) if '_' in arch else 0 #34 15 | arch = arch[:arch.find('_')] if '_' in arch else arch 16 | get_model = _network_factory[arch] 17 | network = get_model(num_layers, heads, head_conv, snake_config.down_ratio, cfg.det_dir) 18 | return network -------------------------------------------------------------------------------- /lib/networks/snake/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/lib/networks/snake/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /lib/networks/snake/__pycache__/ct_snake.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/lib/networks/snake/__pycache__/ct_snake.cpython-37.pyc -------------------------------------------------------------------------------- /lib/networks/snake/__pycache__/dla.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/lib/networks/snake/__pycache__/dla.cpython-37.pyc -------------------------------------------------------------------------------- /lib/networks/snake/__pycache__/evolve.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/lib/networks/snake/__pycache__/evolve.cpython-37.pyc -------------------------------------------------------------------------------- /lib/networks/snake/__pycache__/snake.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/lib/networks/snake/__pycache__/snake.cpython-37.pyc -------------------------------------------------------------------------------- /lib/networks/snake/__pycache__/unethead.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/lib/networks/snake/__pycache__/unethead.cpython-37.pyc -------------------------------------------------------------------------------- /lib/networks/snake/ct_snake.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .dla import DLASeg 3 | from .evolve import Evolution 4 | from lib.utils import net_utils, data_utils 5 | from lib.utils.snake import snake_decode 6 | import torch 7 | from lib.config import cfg 8 | from .unethead import UNet 9 | import time 10 | import sys 11 | 12 | class Network(nn.Module): 13 | def __init__(self, num_layers, heads, head_conv=256, down_ratio=4, det_dir=''): 14 | super(Network, self).__init__() 15 | 16 | self.dla = DLASeg('dla{}'.format(num_layers), heads, 17 | pretrained=True, 18 | down_ratio=down_ratio, 19 | final_kernel=1, 20 | last_level=5, 21 | head_conv=head_conv) 22 | self.gcn = Evolution() 23 | self.sum =torch.zeros((1,100,6)) 24 | self.counter = 0 25 | 26 | def decode_detection(self, output, h, w): 27 | ct_hm = output['ct_hm'] 28 | wh = output['wh'] 29 | ct, detection = snake_decode.decode_ct_hm(torch.sigmoid(ct_hm), wh) 30 | detection[..., :4] = data_utils.clip_to_image(detection[..., :4], h, w) 31 | output.update({'ct': ct, 'detection': detection}) 32 | return ct, detection 33 | 34 | def use_gt_detection(self, output, batch): 35 | _, _, height, width = output['ct_hm'].size() 36 | ct_01 = batch['ct_01'].byte() 37 | 38 | ct_ind = batch['ct_ind'][ct_01] 39 | 40 | xs, ys = ct_ind % width, ct_ind // width 41 | xs, ys = xs[:, None].float(), ys[:, None].float() 42 | ct = torch.cat([xs, ys], dim=1) 43 | 44 | wh = batch['wh'][ct_01] 45 | bboxes = torch.cat([xs - wh[..., 0:1] / 2, 46 | ys - wh[..., 1:2] / 2, 47 | xs + wh[..., 0:1] / 2, 48 | ys + wh[..., 1:2] / 2], dim=1) 49 | score = torch.ones([len(bboxes)]).to(bboxes)[:, None] 50 | ct_cls = batch['ct_cls'][ct_01].float()[:, None] 51 | detection = torch.cat([bboxes, score, ct_cls], dim=1) 52 | 53 | output['ct'] = ct[None] 54 | output['detection'] = detection[None] 55 | return output 56 | 57 | def forward(self, x, batch=None): 58 | output, cnn_feature = self.dla(x) 59 | #cnn feature [batch_size/num_gpus, 64, 128, 128] 60 | #out put ['ct_hm', 'wh'] [-, 2, 128, 128] [-, 2, 128, 128] 61 | with torch.no_grad(): 62 | ct, detection = self.decode_detection(output, cnn_feature.size(2), cnn_feature.size(3)) 63 | 64 | if cfg.use_gt_det: 65 | self.use_gt_detection(output, batch) 66 | 67 | 68 | output = self.gcn(output, cnn_feature, batch) 69 | #print(output['wh'].shape) #['ct_hm', 'wh', 'ct', 'detection', 'it_ex', 'ex', 'it_py', 'py'] 70 | return output 71 | 72 | 73 | def get_network(num_layers, heads, head_conv=256, down_ratio=4, det_dir=''): 74 | network = Network(num_layers, heads, head_conv, down_ratio, det_dir) 75 | return network 76 | -------------------------------------------------------------------------------- /lib/networks/snake/ct_snake_.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .dla import DLASeg 3 | from .evolve import Evolution 4 | from lib.utils import net_utils, data_utils 5 | from lib.utils.snake import snake_decode 6 | import torch 7 | import torch.nn.functional as F 8 | from lib.config import cfg 9 | from .unethead import UNet 10 | import time 11 | from torchvision.utils import save_image 12 | import sys 13 | 14 | class Network(nn.Module): 15 | def __init__(self, num_layers, heads, head_conv=256, down_ratio=4, det_dir=''): 16 | super(Network, self).__init__() 17 | 18 | self.dla = DLASeg('dla{}'.format(num_layers), heads, 19 | pretrained=True, 20 | down_ratio=down_ratio, 21 | final_kernel=1, 22 | last_level=5, 23 | head_conv=head_conv) 24 | self.gcn = Evolution() 25 | self.unet = UNet() 26 | pretrain_model_dict = torch.load(cfg.unet_pretrain_model) 27 | self.unet.load_state_dict(pretrain_model_dict) 28 | 29 | def decode_detection(self, output, h, w): 30 | ct_hm = output['ct_hm'] 31 | wh = output['wh'] 32 | ct, detection = snake_decode.decode_ct_hm(torch.sigmoid(ct_hm), wh) 33 | detection[..., :4] = data_utils.clip_to_image(detection[..., :4], h, w) 34 | output.update({'ct': ct, 'detection': detection}) 35 | return ct, detection 36 | 37 | def use_gt_detection(self, output, batch): 38 | _, _, height, width = output['ct_hm'].size() 39 | ct_01 = batch['ct_01'].byte() 40 | 41 | ct_ind = batch['ct_ind'][ct_01] 42 | 43 | xs, ys = ct_ind % width, ct_ind // width 44 | xs, ys = xs[:, None].float(), ys[:, None].float() 45 | ct = torch.cat([xs, ys], dim=1) 46 | 47 | wh = batch['wh'][ct_01] 48 | bboxes = torch.cat([xs - wh[..., 0:1] / 2, 49 | ys - wh[..., 1:2] / 2, 50 | xs + wh[..., 0:1] / 2, 51 | ys + wh[..., 1:2] / 2], dim=1) 52 | score = torch.ones([len(bboxes)]).to(bboxes)[:, None] 53 | ct_cls = batch['ct_cls'][ct_01].float()[:, None] 54 | detection = torch.cat([bboxes, score, ct_cls], dim=1) 55 | 56 | output['ct'] = ct[None] 57 | output['detection'] = detection[None] 58 | return output 59 | 60 | def forward(self, x, batch=None): 61 | x_input_unet = F.interpolate(x, scale_factor=0.25,mode='nearest') 62 | output, cnn_feature = self.dla(x) 63 | unet_mapfeature = self.unet(x_input_unet) 64 | zeros = torch.zeros_like(unet_mapfeature) 65 | unet_mapfeature = torch.where(unet_mapfeature < 0, zeros, unet_mapfeature) 66 | for i in range(unet_mapfeature.shape[0]): 67 | for j in range(unet_mapfeature.shape[1]): 68 | min1 = unet_mapfeature[i][j].min() 69 | print(i,j,unet_mapfeature[i][j].max(),min1, unet_mapfeature[i][j].mean()) 70 | unet_mapfeature[i][j] = 255*torch.div((unet_mapfeature[i][j]-min1),(unet_mapfeature[i][j].max()-min1)) 71 | 72 | with torch.no_grad(): 73 | ct, detection = self.decode_detection(output, cnn_feature.size(2), cnn_feature.size(3)) 74 | 75 | if cfg.use_gt_det: 76 | self.use_gt_detection(output, batch) 77 | cnn_feature = torch.cat([cnn_feature, unet_mapfeature], dim=1) 78 | output = self.gcn(output, cnn_feature, batch) 79 | return output 80 | 81 | 82 | def get_network(num_layers, heads, head_conv=256, down_ratio=4, det_dir=''): 83 | network = Network(num_layers, heads, head_conv, down_ratio, det_dir) 84 | return network 85 | -------------------------------------------------------------------------------- /lib/networks/snake/ct_snake单独训练unet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .dla import DLASeg 3 | from .evolve import Evolution 4 | from lib.utils import net_utils, data_utils 5 | from lib.utils.snake import snake_decode 6 | import torch 7 | import torch.nn.functional as F 8 | from lib.config import cfg 9 | from .unethead import UNet 10 | import time 11 | from torchvision.utils import save_image 12 | import sys 13 | 14 | class Network(nn.Module): 15 | def __init__(self, num_layers, heads, head_conv=256, down_ratio=4, det_dir=''): 16 | super(Network, self).__init__() 17 | 18 | self.dla = DLASeg('dla{}'.format(num_layers), heads, 19 | pretrained=True, 20 | down_ratio=down_ratio, 21 | final_kernel=1, 22 | last_level=5, 23 | head_conv=head_conv) 24 | self.gcn = Evolution() 25 | self.unet = UNet() 26 | 27 | def decode_detection(self, output, h, w): 28 | ct_hm = output['ct_hm'] 29 | wh = output['wh'] 30 | ct, detection = snake_decode.decode_ct_hm(torch.sigmoid(ct_hm), wh) 31 | detection[..., :4] = data_utils.clip_to_image(detection[..., :4], h, w) 32 | output.update({'ct': ct, 'detection': detection}) 33 | return ct, detection 34 | 35 | def use_gt_detection(self, output, batch): 36 | _, _, height, width = output['ct_hm'].size() 37 | ct_01 = batch['ct_01'].byte() 38 | 39 | ct_ind = batch['ct_ind'][ct_01] 40 | 41 | xs, ys = ct_ind % width, ct_ind // width 42 | xs, ys = xs[:, None].float(), ys[:, None].float() 43 | ct = torch.cat([xs, ys], dim=1) 44 | 45 | wh = batch['wh'][ct_01] 46 | bboxes = torch.cat([xs - wh[..., 0:1] / 2, 47 | ys - wh[..., 1:2] / 2, 48 | xs + wh[..., 0:1] / 2, 49 | ys + wh[..., 1:2] / 2], dim=1) 50 | score = torch.ones([len(bboxes)]).to(bboxes)[:, None] 51 | ct_cls = batch['ct_cls'][ct_01].float()[:, None] 52 | detection = torch.cat([bboxes, score, ct_cls], dim=1) 53 | 54 | output['ct'] = ct[None] 55 | output['detection'] = detection[None] 56 | return output 57 | 58 | def forward(self, x, unet_input, batch=None): 59 | #x:[batch_size/num_gpus, 3, 512, 512] 60 | unet_input = unet_input.to(torch.float32) 61 | unet_input = unet_input.transpose(2,3) 62 | unet_input = unet_input.transpose(1,2) 63 | x_input_unet = F.interpolate(unet_input, scale_factor=0.25,mode='nearest') 64 | save_image(x_input_unet/255, "/data/tzx/AADebug_img/x_input_unet.jpg") 65 | 66 | output, cnn_feature = self.dla(x) 67 | 68 | mapE, mapA, mapB = self.unet(x_input_unet) 69 | 70 | unet_mapfeature = torch.cat([mapE, mapA, mapB ], dim=1) 71 | zeros = torch.zeros_like(unet_mapfeature) 72 | unet_mapfeature = torch.where(unet_mapfeature < 0, zeros, unet_mapfeature) 73 | unet_mapfeature = 255*(unet_mapfeature-unet_mapfeature.min())/(unet_mapfeature.max()-unet_mapfeature.min()) 74 | save_image(mapA/255,"/data/tzx/AADebug_img/mapA.jpg") 75 | sys.exit(0) 76 | 77 | 78 | 79 | with torch.no_grad(): 80 | ct, detection = self.decode_detection(output, cnn_feature.size(2), cnn_feature.size(3)) 81 | 82 | if cfg.use_gt_det: 83 | self.use_gt_detection(output, batch) 84 | cnn_feature = torch.cat([cnn_feature, unet_mapfeature], dim=1) 85 | output = self.gcn(output, cnn_feature, batch) 86 | return output, mapE, mapA, mapB 87 | 88 | 89 | def get_network(num_layers, heads, head_conv=256, down_ratio=4, det_dir=''): 90 | network = Network(num_layers, heads, head_conv, down_ratio, det_dir) 91 | return network 92 | -------------------------------------------------------------------------------- /lib/networks/snake/ct_snake用预训练unet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .dla import DLASeg 3 | from .evolve import Evolution 4 | from lib.utils import net_utils, data_utils 5 | from lib.utils.snake import snake_decode 6 | import torch 7 | import torch.nn.functional as F 8 | from lib.config import cfg 9 | from .unethead import UNet 10 | import time 11 | from torchvision.utils import save_image 12 | import sys 13 | 14 | class Network(nn.Module): 15 | def __init__(self, num_layers, heads, head_conv=256, down_ratio=4, det_dir=''): 16 | super(Network, self).__init__() 17 | 18 | self.dla = DLASeg('dla{}'.format(num_layers), heads, 19 | pretrained=True, 20 | down_ratio=down_ratio, 21 | final_kernel=1, 22 | last_level=5, 23 | head_conv=head_conv) 24 | self.gcn = Evolution() 25 | self.unet = UNet() 26 | pretrain_model_dict = torch.load(cfg.unet_pretrain_model) 27 | self.unet.load_state_dict(pretrain_model_dict) 28 | 29 | def decode_detection(self, output, h, w): 30 | ct_hm = output['ct_hm'] 31 | wh = output['wh'] 32 | ct, detection = snake_decode.decode_ct_hm(torch.sigmoid(ct_hm), wh) 33 | detection[..., :4] = data_utils.clip_to_image(detection[..., :4], h, w) 34 | output.update({'ct': ct, 'detection': detection}) 35 | return ct, detection 36 | 37 | def use_gt_detection(self, output, batch): 38 | _, _, height, width = output['ct_hm'].size() 39 | ct_01 = batch['ct_01'].byte() 40 | 41 | ct_ind = batch['ct_ind'][ct_01] 42 | 43 | xs, ys = ct_ind % width, ct_ind // width 44 | xs, ys = xs[:, None].float(), ys[:, None].float() 45 | ct = torch.cat([xs, ys], dim=1) 46 | 47 | wh = batch['wh'][ct_01] 48 | bboxes = torch.cat([xs - wh[..., 0:1] / 2, 49 | ys - wh[..., 1:2] / 2, 50 | xs + wh[..., 0:1] / 2, 51 | ys + wh[..., 1:2] / 2], dim=1) 52 | score = torch.ones([len(bboxes)]).to(bboxes)[:, None] 53 | ct_cls = batch['ct_cls'][ct_01].float()[:, None] 54 | detection = torch.cat([bboxes, score, ct_cls], dim=1) 55 | 56 | output['ct'] = ct[None] 57 | output['detection'] = detection[None] 58 | return output 59 | 60 | def forward(self, x, unet_input, batch=None): 61 | #x:[batch_size/num_gpus, 3, 512, 512] 62 | unet_input = unet_input.to(torch.float32) 63 | unet_input = unet_input.transpose(2,3) 64 | unet_input = unet_input.transpose(1,2) 65 | #save_image(unet_input[0]/255,'/data/tzx/snake_envo_num/visual_result/debug/unetinput.png') 66 | x_input_unet = F.interpolate(unet_input, scale_factor=0.25,mode='nearest') 67 | output, cnn_feature = self.dla(x) 68 | 69 | unet_mapfeature = self.unet(x_input_unet) 70 | #print(unet_mapfeature,unet_mapfeature.max(),unet_mapfeature.min(), unet_mapfeature.mean()) 71 | #sys.exit(0) 72 | zeros = torch.zeros_like(unet_mapfeature) 73 | unet_mapfeature = torch.where(unet_mapfeature < 0, zeros, unet_mapfeature) 74 | unet_mapfeature = 255*(unet_mapfeature-unet_mapfeature.min())/(unet_mapfeature.max()-unet_mapfeature.min()) 75 | 76 | #print(unet_mapfeature.shape) 77 | 78 | #save_image(x[0],'/data/tzx/snake_envo_num/visual_result/debug/x.png') 79 | #save_image(unet_mapfeature[0],'/data/tzx/snake_envo_num/visual_result/debug/mape_visual.png') 80 | #sys.exit(0) 81 | #print(cnn_feature.shape) 82 | 83 | #cnn feature [batch_size/num_gpus, 64, 128, 128] 84 | #out put ['ct_hm', 'wh'] [-, 2, 128, 128] [-, 2, 128, 128] 85 | 86 | 87 | 88 | with torch.no_grad(): 89 | ct, detection = self.decode_detection(output, cnn_feature.size(2), cnn_feature.size(3)) 90 | 91 | if cfg.use_gt_det: 92 | self.use_gt_detection(output, batch) 93 | #print(output['ct_hm'].shape,output['wh'].shape,output['ct'].shape,output['detection'].shape) 94 | cnn_feature = torch.cat([cnn_feature, unet_mapfeature], dim=1) 95 | output = self.gcn(output, cnn_feature, batch) 96 | #print(output['wh'].shape) #['ct_hm', 'wh', 'ct', 'detection', 'it_ex', 'ex', 'it_py', 'py'] 97 | 98 | return output 99 | 100 | 101 | def get_network(num_layers, heads, head_conv=256, down_ratio=4, det_dir=''): 102 | network = Network(num_layers, heads, head_conv, down_ratio, det_dir) 103 | return network 104 | -------------------------------------------------------------------------------- /lib/networks/snake/snake.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import sys 4 | 5 | class CircConv(nn.Module): 6 | def __init__(self, state_dim, out_state_dim=None, n_adj=4): 7 | super(CircConv, self).__init__() 8 | 9 | self.n_adj = n_adj 10 | out_state_dim = state_dim if out_state_dim is None else out_state_dim 11 | self.fc = nn.Conv1d(state_dim, out_state_dim, kernel_size=self.n_adj*2+1) 12 | 13 | def forward(self, input, adj): 14 | input = torch.cat([input[..., -self.n_adj:], input, input[..., :self.n_adj]], dim=2) 15 | return self.fc(input) 16 | 17 | 18 | class DilatedCircConv(nn.Module): 19 | def __init__(self, state_dim, out_state_dim=None, n_adj=4, dilation=1): 20 | super(DilatedCircConv, self).__init__() 21 | 22 | self.n_adj = n_adj 23 | self.dilation = dilation 24 | out_state_dim = state_dim if out_state_dim is None else out_state_dim 25 | self.fc = nn.Conv1d(state_dim, out_state_dim, kernel_size=self.n_adj*2+1, dilation=self.dilation) 26 | 27 | def forward(self, input, adj): 28 | if self.n_adj != 0: 29 | input = torch.cat([input[..., -self.n_adj*self.dilation:], input, input[..., :self.n_adj*self.dilation]], dim=2) 30 | return self.fc(input) 31 | 32 | 33 | _conv_factory = { 34 | 'grid': CircConv, 35 | 'dgrid': DilatedCircConv 36 | } 37 | 38 | 39 | class BasicBlock(nn.Module): 40 | def __init__(self, state_dim, out_state_dim, conv_type, n_adj=4, dilation=1): 41 | super(BasicBlock, self).__init__() 42 | 43 | self.conv = _conv_factory[conv_type](state_dim, out_state_dim, n_adj, dilation) 44 | self.relu = nn.ReLU(inplace=True) 45 | self.norm = nn.BatchNorm1d(out_state_dim) 46 | 47 | def forward(self, x, adj=None): 48 | x = self.conv(x, adj) 49 | x = self.relu(x) 50 | x = self.norm(x) 51 | return x 52 | 53 | 54 | class Snake(nn.Module): 55 | def __init__(self, state_dim, feature_dim, conv_type='dgrid'): 56 | super(Snake, self).__init__() 57 | 58 | self.head = BasicBlock(feature_dim, state_dim, conv_type) 59 | 60 | self.res_layer_num = 7 61 | dilation = [1, 1, 1, 2, 2, 4, 4] 62 | for i in range(self.res_layer_num): 63 | conv = BasicBlock(state_dim, state_dim, conv_type, n_adj=4, dilation=dilation[i]) 64 | self.__setattr__('res'+str(i), conv) 65 | 66 | fusion_state_dim = 256 67 | self.fusion = nn.Conv1d(state_dim * (self.res_layer_num + 1), fusion_state_dim, 1) 68 | self.prediction = nn.Sequential( 69 | nn.Conv1d(state_dim * (self.res_layer_num + 1) + fusion_state_dim, 256, 1), 70 | nn.ReLU(inplace=True), 71 | nn.Conv1d(256, 64, 1), 72 | nn.ReLU(inplace=True), 73 | nn.Conv1d(64, 2, 1) 74 | ) 75 | 76 | def forward(self, x, adj): 77 | states = [] 78 | 79 | x = self.head(x, adj) # [16, 128, 40] 80 | states.append(x) 81 | for i in range(self.res_layer_num): 82 | x = self.__getattr__('res'+str(i))(x, adj) + x 83 | states.append(x) 84 | 85 | state = torch.cat(states, dim=1) 86 | 87 | global_state = torch.max(self.fusion(state), dim=2, keepdim=True)[0] 88 | 89 | global_state = global_state.expand(global_state.size(0), global_state.size(1), state.size(2)) 90 | state = torch.cat([global_state, state], dim=1) 91 | x = self.prediction(state) 92 | 93 | return x 94 | -------------------------------------------------------------------------------- /lib/networks/snake/unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from fightingcv_attention.attention.SelfAttention import ScaledDotProductAttention 5 | 6 | 7 | class DoubleConv(nn.Module): 8 | 9 | def __init__(self, in_channels, out_channels, mid_channels=None): 10 | super().__init__() 11 | if not mid_channels: 12 | mid_channels = out_channels 13 | self.double_conv = nn.Sequential( 14 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False), 15 | nn.BatchNorm2d(mid_channels), 16 | nn.ReLU(inplace=True), 17 | nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False), 18 | nn.BatchNorm2d(out_channels), 19 | nn.ReLU(inplace=True) 20 | ) 21 | 22 | def forward(self, x): 23 | return self.double_conv(x) 24 | 25 | 26 | class Down(nn.Module): 27 | 28 | def __init__(self, in_channels, out_channels): 29 | super().__init__() 30 | self.maxpool_conv = nn.Sequential( 31 | nn.MaxPool2d(2), 32 | DoubleConv(in_channels, out_channels) 33 | ) 34 | 35 | def forward(self, x): 36 | return self.maxpool_conv(x) 37 | 38 | 39 | class Up(nn.Module): 40 | def __init__(self, in_channels, out_channels, bilinear=True): 41 | super().__init__() 42 | 43 | # if bilinear, use the normal convolutions to reduce the number of channels 44 | if bilinear: 45 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 46 | self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) 47 | else: 48 | self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) 49 | self.conv = DoubleConv(in_channels, out_channels) 50 | 51 | def forward(self, x1, x2): 52 | x1 = self.up(x1) 53 | # input is CHW 54 | diffY = x2.size()[2] - x1.size()[2] 55 | diffX = x2.size()[3] - x1.size()[3] 56 | 57 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, 58 | diffY // 2, diffY - diffY // 2]) 59 | # if you have padding issues, see 60 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a 61 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd 62 | x = torch.cat([x2, x1], dim=1) 63 | return self.conv(x) 64 | 65 | 66 | class OutConv(nn.Module): 67 | def __init__(self, in_channels, out_channels): 68 | super(OutConv, self).__init__() 69 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) 70 | 71 | def forward(self, x): 72 | return self.conv(x) 73 | 74 | 75 | 76 | class CrossAttentionUNet(nn.Module): 77 | def __init__(self, n_channels, n_classes, bilinear=False): 78 | super(UNet, self).__init__() 79 | self.n_channels = n_channels 80 | self.n_classes = n_classes 81 | self.bilinear = bilinear 82 | 83 | self.inc = (DoubleConv(n_channels, 64)) 84 | self.down1 = (Down(64, 128)) 85 | self.down2 = (Down(128, 256)) 86 | self.down3 = (Down(256, 512)) 87 | self.attention_block = ScaledDotProductAttention(d_model=256, d_k=256, d_v=256, h=8) 88 | factor = 2 if bilinear else 1 89 | #self.down4 = (Down(512, 1024 // factor)) 90 | #self.up1 = (Up(1024, 512 // factor, bilinear)) 91 | self.up2 = (Up(512, 256 // factor, bilinear)) 92 | self.up3 = (Up(256, 128 // factor, bilinear)) 93 | self.up4 = (Up(128, 64, bilinear)) 94 | self.outc = (OutConv(64, n_classes)) 95 | 96 | def forward(self, xpre, xmid, xlate): 97 | #计算第一个Unet块的特征 98 | xmid1 = self.inc(xmid) 99 | xpre1 = self.inc(xpre) 100 | xlate1 = self.inc(xlate) 101 | #计算第一个Unet下采样的结果 102 | xmid2 = self.down1(xmid1) 103 | xpre2 = self.down1(xpre1) 104 | xlate2 = self.down1(xlate1) 105 | #第2个 106 | xmid3 = self.down2(xmid2) 107 | xpre3 = self.down2(xpre2) 108 | xlate3 = self.down2(xlate2) 109 | #第3个 110 | xmid4 = self.down3(xmid3) 111 | xpre4 = self.down3(xpre3) 112 | xlate4 = self.down3(xlate3) 113 | #view成attention需要的形状 114 | bs,ntoken,_,_ = xmid4.shape 115 | xmid4=xmid4.view(bs,ntoken,-1) 116 | xpre4=xpre4.view(bs,ntoken,-1) 117 | xlate4=xlate4.view(bs,ntoken,-1) 118 | xmid4 = self.attention_block(xmid4,xpre4,xlate4) 119 | xmid4 = xmid4.view(bs,ntoken,16,16) 120 | #1上采样 121 | xmid = self.up2(xmid4, xmid3) 122 | 123 | #2上采样 124 | xmid = self.up3(xmid, xmid2) 125 | 126 | #3上采样 127 | xmid = self.up4(xmid, xmid1) 128 | return xmid 129 | 130 | 131 | 132 | if __name__=="__main__": 133 | a=torch.zeros((32,64,128,128)) 134 | net = UNet(64,2) 135 | print(net(a,a,a).shape) -------------------------------------------------------------------------------- /lib/train/__init__.py: -------------------------------------------------------------------------------- 1 | from .trainers import make_trainer 2 | from .optimizer import make_optimizer 3 | from .scheduler import make_lr_scheduler, set_lr_scheduler 4 | from .recorder import make_recorder 5 | 6 | -------------------------------------------------------------------------------- /lib/train/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/lib/train/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /lib/train/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/lib/train/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /lib/train/__pycache__/optimizer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/lib/train/__pycache__/optimizer.cpython-37.pyc -------------------------------------------------------------------------------- /lib/train/__pycache__/optimizer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/lib/train/__pycache__/optimizer.cpython-38.pyc -------------------------------------------------------------------------------- /lib/train/__pycache__/recorder.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/lib/train/__pycache__/recorder.cpython-37.pyc -------------------------------------------------------------------------------- /lib/train/__pycache__/recorder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/lib/train/__pycache__/recorder.cpython-38.pyc -------------------------------------------------------------------------------- /lib/train/__pycache__/scheduler.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/lib/train/__pycache__/scheduler.cpython-37.pyc -------------------------------------------------------------------------------- /lib/train/__pycache__/scheduler.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/lib/train/__pycache__/scheduler.cpython-38.pyc -------------------------------------------------------------------------------- /lib/train/optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from lib.utils.optimizer.radam import RAdam 3 | 4 | 5 | _optimizer_factory = { 6 | 'adam': torch.optim.Adam, 7 | 'radam': RAdam, 8 | 'sgd': torch.optim.SGD 9 | } 10 | 11 | 12 | def make_optimizer(cfg, net): 13 | params = [] 14 | lr = cfg.train.lr 15 | weight_decay = cfg.train.weight_decay 16 | 17 | for key, value in net.named_parameters(): 18 | if not value.requires_grad: 19 | continue 20 | params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}] 21 | 22 | if 'adam' in cfg.train.optim: 23 | optimizer = _optimizer_factory[cfg.train.optim](params, lr, weight_decay=weight_decay) 24 | else: 25 | optimizer = _optimizer_factory[cfg.train.optim](params, lr, momentum=0.9) 26 | 27 | return optimizer 28 | -------------------------------------------------------------------------------- /lib/train/recorder.py: -------------------------------------------------------------------------------- 1 | from collections import deque, defaultdict 2 | import torch 3 | from tensorboardX import SummaryWriter 4 | import os 5 | 6 | 7 | class SmoothedValue(object): 8 | """Track a series of values and provide access to smoothed values over a 9 | window or the global series average. 10 | """ 11 | 12 | def __init__(self, window_size=20): 13 | self.deque = deque(maxlen=window_size) 14 | self.total = 0.0 15 | self.count = 0 16 | 17 | def update(self, value): 18 | self.deque.append(value) 19 | self.count += 1 20 | self.total += value 21 | 22 | @property 23 | def median(self): 24 | d = torch.tensor(list(self.deque)) 25 | return d.median().item() 26 | 27 | @property 28 | def avg(self): 29 | d = torch.tensor(list(self.deque)) 30 | return d.mean().item() 31 | 32 | @property 33 | def global_avg(self): 34 | return self.total / self.count 35 | 36 | 37 | class Recorder(object): 38 | def __init__(self, cfg): 39 | log_dir = cfg.record_dir 40 | if not cfg.resume: 41 | os.system('rm -rf {}'.format(log_dir)) 42 | self.writer = SummaryWriter(log_dir=log_dir) 43 | 44 | # scalars 45 | self.epoch = 0 46 | self.step = 0 47 | self.loss_stats = defaultdict(SmoothedValue) 48 | self.batch_time = SmoothedValue() 49 | self.data_time = SmoothedValue() 50 | 51 | # images 52 | self.image_stats = defaultdict(object) 53 | if 'process_'+cfg.task in globals(): 54 | self.processor = globals()['process_'+cfg.task] 55 | else: 56 | self.processor = None 57 | 58 | def update_loss_stats(self, loss_dict): 59 | for k, v in loss_dict.items(): 60 | self.loss_stats[k].update(v.detach().cpu()) 61 | 62 | def update_image_stats(self, image_stats): 63 | if self.processor is None: 64 | return 65 | image_stats = self.processor(image_stats) 66 | for k, v in image_stats.items(): 67 | self.image_stats[k] = v.detach().cpu() 68 | 69 | def record(self, prefix, step=-1, loss_stats=None, image_stats=None): 70 | pattern = prefix + '/{}' 71 | step = step if step >= 0 else self.step 72 | loss_stats = loss_stats if loss_stats else self.loss_stats 73 | 74 | for k, v in loss_stats.items(): 75 | if isinstance(v, SmoothedValue): 76 | self.writer.add_scalar(pattern.format(k), v.median, step) 77 | else: 78 | self.writer.add_scalar(pattern.format(k), v, step) 79 | 80 | if self.processor is None: 81 | return 82 | image_stats = self.processor(image_stats) if image_stats else self.image_stats 83 | for k, v in image_stats.items(): 84 | self.writer.add_image(pattern.format(k), v, step) 85 | 86 | def state_dict(self): 87 | scalar_dict = {} 88 | scalar_dict['step'] = self.step 89 | return scalar_dict 90 | 91 | def load_state_dict(self, scalar_dict): 92 | self.step = scalar_dict['step'] 93 | 94 | def __str__(self): 95 | loss_state = [] 96 | for k, v in self.loss_stats.items(): 97 | loss_state.append('{}: {:.4f}'.format(k, v.avg)) 98 | loss_state = ' '.join(loss_state) 99 | 100 | recording_state = ' '.join(['epoch: {}', 'step: {}', '{}', 'data: {:.4f}', 'batch: {:.4f}']) 101 | return recording_state.format(self.epoch, self.step, loss_state, self.data_time.avg, self.batch_time.avg) 102 | 103 | 104 | def make_recorder(cfg): 105 | return Recorder(cfg) 106 | 107 | -------------------------------------------------------------------------------- /lib/train/scheduler.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import MultiStepLR 2 | from collections import Counter 3 | from lib.utils.optimizer.lr_scheduler import WarmupMultiStepLR, ManualStepLR 4 | 5 | 6 | def make_lr_scheduler(cfg, optimizer): 7 | if cfg.train.warmup: 8 | scheduler = WarmupMultiStepLR(optimizer, cfg.train.milestones, cfg.train.gamma, 1.0/3, 5, 'linear') 9 | elif cfg.train.scheduler == 'manual': 10 | scheduler = ManualStepLR(optimizer, milestones=cfg.train.milestones, gammas=cfg.train.gammas) 11 | else: 12 | scheduler = MultiStepLR(optimizer, milestones=cfg.train.milestones, gamma=cfg.train.gamma) 13 | return scheduler 14 | 15 | 16 | def set_lr_scheduler(cfg, scheduler): 17 | if cfg.train.warmup: 18 | scheduler.milestones = cfg.train.milestones 19 | else: 20 | scheduler.milestones = Counter(cfg.train.milestones) 21 | scheduler.gamma = cfg.train.gamma 22 | 23 | -------------------------------------------------------------------------------- /lib/train/trainers/PolyProcess.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch 3 | import numpy as np 4 | from PIL import Image, ImageDraw, ImageMath 5 | import sys 6 | 7 | def draw_poly(poly,values,im_shape,brush_size): 8 | """ Returns a MxN (im_shape) array with values in the pixels crossed 9 | by the edges of the polygon (poly). total_points is the maximum number 10 | of pixels used for the linear interpolation. 11 | """ 12 | u = poly[:,0] 13 | v = poly[:,1] 14 | b = np.round(brush_size/2) 15 | image = Image.fromarray(np.zeros(im_shape)) 16 | image2 = Image.fromarray(np.zeros(im_shape)) 17 | d = ImageDraw.Draw(image) 18 | if type(values) is int: 19 | values = np.ones(np.shape(u)) * values # 全1矩阵再乘上values 20 | for n in range(len(poly)): 21 | d.ellipse([(v[n]-b,u[n]-b),(v[n]+b,u[n]+b)], fill=values[n]) # 好像在画一个椭圆 22 | image2 = ImageMath.eval("convert(max(a, b), 'F')", a=image, b=image2) 23 | return torch.from_numpy(np.array(image2)) # 蛇上点及其4邻域的点上的值是5,其他地方的值是0 24 | 25 | def draw_poly_fill(poly,im_shape,values=1): 26 | """Returns a MxN (im_shape) array with 1s in the interior of the polygon 27 | defined by (poly) and 0s outside.""" 28 | u = poly[:, 0] 29 | v = poly[:, 1] 30 | image = Image.fromarray(np.zeros(im_shape)) 31 | d = ImageDraw.Draw(image) 32 | if not values == 1: 33 | if (values) is int: 34 | values = np.ones(np.shape(u)) * values 35 | d.polygon(np.column_stack((v, u)).reshape(-1).tolist(), fill=values, outline=1) 36 | return np.array(image) 37 | 38 | def batch_mask_convert(contours, im_shape): 39 | ''' 40 | Returns masks in (imH, imW, batchno), 0-1 binary, PyTorch Tensor 41 | ''' 42 | batch_mask = np.zeros([contours.shape[0], im_shape[0], im_shape[1]]) 43 | 44 | for i in range(contours.shape[0]): 45 | batch_mask[i,:,:] = draw_poly_fill(contours[i,:,:].detach().cpu().numpy(), im_shape, values=1) 46 | 47 | return torch.from_numpy(batch_mask) 48 | 49 | 50 | def GTpoly(poly, im_shape, brush_size, GTmask): 51 | GTmask = GTmask.cpu().numpy() 52 | # GTmask是一个0-1之间的float,np.round之后变为准确掩膜 53 | #ret, GTmask = cv2.threshold(GTmask, 80, 255, cv2.THRESH_BINARY) 54 | side = cv2.Canny(GTmask.astype(np.uint8), 200, 255) 55 | mask_contour, thresh = cv2.findContours(side, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) 56 | poly = poly.cpu().numpy() 57 | 58 | imsize = im_shape[0] 59 | 60 | u = poly[:,0] 61 | v = poly[:,1] 62 | b = np.round(brush_size/2) 63 | image = Image.fromarray(np.zeros(im_shape)) 64 | image2 = Image.fromarray(np.zeros(im_shape)) 65 | d = ImageDraw.Draw(image) 66 | 67 | values = np.ones(np.shape(u)) 68 | for i in range(v.shape[0]): 69 | dist = cv2.pointPolygonTest(mask_contour[0],[v[i],u[i]],True) 70 | if dist==0: # 在边界上 71 | values[i] = 0 72 | elif abs(dist) > 5: # 向内/外偏移得较远 73 | values[i] = abs(dist/imsize * 10) 74 | else: 75 | values[i] = abs(dist/imsize * 3) 76 | 77 | for n in range(len(poly)): 78 | d.ellipse([(v[n]-b,u[n]-b),(v[n]+b,u[n]+b)], fill=values[n]) # 实际上list里画的是椭圆形的边界框,这样子就是在上边点点 79 | image2 = ImageMath.eval("convert(max(a, b), 'F')", a=image, b=image2) 80 | return torch.from_numpy(np.array(image2)) # 蛇上点及其4邻域的点上的值是5,其他地方的值是0 81 | 82 | def derivatives_poly(poly): 83 | """ 84 | :param poly: the Lx2 polygon array [u,v] 85 | :return: der1, der1, Lx2 derivatives arrays 86 | """ 87 | 88 | poly = poly.cpu().numpy() 89 | u = poly[:, 0] 90 | v = poly[:, 1] 91 | L = len(u) 92 | der1_mat = -np.roll(np.eye(L), -1, axis=1) + \ 93 | np.roll(np.eye(L), -1, axis=0) # first order derivative, central difference 94 | # 上句构造的矩阵,主对角线是0,然后上面一层都是1,下面一层都是-1,其他的都是0。那和原有向量乘完了,就是一阶隔项差分用的。 95 | 96 | der2_mat = np.roll(np.eye(L), -1, axis=0) + \ 97 | np.roll(np.eye(L), -1, axis=1) - \ 98 | 2 * np.eye(L) # second order derivative, central difference 99 | # 主对角线是-2,然后上面一层和下面一层都是1,其他的都是0。那和原有向量乘完了,就是二阶差分用的。 100 | der1 = np.sqrt(np.power(np.matmul(der1_mat, u), 2) + \ 101 | np.power(np.matmul(der1_mat, v), 2)) # 蛇上每一点的一阶差分的模(对u坐标和v坐标上的差分,平方-相加-开方),长度为L。 102 | der2 = np.sqrt(np.power(np.matmul(der2_mat, u), 2) + \ 103 | np.power(np.matmul(der2_mat, v), 2)) # 蛇上每一点的二阶差分的模(对u坐标和v坐标上的差分,平方-相加-开方),长度为L。 104 | return torch.from_numpy(der1), torch.from_numpy(der2) -------------------------------------------------------------------------------- /lib/train/trainers/__init__.py: -------------------------------------------------------------------------------- 1 | from .make_trainer import make_trainer 2 | -------------------------------------------------------------------------------- /lib/train/trainers/__pycache__/CCQLoss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/lib/train/trainers/__pycache__/CCQLoss.cpython-37.pyc -------------------------------------------------------------------------------- /lib/train/trainers/__pycache__/PolyProcess.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/lib/train/trainers/__pycache__/PolyProcess.cpython-37.pyc -------------------------------------------------------------------------------- /lib/train/trainers/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/lib/train/trainers/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /lib/train/trainers/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/lib/train/trainers/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /lib/train/trainers/__pycache__/make_trainer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/lib/train/trainers/__pycache__/make_trainer.cpython-37.pyc -------------------------------------------------------------------------------- /lib/train/trainers/__pycache__/make_trainer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/lib/train/trainers/__pycache__/make_trainer.cpython-38.pyc -------------------------------------------------------------------------------- /lib/train/trainers/__pycache__/snake.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/lib/train/trainers/__pycache__/snake.cpython-37.pyc -------------------------------------------------------------------------------- /lib/train/trainers/__pycache__/snakerec.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/lib/train/trainers/__pycache__/snakerec.cpython-37.pyc -------------------------------------------------------------------------------- /lib/train/trainers/__pycache__/trainer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/lib/train/trainers/__pycache__/trainer.cpython-37.pyc -------------------------------------------------------------------------------- /lib/train/trainers/__pycache__/trainer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/lib/train/trainers/__pycache__/trainer.cpython-38.pyc -------------------------------------------------------------------------------- /lib/train/trainers/ct_rcnn.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from lib.utils import net_utils 3 | import torch 4 | 5 | 6 | class NetworkWrapper(nn.Module): 7 | def __init__(self, net): 8 | super(NetworkWrapper, self).__init__() 9 | 10 | self.net = net 11 | 12 | self.act_crit = net_utils.FocalLoss() 13 | self.awh_crit = net_utils.IndL1Loss1d('smooth_l1') 14 | self.cp_crit = net_utils.FocalLoss() 15 | self.cp_wh_crit = net_utils.IndL1Loss1d('smooth_l1') 16 | 17 | def forward(self, batch): 18 | output = self.net(batch['inp'], batch) 19 | 20 | scalar_stats = {} 21 | loss = 0 22 | 23 | act_loss = self.act_crit(net_utils.sigmoid(output['act_hm']), batch['act_hm']) 24 | scalar_stats.update({'act_loss': act_loss}) 25 | loss += act_loss 26 | 27 | awh_loss = self.awh_crit(output['awh'], batch['awh'], batch['act_ind'], batch['act_01']) 28 | scalar_stats.update({'awh_loss': awh_loss}) 29 | loss += 0.1 * awh_loss 30 | 31 | act_01 = batch['act_01'].byte() 32 | 33 | cp_loss = self.cp_crit(net_utils.sigmoid(output['cp_hm']), batch['cp_hm'][act_01]) 34 | scalar_stats.update({'cp_loss': cp_loss}) 35 | loss += cp_loss 36 | 37 | cp_wh, cp_ind, cp_01 = [batch[k][act_01] for k in ['cp_wh', 'cp_ind', 'cp_01']] 38 | cp_wh_loss = self.cp_wh_crit(output['cp_wh'], cp_wh, cp_ind, cp_01) 39 | scalar_stats.update({'cp_wh_loss': cp_wh_loss}) 40 | loss += 0.1 * cp_wh_loss 41 | 42 | scalar_stats.update({'loss': loss}) 43 | image_stats = {} 44 | 45 | return output, loss, scalar_stats, image_stats 46 | 47 | -------------------------------------------------------------------------------- /lib/train/trainers/make_trainer.py: -------------------------------------------------------------------------------- 1 | from .trainer import Trainer 2 | import imp 3 | import os 4 | from lib.config import cfg 5 | 6 | 7 | def _wrapper_factory(cfg, network): 8 | module = '.'.join(['lib.train.trainers', cfg.task]) 9 | if cfg.isrec and cfg.ifmultistage: 10 | path = os.path.join('lib/train/trainers', cfg.task+'rec.py') 11 | else: 12 | path = os.path.join('lib/train/trainers', cfg.task+'.py') 13 | network_wrapper = imp.load_source(module, path).NetworkWrapper(network) 14 | return network_wrapper 15 | 16 | 17 | def make_trainer(cfg, network): 18 | network = _wrapper_factory(cfg, network) 19 | return Trainer(network) 20 | -------------------------------------------------------------------------------- /lib/train/trainers/rcnn_snake.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from lib.utils import net_utils 3 | import torch 4 | 5 | 6 | class NetworkWrapper(nn.Module): 7 | def __init__(self, net): 8 | super(NetworkWrapper, self).__init__() 9 | 10 | self.net = net 11 | 12 | self.act_crit = net_utils.FocalLoss() 13 | self.awh_crit = net_utils.IndL1Loss1d('smooth_l1') 14 | self.cp_crit = net_utils.FocalLoss() 15 | self.cp_wh_crit = net_utils.IndL1Loss1d('smooth_l1') 16 | self.ex_crit = torch.nn.functional.smooth_l1_loss 17 | self.py_crit = torch.nn.functional.smooth_l1_loss 18 | 19 | def forward(self, batch): 20 | output = self.net(batch['inp'], batch) 21 | 22 | scalar_stats = {} 23 | loss = 0 24 | 25 | act_loss = self.act_crit(net_utils.sigmoid(output['act_hm']), batch['act_hm']) 26 | scalar_stats.update({'act_loss': act_loss}) 27 | loss += act_loss 28 | 29 | awh_loss = self.awh_crit(output['awh'], batch['awh'], batch['act_ind'], batch['act_01']) 30 | scalar_stats.update({'awh_loss': awh_loss}) 31 | loss += 0.1 * awh_loss 32 | 33 | act_01 = batch['act_01'].byte() 34 | 35 | cp_loss = self.cp_crit(net_utils.sigmoid(output['cp_hm']), batch['cp_hm'][act_01]) 36 | scalar_stats.update({'cp_loss': cp_loss}) 37 | loss += cp_loss 38 | 39 | cp_wh, cp_ind, cp_01 = [batch[k][act_01] for k in ['cp_wh', 'cp_ind', 'cp_01']] 40 | cp_wh_loss = self.cp_wh_crit(output['cp_wh'], cp_wh, cp_ind, cp_01) 41 | scalar_stats.update({'cp_wh_loss': cp_wh_loss}) 42 | loss += 0.1 * cp_wh_loss 43 | 44 | ex_loss = self.ex_crit(output['ex_pred'], output['i_gt_4py']) 45 | scalar_stats.update({'ex_loss': ex_loss}) 46 | loss += ex_loss 47 | 48 | py_loss = 0 49 | output['py_pred'] = [output['py_pred'][-1]] 50 | for i in range(len(output['py_pred'])): 51 | py_loss += self.py_crit(output['py_pred'][i], output['i_gt_py']) / len(output['py_pred']) 52 | scalar_stats.update({'py_loss': py_loss}) 53 | loss += py_loss 54 | 55 | scalar_stats.update({'loss': loss}) 56 | image_stats = {} 57 | 58 | return output, loss, scalar_stats, image_stats 59 | 60 | -------------------------------------------------------------------------------- /lib/train/trainers/snake.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from lib.utils import net_utils 3 | import torch 4 | import sys 5 | import numpy as np 6 | import cv2 7 | from lib.config import cfg 8 | import numpy as np 9 | import numpy.linalg as LA 10 | 11 | 12 | 13 | 14 | 15 | def intermediate_signal(gtpoly): 16 | decerate = cfg.layer1rate 17 | circle_rate=cfg.circle_rate 18 | sup_polys=[] 19 | for rate in decerate: 20 | for i in range(gtpoly.shape[0]): 21 | gtpyiter = gtpoly[i,:,:] #128*2 22 | # 首先计算poly的中心 扩展成128 23 | center = torch.mean(gtpyiter,dim=0) 24 | center_vector = center.repeat(128,1) 25 | #计算每个点到中心的欧式距离 以及均值 26 | pdist = nn.PairwiseDistance(p=2) 27 | pdist_result = pdist(gtpyiter,center_vector) 28 | mean_pdist = torch.mean(pdist_result) 29 | #计算poly中心点向每个点发从的单位矢量 并乘以均值得到一个圆 30 | vector_poly = gtpyiter-center 31 | pdist_result1 = pdist_result.unsqueeze(dim=1).repeat(1,2) 32 | vector_poly_circle = mean_pdist*vector_poly/pdist_result1 33 | #计算每个点到center距离与均值的差值 34 | gap_dist = pdist_result-mean_pdist 35 | #计算缩放权重 36 | percentage1 = 1+ rate*gap_dist/mean_pdist 37 | percentage1 = percentage1.unsqueeze(dim=1) 38 | percentage1 = percentage1.repeat(1,2) 39 | percentage1 = circle_rate*percentage1 40 | #根据缩放权重计算简化后的值 41 | if i==0: 42 | layer1_sup_poly = vector_poly_circle.mul(percentage1)+center_vector 43 | layer1_sup_poly = layer1_sup_poly.unsqueeze(dim=0) 44 | else: 45 | layer1_sup_poly_toappend = vector_poly_circle.mul(percentage1)+center_vector 46 | layer1_sup_poly_toappend = layer1_sup_poly_toappend.unsqueeze(dim=0) 47 | layer1_sup_poly = torch.cat((layer1_sup_poly,layer1_sup_poly_toappend),dim=0) 48 | sup_polys.append(layer1_sup_poly) 49 | return sup_polys 50 | 51 | 52 | 53 | 54 | class NetworkWrapper(nn.Module): 55 | def __init__(self, net): 56 | super(NetworkWrapper, self).__init__() 57 | self.net = net 58 | self.ct_crit = net_utils.FocalLoss() 59 | self.wh_crit = net_utils.IndL1Loss1d('smooth_l1') 60 | self.reg_crit = net_utils.IndL1Loss1d('smooth_l1') 61 | self.ex_crit = torch.nn.functional.smooth_l1_loss 62 | self.py_crit = torch.nn.functional.smooth_l1_loss 63 | 64 | 65 | def forward(self, batch): 66 | output = self.net(batch['inp'], batch) #需要用原图输入的话output = self.net(batch['inp'], batch['orig_img'], batch) 67 | scalar_stats = {} 68 | loss = 0 69 | 70 | ct_loss = self.ct_crit(net_utils.sigmoid(output['ct_hm']), batch['ct_hm']) 71 | scalar_stats.update({'ct_loss': ct_loss}) 72 | loss += cfg.ct_weight*ct_loss 73 | wh_loss = self.wh_crit(output['wh'], batch['wh'], batch['ct_ind'], batch['ct_01']) 74 | scalar_stats.update({'wh_loss': wh_loss}) 75 | loss += cfg.wh_weight*wh_loss 76 | 77 | ex_loss = self.ex_crit(output['ex_pred'], output['i_gt_4py']) 78 | scalar_stats.update({'ex_loss': ex_loss}) 79 | loss += ex_loss 80 | py_loss = 0 81 | #我们的创新点,使用loss来约束每个演化阶段的输出 82 | sup_polys=intermediate_signal(output['i_gt_py']) 83 | 84 | ifmultistage=cfg.ifmultistage 85 | if ifmultistage: 86 | sup_singal = [] 87 | for poly in sup_polys: 88 | sup_singal.append(poly) 89 | sup_singal.append(output['i_gt_py']) 90 | for i in range(len(output['py_pred'])): 91 | py_loss += self.py_crit(output['py_pred'][i], sup_singal[i]) / len(output['py_pred']) 92 | else: 93 | output['py_pred'] = [output['py_pred'][-1]] 94 | for i in range(len(output['py_pred'])): 95 | py_loss += self.py_crit(output['py_pred'][i], output['i_gt_py']) / len(output['py_pred']) 96 | scalar_stats.update({'py_loss': py_loss}) 97 | loss += py_loss 98 | 99 | 100 | scalar_stats.update({'loss': loss}) 101 | image_stats = {} 102 | 103 | return output, loss, scalar_stats, image_stats 104 | 105 | -------------------------------------------------------------------------------- /lib/train/trainers/trainer.py: -------------------------------------------------------------------------------- 1 | import time 2 | import datetime 3 | import torch 4 | import tqdm 5 | import sys 6 | from torch.nn import DataParallel 7 | import time 8 | from torchvision import utils as vutils 9 | from lib.config import cfg 10 | 11 | 12 | class Trainer(object): 13 | def __init__(self, network): 14 | network = network.cuda() 15 | network = DataParallel(network) 16 | self.network = network 17 | 18 | def reduce_loss_stats(self, loss_stats): 19 | reduced_losses = {k: torch.mean(v) for k, v in loss_stats.items()} 20 | return reduced_losses 21 | 22 | def to_cuda(self, batch): 23 | for k in batch: 24 | if k == 'meta': 25 | continue 26 | if isinstance(batch[k], tuple): 27 | batch[k] = [b.cuda() for b in batch[k]] 28 | else: 29 | batch[k] = batch[k].cuda() 30 | return batch 31 | 32 | def train(self, epoch, data_loader, optimizer, recorder): 33 | max_iter = len(data_loader) 34 | self.network.train() 35 | end = time.time() 36 | print(cfg.model_dir) 37 | 38 | for iteration, batch in enumerate(data_loader): 39 | data_time = time.time() - end 40 | iteration = iteration + 1 41 | recorder.step += 1 42 | 43 | 44 | # batch = self.to_cuda(batch) 45 | #print(batch['i_gt_py'].shape)#dict_keys(['inp', 'meta', 'ct_hm', 'wh', 'ct_cls', 46 | #'ct_ind', 'ct_01', 'i_it_4py', 'c_it_4py', 'i_gt_4py', 'c_gt_4py', 'i_it_py', 'c_it_py', 'i_gt_py', 'c_gt_py']) 47 | output, loss, loss_stats, image_stats = self.network(batch) 48 | 49 | # training stage: loss; optimizer; scheduler 50 | loss = loss.mean() 51 | optimizer.zero_grad() 52 | loss.backward() 53 | torch.nn.utils.clip_grad_value_(self.network.parameters(), 40) 54 | optimizer.step() 55 | 56 | # data recording stage: loss_stats, time, image_stats 57 | loss_stats = self.reduce_loss_stats(loss_stats) 58 | recorder.update_loss_stats(loss_stats) 59 | 60 | batch_time = time.time() - end 61 | end = time.time() 62 | recorder.batch_time.update(batch_time) 63 | recorder.data_time.update(data_time) 64 | 65 | if iteration % 20 == 0 or iteration == (max_iter - 1): 66 | # print training state 67 | eta_seconds = recorder.batch_time.global_avg * (max_iter - iteration) 68 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 69 | lr = optimizer.param_groups[0]['lr'] 70 | memory = torch.cuda.max_memory_allocated() / 1024.0 / 1024.0 71 | 72 | training_state = ' '.join(['eta: {}', '{}', 'lr: {:.6f}', 'max_mem: {:.0f}']) 73 | training_state = training_state.format(eta_string, str(recorder), lr, memory) 74 | print(training_state) 75 | 76 | # record loss_stats and image_dict 77 | recorder.update_image_stats(image_stats) 78 | recorder.record('train') 79 | 80 | def val(self, epoch, data_loader, evaluator=None, recorder=None): 81 | self.network.eval() 82 | torch.cuda.empty_cache() 83 | val_loss_stats = {} 84 | data_size = len(data_loader) 85 | for batch in tqdm.tqdm(data_loader): 86 | for k in batch: 87 | if k != 'meta': 88 | batch[k] = batch[k].cuda() 89 | 90 | with torch.no_grad(): 91 | output, loss, loss_stats, image_stats = self.network(batch) 92 | if evaluator is not None: 93 | evaluator.evaluate(output, batch) 94 | 95 | loss_stats = self.reduce_loss_stats(loss_stats) 96 | for k, v in loss_stats.items(): 97 | val_loss_stats.setdefault(k, 0) 98 | val_loss_stats[k] += v 99 | 100 | loss_state = [] 101 | for k in val_loss_stats.keys(): 102 | val_loss_stats[k] /= data_size 103 | loss_state.append('{}: {:.4f}'.format(k, val_loss_stats[k])) 104 | print(loss_state) 105 | 106 | if evaluator is not None: 107 | result = evaluator.summarize() 108 | val_loss_stats.update(result) 109 | 110 | if recorder: 111 | recorder.record('val', epoch, val_loss_stats, image_stats) 112 | 113 | -------------------------------------------------------------------------------- /lib/utils/__pycache__/data_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/lib/utils/__pycache__/data_utils.cpython-37.pyc -------------------------------------------------------------------------------- /lib/utils/__pycache__/data_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/lib/utils/__pycache__/data_utils.cpython-38.pyc -------------------------------------------------------------------------------- /lib/utils/__pycache__/getedge.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/lib/utils/__pycache__/getedge.cpython-37.pyc -------------------------------------------------------------------------------- /lib/utils/__pycache__/img_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/lib/utils/__pycache__/img_utils.cpython-37.pyc -------------------------------------------------------------------------------- /lib/utils/__pycache__/net_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/lib/utils/__pycache__/net_utils.cpython-37.pyc -------------------------------------------------------------------------------- /lib/utils/base_utils.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import os 3 | import numpy as np 4 | 5 | 6 | def read_pickle(pkl_path): 7 | with open(pkl_path, 'rb') as f: 8 | return pickle.load(f) 9 | 10 | 11 | def save_pickle(data, pkl_path): 12 | os.system('mkdir -p {}'.format(os.path.dirname(pkl_path))) 13 | with open(pkl_path, 'wb') as f: 14 | pickle.dump(data, f) 15 | 16 | 17 | def project(xyz, K, RT): 18 | """ 19 | xyz: [N, 3] 20 | K: [3, 3] 21 | RT: [3, 4] 22 | """ 23 | xyz = np.dot(xyz, RT[:, :3].T) + RT[:, 3:].T 24 | xyz = np.dot(xyz, K.T) 25 | xy = xyz[:, :2] / xyz[:, 2:] 26 | return xy 27 | -------------------------------------------------------------------------------- /lib/utils/getedge.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import cv2 4 | import os 5 | import numpy as np 6 | from skimage import measure 7 | from matplotlib import pyplot as plt 8 | 9 | def uniformsample(pgtnp_px2, newpnum): 10 | print(pgtnp_px2) 11 | pnum, cnum = pgtnp_px2.shape 12 | assert cnum == 2 13 | 14 | idxnext_p = (np.arange(pnum, dtype=np.int32) + 1) % pnum 15 | pgtnext_px2 = pgtnp_px2[idxnext_p] 16 | edgelen_p = np.sqrt(np.sum((pgtnext_px2 - pgtnp_px2) ** 2, axis=1)) 17 | edgeidxsort_p = np.argsort(edgelen_p) 18 | print(edgeidxsort_p) 19 | 20 | # two cases 21 | # we need to remove gt points 22 | # we simply remove shortest paths 23 | if pnum > newpnum: 24 | edgeidxkeep_k = edgeidxsort_p[pnum - newpnum:] 25 | edgeidxsort_k = np.sort(edgeidxkeep_k) 26 | pgtnp_kx2 = pgtnp_px2[edgeidxsort_k] 27 | assert pgtnp_kx2.shape[0] == newpnum 28 | return pgtnp_kx2 29 | 30 | else: 31 | edgenum = np.round(edgelen_p * newpnum / np.sum(edgelen_p)).astype(np.int32) 32 | for i in range(pnum): 33 | if edgenum[i] == 0: 34 | edgenum[i] = 1 35 | 36 | # after round, it may has 1 or 2 mismatch 37 | edgenumsum = np.sum(edgenum) 38 | if edgenumsum != newpnum: 39 | 40 | if edgenumsum > newpnum: 41 | 42 | id = -1 43 | passnum = edgenumsum - newpnum 44 | while passnum > 0: 45 | edgeid = edgeidxsort_p[id] 46 | if edgenum[edgeid] > passnum: 47 | edgenum[edgeid] -= passnum 48 | passnum -= passnum 49 | else: 50 | passnum -= edgenum[edgeid] - 1 51 | edgenum[edgeid] -= edgenum[edgeid] - 1 52 | id -= 1 53 | else: 54 | id = -1 55 | edgeid = edgeidxsort_p[id] 56 | edgenum[edgeid] += newpnum - edgenumsum 57 | 58 | assert np.sum(edgenum) == newpnum 59 | 60 | psample = [] 61 | for i in range(pnum): 62 | pb_1x2 = pgtnp_px2[i:i + 1] 63 | pe_1x2 = pgtnext_px2[i:i + 1] 64 | 65 | pnewnum = edgenum[i] 66 | wnp_kx1 = np.arange(edgenum[i], dtype=np.float32).reshape(-1, 1) / edgenum[i] 67 | 68 | pmids = pb_1x2 * (1 - wnp_kx1) + pe_1x2 * wnp_kx1 69 | psample.append(pmids) 70 | 71 | psamplenp = np.concatenate(psample, axis=0) 72 | return psamplenp 73 | 74 | def close_contour(contour): 75 | if not np.array_equal(contour[0], contour[-1]): 76 | contour = np.vstack((contour, contour[0])) 77 | return contour 78 | 79 | def binary_mask_to_polygon(root, tolerance=0): 80 | 81 | mask = cv2.imread(root,0) 82 | mask = np.array(mask) 83 | mask = mask/128 84 | binary_mask = mask.astype(np.int) 85 | polygons = [] 86 | # pad mask to close contours of shapes which start and end at an edge 87 | padded_binary_mask = np.pad(binary_mask, pad_width=1, mode='constant', constant_values=0) 88 | contours = measure.find_contours(padded_binary_mask, 0.5) 89 | #唐梓轩添加,因为默认处理单连通区域,所以这里取最大的单连通区域 90 | max=0 91 | if len(contours)!=1: 92 | for i,counter1 in enumerate(contours): 93 | if counter1.shape[0]>max: 94 | max=counter1.shape[0] 95 | contours=[counter1] 96 | 97 | contours = np.subtract(contours, 1) 98 | for contour in contours: 99 | contour = close_contour(contour) 100 | contour = measure.approximate_polygon(contour, tolerance) 101 | if len(contour) < 3: 102 | continue 103 | contour = np.flip(contour, axis=1) 104 | segmentation = contour.ravel().tolist() 105 | # after padding and subtracting 1 we may get -0.5 points in our segmentation 106 | segmentation = [0 if i < 0 else i for i in segmentation] 107 | polygons.append(segmentation) 108 | polys=[] 109 | for i in range(int(len(polygons[0])/2)): 110 | polys.append([polygons[0][2*i],polygons[0][2*i+1]]) 111 | 112 | return np.array(polys) 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | if __name__ == '__main__': 123 | np.set_printoptions(threshold=np.inf) 124 | root = '/home/amax/Titan_Five/TZX/deep_sanke/images/4426_mask.jpg' # 修改为你对应的文件路径 125 | poly = binary_mask_to_polygon(root) 126 | #poly=poly[0:256:2,:] 127 | #poly=uniformsample(poly,128) 128 | print(poly.shape) 129 | 130 | inp=cv2.imread("/home/amax/Titan_Five/TZX/snake_envo_num/visual_result/dark.png") 131 | fig, ax = plt.subplots(1, figsize=(20, 10)) 132 | fig.tight_layout() 133 | ax.axis('off') 134 | ax.imshow(inp) 135 | ax.plot(poly[:, 0], poly[:, 1], color='white', linewidth=5) 136 | plt.savefig('./visual_result/demo_result.png', bbox_inches='tight', pad_inches=0) 137 | 138 | 139 | #a=np.array(instance_polys[0]).astype(int) 140 | #print(a.shape) 141 | #a=a*4 142 | #for i in range(a.shape[1]): 143 | # point=(a[0,i,0],a[0,i,1]) 144 | # print(point) 145 | # cv2.circle(img1, point, 1, (255, 0, 0), 1) 146 | #cv2.imwrite('./visual_result/4_inp_with_instance_poly2.jpg', img1) 147 | #print("instance_polys",) 148 | #sys.exit(0) 149 | 150 | #Edge_Extract(root) #shape(278,2) -------------------------------------------------------------------------------- /lib/utils/optimizer/__pycache__/lr_scheduler.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/lib/utils/optimizer/__pycache__/lr_scheduler.cpython-37.pyc -------------------------------------------------------------------------------- /lib/utils/optimizer/__pycache__/lr_scheduler.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/lib/utils/optimizer/__pycache__/lr_scheduler.cpython-38.pyc -------------------------------------------------------------------------------- /lib/utils/optimizer/__pycache__/radam.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/lib/utils/optimizer/__pycache__/radam.cpython-37.pyc -------------------------------------------------------------------------------- /lib/utils/optimizer/__pycache__/radam.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/lib/utils/optimizer/__pycache__/radam.cpython-38.pyc -------------------------------------------------------------------------------- /lib/utils/optimizer/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | from bisect import bisect_right 3 | 4 | import torch 5 | 6 | 7 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): 8 | def __init__( 9 | self, 10 | optimizer, 11 | milestones, 12 | gamma=0.1, 13 | warmup_factor=1.0 / 3, 14 | warmup_iters=5, 15 | warmup_method="linear", 16 | last_epoch=-1, 17 | ): 18 | if not list(milestones) == sorted(milestones): 19 | raise ValueError( 20 | "Milestones should be a list of" " increasing integers. Got {}", 21 | milestones, 22 | ) 23 | 24 | if warmup_method not in ("constant", "linear"): 25 | raise ValueError( 26 | "Only 'constant' or 'linear' warmup_method accepted" 27 | "got {}".format(warmup_method) 28 | ) 29 | self.milestones = milestones 30 | self.gamma = gamma 31 | self.warmup_factor = warmup_factor 32 | self.warmup_iters = warmup_iters 33 | self.warmup_method = warmup_method 34 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch) 35 | 36 | def get_lr(self): 37 | warmup_factor = 1 38 | if self.last_epoch < self.warmup_iters: 39 | if self.warmup_method == "constant": 40 | warmup_factor = self.warmup_factor 41 | elif self.warmup_method == "linear": 42 | alpha = float(self.last_epoch) / self.warmup_iters 43 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 44 | return [ 45 | base_lr 46 | * warmup_factor 47 | * self.gamma ** bisect_right(self.milestones, self.last_epoch) 48 | for base_lr in self.base_lrs 49 | ] 50 | 51 | 52 | class ManualStepLR(torch.optim.lr_scheduler._LRScheduler): 53 | def __init__( 54 | self, 55 | optimizer, 56 | milestones, 57 | gammas, 58 | last_epoch=-1 59 | ): 60 | if not list(milestones) == sorted(milestones): 61 | raise ValueError( 62 | "Milestones should be a list of" " increasing integers. Got {}", 63 | milestones, 64 | ) 65 | 66 | self.milestones = milestones 67 | self.gammas = gammas 68 | super(ManualStepLR, self).__init__(optimizer, last_epoch) 69 | 70 | def get_lr(self): 71 | if self.last_epoch not in self.milestones: 72 | return [group['lr'] for group in self.optimizer.param_groups] 73 | index = self.milestones.index(self.last_epoch) 74 | gamma = self.gammas[index] 75 | return [group['lr'] * gamma 76 | for group in self.optimizer.param_groups] 77 | -------------------------------------------------------------------------------- /lib/utils/rcnn_snake/rcnn_snake_config.py: -------------------------------------------------------------------------------- 1 | from lib.utils.snake.snake_config import * 2 | from lib.config import cfg 3 | 4 | 5 | cp_h, cp_w = 14, 56 6 | roi_h, roi_w = 7, 28 7 | 8 | nms_ct = True 9 | max_ct_overlap = 0.7 10 | ct_score = cfg.ct_score 11 | 12 | cp_hm_nms = False 13 | max_cp_det = 50 14 | max_cp_overlap = 0.1 15 | cp_score = 0.25 16 | 17 | segm_or_bbox = 'segm' 18 | 19 | -------------------------------------------------------------------------------- /lib/utils/rcnn_snake/rcnn_snake_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from lib.utils.snake.snake_decode import nms, topk, transpose_and_gather_feat 3 | from lib.utils.rcnn_snake import rcnn_snake_config 4 | from lib.csrc.extreme_utils import _ext 5 | 6 | 7 | def box_to_roi(box, box_01): 8 | """ box: [b, n, 4] """ 9 | box = box[box_01] 10 | ind = torch.cat([torch.full([box_01[i].sum()], i) for i in range(len(box_01))], dim=0) 11 | ind = ind.to(box.device).float() 12 | roi = torch.cat([ind[:, None], box], dim=1) 13 | return roi 14 | 15 | 16 | def decode_cp_detection(cp_hm, cp_wh, abox, adet): 17 | batch, cat, height, width = cp_hm.size() 18 | if rcnn_snake_config.cp_hm_nms: 19 | cp_hm = nms(cp_hm) 20 | 21 | abox_w, abox_h = abox[..., 2] - abox[..., 0], abox[..., 3] - abox[..., 1] 22 | 23 | scores, inds, clses, ys, xs = topk(cp_hm, rcnn_snake_config.max_cp_det) 24 | cp_wh = transpose_and_gather_feat(cp_wh, inds) 25 | cp_wh = cp_wh.view(batch, rcnn_snake_config.max_cp_det, 2) 26 | 27 | cp_hm_h, cp_hm_w = cp_hm.size(2), cp_hm.size(3) 28 | 29 | xs = xs / cp_hm_w * abox_w[..., None] + abox[:, 0:1] 30 | ys = ys / cp_hm_h * abox_h[..., None] + abox[:, 1:2] 31 | boxes = torch.stack([xs - cp_wh[..., 0] / 2, 32 | ys - cp_wh[..., 1] / 2, 33 | xs + cp_wh[..., 0] / 2, 34 | ys + cp_wh[..., 1] / 2], dim=2) 35 | 36 | ascore = adet[..., 4] 37 | acls = adet[..., 5] 38 | excluded_clses = [1, 2] 39 | for cls_ in excluded_clses: 40 | boxes[acls == cls_, 0] = abox[acls == cls_] 41 | scores[acls == cls_, 0] = 1 42 | scores[acls == cls_, 1:] = 0 43 | 44 | ct_num = len(abox) 45 | boxes_ = [] 46 | for i in range(ct_num): 47 | cp_ind = _ext.nms(boxes[i], scores[i], rcnn_snake_config.max_cp_overlap) 48 | cp_01 = scores[i][cp_ind] > rcnn_snake_config.cp_score 49 | boxes_.append(boxes[i][cp_ind][cp_01]) 50 | 51 | cp_ind = torch.cat([torch.full([len(boxes_[i])], i) for i in range(len(boxes_))], dim=0) 52 | cp_ind = cp_ind.to(boxes.device) 53 | boxes = torch.cat(boxes_, dim=0) 54 | 55 | return boxes, cp_ind 56 | 57 | -------------------------------------------------------------------------------- /lib/utils/snake/__pycache__/active_spline.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/lib/utils/snake/__pycache__/active_spline.cpython-37.pyc -------------------------------------------------------------------------------- /lib/utils/snake/__pycache__/snake_cityscapes_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/lib/utils/snake/__pycache__/snake_cityscapes_utils.cpython-37.pyc -------------------------------------------------------------------------------- /lib/utils/snake/__pycache__/snake_config.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/lib/utils/snake/__pycache__/snake_config.cpython-37.pyc -------------------------------------------------------------------------------- /lib/utils/snake/__pycache__/snake_decode.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/lib/utils/snake/__pycache__/snake_decode.cpython-37.pyc -------------------------------------------------------------------------------- /lib/utils/snake/__pycache__/snake_gcn_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/lib/utils/snake/__pycache__/snake_gcn_utils.cpython-37.pyc -------------------------------------------------------------------------------- /lib/utils/snake/__pycache__/snake_voc_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/lib/utils/snake/__pycache__/snake_voc_utils.cpython-37.pyc -------------------------------------------------------------------------------- /lib/utils/snake/__pycache__/visualize_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/lib/utils/snake/__pycache__/visualize_utils.cpython-37.pyc -------------------------------------------------------------------------------- /lib/utils/snake/active_spline.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | EPS = 1e-7 4 | 5 | 6 | def sample_point(cps, p_num, alpha=0.5): 7 | cp_num = cps.size(1) 8 | p_num = int(p_num / cp_num) 9 | 10 | # Suppose cps is [n_batch, n_cp, 2] 11 | cps = torch.cat([cps, cps[:, 0, :].unsqueeze(1)], dim=1) 12 | auxillary_cps = torch.zeros(cps.size(0), cps.size(1) + 2, cps.size(2)).to(cps.device) 13 | auxillary_cps[:, 1:-1, :] = cps 14 | 15 | l_01 = torch.sqrt(torch.sum(torch.pow(cps[:, 0, :] - cps[:, 1, :], 2), dim=1) + EPS) 16 | l_last_01 = torch.sqrt(torch.sum(torch.pow(cps[:, -1, :] - cps[:, -2, :], 2), dim=1) + EPS) 17 | 18 | l_01.detach_().unsqueeze_(1) 19 | l_last_01.detach_().unsqueeze_(1) 20 | 21 | # print(l_last_01, l_01) 22 | 23 | auxillary_cps[:, 0, :] = cps[:, 0, :] - l_01 / l_last_01 * (cps[:, -1, :] - cps[:, -2, :]) 24 | auxillary_cps[:, -1, :] = cps[:, -1, :] + l_last_01 / l_01 * (cps[:, 1, :] - cps[:, 0, :]) 25 | 26 | # print(auxillary_cps) 27 | 28 | t = torch.zeros([auxillary_cps.size(0), auxillary_cps.size(1)]).to(cps.device) 29 | t[:, 1:] = torch.pow(torch.sum(torch.pow(auxillary_cps[:, 1:, :] - auxillary_cps[:, :-1, :], 2), dim=2), alpha/2) 30 | t = torch.cumsum(t, dim=1) 31 | 32 | # No need to calculate gradient w.r.t t. 33 | t = t.detach() 34 | points = torch.zeros([cps.size(0), p_num * cp_num, cps.size(2)]).to(cps.device) 35 | temp_step = torch.arange(p_num).float().to(cps.device) 36 | temp_step_len = (t[:, 2:-1] - t[:, 1:-2]) / (p_num - 1) 37 | v = torch.matmul(temp_step_len.unsqueeze(2), temp_step.unsqueeze(0).repeat([cps.size(0), 1, 1])).reshape([cps.size(0), -1]) 38 | v = torch.matmul(t[:, 1:-2].unsqueeze(2), torch.ones(cps.size(0), 1, p_num).to(cps.device)).reshape([cps.size(0), -1]) + v 39 | # vuse = v.clone() 40 | t0 = t[:, 0:-3].unsqueeze(2).repeat([1, 1, p_num]).reshape([cps.size(0), -1]) 41 | t1 = t[:, 1:-2].unsqueeze(2).repeat([1, 1, p_num]).reshape([cps.size(0), -1]) 42 | t2 = t[:, 2:-1].unsqueeze(2).repeat([1, 1, p_num]).reshape([cps.size(0), -1]) 43 | t3 = t[:, 3:].unsqueeze(2).repeat([1, 1, p_num]).reshape([cps.size(0), -1]) 44 | 45 | auxillary_cps0 = auxillary_cps[:, 0:-3, :].unsqueeze(2).repeat([1, 1, p_num, 1]).reshape([cps.size(0), -1, 2]) 46 | auxillary_cps1 = auxillary_cps[:, 1:-2, :].unsqueeze(2).repeat([1, 1, p_num, 1]).reshape([cps.size(0), -1, 2]) 47 | auxillary_cps2 = auxillary_cps[:, 2:-1, :].unsqueeze(2).repeat([1, 1, p_num, 1]).reshape([cps.size(0), -1, 2]) 48 | auxillary_cps3 = auxillary_cps[:, 3:, :].unsqueeze(2).repeat([1, 1, p_num, 1]).reshape([cps.size(0), -1, 2]) 49 | 50 | mx01 = ((t1 - v) / (t1 - t0)).unsqueeze(2).repeat([1, 1, 2]) * auxillary_cps0 + \ 51 | ((v - t0) / (t1 - t0)).unsqueeze(2).repeat([1, 1, 2]) * auxillary_cps1 52 | 53 | mx12 = ((t2 - v) / (t2 - t1)).unsqueeze(2).repeat([1, 1, 2]) * auxillary_cps1 + \ 54 | ((v - t1) / (t2 - t1)).unsqueeze(2).repeat([1, 1, 2]) * auxillary_cps2 55 | 56 | mx23 = ((t3 - v) / (t3 - t2)).unsqueeze(2).repeat([1, 1, 2]) * auxillary_cps2 + \ 57 | ((v - t2) / (t3 - t2)).unsqueeze(2).repeat([1, 1, 2]) * auxillary_cps3 58 | 59 | mx012 = ((t2 - v) / (t2 - t0)).unsqueeze(2).repeat([1, 1, 2]) * mx01 \ 60 | + ((v - t0) / (t2 - t0)).unsqueeze(2).repeat([1, 1, 2]) * mx12 61 | 62 | mx123 = ((t3 - v) / (t3 - t1)).unsqueeze(2).repeat([1, 1, 2]) * mx12 \ 63 | + ((v - t1) / (t3 - t1)).unsqueeze(2).repeat([1, 1, 2]) * mx23 64 | 65 | points[:, :] = ((t2 - v) / (t2 - t1)).unsqueeze(2).repeat([1, 1, 2]) * mx012 \ 66 | + ((v - t1) / (t2 - t1)).unsqueeze(2).repeat([1, 1, 2]) * mx123 67 | 68 | return points 69 | -------------------------------------------------------------------------------- /lib/utils/snake/snake_cityscapes_coco_utils.py: -------------------------------------------------------------------------------- 1 | from lib.utils.snake.snake_cityscapes_utils import * 2 | 3 | crop_scale = np.array([896, 384]) 4 | input_scale = np.array([896, 384]) 5 | scale_range = np.arange(0.4, 1.0, 0.1) 6 | 7 | 8 | def augment(img, split, _data_rng, _eig_val, _eig_vec, mean, std, polys): 9 | # resize input 10 | height, width = img.shape[0], img.shape[1] 11 | center = np.array([img.shape[1] / 2., img.shape[0] / 2.], dtype=np.float32) 12 | scale = crop_scale 13 | if not isinstance(scale, np.ndarray) and not isinstance(scale, list): 14 | scale = np.array([scale, scale], dtype=np.float32) 15 | 16 | # random crop and flip augmentation 17 | flipped = False 18 | if split == 'train': 19 | scale = scale * np.random.choice(scale_range) 20 | seed = np.random.randint(0, len(polys)) 21 | index = np.random.randint(0, len(polys[seed][0])) 22 | x, y = polys[seed][0][index] 23 | center[0] = x 24 | border = scale[0] // 2 if scale[0] < width else width - scale[0] // 2 25 | center[0] = np.clip(center[0], a_min=border, a_max=width-border) 26 | center[1] = y 27 | border = scale[1] // 2 if scale[1] < height else height - scale[1] // 2 28 | center[1] = np.clip(center[1], a_min=border, a_max=height-border) 29 | 30 | # flip augmentation 31 | if np.random.random() < 0.5: 32 | flipped = True 33 | img = img[:, ::-1, :] 34 | center[0] = width - center[0] - 1 35 | 36 | input_w, input_h = input_scale 37 | if split != 'train': 38 | center = np.array([width // 2, height // 2]) 39 | scale = np.array([width, height]) 40 | # input_w, input_h = int((width / 1 + 31) // 32 * 32), int((height / 1 + 31) // 32 * 32) 41 | input_w, input_h = int((width / 0.85 + 31) // 32 * 32), int((height / 0.85 + 31) // 32 * 32) 42 | 43 | trans_input = data_utils.get_affine_transform(center, scale, 0, [input_w, input_h]) 44 | inp = cv2.warpAffine(img, trans_input, (input_w, input_h), flags=cv2.INTER_LINEAR) 45 | 46 | # color augmentation 47 | orig_img = inp.copy() 48 | inp = (inp.astype(np.float32) / 255.) 49 | if split == 'train': 50 | data_utils.color_aug(_data_rng, inp, _eig_val, _eig_vec) 51 | 52 | # normalize the image 53 | inp = (inp - mean) / std 54 | inp = inp.transpose(2, 0, 1) 55 | 56 | output_h, output_w = input_h // snake_config.down_ratio, input_w // snake_config.down_ratio 57 | trans_output = data_utils.get_affine_transform(center, scale, 0, [output_w, output_h]) 58 | inp_out_hw = (input_h, input_w, output_h, output_w) 59 | 60 | return orig_img, inp, trans_input, trans_output, flipped, center, scale, inp_out_hw 61 | 62 | -------------------------------------------------------------------------------- /lib/utils/snake/snake_coco_utils.py: -------------------------------------------------------------------------------- 1 | from lib.utils.snake.snake_cityscapes_utils import * 2 | 3 | input_scale = np.array([512, 512]) 4 | 5 | 6 | def augment(img, split, _data_rng, _eig_val, _eig_vec, mean, std, polys): 7 | # resize input 8 | height, width = img.shape[0], img.shape[1] 9 | center = np.array([img.shape[1] / 2., img.shape[0] / 2.], dtype=np.float32) 10 | scale = max(height, width) 11 | if not isinstance(scale, np.ndarray) and not isinstance(scale, list): 12 | scale = np.array([scale, scale], dtype=np.float32) 13 | 14 | # random crop and flip augmentation 15 | flipped = False 16 | if split == 'train': 17 | scale = scale * np.random.uniform(0.6, 1.4) 18 | seed = np.random.randint(0, len(polys)) 19 | index = np.random.randint(0, len(polys[seed][0])) 20 | x, y = polys[seed][0][index] 21 | center[0] = x 22 | border = scale[0] // 2 if scale[0] < width else width - scale[0] // 2 23 | center[0] = np.clip(center[0], a_min=border, a_max=width-border) 24 | center[1] = y 25 | border = scale[1] // 2 if scale[1] < height else height - scale[1] // 2 26 | center[1] = np.clip(center[1], a_min=border, a_max=height-border) 27 | 28 | # flip augmentation 29 | if np.random.random() < 0.5: 30 | flipped = True 31 | img = img[:, ::-1, :] 32 | center[0] = width - center[0] - 1 33 | 34 | input_w, input_h = input_scale 35 | if split != 'train': 36 | center = np.array([width // 2, height // 2]) 37 | scale = np.array([width, height]) 38 | x = 32 39 | input_w = (int(width / 1.) | (x - 1)) + 1 40 | input_h = (int(height / 1.) | (x - 1)) + 1 41 | scale = np.array([input_w, input_h]) 42 | 43 | 44 | trans_input = data_utils.get_affine_transform(center, scale, 0, [input_w, input_h]) 45 | inp = cv2.warpAffine(img, trans_input, (input_w, input_h), flags=cv2.INTER_LINEAR) 46 | 47 | # color augmentation 48 | orig_img = inp.copy() 49 | inp = (inp.astype(np.float32) / 255.) 50 | if split == 'train': 51 | data_utils.color_aug(_data_rng, inp, _eig_val, _eig_vec) 52 | # data_utils.blur_aug(inp) 53 | 54 | # normalize the image 55 | inp = (inp - mean) / std 56 | inp = inp.transpose(2, 0, 1) 57 | 58 | output_h, output_w = input_h // snake_config.down_ratio, input_w // snake_config.down_ratio 59 | trans_output = data_utils.get_affine_transform(center, scale, 0, [output_w, output_h]) 60 | inp_out_hw = (input_h, input_w, output_h, output_w) 61 | 62 | return orig_img, inp, trans_input, trans_output, flipped, center, scale, inp_out_hw -------------------------------------------------------------------------------- /lib/utils/snake/snake_config.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from lib.config import cfg 3 | 4 | 5 | mean = np.array([0.40789654, 0.44719302, 0.47026115], 6 | dtype=np.float32).reshape(1, 1, 3) 7 | std = np.array([0.28863828, 0.27408164, 0.27809835], 8 | dtype=np.float32).reshape(1, 1, 3) 9 | data_rng = np.random.RandomState(123) 10 | eig_val = np.array([0.2141788, 0.01817699, 0.00341571], 11 | dtype=np.float32) 12 | eig_vec = np.array([ 13 | [-0.58752847, -0.69563484, 0.41340352], 14 | [-0.5832747, 0.00994535, -0.81221408], 15 | [-0.56089297, 0.71832671, 0.41158938] 16 | ], dtype=np.float32) 17 | 18 | down_ratio = 4 19 | scale = np.array([800, 800]) 20 | input_w, input_h = (800, 800) 21 | scale_range = np.arange(0.6, 1.4, 0.1) 22 | 23 | voc_input_h, voc_input_w = (512, 512) 24 | voc_scale_range = np.arange(0.6, 1.4, 0.1) 25 | 26 | box_center = False 27 | center_scope = False 28 | 29 | init = 'quadrangle' 30 | init_poly_num = 40 31 | poly_num = 128 32 | gt_poly_num = 128 33 | spline_num = 10 34 | 35 | adj_num = 4 36 | 37 | train_pred_box = False 38 | box_iou = 0.7 39 | confidence = 0.1 40 | train_pred_box_only = True 41 | 42 | train_pred_ex = False 43 | train_nearest_gt = True 44 | 45 | ct_score = cfg.ct_score 46 | 47 | ro = 4 48 | 49 | segm_or_bbox = 'segm' 50 | 51 | -------------------------------------------------------------------------------- /lib/utils/snake/snake_eval_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import pycocotools.mask as mask_utils 4 | from PIL import Image 5 | import os 6 | 7 | 8 | def poly_to_mask(poly, label, h, w): 9 | mask = [] 10 | for i in range(len(poly)): 11 | mask_ = np.zeros((h, w), dtype=np.uint8) 12 | cv2.fillPoly(mask_, [np.round(poly[i]).astype(int)], int(label[i])) 13 | mask.append(mask_) 14 | return mask 15 | 16 | 17 | def coco_poly_to_mask(poly, h, w): 18 | mask = [] 19 | for i in range(len(poly)): 20 | rles = mask_utils.frPyObjects([poly[i].reshape(-1)], h, w) 21 | rle = mask_utils.merge(rles) 22 | mask_ = mask_utils.decode(rle) 23 | mask.append(mask_) 24 | return mask 25 | 26 | 27 | def rcnn_poly_to_mask(poly, ind_group, label, h, w): 28 | mask = [] 29 | for i in range(len(ind_group)): 30 | poly_ = [np.round(poly[ind]).astype(int) for ind in ind_group[i]] 31 | mask_ = np.zeros((h, w), dtype=np.uint8) 32 | cv2.fillPoly(mask_, poly_, int(label[i])) 33 | mask.append(mask_) 34 | return mask 35 | 36 | 37 | def rcnn_coco_poly_to_mask(poly, ind_group, h, w): 38 | mask = [] 39 | for i in range(len(ind_group)): 40 | poly_ = [poly[ind].reshape(-1) for ind in ind_group[i]] 41 | rles = mask_utils.frPyObjects(poly_, h, w) 42 | rle = mask_utils.merge(rles) 43 | mask_ = mask_utils.decode(rle) 44 | mask.append(mask_) 45 | return mask 46 | 47 | 48 | def coco_poly_to_rle(poly, h, w): 49 | rle_ = [] 50 | for i in range(len(poly)): 51 | rles = mask_utils.frPyObjects([poly[i].reshape(-1)], h, w) 52 | rle = mask_utils.merge(rles) 53 | rle['counts'] = rle['counts'].decode('utf-8') 54 | rle_.append(rle) 55 | return rle_ 56 | 57 | 58 | def rcnn_coco_poly_to_rle(poly, ind_group, h, w): 59 | rle_ = [] 60 | for i in range(len(ind_group)): 61 | poly_ = [poly[ind].reshape(-1) for ind in ind_group[i]] 62 | rles = mask_utils.frPyObjects(poly_, h, w) 63 | rle = mask_utils.merge(rles) 64 | rle['counts'] = rle['counts'].decode('utf-8') 65 | rle_.append(rle) 66 | return rle_ 67 | 68 | -------------------------------------------------------------------------------- /lib/utils/snake/snake_kins_utils.py: -------------------------------------------------------------------------------- 1 | from lib.utils.snake.snake_cityscapes_utils import * 2 | 3 | crop_scale = np.array([896, 384]) 4 | input_scale = np.array([896, 384]) 5 | scale_range = np.arange(0.4, 1.0, 0.1) 6 | 7 | 8 | def augment(img, split, _data_rng, _eig_val, _eig_vec, mean, std, polys): 9 | # resize input 10 | height, width = img.shape[0], img.shape[1] 11 | center = np.array([img.shape[1] / 2., img.shape[0] / 2.], dtype=np.float32) 12 | scale = crop_scale 13 | if not isinstance(scale, np.ndarray) and not isinstance(scale, list): 14 | scale = np.array([scale, scale], dtype=np.float32) 15 | 16 | # random crop and flip augmentation 17 | flipped = False 18 | if split == 'train': 19 | scale = scale * np.random.uniform(0.25, 0.8) 20 | seed = np.random.randint(0, len(polys)) 21 | index = np.random.randint(0, len(polys[seed][0])) 22 | x, y = polys[seed][0][index] 23 | center[0] = x 24 | border = scale[0] // 2 if scale[0] < width else width - scale[0] // 2 25 | center[0] = np.clip(center[0], a_min=border, a_max=width-border) 26 | center[1] = y 27 | border = scale[1] // 2 if scale[1] < height else height - scale[1] // 2 28 | center[1] = np.clip(center[1], a_min=border, a_max=height-border) 29 | 30 | # flip augmentation 31 | if np.random.random() < 0.5: 32 | flipped = True 33 | img = img[:, ::-1, :] 34 | center[0] = width - center[0] - 1 35 | 36 | input_w, input_h = input_scale 37 | if split != 'train': 38 | center = np.array([width // 2, height // 2]) 39 | scale = np.array([width, height]) 40 | x = 32 41 | # input_w, input_h = (width + x - 1) // x * x, (height + x - 1) // x * x 42 | input_w, input_h = int((width / 0.5 + x - 1) // x * x), int((height / 0.5 + x - 1) // x * x) 43 | 44 | trans_input = data_utils.get_affine_transform(center, scale, 0, [input_w, input_h]) 45 | inp = cv2.warpAffine(img, trans_input, (input_w, input_h), flags=cv2.INTER_LINEAR) 46 | 47 | # color augmentation 48 | orig_img = inp.copy() 49 | inp = (inp.astype(np.float32) / 255.) 50 | if split == 'train': 51 | data_utils.color_aug(_data_rng, inp, _eig_val, _eig_vec) 52 | # data_utils.blur_aug(inp) 53 | 54 | # normalize the image 55 | inp = (inp - mean) / std 56 | inp = inp.transpose(2, 0, 1) 57 | 58 | output_h, output_w = input_h // snake_config.down_ratio, input_w // snake_config.down_ratio 59 | trans_output = data_utils.get_affine_transform(center, scale, 0, [output_w, output_h]) 60 | inp_out_hw = (input_h, input_w, output_h, output_w) 61 | 62 | return orig_img, inp, trans_input, trans_output, flipped, center, scale, inp_out_hw 63 | 64 | -------------------------------------------------------------------------------- /lib/utils/snake/snake_poly_utils.py: -------------------------------------------------------------------------------- 1 | from shapely.geometry import Polygon, MultiPolygon 2 | from shapely.ops import cascaded_union, polygonize 3 | import numpy as np 4 | from lib.utils.snake import snake_config 5 | 6 | 7 | def get_shape_poly(poly): 8 | shape_poly = Polygon(poly) 9 | if shape_poly.is_valid: 10 | return shape_poly 11 | 12 | # self-intersected situation 13 | linering = shape_poly.exterior 14 | 15 | # disassemble polygons from multiple line strings 16 | mls = linering.intersection(linering) 17 | # assemble polygons from multiple line strings 18 | polygons = polygonize(mls) 19 | multi_shape_poly = MultiPolygon(polygons) 20 | 21 | return multi_shape_poly 22 | 23 | 24 | def poly_iou(poly1, poly2): 25 | poly = cascaded_union([poly1, poly2]) 26 | union = poly.area 27 | intersection = poly1.area + poly2.area - union 28 | return intersection / union 29 | 30 | 31 | def get_poly_iou_matrix(poly1, poly2): 32 | poly1 = [get_shape_poly(poly) for poly in poly1] 33 | poly2 = [get_shape_poly(poly) for poly in poly2] 34 | 35 | iou_matrix = np.zeros([len(poly1), len(poly2)]) 36 | for i in range(len(poly1)): 37 | for j in range(len(poly2)): 38 | iou_matrix[i, j] = poly_iou(poly1[i], poly2[j]) 39 | 40 | return iou_matrix 41 | 42 | 43 | def get_poly_match_ind(poly1, poly2): 44 | iou_matrix = get_poly_iou_matrix(poly1, poly2) 45 | iou = iou_matrix.max(axis=1) 46 | gt_ind = iou_matrix.argmax(axis=1) 47 | poly_ind = np.argwhere(iou > snake_config.poly_iou).ravel() 48 | gt_ind = gt_ind[poly_ind] 49 | return poly_ind, gt_ind 50 | 51 | 52 | def poly_nms(poly): 53 | iou_matrix = get_poly_iou_matrix(poly, poly) 54 | iou_matrix[np.arange(len(poly)), np.arange(len(poly))] = 0 55 | 56 | overlapped = np.zeros([len(poly)]) 57 | poly_ = [] 58 | ind = [] 59 | for i in range(len(poly)): 60 | if overlapped[i]: 61 | continue 62 | poly_.append(poly[i]) 63 | ind.append(i) 64 | overlapped[iou_matrix[i] > snake_config.poly_iou] = 1 65 | 66 | return poly_, ind 67 | -------------------------------------------------------------------------------- /lib/utils/snake/visualize_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import matplotlib.pyplot as plt 4 | import matplotlib.patches as patches 5 | from lib.utils.img_utils import colors 6 | from lib.utils import img_utils 7 | from lib.utils.snake import snake_cityscapes_utils, snake_config 8 | 9 | 10 | R = 8 11 | GREEN = (18, 127, 15) 12 | WHITE = (255, 255, 255) 13 | 14 | 15 | def visualize_snake_detection(img, data): 16 | 17 | def blend_hm_img(hm, img): 18 | hm = np.max(hm, axis=0) 19 | h, w = hm.shape[:2] 20 | img = cv2.resize(img, dsize=(w, h), interpolation=cv2.INTER_LINEAR) 21 | hm = np.array([255, 255, 255]) - (hm.reshape(h, w, 1) * colors[0]).astype(np.uint8) 22 | ratio = 0.5 23 | blend = (img * ratio + hm * (1 - ratio)).astype(np.uint8) 24 | return blend 25 | 26 | img = img_utils.bgr_to_rgb(img) 27 | blend = blend_hm_img(data['ct_hm'], img) 28 | 29 | plt.imshow(blend) 30 | ct_ind = np.array(data['ct_ind']) 31 | w = img.shape[1] // snake_config.down_ratio 32 | xs = ct_ind % w 33 | ys = ct_ind // w 34 | for i in range(len(data['wh'])): 35 | w, h = data['wh'][i] 36 | x_min, y_min = xs[i] - w / 2, ys[i] - h / 2 37 | x_max, y_max = xs[i] + w / 2, ys[i] + h / 2 38 | plt.plot([x_min, x_min, x_max, x_max, x_min], [y_min, y_max, y_max, y_min, y_min]) 39 | plt.show() 40 | 41 | 42 | def visualize_cp_detection(img, data): 43 | act_ind = data['act_ind'] 44 | awh = data['awh'] 45 | 46 | act_hm_w = data['act_hm'].shape[2] 47 | cp_h, cp_w = data['cp_hm'][0].shape[1], data['cp_hm'][0].shape[2] 48 | 49 | img = img_utils.bgr_to_rgb(img) 50 | plt.imshow(img) 51 | 52 | for i in range(len(act_ind)): 53 | act_ind_ = act_ind[i] 54 | ct = act_ind_ % act_hm_w, act_ind_ // act_hm_w 55 | w, h = awh[i] 56 | abox = np.array([ct[0] - w/2, ct[1] - h/2, ct[0] + w/2, ct[1] + h/2]) 57 | 58 | cp_ind_ = data['cp_ind'][i] 59 | cp_wh_ = data['cp_wh'][i] 60 | 61 | for j in range(len(cp_ind_)): 62 | ct = cp_ind_[j] % cp_w, cp_ind_[j] // cp_w 63 | x = ct[0] / cp_w * w 64 | y = ct[1] / cp_h * h 65 | x_min = (x - cp_wh_[j][0] / 2 + abox[0]) * snake_config.down_ratio 66 | y_min = (y - cp_wh_[j][1] / 2 + abox[1]) * snake_config.down_ratio 67 | x_max = (x + cp_wh_[j][0] / 2 + abox[0]) * snake_config.down_ratio 68 | y_max = (y + cp_wh_[j][1] / 2 + abox[1]) * snake_config.down_ratio 69 | plt.plot([x_min, x_min, x_max, x_max, x_min], [y_min, y_max, y_max, y_min, y_min]) 70 | 71 | plt.show() 72 | 73 | 74 | def visualize_snake_evolution(img, data): 75 | img = img_utils.bgr_to_rgb(img) 76 | plt.imshow(img) 77 | for poly in data['i_gt_py']: 78 | poly = poly * 4 79 | poly = np.append(poly, [poly[0]], axis=0) 80 | plt.plot(poly[:, 0], poly[:, 1]) 81 | plt.scatter(poly[0, 0], poly[0, 1], edgecolors='w') 82 | plt.show() 83 | 84 | 85 | def visualize_snake_octagon(img, extreme_points): 86 | img = img_utils.bgr_to_rgb(img) 87 | octagons = [] 88 | bboxes = [] 89 | ex_points = [] 90 | for i in range(len(extreme_points)): 91 | for j in range(len(extreme_points[i])): 92 | bbox = get_bbox(extreme_points[i][j]*4) 93 | octagon = snake_cityscapes_utils.get_octagon(extreme_points[i][j]*4) 94 | bboxes.append(bbox) 95 | octagons.append(octagon) 96 | ex_points.append(extreme_points[i][j]) 97 | _, ax = plt.subplots(1) 98 | ax.imshow(img) 99 | n = len(octagons) 100 | for i in range(n): 101 | x, y, x_max, y_max = bboxes[i] 102 | ax.add_patch(patches.Polygon(xy=[[x, y], [x, y_max], [x_max, y_max], [x_max, y]], fill=False, linewidth=1, 103 | edgecolor='r')) 104 | octagon = np.append(octagons[i], octagons[i][0]).reshape(-1, 2) 105 | ax.plot(octagon[:, 0], octagon[:, 1]) 106 | ax.scatter(ex_points[i][:, 0] * 4, ex_points[i][:, 1] * 4, edgecolors='w') 107 | plt.show() 108 | 109 | 110 | def get_bbox(ex): 111 | x = ex[:, 0] 112 | y = ex[:, 1] 113 | bbox = [np.min(x), np.min(y), np.max(x), np.max(y)] 114 | return bbox 115 | -------------------------------------------------------------------------------- /poly2mask.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | x = [70,222,280,330,467,358,392,280,138,195] 5 | 6 | y = [190,190,61,190,190,260,380,308,380,260] 7 | 8 | cor_xy = np.vstack((x, y)).T 9 | print(cor_xy.shape) 10 | img=np.zeros((512,512)) 11 | print((img==1).sum()) 12 | 13 | img = cv2.polylines(img,[cor_xy],True,1,1) 14 | print((img==1).sum()) 15 | 16 | img = cv2.fillPoly(img, [cor_xy], 1) 17 | print((img==1).sum()) 18 | 19 | cv2.imwrite('/home/amax/Titan_Five/TZX/snake-master/visual_result/poly2mask.jpg',img*255) 20 | -------------------------------------------------------------------------------- /process/__pycache__/metrics.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/process/__pycache__/metrics.cpython-37.pyc -------------------------------------------------------------------------------- /process/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.metrics import f1_score, average_precision_score 3 | from skimage.morphology import binary_dilation, disk 4 | 5 | __all__ = ['get_f1_scores', 'get_ap_scores', 'get_iou', 'WCov_metric', 'FBound_metric'] 6 | 7 | SMOOTH = 1e-6 8 | 9 | def get_f1_scores(predict, target, ignore_index=-1): 10 | f1 = [] 11 | for pred, tgt in zip(predict, target): 12 | # Tensor process 13 | pred = pred.data.cpu().numpy().reshape(-1) 14 | tgt = tgt.data.cpu().numpy().reshape(-1) 15 | p = pred[tgt != ignore_index] 16 | t = tgt[tgt != ignore_index] 17 | f1.append(f1_score(t, p)) 18 | 19 | return f1 20 | 21 | 22 | def get_ap_scores(predict, target, ignore_index=-1): 23 | ap = [] 24 | for pred, tgt in zip(predict, target): 25 | # Tensor process 26 | pred = pred.data.cpu().numpy().reshape(-1) 27 | tgt = tgt.data.cpu().numpy().reshape(-1) 28 | p = pred[tgt != ignore_index] 29 | t = tgt[tgt != ignore_index] 30 | 31 | ap.append(average_precision_score(t, p)) 32 | 33 | return ap 34 | 35 | 36 | def get_iou(outputs, labels): 37 | outputs = outputs.squeeze(1) 38 | 39 | intersection = (outputs & labels).float().sum((1, 2)) 40 | union = (outputs | labels).float().sum((1, 2)) 41 | 42 | iou = (intersection + SMOOTH) / (union + SMOOTH) 43 | 44 | return iou.cpu().tolist() 45 | 46 | 47 | def WCov_metric(X, Y): 48 | A1 = float(np.count_nonzero(X)) 49 | A2 = float(np.count_nonzero(Y)) 50 | if A1 >= A2: return A2 / A1 51 | if A2 > A1: return A1 / A2 52 | 53 | 54 | def FBound_metric(X, Y): 55 | print(X.max(), type(X), Y.max()) 56 | tmp1 = db_eval_boundary(X, Y, 1)[0] 57 | tmp2 = db_eval_boundary(X, Y, 2)[0] 58 | tmp3 = db_eval_boundary(X, Y, 3)[0] 59 | tmp4 = db_eval_boundary(X, Y, 4)[0] 60 | tmp5 = db_eval_boundary(X, Y, 5)[0] 61 | return (tmp1 + tmp2 + tmp3 + tmp4 + tmp5) / 5.0 62 | 63 | 64 | def db_eval_boundary(foreground_mask, gt_mask, bound_th): 65 | """ 66 | Compute mean,recall and decay from per-frame evaluation. 67 | Calculates precision/recall for boundaries between foreground_mask and 68 | gt_mask using morphological operators to speed it up. 69 | Arguments: 70 | foreground_mask (ndarray): binary segmentation image. 71 | gt_mask (ndarray): binary annotated image. 72 | Returns: 73 | F (float): boundaries F-measure 74 | P (float): boundaries precision 75 | R (float): boundaries recall 76 | """ 77 | 78 | assert np.atleast_3d(foreground_mask).shape[2] == 1 79 | 80 | bound_pix = bound_th if bound_th >= 1 else \ 81 | np.ceil(bound_th * np.linalg.norm(foreground_mask.shape)) 82 | 83 | # Get the pixel boundaries of both masks 84 | fg_boundary = seg2bmap(foreground_mask); 85 | gt_boundary = seg2bmap(gt_mask); 86 | 87 | fg_dil = binary_dilation(fg_boundary, disk(bound_pix)) 88 | gt_dil = binary_dilation(gt_boundary, disk(bound_pix)) 89 | 90 | # Get the intersection 91 | gt_match = gt_boundary * fg_dil 92 | fg_match = fg_boundary * gt_dil 93 | 94 | # Area of the intersection 95 | n_fg = np.sum(fg_boundary) 96 | n_gt = np.sum(gt_boundary) 97 | 98 | # % Compute precision and recall 99 | if n_fg == 0 and n_gt > 0: 100 | precision = 1 101 | recall = 0 102 | elif n_fg > 0 and n_gt == 0: 103 | precision = 0 104 | recall = 1 105 | elif n_fg == 0 and n_gt == 0: 106 | precision = 1 107 | recall = 1 108 | else: 109 | precision = np.sum(fg_match) / float(n_fg) 110 | recall = np.sum(gt_match) / float(n_gt) 111 | 112 | # Compute F measure 113 | if precision + recall == 0: 114 | F = 0 115 | else: 116 | F = 2 * precision * recall / (precision + recall); 117 | 118 | return F, precision, recall, np.sum(fg_match), n_fg, np.sum(gt_match), n_gt 119 | 120 | 121 | def seg2bmap(seg, width=None, height=None): 122 | """ 123 | From a segmentation, compute a binary boundary map with 1 pixel wide 124 | boundaries. The boundary pixels are offset by 1/2 pixel towards the 125 | origin from the actual segment boundary. 126 | Arguments: 127 | seg : Segments labeled from 1..k. 128 | width : Width of desired bmap <= seg.shape[1] 129 | height : Height of desired bmap <= seg.shape[0] 130 | Returns: 131 | bmap (ndarray): Binary boundary map. 132 | David Martin 133 | January 2003 134 | """ 135 | seg = seg.astype(np.bool) 136 | seg[seg > 0] = 1 137 | 138 | assert np.atleast_3d(seg).shape[2] == 1 139 | 140 | width = seg.shape[1] if width is None else width 141 | height = seg.shape[0] if height is None else height 142 | 143 | h, w = seg.shape[:2] 144 | 145 | ar1 = float(width) / float(height) 146 | ar2 = float(w) / float(h) 147 | 148 | assert not (width > w | height > h | abs(ar1 - ar2) > 0.01), \ 149 | 'Can''t convert %dx%d seg to %dx%d bmap.' % (w, h, width, height) 150 | 151 | e = np.zeros_like(seg) 152 | s = np.zeros_like(seg) 153 | se = np.zeros_like(seg) 154 | 155 | e[:, :-1] = seg[:, 1:] 156 | s[:-1, :] = seg[1:, :] 157 | se[:-1, :-1] = seg[1:, 1:] 158 | 159 | b = seg ^ e | seg ^ s | seg ^ se 160 | b[-1, :] = seg[-1, :] ^ e[-1, :] 161 | b[:, -1] = seg[:, -1] ^ s[:, -1] 162 | b[-1, -1] = 0 163 | 164 | if w == width and h == height: 165 | bmap = b 166 | else: 167 | bmap = np.zeros((height, width)) 168 | for x in range(w): 169 | for y in range(h): 170 | if b[y, x]: 171 | j = 1 + np.floor((y - 1) + height / h) 172 | i = 1 + np.floor((x - 1) + width / h) 173 | bmap[j, i] = 1 174 | 175 | return bmap 176 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | yacs==0.1.4 2 | numpy==1.16.4 3 | tqdm==4.28.1 4 | opencv-contrib-python==3.4.2.17 5 | opencv-python==3.4.2.17 6 | imgaug==0.2.9 7 | pycocotools==2.0.0 8 | Pillow 9 | tensorboardX==1.2 -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | from lib.config import cfg, args 2 | import numpy as np 3 | import os 4 | 5 | 6 | def run_dataset(): 7 | from lib.datasets import make_data_loader 8 | import tqdm 9 | 10 | cfg.train.num_workers = 0 11 | data_loader = make_data_loader(cfg, is_train=False) 12 | for batch in tqdm.tqdm(data_loader): 13 | pass 14 | 15 | 16 | def run_network(): 17 | from lib.networks import make_network 18 | from lib.datasets import make_data_loader 19 | from lib.utils.net_utils import load_network 20 | import tqdm 21 | import torch 22 | import time 23 | 24 | network = make_network(cfg).cuda() 25 | load_network(network, cfg.model_dir, epoch=cfg.test.epoch) 26 | network.eval() 27 | 28 | data_loader = make_data_loader(cfg, is_train=False) 29 | total_time = 0 30 | for batch in tqdm.tqdm(data_loader): 31 | for k in batch: 32 | if k != 'meta': 33 | batch[k] = batch[k].cuda() 34 | with torch.no_grad(): 35 | torch.cuda.synchronize() 36 | start = time.time() 37 | network(batch['inp']) 38 | torch.cuda.synchronize() 39 | total_time += time.time() - start 40 | print(total_time / len(data_loader)) 41 | 42 | 43 | def run_evaluate(): 44 | from lib.datasets import make_data_loader 45 | from lib.evaluators import make_evaluator 46 | import tqdm 47 | import torch 48 | from lib.networks import make_network 49 | from lib.utils.net_utils import load_network 50 | 51 | network = make_network(cfg).cuda() 52 | load_network(network, cfg.model_dir, epoch=cfg.test.epoch) 53 | network.eval() 54 | 55 | data_loader = make_data_loader(cfg, is_train=False) 56 | evaluator = make_evaluator(cfg) 57 | for batch in tqdm.tqdm(data_loader): 58 | inp = batch['inp'].cuda() 59 | with torch.no_grad(): 60 | output = network(inp) 61 | evaluator.evaluate(output, batch) 62 | evaluator.summarize() 63 | 64 | 65 | def run_visualize(): 66 | from lib.networks import make_network 67 | from lib.datasets import make_data_loader 68 | from lib.utils.net_utils import load_network 69 | import tqdm 70 | import torch 71 | from lib.visualizers import make_visualizer 72 | 73 | network = make_network(cfg).cuda() 74 | load_network(network, cfg.model_dir, resume=cfg.resume, epoch=cfg.test.epoch) 75 | network.eval() 76 | 77 | data_loader = make_data_loader(cfg, is_train=False) 78 | visualizer = make_visualizer(cfg) 79 | for batch in tqdm.tqdm(data_loader): 80 | for k in batch: 81 | if k != 'meta': 82 | batch[k] = batch[k].cuda() 83 | with torch.no_grad(): 84 | output = network(batch['inp'], batch) 85 | visualizer.visualize(output, batch) 86 | 87 | 88 | def run_sbd(): 89 | from tools import convert_sbd 90 | convert_sbd.convert_sbd() 91 | 92 | 93 | def run_demo(): 94 | from tools import demo 95 | demo.demo() 96 | 97 | def run_visiual(): 98 | from tools import visualization 99 | visualization.visual() 100 | 101 | 102 | if __name__ == '__main__': 103 | globals()['run_'+args.type]() 104 | -------------------------------------------------------------------------------- /run_visual.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | #os.environ["CUDA_VISIBLE_DEVICES"] = '6' 5 | 6 | sys.argv.extend(['--cfg_file', 'configs/sbd_snake.yaml', 'demo_path', "/data/tzx/snake_envo_num/visual_result/MRAVBCE/PDS-C/input", 'ct_score', '1.2','ff_num','0']) 7 | #/home/amax/Titan_Five/TZX/snake-master/demo_images/2009_000871.jpg 8 | #/home/amax/Titan_Five/TZX/snake-master/demo_images/4599_image.jpg 9 | from run import run_visiual 10 | run_visiual() -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | #os.environ["CUDA_VISIBLE_DEVICES"] = '6' 5 | 6 | sys.argv.extend(['--cfg_file', 'configs/sbd_snake.yaml', 'demo_path', "/data/tzx/data/images", 'ct_score', '0.3','ff_num','0']) 7 | from run import run_demo 8 | run_demo() 9 | -------------------------------------------------------------------------------- /tools/__pycache__/demo.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/tools/__pycache__/demo.cpython-37.pyc -------------------------------------------------------------------------------- /tools/__pycache__/demo.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/tools/__pycache__/demo.cpython-38.pyc -------------------------------------------------------------------------------- /tools/__pycache__/visualization.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu19351118/PDS-pytorch/fc25d0ce48bc4354a97aeaffb91efb560e83db02/tools/__pycache__/visualization.cpython-37.pyc -------------------------------------------------------------------------------- /tools/visualization.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import glob 3 | import os 4 | import cv2 5 | import numpy as np 6 | from lib.utils.snake import snake_config 7 | from lib.utils import data_utils 8 | from lib.config import cfg 9 | import tqdm 10 | import torch 11 | from lib.networks import make_network 12 | from lib.utils.net_utils import load_network 13 | from lib.visualizers import make_visualizer 14 | from process.metrics import FBound_metric, WCov_metric 15 | import random 16 | import sys 17 | 18 | 19 | class Dataset(data.Dataset): 20 | def __init__(self): 21 | super(Dataset, self).__init__() 22 | self.imgs=[os.path.join(cfg.demo_path, img_name) for img_name in os.listdir(cfg.demo_path)] 23 | 24 | def normalize_image(self, inp): 25 | inp = (inp.astype(np.float32) / 255.) 26 | inp = (inp - snake_config.mean) / snake_config.std 27 | inp = inp.transpose(2, 0, 1) 28 | return inp 29 | 30 | def __getitem__(self, index): 31 | img_path = self.imgs[index] 32 | img = cv2.imread(img_path) 33 | 34 | img = cv2.resize(img,(512,512)) 35 | imgreturn = img 36 | width, height = img.shape[1], img.shape[0] 37 | center = np.array([width // 2, height // 2]) 38 | scale = np.array([width, height]) 39 | x = 32 40 | input_w = 512 41 | input_h = 512 42 | 43 | trans_input = data_utils.get_affine_transform(center, scale, 0, [input_w, input_h]) 44 | inp = cv2.warpAffine(img, trans_input, (input_w, input_h), flags=cv2.INTER_LINEAR) 45 | img_org = inp 46 | img_org = (img_org.astype(np.float32)) 47 | inp = self.normalize_image(inp) 48 | 49 | 50 | ret = {'inp': inp} 51 | meta = {'center': center, 'scale': scale, 'test': '', 'ann': ''} 52 | ret.update({"orig_img":img_org}) 53 | ret.update({'meta': meta}) 54 | 55 | return ret, img_path, imgreturn 56 | 57 | def __len__(self): 58 | return len(self.imgs) 59 | 60 | def poly2mask(ex): 61 | 62 | ex = ex[-1] if isinstance(ex, list) else ex 63 | ex = ex.detach().cpu().numpy() * snake_config.down_ratio 64 | 65 | img = np.zeros((512,512)) 66 | ex = np.array(ex) 67 | ex = ex.astype(np.int32) 68 | img = cv2.polylines(img,[ex[0]],True,1,1) 69 | img = cv2.fillPoly(img, [ex[0]], 1) 70 | return img 71 | 72 | 73 | def cal_iou(mask, gtmask): 74 | jiaoji = mask*gtmask 75 | bingji = ((mask+gtmask)!=0).astype(np.int16) 76 | return jiaoji.sum()/bingji.sum() 77 | 78 | def cal_dice(iou): 79 | return 2*iou/(iou+1) 80 | 81 | def cal_fbound(mask, gt_mask): 82 | return FBound_metric(mask, gt_mask) 83 | 84 | def cal_wcov(mask, gt_mask): 85 | return WCov_metric(mask, gt_mask) 86 | 87 | def visual(): 88 | 89 | network = make_network(cfg).cuda() 90 | print(cfg.model_dir) 91 | load_network(network, cfg.model_dir, resume=cfg.resume, epoch=cfg.test.epoch) 92 | 93 | network.eval() 94 | dataset = Dataset() 95 | 96 | for batch, path, img in tqdm.tqdm(dataset): 97 | try: 98 | print(path) 99 | batch['inp'] = torch.FloatTensor(batch['inp'])[None].cuda() 100 | batch['orig_img'] = torch.FloatTensor(batch['orig_img'])[None].cuda() 101 | with torch.no_grad(): 102 | output= network(batch['inp'], batch) 103 | img = cv2.imread(path) 104 | print(len(output['py'])) 105 | for i in range(128): 106 | cv2.line(img, (int(output['py'][2][0,i,0].item()),int(output['py'][2][0,i,1].item())), (int(output['py'][2][0,(i+1)%128,0].item()),int(output['py'][2][0,(i+1)%128,1].item())), color=(255, 0, 255), thickness=2) 107 | cv2.imwrite(path.replace("input","output3"),img) 108 | except: 109 | continue 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | -------------------------------------------------------------------------------- /train_net.py: -------------------------------------------------------------------------------- 1 | from lib.config import cfg, args 2 | from lib.networks import make_network 3 | from lib.train import make_trainer, make_optimizer, make_lr_scheduler, make_recorder, set_lr_scheduler 4 | from lib.datasets import make_data_loader 5 | from lib.utils.net_utils import load_model, save_model, load_network 6 | import torch.multiprocessing 7 | import numpy as np 8 | import random 9 | import os 10 | import sys 11 | import time 12 | 13 | 14 | 15 | 16 | def train(cfg, network): 17 | 18 | trainer = make_trainer(cfg, network) 19 | optimizer = make_optimizer(cfg, network) 20 | scheduler = make_lr_scheduler(cfg, optimizer) 21 | recorder = make_recorder(cfg) 22 | 23 | begin_epoch = load_model(network, optimizer, scheduler, recorder, cfg.model_dir, resume=cfg.resume) 24 | 25 | begin_epoch = 0 26 | 27 | train_loader = make_data_loader(cfg, is_train=True) 28 | val_loader = make_data_loader(cfg, is_train=False) 29 | for epoch in range(begin_epoch, cfg.train.epoch): 30 | recorder.epoch = epoch 31 | trainer.train(epoch, train_loader, optimizer, recorder) 32 | scheduler.step() 33 | 34 | if (epoch + 1) % cfg.save_ep == 0: 35 | save_model(network, optimizer, scheduler, recorder, epoch, cfg.model_dir) 36 | 37 | return network 38 | 39 | 40 | def init_seed(seed): 41 | torch.cuda.manual_seed_all(seed) 42 | torch.cuda.manual_seed(seed) 43 | torch.manual_seed(seed) 44 | np.random.seed(seed) 45 | random.seed(seed) 46 | torch.backends.cudnn.enabled = False 47 | torch.backends.cudnn.deterministic = True 48 | torch.backends.cudnn.benchmark = False 49 | os.environ['PYTHONHASHSEED'] = str(seed) # 为了禁止hash随机化,使得实验可复现 50 | os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8" 51 | 52 | 53 | 54 | def main(): 55 | time1 = time.time() 56 | init_seed(1) 57 | network = make_network(cfg) 58 | network = network.cuda() 59 | 60 | train(cfg, network) 61 | 62 | print("训练时间:",time.time()-time1) 63 | 64 | 65 | if __name__ == "__main__": 66 | main() 67 | --------------------------------------------------------------------------------