├── .idea
├── .gitignore
├── CE-Net.iml
├── codeStyles
│ └── codeStyleConfig.xml
├── deployment.xml
├── dictionaries
│ └── guzaiwang.xml
├── inspectionProfiles
│ ├── Project_Default.xml
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
├── remote-mappings.xml
├── sshConfigs.xml
├── vcs.xml
└── webServers.xml
├── README.md
├── dataset
└── DRIVE
│ ├── test
│ ├── 1st_manual
│ │ ├── 01_manual1.gif
│ │ ├── 02_manual1.gif
│ │ ├── 03_manual1.gif
│ │ ├── 04_manual1.gif
│ │ ├── 05_manual1.gif
│ │ ├── 06_manual1.gif
│ │ ├── 07_manual1.gif
│ │ ├── 08_manual1.gif
│ │ ├── 09_manual1.gif
│ │ ├── 10_manual1.gif
│ │ ├── 11_manual1.gif
│ │ ├── 12_manual1.gif
│ │ ├── 13_manual1.gif
│ │ ├── 14_manual1.gif
│ │ ├── 15_manual1.gif
│ │ ├── 16_manual1.gif
│ │ ├── 17_manual1.gif
│ │ ├── 18_manual1.gif
│ │ ├── 19_manual1.gif
│ │ └── 20_manual1.gif
│ ├── 2nd_manual
│ │ ├── 01_manual2.gif
│ │ ├── 02_manual2.gif
│ │ ├── 03_manual2.gif
│ │ ├── 04_manual2.gif
│ │ ├── 05_manual2.gif
│ │ ├── 06_manual2.gif
│ │ ├── 07_manual2.gif
│ │ ├── 08_manual2.gif
│ │ ├── 09_manual2.gif
│ │ ├── 10_manual2.gif
│ │ ├── 11_manual2.gif
│ │ ├── 12_manual2.gif
│ │ ├── 13_manual2.gif
│ │ ├── 14_manual2.gif
│ │ ├── 15_manual2.gif
│ │ ├── 16_manual2.gif
│ │ ├── 17_manual2.gif
│ │ ├── 18_manual2.gif
│ │ ├── 19_manual2.gif
│ │ └── 20_manual2.gif
│ ├── images
│ │ ├── 01_test.tif
│ │ ├── 02_test.tif
│ │ ├── 03_test.tif
│ │ ├── 04_test.tif
│ │ ├── 05_test.tif
│ │ ├── 06_test.tif
│ │ ├── 07_test.tif
│ │ ├── 08_test.tif
│ │ ├── 09_test.tif
│ │ ├── 10_test.tif
│ │ ├── 11_test.tif
│ │ ├── 12_test.tif
│ │ ├── 13_test.tif
│ │ ├── 14_test.tif
│ │ ├── 15_test.tif
│ │ ├── 16_test.tif
│ │ ├── 17_test.tif
│ │ ├── 18_test.tif
│ │ ├── 19_test.tif
│ │ └── 20_test.tif
│ └── mask
│ │ ├── 01_test_mask.gif
│ │ ├── 02_test_mask.gif
│ │ ├── 03_test_mask.gif
│ │ ├── 04_test_mask.gif
│ │ ├── 05_test_mask.gif
│ │ ├── 06_test_mask.gif
│ │ ├── 07_test_mask.gif
│ │ ├── 08_test_mask.gif
│ │ ├── 09_test_mask.gif
│ │ ├── 10_test_mask.gif
│ │ ├── 11_test_mask.gif
│ │ ├── 12_test_mask.gif
│ │ ├── 13_test_mask.gif
│ │ ├── 14_test_mask.gif
│ │ ├── 15_test_mask.gif
│ │ ├── 16_test_mask.gif
│ │ ├── 17_test_mask.gif
│ │ ├── 18_test_mask.gif
│ │ ├── 19_test_mask.gif
│ │ └── 20_test_mask.gif
│ └── training
│ ├── 1st_manual
│ ├── 21_manual1.gif
│ ├── 22_manual1.gif
│ ├── 23_manual1.gif
│ ├── 24_manual1.gif
│ ├── 25_manual1.gif
│ ├── 26_manual1.gif
│ ├── 27_manual1.gif
│ ├── 28_manual1.gif
│ ├── 29_manual1.gif
│ ├── 30_manual1.gif
│ ├── 31_manual1.gif
│ ├── 32_manual1.gif
│ ├── 33_manual1.gif
│ ├── 34_manual1.gif
│ ├── 35_manual1.gif
│ ├── 36_manual1.gif
│ ├── 37_manual1.gif
│ ├── 38_manual1.gif
│ ├── 39_manual1.gif
│ └── 40_manual1.gif
│ └── images
│ ├── 21_training.tif
│ ├── 22_training.tif
│ ├── 23_training.tif
│ ├── 24_training.tif
│ ├── 25_training.tif
│ ├── 26_training.tif
│ ├── 27_training.tif
│ ├── 28_training.tif
│ ├── 29_training.tif
│ ├── 30_training.tif
│ ├── 31_training.tif
│ ├── 32_training.tif
│ ├── 33_training.tif
│ ├── 34_training.tif
│ ├── 35_training.tif
│ ├── 36_training.tif
│ ├── 37_training.tif
│ ├── 38_training.tif
│ ├── 39_training.tif
│ └── 40_training.tif
├── readme
└── exp_res1.jpg
└── src
├── _init_paths.py
├── lib
├── __init__.py
├── datasets
│ ├── __init__.py
│ ├── dataset
│ │ ├── Cityscape.py
│ │ ├── HumanSeg.py
│ │ ├── ORIGA_OD.py
│ │ ├── VOC.py
│ │ └── __init__.py
│ ├── dataset_factory.py
│ └── sample
│ │ ├── __init__.py
│ │ ├── binarySeg.py
│ │ └── multiSeg.py
├── logger.py
├── models
│ ├── __init__.py
│ ├── data_parallel.py
│ ├── losses.py
│ ├── model.py
│ ├── networks
│ │ ├── __init__.py
│ │ ├── backbones
│ │ │ ├── __init__.py
│ │ │ ├── backbone_factory.py
│ │ │ ├── mobilenet
│ │ │ │ ├── __init__.py
│ │ │ │ ├── basic_module.py
│ │ │ │ ├── build_mobilenet.py
│ │ │ │ └── mobilenet_factory.py
│ │ │ └── resnet
│ │ │ │ ├── __init__.py
│ │ │ │ ├── basic_module.py
│ │ │ │ ├── build_resnet.py
│ │ │ │ └── resnet_factory.py
│ │ ├── cenet.py
│ │ └── neck_blocks
│ │ │ ├── __init__.py
│ │ │ └── attention_modules
│ │ │ ├── Dense_atrous.py
│ │ │ └── __init__.py
│ └── scatter_gather.py
├── opts.py
├── trains
│ ├── __init__.py
│ ├── base_trainer.py
│ ├── binarySeg.py
│ └── train_factory.py
└── utils
│ ├── __init__.py
│ ├── image.py
│ └── utils.py
└── main.py
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 | # Datasource local storage ignored files
5 | /dataSources/
6 | /dataSources.local.xml
7 | # Editor-based HTTP Client requests
8 | /httpRequests/
9 |
--------------------------------------------------------------------------------
/.idea/CE-Net.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/.idea/codeStyles/codeStyleConfig.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
--------------------------------------------------------------------------------
/.idea/deployment.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/.idea/dictionaries/guzaiwang.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
47 |
48 |
49 |
50 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/remote-mappings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
--------------------------------------------------------------------------------
/.idea/sshConfigs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/webServers.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
14 |
15 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Context Encoder Network for 2D Medical Image Segmentation
2 | > [**CE-Net: Context Encoder Network for 2D Medical Image Segmentation**](https://arxiv.org/abs/1903.02740),
3 | > Zaiwang Gu, Jun Cheng, Huazhu Fu, Kang Zhou, Huaying Hao, Yitian Zhao, Tianyang Zhang, Shenghua Gao, Jiang Liu
4 | > *arXiv technical report ([arXiv 1903.02740](https://arxiv.org/abs/1903.02740))*
5 |
6 |
7 | Contact: [guzw@i2r.a-star.edu.sg](mailto:guzw@i2r.a-star.edu.sg) or [guzaiwang01@gmail.com](mailto:guzaiwang01@gmail.com). Any questions or discussions are welcomed!
8 |
9 | ## Abstract
10 |
11 | Medical image segmentation is an important step
12 | in medical image analysis. With the rapid development of
13 | convolutional neural network in image processing, deep learning
14 | has been used for medical image segmentation, such as optic
15 | disc segmentation, blood vessel detection, lung segmentation, cell
16 | segmentation, etc. Previously, U-net based approaches have been
17 | proposed. However, the consecutive pooling and strided convolutional operations lead to the loss of some spatial information. In
18 | this paper, we propose a context encoder network (referred to as
19 | CE-Net) to capture more high-level information and preserve
20 | spatial information for 2D medical image segmentation. CENet mainly contains three major components: a feature encoder
21 | module, a context extractor and a feature decoder module. We
22 | use pretrained ResNet block as the fixed feature extractor. The
23 | context extractor module is formed by a newly proposed dense
24 | atrous convolution (DAC) block and residual multi-kernel pooling
25 | (RMP) block. We applied the proposed CE-Net to different 2D
26 | medical image segmentation tasks. Comprehensive results show
27 | that the proposed method outperforms the original U-Net method
28 | and other state-of-the-art methods for optic disc segmentation,
29 | vessel detection, lung segmentation, cell contour segmentation
30 | and retinal optical coherence tomography layer segmentation.
31 |
32 |
33 | ## Use CE-Net
34 | Please start up the "visdom" before running the main.py.
35 | Then, run the main.py file.
36 |
37 | We have uploaded the DRIVE dataset to run the retinal vessel detection. The other medical datasets will be
38 | uploaded in the next submission.
39 |
40 | The submission mainly contains:
41 | 1. architecture (called CE-Net) in networks/cenet.py
42 | 2. multi-class dice loss in loss.py
43 | 3. data augmentation in data.py
44 |
45 | Update:
46 | We have modified the loss function.
47 | The cuda error (or warning) will not occur.
48 |
49 | Update:
50 | The test code has been uploaded.
51 | Besides, we release a pretrained model, which achieves 0.9819 in the AUC scor in the DRIVE dataset.
52 |
53 | ## Citation
54 |
55 | If you find this project useful for your research, please use the following BibTeX entry.
56 |
57 | @article{gu2019net,
58 | title={Ce-net: Context encoder network for 2d medical image segmentation},
59 | author={Gu, Zaiwang and Cheng, Jun and Fu, Huazhu and Zhou, Kang and Hao, Huaying and Zhao, Yitian and Zhang, Tianyang and Gao, Shenghua and Liu, Jiang},
60 | journal={IEEE transactions on medical imaging},
61 | volume={38},
62 | number={10},
63 | pages={2281--2292},
64 | year={2019},
65 | publisher={IEEE}
66 | }
67 |
68 | The manuscript has been accepted in TMI.
69 |
70 |
71 |
--------------------------------------------------------------------------------
/dataset/DRIVE/test/1st_manual/01_manual1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/1st_manual/01_manual1.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/test/1st_manual/02_manual1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/1st_manual/02_manual1.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/test/1st_manual/03_manual1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/1st_manual/03_manual1.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/test/1st_manual/04_manual1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/1st_manual/04_manual1.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/test/1st_manual/05_manual1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/1st_manual/05_manual1.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/test/1st_manual/06_manual1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/1st_manual/06_manual1.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/test/1st_manual/07_manual1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/1st_manual/07_manual1.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/test/1st_manual/08_manual1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/1st_manual/08_manual1.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/test/1st_manual/09_manual1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/1st_manual/09_manual1.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/test/1st_manual/10_manual1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/1st_manual/10_manual1.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/test/1st_manual/11_manual1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/1st_manual/11_manual1.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/test/1st_manual/12_manual1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/1st_manual/12_manual1.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/test/1st_manual/13_manual1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/1st_manual/13_manual1.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/test/1st_manual/14_manual1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/1st_manual/14_manual1.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/test/1st_manual/15_manual1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/1st_manual/15_manual1.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/test/1st_manual/16_manual1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/1st_manual/16_manual1.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/test/1st_manual/17_manual1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/1st_manual/17_manual1.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/test/1st_manual/18_manual1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/1st_manual/18_manual1.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/test/1st_manual/19_manual1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/1st_manual/19_manual1.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/test/1st_manual/20_manual1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/1st_manual/20_manual1.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/test/2nd_manual/01_manual2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/2nd_manual/01_manual2.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/test/2nd_manual/02_manual2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/2nd_manual/02_manual2.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/test/2nd_manual/03_manual2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/2nd_manual/03_manual2.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/test/2nd_manual/04_manual2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/2nd_manual/04_manual2.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/test/2nd_manual/05_manual2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/2nd_manual/05_manual2.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/test/2nd_manual/06_manual2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/2nd_manual/06_manual2.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/test/2nd_manual/07_manual2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/2nd_manual/07_manual2.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/test/2nd_manual/08_manual2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/2nd_manual/08_manual2.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/test/2nd_manual/09_manual2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/2nd_manual/09_manual2.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/test/2nd_manual/10_manual2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/2nd_manual/10_manual2.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/test/2nd_manual/11_manual2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/2nd_manual/11_manual2.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/test/2nd_manual/12_manual2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/2nd_manual/12_manual2.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/test/2nd_manual/13_manual2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/2nd_manual/13_manual2.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/test/2nd_manual/14_manual2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/2nd_manual/14_manual2.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/test/2nd_manual/15_manual2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/2nd_manual/15_manual2.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/test/2nd_manual/16_manual2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/2nd_manual/16_manual2.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/test/2nd_manual/17_manual2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/2nd_manual/17_manual2.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/test/2nd_manual/18_manual2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/2nd_manual/18_manual2.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/test/2nd_manual/19_manual2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/2nd_manual/19_manual2.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/test/2nd_manual/20_manual2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/2nd_manual/20_manual2.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/test/images/01_test.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/images/01_test.tif
--------------------------------------------------------------------------------
/dataset/DRIVE/test/images/02_test.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/images/02_test.tif
--------------------------------------------------------------------------------
/dataset/DRIVE/test/images/03_test.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/images/03_test.tif
--------------------------------------------------------------------------------
/dataset/DRIVE/test/images/04_test.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/images/04_test.tif
--------------------------------------------------------------------------------
/dataset/DRIVE/test/images/05_test.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/images/05_test.tif
--------------------------------------------------------------------------------
/dataset/DRIVE/test/images/06_test.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/images/06_test.tif
--------------------------------------------------------------------------------
/dataset/DRIVE/test/images/07_test.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/images/07_test.tif
--------------------------------------------------------------------------------
/dataset/DRIVE/test/images/08_test.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/images/08_test.tif
--------------------------------------------------------------------------------
/dataset/DRIVE/test/images/09_test.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/images/09_test.tif
--------------------------------------------------------------------------------
/dataset/DRIVE/test/images/10_test.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/images/10_test.tif
--------------------------------------------------------------------------------
/dataset/DRIVE/test/images/11_test.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/images/11_test.tif
--------------------------------------------------------------------------------
/dataset/DRIVE/test/images/12_test.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/images/12_test.tif
--------------------------------------------------------------------------------
/dataset/DRIVE/test/images/13_test.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/images/13_test.tif
--------------------------------------------------------------------------------
/dataset/DRIVE/test/images/14_test.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/images/14_test.tif
--------------------------------------------------------------------------------
/dataset/DRIVE/test/images/15_test.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/images/15_test.tif
--------------------------------------------------------------------------------
/dataset/DRIVE/test/images/16_test.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/images/16_test.tif
--------------------------------------------------------------------------------
/dataset/DRIVE/test/images/17_test.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/images/17_test.tif
--------------------------------------------------------------------------------
/dataset/DRIVE/test/images/18_test.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/images/18_test.tif
--------------------------------------------------------------------------------
/dataset/DRIVE/test/images/19_test.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/images/19_test.tif
--------------------------------------------------------------------------------
/dataset/DRIVE/test/images/20_test.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/images/20_test.tif
--------------------------------------------------------------------------------
/dataset/DRIVE/test/mask/01_test_mask.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/mask/01_test_mask.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/test/mask/02_test_mask.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/mask/02_test_mask.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/test/mask/03_test_mask.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/mask/03_test_mask.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/test/mask/04_test_mask.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/mask/04_test_mask.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/test/mask/05_test_mask.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/mask/05_test_mask.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/test/mask/06_test_mask.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/mask/06_test_mask.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/test/mask/07_test_mask.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/mask/07_test_mask.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/test/mask/08_test_mask.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/mask/08_test_mask.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/test/mask/09_test_mask.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/mask/09_test_mask.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/test/mask/10_test_mask.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/mask/10_test_mask.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/test/mask/11_test_mask.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/mask/11_test_mask.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/test/mask/12_test_mask.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/mask/12_test_mask.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/test/mask/13_test_mask.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/mask/13_test_mask.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/test/mask/14_test_mask.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/mask/14_test_mask.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/test/mask/15_test_mask.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/mask/15_test_mask.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/test/mask/16_test_mask.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/mask/16_test_mask.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/test/mask/17_test_mask.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/mask/17_test_mask.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/test/mask/18_test_mask.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/mask/18_test_mask.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/test/mask/19_test_mask.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/mask/19_test_mask.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/test/mask/20_test_mask.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/test/mask/20_test_mask.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/training/1st_manual/21_manual1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/1st_manual/21_manual1.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/training/1st_manual/22_manual1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/1st_manual/22_manual1.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/training/1st_manual/23_manual1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/1st_manual/23_manual1.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/training/1st_manual/24_manual1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/1st_manual/24_manual1.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/training/1st_manual/25_manual1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/1st_manual/25_manual1.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/training/1st_manual/26_manual1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/1st_manual/26_manual1.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/training/1st_manual/27_manual1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/1st_manual/27_manual1.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/training/1st_manual/28_manual1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/1st_manual/28_manual1.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/training/1st_manual/29_manual1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/1st_manual/29_manual1.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/training/1st_manual/30_manual1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/1st_manual/30_manual1.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/training/1st_manual/31_manual1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/1st_manual/31_manual1.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/training/1st_manual/32_manual1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/1st_manual/32_manual1.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/training/1st_manual/33_manual1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/1st_manual/33_manual1.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/training/1st_manual/34_manual1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/1st_manual/34_manual1.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/training/1st_manual/35_manual1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/1st_manual/35_manual1.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/training/1st_manual/36_manual1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/1st_manual/36_manual1.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/training/1st_manual/37_manual1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/1st_manual/37_manual1.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/training/1st_manual/38_manual1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/1st_manual/38_manual1.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/training/1st_manual/39_manual1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/1st_manual/39_manual1.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/training/1st_manual/40_manual1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/1st_manual/40_manual1.gif
--------------------------------------------------------------------------------
/dataset/DRIVE/training/images/21_training.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/images/21_training.tif
--------------------------------------------------------------------------------
/dataset/DRIVE/training/images/22_training.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/images/22_training.tif
--------------------------------------------------------------------------------
/dataset/DRIVE/training/images/23_training.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/images/23_training.tif
--------------------------------------------------------------------------------
/dataset/DRIVE/training/images/24_training.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/images/24_training.tif
--------------------------------------------------------------------------------
/dataset/DRIVE/training/images/25_training.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/images/25_training.tif
--------------------------------------------------------------------------------
/dataset/DRIVE/training/images/26_training.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/images/26_training.tif
--------------------------------------------------------------------------------
/dataset/DRIVE/training/images/27_training.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/images/27_training.tif
--------------------------------------------------------------------------------
/dataset/DRIVE/training/images/28_training.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/images/28_training.tif
--------------------------------------------------------------------------------
/dataset/DRIVE/training/images/29_training.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/images/29_training.tif
--------------------------------------------------------------------------------
/dataset/DRIVE/training/images/30_training.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/images/30_training.tif
--------------------------------------------------------------------------------
/dataset/DRIVE/training/images/31_training.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/images/31_training.tif
--------------------------------------------------------------------------------
/dataset/DRIVE/training/images/32_training.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/images/32_training.tif
--------------------------------------------------------------------------------
/dataset/DRIVE/training/images/33_training.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/images/33_training.tif
--------------------------------------------------------------------------------
/dataset/DRIVE/training/images/34_training.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/images/34_training.tif
--------------------------------------------------------------------------------
/dataset/DRIVE/training/images/35_training.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/images/35_training.tif
--------------------------------------------------------------------------------
/dataset/DRIVE/training/images/36_training.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/images/36_training.tif
--------------------------------------------------------------------------------
/dataset/DRIVE/training/images/37_training.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/images/37_training.tif
--------------------------------------------------------------------------------
/dataset/DRIVE/training/images/38_training.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/images/38_training.tif
--------------------------------------------------------------------------------
/dataset/DRIVE/training/images/39_training.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/images/39_training.tif
--------------------------------------------------------------------------------
/dataset/DRIVE/training/images/40_training.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/dataset/DRIVE/training/images/40_training.tif
--------------------------------------------------------------------------------
/readme/exp_res1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/readme/exp_res1.jpg
--------------------------------------------------------------------------------
/src/_init_paths.py:
--------------------------------------------------------------------------------
1 | # *coding:utf-8 *
2 |
3 | import os.path as osp
4 | import sys
5 |
6 |
7 | def add_path(path):
8 | if path not in sys.path:
9 | sys.path.insert(0, path)
10 |
11 |
12 | def __init_path():
13 | this_dir = osp.dirname(__file__)
14 |
15 | # Add lib to PYTHONPATH
16 | lib_path = osp.join(this_dir, 'lib')
17 | add_path(lib_path)
18 |
19 | if __name__ == '__main__':
20 | __init_path()
21 | print("system path:")
22 | print(sys.path)
23 |
--------------------------------------------------------------------------------
/src/lib/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/src/lib/__init__.py
--------------------------------------------------------------------------------
/src/lib/datasets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/src/lib/datasets/__init__.py
--------------------------------------------------------------------------------
/src/lib/datasets/dataset/Cityscape.py:
--------------------------------------------------------------------------------
1 | # *coding:utf-8 *
2 | from __future__ import absolute_import
3 | from __future__ import division
4 | from __future__ import print_function
5 |
6 |
7 | import numpy as np
8 | import os
9 |
10 | import torch.utils.data as data
11 |
12 | class Cityscape(data.Dataset):
13 | num_classes = 1
14 | default_resolution = [448, 448]
15 | mean = np.array([0.40789654, 0.44719302, 0.47026115],
16 | dtype=np.float32).reshape(1, 1, 3)
17 | std = np.array([0.28863828, 0.27408164, 0.27809835],
18 | dtype=np.float32).reshape(1, 1, 3)
19 |
20 | def __init__(self, opt, split):
21 | self.images = []
22 | self.labels = []
23 | self.opt = opt
24 |
25 | if split == 'train':
26 | read_files = os.path.join(opt.data_dir, 'Set_A.txt')
27 | else:
28 | read_files = os.path.join(opt.data_dir, 'Set_B.txt')
29 |
30 | self.image_root_folder = os.path.join(opt.data_dir, 'crop_image')
31 | self.gt_root_folder = os.path.join(opt.data_dir, 'crop_mask')
32 |
33 | self._read_img_mask(self.image_root_folder, self.gt_root_folder, read_files)
34 |
35 | def _read_img_mask(self, image_folder, mask_folder, read_files):
36 | for img_name in open(read_files):
37 | image_path = os.path.join(image_folder, img_name.split('.')[0] + '.jpg')
38 | label_path = os.path.join(mask_folder, img_name.split('.')[0] + '.png')
39 |
40 | self.images.append(image_path)
41 | self.labels.append(label_path)
42 |
43 | def __len__(self):
44 | assert len(self.images) == len(self.labels), 'The number of images must be equal to labels'
45 | return len(self.images)
--------------------------------------------------------------------------------
/src/lib/datasets/dataset/HumanSeg.py:
--------------------------------------------------------------------------------
1 | # *coding:utf-8 *
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 |
8 | import numpy as np
9 | import os
10 |
11 | import torch.utils.data as data
12 |
13 |
14 | class HumanSeg(data.Dataset):
15 | num_classes = 1
16 | default_resolution = [448, 448]
17 | mean = np.array([0.40789654, 0.44719302, 0.47026115],
18 | dtype=np.float32).reshape(1, 1, 3)
19 | std = np.array([0.28863828, 0.27408164, 0.27809835],
20 | dtype=np.float32).reshape(1, 1, 3)
21 | def __init__(self, opt, split):
22 | self.images = []
23 | self.labels = []
24 | self.opt = opt
25 |
26 | if split == 'train':
27 | self.image_root_folder = os.path.join(opt.data_dir,'test', 'imgs')
28 | self.gt_root_folder = os.path.join(opt.data_dir,'test', 'masks')
29 | else:
30 | self.image_root_folder = os.path.join(opt.data_dir, 'val', 'imgs')
31 | self.gt_root_folder = os.path.join(opt.data_dir, 'val', 'masks')
32 |
33 | self._read_img_mask(self.image_root_folder, self.gt_root_folder)
34 |
35 | def _read_img_mask(self, image_folder, mask_folder):
36 | for img_name in os.listdir(image_folder):
37 | image_path = os.path.join(image_folder, img_name.split('.')[0] + '.png')
38 | label_path = os.path.join(mask_folder, img_name.split('.')[0] + '.png')
39 |
40 | self.images.append(image_path)
41 | self.labels.append(label_path)
42 |
43 | def __len__(self):
44 | assert len(self.images) == len(self.labels), 'The number of images must be equal to labels'
45 | return len(self.images)
--------------------------------------------------------------------------------
/src/lib/datasets/dataset/ORIGA_OD.py:
--------------------------------------------------------------------------------
1 | # *coding:utf-8 *
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 |
8 | import numpy as np
9 | import os
10 |
11 | import torch.utils.data as data
12 |
13 |
14 | class ORIGA_OD(data.Dataset):
15 | num_classes = 1
16 | default_resolution = [448, 448]
17 | mean = np.array([0.40789654, 0.44719302, 0.47026115],
18 | dtype=np.float32).reshape(1, 1, 3)
19 | std = np.array([0.28863828, 0.27408164, 0.27809835],
20 | dtype=np.float32).reshape(1, 1, 3)
21 | def __init__(self, opt, split):
22 | self.images = []
23 | self.labels = []
24 | self.opt = opt
25 |
26 | if split == 'train':
27 | read_files = os.path.join(opt.data_dir, 'Set_A.txt')
28 | else:
29 | read_files = os.path.join(opt.data_dir, 'Set_B.txt')
30 |
31 | self.image_root_folder = os.path.join(opt.data_dir, 'crop_image')
32 | self.gt_root_folder = os.path.join(opt.data_dir, 'crop_mask')
33 |
34 | self._read_img_mask(self.image_root_folder, self.gt_root_folder, read_files)
35 |
36 | def _read_img_mask(self, image_folder, mask_folder, read_files):
37 | for img_name in open(read_files):
38 | image_path = os.path.join(image_folder, img_name.split('.')[0] + '.jpg')
39 | label_path = os.path.join(mask_folder, img_name.split('.')[0] + '.png')
40 |
41 | self.images.append(image_path)
42 | self.labels.append(label_path)
43 |
44 | def __len__(self):
45 | assert len(self.images) == len(self.labels), 'The number of images must be equal to labels'
46 | return len(self.images)
--------------------------------------------------------------------------------
/src/lib/datasets/dataset/VOC.py:
--------------------------------------------------------------------------------
1 | # *coding:utf-8 *
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 |
8 | import numpy as np
9 | import os
10 |
11 | import torch.utils.data as data
12 |
13 | class VOC(data.Dataset):
14 | num_classes = 1
15 | default_resolution = [448, 448]
16 | mean = np.array([0.40789654, 0.44719302, 0.47026115],
17 | dtype=np.float32).reshape(1, 1, 3)
18 | std = np.array([0.28863828, 0.27408164, 0.27809835],
19 | dtype=np.float32).reshape(1, 1, 3)
20 |
21 | def __init__(self, opt, split):
22 | self.images = []
23 | self.labels = []
24 | self.opt = opt
25 |
26 | if split == 'train':
27 | read_files = os.path.join(opt.data_dir, 'Set_A.txt')
28 | else:
29 | read_files = os.path.join(opt.data_dir, 'Set_B.txt')
30 |
31 | self.image_root_folder = os.path.join(opt.data_dir, 'crop_image')
32 | self.gt_root_folder = os.path.join(opt.data_dir, 'crop_mask')
33 |
34 | self._read_img_mask(self.image_root_folder, self.gt_root_folder, read_files)
35 |
36 | def _read_img_mask(self, image_folder, mask_folder, read_files):
37 | for img_name in open(read_files):
38 | image_path = os.path.join(image_folder, img_name.split('.')[0] + '.jpg')
39 | label_path = os.path.join(mask_folder, img_name.split('.')[0] + '.png')
40 |
41 | self.images.append(image_path)
42 | self.labels.append(label_path)
43 |
44 | def __len__(self):
45 | assert len(self.images) == len(self.labels), 'The number of images must be equal to labels'
46 | return len(self.images)
47 |
--------------------------------------------------------------------------------
/src/lib/datasets/dataset/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/src/lib/datasets/dataset/__init__.py
--------------------------------------------------------------------------------
/src/lib/datasets/dataset_factory.py:
--------------------------------------------------------------------------------
1 | # *coding:utf-8 *
2 | from __future__ import absolute_import
3 | from __future__ import division
4 | from __future__ import print_function
5 |
6 | from .dataset.ORIGA_OD import ORIGA_OD
7 | from .dataset.VOC import VOC
8 | from .dataset.Cityscape import Cityscape
9 | from .dataset.HumanSeg import HumanSeg
10 |
11 | from .sample.binarySeg import BinarySegDataset
12 | # from .sampel.multiSeg import MultiSegDataset
13 |
14 | dataset_factory = {
15 | 'ORIGA_OD': ORIGA_OD,
16 | 'VOC': VOC,
17 | 'CityScape': Cityscape,
18 | 'humanseg': HumanSeg,
19 | }
20 |
21 | _sample_factory = {
22 | 'binSeg': BinarySegDataset,
23 | # 'multiSeg': MultiSegDataset
24 | }
25 |
26 |
27 | def get_dataset(dataset, task):
28 | class Dataset(dataset_factory[dataset], _sample_factory[task]):
29 | pass
30 |
31 | return Dataset
32 |
--------------------------------------------------------------------------------
/src/lib/datasets/sample/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/src/lib/datasets/sample/__init__.py
--------------------------------------------------------------------------------
/src/lib/datasets/sample/binarySeg.py:
--------------------------------------------------------------------------------
1 | # *coding:utf-8 *
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | import torch
8 | import torch.utils.data as data
9 | from torch.autograd import Variable as V
10 |
11 | import cv2
12 | import numpy as np
13 | from PIL import Image
14 |
15 | from utils.image import randomHueSaturationValue
16 | from utils.image import randomShiftScaleRotate
17 | from utils.image import randomHorizontalFlip
18 | from utils.image import randomVerticleFlip, randomRotate90
19 |
20 |
21 | class BinarySegDataset(data.Dataset):
22 |
23 | def __getitem__(self, index):
24 | img = cv2.imread(self.images[index])
25 | img = cv2.resize(img, (self.opt.height, self.opt.width))
26 |
27 | mask = np.array(Image.open(self.labels[index]))
28 | mask = cv2.resize(mask, (self.opt.height, self.opt.width))
29 |
30 | # Data augmentation
31 | if self.opt.color_aug:
32 | img = randomHueSaturationValue(img,
33 | hue_shift_limit=(-30, 30),
34 | sat_shift_limit=(-5, 5),
35 | val_shift_limit=(-15, 15))
36 |
37 | if self.opt.shift_scale:
38 | img, mask = randomShiftScaleRotate(img, mask,
39 | shift_limit=(-0.1, 0.1),
40 | scale_limit=(-0.1, 0.1),
41 | aspect_limit=(-0.1, 0.1),
42 | rotate_limit=(-0, 0))
43 |
44 | if self.opt.HorizontalFlip:
45 | img, mask = randomHorizontalFlip(img, mask)
46 |
47 | if self.opt.VerticleFlip:
48 | img, mask = randomVerticleFlip(img, mask)
49 |
50 | if self.opt.rotate_90:
51 | img, mask = randomRotate90(img, mask)
52 |
53 | mask = np.expand_dims(mask, axis=2)
54 | img = np.array(img, np.float32).transpose(2, 0, 1) / 255.0 * 3.2 - 1.6
55 | mask = np.array(mask, np.float32).transpose(2, 0, 1) / 255.0
56 | mask[mask >= 0.5] = 1
57 | mask[mask <= 0.5] = 0
58 |
59 | img = torch.Tensor(img)
60 | mask = torch.Tensor(mask)
61 |
62 | ret = {'input': img, 'gt': mask}
63 | return ret
64 |
--------------------------------------------------------------------------------
/src/lib/datasets/sample/multiSeg.py:
--------------------------------------------------------------------------------
1 | # *coding:utf-8 *
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | import torch
8 | import torch.utils.data as data
9 | from torch.autograd import Variable as V
10 |
11 | import cv2
12 | import numpy as np
13 | from PIL import Image
14 |
15 | from utils.image import randomHueSaturationValue
16 | from utils.image import randomShiftScaleRotate
17 | from utils.image import randomHorizontalFlip
18 | from utils.image import randomVerticleFlip, randomRotate90
19 |
20 |
21 | class BinarySegDataset(data.Dataset):
22 |
23 | def __getitem__(self, index):
24 | img = cv2.imread(self.images[index])
25 | img = cv2.resize(img, (self.opt.height, self.opt.width))
26 |
27 | mask = np.array(Image.open(self.labels[index]))
28 | mask = cv2.resize(mask, (self.opt.height, self.opt.width))
29 |
30 | # Data augmentation
31 | img = randomHueSaturationValue(img,
32 | hue_shift_limit=(-30, 30),
33 | sat_shift_limit=(-5, 5),
34 | val_shift_limit=(-15, 15))
35 |
36 | img, mask = randomShiftScaleRotate(img, mask,
37 | shift_limit=(-0.1, 0.1),
38 | scale_limit=(-0.1, 0.1),
39 | aspect_limit=(-0.1, 0.1),
40 | rotate_limit=(-0, 0))
41 | img, mask = randomHorizontalFlip(img, mask)
42 | img, mask = randomVerticleFlip(img, mask)
43 | img, mask = randomRotate90(img, mask)
44 |
45 | mask = np.expand_dims(mask, axis=2)
46 | img = np.array(img, np.float32).transpose(2, 0, 1) / 255.0 * 3.2 - 1.6
47 | mask = np.array(mask, np.float32).transpose(2, 0, 1) / 255.0
48 | mask[mask >= 0.5] = 1
49 | mask[mask <= 0.5] = 0
50 |
51 | img = torch.Tensor(img)
52 | mask = torch.Tensor(mask)
53 |
54 | ret = {'input': img, 'gt': mask}
55 | return ret
56 |
--------------------------------------------------------------------------------
/src/lib/logger.py:
--------------------------------------------------------------------------------
1 | # *coding:utf-8 *
2 | from __future__ import absolute_import
3 | from __future__ import division
4 | from __future__ import print_function
5 |
6 |
7 | import os
8 | import time
9 | import sys
10 | import torch
11 |
12 | USE_TENSORBOARD = True
13 | try:
14 | import tensorboardX
15 | print("Using tensorboardX")
16 | except:
17 | USE_TENSORBOARD = False
18 |
19 | class Logger(object):
20 | def __init__(self, opt):
21 | """
22 | Create a summary writer logging to log_dir.
23 | """
24 | if not os.path.exists(opt.save_dir):
25 | os.makedirs(opt.save_dir)
26 | if not os.path.exists(opt.debug_dir):
27 | os.makedirs(opt.debug_dir)
28 |
29 | time_str = time.strftime('%Y-%m-%d-%H-%M')
30 |
31 | args = dict((name, getattr(opt, name)) for name in dir(opt)
32 | if not name.startswith('_'))
33 |
34 | file_name = os.path.join(opt.save_dir, 'opt.txt')
35 |
36 | with open(file_name, 'wt') as opt_file:
37 | opt_file.write('==> torch version: {}\n'.format(torch.__version__))
38 | opt_file.write('==> cudnn version: {}\n'.format(
39 | torch.backends.cudnn.version()))
40 | opt_file.write('==> Cmd:\n')
41 | opt_file.write(str(sys.argv))
42 | opt_file.write('\n==> Opt:\n')
43 | for k, v in sorted(args.items()):
44 | opt_file.write(' %s: %s\n' % (str(k), str(v)))
45 |
46 | log_dir = opt.save_dir + '/logs_{}'.format(time_str)
47 | if USE_TENSORBOARD:
48 | self.writer = tensorboardX.SummaryWriter(log_dir=log_dir)
49 | else:
50 | if not os.path.exists(os.path.dirname(log_dir)):
51 | os.mkdir(os.path.dirname(log_dir))
52 | if not os.path.exists(log_dir):
53 | os.mkdir(log_dir)
54 | self.log = open(log_dir + '/log.txt', 'w')
55 | try:
56 | os.system('cp {}/opt.txt {}/'.format(opt.save_dir, log_dir))
57 | except:
58 | pass
59 | self.start_line = True
60 |
61 | def write(self, txt):
62 | if self.start_line:
63 | time_str = time.strftime('%Y-%m-%d-%H-%M')
64 | self.log.write('{}: {}'.format(time_str, txt))
65 | else:
66 | self.log.write(txt)
67 | self.start_line = False
68 | if '\n' in txt:
69 | self.start_line = True
70 | self.log.flush()
71 |
72 | def close(self):
73 | self.log.close()
74 |
75 | def scalar_summary(self, tag, value, step):
76 | """Log a scalar variable."""
77 | if USE_TENSORBOARD:
78 | self.writer.add_scalar(tag, value, step)
--------------------------------------------------------------------------------
/src/lib/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/src/lib/models/__init__.py
--------------------------------------------------------------------------------
/src/lib/models/data_parallel.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn.modules import Module
3 | from torch.nn.parallel.scatter_gather import gather
4 | from torch.nn.parallel.replicate import replicate
5 | from torch.nn.parallel.parallel_apply import parallel_apply
6 |
7 |
8 | from .scatter_gather import scatter_kwargs
9 |
10 | class _DataParallel(Module):
11 | r"""Implements data parallelism at the module level.
12 |
13 | This container parallelizes the application of the given module by
14 | splitting the input across the specified devices by chunking in the batch
15 | dimension. In the forward pass, the module is replicated on each device,
16 | and each replica handles a portion of the input. During the backwards
17 | pass, gradients from each replica are summed into the original module.
18 |
19 | The batch size should be larger than the number of GPUs used. It should
20 | also be an integer multiple of the number of GPUs so that each chunk is the
21 | same size (so that each GPU processes the same number of samples).
22 |
23 | See also: :ref:`cuda-nn-dataparallel-instead`
24 |
25 | Arbitrary positional and keyword inputs are allowed to be passed into
26 | DataParallel EXCEPT Tensors. All variables will be scattered on dim
27 | specified (default 0). Primitive types will be broadcasted, but all
28 | other types will be a shallow copy and can be corrupted if written to in
29 | the model's forward pass.
30 |
31 | Args:
32 | module: module to be parallelized
33 | device_ids: CUDA devices (default: all devices)
34 | output_device: device location of output (default: device_ids[0])
35 |
36 | Example::
37 |
38 | >>> net = torch.nn.DataParallel(model, device_ids=[0, 1, 2])
39 | >>> output = net(input_var)
40 | """
41 |
42 |
43 | def __init__(self, module, device_ids=None, output_device=None, dim=0, chunk_sizes=None):
44 | super(_DataParallel, self).__init__()
45 |
46 | if not torch.cuda.is_available():
47 | self.module = module
48 | self.device_ids = []
49 | return
50 |
51 | if device_ids is None:
52 | device_ids = list(range(torch.cuda.device_count()))
53 | if output_device is None:
54 | output_device = device_ids[0]
55 | self.dim = dim
56 | self.module = module
57 | self.device_ids = device_ids
58 | self.chunk_sizes = chunk_sizes
59 | self.output_device = output_device
60 | if len(self.device_ids) == 1:
61 | self.module.cuda(device_ids[0])
62 |
63 | def forward(self, *inputs, **kwargs):
64 | if not self.device_ids:
65 | return self.module(*inputs, **kwargs)
66 | inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids, self.chunk_sizes)
67 | if len(self.device_ids) == 1:
68 | return self.module(*inputs[0], **kwargs[0])
69 | replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
70 | outputs = self.parallel_apply(replicas, inputs, kwargs)
71 | return self.gather(outputs, self.output_device)
72 |
73 | def replicate(self, module, device_ids):
74 | return replicate(module, device_ids)
75 |
76 | def scatter(self, inputs, kwargs, device_ids, chunk_sizes):
77 | return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim, chunk_sizes=self.chunk_sizes)
78 |
79 | def parallel_apply(self, replicas, inputs, kwargs):
80 | return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
81 |
82 | def gather(self, outputs, output_device):
83 | return gather(outputs, output_device, dim=self.dim)
84 |
85 |
86 | def data_parallel(module, inputs, device_ids=None, output_device=None, dim=0, module_kwargs=None):
87 | r"""Evaluates module(input) in parallel across the GPUs given in device_ids.
88 |
89 | This is the functional version of the DataParallel module.
90 |
91 | Args:
92 | module: the module to evaluate in parallel
93 | inputs: inputs to the module
94 | device_ids: GPU ids on which to replicate module
95 | output_device: GPU location of the output Use -1 to indicate the CPU.
96 | (default: device_ids[0])
97 | Returns:
98 | a Variable containing the result of module(input) located on
99 | output_device
100 | """
101 | if not isinstance(inputs, tuple):
102 | inputs = (inputs,)
103 |
104 | if device_ids is None:
105 | device_ids = list(range(torch.cuda.device_count()))
106 |
107 | if output_device is None:
108 | output_device = device_ids[0]
109 |
110 | inputs, module_kwargs = scatter_kwargs(inputs, module_kwargs, device_ids, dim)
111 | if len(device_ids) == 1:
112 | return module(*inputs[0], **module_kwargs[0])
113 | used_device_ids = device_ids[:len(inputs)]
114 | replicas = replicate(module, used_device_ids)
115 | outputs = parallel_apply(replicas, inputs, module_kwargs, used_device_ids)
116 | return gather(outputs, output_device, dim)
117 |
118 | def DataParallel(module, device_ids=None, output_device=None, dim=0, chunk_sizes=None):
119 | if chunk_sizes is None:
120 | return torch.nn.DataParallel(module, device_ids, output_device, dim)
121 | standard_size = True
122 | for i in range(1, len(chunk_sizes)):
123 | if chunk_sizes[i] != chunk_sizes[0]:
124 | standard_size = False
125 | if standard_size:
126 | return torch.nn.DataParallel(module, device_ids, output_device, dim)
127 | return _DataParallel(module, device_ids, output_device, dim, chunk_sizes)
--------------------------------------------------------------------------------
/src/lib/models/losses.py:
--------------------------------------------------------------------------------
1 | # *coding:utf-8 *
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torch.autograd import Variable as V
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | from torch.autograd import Variable
11 | from utils.image import mask_to_boundary
12 |
13 | import cv2
14 | import numpy as np
15 |
16 |
17 | class weighted_cross_entropy(nn.Module):
18 | def __init__(self, num_classes=12, batch=True):
19 | super(weighted_cross_entropy, self).__init__()
20 | self.batch = batch
21 | self.weight = torch.Tensor([52.] * num_classes).cuda()
22 | self.ce_loss = nn.CrossEntropyLoss(weight=self.weight)
23 |
24 | def __call__(self, y_true, y_pred):
25 | y_ce_true = y_true.squeeze(dim=1).long()
26 | a = self.ce_loss(y_pred, y_ce_true)
27 |
28 | return a
29 |
30 |
31 | class dice_loss(nn.Module):
32 | def __init__(self, batch=True):
33 | super(dice_loss, self).__init__()
34 | self.batch = batch
35 |
36 | def soft_dice_coeff(self, y_true, y_pred):
37 | smooth = 0.0 # may change
38 | if self.batch:
39 | i = torch.sum(y_true)
40 | j = torch.sum(y_pred)
41 | intersection = torch.sum(y_true * y_pred)
42 | else:
43 | i = y_true.sum(1).sum(1).sum(1)
44 | j = y_pred.sum(1).sum(1).sum(1)
45 | intersection = (y_true * y_pred).sum(1).sum(1).sum(1)
46 | score = (2. * intersection + smooth) / (i + j + smooth)
47 | # score = (intersection + smooth) / (i + j - intersection + smooth)#iou
48 | return score.mean()
49 |
50 | def soft_dice_loss(self, y_true, y_pred):
51 | loss = 1 - self.soft_dice_coeff(y_true, y_pred)
52 | return loss
53 |
54 | def __call__(self, y_true, y_pred):
55 |
56 | b = self.soft_dice_loss(y_true, y_pred)
57 | return b
58 |
59 |
60 | def test_weight_cross_entropy():
61 | N = 4
62 | C = 12
63 | H, W = 128, 128
64 |
65 | inputs = torch.rand(N, C, H, W)
66 | targets = torch.LongTensor(N, H, W).random_(C)
67 | inputs_fl = Variable(inputs.clone(), requires_grad=True)
68 | targets_fl = Variable(targets.clone())
69 | print(weighted_cross_entropy()(targets_fl, inputs_fl))
70 |
71 |
72 | class boundary_dice_bce_loss(nn.Module):
73 | def __init__(self, batch=True):
74 | super(boundary_dice_bce_loss, self).__init__()
75 | self.batch = batch
76 | self.bce_loss = nn.BCELoss()
77 |
78 | def soft_dice_coeff(self, y_true, y_pred):
79 | smooth = 0.0 # may change
80 | if self.batch:
81 | i = torch.sum(y_true)
82 | j = torch.sum(y_pred)
83 | intersection = torch.sum(y_true * y_pred)
84 | else:
85 | i = y_true.sum(1).sum(1).sum(1)
86 | j = y_pred.sum(1).sum(1).sum(1)
87 | intersection = (y_true * y_pred).sum(1).sum(1).sum(1)
88 | score = (2. * intersection + smooth) / (i + j + smooth)
89 | # score = (intersection + smooth) / (i + j - intersection + smooth)#iou
90 | return score.mean()
91 |
92 | def soft_dice_loss(self, y_true, y_pred):
93 | loss = 1 - self.soft_dice_coeff(y_true, y_pred)
94 | return loss
95 |
96 | def batch_mask_to_boundary(self, y_true, y_pred):
97 | batch, channel, height, width = y_true.size()
98 | y_true_boundary_batch = np.zeros(shape=(batch, height, width))
99 | y_pred_boundary_batch = np.zeros(shape=(batch, height, width))
100 |
101 | for i in range(batch):
102 | y_true_mask = y_true.cpu().detach().numpy()[i, 0, :, :]
103 | y_pred_mask = y_pred.cpu().detach().numpy()[i, 0, :, :]
104 | # shape (448, 448)(448, 448)
105 | y_true_mask_boundary = mask_to_boundary(y_true_mask)
106 | y_pred_mask_boundary = mask_to_boundary(y_pred_mask)
107 |
108 | y_true_boundary_batch[i, :, :] = y_true_mask_boundary
109 | y_pred_boundary_batch[i, :, :] = y_pred_mask_boundary
110 |
111 | y_true_boundary_mask = torch.Tensor(y_true_boundary_batch).cuda()
112 | y_pred_boundary_mask = torch.Tensor(y_pred_boundary_batch).cuda()
113 | # torch.Size([12, 448, 448])
114 | return y_true_boundary_mask, y_pred_boundary_mask
115 |
116 | def __call__(self, y_true, y_pred):
117 | # a = self.bce_loss(y_pred, y_true)
118 | b = self.soft_dice_loss(y_true, y_pred)
119 |
120 | y_true_boundary, y_pred_boundary = self.batch_mask_to_boundary(y_true, y_pred)
121 | b_aux = self.soft_dice_loss(y_true_boundary, y_pred_boundary)
122 | return b + b_aux
123 |
124 |
125 | class dice_bce_loss(nn.Module):
126 | def __init__(self, batch=True):
127 | super(dice_bce_loss, self).__init__()
128 | self.batch = batch
129 | self.bce_loss = nn.BCELoss()
130 |
131 | def soft_dice_coeff(self, y_true, y_pred):
132 | smooth = 0.0 # may change
133 | if self.batch:
134 | i = torch.sum(y_true)
135 | j = torch.sum(y_pred)
136 | intersection = torch.sum(y_true * y_pred)
137 | else:
138 | i = y_true.sum(1).sum(1).sum(1)
139 | j = y_pred.sum(1).sum(1).sum(1)
140 | intersection = (y_true * y_pred).sum(1).sum(1).sum(1)
141 | score = (2. * intersection + smooth) / (i + j + smooth)
142 | # score = (intersection + smooth) / (i + j - intersection + smooth)#iou
143 | return score.mean()
144 |
145 | def soft_dice_loss(self, y_true, y_pred):
146 | loss = 1 - self.soft_dice_coeff(y_true, y_pred)
147 | return loss
148 |
149 | def __call__(self, y_true, y_pred):
150 | a = self.bce_loss(y_pred, y_true)
151 | b = self.soft_dice_loss(y_true, y_pred)
152 |
153 | return b
154 |
155 |
156 | import torch
157 | import torch.nn as nn
158 |
159 |
160 | class DiceLoss(nn.Module):
161 | def __init__(self):
162 | super(DiceLoss, self).__init__()
163 |
164 | def forward(self, input, target):
165 | N, H, W = target.size(0), target.size(2), target.size(3)
166 | smooth = 1
167 |
168 | input_flat = input.view(N, -1)
169 | target_flat = target.view(N, -1)
170 |
171 | intersection = input_flat * target_flat
172 |
173 | loss = 2 * (intersection.sum(1) + smooth) / (input_flat.sum(1) + target_flat.sum(1) + smooth)
174 | loss = 1 - loss.sum() / N
175 |
176 | return loss
177 |
178 |
179 | class MulticlassDiceLoss(nn.Module):
180 | """
181 | requires one hot encoded target. Applies DiceLoss on each class iteratively.
182 | requires input.shape[0:1] and target.shape[0:1] to be (N, C) where N is
183 | batch size and C is number of classes
184 | """
185 |
186 | def __init__(self):
187 | super(MulticlassDiceLoss, self).__init__()
188 |
189 | def forward(self, input, target, weights=None):
190 |
191 | C = target.shape[1]
192 |
193 | # if weights is None:
194 | # weights = torch.ones(C) #uniform weights for all classes
195 |
196 | dice = DiceLoss()
197 | totalLoss = 0
198 |
199 | for i in range(C):
200 | diceLoss = dice(input[:, i, :, :], target[:, i, :, :])
201 | if weights is not None:
202 | diceLoss *= weights[i]
203 | totalLoss += diceLoss
204 |
205 | return totalLoss
206 |
207 |
208 | class FocalLoss(nn.Module):
209 | def __init__(self, gamma=0, alpha=None, size_average=True):
210 | super(FocalLoss, self).__init__()
211 | self.gamma = gamma
212 | self.alpha = alpha
213 | if isinstance(alpha, (float, int)): self.alpha = torch.Tensor([alpha, 1 - alpha])
214 | if isinstance(alpha, list): self.alpha = torch.Tensor(alpha)
215 | self.size_average = size_average
216 |
217 | def forward(self, target, input):
218 | target1 = torch.squeeze(target, dim=1)
219 | if input.dim() > 2:
220 | input = input.view(input.size(0), input.size(1), -1) # N,C,H,W => N,C,H*W
221 | input = input.transpose(1, 2) # N,C,H*W => N,H*W,C
222 | input = input.contiguous().view(-1, input.size(2)) # N,H*W,C => N*H*W,C
223 | target2 = target1.view(-1, 1).long()
224 |
225 | logpt = F.log_softmax(input, dim=1)
226 | # print(logpt.size())
227 | # print(target2.size())
228 | logpt = logpt.gather(1, target2)
229 | logpt = logpt.view(-1)
230 | pt = Variable(logpt.data.exp())
231 |
232 | if self.alpha is not None:
233 | if self.alpha.type() != input.data.type():
234 | self.alpha = self.alpha.type_as(input.data)
235 | at = self.alpha.gather(0, target.data.view(-1))
236 | logpt = logpt * Variable(at)
237 |
238 | loss = -1 * (1 - pt) ** self.gamma * logpt
239 | if self.size_average:
240 | return loss.mean()
241 | else:
242 | return loss.sum()
243 |
--------------------------------------------------------------------------------
/src/lib/models/model.py:
--------------------------------------------------------------------------------
1 | # *coding:utf-8 *
2 |
3 | import torch
4 |
5 | from .networks.cenet import CE_Net_
6 |
7 | _model_factory = {
8 | 'cenet': CE_Net_
9 | }
10 |
11 |
12 | def create_model(model_name):
13 | get_model = _model_factory[model_name]
14 | model = get_model()
15 | return model
16 |
17 |
18 | def load_model(model, model_path, optimizer=None, resume=False,
19 | lr=None, lr_step=None):
20 | start_epoch = 0
21 | checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage)
22 | print('loaded {}, epoch {}'.format(model_path, checkpoint['epoch']))
23 | state_dict_ = checkpoint['state_dict']
24 | state_dict = {}
25 |
26 | # convert data_parallal to model
27 | for k in state_dict_:
28 | if k.startswith('module') and not k.startswith('module_list'):
29 | state_dict[k[7:]] = state_dict_[k]
30 | else:
31 | state_dict[k] = state_dict_[k]
32 | model_state_dict = model.state_dict()
33 |
34 | # check loaded parameters and created model parameters
35 | msg = 'If you see this, your model does not fully load the ' + \
36 | 'pre-trained weight. Please make sure ' + \
37 | 'you have correctly specified --arch xxx ' + \
38 | 'or set the correct --num_classes for your own dataset.'
39 | for k in state_dict:
40 | if k in model_state_dict:
41 | if state_dict[k].shape != model_state_dict[k].shape:
42 | print('Skip loading parameter {}, required shape{}, ' \
43 | 'loaded shape{}. {}'.format(
44 | k, model_state_dict[k].shape, state_dict[k].shape, msg))
45 | state_dict[k] = model_state_dict[k]
46 | else:
47 | print('Drop parameter {}.'.format(k) + msg)
48 | for k in model_state_dict:
49 | if not (k in state_dict):
50 | print('No param {}.'.format(k) + msg)
51 | state_dict[k] = model_state_dict[k]
52 | model.load_state_dict(state_dict, strict=False)
53 |
54 | # resume optimizer parameters
55 | if optimizer is not None and resume:
56 | if 'optimizer' in checkpoint:
57 | optimizer.load_state_dict(checkpoint['optimizer'])
58 | start_epoch = checkpoint['epoch']
59 | start_lr = lr
60 | for step in lr_step:
61 | if start_epoch >= step:
62 | start_lr *= 0.1
63 | for param_group in optimizer.param_groups:
64 | param_group['lr'] = start_lr
65 | print('Resumed optimizer with start lr', start_lr)
66 | else:
67 | print('No optimizer parameters in checkpoint.')
68 | if optimizer is not None:
69 | return model, optimizer, start_epoch
70 | else:
71 | return model
72 |
73 |
74 | def save_model(path, epoch, model, optimizer=None):
75 | if isinstance(model, torch.nn.DataParallel):
76 | state_dict = model.module.state_dict()
77 | else:
78 | state_dict = model.state_dict()
79 | data = {'epoch': epoch,
80 | 'state_dict': state_dict}
81 | if not (optimizer is None):
82 | data['optimizer'] = optimizer.state_dict()
83 | torch.save(data, path)
84 |
--------------------------------------------------------------------------------
/src/lib/models/networks/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/src/lib/models/networks/__init__.py
--------------------------------------------------------------------------------
/src/lib/models/networks/backbones/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/src/lib/models/networks/backbones/__init__.py
--------------------------------------------------------------------------------
/src/lib/models/networks/backbones/backbone_factory.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 |
6 | from .mobilenet.mobilenet_factory import get_mobilenet_backbone
7 | from .resnet.resnet_factory import get_resnet_backbone
8 |
9 |
10 | def get_backbone_architecture(backbone_name):
11 | if "resnet" in backbone_name:
12 | return get_resnet_backbone(backbone_name)
13 |
14 | elif "mobilenet" in backbone_name:
15 | return get_mobilenet_backbone(backbone_name)
--------------------------------------------------------------------------------
/src/lib/models/networks/backbones/mobilenet/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/src/lib/models/networks/backbones/mobilenet/__init__.py
--------------------------------------------------------------------------------
/src/lib/models/networks/backbones/mobilenet/basic_module.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | from torch import nn
6 |
7 | from collections import OrderedDict
8 |
9 | def _make_divisible(v, divisor, min_value=None):
10 | """
11 | This function is taken from the original pytorch repo
12 | It ensures that all layers have a channel number that is divisable by 8
13 | it can be seen here:
14 | https://pytorch.org/vision/stable/_modules/torchvision/models/mobilenetv2.html
15 | :param v:
16 | :param divisor:
17 | :param min_value:
18 | :return:
19 | """
20 |
21 | if min_value is None:
22 | min_value = divisor
23 |
24 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
25 |
26 | if new_v < 0.9 * v:
27 | new_v += divisor
28 |
29 | return new_v
30 |
31 |
32 | class ConvBNReLU(nn.Sequential):
33 | def __init__(self, in_planes, out_planes, kernel_size=3,
34 | stride=1, group=1):
35 | padding = (kernel_size - 1) // 2
36 | super(ConvBNReLU, self).__init__(
37 | nn.Conv2d(in_planes, out_planes, kernel_size,
38 | stride, padding),
39 | nn.BatchNorm2d(out_planes),
40 | nn.ReLU6(in_planes)
41 | )
42 |
43 |
44 | class InvertedResidual(nn.Module):
45 | def __init__(self, inp, oup, stride, expand_ratio):
46 | super(InvertedResidual, self).__init__()
47 | self.stride = stride
48 | assert stride in [1, 2]
49 |
50 | hidden_dim = int(round(inp * expand_ratio))
51 |
52 | self.use_res_connect = self.stride == 1 and inp == oup
53 |
54 | layers = []
55 |
56 | if expand_ratio != 1:
57 | # pixelwise
58 | layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
59 |
60 | layers.extend([
61 | # depthwise
62 | ConvBNReLU(hidden_dim, hidden_dim, stride=stride,
63 | group=hidden_dim),
64 | # pw-linear
65 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
66 | nn.BatchNorm2d(oup)
67 | ])
68 |
69 | self.conv = nn.Sequential(*layers)
70 |
71 | def forward(self, x):
72 | if self.use_res_connect:
73 | return x + self.conv(x)
74 |
75 | else:
76 | return self.conv(x)
77 |
78 |
79 | def load_model(model, state_dict):
80 | new_model = model.state_dict()
81 | new_keys = list(new_model.keys())
82 | old_keys = list(state_dict.keys())
83 |
84 | restore_dict = OrderedDict()
85 |
86 | for id in range(len(new_keys)):
87 | restore_dict[new_keys[id]] = state_dict[old_keys[id]]
88 |
89 | model.load_state_dict(restore_dict)
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
--------------------------------------------------------------------------------
/src/lib/models/networks/backbones/mobilenet/build_mobilenet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.utils.model_zoo as model_zoo
4 |
5 | from .basic_module import ConvBNReLU, InvertedResidual, _make_divisible, load_model
6 |
7 | model_urls = {
8 | 'mobilenetv2_10': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth',
9 | }
10 |
11 |
12 | class MobileNetV2(nn.Module):
13 | def __init__(self, width_mult=1.0, round_nearest=8):
14 | super(MobileNetV2, self).__init__()
15 | block = InvertedResidual
16 | input_channel = 32
17 |
18 | inverted_residual_setting = [
19 | # t, c, n, s
20 | [1, 16, 1, 1], # 0
21 | [6, 24, 2, 2], # 1
22 | [6, 32, 3, 2], # 2
23 | [6, 64, 4, 2], # 3
24 | [6, 96, 3, 1], # 4
25 | [6, 160, 3, 2], # 5
26 | [6, 320, 1, 1], # 6
27 | ]
28 |
29 | self.feat_id = [1, 2, 4, 6]
30 | self.feat_channel = []
31 |
32 | # only check the first element, assuming user know t,c,n,s
33 | # are required
34 |
35 | if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
36 | raise ValueError("inverted_residual_setting should be non-empty "
37 | "or a 4-element list, got {}".format(inverted_residual_setting))
38 |
39 | # building first layer
40 | input_channel = _make_divisible(input_channel * width_mult, round_nearest)
41 |
42 | features = [ConvBNReLU(3, input_channel, stride=2)]
43 |
44 | # building inverted residual blocks
45 |
46 | for id, (t, c, n, s) in enumerate(inverted_residual_setting):
47 | output_channel = _make_divisible(c * width_mult, round_nearest)
48 | for i in range(n):
49 | stride = s if i == 0 else 1
50 | features.append(block(input_channel, output_channel, stride, expand_ratio=t))
51 | input_channel = output_channel
52 | if id in self.feat_id:
53 | self.__setattr__("layer%d" % id, nn.Sequential(*features))
54 | self.feat_channel.append(output_channel)
55 | features = []
56 |
57 | # weight initialization
58 | for m in self.modules():
59 | if isinstance(m, nn.Conv2d):
60 | nn.init.kaiming_normal_(m.weight, mode='fan_out')
61 | if m.bias is not None:
62 | nn.init.zeros_(m.bias)
63 | elif isinstance(m, nn.BatchNorm2d):
64 | nn.init.ones_(m.weight)
65 | nn.init.zeros_(m.bias)
66 |
67 | def forward(self, x):
68 | y = []
69 | for id in self.feat_id:
70 | x = self.__getattr__("layer%d" % id)(x)
71 | y.append(x)
72 |
73 | return y
74 |
75 | def init_weights(self, model_name):
76 | url = model_urls[model_name]
77 | pretrained_state_dict = model_zoo.load_url(url)
78 | print('=> loading pretrained model {}'.format(url))
79 | self.load_state_dict(pretrained_state_dict, strict=False)
80 |
81 |
82 | def get_mobilenetv2_10(pretrained=True, **kwargs):
83 | model = MobileNetV2(width_mult=1.0)
84 | if pretrained:
85 | model.init_weights(model_name='mobilenetv2_10')
86 |
87 | return model
88 |
89 |
90 | def get_mobilenetv2_5(pretrained=True, **kwargs):
91 | model = MobileNetV2(width_mult=0.5)
92 | if pretrained:
93 | print("MobilenetV2_5 does not have the pretrained weight")
94 |
95 | return model
96 |
97 | if __name__ == '__main__':
98 |
99 | model = get_mobilenetv2_10(pretrained=True)
100 |
101 | input = torch.zeros([1, 3, 512, 512])
102 | feats = model(input)
103 | print(feats[0].size())
104 | print(feats[1].size())
105 | print(feats[2].size())
106 | print(feats[3].size())
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
--------------------------------------------------------------------------------
/src/lib/models/networks/backbones/mobilenet/mobilenet_factory.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | from .build_mobilenet import get_mobilenetv2_10
6 | from .build_mobilenet import get_mobilenetv2_5
7 |
8 | _mobilenet_backbone = {
9 | 'mobilenetv2_10': get_mobilenetv2_10,
10 | 'mobilenetv2_5': get_mobilenetv2_5,
11 |
12 | }
13 |
14 |
15 | def get_mobilenet_backbone(model_name):
16 | support_mobile_models = ['mobilenetv2_10', 'mobilenetv2_5']
17 | assert model_name in support_mobile_models, "We just support the following models: {}".format(support_mobile_models)
18 |
19 | model = _mobilenet_backbone[model_name]
20 |
21 | return model
22 |
23 |
24 | if __name__ == '__main__':
25 | str1 = 'mobilenetv2_10'
26 | model = get_mobilenet_backbone(str1)
27 |
28 | print(model(pretrain=True))
29 |
--------------------------------------------------------------------------------
/src/lib/models/networks/backbones/resnet/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/src/lib/models/networks/backbones/resnet/__init__.py
--------------------------------------------------------------------------------
/src/lib/models/networks/backbones/resnet/basic_module.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------
2 | # Copyright (c) Microsoft
3 | # Licensed under the MIT License.
4 | # Written by Bin Xiao (Bin.Xiao@microsoft.com)
5 | # Modified by Dequan Wang and Xingyi Zhou
6 | # ------------------------------------------------------------------------------
7 |
8 | from __future__ import absolute_import
9 | from __future__ import division
10 | from __future__ import print_function
11 |
12 | import torch.nn as nn
13 |
14 | BN_MOMENTUM = 0.1
15 |
16 |
17 | def conv3x3(in_planes, out_planes, stride=1):
18 | """
19 | 3x3 convolution with padding
20 | """
21 |
22 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
23 | padding=1, bias=False)
24 |
25 |
26 | class BasicBlock(nn.Module):
27 | expansion = 1
28 |
29 | def __init__(self, inplanes, planes, stride=1, downsample=None):
30 | super(BasicBlock, self).__init__()
31 | self.conv1 = conv3x3(inplanes, planes, stride)
32 | self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
33 | self.relu = nn.ReLU(inplace=True)
34 |
35 | self.conv2 = conv3x3(planes, planes)
36 | self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
37 | self.downsample = downsample
38 | self.stride = stride
39 |
40 | def forward(self, x):
41 | residual = x
42 |
43 | out = self.relu(self.bn1(self.conv1(x)))
44 | out = self.bn2(self.conv2(out))
45 |
46 | if self.downsample is not None:
47 | residual = self.downsample(x)
48 |
49 | out += residual
50 | out = self.relu(out)
51 |
52 | return out
53 |
54 |
55 | class Bottleneck(nn.Module):
56 | expansion = 4
57 |
58 | def __init__(self, inplanes, planes, stride=1, downsample=None):
59 | super(Bottleneck, self).__init__()
60 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
61 | self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
62 |
63 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
64 | padding=1, bias=False)
65 | self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
66 |
67 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
68 | self.bn3 = nn.BatchNorm2d(planes * self.expansion, momentum=BN_MOMENTUM)
69 |
70 | self.relu = nn.ReLU(inplace=True)
71 | self.downsample = downsample
72 | self.stride = stride
73 |
74 | def forward(self, x):
75 | residual = x
76 |
77 | out = self.relu(self.bn1(self.conv1(x)))
78 |
79 | out = self.conv2(self.bn2(self.conv2(out)))
80 |
81 | out = self.bn3(self.conv3(out))
82 |
83 | if self.downsample is not None:
84 | residual = self.downsample(x)
85 |
86 | out += residual
87 | out = self.relu(out)
88 |
89 | return out
90 |
--------------------------------------------------------------------------------
/src/lib/models/networks/backbones/resnet/build_resnet.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | from .basic_module import BasicBlock, Bottleneck, BN_MOMENTUM
6 |
7 | import torch.nn as nn
8 | import torch.utils.model_zoo as model_zoo
9 |
10 | model_urls = {
11 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
12 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
13 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
14 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
15 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
16 | }
17 |
18 |
19 | class build_resnet(nn.Module):
20 | def __init__(self, block, layers):
21 | super(build_resnet, self).__init__()
22 | self.inplanes = 64
23 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
24 | self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
25 | self.relu = nn.ReLU(inplace=True)
26 |
27 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
28 |
29 | self.layer1 = self._make_layers(block, 64, layers[0])
30 | self.layer2 = self._make_layers(block, 128, layers[1], stride=2)
31 | self.layer3 = self._make_layers(block, 256, layers[2], stride=2)
32 | self.layer4 = self._make_layers(block, 512, layers[3], stride=2)
33 |
34 | def _make_layers(self, block, planes, blocks, stride=1):
35 | downsample = None
36 |
37 | if stride != 1 or self.inplanes != planes * block.expansion:
38 | downsample = nn.Sequential(
39 | nn.Conv2d(self.inplanes, planes * block.expansion,
40 | kernel_size=1, stride=stride, bias=False),
41 | nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM),
42 | )
43 |
44 | layers = []
45 | layers.append(block(self.inplanes, planes, stride, downsample))
46 |
47 | self.inplanes = planes * block.expansion
48 |
49 | for i in range(1, blocks):
50 | layers.append(block(self.inplanes, planes))
51 |
52 | return nn.Sequential(*layers)
53 |
54 | def forward(self, x):
55 | x = self.conv1(x)
56 | x = self.bn1(x)
57 | x = self.relu(x)
58 | x = self.maxpool(x)
59 |
60 | x = self.layer1(x)
61 | stage_1_feature = x
62 | x = self.layer2(x)
63 | stage_2_feature = x
64 | x = self.layer3(x)
65 | stage_3_feature = x
66 | x = self.layer4(x)
67 | stage_4_feature = x
68 |
69 | return x
70 |
71 | def init_weights(self, resnet_model_name):
72 | url = model_urls[resnet_model_name]
73 | pretrained_state_dict = model_zoo.load_url(url)
74 | print('=> loading pretrained model {}'.format(url))
75 | self.load_state_dict(pretrained_state_dict, strict=False)
76 |
77 |
78 | # Resnet specific parameters
79 | # resnet_spec = {18: (BasicBlock, [2, 2, 2, 2]),
80 | # 34: (BasicBlock, [3, 4, 6, 3]),
81 | # 50: (Bottleneck, [3, 4, 6, 3]),
82 | # 101: (Bottleneck, [3, 4, 23, 3]),
83 | # 152: (Bottleneck, [3, 8, 36, 3])}
84 |
85 | def get_resnet_18(pretrain=True):
86 | model = build_resnet(BasicBlock, [2, 2, 2, 2])
87 | if pretrain:
88 | model.init_weights(resnet_model_name='resnet18')
89 |
90 | return model
91 |
92 |
93 | def get_resnet_34(pretrain=True):
94 | model = build_resnet(BasicBlock, [3, 4, 6, 3])
95 | if pretrain:
96 | model.init_weights(resnet_model_name='resnet34')
97 | return model
98 |
99 |
100 | def get_resnet_50(pretrain=True):
101 | model = build_resnet(Bottleneck, [3, 4, 6, 3])
102 | if pretrain:
103 | model.init_weights(resnet_model_name='resnet50')
104 | return model
105 |
106 |
107 | def get_resnet_101(pretrain=True):
108 | model = build_resnet(Bottleneck, [3, 4, 23, 3])
109 | if pretrain:
110 | model.init_weights(resnet_model_name='resnet101')
111 | return model
112 |
113 |
114 | def get_resnet_152(pretrain=True):
115 | model = build_resnet(Bottleneck, [3, 8, 36, 3])
116 | if pretrain:
117 | model.init_weights(resnet_model_name='resnet152')
118 | return model
119 |
120 |
121 | if __name__ == '__main__':
122 | resnet_model = get_resnet_34(pretrain=True)
123 | print(resnet_model)
124 |
--------------------------------------------------------------------------------
/src/lib/models/networks/backbones/resnet/resnet_factory.py:
--------------------------------------------------------------------------------
1 | # *coding:utf-8 *
2 |
3 |
4 | from .build_resnet import get_resnet_18
5 | from .build_resnet import get_resnet_34
6 | from .build_resnet import get_resnet_50
7 | from .build_resnet import get_resnet_101
8 | from .build_resnet import get_resnet_152
9 |
10 | _resnet_backbone = {
11 | 'resnet18': get_resnet_18,
12 | 'resnet34': get_resnet_34,
13 | 'resnet50': get_resnet_50,
14 | 'resnet101': get_resnet_101,
15 | 'resnet152': get_resnet_152,
16 |
17 | }
18 |
19 |
20 | def get_resnet_backbone(model_name):
21 | support_resnet_models = ['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152']
22 | assert model_name in support_resnet_models, "We just support the following models: {}".format(support_resnet_models)
23 |
24 | model = _resnet_backbone[model_name]
25 |
26 | return model
27 |
28 | if __name__ == '__main__':
29 | str1 = 'resnet18'
30 | model = get_resnet_backbone(str1)
31 |
32 | print(model(pretrain=True))
--------------------------------------------------------------------------------
/src/lib/models/networks/cenet.py:
--------------------------------------------------------------------------------
1 | # *coding:utf-8 *
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torchvision import models
6 | import torch.nn.functional as F
7 |
8 | from .backbones.resnet.resnet_factory import get_resnet_backbone
9 |
10 |
11 | from functools import partial
12 |
13 | nonlinearity = partial(F.relu, inplace=True)
14 |
15 | class DACblock(nn.Module):
16 | def __init__(self, channel):
17 | super(DACblock, self).__init__()
18 | self.dilate1 = nn.Conv2d(channel, channel, kernel_size=3, dilation=1, padding=1)
19 | self.dilate2 = nn.Conv2d(channel, channel, kernel_size=3, dilation=3, padding=3)
20 | self.dilate3 = nn.Conv2d(channel, channel, kernel_size=3, dilation=5, padding=5)
21 | self.conv1x1 = nn.Conv2d(channel, channel, kernel_size=1, dilation=1, padding=0)
22 | for m in self.modules():
23 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
24 | if m.bias is not None:
25 | m.bias.data.zero_()
26 |
27 | def forward(self, x):
28 | dilate1_out = nonlinearity(self.dilate1(x))
29 | dilate2_out = nonlinearity(self.conv1x1(self.dilate2(x)))
30 | dilate3_out = nonlinearity(self.conv1x1(self.dilate2(self.dilate1(x))))
31 | dilate4_out = nonlinearity(self.conv1x1(self.dilate3(self.dilate2(self.dilate1(x)))))
32 | out = x + dilate1_out + dilate2_out + dilate3_out + dilate4_out
33 | return out
34 |
35 |
36 | class DACblock_without_atrous(nn.Module):
37 | def __init__(self, channel):
38 | super(DACblock_without_atrous, self).__init__()
39 | self.dilate1 = nn.Conv2d(channel, channel, kernel_size=3, dilation=1, padding=1)
40 | self.dilate2 = nn.Conv2d(channel, channel, kernel_size=3, dilation=1, padding=1)
41 | self.dilate3 = nn.Conv2d(channel, channel, kernel_size=3, dilation=1, padding=1)
42 | self.conv1x1 = nn.Conv2d(channel, channel, kernel_size=1, dilation=1, padding=0)
43 | for m in self.modules():
44 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
45 | if m.bias is not None:
46 | m.bias.data.zero_()
47 |
48 | def forward(self, x):
49 | dilate1_out = nonlinearity(self.dilate1(x))
50 | dilate2_out = nonlinearity(self.conv1x1(self.dilate2(x)))
51 | dilate3_out = nonlinearity(self.conv1x1(self.dilate2(self.dilate1(x))))
52 | dilate4_out = nonlinearity(self.conv1x1(self.dilate3(self.dilate2(self.dilate1(x)))))
53 | out = x + dilate1_out + dilate2_out + dilate3_out + dilate4_out
54 |
55 | return out
56 |
57 | class DACblock_with_inception(nn.Module):
58 | def __init__(self, channel):
59 | super(DACblock_with_inception, self).__init__()
60 | self.dilate1 = nn.Conv2d(channel, channel, kernel_size=1, dilation=1, padding=0)
61 |
62 | self.dilate3 = nn.Conv2d(channel, channel, kernel_size=3, dilation=1, padding=1)
63 | self.conv1x1 = nn.Conv2d(2 * channel, channel, kernel_size=1, dilation=1, padding=0)
64 | for m in self.modules():
65 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
66 | if m.bias is not None:
67 | m.bias.data.zero_()
68 |
69 | def forward(self, x):
70 | dilate1_out = nonlinearity(self.dilate1(x))
71 | dilate2_out = nonlinearity(self.dilate3(self.dilate1(x)))
72 | dilate_concat = nonlinearity(self.conv1x1(torch.cat([dilate1_out, dilate2_out], 1)))
73 | dilate3_out = nonlinearity(self.dilate1(dilate_concat))
74 | out = x + dilate3_out
75 | return out
76 |
77 |
78 | class DACblock_with_inception_blocks(nn.Module):
79 | def __init__(self, channel):
80 | super(DACblock_with_inception_blocks, self).__init__()
81 | self.conv1x1 = nn.Conv2d(channel, channel, kernel_size=1, dilation=1, padding=0)
82 | self.conv3x3 = nn.Conv2d(channel, channel, kernel_size=3, dilation=1, padding=1)
83 | self.conv5x5 = nn.Conv2d(channel, channel, kernel_size=5, dilation=1, padding=2)
84 | self.pooling = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
85 |
86 | for m in self.modules():
87 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
88 | if m.bias is not None:
89 | m.bias.data.zero_()
90 |
91 | def forward(self, x):
92 | dilate1_out = nonlinearity(self.conv1x1(x))
93 | dilate2_out = nonlinearity(self.conv3x3(self.conv1x1(x)))
94 | dilate3_out = nonlinearity(self.conv5x5(self.conv1x1(x)))
95 | dilate4_out = self.pooling(x)
96 | out = dilate1_out + dilate2_out + dilate3_out + dilate4_out
97 | return out
98 |
99 |
100 |
101 | class PSPModule(nn.Module):
102 | def __init__(self, features, out_features=1024, sizes=(2, 3, 6, 14)):
103 | super().__init__()
104 | self.stages = []
105 | self.stages = nn.ModuleList([self._make_stage(features, size) for size in sizes])
106 | self.bottleneck = nn.Conv2d(features * (len(sizes) + 1), out_features, kernel_size=1)
107 | self.relu = nn.ReLU()
108 |
109 | def _make_stage(self, features, size):
110 | prior = nn.AdaptiveAvgPool2d(output_size=(size, size))
111 | conv = nn.Conv2d(features, features, kernel_size=1, bias=False)
112 | return nn.Sequential(prior, conv)
113 |
114 | def forward(self, feats):
115 | h, w = feats.size(2), feats.size(3)
116 | priors = [F.upsample(input=stage(feats), size=(h, w), mode='bilinear') for stage in self.stages] + [feats]
117 | bottle = self.bottleneck(torch.cat(priors, 1))
118 | return self.relu(bottle)
119 |
120 |
121 | class SPPblock(nn.Module):
122 | def __init__(self, in_channels):
123 | super(SPPblock, self).__init__()
124 | self.pool1 = nn.MaxPool2d(kernel_size=[2, 2], stride=2)
125 | self.pool2 = nn.MaxPool2d(kernel_size=[3, 3], stride=3)
126 | self.pool3 = nn.MaxPool2d(kernel_size=[5, 5], stride=5)
127 | self.pool4 = nn.MaxPool2d(kernel_size=[6, 6], stride=6)
128 |
129 | self.conv = nn.Conv2d(in_channels=in_channels, out_channels=1, kernel_size=1, padding=0)
130 |
131 | def forward(self, x):
132 | self.in_channels, h, w = x.size(1), x.size(2), x.size(3)
133 | self.layer1 = F.upsample(self.conv(self.pool1(x)), size=(h, w), mode='bilinear')
134 | self.layer2 = F.upsample(self.conv(self.pool2(x)), size=(h, w), mode='bilinear')
135 | self.layer3 = F.upsample(self.conv(self.pool3(x)), size=(h, w), mode='bilinear')
136 | self.layer4 = F.upsample(self.conv(self.pool4(x)), size=(h, w), mode='bilinear')
137 |
138 | out = torch.cat([self.layer1, self.layer2, self.layer3, self.layer4, x], 1)
139 |
140 | return out
141 |
142 |
143 | class DecoderBlock(nn.Module):
144 | def __init__(self, in_channels, n_filters):
145 | super(DecoderBlock, self).__init__()
146 |
147 | self.conv1 = nn.Conv2d(in_channels, in_channels // 4, 1)
148 | self.norm1 = nn.BatchNorm2d(in_channels // 4)
149 | self.relu1 = nonlinearity
150 |
151 | self.deconv2 = nn.ConvTranspose2d(in_channels // 4, in_channels // 4, 3, stride=2, padding=1, output_padding=1)
152 | self.norm2 = nn.BatchNorm2d(in_channels // 4)
153 | self.relu2 = nonlinearity
154 |
155 | self.conv3 = nn.Conv2d(in_channels // 4, n_filters, 1)
156 | self.norm3 = nn.BatchNorm2d(n_filters)
157 | self.relu3 = nonlinearity
158 |
159 | def forward(self, x):
160 | x = self.conv1(x)
161 | x = self.norm1(x)
162 | x = self.relu1(x)
163 | x = self.deconv2(x)
164 | x = self.norm2(x)
165 | x = self.relu2(x)
166 | x = self.conv3(x)
167 | x = self.norm3(x)
168 | x = self.relu3(x)
169 | return x
170 |
171 |
172 | class CE_Net_(nn.Module):
173 | def __init__(self, num_classes=1, num_channels=3):
174 | super(CE_Net_, self).__init__()
175 | filters = [64, 128, 256, 512]
176 | # resnet = models.resnet34(pretrained=True)
177 | resnet = get_resnet_backbone('resnet34')(pretrain=True)
178 | self.firstconv = resnet.conv1
179 | self.firstbn = resnet.bn1
180 | self.firstrelu = resnet.relu
181 | self.firstmaxpool = resnet.maxpool
182 | self.encoder1 = resnet.layer1
183 | self.encoder2 = resnet.layer2
184 | self.encoder3 = resnet.layer3
185 | self.encoder4 = resnet.layer4
186 |
187 | self.dblock = DACblock(512)
188 | self.spp = SPPblock(512)
189 |
190 | self.decoder4 = DecoderBlock(516, filters[2])
191 | self.decoder3 = DecoderBlock(filters[2], filters[1])
192 | self.decoder2 = DecoderBlock(filters[1], filters[0])
193 | self.decoder1 = DecoderBlock(filters[0], filters[0])
194 |
195 | self.finaldeconv1 = nn.ConvTranspose2d(filters[0], 32, 4, 2, 1)
196 | self.finalrelu1 = nonlinearity
197 | self.finalconv2 = nn.Conv2d(32, 32, 3, padding=1)
198 | self.finalrelu2 = nonlinearity
199 | self.finalconv3 = nn.Conv2d(32, num_classes, 3, padding=1)
200 |
201 | def forward(self, x):
202 | # Encoder
203 | x = self.firstconv(x)
204 | x = self.firstbn(x)
205 | x = self.firstrelu(x)
206 | x = self.firstmaxpool(x)
207 | e1 = self.encoder1(x)
208 | e2 = self.encoder2(e1)
209 | e3 = self.encoder3(e2)
210 | e4 = self.encoder4(e3)
211 |
212 | # Center
213 | e4 = self.dblock(e4)
214 | e4 = self.spp(e4)
215 |
216 | # Decoder
217 | d4 = self.decoder4(e4) + e3
218 | d3 = self.decoder3(d4) + e2
219 | d2 = self.decoder2(d3) + e1
220 | d1 = self.decoder1(d2)
221 |
222 |
223 | out = self.finaldeconv1(d1)
224 | out = self.finalrelu1(out)
225 | out = self.finalconv2(out)
226 | out = self.finalrelu2(out)
227 | out = self.finalconv3(out)
228 |
229 | return torch.sigmoid(out)
230 |
231 |
232 | class CE_Net_backbone_DAC_without_atrous(nn.Module):
233 | def __init__(self, num_classes=1, num_channels=3):
234 | super(CE_Net_backbone_DAC_without_atrous, self).__init__()
235 |
236 | filters = [64, 128, 256, 512]
237 | resnet = models.resnet34(pretrained=True)
238 | self.firstconv = resnet.conv1
239 | self.firstbn = resnet.bn1
240 | self.firstrelu = resnet.relu
241 | self.firstmaxpool = resnet.maxpool
242 | self.encoder1 = resnet.layer1
243 | self.encoder2 = resnet.layer2
244 | self.encoder3 = resnet.layer3
245 | self.encoder4 = resnet.layer4
246 |
247 | self.dblock = DACblock_without_atrous(512)
248 |
249 |
250 | self.decoder4 = DecoderBlock(512, filters[2])
251 | self.decoder3 = DecoderBlock(filters[2], filters[1])
252 | self.decoder2 = DecoderBlock(filters[1], filters[0])
253 | self.decoder1 = DecoderBlock(filters[0], filters[0])
254 |
255 | self.finaldeconv1 = nn.ConvTranspose2d(filters[0], 32, 4, 2, 1)
256 | self.finalrelu1 = nonlinearity
257 | self.finalconv2 = nn.Conv2d(32, 32, 3, padding=1)
258 | self.finalrelu2 = nonlinearity
259 | self.finalconv3 = nn.Conv2d(32, num_classes, 3, padding=1)
260 |
261 | def forward(self, x):
262 | # Encoder
263 | x = self.firstconv(x)
264 | x = self.firstbn(x)
265 | x = self.firstrelu(x)
266 | x = self.firstmaxpool(x)
267 | e1 = self.encoder1(x)
268 | e2 = self.encoder2(e1)
269 | e3 = self.encoder3(e2)
270 | e4 = self.encoder4(e3)
271 |
272 | # Center
273 | e4 = self.dblock(e4)
274 | # e4 = self.spp(e4)
275 |
276 | # Decoder
277 | d4 = self.decoder4(e4) + e3
278 | d3 = self.decoder3(d4) + e2
279 | d2 = self.decoder2(d3) + e1
280 | d1 = self.decoder1(d2)
281 |
282 | out = self.finaldeconv1(d1)
283 | out = self.finalrelu1(out)
284 | out = self.finalconv2(out)
285 | out = self.finalrelu2(out)
286 | out = self.finalconv3(out)
287 |
288 | return torch.sigmoid(out)
289 |
290 | class CE_Net_backbone_DAC_with_inception(nn.Module):
291 | def __init__(self, num_classes=1, num_channels=3):
292 | super(CE_Net_backbone_DAC_with_inception, self).__init__()
293 |
294 | filters = [64, 128, 256, 512]
295 | resnet = models.resnet34(pretrained=True)
296 | self.firstconv = resnet.conv1
297 | self.firstbn = resnet.bn1
298 | self.firstrelu = resnet.relu
299 | self.firstmaxpool = resnet.maxpool
300 | self.encoder1 = resnet.layer1
301 | self.encoder2 = resnet.layer2
302 | self.encoder3 = resnet.layer3
303 | self.encoder4 = resnet.layer4
304 |
305 | self.dblock = DACblock_with_inception(512)
306 |
307 |
308 | self.decoder4 = DecoderBlock(512, filters[2])
309 | self.decoder3 = DecoderBlock(filters[2], filters[1])
310 | self.decoder2 = DecoderBlock(filters[1], filters[0])
311 | self.decoder1 = DecoderBlock(filters[0], filters[0])
312 |
313 | self.finaldeconv1 = nn.ConvTranspose2d(filters[0], 32, 4, 2, 1)
314 | self.finalrelu1 = nonlinearity
315 | self.finalconv2 = nn.Conv2d(32, 32, 3, padding=1)
316 | self.finalrelu2 = nonlinearity
317 | self.finalconv3 = nn.Conv2d(32, num_classes, 3, padding=1)
318 |
319 | def forward(self, x):
320 | # Encoder
321 | x = self.firstconv(x)
322 | x = self.firstbn(x)
323 | x = self.firstrelu(x)
324 | x = self.firstmaxpool(x)
325 | e1 = self.encoder1(x)
326 | e2 = self.encoder2(e1)
327 | e3 = self.encoder3(e2)
328 | e4 = self.encoder4(e3)
329 |
330 | # Center
331 | e4 = self.dblock(e4)
332 | # e4 = self.spp(e4)
333 |
334 | # Decoder
335 | d4 = self.decoder4(e4) + e3
336 | d3 = self.decoder3(d4) + e2
337 | d2 = self.decoder2(d3) + e1
338 | d1 = self.decoder1(d2)
339 |
340 | out = self.finaldeconv1(d1)
341 | out = self.finalrelu1(out)
342 | out = self.finalconv2(out)
343 | out = self.finalrelu2(out)
344 | out = self.finalconv3(out)
345 |
346 | return torch.sigmoid(out)
347 |
348 | class CE_Net_backbone_inception_blocks(nn.Module):
349 | def __init__(self, num_classes=1, num_channels=3):
350 | super(CE_Net_backbone_inception_blocks, self).__init__()
351 |
352 | filters = [64, 128, 256, 512]
353 | resnet = models.resnet34(pretrained=True)
354 | self.firstconv = resnet.conv1
355 | self.firstbn = resnet.bn1
356 | self.firstrelu = resnet.relu
357 | self.firstmaxpool = resnet.maxpool
358 | self.encoder1 = resnet.layer1
359 | self.encoder2 = resnet.layer2
360 | self.encoder3 = resnet.layer3
361 | self.encoder4 = resnet.layer4
362 |
363 | self.dblock = DACblock_with_inception_blocks(512)
364 |
365 |
366 | self.decoder4 = DecoderBlock(512, filters[2])
367 | self.decoder3 = DecoderBlock(filters[2], filters[1])
368 | self.decoder2 = DecoderBlock(filters[1], filters[0])
369 | self.decoder1 = DecoderBlock(filters[0], filters[0])
370 |
371 | self.finaldeconv1 = nn.ConvTranspose2d(filters[0], 32, 4, 2, 1)
372 | self.finalrelu1 = nonlinearity
373 | self.finalconv2 = nn.Conv2d(32, 32, 3, padding=1)
374 | self.finalrelu2 = nonlinearity
375 | self.finalconv3 = nn.Conv2d(32, num_classes, 3, padding=1)
376 |
377 | def forward(self, x):
378 | # Encoder
379 | x = self.firstconv(x)
380 | x = self.firstbn(x)
381 | x = self.firstrelu(x)
382 | x = self.firstmaxpool(x)
383 | e1 = self.encoder1(x)
384 | e2 = self.encoder2(e1)
385 | e3 = self.encoder3(e2)
386 | e4 = self.encoder4(e3)
387 |
388 | # Center
389 | e4 = self.dblock(e4)
390 | # e4 = self.spp(e4)
391 |
392 | # Decoder
393 | d4 = self.decoder4(e4) + e3
394 | d3 = self.decoder3(d4) + e2
395 | d2 = self.decoder2(d3) + e1
396 | d1 = self.decoder1(d2)
397 |
398 | out = self.finaldeconv1(d1)
399 | out = self.finalrelu1(out)
400 | out = self.finalconv2(out)
401 | out = self.finalrelu2(out)
402 | out = self.finalconv3(out)
403 |
404 | return torch.sigmoid(out)
405 |
406 |
407 | class CE_Net_OCT(nn.Module):
408 | def __init__(self, num_classes=12, num_channels=3):
409 | super(CE_Net_OCT, self).__init__()
410 |
411 | filters = [64, 128, 256, 512]
412 | resnet = models.resnet34(pretrained=True)
413 | self.firstconv = resnet.conv1
414 | self.firstbn = resnet.bn1
415 | self.firstrelu = resnet.relu
416 | self.firstmaxpool = resnet.maxpool
417 | self.encoder1 = resnet.layer1
418 | self.encoder2 = resnet.layer2
419 | self.encoder3 = resnet.layer3
420 | self.encoder4 = resnet.layer4
421 |
422 | self.dblock = DACblock(512)
423 | self.spp = SPPblock(512)
424 |
425 | self.decoder4 = DecoderBlock(516, filters[2])
426 | self.decoder3 = DecoderBlock(filters[2], filters[1])
427 | self.decoder2 = DecoderBlock(filters[1], filters[0])
428 | self.decoder1 = DecoderBlock(filters[0], filters[0])
429 |
430 | self.finaldeconv1 = nn.ConvTranspose2d(filters[0], 32, 4, 2, 1)
431 | self.finalrelu1 = nonlinearity
432 | self.finalconv2 = nn.Conv2d(32, 32, 3, padding=1)
433 | self.finalrelu2 = nonlinearity
434 | self.finalconv3 = nn.Conv2d(32, num_classes, 3, padding=1)
435 |
436 | def forward(self, x):
437 | # Encoder
438 | x = self.firstconv(x)
439 | x = self.firstbn(x)
440 | x = self.firstrelu(x)
441 | x = self.firstmaxpool(x)
442 | e1 = self.encoder1(x)
443 | e2 = self.encoder2(e1)
444 | e3 = self.encoder3(e2)
445 | e4 = self.encoder4(e3)
446 |
447 | # Center
448 | e4 = self.dblock(e4)
449 | e4 = self.spp(e4)
450 |
451 | # Decoder
452 | d4 = self.decoder4(e4) + e3
453 | d3 = self.decoder3(d4) + e2
454 | d2 = self.decoder2(d3) + e1
455 | d1 = self.decoder1(d2)
456 |
457 | out = self.finaldeconv1(d1)
458 | out = self.finalrelu1(out)
459 | out = self.finalconv2(out)
460 | out = self.finalrelu2(out)
461 | out = self.finalconv3(out)
462 |
463 | return out
464 |
465 |
466 |
467 | class double_conv(nn.Module):
468 | def __init__(self, in_ch, out_ch):
469 | super(double_conv, self).__init__()
470 | self.conv = nn.Sequential(
471 | nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
472 | nn.BatchNorm2d(out_ch),
473 | nn.ReLU(inplace=True),
474 | nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
475 | nn.BatchNorm2d(out_ch),
476 | nn.ReLU(inplace=True)
477 | )
478 |
479 | def forward(self, x):
480 | x = self.conv(x)
481 | return x
482 |
483 |
484 | class inconv(nn.Module):
485 | def __init__(self, in_ch, out_ch):
486 | super(inconv, self).__init__()
487 | self.conv = double_conv(in_ch, out_ch)
488 |
489 | def forward(self, x):
490 | x = self.conv(x)
491 | return x
492 |
493 |
494 | class down(nn.Module):
495 | def __init__(self, in_ch, out_ch):
496 | super(down, self).__init__()
497 | self.max_pool_conv = nn.Sequential(
498 | nn.MaxPool2d(2),
499 | double_conv(in_ch, out_ch)
500 | )
501 |
502 | def forward(self, x):
503 | x = self.max_pool_conv(x)
504 | return x
505 |
506 |
507 | class up(nn.Module):
508 | def __init__(self, in_ch, out_ch, bilinear=True):
509 | super(up, self).__init__()
510 | if bilinear:
511 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
512 | else:
513 | self.up = nn.ConvTranspose2d(in_ch // 2, in_ch // 2, 2, stride=2)
514 |
515 | self.conv = double_conv(in_ch, out_ch)
516 |
517 | def forward(self, x1, x2):
518 | x1 = self.up(x1)
519 | diffX = x1.size()[2] - x2.size()[2]
520 | diffY = x1.size()[3] - x2.size()[3]
521 | x2 = F.pad(x2, (diffX // 2, int(diffX / 2), diffY // 2, int(diffY / 2)))
522 | x = torch.cat([x2, x1], dim=1)
523 | x = self.conv(x)
524 | return x
525 |
526 |
527 | class outconv(nn.Module):
528 | def __init__(self, in_ch, out_ch):
529 | super(outconv, self).__init__()
530 | self.conv = nn.Conv2d(in_ch, out_ch, kernel_size=1)
531 |
532 | def forward(self, x):
533 | x = self.conv(x)
534 | return x
535 |
536 |
537 | class UNet(nn.Module):
538 | def __init__(self, n_channels=3, n_classes=1):
539 | super(UNet, self).__init__()
540 | self.inc = inconv(n_channels, 64)
541 | self.down1 = down(64, 128)
542 | self.down2 = down(128, 256)
543 | self.down3 = down(256, 512)
544 | self.down4 = down(512, 512)
545 | self.up1 = up(1024, 256)
546 | self.up2 = up(512, 128)
547 | self.up3 = up(256, 64)
548 | self.up4 = up(128, 64)
549 | self.outc = outconv(64, n_classes)
550 | self.relu = nn.ReLU()
551 |
552 | def forward(self, x):
553 | x1 = self.inc(x)
554 | x2 = self.down1(x1)
555 | x3 = self.down2(x2)
556 | x4 = self.down3(x3)
557 | x5 = self.down4(x4)
558 | x = self.up1(x5, x4)
559 | x = self.up2(x, x3)
560 | x = self.up3(x, x2)
561 | x = self.up4(x, x1)
562 | x = self.outc(x)
563 | #x = self.relu(x)
564 | return torch.sigmoid(x)
565 |
--------------------------------------------------------------------------------
/src/lib/models/networks/neck_blocks/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/src/lib/models/networks/neck_blocks/__init__.py
--------------------------------------------------------------------------------
/src/lib/models/networks/neck_blocks/attention_modules/Dense_atrous.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from functools import partial
6 |
7 | nonlinearity = partial(nn.ReLU, inplace=True)
8 |
9 |
10 | class DACblock(nn.Module):
11 | def __init__(self, channel):
12 | super(DACblock, self).__init__()
13 | self.dilate1 = nn.Conv2d(channel, channel, kernel_size=3, dilation=1, padding=1)
14 | self.dilate2 = nn.Conv2d(channel, channel, kernel_size=3, dilation=3, padding=3)
15 | self.dilate3 = nn.Conv2d(channel, channel, kernel_size=3, dilation=5, padding=5)
16 | self.conv1x1 = nn.Conv2d(channel, channel, kernel_size=1, dilation=1, padding=0)
17 |
18 | for m in self.modules():
19 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
20 | if m.bias is not None:
21 | m.bias.data.zero_()
22 |
23 | def forward(self, x):
24 | dilated1_out = nn.ReLU(self.dilate1(x))
25 | dilated2_out = nn.ReLU(self.conv1x1(self.dilate2(x)))
26 | dilated3_out = nn.ReLU(self.conv1x1(self.dilate2(self.dilate1(x))))
27 | dilated4_out = nn.ReLU(self.conv1x1(self.dilate3(self.dilate2(self.dilate1(x)))))
28 | out = x + dilated1_out + dilated2_out + dilated3_out + dilated4_out
29 | return out
30 |
31 |
32 | class PSPModule(nn.Module):
33 | def __init__(self, features, out_features=1024, sizes=(2, 3, 6, 14)):
34 | super(PSPModule, self).__init__()
35 | self.stages = []
36 | self.stages = nn.ModuleList([self._make_stage(features, size) for size in sizes])
37 | self.bottleneck = nn.Conv2d(features * (len(sizes) + 1), out_features, kernel_size=1)
38 | self.relu = nn.ReLU()
39 |
40 | def _make_stage(self, features, size):
41 | prior = nn.AdaptiveAvgPool2d(output_size=(size, size))
42 | conv = nn.Conv2d(features, features, kernel_size=1, bias=False)
43 |
44 | return nn.Sequential(prior, conv)
45 |
46 | def forward(self, feats):
47 | h, w = feats.size(2), feats.size(3)
48 | # F.upsample ----> F.interpolate
49 | # UserWarning: Default upsampling behavior when mode=bilinear is changed to align_corners=False since 0.4.0.
50 | # Please specify align_corners=True if the old behavior is desired.
51 | # See the documentation of nn.Upsample for details.
52 | # warnings.warn("Default upsampling behavior when mode={} is changed "
53 | priors = [F.interpolate(input=stage(feats), size=(h, w), mode='bilinear', align_corners=True) for stage in
54 | self.stages] + [feats]
55 | bottle = self.bottleneck(torch.cat(priors, 1))
56 | return self.relu(bottle)
57 |
58 |
59 | if __name__ == '__main__':
60 | inp = torch.zeros(size=(10, 12, 14, 14))
61 | module = PSPModule(12)
62 | oup = module(inp)
63 | print(module)
64 | print(oup.size())
65 |
--------------------------------------------------------------------------------
/src/lib/models/networks/neck_blocks/attention_modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/src/lib/models/networks/neck_blocks/attention_modules/__init__.py
--------------------------------------------------------------------------------
/src/lib/models/scatter_gather.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd import Variable
3 | from torch.nn.parallel._functions import Scatter, Gather
4 |
5 |
6 | def scatter(inputs, target_gpus, dim=0, chunk_sizes=None):
7 | r"""
8 | Slices variables into approximately equal chunks and
9 | distributes them across given GPUs. Duplicates
10 | references to objects that are not variables. Does not
11 | support Tensors.
12 | """
13 | def scatter_map(obj):
14 | if isinstance(obj, Variable):
15 | return Scatter.apply(target_gpus, chunk_sizes, dim, obj)
16 | assert not torch.is_tensor(obj), "Tensors not supported in scatter."
17 | if isinstance(obj, tuple):
18 | return list(zip(*map(scatter_map, obj)))
19 | if isinstance(obj, list):
20 | return list(map(list, zip(*map(scatter_map, obj))))
21 | if isinstance(obj, dict):
22 | return list(map(type(obj), zip(*map(scatter_map, obj.items()))))
23 | return [obj for targets in target_gpus]
24 |
25 | return scatter_map(inputs)
26 |
27 |
28 | def scatter_kwargs(inputs, kwargs, target_gpus, dim=0, chunk_sizes=None):
29 | r"""Scatter with support for kwargs dictionary"""
30 | inputs = scatter(inputs, target_gpus, dim, chunk_sizes) if inputs else []
31 | kwargs = scatter(kwargs, target_gpus, dim, chunk_sizes) if kwargs else []
32 | if len(inputs) < len(kwargs):
33 | inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
34 | elif len(kwargs) < len(inputs):
35 | kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
36 | inputs = tuple(inputs)
37 | kwargs = tuple(kwargs)
38 | return inputs, kwargs
39 |
--------------------------------------------------------------------------------
/src/lib/opts.py:
--------------------------------------------------------------------------------
1 | # *coding:utf-8 *
2 | from __future__ import absolute_import
3 | from __future__ import division
4 | from __future__ import print_function
5 |
6 | import argparse
7 | import os
8 |
9 |
10 | class opts(object):
11 | def __init__(self):
12 | self.parser = argparse.ArgumentParser()
13 | # basic experiments setting
14 | self.parser.add_argument('--task', default='binSeg',
15 | help='binSeg'
16 | 'binSeg: Binary Segmentation')
17 | self.parser.add_argument('--dataset', default='ORIGA_OD',
18 | help='ORIGA_OD | kitti | coco_hp | pascal')
19 | self.parser.add_argument('--exp_id', default='default')
20 | self.parser.add_argument('--test', action='store_true')
21 | self.parser.add_argument('--debug', type=int, default=0,
22 | help='level of visualization.'
23 | '1: only show the final detection results'
24 | '2: show the network output features'
25 | '3: use matplot to display' # useful when lunching training with ipython notebook
26 | '4: save all visualizations to disk')
27 | self.parser.add_argument('--demo', default='',
28 | help='path to image/ image folders/ video. '
29 | 'or "webcam"')
30 | self.parser.add_argument('--load_model', default='',
31 | help='path to pretrained model')
32 | self.parser.add_argument('--resume', action='store_true',
33 | help='resume an experiment. '
34 | 'Reloaded the optimizer parameter and '
35 | 'set load_model to model_last.pth '
36 | 'in the exp dir if load_model is empty.')
37 |
38 | # model
39 | self.parser.add_argument('--model_name', default='cenet',
40 | help='which model')
41 |
42 | # data_root
43 | self.parser.add_argument('--data_root', default='/data/zaiwang/Dataset',
44 | help='which model')
45 |
46 | # loss type
47 | self.parser.add_argument('--loss_type', default='dice_bce_loss',
48 | help='which loss type')
49 |
50 | # system
51 | self.parser.add_argument('--gpus', default='3',
52 | help='-1 for CPU, use comma for multiple gpus')
53 | self.parser.add_argument('--num_workers', type=int, default=4,
54 | help='dataloader threads. 0 for single-thread.')
55 | self.parser.add_argument('--not_cuda_benchmark', action='store_true',
56 | help='disable when the input size is not fixed.')
57 | self.parser.add_argument('--seed', type=int, default=317,
58 | help='random seed') # from CornerNet
59 |
60 | self.parser.add_argument('--print_iter', type=int, default=0,
61 | help='disable progress bar and print to screen.')
62 | self.parser.add_argument('--hide_data_time', action='store_true',
63 | help='not display time during training.')
64 | self.parser.add_argument('--save_all', action='store_true',
65 | help='save model to disk every 5 epochs.')
66 | self.parser.add_argument('--metric', default='loss',
67 | help='main metric to save best model')
68 | self.parser.add_argument('--vis_thresh', type=float, default=0.3,
69 | help='visualization threshold.')
70 | self.parser.add_argument('--debugger_theme', default='white',
71 | choices=['white', 'black'])
72 |
73 | # input
74 | self.parser.add_argument('--input_res', type=int, default=-1,
75 | help='input height and width. -1 for default from '
76 | 'dataset. Will be overriden by input_h | input_w')
77 | self.parser.add_argument('--input_h', type=int, default=-1,
78 | help='input height. -1 for default from dataset.')
79 | self.parser.add_argument('--input_w', type=int, default=-1,
80 | help='input width. -1 for default from dataset.')
81 |
82 | # train
83 | self.parser.add_argument('--lr', type=float, default=1.25e-4,
84 | help='learning rate for batch size 32.')
85 | self.parser.add_argument('--lr_step', type=str, default='90,120',
86 | help='drop learning rate by 10.')
87 | self.parser.add_argument('--num_epochs', type=int, default=140,
88 | help='total training epochs.')
89 | self.parser.add_argument('--batch_size', type=int, default=24,
90 | help='batch size')
91 | self.parser.add_argument('--master_batch_size', type=int, default=-1,
92 | help='batch size on the master gpu.')
93 | self.parser.add_argument('--num_iters', type=int, default=-1,
94 | help='default: #samples / batch_size.')
95 | self.parser.add_argument('--val_intervals', type=int, default=5,
96 | help='number of epochs to run validation.')
97 |
98 | self.parser.add_argument('--root_dir', type=str, default='/data/zaiwang/output',
99 | help='the path to save training model and loggers')
100 |
101 | # dataset augmentation
102 | self.parser.add_argument('--color_aug', default=True,
103 | help='HSV color augmentation')
104 |
105 | self.parser.add_argument('--shift_scale', default=True,
106 | help='when not using random crop'
107 | 'apply shift augmentation.')
108 |
109 | self.parser.add_argument('--HorizontalFlip', default=True,
110 | help='Horizontal Flip')
111 |
112 | self.parser.add_argument('--VerticleFlip', default=True,
113 | help='Verticle Flip')
114 |
115 | self.parser.add_argument('--rotate_90', default=True,
116 | help='when not using random crop'
117 | 'apply rotation augmentation.')
118 |
119 | def parse(self, args=''):
120 | if args == '':
121 | opt = self.parser.parse_args()
122 | else:
123 | opt = self.parser.parse_args(args)
124 |
125 | opt.gpus_str = opt.gpus
126 | opt.gpus = [int(gpu) for gpu in opt.gpus.split(',')]
127 | opt.gpus = [i for i in range(len(opt.gpus))] if opt.gpus[0] >= 0 else [-1]
128 | print("GPUS device is {}".format(opt.gpus))
129 |
130 | opt.lr_step = [int(i) for i in opt.lr_step.split(',')]
131 |
132 | if opt.debug > 0:
133 | opt.num_workers = 0
134 | opt.batch_size = 1
135 | opt.gpus = [opt.gpus[0]]
136 | opt.master_batch_size = -1
137 |
138 | if opt.master_batch_size == -1:
139 | opt.master_batch_size = opt.batch_size // len(opt.gpus)
140 | rest_batch_size = (opt.batch_size - opt.master_batch_size)
141 | opt.chunk_sizes = [opt.master_batch_size]
142 | for i in range(len(opt.gpus) - 1):
143 | slave_chunk_size = rest_batch_size // (len(opt.gpus) - 1)
144 | if i < rest_batch_size % (len(opt.gpus) - 1):
145 | slave_chunk_size += 1
146 | opt.chunk_sizes.append(slave_chunk_size)
147 | print('training chunk_sizes:', opt.chunk_sizes)
148 |
149 | opt.data_dir = os.path.join(opt.data_root, opt.dataset)
150 | opt.exp_dir = os.path.join(opt.root_dir, 'UBT_Seg', opt.task)
151 | opt.exp_id = opt.dataset + '_' + opt.model_name + '_' + opt.loss_type
152 | opt.save_dir = os.path.join(opt.exp_dir, opt.exp_id)
153 | opt.debug_dir = os.path.join(opt.save_dir, 'debug')
154 | print('The output will be saved to ', opt.save_dir)
155 |
156 | if opt.resume and opt.load_model == '':
157 | model_path = opt.save_dir[:-4] if opt.save_dir.endswith('TEST') \
158 | else opt.save_dir
159 | opt.load_model = os.path.join(model_path, 'model_last.pth')
160 | return opt
161 |
162 | def update_dataset_info_and_set_heads(self, opt, dataset):
163 | opt.height, opt.width = dataset.default_resolution
164 | opt.mean, opt.std = dataset.mean, dataset.std
165 | opt.num_classes = dataset.num_classes
166 |
167 | return opt
168 |
169 | def init(self, args=''):
170 | default_dataset_info = {
171 | 'binSeg': {'default_resolution': [512, 512], 'num_classes': 1,
172 | 'mean': [0.408, 0.447, 0.470], 'std': [0.289, 0.274, 0.278],
173 | 'dataset': 'ORIGA_OD'},
174 |
175 | }
176 |
177 | class Struct:
178 | def __init__(self, entries):
179 | for k, v in entries.items():
180 | self.__setattr__(k, v)
181 |
182 | opt = self.parse(args)
183 | dataset = Struct(default_dataset_info[opt.task])
184 | opt.dataset = dataset.dataset
185 | opt = self.update_dataset_info_and_set_heads(opt, dataset)
186 | return opt
187 |
--------------------------------------------------------------------------------
/src/lib/trains/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/src/lib/trains/__init__.py
--------------------------------------------------------------------------------
/src/lib/trains/base_trainer.py:
--------------------------------------------------------------------------------
1 | # *coding:utf-8 *
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | import time
8 | import torch
9 | from progress.bar import Bar
10 | from models.data_parallel import DataParallel
11 | from utils.utils import AverageMeter
12 |
13 |
14 | class ModelWithLoss(torch.nn.Module):
15 | def __init__(self, model, loss):
16 | super(ModelWithLoss, self).__init__()
17 | self.model = model
18 | self.loss = loss
19 |
20 | def forward(self, batch):
21 | outputs = self.model(batch['input'])
22 | loss, loss_stats = self.loss(outputs, batch)
23 |
24 | return outputs[-1], loss, loss_stats
25 |
26 |
27 | class BaseTrainer(object):
28 | def __init__(self, opt, model, optimizer=None):
29 | self.opt = opt
30 | self.optimizer = optimizer
31 | self.loss_stats, self.loss = self._get_losses(opt)
32 | self.model_with_loss = ModelWithLoss(model, self.loss)
33 |
34 | # 2021-09-11
35 | # DataParallel() function: device_ids=gpus (从0开始),适用于单机多卡
36 | def set_device(self, gpus, chunk_sizes, device):
37 | if len(gpus) > 1:
38 | self.model_with_loss = DataParallel(
39 | self.model_with_loss, device_ids=gpus,
40 | chunk_sizes=chunk_sizes).to(device)
41 | else:
42 | self.model_with_loss = self.model_with_loss.to(device)
43 |
44 | for state in self.optimizer.state.values():
45 | for k, v in state.items():
46 | if isinstance(v, torch.Tensor):
47 | state[k] = v.to(device=device, non_blocking=True)
48 |
49 | def run_epoch(self, phase, epoch, data_loader):
50 | model_with_loss = self.model_with_loss
51 | if phase == 'train':
52 | model_with_loss.train()
53 | else:
54 | if len(self.opt.gpus) > 1:
55 | model_with_loss = self.model_with_loss.module
56 | model_with_loss.eval()
57 | torch.cuda.empty_cache()
58 |
59 | opt = self.opt
60 | results = {}
61 | data_time, batch_time = AverageMeter(), AverageMeter()
62 | avg_loss_stats = {l: AverageMeter() for l in self.loss_stats}
63 |
64 | num_iters = len(data_loader) if opt.num_iters < 0 else opt.num_iters
65 | bar = Bar('{}/{}'.format(opt.task, opt.exp_id), max=num_iters)
66 | end = time.time()
67 |
68 | for iter_id, batch in enumerate(data_loader):
69 | if iter_id >= num_iters:
70 | break
71 | data_time.update(time.time() - end)
72 |
73 | for k in batch:
74 | if k != 'meta':
75 | batch[k] = batch[k].to(device=opt.device, non_blocking=True)
76 |
77 | output, loss, loss_stats = model_with_loss(batch)
78 |
79 | loss = loss.mean()
80 |
81 | if phase == 'train':
82 | self.optimizer.zero_grad()
83 | loss.backward()
84 | self.optimizer.step()
85 | batch_time.update(time.time() - end)
86 | end = time.time()
87 |
88 | Bar.suffix = '{phase}: [{0}][{1}/{2}]|Tot: {total:} |ETA: {eta:} '.format(
89 | epoch, iter_id, num_iters, phase=phase,
90 | total=bar.elapsed_td, eta=bar.eta_td)
91 |
92 | for l in avg_loss_stats:
93 | avg_loss_stats[l].update(
94 | loss_stats[l].mean().item(), batch['input'].size(0))
95 | Bar.suffix = Bar.suffix + '|{} {:.4f} '.format(l, avg_loss_stats[l].avg)
96 | if not opt.hide_data_time:
97 | Bar.suffix = Bar.suffix + '|Data {dt.val:.3f}s({dt.avg:.3f}s) ' \
98 | '|Net {bt.avg:.3f}s'.format(dt=data_time, bt=batch_time)
99 | if opt.print_iter > 0:
100 | if iter_id % opt.print_iter == 0:
101 | print('{}/{}| {}'.format(opt.task, opt.exp_id, Bar.suffix))
102 | else:
103 | bar.next()
104 |
105 | if opt.debug > 0:
106 | self.debug(batch, output, iter_id)
107 |
108 | if opt.test:
109 | self.save_result(output, batch, results)
110 |
111 | del output, loss, loss_stats
112 |
113 | bar.finish()
114 | ret = {k: v.avg for k, v in avg_loss_stats.items()}
115 | ret['time'] = bar.elapsed_td.total_seconds() / 60.
116 | return ret, results
117 |
118 | def debug(self, batch, output, iter_id):
119 | raise NotImplementedError
120 |
121 | def save_result(self, output, batch, results):
122 | raise NotImplementedError
123 |
124 | def _get_losses(self, opt):
125 | raise NotImplementedError
126 |
127 | def val(self, epoch, data_loader):
128 | return self.run_epoch('val', epoch, data_loader)
129 |
130 | def train(self, epoch, data_loader):
131 | return self.run_epoch('train', epoch, data_loader)
132 |
--------------------------------------------------------------------------------
/src/lib/trains/binarySeg.py:
--------------------------------------------------------------------------------
1 | # *coding:utf-8 *
2 |
3 | import torch
4 | import numpy as np
5 |
6 | from models.losses import dice_bce_loss
7 | from .base_trainer import BaseTrainer
8 |
9 | class binarySegLoss(torch.nn.Module):
10 | def __init__(self, opt):
11 | super(binarySegLoss, self).__init__()
12 | self.crit = dice_bce_loss()
13 | self.opt = opt
14 |
15 | def forward(self, outputs, batch):
16 |
17 | loss = self.crit(batch['gt'], outputs)
18 |
19 | loss_stats = {'loss': loss}
20 |
21 | return loss, loss_stats
22 |
23 |
24 | class BinarySegTrainer(BaseTrainer):
25 | def __init__(self, opt, model, optimizer=None):
26 | super(BinarySegTrainer, self).__init__(opt, model, optimizer=optimizer)
27 |
28 | def _get_losses(self, opt):
29 | loss_stats = ['loss']
30 | loss = binarySegLoss(opt)
31 | return loss_stats, loss
32 |
33 |
--------------------------------------------------------------------------------
/src/lib/trains/train_factory.py:
--------------------------------------------------------------------------------
1 | # *coding:utf-8 *
2 | from __future__ import absolute_import
3 | from __future__ import division
4 | from __future__ import print_function
5 |
6 | from .binarySeg import BinarySegTrainer
7 |
8 | train_factory = {
9 | 'binSeg': BinarySegTrainer,
10 | }
--------------------------------------------------------------------------------
/src/lib/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Guzaiwang/CE-Net/d7f0865d70d3ba5bc3c17e7c76f75ae52d2d36b7/src/lib/utils/__init__.py
--------------------------------------------------------------------------------
/src/lib/utils/image.py:
--------------------------------------------------------------------------------
1 | # *coding:utf-8 *
2 |
3 | """
4 | Based on https://github.com/asanakoy/kaggle_carvana_segmentation
5 | """
6 | import torch
7 | import torch.utils.data as data
8 | from torch.autograd import Variable as V
9 | from PIL import Image
10 |
11 |
12 |
13 | import cv2
14 | import numpy as np
15 | import os
16 | import scipy.misc as misc
17 |
18 |
19 | def randomHueSaturationValue(image, hue_shift_limit=(-180, 180),
20 | sat_shift_limit=(-255, 255),
21 | val_shift_limit=(-255, 255), u=0.5):
22 | if np.random.random() < u:
23 | image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
24 | h, s, v = cv2.split(image)
25 | hue_shift = np.random.randint(hue_shift_limit[0], hue_shift_limit[1] + 1)
26 | hue_shift = np.uint8(hue_shift)
27 | h += hue_shift
28 | sat_shift = np.random.uniform(sat_shift_limit[0], sat_shift_limit[1])
29 | s = cv2.add(s, sat_shift)
30 | val_shift = np.random.uniform(val_shift_limit[0], val_shift_limit[1])
31 | v = cv2.add(v, val_shift)
32 | image = cv2.merge((h, s, v))
33 | # image = cv2.merge((s, v))
34 | image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR)
35 |
36 | return image
37 |
38 |
39 | def randomShiftScaleRotate(image, mask,
40 | shift_limit=(-0.0, 0.0),
41 | scale_limit=(-0.0, 0.0),
42 | rotate_limit=(-0.0, 0.0),
43 | aspect_limit=(-0.0, 0.0),
44 | borderMode=cv2.BORDER_CONSTANT, u=0.5):
45 | if np.random.random() < u:
46 | height, width, channel = image.shape
47 |
48 | angle = np.random.uniform(rotate_limit[0], rotate_limit[1])
49 | scale = np.random.uniform(1 + scale_limit[0], 1 + scale_limit[1])
50 | aspect = np.random.uniform(1 + aspect_limit[0], 1 + aspect_limit[1])
51 | sx = scale * aspect / (aspect ** 0.5)
52 | sy = scale / (aspect ** 0.5)
53 | dx = round(np.random.uniform(shift_limit[0], shift_limit[1]) * width)
54 | dy = round(np.random.uniform(shift_limit[0], shift_limit[1]) * height)
55 |
56 | cc = np.math.cos(angle / 180 * np.math.pi) * sx
57 | ss = np.math.sin(angle / 180 * np.math.pi) * sy
58 | rotate_matrix = np.array([[cc, -ss], [ss, cc]])
59 |
60 | box0 = np.array([[0, 0], [width, 0], [width, height], [0, height], ])
61 | box1 = box0 - np.array([width / 2, height / 2])
62 | box1 = np.dot(box1, rotate_matrix.T) + np.array([width / 2 + dx, height / 2 + dy])
63 |
64 | box0 = box0.astype(np.float32)
65 | box1 = box1.astype(np.float32)
66 | mat = cv2.getPerspectiveTransform(box0, box1)
67 | image = cv2.warpPerspective(image, mat, (width, height), flags=cv2.INTER_LINEAR, borderMode=borderMode,
68 | borderValue=(
69 | 0, 0,
70 | 0,))
71 | mask = cv2.warpPerspective(mask, mat, (width, height), flags=cv2.INTER_LINEAR, borderMode=borderMode,
72 | borderValue=(
73 | 0, 0,
74 | 0,))
75 |
76 | return image, mask
77 |
78 |
79 | def randomHorizontalFlip(image, mask, u=0.5):
80 | if np.random.random() < u:
81 | image = cv2.flip(image, 1)
82 | mask = cv2.flip(mask, 1)
83 |
84 | return image, mask
85 |
86 |
87 | def randomVerticleFlip(image, mask, u=0.5):
88 | if np.random.random() < u:
89 | image = cv2.flip(image, 0)
90 | mask = cv2.flip(mask, 0)
91 |
92 | return image, mask
93 |
94 |
95 | def randomRotate90(image, mask, u=0.5):
96 | if np.random.random() < u:
97 | image = np.rot90(image)
98 | mask = np.rot90(mask)
99 |
100 | return image, mask
101 |
102 |
103 | def mask_to_boundary(mask, dilation_ratio=0.02):
104 | """
105 | Convert binary mask to boundary mask.
106 | :param mask (numpy array, uint8): binary mask
107 | :param dilation_ratio (float): ratio to calculate dilation = dilation_ratio * image_diagonal
108 | :return: boundary mask (numpy array)
109 | """
110 | h, w = mask.shape
111 | img_diag = np.sqrt(h ** 2 + w ** 2)
112 | dilation = int(round(dilation_ratio * img_diag))
113 | if dilation < 1:
114 | dilation = 1
115 | # Pad image so mask truncated by the image border is also considered as boundary.
116 | new_mask = cv2.copyMakeBorder(mask, 1, 1, 1, 1, cv2.BORDER_CONSTANT, value=0)
117 | kernel = np.ones((3, 3), dtype=np.uint8)
118 | new_mask_erode = cv2.erode(new_mask, kernel, iterations=dilation)
119 | mask_erode = new_mask_erode[1 : h + 1, 1 : w + 1]
120 | # G_d intersects G in the paper.
121 | return mask - mask_erode
122 |
--------------------------------------------------------------------------------
/src/lib/utils/utils.py:
--------------------------------------------------------------------------------
1 | # *coding:utf-8 *
2 | from __future__ import absolute_import
3 | from __future__ import division
4 | from __future__ import print_function
5 |
6 |
7 | class AverageMeter(object):
8 | """Computes and stores the average and current value"""
9 |
10 | def __init__(self):
11 | self.reset()
12 |
13 | def reset(self):
14 | self.val = 0
15 | self.avg = 0
16 | self.sum = 0
17 | self.count = 0
18 |
19 | def update(self, val, n=1):
20 | self.val = val
21 | self.sum += val * n
22 | self.count += n
23 | if self.count > 0:
24 | self.avg = self.sum / self.count
25 |
--------------------------------------------------------------------------------
/src/main.py:
--------------------------------------------------------------------------------
1 | # *coding:utf-8 *
2 | from __future__ import absolute_import
3 | from __future__ import division
4 | from __future__ import print_function
5 |
6 | from _init_paths import __init_path
7 |
8 | __init_path()
9 |
10 | import os
11 |
12 | import torch
13 | import torch.utils
14 |
15 | from logger import Logger
16 | from datasets.dataset_factory import get_dataset
17 | from models.model import create_model, load_model, save_model
18 | from trains.train_factory import train_factory
19 |
20 | from opts import opts
21 |
22 |
23 | def main(opt):
24 | # Completely reproducible results are not guaranteed across PyTorch releases, \
25 | # individual commits, or different platforms. Furthermore, results may not be reproducible \
26 | # between CPU and GPU executions, even when using identical seeds.
27 | # We can use torch.manual_seed() to seed the RNG for all devices (both CPU and CUDA):
28 | torch.manual_seed(opt.seed)
29 |
30 | # 设置 torch.backends.cudnn.benchmark=True 将会让程序在开始时花费一点额外时间,\
31 | # 为整个网络的每个卷积层搜索最适合它的卷积实现算法,进而实现网络的加速。\
32 | # 适用场景是网络结构固定(不是动态变化的),网络的输入形状(包括 batch size,图片大小,输入的通道)是不变的,\
33 | # 其实也就是一般情况下都比较适用。反之,如果卷积层的设置一直变化,将会导致程序不停地做优化,反而会耗费更多的时间。
34 | torch.backends.cudnn.benchmark = not opt.not_cuda_benchmark and not opt.test
35 |
36 | Dataset = get_dataset(opt.dataset, opt.task)
37 |
38 | opt = opts().update_dataset_info_and_set_heads(opt, Dataset)
39 |
40 | logger = Logger(opt)
41 |
42 | os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpus_str
43 | opt.device = torch.device('cuda' if opt.gpus[0] >= 0 else 'cpu')
44 |
45 | print('Creating model...')
46 |
47 | model = create_model(opt.model_name)
48 |
49 | optimizer = torch.optim.Adam(model.parameters(), opt.lr)
50 |
51 | start_epoch = 0
52 |
53 | if opt.load_model != '':
54 | model, optimizer, start_epoch = load_model(
55 | model, opt.load_model, optimizer, opt.resume, opt.lr, opt.lr_step)
56 |
57 | Trainer = train_factory[opt.task]
58 |
59 | trainer = Trainer(opt, model, optimizer)
60 | trainer.set_device(opt.gpus, opt.chunk_sizes, opt.device)
61 |
62 | print("Setting up data...")
63 | val_loader = torch.utils.data.DataLoader(
64 | Dataset(opt, 'val'),
65 | batch_size=1,
66 | shuffle=False,
67 | num_workers=1,
68 | pin_memory=True
69 | )
70 |
71 | if opt.test:
72 | _, preds = trainer.val(0, val_loader)
73 | # run evaluation code
74 | val_loader.dataset.run_eval(preds, opt.save_dir)
75 | return
76 |
77 | train_loader = torch.utils.data.DataLoader(
78 | Dataset(opt, 'train'),
79 | batch_size=opt.batch_size,
80 | shuffle=True,
81 | num_workers=opt.num_workers,
82 | pin_memory=True,
83 | drop_last=True
84 | )
85 |
86 | print("Starting training")
87 | best = 1e10
88 |
89 | for epoch in range(start_epoch + 1, opt.num_epochs + 1):
90 | mark = epoch if opt.save_all else 'last'
91 | log_dict_train, _ = trainer.train(epoch, train_loader)
92 | logger.write('epoch: {} |'.format(epoch))
93 |
94 | for k, v in log_dict_train.items():
95 | logger.scalar_summary('train_{}'.format(k), v, epoch)
96 | logger.write('{} {:8f} | '.format(k, v))
97 | if opt.val_intervals > 0 and epoch % opt.val_intervals == 0:
98 | save_model(os.path.join(opt.save_dir, 'model_{}.pth'.format(mark)),
99 | epoch, model, optimizer)
100 | with torch.no_grad():
101 | log_dict_val, preds = trainer.val(epoch, val_loader)
102 | for k, v in log_dict_val.items():
103 | logger.scalar_summary('val_{}'.format(k), v, epoch)
104 | logger.write('{} {:8f} | '.format(k, v))
105 | if log_dict_val[opt.metric] < best:
106 | best = log_dict_val[opt.metric]
107 | save_model(os.path.join(opt.save_dir, 'model_best.pth'),
108 | epoch, model)
109 | else:
110 | save_model(os.path.join(opt.save_dir, 'model_last.pth'),
111 | epoch, model, optimizer)
112 |
113 | logger.write('\n')
114 | if epoch in opt.lr_step:
115 | save_model(os.path.join(opt.save_dir, 'model_{}.pth'.format(epoch)),
116 | epoch, model, optimizer)
117 | lr = opt.lr * (0.1 ** (opt.lr_step.index(epoch) + 1))
118 | print('Drop LR to', lr)
119 | for param_group in optimizer.param_groups:
120 | param_group['lr'] = lr
121 | logger.close()
122 |
123 |
124 | if __name__ == '__main__':
125 | opt = opts().parse()
126 | main(opt)
127 |
--------------------------------------------------------------------------------