├── .idea ├── .gitignore ├── CE-Net.iml ├── codeStyles │ └── codeStyleConfig.xml ├── deployment.xml ├── dictionaries │ └── guzaiwang.xml ├── inspectionProfiles │ ├── Project_Default.xml │ └── profiles_settings.xml ├── misc.xml ├── modules.xml ├── remote-mappings.xml ├── sshConfigs.xml ├── vcs.xml └── webServers.xml ├── README.md ├── dataset └── DRIVE │ ├── test │ ├── 1st_manual │ │ ├── 01_manual1.gif │ │ ├── 02_manual1.gif │ │ ├── 03_manual1.gif │ │ ├── 04_manual1.gif │ │ ├── 05_manual1.gif │ │ ├── 06_manual1.gif │ │ ├── 07_manual1.gif │ │ ├── 08_manual1.gif │ │ ├── 09_manual1.gif │ │ ├── 10_manual1.gif │ │ ├── 11_manual1.gif │ │ ├── 12_manual1.gif │ │ ├── 13_manual1.gif │ │ ├── 14_manual1.gif │ │ ├── 15_manual1.gif │ │ ├── 16_manual1.gif │ │ ├── 17_manual1.gif │ │ ├── 18_manual1.gif │ │ ├── 19_manual1.gif │ │ └── 20_manual1.gif │ ├── 2nd_manual │ │ ├── 01_manual2.gif │ │ ├── 02_manual2.gif │ │ ├── 03_manual2.gif │ │ ├── 04_manual2.gif │ │ ├── 05_manual2.gif │ │ ├── 06_manual2.gif │ │ ├── 07_manual2.gif │ │ ├── 08_manual2.gif │ │ ├── 09_manual2.gif │ │ ├── 10_manual2.gif │ │ ├── 11_manual2.gif │ │ ├── 12_manual2.gif │ │ ├── 13_manual2.gif │ │ ├── 14_manual2.gif │ │ ├── 15_manual2.gif │ │ ├── 16_manual2.gif │ │ ├── 17_manual2.gif │ │ ├── 18_manual2.gif │ │ ├── 19_manual2.gif │ │ └── 20_manual2.gif │ ├── images │ │ ├── 01_test.tif │ │ ├── 02_test.tif │ │ ├── 03_test.tif │ │ ├── 04_test.tif │ │ ├── 05_test.tif │ │ ├── 06_test.tif │ │ ├── 07_test.tif │ │ ├── 08_test.tif │ │ ├── 09_test.tif │ │ ├── 10_test.tif │ │ ├── 11_test.tif │ │ ├── 12_test.tif │ │ ├── 13_test.tif │ │ ├── 14_test.tif │ │ ├── 15_test.tif │ │ ├── 16_test.tif │ │ ├── 17_test.tif │ │ ├── 18_test.tif │ │ ├── 19_test.tif │ │ └── 20_test.tif │ └── mask │ │ ├── 01_test_mask.gif │ │ ├── 02_test_mask.gif │ │ ├── 03_test_mask.gif │ │ ├── 04_test_mask.gif │ │ ├── 05_test_mask.gif │ │ ├── 06_test_mask.gif │ │ ├── 07_test_mask.gif │ │ ├── 08_test_mask.gif │ │ ├── 09_test_mask.gif │ │ ├── 10_test_mask.gif │ │ ├── 11_test_mask.gif │ │ ├── 12_test_mask.gif │ │ ├── 13_test_mask.gif │ │ ├── 14_test_mask.gif │ │ ├── 15_test_mask.gif │ │ ├── 16_test_mask.gif │ │ ├── 17_test_mask.gif │ │ ├── 18_test_mask.gif │ │ ├── 19_test_mask.gif │ │ └── 20_test_mask.gif │ └── training │ ├── 1st_manual │ ├── 21_manual1.gif │ ├── 22_manual1.gif │ ├── 23_manual1.gif │ ├── 24_manual1.gif │ ├── 25_manual1.gif │ ├── 26_manual1.gif │ ├── 27_manual1.gif │ ├── 28_manual1.gif │ ├── 29_manual1.gif │ ├── 30_manual1.gif │ ├── 31_manual1.gif │ ├── 32_manual1.gif │ ├── 33_manual1.gif │ ├── 34_manual1.gif │ ├── 35_manual1.gif │ ├── 36_manual1.gif │ ├── 37_manual1.gif │ ├── 38_manual1.gif │ ├── 39_manual1.gif │ └── 40_manual1.gif │ └── images │ ├── 21_training.tif │ ├── 22_training.tif │ ├── 23_training.tif │ ├── 24_training.tif │ ├── 25_training.tif │ ├── 26_training.tif │ ├── 27_training.tif │ ├── 28_training.tif │ ├── 29_training.tif │ ├── 30_training.tif │ ├── 31_training.tif │ ├── 32_training.tif │ ├── 33_training.tif │ ├── 34_training.tif │ ├── 35_training.tif │ ├── 36_training.tif │ ├── 37_training.tif │ ├── 38_training.tif │ ├── 39_training.tif │ └── 40_training.tif ├── readme └── exp_res1.jpg └── src ├── _init_paths.py ├── lib ├── __init__.py ├── datasets │ ├── __init__.py │ ├── dataset │ │ ├── Cityscape.py │ │ ├── HumanSeg.py │ │ ├── ORIGA_OD.py │ │ ├── VOC.py │ │ └── __init__.py │ ├── dataset_factory.py │ └── sample │ │ ├── __init__.py │ │ ├── binarySeg.py │ │ └── multiSeg.py ├── logger.py ├── models │ ├── __init__.py │ ├── data_parallel.py │ ├── losses.py │ ├── model.py │ ├── networks │ │ ├── __init__.py │ │ ├── backbones │ │ │ ├── __init__.py │ │ │ ├── backbone_factory.py │ │ │ ├── mobilenet │ │ │ │ ├── __init__.py │ │ │ │ ├── basic_module.py │ │ │ │ ├── build_mobilenet.py │ │ │ │ └── mobilenet_factory.py │ │ │ └── resnet │ │ │ │ ├── __init__.py │ │ │ │ ├── basic_module.py │ │ │ │ ├── build_resnet.py │ │ │ │ └── resnet_factory.py │ │ ├── cenet.py │ │ └── neck_blocks │ │ │ ├── __init__.py │ │ │ └── attention_modules │ │ │ ├── Dense_atrous.py │ │ │ └── __init__.py │ └── scatter_gather.py ├── opts.py ├── trains │ ├── __init__.py │ ├── base_trainer.py │ ├── binarySeg.py │ └── train_factory.py └── utils │ ├── __init__.py │ ├── image.py │ └── utils.py └── main.py /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Datasource local storage ignored files 5 | /dataSources/ 6 | /dataSources.local.xml 7 | # Editor-based HTTP Client requests 8 | /httpRequests/ 9 | -------------------------------------------------------------------------------- /.idea/CE-Net.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 15 | -------------------------------------------------------------------------------- /.idea/codeStyles/codeStyleConfig.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | -------------------------------------------------------------------------------- /.idea/deployment.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | -------------------------------------------------------------------------------- /.idea/dictionaries/guzaiwang.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 50 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/remote-mappings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | -------------------------------------------------------------------------------- /.idea/sshConfigs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/webServers.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 14 | 15 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Context Encoder Network for 2D Medical Image Segmentation 2 | > [**CE-Net: Context Encoder Network for 2D Medical Image Segmentation**](https://arxiv.org/abs/1903.02740), 3 | > Zaiwang Gu, Jun Cheng, Huazhu Fu, Kang Zhou, Huaying Hao, Yitian Zhao, Tianyang Zhang, Shenghua Gao, Jiang Liu 4 | > *arXiv technical report ([arXiv 1903.02740](https://arxiv.org/abs/1903.02740))* 5 | 6 | 7 | Contact: [guzw@i2r.a-star.edu.sg](mailto:guzw@i2r.a-star.edu.sg) or [guzaiwang01@gmail.com](mailto:guzaiwang01@gmail.com). Any questions or discussions are welcomed! 8 | 9 | ## Abstract 10 | 11 | Medical image segmentation is an important step 12 | in medical image analysis. With the rapid development of 13 | convolutional neural network in image processing, deep learning 14 | has been used for medical image segmentation, such as optic 15 | disc segmentation, blood vessel detection, lung segmentation, cell 16 | segmentation, etc. Previously, U-net based approaches have been 17 | proposed. However, the consecutive pooling and strided convolutional operations lead to the loss of some spatial information. In 18 | this paper, we propose a context encoder network (referred to as 19 | CE-Net) to capture more high-level information and preserve 20 | spatial information for 2D medical image segmentation. CENet mainly contains three major components: a feature encoder 21 | module, a context extractor and a feature decoder module. We 22 | use pretrained ResNet block as the fixed feature extractor. The 23 | context extractor module is formed by a newly proposed dense 24 | atrous convolution (DAC) block and residual multi-kernel pooling 25 | (RMP) block. We applied the proposed CE-Net to different 2D 26 | medical image segmentation tasks. Comprehensive results show 27 | that the proposed method outperforms the original U-Net method 28 | and other state-of-the-art methods for optic disc segmentation, 29 | vessel detection, lung segmentation, cell contour segmentation 30 | and retinal optical coherence tomography layer segmentation. 31 | 32 | 33 | ## Use CE-Net 34 | Please start up the "visdom" before running the main.py. 35 | Then, run the main.py file. 36 | 37 | We have uploaded the DRIVE dataset to run the retinal vessel detection. The other medical datasets will be 38 | uploaded in the next submission. 39 | 40 | The submission mainly contains: 41 | 1. architecture (called CE-Net) in networks/cenet.py 42 | 2. multi-class dice loss in loss.py 43 | 3. data augmentation in data.py 44 | 45 | Update: 46 | We have modified the loss function. 47 | The cuda error (or warning) will not occur. 48 | 49 | Update: 50 | The test code has been uploaded. 51 | Besides, we release a pretrained model, which achieves 0.9819 in the AUC scor in the DRIVE dataset. 52 | 53 | ## Citation 54 | 55 | If you find this project useful for your research, please use the following BibTeX entry. 56 | 57 | @article{gu2019net, 58 | title={Ce-net: Context encoder network for 2d medical image segmentation}, 59 | author={Gu, Zaiwang and Cheng, Jun and Fu, Huazhu and Zhou, Kang and Hao, Huaying and Zhao, Yitian and Zhang, Tianyang and Gao, Shenghua and Liu, Jiang}, 60 | journal={IEEE transactions on medical imaging}, 61 | volume={38}, 62 | number={10}, 63 | pages={2281--2292}, 64 | year={2019}, 65 | publisher={IEEE} 66 | } 67 | 68 | The manuscript has been accepted in TMI. 69 | 70 | 71 | -------------------------------------------------------------------------------- /dataset/DRIVE/test/1st_manual/01_manual1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/1st_manual/01_manual1.gif -------------------------------------------------------------------------------- /dataset/DRIVE/test/1st_manual/02_manual1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/1st_manual/02_manual1.gif -------------------------------------------------------------------------------- /dataset/DRIVE/test/1st_manual/03_manual1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/1st_manual/03_manual1.gif -------------------------------------------------------------------------------- /dataset/DRIVE/test/1st_manual/04_manual1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/1st_manual/04_manual1.gif -------------------------------------------------------------------------------- /dataset/DRIVE/test/1st_manual/05_manual1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/1st_manual/05_manual1.gif -------------------------------------------------------------------------------- /dataset/DRIVE/test/1st_manual/06_manual1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/1st_manual/06_manual1.gif -------------------------------------------------------------------------------- /dataset/DRIVE/test/1st_manual/07_manual1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/1st_manual/07_manual1.gif -------------------------------------------------------------------------------- /dataset/DRIVE/test/1st_manual/08_manual1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/1st_manual/08_manual1.gif -------------------------------------------------------------------------------- /dataset/DRIVE/test/1st_manual/09_manual1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/1st_manual/09_manual1.gif -------------------------------------------------------------------------------- /dataset/DRIVE/test/1st_manual/10_manual1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/1st_manual/10_manual1.gif -------------------------------------------------------------------------------- /dataset/DRIVE/test/1st_manual/11_manual1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/1st_manual/11_manual1.gif -------------------------------------------------------------------------------- /dataset/DRIVE/test/1st_manual/12_manual1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/1st_manual/12_manual1.gif -------------------------------------------------------------------------------- /dataset/DRIVE/test/1st_manual/13_manual1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/1st_manual/13_manual1.gif -------------------------------------------------------------------------------- /dataset/DRIVE/test/1st_manual/14_manual1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/1st_manual/14_manual1.gif -------------------------------------------------------------------------------- /dataset/DRIVE/test/1st_manual/15_manual1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/1st_manual/15_manual1.gif -------------------------------------------------------------------------------- /dataset/DRIVE/test/1st_manual/16_manual1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/1st_manual/16_manual1.gif -------------------------------------------------------------------------------- /dataset/DRIVE/test/1st_manual/17_manual1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/1st_manual/17_manual1.gif -------------------------------------------------------------------------------- /dataset/DRIVE/test/1st_manual/18_manual1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/1st_manual/18_manual1.gif -------------------------------------------------------------------------------- /dataset/DRIVE/test/1st_manual/19_manual1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/1st_manual/19_manual1.gif -------------------------------------------------------------------------------- /dataset/DRIVE/test/1st_manual/20_manual1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/1st_manual/20_manual1.gif -------------------------------------------------------------------------------- /dataset/DRIVE/test/2nd_manual/01_manual2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/2nd_manual/01_manual2.gif -------------------------------------------------------------------------------- /dataset/DRIVE/test/2nd_manual/02_manual2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/2nd_manual/02_manual2.gif -------------------------------------------------------------------------------- /dataset/DRIVE/test/2nd_manual/03_manual2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/2nd_manual/03_manual2.gif -------------------------------------------------------------------------------- /dataset/DRIVE/test/2nd_manual/04_manual2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/2nd_manual/04_manual2.gif -------------------------------------------------------------------------------- /dataset/DRIVE/test/2nd_manual/05_manual2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/2nd_manual/05_manual2.gif -------------------------------------------------------------------------------- /dataset/DRIVE/test/2nd_manual/06_manual2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/2nd_manual/06_manual2.gif -------------------------------------------------------------------------------- /dataset/DRIVE/test/2nd_manual/07_manual2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/2nd_manual/07_manual2.gif -------------------------------------------------------------------------------- /dataset/DRIVE/test/2nd_manual/08_manual2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/2nd_manual/08_manual2.gif -------------------------------------------------------------------------------- /dataset/DRIVE/test/2nd_manual/09_manual2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/2nd_manual/09_manual2.gif -------------------------------------------------------------------------------- /dataset/DRIVE/test/2nd_manual/10_manual2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/2nd_manual/10_manual2.gif -------------------------------------------------------------------------------- /dataset/DRIVE/test/2nd_manual/11_manual2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/2nd_manual/11_manual2.gif -------------------------------------------------------------------------------- /dataset/DRIVE/test/2nd_manual/12_manual2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/2nd_manual/12_manual2.gif -------------------------------------------------------------------------------- /dataset/DRIVE/test/2nd_manual/13_manual2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/2nd_manual/13_manual2.gif -------------------------------------------------------------------------------- /dataset/DRIVE/test/2nd_manual/14_manual2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/2nd_manual/14_manual2.gif -------------------------------------------------------------------------------- /dataset/DRIVE/test/2nd_manual/15_manual2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/2nd_manual/15_manual2.gif -------------------------------------------------------------------------------- /dataset/DRIVE/test/2nd_manual/16_manual2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/2nd_manual/16_manual2.gif -------------------------------------------------------------------------------- /dataset/DRIVE/test/2nd_manual/17_manual2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/2nd_manual/17_manual2.gif -------------------------------------------------------------------------------- /dataset/DRIVE/test/2nd_manual/18_manual2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/2nd_manual/18_manual2.gif -------------------------------------------------------------------------------- /dataset/DRIVE/test/2nd_manual/19_manual2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/2nd_manual/19_manual2.gif -------------------------------------------------------------------------------- /dataset/DRIVE/test/2nd_manual/20_manual2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/2nd_manual/20_manual2.gif -------------------------------------------------------------------------------- /dataset/DRIVE/test/images/01_test.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/images/01_test.tif -------------------------------------------------------------------------------- /dataset/DRIVE/test/images/02_test.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/images/02_test.tif -------------------------------------------------------------------------------- /dataset/DRIVE/test/images/03_test.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/images/03_test.tif -------------------------------------------------------------------------------- /dataset/DRIVE/test/images/04_test.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/images/04_test.tif -------------------------------------------------------------------------------- /dataset/DRIVE/test/images/05_test.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/images/05_test.tif -------------------------------------------------------------------------------- /dataset/DRIVE/test/images/06_test.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/images/06_test.tif -------------------------------------------------------------------------------- /dataset/DRIVE/test/images/07_test.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/images/07_test.tif -------------------------------------------------------------------------------- /dataset/DRIVE/test/images/08_test.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/images/08_test.tif -------------------------------------------------------------------------------- /dataset/DRIVE/test/images/09_test.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/images/09_test.tif -------------------------------------------------------------------------------- /dataset/DRIVE/test/images/10_test.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/images/10_test.tif -------------------------------------------------------------------------------- /dataset/DRIVE/test/images/11_test.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/images/11_test.tif -------------------------------------------------------------------------------- /dataset/DRIVE/test/images/12_test.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/images/12_test.tif -------------------------------------------------------------------------------- /dataset/DRIVE/test/images/13_test.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/images/13_test.tif -------------------------------------------------------------------------------- /dataset/DRIVE/test/images/14_test.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/images/14_test.tif -------------------------------------------------------------------------------- /dataset/DRIVE/test/images/15_test.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/images/15_test.tif -------------------------------------------------------------------------------- /dataset/DRIVE/test/images/16_test.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/images/16_test.tif -------------------------------------------------------------------------------- /dataset/DRIVE/test/images/17_test.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/images/17_test.tif -------------------------------------------------------------------------------- /dataset/DRIVE/test/images/18_test.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/images/18_test.tif -------------------------------------------------------------------------------- /dataset/DRIVE/test/images/19_test.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/images/19_test.tif -------------------------------------------------------------------------------- /dataset/DRIVE/test/images/20_test.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/images/20_test.tif -------------------------------------------------------------------------------- /dataset/DRIVE/test/mask/01_test_mask.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/mask/01_test_mask.gif -------------------------------------------------------------------------------- /dataset/DRIVE/test/mask/02_test_mask.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/mask/02_test_mask.gif -------------------------------------------------------------------------------- /dataset/DRIVE/test/mask/03_test_mask.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/mask/03_test_mask.gif -------------------------------------------------------------------------------- /dataset/DRIVE/test/mask/04_test_mask.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/mask/04_test_mask.gif -------------------------------------------------------------------------------- /dataset/DRIVE/test/mask/05_test_mask.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/mask/05_test_mask.gif -------------------------------------------------------------------------------- /dataset/DRIVE/test/mask/06_test_mask.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/mask/06_test_mask.gif -------------------------------------------------------------------------------- /dataset/DRIVE/test/mask/07_test_mask.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/mask/07_test_mask.gif -------------------------------------------------------------------------------- /dataset/DRIVE/test/mask/08_test_mask.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/mask/08_test_mask.gif -------------------------------------------------------------------------------- /dataset/DRIVE/test/mask/09_test_mask.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/mask/09_test_mask.gif -------------------------------------------------------------------------------- /dataset/DRIVE/test/mask/10_test_mask.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/mask/10_test_mask.gif -------------------------------------------------------------------------------- /dataset/DRIVE/test/mask/11_test_mask.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/mask/11_test_mask.gif -------------------------------------------------------------------------------- /dataset/DRIVE/test/mask/12_test_mask.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/mask/12_test_mask.gif -------------------------------------------------------------------------------- /dataset/DRIVE/test/mask/13_test_mask.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/mask/13_test_mask.gif -------------------------------------------------------------------------------- /dataset/DRIVE/test/mask/14_test_mask.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/mask/14_test_mask.gif -------------------------------------------------------------------------------- /dataset/DRIVE/test/mask/15_test_mask.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/mask/15_test_mask.gif -------------------------------------------------------------------------------- /dataset/DRIVE/test/mask/16_test_mask.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/mask/16_test_mask.gif -------------------------------------------------------------------------------- /dataset/DRIVE/test/mask/17_test_mask.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/mask/17_test_mask.gif -------------------------------------------------------------------------------- /dataset/DRIVE/test/mask/18_test_mask.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/mask/18_test_mask.gif -------------------------------------------------------------------------------- /dataset/DRIVE/test/mask/19_test_mask.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/mask/19_test_mask.gif -------------------------------------------------------------------------------- /dataset/DRIVE/test/mask/20_test_mask.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/mask/20_test_mask.gif -------------------------------------------------------------------------------- /dataset/DRIVE/training/1st_manual/21_manual1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/1st_manual/21_manual1.gif -------------------------------------------------------------------------------- /dataset/DRIVE/training/1st_manual/22_manual1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/1st_manual/22_manual1.gif -------------------------------------------------------------------------------- /dataset/DRIVE/training/1st_manual/23_manual1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/1st_manual/23_manual1.gif -------------------------------------------------------------------------------- /dataset/DRIVE/training/1st_manual/24_manual1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/1st_manual/24_manual1.gif -------------------------------------------------------------------------------- /dataset/DRIVE/training/1st_manual/25_manual1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/1st_manual/25_manual1.gif -------------------------------------------------------------------------------- /dataset/DRIVE/training/1st_manual/26_manual1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/1st_manual/26_manual1.gif -------------------------------------------------------------------------------- /dataset/DRIVE/training/1st_manual/27_manual1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/1st_manual/27_manual1.gif -------------------------------------------------------------------------------- /dataset/DRIVE/training/1st_manual/28_manual1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/1st_manual/28_manual1.gif -------------------------------------------------------------------------------- /dataset/DRIVE/training/1st_manual/29_manual1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/1st_manual/29_manual1.gif -------------------------------------------------------------------------------- /dataset/DRIVE/training/1st_manual/30_manual1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/1st_manual/30_manual1.gif -------------------------------------------------------------------------------- /dataset/DRIVE/training/1st_manual/31_manual1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/1st_manual/31_manual1.gif -------------------------------------------------------------------------------- /dataset/DRIVE/training/1st_manual/32_manual1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/1st_manual/32_manual1.gif -------------------------------------------------------------------------------- /dataset/DRIVE/training/1st_manual/33_manual1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/1st_manual/33_manual1.gif -------------------------------------------------------------------------------- /dataset/DRIVE/training/1st_manual/34_manual1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/1st_manual/34_manual1.gif -------------------------------------------------------------------------------- /dataset/DRIVE/training/1st_manual/35_manual1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/1st_manual/35_manual1.gif -------------------------------------------------------------------------------- /dataset/DRIVE/training/1st_manual/36_manual1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/1st_manual/36_manual1.gif -------------------------------------------------------------------------------- /dataset/DRIVE/training/1st_manual/37_manual1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/1st_manual/37_manual1.gif -------------------------------------------------------------------------------- /dataset/DRIVE/training/1st_manual/38_manual1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/1st_manual/38_manual1.gif -------------------------------------------------------------------------------- /dataset/DRIVE/training/1st_manual/39_manual1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/1st_manual/39_manual1.gif -------------------------------------------------------------------------------- /dataset/DRIVE/training/1st_manual/40_manual1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/1st_manual/40_manual1.gif -------------------------------------------------------------------------------- /dataset/DRIVE/training/images/21_training.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/images/21_training.tif -------------------------------------------------------------------------------- /dataset/DRIVE/training/images/22_training.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/images/22_training.tif -------------------------------------------------------------------------------- /dataset/DRIVE/training/images/23_training.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/images/23_training.tif -------------------------------------------------------------------------------- /dataset/DRIVE/training/images/24_training.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/images/24_training.tif -------------------------------------------------------------------------------- /dataset/DRIVE/training/images/25_training.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/images/25_training.tif -------------------------------------------------------------------------------- /dataset/DRIVE/training/images/26_training.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/images/26_training.tif -------------------------------------------------------------------------------- /dataset/DRIVE/training/images/27_training.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/images/27_training.tif -------------------------------------------------------------------------------- /dataset/DRIVE/training/images/28_training.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/images/28_training.tif -------------------------------------------------------------------------------- /dataset/DRIVE/training/images/29_training.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/images/29_training.tif -------------------------------------------------------------------------------- /dataset/DRIVE/training/images/30_training.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/images/30_training.tif -------------------------------------------------------------------------------- /dataset/DRIVE/training/images/31_training.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/images/31_training.tif -------------------------------------------------------------------------------- /dataset/DRIVE/training/images/32_training.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/images/32_training.tif -------------------------------------------------------------------------------- /dataset/DRIVE/training/images/33_training.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/images/33_training.tif -------------------------------------------------------------------------------- /dataset/DRIVE/training/images/34_training.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/images/34_training.tif -------------------------------------------------------------------------------- /dataset/DRIVE/training/images/35_training.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/images/35_training.tif -------------------------------------------------------------------------------- /dataset/DRIVE/training/images/36_training.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/images/36_training.tif -------------------------------------------------------------------------------- /dataset/DRIVE/training/images/37_training.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/images/37_training.tif -------------------------------------------------------------------------------- /dataset/DRIVE/training/images/38_training.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/images/38_training.tif -------------------------------------------------------------------------------- /dataset/DRIVE/training/images/39_training.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/images/39_training.tif -------------------------------------------------------------------------------- /dataset/DRIVE/training/images/40_training.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/images/40_training.tif -------------------------------------------------------------------------------- /readme/exp_res1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/readme/exp_res1.jpg -------------------------------------------------------------------------------- /src/_init_paths.py: -------------------------------------------------------------------------------- 1 | # *coding:utf-8 * 2 | 3 | import os.path as osp 4 | import sys 5 | 6 | 7 | def add_path(path): 8 | if path not in sys.path: 9 | sys.path.insert(0, path) 10 | 11 | 12 | def __init_path(): 13 | this_dir = osp.dirname(__file__) 14 | 15 | # Add lib to PYTHONPATH 16 | lib_path = osp.join(this_dir, 'lib') 17 | add_path(lib_path) 18 | 19 | if __name__ == '__main__': 20 | __init_path() 21 | print("system path:") 22 | print(sys.path) 23 | -------------------------------------------------------------------------------- /src/lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/src/lib/__init__.py -------------------------------------------------------------------------------- /src/lib/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/src/lib/datasets/__init__.py -------------------------------------------------------------------------------- /src/lib/datasets/dataset/Cityscape.py: -------------------------------------------------------------------------------- 1 | # *coding:utf-8 * 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | 7 | import numpy as np 8 | import os 9 | 10 | import torch.utils.data as data 11 | 12 | class Cityscape(data.Dataset): 13 | num_classes = 1 14 | default_resolution = [448, 448] 15 | mean = np.array([0.40789654, 0.44719302, 0.47026115], 16 | dtype=np.float32).reshape(1, 1, 3) 17 | std = np.array([0.28863828, 0.27408164, 0.27809835], 18 | dtype=np.float32).reshape(1, 1, 3) 19 | 20 | def __init__(self, opt, split): 21 | self.images = [] 22 | self.labels = [] 23 | self.opt = opt 24 | 25 | if split == 'train': 26 | read_files = os.path.join(opt.data_dir, 'Set_A.txt') 27 | else: 28 | read_files = os.path.join(opt.data_dir, 'Set_B.txt') 29 | 30 | self.image_root_folder = os.path.join(opt.data_dir, 'crop_image') 31 | self.gt_root_folder = os.path.join(opt.data_dir, 'crop_mask') 32 | 33 | self._read_img_mask(self.image_root_folder, self.gt_root_folder, read_files) 34 | 35 | def _read_img_mask(self, image_folder, mask_folder, read_files): 36 | for img_name in open(read_files): 37 | image_path = os.path.join(image_folder, img_name.split('.')[0] + '.jpg') 38 | label_path = os.path.join(mask_folder, img_name.split('.')[0] + '.png') 39 | 40 | self.images.append(image_path) 41 | self.labels.append(label_path) 42 | 43 | def __len__(self): 44 | assert len(self.images) == len(self.labels), 'The number of images must be equal to labels' 45 | return len(self.images) -------------------------------------------------------------------------------- /src/lib/datasets/dataset/HumanSeg.py: -------------------------------------------------------------------------------- 1 | # *coding:utf-8 * 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | 8 | import numpy as np 9 | import os 10 | 11 | import torch.utils.data as data 12 | 13 | 14 | class HumanSeg(data.Dataset): 15 | num_classes = 1 16 | default_resolution = [448, 448] 17 | mean = np.array([0.40789654, 0.44719302, 0.47026115], 18 | dtype=np.float32).reshape(1, 1, 3) 19 | std = np.array([0.28863828, 0.27408164, 0.27809835], 20 | dtype=np.float32).reshape(1, 1, 3) 21 | def __init__(self, opt, split): 22 | self.images = [] 23 | self.labels = [] 24 | self.opt = opt 25 | 26 | if split == 'train': 27 | self.image_root_folder = os.path.join(opt.data_dir,'test', 'imgs') 28 | self.gt_root_folder = os.path.join(opt.data_dir,'test', 'masks') 29 | else: 30 | self.image_root_folder = os.path.join(opt.data_dir, 'val', 'imgs') 31 | self.gt_root_folder = os.path.join(opt.data_dir, 'val', 'masks') 32 | 33 | self._read_img_mask(self.image_root_folder, self.gt_root_folder) 34 | 35 | def _read_img_mask(self, image_folder, mask_folder): 36 | for img_name in os.listdir(image_folder): 37 | image_path = os.path.join(image_folder, img_name.split('.')[0] + '.png') 38 | label_path = os.path.join(mask_folder, img_name.split('.')[0] + '.png') 39 | 40 | self.images.append(image_path) 41 | self.labels.append(label_path) 42 | 43 | def __len__(self): 44 | assert len(self.images) == len(self.labels), 'The number of images must be equal to labels' 45 | return len(self.images) -------------------------------------------------------------------------------- /src/lib/datasets/dataset/ORIGA_OD.py: -------------------------------------------------------------------------------- 1 | # *coding:utf-8 * 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | 8 | import numpy as np 9 | import os 10 | 11 | import torch.utils.data as data 12 | 13 | 14 | class ORIGA_OD(data.Dataset): 15 | num_classes = 1 16 | default_resolution = [448, 448] 17 | mean = np.array([0.40789654, 0.44719302, 0.47026115], 18 | dtype=np.float32).reshape(1, 1, 3) 19 | std = np.array([0.28863828, 0.27408164, 0.27809835], 20 | dtype=np.float32).reshape(1, 1, 3) 21 | def __init__(self, opt, split): 22 | self.images = [] 23 | self.labels = [] 24 | self.opt = opt 25 | 26 | if split == 'train': 27 | read_files = os.path.join(opt.data_dir, 'Set_A.txt') 28 | else: 29 | read_files = os.path.join(opt.data_dir, 'Set_B.txt') 30 | 31 | self.image_root_folder = os.path.join(opt.data_dir, 'crop_image') 32 | self.gt_root_folder = os.path.join(opt.data_dir, 'crop_mask') 33 | 34 | self._read_img_mask(self.image_root_folder, self.gt_root_folder, read_files) 35 | 36 | def _read_img_mask(self, image_folder, mask_folder, read_files): 37 | for img_name in open(read_files): 38 | image_path = os.path.join(image_folder, img_name.split('.')[0] + '.jpg') 39 | label_path = os.path.join(mask_folder, img_name.split('.')[0] + '.png') 40 | 41 | self.images.append(image_path) 42 | self.labels.append(label_path) 43 | 44 | def __len__(self): 45 | assert len(self.images) == len(self.labels), 'The number of images must be equal to labels' 46 | return len(self.images) -------------------------------------------------------------------------------- /src/lib/datasets/dataset/VOC.py: -------------------------------------------------------------------------------- 1 | # *coding:utf-8 * 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | 8 | import numpy as np 9 | import os 10 | 11 | import torch.utils.data as data 12 | 13 | class VOC(data.Dataset): 14 | num_classes = 1 15 | default_resolution = [448, 448] 16 | mean = np.array([0.40789654, 0.44719302, 0.47026115], 17 | dtype=np.float32).reshape(1, 1, 3) 18 | std = np.array([0.28863828, 0.27408164, 0.27809835], 19 | dtype=np.float32).reshape(1, 1, 3) 20 | 21 | def __init__(self, opt, split): 22 | self.images = [] 23 | self.labels = [] 24 | self.opt = opt 25 | 26 | if split == 'train': 27 | read_files = os.path.join(opt.data_dir, 'Set_A.txt') 28 | else: 29 | read_files = os.path.join(opt.data_dir, 'Set_B.txt') 30 | 31 | self.image_root_folder = os.path.join(opt.data_dir, 'crop_image') 32 | self.gt_root_folder = os.path.join(opt.data_dir, 'crop_mask') 33 | 34 | self._read_img_mask(self.image_root_folder, self.gt_root_folder, read_files) 35 | 36 | def _read_img_mask(self, image_folder, mask_folder, read_files): 37 | for img_name in open(read_files): 38 | image_path = os.path.join(image_folder, img_name.split('.')[0] + '.jpg') 39 | label_path = os.path.join(mask_folder, img_name.split('.')[0] + '.png') 40 | 41 | self.images.append(image_path) 42 | self.labels.append(label_path) 43 | 44 | def __len__(self): 45 | assert len(self.images) == len(self.labels), 'The number of images must be equal to labels' 46 | return len(self.images) 47 | -------------------------------------------------------------------------------- /src/lib/datasets/dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/src/lib/datasets/dataset/__init__.py -------------------------------------------------------------------------------- /src/lib/datasets/dataset_factory.py: -------------------------------------------------------------------------------- 1 | # *coding:utf-8 * 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | from .dataset.ORIGA_OD import ORIGA_OD 7 | from .dataset.VOC import VOC 8 | from .dataset.Cityscape import Cityscape 9 | from .dataset.HumanSeg import HumanSeg 10 | 11 | from .sample.binarySeg import BinarySegDataset 12 | # from .sampel.multiSeg import MultiSegDataset 13 | 14 | dataset_factory = { 15 | 'ORIGA_OD': ORIGA_OD, 16 | 'VOC': VOC, 17 | 'CityScape': Cityscape, 18 | 'humanseg': HumanSeg, 19 | } 20 | 21 | _sample_factory = { 22 | 'binSeg': BinarySegDataset, 23 | # 'multiSeg': MultiSegDataset 24 | } 25 | 26 | 27 | def get_dataset(dataset, task): 28 | class Dataset(dataset_factory[dataset], _sample_factory[task]): 29 | pass 30 | 31 | return Dataset 32 | -------------------------------------------------------------------------------- /src/lib/datasets/sample/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/src/lib/datasets/sample/__init__.py -------------------------------------------------------------------------------- /src/lib/datasets/sample/binarySeg.py: -------------------------------------------------------------------------------- 1 | # *coding:utf-8 * 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import torch 8 | import torch.utils.data as data 9 | from torch.autograd import Variable as V 10 | 11 | import cv2 12 | import numpy as np 13 | from PIL import Image 14 | 15 | from utils.image import randomHueSaturationValue 16 | from utils.image import randomShiftScaleRotate 17 | from utils.image import randomHorizontalFlip 18 | from utils.image import randomVerticleFlip, randomRotate90 19 | 20 | 21 | class BinarySegDataset(data.Dataset): 22 | 23 | def __getitem__(self, index): 24 | img = cv2.imread(self.images[index]) 25 | img = cv2.resize(img, (self.opt.height, self.opt.width)) 26 | 27 | mask = np.array(Image.open(self.labels[index])) 28 | mask = cv2.resize(mask, (self.opt.height, self.opt.width)) 29 | 30 | # Data augmentation 31 | if self.opt.color_aug: 32 | img = randomHueSaturationValue(img, 33 | hue_shift_limit=(-30, 30), 34 | sat_shift_limit=(-5, 5), 35 | val_shift_limit=(-15, 15)) 36 | 37 | if self.opt.shift_scale: 38 | img, mask = randomShiftScaleRotate(img, mask, 39 | shift_limit=(-0.1, 0.1), 40 | scale_limit=(-0.1, 0.1), 41 | aspect_limit=(-0.1, 0.1), 42 | rotate_limit=(-0, 0)) 43 | 44 | if self.opt.HorizontalFlip: 45 | img, mask = randomHorizontalFlip(img, mask) 46 | 47 | if self.opt.VerticleFlip: 48 | img, mask = randomVerticleFlip(img, mask) 49 | 50 | if self.opt.rotate_90: 51 | img, mask = randomRotate90(img, mask) 52 | 53 | mask = np.expand_dims(mask, axis=2) 54 | img = np.array(img, np.float32).transpose(2, 0, 1) / 255.0 * 3.2 - 1.6 55 | mask = np.array(mask, np.float32).transpose(2, 0, 1) / 255.0 56 | mask[mask >= 0.5] = 1 57 | mask[mask <= 0.5] = 0 58 | 59 | img = torch.Tensor(img) 60 | mask = torch.Tensor(mask) 61 | 62 | ret = {'input': img, 'gt': mask} 63 | return ret 64 | -------------------------------------------------------------------------------- /src/lib/datasets/sample/multiSeg.py: -------------------------------------------------------------------------------- 1 | # *coding:utf-8 * 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import torch 8 | import torch.utils.data as data 9 | from torch.autograd import Variable as V 10 | 11 | import cv2 12 | import numpy as np 13 | from PIL import Image 14 | 15 | from utils.image import randomHueSaturationValue 16 | from utils.image import randomShiftScaleRotate 17 | from utils.image import randomHorizontalFlip 18 | from utils.image import randomVerticleFlip, randomRotate90 19 | 20 | 21 | class BinarySegDataset(data.Dataset): 22 | 23 | def __getitem__(self, index): 24 | img = cv2.imread(self.images[index]) 25 | img = cv2.resize(img, (self.opt.height, self.opt.width)) 26 | 27 | mask = np.array(Image.open(self.labels[index])) 28 | mask = cv2.resize(mask, (self.opt.height, self.opt.width)) 29 | 30 | # Data augmentation 31 | img = randomHueSaturationValue(img, 32 | hue_shift_limit=(-30, 30), 33 | sat_shift_limit=(-5, 5), 34 | val_shift_limit=(-15, 15)) 35 | 36 | img, mask = randomShiftScaleRotate(img, mask, 37 | shift_limit=(-0.1, 0.1), 38 | scale_limit=(-0.1, 0.1), 39 | aspect_limit=(-0.1, 0.1), 40 | rotate_limit=(-0, 0)) 41 | img, mask = randomHorizontalFlip(img, mask) 42 | img, mask = randomVerticleFlip(img, mask) 43 | img, mask = randomRotate90(img, mask) 44 | 45 | mask = np.expand_dims(mask, axis=2) 46 | img = np.array(img, np.float32).transpose(2, 0, 1) / 255.0 * 3.2 - 1.6 47 | mask = np.array(mask, np.float32).transpose(2, 0, 1) / 255.0 48 | mask[mask >= 0.5] = 1 49 | mask[mask <= 0.5] = 0 50 | 51 | img = torch.Tensor(img) 52 | mask = torch.Tensor(mask) 53 | 54 | ret = {'input': img, 'gt': mask} 55 | return ret 56 | -------------------------------------------------------------------------------- /src/lib/logger.py: -------------------------------------------------------------------------------- 1 | # *coding:utf-8 * 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | 7 | import os 8 | import time 9 | import sys 10 | import torch 11 | 12 | USE_TENSORBOARD = True 13 | try: 14 | import tensorboardX 15 | print("Using tensorboardX") 16 | except: 17 | USE_TENSORBOARD = False 18 | 19 | class Logger(object): 20 | def __init__(self, opt): 21 | """ 22 | Create a summary writer logging to log_dir. 23 | """ 24 | if not os.path.exists(opt.save_dir): 25 | os.makedirs(opt.save_dir) 26 | if not os.path.exists(opt.debug_dir): 27 | os.makedirs(opt.debug_dir) 28 | 29 | time_str = time.strftime('%Y-%m-%d-%H-%M') 30 | 31 | args = dict((name, getattr(opt, name)) for name in dir(opt) 32 | if not name.startswith('_')) 33 | 34 | file_name = os.path.join(opt.save_dir, 'opt.txt') 35 | 36 | with open(file_name, 'wt') as opt_file: 37 | opt_file.write('==> torch version: {}\n'.format(torch.__version__)) 38 | opt_file.write('==> cudnn version: {}\n'.format( 39 | torch.backends.cudnn.version())) 40 | opt_file.write('==> Cmd:\n') 41 | opt_file.write(str(sys.argv)) 42 | opt_file.write('\n==> Opt:\n') 43 | for k, v in sorted(args.items()): 44 | opt_file.write(' %s: %s\n' % (str(k), str(v))) 45 | 46 | log_dir = opt.save_dir + '/logs_{}'.format(time_str) 47 | if USE_TENSORBOARD: 48 | self.writer = tensorboardX.SummaryWriter(log_dir=log_dir) 49 | else: 50 | if not os.path.exists(os.path.dirname(log_dir)): 51 | os.mkdir(os.path.dirname(log_dir)) 52 | if not os.path.exists(log_dir): 53 | os.mkdir(log_dir) 54 | self.log = open(log_dir + '/log.txt', 'w') 55 | try: 56 | os.system('cp {}/opt.txt {}/'.format(opt.save_dir, log_dir)) 57 | except: 58 | pass 59 | self.start_line = True 60 | 61 | def write(self, txt): 62 | if self.start_line: 63 | time_str = time.strftime('%Y-%m-%d-%H-%M') 64 | self.log.write('{}: {}'.format(time_str, txt)) 65 | else: 66 | self.log.write(txt) 67 | self.start_line = False 68 | if '\n' in txt: 69 | self.start_line = True 70 | self.log.flush() 71 | 72 | def close(self): 73 | self.log.close() 74 | 75 | def scalar_summary(self, tag, value, step): 76 | """Log a scalar variable.""" 77 | if USE_TENSORBOARD: 78 | self.writer.add_scalar(tag, value, step) -------------------------------------------------------------------------------- /src/lib/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/src/lib/models/__init__.py -------------------------------------------------------------------------------- /src/lib/models/data_parallel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.modules import Module 3 | from torch.nn.parallel.scatter_gather import gather 4 | from torch.nn.parallel.replicate import replicate 5 | from torch.nn.parallel.parallel_apply import parallel_apply 6 | 7 | 8 | from .scatter_gather import scatter_kwargs 9 | 10 | class _DataParallel(Module): 11 | r"""Implements data parallelism at the module level. 12 | 13 | This container parallelizes the application of the given module by 14 | splitting the input across the specified devices by chunking in the batch 15 | dimension. In the forward pass, the module is replicated on each device, 16 | and each replica handles a portion of the input. During the backwards 17 | pass, gradients from each replica are summed into the original module. 18 | 19 | The batch size should be larger than the number of GPUs used. It should 20 | also be an integer multiple of the number of GPUs so that each chunk is the 21 | same size (so that each GPU processes the same number of samples). 22 | 23 | See also: :ref:`cuda-nn-dataparallel-instead` 24 | 25 | Arbitrary positional and keyword inputs are allowed to be passed into 26 | DataParallel EXCEPT Tensors. All variables will be scattered on dim 27 | specified (default 0). Primitive types will be broadcasted, but all 28 | other types will be a shallow copy and can be corrupted if written to in 29 | the model's forward pass. 30 | 31 | Args: 32 | module: module to be parallelized 33 | device_ids: CUDA devices (default: all devices) 34 | output_device: device location of output (default: device_ids[0]) 35 | 36 | Example:: 37 | 38 | >>> net = torch.nn.DataParallel(model, device_ids=[0, 1, 2]) 39 | >>> output = net(input_var) 40 | """ 41 | 42 | 43 | def __init__(self, module, device_ids=None, output_device=None, dim=0, chunk_sizes=None): 44 | super(_DataParallel, self).__init__() 45 | 46 | if not torch.cuda.is_available(): 47 | self.module = module 48 | self.device_ids = [] 49 | return 50 | 51 | if device_ids is None: 52 | device_ids = list(range(torch.cuda.device_count())) 53 | if output_device is None: 54 | output_device = device_ids[0] 55 | self.dim = dim 56 | self.module = module 57 | self.device_ids = device_ids 58 | self.chunk_sizes = chunk_sizes 59 | self.output_device = output_device 60 | if len(self.device_ids) == 1: 61 | self.module.cuda(device_ids[0]) 62 | 63 | def forward(self, *inputs, **kwargs): 64 | if not self.device_ids: 65 | return self.module(*inputs, **kwargs) 66 | inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids, self.chunk_sizes) 67 | if len(self.device_ids) == 1: 68 | return self.module(*inputs[0], **kwargs[0]) 69 | replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) 70 | outputs = self.parallel_apply(replicas, inputs, kwargs) 71 | return self.gather(outputs, self.output_device) 72 | 73 | def replicate(self, module, device_ids): 74 | return replicate(module, device_ids) 75 | 76 | def scatter(self, inputs, kwargs, device_ids, chunk_sizes): 77 | return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim, chunk_sizes=self.chunk_sizes) 78 | 79 | def parallel_apply(self, replicas, inputs, kwargs): 80 | return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)]) 81 | 82 | def gather(self, outputs, output_device): 83 | return gather(outputs, output_device, dim=self.dim) 84 | 85 | 86 | def data_parallel(module, inputs, device_ids=None, output_device=None, dim=0, module_kwargs=None): 87 | r"""Evaluates module(input) in parallel across the GPUs given in device_ids. 88 | 89 | This is the functional version of the DataParallel module. 90 | 91 | Args: 92 | module: the module to evaluate in parallel 93 | inputs: inputs to the module 94 | device_ids: GPU ids on which to replicate module 95 | output_device: GPU location of the output Use -1 to indicate the CPU. 96 | (default: device_ids[0]) 97 | Returns: 98 | a Variable containing the result of module(input) located on 99 | output_device 100 | """ 101 | if not isinstance(inputs, tuple): 102 | inputs = (inputs,) 103 | 104 | if device_ids is None: 105 | device_ids = list(range(torch.cuda.device_count())) 106 | 107 | if output_device is None: 108 | output_device = device_ids[0] 109 | 110 | inputs, module_kwargs = scatter_kwargs(inputs, module_kwargs, device_ids, dim) 111 | if len(device_ids) == 1: 112 | return module(*inputs[0], **module_kwargs[0]) 113 | used_device_ids = device_ids[:len(inputs)] 114 | replicas = replicate(module, used_device_ids) 115 | outputs = parallel_apply(replicas, inputs, module_kwargs, used_device_ids) 116 | return gather(outputs, output_device, dim) 117 | 118 | def DataParallel(module, device_ids=None, output_device=None, dim=0, chunk_sizes=None): 119 | if chunk_sizes is None: 120 | return torch.nn.DataParallel(module, device_ids, output_device, dim) 121 | standard_size = True 122 | for i in range(1, len(chunk_sizes)): 123 | if chunk_sizes[i] != chunk_sizes[0]: 124 | standard_size = False 125 | if standard_size: 126 | return torch.nn.DataParallel(module, device_ids, output_device, dim) 127 | return _DataParallel(module, device_ids, output_device, dim, chunk_sizes) -------------------------------------------------------------------------------- /src/lib/models/losses.py: -------------------------------------------------------------------------------- 1 | # *coding:utf-8 * 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.autograd import Variable as V 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from torch.autograd import Variable 11 | from utils.image import mask_to_boundary 12 | 13 | import cv2 14 | import numpy as np 15 | 16 | 17 | class weighted_cross_entropy(nn.Module): 18 | def __init__(self, num_classes=12, batch=True): 19 | super(weighted_cross_entropy, self).__init__() 20 | self.batch = batch 21 | self.weight = torch.Tensor([52.] * num_classes).cuda() 22 | self.ce_loss = nn.CrossEntropyLoss(weight=self.weight) 23 | 24 | def __call__(self, y_true, y_pred): 25 | y_ce_true = y_true.squeeze(dim=1).long() 26 | a = self.ce_loss(y_pred, y_ce_true) 27 | 28 | return a 29 | 30 | 31 | class dice_loss(nn.Module): 32 | def __init__(self, batch=True): 33 | super(dice_loss, self).__init__() 34 | self.batch = batch 35 | 36 | def soft_dice_coeff(self, y_true, y_pred): 37 | smooth = 0.0 # may change 38 | if self.batch: 39 | i = torch.sum(y_true) 40 | j = torch.sum(y_pred) 41 | intersection = torch.sum(y_true * y_pred) 42 | else: 43 | i = y_true.sum(1).sum(1).sum(1) 44 | j = y_pred.sum(1).sum(1).sum(1) 45 | intersection = (y_true * y_pred).sum(1).sum(1).sum(1) 46 | score = (2. * intersection + smooth) / (i + j + smooth) 47 | # score = (intersection + smooth) / (i + j - intersection + smooth)#iou 48 | return score.mean() 49 | 50 | def soft_dice_loss(self, y_true, y_pred): 51 | loss = 1 - self.soft_dice_coeff(y_true, y_pred) 52 | return loss 53 | 54 | def __call__(self, y_true, y_pred): 55 | 56 | b = self.soft_dice_loss(y_true, y_pred) 57 | return b 58 | 59 | 60 | def test_weight_cross_entropy(): 61 | N = 4 62 | C = 12 63 | H, W = 128, 128 64 | 65 | inputs = torch.rand(N, C, H, W) 66 | targets = torch.LongTensor(N, H, W).random_(C) 67 | inputs_fl = Variable(inputs.clone(), requires_grad=True) 68 | targets_fl = Variable(targets.clone()) 69 | print(weighted_cross_entropy()(targets_fl, inputs_fl)) 70 | 71 | 72 | class boundary_dice_bce_loss(nn.Module): 73 | def __init__(self, batch=True): 74 | super(boundary_dice_bce_loss, self).__init__() 75 | self.batch = batch 76 | self.bce_loss = nn.BCELoss() 77 | 78 | def soft_dice_coeff(self, y_true, y_pred): 79 | smooth = 0.0 # may change 80 | if self.batch: 81 | i = torch.sum(y_true) 82 | j = torch.sum(y_pred) 83 | intersection = torch.sum(y_true * y_pred) 84 | else: 85 | i = y_true.sum(1).sum(1).sum(1) 86 | j = y_pred.sum(1).sum(1).sum(1) 87 | intersection = (y_true * y_pred).sum(1).sum(1).sum(1) 88 | score = (2. * intersection + smooth) / (i + j + smooth) 89 | # score = (intersection + smooth) / (i + j - intersection + smooth)#iou 90 | return score.mean() 91 | 92 | def soft_dice_loss(self, y_true, y_pred): 93 | loss = 1 - self.soft_dice_coeff(y_true, y_pred) 94 | return loss 95 | 96 | def batch_mask_to_boundary(self, y_true, y_pred): 97 | batch, channel, height, width = y_true.size() 98 | y_true_boundary_batch = np.zeros(shape=(batch, height, width)) 99 | y_pred_boundary_batch = np.zeros(shape=(batch, height, width)) 100 | 101 | for i in range(batch): 102 | y_true_mask = y_true.cpu().detach().numpy()[i, 0, :, :] 103 | y_pred_mask = y_pred.cpu().detach().numpy()[i, 0, :, :] 104 | # shape (448, 448)(448, 448) 105 | y_true_mask_boundary = mask_to_boundary(y_true_mask) 106 | y_pred_mask_boundary = mask_to_boundary(y_pred_mask) 107 | 108 | y_true_boundary_batch[i, :, :] = y_true_mask_boundary 109 | y_pred_boundary_batch[i, :, :] = y_pred_mask_boundary 110 | 111 | y_true_boundary_mask = torch.Tensor(y_true_boundary_batch).cuda() 112 | y_pred_boundary_mask = torch.Tensor(y_pred_boundary_batch).cuda() 113 | # torch.Size([12, 448, 448]) 114 | return y_true_boundary_mask, y_pred_boundary_mask 115 | 116 | def __call__(self, y_true, y_pred): 117 | # a = self.bce_loss(y_pred, y_true) 118 | b = self.soft_dice_loss(y_true, y_pred) 119 | 120 | y_true_boundary, y_pred_boundary = self.batch_mask_to_boundary(y_true, y_pred) 121 | b_aux = self.soft_dice_loss(y_true_boundary, y_pred_boundary) 122 | return b + b_aux 123 | 124 | 125 | class dice_bce_loss(nn.Module): 126 | def __init__(self, batch=True): 127 | super(dice_bce_loss, self).__init__() 128 | self.batch = batch 129 | self.bce_loss = nn.BCELoss() 130 | 131 | def soft_dice_coeff(self, y_true, y_pred): 132 | smooth = 0.0 # may change 133 | if self.batch: 134 | i = torch.sum(y_true) 135 | j = torch.sum(y_pred) 136 | intersection = torch.sum(y_true * y_pred) 137 | else: 138 | i = y_true.sum(1).sum(1).sum(1) 139 | j = y_pred.sum(1).sum(1).sum(1) 140 | intersection = (y_true * y_pred).sum(1).sum(1).sum(1) 141 | score = (2. * intersection + smooth) / (i + j + smooth) 142 | # score = (intersection + smooth) / (i + j - intersection + smooth)#iou 143 | return score.mean() 144 | 145 | def soft_dice_loss(self, y_true, y_pred): 146 | loss = 1 - self.soft_dice_coeff(y_true, y_pred) 147 | return loss 148 | 149 | def __call__(self, y_true, y_pred): 150 | a = self.bce_loss(y_pred, y_true) 151 | b = self.soft_dice_loss(y_true, y_pred) 152 | 153 | return b 154 | 155 | 156 | import torch 157 | import torch.nn as nn 158 | 159 | 160 | class DiceLoss(nn.Module): 161 | def __init__(self): 162 | super(DiceLoss, self).__init__() 163 | 164 | def forward(self, input, target): 165 | N, H, W = target.size(0), target.size(2), target.size(3) 166 | smooth = 1 167 | 168 | input_flat = input.view(N, -1) 169 | target_flat = target.view(N, -1) 170 | 171 | intersection = input_flat * target_flat 172 | 173 | loss = 2 * (intersection.sum(1) + smooth) / (input_flat.sum(1) + target_flat.sum(1) + smooth) 174 | loss = 1 - loss.sum() / N 175 | 176 | return loss 177 | 178 | 179 | class MulticlassDiceLoss(nn.Module): 180 | """ 181 | requires one hot encoded target. Applies DiceLoss on each class iteratively. 182 | requires input.shape[0:1] and target.shape[0:1] to be (N, C) where N is 183 | batch size and C is number of classes 184 | """ 185 | 186 | def __init__(self): 187 | super(MulticlassDiceLoss, self).__init__() 188 | 189 | def forward(self, input, target, weights=None): 190 | 191 | C = target.shape[1] 192 | 193 | # if weights is None: 194 | # weights = torch.ones(C) #uniform weights for all classes 195 | 196 | dice = DiceLoss() 197 | totalLoss = 0 198 | 199 | for i in range(C): 200 | diceLoss = dice(input[:, i, :, :], target[:, i, :, :]) 201 | if weights is not None: 202 | diceLoss *= weights[i] 203 | totalLoss += diceLoss 204 | 205 | return totalLoss 206 | 207 | 208 | class FocalLoss(nn.Module): 209 | def __init__(self, gamma=0, alpha=None, size_average=True): 210 | super(FocalLoss, self).__init__() 211 | self.gamma = gamma 212 | self.alpha = alpha 213 | if isinstance(alpha, (float, int)): self.alpha = torch.Tensor([alpha, 1 - alpha]) 214 | if isinstance(alpha, list): self.alpha = torch.Tensor(alpha) 215 | self.size_average = size_average 216 | 217 | def forward(self, target, input): 218 | target1 = torch.squeeze(target, dim=1) 219 | if input.dim() > 2: 220 | input = input.view(input.size(0), input.size(1), -1) # N,C,H,W => N,C,H*W 221 | input = input.transpose(1, 2) # N,C,H*W => N,H*W,C 222 | input = input.contiguous().view(-1, input.size(2)) # N,H*W,C => N*H*W,C 223 | target2 = target1.view(-1, 1).long() 224 | 225 | logpt = F.log_softmax(input, dim=1) 226 | # print(logpt.size()) 227 | # print(target2.size()) 228 | logpt = logpt.gather(1, target2) 229 | logpt = logpt.view(-1) 230 | pt = Variable(logpt.data.exp()) 231 | 232 | if self.alpha is not None: 233 | if self.alpha.type() != input.data.type(): 234 | self.alpha = self.alpha.type_as(input.data) 235 | at = self.alpha.gather(0, target.data.view(-1)) 236 | logpt = logpt * Variable(at) 237 | 238 | loss = -1 * (1 - pt) ** self.gamma * logpt 239 | if self.size_average: 240 | return loss.mean() 241 | else: 242 | return loss.sum() 243 | -------------------------------------------------------------------------------- /src/lib/models/model.py: -------------------------------------------------------------------------------- 1 | # *coding:utf-8 * 2 | 3 | import torch 4 | 5 | from .networks.cenet import CE_Net_ 6 | 7 | _model_factory = { 8 | 'cenet': CE_Net_ 9 | } 10 | 11 | 12 | def create_model(model_name): 13 | get_model = _model_factory[model_name] 14 | model = get_model() 15 | return model 16 | 17 | 18 | def load_model(model, model_path, optimizer=None, resume=False, 19 | lr=None, lr_step=None): 20 | start_epoch = 0 21 | checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage) 22 | print('loaded {}, epoch {}'.format(model_path, checkpoint['epoch'])) 23 | state_dict_ = checkpoint['state_dict'] 24 | state_dict = {} 25 | 26 | # convert data_parallal to model 27 | for k in state_dict_: 28 | if k.startswith('module') and not k.startswith('module_list'): 29 | state_dict[k[7:]] = state_dict_[k] 30 | else: 31 | state_dict[k] = state_dict_[k] 32 | model_state_dict = model.state_dict() 33 | 34 | # check loaded parameters and created model parameters 35 | msg = 'If you see this, your model does not fully load the ' + \ 36 | 'pre-trained weight. Please make sure ' + \ 37 | 'you have correctly specified --arch xxx ' + \ 38 | 'or set the correct --num_classes for your own dataset.' 39 | for k in state_dict: 40 | if k in model_state_dict: 41 | if state_dict[k].shape != model_state_dict[k].shape: 42 | print('Skip loading parameter {}, required shape{}, ' \ 43 | 'loaded shape{}. {}'.format( 44 | k, model_state_dict[k].shape, state_dict[k].shape, msg)) 45 | state_dict[k] = model_state_dict[k] 46 | else: 47 | print('Drop parameter {}.'.format(k) + msg) 48 | for k in model_state_dict: 49 | if not (k in state_dict): 50 | print('No param {}.'.format(k) + msg) 51 | state_dict[k] = model_state_dict[k] 52 | model.load_state_dict(state_dict, strict=False) 53 | 54 | # resume optimizer parameters 55 | if optimizer is not None and resume: 56 | if 'optimizer' in checkpoint: 57 | optimizer.load_state_dict(checkpoint['optimizer']) 58 | start_epoch = checkpoint['epoch'] 59 | start_lr = lr 60 | for step in lr_step: 61 | if start_epoch >= step: 62 | start_lr *= 0.1 63 | for param_group in optimizer.param_groups: 64 | param_group['lr'] = start_lr 65 | print('Resumed optimizer with start lr', start_lr) 66 | else: 67 | print('No optimizer parameters in checkpoint.') 68 | if optimizer is not None: 69 | return model, optimizer, start_epoch 70 | else: 71 | return model 72 | 73 | 74 | def save_model(path, epoch, model, optimizer=None): 75 | if isinstance(model, torch.nn.DataParallel): 76 | state_dict = model.module.state_dict() 77 | else: 78 | state_dict = model.state_dict() 79 | data = {'epoch': epoch, 80 | 'state_dict': state_dict} 81 | if not (optimizer is None): 82 | data['optimizer'] = optimizer.state_dict() 83 | torch.save(data, path) 84 | -------------------------------------------------------------------------------- /src/lib/models/networks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/src/lib/models/networks/__init__.py -------------------------------------------------------------------------------- /src/lib/models/networks/backbones/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/src/lib/models/networks/backbones/__init__.py -------------------------------------------------------------------------------- /src/lib/models/networks/backbones/backbone_factory.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | 6 | from .mobilenet.mobilenet_factory import get_mobilenet_backbone 7 | from .resnet.resnet_factory import get_resnet_backbone 8 | 9 | 10 | def get_backbone_architecture(backbone_name): 11 | if "resnet" in backbone_name: 12 | return get_resnet_backbone(backbone_name) 13 | 14 | elif "mobilenet" in backbone_name: 15 | return get_mobilenet_backbone(backbone_name) -------------------------------------------------------------------------------- /src/lib/models/networks/backbones/mobilenet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/src/lib/models/networks/backbones/mobilenet/__init__.py -------------------------------------------------------------------------------- /src/lib/models/networks/backbones/mobilenet/basic_module.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from torch import nn 6 | 7 | from collections import OrderedDict 8 | 9 | def _make_divisible(v, divisor, min_value=None): 10 | """ 11 | This function is taken from the original pytorch repo 12 | It ensures that all layers have a channel number that is divisable by 8 13 | it can be seen here: 14 | https://pytorch.org/vision/stable/_modules/torchvision/models/mobilenetv2.html 15 | :param v: 16 | :param divisor: 17 | :param min_value: 18 | :return: 19 | """ 20 | 21 | if min_value is None: 22 | min_value = divisor 23 | 24 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 25 | 26 | if new_v < 0.9 * v: 27 | new_v += divisor 28 | 29 | return new_v 30 | 31 | 32 | class ConvBNReLU(nn.Sequential): 33 | def __init__(self, in_planes, out_planes, kernel_size=3, 34 | stride=1, group=1): 35 | padding = (kernel_size - 1) // 2 36 | super(ConvBNReLU, self).__init__( 37 | nn.Conv2d(in_planes, out_planes, kernel_size, 38 | stride, padding), 39 | nn.BatchNorm2d(out_planes), 40 | nn.ReLU6(in_planes) 41 | ) 42 | 43 | 44 | class InvertedResidual(nn.Module): 45 | def __init__(self, inp, oup, stride, expand_ratio): 46 | super(InvertedResidual, self).__init__() 47 | self.stride = stride 48 | assert stride in [1, 2] 49 | 50 | hidden_dim = int(round(inp * expand_ratio)) 51 | 52 | self.use_res_connect = self.stride == 1 and inp == oup 53 | 54 | layers = [] 55 | 56 | if expand_ratio != 1: 57 | # pixelwise 58 | layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) 59 | 60 | layers.extend([ 61 | # depthwise 62 | ConvBNReLU(hidden_dim, hidden_dim, stride=stride, 63 | group=hidden_dim), 64 | # pw-linear 65 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 66 | nn.BatchNorm2d(oup) 67 | ]) 68 | 69 | self.conv = nn.Sequential(*layers) 70 | 71 | def forward(self, x): 72 | if self.use_res_connect: 73 | return x + self.conv(x) 74 | 75 | else: 76 | return self.conv(x) 77 | 78 | 79 | def load_model(model, state_dict): 80 | new_model = model.state_dict() 81 | new_keys = list(new_model.keys()) 82 | old_keys = list(state_dict.keys()) 83 | 84 | restore_dict = OrderedDict() 85 | 86 | for id in range(len(new_keys)): 87 | restore_dict[new_keys[id]] = state_dict[old_keys[id]] 88 | 89 | model.load_state_dict(restore_dict) 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | -------------------------------------------------------------------------------- /src/lib/models/networks/backbones/mobilenet/build_mobilenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.utils.model_zoo as model_zoo 4 | 5 | from .basic_module import ConvBNReLU, InvertedResidual, _make_divisible, load_model 6 | 7 | model_urls = { 8 | 'mobilenetv2_10': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth', 9 | } 10 | 11 | 12 | class MobileNetV2(nn.Module): 13 | def __init__(self, width_mult=1.0, round_nearest=8): 14 | super(MobileNetV2, self).__init__() 15 | block = InvertedResidual 16 | input_channel = 32 17 | 18 | inverted_residual_setting = [ 19 | # t, c, n, s 20 | [1, 16, 1, 1], # 0 21 | [6, 24, 2, 2], # 1 22 | [6, 32, 3, 2], # 2 23 | [6, 64, 4, 2], # 3 24 | [6, 96, 3, 1], # 4 25 | [6, 160, 3, 2], # 5 26 | [6, 320, 1, 1], # 6 27 | ] 28 | 29 | self.feat_id = [1, 2, 4, 6] 30 | self.feat_channel = [] 31 | 32 | # only check the first element, assuming user know t,c,n,s 33 | # are required 34 | 35 | if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4: 36 | raise ValueError("inverted_residual_setting should be non-empty " 37 | "or a 4-element list, got {}".format(inverted_residual_setting)) 38 | 39 | # building first layer 40 | input_channel = _make_divisible(input_channel * width_mult, round_nearest) 41 | 42 | features = [ConvBNReLU(3, input_channel, stride=2)] 43 | 44 | # building inverted residual blocks 45 | 46 | for id, (t, c, n, s) in enumerate(inverted_residual_setting): 47 | output_channel = _make_divisible(c * width_mult, round_nearest) 48 | for i in range(n): 49 | stride = s if i == 0 else 1 50 | features.append(block(input_channel, output_channel, stride, expand_ratio=t)) 51 | input_channel = output_channel 52 | if id in self.feat_id: 53 | self.__setattr__("layer%d" % id, nn.Sequential(*features)) 54 | self.feat_channel.append(output_channel) 55 | features = [] 56 | 57 | # weight initialization 58 | for m in self.modules(): 59 | if isinstance(m, nn.Conv2d): 60 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 61 | if m.bias is not None: 62 | nn.init.zeros_(m.bias) 63 | elif isinstance(m, nn.BatchNorm2d): 64 | nn.init.ones_(m.weight) 65 | nn.init.zeros_(m.bias) 66 | 67 | def forward(self, x): 68 | y = [] 69 | for id in self.feat_id: 70 | x = self.__getattr__("layer%d" % id)(x) 71 | y.append(x) 72 | 73 | return y 74 | 75 | def init_weights(self, model_name): 76 | url = model_urls[model_name] 77 | pretrained_state_dict = model_zoo.load_url(url) 78 | print('=> loading pretrained model {}'.format(url)) 79 | self.load_state_dict(pretrained_state_dict, strict=False) 80 | 81 | 82 | def get_mobilenetv2_10(pretrained=True, **kwargs): 83 | model = MobileNetV2(width_mult=1.0) 84 | if pretrained: 85 | model.init_weights(model_name='mobilenetv2_10') 86 | 87 | return model 88 | 89 | 90 | def get_mobilenetv2_5(pretrained=True, **kwargs): 91 | model = MobileNetV2(width_mult=0.5) 92 | if pretrained: 93 | print("MobilenetV2_5 does not have the pretrained weight") 94 | 95 | return model 96 | 97 | if __name__ == '__main__': 98 | 99 | model = get_mobilenetv2_10(pretrained=True) 100 | 101 | input = torch.zeros([1, 3, 512, 512]) 102 | feats = model(input) 103 | print(feats[0].size()) 104 | print(feats[1].size()) 105 | print(feats[2].size()) 106 | print(feats[3].size()) 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | -------------------------------------------------------------------------------- /src/lib/models/networks/backbones/mobilenet/mobilenet_factory.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from .build_mobilenet import get_mobilenetv2_10 6 | from .build_mobilenet import get_mobilenetv2_5 7 | 8 | _mobilenet_backbone = { 9 | 'mobilenetv2_10': get_mobilenetv2_10, 10 | 'mobilenetv2_5': get_mobilenetv2_5, 11 | 12 | } 13 | 14 | 15 | def get_mobilenet_backbone(model_name): 16 | support_mobile_models = ['mobilenetv2_10', 'mobilenetv2_5'] 17 | assert model_name in support_mobile_models, "We just support the following models: {}".format(support_mobile_models) 18 | 19 | model = _mobilenet_backbone[model_name] 20 | 21 | return model 22 | 23 | 24 | if __name__ == '__main__': 25 | str1 = 'mobilenetv2_10' 26 | model = get_mobilenet_backbone(str1) 27 | 28 | print(model(pretrain=True)) 29 | -------------------------------------------------------------------------------- /src/lib/models/networks/backbones/resnet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/src/lib/models/networks/backbones/resnet/__init__.py -------------------------------------------------------------------------------- /src/lib/models/networks/backbones/resnet/basic_module.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft 3 | # Licensed under the MIT License. 4 | # Written by Bin Xiao (Bin.Xiao@microsoft.com) 5 | # Modified by Dequan Wang and Xingyi Zhou 6 | # ------------------------------------------------------------------------------ 7 | 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | 12 | import torch.nn as nn 13 | 14 | BN_MOMENTUM = 0.1 15 | 16 | 17 | def conv3x3(in_planes, out_planes, stride=1): 18 | """ 19 | 3x3 convolution with padding 20 | """ 21 | 22 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 23 | padding=1, bias=False) 24 | 25 | 26 | class BasicBlock(nn.Module): 27 | expansion = 1 28 | 29 | def __init__(self, inplanes, planes, stride=1, downsample=None): 30 | super(BasicBlock, self).__init__() 31 | self.conv1 = conv3x3(inplanes, planes, stride) 32 | self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) 33 | self.relu = nn.ReLU(inplace=True) 34 | 35 | self.conv2 = conv3x3(planes, planes) 36 | self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) 37 | self.downsample = downsample 38 | self.stride = stride 39 | 40 | def forward(self, x): 41 | residual = x 42 | 43 | out = self.relu(self.bn1(self.conv1(x))) 44 | out = self.bn2(self.conv2(out)) 45 | 46 | if self.downsample is not None: 47 | residual = self.downsample(x) 48 | 49 | out += residual 50 | out = self.relu(out) 51 | 52 | return out 53 | 54 | 55 | class Bottleneck(nn.Module): 56 | expansion = 4 57 | 58 | def __init__(self, inplanes, planes, stride=1, downsample=None): 59 | super(Bottleneck, self).__init__() 60 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 61 | self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) 62 | 63 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 64 | padding=1, bias=False) 65 | self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) 66 | 67 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 68 | self.bn3 = nn.BatchNorm2d(planes * self.expansion, momentum=BN_MOMENTUM) 69 | 70 | self.relu = nn.ReLU(inplace=True) 71 | self.downsample = downsample 72 | self.stride = stride 73 | 74 | def forward(self, x): 75 | residual = x 76 | 77 | out = self.relu(self.bn1(self.conv1(x))) 78 | 79 | out = self.conv2(self.bn2(self.conv2(out))) 80 | 81 | out = self.bn3(self.conv3(out)) 82 | 83 | if self.downsample is not None: 84 | residual = self.downsample(x) 85 | 86 | out += residual 87 | out = self.relu(out) 88 | 89 | return out 90 | -------------------------------------------------------------------------------- /src/lib/models/networks/backbones/resnet/build_resnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from .basic_module import BasicBlock, Bottleneck, BN_MOMENTUM 6 | 7 | import torch.nn as nn 8 | import torch.utils.model_zoo as model_zoo 9 | 10 | model_urls = { 11 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 12 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 13 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 14 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 15 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 16 | } 17 | 18 | 19 | class build_resnet(nn.Module): 20 | def __init__(self, block, layers): 21 | super(build_resnet, self).__init__() 22 | self.inplanes = 64 23 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 24 | self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) 25 | self.relu = nn.ReLU(inplace=True) 26 | 27 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 28 | 29 | self.layer1 = self._make_layers(block, 64, layers[0]) 30 | self.layer2 = self._make_layers(block, 128, layers[1], stride=2) 31 | self.layer3 = self._make_layers(block, 256, layers[2], stride=2) 32 | self.layer4 = self._make_layers(block, 512, layers[3], stride=2) 33 | 34 | def _make_layers(self, block, planes, blocks, stride=1): 35 | downsample = None 36 | 37 | if stride != 1 or self.inplanes != planes * block.expansion: 38 | downsample = nn.Sequential( 39 | nn.Conv2d(self.inplanes, planes * block.expansion, 40 | kernel_size=1, stride=stride, bias=False), 41 | nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM), 42 | ) 43 | 44 | layers = [] 45 | layers.append(block(self.inplanes, planes, stride, downsample)) 46 | 47 | self.inplanes = planes * block.expansion 48 | 49 | for i in range(1, blocks): 50 | layers.append(block(self.inplanes, planes)) 51 | 52 | return nn.Sequential(*layers) 53 | 54 | def forward(self, x): 55 | x = self.conv1(x) 56 | x = self.bn1(x) 57 | x = self.relu(x) 58 | x = self.maxpool(x) 59 | 60 | x = self.layer1(x) 61 | stage_1_feature = x 62 | x = self.layer2(x) 63 | stage_2_feature = x 64 | x = self.layer3(x) 65 | stage_3_feature = x 66 | x = self.layer4(x) 67 | stage_4_feature = x 68 | 69 | return x 70 | 71 | def init_weights(self, resnet_model_name): 72 | url = model_urls[resnet_model_name] 73 | pretrained_state_dict = model_zoo.load_url(url) 74 | print('=> loading pretrained model {}'.format(url)) 75 | self.load_state_dict(pretrained_state_dict, strict=False) 76 | 77 | 78 | # Resnet specific parameters 79 | # resnet_spec = {18: (BasicBlock, [2, 2, 2, 2]), 80 | # 34: (BasicBlock, [3, 4, 6, 3]), 81 | # 50: (Bottleneck, [3, 4, 6, 3]), 82 | # 101: (Bottleneck, [3, 4, 23, 3]), 83 | # 152: (Bottleneck, [3, 8, 36, 3])} 84 | 85 | def get_resnet_18(pretrain=True): 86 | model = build_resnet(BasicBlock, [2, 2, 2, 2]) 87 | if pretrain: 88 | model.init_weights(resnet_model_name='resnet18') 89 | 90 | return model 91 | 92 | 93 | def get_resnet_34(pretrain=True): 94 | model = build_resnet(BasicBlock, [3, 4, 6, 3]) 95 | if pretrain: 96 | model.init_weights(resnet_model_name='resnet34') 97 | return model 98 | 99 | 100 | def get_resnet_50(pretrain=True): 101 | model = build_resnet(Bottleneck, [3, 4, 6, 3]) 102 | if pretrain: 103 | model.init_weights(resnet_model_name='resnet50') 104 | return model 105 | 106 | 107 | def get_resnet_101(pretrain=True): 108 | model = build_resnet(Bottleneck, [3, 4, 23, 3]) 109 | if pretrain: 110 | model.init_weights(resnet_model_name='resnet101') 111 | return model 112 | 113 | 114 | def get_resnet_152(pretrain=True): 115 | model = build_resnet(Bottleneck, [3, 8, 36, 3]) 116 | if pretrain: 117 | model.init_weights(resnet_model_name='resnet152') 118 | return model 119 | 120 | 121 | if __name__ == '__main__': 122 | resnet_model = get_resnet_34(pretrain=True) 123 | print(resnet_model) 124 | -------------------------------------------------------------------------------- /src/lib/models/networks/backbones/resnet/resnet_factory.py: -------------------------------------------------------------------------------- 1 | # *coding:utf-8 * 2 | 3 | 4 | from .build_resnet import get_resnet_18 5 | from .build_resnet import get_resnet_34 6 | from .build_resnet import get_resnet_50 7 | from .build_resnet import get_resnet_101 8 | from .build_resnet import get_resnet_152 9 | 10 | _resnet_backbone = { 11 | 'resnet18': get_resnet_18, 12 | 'resnet34': get_resnet_34, 13 | 'resnet50': get_resnet_50, 14 | 'resnet101': get_resnet_101, 15 | 'resnet152': get_resnet_152, 16 | 17 | } 18 | 19 | 20 | def get_resnet_backbone(model_name): 21 | support_resnet_models = ['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152'] 22 | assert model_name in support_resnet_models, "We just support the following models: {}".format(support_resnet_models) 23 | 24 | model = _resnet_backbone[model_name] 25 | 26 | return model 27 | 28 | if __name__ == '__main__': 29 | str1 = 'resnet18' 30 | model = get_resnet_backbone(str1) 31 | 32 | print(model(pretrain=True)) -------------------------------------------------------------------------------- /src/lib/models/networks/cenet.py: -------------------------------------------------------------------------------- 1 | # *coding:utf-8 * 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torchvision import models 6 | import torch.nn.functional as F 7 | 8 | from .backbones.resnet.resnet_factory import get_resnet_backbone 9 | 10 | 11 | from functools import partial 12 | 13 | nonlinearity = partial(F.relu, inplace=True) 14 | 15 | class DACblock(nn.Module): 16 | def __init__(self, channel): 17 | super(DACblock, self).__init__() 18 | self.dilate1 = nn.Conv2d(channel, channel, kernel_size=3, dilation=1, padding=1) 19 | self.dilate2 = nn.Conv2d(channel, channel, kernel_size=3, dilation=3, padding=3) 20 | self.dilate3 = nn.Conv2d(channel, channel, kernel_size=3, dilation=5, padding=5) 21 | self.conv1x1 = nn.Conv2d(channel, channel, kernel_size=1, dilation=1, padding=0) 22 | for m in self.modules(): 23 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 24 | if m.bias is not None: 25 | m.bias.data.zero_() 26 | 27 | def forward(self, x): 28 | dilate1_out = nonlinearity(self.dilate1(x)) 29 | dilate2_out = nonlinearity(self.conv1x1(self.dilate2(x))) 30 | dilate3_out = nonlinearity(self.conv1x1(self.dilate2(self.dilate1(x)))) 31 | dilate4_out = nonlinearity(self.conv1x1(self.dilate3(self.dilate2(self.dilate1(x))))) 32 | out = x + dilate1_out + dilate2_out + dilate3_out + dilate4_out 33 | return out 34 | 35 | 36 | class DACblock_without_atrous(nn.Module): 37 | def __init__(self, channel): 38 | super(DACblock_without_atrous, self).__init__() 39 | self.dilate1 = nn.Conv2d(channel, channel, kernel_size=3, dilation=1, padding=1) 40 | self.dilate2 = nn.Conv2d(channel, channel, kernel_size=3, dilation=1, padding=1) 41 | self.dilate3 = nn.Conv2d(channel, channel, kernel_size=3, dilation=1, padding=1) 42 | self.conv1x1 = nn.Conv2d(channel, channel, kernel_size=1, dilation=1, padding=0) 43 | for m in self.modules(): 44 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 45 | if m.bias is not None: 46 | m.bias.data.zero_() 47 | 48 | def forward(self, x): 49 | dilate1_out = nonlinearity(self.dilate1(x)) 50 | dilate2_out = nonlinearity(self.conv1x1(self.dilate2(x))) 51 | dilate3_out = nonlinearity(self.conv1x1(self.dilate2(self.dilate1(x)))) 52 | dilate4_out = nonlinearity(self.conv1x1(self.dilate3(self.dilate2(self.dilate1(x))))) 53 | out = x + dilate1_out + dilate2_out + dilate3_out + dilate4_out 54 | 55 | return out 56 | 57 | class DACblock_with_inception(nn.Module): 58 | def __init__(self, channel): 59 | super(DACblock_with_inception, self).__init__() 60 | self.dilate1 = nn.Conv2d(channel, channel, kernel_size=1, dilation=1, padding=0) 61 | 62 | self.dilate3 = nn.Conv2d(channel, channel, kernel_size=3, dilation=1, padding=1) 63 | self.conv1x1 = nn.Conv2d(2 * channel, channel, kernel_size=1, dilation=1, padding=0) 64 | for m in self.modules(): 65 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 66 | if m.bias is not None: 67 | m.bias.data.zero_() 68 | 69 | def forward(self, x): 70 | dilate1_out = nonlinearity(self.dilate1(x)) 71 | dilate2_out = nonlinearity(self.dilate3(self.dilate1(x))) 72 | dilate_concat = nonlinearity(self.conv1x1(torch.cat([dilate1_out, dilate2_out], 1))) 73 | dilate3_out = nonlinearity(self.dilate1(dilate_concat)) 74 | out = x + dilate3_out 75 | return out 76 | 77 | 78 | class DACblock_with_inception_blocks(nn.Module): 79 | def __init__(self, channel): 80 | super(DACblock_with_inception_blocks, self).__init__() 81 | self.conv1x1 = nn.Conv2d(channel, channel, kernel_size=1, dilation=1, padding=0) 82 | self.conv3x3 = nn.Conv2d(channel, channel, kernel_size=3, dilation=1, padding=1) 83 | self.conv5x5 = nn.Conv2d(channel, channel, kernel_size=5, dilation=1, padding=2) 84 | self.pooling = nn.MaxPool2d(kernel_size=3, stride=1, padding=1) 85 | 86 | for m in self.modules(): 87 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 88 | if m.bias is not None: 89 | m.bias.data.zero_() 90 | 91 | def forward(self, x): 92 | dilate1_out = nonlinearity(self.conv1x1(x)) 93 | dilate2_out = nonlinearity(self.conv3x3(self.conv1x1(x))) 94 | dilate3_out = nonlinearity(self.conv5x5(self.conv1x1(x))) 95 | dilate4_out = self.pooling(x) 96 | out = dilate1_out + dilate2_out + dilate3_out + dilate4_out 97 | return out 98 | 99 | 100 | 101 | class PSPModule(nn.Module): 102 | def __init__(self, features, out_features=1024, sizes=(2, 3, 6, 14)): 103 | super().__init__() 104 | self.stages = [] 105 | self.stages = nn.ModuleList([self._make_stage(features, size) for size in sizes]) 106 | self.bottleneck = nn.Conv2d(features * (len(sizes) + 1), out_features, kernel_size=1) 107 | self.relu = nn.ReLU() 108 | 109 | def _make_stage(self, features, size): 110 | prior = nn.AdaptiveAvgPool2d(output_size=(size, size)) 111 | conv = nn.Conv2d(features, features, kernel_size=1, bias=False) 112 | return nn.Sequential(prior, conv) 113 | 114 | def forward(self, feats): 115 | h, w = feats.size(2), feats.size(3) 116 | priors = [F.upsample(input=stage(feats), size=(h, w), mode='bilinear') for stage in self.stages] + [feats] 117 | bottle = self.bottleneck(torch.cat(priors, 1)) 118 | return self.relu(bottle) 119 | 120 | 121 | class SPPblock(nn.Module): 122 | def __init__(self, in_channels): 123 | super(SPPblock, self).__init__() 124 | self.pool1 = nn.MaxPool2d(kernel_size=[2, 2], stride=2) 125 | self.pool2 = nn.MaxPool2d(kernel_size=[3, 3], stride=3) 126 | self.pool3 = nn.MaxPool2d(kernel_size=[5, 5], stride=5) 127 | self.pool4 = nn.MaxPool2d(kernel_size=[6, 6], stride=6) 128 | 129 | self.conv = nn.Conv2d(in_channels=in_channels, out_channels=1, kernel_size=1, padding=0) 130 | 131 | def forward(self, x): 132 | self.in_channels, h, w = x.size(1), x.size(2), x.size(3) 133 | self.layer1 = F.upsample(self.conv(self.pool1(x)), size=(h, w), mode='bilinear') 134 | self.layer2 = F.upsample(self.conv(self.pool2(x)), size=(h, w), mode='bilinear') 135 | self.layer3 = F.upsample(self.conv(self.pool3(x)), size=(h, w), mode='bilinear') 136 | self.layer4 = F.upsample(self.conv(self.pool4(x)), size=(h, w), mode='bilinear') 137 | 138 | out = torch.cat([self.layer1, self.layer2, self.layer3, self.layer4, x], 1) 139 | 140 | return out 141 | 142 | 143 | class DecoderBlock(nn.Module): 144 | def __init__(self, in_channels, n_filters): 145 | super(DecoderBlock, self).__init__() 146 | 147 | self.conv1 = nn.Conv2d(in_channels, in_channels // 4, 1) 148 | self.norm1 = nn.BatchNorm2d(in_channels // 4) 149 | self.relu1 = nonlinearity 150 | 151 | self.deconv2 = nn.ConvTranspose2d(in_channels // 4, in_channels // 4, 3, stride=2, padding=1, output_padding=1) 152 | self.norm2 = nn.BatchNorm2d(in_channels // 4) 153 | self.relu2 = nonlinearity 154 | 155 | self.conv3 = nn.Conv2d(in_channels // 4, n_filters, 1) 156 | self.norm3 = nn.BatchNorm2d(n_filters) 157 | self.relu3 = nonlinearity 158 | 159 | def forward(self, x): 160 | x = self.conv1(x) 161 | x = self.norm1(x) 162 | x = self.relu1(x) 163 | x = self.deconv2(x) 164 | x = self.norm2(x) 165 | x = self.relu2(x) 166 | x = self.conv3(x) 167 | x = self.norm3(x) 168 | x = self.relu3(x) 169 | return x 170 | 171 | 172 | class CE_Net_(nn.Module): 173 | def __init__(self, num_classes=1, num_channels=3): 174 | super(CE_Net_, self).__init__() 175 | filters = [64, 128, 256, 512] 176 | # resnet = models.resnet34(pretrained=True) 177 | resnet = get_resnet_backbone('resnet34')(pretrain=True) 178 | self.firstconv = resnet.conv1 179 | self.firstbn = resnet.bn1 180 | self.firstrelu = resnet.relu 181 | self.firstmaxpool = resnet.maxpool 182 | self.encoder1 = resnet.layer1 183 | self.encoder2 = resnet.layer2 184 | self.encoder3 = resnet.layer3 185 | self.encoder4 = resnet.layer4 186 | 187 | self.dblock = DACblock(512) 188 | self.spp = SPPblock(512) 189 | 190 | self.decoder4 = DecoderBlock(516, filters[2]) 191 | self.decoder3 = DecoderBlock(filters[2], filters[1]) 192 | self.decoder2 = DecoderBlock(filters[1], filters[0]) 193 | self.decoder1 = DecoderBlock(filters[0], filters[0]) 194 | 195 | self.finaldeconv1 = nn.ConvTranspose2d(filters[0], 32, 4, 2, 1) 196 | self.finalrelu1 = nonlinearity 197 | self.finalconv2 = nn.Conv2d(32, 32, 3, padding=1) 198 | self.finalrelu2 = nonlinearity 199 | self.finalconv3 = nn.Conv2d(32, num_classes, 3, padding=1) 200 | 201 | def forward(self, x): 202 | # Encoder 203 | x = self.firstconv(x) 204 | x = self.firstbn(x) 205 | x = self.firstrelu(x) 206 | x = self.firstmaxpool(x) 207 | e1 = self.encoder1(x) 208 | e2 = self.encoder2(e1) 209 | e3 = self.encoder3(e2) 210 | e4 = self.encoder4(e3) 211 | 212 | # Center 213 | e4 = self.dblock(e4) 214 | e4 = self.spp(e4) 215 | 216 | # Decoder 217 | d4 = self.decoder4(e4) + e3 218 | d3 = self.decoder3(d4) + e2 219 | d2 = self.decoder2(d3) + e1 220 | d1 = self.decoder1(d2) 221 | 222 | 223 | out = self.finaldeconv1(d1) 224 | out = self.finalrelu1(out) 225 | out = self.finalconv2(out) 226 | out = self.finalrelu2(out) 227 | out = self.finalconv3(out) 228 | 229 | return torch.sigmoid(out) 230 | 231 | 232 | class CE_Net_backbone_DAC_without_atrous(nn.Module): 233 | def __init__(self, num_classes=1, num_channels=3): 234 | super(CE_Net_backbone_DAC_without_atrous, self).__init__() 235 | 236 | filters = [64, 128, 256, 512] 237 | resnet = models.resnet34(pretrained=True) 238 | self.firstconv = resnet.conv1 239 | self.firstbn = resnet.bn1 240 | self.firstrelu = resnet.relu 241 | self.firstmaxpool = resnet.maxpool 242 | self.encoder1 = resnet.layer1 243 | self.encoder2 = resnet.layer2 244 | self.encoder3 = resnet.layer3 245 | self.encoder4 = resnet.layer4 246 | 247 | self.dblock = DACblock_without_atrous(512) 248 | 249 | 250 | self.decoder4 = DecoderBlock(512, filters[2]) 251 | self.decoder3 = DecoderBlock(filters[2], filters[1]) 252 | self.decoder2 = DecoderBlock(filters[1], filters[0]) 253 | self.decoder1 = DecoderBlock(filters[0], filters[0]) 254 | 255 | self.finaldeconv1 = nn.ConvTranspose2d(filters[0], 32, 4, 2, 1) 256 | self.finalrelu1 = nonlinearity 257 | self.finalconv2 = nn.Conv2d(32, 32, 3, padding=1) 258 | self.finalrelu2 = nonlinearity 259 | self.finalconv3 = nn.Conv2d(32, num_classes, 3, padding=1) 260 | 261 | def forward(self, x): 262 | # Encoder 263 | x = self.firstconv(x) 264 | x = self.firstbn(x) 265 | x = self.firstrelu(x) 266 | x = self.firstmaxpool(x) 267 | e1 = self.encoder1(x) 268 | e2 = self.encoder2(e1) 269 | e3 = self.encoder3(e2) 270 | e4 = self.encoder4(e3) 271 | 272 | # Center 273 | e4 = self.dblock(e4) 274 | # e4 = self.spp(e4) 275 | 276 | # Decoder 277 | d4 = self.decoder4(e4) + e3 278 | d3 = self.decoder3(d4) + e2 279 | d2 = self.decoder2(d3) + e1 280 | d1 = self.decoder1(d2) 281 | 282 | out = self.finaldeconv1(d1) 283 | out = self.finalrelu1(out) 284 | out = self.finalconv2(out) 285 | out = self.finalrelu2(out) 286 | out = self.finalconv3(out) 287 | 288 | return torch.sigmoid(out) 289 | 290 | class CE_Net_backbone_DAC_with_inception(nn.Module): 291 | def __init__(self, num_classes=1, num_channels=3): 292 | super(CE_Net_backbone_DAC_with_inception, self).__init__() 293 | 294 | filters = [64, 128, 256, 512] 295 | resnet = models.resnet34(pretrained=True) 296 | self.firstconv = resnet.conv1 297 | self.firstbn = resnet.bn1 298 | self.firstrelu = resnet.relu 299 | self.firstmaxpool = resnet.maxpool 300 | self.encoder1 = resnet.layer1 301 | self.encoder2 = resnet.layer2 302 | self.encoder3 = resnet.layer3 303 | self.encoder4 = resnet.layer4 304 | 305 | self.dblock = DACblock_with_inception(512) 306 | 307 | 308 | self.decoder4 = DecoderBlock(512, filters[2]) 309 | self.decoder3 = DecoderBlock(filters[2], filters[1]) 310 | self.decoder2 = DecoderBlock(filters[1], filters[0]) 311 | self.decoder1 = DecoderBlock(filters[0], filters[0]) 312 | 313 | self.finaldeconv1 = nn.ConvTranspose2d(filters[0], 32, 4, 2, 1) 314 | self.finalrelu1 = nonlinearity 315 | self.finalconv2 = nn.Conv2d(32, 32, 3, padding=1) 316 | self.finalrelu2 = nonlinearity 317 | self.finalconv3 = nn.Conv2d(32, num_classes, 3, padding=1) 318 | 319 | def forward(self, x): 320 | # Encoder 321 | x = self.firstconv(x) 322 | x = self.firstbn(x) 323 | x = self.firstrelu(x) 324 | x = self.firstmaxpool(x) 325 | e1 = self.encoder1(x) 326 | e2 = self.encoder2(e1) 327 | e3 = self.encoder3(e2) 328 | e4 = self.encoder4(e3) 329 | 330 | # Center 331 | e4 = self.dblock(e4) 332 | # e4 = self.spp(e4) 333 | 334 | # Decoder 335 | d4 = self.decoder4(e4) + e3 336 | d3 = self.decoder3(d4) + e2 337 | d2 = self.decoder2(d3) + e1 338 | d1 = self.decoder1(d2) 339 | 340 | out = self.finaldeconv1(d1) 341 | out = self.finalrelu1(out) 342 | out = self.finalconv2(out) 343 | out = self.finalrelu2(out) 344 | out = self.finalconv3(out) 345 | 346 | return torch.sigmoid(out) 347 | 348 | class CE_Net_backbone_inception_blocks(nn.Module): 349 | def __init__(self, num_classes=1, num_channels=3): 350 | super(CE_Net_backbone_inception_blocks, self).__init__() 351 | 352 | filters = [64, 128, 256, 512] 353 | resnet = models.resnet34(pretrained=True) 354 | self.firstconv = resnet.conv1 355 | self.firstbn = resnet.bn1 356 | self.firstrelu = resnet.relu 357 | self.firstmaxpool = resnet.maxpool 358 | self.encoder1 = resnet.layer1 359 | self.encoder2 = resnet.layer2 360 | self.encoder3 = resnet.layer3 361 | self.encoder4 = resnet.layer4 362 | 363 | self.dblock = DACblock_with_inception_blocks(512) 364 | 365 | 366 | self.decoder4 = DecoderBlock(512, filters[2]) 367 | self.decoder3 = DecoderBlock(filters[2], filters[1]) 368 | self.decoder2 = DecoderBlock(filters[1], filters[0]) 369 | self.decoder1 = DecoderBlock(filters[0], filters[0]) 370 | 371 | self.finaldeconv1 = nn.ConvTranspose2d(filters[0], 32, 4, 2, 1) 372 | self.finalrelu1 = nonlinearity 373 | self.finalconv2 = nn.Conv2d(32, 32, 3, padding=1) 374 | self.finalrelu2 = nonlinearity 375 | self.finalconv3 = nn.Conv2d(32, num_classes, 3, padding=1) 376 | 377 | def forward(self, x): 378 | # Encoder 379 | x = self.firstconv(x) 380 | x = self.firstbn(x) 381 | x = self.firstrelu(x) 382 | x = self.firstmaxpool(x) 383 | e1 = self.encoder1(x) 384 | e2 = self.encoder2(e1) 385 | e3 = self.encoder3(e2) 386 | e4 = self.encoder4(e3) 387 | 388 | # Center 389 | e4 = self.dblock(e4) 390 | # e4 = self.spp(e4) 391 | 392 | # Decoder 393 | d4 = self.decoder4(e4) + e3 394 | d3 = self.decoder3(d4) + e2 395 | d2 = self.decoder2(d3) + e1 396 | d1 = self.decoder1(d2) 397 | 398 | out = self.finaldeconv1(d1) 399 | out = self.finalrelu1(out) 400 | out = self.finalconv2(out) 401 | out = self.finalrelu2(out) 402 | out = self.finalconv3(out) 403 | 404 | return torch.sigmoid(out) 405 | 406 | 407 | class CE_Net_OCT(nn.Module): 408 | def __init__(self, num_classes=12, num_channels=3): 409 | super(CE_Net_OCT, self).__init__() 410 | 411 | filters = [64, 128, 256, 512] 412 | resnet = models.resnet34(pretrained=True) 413 | self.firstconv = resnet.conv1 414 | self.firstbn = resnet.bn1 415 | self.firstrelu = resnet.relu 416 | self.firstmaxpool = resnet.maxpool 417 | self.encoder1 = resnet.layer1 418 | self.encoder2 = resnet.layer2 419 | self.encoder3 = resnet.layer3 420 | self.encoder4 = resnet.layer4 421 | 422 | self.dblock = DACblock(512) 423 | self.spp = SPPblock(512) 424 | 425 | self.decoder4 = DecoderBlock(516, filters[2]) 426 | self.decoder3 = DecoderBlock(filters[2], filters[1]) 427 | self.decoder2 = DecoderBlock(filters[1], filters[0]) 428 | self.decoder1 = DecoderBlock(filters[0], filters[0]) 429 | 430 | self.finaldeconv1 = nn.ConvTranspose2d(filters[0], 32, 4, 2, 1) 431 | self.finalrelu1 = nonlinearity 432 | self.finalconv2 = nn.Conv2d(32, 32, 3, padding=1) 433 | self.finalrelu2 = nonlinearity 434 | self.finalconv3 = nn.Conv2d(32, num_classes, 3, padding=1) 435 | 436 | def forward(self, x): 437 | # Encoder 438 | x = self.firstconv(x) 439 | x = self.firstbn(x) 440 | x = self.firstrelu(x) 441 | x = self.firstmaxpool(x) 442 | e1 = self.encoder1(x) 443 | e2 = self.encoder2(e1) 444 | e3 = self.encoder3(e2) 445 | e4 = self.encoder4(e3) 446 | 447 | # Center 448 | e4 = self.dblock(e4) 449 | e4 = self.spp(e4) 450 | 451 | # Decoder 452 | d4 = self.decoder4(e4) + e3 453 | d3 = self.decoder3(d4) + e2 454 | d2 = self.decoder2(d3) + e1 455 | d1 = self.decoder1(d2) 456 | 457 | out = self.finaldeconv1(d1) 458 | out = self.finalrelu1(out) 459 | out = self.finalconv2(out) 460 | out = self.finalrelu2(out) 461 | out = self.finalconv3(out) 462 | 463 | return out 464 | 465 | 466 | 467 | class double_conv(nn.Module): 468 | def __init__(self, in_ch, out_ch): 469 | super(double_conv, self).__init__() 470 | self.conv = nn.Sequential( 471 | nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1), 472 | nn.BatchNorm2d(out_ch), 473 | nn.ReLU(inplace=True), 474 | nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1), 475 | nn.BatchNorm2d(out_ch), 476 | nn.ReLU(inplace=True) 477 | ) 478 | 479 | def forward(self, x): 480 | x = self.conv(x) 481 | return x 482 | 483 | 484 | class inconv(nn.Module): 485 | def __init__(self, in_ch, out_ch): 486 | super(inconv, self).__init__() 487 | self.conv = double_conv(in_ch, out_ch) 488 | 489 | def forward(self, x): 490 | x = self.conv(x) 491 | return x 492 | 493 | 494 | class down(nn.Module): 495 | def __init__(self, in_ch, out_ch): 496 | super(down, self).__init__() 497 | self.max_pool_conv = nn.Sequential( 498 | nn.MaxPool2d(2), 499 | double_conv(in_ch, out_ch) 500 | ) 501 | 502 | def forward(self, x): 503 | x = self.max_pool_conv(x) 504 | return x 505 | 506 | 507 | class up(nn.Module): 508 | def __init__(self, in_ch, out_ch, bilinear=True): 509 | super(up, self).__init__() 510 | if bilinear: 511 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 512 | else: 513 | self.up = nn.ConvTranspose2d(in_ch // 2, in_ch // 2, 2, stride=2) 514 | 515 | self.conv = double_conv(in_ch, out_ch) 516 | 517 | def forward(self, x1, x2): 518 | x1 = self.up(x1) 519 | diffX = x1.size()[2] - x2.size()[2] 520 | diffY = x1.size()[3] - x2.size()[3] 521 | x2 = F.pad(x2, (diffX // 2, int(diffX / 2), diffY // 2, int(diffY / 2))) 522 | x = torch.cat([x2, x1], dim=1) 523 | x = self.conv(x) 524 | return x 525 | 526 | 527 | class outconv(nn.Module): 528 | def __init__(self, in_ch, out_ch): 529 | super(outconv, self).__init__() 530 | self.conv = nn.Conv2d(in_ch, out_ch, kernel_size=1) 531 | 532 | def forward(self, x): 533 | x = self.conv(x) 534 | return x 535 | 536 | 537 | class UNet(nn.Module): 538 | def __init__(self, n_channels=3, n_classes=1): 539 | super(UNet, self).__init__() 540 | self.inc = inconv(n_channels, 64) 541 | self.down1 = down(64, 128) 542 | self.down2 = down(128, 256) 543 | self.down3 = down(256, 512) 544 | self.down4 = down(512, 512) 545 | self.up1 = up(1024, 256) 546 | self.up2 = up(512, 128) 547 | self.up3 = up(256, 64) 548 | self.up4 = up(128, 64) 549 | self.outc = outconv(64, n_classes) 550 | self.relu = nn.ReLU() 551 | 552 | def forward(self, x): 553 | x1 = self.inc(x) 554 | x2 = self.down1(x1) 555 | x3 = self.down2(x2) 556 | x4 = self.down3(x3) 557 | x5 = self.down4(x4) 558 | x = self.up1(x5, x4) 559 | x = self.up2(x, x3) 560 | x = self.up3(x, x2) 561 | x = self.up4(x, x1) 562 | x = self.outc(x) 563 | #x = self.relu(x) 564 | return torch.sigmoid(x) 565 | -------------------------------------------------------------------------------- /src/lib/models/networks/neck_blocks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/src/lib/models/networks/neck_blocks/__init__.py -------------------------------------------------------------------------------- /src/lib/models/networks/neck_blocks/attention_modules/Dense_atrous.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from functools import partial 6 | 7 | nonlinearity = partial(nn.ReLU, inplace=True) 8 | 9 | 10 | class DACblock(nn.Module): 11 | def __init__(self, channel): 12 | super(DACblock, self).__init__() 13 | self.dilate1 = nn.Conv2d(channel, channel, kernel_size=3, dilation=1, padding=1) 14 | self.dilate2 = nn.Conv2d(channel, channel, kernel_size=3, dilation=3, padding=3) 15 | self.dilate3 = nn.Conv2d(channel, channel, kernel_size=3, dilation=5, padding=5) 16 | self.conv1x1 = nn.Conv2d(channel, channel, kernel_size=1, dilation=1, padding=0) 17 | 18 | for m in self.modules(): 19 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 20 | if m.bias is not None: 21 | m.bias.data.zero_() 22 | 23 | def forward(self, x): 24 | dilated1_out = nn.ReLU(self.dilate1(x)) 25 | dilated2_out = nn.ReLU(self.conv1x1(self.dilate2(x))) 26 | dilated3_out = nn.ReLU(self.conv1x1(self.dilate2(self.dilate1(x)))) 27 | dilated4_out = nn.ReLU(self.conv1x1(self.dilate3(self.dilate2(self.dilate1(x))))) 28 | out = x + dilated1_out + dilated2_out + dilated3_out + dilated4_out 29 | return out 30 | 31 | 32 | class PSPModule(nn.Module): 33 | def __init__(self, features, out_features=1024, sizes=(2, 3, 6, 14)): 34 | super(PSPModule, self).__init__() 35 | self.stages = [] 36 | self.stages = nn.ModuleList([self._make_stage(features, size) for size in sizes]) 37 | self.bottleneck = nn.Conv2d(features * (len(sizes) + 1), out_features, kernel_size=1) 38 | self.relu = nn.ReLU() 39 | 40 | def _make_stage(self, features, size): 41 | prior = nn.AdaptiveAvgPool2d(output_size=(size, size)) 42 | conv = nn.Conv2d(features, features, kernel_size=1, bias=False) 43 | 44 | return nn.Sequential(prior, conv) 45 | 46 | def forward(self, feats): 47 | h, w = feats.size(2), feats.size(3) 48 | # F.upsample ----> F.interpolate 49 | # UserWarning: Default upsampling behavior when mode=bilinear is changed to align_corners=False since 0.4.0. 50 | # Please specify align_corners=True if the old behavior is desired. 51 | # See the documentation of nn.Upsample for details. 52 | # warnings.warn("Default upsampling behavior when mode={} is changed " 53 | priors = [F.interpolate(input=stage(feats), size=(h, w), mode='bilinear', align_corners=True) for stage in 54 | self.stages] + [feats] 55 | bottle = self.bottleneck(torch.cat(priors, 1)) 56 | return self.relu(bottle) 57 | 58 | 59 | if __name__ == '__main__': 60 | inp = torch.zeros(size=(10, 12, 14, 14)) 61 | module = PSPModule(12) 62 | oup = module(inp) 63 | print(module) 64 | print(oup.size()) 65 | -------------------------------------------------------------------------------- /src/lib/models/networks/neck_blocks/attention_modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/src/lib/models/networks/neck_blocks/attention_modules/__init__.py -------------------------------------------------------------------------------- /src/lib/models/scatter_gather.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | from torch.nn.parallel._functions import Scatter, Gather 4 | 5 | 6 | def scatter(inputs, target_gpus, dim=0, chunk_sizes=None): 7 | r""" 8 | Slices variables into approximately equal chunks and 9 | distributes them across given GPUs. Duplicates 10 | references to objects that are not variables. Does not 11 | support Tensors. 12 | """ 13 | def scatter_map(obj): 14 | if isinstance(obj, Variable): 15 | return Scatter.apply(target_gpus, chunk_sizes, dim, obj) 16 | assert not torch.is_tensor(obj), "Tensors not supported in scatter." 17 | if isinstance(obj, tuple): 18 | return list(zip(*map(scatter_map, obj))) 19 | if isinstance(obj, list): 20 | return list(map(list, zip(*map(scatter_map, obj)))) 21 | if isinstance(obj, dict): 22 | return list(map(type(obj), zip(*map(scatter_map, obj.items())))) 23 | return [obj for targets in target_gpus] 24 | 25 | return scatter_map(inputs) 26 | 27 | 28 | def scatter_kwargs(inputs, kwargs, target_gpus, dim=0, chunk_sizes=None): 29 | r"""Scatter with support for kwargs dictionary""" 30 | inputs = scatter(inputs, target_gpus, dim, chunk_sizes) if inputs else [] 31 | kwargs = scatter(kwargs, target_gpus, dim, chunk_sizes) if kwargs else [] 32 | if len(inputs) < len(kwargs): 33 | inputs.extend([() for _ in range(len(kwargs) - len(inputs))]) 34 | elif len(kwargs) < len(inputs): 35 | kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))]) 36 | inputs = tuple(inputs) 37 | kwargs = tuple(kwargs) 38 | return inputs, kwargs 39 | -------------------------------------------------------------------------------- /src/lib/opts.py: -------------------------------------------------------------------------------- 1 | # *coding:utf-8 * 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import argparse 7 | import os 8 | 9 | 10 | class opts(object): 11 | def __init__(self): 12 | self.parser = argparse.ArgumentParser() 13 | # basic experiments setting 14 | self.parser.add_argument('--task', default='binSeg', 15 | help='binSeg' 16 | 'binSeg: Binary Segmentation') 17 | self.parser.add_argument('--dataset', default='ORIGA_OD', 18 | help='ORIGA_OD | kitti | coco_hp | pascal') 19 | self.parser.add_argument('--exp_id', default='default') 20 | self.parser.add_argument('--test', action='store_true') 21 | self.parser.add_argument('--debug', type=int, default=0, 22 | help='level of visualization.' 23 | '1: only show the final detection results' 24 | '2: show the network output features' 25 | '3: use matplot to display' # useful when lunching training with ipython notebook 26 | '4: save all visualizations to disk') 27 | self.parser.add_argument('--demo', default='', 28 | help='path to image/ image folders/ video. ' 29 | 'or "webcam"') 30 | self.parser.add_argument('--load_model', default='', 31 | help='path to pretrained model') 32 | self.parser.add_argument('--resume', action='store_true', 33 | help='resume an experiment. ' 34 | 'Reloaded the optimizer parameter and ' 35 | 'set load_model to model_last.pth ' 36 | 'in the exp dir if load_model is empty.') 37 | 38 | # model 39 | self.parser.add_argument('--model_name', default='cenet', 40 | help='which model') 41 | 42 | # data_root 43 | self.parser.add_argument('--data_root', default='/data/zaiwang/Dataset', 44 | help='which model') 45 | 46 | # loss type 47 | self.parser.add_argument('--loss_type', default='dice_bce_loss', 48 | help='which loss type') 49 | 50 | # system 51 | self.parser.add_argument('--gpus', default='3', 52 | help='-1 for CPU, use comma for multiple gpus') 53 | self.parser.add_argument('--num_workers', type=int, default=4, 54 | help='dataloader threads. 0 for single-thread.') 55 | self.parser.add_argument('--not_cuda_benchmark', action='store_true', 56 | help='disable when the input size is not fixed.') 57 | self.parser.add_argument('--seed', type=int, default=317, 58 | help='random seed') # from CornerNet 59 | 60 | self.parser.add_argument('--print_iter', type=int, default=0, 61 | help='disable progress bar and print to screen.') 62 | self.parser.add_argument('--hide_data_time', action='store_true', 63 | help='not display time during training.') 64 | self.parser.add_argument('--save_all', action='store_true', 65 | help='save model to disk every 5 epochs.') 66 | self.parser.add_argument('--metric', default='loss', 67 | help='main metric to save best model') 68 | self.parser.add_argument('--vis_thresh', type=float, default=0.3, 69 | help='visualization threshold.') 70 | self.parser.add_argument('--debugger_theme', default='white', 71 | choices=['white', 'black']) 72 | 73 | # input 74 | self.parser.add_argument('--input_res', type=int, default=-1, 75 | help='input height and width. -1 for default from ' 76 | 'dataset. Will be overriden by input_h | input_w') 77 | self.parser.add_argument('--input_h', type=int, default=-1, 78 | help='input height. -1 for default from dataset.') 79 | self.parser.add_argument('--input_w', type=int, default=-1, 80 | help='input width. -1 for default from dataset.') 81 | 82 | # train 83 | self.parser.add_argument('--lr', type=float, default=1.25e-4, 84 | help='learning rate for batch size 32.') 85 | self.parser.add_argument('--lr_step', type=str, default='90,120', 86 | help='drop learning rate by 10.') 87 | self.parser.add_argument('--num_epochs', type=int, default=140, 88 | help='total training epochs.') 89 | self.parser.add_argument('--batch_size', type=int, default=24, 90 | help='batch size') 91 | self.parser.add_argument('--master_batch_size', type=int, default=-1, 92 | help='batch size on the master gpu.') 93 | self.parser.add_argument('--num_iters', type=int, default=-1, 94 | help='default: #samples / batch_size.') 95 | self.parser.add_argument('--val_intervals', type=int, default=5, 96 | help='number of epochs to run validation.') 97 | 98 | self.parser.add_argument('--root_dir', type=str, default='/data/zaiwang/output', 99 | help='the path to save training model and loggers') 100 | 101 | # dataset augmentation 102 | self.parser.add_argument('--color_aug', default=True, 103 | help='HSV color augmentation') 104 | 105 | self.parser.add_argument('--shift_scale', default=True, 106 | help='when not using random crop' 107 | 'apply shift augmentation.') 108 | 109 | self.parser.add_argument('--HorizontalFlip', default=True, 110 | help='Horizontal Flip') 111 | 112 | self.parser.add_argument('--VerticleFlip', default=True, 113 | help='Verticle Flip') 114 | 115 | self.parser.add_argument('--rotate_90', default=True, 116 | help='when not using random crop' 117 | 'apply rotation augmentation.') 118 | 119 | def parse(self, args=''): 120 | if args == '': 121 | opt = self.parser.parse_args() 122 | else: 123 | opt = self.parser.parse_args(args) 124 | 125 | opt.gpus_str = opt.gpus 126 | opt.gpus = [int(gpu) for gpu in opt.gpus.split(',')] 127 | opt.gpus = [i for i in range(len(opt.gpus))] if opt.gpus[0] >= 0 else [-1] 128 | print("GPUS device is {}".format(opt.gpus)) 129 | 130 | opt.lr_step = [int(i) for i in opt.lr_step.split(',')] 131 | 132 | if opt.debug > 0: 133 | opt.num_workers = 0 134 | opt.batch_size = 1 135 | opt.gpus = [opt.gpus[0]] 136 | opt.master_batch_size = -1 137 | 138 | if opt.master_batch_size == -1: 139 | opt.master_batch_size = opt.batch_size // len(opt.gpus) 140 | rest_batch_size = (opt.batch_size - opt.master_batch_size) 141 | opt.chunk_sizes = [opt.master_batch_size] 142 | for i in range(len(opt.gpus) - 1): 143 | slave_chunk_size = rest_batch_size // (len(opt.gpus) - 1) 144 | if i < rest_batch_size % (len(opt.gpus) - 1): 145 | slave_chunk_size += 1 146 | opt.chunk_sizes.append(slave_chunk_size) 147 | print('training chunk_sizes:', opt.chunk_sizes) 148 | 149 | opt.data_dir = os.path.join(opt.data_root, opt.dataset) 150 | opt.exp_dir = os.path.join(opt.root_dir, 'UBT_Seg', opt.task) 151 | opt.exp_id = opt.dataset + '_' + opt.model_name + '_' + opt.loss_type 152 | opt.save_dir = os.path.join(opt.exp_dir, opt.exp_id) 153 | opt.debug_dir = os.path.join(opt.save_dir, 'debug') 154 | print('The output will be saved to ', opt.save_dir) 155 | 156 | if opt.resume and opt.load_model == '': 157 | model_path = opt.save_dir[:-4] if opt.save_dir.endswith('TEST') \ 158 | else opt.save_dir 159 | opt.load_model = os.path.join(model_path, 'model_last.pth') 160 | return opt 161 | 162 | def update_dataset_info_and_set_heads(self, opt, dataset): 163 | opt.height, opt.width = dataset.default_resolution 164 | opt.mean, opt.std = dataset.mean, dataset.std 165 | opt.num_classes = dataset.num_classes 166 | 167 | return opt 168 | 169 | def init(self, args=''): 170 | default_dataset_info = { 171 | 'binSeg': {'default_resolution': [512, 512], 'num_classes': 1, 172 | 'mean': [0.408, 0.447, 0.470], 'std': [0.289, 0.274, 0.278], 173 | 'dataset': 'ORIGA_OD'}, 174 | 175 | } 176 | 177 | class Struct: 178 | def __init__(self, entries): 179 | for k, v in entries.items(): 180 | self.__setattr__(k, v) 181 | 182 | opt = self.parse(args) 183 | dataset = Struct(default_dataset_info[opt.task]) 184 | opt.dataset = dataset.dataset 185 | opt = self.update_dataset_info_and_set_heads(opt, dataset) 186 | return opt 187 | -------------------------------------------------------------------------------- /src/lib/trains/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/src/lib/trains/__init__.py -------------------------------------------------------------------------------- /src/lib/trains/base_trainer.py: -------------------------------------------------------------------------------- 1 | # *coding:utf-8 * 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import time 8 | import torch 9 | from progress.bar import Bar 10 | from models.data_parallel import DataParallel 11 | from utils.utils import AverageMeter 12 | 13 | 14 | class ModelWithLoss(torch.nn.Module): 15 | def __init__(self, model, loss): 16 | super(ModelWithLoss, self).__init__() 17 | self.model = model 18 | self.loss = loss 19 | 20 | def forward(self, batch): 21 | outputs = self.model(batch['input']) 22 | loss, loss_stats = self.loss(outputs, batch) 23 | 24 | return outputs[-1], loss, loss_stats 25 | 26 | 27 | class BaseTrainer(object): 28 | def __init__(self, opt, model, optimizer=None): 29 | self.opt = opt 30 | self.optimizer = optimizer 31 | self.loss_stats, self.loss = self._get_losses(opt) 32 | self.model_with_loss = ModelWithLoss(model, self.loss) 33 | 34 | # 2021-09-11 35 | # DataParallel() function: device_ids=gpus (从0开始),适用于单机多卡 36 | def set_device(self, gpus, chunk_sizes, device): 37 | if len(gpus) > 1: 38 | self.model_with_loss = DataParallel( 39 | self.model_with_loss, device_ids=gpus, 40 | chunk_sizes=chunk_sizes).to(device) 41 | else: 42 | self.model_with_loss = self.model_with_loss.to(device) 43 | 44 | for state in self.optimizer.state.values(): 45 | for k, v in state.items(): 46 | if isinstance(v, torch.Tensor): 47 | state[k] = v.to(device=device, non_blocking=True) 48 | 49 | def run_epoch(self, phase, epoch, data_loader): 50 | model_with_loss = self.model_with_loss 51 | if phase == 'train': 52 | model_with_loss.train() 53 | else: 54 | if len(self.opt.gpus) > 1: 55 | model_with_loss = self.model_with_loss.module 56 | model_with_loss.eval() 57 | torch.cuda.empty_cache() 58 | 59 | opt = self.opt 60 | results = {} 61 | data_time, batch_time = AverageMeter(), AverageMeter() 62 | avg_loss_stats = {l: AverageMeter() for l in self.loss_stats} 63 | 64 | num_iters = len(data_loader) if opt.num_iters < 0 else opt.num_iters 65 | bar = Bar('{}/{}'.format(opt.task, opt.exp_id), max=num_iters) 66 | end = time.time() 67 | 68 | for iter_id, batch in enumerate(data_loader): 69 | if iter_id >= num_iters: 70 | break 71 | data_time.update(time.time() - end) 72 | 73 | for k in batch: 74 | if k != 'meta': 75 | batch[k] = batch[k].to(device=opt.device, non_blocking=True) 76 | 77 | output, loss, loss_stats = model_with_loss(batch) 78 | 79 | loss = loss.mean() 80 | 81 | if phase == 'train': 82 | self.optimizer.zero_grad() 83 | loss.backward() 84 | self.optimizer.step() 85 | batch_time.update(time.time() - end) 86 | end = time.time() 87 | 88 | Bar.suffix = '{phase}: [{0}][{1}/{2}]|Tot: {total:} |ETA: {eta:} '.format( 89 | epoch, iter_id, num_iters, phase=phase, 90 | total=bar.elapsed_td, eta=bar.eta_td) 91 | 92 | for l in avg_loss_stats: 93 | avg_loss_stats[l].update( 94 | loss_stats[l].mean().item(), batch['input'].size(0)) 95 | Bar.suffix = Bar.suffix + '|{} {:.4f} '.format(l, avg_loss_stats[l].avg) 96 | if not opt.hide_data_time: 97 | Bar.suffix = Bar.suffix + '|Data {dt.val:.3f}s({dt.avg:.3f}s) ' \ 98 | '|Net {bt.avg:.3f}s'.format(dt=data_time, bt=batch_time) 99 | if opt.print_iter > 0: 100 | if iter_id % opt.print_iter == 0: 101 | print('{}/{}| {}'.format(opt.task, opt.exp_id, Bar.suffix)) 102 | else: 103 | bar.next() 104 | 105 | if opt.debug > 0: 106 | self.debug(batch, output, iter_id) 107 | 108 | if opt.test: 109 | self.save_result(output, batch, results) 110 | 111 | del output, loss, loss_stats 112 | 113 | bar.finish() 114 | ret = {k: v.avg for k, v in avg_loss_stats.items()} 115 | ret['time'] = bar.elapsed_td.total_seconds() / 60. 116 | return ret, results 117 | 118 | def debug(self, batch, output, iter_id): 119 | raise NotImplementedError 120 | 121 | def save_result(self, output, batch, results): 122 | raise NotImplementedError 123 | 124 | def _get_losses(self, opt): 125 | raise NotImplementedError 126 | 127 | def val(self, epoch, data_loader): 128 | return self.run_epoch('val', epoch, data_loader) 129 | 130 | def train(self, epoch, data_loader): 131 | return self.run_epoch('train', epoch, data_loader) 132 | -------------------------------------------------------------------------------- /src/lib/trains/binarySeg.py: -------------------------------------------------------------------------------- 1 | # *coding:utf-8 * 2 | 3 | import torch 4 | import numpy as np 5 | 6 | from models.losses import dice_bce_loss 7 | from .base_trainer import BaseTrainer 8 | 9 | class binarySegLoss(torch.nn.Module): 10 | def __init__(self, opt): 11 | super(binarySegLoss, self).__init__() 12 | self.crit = dice_bce_loss() 13 | self.opt = opt 14 | 15 | def forward(self, outputs, batch): 16 | 17 | loss = self.crit(batch['gt'], outputs) 18 | 19 | loss_stats = {'loss': loss} 20 | 21 | return loss, loss_stats 22 | 23 | 24 | class BinarySegTrainer(BaseTrainer): 25 | def __init__(self, opt, model, optimizer=None): 26 | super(BinarySegTrainer, self).__init__(opt, model, optimizer=optimizer) 27 | 28 | def _get_losses(self, opt): 29 | loss_stats = ['loss'] 30 | loss = binarySegLoss(opt) 31 | return loss_stats, loss 32 | 33 | -------------------------------------------------------------------------------- /src/lib/trains/train_factory.py: -------------------------------------------------------------------------------- 1 | # *coding:utf-8 * 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | from .binarySeg import BinarySegTrainer 7 | 8 | train_factory = { 9 | 'binSeg': BinarySegTrainer, 10 | } -------------------------------------------------------------------------------- /src/lib/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/src/lib/utils/__init__.py -------------------------------------------------------------------------------- /src/lib/utils/image.py: -------------------------------------------------------------------------------- 1 | # *coding:utf-8 * 2 | 3 | """ 4 | Based on https://github.com/asanakoy/kaggle_carvana_segmentation 5 | """ 6 | import torch 7 | import torch.utils.data as data 8 | from torch.autograd import Variable as V 9 | from PIL import Image 10 | 11 | 12 | 13 | import cv2 14 | import numpy as np 15 | import os 16 | import scipy.misc as misc 17 | 18 | 19 | def randomHueSaturationValue(image, hue_shift_limit=(-180, 180), 20 | sat_shift_limit=(-255, 255), 21 | val_shift_limit=(-255, 255), u=0.5): 22 | if np.random.random() < u: 23 | image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) 24 | h, s, v = cv2.split(image) 25 | hue_shift = np.random.randint(hue_shift_limit[0], hue_shift_limit[1] + 1) 26 | hue_shift = np.uint8(hue_shift) 27 | h += hue_shift 28 | sat_shift = np.random.uniform(sat_shift_limit[0], sat_shift_limit[1]) 29 | s = cv2.add(s, sat_shift) 30 | val_shift = np.random.uniform(val_shift_limit[0], val_shift_limit[1]) 31 | v = cv2.add(v, val_shift) 32 | image = cv2.merge((h, s, v)) 33 | # image = cv2.merge((s, v)) 34 | image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR) 35 | 36 | return image 37 | 38 | 39 | def randomShiftScaleRotate(image, mask, 40 | shift_limit=(-0.0, 0.0), 41 | scale_limit=(-0.0, 0.0), 42 | rotate_limit=(-0.0, 0.0), 43 | aspect_limit=(-0.0, 0.0), 44 | borderMode=cv2.BORDER_CONSTANT, u=0.5): 45 | if np.random.random() < u: 46 | height, width, channel = image.shape 47 | 48 | angle = np.random.uniform(rotate_limit[0], rotate_limit[1]) 49 | scale = np.random.uniform(1 + scale_limit[0], 1 + scale_limit[1]) 50 | aspect = np.random.uniform(1 + aspect_limit[0], 1 + aspect_limit[1]) 51 | sx = scale * aspect / (aspect ** 0.5) 52 | sy = scale / (aspect ** 0.5) 53 | dx = round(np.random.uniform(shift_limit[0], shift_limit[1]) * width) 54 | dy = round(np.random.uniform(shift_limit[0], shift_limit[1]) * height) 55 | 56 | cc = np.math.cos(angle / 180 * np.math.pi) * sx 57 | ss = np.math.sin(angle / 180 * np.math.pi) * sy 58 | rotate_matrix = np.array([[cc, -ss], [ss, cc]]) 59 | 60 | box0 = np.array([[0, 0], [width, 0], [width, height], [0, height], ]) 61 | box1 = box0 - np.array([width / 2, height / 2]) 62 | box1 = np.dot(box1, rotate_matrix.T) + np.array([width / 2 + dx, height / 2 + dy]) 63 | 64 | box0 = box0.astype(np.float32) 65 | box1 = box1.astype(np.float32) 66 | mat = cv2.getPerspectiveTransform(box0, box1) 67 | image = cv2.warpPerspective(image, mat, (width, height), flags=cv2.INTER_LINEAR, borderMode=borderMode, 68 | borderValue=( 69 | 0, 0, 70 | 0,)) 71 | mask = cv2.warpPerspective(mask, mat, (width, height), flags=cv2.INTER_LINEAR, borderMode=borderMode, 72 | borderValue=( 73 | 0, 0, 74 | 0,)) 75 | 76 | return image, mask 77 | 78 | 79 | def randomHorizontalFlip(image, mask, u=0.5): 80 | if np.random.random() < u: 81 | image = cv2.flip(image, 1) 82 | mask = cv2.flip(mask, 1) 83 | 84 | return image, mask 85 | 86 | 87 | def randomVerticleFlip(image, mask, u=0.5): 88 | if np.random.random() < u: 89 | image = cv2.flip(image, 0) 90 | mask = cv2.flip(mask, 0) 91 | 92 | return image, mask 93 | 94 | 95 | def randomRotate90(image, mask, u=0.5): 96 | if np.random.random() < u: 97 | image = np.rot90(image) 98 | mask = np.rot90(mask) 99 | 100 | return image, mask 101 | 102 | 103 | def mask_to_boundary(mask, dilation_ratio=0.02): 104 | """ 105 | Convert binary mask to boundary mask. 106 | :param mask (numpy array, uint8): binary mask 107 | :param dilation_ratio (float): ratio to calculate dilation = dilation_ratio * image_diagonal 108 | :return: boundary mask (numpy array) 109 | """ 110 | h, w = mask.shape 111 | img_diag = np.sqrt(h ** 2 + w ** 2) 112 | dilation = int(round(dilation_ratio * img_diag)) 113 | if dilation < 1: 114 | dilation = 1 115 | # Pad image so mask truncated by the image border is also considered as boundary. 116 | new_mask = cv2.copyMakeBorder(mask, 1, 1, 1, 1, cv2.BORDER_CONSTANT, value=0) 117 | kernel = np.ones((3, 3), dtype=np.uint8) 118 | new_mask_erode = cv2.erode(new_mask, kernel, iterations=dilation) 119 | mask_erode = new_mask_erode[1 : h + 1, 1 : w + 1] 120 | # G_d intersects G in the paper. 121 | return mask - mask_erode 122 | -------------------------------------------------------------------------------- /src/lib/utils/utils.py: -------------------------------------------------------------------------------- 1 | # *coding:utf-8 * 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | 7 | class AverageMeter(object): 8 | """Computes and stores the average and current value""" 9 | 10 | def __init__(self): 11 | self.reset() 12 | 13 | def reset(self): 14 | self.val = 0 15 | self.avg = 0 16 | self.sum = 0 17 | self.count = 0 18 | 19 | def update(self, val, n=1): 20 | self.val = val 21 | self.sum += val * n 22 | self.count += n 23 | if self.count > 0: 24 | self.avg = self.sum / self.count 25 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | # *coding:utf-8 * 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | from _init_paths import __init_path 7 | 8 | __init_path() 9 | 10 | import os 11 | 12 | import torch 13 | import torch.utils 14 | 15 | from logger import Logger 16 | from datasets.dataset_factory import get_dataset 17 | from models.model import create_model, load_model, save_model 18 | from trains.train_factory import train_factory 19 | 20 | from opts import opts 21 | 22 | 23 | def main(opt): 24 | # Completely reproducible results are not guaranteed across PyTorch releases, \ 25 | # individual commits, or different platforms. Furthermore, results may not be reproducible \ 26 | # between CPU and GPU executions, even when using identical seeds. 27 | # We can use torch.manual_seed() to seed the RNG for all devices (both CPU and CUDA): 28 | torch.manual_seed(opt.seed) 29 | 30 | # 设置 torch.backends.cudnn.benchmark=True 将会让程序在开始时花费一点额外时间,\ 31 | # 为整个网络的每个卷积层搜索最适合它的卷积实现算法,进而实现网络的加速。\ 32 | # 适用场景是网络结构固定(不是动态变化的),网络的输入形状(包括 batch size,图片大小,输入的通道)是不变的,\ 33 | # 其实也就是一般情况下都比较适用。反之,如果卷积层的设置一直变化,将会导致程序不停地做优化,反而会耗费更多的时间。 34 | torch.backends.cudnn.benchmark = not opt.not_cuda_benchmark and not opt.test 35 | 36 | Dataset = get_dataset(opt.dataset, opt.task) 37 | 38 | opt = opts().update_dataset_info_and_set_heads(opt, Dataset) 39 | 40 | logger = Logger(opt) 41 | 42 | os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpus_str 43 | opt.device = torch.device('cuda' if opt.gpus[0] >= 0 else 'cpu') 44 | 45 | print('Creating model...') 46 | 47 | model = create_model(opt.model_name) 48 | 49 | optimizer = torch.optim.Adam(model.parameters(), opt.lr) 50 | 51 | start_epoch = 0 52 | 53 | if opt.load_model != '': 54 | model, optimizer, start_epoch = load_model( 55 | model, opt.load_model, optimizer, opt.resume, opt.lr, opt.lr_step) 56 | 57 | Trainer = train_factory[opt.task] 58 | 59 | trainer = Trainer(opt, model, optimizer) 60 | trainer.set_device(opt.gpus, opt.chunk_sizes, opt.device) 61 | 62 | print("Setting up data...") 63 | val_loader = torch.utils.data.DataLoader( 64 | Dataset(opt, 'val'), 65 | batch_size=1, 66 | shuffle=False, 67 | num_workers=1, 68 | pin_memory=True 69 | ) 70 | 71 | if opt.test: 72 | _, preds = trainer.val(0, val_loader) 73 | # run evaluation code 74 | val_loader.dataset.run_eval(preds, opt.save_dir) 75 | return 76 | 77 | train_loader = torch.utils.data.DataLoader( 78 | Dataset(opt, 'train'), 79 | batch_size=opt.batch_size, 80 | shuffle=True, 81 | num_workers=opt.num_workers, 82 | pin_memory=True, 83 | drop_last=True 84 | ) 85 | 86 | print("Starting training") 87 | best = 1e10 88 | 89 | for epoch in range(start_epoch + 1, opt.num_epochs + 1): 90 | mark = epoch if opt.save_all else 'last' 91 | log_dict_train, _ = trainer.train(epoch, train_loader) 92 | logger.write('epoch: {} |'.format(epoch)) 93 | 94 | for k, v in log_dict_train.items(): 95 | logger.scalar_summary('train_{}'.format(k), v, epoch) 96 | logger.write('{} {:8f} | '.format(k, v)) 97 | if opt.val_intervals > 0 and epoch % opt.val_intervals == 0: 98 | save_model(os.path.join(opt.save_dir, 'model_{}.pth'.format(mark)), 99 | epoch, model, optimizer) 100 | with torch.no_grad(): 101 | log_dict_val, preds = trainer.val(epoch, val_loader) 102 | for k, v in log_dict_val.items(): 103 | logger.scalar_summary('val_{}'.format(k), v, epoch) 104 | logger.write('{} {:8f} | '.format(k, v)) 105 | if log_dict_val[opt.metric] < best: 106 | best = log_dict_val[opt.metric] 107 | save_model(os.path.join(opt.save_dir, 'model_best.pth'), 108 | epoch, model) 109 | else: 110 | save_model(os.path.join(opt.save_dir, 'model_last.pth'), 111 | epoch, model, optimizer) 112 | 113 | logger.write('\n') 114 | if epoch in opt.lr_step: 115 | save_model(os.path.join(opt.save_dir, 'model_{}.pth'.format(epoch)), 116 | epoch, model, optimizer) 117 | lr = opt.lr * (0.1 ** (opt.lr_step.index(epoch) + 1)) 118 | print('Drop LR to', lr) 119 | for param_group in optimizer.param_groups: 120 | param_group['lr'] = lr 121 | logger.close() 122 | 123 | 124 | if __name__ == '__main__': 125 | opt = opts().parse() 126 | main(opt) 127 | --------------------------------------------------------------------------------