├── README.md ├── data ├── annotations.csv ├── labels_histo_train_fold_1.csv ├── labels_histo_train_fold_10.csv ├── labels_histo_train_fold_2.csv ├── labels_histo_train_fold_3.csv ├── labels_histo_train_fold_4.csv ├── labels_histo_train_fold_5.csv ├── labels_histo_train_fold_6.csv ├── labels_histo_train_fold_7.csv ├── labels_histo_train_fold_8.csv ├── labels_histo_train_fold_9.csv ├── labels_histo_valid_fold_1.csv ├── labels_histo_valid_fold_10.csv ├── labels_histo_valid_fold_2.csv ├── labels_histo_valid_fold_3.csv ├── labels_histo_valid_fold_4.csv ├── labels_histo_valid_fold_5.csv ├── labels_histo_valid_fold_6.csv ├── labels_histo_valid_fold_7.csv ├── labels_histo_valid_fold_8.csv └── labels_histo_valid_fold_9.csv ├── dataset ├── colon_cancer_dataset.py └── custom_dataset.py ├── eval.py ├── main.py ├── models ├── attention_models.py ├── classifier.py └── feature_extractors.py ├── train.py ├── utils ├── dataset_utils.py └── utils.py └── zoom_in ├── core_utils.py ├── layers.py ├── regularizer.py └── zoom_in.py /README.md: -------------------------------------------------------------------------------- 1 | This is the official repository for "[Efficient Classification of Very Large Images with Tiny Objects](https://arxiv.org/abs/2106.02694)". 2 | 3 | # Overview 4 | An increasing number of applications in computer vision, specially, in medical imaging and remote sensing, become challenging when the goal is to classify very large images with tiny informative objects. 5 | Specifically, these classification tasks face two key challenges: $i$) the size of the input image is usually in the order of mega- or giga-pixels, however, existing deep architectures do not easily operate on such big images due to memory constraints, consequently, we seek a memory-efficient method to process these images; and $ii$) only a very small fraction of the input images are informative of the label of interest, resulting in low region of interest (ROI) to image ratio. 6 | However, most of the current convolutional neural networks (CNNs) are designed for image classification datasets that have relatively large ROIs and small image sizes (sub-megapixel). 7 | Existing approaches have addressed these two challenges in isolation. 8 | We present an end-to-end CNN model termed Zoom-In network that leverages hierarchical attention sampling for classification of large images with tiny objects using a single GPU. 9 | We evaluate our method on four large-image histopathology, road-scene and satellite imaging datasets, and one gigapixel pathology dataset. 10 | Experimental results show that our model achieves higher accuracy than existing methods while requiring less memory resources. 11 | 12 | ## Major Dependencies 13 | 14 | pytorch-gpu == 1.6.0 15 | 16 | torchvision == 0.7.0 17 | 18 | ## Training 19 | training on colon cancer dataset: 20 | 21 | python main.py /path_to_your_dataset/ /path_to_your_output/ --mode 10CrossValidation --model_name yourModelName 22 | 23 | ## Evaluation 24 | 25 | python main.py /path_to_your_dataset/ /path_to_your_model_directory/ --mode Evaluation --model_name yourModelName 26 | 27 | ## Acknowledgemennt 28 | This work was supported by NIH (R44-HL140794), DARPA (FA8650-18-2-7832-P00009-12) and ONR (N00014-18-1-2871-P00002-3). 29 | 30 | ## Research 31 | 32 | If you would like to cite our work, 33 | 34 | @inproceedings{kong2022efficient, 35 | title={Efficient Classification of Very Large Images with Tiny Objects}, 36 | author={Kong, Fanjie and Henao, Ricardo}, 37 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 38 | pages={2384--2394}, 39 | year={2022} 40 | } 41 | 42 | ##### Thanks to the following repositories inpired our work: 43 | - https://github.com/idiap/attention-sampling.git 44 | -------------------------------------------------------------------------------- /data/annotations.csv: -------------------------------------------------------------------------------- 1 | 0,0,0,0,1,1,1,1,1,1,0,0,1,1,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,1,1,1,0,0,0,1,0,1,0,1,1,1,1,0,0,1,1,1,1,1,1,1,1,1,1,0,0,0,0,1,1,0,0,0,1,1,1,1,1,1,1,1,0,0,0,0,1,1,1,1,0,0,0,0,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0 2 | -------------------------------------------------------------------------------- /data/labels_histo_train_fold_1.csv: -------------------------------------------------------------------------------- 1 | img11.bmp,0 2 | img12.bmp,0 3 | img13.bmp,1 4 | img14.bmp,1 5 | img15.bmp,0 6 | img16.bmp,0 7 | img17.bmp,1 8 | img18.bmp,1 9 | img19.bmp,1 10 | img20.bmp,1 11 | img21.bmp,1 12 | img22.bmp,1 13 | img23.bmp,0 14 | img24.bmp,0 15 | img25.bmp,0 16 | img26.bmp,0 17 | img27.bmp,0 18 | img28.bmp,0 19 | img29.bmp,0 20 | img30.bmp,1 21 | img31.bmp,1 22 | img32.bmp,1 23 | img33.bmp,0 24 | img34.bmp,0 25 | img35.bmp,0 26 | img36.bmp,1 27 | img37.bmp,0 28 | img38.bmp,1 29 | img39.bmp,0 30 | img40.bmp,1 31 | img41.bmp,1 32 | img42.bmp,1 33 | img43.bmp,1 34 | img44.bmp,0 35 | img45.bmp,0 36 | img46.bmp,1 37 | img47.bmp,1 38 | img48.bmp,1 39 | img49.bmp,1 40 | img50.bmp,1 41 | img51.bmp,1 42 | img52.bmp,1 43 | img53.bmp,1 44 | img54.bmp,1 45 | img55.bmp,1 46 | img56.bmp,0 47 | img57.bmp,0 48 | img58.bmp,0 49 | img59.bmp,0 50 | img60.bmp,1 51 | img61.bmp,1 52 | img62.bmp,0 53 | img63.bmp,0 54 | img64.bmp,0 55 | img65.bmp,1 56 | img66.bmp,1 57 | img67.bmp,1 58 | img68.bmp,1 59 | img69.bmp,1 60 | img70.bmp,1 61 | img71.bmp,1 62 | img72.bmp,1 63 | img73.bmp,0 64 | img74.bmp,0 65 | img75.bmp,0 66 | img76.bmp,0 67 | img77.bmp,1 68 | img78.bmp,1 69 | img79.bmp,1 70 | img80.bmp,1 71 | img81.bmp,0 72 | img82.bmp,0 73 | img83.bmp,0 74 | img84.bmp,0 75 | img85.bmp,1 76 | img86.bmp,1 77 | img87.bmp,1 78 | img88.bmp,1 79 | img89.bmp,0 80 | img90.bmp,0 81 | img91.bmp,0 82 | img92.bmp,0 83 | img93.bmp,0 84 | img94.bmp,0 85 | img95.bmp,0 86 | img96.bmp,0 87 | img97.bmp,0 88 | img98.bmp,0 89 | img99.bmp,0 90 | img100.bmp,0 91 | -------------------------------------------------------------------------------- /data/labels_histo_train_fold_10.csv: -------------------------------------------------------------------------------- 1 | img1.bmp,0 2 | img2.bmp,0 3 | img3.bmp,0 4 | img4.bmp,0 5 | img5.bmp,1 6 | img6.bmp,1 7 | img7.bmp,1 8 | img8.bmp,1 9 | img9.bmp,1 10 | img10.bmp,1 11 | img11.bmp,0 12 | img12.bmp,0 13 | img13.bmp,1 14 | img14.bmp,1 15 | img15.bmp,0 16 | img16.bmp,0 17 | img17.bmp,1 18 | img18.bmp,1 19 | img19.bmp,1 20 | img20.bmp,1 21 | img21.bmp,1 22 | img22.bmp,1 23 | img23.bmp,0 24 | img24.bmp,0 25 | img25.bmp,0 26 | img26.bmp,0 27 | img27.bmp,0 28 | img28.bmp,0 29 | img29.bmp,0 30 | img30.bmp,1 31 | img31.bmp,1 32 | img32.bmp,1 33 | img33.bmp,0 34 | img34.bmp,0 35 | img35.bmp,0 36 | img36.bmp,1 37 | img37.bmp,0 38 | img38.bmp,1 39 | img39.bmp,0 40 | img40.bmp,1 41 | img41.bmp,1 42 | img42.bmp,1 43 | img43.bmp,1 44 | img44.bmp,0 45 | img45.bmp,0 46 | img46.bmp,1 47 | img47.bmp,1 48 | img48.bmp,1 49 | img49.bmp,1 50 | img50.bmp,1 51 | img51.bmp,1 52 | img52.bmp,1 53 | img53.bmp,1 54 | img54.bmp,1 55 | img55.bmp,1 56 | img56.bmp,0 57 | img57.bmp,0 58 | img58.bmp,0 59 | img59.bmp,0 60 | img60.bmp,1 61 | img61.bmp,1 62 | img62.bmp,0 63 | img63.bmp,0 64 | img64.bmp,0 65 | img65.bmp,1 66 | img66.bmp,1 67 | img67.bmp,1 68 | img68.bmp,1 69 | img69.bmp,1 70 | img70.bmp,1 71 | img71.bmp,1 72 | img72.bmp,1 73 | img73.bmp,0 74 | img74.bmp,0 75 | img75.bmp,0 76 | img76.bmp,0 77 | img77.bmp,1 78 | img78.bmp,1 79 | img79.bmp,1 80 | img80.bmp,1 81 | img81.bmp,0 82 | img82.bmp,0 83 | img83.bmp,0 84 | img84.bmp,0 85 | img85.bmp,1 86 | img86.bmp,1 87 | img87.bmp,1 88 | img88.bmp,1 89 | img89.bmp,0 90 | img90.bmp,0 91 | -------------------------------------------------------------------------------- /data/labels_histo_train_fold_2.csv: -------------------------------------------------------------------------------- 1 | img1.bmp,0 2 | img2.bmp,0 3 | img3.bmp,0 4 | img4.bmp,0 5 | img5.bmp,1 6 | img6.bmp,1 7 | img7.bmp,1 8 | img8.bmp,1 9 | img9.bmp,1 10 | img10.bmp,1 11 | img21.bmp,1 12 | img22.bmp,1 13 | img23.bmp,0 14 | img24.bmp,0 15 | img25.bmp,0 16 | img26.bmp,0 17 | img27.bmp,0 18 | img28.bmp,0 19 | img29.bmp,0 20 | img30.bmp,1 21 | img31.bmp,1 22 | img32.bmp,1 23 | img33.bmp,0 24 | img34.bmp,0 25 | img35.bmp,0 26 | img36.bmp,1 27 | img37.bmp,0 28 | img38.bmp,1 29 | img39.bmp,0 30 | img40.bmp,1 31 | img41.bmp,1 32 | img42.bmp,1 33 | img43.bmp,1 34 | img44.bmp,0 35 | img45.bmp,0 36 | img46.bmp,1 37 | img47.bmp,1 38 | img48.bmp,1 39 | img49.bmp,1 40 | img50.bmp,1 41 | img51.bmp,1 42 | img52.bmp,1 43 | img53.bmp,1 44 | img54.bmp,1 45 | img55.bmp,1 46 | img56.bmp,0 47 | img57.bmp,0 48 | img58.bmp,0 49 | img59.bmp,0 50 | img60.bmp,1 51 | img61.bmp,1 52 | img62.bmp,0 53 | img63.bmp,0 54 | img64.bmp,0 55 | img65.bmp,1 56 | img66.bmp,1 57 | img67.bmp,1 58 | img68.bmp,1 59 | img69.bmp,1 60 | img70.bmp,1 61 | img71.bmp,1 62 | img72.bmp,1 63 | img73.bmp,0 64 | img74.bmp,0 65 | img75.bmp,0 66 | img76.bmp,0 67 | img77.bmp,1 68 | img78.bmp,1 69 | img79.bmp,1 70 | img80.bmp,1 71 | img81.bmp,0 72 | img82.bmp,0 73 | img83.bmp,0 74 | img84.bmp,0 75 | img85.bmp,1 76 | img86.bmp,1 77 | img87.bmp,1 78 | img88.bmp,1 79 | img89.bmp,0 80 | img90.bmp,0 81 | img91.bmp,0 82 | img92.bmp,0 83 | img93.bmp,0 84 | img94.bmp,0 85 | img95.bmp,0 86 | img96.bmp,0 87 | img97.bmp,0 88 | img98.bmp,0 89 | img99.bmp,0 90 | img100.bmp,0 91 | -------------------------------------------------------------------------------- /data/labels_histo_train_fold_3.csv: -------------------------------------------------------------------------------- 1 | img1.bmp,0 2 | img2.bmp,0 3 | img3.bmp,0 4 | img4.bmp,0 5 | img5.bmp,1 6 | img6.bmp,1 7 | img7.bmp,1 8 | img8.bmp,1 9 | img9.bmp,1 10 | img10.bmp,1 11 | img11.bmp,0 12 | img12.bmp,0 13 | img13.bmp,1 14 | img14.bmp,1 15 | img15.bmp,0 16 | img16.bmp,0 17 | img17.bmp,1 18 | img18.bmp,1 19 | img19.bmp,1 20 | img20.bmp,1 21 | img31.bmp,1 22 | img32.bmp,1 23 | img33.bmp,0 24 | img34.bmp,0 25 | img35.bmp,0 26 | img36.bmp,1 27 | img37.bmp,0 28 | img38.bmp,1 29 | img39.bmp,0 30 | img40.bmp,1 31 | img41.bmp,1 32 | img42.bmp,1 33 | img43.bmp,1 34 | img44.bmp,0 35 | img45.bmp,0 36 | img46.bmp,1 37 | img47.bmp,1 38 | img48.bmp,1 39 | img49.bmp,1 40 | img50.bmp,1 41 | img51.bmp,1 42 | img52.bmp,1 43 | img53.bmp,1 44 | img54.bmp,1 45 | img55.bmp,1 46 | img56.bmp,0 47 | img57.bmp,0 48 | img58.bmp,0 49 | img59.bmp,0 50 | img60.bmp,1 51 | img61.bmp,1 52 | img62.bmp,0 53 | img63.bmp,0 54 | img64.bmp,0 55 | img65.bmp,1 56 | img66.bmp,1 57 | img67.bmp,1 58 | img68.bmp,1 59 | img69.bmp,1 60 | img70.bmp,1 61 | img71.bmp,1 62 | img72.bmp,1 63 | img73.bmp,0 64 | img74.bmp,0 65 | img75.bmp,0 66 | img76.bmp,0 67 | img77.bmp,1 68 | img78.bmp,1 69 | img79.bmp,1 70 | img80.bmp,1 71 | img81.bmp,0 72 | img82.bmp,0 73 | img83.bmp,0 74 | img84.bmp,0 75 | img85.bmp,1 76 | img86.bmp,1 77 | img87.bmp,1 78 | img88.bmp,1 79 | img89.bmp,0 80 | img90.bmp,0 81 | img91.bmp,0 82 | img92.bmp,0 83 | img93.bmp,0 84 | img94.bmp,0 85 | img95.bmp,0 86 | img96.bmp,0 87 | img97.bmp,0 88 | img98.bmp,0 89 | img99.bmp,0 90 | img100.bmp,0 91 | -------------------------------------------------------------------------------- /data/labels_histo_train_fold_4.csv: -------------------------------------------------------------------------------- 1 | img1.bmp,0 2 | img2.bmp,0 3 | img3.bmp,0 4 | img4.bmp,0 5 | img5.bmp,1 6 | img6.bmp,1 7 | img7.bmp,1 8 | img8.bmp,1 9 | img9.bmp,1 10 | img10.bmp,1 11 | img11.bmp,0 12 | img12.bmp,0 13 | img13.bmp,1 14 | img14.bmp,1 15 | img15.bmp,0 16 | img16.bmp,0 17 | img17.bmp,1 18 | img18.bmp,1 19 | img19.bmp,1 20 | img20.bmp,1 21 | img21.bmp,1 22 | img22.bmp,1 23 | img23.bmp,0 24 | img24.bmp,0 25 | img25.bmp,0 26 | img26.bmp,0 27 | img27.bmp,0 28 | img28.bmp,0 29 | img29.bmp,0 30 | img30.bmp,1 31 | img41.bmp,1 32 | img42.bmp,1 33 | img43.bmp,1 34 | img44.bmp,0 35 | img45.bmp,0 36 | img46.bmp,1 37 | img47.bmp,1 38 | img48.bmp,1 39 | img49.bmp,1 40 | img50.bmp,1 41 | img51.bmp,1 42 | img52.bmp,1 43 | img53.bmp,1 44 | img54.bmp,1 45 | img55.bmp,1 46 | img56.bmp,0 47 | img57.bmp,0 48 | img58.bmp,0 49 | img59.bmp,0 50 | img60.bmp,1 51 | img61.bmp,1 52 | img62.bmp,0 53 | img63.bmp,0 54 | img64.bmp,0 55 | img65.bmp,1 56 | img66.bmp,1 57 | img67.bmp,1 58 | img68.bmp,1 59 | img69.bmp,1 60 | img70.bmp,1 61 | img71.bmp,1 62 | img72.bmp,1 63 | img73.bmp,0 64 | img74.bmp,0 65 | img75.bmp,0 66 | img76.bmp,0 67 | img77.bmp,1 68 | img78.bmp,1 69 | img79.bmp,1 70 | img80.bmp,1 71 | img81.bmp,0 72 | img82.bmp,0 73 | img83.bmp,0 74 | img84.bmp,0 75 | img85.bmp,1 76 | img86.bmp,1 77 | img87.bmp,1 78 | img88.bmp,1 79 | img89.bmp,0 80 | img90.bmp,0 81 | img91.bmp,0 82 | img92.bmp,0 83 | img93.bmp,0 84 | img94.bmp,0 85 | img95.bmp,0 86 | img96.bmp,0 87 | img97.bmp,0 88 | img98.bmp,0 89 | img99.bmp,0 90 | img100.bmp,0 91 | -------------------------------------------------------------------------------- /data/labels_histo_train_fold_5.csv: -------------------------------------------------------------------------------- 1 | img1.bmp,0 2 | img2.bmp,0 3 | img3.bmp,0 4 | img4.bmp,0 5 | img5.bmp,1 6 | img6.bmp,1 7 | img7.bmp,1 8 | img8.bmp,1 9 | img9.bmp,1 10 | img10.bmp,1 11 | img11.bmp,0 12 | img12.bmp,0 13 | img13.bmp,1 14 | img14.bmp,1 15 | img15.bmp,0 16 | img16.bmp,0 17 | img17.bmp,1 18 | img18.bmp,1 19 | img19.bmp,1 20 | img20.bmp,1 21 | img21.bmp,1 22 | img22.bmp,1 23 | img23.bmp,0 24 | img24.bmp,0 25 | img25.bmp,0 26 | img26.bmp,0 27 | img27.bmp,0 28 | img28.bmp,0 29 | img29.bmp,0 30 | img30.bmp,1 31 | img31.bmp,1 32 | img32.bmp,1 33 | img33.bmp,0 34 | img34.bmp,0 35 | img35.bmp,0 36 | img36.bmp,1 37 | img37.bmp,0 38 | img38.bmp,1 39 | img39.bmp,0 40 | img40.bmp,1 41 | img51.bmp,1 42 | img52.bmp,1 43 | img53.bmp,1 44 | img54.bmp,1 45 | img55.bmp,1 46 | img56.bmp,0 47 | img57.bmp,0 48 | img58.bmp,0 49 | img59.bmp,0 50 | img60.bmp,1 51 | img61.bmp,1 52 | img62.bmp,0 53 | img63.bmp,0 54 | img64.bmp,0 55 | img65.bmp,1 56 | img66.bmp,1 57 | img67.bmp,1 58 | img68.bmp,1 59 | img69.bmp,1 60 | img70.bmp,1 61 | img71.bmp,1 62 | img72.bmp,1 63 | img73.bmp,0 64 | img74.bmp,0 65 | img75.bmp,0 66 | img76.bmp,0 67 | img77.bmp,1 68 | img78.bmp,1 69 | img79.bmp,1 70 | img80.bmp,1 71 | img81.bmp,0 72 | img82.bmp,0 73 | img83.bmp,0 74 | img84.bmp,0 75 | img85.bmp,1 76 | img86.bmp,1 77 | img87.bmp,1 78 | img88.bmp,1 79 | img89.bmp,0 80 | img90.bmp,0 81 | img91.bmp,0 82 | img92.bmp,0 83 | img93.bmp,0 84 | img94.bmp,0 85 | img95.bmp,0 86 | img96.bmp,0 87 | img97.bmp,0 88 | img98.bmp,0 89 | img99.bmp,0 90 | img100.bmp,0 91 | -------------------------------------------------------------------------------- /data/labels_histo_train_fold_6.csv: -------------------------------------------------------------------------------- 1 | img1.bmp,0 2 | img2.bmp,0 3 | img3.bmp,0 4 | img4.bmp,0 5 | img5.bmp,1 6 | img6.bmp,1 7 | img7.bmp,1 8 | img8.bmp,1 9 | img9.bmp,1 10 | img10.bmp,1 11 | img11.bmp,0 12 | img12.bmp,0 13 | img13.bmp,1 14 | img14.bmp,1 15 | img15.bmp,0 16 | img16.bmp,0 17 | img17.bmp,1 18 | img18.bmp,1 19 | img19.bmp,1 20 | img20.bmp,1 21 | img21.bmp,1 22 | img22.bmp,1 23 | img23.bmp,0 24 | img24.bmp,0 25 | img25.bmp,0 26 | img26.bmp,0 27 | img27.bmp,0 28 | img28.bmp,0 29 | img29.bmp,0 30 | img30.bmp,1 31 | img31.bmp,1 32 | img32.bmp,1 33 | img33.bmp,0 34 | img34.bmp,0 35 | img35.bmp,0 36 | img36.bmp,1 37 | img37.bmp,0 38 | img38.bmp,1 39 | img39.bmp,0 40 | img40.bmp,1 41 | img41.bmp,1 42 | img42.bmp,1 43 | img43.bmp,1 44 | img44.bmp,0 45 | img45.bmp,0 46 | img46.bmp,1 47 | img47.bmp,1 48 | img48.bmp,1 49 | img49.bmp,1 50 | img50.bmp,1 51 | img61.bmp,1 52 | img62.bmp,0 53 | img63.bmp,0 54 | img64.bmp,0 55 | img65.bmp,1 56 | img66.bmp,1 57 | img67.bmp,1 58 | img68.bmp,1 59 | img69.bmp,1 60 | img70.bmp,1 61 | img71.bmp,1 62 | img72.bmp,1 63 | img73.bmp,0 64 | img74.bmp,0 65 | img75.bmp,0 66 | img76.bmp,0 67 | img77.bmp,1 68 | img78.bmp,1 69 | img79.bmp,1 70 | img80.bmp,1 71 | img81.bmp,0 72 | img82.bmp,0 73 | img83.bmp,0 74 | img84.bmp,0 75 | img85.bmp,1 76 | img86.bmp,1 77 | img87.bmp,1 78 | img88.bmp,1 79 | img89.bmp,0 80 | img90.bmp,0 81 | img91.bmp,0 82 | img92.bmp,0 83 | img93.bmp,0 84 | img94.bmp,0 85 | img95.bmp,0 86 | img96.bmp,0 87 | img97.bmp,0 88 | img98.bmp,0 89 | img99.bmp,0 90 | img100.bmp,0 91 | -------------------------------------------------------------------------------- /data/labels_histo_train_fold_7.csv: -------------------------------------------------------------------------------- 1 | img1.bmp,0 2 | img2.bmp,0 3 | img3.bmp,0 4 | img4.bmp,0 5 | img5.bmp,1 6 | img6.bmp,1 7 | img7.bmp,1 8 | img8.bmp,1 9 | img9.bmp,1 10 | img10.bmp,1 11 | img11.bmp,0 12 | img12.bmp,0 13 | img13.bmp,1 14 | img14.bmp,1 15 | img15.bmp,0 16 | img16.bmp,0 17 | img17.bmp,1 18 | img18.bmp,1 19 | img19.bmp,1 20 | img20.bmp,1 21 | img21.bmp,1 22 | img22.bmp,1 23 | img23.bmp,0 24 | img24.bmp,0 25 | img25.bmp,0 26 | img26.bmp,0 27 | img27.bmp,0 28 | img28.bmp,0 29 | img29.bmp,0 30 | img30.bmp,1 31 | img31.bmp,1 32 | img32.bmp,1 33 | img33.bmp,0 34 | img34.bmp,0 35 | img35.bmp,0 36 | img36.bmp,1 37 | img37.bmp,0 38 | img38.bmp,1 39 | img39.bmp,0 40 | img40.bmp,1 41 | img41.bmp,1 42 | img42.bmp,1 43 | img43.bmp,1 44 | img44.bmp,0 45 | img45.bmp,0 46 | img46.bmp,1 47 | img47.bmp,1 48 | img48.bmp,1 49 | img49.bmp,1 50 | img50.bmp,1 51 | img51.bmp,1 52 | img52.bmp,1 53 | img53.bmp,1 54 | img54.bmp,1 55 | img55.bmp,1 56 | img56.bmp,0 57 | img57.bmp,0 58 | img58.bmp,0 59 | img59.bmp,0 60 | img60.bmp,1 61 | img71.bmp,1 62 | img72.bmp,1 63 | img73.bmp,0 64 | img74.bmp,0 65 | img75.bmp,0 66 | img76.bmp,0 67 | img77.bmp,1 68 | img78.bmp,1 69 | img79.bmp,1 70 | img80.bmp,1 71 | img81.bmp,0 72 | img82.bmp,0 73 | img83.bmp,0 74 | img84.bmp,0 75 | img85.bmp,1 76 | img86.bmp,1 77 | img87.bmp,1 78 | img88.bmp,1 79 | img89.bmp,0 80 | img90.bmp,0 81 | img91.bmp,0 82 | img92.bmp,0 83 | img93.bmp,0 84 | img94.bmp,0 85 | img95.bmp,0 86 | img96.bmp,0 87 | img97.bmp,0 88 | img98.bmp,0 89 | img99.bmp,0 90 | img100.bmp,0 91 | -------------------------------------------------------------------------------- /data/labels_histo_train_fold_8.csv: -------------------------------------------------------------------------------- 1 | img1.bmp,0 2 | img2.bmp,0 3 | img3.bmp,0 4 | img4.bmp,0 5 | img5.bmp,1 6 | img6.bmp,1 7 | img7.bmp,1 8 | img8.bmp,1 9 | img9.bmp,1 10 | img10.bmp,1 11 | img11.bmp,0 12 | img12.bmp,0 13 | img13.bmp,1 14 | img14.bmp,1 15 | img15.bmp,0 16 | img16.bmp,0 17 | img17.bmp,1 18 | img18.bmp,1 19 | img19.bmp,1 20 | img20.bmp,1 21 | img21.bmp,1 22 | img22.bmp,1 23 | img23.bmp,0 24 | img24.bmp,0 25 | img25.bmp,0 26 | img26.bmp,0 27 | img27.bmp,0 28 | img28.bmp,0 29 | img29.bmp,0 30 | img30.bmp,1 31 | img31.bmp,1 32 | img32.bmp,1 33 | img33.bmp,0 34 | img34.bmp,0 35 | img35.bmp,0 36 | img36.bmp,1 37 | img37.bmp,0 38 | img38.bmp,1 39 | img39.bmp,0 40 | img40.bmp,1 41 | img41.bmp,1 42 | img42.bmp,1 43 | img43.bmp,1 44 | img44.bmp,0 45 | img45.bmp,0 46 | img46.bmp,1 47 | img47.bmp,1 48 | img48.bmp,1 49 | img49.bmp,1 50 | img50.bmp,1 51 | img51.bmp,1 52 | img52.bmp,1 53 | img53.bmp,1 54 | img54.bmp,1 55 | img55.bmp,1 56 | img56.bmp,0 57 | img57.bmp,0 58 | img58.bmp,0 59 | img59.bmp,0 60 | img60.bmp,1 61 | img61.bmp,1 62 | img62.bmp,0 63 | img63.bmp,0 64 | img64.bmp,0 65 | img65.bmp,1 66 | img66.bmp,1 67 | img67.bmp,1 68 | img68.bmp,1 69 | img69.bmp,1 70 | img70.bmp,1 71 | img81.bmp,0 72 | img82.bmp,0 73 | img83.bmp,0 74 | img84.bmp,0 75 | img85.bmp,1 76 | img86.bmp,1 77 | img87.bmp,1 78 | img88.bmp,1 79 | img89.bmp,0 80 | img90.bmp,0 81 | img91.bmp,0 82 | img92.bmp,0 83 | img93.bmp,0 84 | img94.bmp,0 85 | img95.bmp,0 86 | img96.bmp,0 87 | img97.bmp,0 88 | img98.bmp,0 89 | img99.bmp,0 90 | img100.bmp,0 91 | -------------------------------------------------------------------------------- /data/labels_histo_train_fold_9.csv: -------------------------------------------------------------------------------- 1 | img1.bmp,0 2 | img2.bmp,0 3 | img3.bmp,0 4 | img4.bmp,0 5 | img5.bmp,1 6 | img6.bmp,1 7 | img7.bmp,1 8 | img8.bmp,1 9 | img9.bmp,1 10 | img10.bmp,1 11 | img11.bmp,0 12 | img12.bmp,0 13 | img13.bmp,1 14 | img14.bmp,1 15 | img15.bmp,0 16 | img16.bmp,0 17 | img17.bmp,1 18 | img18.bmp,1 19 | img19.bmp,1 20 | img20.bmp,1 21 | img21.bmp,1 22 | img22.bmp,1 23 | img23.bmp,0 24 | img24.bmp,0 25 | img25.bmp,0 26 | img26.bmp,0 27 | img27.bmp,0 28 | img28.bmp,0 29 | img29.bmp,0 30 | img30.bmp,1 31 | img31.bmp,1 32 | img32.bmp,1 33 | img33.bmp,0 34 | img34.bmp,0 35 | img35.bmp,0 36 | img36.bmp,1 37 | img37.bmp,0 38 | img38.bmp,1 39 | img39.bmp,0 40 | img40.bmp,1 41 | img41.bmp,1 42 | img42.bmp,1 43 | img43.bmp,1 44 | img44.bmp,0 45 | img45.bmp,0 46 | img46.bmp,1 47 | img47.bmp,1 48 | img48.bmp,1 49 | img49.bmp,1 50 | img50.bmp,1 51 | img51.bmp,1 52 | img52.bmp,1 53 | img53.bmp,1 54 | img54.bmp,1 55 | img55.bmp,1 56 | img56.bmp,0 57 | img57.bmp,0 58 | img58.bmp,0 59 | img59.bmp,0 60 | img60.bmp,1 61 | img61.bmp,1 62 | img62.bmp,0 63 | img63.bmp,0 64 | img64.bmp,0 65 | img65.bmp,1 66 | img66.bmp,1 67 | img67.bmp,1 68 | img68.bmp,1 69 | img69.bmp,1 70 | img70.bmp,1 71 | img71.bmp,1 72 | img72.bmp,1 73 | img73.bmp,0 74 | img74.bmp,0 75 | img75.bmp,0 76 | img76.bmp,0 77 | img77.bmp,1 78 | img78.bmp,1 79 | img79.bmp,1 80 | img80.bmp,1 81 | img91.bmp,0 82 | img92.bmp,0 83 | img93.bmp,0 84 | img94.bmp,0 85 | img95.bmp,0 86 | img96.bmp,0 87 | img97.bmp,0 88 | img98.bmp,0 89 | img99.bmp,0 90 | img100.bmp,0 91 | -------------------------------------------------------------------------------- /data/labels_histo_valid_fold_1.csv: -------------------------------------------------------------------------------- 1 | img1.bmp,0 2 | img2.bmp,0 3 | img3.bmp,0 4 | img4.bmp,0 5 | img5.bmp,1 6 | img6.bmp,1 7 | img7.bmp,1 8 | img8.bmp,1 9 | img9.bmp,1 10 | img10.bmp,1 11 | -------------------------------------------------------------------------------- /data/labels_histo_valid_fold_10.csv: -------------------------------------------------------------------------------- 1 | img91.bmp,0 2 | img92.bmp,0 3 | img93.bmp,0 4 | img94.bmp,0 5 | img95.bmp,0 6 | img96.bmp,0 7 | img97.bmp,0 8 | img98.bmp,0 9 | img99.bmp,0 10 | img100.bmp,0 11 | -------------------------------------------------------------------------------- /data/labels_histo_valid_fold_2.csv: -------------------------------------------------------------------------------- 1 | img11.bmp,0 2 | img12.bmp,0 3 | img13.bmp,1 4 | img14.bmp,1 5 | img15.bmp,0 6 | img16.bmp,0 7 | img17.bmp,1 8 | img18.bmp,1 9 | img19.bmp,1 10 | img20.bmp,1 11 | -------------------------------------------------------------------------------- /data/labels_histo_valid_fold_3.csv: -------------------------------------------------------------------------------- 1 | img21.bmp,1 2 | img22.bmp,1 3 | img23.bmp,0 4 | img24.bmp,0 5 | img25.bmp,0 6 | img26.bmp,0 7 | img27.bmp,0 8 | img28.bmp,0 9 | img29.bmp,0 10 | img30.bmp,1 11 | -------------------------------------------------------------------------------- /data/labels_histo_valid_fold_4.csv: -------------------------------------------------------------------------------- 1 | img31.bmp,1 2 | img32.bmp,1 3 | img33.bmp,0 4 | img34.bmp,0 5 | img35.bmp,0 6 | img36.bmp,1 7 | img37.bmp,0 8 | img38.bmp,1 9 | img39.bmp,0 10 | img40.bmp,1 11 | -------------------------------------------------------------------------------- /data/labels_histo_valid_fold_5.csv: -------------------------------------------------------------------------------- 1 | img41.bmp,1 2 | img42.bmp,1 3 | img43.bmp,1 4 | img44.bmp,0 5 | img45.bmp,0 6 | img46.bmp,1 7 | img47.bmp,1 8 | img48.bmp,1 9 | img49.bmp,1 10 | img50.bmp,1 11 | -------------------------------------------------------------------------------- /data/labels_histo_valid_fold_6.csv: -------------------------------------------------------------------------------- 1 | img51.bmp,1 2 | img52.bmp,1 3 | img53.bmp,1 4 | img54.bmp,1 5 | img55.bmp,1 6 | img56.bmp,0 7 | img57.bmp,0 8 | img58.bmp,0 9 | img59.bmp,0 10 | img60.bmp,1 11 | -------------------------------------------------------------------------------- /data/labels_histo_valid_fold_7.csv: -------------------------------------------------------------------------------- 1 | img61.bmp,1 2 | img62.bmp,0 3 | img63.bmp,0 4 | img64.bmp,0 5 | img65.bmp,1 6 | img66.bmp,1 7 | img67.bmp,1 8 | img68.bmp,1 9 | img69.bmp,1 10 | img70.bmp,1 11 | -------------------------------------------------------------------------------- /data/labels_histo_valid_fold_8.csv: -------------------------------------------------------------------------------- 1 | img71.bmp,1 2 | img72.bmp,1 3 | img73.bmp,0 4 | img74.bmp,0 5 | img75.bmp,0 6 | img76.bmp,0 7 | img77.bmp,1 8 | img78.bmp,1 9 | img79.bmp,1 10 | img80.bmp,1 11 | -------------------------------------------------------------------------------- /data/labels_histo_valid_fold_9.csv: -------------------------------------------------------------------------------- 1 | img81.bmp,0 2 | img82.bmp,0 3 | img83.bmp,0 4 | img84.bmp,0 5 | img85.bmp,1 6 | img86.bmp,1 7 | img87.bmp,1 8 | img88.bmp,1 9 | img89.bmp,0 10 | img90.bmp,0 11 | -------------------------------------------------------------------------------- /dataset/colon_cancer_dataset.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms, datasets 2 | from cv2 import imread, imwrite, resize, INTER_LINEAR 3 | from skimage.io import imsave 4 | from skimage.draw import polygon as ski_polygon 5 | from sklearn.metrics import roc_auc_score 6 | from skimage.filters import threshold_otsu 7 | from itertools import zip_longest 8 | from math import floor 9 | import openslide as ops 10 | import warnings as ws 11 | import numpy as np 12 | import random 13 | import torch 14 | import h5py 15 | import math 16 | import cv2 17 | import os 18 | import time 19 | import csv 20 | import logging 21 | import os 22 | import xml.etree.ElementTree as Xml 23 | from collections import OrderedDict, defaultdict, namedtuple 24 | from typing import Sequence, Any, Tuple 25 | 26 | import fnmatch 27 | import logging 28 | import os 29 | from collections import namedtuple 30 | from typing import Dict 31 | 32 | from PIL import Image 33 | from PIL import ImageDraw 34 | from progress.bar import IncrementalBar 35 | 36 | import openslide 37 | from PIL import Image 38 | class OpenHisto: 39 | ''' 40 | Read one histopathology image 41 | ''' 42 | def __init__(self, directory, 43 | data_transform=transforms.Compose([ 44 | transforms.ToPILImage(), 45 | transforms.ToTensor(), 46 | ])): 47 | self.img = np.asarray(imread(directory)) 48 | self.data_transform = data_transform 49 | 50 | def read_region(self, pos, level, size): 51 | ''' 52 | x, y are the cardinality axis: x for column and y for rows 53 | :param x: x location 54 | :param y: y location 55 | :param level: the view we are looking right now 56 | :param size: size of patch 57 | :return: 58 | ''' 59 | x, y = pos 60 | factor = np.around(2 ** level) 61 | #print("in read_region", self.img.shape) 62 | patch = self.img[max(y, 0) : min(self.img.shape[0]-1, y+int(size[1]*factor)), 63 | max(x, 0) : min(self.img.shape[1]-1, x+int(size[0]*factor)), :] 64 | #print("in read_region patch 1", patch.shape) 65 | if patch.shape != size: 66 | patch = np.pad(patch, ((max(0-y, 0), max(min(y+int(size[1]*factor)+1-self.img.shape[0], int(size[1]*factor)), 0)), 67 | (max(0-x, 0), max(min(x+int(size[0]*factor)+1-self.img.shape[1], int(size[0]*factor)), 0)), (0, 0))) 68 | if level != 0: 69 | patch = resize(patch, (int(patch.shape[0]//factor), int(patch.shape[1]//factor)), INTER_LINEAR) 70 | #print("in read_region", patch.shape) 71 | return patch 72 | 73 | def extract_patches(self, x, y, level=0, size=(50, 50), show_mode='channel_first'): 74 | ''' 75 | Read patches from one gigapixel image 76 | Should read the patches from the center 77 | :return: 78 | ''' 79 | ''' 80 | return img.read_region((int(x), int(y)), level, size) 81 | ''' 82 | 83 | this_patch = np.asarray(self.read_region((int(x), int(y)), level, size))[:, :, :3].astype(np.uint8) 84 | if self.data_transform is not None: 85 | this_patch = self.data_transform(this_patch) 86 | return this_patch 87 | 88 | 89 | def get_mask(self, level=0): 90 | assert False, "get pixel-level annotationsn is not implemented" 91 | return 92 | 93 | def get_patches(self, x, y, level=0, size=(0, 0), show_mode='channel_first'): 94 | ''' 95 | :param x: a list of x-axis for patches. (batch_dim, [num_of_patches for one image]) 96 | :param y: a list of y-axis for patches. (batch_dim, [num_of_patches for one image]) 97 | :param level: 98 | :param size: 99 | :return: 100 | ''' 101 | return self.extract_patches(x, y, level=level, size=size, show_mode=show_mode) 102 | 103 | def get_size(self): 104 | return self.img.shape 105 | 106 | class HistoPatchwiseReader: 107 | ''' 108 | Reading NCAM Images by Fanjie Kong 109 | ''' 110 | # Now we only support batch size 1 just for simplicity 111 | def __init__(self, directory, annotation_name='annotations.csv', batch_size=1, train=False): 112 | self.batch_size = batch_size 113 | self.directory = directory 114 | self.annotations = os.path.join(directory, annotation_name) 115 | self.cur_batch = None 116 | self.__cur_name = None 117 | self.train = train 118 | self.img_gen = self.generator() 119 | self.data_list = self._get_data() 120 | 121 | 122 | def __len__(self): return int(len(self.data_list)//self.batch_size) 123 | 124 | def _get_data(self): 125 | 126 | data_list = [] 127 | annotations = np.loadtxt(self.annotations, delimiter='\n', dtype=str) 128 | for e_ann in annotations: 129 | img_dir, label = e_ann.split(',') 130 | data_list.append((img_dir, int(label))) 131 | return data_list 132 | 133 | def _batcher(self, inputs, batch_size=1, fillvalue=None): 134 | inputs = iter(inputs) 135 | args = [iter(inputs)] * batch_size 136 | return zip_longest(fillvalue=fillvalue, *args) 137 | 138 | def _get_cur_name(self): 139 | return self.__cur_name 140 | 141 | def generator(self): 142 | if self.train: 143 | np.random.shuffle(self.data_list) 144 | dataGenerator = self._batcher(self.data_list, self.batch_size) 145 | while dataGenerator: 146 | try: 147 | self.cur_batch = next(dataGenerator) 148 | self.__cur_name = self.cur_batch[0][0] 149 | self.cur_batch = self._DataPathtoData(self.cur_batch) 150 | yield self.cur_batch 151 | except StopIteration: 152 | print('Finished this epoch') 153 | break 154 | 155 | def _DataPathtoData(self, batch): 156 | ''' 157 | :param batch: 158 | :return: reading data from the file path in batch 159 | ''' 160 | new_batch = [] 161 | #print(os.path.join(self.directory, batch[0][0])) 162 | for e_b in batch: 163 | if e_b is not None: 164 | new_batch.append((OpenHisto(os.path.join(self.directory, e_b[0])), e_b[1])) 165 | return new_batch -------------------------------------------------------------------------------- /dataset/custom_dataset.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms, datasets 2 | from cv2 import imread, imwrite, resize, INTER_LINEAR 3 | from skimage.io import imsave 4 | from skimage.draw import polygon as ski_polygon 5 | from sklearn.metrics import roc_auc_score 6 | from skimage.filters import threshold_otsu 7 | from itertools import zip_longest 8 | from math import floor 9 | import openslide as ops 10 | import warnings as ws 11 | import numpy as np 12 | import random 13 | import torch 14 | import h5py 15 | import math 16 | import cv2 17 | import os 18 | import time 19 | import csv 20 | import logging 21 | import os 22 | import xml.etree.ElementTree as Xml 23 | from collections import OrderedDict, defaultdict, namedtuple 24 | from typing import Sequence, Any, Tuple 25 | 26 | import fnmatch 27 | import logging 28 | import os 29 | from collections import namedtuple 30 | from typing import Dict 31 | 32 | from PIL import Image 33 | from PIL import ImageDraw 34 | from progress.bar import IncrementalBar 35 | 36 | import openslide 37 | from PIL import Image 38 | class OpenImage: 39 | ''' 40 | Read one histopathology image 41 | ''' 42 | def __init__(self, directory, 43 | data_transform=transforms.Compose([ 44 | transforms.ToPILImage(), 45 | transforms.ToTensor(), 46 | ])): 47 | self.img = np.asarray(imread(directory)) 48 | self.data_transform = data_transform 49 | 50 | def read_region(self, pos, level, size): 51 | ''' 52 | x, y are the cardinality axis: x for column and y for rows 53 | :param x: x location 54 | :param y: y location 55 | :param level: the view we are looking right now 56 | :param size: size of patch 57 | :return: 58 | ''' 59 | x, y = pos 60 | factor = np.around(2 ** level) 61 | #print("in read_region", self.img.shape) 62 | patch = self.img[max(y, 0) : min(self.img.shape[0]-1, y+int(size[1]*factor)), 63 | max(x, 0) : min(self.img.shape[1]-1, x+int(size[0]*factor)), :] 64 | #print("in read_region patch 1", patch.shape) 65 | if patch.shape != size: 66 | patch = np.pad(patch, ((max(0-y, 0), max(min(y+int(size[1]*factor)+1-self.img.shape[0], int(size[1]*factor)), 0)), 67 | (max(0-x, 0), max(min(x+int(size[0]*factor)+1-self.img.shape[1], int(size[0]*factor)), 0)), (0, 0))) 68 | if level != 0: 69 | patch = resize(patch, (int(patch.shape[0]//factor), int(patch.shape[1]//factor)), INTER_LINEAR) 70 | #print("in read_region", patch.shape) 71 | return patch 72 | 73 | def extract_patches(self, x, y, level=0, size=(50, 50), show_mode='channel_first'): 74 | ''' 75 | Read patches from one gigapixel image 76 | Should read the patches from the center 77 | :return: 78 | ''' 79 | ''' 80 | return img.read_region((int(x), int(y)), level, size) 81 | ''' 82 | 83 | this_patch = np.asarray(self.read_region((int(x), int(y)), level, size))[:, :, :3].astype(np.uint8) 84 | if self.data_transform is not None: 85 | this_patch = self.data_transform(this_patch) 86 | return this_patch 87 | 88 | 89 | def get_mask(self, level=0): 90 | assert False, "get pixel-level annotationsn is not implemented" 91 | return 92 | 93 | def get_patches(self, x, y, level=0, size=(0, 0), show_mode='channel_first'): 94 | ''' 95 | :param x: a list of x-axis for patches. (batch_dim, [num_of_patches for one image]) 96 | :param y: a list of y-axis for patches. (batch_dim, [num_of_patches for one image]) 97 | :param level: 98 | :param size: 99 | :return: 100 | ''' 101 | return self.extract_patches(x, y, level=level, size=size, show_mode=show_mode) 102 | 103 | def get_size(self): 104 | return self.img.shape 105 | 106 | class OpenGigapixel: 107 | ''' 108 | Read one Camelyon16 reader 109 | ''' 110 | def __init__(self, directory, 111 | data_transform=transforms.Compose([ 112 | transforms.ToPILImage(), 113 | transforms.ToTensor(), 114 | ])): 115 | self.img = ops.OpenSlide(os.path.join(directory)) 116 | self.cur_name = directory 117 | self.data_transform = data_transform 118 | 119 | def extract_patches(self, x, y, level=0, size=(50, 50), show_mode='channel_first'): 120 | ''' 121 | Read patches from one gigapixel image 122 | Should read the patches from the center 123 | :return: 124 | ''' 125 | ''' 126 | return img.read_region((int(x), int(y)), level, size) 127 | ''' 128 | 129 | this_patch = np.asarray(self.img.read_region((int(x), int(y)), level, size))[:, :, :3].astype(np.uint8) 130 | if self.data_transform is not None: 131 | this_patch = self.data_transform(this_patch) 132 | return this_patch 133 | 134 | 135 | def get_mask(self, level=0): 136 | assert False, "get pixel-level annotationsn is not implemented" 137 | return 138 | 139 | def get_patches(self, x, y, level=0, size=(0, 0), show_mode='channel_first'): 140 | ''' 141 | :param x: a list of x-axis for patches. (batch_dim, [num_of_patches for one image]) 142 | :param y: a list of y-axis for patches. (batch_dim, [num_of_patches for one image]) 143 | :param level: 144 | :param size: 145 | :return: 146 | ''' 147 | return self.extract_patches(x, y, level=level, size=size, show_mode=show_mode) 148 | 149 | def get_size(self): 150 | return self.img.dimensions 151 | 152 | def _get_cur_name(self): 153 | return self.cur_name 154 | 155 | def find_roi_normal(self, rgb_image): 156 | # self.mask = cv2.cvtColor(self.mask, cv2.CV_32SC1) 157 | hsv = cv2.cvtColor(rgb_image, cv2.COLOR_BGR2HSV) 158 | # [20, 20, 20] 159 | lower_red = np.array([30, 30, 30]) 160 | # [255, 255, 255] 161 | upper_red = np.array([200, 200, 200]) 162 | mask = cv2.inRange(hsv, lower_red, upper_red) 163 | res = cv2.bitwise_and(rgb_image, rgb_image, mask=mask) 164 | 165 | # (50, 50) 166 | close_kernel = np.ones((50, 50), dtype=np.uint8) 167 | #close_kernel_tmp = np.ones((30, 30), dtype=np.uint8) 168 | image_close = Image.fromarray(cv2.morphologyEx(np.array(mask), cv2.MORPH_CLOSE, close_kernel)) 169 | # (30, 30) 170 | open_kernel = np.ones((30, 30), dtype=np.uint8) 171 | image_open = Image.fromarray(cv2.morphologyEx(np.array(image_close), cv2.MORPH_OPEN, open_kernel)) 172 | return image_open 173 | 174 | def get_preprocessed_locations_v2(self, frame_size, mask_level=4): 175 | ''' 176 | :param frame_size: size of one frame 177 | :param mask_level: the shrink level of the binary mask 178 | :return: 179 | ''' 180 | # Binary the map 181 | # Get the location 182 | name = self._get_cur_name() 183 | binary_mask_dir = os.path.join(self._get_cur_name().replace('.tif', 'frame_size_'+str(frame_size[0])+'level_'+str(mask_level)+'_binary.png')) 184 | 185 | if os.path.isfile(binary_mask_dir): 186 | b_mask = np.asarray(imread(binary_mask_dir)) 187 | else: 188 | img = self.img 189 | lowimg = np.asarray(img.read_region((int(0), int(0)), mask_level, img.level_dimensions[mask_level])) 190 | mask = np.asarray(self.find_roi_normal(lowimg)) 191 | imsave(binary_mask_dir, mask.astype(np.int), check_contrast=False) 192 | b_mask = mask 193 | size_img = self.get_size() 194 | x = np.arange(0, size_img[0], frame_size[0]) 195 | y = np.arange(0, size_img[1], frame_size[1]) 196 | xx, yy = np.meshgrid(x, y) 197 | loc_list = [(xx.flatten()[i], yy.flatten()[i]) for i in range(len(xx.flatten()[:]))] 198 | locs_array = np.array(loc_list) 199 | # filter 200 | filtered_locs_array = [] 201 | for each_loc in locs_array: 202 | mapped_loc = np.floor(each_loc // 2**mask_level) 203 | ROI_exist = b_mask[int(mapped_loc[1]): int(mapped_loc[1])+int(np.floor(frame_size[0] // 2**mask_level)) - 1, 204 | int(mapped_loc[0]): int(mapped_loc[0]) + int(np.floor(frame_size[1] // 2**mask_level)) - 1].sum() > 0 205 | if ROI_exist: 206 | filtered_locs_array.append(each_loc) 207 | filtered_locs_array = np.array(filtered_locs_array) 208 | return filtered_locs_array 209 | 210 | class CustomDataReader: 211 | ''' 212 | 213 | ''' 214 | # Now we only support batch size 1 just for simplicity 215 | def __init__(self, directory, annotation_name='annotations.csv', batch_size=1, train=False, reader_type=None): 216 | self.batch_size = batch_size 217 | self.directory = directory 218 | self.annotations = os.path.join(directory, annotation_name) 219 | self.cur_batch = None 220 | self.__cur_name = None 221 | self.train = train 222 | self.reader = OpenImage if reader_type is None else OpenGigapixel 223 | self.img_gen = self.generator() 224 | self.data_list = self._get_data() 225 | 226 | 227 | def __len__(self): return int(len(self.data_list)//self.batch_size) 228 | 229 | def _get_data(self): 230 | 231 | data_list = [] 232 | annotations = np.loadtxt(self.annotations, delimiter='\n', dtype=str) 233 | for e_ann in annotations: 234 | img_dir, label = e_ann.split(',') 235 | data_list.append((img_dir, int(label))) 236 | return data_list 237 | 238 | def _batcher(self, inputs, batch_size=1, fillvalue=None): 239 | inputs = iter(inputs) 240 | args = [iter(inputs)] * batch_size 241 | return zip_longest(fillvalue=fillvalue, *args) 242 | 243 | def _get_cur_name(self): 244 | return self.__cur_name 245 | 246 | def generator(self): 247 | if self.train: 248 | np.random.shuffle(self.data_list) 249 | dataGenerator = self._batcher(self.data_list, self.batch_size) 250 | while dataGenerator: 251 | try: 252 | self.cur_batch = next(dataGenerator) 253 | self.__cur_name = self.cur_batch[0][0] 254 | self.cur_batch = self._DataPathtoData(self.cur_batch) 255 | yield self.cur_batch 256 | except StopIteration: 257 | print('Finished this epoch') 258 | break 259 | 260 | def _DataPathtoData(self, batch): 261 | ''' 262 | :param batch: 263 | :return: reading data from the file path in batch 264 | ''' 265 | new_batch = [] 266 | #print(os.path.join(self.directory, batch[0][0])) 267 | for e_b in batch: 268 | if e_b is not None: 269 | new_batch.append((OpenHisto(os.path.join(self.directory, e_b[0])), e_b[1])) 270 | return new_batch -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from utils.utils import save_checkpoint, load_checkpoint, save_metrics, load_metrics, get_activation 4 | 5 | def eval(): 6 | sample_count = 0.0 7 | running_acc = 0.0 8 | running_loss = 0.0 9 | valid_running_acc = 0.0 10 | valid_running_loss = 0.0 11 | global_step = 0 12 | best_ACC = 0 13 | load_checkpoint(os.path.join(saved_path, 'con_histo_cross_'+str(i+1)+'_test1.pt'), model) 14 | model.eval() 15 | with torch.no_grad(): 16 | 17 | # validation loop 18 | print("Begin Validation") 19 | valid_sample_count = 0.0 20 | valid_running_loss = 0.0 21 | valid_running_acc = 0.0 22 | for val_batch in tqdm.tqdm(valid_loader.generator()): 23 | val_input_reader = [e_reader for e_reader, _ in val_batch] 24 | labels = torch.stack([torch.Tensor([e_label]) for _, e_label in val_batch]).squeeze(1) 25 | labels = labels.type(torch.LongTensor).to(device) 26 | output, _ = model(val_input_reader) 27 | loss = criterion(output, labels) 28 | valid_running_loss += loss.item() 29 | _, pred_labels = output.data.cpu().topk(1, dim=1) 30 | valid_running_acc += torch.sum(pred_labels.t().squeeze() == labels.data.cpu().squeeze()).item() 31 | valid_sample_count += labels.shape[0] 32 | # evaluation 33 | average_valid_loss = valid_running_loss / len(valid_loader) 34 | average_valid_acc = valid_running_acc / valid_sample_count 35 | 36 | # print progress 37 | print('Valid Loss: {:.4f}, Valid Acc: {:.4f}' 38 | .format(average_valid_loss, average_valid_acc)) 39 | print("Max memory used: {} Mb ".format(torch.cuda.memory_allocated(device=0)/ (1000 * 1000))) 40 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Created on 2022-03-12 3 | @author: Fanjie Kong 4 | ''' 5 | #!/usr/bin/env python 6 | # -*- coding: utf-8 -*- 7 | import os 8 | import time 9 | import torch 10 | import shutil 11 | import argparse 12 | import warnings 13 | import numpy as np 14 | import pandas as pd 15 | from math import floor, sqrt 16 | from skimage.io import imsave 17 | import torch.nn as nn 18 | from torch.nn import functional as F 19 | from zoom_in.zoom_in import ZoomInNet 20 | from utils.dataset_utils import load_annotations 21 | from utils.utils import get_optim 22 | from utils.utils import save_checkpoint, load_checkpoint, save_metrics, load_metrics, get_activation 23 | 24 | from dataset.custom_dataset import CustomDataReader 25 | from dataset.colon_cancer_dataset import HistoPatchwiseReader 26 | 27 | from models.attention_models import Attention, AttentionOnAttention 28 | from models.feature_extractors import FeatureExtractor 29 | from models.classifier import Classifier 30 | 31 | from train import train 32 | from eval import eval 33 | 34 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 35 | 36 | 37 | def get_models(args): 38 | 39 | if args.models_set == 'base': 40 | attention = Attention() 41 | aoa = AttentionOnAttention() 42 | fe = FeatureExtractor() 43 | clf = Classifier(args.num_classes) 44 | 45 | elif args.models_set == 'ResNet': 46 | attention = Attention() 47 | aoa = AttentionOnAttention() 48 | fe = FeatureExtractor() 49 | clf = Classifier(args.num_classes) 50 | 51 | model = ZoomInNet( 52 | attention, aoa, fe, clf, 53 | batch_size=args.batch_size, 54 | tile_size= args.tile_size, 55 | patch_size= args.patch_size, 56 | stage_1_level = np.log2(10), 57 | stage_2_level = np.log2(5), 58 | original_image_level = 0, 59 | num_classes=args.num_classes, 60 | reg_strength = args.reg_strength, 61 | n_patches = args.n_patches, 62 | device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')) 63 | return attention, aoa, fe, clf, model 64 | 65 | def load_configuratiton(model, args): 66 | config = dict() 67 | config['device'] = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 68 | config['model_name'] =args.model_name 69 | config['optimizer'] = get_optim(model, args) 70 | config['scheduler'] = torch.optim.lr_scheduler.StepLR(config['optimizer'], step_size=args.lr_decay_steps, gamma=args.lr_decay_ratio) 71 | config['criterion'] = nn.CrossEntropyLoss() 72 | config['num_epochs'] = args.num_epochs 73 | config['eval_every_epochs'] = args.valid_step 74 | config['save_path'] = args.output 75 | config['clip'] = args.clipnorm 76 | config['contrastive_learning'] = args.contrastive_learning 77 | config['apply_con_epochs'] = args.apply_con_epochs 78 | config['device'] = "cuda" if args.use_gpu else "cpu" 79 | return config 80 | 81 | def get_custom_dataset(args): 82 | 83 | train_at, valid_at, test_at = load_annotations(args.dataset) 84 | train_loader = CustomDataReader(dataset_dir, annotation_name=train_at, 85 | batch_size=args.batch_size, train=True) 86 | valid_loader = CustomDataReader(dataset_dir, annotation_name=valid_at, 87 | batch_size=1, train=False) 88 | test_loader = CustomDataReader(dataset_dir, annotation_name=test_at, 89 | batch_size=1, train=False) 90 | return train_loader, valid_loader, test_loader 91 | 92 | def main(argv): 93 | parser = argparse.ArgumentParser( 94 | description=("Train a model with attention sampling on the " 95 | "artificial mnist dataset") 96 | ) 97 | parser.add_argument( 98 | "dataset", 99 | help="The directory that contains the dataset ", 100 | ) 101 | parser.add_argument( 102 | "output", 103 | help="An output directory" 104 | ) 105 | 106 | parser.add_argument( 107 | "--TenCrossValidation", default=True, action='store_true') 108 | 109 | parser.add_argument( 110 | "--optimizer", 111 | choices=["sgd", "adam"], 112 | default="adam", 113 | help="Choose the optimizer for Q1" 114 | ) 115 | parser.add_argument( 116 | "--models_set", 117 | choices=["base", "ResNet"], 118 | default="base", 119 | help="Choose the different architecture of the zoom-in modules" 120 | ) 121 | parser.add_argument( 122 | "--mode", 123 | choices=["10CrossValidation", "Training", "Evaluation"], 124 | default="10CrossValidation", 125 | help="working mode of the program" 126 | ) 127 | parser.add_argument( 128 | "--lr", 129 | type=float, 130 | default=0.0001, 131 | help="Set the optimizer's learning rate" 132 | ) 133 | parser.add_argument( 134 | "--weight_decay", 135 | type=float, 136 | default=1e-5, 137 | help="Set the optimizer's weight_decay" 138 | ) 139 | parser.add_argument( 140 | "--lr_decay_steps", 141 | type=int, 142 | default=30, 143 | help="Set the decay steps of learning rate scheduler" 144 | ) 145 | parser.add_argument( 146 | "--lr_decay_ratio", 147 | type=float, 148 | default=0.9, 149 | help="Set the decay ratio of learning rate scheduler" 150 | ) 151 | parser.add_argument( 152 | "--clipnorm", 153 | type=float, 154 | default=5.0, 155 | help=("Clip the gradient norm to avoid exploding gradients " 156 | "towards the end of convergence") 157 | ) 158 | 159 | parser.add_argument( 160 | "--tile_size", 161 | type=lambda x: tuple(int(xi) for xi in x.split("x")), 162 | default="250x250", 163 | help="Choose the size of the first-level tile" 164 | ) 165 | 166 | parser.add_argument( 167 | "--patch_size", 168 | type=lambda x: tuple(int(xi) for xi in x.split("x")), 169 | default="27x27", 170 | help="Choose the size of the patch(sub-tile) to extract from the high resolution" 171 | ) 172 | 173 | parser.add_argument( 174 | "--n_patches", 175 | type=int, 176 | default=10, 177 | help="How many patches to sample" 178 | ) 179 | parser.add_argument( 180 | "--regularizer_strength", 181 | type=float, 182 | default=0.01, 183 | help="How strong should the regularization be for the attention" 184 | ) 185 | 186 | parser.add_argument( 187 | "--batch_size", 188 | type=int, 189 | default=5, 190 | help="Choose the batch size for SGD" 191 | ) 192 | parser.add_argument( 193 | "--num_epochs", 194 | type=int, 195 | default=100, 196 | help="How many epochs to train" 197 | ) 198 | parser.add_argument( 199 | "--scale", 200 | type=float, 201 | default=0.2, 202 | help="Scale for downsampling images" 203 | ) 204 | parser.add_argument( 205 | "--contrastive_learning", 206 | default=False, 207 | action='store_true') 208 | 209 | parser.add_argument( 210 | "--apply_con_epochs", 211 | type=int, 212 | default=10, 213 | help="when to apply contrastive learning" 214 | ) 215 | 216 | parser.add_argument( 217 | "--num_classes", 218 | default=2, 219 | help="Number of classes of targets" 220 | ) 221 | parser.add_argument( 222 | "--model_name", 223 | default=None, 224 | help="the name of the training/testing model" 225 | ) 226 | parser.add_argument( 227 | "--overlap", 228 | type=int, 229 | default=1, 230 | help="Overlap for eliminating boundary effects" 231 | ) 232 | parser.add_argument( 233 | "--reg_strength", 234 | type=float, 235 | default=0, 236 | help="whether regularize the coefficients for dynamic sampling" 237 | ) 238 | parser.add_argument( 239 | "--valid_step", 240 | default=1, 241 | help="how often we do the validation " 242 | ) 243 | parser.add_argument( 244 | "--load_model", 245 | default='best', 246 | help="load which model" 247 | ) 248 | 249 | parser.add_argument( 250 | "--use_gpu", 251 | default=True, 252 | action='store_true') 253 | 254 | parser.add_argument( 255 | "--gpu", 256 | default=0 257 | ) 258 | parser.add_argument( 259 | "--num_works", 260 | default=4 261 | ) 262 | args = parser.parse_args(argv) 263 | if args.gpu is not None: 264 | os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' 265 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) 266 | if not os.path.exists(args.output): 267 | os.makedirs(args.output) 268 | 269 | if args.mode == "10CrossValidation": 270 | 271 | total_cross = 10 272 | ten_cross_acc_list = [] 273 | for i in range(total_cross): 274 | train_loader = HistoPatchwiseReader(args.dataset, annotation_name='labels_histo_train_fold_'+str(i+1)+'.csv', 275 | batch_size=args.batch_size, train=True) 276 | valid_loader = HistoPatchwiseReader(args.dataset, annotation_name='labels_histo_valid_fold_'+str(i+1)+'.csv', 277 | batch_size=1, train=False) 278 | _, _, _, _, model = get_models(args) 279 | configs = load_configuratiton(model, args) 280 | train_loss, train_acc, valid_loss, valid_acc = train(model, train_loader, valid_loader, configs) 281 | ten_cross_acc_list.append(valid_acc) 282 | print("Final 10 Cross Validation Results: ") 283 | print(ten_cross_acc_list) 284 | print("Average Accuracy: ", np.mean(ten_cross_acc_list)) 285 | elif args.mode == "Training": 286 | train_loader, valid_loader, test_loader = get_custom_dataset(args) 287 | _, _, _, _, model = get_models(args) 288 | train(model, train_loader, valid_loader, configs, args) 289 | eval(model, test_loader, configs, args) 290 | elif args.mode == "Evaluation": 291 | test_loader = HistoPatchwiseReader(args.dataset, annotation_name='labels_histo_valid_fold_'+str(i+1)+'.csv', 292 | batch_size=1, train=False) 293 | _, _, _, _, model = get_models(args) 294 | load_checkpoint() 295 | eval(model, test_loader, configs, args) 296 | 297 | if __name__ == "__main__": 298 | main(None) 299 | 300 | -------------------------------------------------------------------------------- /models/attention_models.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | 6 | padding = 0 7 | 8 | 9 | 10 | 11 | class Attention(torch.nn.Sequential): 12 | def __init__(self, no_softmax=False): 13 | super(Attention, self).__init__() 14 | # self.padding = 2 15 | self.no_softmax = no_softmax 16 | self.layers = \ 17 | torch.nn.Sequential( 18 | torch.nn.Conv2d(3, 8, kernel_size=3, padding=(1, 1)), 19 | torch.nn.ReLU(inplace=True), 20 | torch.nn.Conv2d(8, 8, kernel_size=3, padding=(1, 1)), 21 | torch.nn.ReLU(inplace=True), 22 | torch.nn.Conv2d(8, 1, kernel_size=3, padding=(1, 1)), 23 | ) 24 | self.layers2 = \ 25 | torch.nn.Sequential( 26 | torch.nn.Conv2d(1, 8, kernel_size=3, padding=(1, 1)), 27 | torch.nn.ReLU(inplace=True), 28 | torch.nn.Conv2d(8, 8, kernel_size=3, padding=(1, 1)), 29 | torch.nn.ReLU(inplace=True), 30 | torch.nn.Conv2d(8, 1, kernel_size=3, padding=(1, 1)), 31 | ) 32 | 33 | def forward(self, x): 34 | if x.shape[1] == 3: 35 | x = self.layers(x) 36 | elif x.shape[1] == 1: 37 | x = self.layers2(x) 38 | xs = x.shape 39 | if self.no_softmax: 40 | return x.view(xs) 41 | else: 42 | x = F.softmax(x.view(x.shape[0], -1), dim=1) 43 | return x.view(xs) 44 | 45 | 46 | 47 | 48 | class AttentionOnAttention(torch.nn.Sequential): 49 | def __init__(self): 50 | super(AttentionOnAttention, self).__init__() 51 | self.layers1 = \ 52 | torch.nn.Sequential( 53 | torch.nn.Conv2d(3, 8, kernel_size=3, padding=(1, 1)), 54 | torch.nn.ReLU(), 55 | torch.nn.Conv2d(8, 8, kernel_size=3, padding=(1, 1)), 56 | torch.nn.ReLU(), 57 | torch.nn.Conv2d(8, 1, kernel_size=3, padding=(1, 1)), 58 | ) 59 | 60 | 61 | def forward(self, x): 62 | 63 | x = self.layers1(x) 64 | x = F.adaptive_avg_pool2d(x, (1, 1)) 65 | x = x.squeeze(1).squeeze(1).squeeze(1) 66 | return x 67 | 68 | 69 | 70 | 71 | 72 | -------------------------------------------------------------------------------- /models/classifier.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | 5 | class Classifier(torch.nn.Sequential): 6 | def __init__(self, num_classes=10): 7 | super(Classifier, self).__init__() 8 | self.layers = \ 9 | torch.nn.Sequential( 10 | torch.nn.Linear(48, num_classes) 11 | ) 12 | 13 | def forward(self, x): 14 | x = self.layers(x) 15 | prediction = F.softmax(x) 16 | return x, prediction -------------------------------------------------------------------------------- /models/feature_extractors.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | 6 | padding = 0 7 | 8 | def conv_layer(in_channels, out_channels, kernel, strides, padding=1): 9 | return nn.Conv2d(in_channels, out_channels, kernel_size=kernel, stride=strides, padding_mode="zeros", bias=False, 10 | padding=padding) 11 | 12 | 13 | def batch_norm(filters): 14 | return nn.BatchNorm2d(filters) 15 | 16 | 17 | def relu(): 18 | return nn.ReLU() 19 | 20 | class Block(nn.Module): 21 | 22 | def __init__(self, in_channels, out_channels, stride, kernel_size, short): 23 | super(Block, self).__init__() 24 | 25 | self.short = short 26 | self.bn1 = batch_norm(in_channels) 27 | self.relu1 = relu() 28 | self.conv1 = conv_layer(in_channels, out_channels, 1, stride, padding=0) 29 | 30 | self.conv2 = conv_layer(in_channels, out_channels, kernel_size, stride) 31 | self.bn2 = batch_norm(out_channels) 32 | self.relu2 = relu() 33 | self.conv3 = conv_layer(out_channels, out_channels, kernel_size, 1) 34 | 35 | def forward(self, x): 36 | x = self.bn1(x) 37 | x = self.relu1(x) 38 | 39 | x_short = x 40 | if self.short: 41 | x_short = self.conv1(x) 42 | 43 | x = self.conv2(x) 44 | x = self.bn2(x) 45 | x = self.relu2(x) 46 | x = self.conv3(x) 47 | 48 | out = x + x_short 49 | return out 50 | 51 | class FeatureModelResNet(nn.Module): 52 | 53 | def __init__(self, in_channels, strides=[1, 2, 2, 2], filters=[48, 48, 48, 48]): 54 | super(FeatureModelResNet, self).__init__() 55 | 56 | stride_prev = strides.pop(0) 57 | filters_prev = filters.pop(0) 58 | 59 | self.conv1 = conv_layer(in_channels, filters_prev, 3, stride_prev) 60 | 61 | module_list = nn.ModuleList() 62 | for s, f in zip(strides, filters): 63 | module_list.append(Block(filters_prev, f, s, 3, s != 1 or f != filters_prev)) 64 | 65 | stride_prev = s 66 | filters_prev = f 67 | 68 | self.module_list = nn.Sequential(*module_list) 69 | 70 | self.bn1 = batch_norm(filters_prev) 71 | self.relu1 = relu() 72 | 73 | def forward(self, x): 74 | out = self.conv1(x) 75 | out = self.module_list(out) 76 | out = self.bn1(out) 77 | out = self.relu1(out) 78 | out = F.adaptive_avg_pool2d(out, (1, 1)) 79 | out = out.view(out.shape[0], out.shape[1]) 80 | out = F.normalize(out, p=2, dim=-1) 81 | return out 82 | 83 | 84 | 85 | 86 | class FeatureExtractor(torch.nn.Sequential): 87 | def __init__(self): 88 | super(FeatureExtractor, self).__init__() 89 | self.layers = \ 90 | torch.nn.Sequential( 91 | torch.nn.Conv2d(3, 36, kernel_size=5, padding=(2, 2)), 92 | torch.nn.ReLU(inplace=True), 93 | torch.nn.MaxPool2d(kernel_size=2, stride=2), 94 | torch.nn.Conv2d(36, 48, kernel_size=5, padding=(2, 2)), 95 | torch.nn.ReLU(inplace=True), 96 | torch.nn.MaxPool2d(kernel_size=2, stride=2), 97 | ) 98 | self.layers2 = \ 99 | torch.nn.Sequential( 100 | torch.nn.Conv2d(1, 36, kernel_size=5, padding=(2, 2)), 101 | torch.nn.ReLU(inplace=True), 102 | torch.nn.MaxPool2d(kernel_size=2, stride=2), 103 | torch.nn.Conv2d(36, 48, kernel_size=3, padding=(1, 1)), 104 | torch.nn.ReLU(inplace=True), 105 | torch.nn.MaxPool2d(kernel_size=2, stride=2), 106 | ) 107 | def forward(self, x): 108 | if x.shape[1] == 3: 109 | x = self.layers(x) 110 | elif x.shape[1] == 1: 111 | x = self.layers2(x) 112 | 113 | x = F.adaptive_avg_pool2d(x, (1, 1)) 114 | x = x.view(x.shape[0], -1) 115 | x = F.normalize(x, p=2, dim=1) # force the features have the same L2 norm. 116 | return x 117 | 118 | 119 | class Classifier(torch.nn.Sequential): 120 | def __init__(self, output_num=2): 121 | super(Classifier, self).__init__() 122 | self.layers = \ 123 | torch.nn.Sequential( 124 | torch.nn.Linear(48, 512), 125 | torch.nn.ReLU(inplace=True), 126 | torch.nn.Dropout(0.5), 127 | torch.nn.Linear(512, 512), 128 | torch.nn.ReLU(inplace=True), 129 | torch.nn.Dropout(0.5), 130 | torch.nn.Linear(512, output_num), 131 | ) 132 | 133 | def forward(self, x): 134 | x = self.layers(x) 135 | prediction = F.softmax(x) 136 | return x, prediction 137 | 138 | 139 | 140 | class Bottleneck(nn.Module): 141 | expansion = 4 142 | 143 | def __init__(self, inplanes, planes, stride=1, downsample=None, 144 | norm=lambda x: nn.InstanceNorm2d(x, affine=True), 145 | kernel_size=1, dropout=False, compensate=False): 146 | super(Bottleneck, self).__init__() 147 | mid_planes = planes 148 | if compensate: 149 | mid_planes = int(2.5 * planes) 150 | 151 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 152 | self.bn1 = norm(planes) 153 | 154 | self.conv2 = nn.Conv2d(planes, mid_planes, kernel_size=kernel_size, 155 | stride=stride, 156 | padding=(kernel_size - 1) // 2, # used to be 0 157 | bias=False) # changed padding from (kernel_size - 1) // 2 158 | self.bn2 = norm(mid_planes) 159 | 160 | self.drop = nn.Dropout2d(p=0.2) if dropout else lambda x: x 161 | self.conv3 = nn.Conv2d(mid_planes, planes * 4, kernel_size=1, bias=False) 162 | self.bn3 = norm(planes * 4) 163 | self.relu = nn.ReLU(inplace=True) 164 | self.downsample = downsample 165 | self.stride = stride 166 | 167 | def forward(self, x, **kwargs): 168 | residual = x 169 | 170 | out = self.conv1(x) 171 | out = self.bn1(out) 172 | out = self.relu(out) 173 | 174 | out = self.conv2(out) 175 | out = self.bn2(out) 176 | out = self.drop(out) 177 | out = self.relu(out) 178 | 179 | out = self.conv3(out) 180 | out = self.bn3(out) 181 | 182 | if self.downsample is not None: 183 | residual = self.downsample(x) 184 | 185 | if residual.size(-1) != out.size(-1): 186 | diff = residual.size(-1) - out.size(-1) 187 | residual = residual[:, :, :-diff, :-diff] 188 | 189 | out += residual 190 | out = self.relu(out) 191 | 192 | return out 193 | 194 | 195 | class BagNetEncoder(nn.Module): 196 | norms = { 197 | 'in_aff': lambda x: nn.InstanceNorm2d(x, affine=True), 198 | 'in': nn.InstanceNorm2d, 199 | 'bn': nn.BatchNorm2d 200 | } 201 | 202 | def __init__(self, block, layers, strides=[1, 2, 2, 2], wide_factor=1, 203 | kernel3=[0, 0, 0, 0], dropout=False, inp_channels=3, 204 | compensate=False, norm='in_aff'): 205 | self.planes = int(64 * wide_factor) 206 | self.inplanes = int(64 * wide_factor) 207 | self.compensate = compensate 208 | self.dropout = dropout 209 | self.norm = norm 210 | super(BagNetEncoder, self).__init__() 211 | self.conv1 = nn.Conv2d(inp_channels, self.planes, kernel_size=1, 212 | stride=1, padding=0, bias=False) 213 | self.conv2 = nn.Conv2d(self.planes, self.planes, kernel_size=3, 214 | stride=1, padding=0, bias=False) 215 | self.bn1 = self.norms[self.norm](self.planes) 216 | self.relu = nn.ReLU(inplace=True) 217 | self.layer1 = self._make_layer(block, self.planes, layers[0], 218 | stride=strides[0], kernel3=kernel3[0], 219 | prefix='layer1') 220 | self.layer2 = self._make_layer(block, self.planes * 2, layers[1], 221 | stride=strides[1], kernel3=kernel3[1], 222 | prefix='layer2') 223 | self.layer3 = self._make_layer(block, self.planes * 4, layers[2], 224 | stride=strides[2], kernel3=kernel3[2], 225 | prefix='layer3') 226 | self.layer4 = self._make_layer(block, self.planes * 8, layers[3], 227 | stride=strides[3], kernel3=kernel3[3], 228 | prefix='layer4') 229 | self.block = block 230 | 231 | for m in self.modules(): 232 | if isinstance(m, nn.Conv2d): 233 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 234 | m.weight.data.normal_(0, math.sqrt(2. / n)) 235 | elif isinstance(m, nn.InstanceNorm2d) and self.norm == 'in_aff': 236 | m.weight.data.fill_(1) 237 | m.bias.data.zero_() 238 | elif isinstance(m, nn.BatchNorm2d): 239 | m.weight.data.fill_(1) 240 | m.bias.data.zero_() 241 | 242 | def _make_layer(self, block, planes, blocks, stride=1, kernel3=0, 243 | prefix=''): 244 | downsample = None 245 | if stride != 1 or self.inplanes != planes * block.expansion: 246 | downsample = nn.Sequential( 247 | nn.Conv2d(self.inplanes, planes * block.expansion, 248 | kernel_size=1, stride=stride, bias=False), 249 | self.norms[self.norm](planes * block.expansion), 250 | ) 251 | 252 | layers = [] 253 | kernel = 1 if kernel3 == 0 else 3 254 | layers.append(block(self.inplanes, planes, stride, downsample, 255 | kernel_size=kernel, dropout=self.dropout, 256 | norm=self.norms[self.norm], 257 | compensate=(self.compensate and kernel == 1))) 258 | self.inplanes = planes * block.expansion 259 | for i in range(1, blocks): 260 | kernel = 1 if kernel3 <= i else 3 261 | layers.append(block(self.inplanes, planes, kernel_size=kernel, 262 | norm=self.norms[self.norm], 263 | compensate=(self.compensate and kernel == 1))) 264 | 265 | return nn.Sequential(*layers) 266 | 267 | def forward(self, x): 268 | x = self.conv1(x) 269 | x = self.conv2(x) 270 | x = self.bn1(x) 271 | x = self.relu(x) 272 | 273 | x = self.layer1(x) 274 | x = self.layer2(x) 275 | x = self.layer3(x) 276 | x = self.layer4(x) 277 | 278 | return x 279 | 280 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tqdm 3 | import time 4 | import torch 5 | import numpy as np 6 | from utils.utils import save_checkpoint, load_checkpoint, save_metrics, load_metrics, get_activation 7 | 8 | 9 | def train(model, train_loader, valid_loader, config): 10 | 11 | optimizer = config['optimizer'] 12 | training_show_every = len(train_loader) * 0.1 13 | eval_every = len(train_loader) * config['eval_every_epochs'] 14 | device = config['device'] 15 | criterion = config['criterion'] 16 | num_epochs = config['num_epochs'] 17 | scheduler = config['scheduler'] 18 | global_step = 0 19 | best_ACC = 0.0 20 | train_acc_list = [] 21 | valid_acc_list = [] 22 | train_loss_list = [] 23 | valid_loss_list = [] 24 | global_steps_list = [] 25 | for epoch in range(config['num_epochs']): 26 | sample_count = 0.0 27 | running_acc = 0.0 28 | running_loss = 0.0 29 | model.train() 30 | for train_batch in tqdm.tqdm(train_loader.generator()): 31 | train_input_reader = [e_reader for e_reader, _ in train_batch] 32 | labels = torch.stack([torch.Tensor([e_label]) for _, e_label in train_batch]).squeeze(1) 33 | labels = labels.type(torch.LongTensor).to(device) 34 | if config['contrastive_learning'] and epoch >= config['apply_con_epochs'] and torch.sum(labels) > 0: 35 | 36 | pos_reader = np.array(train_input_reader)[labels.data.cpu().numpy() == 1] 37 | pos_output, pos_sparse_loss = model(pos_reader) 38 | 39 | con_output, con_sparse_loss = model._compute_constrastive_predictions(pos_reader) 40 | 41 | pos_loss = criterion(pos_output, labels[labels == 1]) + pos_sparse_loss 42 | con_loss = criterion(con_output, torch.zeros_like(labels[labels == 1]).to(device)) + con_sparse_loss 43 | 44 | pos_weight = len(pos_reader)/len(train_input_reader) 45 | if torch.sum(1-labels) == 0: 46 | neg_loss = 0 47 | neg_weight = 0 48 | neg_output = None 49 | output = pos_output 50 | labels = labels[labels == 1] 51 | else: 52 | neg_reader = np.array(train_input_reader)[labels.data.cpu().numpy() == 0] 53 | neg_output, neg_sparse_loss = model(neg_reader) 54 | neg_loss = criterion(neg_output, labels[labels == 0]) + neg_sparse_loss 55 | neg_weight = len(neg_reader)/(len(neg_reader) + len(pos_reader)) 56 | output = torch.cat([pos_output, neg_output], dim=0) 57 | labels = torch.cat([labels[labels == 1], labels[labels == 0]], dim=0) 58 | 59 | loss = pos_weight * pos_loss + (1-pos_weight)* (neg_weight * neg_loss + (1-neg_weight) *con_loss) 60 | 61 | else: 62 | output, sparse_loss = model(train_input_reader) 63 | loss = criterion(output, labels) + sparse_loss 64 | 65 | optimizer.zero_grad() 66 | loss.backward() 67 | torch.nn.utils.clip_grad_norm(model.parameters(), config['clip']) 68 | optimizer.step() 69 | 70 | # update running values 71 | _, pred_labels = output.data.cpu().topk(1, dim=1) 72 | 73 | running_acc += torch.sum(pred_labels.t().squeeze() == labels.data.cpu().squeeze()).item() 74 | sample_count += labels.shape[0] 75 | running_loss += loss.item() 76 | global_step += 1 77 | 78 | # training stats 79 | if global_step % training_show_every == 0: 80 | average_running_acc = running_acc / sample_count 81 | average_train_loss = running_loss / sample_count 82 | print('Epoch [{}/{}], Step [{}/{}], Train Loss: {:.4f}, Train Acc: {:.4f}' 83 | .format(epoch+1, num_epochs, global_step, num_epochs*len(train_loader), 84 | average_train_loss, average_running_acc)) 85 | # evaluation step 86 | if global_step % eval_every == 0: 87 | model.eval() 88 | with torch.no_grad(): 89 | # validation loop 90 | print("Begin Validation") 91 | valid_sample_count = 0.0 92 | valid_running_loss = 0.0 93 | valid_running_acc = 0.0 94 | used_time = 0.0 95 | for val_batch in tqdm.tqdm(valid_loader.generator()): 96 | val_input_reader = [e_reader for e_reader, _ in val_batch] 97 | labels = torch.stack([torch.Tensor([e_label]) for _, e_label in val_batch]).squeeze(1) 98 | labels = labels.type(torch.LongTensor).to(device) 99 | start_time = time.time() 100 | output, _ = model(val_input_reader) 101 | used_time += time.time() - start_time 102 | loss = criterion(output, labels) 103 | 104 | valid_running_loss += loss.item() 105 | _, pred_labels = output.data.cpu().topk(1, dim=1) 106 | valid_running_acc += torch.sum(pred_labels.t().squeeze() == labels.data.cpu().squeeze()).item() 107 | valid_sample_count += labels.shape[0] 108 | # evaluation 109 | 110 | average_train_loss = running_loss / eval_every 111 | average_train_acc = running_acc / sample_count 112 | average_valid_loss = valid_running_loss / len(valid_loader) 113 | average_valid_acc = valid_running_acc / valid_sample_count 114 | train_loss_list.append(average_train_loss) 115 | train_acc_list.append(average_train_acc) 116 | valid_loss_list.append(average_valid_loss) 117 | valid_acc_list.append(average_valid_acc) 118 | global_steps_list.append(global_step) 119 | 120 | # resetting running values 121 | running_loss = 0.0 122 | valid_running_loss = 0.0 123 | model.train() 124 | # print progress 125 | print('Epoch [{}/{}], Step [{}/{}], Train Loss: {:.4f}, Train Acc: {:.4f}, Valid Loss: {:.4f}, Valid Acc: {:.4f}' 126 | .format(epoch+1, num_epochs, global_step, num_epochs*len(train_loader), 127 | average_train_loss, average_train_acc, average_valid_loss, average_valid_acc)) 128 | print("Max memory used: {} Mb ".format(torch.cuda.memory_allocated(device=0)/ (1024 * 1024))) 129 | print("Average Time per sample {} sec".format(used_time/(valid_sample_count))) 130 | # checkpoint 131 | if best_ACC <= average_valid_acc: 132 | best_ACC = average_valid_acc 133 | save_checkpoint(os.path.join(config['save_path'], 134 | config['model_name'] + '.pt'), model, best_ACC) 135 | save_metrics(os.path.join(config['save_path'],config['model_name'] + '.pt'), 136 | train_loss_list,valid_loss_list, global_steps_list) 137 | scheduler.step() 138 | train_loss = train_loss_list[-1] 139 | train_acc = train_acc_list[-1] 140 | valid_loss = valid_loss_list[-1] 141 | valid_acc = best_ACC 142 | return train_loss, train_acc, valid_loss, valid_acc -------------------------------------------------------------------------------- /utils/dataset_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | def load_annotations(data_dir): 5 | ''' 6 | load the path of annotations 7 | ''' 8 | at_dirs = np.load(os.path.join(data_dir, 'annotation_names.npy')) 9 | train_at, valid_at, test_at = at_dirs[0], at_dirs[1], at_dirs[2] 10 | return train_at, valid_at, test_at -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | # Need kind of clean up my code 3 | import torch 4 | import os 5 | import time 6 | import torch 7 | import random 8 | import shutil 9 | import argparse 10 | import warnings 11 | import numpy as np 12 | import pandas as pd 13 | from math import floor, sqrt 14 | from skimage.io import imsave 15 | from torch.nn import functional as F 16 | from torch.distributions import Multinomial 17 | 18 | import six 19 | import torch 20 | import pandas as pd 21 | import torch.nn as nn 22 | import seaborn as sns 23 | from itertools import chain 24 | import matplotlib.pyplot as plt 25 | import torch.nn as nn 26 | from torch.nn import Linear, Identity 27 | import torch.optim as optim 28 | from sklearn.metrics import accuracy_score, classification_report, confusion_matrix 29 | from collections import Counter, OrderedDict 30 | # Save and Load Functions 31 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 32 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 33 | np.random.seed(1) 34 | torch.manual_seed(1) 35 | 36 | 37 | def get_optim(model, args): 38 | if args.optimizer == "adam": 39 | optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.weight_decay) 40 | elif args.optimizer == 'sgd': 41 | optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, momentum=0.9, weight_decay=args.weight_decay) 42 | else: 43 | raise NotImplementedError 44 | return optimizer 45 | 46 | 47 | def read_image(nargs): 48 | img_path, transform, device = nargs 49 | image = imread(img_path) 50 | if transform: 51 | image = transform(image) 52 | else: 53 | image = torch.tensor(image, dtype=torch.float, device=device) 54 | image = image.permute(2, 0, 1) 55 | image = image.unsqueeze(0) 56 | return image 57 | 58 | def image_path_processor_base(image_paths, source_path = '', transform=None, device="cuda"): 59 | img_list = [] 60 | for img_path in image_paths: 61 | image = imread(os.path.join(source_path, img_path)) 62 | if transform: 63 | image = transform(image) 64 | else: 65 | image = torch.tensor(image, dtype=torch.float, device=device) 66 | image = image.permute(2, 0, 1) 67 | image = image.unsqueeze(0) 68 | img_list.append(image) 69 | img = torch.stack(img_list) 70 | return img 71 | 72 | def image_path_processor_parallizer(image_paths, source_path = '', transform=None, device="cuda", num_of_threads=8): 73 | image_paths = [(os.path.join(source_path, e_ip), transform, device) for e_ip in image_paths] 74 | pool = multiprocessing.Pool(num_of_threads) 75 | img = pool.map(read_image, image_paths) 76 | pool.close() 77 | pool.join() 78 | return img 79 | 80 | def save_checkpoint(save_path, model, valid_loss): 81 | 82 | if save_path == None: 83 | return 84 | 85 | state_dict = {'model_state_dict': model.state_dict(), 86 | 'valid_loss': valid_loss} 87 | 88 | torch.save(state_dict, save_path) 89 | print(f'Model saved to ==> {save_path}') 90 | 91 | def load_checkpoint(load_path, model): 92 | 93 | if load_path==None: 94 | return 95 | 96 | state_dict = torch.load(load_path, map_location=device) 97 | print(f'Model loaded from <== {load_path}') 98 | 99 | model.load_state_dict(state_dict['model_state_dict']) 100 | return state_dict['valid_loss'] 101 | 102 | 103 | def save_metrics(save_path, train_loss_list, valid_loss_list, global_steps_list): 104 | 105 | if save_path == None: 106 | return 107 | 108 | state_dict = {'train_loss_list': train_loss_list, 109 | 'valid_loss_list': valid_loss_list, 110 | 'global_steps_list': global_steps_list} 111 | 112 | torch.save(state_dict, save_path) 113 | print(f'Model saved to ==> {save_path}') 114 | 115 | 116 | def load_metrics(load_path): 117 | 118 | if load_path==None: 119 | return 120 | 121 | state_dict = torch.load(load_path, map_location=device) 122 | print(f'Model loaded from <== {load_path}') 123 | 124 | return state_dict['train_loss_list'], state_dict['valid_loss_list'], state_dict['global_steps_list'] 125 | 126 | 127 | def get_activation(name, activation): 128 | def hook(model, input, output): 129 | activation[name] = output.detach() 130 | return hook -------------------------------------------------------------------------------- /zoom_in/core_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | def SamplingPatches(location, frame_location, source_image, sample_space, low_img_level, high_img_level, patch_size): 5 | ''' 6 | 7 | :param location: (x,y) 8 | :param source_image: 9 | :param sample_space: shape of attention, like [100, 100] 10 | :return: shape [bs, C, H, W] 11 | ''' 12 | # sampled location transform horizon expansion 13 | if len(location.shape) == 2: 14 | location = location[0] 15 | row = np.floor(location // sample_space[1]) 16 | col = location % sample_space[1] 17 | # location - axis switch! and linear map back 18 | x = col * (np.around(2**low_img_level)) + frame_location[0] -int(np.around(2**high_img_level)* patch_size[1]//2) 19 | y = row * (np.around(2**low_img_level)) + frame_location[1] -int(np.around(2**high_img_level)* patch_size[0]//2) 20 | patch_list = [] 21 | for idx in range(x.size): 22 | patch = source_image.extract_patches(x[idx], y[idx], level=high_img_level, size=patch_size, show_mode='channel_first') 23 | patch_list.append(patch) 24 | patch_list = torch.stack(patch_list) 25 | return patch_list 26 | 27 | -------------------------------------------------------------------------------- /zoom_in/layers.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | 5 | def mean_forward(accum, sub): 6 | return accum + sub 7 | 8 | def max_forward(accum, sub): 9 | return torch.where(sub > accum, sub, accum) 10 | 11 | def logsumexp_forward(accum, sub): 12 | return accum + torch.exp(sub) 13 | 14 | def torch_exclusive_cumsum_in_dim_1_shape2(x): 15 | x = torch.cumsum(x, dim=1) 16 | x = torch.cat((torch.zeros((x.shape[0], 1)), x[:, :-1]), dim=1) 17 | return x 18 | 19 | 20 | def torch_exclusive_cumsum_in_dim_1_shape3(x): 21 | x = torch.cumsum(x, dim=1) 22 | x = torch.cat((torch.zeros((x.shape[0], 1, x.shape[2])), x[:, :-1, :]), dim=1) 23 | return x 24 | 25 | def to_tensor(x, dtype=torch.int32, device=None): 26 | """If x is a Tensor return it as is otherwise return a constant tensor of 27 | type dtype.""" 28 | device = torch.device('cpu') if device is None else device 29 | if torch.is_tensor(x): 30 | return x.to(device) 31 | 32 | return torch.tensor(x, dtype=dtype, device=device) 33 | 34 | 35 | def to_dtype(x, dtype): 36 | """Cast Tensor x to the dtype """ 37 | return x.type(dtype) 38 | 39 | 40 | to_float16 = partial(to_dtype, dtype=torch.float16) 41 | to_float32 = partial(to_dtype, dtype=torch.float32) 42 | to_float64 = partial(to_dtype, dtype=torch.float64) 43 | to_double = to_float64 44 | to_int8 = partial(to_dtype, dtype=torch.int8) 45 | to_int16 = partial(to_dtype, dtype=torch.int16) 46 | to_int32 = partial(to_dtype, dtype=torch.int32) 47 | to_int64 = partial(to_dtype, dtype=torch.int64) 48 | 49 | 50 | def expand_many(x, axes): 51 | """Call expand_dims many times on x once for each item in axes.""" 52 | for ax in axes: 53 | x = torch.unsqueeze(x, ax) 54 | return x 55 | 56 | class ExpectationWithoutReplacement(torch.autograd.Function): 57 | """ Custom pytorch layer for calculating the expectation of the sampled patches 58 | without replacement. 59 | """ 60 | 61 | @staticmethod 62 | def forward(ctx, weights, attention, features): 63 | # Reshape the passed weights and attention in feature compatible shapes 64 | axes = [-1] * (len(features.shape) - 2) 65 | wf = expand_many(weights, axes) 66 | af = expand_many(attention, axes) 67 | 68 | # Compute how much of the probablity mass was available for each sample 69 | pm = 1 - torch.cumsum(attention, axis=1) 70 | pmf = expand_many(pm, axes) 71 | 72 | # Compute the features 73 | Fa = af * features 74 | Fpm = pmf * features 75 | Fa_cumsum = torch.cumsum(Fa, axis=1) 76 | F_estimator = Fa_cumsum + Fpm 77 | 78 | F = torch.sum(wf * F_estimator, axis=1) 79 | 80 | ctx.save_for_backward(weights, attention, features, pm, pmf, Fa, Fpm, Fa_cumsum, F_estimator) 81 | 82 | return F 83 | 84 | @staticmethod 85 | def backward(ctx, grad_output): 86 | weights, attention, features, pm, pmf, Fa, Fpm, Fa_cumsum, F_estimator = ctx.saved_tensors 87 | device = weights.device 88 | 89 | axes = [-1] * (len(features.shape) - 2) 90 | wf = expand_many(weights, axes) 91 | af = expand_many(attention, axes) 92 | 93 | N = attention.shape[1] 94 | probs = attention / pm 95 | probsf = expand_many(probs, axes) 96 | grad = torch.unsqueeze(grad_output, 1) 97 | 98 | # Gradient wrt to the attention 99 | ga1 = F_estimator / probsf 100 | ga2 = ( 101 | torch.cumsum(features, axis=1) - 102 | expand_many(to_float32(torch.arange(0, N, device=device)), [0] + axes) * features 103 | ) 104 | ga = grad * (ga1 + ga2) 105 | ga = torch.sum(ga, axis=list(range(2, len(ga.shape)))) 106 | ga = ga * weights 107 | 108 | # Gradient wrt to the features 109 | gf = expand_many(to_float32(torch.arange(N-1, -1, -1, device=device)), [0] + axes) 110 | gf = pmf + gf * af 111 | gf = wf * gf 112 | gf = gf * grad 113 | 114 | return None, ga, gf 115 | 116 | 117 | class ExpectationWithReplacement(torch.autograd.Function): 118 | """ Custom pytorch layer for calculating the expectation of the sampled patches 119 | with replacement. 120 | """ 121 | @staticmethod 122 | def forward(ctx, weights, attention, features): 123 | 124 | axes = [-1] * (len(features.shape) - 2) 125 | wf = expand_many(weights, axes) 126 | 127 | F = torch.sum(wf * features, dim=1) 128 | 129 | ctx.save_for_backward(weights, attention, features, F) 130 | return F 131 | 132 | @staticmethod 133 | def backward(ctx, grad_output): 134 | weights, attention, features, F = ctx.saved_tensors 135 | axes = [-1] * (len(features.shape) - 2) 136 | wf = expand_many(weights, axes) 137 | 138 | grad = torch.unsqueeze(grad_output, 1) 139 | 140 | # Gradient wrt to the attention 141 | ga = grad * features 142 | ga = torch.sum(ga, axis=list(range(2, len(ga.shape)))) 143 | ga = ga * weights / attention 144 | 145 | # Gradient wrt to the features 146 | gf = wf * grad 147 | 148 | return None, ga, gf -------------------------------------------------------------------------------- /zoom_in/regularizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def MultinomialRegularizer(x, strength, eps=1e-6): 4 | logx = torch.log(x + eps) 5 | return strength * torch.sum(x * logx) / float(x.shape[0]) 6 | -------------------------------------------------------------------------------- /zoom_in/zoom_in.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import tqdm 3 | import torch.nn as nn 4 | from matplotlib import pyplot as plt 5 | # Need kind of clean up my code 6 | import os 7 | import time 8 | import torch 9 | import random 10 | import shutil 11 | import argparse 12 | import warnings 13 | import numpy as np 14 | import pandas as pd 15 | from cv2 import imread 16 | from math import floor, sqrt 17 | from skimage.io import imsave 18 | from itertools import zip_longest 19 | from torch.nn import functional as F 20 | from new_utils import SamplingPatchesV2 21 | import torch.distributions as dist 22 | from torch.distributions import Multinomial 23 | from utils import MultinomialRegularizer, NCAMPatchwiseReader 24 | from layers import ExpectationWithoutReplacement, ExpectationWithReplacement 25 | from sampling import _sample_without_replacement, _sample_with_replacement 26 | from networks import Attention, AttentionOnAttention, FeatureExtractor, Classifier 27 | 28 | 29 | def attention_inference_weights(attention): 30 | attention = (attention - attention.min())/attention.max() 31 | attention[attention < torch.quantile(attention, 0.3, dim=0, keepdim=True, interpolation='lower')] = 0 32 | return attention 33 | 34 | 35 | class ZoomInNet(nn.Module): 36 | ''' 37 | Zoom-in Net in a NN NutShell 38 | ''' 39 | def __init__(self, 40 | attention, aoa, fe, clf, 41 | batch_size=32, 42 | frame_size= (250, 250), 43 | patch_size= (50, 50), 44 | low_low_img_level = 2, 45 | low_img_level = 1, 46 | high_img_level = 0, 47 | num_classes=10, 48 | reg_strength = 1e-4, 49 | n_patches = 15, 50 | weights = None, 51 | contrasitve_learning=False, 52 | device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')): 53 | super().__init__() 54 | self.batch_size = batch_size 55 | self.attention = attention 56 | self.aoa = aoa 57 | self.fe = fe 58 | self.clf = clf 59 | self.expected_accum = ExpectationWithReplacement.apply 60 | self.expected = ExpectationWithoutReplacement.apply 61 | self.low_low_img_level = low_low_img_level 62 | self.low_img_level = low_img_level 63 | self.high_img_level = high_img_level 64 | self.frame_size= frame_size 65 | self.patch_size= patch_size 66 | self.device = device 67 | self.reg_strength = reg_strength 68 | self.desired_patches = n_patches 69 | self.contrasitve_learning = contrasitve_learning 70 | self.weights = weights 71 | 72 | def _sample_with_replacement(self, logits, n_samples=10): 73 | ''' 74 | Helper function to sample with replacement 75 | ''' 76 | # distribution = dist.categorical.Categorical(logits=logits) 77 | # return distribution.sample(sample_shape=torch.Size([n_samples])) 78 | if self.training: 79 | return Multinomial(n_samples, logits).sample() 80 | else: 81 | #torch.topk(logits, n_samples) 82 | logits = attention_inference_weights(logits) 83 | #return torch.floor(n_samples * logits) 84 | return Multinomial(n_samples, logits).sample() 85 | 86 | 87 | def _sample_without_replacement(self, logits, n_samples=10): 88 | ''' 89 | Helper function to sample without replacement 90 | ''' 91 | if self.training: 92 | z = -torch.log(-torch.log(torch.rand_like(logits))) 93 | return torch.topk(logits+z, k=n_samples)[1] 94 | else: 95 | return torch.topk(logits, k=n_samples)[1] 96 | 97 | def _compute_loc_array(self, DataReader): 98 | ''' 99 | Helper function to sample without replacement 100 | ''' 101 | size_img = DataReader.get_size() 102 | frame_size = self.frame_size 103 | patch_size = self.patch_size 104 | x = np.arange(0, size_img[0], frame_size[0]) 105 | y = np.arange(0, size_img[1], frame_size[1]) 106 | xx, yy = np.meshgrid(x, y) 107 | low_img_shape = xx.shape 108 | high_img_shape = size_img 109 | loc_list = [(xx.flatten()[i], yy.flatten()[i]) for i in range(len(xx.flatten()[:]))] 110 | locs_array = np.array(loc_list) 111 | return locs_array 112 | 113 | def _computing_aoa(self, locs_array, dataReader): 114 | ''' 115 | Helper funcntion to compute the attention-on-attention 116 | ''' 117 | aoa_list = [] 118 | for idx, loc in enumerate(locs_array): 119 | ## time - memory trading 120 | x, y = loc 121 | sub_image = dataReader.get_patches(x, y, level=self.low_low_img_level, 122 | size=np.asarray(self.frame_size) // 2 ** self.low_low_img_level, 123 | show_mode='channel_first') 124 | sub_image = sub_image.to(self.device) 125 | sub_image_tensor = torch.Tensor(sub_image).to(self.device) 126 | sub_image_tensor = sub_image_tensor.unsqueeze(0) 127 | sub_aoa = self.aoa(sub_image_tensor) 128 | aoa_list.append(sub_aoa) 129 | aoa_list = torch.cat(aoa_list) 130 | aoa_list = F.softmax(aoa_list) 131 | return aoa_list 132 | 133 | def _computing_attention(self, loc, dataReader): 134 | x, y = loc 135 | sub_image = dataReader.get_patches(x, y, level=self.low_img_level, 136 | size=np.asarray(self.frame_size) // 2 ** self.low_img_level, 137 | show_mode='channel_first') 138 | sub_image = sub_image.to(self.device) 139 | sub_image_tensor = torch.Tensor(sub_image).to(self.device) 140 | sub_image_tensor = sub_image_tensor.unsqueeze(0) 141 | sub_att = self.attention(sub_image_tensor) 142 | return sub_att 143 | 144 | def _compute_constrastive_predictions(self, x_reader): 145 | accum_feature_list = [] 146 | # if self.contrasitve_learning: 147 | # con_accum_feature_list = [] 148 | for e_reader in x_reader: 149 | now_dataReader = e_reader 150 | ## compute locs_array 151 | #free of size 152 | locs_array = self._compute_loc_array(now_dataReader) 153 | aoa_list = self._computing_aoa(locs_array, now_dataReader) 154 | reg_aoa = MultinomialRegularizer(aoa_list, self.reg_strength) 155 | # Sampling again implementation 156 | target_samples = self._sample_with_replacement(1 - aoa_list, self.desired_patches) 157 | target_idx = (target_samples != 0).nonzero() 158 | num_of_region = target_idx 159 | target_samples = target_samples[target_idx] 160 | target_aoa = aoa_list[target_idx].squeeze(dim=1) 161 | target_locs_array = locs_array[target_idx.cpu().numpy().astype(np.int)].squeeze() 162 | ############################## 163 | if len(target_locs_array.shape) == 1: 164 | # add an axis 165 | target_locs_array = np.expand_dims(target_locs_array, axis=0) 166 | 167 | expected_sub_features_list = [] 168 | for idx, loc in enumerate(target_locs_array): 169 | sub_att = self._computing_attention(loc, now_dataReader) 170 | reg_aoa += MultinomialRegularizer(sub_att.view(-1), self.reg_strength) 171 | #sub_att = 1 - sub_att 172 | sampled_location = self._sample_without_replacement(1 - sub_att.view(sub_att.shape[0], -1), 173 | int(target_samples[idx].cpu().numpy())) 174 | sampled_attention = sub_att.view(-1)[sampled_location] 175 | sampled_patches = SamplingPatches(sampled_location.cpu().numpy(), loc, 176 | now_dataReader, sub_att.shape[2:], 177 | self.low_img_level, self.high_img_level, self.patch_size) 178 | 179 | sub_feature = self.fe(sampled_patches.to(self.device)) 180 | if self.weights is None: 181 | self.weights = torch.ones_like(sampled_attention) / sampled_attention.shape[1] 182 | expected_sub_feature = self.expected(self.weights, sampled_attention, sub_feature.unsqueeze(0)) 183 | expected_sub_features_list.append(expected_sub_feature.unsqueeze(0)) 184 | # Re-init weights para 185 | self.weights = None 186 | # Double expectations 187 | # Repeat aoa based on the target samples 188 | expected_sub_feature_list = torch.cat(expected_sub_features_list, dim=1) 189 | weights_accum_f = target_samples.squeeze(1).unsqueeze(0) / self.desired_patches 190 | accum_features = self.expected_accum(weights_accum_f, target_aoa.unsqueeze(0), expected_sub_feature_list) 191 | accum_feature_list.append(accum_features) 192 | accum_feature_list = torch.cat(accum_feature_list, dim=0) 193 | con_prediction, _ = self.clf(accum_feature_list.to(self.device)) 194 | return con_prediction, reg_aoa 195 | 196 | def forward(self, x_reader): 197 | ''' 198 | x_reader : a list of datareader which has a length equal to batch size 199 | 200 | return 201 | - Predictions 202 | ''' 203 | accum_feature_list = [] 204 | for e_reader in x_reader: 205 | now_dataReader = e_reader 206 | ## compute locs_array 207 | #free of size 208 | locs_array = self._compute_loc_array(now_dataReader) 209 | aoa_list = self._computing_aoa(locs_array, now_dataReader) 210 | reg_aoa = MultinomialRegularizer(aoa_list, self.reg_strength) 211 | # Sampling again implementation 212 | 213 | target_samples = self._sample_with_replacement(aoa_list, self.desired_patches) 214 | #print(target_samples) 215 | target_idx = (target_samples != 0).nonzero() 216 | #print(target_idx) 217 | #print(locs_array.shape) 218 | num_of_region = target_idx 219 | target_samples = target_samples[target_idx] 220 | #print(target_samples) 221 | target_aoa = aoa_list[target_idx] 222 | #print(len(locs_array)) 223 | #print(target_idx.cpu().data.numpy()) 224 | target_locs_array = locs_array[target_idx.detach().cpu().numpy().astype(np.int)].squeeze() 225 | ############################## 226 | if len(target_locs_array.shape) == 1: 227 | # add an axis 228 | target_locs_array = np.expand_dims(target_locs_array, axis=0) 229 | 230 | expected_sub_features_list = [] 231 | for idx, loc in enumerate(target_locs_array): 232 | sub_att = self._computing_attention(loc, now_dataReader) 233 | reg_aoa += MultinomialRegularizer(sub_att.view(-1), self.reg_strength) 234 | sampled_location = self._sample_without_replacement(sub_att.view(sub_att.shape[0], -1), 235 | int(target_samples[idx].cpu().numpy())) 236 | sampled_attention = sub_att.view(-1)[sampled_location] 237 | sampled_patches = SamplingPatches(sampled_location.cpu().numpy(), loc, 238 | now_dataReader, sub_att.shape[2:], 239 | self.low_img_level, self.high_img_level, self.patch_size) 240 | 241 | sub_feature = self.fe(sampled_patches.to(self.device)) 242 | if self.weights is None: 243 | self.weights = torch.ones_like(sampled_attention) / sampled_attention.shape[1] 244 | expected_sub_feature = self.expected(self.weights, sampled_attention, sub_feature.unsqueeze(0)) 245 | expected_sub_features_list.append(expected_sub_feature.unsqueeze(0)) 246 | # Re-init weights para 247 | self.weights = None 248 | # Double expectations 249 | # Repeat aoa based on the target samples 250 | expected_sub_feature_list = torch.cat(expected_sub_features_list, dim=1) 251 | weights_accum_f = target_samples.squeeze(1).unsqueeze(0) / self.desired_patches 252 | accum_features = self.expected_accum(weights_accum_f, target_aoa.unsqueeze(0), expected_sub_feature_list) 253 | accum_feature_list.append(accum_features) 254 | accum_feature_list = torch.cat(accum_feature_list, dim=0) 255 | prediction, _ = self.clf(accum_feature_list.to(self.device)) 256 | return prediction, reg_aoa --------------------------------------------------------------------------------