├── CPTAC.xlsx ├── IvYGAP.xlsx ├── README.md ├── ROC1.zip ├── ROC2.zip ├── TCGA.xlsx ├── __pycache__ ├── dataset.cpython-38.pyc ├── dataset_mine.cpython-38.pyc ├── evaluation.cpython-38.pyc ├── model.cpython-38.pyc ├── net.cpython-38.pyc ├── saver.cpython-38.pyc └── utils.cpython-38.pyc ├── basic_net ├── CDKN_main.py ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── alexnet.cpython-38.pyc │ ├── densenet.cpython-38.pyc │ ├── inception.cpython-38.pyc │ ├── mnasnet.cpython-38.pyc │ ├── nystrom_atten.cpython-38.pyc │ └── resnet.cpython-38.pyc ├── alexnet.py ├── densenet.py ├── inception.py ├── mnasnet.py ├── nystrom_atten.py └── resnet.py ├── config ├── CDKN.yml ├── cifar_pretrain.yml ├── miccai.yml └── mine.yml ├── data_process.py ├── dataset.py ├── dataset_mine copy 2.py ├── dataset_mine copy.py ├── dataset_mine.py ├── debug.py ├── docs ├── 1748968628246.png └── framework图.png ├── evaluation.py ├── feature_generation.py ├── logs.py ├── main _noGrad_guide.py ├── main miccai.py ├── main.py ├── main_DCC-IDH.py ├── main_dis_loss1.py ├── main_dis_loss2.py ├── main_dis_noloss.py ├── main_noDCC.py ├── main_noGrad.py ├── main_noGrad_longth copy.py ├── main_noGrad_longth.py ├── main_noLCloss.py ├── main_no_both.py ├── main_no_his.py ├── main_no_maker.py ├── main_nograph pre.py ├── main_nograph.py ├── mainpre.py ├── merge_who.xlsx ├── model copy.py ├── model.py ├── net.py ├── post_processing.py ├── roc_plot mu.py ├── roc_plot.py ├── roc_util ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── _demo.cpython-38.pyc │ ├── _plot.cpython-38.pyc │ ├── _roc.cpython-38.pyc │ ├── _sampling.cpython-38.pyc │ ├── _stats.cpython-38.pyc │ └── _types.cpython-38.pyc ├── _demo.py ├── _plot.py ├── _roc.py ├── _sampling.py ├── _stats.py └── _types.py ├── saver.py ├── test_forroc.py ├── test_forroc_new.py ├── test_main.py ├── test_main_new.py ├── transform ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── functional.cpython-38.pyc │ └── transforms_group.cpython-38.pyc ├── functional.py └── transforms_group.py ├── utils copy.py ├── utils.py └── utils_finetune.py /CPTAC.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LHY1007/M3C2/ef66915bd3e9671312205c424fe6017915f1ed81/CPTAC.xlsx -------------------------------------------------------------------------------- /IvYGAP.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LHY1007/M3C2/ef66915bd3e9671312205c424fe6017915f1ed81/IvYGAP.xlsx -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 |

3 | Icon M3C2 4 | 5 |

6 | Joint Modelling Histology and Molecular Markers for Glioma Classification

7 | 8 |

9 | Xiaofei Wang, a,1, 10 | Hanyu Liu,b,1, 11 | Yupei Zhanga, 12 | Boyang Zhaob, 13 | Hao Duanc, 14 | Wanming Hud, 15 | Yonggao Mouc, 16 | Stephen Pricea, 17 | Chao Lia,b,e,f 
18 | a Department of Clinical Neurosciences, University of Cambridge, UK
19 | b School of Science and Engineering, University of Dundee, UK
20 | c Department of Neurosurgery, State Key Laboratory of Oncology in South China, Guangdong Provincial Clinical Research Center for Cancer, Sun Yat-sen 21 | University Cancer Center, China
22 | d Department of Pathology, State Key Laboratory of Oncology in South China, Guangdong Provincial Clinical Research Center for Cancer, Sun Yat-sen University 23 | Cancer Center, China
24 | e Department of Applied Mathematics and Theoretical Physics, University of Cambridge, UK 25 | f School of Medicine, University of Dundee, UK 26 |

27 | 28 | 29 | 30 |
31 | 32 | Paper arXiv 33 |
34 |

35 | 36 | ## 📣 Latest Updates 37 | 38 | - **[2025-03-25]** 📊 *M3C2 Code have been released!* 39 | - **[2025-02-04]** 📝 *M3C2 paper preprint is now available on [Medical Image Analysis]([https://arxiv.org/abs/your-link](https://www.sciencedirect.com/science/article/pii/S1361841525000532)).* 40 | - **[2024-12-04]** 🎉 *M3C2 has been accepted to Medical Image Analysis!* 41 | - **[2024-08-02]** 📝 *M3C2 are now submitted to the Medical Image Analysis.* 42 | 43 | ## Key Takeaways 44 | 45 | - **M3C2** presents a groundbreaking framework for cancer classification by integrating histology and molecular markers. 46 | 🧠 **Key Innovation**: The framework employs **multi-scale disentangling modules** to extract both **high-magnification cellular-level** and **low-magnification tissue-level** features, which are then used to predict histology and molecular markers simultaneously. 47 | 48 | - The method introduces a **Co-occurrence Probability-based Label-Correlation Graph (CPLC-Graph)** to model the relationships between multiple molecular markers. 49 | This enhancement leads to better classification accuracy by capturing **intrinsic marker co-occurrences** and their impact on cancer classification. 50 | 51 | - **Cross-Modal Interaction** is key to the model’s success. 52 | 🔄 **Interaction Mechanism**: By using **dynamic confidence constraints** and a **cross-modal gradient modulation strategy**, M3C2 efficiently aligns the prediction tasks for histology and molecular markers, ensuring both tasks complement each other for more accurate results. 53 | 54 | - **Validation Across Diverse Datasets**: M3C2 outperforms existing state-of-the-art methods in **glioma classification** and **molecular marker prediction**, showcasing its robustness in **internal** and **external validation datasets**. 55 | 📊 **Performance**: The method achieves significant improvements, with **accuracy** and **AUC scores** surpassing previous models by as much as **5.6%** in certain tasks. 56 | 57 | - **Clinical Implications**: The ability to predict molecular markers directly from **whole-slide images (WSIs)**, combined with the model's capacity to understand the interactions between histology and molecular data, offers strong potential for **precision oncology**. 58 | 🏥 **Impact**: M3C2’s approach aligns with the latest **WHO glioma classification criteria**, making it a promising tool for clinical decision-making and personalized cancer treatment. 59 | 60 | 61 | ![sicl](docs/framework图.png) 62 | 63 | ## About this code 64 | 65 | he M3C2 codebase is written in Python and focuses on integrating histology features and molecular markers for cancer classification. It uses various deep learning techniques for analyzing whole-slide images (WSIs) and predicting cancer types, particularly gliomas. The core module structure is as follows: 66 | 67 | ``` 68 | M3C2-main/ 69 | ├── CPTAC.xlsx # Dataset containing clinical and molecular data for cancer classification (CPTAC). 70 | ├── IvYGAP.xlsx # Dataset containing data for glioma diagnosis and treatment (IvYGAP). 71 | ├── TCGA.xlsx # Dataset from The Cancer Genome Atlas (TCGA) for glioma classification. 72 | ├── merge_who.xlsx # Merged dataset based on the latest WHO glioma classification. 73 | ├── README.md # Overview of the project, its purpose, and instructions for usage. 74 | ├── data_process.py # Script to process and clean datasets for model training. 75 | ├── dataset.py # Defines the dataset structure and loading mechanism for WSIs. 76 | ├── dataset_mine.py # Alternative dataset processing script with additional feature extraction. 77 | ├── feature_generation.py # Generates features required for classification from raw data. 78 | ├── model.py # Defines the neural network architecture used for cancer classification. 79 | ├── net.py # Contains code for building the network layers of the model. 80 | ├── evaluation.py # Script to evaluate the model’s performance on various tasks. 81 | ├── post_processing.py # Post-processing of model predictions such as filtering or formatting. 82 | ├── main.py # Main script for running the model with specified parameters. 83 | ├── main_noGrad_guide.py # Variant of the main script with gradient updates disabled for specific tasks. 84 | ├── main_miccai.py # Main script variant used for experiments aligned with MICCAI conference. 85 | ├── main_noLCloss.py # Variant without the label correlation loss for model training. 86 | ├── model_copy.py # Another copy of the model code, likely with experimental variations. 87 | ├── roc_plot.py # Script for plotting ROC curves to evaluate model performance. 88 | ├── test_forroc.py # Testing script specifically for evaluating model performance via ROC. 89 | ├── transform # Folder containing data transformation functions for WSIs. 90 | │ ├── augmentations.py # Defines data augmentation methods (e.g., rotation, zoom). 91 | │ ├── normalize.py # Script for normalizing the input data (WSIs). 92 | ├── utils.py # Contains helper functions used throughout the project. 93 | ├── utils_finetune.py # Utility functions specifically for fine-tuning the model. 94 | ├── logs.py # Handles logging of training and testing results. 95 | ├── __pycache__ # Folder containing Python bytecode files for faster execution. 96 | ``` 97 | 98 | ## How to apply the work 99 | ### 1. Environment 100 | - Python >= 3.7 101 | - Pytorch >= 1.12 is recommended 102 | - opencv-python 103 | - sklearn 104 | - matplotlib 105 | 106 | 107 | ### 2. Train 108 | Use the below command to train the model on our database. 109 | ``` 110 | python ./main.py 111 | ``` 112 | 113 | ### 3. Test 114 | Use the below command to test the model on our database. 115 | ``` 116 | python ./test_main.py 117 | ``` 118 | 119 | ### 4. Datasets 120 | ``` 121 | https://www.kaggle.com/datasets/liuhanyu1007/m3c2-data 122 | ``` 123 | 124 | ### 5. Model 125 | ``` 126 | https://www.kaggle.com/models/liuhanyu1007/m3c2_model 127 | ``` 128 | 129 | ## Contact 130 | - Xiaofei Wang: xw405@cam.ac.uk 131 | - Hanyu Liu: 2485644@dundee.ac.uk 132 | 133 | 134 | Please open an issue or submit a pull request for issues, or contributions. 135 | 136 | ## 💼 License 137 | 138 | 139 | License: MIT 140 | 141 | 142 | ## Citation 143 | 144 | If you find our benchmark is helpful, please cite our paper: 145 | 146 | ``` 147 | @article{wang2025joint, 148 | title={Joint Modelling Histology and Molecular Markers for Cancer Classification}, 149 | author={Wang, Xiaofei and Liu, Hanyu and Zhang, Yupei and Zhao, Boyang and Duan, Hao and Hu, Wanming and Mou, Yonggao and Price, Stephen and Li, Chao}, 150 | journal={arXiv preprint arXiv:2502.07979}, 151 | year={2025} 152 | } 153 | ``` 154 | -------------------------------------------------------------------------------- /ROC1.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LHY1007/M3C2/ef66915bd3e9671312205c424fe6017915f1ed81/ROC1.zip -------------------------------------------------------------------------------- /ROC2.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LHY1007/M3C2/ef66915bd3e9671312205c424fe6017915f1ed81/ROC2.zip -------------------------------------------------------------------------------- /TCGA.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LHY1007/M3C2/ef66915bd3e9671312205c424fe6017915f1ed81/TCGA.xlsx -------------------------------------------------------------------------------- /__pycache__/dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LHY1007/M3C2/ef66915bd3e9671312205c424fe6017915f1ed81/__pycache__/dataset.cpython-38.pyc -------------------------------------------------------------------------------- /__pycache__/dataset_mine.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LHY1007/M3C2/ef66915bd3e9671312205c424fe6017915f1ed81/__pycache__/dataset_mine.cpython-38.pyc -------------------------------------------------------------------------------- /__pycache__/evaluation.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LHY1007/M3C2/ef66915bd3e9671312205c424fe6017915f1ed81/__pycache__/evaluation.cpython-38.pyc -------------------------------------------------------------------------------- /__pycache__/model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LHY1007/M3C2/ef66915bd3e9671312205c424fe6017915f1ed81/__pycache__/model.cpython-38.pyc -------------------------------------------------------------------------------- /__pycache__/net.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LHY1007/M3C2/ef66915bd3e9671312205c424fe6017915f1ed81/__pycache__/net.cpython-38.pyc -------------------------------------------------------------------------------- /__pycache__/saver.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LHY1007/M3C2/ef66915bd3e9671312205c424fe6017915f1ed81/__pycache__/saver.cpython-38.pyc -------------------------------------------------------------------------------- /__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LHY1007/M3C2/ef66915bd3e9671312205c424fe6017915f1ed81/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /basic_net/CDKN_main.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LHY1007/M3C2/ef66915bd3e9671312205c424fe6017915f1ed81/basic_net/CDKN_main.py -------------------------------------------------------------------------------- /basic_net/__init__.py: -------------------------------------------------------------------------------- 1 | from .alexnet import * 2 | from .resnet import * 3 | from .inception import * 4 | from .densenet import * 5 | from .mnasnet import * 6 | from .nystrom_atten import * 7 | 8 | 9 | -------------------------------------------------------------------------------- /basic_net/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LHY1007/M3C2/ef66915bd3e9671312205c424fe6017915f1ed81/basic_net/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /basic_net/__pycache__/alexnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LHY1007/M3C2/ef66915bd3e9671312205c424fe6017915f1ed81/basic_net/__pycache__/alexnet.cpython-38.pyc -------------------------------------------------------------------------------- /basic_net/__pycache__/densenet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LHY1007/M3C2/ef66915bd3e9671312205c424fe6017915f1ed81/basic_net/__pycache__/densenet.cpython-38.pyc -------------------------------------------------------------------------------- /basic_net/__pycache__/inception.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LHY1007/M3C2/ef66915bd3e9671312205c424fe6017915f1ed81/basic_net/__pycache__/inception.cpython-38.pyc -------------------------------------------------------------------------------- /basic_net/__pycache__/mnasnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LHY1007/M3C2/ef66915bd3e9671312205c424fe6017915f1ed81/basic_net/__pycache__/mnasnet.cpython-38.pyc -------------------------------------------------------------------------------- /basic_net/__pycache__/nystrom_atten.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LHY1007/M3C2/ef66915bd3e9671312205c424fe6017915f1ed81/basic_net/__pycache__/nystrom_atten.cpython-38.pyc -------------------------------------------------------------------------------- /basic_net/__pycache__/resnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LHY1007/M3C2/ef66915bd3e9671312205c424fe6017915f1ed81/basic_net/__pycache__/resnet.cpython-38.pyc -------------------------------------------------------------------------------- /basic_net/alexnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.utils.model_zoo as model_zoo 3 | 4 | 5 | __all__ = ['AlexNet', 'alexnet'] 6 | 7 | 8 | model_urls = { 9 | 'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth', 10 | } 11 | 12 | 13 | class AlexNet(nn.Module): 14 | 15 | def __init__(self, num_classes=1000): 16 | super(AlexNet, self).__init__() 17 | self.features = nn.Sequential( 18 | nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), 19 | nn.ReLU(inplace=True), 20 | nn.MaxPool2d(kernel_size=3, stride=2), 21 | nn.Conv2d(64, 192, kernel_size=5, padding=2), 22 | nn.ReLU(inplace=True), 23 | nn.MaxPool2d(kernel_size=3, stride=2), 24 | nn.Conv2d(192, 384, kernel_size=3, padding=1), 25 | nn.ReLU(inplace=True), 26 | nn.Conv2d(384, 256, kernel_size=3, padding=1), 27 | nn.ReLU(inplace=True), 28 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 29 | nn.ReLU(inplace=True), 30 | nn.MaxPool2d(kernel_size=3, stride=2), 31 | ) 32 | self.avgpool = nn.AdaptiveAvgPool2d((6, 6)) 33 | self.classifier = nn.Sequential( 34 | nn.Dropout(), 35 | nn.Linear(256 * 6 * 6, 4096), 36 | nn.ReLU(inplace=True), 37 | nn.Dropout(), 38 | nn.Linear(4096, 4096), 39 | nn.ReLU(inplace=True), 40 | nn.Linear(4096, num_classes), 41 | ) 42 | 43 | def forward(self, x): 44 | x = self.features(x) 45 | x = self.avgpool(x) 46 | x = x.view(x.size(0), 256 * 6 * 6) 47 | x = self.classifier(x) 48 | return x 49 | 50 | 51 | def alexnet(pretrained=False, **kwargs): 52 | r"""AlexNet model architecture from the 53 | `"One weird trick..." `_ paper. 54 | 55 | Args: 56 | pretrained (bool): If True, returns a model pre-trained on ImageNet 57 | """ 58 | model = AlexNet(**kwargs) 59 | if pretrained: 60 | model.load_state_dict(model_zoo.load_url(model_urls['alexnet'])) 61 | return model 62 | -------------------------------------------------------------------------------- /basic_net/densenet.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.utils.model_zoo as model_zoo 6 | from collections import OrderedDict 7 | 8 | __all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161'] 9 | 10 | 11 | model_urls = { 12 | 'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth', 13 | 'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth', 14 | 'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth', 15 | 'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth', 16 | } 17 | 18 | 19 | class _DenseLayer(nn.Sequential): 20 | def __init__(self, num_input_features, growth_rate, bn_size, drop_rate): 21 | super(_DenseLayer, self).__init__() 22 | self.add_module('norm1', nn.BatchNorm2d(num_input_features)), 23 | self.add_module('relu1', nn.ReLU(inplace=True)), 24 | self.add_module('conv1', nn.Conv2d(num_input_features, bn_size * 25 | growth_rate, kernel_size=1, stride=1, bias=False)), 26 | self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)), 27 | self.add_module('relu2', nn.ReLU(inplace=True)), 28 | self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate, 29 | kernel_size=3, stride=1, padding=1, bias=False)), 30 | self.drop_rate = drop_rate 31 | 32 | def forward(self, x): 33 | new_features = super(_DenseLayer, self).forward(x) 34 | if self.drop_rate > 0: 35 | new_features = F.dropout(new_features, p=self.drop_rate, training=self.training) 36 | return torch.cat([x, new_features], 1) 37 | 38 | 39 | class _DenseBlock(nn.Sequential): 40 | def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate): 41 | super(_DenseBlock, self).__init__() 42 | for i in range(num_layers): 43 | layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate) 44 | self.add_module('denselayer%d' % (i + 1), layer) 45 | 46 | 47 | class _Transition(nn.Sequential): 48 | def __init__(self, num_input_features, num_output_features): 49 | super(_Transition, self).__init__() 50 | self.add_module('norm', nn.BatchNorm2d(num_input_features)) 51 | self.add_module('relu', nn.ReLU(inplace=True)) 52 | self.add_module('conv', nn.Conv2d(num_input_features, num_output_features, 53 | kernel_size=1, stride=1, bias=False)) 54 | self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) 55 | 56 | 57 | class DenseNet(nn.Module): 58 | r"""Densenet-BC model class, based on 59 | `"Densely Connected Convolutional Networks" `_ 60 | 61 | Args: 62 | growth_rate (int) - how many filters to add each layer (`k` in paper) 63 | block_config (list of 4 ints) - how many layers in each pooling block 64 | num_init_features (int) - the number of filters to learn in the first convolution layer 65 | bn_size (int) - multiplicative factor for number of bottle neck layers 66 | (i.e. bn_size * k features in the bottleneck layer) 67 | drop_rate (float) - dropout rate after each dense layer 68 | num_classes (int) - number of classification classes 69 | """ 70 | 71 | def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), 72 | num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000): 73 | 74 | super(DenseNet, self).__init__() 75 | 76 | # First convolution 77 | self.features = nn.Sequential(OrderedDict([ 78 | ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)), 79 | ('norm0', nn.BatchNorm2d(num_init_features)), 80 | ('relu0', nn.ReLU(inplace=True)), 81 | ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), 82 | ])) 83 | 84 | # Each denseblock 85 | num_features = num_init_features 86 | for i, num_layers in enumerate(block_config): 87 | block = _DenseBlock(num_layers=num_layers, num_input_features=num_features, 88 | bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate) 89 | self.features.add_module('denseblock%d' % (i + 1), block) 90 | num_features = num_features + num_layers * growth_rate 91 | if i != len(block_config) - 1: 92 | trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2) 93 | self.features.add_module('transition%d' % (i + 1), trans) 94 | num_features = num_features // 2 95 | 96 | # Final batch norm 97 | self.features.add_module('norm5', nn.BatchNorm2d(num_features)) 98 | 99 | # Linear layer 100 | self.classifier = nn.Linear(num_features, num_classes) 101 | 102 | # Official init from torch repo. 103 | for m in self.modules(): 104 | if isinstance(m, nn.Conv2d): 105 | nn.init.kaiming_normal_(m.weight) 106 | elif isinstance(m, nn.BatchNorm2d): 107 | nn.init.constant_(m.weight, 1) 108 | nn.init.constant_(m.bias, 0) 109 | elif isinstance(m, nn.Linear): 110 | nn.init.constant_(m.bias, 0) 111 | 112 | def forward(self, x): 113 | features = self.features(x) 114 | out = F.relu(features, inplace=True) 115 | out = F.adaptive_avg_pool2d(out, (1, 1)).view(features.size(0), -1) 116 | out = self.classifier(out) 117 | return out 118 | 119 | 120 | def densenet121(pretrained=False, **kwargs): 121 | r"""Densenet-121 model from 122 | `"Densely Connected Convolutional Networks" `_ 123 | 124 | Args: 125 | pretrained (bool): If True, returns a model pre-trained on ImageNet 126 | """ 127 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), 128 | **kwargs) 129 | if pretrained: 130 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 131 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 132 | # They are also in the checkpoints in model_urls. This pattern is used 133 | # to find such keys. 134 | pattern = re.compile( 135 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 136 | state_dict = model_zoo.load_url(model_urls['densenet121']) 137 | for key in list(state_dict.keys()): 138 | res = pattern.match(key) 139 | if res: 140 | new_key = res.group(1) + res.group(2) 141 | state_dict[new_key] = state_dict[key] 142 | del state_dict[key] 143 | # state_dict = {k: v for k, v in state_dict.items() if ('conv0' in k or'norm0' in k ) } 144 | model.load_state_dict(state_dict,strict=False) 145 | return model 146 | 147 | 148 | def densenet169(pretrained=False, **kwargs): 149 | r"""Densenet-169 model from 150 | `"Densely Connected Convolutional Networks" `_ 151 | 152 | Args: 153 | pretrained (bool): If True, returns a model pre-trained on ImageNet 154 | """ 155 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32), 156 | **kwargs) 157 | if pretrained: 158 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 159 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 160 | # They are also in the checkpoints in model_urls. This pattern is used 161 | # to find such keys. 162 | pattern = re.compile( 163 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 164 | state_dict = model_zoo.load_url(model_urls['densenet169']) 165 | for key in list(state_dict.keys()): 166 | res = pattern.match(key) 167 | if res: 168 | new_key = res.group(1) + res.group(2) 169 | state_dict[new_key] = state_dict[key] 170 | del state_dict[key] 171 | model.load_state_dict(state_dict) 172 | return model 173 | 174 | 175 | def densenet201(pretrained=False, **kwargs): 176 | r"""Densenet-201 model from 177 | `"Densely Connected Convolutional Networks" `_ 178 | 179 | Args: 180 | pretrained (bool): If True, returns a model pre-trained on ImageNet 181 | """ 182 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32), 183 | **kwargs) 184 | if pretrained: 185 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 186 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 187 | # They are also in the checkpoints in model_urls. This pattern is used 188 | # to find such keys. 189 | pattern = re.compile( 190 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 191 | state_dict = model_zoo.load_url(model_urls['densenet201']) 192 | for key in list(state_dict.keys()): 193 | res = pattern.match(key) 194 | if res: 195 | new_key = res.group(1) + res.group(2) 196 | state_dict[new_key] = state_dict[key] 197 | del state_dict[key] 198 | model.load_state_dict(state_dict) 199 | return model 200 | 201 | 202 | def densenet161(pretrained=False, **kwargs): 203 | r"""Densenet-161 model from 204 | `"Densely Connected Convolutional Networks" `_ 205 | 206 | Args: 207 | pretrained (bool): If True, returns a model pre-trained on ImageNet 208 | """ 209 | model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24), 210 | **kwargs) 211 | if pretrained: 212 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 213 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 214 | # They are also in the checkpoints in model_urls. This pattern is used 215 | # to find such keys. 216 | pattern = re.compile( 217 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 218 | state_dict = model_zoo.load_url(model_urls['densenet161']) 219 | for key in list(state_dict.keys()): 220 | res = pattern.match(key) 221 | if res: 222 | new_key = res.group(1) + res.group(2) 223 | state_dict[new_key] = state_dict[key] 224 | del state_dict[key] 225 | model.load_state_dict(state_dict) 226 | return model 227 | -------------------------------------------------------------------------------- /basic_net/inception.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.utils.model_zoo as model_zoo 5 | 6 | 7 | __all__ = ['Inception3', 'inception_v3'] 8 | 9 | 10 | model_urls = { 11 | # Inception v3 ported from TensorFlow 12 | 'inception_v3_google': 'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth', 13 | } 14 | 15 | 16 | def inception_v3(pretrained=False, **kwargs): 17 | r"""Inception v3 model architecture from 18 | `"Rethinking the Inception Architecture for Computer Vision" `_. 19 | 20 | .. note:: 21 | **Important**: In contrast to the other models the inception_v3 expects tensors with a size of 22 | N x 3 x 299 x 299, so ensure your images are sized accordingly. 23 | 24 | Args: 25 | pretrained (bool): If True, returns a model pre-trained on ImageNet 26 | """ 27 | if pretrained: 28 | if 'transform_input' not in kwargs: 29 | kwargs['transform_input'] = True 30 | model = Inception3(**kwargs) 31 | model.load_state_dict(model_zoo.load_url(model_urls['inception_v3_google'])) 32 | return model 33 | 34 | return Inception3(**kwargs) 35 | 36 | 37 | class Inception3(nn.Module): 38 | 39 | def __init__(self, num_classes=1000, aux_logits=True, transform_input=False): 40 | super(Inception3, self).__init__() 41 | self.aux_logits = aux_logits 42 | self.transform_input = transform_input 43 | self.Conv2d_1a_3x3 = BasicConv2d(3, 32, kernel_size=3, stride=2) 44 | self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3) 45 | self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1) 46 | self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1) 47 | self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3) 48 | self.Mixed_5b = InceptionA(192, pool_features=32) 49 | self.Mixed_5c = InceptionA(256, pool_features=64) 50 | self.Mixed_5d = InceptionA(288, pool_features=64) 51 | self.Mixed_6a = InceptionB(288) 52 | self.Mixed_6b = InceptionC(768, channels_7x7=128) 53 | self.Mixed_6c = InceptionC(768, channels_7x7=160) 54 | self.Mixed_6d = InceptionC(768, channels_7x7=160) 55 | self.Mixed_6e = InceptionC(768, channels_7x7=192) 56 | if aux_logits: 57 | self.AuxLogits = InceptionAux(768, num_classes) 58 | self.Mixed_7a = InceptionD(768) 59 | self.Mixed_7b = InceptionE(1280) 60 | self.Mixed_7c = InceptionE(2048) 61 | self.fc = nn.Linear(2048, num_classes) 62 | 63 | for m in self.modules(): 64 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): 65 | import scipy.stats as stats 66 | stddev = m.stddev if hasattr(m, 'stddev') else 0.1 67 | X = stats.truncnorm(-2, 2, scale=stddev) 68 | values = torch.Tensor(X.rvs(m.weight.numel())) 69 | values = values.view(m.weight.size()) 70 | m.weight.data.copy_(values) 71 | elif isinstance(m, nn.BatchNorm2d): 72 | nn.init.constant_(m.weight, 1) 73 | nn.init.constant_(m.bias, 0) 74 | 75 | def forward(self, x): 76 | # if self.transform_input: 77 | # x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5 78 | # x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5 79 | # x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5 80 | # x = torch.cat((x_ch0, x_ch1, x_ch2), 1) 81 | # N x 3 x 299 x 299 82 | x = self.Conv2d_1a_3x3(x) 83 | # N x 32 x 149 x 149 84 | x = self.Conv2d_2a_3x3(x) 85 | # N x 32 x 147 x 147 86 | x = self.Conv2d_2b_3x3(x) 87 | # N x 64 x 147 x 147 88 | x = F.max_pool2d(x, kernel_size=3, stride=2) 89 | # N x 64 x 73 x 73 90 | x = self.Conv2d_3b_1x1(x) 91 | # N x 80 x 73 x 73 92 | x = self.Conv2d_4a_3x3(x) 93 | # N x 192 x 71 x 71 94 | x = F.max_pool2d(x, kernel_size=3, stride=2) 95 | # N x 192 x 35 x 35 96 | x = self.Mixed_5b(x) 97 | # N x 256 x 35 x 35 98 | x = self.Mixed_5c(x) 99 | # N x 288 x 35 x 35 100 | x = self.Mixed_5d(x) 101 | # N x 288 x 35 x 35 102 | x = self.Mixed_6a(x) 103 | # N x 768 x 17 x 17 104 | x = self.Mixed_6b(x) 105 | # N x 768 x 17 x 17 106 | x = self.Mixed_6c(x) 107 | # N x 768 x 17 x 17 108 | x = self.Mixed_6d(x) 109 | # N x 768 x 17 x 17 110 | x = self.Mixed_6e(x) 111 | # N x 768 x 17 x 17 112 | if self.training and self.aux_logits: 113 | aux = self.AuxLogits(x) 114 | # N x 768 x 17 x 17 115 | x = self.Mixed_7a(x) 116 | # N x 1280 x 8 x 8 117 | x = self.Mixed_7b(x) 118 | # N x 2048 x 8 x 8 119 | x = self.Mixed_7c(x) 120 | # N x 2048 x 8 x 8 121 | # Adaptive average pooling 122 | x = F.adaptive_avg_pool2d(x, (1, 1)) 123 | # N x 2048 x 1 x 1 124 | x = F.dropout(x, training=self.training) 125 | # N x 2048 x 1 x 1 126 | x = x.view(x.size(0), -1) 127 | # N x 2048 128 | x = self.fc(x) 129 | # N x 1000 (num_classes) 130 | if self.training and self.aux_logits: 131 | return x, aux 132 | return x 133 | 134 | 135 | class InceptionA(nn.Module): 136 | 137 | def __init__(self, in_channels, pool_features): 138 | super(InceptionA, self).__init__() 139 | self.branch1x1 = BasicConv2d(in_channels, 64, kernel_size=1) 140 | 141 | self.branch5x5_1 = BasicConv2d(in_channels, 48, kernel_size=1) 142 | self.branch5x5_2 = BasicConv2d(48, 64, kernel_size=5, padding=2) 143 | 144 | self.branch3x3dbl_1 = BasicConv2d(in_channels, 64, kernel_size=1) 145 | self.branch3x3dbl_2 = BasicConv2d(64, 96, kernel_size=3, padding=1) 146 | self.branch3x3dbl_3 = BasicConv2d(96, 96, kernel_size=3, padding=1) 147 | 148 | self.branch_pool = BasicConv2d(in_channels, pool_features, kernel_size=1) 149 | 150 | def forward(self, x): 151 | branch1x1 = self.branch1x1(x) 152 | 153 | branch5x5 = self.branch5x5_1(x) 154 | branch5x5 = self.branch5x5_2(branch5x5) 155 | 156 | branch3x3dbl = self.branch3x3dbl_1(x) 157 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 158 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) 159 | 160 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) 161 | branch_pool = self.branch_pool(branch_pool) 162 | 163 | outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] 164 | return torch.cat(outputs, 1) 165 | 166 | 167 | class InceptionB(nn.Module): 168 | 169 | def __init__(self, in_channels): 170 | super(InceptionB, self).__init__() 171 | self.branch3x3 = BasicConv2d(in_channels, 384, kernel_size=3, stride=2) 172 | 173 | self.branch3x3dbl_1 = BasicConv2d(in_channels, 64, kernel_size=1) 174 | self.branch3x3dbl_2 = BasicConv2d(64, 96, kernel_size=3, padding=1) 175 | self.branch3x3dbl_3 = BasicConv2d(96, 96, kernel_size=3, stride=2) 176 | 177 | def forward(self, x): 178 | branch3x3 = self.branch3x3(x) 179 | 180 | branch3x3dbl = self.branch3x3dbl_1(x) 181 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 182 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) 183 | 184 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=2) 185 | 186 | outputs = [branch3x3, branch3x3dbl, branch_pool] 187 | return torch.cat(outputs, 1) 188 | 189 | 190 | class InceptionC(nn.Module): 191 | 192 | def __init__(self, in_channels, channels_7x7): 193 | super(InceptionC, self).__init__() 194 | self.branch1x1 = BasicConv2d(in_channels, 192, kernel_size=1) 195 | 196 | c7 = channels_7x7 197 | self.branch7x7_1 = BasicConv2d(in_channels, c7, kernel_size=1) 198 | self.branch7x7_2 = BasicConv2d(c7, c7, kernel_size=(1, 7), padding=(0, 3)) 199 | self.branch7x7_3 = BasicConv2d(c7, 192, kernel_size=(7, 1), padding=(3, 0)) 200 | 201 | self.branch7x7dbl_1 = BasicConv2d(in_channels, c7, kernel_size=1) 202 | self.branch7x7dbl_2 = BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0)) 203 | self.branch7x7dbl_3 = BasicConv2d(c7, c7, kernel_size=(1, 7), padding=(0, 3)) 204 | self.branch7x7dbl_4 = BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0)) 205 | self.branch7x7dbl_5 = BasicConv2d(c7, 192, kernel_size=(1, 7), padding=(0, 3)) 206 | 207 | self.branch_pool = BasicConv2d(in_channels, 192, kernel_size=1) 208 | 209 | def forward(self, x): 210 | branch1x1 = self.branch1x1(x) 211 | 212 | branch7x7 = self.branch7x7_1(x) 213 | branch7x7 = self.branch7x7_2(branch7x7) 214 | branch7x7 = self.branch7x7_3(branch7x7) 215 | 216 | branch7x7dbl = self.branch7x7dbl_1(x) 217 | branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) 218 | branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) 219 | branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) 220 | branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) 221 | 222 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) 223 | branch_pool = self.branch_pool(branch_pool) 224 | 225 | outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] 226 | return torch.cat(outputs, 1) 227 | 228 | 229 | class InceptionD(nn.Module): 230 | 231 | def __init__(self, in_channels): 232 | super(InceptionD, self).__init__() 233 | self.branch3x3_1 = BasicConv2d(in_channels, 192, kernel_size=1) 234 | self.branch3x3_2 = BasicConv2d(192, 320, kernel_size=3, stride=2) 235 | 236 | self.branch7x7x3_1 = BasicConv2d(in_channels, 192, kernel_size=1) 237 | self.branch7x7x3_2 = BasicConv2d(192, 192, kernel_size=(1, 7), padding=(0, 3)) 238 | self.branch7x7x3_3 = BasicConv2d(192, 192, kernel_size=(7, 1), padding=(3, 0)) 239 | self.branch7x7x3_4 = BasicConv2d(192, 192, kernel_size=3, stride=2) 240 | 241 | def forward(self, x): 242 | branch3x3 = self.branch3x3_1(x) 243 | branch3x3 = self.branch3x3_2(branch3x3) 244 | 245 | branch7x7x3 = self.branch7x7x3_1(x) 246 | branch7x7x3 = self.branch7x7x3_2(branch7x7x3) 247 | branch7x7x3 = self.branch7x7x3_3(branch7x7x3) 248 | branch7x7x3 = self.branch7x7x3_4(branch7x7x3) 249 | 250 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=2) 251 | outputs = [branch3x3, branch7x7x3, branch_pool] 252 | return torch.cat(outputs, 1) 253 | 254 | 255 | class InceptionE(nn.Module): 256 | 257 | def __init__(self, in_channels): 258 | super(InceptionE, self).__init__() 259 | self.branch1x1 = BasicConv2d(in_channels, 320, kernel_size=1) 260 | 261 | self.branch3x3_1 = BasicConv2d(in_channels, 384, kernel_size=1) 262 | self.branch3x3_2a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1)) 263 | self.branch3x3_2b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0)) 264 | 265 | self.branch3x3dbl_1 = BasicConv2d(in_channels, 448, kernel_size=1) 266 | self.branch3x3dbl_2 = BasicConv2d(448, 384, kernel_size=3, padding=1) 267 | self.branch3x3dbl_3a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1)) 268 | self.branch3x3dbl_3b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0)) 269 | 270 | self.branch_pool = BasicConv2d(in_channels, 192, kernel_size=1) 271 | 272 | def forward(self, x): 273 | branch1x1 = self.branch1x1(x) 274 | 275 | branch3x3 = self.branch3x3_1(x) 276 | branch3x3 = [ 277 | self.branch3x3_2a(branch3x3), 278 | self.branch3x3_2b(branch3x3), 279 | ] 280 | branch3x3 = torch.cat(branch3x3, 1) 281 | 282 | branch3x3dbl = self.branch3x3dbl_1(x) 283 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 284 | branch3x3dbl = [ 285 | self.branch3x3dbl_3a(branch3x3dbl), 286 | self.branch3x3dbl_3b(branch3x3dbl), 287 | ] 288 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 289 | 290 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) 291 | branch_pool = self.branch_pool(branch_pool) 292 | 293 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 294 | return torch.cat(outputs, 1) 295 | 296 | 297 | class InceptionAux(nn.Module): 298 | 299 | def __init__(self, in_channels, num_classes): 300 | super(InceptionAux, self).__init__() 301 | self.conv0 = BasicConv2d(in_channels, 128, kernel_size=1) 302 | self.conv1 = BasicConv2d(128, 768, kernel_size=5) 303 | self.conv1.stddev = 0.01 304 | self.fc = nn.Linear(768, num_classes) 305 | self.fc.stddev = 0.001 306 | 307 | def forward(self, x): 308 | # N x 768 x 17 x 17 309 | x = F.avg_pool2d(x, kernel_size=5, stride=3) 310 | # N x 768 x 5 x 5 311 | x = self.conv0(x) 312 | # N x 128 x 5 x 5 313 | x = self.conv1(x) 314 | # N x 768 x 1 x 1 315 | # Adaptive average pooling 316 | x = F.adaptive_avg_pool2d(x, (1, 1)) 317 | # N x 768 x 1 x 1 318 | x = x.view(x.size(0), -1) 319 | # N x 768 320 | x = self.fc(x) 321 | # N x 1000 322 | return x 323 | 324 | 325 | class BasicConv2d(nn.Module): 326 | 327 | def __init__(self, in_channels, out_channels, **kwargs): 328 | super(BasicConv2d, self).__init__() 329 | self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) 330 | self.bn = nn.BatchNorm2d(out_channels, eps=0.001) 331 | 332 | def forward(self, x): 333 | x = self.conv(x) 334 | x = self.bn(x) 335 | return F.relu(x, inplace=True) 336 | -------------------------------------------------------------------------------- /basic_net/mnasnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import warnings 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.utils.model_zoo as model_zoo 7 | 8 | __all__ = ['MNASNet', 'mnasnet0_5', 'mnasnet0_75', 'mnasnet1_0', 'mnasnet1_3'] 9 | 10 | _MODEL_URLS = { 11 | "mnasnet0_5": 12 | "https://download.pytorch.org/models/mnasnet0.5_top1_67.823-3ffadce67e.pth", 13 | "mnasnet0_75": None, 14 | "mnasnet1_0": 15 | "https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth", 16 | "mnasnet1_3": None 17 | } 18 | 19 | # Paper suggests 0.9997 momentum, for TensorFlow. Equivalent PyTorch momentum is 20 | # 1.0 - tensorflow. 21 | _BN_MOMENTUM = 1 - 0.9997 22 | 23 | 24 | class _InvertedResidual(nn.Module): 25 | 26 | def __init__(self, in_ch, out_ch, kernel_size, stride, expansion_factor, 27 | bn_momentum=0.1): 28 | super(_InvertedResidual, self).__init__() 29 | assert stride in [1, 2] 30 | assert kernel_size in [3, 5] 31 | mid_ch = in_ch * expansion_factor 32 | self.apply_residual = (in_ch == out_ch and stride == 1) 33 | self.layers = nn.Sequential( 34 | # Pointwise 35 | nn.Conv2d(in_ch, mid_ch, 1, bias=False), 36 | nn.BatchNorm2d(mid_ch, momentum=bn_momentum), 37 | nn.ReLU(inplace=True), 38 | # Depthwise 39 | nn.Conv2d(mid_ch, mid_ch, kernel_size, padding=kernel_size // 2, 40 | stride=stride, groups=mid_ch, bias=False), 41 | nn.BatchNorm2d(mid_ch, momentum=bn_momentum), 42 | nn.ReLU(inplace=True), 43 | # Linear pointwise. Note that there's no activation. 44 | nn.Conv2d(mid_ch, out_ch, 1, bias=False), 45 | nn.BatchNorm2d(out_ch, momentum=bn_momentum)) 46 | 47 | def forward(self, input): 48 | if self.apply_residual: 49 | return self.layers(input) + input 50 | else: 51 | return self.layers(input) 52 | 53 | 54 | def _stack(in_ch, out_ch, kernel_size, stride, exp_factor, repeats, 55 | bn_momentum): 56 | """ Creates a stack of inverted residuals. """ 57 | assert repeats >= 1 58 | # First one has no skip, because feature map size changes. 59 | first = _InvertedResidual(in_ch, out_ch, kernel_size, stride, exp_factor, 60 | bn_momentum=bn_momentum) 61 | remaining = [] 62 | for _ in range(1, repeats): 63 | remaining.append( 64 | _InvertedResidual(out_ch, out_ch, kernel_size, 1, exp_factor, 65 | bn_momentum=bn_momentum)) 66 | return nn.Sequential(first, *remaining) 67 | 68 | 69 | def _round_to_multiple_of(val, divisor, round_up_bias=0.9): 70 | """ Asymmetric rounding to make `val` divisible by `divisor`. With default 71 | bias, will round up, unless the number is no more than 10% greater than the 72 | smaller divisible value, i.e. (83, 8) -> 80, but (84, 8) -> 88. """ 73 | assert 0.0 < round_up_bias < 1.0 74 | new_val = max(divisor, int(val + divisor / 2) // divisor * divisor) 75 | return new_val if new_val >= round_up_bias * val else new_val + divisor 76 | 77 | 78 | def _get_depths(alpha): 79 | """ Scales tensor depths as in reference MobileNet code, prefers rouding up 80 | rather than down. """ 81 | depths = [32, 16, 24, 40, 80, 96, 192, 320] 82 | return [_round_to_multiple_of(depth * alpha, 8) for depth in depths] 83 | 84 | 85 | class MNASNet(torch.nn.Module): 86 | """ MNASNet, as described in https://arxiv.org/pdf/1807.11626.pdf. This 87 | implements the B1 variant of the model. 88 | >>> model = MNASNet(1000, 1.0) 89 | >>> x = torch.rand(1, 3, 224, 224) 90 | >>> y = model(x) 91 | >>> y.dim() 92 | 1 93 | >>> y.nelement() 94 | 1000 95 | """ 96 | # Version 2 adds depth scaling in the initial stages of the network. 97 | _version = 2 98 | 99 | def __init__(self, alpha, num_classes=1000, dropout=0.2): 100 | super(MNASNet, self).__init__() 101 | assert alpha > 0.0 102 | self.alpha = alpha 103 | self.num_classes = num_classes 104 | depths = _get_depths(alpha) 105 | layers = [ 106 | # First layer: regular conv. 107 | nn.Conv2d(3, depths[0], 3, padding=1, stride=2, bias=False), 108 | nn.BatchNorm2d(depths[0], momentum=_BN_MOMENTUM), 109 | nn.ReLU(inplace=True), 110 | # Depthwise separable, no skip. 111 | nn.Conv2d(depths[0], depths[0], 3, padding=1, stride=1, 112 | groups=depths[0], bias=False), 113 | nn.BatchNorm2d(depths[0], momentum=_BN_MOMENTUM), 114 | nn.ReLU(inplace=True), 115 | nn.Conv2d(depths[0], depths[1], 1, padding=0, stride=1, bias=False), 116 | nn.BatchNorm2d(depths[1], momentum=_BN_MOMENTUM), 117 | # MNASNet blocks: stacks of inverted residuals. 118 | _stack(depths[1], depths[2], 3, 2, 3, 3, _BN_MOMENTUM), 119 | _stack(depths[2], depths[3], 5, 2, 3, 3, _BN_MOMENTUM), 120 | _stack(depths[3], depths[4], 5, 2, 6, 3, _BN_MOMENTUM), 121 | _stack(depths[4], depths[5], 3, 1, 6, 2, _BN_MOMENTUM), 122 | _stack(depths[5], depths[6], 5, 2, 6, 4, _BN_MOMENTUM), 123 | _stack(depths[6], depths[7], 3, 1, 6, 1, _BN_MOMENTUM), 124 | # Final mapping to classifier input. 125 | nn.Conv2d(depths[7], 1280, 1, padding=0, stride=1, bias=False), 126 | nn.BatchNorm2d(1280, momentum=_BN_MOMENTUM), 127 | nn.ReLU(inplace=True), 128 | ] 129 | self.layers = nn.Sequential(*layers) 130 | self.classifier = nn.Sequential(nn.Dropout(p=dropout, inplace=True), 131 | nn.Linear(1280, num_classes)) 132 | self._initialize_weights() 133 | 134 | def forward(self, x): 135 | x = self.layers(x) 136 | # Equivalent to global avgpool and removing H and W dimensions. 137 | x = x.mean([2, 3]) 138 | return self.classifier(x) 139 | 140 | def _initialize_weights(self): 141 | for m in self.modules(): 142 | if isinstance(m, nn.Conv2d): 143 | nn.init.kaiming_normal_(m.weight, mode="fan_out", 144 | nonlinearity="relu") 145 | if m.bias is not None: 146 | nn.init.zeros_(m.bias) 147 | elif isinstance(m, nn.BatchNorm2d): 148 | nn.init.ones_(m.weight) 149 | nn.init.zeros_(m.bias) 150 | elif isinstance(m, nn.Linear): 151 | nn.init.kaiming_uniform_(m.weight, mode="fan_out", 152 | nonlinearity="sigmoid") 153 | nn.init.zeros_(m.bias) 154 | 155 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, 156 | missing_keys, unexpected_keys, error_msgs): 157 | version = local_metadata.get("version", None) 158 | assert version in [1, 2] 159 | 160 | if version == 1 and not self.alpha == 1.0: 161 | # In the initial version of the model (v1), stem was fixed-size. 162 | # All other layer configurations were the same. This will patch 163 | # the model so that it's identical to v1. Model with alpha 1.0 is 164 | # unaffected. 165 | depths = _get_depths(self.alpha) 166 | v1_stem = [ 167 | nn.Conv2d(3, 32, 3, padding=1, stride=2, bias=False), 168 | nn.BatchNorm2d(32, momentum=_BN_MOMENTUM), 169 | nn.ReLU(inplace=True), 170 | nn.Conv2d(32, 32, 3, padding=1, stride=1, groups=32, 171 | bias=False), 172 | nn.BatchNorm2d(32, momentum=_BN_MOMENTUM), 173 | nn.ReLU(inplace=True), 174 | nn.Conv2d(32, 16, 1, padding=0, stride=1, bias=False), 175 | nn.BatchNorm2d(16, momentum=_BN_MOMENTUM), 176 | _stack(16, depths[2], 3, 2, 3, 3, _BN_MOMENTUM), 177 | ] 178 | for idx, layer in enumerate(v1_stem): 179 | self.layers[idx] = layer 180 | 181 | # The model is now identical to v1, and must be saved as such. 182 | self._version = 1 183 | warnings.warn( 184 | "A new version of MNASNet model has been implemented. " 185 | "Your checkpoint was saved using the previous version. " 186 | "This checkpoint will load and work as before, but " 187 | "you may want to upgrade by training a newer model or " 188 | "transfer learning from an updated ImageNet checkpoint.", 189 | UserWarning) 190 | 191 | super(MNASNet, self)._load_from_state_dict( 192 | state_dict, prefix, local_metadata, strict, missing_keys, 193 | unexpected_keys, error_msgs) 194 | 195 | 196 | def _load_pretrained(model_name, model, progress): 197 | if model_name not in _MODEL_URLS or _MODEL_URLS[model_name] is None: 198 | raise ValueError( 199 | "No checkpoint is available for model type {}".format(model_name)) 200 | checkpoint_url = _MODEL_URLS[model_name] 201 | model.load_state_dict(model_zoo.load_url(url=checkpoint_url,model_dir='./')) 202 | 203 | 204 | def mnasnet0_5(pretrained=False, progress=True, **kwargs): 205 | """MNASNet with depth multiplier of 0.5 from 206 | `"MnasNet: Platform-Aware Neural Architecture Search for Mobile" 207 | `_. 208 | Args: 209 | pretrained (bool): If True, returns a model pre-trained on ImageNet 210 | progress (bool): If True, displays a progress bar of the download to stderr 211 | """ 212 | model = MNASNet(0.5, **kwargs) 213 | if pretrained: 214 | _load_pretrained("mnasnet0_5", model, progress) 215 | return model 216 | 217 | 218 | def mnasnet0_75(pretrained=False, progress=True, **kwargs): 219 | """MNASNet with depth multiplier of 0.75 from 220 | `"MnasNet: Platform-Aware Neural Architecture Search for Mobile" 221 | `_. 222 | Args: 223 | pretrained (bool): If True, returns a model pre-trained on ImageNet 224 | progress (bool): If True, displays a progress bar of the download to stderr 225 | """ 226 | model = MNASNet(0.75, **kwargs) 227 | if pretrained: 228 | _load_pretrained("mnasnet0_75", model, progress) 229 | return model 230 | 231 | 232 | def mnasnet1_0(pretrained=False, progress=True, **kwargs): 233 | """MNASNet with depth multiplier of 1.0 from 234 | `"MnasNet: Platform-Aware Neural Architecture Search for Mobile" 235 | `_. 236 | Args: 237 | pretrained (bool): If True, returns a model pre-trained on ImageNet 238 | progress (bool): If True, displays a progress bar of the download to stderr 239 | """ 240 | model = MNASNet(1.0, **kwargs) 241 | if pretrained: 242 | _load_pretrained("mnasnet1_0", model, progress) 243 | return model 244 | 245 | 246 | def mnasnet1_3(pretrained=False, progress=True, **kwargs): 247 | """MNASNet with depth multiplier of 1.3 from 248 | `"MnasNet: Platform-Aware Neural Architecture Search for Mobile" 249 | `_. 250 | Args: 251 | pretrained (bool): If True, returns a model pre-trained on ImageNet 252 | progress (bool): If True, displays a progress bar of the download to stderr 253 | """ 254 | model = MNASNet(1.3, **kwargs) 255 | if pretrained: 256 | _load_pretrained("mnasnet1_3", model, progress) 257 | return model 258 | -------------------------------------------------------------------------------- /basic_net/nystrom_atten.py: -------------------------------------------------------------------------------- 1 | from math import ceil 2 | import torch 3 | from torch import nn, einsum 4 | import torch.nn.functional as F 5 | 6 | from einops import rearrange, reduce 7 | 8 | # helper functions 9 | 10 | def exists(val): 11 | return val is not None 12 | 13 | def moore_penrose_iter_pinv(x, iters = 6): 14 | device = x.device 15 | 16 | abs_x = torch.abs(x) 17 | col = abs_x.sum(dim = -1) 18 | row = abs_x.sum(dim = -2) 19 | z = rearrange(x, '... i j -> ... j i') / (torch.max(col) * torch.max(row)) 20 | 21 | I = torch.eye(x.shape[-1], device = device) 22 | I = rearrange(I, 'i j -> () i j') 23 | 24 | for _ in range(iters): 25 | xz = x @ z 26 | z = 0.25 * z @ (13 * I - (xz @ (15 * I - (xz @ (7 * I - xz))))) 27 | 28 | return z 29 | 30 | # main attention class 31 | 32 | class NystromAttention(nn.Module): 33 | def __init__( 34 | self, 35 | dim, 36 | dim_head = 64, 37 | heads = 8, 38 | num_landmarks = 256, 39 | pinv_iterations = 6, 40 | residual = True, 41 | residual_conv_kernel = 33, 42 | eps = 1e-8, 43 | dropout = 0. 44 | ): 45 | super().__init__() 46 | self.eps = eps 47 | inner_dim = heads * dim_head 48 | 49 | self.num_landmarks = num_landmarks 50 | self.pinv_iterations = pinv_iterations 51 | 52 | self.heads = heads 53 | self.scale = dim_head ** -0.5 54 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 55 | 56 | self.to_out = nn.Sequential( 57 | nn.Linear(inner_dim, dim), 58 | nn.Dropout(dropout) 59 | ) 60 | 61 | self.residual = residual 62 | if residual: 63 | kernel_size = residual_conv_kernel 64 | padding = residual_conv_kernel // 2 65 | self.res_conv = nn.Conv2d(heads, heads, (kernel_size, 1), padding = (padding, 0), groups = heads, bias = False) 66 | 67 | def forward(self, x, mask = None, return_attn = False): 68 | b, n, _, h, m, iters, eps = *x.shape, self.heads, self.num_landmarks, self.pinv_iterations, self.eps 69 | 70 | # pad so that sequence can be evenly divided into m landmarks 71 | 72 | remainder = n % m 73 | if remainder > 0: 74 | padding = m - (n % m) 75 | x = F.pad(x, (0, 0, padding, 0), value = 0) 76 | 77 | if exists(mask): 78 | mask = F.pad(mask, (padding, 0), value = False) 79 | 80 | # derive query, keys, values 81 | 82 | q, k, v = self.to_qkv(x).chunk(3, dim = -1) 83 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) 84 | 85 | # set masked positions to 0 in queries, keys, values 86 | 87 | if exists(mask): 88 | mask = rearrange(mask, 'b n -> b () n') 89 | q, k, v = map(lambda t: t * mask[..., None], (q, k, v)) 90 | 91 | q = q * self.scale 92 | 93 | # generate landmarks by sum reduction, and then calculate mean using the mask 94 | 95 | l = ceil(n / m) 96 | landmark_einops_eq = '... (n l) d -> ... n d' 97 | q_landmarks = reduce(q, landmark_einops_eq, 'sum', l = l) 98 | k_landmarks = reduce(k, landmark_einops_eq, 'sum', l = l) 99 | 100 | # calculate landmark mask, and also get sum of non-masked elements in preparation for masked mean 101 | 102 | divisor = l 103 | if exists(mask): 104 | mask_landmarks_sum = reduce(mask, '... (n l) -> ... n', 'sum', l = l) 105 | divisor = mask_landmarks_sum[..., None] + eps 106 | mask_landmarks = mask_landmarks_sum > 0 107 | 108 | # masked mean (if mask exists) 109 | 110 | q_landmarks /= divisor 111 | k_landmarks /= divisor 112 | 113 | # similarities 114 | 115 | einops_eq = '... i d, ... j d -> ... i j' 116 | sim1 = einsum(einops_eq, q.float(), k_landmarks.float()) 117 | sim2 = einsum(einops_eq, q_landmarks.float(), k_landmarks.float()) 118 | sim3 = einsum(einops_eq, q_landmarks.float(), k.float()) 119 | 120 | # masking 121 | 122 | if exists(mask): 123 | mask_value = -torch.finfo(q.dtype).max 124 | sim1.masked_fill_(~(mask[..., None] * mask_landmarks[..., None, :]), mask_value) 125 | sim2.masked_fill_(~(mask_landmarks[..., None] * mask_landmarks[..., None, :]), mask_value) 126 | sim3.masked_fill_(~(mask_landmarks[..., None] * mask[..., None, :]), mask_value) 127 | 128 | # eq (15) in the paper and aggregate values 129 | 130 | attn1, attn2, attn3 = map(lambda t: t.softmax(dim = -1), (sim1, sim2, sim3)) 131 | attn2_inv = moore_penrose_iter_pinv(attn2, iters) 132 | 133 | out = (attn1 @ attn2_inv) @ (attn3 @ v) 134 | 135 | # add depth-wise conv residual of values 136 | 137 | if self.residual: 138 | out += self.res_conv(v) 139 | 140 | # merge and combine heads 141 | 142 | out = rearrange(out, 'b h n d -> b n (h d)', h = h) 143 | out = self.to_out(out) 144 | out = out[:, -n:] 145 | 146 | if return_attn: 147 | attn = attn1 @ attn2_inv @ attn3 148 | return out, attn 149 | 150 | return out 151 | 152 | # transformer 153 | 154 | class PreNorm(nn.Module): 155 | def __init__(self, dim, fn): 156 | super().__init__() 157 | self.norm = nn.LayerNorm(dim) 158 | self.fn = fn 159 | 160 | def forward(self, x, **kwargs): 161 | x = self.norm(x) 162 | return self.fn(x, **kwargs) 163 | 164 | class FeedForward(nn.Module): 165 | def __init__(self, dim, mult = 4, dropout = 0.): 166 | super().__init__() 167 | self.net = nn.Sequential( 168 | nn.Linear(dim, dim * mult), 169 | nn.GELU(), 170 | nn.Dropout(dropout), 171 | nn.Linear(dim * mult, dim) 172 | ) 173 | 174 | def forward(self, x): 175 | return self.net(x) 176 | 177 | class Nystromformer(nn.Module): 178 | def __init__( 179 | self, 180 | *, 181 | dim, 182 | depth, 183 | dim_head = 64, 184 | heads = 8, 185 | num_landmarks = 256, 186 | pinv_iterations = 6, 187 | attn_values_residual = True, 188 | attn_values_residual_conv_kernel = 33, 189 | attn_dropout = 0., 190 | ff_dropout = 0. 191 | ): 192 | super().__init__() 193 | 194 | self.layers = nn.ModuleList([]) 195 | for _ in range(depth): 196 | self.layers.append(nn.ModuleList([ 197 | PreNorm(dim, NystromAttention(dim = dim, dim_head = dim_head, heads = heads, num_landmarks = num_landmarks, pinv_iterations = pinv_iterations, residual = attn_values_residual, residual_conv_kernel = attn_values_residual_conv_kernel, dropout = attn_dropout)), 198 | PreNorm(dim, FeedForward(dim = dim, dropout = ff_dropout)) 199 | ])) 200 | 201 | def forward(self, x, mask = None): 202 | for attn, ff in self.layers: 203 | x = attn(x, mask = mask) + x 204 | x = ff(x) + x 205 | return x 206 | -------------------------------------------------------------------------------- /basic_net/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.utils.model_zoo as model_zoo 3 | 4 | 5 | __all__ = ['ResNet', 'resnet18','resnet18_stem', 'resnet34', 'resnet50', 'resnet101', 6 | 'resnet152'] 7 | 8 | 9 | model_urls = { 10 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 11 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 12 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 13 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 14 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 15 | } 16 | 17 | 18 | def conv3x3(in_planes, out_planes, stride=1): 19 | """3x3 convolution with padding""" 20 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 21 | padding=1, bias=False) 22 | 23 | 24 | def conv1x1(in_planes, out_planes, stride=1): 25 | """1x1 convolution""" 26 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 27 | 28 | 29 | class BasicBlock(nn.Module): 30 | expansion = 1 31 | 32 | def __init__(self, inplanes, planes, stride=1, downsample=None): 33 | super(BasicBlock, self).__init__() 34 | self.conv1 = conv3x3(inplanes, planes, stride) 35 | self.bn1 = nn.BatchNorm2d(planes) 36 | self.relu = nn.ReLU(inplace=True) 37 | self.conv2 = conv3x3(planes, planes) 38 | self.bn2 = nn.BatchNorm2d(planes) 39 | self.downsample = downsample 40 | self.stride = stride 41 | 42 | def forward(self, x): 43 | identity = x 44 | 45 | out = self.conv1(x) 46 | out = self.bn1(out) 47 | out = self.relu(out) 48 | 49 | out = self.conv2(out) 50 | out = self.bn2(out) 51 | 52 | if self.downsample is not None: 53 | identity = self.downsample(x) 54 | 55 | out += identity 56 | out = self.relu(out) 57 | 58 | return out 59 | 60 | 61 | class Bottleneck(nn.Module): 62 | expansion = 4 63 | 64 | def __init__(self, inplanes, planes, stride=1, downsample=None): 65 | super(Bottleneck, self).__init__() 66 | self.conv1 = conv1x1(inplanes, planes) 67 | self.bn1 = nn.BatchNorm2d(planes) 68 | self.conv2 = conv3x3(planes, planes, stride) 69 | self.bn2 = nn.BatchNorm2d(planes) 70 | self.conv3 = conv1x1(planes, planes * self.expansion) 71 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 72 | self.relu = nn.ReLU(inplace=True) 73 | self.downsample = downsample 74 | self.stride = stride 75 | 76 | def forward(self, x): 77 | identity = x 78 | 79 | out = self.conv1(x) 80 | out = self.bn1(out) 81 | out = self.relu(out) 82 | 83 | out = self.conv2(out) 84 | out = self.bn2(out) 85 | out = self.relu(out) 86 | 87 | out = self.conv3(out) 88 | out = self.bn3(out) 89 | 90 | if self.downsample is not None: 91 | identity = self.downsample(x) 92 | 93 | out += identity 94 | out = self.relu(out) 95 | 96 | return out 97 | 98 | 99 | class ResNet(nn.Module): 100 | 101 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False): 102 | super(ResNet, self).__init__() 103 | self.inplanes = 64 104 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 105 | bias=False) 106 | self.bn1 = nn.BatchNorm2d(64) 107 | self.relu = nn.ReLU(inplace=True) 108 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 109 | self.layer1 = self._make_layer(block, 64, layers[0]) 110 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 111 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 112 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 113 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 114 | self.fc = nn.Linear(512 * block.expansion, num_classes) 115 | 116 | for m in self.modules(): 117 | if isinstance(m, nn.Conv2d): 118 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 119 | elif isinstance(m, nn.BatchNorm2d): 120 | nn.init.constant_(m.weight, 1) 121 | nn.init.constant_(m.bias, 0) 122 | 123 | # Zero-initialize the last BN in each residual branch, 124 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 125 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 126 | if zero_init_residual: 127 | for m in self.modules(): 128 | if isinstance(m, Bottleneck): 129 | nn.init.constant_(m.bn3.weight, 0) 130 | elif isinstance(m, BasicBlock): 131 | nn.init.constant_(m.bn2.weight, 0) 132 | 133 | def _make_layer(self, block, planes, blocks, stride=1): 134 | downsample = None 135 | if stride != 1 or self.inplanes != planes * block.expansion: 136 | downsample = nn.Sequential( 137 | conv1x1(self.inplanes, planes * block.expansion, stride), 138 | nn.BatchNorm2d(planes * block.expansion), 139 | ) 140 | 141 | layers = [] 142 | layers.append(block(self.inplanes, planes, stride, downsample)) 143 | self.inplanes = planes * block.expansion 144 | for _ in range(1, blocks): 145 | layers.append(block(self.inplanes, planes)) 146 | 147 | return nn.Sequential(*layers) 148 | 149 | def forward(self, x): 150 | x = self.conv1(x) 151 | x = self.bn1(x) 152 | x = self.relu(x) 153 | x = self.maxpool(x) 154 | 155 | x = self.layer1(x) 156 | x = self.layer2(x) 157 | x = self.layer3(x) 158 | x = self.layer4(x) 159 | 160 | x = self.avgpool(x) 161 | x = x.view(x.size(0), -1) 162 | x = self.fc(x) 163 | 164 | return x 165 | 166 | class ResNet_stem(nn.Module): 167 | 168 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False): 169 | super(ResNet_stem, self).__init__() 170 | self.inplanes = 64 171 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 172 | bias=False) 173 | self.bn1 = nn.BatchNorm2d(64) 174 | self.relu = nn.ReLU(inplace=True) 175 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 176 | 177 | 178 | # Zero-initialize the last BN in each residual branch, 179 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 180 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 181 | if zero_init_residual: 182 | for m in self.modules(): 183 | if isinstance(m, Bottleneck): 184 | nn.init.constant_(m.bn3.weight, 0) 185 | elif isinstance(m, BasicBlock): 186 | nn.init.constant_(m.bn2.weight, 0) 187 | 188 | 189 | 190 | def forward(self, x): 191 | x = self.conv1(x) 192 | x = self.bn1(x) 193 | x = self.relu(x) 194 | x = self.maxpool(x) 195 | 196 | return x 197 | 198 | 199 | 200 | def resnet18(pretrained=False, **kwargs): 201 | """Constructs a ResNet-18 model. 202 | 203 | Args: 204 | pretrained (bool): If True, returns a model pre-trained on ImageNet 205 | """ 206 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 207 | if pretrained: 208 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18']),strict=False) 209 | return model 210 | 211 | def resnet18_stem(pretrained=False, **kwargs): 212 | """Constructs a ResNet-18 model. 213 | 214 | Args: 215 | pretrained (bool): If True, returns a model pre-trained on ImageNet 216 | """ 217 | model = ResNet_stem(BasicBlock, [2, 2, 2, 2], **kwargs) 218 | if pretrained: 219 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18']),strict=False) 220 | return model 221 | 222 | 223 | def resnet34(pretrained=False, **kwargs): 224 | """Constructs a ResNet-34 model. 225 | 226 | Args: 227 | pretrained (bool): If True, returns a model pre-trained on ImageNet 228 | """ 229 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 230 | if pretrained: 231 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 232 | return model 233 | 234 | 235 | def resnet50(pretrained=False, **kwargs): 236 | """Constructs a ResNet-50 model. 237 | 238 | Args: 239 | pretrained (bool): If True, returns a model pre-trained on ImageNet 240 | """ 241 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 242 | if pretrained: 243 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 244 | return model 245 | 246 | 247 | def resnet101(pretrained=False, **kwargs): 248 | """Constructs a ResNet-101 model. 249 | 250 | Args: 251 | pretrained (bool): If True, returns a model pre-trained on ImageNet 252 | """ 253 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 254 | if pretrained: 255 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 256 | return model 257 | 258 | 259 | def resnet152(pretrained=False, **kwargs): 260 | """Constructs a ResNet-152 model. 261 | 262 | Args: 263 | pretrained (bool): If True, returns a model pre-trained on ImageNet 264 | """ 265 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 266 | if pretrained: 267 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 268 | return model 269 | -------------------------------------------------------------------------------- /config/CDKN.yml: -------------------------------------------------------------------------------- 1 | name: TransMIL_fea # task-specific 1p19q CDKN Diag Grade His MLC fea img 2 | # 3 | command: Train # Test 4 | gpus: [2] 5 | 6 | 7 | dataDir: /mnt/disk10T/fuyibing/wxf_data/TCGA/brain/ #ali8k_3 8 | #dataDir: /mnt/disk10T/fyb/wxf_data/TCGA/brain/ #ali8k_1 9 | 10 | 11 | #### Network setting main 12 | Network: 13 | BasicMIL: # instance level information integration with simply; majority voting or standard bag label definition 14 | lr: 0.0003 15 | AMIL: 16 | lr: 0.0005 17 | TransMIL: 18 | lr: 0.0002 19 | CLAM: 20 | lr: 0.0002 21 | UACNN: 22 | lr: 0.0003 23 | PatchGCN: 24 | lr: 0.0002 25 | Mine: 26 | lr: 0.0003 27 | 28 | 29 | #### Training setting 30 | 31 | n_ep: 200 32 | n_ep_decay: 50 33 | decayType: step # step, linear, exp 34 | n_ep_save: 5 35 | resume_epoch: 0 # 100 36 | eva_cm: False 37 | batchSize: 1 38 | Val_batchSize: 1 39 | Test_batchSize: 1 40 | 41 | #### Directories 42 | logDir: ./logs 43 | saveDir: ./outs 44 | modelDir: ./models 45 | cm_saveDir: ./cm 46 | label_path: ./merge_who.xlsx 47 | 48 | #### Meta setting main 49 | dataset: TCGA 50 | nThreads: 16 51 | seed: 124 52 | imgSize: [224,224] 53 | -------------------------------------------------------------------------------- /config/cifar_pretrain.yml: -------------------------------------------------------------------------------- 1 | name: CLAM # model 2 | gpus: [1] 3 | batchSize: 1 4 | Val_batchSize: 1 5 | Test_batchSize: 1 6 | #dataDir: /mnt/disk10T_2/fuyibing/wxf_data/TCGA/brain/ #ali8k_2 7 | dataDir: /mnt/disk10T/fuyibing/wxf_data/TCGA/brain/ #ali8k_3 8 | #dataDir: /mnt/disk10T/fyb/wxf_data/TCGA/brain/ #ali8k_1 9 | #dataDir: /mnt/disk10T/fuyibing/wxf_data/TCGA/brain/ #ali4k 10 | 11 | 12 | #### Network setting main 13 | Network: 14 | BasicMIL: # instance level information integration with simply; majority voting or standard bag label definition 15 | lr: 0.0003 16 | AMIL: 17 | lr: 0.0005 18 | TransMIL: 19 | lr: 0.0002 20 | CLAM: 21 | lr: 0.003 22 | UACNN: 23 | lr: 0.0003 24 | PatchGCN: 25 | lr: 0.0002 26 | 27 | 28 | 29 | #### Training setting 30 | 31 | n_ep: 200 32 | n_ep_decay: 50 33 | decayType: exp # step, linear, exp,cos 34 | n_ep_save: 5 35 | resume_epoch: 0 # 100 36 | 37 | 38 | #### Directories 39 | logDir: ./logs 40 | saveDir: ./outs 41 | modelDir: ./models 42 | cm_saveDir: ./cm 43 | label_path: ./merge_who.xlsx 44 | 45 | #### Meta setting main 46 | dataset: TCGA 47 | nThreads: 16 48 | seed: 124 49 | imgSize: [224,224] 50 | -------------------------------------------------------------------------------- /config/miccai.yml: -------------------------------------------------------------------------------- 1 | name: CLAM_IDH_fea # mna incep dense alex res 2 | gpus: [0] 3 | batchSize: 1 4 | fixdim: 2500 5 | dataDir: /home/zeiler/WSI_proj/data/ #Adden_linux1 6 | #dataDir: /home/cbtil2/WSI_proj/data/ #Adden_linux2 7 | 8 | #### Network setting main 9 | Network: 10 | BasicMIL: # instance level information integration with simply; majority voting or standard bag label definition 11 | lr: 0.001 12 | AMIL: 13 | lr: 0.0002 14 | TransMIL: 15 | lr: 0.0002 16 | CLAM: 17 | lr: 0.0001 18 | UACNN: 19 | lr: 0.0003 20 | PatchGCN: 21 | lr: 0.0002 22 | 23 | 24 | #### Training setting 25 | 26 | n_ep: 70 27 | decay_cos_warmup_steps: 35 28 | n_ep_decay: 15 29 | decayType: exp # step, linear, exp,cos 30 | n_ep_save: 5 31 | resume_epoch: 70 # 100 32 | eva_cm: False 33 | dataLabels: ['G2_O', 'G3_O', 'G2_A', 'G3_A', 'G4_A' ,'GBM'] #['G2_O', 'G3_O', 'G2_A', 'G3_A','G2_OA','G3_OA' 'GBM'] 34 | 35 | #### Directories 36 | logDir: ./logs 37 | saveDir: ./outs 38 | modelDir: ./models 39 | cm_saveDir: ./cm 40 | label_path: ./merge_who.xlsx 41 | 42 | #### Meta setting main 43 | dataset: TCGA 44 | nThreads: 16 45 | seed: 124 46 | Val_batchSize: 1 47 | Test_batchSize: 1 48 | imgSize: [224,224] 49 | command: Train # Test 50 | -------------------------------------------------------------------------------- /config/mine.yml: -------------------------------------------------------------------------------- 1 | name: Mine # Mine_dim2500_seed124_pretrain_exp_45 2 | command: Train # Test 3 | gpus: [0,1] 4 | fixdim: 2500 5 | batchSize: 2 6 | decayType: cos # step, linear, exp cos 7 | #dataDir: /raid/qiaominglang/brain/ #hyy3 8 | # dataDir: /home/zeiler/WSI_proj/data/ #Adden_linux1 9 | #dataDir: /home/cbtil2/WSI_proj/data/ #Adden_linux2 10 | # dataDir: /home/cbtil3/WSI_proj/data/ #Adden_linux3 11 | # dataDir: /home/cbtil/ST_proj/LDH/data/ 12 | dataDir: /home/hanyu/ 13 | 14 | #### Network setting main 15 | Network: 16 | 17 | lr: 0.003 18 | dropout_rate: 0.1 19 | IDH_layers: 3 20 | 1p19q_layers: 2 21 | CDKN_layers: 2 22 | His_layers: 3 23 | Grade_layers: 1 24 | Trans_block: 'full' #'full' 'simple' 25 | graph_alpha: 0.1 26 | corre_loss_ratio: 0.1 27 | 28 | 29 | #### Training setting 30 | n_ep: 80 31 | decay_cos_warmup_steps: 25 32 | 33 | 34 | 35 | #### Directories 36 | logDir: ./writer/logs 37 | saveDir: ./writer/outs 38 | modelDir: ./writer/models 39 | cm_saveDir: ./writer/cm 40 | label_path: ./merge_who.xlsx 41 | TCGA_label_path: ./TCGA.xlsx 42 | IvYGAP_label_path: ./IvYGAP.xlsx 43 | CPTAC_label_path: ./CPTAC.xlsx 44 | #### Meta setting main 45 | dataset: TCGA 46 | nThreads: 16 47 | seed: 124 48 | imgSize: [224,224] 49 | eva_cm: False 50 | n_ep_save: 1 51 | fp16: False 52 | 53 | Val_batchSize: 1 54 | Test_batchSize: 1 55 | n_ep_decay: 15 56 | top_K_patch: 300 -------------------------------------------------------------------------------- /data_process.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import h5py 3 | import argparse, time 4 | 5 | import gc 6 | # root=r'/mnt/disk10T/fuyibing/wxf_data/TCGA/brain/npy/' 7 | root=r'/mnt/disk10T/fyb/wxf_data/TCGA/brain/npy/' 8 | 9 | import SimpleITK as sitk 10 | 11 | #####itk 12 | img=np.zeros(shape=(1200,3,224,224), dtype=np.uint8) 13 | out = sitk.GetImageFromArray(img) 14 | sitk.WriteImage(out, root+'simpleitk_save.nii.gz') 15 | 16 | 17 | #####npy 18 | img=np.zeros(shape=(1200,3,224,224), dtype=np.uint8) 19 | np.save(root+'img.npy', img) 20 | 21 | #####h5 22 | imgData=np.zeros(shape=(1200,3,224,224), dtype=np.uint8) 23 | with h5py.File(root+'test.h5','w') as f: 24 | f['data'] = imgData 25 | 26 | k=h5py.File(root+'test.h5')['data'][:] 27 | a=1 -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | from matplotlib import pyplot as plt 4 | from torch.utils.data import DataLoader,Dataset 5 | import numpy as np 6 | import os 7 | from PIL import Image 8 | from skimage import io,transform 9 | import cv2 10 | import torch 11 | import platform 12 | import pandas as pd 13 | import argparse, time, random 14 | import yaml 15 | from yaml.loader import SafeLoader 16 | from tqdm import tqdm 17 | import h5py 18 | import gc 19 | import math 20 | import scipy.interpolate 21 | from PIL import Image 22 | import cv2 23 | from matplotlib import pyplot as plt 24 | from torchvision.transforms import Compose 25 | import transform.transforms_group as our_transform 26 | from torchvision.transforms import Compose, ToTensor, ToPILImage, CenterCrop, Resize 27 | def train_transform(degree=180): 28 | return Compose([ 29 | our_transform.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.05), 30 | ]) 31 | class Our_Dataset(Dataset): 32 | def __init__(self, phase,opt): 33 | super(Our_Dataset, self).__init__() 34 | self.opt = opt 35 | self.patc_bs=64 36 | self.phase=phase 37 | # self.test_mode=opt['test_mode'] # WSI Patient 38 | self.name = opt['name'].split('_') 39 | 40 | excel_label_wsi = pd.read_excel(opt['label_path'],sheet_name='wsi_level',header=0) 41 | excel_wsi =excel_label_wsi.values 42 | PATIENT_LIST=excel_wsi[:,0] 43 | np.random.seed(self.opt['seed']) 44 | random.seed(self.opt['seed']) 45 | PATIENT_LIST=list(PATIENT_LIST) 46 | 47 | 48 | 49 | 50 | PATIENT_LIST=np.unique(PATIENT_LIST) 51 | np.random.shuffle(PATIENT_LIST) 52 | NUM_PATIENT_ALL=len(PATIENT_LIST) # 952 53 | TRAIN_PATIENT_LIST=PATIENT_LIST[0:int(NUM_PATIENT_ALL* 0.8)] 54 | VAL_PATIENT_LIST = PATIENT_LIST[int(NUM_PATIENT_ALL * 0.8):int(NUM_PATIENT_ALL * 0.9)] 55 | TEST_PATIENT_LIST = PATIENT_LIST[int(NUM_PATIENT_ALL * 0.9):] 56 | self.TRAIN_LIST=[] 57 | self.VAL_LIST = [] 58 | self.TEST_LIST = [] 59 | self.My_transform=train_transform() 60 | for i in range(excel_wsi.shape[0]):# 2612 61 | if excel_wsi[:,0][i] in TRAIN_PATIENT_LIST: 62 | self.TRAIN_LIST.append(excel_wsi[i,:]) 63 | elif excel_wsi[:,0][i] in VAL_PATIENT_LIST: 64 | self.VAL_LIST.append(excel_wsi[i,:]) 65 | elif excel_wsi[:,0][i] in TEST_PATIENT_LIST: 66 | self.TEST_LIST.append(excel_wsi[i,:]) 67 | self.LIST= np.asarray(self.TRAIN_LIST) if self.phase == 'Train' else (np.asarray(self.VAL_LIST) if self.phase == 'Val' else np.asarray(self.TEST_LIST)) 68 | # df = pd.DataFrame(self.LIST, columns=list(excel_label_wsi)) 69 | # df.to_excel("vis/Val.xlsx", index=False) 70 | 71 | self.train_iter_count=0 72 | self.Flat=0 73 | self.WSI_all=[] 74 | 75 | def __getitem__(self, index): 76 | 77 | 78 | 79 | if self.name[2]=='img': 80 | self.read_img(index) 81 | elif self.name[2]=='fea': 82 | patch_all,coor_all=self.read_feature(index) 83 | label=self.label_gene(index) 84 | 85 | return torch.from_numpy(np.array(patch_all)).float(),torch.from_numpy(np.array(label)).long(), self.LIST[index, 1],coor_all 86 | 87 | def read_feature(self, index): 88 | root = self.opt['dataDir']+'TCGA/Res50_feature_'+str(self.opt['fixdim'])+'_fixdim0/' 89 | patch_all=h5py.File(root+self.LIST[index, 1]+'.h5')['Res_feature'][:] #(1,N,1024) 90 | coor_all = h5py.File(root + self.LIST[index, 1] + '.h5')['patches_coor'][:] 91 | return patch_all ,coor_all 92 | def read_feature1(self, index,k): 93 | root = self.opt['dataDir']+'Res50_feature_1200_fixdim0_aug/aug_set'+str(k)+'/' 94 | patch_all=h5py.File(root+self.LIST[index, 1]+'.h5')['Res_feature'][:] #(1,N,1024) 95 | coor_all = h5py.File(root + self.LIST[index, 1] + '.h5')['patches_coor'][:] 96 | return patch_all ,coor_all 97 | 98 | 99 | 100 | def read_img(self,index): 101 | wsi_path = self.dataDir + self.LIST[index, 1] 102 | patch_all = [] 103 | patch_all_ori=[] 104 | coor_all=[] 105 | coor_all_ori = [] 106 | self.img_dir = os.listdir(wsi_path) 107 | 108 | read_details=np.load(self.opt['dataDir']+'read_details/'+self.LIST[index, 1]+'.npy',allow_pickle=True)[0] 109 | num_patches = read_details.shape[0] 110 | print(num_patches) 111 | max_num=2500 112 | Use_patch_num = num_patches if num_patches <= max_num else max_num 113 | if num_patches <= max_num: 114 | times=int(np.floor(max_num/num_patches)) 115 | remaining=max_num % num_patches 116 | for i in range(Use_patch_num): 117 | img_temp=io.imread(wsi_path + '/' + str(read_details[i][0]) + '_' + str(read_details[i][1]) + '.jpg') 118 | img_temp = cv2.resize(img_temp, (224, 224)) 119 | patch_all_ori.append(img_temp) 120 | coor_all_ori.append(read_details[i]) 121 | patch_all=patch_all_ori 122 | coor_all = coor_all_ori 123 | 124 | ####### fixdim0 125 | if times>1: 126 | for k in range(times-1): 127 | patch_all=patch_all+patch_all_ori 128 | coor_all=coor_all+coor_all_ori 129 | if not remaining==0: 130 | patch_all = patch_all + patch_all_ori[0:remaining] 131 | coor_all = coor_all + coor_all_ori[0:remaining] 132 | 133 | else: 134 | for i in range(Use_patch_num): 135 | img_temp = io.imread(wsi_path + '/' + str(read_details[int(np.around(i*(num_patches/max_num)))][0])+'_'+str(read_details[int(np.around(i*(num_patches/max_num)))][1])+'.jpg') 136 | img_temp = cv2.resize(img_temp, (224, 224)) 137 | patch_all.append(img_temp) 138 | coor_all.append(read_details[int(np.around(i*(num_patches/max_num)))]) 139 | 140 | patch_all = np.asarray(patch_all) 141 | 142 | # data augmentation 143 | patch_all = patch_all.reshape(-1, 224, 3) # (num_patches*224,224,3) 144 | patch_all = patch_all.reshape(-1, 224, 224, 3) # (num_patches,224,224,3) 145 | 146 | 147 | patch_all = patch_all / 255.0 148 | patch_all = np.transpose(patch_all, (0, 3, 1, 2)) 149 | patch_all = patch_all.astype(np.float32) 150 | 151 | 152 | return patch_all,coor_all 153 | 154 | def label_gene(self,index): 155 | his_label_map= {'glioblastoma'} 156 | grade_label_map = {'2':0,'3':1,'4':2} 157 | #grade 2021 = {0: 'G2', 1: 'G3_O', 2: 'G4'} 158 | # His 2021 = {0: 'A', 1: 'O', 2: 'GBM'} 159 | #label 2021={ 0:'G2_O', 1:'G3_O', 2:'G2_A', 3:'G3_A', 4:'G4_A', 5:'GBM'} 160 | if self.name[1]=='IDH': 161 | if self.LIST[index, 4]=='WT': 162 | label=0 163 | elif self.LIST[index, 4]=='Mutant': 164 | label=1 165 | elif self.name[1] == '1p19q': 166 | if self.LIST[index, 5] == 'non-codel': 167 | label = 0 168 | elif self.LIST[index, 5] == 'codel': 169 | label = 1 170 | elif self.name[1] == 'CDKN': 171 | if self.LIST[index, 6] == -2 or self.LIST[index, 6] == -1: 172 | label = 1 173 | else: 174 | label = 0 175 | 176 | elif self.name[1] == 'Diag': 177 | if self.LIST[index, 4] == 'WT': 178 | label = 0 179 | elif self.LIST[index, 5] == 'codel': 180 | label = 3 181 | else: 182 | if self.LIST[index, 6] == -2 or self.LIST[index, 6] == -1 or self.LIST[index, 3] == 'G4': 183 | label = 1 184 | else: 185 | label = 2 186 | elif self.name[1] == 'Grade': 187 | if self.LIST[index, 4] == 'WT': 188 | label = 2 189 | elif self.LIST[index, 5] == 'codel': 190 | label = 0 if self.LIST[index, 3] =='G2' else 1 191 | else: 192 | if self.LIST[index, 6] == -2 or self.LIST[index, 6] == -1 or self.LIST[index, 3] =='G4': 193 | label = 2 194 | else: 195 | label = 0 if self.LIST[index, 3] == 'G2' else 1 196 | elif self.name[1] == 'His': 197 | # if self.LIST[index, 2]=='astrocytoma': 198 | # label = 0 199 | # elif self.LIST[index, 2] == 'oligoastrocytoma': 200 | # label = 1 201 | # elif self.LIST[index, 2] == 'oligodendroglioma': 202 | # label = 2 203 | # elif self.LIST[index, 2] == 'glioblastoma': 204 | # label = 3 205 | if self.LIST[index, 2] == 'glioblastoma': 206 | label = 1 207 | else: 208 | label = 0 209 | 210 | 211 | 212 | return label 213 | 214 | 215 | def shuffle_list(self, seed): 216 | np.random.seed(seed) 217 | random.seed(seed) 218 | np.random.shuffle(self.LIST) 219 | 220 | 221 | 222 | def __len__(self): 223 | return self.LIST.shape[0] 224 | 225 | 226 | 227 | if __name__ == '__main__': 228 | 229 | parser = argparse.ArgumentParser() 230 | parser.add_argument('--opt', type=str, default='config/miccai.yml') 231 | args = parser.parse_args() 232 | with open(args.opt) as f: 233 | opt = yaml.load(f, Loader=SafeLoader) 234 | trainDataset = Our_Dataset(phase='Val', opt=opt) 235 | for i in range(100): 236 | trainDataset._getitem__(i) -------------------------------------------------------------------------------- /dataset_mine copy 2.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | from matplotlib import pyplot as plt 4 | from torch.utils.data import DataLoader,Dataset 5 | import numpy as np 6 | import os 7 | from PIL import Image 8 | from skimage import io,transform 9 | import cv2 10 | import torch 11 | import platform 12 | import pandas as pd 13 | import argparse, time, random 14 | import yaml 15 | from yaml.loader import SafeLoader 16 | from tqdm import tqdm 17 | import h5py 18 | import gc 19 | import math 20 | import scipy.interpolate 21 | from PIL import Image 22 | import cv2 23 | from matplotlib import pyplot as plt 24 | from torchvision.transforms import Compose 25 | import transform.transforms_group as our_transform 26 | 27 | class Our_Dataset(Dataset): 28 | def __init__(self, phase,opt,if_end2end=False): 29 | super(Our_Dataset, self).__init__() 30 | self.opt = opt 31 | self.patc_bs=64 32 | self.phase=phase 33 | self.if_end2end=if_end2end 34 | 35 | CPTAC_label = pd.read_excel(opt['CPTAC_label_path'], header=0) 36 | IvYGAP_label = pd.read_excel(opt['IvYGAP_label_path'], sheet_name='Sheet1', header=0) 37 | TCGA_label = pd.read_excel(opt['TCGA_label_path'], sheet_name='Sheet1', header=0) 38 | combined_labels = pd.concat([TCGA_label, CPTAC_label], ignore_index=True) 39 | excel_wsi = combined_labels.values 40 | 41 | PATIENT_LIST=excel_wsi[:,0] 42 | np.random.seed(self.opt['seed']) 43 | random.seed(self.opt['seed']) 44 | PATIENT_LIST=list(PATIENT_LIST) 45 | # IvYGAP_label 46 | IvYGAP_label = IvYGAP_label.values 47 | 48 | PATIENT_LIST=np.unique(PATIENT_LIST) 49 | np.random.shuffle(PATIENT_LIST) 50 | NUM_PATIENT_ALL=len(PATIENT_LIST) # 952 51 | TRAIN_PATIENT_LIST=PATIENT_LIST[0:int(NUM_PATIENT_ALL * 0.8)] 52 | VAL_PATIENT_LIST = PATIENT_LIST[int(NUM_PATIENT_ALL * 0.9):] 53 | TEST_PATIENT_LIST = PATIENT_LIST[int(NUM_PATIENT_ALL * 0.80):int(NUM_PATIENT_ALL * 0.90)] 54 | self.TRAIN_LIST=[] 55 | self.VAL_LIST = [] 56 | self.TEST_LIST = [] 57 | self.I_TEST_LIST = [] 58 | for i in range(excel_wsi.shape[0]):# 2612 59 | if excel_wsi[:,0][i] in TRAIN_PATIENT_LIST: 60 | self.TRAIN_LIST.append(excel_wsi[i,:]) 61 | elif excel_wsi[:,0][i] in VAL_PATIENT_LIST: 62 | self.VAL_LIST.append(excel_wsi[i,:]) 63 | elif excel_wsi[:,0][i] in TEST_PATIENT_LIST: 64 | self.TEST_LIST.append(excel_wsi[i,:]) 65 | 66 | for i in range(IvYGAP_label.shape[0]):# 2612 67 | self.I_TEST_LIST.append(IvYGAP_label[i,:]) 68 | self.LIST= np.asarray(self.TRAIN_LIST) if self.phase == 'Train' else np.asarray(self.VAL_LIST) if self.phase == 'Val' else np.asarray(self.TEST_LIST) if self.phase == 'Test' else np.asarray(self.I_TEST_LIST) 69 | 70 | self.train_iter_count=0 71 | self.Flat=0 72 | self.WSI_all=[] 73 | 74 | def __getitem__(self, index): 75 | feature_all_20,feature_all_10, = self.read_feature(index) 76 | 77 | label=self.label_gene(index) 78 | 79 | return torch.from_numpy(np.array(feature_all_20)).float(),torch.from_numpy(np.array(feature_all_10)).float(),\ 80 | torch.from_numpy(label) 81 | 82 | def read_feature(self, index): 83 | 84 | root = '/Res50_feature_2500_fixdim0_norm' 85 | 86 | patient_id = self.LIST[index, 0] 87 | 88 | 89 | if patient_id[0].startswith('T'): 90 | base_path = self.opt['dataDir'] + 'TCGA' 91 | elif patient_id[0].startswith('W'): 92 | base_path = self.opt['dataDir'] + 'IvYGAP' 93 | elif patient_id[0].startswith('C'): 94 | base_path = self.opt['dataDir'] + 'CPTAC' 95 | else: 96 | raise ValueError("Unknown data source") 97 | 98 | patch_20 = h5py.File(base_path + root + '_20x/' + self.LIST[index, 1] + '.h5')['Res_feature'][:] 99 | patch_10 = h5py.File(base_path + root + '/' + self.LIST[index, 1] + '.h5')['Res_feature'][:] 100 | return patch_20[0], patch_10[0]#, patch_1_25[0] 101 | 102 | 103 | def label_gene(self,index): 104 | 105 | 106 | if self.LIST[index, 4]=='WT': 107 | label_IDH=0 108 | elif self.LIST[index, 4]=='Mutant': 109 | label_IDH=1 110 | if self.LIST[index, 5] == 'non-codel': 111 | label_1p19q = 0 112 | elif self.LIST[index, 5] == 'codel': 113 | label_1p19q = 1 114 | if self.LIST[index, 6] == -2 or self.LIST[index, 6] == -1: 115 | label_CDKN = 1 116 | else: 117 | label_CDKN = 0 118 | 119 | if self.LIST[index, 2]=='oligoastrocytoma': 120 | label_His = 0 121 | elif self.LIST[index, 2] == 'astrocytoma': 122 | label_His = 1 123 | elif self.LIST[index, 2] == 'oligodendroglioma': 124 | label_His = 2 125 | elif self.LIST[index, 2] == 'glioblastoma': 126 | label_His = 3 127 | 128 | if self.LIST[index, 2]=='glioblastoma': 129 | label_His_2class = 1 130 | else: 131 | label_His_2class = 0 132 | 133 | if self.LIST[index, 3]=='G2': 134 | label_Grade=0 135 | elif self.LIST[index, 3] == 'G3': 136 | label_Grade = 1 137 | else: 138 | label_Grade=2 #### Useless 139 | 140 | 141 | if self.LIST[index, 4]=='WT': 142 | label_Diag = 0 143 | elif self.LIST[index, 5] == 'codel': 144 | label_Diag = 3 145 | else: 146 | if self.LIST[index, 6] == -2 or self.LIST[index, 6] == -1 or self.LIST[index, 3] =='G4': 147 | label_Diag = 1 148 | else: 149 | label_Diag = 2 150 | 151 | 152 | label=np.asarray([label_IDH,label_1p19q,label_CDKN,label_His,label_Grade,label_Diag,label_His_2class]) 153 | 154 | return label 155 | 156 | 157 | def shuffle_list(self, seed): 158 | np.random.seed(seed) 159 | random.seed(seed) 160 | np.random.shuffle(self.LIST) 161 | 162 | 163 | 164 | def __len__(self): 165 | return self.LIST.shape[0] 166 | 167 | class Our_Dataset_vis(Dataset): 168 | def __init__(self, phase,opt,if_end2end=False): 169 | super(Our_Dataset_vis, self).__init__() 170 | self.opt = opt 171 | self.patc_bs=64 172 | self.phase=phase 173 | self.if_end2end=if_end2end 174 | self.dataDir = (opt['dataDir']+'extract_224/') if opt['imgSize'][0]==224 else (opt['dataDir']+'extract_512/') 175 | 176 | excel_label_wsi = pd.read_excel(opt['label_path'],sheet_name='wsi_level',header=0) 177 | excel_wsi =excel_label_wsi.values 178 | PATIENT_LIST=excel_wsi[:,0] 179 | np.random.seed(self.opt['seed']) 180 | random.seed(self.opt['seed']) 181 | PATIENT_LIST=list(PATIENT_LIST) 182 | 183 | 184 | PATIENT_LIST=np.unique(PATIENT_LIST) 185 | np.random.shuffle(PATIENT_LIST) 186 | NUM_PATIENT_ALL=len(PATIENT_LIST) # 952 187 | TEST_PATIENT_LIST=PATIENT_LIST[0:int(NUM_PATIENT_ALL)] 188 | TEST_WSI_LIST=os.listdir(r'/home/zeiler/WSI_proj/miccai/vis_results/set0/') 189 | self.TRAIN_LIST=[] 190 | self.VAL_LIST = [] 191 | self.TEST_LIST = [] 192 | for i in range(excel_wsi.shape[0]):# 2612 193 | if excel_wsi[:,1][i]+'.h5' in TEST_WSI_LIST: 194 | self.TEST_LIST.append(excel_wsi[i,:]) 195 | self.LIST= np.asarray(self.TEST_LIST) 196 | 197 | 198 | self.train_iter_count=0 199 | self.Flat=0 200 | self.WSI_all=[] 201 | 202 | def __getitem__(self, index): 203 | feature_all ,read_details,self.LIST[index, 1]= self.read_feature(index) 204 | 205 | label=self.label_gene(index) 206 | 207 | return torch.from_numpy(np.array(feature_all)).float(),torch.from_numpy(label),read_details,self.LIST[index, 1] 208 | 209 | def read_feature(self, index): 210 | read_details = np.load(self.opt['dataDir'] + 'read_details/' + self.LIST[index, 1] + '.npy', allow_pickle=True)[ 211 | 0] 212 | num_patches = read_details.shape[0] 213 | root = self.opt['dataDir']+'Res50_feature_'+str(self.opt['fixdim'])+'_fixdim0/' 214 | patch_all = h5py.File(root + self.LIST[index, 1] + '.h5')['Res_feature'][:] # (1,N,1024) 215 | return patch_all[0],read_details,self.LIST[index, 1] 216 | 217 | 218 | def label_gene(self,index): 219 | 220 | 221 | if self.LIST[index, 4]=='WT': 222 | label_IDH=0 223 | elif self.LIST[index, 4]=='Mutant': 224 | label_IDH=1 225 | if self.LIST[index, 5] == 'non-codel': 226 | label_1p19q = 0 227 | elif self.LIST[index, 5] == 'codel': 228 | label_1p19q = 1 229 | if self.LIST[index, 6] == -2 or self.LIST[index, 6] == -1: 230 | label_CDKN = 1 231 | else: 232 | label_CDKN = 0 233 | 234 | if self.LIST[index, 2]=='oligoastrocytoma': 235 | label_His = 0 236 | elif self.LIST[index, 2] == 'astrocytoma': 237 | label_His = 1 238 | elif self.LIST[index, 2] == 'oligodendroglioma': 239 | label_His = 2 240 | elif self.LIST[index, 2] == 'glioblastoma': 241 | label_His = 3 242 | 243 | if self.LIST[index, 2]=='glioblastoma': 244 | label_His_2class = 1 245 | else: 246 | label_His_2class = 0 247 | 248 | if self.LIST[index, 3]=='G2': 249 | label_Grade=0 250 | elif self.LIST[index, 3] == 'G3': 251 | label_Grade = 1 252 | else: 253 | label_Grade=2 #### Useless 254 | 255 | 256 | if self.LIST[index, 4]=='WT': 257 | label_Diag = 0 258 | elif self.LIST[index, 5] == 'codel': 259 | label_Diag = 3 260 | else: 261 | if self.LIST[index, 6] == -2 or self.LIST[index, 6] == -1 or self.LIST[index, 3] =='G4': 262 | label_Diag = 1 263 | else: 264 | label_Diag = 2 265 | 266 | 267 | label=np.asarray([label_IDH,label_1p19q,label_CDKN,label_His,label_Grade,label_Diag,label_His_2class]) 268 | 269 | return label 270 | 271 | 272 | def shuffle_list(self, seed): 273 | np.random.seed(seed) 274 | random.seed(seed) 275 | np.random.shuffle(self.LIST) 276 | 277 | 278 | 279 | def __len__(self): 280 | return self.LIST.shape[0] 281 | 282 | if __name__ == '__main__': 283 | # epoch_seed1 = np.arange(1000) 284 | # np.random.seed(100) 285 | # random.seed(100) 286 | # epoch_seed = np.arange(5) 287 | # np.random.shuffle(epoch_seed) 288 | # for i in range(5): 289 | # np.random.seed(epoch_seed[i]) 290 | # random.seed(epoch_seed[i]) 291 | # np.random.shuffle(epoch_seed1) 292 | # print(epoch_seed1[:20]) 293 | 294 | 295 | parser = argparse.ArgumentParser() 296 | parser.add_argument('--opt', type=str, default='config/mine.yml') 297 | args = parser.parse_args() 298 | with open(args.opt) as f: 299 | opt = yaml.load(f, Loader=SafeLoader) 300 | trainDataset = Our_Dataset(phase='Train', opt=opt) 301 | for i in range(100): 302 | _,x_,y_=trainDataset._getitem__(index=2000-i) -------------------------------------------------------------------------------- /dataset_mine copy.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | from matplotlib import pyplot as plt 4 | from torch.utils.data import DataLoader,Dataset 5 | import numpy as np 6 | import os 7 | from PIL import Image 8 | from skimage import io,transform 9 | import cv2 10 | import torch 11 | import platform 12 | import pandas as pd 13 | import argparse, time, random 14 | import yaml 15 | from yaml.loader import SafeLoader 16 | from tqdm import tqdm 17 | import h5py 18 | import gc 19 | import math 20 | import scipy.interpolate 21 | from PIL import Image 22 | import cv2 23 | from matplotlib import pyplot as plt 24 | from torchvision.transforms import Compose 25 | import transform.transforms_group as our_transform 26 | 27 | class Our_Dataset(Dataset): 28 | def __init__(self, phase,opt,if_end2end=False): 29 | super(Our_Dataset, self).__init__() 30 | self.opt = opt 31 | self.patc_bs=64 32 | self.phase=phase 33 | self.if_end2end=if_end2end 34 | 35 | CPTAC_label = pd.read_excel(opt['CPTAC_label_path'], header=0) 36 | IvYGAP_label = pd.read_excel(opt['IvYGAP_label_path'], sheet_name='Sheet1', header=0) 37 | TCGA_label = pd.read_excel(opt['TCGA_label_path'], sheet_name='wsi_level', header=0) 38 | combined_labels = pd.concat([TCGA_label, CPTAC_label], ignore_index=True) 39 | excel_wsi = combined_labels.values 40 | 41 | PATIENT_LIST=excel_wsi[:,0] 42 | np.random.seed(self.opt['seed']) 43 | random.seed(self.opt['seed']) 44 | PATIENT_LIST=list(PATIENT_LIST) 45 | # IvYGAP_label 46 | IvYGAP_label = IvYGAP_label.values 47 | 48 | PATIENT_LIST=np.unique(PATIENT_LIST) 49 | np.random.shuffle(PATIENT_LIST) 50 | NUM_PATIENT_ALL=len(PATIENT_LIST) # 952 51 | TRAIN_PATIENT_LIST=PATIENT_LIST[0:int(NUM_PATIENT_ALL * 0.8)] 52 | # VAL_PATIENT_LIST = PATIENT_LIST[int(NUM_PATIENT_ALL * 0.9):] 53 | TEST_PATIENT_LIST = PATIENT_LIST[int(NUM_PATIENT_ALL * 0.80):int(NUM_PATIENT_ALL * 0.90)] 54 | self.TRAIN_LIST=[] 55 | self.VAL_LIST = [] 56 | self.TEST_LIST = [] 57 | self.I_TEST_LIST = [] 58 | 59 | for i in range(excel_wsi.shape[0]):# 2612 60 | if excel_wsi[:,0][i] in TRAIN_PATIENT_LIST: 61 | self.TRAIN_LIST.append(excel_wsi[i,:]) 62 | # elif excel_wsi[:,0][i] in VAL_PATIENT_LIST: 63 | # self.VAL_LIST.append(excel_wsi[i,:]) 64 | elif excel_wsi[:,0][i] in TEST_PATIENT_LIST: 65 | self.TEST_LIST.append(excel_wsi[i,:]) 66 | 67 | for i in range(IvYGAP_label.shape[0]):# 2612 68 | self.I_TEST_LIST.append(IvYGAP_label[i,:]) 69 | self.LIST= np.asarray(self.TRAIN_LIST) if self.phase == 'Train' else np.asarray(self.VAL_LIST) if self.phase == 'Val' else np.asarray(self.TEST_LIST) if self.phase == 'Test' else np.asarray(self.I_TEST_LIST) 70 | 71 | 72 | self.train_iter_count=0 73 | self.Flat=0 74 | self.WSI_all=[] 75 | 76 | def __getitem__(self, index): 77 | feature_all_20,feature_all_10, = self.read_feature(index) 78 | 79 | label=self.label_gene(index) 80 | 81 | return torch.from_numpy(np.array(feature_all_20)).float(),torch.from_numpy(np.array(feature_all_10)).float(),\ 82 | torch.from_numpy(label) 83 | 84 | def read_feature(self, index): 85 | 86 | root = '/Res50_feature_2500_fixdim0_norm' 87 | 88 | patient_id = self.LIST[index, 0] 89 | 90 | 91 | if patient_id[0].startswith('T'): 92 | base_path = self.opt['dataDir'] + 'TCGA' 93 | elif patient_id[0].startswith('W'): 94 | base_path = self.opt['dataDir'] + 'IvYGAP' 95 | elif patient_id[0].startswith('C'): 96 | base_path = self.opt['dataDir'] + 'CPTAC' 97 | else: 98 | raise ValueError("Unknown data source") 99 | 100 | patch_20 = h5py.File(base_path + root + '_20x/' + self.LIST[index, 1] + '.h5')['Res_feature'][:] 101 | patch_10 = h5py.File(base_path + root + '/' + self.LIST[index, 1] + '.h5')['Res_feature'][:] 102 | return patch_20[0], patch_10[0]#, patch_1_25[0] 103 | 104 | 105 | def label_gene(self,index): 106 | 107 | 108 | if self.LIST[index, 4]=='WT': 109 | label_IDH=0 110 | elif self.LIST[index, 4]=='Mutant': 111 | label_IDH=1 112 | if self.LIST[index, 5] == 'non-codel': 113 | label_1p19q = 0 114 | elif self.LIST[index, 5] == 'codel': 115 | label_1p19q = 1 116 | if self.LIST[index, 6] == -2 or self.LIST[index, 6] == -1: 117 | label_CDKN = 1 118 | else: 119 | label_CDKN = 0 120 | 121 | if self.LIST[index, 2]=='oligoastrocytoma': 122 | label_His = 0 123 | elif self.LIST[index, 2] == 'astrocytoma': 124 | label_His = 1 125 | elif self.LIST[index, 2] == 'oligodendroglioma': 126 | label_His = 2 127 | elif self.LIST[index, 2] == 'glioblastoma': 128 | label_His = 3 129 | 130 | if self.LIST[index, 2]=='glioblastoma': 131 | label_His_2class = 1 132 | else: 133 | label_His_2class = 0 134 | 135 | if self.LIST[index, 3]=='G2': 136 | label_Grade=0 137 | elif self.LIST[index, 3] == 'G3': 138 | label_Grade = 1 139 | else: 140 | label_Grade=2 #### Useless 141 | 142 | 143 | if self.LIST[index, 4]=='WT': 144 | label_Diag = 0 145 | elif self.LIST[index, 5] == 'codel': 146 | label_Diag = 3 147 | else: 148 | if self.LIST[index, 6] == -2 or self.LIST[index, 6] == -1 or self.LIST[index, 3] =='G4': 149 | label_Diag = 1 150 | else: 151 | label_Diag = 2 152 | 153 | 154 | label=np.asarray([label_IDH,label_1p19q,label_CDKN,label_His,label_Grade,label_Diag,label_His_2class]) 155 | 156 | return label 157 | 158 | 159 | def shuffle_list(self, seed): 160 | np.random.seed(seed) 161 | random.seed(seed) 162 | np.random.shuffle(self.LIST) 163 | 164 | 165 | 166 | def __len__(self): 167 | return self.LIST.shape[0] 168 | 169 | class Our_Dataset_vis(Dataset): 170 | def __init__(self, phase,opt,if_end2end=False): 171 | super(Our_Dataset_vis, self).__init__() 172 | self.opt = opt 173 | self.patc_bs=64 174 | self.phase=phase 175 | self.if_end2end=if_end2end 176 | self.dataDir = (opt['dataDir']+'extract_224/') if opt['imgSize'][0]==224 else (opt['dataDir']+'extract_512/') 177 | 178 | excel_label_wsi = pd.read_excel(opt['label_path'],sheet_name='wsi_level',header=0) 179 | excel_wsi =excel_label_wsi.values 180 | PATIENT_LIST=excel_wsi[:,0] 181 | np.random.seed(self.opt['seed']) 182 | random.seed(self.opt['seed']) 183 | PATIENT_LIST=list(PATIENT_LIST) 184 | 185 | 186 | PATIENT_LIST=np.unique(PATIENT_LIST) 187 | np.random.shuffle(PATIENT_LIST) 188 | NUM_PATIENT_ALL=len(PATIENT_LIST) # 952 189 | TEST_PATIENT_LIST=PATIENT_LIST[0:int(NUM_PATIENT_ALL)] 190 | TEST_WSI_LIST=os.listdir(r'/home/zeiler/WSI_proj/miccai/vis_results/set0/') 191 | self.TRAIN_LIST=[] 192 | self.VAL_LIST = [] 193 | self.TEST_LIST = [] 194 | for i in range(excel_wsi.shape[0]):# 2612 195 | if excel_wsi[:,1][i]+'.h5' in TEST_WSI_LIST: 196 | self.TEST_LIST.append(excel_wsi[i,:]) 197 | self.LIST= np.asarray(self.TEST_LIST) 198 | 199 | 200 | self.train_iter_count=0 201 | self.Flat=0 202 | self.WSI_all=[] 203 | 204 | def __getitem__(self, index): 205 | feature_all ,read_details,self.LIST[index, 1]= self.read_feature(index) 206 | 207 | label=self.label_gene(index) 208 | 209 | return torch.from_numpy(np.array(feature_all)).float(),torch.from_numpy(label),read_details,self.LIST[index, 1] 210 | 211 | def read_feature(self, index): 212 | read_details = np.load(self.opt['dataDir'] + 'read_details/' + self.LIST[index, 1] + '.npy', allow_pickle=True)[ 213 | 0] 214 | num_patches = read_details.shape[0] 215 | root = self.opt['dataDir']+'Res50_feature_'+str(self.opt['fixdim'])+'_fixdim0/' 216 | patch_all = h5py.File(root + self.LIST[index, 1] + '.h5')['Res_feature'][:] # (1,N,1024) 217 | return patch_all[0],read_details,self.LIST[index, 1] 218 | 219 | 220 | def label_gene(self,index): 221 | 222 | 223 | if self.LIST[index, 4]=='WT': 224 | label_IDH=0 225 | elif self.LIST[index, 4]=='Mutant': 226 | label_IDH=1 227 | if self.LIST[index, 5] == 'non-codel': 228 | label_1p19q = 0 229 | elif self.LIST[index, 5] == 'codel': 230 | label_1p19q = 1 231 | if self.LIST[index, 6] == -2 or self.LIST[index, 6] == -1: 232 | label_CDKN = 1 233 | else: 234 | label_CDKN = 0 235 | 236 | if self.LIST[index, 2]=='oligoastrocytoma': 237 | label_His = 0 238 | elif self.LIST[index, 2] == 'astrocytoma': 239 | label_His = 1 240 | elif self.LIST[index, 2] == 'oligodendroglioma': 241 | label_His = 2 242 | elif self.LIST[index, 2] == 'glioblastoma': 243 | label_His = 3 244 | 245 | if self.LIST[index, 2]=='glioblastoma': 246 | label_His_2class = 1 247 | else: 248 | label_His_2class = 0 249 | 250 | if self.LIST[index, 3]=='G2': 251 | label_Grade=0 252 | elif self.LIST[index, 3] == 'G3': 253 | label_Grade = 1 254 | else: 255 | label_Grade=2 #### Useless 256 | 257 | 258 | if self.LIST[index, 4]=='WT': 259 | label_Diag = 0 260 | elif self.LIST[index, 5] == 'codel': 261 | label_Diag = 3 262 | else: 263 | if self.LIST[index, 6] == -2 or self.LIST[index, 6] == -1 or self.LIST[index, 3] =='G4': 264 | label_Diag = 1 265 | else: 266 | label_Diag = 2 267 | 268 | 269 | label=np.asarray([label_IDH,label_1p19q,label_CDKN,label_His,label_Grade,label_Diag,label_His_2class]) 270 | 271 | return label 272 | 273 | 274 | def shuffle_list(self, seed): 275 | np.random.seed(seed) 276 | random.seed(seed) 277 | np.random.shuffle(self.LIST) 278 | 279 | 280 | 281 | def __len__(self): 282 | return self.LIST.shape[0] 283 | 284 | if __name__ == '__main__': 285 | # epoch_seed1 = np.arange(1000) 286 | # np.random.seed(100) 287 | # random.seed(100) 288 | # epoch_seed = np.arange(5) 289 | # np.random.shuffle(epoch_seed) 290 | # for i in range(5): 291 | # np.random.seed(epoch_seed[i]) 292 | # random.seed(epoch_seed[i]) 293 | # np.random.shuffle(epoch_seed1) 294 | # print(epoch_seed1[:20]) 295 | 296 | 297 | parser = argparse.ArgumentParser() 298 | parser.add_argument('--opt', type=str, default='config/mine.yml') 299 | args = parser.parse_args() 300 | with open(args.opt) as f: 301 | opt = yaml.load(f, Loader=SafeLoader) 302 | trainDataset = Our_Dataset(phase='Train', opt=opt) 303 | for i in range(100): 304 | _,x_,y_=trainDataset._getitem__(index=2000-i) -------------------------------------------------------------------------------- /dataset_mine.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | from matplotlib import pyplot as plt 4 | from torch.utils.data import DataLoader,Dataset 5 | import numpy as np 6 | import os 7 | from PIL import Image 8 | from skimage import io,transform 9 | import cv2 10 | import torch 11 | import platform 12 | import pandas as pd 13 | import argparse, time, random 14 | import yaml 15 | from yaml.loader import SafeLoader 16 | from tqdm import tqdm 17 | import h5py 18 | import gc 19 | import math 20 | import scipy.interpolate 21 | from PIL import Image 22 | import cv2 23 | from matplotlib import pyplot as plt 24 | from torchvision.transforms import Compose 25 | import transform.transforms_group as our_transform 26 | 27 | class Our_Dataset(Dataset): 28 | def __init__(self, phase,opt,if_end2end=False): 29 | super(Our_Dataset, self).__init__() 30 | self.opt = opt 31 | self.patc_bs=64 32 | self.phase=phase 33 | self.if_end2end=if_end2end 34 | 35 | CPTAC_label = pd.read_excel(opt['CPTAC_label_path'], header=0) 36 | IvYGAP_label = pd.read_excel(opt['IvYGAP_label_path'], sheet_name='Sheet1', header=0) 37 | TCGA_label = pd.read_excel(opt['TCGA_label_path'], sheet_name='Sheet1', header=0) 38 | combined_labels = pd.concat([TCGA_label, CPTAC_label], ignore_index=True) 39 | excel_wsi = combined_labels.values 40 | 41 | PATIENT_LIST=excel_wsi[:,0] 42 | np.random.seed(self.opt['seed']) 43 | random.seed(self.opt['seed']) 44 | PATIENT_LIST=list(PATIENT_LIST) 45 | # IvYGAP_label 46 | IvYGAP_label = IvYGAP_label.values 47 | 48 | PATIENT_LIST=np.unique(PATIENT_LIST) 49 | np.random.shuffle(PATIENT_LIST) 50 | NUM_PATIENT_ALL=len(PATIENT_LIST) # 952 51 | TRAIN_PATIENT_LIST=PATIENT_LIST[0:int(NUM_PATIENT_ALL * 0.8)] 52 | VAL_PATIENT_LIST = PATIENT_LIST[int(NUM_PATIENT_ALL * 0.9):] 53 | TEST_PATIENT_LIST = PATIENT_LIST[int(NUM_PATIENT_ALL * 0.80):int(NUM_PATIENT_ALL * 0.90)] 54 | self.TRAIN_LIST=[] 55 | self.VAL_LIST = [] 56 | self.TEST_LIST = [] 57 | self.I_TEST_LIST = [] 58 | for i in range(excel_wsi.shape[0]):# 2612 59 | if excel_wsi[:,0][i] in TRAIN_PATIENT_LIST: 60 | self.TRAIN_LIST.append(excel_wsi[i,:]) 61 | # elif excel_wsi[:,0][i] in VAL_PATIENT_LIST: 62 | # self.VAL_LIST.append(excel_wsi[i,:]) 63 | elif excel_wsi[:,0][i] in TEST_PATIENT_LIST: 64 | self.TEST_LIST.append(excel_wsi[i,:]) 65 | 66 | for i in range(IvYGAP_label.shape[0]):# 2612 67 | self.I_TEST_LIST.append(IvYGAP_label[i,:]) 68 | self.LIST= np.asarray(self.TRAIN_LIST) if self.phase == 'Train' else np.asarray(self.VAL_LIST) if self.phase == 'Val' else np.asarray(self.TEST_LIST) if self.phase == 'Test' else np.asarray(self.I_TEST_LIST) 69 | 70 | self.train_iter_count=0 71 | self.Flat=0 72 | self.WSI_all=[] 73 | 74 | def __getitem__(self, index): 75 | feature_all_20,feature_all_10, = self.read_feature(index) 76 | 77 | label=self.label_gene(index) 78 | 79 | return torch.from_numpy(np.array(feature_all_20)).float(),torch.from_numpy(np.array(feature_all_10)).float(),\ 80 | torch.from_numpy(label) 81 | 82 | def read_feature(self, index): 83 | 84 | root = '/Res50_feature_2500_fixdim0_norm' 85 | 86 | patient_id = self.LIST[index, 0] 87 | 88 | 89 | if patient_id[0].startswith('T'): 90 | base_path = self.opt['dataDir'] + 'TCGA' 91 | elif patient_id[0].startswith('W'): 92 | base_path = self.opt['dataDir'] + 'IvYGAP' 93 | elif patient_id[0].startswith('C'): 94 | base_path = self.opt['dataDir'] + 'CPTAC' 95 | else: 96 | raise ValueError("Unknown data source") 97 | 98 | patch_20 = h5py.File(base_path + root + '_20x/' + self.LIST[index, 1] + '.h5')['Res_feature'][:] 99 | patch_10 = h5py.File(base_path + root + '/' + self.LIST[index, 1] + '.h5')['Res_feature'][:] 100 | return patch_20[0], patch_10[0]#, patch_1_25[0] 101 | 102 | 103 | def label_gene(self,index): 104 | 105 | 106 | if self.LIST[index, 4]=='WT': 107 | label_IDH=0 108 | elif self.LIST[index, 4]=='Mutant': 109 | label_IDH=1 110 | if self.LIST[index, 5] == 'non-codel': 111 | label_1p19q = 0 112 | elif self.LIST[index, 5] == 'codel': 113 | label_1p19q = 1 114 | if self.LIST[index, 6] == -2 or self.LIST[index, 6] == -1: 115 | label_CDKN = 1 116 | else: 117 | label_CDKN = 0 118 | 119 | if self.LIST[index, 2]=='oligoastrocytoma': 120 | label_His = 0 121 | elif self.LIST[index, 2] == 'astrocytoma': 122 | label_His = 1 123 | elif self.LIST[index, 2] == 'oligodendroglioma': 124 | label_His = 2 125 | elif self.LIST[index, 2] == 'glioblastoma': 126 | label_His = 3 127 | 128 | if self.LIST[index, 2]=='glioblastoma': 129 | label_His_2class = 1 130 | else: 131 | label_His_2class = 0 132 | 133 | if self.LIST[index, 3]=='G2': 134 | label_Grade=0 135 | elif self.LIST[index, 3] == 'G3': 136 | label_Grade = 1 137 | else: 138 | label_Grade=2 #### Useless 139 | 140 | 141 | if self.LIST[index, 4]=='WT': 142 | label_Diag = 0 143 | elif self.LIST[index, 5] == 'codel': 144 | label_Diag = 3 145 | else: 146 | if self.LIST[index, 6] == -2 or self.LIST[index, 6] == -1 or self.LIST[index, 3] =='G4': 147 | label_Diag = 1 148 | else: 149 | label_Diag = 2 150 | 151 | 152 | label=np.asarray([label_IDH,label_1p19q,label_CDKN,label_His,label_Grade,label_Diag,label_His_2class]) 153 | 154 | return label 155 | 156 | 157 | def shuffle_list(self, seed): 158 | np.random.seed(seed) 159 | random.seed(seed) 160 | np.random.shuffle(self.LIST) 161 | 162 | 163 | 164 | def __len__(self): 165 | return self.LIST.shape[0] 166 | 167 | class Our_Dataset_vis(Dataset): 168 | def __init__(self, phase,opt,if_end2end=False): 169 | super(Our_Dataset_vis, self).__init__() 170 | self.opt = opt 171 | self.patc_bs=64 172 | self.phase=phase 173 | self.if_end2end=if_end2end 174 | self.dataDir = (opt['dataDir']+'extract_224/') if opt['imgSize'][0]==224 else (opt['dataDir']+'extract_512/') 175 | 176 | 177 | TCGA_label = pd.read_excel(opt['TCGA_label_path'], sheet_name='Sheet1', header=0) 178 | excel_wsi = TCGA_label.values 179 | 180 | PATIENT_LIST=excel_wsi[:,0] 181 | np.random.seed(opt['seed']) 182 | random.seed(opt['seed']) 183 | PATIENT_LIST=list(PATIENT_LIST) 184 | 185 | PATIENT_LIST=np.unique(PATIENT_LIST) 186 | np.random.shuffle(PATIENT_LIST) 187 | TEST_LIST = [] 188 | 189 | for i in range(excel_wsi.shape[0]):# 2612 190 | 191 | if excel_wsi[:,0][i] in PATIENT_LIST: 192 | TEST_LIST.append(excel_wsi[i,:]) 193 | TEST_LIST = np.asarray(TEST_LIST) 194 | 195 | 196 | self.train_iter_count=0 197 | self.Flat=0 198 | self.WSI_all=[] 199 | 200 | def __getitem__(self, index): 201 | feature_all_20,feature_all_10,read_details,self.LIST[index, 1] = self.read_feature(index) 202 | 203 | label=self.label_gene(index) 204 | 205 | return torch.from_numpy(np.array(feature_all_20)).float(),torch.from_numpy(np.array(feature_all_10)).float(),\ 206 | torch.from_numpy(label),read_details,self.LIST[index, 1] 207 | 208 | def read_feature(self, index): 209 | 210 | 211 | root = '/Res50_feature_2500_fixdim0_norm' 212 | 213 | patient_id = self.LIST[index, 0] 214 | 215 | 216 | if patient_id[0].startswith('T'): 217 | base_path = self.opt['dataDir'] + 'TCGA' 218 | elif patient_id[0].startswith('W'): 219 | base_path = self.opt['dataDir'] + 'IvYGAP' 220 | elif patient_id[0].startswith('C'): 221 | base_path = self.opt['dataDir'] + 'CPTAC' 222 | else: 223 | raise ValueError("Unknown data source") 224 | read_details = np.load(self.opt['dataDir'] + 'read_details/' + self.LIST[index, 1] + '.npy', allow_pickle=True)[0] 225 | num_patches = read_details.shape[0] 226 | patch_20 = h5py.File(base_path + root + '_20x/' + self.LIST[index, 1] + '.h5')['Res_feature'][:] 227 | patch_10 = h5py.File(base_path + root + '/' + self.LIST[index, 1] + '.h5')['Res_feature'][:] 228 | return patch_20[0], patch_10[0],num_patches,self.LIST[index, 1] 229 | 230 | 231 | def label_gene(self,index): 232 | 233 | 234 | if self.LIST[index, 4]=='WT': 235 | label_IDH=0 236 | elif self.LIST[index, 4]=='Mutant': 237 | label_IDH=1 238 | if self.LIST[index, 5] == 'non-codel': 239 | label_1p19q = 0 240 | elif self.LIST[index, 5] == 'codel': 241 | label_1p19q = 1 242 | if self.LIST[index, 6] == -2 or self.LIST[index, 6] == -1: 243 | label_CDKN = 1 244 | else: 245 | label_CDKN = 0 246 | 247 | if self.LIST[index, 2]=='oligoastrocytoma': 248 | label_His = 0 249 | elif self.LIST[index, 2] == 'astrocytoma': 250 | label_His = 1 251 | elif self.LIST[index, 2] == 'oligodendroglioma': 252 | label_His = 2 253 | elif self.LIST[index, 2] == 'glioblastoma': 254 | label_His = 3 255 | 256 | if self.LIST[index, 2]=='glioblastoma': 257 | label_His_2class = 1 258 | else: 259 | label_His_2class = 0 260 | 261 | if self.LIST[index, 3]=='G2': 262 | label_Grade=0 263 | elif self.LIST[index, 3] == 'G3': 264 | label_Grade = 1 265 | else: 266 | label_Grade=2 #### Useless 267 | 268 | 269 | if self.LIST[index, 4]=='WT': 270 | label_Diag = 0 271 | elif self.LIST[index, 5] == 'codel': 272 | label_Diag = 3 273 | else: 274 | if self.LIST[index, 6] == -2 or self.LIST[index, 6] == -1 or self.LIST[index, 3] =='G4': 275 | label_Diag = 1 276 | else: 277 | label_Diag = 2 278 | 279 | 280 | label=np.asarray([label_IDH,label_1p19q,label_CDKN,label_His,label_Grade,label_Diag,label_His_2class]) 281 | 282 | return label 283 | 284 | 285 | def shuffle_list(self, seed): 286 | np.random.seed(seed) 287 | random.seed(seed) 288 | np.random.shuffle(self.LIST) 289 | 290 | 291 | 292 | def __len__(self): 293 | return self.LIST.shape[0] 294 | 295 | if __name__ == '__main__': 296 | # epoch_seed1 = np.arange(1000) 297 | # np.random.seed(100) 298 | # random.seed(100) 299 | # epoch_seed = np.arange(5) 300 | # np.random.shuffle(epoch_seed) 301 | # for i in range(5): 302 | # np.random.seed(epoch_seed[i]) 303 | # random.seed(epoch_seed[i]) 304 | # np.random.shuffle(epoch_seed1) 305 | # print(epoch_seed1[:20]) 306 | 307 | 308 | parser = argparse.ArgumentParser() 309 | parser.add_argument('--opt', type=str, default='config/mine.yml') 310 | args = parser.parse_args() 311 | with open(args.opt) as f: 312 | opt = yaml.load(f, Loader=SafeLoader) 313 | trainDataset = Our_Dataset(phase='Train', opt=opt) 314 | for i in range(100): 315 | _,x_,y_=trainDataset._getitem__(index=2000-i) -------------------------------------------------------------------------------- /debug.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from skimage import io 4 | import math 5 | ####### find out backward propogation difference of img and fea in TransMIL 6 | 7 | # seed=100 8 | # torch.manual_seed(seed) 9 | # torch.cuda.manual_seed(seed) 10 | # torch.cuda.manual_seed_all(seed) 11 | # np.random.seed(seed) 12 | # random.seed(seed) 13 | # parser = argparse.ArgumentParser() 14 | # parser.add_argument('--opt', type=str, default='config/miccai.yml') 15 | # args = parser.parse_args() 16 | # with open(args.opt) as f: 17 | # opt = yaml.load(f, Loader=SafeLoader) 18 | # gpuID = opt['gpus'] 19 | # TransMIL = model.TransMIL(opt) 20 | # model.init_weights(TransMIL, init_type='xavier', init_gain=1) 21 | # assert opt['name'].split('_')[0]=='TransMIL' 22 | # device = torch.device('cuda:{}'.format(gpuID[0])) if gpuID else torch.device('cpu') 23 | # if opt['name'].split('_')[2] == 'img': 24 | # Res_pretrain= net.Res50_pretrain() 25 | # Res_pretrain.to(device) 26 | # 27 | # TransMIL.to(device) 28 | # TransMIL_opt = torch.optim.Adam(filter(lambda p: p.requires_grad, TransMIL.parameters()), 0.01, weight_decay=0.00001) 29 | # trainDataset = dataset.Our_Dataset(phase='Train',opt=opt) 30 | # 31 | # img,label,img_path=trainDataset._getitem__(0) 32 | # 33 | # label = torch.from_numpy(np.asarray([label.detach().numpy()])).long() 34 | # 35 | # if torch.cuda.is_available(): 36 | # img = img.cuda(gpuID[0]) 37 | # label = label.cuda(gpuID[0]) 38 | # TransMIL.train() 39 | # TransMIL.zero_grad() 40 | # print(img_path) 41 | # if opt['name'].split('_')[2] == 'img': 42 | # Res_pretrain.train() 43 | # img=Res_pretrain(img) 44 | # 45 | # results_dict = TransMIL(img) 46 | # pred=results_dict['logits'] 47 | # loss_all = TransMIL.calculateLoss(pred,label) 48 | # loss_all.backward() 49 | # TransMIL_opt.step() 50 | 51 | # coords_all = h5py.File('temp/TCGA-DU-6404-01Z-00-DX1.93c15688-f5b2-40bc-85eb-ff2661e16d4e.h5')['coords'][:] 52 | """ 53 | Use only one object per slide 54 | """ 55 | excel_label_wsi = pd.read_excel('./merge_who.xlsx',sheet_name='wsi_level',header=0) 56 | excel_wsi =excel_label_wsi.values 57 | WSI_used_names=list(excel_wsi[:,1]) 58 | np.random.seed(1) 59 | random.seed(1) 60 | random.shuffle(WSI_used_names) 61 | 62 | root_reading_list=r'/mnt/disk10T/fyb/wxf_data/TCGA/brain/reading_list_extract224/' 63 | files_reading_list=os.listdir(root_reading_list) 64 | 65 | for i in range(excel_wsi.shape[0]): 66 | 67 | # print(WSI_used_names[i]) 68 | reading_list = np.load(root_reading_list + WSI_used_names[i]+'.npy') 69 | # reading_list = np.load(root_reading_list + 'TCGA-E1-A7Z4-01Z-00-DX2' + '.npy') 70 | Num_cluster = 0 71 | point_num = reading_list.shape[0] 72 | points_center_corrd = reading_list 73 | 74 | FLAT_w = int(points_center_corrd[0].split('_')[0]) 75 | FLAT_h = int(points_center_corrd[0].split('_')[1]) 76 | claster_corrds = {'0': []} 77 | for k in range(point_num): 78 | w_point = int(points_center_corrd[k].split('_')[0]) 79 | h_point = int(points_center_corrd[k].split('_')[1]) 80 | if w_point == FLAT_w: 81 | claster_corrds[str(Num_cluster)].append([w_point, h_point]) 82 | continue 83 | # FLAT_w=w_point 84 | if w_point > FLAT_w and (w_point-FLAT_w)<=6: 85 | 86 | FLAT_w = w_point 87 | claster_corrds[str(Num_cluster)].append([w_point, h_point]) 88 | continue 89 | if w_point < FLAT_w or (w_point-FLAT_w)>6: 90 | FLAT_w = w_point 91 | Num_cluster += 1 92 | claster_corrds[str(Num_cluster)] = [] 93 | claster_corrds[str(Num_cluster)].append([w_point, h_point]) 94 | del_value=[] 95 | for key,value in enumerate(claster_corrds): 96 | claster_corrds[value]=np.asarray(claster_corrds[value]) 97 | if claster_corrds[value].shape[0] < 150: 98 | del_value.append(value) 99 | for nn in range(len(del_value)): 100 | del claster_corrds[del_value[nn]] 101 | patch_length=[] 102 | value_name=[] 103 | for key, value in enumerate(claster_corrds): 104 | patch_length.append(claster_corrds[value].shape[0]) 105 | value_name.append(value) 106 | save_dict=[] 107 | sort_0=np.argsort(np.asarray(patch_length)) 108 | for nn in range(len(value_name)): 109 | save_dict.append(claster_corrds[value_name[sort_0[len(value_name)-nn-1]]]) 110 | 111 | np.save('/mnt/disk10T/fyb/wxf_data/TCGA/brain/read_details/' + WSI_used_names[i]+ '.npy', save_dict) 112 | print(i) 113 | a=1 114 | a=[[[1,2],[11,21]],[[12,23],[14,24]],[[15,25],[16,26]],[[17,27],[18,28]]] 115 | a=np.asarray(a) 116 | a=a.reshape(2,2,2,2) 117 | from PIL import Image 118 | import cv2 119 | from matplotlib import pyplot as plt 120 | from torchvision.transforms import Compose 121 | import transform.transforms_group as our_transform 122 | def train_transform(degree=180): 123 | return Compose([ 124 | our_transform.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.05), 125 | ]) 126 | My_transform=train_transform() 127 | root=r'D:\PhD\Project_WSI\data\a/' 128 | files=os.listdir(root) 129 | imgs=[] 130 | max_num=1200 131 | read_details=np.load(r'D:\PhD\Project_WSI\data\TCGA-HT-8104-01A-01-TS1.npy',allow_pickle=True)[0] 132 | for i in range(len(files)): 133 | imgs.append(io.imread(root + '/' + str(read_details[i][0]) + '_' + str(read_details[i][1]) + '.jpg')) 134 | imgs = np.asarray(imgs)#(num_patches,224,224,3) 135 | 136 | 137 | 138 | imgs=imgs.reshape(-1,224,3) #(num_patches*224,224,3) 139 | imgs = Image.fromarray(imgs.astype('uint8')).convert('RGB') 140 | imgs=My_transform(imgs) 141 | imgs=np.array(imgs)#(num_patches*224,224,3) 142 | imgs=imgs.reshape(-1,224,224,3)#(num_patches,224,224,3) 143 | 144 | N_adj=int(math.sqrt(len(files))) 145 | imgs=imgs[0:N_adj*N_adj] 146 | imgs=imgs.reshape(N_adj,N_adj,224,224,3) #(Na,Na,224,224,3) 147 | imgs=np.transpose(imgs,(0,2,1,3,4)) #(Na,224,Na,224,3) 148 | imgs=imgs.reshape(N_adj*224,N_adj*224,3)#(Na*224,Na*224,3) 149 | 150 | plt.imshow(imgs) 151 | plt.show() 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | -------------------------------------------------------------------------------- /docs/1748968628246.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LHY1007/M3C2/ef66915bd3e9671312205c424fe6017915f1ed81/docs/1748968628246.png -------------------------------------------------------------------------------- /docs/framework图.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LHY1007/M3C2/ef66915bd3e9671312205c424fe6017915f1ed81/docs/framework图.png -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['QTQPAPLATFORM']='offscreen' 3 | import torch 4 | import torchvision 5 | # from tensorboardX import SummaryWriter 6 | import numpy as np 7 | from PIL import Image 8 | #import cv2 9 | import matplotlib 10 | matplotlib.use('agg') 11 | import matplotlib.pyplot as plt 12 | 13 | 14 | def plot_confusion_matrix(cm, savename, title='Confusion Matrix', classes=['G2_O', 'G3_O', 'G2_A', 'G3_A', 'G4_A', 'GBM']): 15 | plt.figure(figsize=(11, 11), dpi=100) 16 | np.set_printoptions(precision=2) # 输出小数点的个数0:'G2_O', 1:'G3_O', 2:'G2_A', 3:'G3_A', 4:'G4_A', 5:'GBM' 17 | 18 | # 在混淆矩阵中每格的概率值 19 | # classes = ['P_MN', 'S_MN', 'P_IgAN', 'S_IgAN', 'LN', 'DN', 'ANCA', 'MPGN'] 20 | ind_array = np.arange(len(classes)) 21 | x, y = np.meshgrid(ind_array, ind_array) 22 | thresh = cm.max() / 2. 23 | 24 | 25 | for x_val, y_val in zip(x.flatten(), y.flatten()): 26 | c = cm[y_val][x_val] 27 | if c > 0.001: 28 | plt.text(x_val, y_val, "%0.2f" % (c,), color='white' if cm[x_val, y_val] > thresh else 'black', 29 | fontsize=20, va='center', ha='center') 30 | 31 | # plt.imshow(cm, interpolation='nearest', cmap=plt.cm.binary) 32 | plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues) 33 | plt.title(title, fontsize=36, pad=20) 34 | 35 | # plt.matshow(cm, cmap=plt.cm.Blues) # 背景颜色 36 | 37 | plt.colorbar() 38 | xlocations = np.array(range(len(classes))) 39 | # plt.xticks(xlocations, classes, rotation=90) 40 | # plt.yticks(xlocations, classes) 41 | plt.xticks(xlocations, classes, size=16) 42 | plt.yticks(xlocations, classes, size=16) 43 | plt.ylabel('Actual label', fontsize=22, labelpad=12) 44 | plt.xlabel('Predict label', fontsize=22, labelpad=12) 45 | 46 | # offset the tick 47 | tick_marks = np.array(range(len(classes))) + 0.5 48 | plt.gca().set_xticks(tick_marks, minor=True) 49 | plt.gca().set_yticks(tick_marks, minor=True) 50 | plt.gca().xaxis.set_ticks_position('none') 51 | plt.gca().yaxis.set_ticks_position('none') 52 | plt.grid(True, which='minor', linestyle='-') 53 | plt.gcf().subplots_adjust(bottom=0.15) 54 | 55 | # show confusion matrix 56 | plt.savefig(savename, format='png') 57 | # plt.show() 58 | plt.close('all') -------------------------------------------------------------------------------- /feature_generation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | # import curve 5 | import platform 6 | import dataset 7 | from torch.utils.data import DataLoader,Dataset 8 | import numpy as np 9 | import argparse, time, random 10 | import yaml 11 | from yaml.loader import SafeLoader 12 | from evaluation import * 13 | import glob 14 | import model 15 | import net 16 | from saver import Saver 17 | from tqdm import tqdm 18 | import matplotlib.pyplot as plt 19 | from matplotlib.ticker import MultipleLocator 20 | from sklearn.metrics import roc_curve, auc 21 | from tensorboardX import SummaryWriter 22 | from sklearn.metrics import cohen_kappa_score 23 | from sklearn.metrics import accuracy_score 24 | from sklearn.metrics import precision_score, recall_score, f1_score 25 | from sklearn.metrics import confusion_matrix 26 | import gc 27 | from sklearn import metrics 28 | import h5py 29 | from utils import * 30 | import pandas as pd 31 | from skimage import io 32 | 33 | from apex import amp 34 | """ 35 | label 2016={ 0:'G2_O', 1:'G3_O', 2:'G2_A', 3:'G3_A', 4:'G2_OA', 5:'G3_OA', 6:'GBM'} 36 | label 2021={ 0:'G2_O', 1:'G3_O', 2:'G2_A', 3:'G3_A', 4:'G4_A', 5:'GBM'} 37 | """ 38 | 39 | 40 | def train(opt): 41 | 42 | gpuID = opt['gpus'] 43 | 44 | ############## Init ##################################### 45 | 46 | 47 | device = torch.device('cuda:{}'.format(gpuID[0])) if gpuID else torch.device('cpu') 48 | 49 | 50 | Res_pretrain= net.Res50_pretrain().cuda(gpuID[0]) 51 | # Res_pretrain.to(device) 52 | Res_pretrain = nn.DataParallel(Res_pretrain, device_ids=gpuID) 53 | Res_pretrain.eval() 54 | 55 | ############### Datasets ####################### 56 | 57 | 58 | 59 | trainDataset = dataset.Our_Dataset(phase='Train',opt=opt) 60 | valDataset = dataset.Our_Dataset(phase='Val', opt=opt) 61 | testDataset = dataset.Our_Dataset(phase='Test',opt=opt) 62 | trainLoader = DataLoader(trainDataset, batch_size=opt['batchSize'], 63 | num_workers=opt['nThreads'], shuffle=True) 64 | valLoader = DataLoader(valDataset, batch_size=opt['Val_batchSize'], 65 | num_workers=opt['nThreads'], shuffle=True) 66 | testLoader = DataLoader(testDataset, batch_size=opt['Test_batchSize'], 67 | num_workers=opt['nThreads'], shuffle=True) 68 | 69 | 70 | train_bar = tqdm(trainLoader) 71 | for packs in train_bar: 72 | img = packs[0][0] #(N,3,224,224) 73 | imgPath = packs[2][0] 74 | patches_coor = packs[3][0] # list N,2 75 | if torch.cuda.is_available(): 76 | img = img.cuda(gpuID[0]) 77 | N=img.detach().cpu().numpy().shape[0] 78 | 79 | 80 | feature = Res_pretrain(img) # N 1024 81 | feature = torch.unsqueeze(feature, dim=0) 82 | feature_save = feature.detach().cpu().numpy() 83 | feature_save=np.float16(feature_save) 84 | print(feature_save.shape) 85 | if not os.path.exists(opt['dataDir'] +'Res50_feature_2000_fixdim0_512/'): 86 | os.makedirs(opt['dataDir'] +'Res50_feature_2000_fixdim0_512/') 87 | with h5py.File(opt['dataDir'] +'Res50_feature_2000_fixdim0_512/'+ imgPath+'.h5', 'w') as f: 88 | f['Res_feature'] = feature_save 89 | f['patches_coor'] = patches_coor 90 | 91 | train_bar = tqdm(valLoader) 92 | for packs in train_bar: 93 | img = packs[0][0] # (N,3,224,224) 94 | imgPath = packs[2][0] 95 | patches_coor = packs[3][0] # list N,2 96 | if torch.cuda.is_available(): 97 | img = img.cuda(gpuID[0]) 98 | N = img.detach().cpu().numpy().shape[0] 99 | 100 | feature = Res_pretrain(img) # N 1024 101 | feature = torch.unsqueeze(feature, dim=0) 102 | feature_save = feature.detach().cpu().numpy() 103 | feature_save = np.float16(feature_save) 104 | print(feature_save.shape) 105 | if not os.path.exists(opt['dataDir'] + 'Res50_feature_2000_fixdim0_512/'): 106 | os.makedirs(opt['dataDir'] + 'Res50_feature_2000_fixdim0_512/') 107 | with h5py.File(opt['dataDir'] + 'Res50_feature_2000_fixdim0_512/' + imgPath + '.h5', 'w') as f: 108 | f['Res_feature'] = feature_save 109 | f['patches_coor'] = patches_coor 110 | # 111 | train_bar = tqdm(testLoader) 112 | for packs in train_bar: 113 | img = packs[0][0] # (N,3,224,224) 114 | imgPath = packs[2][0] 115 | patches_coor = packs[3][0] # list N,2 116 | if torch.cuda.is_available(): 117 | img = img.cuda(gpuID[0]) 118 | N = img.detach().cpu().numpy().shape[0] 119 | 120 | feature = Res_pretrain(img) # N 1024 121 | feature = torch.unsqueeze(feature, dim=0) 122 | feature_save = feature.detach().cpu().numpy() 123 | feature_save = np.float16(feature_save) 124 | print(feature_save.shape) 125 | if not os.path.exists(opt['dataDir'] + 'Res50_feature_2000_fixdim0_512/'): 126 | os.makedirs(opt['dataDir'] + 'Res50_feature_2000_fixdim0_512/') 127 | with h5py.File(opt['dataDir'] + 'Res50_feature_2000_fixdim0_512/' + imgPath + '.h5', 'w') as f: 128 | f['Res_feature'] = feature_save 129 | f['patches_coor'] = patches_coor 130 | # features_WSI=[] 131 | # for i in range(N): 132 | # feature=Res_pretrain(torch.unsqueeze(img[i],dim=0)) 133 | # feature=feature.detach().cpu().numpy() 134 | # features_WSI.append(feature) 135 | # features_WSI=np.asarray(features_WSI) #(N,1024) 136 | # feature_save = np.expand_dims(features_WSI, axis=0) 137 | # feature_save = np.float16(feature_save) 138 | # print(feature_save.shape) 139 | # if not os.path.exists(opt['dataDir'] + 'Res50_feature_1200_iso_512/'): 140 | # os.makedirs(opt['dataDir'] + 'Res50_feature_1200_iso_512/') 141 | # with h5py.File(opt['dataDir'] + 'Res50_feature_1200_iso_512/' + imgPath + '.h5', 'w') as f: 142 | # f['Res_feature'] = feature_save 143 | # f['patches_coor'] = patches_coor 144 | 145 | 146 | a=1 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | if __name__ == '__main__': 157 | 158 | 159 | 160 | 161 | 162 | parser = argparse.ArgumentParser() 163 | parser.add_argument('--opt', type=str, default='config/miccai.yml') 164 | args = parser.parse_args() 165 | with open(args.opt) as f: 166 | opt = yaml.load(f, Loader=SafeLoader) 167 | # 168 | # k1 = h5py.File(opt['dataDir'] + 'Res50_feature_1/TCGA-DH-5143-01Z-00-DX1.h5')['Res50_feature'][:][0,100,:] 169 | # k2 = h5py.File( './TCGA-DH-5143-01Z-00-DX1.h5')['Res50_feature'][:][0,100,:] 170 | 171 | # k3 = h5py.File(opt['dataDir'] + 'Res50_feature_1/TCGA-06-0876-01Z-00-DX1.h5')['Res50_feature'][:][0,0,100,:] 172 | # k4 = h5py.File(opt['dataDir'] + 'Res50_feature_2/TCGA-06-0876-01Z-00-DX1.h5')['Res50_feature'][:][0,100,:] 173 | # 174 | # k5 = h5py.File(opt['dataDir'] + 'Res50_feature_1/TCGA-CS-6290-01A-01-TS1.h5')['Res50_feature'][:][0,0,100,:] 175 | # k6 = h5py.File(opt['dataDir'] + 'Res50_feature_2/TCGA-CS-6290-01A-01-TS1.h5')['Res50_feature'][:][0,100,:] 176 | 177 | 178 | a = 1 179 | train(opt) 180 | 181 | 182 | 183 | a=1 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | -------------------------------------------------------------------------------- /logs.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LHY1007/M3C2/ef66915bd3e9671312205c424fe6017915f1ed81/logs.py -------------------------------------------------------------------------------- /mainpre.py: -------------------------------------------------------------------------------- 1 | # from apex import amp 2 | from utils import * 3 | import dataset_mine 4 | from net import init_weights,get_scheduler,WarmupCosineSchedule 5 | def setup_seed(seed): 6 | torch.manual_seed(seed) 7 | torch.cuda.manual_seed(seed) 8 | torch.cuda.manual_seed_all(seed) 9 | np.random.seed(seed) 10 | random.seed(seed) 11 | if seed == 0: 12 | torch.backends.cudnn.deterministic = True 13 | torch.backends.cudnn.benchmark = False 14 | 15 | 16 | def train(opt): 17 | opt['gpus'] = [5] 18 | gpuID = opt['gpus'] 19 | opt['batchSize'] =1 20 | 21 | ############### Mine_model ####################### 22 | Mine_model_init,Mine_model_IDH,Mine_model_1p19q,Mine_model_CDKN,Mine_model_Graph,Mine_model_His,Mine_model_Cls,Mine_model_Task\ 23 | ,opt_init,opt_IDH,opt_1p19q,opt_CDKN,opt_Graph,opt_His,opt_Cls,opt_Task=get_model(opt) 24 | 25 | if opt['decayType']=='exp' or opt['decayType']=='step': 26 | Mine_model_sch_init = get_scheduler(opt_init, opt['n_ep'], opt['n_ep_decay'], opt['decayType'], -1) 27 | Mine_model_sch_IDH = get_scheduler(opt_IDH, opt['n_ep'], opt['n_ep_decay'], opt['decayType'], -1) 28 | Mine_model_sch_1p19q = get_scheduler(opt_1p19q, opt['n_ep'], opt['n_ep_decay'], opt['decayType'], -1) 29 | Mine_model_sch_CDKN = get_scheduler(opt_CDKN, opt['n_ep'], opt['n_ep_decay'], opt['decayType'], -1) 30 | Mine_model_sch_Graph = get_scheduler(opt_Graph, opt['n_ep'], opt['n_ep_decay'], opt['decayType'], -1) 31 | Mine_model_sch_His = get_scheduler(opt_His, opt['n_ep'], opt['n_ep_decay'], opt['decayType'], -1) 32 | Mine_model_sch_Cls = get_scheduler(opt_Cls, opt['n_ep'], opt['n_ep_decay'], opt['decayType'], -1) 33 | Mine_model_sch_Task = get_scheduler(opt_Task, opt['n_ep'], opt['n_ep_decay'], opt['decayType'], -1) 34 | elif opt['decayType']=='cos': 35 | Mine_model_sch_init = WarmupCosineSchedule(opt_init, warmup_steps=opt['decay_cos_warmup_steps'], t_total=opt['n_ep']) 36 | Mine_model_sch_IDH = WarmupCosineSchedule(opt_IDH, warmup_steps=opt['decay_cos_warmup_steps'], t_total=opt['n_ep']) 37 | Mine_model_sch_1p19q = WarmupCosineSchedule(opt_1p19q, warmup_steps=opt['decay_cos_warmup_steps'], t_total=opt['n_ep']) 38 | Mine_model_sch_CDKN = WarmupCosineSchedule(opt_CDKN, warmup_steps=opt['decay_cos_warmup_steps'], t_total=opt['n_ep']) 39 | Mine_model_sch_Graph = WarmupCosineSchedule(opt_Graph, warmup_steps=opt['decay_cos_warmup_steps'],t_total=opt['n_ep']) 40 | Mine_model_sch_His = WarmupCosineSchedule(opt_His, warmup_steps=opt['decay_cos_warmup_steps'], t_total=opt['n_ep']) 41 | Mine_model_sch_Cls = WarmupCosineSchedule(opt_Cls, warmup_steps=opt['decay_cos_warmup_steps'],t_total=opt['n_ep']) 42 | Mine_model_sch_Task = WarmupCosineSchedule(opt_Task, warmup_steps=opt['decay_cos_warmup_steps'],t_total=opt['n_ep']) 43 | 44 | print('%d GPUs are working with the id of %s' % (torch.cuda.device_count(), str(gpuID))) 45 | 46 | 47 | 48 | ############### Datasets ####################### 49 | root_init =r'/home/hanyu/LHY/miccai7.22/best_model/Best3_0720-0755-0003.pth' 50 | ckptdir_init = os.path.join(root_init) 51 | checkpoint_init = torch.load(ckptdir_init) 52 | root_IDH =r'/home/hanyu/LHY/miccai7.22/best_model/Best3_0720-0755-0003.pth' 53 | ckptdir_IDH = os.path.join(root_IDH) 54 | checkpoint_IDH = torch.load(ckptdir_IDH) 55 | root_1p19q =r'/home/hanyu/LHY/miccai7.22/best_model/Best3_0720-0755-0003.pth' 56 | ckptdir_1p19q = os.path.join(root_1p19q) 57 | checkpoint_1p19q = torch.load(ckptdir_1p19q) 58 | root_CDKN =r'/home/hanyu/LHY/miccai7.22/best_model/Best3_0720-0755-0003.pth' 59 | ckptdir_CDKN = os.path.join(root_CDKN) 60 | checkpoint_CDKN = torch.load(ckptdir_CDKN) 61 | root_Task =r'/home/hanyu/LHY/miccai7.22/best_model/Best3_0720-0755-0003.pth' 62 | ckptdir_Task = os.path.join(root_Task) 63 | checkpoint_Task = torch.load(ckptdir_Task) 64 | 65 | related_params = {k: v for k, v in checkpoint_init['init'].items()} 66 | Mine_model_init.load_state_dict(related_params) 67 | related_params = {k: v for k, v in checkpoint_IDH['IDH'].items()} 68 | Mine_model_IDH.load_state_dict(related_params) 69 | related_params = {k: v for k, v in checkpoint_1p19q['1p19q'].items()} 70 | Mine_model_1p19q.load_state_dict(related_params) 71 | related_params = {k: v for k, v in checkpoint_CDKN['CDKN'].items()} 72 | Mine_model_CDKN.load_state_dict(related_params) 73 | related_params = {k: v for k, v in checkpoint_IDH['Graph'].items()} 74 | Mine_model_Graph.load_state_dict(related_params) 75 | related_params = {k: v for k, v in checkpoint_IDH['His'].items()} 76 | Mine_model_His.load_state_dict(related_params) 77 | related_params = {k: v for k, v in checkpoint_IDH['Cls'].items()} 78 | Mine_model_Cls.load_state_dict(related_params) 79 | related_params = {k: v for k, v in checkpoint_Task['Task'].items()} 80 | Mine_model_Task.load_state_dict(related_params) 81 | 82 | Mine_model_init.eval() 83 | Mine_model_IDH.eval() 84 | Mine_model_1p19q.eval() 85 | Mine_model_CDKN.eval() 86 | Mine_model_Graph.eval() 87 | Mine_model_His.eval() 88 | Mine_model_Cls.eval() 89 | Mine_model_Task.eval() 90 | 91 | 92 | for i in range(1): 93 | testDataset = dataset_mine.Our_Dataset(phase='Test',opt=opt) 94 | testLoader = DataLoader(testDataset, batch_size=opt['Test_batchSize'], 95 | num_workers=opt['nThreads'] if (sysstr == "Linux") else 1, shuffle=False) 96 | ItestDataset = dataset_mine.Our_Dataset(phase='ITest',opt=opt) 97 | ItestLoader = DataLoader(ItestDataset, batch_size=opt['Test_batchSize'], 98 | num_workers=opt['nThreads'] if (sysstr == "Linux") else 1, shuffle=False) 99 | 100 | last_ep = 0 101 | saver = Saver(opt) 102 | alleps = opt['n_ep'] - last_ep 103 | curep=0 104 | print('-------------------------------------Val and Test--------------------------------------') 105 | if (curep + 1) > (2): 106 | save_dir = os.path.join(opt['modelDir'], 'Mine_model-%04d.pth' % (curep + 1)) 107 | state = { 108 | 'init': Mine_model_init.state_dict(), 109 | 'IDH': Mine_model_IDH.state_dict(), 110 | '1p19q': Mine_model_1p19q.state_dict(), 111 | 'CDKN': Mine_model_CDKN.state_dict(), 112 | 'Graph': Mine_model_Graph.state_dict(), 113 | 'His': Mine_model_His.state_dict(), 114 | 'Cls': Mine_model_Cls.state_dict(), 115 | 'Task': Mine_model_Task.state_dict(), 116 | } 117 | torch.save(state, save_dir) 118 | 119 | print("----------Test-------------") 120 | list_WSI_IDH,list_WSI_1p19q,list_WSI_CDKN,list_WSI_His_2class,list_WSI_Diag= \ 121 | validation_All(opt, Mine_model_init, Mine_model_IDH, Mine_model_1p19q,Mine_model_CDKN,Mine_model_Graph,Mine_model_His,Mine_model_Cls,Mine_model_Task,testLoader, saver, curep + 1, opt['eva_cm'], gpuID, task = '') 122 | print('test in epoch: %d/%d, acc_IDH:%.3f,acc_1p19q:%.3f,acc_CDKN:%.3f,acc_His_2class:%.3f, acc_Diag:%.3f' % ( 123 | curep + 1, alleps, list_WSI_IDH[0], list_WSI_1p19q[0], list_WSI_CDKN[0], list_WSI_His_2class[0], list_WSI_Diag[0])) 124 | test_dict = {'test/acc_IDH': list_WSI_IDH[0], 'test/sen_IDH': list_WSI_IDH[3], 'test/spec_IDH': list_WSI_IDH[4], 125 | 'test/auc_IDH': list_WSI_IDH[5],'test/f1_IDH': list_WSI_IDH[2], 'test/prec_IDH': list_WSI_IDH[6], 126 | 'test/acc_1p19q': list_WSI_1p19q[0], 'test/sen_1p19q': list_WSI_1p19q[3], 'test/spec_1p19q': list_WSI_1p19q[4], 127 | 'test/auc_1p19q': list_WSI_1p19q[5], 'test/f1_1p19q': list_WSI_1p19q[2], 'test/prec_1p19q': list_WSI_1p19q[6], 128 | 'test/acc_CDKN': list_WSI_CDKN[0], 'test/sen_CDKN': list_WSI_CDKN[3], 'test/spec_CDKN': list_WSI_CDKN[4], 129 | 'test/auc_CDKN': list_WSI_CDKN[5], 'test/f1_CDKN': list_WSI_CDKN[2], 'test/prec_CDKN': list_WSI_CDKN[6], 130 | 'test/acc_His_2class': list_WSI_His_2class[0], 'test/sen_His_2class': list_WSI_His_2class[3], 'test/spec_His_2class': list_WSI_His_2class[4], 131 | 'test/auc_His_2class': list_WSI_His_2class[5], 'test/f1_His_2class': list_WSI_His_2class[2], 'test/prec_His_2class': list_WSI_His_2class[6], 132 | 'test/acc_Diag': list_WSI_Diag[0], 'test/sen_Diag': list_WSI_Diag[3], 'test/spec_Diag': list_WSI_Diag[4], 133 | 'test/auc_Diag': list_WSI_Diag[5], 'test/f1_Diag': list_WSI_Diag[2], 'test/prec_Diag': list_WSI_Diag[6], 134 | 135 | } 136 | test_dict_IDH = {'test/acc_IDH': list_WSI_IDH[0], 'test/sen_IDH': list_WSI_IDH[3], 'test/spec_IDH': list_WSI_IDH[4], 137 | 'test/auc_IDH': list_WSI_IDH[5], 'test/f1_IDH': list_WSI_IDH[2], 'test/prec_IDH': list_WSI_IDH[6],} 138 | test_dict_1p19q = {'test/acc_1p19q': list_WSI_1p19q[0], 'test/sen_1p19q': list_WSI_1p19q[3],'test/spec_1p19q': list_WSI_1p19q[4], 139 | 'test/auc_1p19q': list_WSI_1p19q[5], 'test/f1_1p19q': list_WSI_1p19q[2],'test/prec_1p19q': list_WSI_1p19q[6],} 140 | test_dict_CDKN = {'test/acc_CDKN': list_WSI_CDKN[0], 'test/sen_CDKN': list_WSI_CDKN[3],'test/spec_CDKN': list_WSI_CDKN[4], 141 | 'test/auc_CDKN': list_WSI_CDKN[5], 'test/f1_CDKN': list_WSI_CDKN[2],'test/prec_CDKN': list_WSI_CDKN[6],} 142 | test_dict_His_2class = {'test/acc_His_2class': list_WSI_His_2class[0], 'test/sen_His_2class': list_WSI_His_2class[3],'test/spec_His_2class': list_WSI_His_2class[4], 143 | 'test/auc_His_2class': list_WSI_His_2class[5], 'test/f1_His_2class': list_WSI_His_2class[2],'test/prec_His_2class': list_WSI_His_2class[6],} 144 | test_dict_Diag = {'test/acc_Diag': list_WSI_Diag[0], 'test/sen_Diag': list_WSI_Diag[3],'test/spec_Diag': list_WSI_Diag[4], 145 | 'test/auc_Diag': list_WSI_Diag[5], 'test/f1_Diag': list_WSI_Diag[2],'test/prec_Diag': list_WSI_Diag[6], } 146 | saver.write_scalars(curep + 1, test_dict) 147 | saver.write_log(curep + 1, test_dict_IDH, 'test_IDH') 148 | saver.write_log(curep + 1, test_dict_1p19q, 'test_1p19q') 149 | saver.write_log(curep + 1, test_dict_CDKN, 'test_CDKN') 150 | saver.write_log(curep + 1, test_dict_His_2class, 'test_His_2class') 151 | saver.write_log(curep + 1, test_dict_Diag, 'test_Diag') 152 | 153 | # print("----------ITest-------------") 154 | # list_WSI_IDH,list_WSI_1p19q,list_WSI_CDKN,list_WSI_His_2class,list_WSI_Diag= \ 155 | # validation_All(opt, Mine_model_init, Mine_model_IDH, Mine_model_1p19q,Mine_model_CDKN,Mine_model_Graph,Mine_model_His,Mine_model_Cls,Mine_model_Task,ItestLoader, saver, curep + 1, opt['eva_cm'], gpuID, task = 'Itest') 156 | # print('Itest in epoch: %d/%d, acc_IDH:%.3f,acc_1p19q:%.3f,acc_CDKN:%.3f,acc_His_2class:%.3f, acc_Diag:%.3f' % ( 157 | # curep + 1, alleps, list_WSI_IDH[0], list_WSI_1p19q[0], list_WSI_CDKN[0], list_WSI_His_2class[0], list_WSI_Diag[0])) 158 | # Itest_dict = {'Itest/acc_IDH': list_WSI_IDH[0], 'Itest/sen_IDH': list_WSI_IDH[3], 'Itest/spec_IDH': list_WSI_IDH[4], 159 | # 'Itest/auc_IDH': list_WSI_IDH[5],'Itest/f1_IDH': list_WSI_IDH[2], 'Itest/prec_IDH': list_WSI_IDH[6], 160 | # 'Itest/acc_1p19q': list_WSI_1p19q[0], 'Itest/sen_1p19q': list_WSI_1p19q[3], 'Itest/spec_1p19q': list_WSI_1p19q[4], 161 | # 'Itest/auc_1p19q': list_WSI_1p19q[5], 'Itest/f1_1p19q': list_WSI_1p19q[2], 'Itest/prec_1p19q': list_WSI_1p19q[6], 162 | # 'Itest/acc_CDKN': list_WSI_CDKN[0], 'Itest/sen_CDKN': list_WSI_CDKN[3], 'Itest/spec_CDKN': list_WSI_CDKN[4], 163 | # 'Itest/auc_CDKN': list_WSI_CDKN[5], 'Itest/f1_CDKN': list_WSI_CDKN[2], 'Itest/prec_CDKN': list_WSI_CDKN[6], 164 | # 'Itest/acc_His_2class': list_WSI_His_2class[0], 'Itest/sen_His_2class': list_WSI_His_2class[3], 'Itest/spec_His_2class': list_WSI_His_2class[4], 165 | # 'Itest/auc_His_2class': list_WSI_His_2class[5], 'Itest/f1_His_2class': list_WSI_His_2class[2], 'Itest/prec_His_2class': list_WSI_His_2class[6], 166 | # 'Itest/acc_Diag': list_WSI_Diag[0], 'Itest/sen_Diag': list_WSI_Diag[3], 'Itest/spec_Diag': list_WSI_Diag[4], 167 | # 'Itest/auc_Diag': list_WSI_Diag[5], 'Itest/f1_Diag': list_WSI_Diag[2], 'Itest/prec_Diag': list_WSI_Diag[6], 168 | # } 169 | # Itest_dict_IDH = {'Itest/acc_IDH': list_WSI_IDH[0], 'Itest/sen_IDH': list_WSI_IDH[3], 'Itest/spec_IDH': list_WSI_IDH[4], 170 | # 'Itest/auc_IDH': list_WSI_IDH[5], 'Itest/f1_IDH': list_WSI_IDH[2], 'Itest/prec_IDH': list_WSI_IDH[6],} 171 | # Itest_dict_1p19q = {'Itest/acc_1p19q': list_WSI_1p19q[0], 'Itest/sen_1p19q': list_WSI_1p19q[3],'Itest/spec_1p19q': list_WSI_1p19q[4], 172 | # 'Itest/auc_1p19q': list_WSI_1p19q[5], 'Itest/f1_1p19q': list_WSI_1p19q[2],'Itest/prec_1p19q': list_WSI_1p19q[6],} 173 | # Itest_dict_CDKN = {'Itest/acc_CDKN': list_WSI_CDKN[0], 'Itest/sen_CDKN': list_WSI_CDKN[3],'Itest/spec_CDKN': list_WSI_CDKN[4], 174 | # 'Itest/auc_CDKN': list_WSI_CDKN[5], 'Itest/f1_CDKN': list_WSI_CDKN[2],'Itest/prec_CDKN': list_WSI_CDKN[6],} 175 | # Itest_dict_His_2class = {'Itest/acc_His_2class': list_WSI_His_2class[0], 'Itest/sen_His_2class': list_WSI_His_2class[3],'Itest/spec_His_2class': list_WSI_His_2class[4], 176 | # 'Itest/auc_His_2class': list_WSI_His_2class[5], 'Itest/f1_His_2class': list_WSI_His_2class[2],'Itest/prec_His_2class': list_WSI_His_2class[6],} 177 | # Itest_dict_Diag = {'Itest/acc_Diag': list_WSI_Diag[0], 'Itest/sen_Diag': list_WSI_Diag[3],'Itest/spec_Diag': list_WSI_Diag[4], 178 | # 'Itest/auc_Diag': list_WSI_Diag[5], 'Itest/f1_Diag': list_WSI_Diag[2],'Itest/prec_Diag': list_WSI_Diag[6], } 179 | # saver.write_scalars(curep + 1, Itest_dict) 180 | # saver.write_log(curep + 1, Itest_dict_IDH, 'Itest_IDH') 181 | # saver.write_log(curep + 1, Itest_dict_1p19q, 'Itest_1p19q') 182 | # saver.write_log(curep + 1, Itest_dict_CDKN, 'Itest_CDKN') 183 | # saver.write_log(curep + 1, Itest_dict_His_2class, 'Itest_His_2class') 184 | # saver.write_log(curep + 1, Itest_dict_Diag, 'Itest_Diag') 185 | 186 | def remove_all_file(path): 187 | if os.path.isdir(path): 188 | for i in os.listdir(path): 189 | path_file = os.path.join(path, i) 190 | os.remove(path_file) 191 | def remove_all_dir(path): 192 | if os.path.isdir(path): 193 | for i in os.listdir(path): 194 | path_file = os.path.join(path, i) 195 | for j in os.listdir(path_file): 196 | path_file1 = os.path.join(path_file, j) 197 | os.remove(path_file1) 198 | os.rmdir(path_file) 199 | 200 | 201 | 202 | 203 | 204 | if __name__ == '__main__': 205 | parser = argparse.ArgumentParser() 206 | parser.add_argument('--opt', type=str, default='config/mine.yml') 207 | args = parser.parse_args() 208 | with open(args.opt) as f: 209 | opt = yaml.load(f, Loader=SafeLoader) 210 | 211 | 212 | sysstr = platform.system() 213 | 214 | 215 | setup_seed(opt['seed']) 216 | if opt['command']=='Train': 217 | cur_time = time.strftime('%m%d-%H%M', time.localtime()) 218 | 219 | 220 | opt['name'] = 'Pretrain'+'_{}'.format(cur_time) 221 | opt['logDir'] = os.path.join(opt['logDir'], opt['name']) 222 | opt['modelDir'] = os.path.join(opt['modelDir'], opt['name']) 223 | opt['saveDir'] = os.path.join(opt['saveDir'], opt['name']) 224 | opt['cm_saveDir'] = os.path.join(opt['cm_saveDir'], opt['name']) 225 | if not os.path.exists(opt['logDir']): 226 | os.makedirs(opt['logDir']) 227 | if not os.path.exists(opt['modelDir']): 228 | os.makedirs(opt['modelDir']) 229 | if not os.path.exists(opt['saveDir']): 230 | os.makedirs(opt['saveDir']) 231 | if not os.path.exists(opt['cm_saveDir']): 232 | os.makedirs(opt['cm_saveDir']) 233 | 234 | para_log = os.path.join(opt['modelDir'], 'params.yml') 235 | if os.path.exists(para_log): 236 | os.remove(para_log) 237 | with open(para_log, 'w') as f: 238 | data = yaml.dump(opt, f, sort_keys=False, default_flow_style=False) 239 | 240 | print("\n\n============> begin training <=======") 241 | train(opt) 242 | 243 | 244 | 245 | 246 | a=1 247 | 248 | 249 | 250 | 251 | 252 | 253 | 254 | 255 | 256 | 257 | 258 | 259 | 260 | 261 | 262 | 263 | 264 | 265 | 266 | 267 | 268 | 269 | 270 | 271 | 272 | 273 | 274 | 275 | 276 | -------------------------------------------------------------------------------- /merge_who.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LHY1007/M3C2/ef66915bd3e9671312205c424fe6017915f1ed81/merge_who.xlsx -------------------------------------------------------------------------------- /post_processing.py: -------------------------------------------------------------------------------- 1 | 2 | # root=r'D:\PhD\Project_WSI\Others_code\GBM_WSSM-master\OneDrive-2023-03-15/' 3 | # dirs=os.listdir(root) 4 | # save_root=r'D:\PhD\Project_WSI\Others_code\GBM_WSSM-master/GBM_TCGA_SEG/' 5 | # for i in range(len(dirs)): 6 | # files=os.listdir(root+dirs[i]+'/') 7 | # for j in range(len(files)): 8 | # shutil.copy(root+dirs[i]+'/'+files[j],save_root+files[j]) 9 | import os 10 | # from openslide import OpenSlide 11 | import h5py 12 | import numpy as np 13 | import cv2 14 | from skimage import io,transform 15 | root=r'D:\PhD\Project_WSI\Others_code\GBM_WSSM-master\GBM_Test_Images/' 16 | imgs=os.listdir(root) 17 | for i in range(len(imgs)): 18 | img_temp = io.imread(root+imgs[i]) 19 | img_temp = cv2.resize(img_temp, (1024, 1024)) 20 | io.imsave(root+imgs[i][0:-4]+'resize.jpg',img_temp) 21 | 22 | 23 | # def get_filename(root_path,file_path): 24 | # return_file = [] 25 | # files=os.listdir(root_path+file_path) 26 | # for i in range(len(files)): 27 | # get_path = os.path.join(root_path, file_path,files[i]) 28 | # if get_path.endswith('.svs') or get_path.endswith('.partial'): 29 | # return_file.append(get_path) 30 | # return return_file 31 | # 32 | # root=r'/mnt/disk10T/fuyibing/wxf_data/TCGA/brain/ori_wsi/ffpe_GBM/' 33 | # slide_path = [] 34 | # WSI_path_list = os.listdir(root) 35 | # for i in range(len(WSI_path_list)): 36 | # get_file = get_filename(root, WSI_path_list[i]) 37 | # slide_path.append(get_file[0]) 38 | # 39 | # 40 | # for i in range(len(slide_path)): 41 | # wsi_obj = OpenSlide(slide_path[i]) 42 | # 43 | # pro = dict(wsi_obj.properties) 44 | # MPP=np.float(pro['aperio.MPP']) 45 | # wsi_w=wsi_obj.dimensions[0] 46 | # wsi_h=wsi_obj.dimensions[1] 47 | # 48 | # 49 | # 50 | # with h5py.File('vis_results/set0/'+slide_path[i].split('/')[-1].split('.')[0]+'.h5', 'w') as f: 51 | # f['wsi_w'] = wsi_w 52 | # f['wsi_h'] = wsi_h 53 | # f['MPP'] = MPP 54 | # print(i) 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | -------------------------------------------------------------------------------- /roc_plot mu.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import matplotlib.pyplot as plt 4 | from sklearn.metrics import roc_curve, auc, roc_auc_score 5 | import torch.nn.functional as F 6 | from sklearn.metrics import roc_curve, auc, roc_auc_score 7 | import torch 8 | import ast 9 | import re 10 | import numpy as np 11 | import pandas as pd 12 | import re 13 | from scipy.special import softmax 14 | from sklearn.metrics import roc_curve, auc, roc_auc_score 15 | from scipy.special import expit # Sigmoid function 16 | def ROC(df): 17 | n_classes = 4 18 | fpr = dict() 19 | tpr = dict() 20 | roc_auc = dict() 21 | l = np.zeros((len(df), n_classes)) 22 | p = np.zeros((len(df), n_classes)) 23 | for i in range(len(df)): 24 | l[i, df.iloc[i, 0]] = 1 25 | label = df['label'].tolist() 26 | scores = df['score'].tolist() 27 | converted_data = [list(map(float, re.findall(r'-?\d+\.\d+', item))) for item in scores] 28 | score = np.array(converted_data) 29 | 30 | for i in range(n_classes): 31 | 32 | p[:, i] = score[:, i] 33 | fpr[i], tpr[i], _ = roc_curve(label, score[:, i], pos_label=i) 34 | roc_auc[i] = auc(fpr[i], tpr[i]) 35 | 36 | all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)])) 37 | mean_tpr = np.zeros_like(all_fpr) 38 | for i in range(n_classes): 39 | mean_tpr += np.interp(all_fpr, fpr[i], tpr[i]) 40 | mean_tpr /= n_classes 41 | fpr["macro"] = all_fpr 42 | tpr["macro"] = mean_tpr 43 | weighted_auc = roc_auc_score(l, p, average='macro') 44 | 45 | return fpr['macro'],tpr['macro'], weighted_auc 46 | 47 | if __name__ == "__main__": 48 | name='IN_Diag' 49 | # print('\033[1;35;0m字体变色,但无背景色 \033[0m') 50 | np.random.seed(2) 51 | plt.rcParams["font.family"] = "ARIAL" 52 | plt.rcParams["font.weight"] = "bold" 53 | plt.rcParams["axes.labelweight"] = "bold" 54 | 55 | COLOR_LIST = [[139 / 255, 20 / 255, 8 / 255], 56 | [188 / 255, 189 / 255, 34 / 255], 57 | [52 / 255, 193 / 255, 52 / 255], 58 | [150 / 255, 150 / 255, 190 / 255], 59 | [139 / 255, 101 / 255, 8 / 255], 60 | [68 / 255, 114 / 255, 236 / 255], 61 | [100 / 255, 114 / 255, 196 / 255], 62 | [214 / 255 + 0.1, 39 / 255 + 0.2, 40 / 255 + 0.2], 63 | [52 / 255, 163 / 255, 152 / 255], 64 | [139 / 255 * 1.1, 20 / 255 * 1.1, 8 / 255 * 1.1], 65 | [188 / 255 * 0.9, 189 / 255 * 0.9, 34 / 255 * 0.9], 66 | [52 / 255 * 1.1, 193 / 255 * 1.1, 52 / 255 * 1.1], 67 | [150 / 255 * 0.9, 150 / 255 * 0.9, 190 / 255 * 0.9], 68 | [139 / 255 * 1.1, 101 / 255 * 1.1, 8 / 255 * 1.1]] 69 | 70 | LINE_WIDTH_LIST = [3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3] 71 | 72 | i = 0 73 | plt.figure(figsize=[10.5, 10]) 74 | 75 | LABEL_LIST =['Ours', 'Wang et al.','ABMIL', 'TransMIL', 'CLAM', 76 | 'Charm', 'Deepglioma', 'MCAT', 'CMTA', 77 | 'AlexNet','DenseNet','InceptionNet','ResNet-50','VGG-18'] 78 | EXCEL_LIST = ['plot/Mine_'+name+'.xlsx', 'plot/MICCAI_'+name+'.xlsx','plot/ABMIL_'+name+'.xlsx','plot/TransMIL_'+name+'.xlsx',\ 79 | 'plot/Charm_'+name+'.xlsx','plot/Deepglioma_'+name+'.xlsx','plot/MCAT_'+name+'.xlsx','plot/CMTA_'+name+'.xlsx',\ 80 | 'plot/CLAM_'+name+'.xlsx','plot/AlexNet_'+name+'.xlsx','plot/DenseNet_'+name+'.xlsx','plot/InceptionNet_'+name+'.xlsx',\ 81 | 'plot/ResNet-50_'+name+'.xlsx','plot/VGG-18'+name+'.xlsx'] 82 | 83 | LABEL_LIST = ['Ours'] 84 | EXCEL_LIST = ['plot/Mine_'+name+'.xlsx'] 85 | 86 | 87 | fpr = dict() 88 | tpr = dict() 89 | roc_auc = dict() 90 | num = len(LABEL_LIST) 91 | 92 | 93 | for i in range(num): 94 | df = pd.read_excel(EXCEL_LIST[i]) 95 | label = df['label'].tolist() 96 | score = df['score'].tolist() 97 | 98 | fpr[i], tpr[i], _ = ROC(df) 99 | plt.plot(fpr[i], tpr[i], 100 | label=LABEL_LIST[i], # 添加 label 参数 101 | linewidth= LINE_WIDTH_LIST[i] , color=np.array(COLOR_LIST[i])) 102 | plt.plot(1-df['spec'].tolist()[0], df['sen'].tolist()[0], marker="o", markersize=15, markerfacecolor=np.array(COLOR_LIST[i]), markeredgecolor=np.array(COLOR_LIST[i])) 103 | print(df['sen'].tolist()[0]) 104 | print(df['spec'].tolist()[0]) 105 | 106 | plt.plot([0, 1], [0, 1], 'k--', lw=2) 107 | plt.grid(color=[0.85, 0.85, 0.85]) 108 | 109 | plt.xticks([0, 0.2, 0.4, 0.6, 0.8, 1], fontsize=24, weight='semibold') 110 | plt.yticks([0, 0.2, 0.4, 0.6, 0.8, 1], fontsize=24, weight='semibold') 111 | 112 | font_axis_name = {'fontsize': 34, 'weight': 'bold'} 113 | plt.xlabel('1-Specificity', font_axis_name) 114 | plt.ylabel('Sensitivity', font_axis_name) 115 | plt.xlim((0, 1)) 116 | plt.ylim((0, 1)) 117 | plt.legend(framealpha=1, fontsize=30, loc='lower right') 118 | plt.tight_layout() 119 | 120 | plt.savefig("plot/"+name+".tiff") 121 | plt.show() 122 | -------------------------------------------------------------------------------- /roc_plot.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import platform 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | from matplotlib.patches import Polygon 9 | from itertools import cycle 10 | from sklearn.metrics import precision_score, recall_score, f1_score 11 | from sklearn import svm, datasets 12 | from sklearn.metrics import roc_curve, auc 13 | from sklearn.model_selection import train_test_split 14 | from sklearn.preprocessing import label_binarize 15 | from sklearn.multiclass import OneVsRestClassifier 16 | from sklearn.metrics import roc_curve, auc 17 | import scipy 18 | import os 19 | from sklearn.metrics import roc_auc_score 20 | import pandas as pd 21 | 22 | plt.rcParams["font.family"] = "Times New Roman" 23 | plt.rcParams["font.weight"] = "bold" 24 | plt.rcParams["axes.labelweight"] = "bold" 25 | 26 | def ROC(df): 27 | n_classes = 9 28 | fpr = dict() 29 | tpr = dict() 30 | roc_auc = dict() 31 | l = np.zeros((len(df), n_classes)) 32 | p = np.zeros((len(df), n_classes)) 33 | for i in range(len(df)): 34 | l[i, df.iloc[i, 0]] = 1 35 | 36 | for i in range(n_classes): 37 | label = df['label'].tolist() 38 | score = df['score' + str(i)].tolist() 39 | p[:, i] = score 40 | fpr[i], tpr[i], _ = roc_curve(label, score, pos_label=i) 41 | roc_auc[i] = auc(fpr[i], tpr[i]) 42 | 43 | all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)])) 44 | mean_tpr = np.zeros_like(all_fpr) 45 | for i in range(n_classes): 46 | mean_tpr += np.interp(all_fpr, fpr[i], tpr[i]) 47 | mean_tpr /= n_classes 48 | fpr["macro"] = all_fpr 49 | tpr["macro"] = mean_tpr 50 | weighted_auc = roc_auc_score(l, p, average='macro') 51 | 52 | return fpr['macro'],tpr['macro'], weighted_auc 53 | 54 | 55 | if __name__ == "__main__": 56 | 57 | name='1p19q' 58 | print('\033[1;35;0m字体变色,但无背景色 \033[0m') 59 | np.random.seed(2) 60 | plt.rcParams["font.family"] = "Arial" 61 | plt.rcParams["font.weight"] = "bold" 62 | plt.rcParams["axes.labelweight"] = "bold" 63 | 64 | COLOR_LIST = [ [139 / 255, 20 / 255, 8 / 255],[188 / 255, 189 / 255, 34 / 255], 65 | [52 / 255, 193 / 255, 52 / 255], 66 | [150 / 255, 150 / 255, 190 / 255], [139 / 255, 101 / 255, 8 / 255], 67 | [68 / 255, 114 / 255, 236 / 255], 68 | [100 / 255, 114 / 255, 196 / 255], [214 / 255 + 0.1, 39 / 255 + 0.2, 40 / 255 + 0.2], 69 | [52 / 255, 163 / 255, 152 / 255]] 70 | 71 | LINE_WIDTH_LIST = [3, 3, 3, 3, 3, 3, 3,3,3] 72 | 73 | 74 | i = 0 75 | plt.figure(figsize=[10.5, 10]) 76 | 77 | 78 | LABEL_LIST =['Ours', 'CLAM', 'TransMIL', 'ResNet-18', 79 | 'DenseNet-121', 'VGG-16', 'W/O Graph', 'W/O LC loss', 80 | 'W/O DCC'] 81 | EXCEL_LIST = ['plot/Mine_'+name+'.xlsx', 'plot/CLAM_'+name+'_fea.xlsx', 'plot/TransMIL_'+name+'_fea.xlsx', 82 | 'plot/Basic_'+name+'_img_res.xlsx', 'plot/Basic_'+name+'_img_dense.xlsx', 'plot/Basic_'+name+'_img_VGG.xlsx', 83 | 'plot/Mine_Graph_'+name+'.xlsx', 'plot/Mine_Graphloss_'+name+'.xlsx', 'plot/Mine_DCC_'+name+'.xlsx', ] 84 | 85 | fpr = dict() 86 | tpr = dict() 87 | roc_auc = dict() 88 | 89 | 90 | for i in range(9): 91 | df = pd.read_excel(EXCEL_LIST[i]) 92 | label = df['label'].tolist() 93 | score = df['score'].tolist() 94 | fpr[i], tpr[i], _ = roc_curve(label, score) 95 | # roc_auc[i] = auc(fpr[i], tpr[i]) 96 | plt.plot(fpr[i], tpr[i], 97 | # label=LABEL_LIST[i], 98 | linewidth= LINE_WIDTH_LIST[i] , color=np.array(COLOR_LIST[i])) 99 | plt.plot(1-df['spec'].tolist()[0], df['sen'].tolist()[0], marker="o", markersize=15, markerfacecolor=np.array(COLOR_LIST[i]), markeredgecolor=np.array(COLOR_LIST[i])) 100 | print(df['sen'].tolist()[0]) 101 | print(df['spec'].tolist()[0]) 102 | plt.plot([0, 1], [0, 1], 'k--', lw=2) 103 | plt.grid(color=[0.85, 0.85, 0.85]) 104 | 105 | plt.xticks([0, 0.2, 0.4, 0.6, 0.8, 1], fontsize=24, weight='semibold') 106 | plt.yticks([0, 0.2, 0.4, 0.6, 0.8, 1], fontsize=24, weight='semibold') 107 | 108 | font_axis_name = {'fontsize': 34, 'weight': 'bold'} 109 | plt.xlabel('1-Specificity',font_axis_name) 110 | plt.ylabel('Sensitivity',font_axis_name) 111 | plt.xlim((0, 0.5)) 112 | plt.ylim((0.5, 1)) 113 | plt.legend(framealpha=1, fontsize=30, loc='lower right') 114 | plt.tight_layout() 115 | 116 | plt.savefig("plot/"+name+".tiff") 117 | 118 | plt.show() 119 | -------------------------------------------------------------------------------- /roc_util/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.2.2" 2 | __author__ = "Norman Juchler" 3 | 4 | from ._roc import (get_objective, 5 | compute_roc, 6 | compute_mean_roc, 7 | compute_roc_bootstrap) 8 | from ._plot import (plot_roc, 9 | plot_mean_roc, 10 | plot_roc_simple, 11 | plot_roc_bootstrap) 12 | from ._demo import (demo_basic, 13 | demo_bootstrap, 14 | demo_sample_data) 15 | -------------------------------------------------------------------------------- /roc_util/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LHY1007/M3C2/ef66915bd3e9671312205c424fe6017915f1ed81/roc_util/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /roc_util/__pycache__/_demo.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LHY1007/M3C2/ef66915bd3e9671312205c424fe6017915f1ed81/roc_util/__pycache__/_demo.cpython-38.pyc -------------------------------------------------------------------------------- /roc_util/__pycache__/_plot.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LHY1007/M3C2/ef66915bd3e9671312205c424fe6017915f1ed81/roc_util/__pycache__/_plot.cpython-38.pyc -------------------------------------------------------------------------------- /roc_util/__pycache__/_roc.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LHY1007/M3C2/ef66915bd3e9671312205c424fe6017915f1ed81/roc_util/__pycache__/_roc.cpython-38.pyc -------------------------------------------------------------------------------- /roc_util/__pycache__/_sampling.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LHY1007/M3C2/ef66915bd3e9671312205c424fe6017915f1ed81/roc_util/__pycache__/_sampling.cpython-38.pyc -------------------------------------------------------------------------------- /roc_util/__pycache__/_stats.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LHY1007/M3C2/ef66915bd3e9671312205c424fe6017915f1ed81/roc_util/__pycache__/_stats.cpython-38.pyc -------------------------------------------------------------------------------- /roc_util/__pycache__/_types.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LHY1007/M3C2/ef66915bd3e9671312205c424fe6017915f1ed81/roc_util/__pycache__/_types.cpython-38.pyc -------------------------------------------------------------------------------- /roc_util/_demo.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from ._plot import plot_roc_bootstrap, plot_roc 3 | from ._roc import compute_roc 4 | 5 | 6 | def demo_sample_data(n1, mu1, std1, n2, mu2, std2, seed=42): 7 | """ 8 | Construct binary classification problem with n1 and n2 9 | samples per class, respectively. 10 | 11 | Returns two np.ndarrays x and y of length (n1+n2). 12 | x represents the predictor, y the binary response. 13 | """ 14 | rng = np.random.RandomState(seed) 15 | x1 = rng.normal(mu1, std1, n1) 16 | x2 = rng.normal(mu2, std2, n2) 17 | y1 = np.zeros(n1, dtype=bool) 18 | y2 = np.ones(n2, dtype=bool) 19 | x = np.concatenate([x1, x2]) 20 | y = np.concatenate([y1, y2]) 21 | return x, y 22 | 23 | 24 | def demo_basic(n_samples=600, seed=42): 25 | """ 26 | Demonstrate basic usage of compute_roc() and plot_roc(). 27 | """ 28 | import matplotlib.pyplot as plt 29 | pos_label = True 30 | x, y = demo_sample_data(n1=n_samples//2, mu1=0.0, std1=0.5, 31 | n2=n_samples//2, mu2=1.0, std2=0.7, 32 | seed=seed) 33 | roc = compute_roc(X=x, y=y, pos_label=pos_label) 34 | plot_roc(roc, label="Dataset", color="red") 35 | plt.title("Basic demo") 36 | plt.show() 37 | 38 | 39 | def demo_bootstrap(n_samples=600, n_bootstrap=50, seed=42): 40 | """ 41 | Demonstrate a ROC analysis for a bootstrapped dataset. 42 | """ 43 | import matplotlib.pyplot as plt 44 | assert(n_samples > 2) 45 | pos_label = True 46 | x, y = demo_sample_data(n1=n_samples//2, mu1=0.0, std1=0.5, 47 | n2=n_samples//2, mu2=1.0, std2=0.7, 48 | seed=seed) 49 | plot_roc_bootstrap(X=x, y=y, pos_label=pos_label, 50 | n_bootstrap=n_bootstrap, 51 | random_state=seed+1, 52 | show_boots=False, 53 | title="Bootstrap demo") 54 | plt.show() 55 | -------------------------------------------------------------------------------- /roc_util/_plot.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from ._roc import (compute_roc, 3 | compute_mean_roc, 4 | compute_roc_bootstrap, 5 | _DEFAULT_OBJECTIVE) 6 | 7 | 8 | def plot_roc(roc, 9 | color="red", 10 | label=None, 11 | show_opt=False, 12 | show_details=False, 13 | format_axes=True, 14 | ax=None, 15 | **kwargs): 16 | """ 17 | Plot the ROC curve given the output of compute_roc. 18 | 19 | Arguments: 20 | roc: Output of compute_roc() with the following keys: 21 | - fpr: false positive rates fpr(thr) 22 | - tpr: true positive rates tpr(thr) 23 | - opd: optimal point(s). 24 | - inv: true if predictor is inverted (predicts ~y) 25 | label: Label used for legend. 26 | show_opt: Show optimal point. 27 | show_details: Show additional information. 28 | format_axes: Apply axes settings, show legend, etc. 29 | kwargs: A dictionary with detail settings not exposed 30 | explicitly in the function signature. The following 31 | options are available: 32 | - zorder: 33 | - legend_out: Place legend outside (default: False) 34 | - legend_label_inv: Use 1-AUC if roc.inv=True (True) 35 | Additional kwargs are forwarded to ax.plot(). 36 | """ 37 | import matplotlib.pyplot as plt 38 | import matplotlib.colors as mplc 39 | 40 | def _format_axes(loc, margin): 41 | ax.axis("square") 42 | ax.set_xlim([0 - margin, 1. + margin]) 43 | ax.set_ylim([0 - margin, 1. + margin]) 44 | ax.set_xlabel("FPR (false positive rate)") 45 | ax.set_ylabel("TPR (true positive rate)") 46 | ax.grid(True) 47 | if legend_out: 48 | ax.legend(loc=loc, 49 | bbox_to_anchor=(1.05, 1), 50 | borderaxespad=0.) 51 | else: 52 | ax.legend(loc=loc) 53 | 54 | def _plot_opt_point(key, opt, color, marker, zorder, label, ax): 55 | # Some objectives can be visualized. 56 | # Plot these optional things first. 57 | if key == "minopt": 58 | ax.plot([0, opt.opp[0]], [1, opt.opp[1]], ":ok", 59 | alpha=0.3, 60 | zorder=zorder + 1) 61 | if key == "minoptsym": 62 | d2_ul = (opt.opp[0]-0)**2+(opt.opp[1]-1)**2 63 | d2_ll = (opt.opp[0]-1)**2+(opt.opp[1]-0)**2 64 | ref = (0, 1) if (d2_ul < d2_ll) else (1, 0) 65 | ax.plot([ref[0], opt.opp[0]], [ref[1], opt.opp[1]], ":ok", 66 | alpha=0.3, 67 | zorder=zorder + 1) 68 | if key == "youden": 69 | # Vertical line between optimal point and diagonal. 70 | ax.plot([opt.opp[0], opt.opp[0]], 71 | [opt.opp[0], opt.opp[1]], 72 | color=color, 73 | zorder=zorder + 1) 74 | if key == "cost": 75 | # Line parallel to diagonal (shrunk by m if m≠1). 76 | ax.plot(opt.opq[0], opt.opq[1], ":k", 77 | alpha=0.3, 78 | zorder=zorder + 1) 79 | if key == "concordance": 80 | # Rectangle illustrating the area tpr*(1-fpr) 81 | from matplotlib import patches 82 | ll = [opt.opp[0], 0] 83 | w = 1 - opt.opp[0] 84 | h = opt.opp[1] 85 | rect = patches.Rectangle(ll, w, h, 86 | facecolor=color, 87 | alpha=0.2, 88 | zorder=zorder + 1) 89 | ax.add_patch(rect) 90 | 91 | face_color = mplc.to_rgba(color, alpha=0.3) 92 | ax.plot(opt.opp[0], opt.opp[1], 93 | linestyle="None", 94 | marker=marker, 95 | markerfacecolor=face_color, 96 | markeredgecolor=color, 97 | label=label, 98 | zorder=zorder + 3) 99 | 100 | if ax is None: 101 | ax = plt.gca() 102 | 103 | # Copy the kwargs (a shallow copy should be sufficient). 104 | # Set some defaults. 105 | label = label if label else "Feature" 106 | zorder = kwargs.pop("zorder", 1) 107 | legend_out = kwargs.pop("legend_out", False) 108 | legend_label_inv = kwargs.pop("legend_label_inv", True) 109 | if legend_label_inv: 110 | auc_disp, auc_val = "1-AUC" if roc.inv else "AUC", roc.auc 111 | else: 112 | auc_disp, auc_val = "AUC", roc.auc 113 | label = "%s (%s=%.3f)" % (label, auc_disp, auc_val) 114 | 115 | # Plot the ROC curve. 116 | ax.plot( 117 | roc.fpr, 118 | roc.tpr, 119 | color=color, 120 | zorder=zorder + 2, 121 | # label=label, 122 | **kwargs) 123 | 124 | # Plot the no-discrimination line. 125 | label_diag = "No discrimination" if show_details else None 126 | ax.plot([0, 1], [0, 1], ":k", label=label_diag, 127 | zorder=zorder, linewidth=1) 128 | 129 | # Visualize the optimal point. 130 | if show_opt: 131 | from itertools import cycle 132 | markers = cycle(["o", "*", "^", "s", "P", "D"]) 133 | for key, opt in roc.opd.items(): 134 | pa_str = (", PA=%.3f" % opt.opa) if opt.opa else "" 135 | if show_details: 136 | legend_entry_opt = ("Optimal point (%s, thr=%.3g%s)" 137 | % (key, opt.opt, pa_str)) 138 | else: 139 | legend_entry_opt = "Optimal point (thr=%.3g)" % opt.opt 140 | _plot_opt_point(key=key, opt=opt, color=color, 141 | marker=next(markers), zorder=zorder, 142 | label=legend_entry_opt, ax=ax) 143 | 144 | if format_axes: 145 | margin = 0.02 146 | loc = "upper left" if (roc.inv or legend_out) else "lower right" 147 | _format_axes(loc=loc, margin=margin) 148 | 149 | 150 | def plot_mean_roc(rocs, auto_flip=True, show_all=False, ax=None, **kwargs): 151 | """ 152 | Compute and plot the mean ROC curve for a sequence of ROC containers. 153 | 154 | rocs: List of ROC containers created by compute_roc(). 155 | auto_flip: See compute_roc(), applies only to mean ROC curve. 156 | show_all: If True, show the single ROC curves. 157 | If an integer, show the rocs[:show_all] roc curves. 158 | show_ci: Show confidence interval 159 | show_ti: Show tolerance interval 160 | kwargs: Forwarded to plot_roc(), applies only to mean ROC curve. 161 | 162 | Optional kwargs argument show_opt can be either False, True or a string 163 | specifying the particular objective function to be used to plot the 164 | optimal point. See get_objective() for details. Default choice is the 165 | "minopt" objective. 166 | """ 167 | import matplotlib.pyplot as plt 168 | if ax is None: 169 | ax = plt.gca() 170 | 171 | n_samples = len(rocs) 172 | 173 | # Some default values. 174 | zorder = kwargs.get("zorder", 1) 175 | label = kwargs.pop("label", "Mean ROC curve") 176 | # kwargs for plot_roc()... 177 | show_details = kwargs.get("show_details", False) 178 | show_opt = kwargs.pop("show_opt", False) 179 | show_ti = kwargs.pop("show_ti", True) 180 | show_ci = kwargs.pop("show_ci", True) 181 | color = kwargs.pop("color", "red") 182 | is_opt_str = isinstance(show_opt, (str, list, tuple)) 183 | # Defaults for mean-ROC. 184 | resolution = kwargs.pop("resolution", 101) 185 | objective = show_opt if is_opt_str else _DEFAULT_OBJECTIVE 186 | 187 | # Compute average ROC. 188 | ret_mean = compute_mean_roc(rocs=rocs, 189 | resolution=resolution, 190 | auto_flip=auto_flip, 191 | objective=objective) 192 | 193 | # Plot ROC curve for single bootstrap samples. 194 | if show_all: 195 | def isint(x): 196 | return isinstance(x, int) and not isinstance(x, bool) 197 | n_loops = show_all if isint(show_all) else np.inf 198 | n_loops = min(n_loops, len(rocs)) 199 | for ret in rocs[:n_loops]: 200 | ax.plot(ret.fpr, ret.tpr, 201 | color="gray", 202 | alpha=0.2, 203 | zorder=zorder + 2) 204 | if show_ti: 205 | # 95% interval 206 | tpr_sort = np.sort(ret_mean.tpr_all, axis=0) 207 | tpr_lower = tpr_sort[int(0.025 * n_samples), :] 208 | tpr_upper = tpr_sort[int(0.975 * n_samples), :] 209 | label_int = "95% of all samples" if show_details else None 210 | ax.fill_between(ret_mean.fpr, tpr_lower, tpr_upper, 211 | color="gray", alpha=.2, 212 | label=label_int, 213 | zorder=zorder + 1) 214 | if show_ci: 215 | # 95% confidence interval 216 | tpr_std = np.std(ret_mean.tpr_all, axis=0, ddof=1) 217 | tpr_lower = ret_mean.tpr - 1.96 * tpr_std / np.sqrt(n_samples) 218 | tpr_upper = ret_mean.tpr + 1.96 * tpr_std / np.sqrt(n_samples) 219 | label_ci = "95% CI of mean curve" if show_details else None 220 | ax.fill_between(ret_mean.fpr, tpr_lower, tpr_upper, 221 | color=color, alpha=.3, 222 | label=label_ci, 223 | zorder=zorder) 224 | 225 | # Last but not least, plot the average ROC curve on top of everything. 226 | # plot_roc(roc=ret_mean, label=label, show_opt=show_opt, 227 | # color=color, ax=ax, zorder=zorder + 3, **kwargs) 228 | return ret_mean 229 | 230 | 231 | def plot_roc_simple(X, y, pos_label, auto_flip=True, 232 | title=None, ax=None, **kwargs): 233 | """ 234 | Compute and plot the receiver-operator characteristic curve for X and y. 235 | kwargs are forwarded to plot_roc(), see there for details. 236 | 237 | Optional kwargs argument show_opt can be either False, True or a string 238 | specifying the particular objective function to be used to plot the 239 | optimal point. See get_objective() for details. Default choice is the 240 | "minopt" objective. 241 | """ 242 | import matplotlib.pyplot as plt 243 | if ax is None: 244 | ax = plt.gca() 245 | show_opt = kwargs.pop("show_opt", False) 246 | is_opt_str = isinstance(show_opt, (str, list, tuple)) 247 | objective = show_opt if is_opt_str else _DEFAULT_OBJECTIVE 248 | ret = compute_roc(X=X, y=y, pos_label=pos_label, 249 | objective=objective, 250 | auto_flip=auto_flip) 251 | plot_roc(roc=ret, show_opt=show_opt, ax=ax, **kwargs) 252 | title = "ROC curve" if title is None else title 253 | ax.get_figure().suptitle(title) 254 | return ret 255 | 256 | 257 | def plot_roc_bootstrap(X, y, pos_label, 258 | objective=_DEFAULT_OBJECTIVE, 259 | auto_flip=True, 260 | n_bootstrap=100, 261 | random_state=None, 262 | stratified=False, 263 | show_boots=False, 264 | title=None, 265 | ax=None, 266 | **kwargs): 267 | """ 268 | Similar as plot_roc_simple(), but estimate an average ROC curve from 269 | multiple bootstrap samples. 270 | 271 | See compute_roc_bootstrap() for the meaning of the arguments. 272 | 273 | Optional kwargs argument show_opt can be either False, True or a string 274 | specifying the particular objective function to be used to plot the 275 | optimal point. See get_objective() for details. Default choice is the 276 | "minopt" objective. 277 | """ 278 | import matplotlib.pyplot as plt 279 | if ax is None: 280 | ax = plt.gca() 281 | 282 | # 1) Collect the data. 283 | rocs = compute_roc_bootstrap(X=X, y=y, 284 | pos_label=pos_label, 285 | objective=objective, 286 | auto_flip=auto_flip, 287 | n_bootstrap=n_bootstrap, 288 | random_state=random_state, 289 | stratified=stratified, 290 | return_mean=False) 291 | # 2) Plot the average ROC curve. 292 | ret_mean = plot_mean_roc(rocs=rocs, auto_flip=auto_flip, 293 | show_all=show_boots, ax=ax, **kwargs) 294 | 295 | title = "ROC curve" if title is None else title 296 | ax.get_figure().suptitle(title) 297 | ax.set_title("Bootstrap reps: %d, sample size: %d" % 298 | (n_bootstrap, len(y)), fontsize=10) 299 | return ret_mean 300 | -------------------------------------------------------------------------------- /roc_util/_sampling.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def resample_data(*arrays, **kwargs): 5 | """ 6 | Similar to sklearn's resample function, with a few more extras. 7 | 8 | arrays: Arrays with consistent first dimension. 9 | kwargs: 10 | replace: Sample with replacement. Default: True 11 | n_samples: Number of samples. Default: len(arrays[0]) 12 | frac: Compute the number of samples as a fraction of the 13 | array length: n_samples=frac*len(arrays[0]) 14 | Overrides the value for n_samples if provided. 15 | random_state: Determines the random number generation. Can be None, 16 | an int or np.random.RandomState. Default: None 17 | stratify: An iterable containing the class labels by which the 18 | the arrays should be stratified. Default: None 19 | axis: Sampling axis. Note: axis!=0 is slow! Also, stratify 20 | is currently not supported if axis!=0. Default: axis=0 21 | squeeze: Flatten the output array if only one array is provided. 22 | Default: Trues 23 | """ 24 | def _resample(*arrays, replace, n_samples, stratify, rng, axis=0): 25 | lens = [x.shape[axis] for x in arrays] 26 | equal_length = (lens.count(lens[0]) == len(lens)) 27 | if not equal_length: 28 | msg = "Input arrays don't have equal length: %s" 29 | raise ValueError(msg % lens) 30 | if stratify is not None: 31 | msg = "Stratification is not supported yet." 32 | raise ValueError(msg) 33 | if not isinstance(rng, np.random.RandomState): 34 | rng = np.random.RandomState(rng) 35 | 36 | n = lens[0] 37 | if replace: 38 | indices = rng.randint(0, n, n_samples) 39 | else: 40 | indices = rng.choice(np.arange(0, n), n_samples) 41 | # Sampling along an axis!=0 is not very clever. 42 | arrays = [x.take(indices, axis=axis) for x in arrays] 43 | # Flatten the output if only one input array was provided. 44 | return arrays if len(arrays) > 1 else arrays[0] 45 | try: 46 | from sklearn.utils import resample 47 | has_sklearn = True 48 | except ModuleNotFoundError: 49 | has_sklearn = False 50 | 51 | replace = kwargs.pop("replace", True) 52 | n_samples = kwargs.pop("n_samples", None) 53 | frac = kwargs.pop("frac", None) 54 | rng = kwargs.pop("random_state", None) 55 | stratify = kwargs.pop("stratify", None) 56 | squeeze = kwargs.pop("squeeze", True) 57 | axis = kwargs.pop("axis", 0) 58 | if kwargs: 59 | msg = "Received unexpected argument(s): %s" % kwargs 60 | raise ValueError(msg) 61 | 62 | arrays = [np.asarray(x) if not hasattr(x, "shape") else x for x in arrays] 63 | lens = [x.shape[axis] for x in arrays] 64 | if frac: 65 | n_samples = int(np.round(frac * lens[0])) 66 | if n_samples is None: 67 | n_samples = lens[0] 68 | if axis > 0 or not has_sklearn: 69 | ret = _resample(*arrays, replace=replace, n_samples=n_samples, 70 | stratify=stratify, rng=rng, axis=axis) 71 | else: 72 | ret = resample(*arrays, 73 | replace=replace, 74 | n_samples=n_samples, 75 | stratify=stratify, 76 | random_state=rng,) 77 | # Undo the squeezing, which is done by resample (and _resample). 78 | if not squeeze and len(arrays) == 1: 79 | ret = [ret] 80 | return ret 81 | -------------------------------------------------------------------------------- /roc_util/_stats.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.stats as st 3 | 4 | 5 | def mean_intervals(data, confidence=0.95, axis=None): 6 | """ 7 | Compute the mean, the confidence interval of the mean, and the tolerance 8 | interval. Note that the confidence interval is often misinterpreted [3]. 9 | 10 | References: 11 | [1] https://en.wikipedia.org/wiki/Confidence_interval 12 | [2| https://en.wikipedia.org/wiki/Tolerance_interval 13 | [3] https://en.wikipedia.org/wiki/Confidence_interval#Meaning_and_interpretation 14 | """ 15 | confidence = confidence / 100.0 if confidence > 1.0 else confidence 16 | assert(0 < confidence < 1) 17 | a = 1.0 * np.array(data) 18 | n = len(a) 19 | # Both s=std() and se=sem() use unbiased estimators (ddof=1). 20 | m = np.mean(a, axis=axis) 21 | s = np.std(a, ddof=1, axis=axis) 22 | se = st.sem(a, axis=axis) 23 | t = st.t.ppf((1 + confidence) / 2., n - 1) 24 | ci = np.c_[m - se * t, m + se * t] 25 | ti = np.c_[m - s * t, m + s * t] 26 | assert(ci.shape[1] == 2 and ci.shape[0] == 27 | np.size(m, axis=None if axis is None else 0)) 28 | assert(ti.shape[1] == 2 and ti.shape[0] == 29 | np.size(m, axis=None if axis is None else 0)) 30 | return m, ci, ti 31 | 32 | 33 | def mean_confidence_interval(data, confidence=0.95, axis=None): 34 | """ 35 | Compute the mean and the confidence interval of the mean. 36 | """ 37 | m, ci, _ = mean_intervals(data, confidence, axis=axis) 38 | return m, ci 39 | 40 | 41 | def mean_tolerance_interval(data, confidence=0.95, axis=None): 42 | """ 43 | Compute the tolerance interval for the data. 44 | """ 45 | m, _, ti = mean_intervals(data, confidence, axis=axis) 46 | return m, ti 47 | -------------------------------------------------------------------------------- /roc_util/_types.py: -------------------------------------------------------------------------------- 1 | class StructContainer(): 2 | """ 3 | Build a type that behaves similar to a struct. 4 | 5 | Usage: 6 | # Construction from named arguments. 7 | settings = StructContainer(option1 = False, 8 | option2 = True) 9 | # Construction from dictionary. 10 | settings = StructContainer({"option1": False, 11 | "option2": True}) 12 | print(settings.option1) 13 | settings.option2 = False 14 | for k,v in settings.items(): 15 | print(k,v) 16 | """ 17 | 18 | def __init__(self, dictionary=None, **kwargs): 19 | if dictionary is not None: 20 | assert(isinstance(dictionary, (dict, StructContainer))) 21 | self.__dict__.update(dictionary) 22 | self.__dict__.update(kwargs) 23 | 24 | def __iter__(self): 25 | for i in self.__dict__: 26 | yield i 27 | 28 | def __getitem__(self, key): 29 | return self.__dict__[key] 30 | 31 | def __setitem__(self, key, value): 32 | self.__dict__[key] = value 33 | 34 | def __len__(self): 35 | return sum(1 for k in self.keys()) 36 | 37 | def __repr__(self): 38 | return "struct(**%s)" % str(self.__dict__) 39 | 40 | def __str__(self): 41 | return str(self.__dict__) 42 | 43 | def items(self): 44 | for k, v in self.__dict__.items(): 45 | if not k.startswith("_"): 46 | yield (k, v) 47 | 48 | def keys(self): 49 | for k in self.__dict__: 50 | if not k.startswith("_"): 51 | yield k 52 | 53 | def values(self): 54 | for k, v in self.__dict__.items(): 55 | if not k.startswith("_"): 56 | yield v 57 | 58 | def update(self, data): 59 | self.__dict__.update(data) 60 | 61 | def asdict(self): 62 | return dict(self.items()) 63 | 64 | def first(self): 65 | # Assumption: __dict__ is ordered (python>=3.6). 66 | key, value = next(self.items()) 67 | return key, value 68 | 69 | def last(self): 70 | # Assumption: __dict__ is ordered (python>=3.6). 71 | # See also: https://stackoverflow.com/questions/58413076 72 | key = list(self.keys())[-1] 73 | return key, self[key] 74 | 75 | def get(self, key, default=None): 76 | return self.__dict__.get(key, default) 77 | 78 | def setdefault(self, key, default=None): 79 | return self.__dict__.setdefault(key, default) 80 | -------------------------------------------------------------------------------- /saver.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision 4 | from tensorboardX import SummaryWriter 5 | import numpy as np 6 | from PIL import Image 7 | from evaluation import * 8 | import cv2 9 | 10 | 11 | 12 | class Saver(): 13 | def __init__(self, opt): 14 | self.logDir = opt['logDir'] 15 | self.n_ep_save = opt['n_ep_save'] 16 | self.writer = SummaryWriter(logdir=self.logDir) 17 | 18 | def write_scalars(self, ep, lossdict): 19 | # Todo Save images 20 | for loss_key, loss_value in lossdict.items(): 21 | self.writer.add_scalar(loss_key, loss_value, ep) 22 | 23 | 24 | def write_maps(self, ep, map_dict): 25 | for name,map in map_dict.items(): 26 | if len(map.shape)==2: 27 | map = map[np.newaxis,...] 28 | if map.shape[0] == 1: 29 | map = np.concatenate((map, map, map), axis=0) 30 | self.writer.add_image('map/'+name, map, ep) 31 | 32 | 33 | def write_log(self, ep, lossdict, Name): 34 | logpath = os.path.join(self.logDir, Name + '.log') 35 | title = 'epochs,' 36 | vals = '%d,'%(ep) 37 | for loss_key, loss_value in lossdict.items(): 38 | title = title + loss_key + ',' 39 | vals = vals + '%4f,'% (loss_value) 40 | title = title[:-1] + '\n' 41 | vals = vals[:-1] + '\n' 42 | if ep==self.n_ep_save-1: 43 | saveFile = open(logpath, "w") 44 | saveFile.write(title) 45 | saveFile.write(vals) 46 | else: 47 | saveFile = open(logpath, "a") 48 | saveFile.write(vals) 49 | saveFile.close() 50 | 51 | 52 | 53 | 54 | def write_imagegroup(self, ep, images, basename, key): 55 | # images: tensor Bx3xHxW or Bx1xHxW or BxHxW 56 | if len(images.shape) == 3: 57 | images = torch.unsqueeze(images, 1) 58 | images = torch.cat([images, images, images], 1) 59 | elif images.shape[1] == 1: 60 | images = torch.cat([images, images, images], 1) 61 | image_dis = torchvision.utils.make_grid(images, nrow=7) 62 | self.writer.add_image('map/' + key, image_dis, ep) 63 | image_dis2 = image_dis 64 | 65 | ndarr = image_dis2.mul_(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy() 66 | savename = os.path.join(self.logDir, key + '_' + basename + '_' +str(ep) + '.png') 67 | if key == 'SAmap': 68 | ndarr = cv2.applyColorMap(ndarr, cv2.COLORMAP_JET) 69 | cv2.imwrite(savename, ndarr) 70 | 71 | return ndarr 72 | 73 | def write_cm_maps(self, ep, cm, class_list, savename='cm.png'): 74 | savename = os.path.join(self.logDir, savename) 75 | plot_confusion_matrix(cm, savename, title='Confusion Matrix', 76 | classes=class_list) 77 | cmimg = cv2.imread(savename) 78 | cmimg = np.transpose(cmimg, (2, 0, 1)) 79 | self.writer.add_image('map/cm', cmimg, ep) 80 | 81 | -------------------------------------------------------------------------------- /test_main_new.py: -------------------------------------------------------------------------------- 1 | # from apex import amp 2 | import numpy as np 3 | 4 | from utils import * 5 | import dataset_mine 6 | from net import init_weights,get_scheduler,WarmupCosineSchedule 7 | from matplotlib import pyplot as plt 8 | import matplotlib as mpl 9 | import cmaps 10 | import h5py 11 | def train(opt): 12 | root=r'/home/hanyu/LHY/miccai7.22/best_model/Best3_0720-0755-0003.pth' 13 | 14 | opt['gpus']=[6,7] 15 | gpuID = opt['gpus'] 16 | valDataset = dataset_mine.Our_Dataset_vis(phase='Test', opt=opt) 17 | valLoader = DataLoader(valDataset, batch_size=opt['Val_batchSize'], 18 | num_workers=opt['nThreads'] if (sysstr == "Linux") else 1, shuffle=False) 19 | 20 | ############## initialize ####################### 21 | 22 | last_ep = 0 23 | total_it = 0 24 | saver = Saver(opt) 25 | print('%d epochs and %d iterations has been trained' % (last_ep, total_it)) 26 | alleps = opt['n_ep'] - last_ep 27 | curep=0 28 | if 1: 29 | ## IDH 30 | model_init = Mine_init(opt).cuda(gpuID[0]) 31 | model_IDH = Mine_IDH(opt).cuda(gpuID[0]) 32 | model_1p19q = Mine_1p19q(opt).cuda(gpuID[0]) 33 | model_CDKN = Mine_CDKN(opt).cuda(gpuID[0]) 34 | model_Graph = Label_correlation_Graph(opt).cuda(gpuID[0]) 35 | model_His = Mine_His(opt).cuda(gpuID[0]) 36 | model_Cls = Cls_His_Grade(opt).cuda(gpuID[0]) 37 | model_Task = Mine_Task(opt).cuda(gpuID[0]) 38 | 39 | model_init = torch.nn.DataParallel(model_init, device_ids=gpuID) 40 | model_IDH = torch.nn.DataParallel(model_IDH, device_ids=gpuID) 41 | model_1p19q = torch.nn.DataParallel(model_1p19q, device_ids=gpuID) 42 | model_CDKN = torch.nn.DataParallel(model_CDKN, device_ids=gpuID) 43 | model_Graph = torch.nn.DataParallel(model_Graph, device_ids=gpuID) 44 | model_His = torch.nn.DataParallel(model_His, device_ids=gpuID) 45 | model_Cls = torch.nn.DataParallel(model_Cls, device_ids=gpuID) 46 | model_Task = torch.nn.DataParallel(model_Task, device_ids=gpuID) 47 | 48 | ckptdir = os.path.join(root) 49 | checkpoint = torch.load(ckptdir) 50 | related_params = {k: v for k, v in checkpoint['init'].items()} 51 | model_init.load_state_dict(related_params) 52 | related_params = {k: v for k, v in checkpoint['IDH'].items()} 53 | model_IDH.load_state_dict(related_params) 54 | related_params = {k: v for k, v in checkpoint['1p19q'].items()} 55 | model_1p19q.load_state_dict(related_params) 56 | related_params = {k: v for k, v in checkpoint['CDKN'].items()} 57 | model_CDKN.load_state_dict(related_params) 58 | related_params = {k: v for k, v in checkpoint['Graph'].items()} 59 | model_Graph.load_state_dict(related_params) 60 | related_params = {k: v for k, v in checkpoint['His'].items()} 61 | model_His.load_state_dict(related_params) 62 | related_params = {k: v for k, v in checkpoint['Cls'].items()} 63 | model_Cls.load_state_dict(related_params,strict=False) 64 | related_params = {k: v for k, v in checkpoint['Task'].items()} 65 | model_Task.load_state_dict(related_params,strict=False) 66 | 67 | model_init.eval() 68 | model_IDH.eval() 69 | model_1p19q.eval() 70 | model_CDKN.eval() 71 | model_Graph.eval() 72 | model_His.eval() 73 | model_Cls.eval() 74 | model_Task.eval() 75 | model = [model_init, model_IDH, model_1p19q, model_CDKN, model_Graph, model_His, model_Cls,model_Task] 76 | 77 | 78 | 79 | print("----------Val-------------") 80 | # validation_test_vis(opt, model,valLoader, gpuID) 81 | vis_reconstruct(opt, model,valLoader, gpuID) 82 | 83 | 84 | 85 | def vis_reconstruct(opt, model,dataloader, gpuID): 86 | CPTAC_label = pd.read_excel(opt['CPTAC_label_path'], header=0) 87 | IvYGAP_label = pd.read_excel(opt['IvYGAP_label_path'], sheet_name='Sheet1', header=0) 88 | TCGA_label = pd.read_excel(opt['TCGA_label_path'], sheet_name='Sheet1', header=0) 89 | combined_labels = pd.concat([TCGA_label, CPTAC_label], ignore_index=True) 90 | excel_wsi = combined_labels.values 91 | 92 | PATIENT_LIST=excel_wsi[:,0] 93 | np.random.seed(opt['seed']) 94 | random.seed(opt['seed']) 95 | PATIENT_LIST=list(PATIENT_LIST) 96 | # IvYGAP_label 97 | IvYGAP_label = IvYGAP_label.values 98 | 99 | PATIENT_LIST=np.unique(PATIENT_LIST) 100 | np.random.shuffle(PATIENT_LIST) 101 | NUM_PATIENT_ALL=len(PATIENT_LIST) # 952 102 | TRAIN_PATIENT_LIST=PATIENT_LIST[0:int(NUM_PATIENT_ALL * 0.8)] 103 | VAL_PATIENT_LIST = PATIENT_LIST[int(NUM_PATIENT_ALL * 0.9):] 104 | TEST_PATIENT_LIST = PATIENT_LIST[int(NUM_PATIENT_ALL * 0.80):int(NUM_PATIENT_ALL * 0.90)] 105 | TEST_LIST = [] 106 | I_TEST_LIST = [] 107 | 108 | for i in range(excel_wsi.shape[0]):# 2612 109 | 110 | if excel_wsi[:,0][i] in PATIENT_LIST: 111 | TEST_LIST.append(excel_wsi[i,:]) 112 | TEST_LIST = np.asarray(TEST_LIST) 113 | 114 | for i in range(TEST_LIST.shape[0]): 115 | root = '/Res50_feature_2500_fixdim0_norm' 116 | 117 | patient_id = TEST_LIST[i, 0] 118 | 119 | 120 | if patient_id[0].startswith('T'): 121 | base_path = opt['dataDir'] + 'TCGA' 122 | elif patient_id[0].startswith('W'): 123 | base_path = opt['dataDir'] + 'IvYGAP' 124 | elif patient_id[0].startswith('C'): 125 | base_path = opt['dataDir'] + 'CPTAC' 126 | else: 127 | raise ValueError("Unknown data source") 128 | 129 | patch_20 = h5py.File(base_path + root + '_20x/' + TEST_LIST[i, 1] + '.h5')['Res_feature'][:] 130 | patch_10 = h5py.File(base_path + root + '/' + TEST_LIST[i, 1] + '.h5')['Res_feature'][:] 131 | img20 = torch.from_numpy(np.array(patch_20[0])).float() 132 | img10 = torch.from_numpy(np.array(patch_10[0])).float() 133 | 134 | read_details = np.load(opt['dataDir'] + 'read_details/' + TEST_LIST[i, 1] + '.npy', allow_pickle=True)[0] 135 | WSI_name=TEST_LIST[i, 1] 136 | 137 | 138 | if torch.cuda.is_available(): 139 | img20 = img20.cuda(gpuID[0]) 140 | img10 = img10.cuda(gpuID[0]) 141 | # label = label.cuda(gpuID[0]) 142 | # # # # IDH 143 | 144 | init_feature_his, init_feature_mark, _, _, _, _ = model[0](img20,img10) # (BS,2500,1024) 145 | hidden_states, encoded_IDH = model[1](init_feature_mark) 146 | hidden_states, encoded_1p19q = model[2](hidden_states) 147 | encoded_CDKN = model[3](hidden_states) 148 | 149 | results_dict,weight_IDH_wt,weight_IDH_mut,weight_1p19q_codel,weight_CDKN_HOMDEL,encoded_IDH0,encoded_1p19q0,encoded_CDKN0,Mark_output = \ 150 | model[4](encoded_IDH, encoded_1p19q, encoded_CDKN) 151 | pred_IDH = results_dict['logits_IDH'] 152 | pred_1p19q = results_dict['logits_1p19q'] 153 | pred_CDKN = results_dict['logits_CDKN'] 154 | 155 | hidden_states, encoded_His = model[5](init_feature_his) 156 | results_dict, weight_His_GBM, weight_His_GBM_Cls2,weight_His_O,His_output = model[6](encoded_His) 157 | pred_His_2class = results_dict['logits_His_2class'] 158 | pred_His = results_dict['logits_His'] 159 | 160 | 161 | # wsi_w = h5py.File('vis_results/set0/' + WSI_name + '.h5')['wsi_w'][()] 162 | # wsi_h = h5py.File('vis_results/set0/' + WSI_name + '.h5')['wsi_h'][()] 163 | # MPP = h5py.File('vis_results/set0/' + WSI_name + '.h5')['MPP'][()] 164 | 165 | weight_IDH_wt = norm(np.array(weight_IDH_wt.tolist()),read_details.shape[0]) 166 | weight_His_GBM = norm(np.array(weight_His_GBM.tolist()),read_details.shape[0]) 167 | weight_IDH_mut = norm(np.array(weight_IDH_mut.tolist()),read_details.shape[0]) 168 | weight_1p19q_codel = norm(np.array(weight_1p19q_codel.tolist()),read_details.shape[0]) 169 | weight_CDKN_HOMDEL = norm(np.array(weight_CDKN_HOMDEL.tolist()),read_details.shape[0]) 170 | 171 | 172 | 173 | ################################################################ 174 | 175 | 176 | 177 | # 178 | 179 | # relative_MPP = MPP / 0.5 180 | # PATCH_SIZE_revise = np.int(512 / relative_MPP) 181 | 182 | # wsi_w = np.int(wsi_w * (224 / PATCH_SIZE_revise)) + 1 183 | # wsi_h = np.int(wsi_h * (224 / PATCH_SIZE_revise)) + 1 184 | wsi_h=224 185 | wsi_w=224 186 | wsi_reconstruct = np.ones(shape=(wsi_h, wsi_w, 3), dtype=np.uint8) * 255 187 | wsi_reconstruct_nmp = np.ones(shape=(wsi_h, wsi_w, 3), dtype=np.uint8) * 255 188 | wsi_reconstruct_mut = np.ones(shape=(wsi_h, wsi_w, 3), dtype=np.uint8) * 255 189 | wsi_reconstruct_pq = np.ones(shape=(wsi_h, wsi_w, 3), dtype=np.uint8) * 255 190 | wsi_reconstruct_cdkn = np.ones(shape=(wsi_h, wsi_w, 3), dtype=np.uint8) * 255 191 | 192 | for j in range(len(weight_IDH_wt)): 193 | width_index = np.int(read_details[j][0]) 194 | height_index = np.int(read_details[j][1]) 195 | wsi_reconstruct[height_index * 224:(height_index + 1) * 224, width_index * 224:(width_index + 1) * 224, 196 | :] = weight_IDH_wt[j] 197 | wsi_reconstruct_nmp[height_index * 224:(height_index + 1) * 224, width_index * 224:(width_index + 1) * 224, 198 | :] = weight_His_GBM[j] 199 | wsi_reconstruct_mut[height_index * 224:(height_index + 1) * 224, width_index * 224:(width_index + 1) * 224, 200 | :] = weight_IDH_mut[j] 201 | wsi_reconstruct_pq[height_index * 224:(height_index + 1) * 224, width_index * 224:(width_index + 1) * 224, 202 | :] = weight_1p19q_codel[j] 203 | wsi_reconstruct_cdkn[height_index * 224:(height_index + 1) * 224, width_index * 224:(width_index + 1) * 224, 204 | :] = weight_CDKN_HOMDEL[j] 205 | 206 | # 207 | 208 | wsi_reconstruct = Image.fromarray(wsi_reconstruct) 209 | wsi_reconstruct = wsi_reconstruct.resize((2000, int(wsi_h / wsi_w * 2000))) 210 | wsi_reconstruct.save('vis_results/set2/' + WSI_name + '_vis_IDHwt.jpg') 211 | 212 | wsi_reconstruct_nmp = Image.fromarray(wsi_reconstruct_nmp) 213 | wsi_reconstruct_nmp = wsi_reconstruct_nmp.resize((2000, int(wsi_h / wsi_w * 2000))) 214 | wsi_reconstruct_nmp.save('vis_results/set2/' + WSI_name + '_vis_nmp.jpg') 215 | 216 | wsi_reconstruct_mut = Image.fromarray(wsi_reconstruct_mut) 217 | wsi_reconstruct_mut = wsi_reconstruct_mut.resize((2000, int(wsi_h / wsi_w * 2000))) 218 | wsi_reconstruct_mut.save('vis_results/set2/' + WSI_name + '_vis_IDHmut.jpg') 219 | 220 | wsi_reconstruct_pq = Image.fromarray(wsi_reconstruct_pq) 221 | wsi_reconstruct_pq = wsi_reconstruct_pq.resize((2000, int(wsi_h / wsi_w * 2000))) 222 | wsi_reconstruct_pq.save('vis_results/set2/' + WSI_name + '_vis_pq.jpg') 223 | 224 | wsi_reconstruct_cdkn = Image.fromarray(wsi_reconstruct_cdkn) 225 | wsi_reconstruct_cdkn = wsi_reconstruct_cdkn.resize((2000, int(wsi_h / wsi_w * 2000))) 226 | wsi_reconstruct_cdkn.save('vis_results/set2/' + WSI_name + '_vis_cdkn.jpg') 227 | 228 | 229 | count += 1 230 | print(i) 231 | 232 | 233 | 234 | 235 | 236 | def validation_test_vis(opt,model, dataloader,gpuID): 237 | 238 | CPTAC_label = pd.read_excel(opt['CPTAC_label_path'], header=0) 239 | IvYGAP_label = pd.read_excel(opt['IvYGAP_label_path'], sheet_name='Sheet1', header=0) 240 | TCGA_label = pd.read_excel(opt['TCGA_label_path'], sheet_name='Sheet1', header=0) 241 | combined_labels = pd.concat([TCGA_label, CPTAC_label], ignore_index=True) 242 | excel_label_wsi = combined_labels.values 243 | 244 | excel_wsi = list(excel_label_wsi.values) 245 | excel_wsi_new=[] 246 | test_bar = tqdm(dataloader) 247 | bs = 1 248 | count = 0 249 | for packs in test_bar: 250 | img20,img10, label= packs[0] 251 | 252 | 253 | 254 | if torch.cuda.is_available(): 255 | img20 = img20.cuda(gpuID[0]) 256 | img10 = img10.cuda(gpuID[0]) 257 | # label = label.cuda(gpuID[0]) 258 | # # # # IDH 259 | 260 | init_feature_his, init_feature_mark, _, _, _, _ = model[0](img20,img10) # (BS,2500,1024) 261 | hidden_states, encoded_IDH = model[1](init_feature_mark) 262 | hidden_states, encoded_1p19q = model[2](hidden_states) 263 | encoded_CDKN = model[3](hidden_states) 264 | 265 | results_dict,weight_IDH_wt,weight_IDH_mut,weight_1p19q_codel,weight_CDKN_HOMDEL,encoded_IDH0,encoded_1p19q0,encoded_CDKN0,Mark_output = \ 266 | model[4](encoded_IDH, encoded_1p19q, encoded_CDKN) 267 | pred_IDH = results_dict['logits_IDH'] 268 | pred_1p19q = results_dict['logits_1p19q'] 269 | pred_CDKN = results_dict['logits_CDKN'] 270 | 271 | hidden_states, encoded_His = model[5](init_feature_his) 272 | results_dict, weight_His_GBM, weight_His_GBM_Cls2,weight_His_O,His_output = model[6](encoded_His) 273 | pred_His_2class = results_dict['logits_His_2class'] 274 | pred_His = results_dict['logits_His'] 275 | his_mark = model[7](His_output.float(), Mark_output.float()) 276 | weight_IDH_wt=weight_IDH_wt.tolist() 277 | 278 | _, pred_His = torch.max(pred_His.data, 1) 279 | pred_His = pred_His.tolist() 280 | _, pred_IDH = torch.max(pred_IDH.data, 1) 281 | pred_IDH = pred_IDH.tolist() 282 | _, pred_1p19q = torch.max(pred_1p19q.data, 1) 283 | pred_1p19q = pred_1p19q.tolist() 284 | _, pred_CDKN = torch.max(pred_CDKN.data, 1) 285 | pred_CDKN = pred_CDKN.tolist() 286 | _, pred_His_2class = torch.max(pred_His_2class.data, 1) 287 | pred_His_2class = pred_His_2class.tolist() 288 | _, pred_Task = torch.max(his_mark.data, 1) 289 | pred_Task = pred_Task.tolist() 290 | 291 | 292 | if pred_His[0]==3 and pred_IDH[0]==0 and pred_1p19q[0]==0 and (pred_CDKN[0]==label[:, 2].tolist()[0]): 293 | excel_wsi_new.append(excel_wsi[count]) 294 | count += 1 295 | excel_wsi_new=np.array(excel_wsi_new) 296 | df = pd.DataFrame(excel_wsi_new, columns=list(excel_label_wsi)) 297 | df.to_excel("vis/Test_TPall.xlsx",index=False) 298 | 299 | 300 | def norm(weight,num_patch): 301 | N_biorepet = int(2500 / num_patch) 302 | weight_0=weight[0:num_patch] 303 | weight_color=[] 304 | if N_biorepet>1: 305 | for j in range(N_biorepet-1): 306 | weight_0+=weight[(j+1)*num_patch:(j+2)*num_patch] 307 | 308 | min_w=np.min(weight_0) 309 | max_w = np.max(weight_0) 310 | weight_0=(weight_0-min_w)/(max_w-min_w) 311 | 312 | cmap = cmaps.MPL_viridis # 引用NCL的colormap 313 | newcolors = cmap(np.linspace(0, 1, 256))*255 314 | newcolors = np.trunc(newcolors) 315 | newcolors = newcolors.astype(int) 316 | ref_array=np.zeros(shape=[256]) 317 | 318 | for k in range(256): 319 | ref_array[k]=k/256 320 | 321 | for k in range(weight_0.shape[0]): 322 | delta=np.abs(ref_array-weight_0[k]*weight_0[k]) 323 | delta=list(delta) 324 | min_del=min(delta) 325 | min_del_index=delta.index(min_del) 326 | weight_color.append(newcolors[min_del_index][0:3]) 327 | 328 | 329 | 330 | return weight_color 331 | 332 | 333 | 334 | def setup_seed(seed): 335 | torch.manual_seed(seed) 336 | torch.cuda.manual_seed(seed) 337 | torch.cuda.manual_seed_all(seed) 338 | np.random.seed(seed) 339 | random.seed(seed) 340 | if seed == 0: 341 | torch.backends.cudnn.deterministic = True 342 | torch.backends.cudnn.benchmark = False 343 | 344 | if __name__ == '__main__': 345 | parser = argparse.ArgumentParser() 346 | parser.add_argument('--opt', type=str, default='config/mine.yml') 347 | args = parser.parse_args() 348 | with open(args.opt) as f: 349 | opt = yaml.load(f, Loader=SafeLoader) 350 | 351 | setup_seed(opt['seed']) 352 | sysstr = platform.system() 353 | opt['logDir'] = os.path.join(opt['logDir'], 'Mine') 354 | if not os.path.exists(opt['logDir']): 355 | os.makedirs(opt['logDir']) 356 | train(opt) 357 | 358 | 359 | 360 | 361 | a=1 362 | 363 | 364 | 365 | 366 | 367 | 368 | 369 | 370 | 371 | 372 | 373 | 374 | 375 | 376 | 377 | 378 | 379 | 380 | 381 | 382 | 383 | 384 | 385 | 386 | 387 | 388 | 389 | 390 | 391 | -------------------------------------------------------------------------------- /transform/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # ************************************** 4 | # @Author : Qiqi Xiao 5 | # @Email : xiaoqiqi177@gmail.com 6 | # @File : __init__.py 7 | # ************************************** 8 | from .transforms_group import * 9 | -------------------------------------------------------------------------------- /transform/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LHY1007/M3C2/ef66915bd3e9671312205c424fe6017915f1ed81/transform/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /transform/__pycache__/functional.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LHY1007/M3C2/ef66915bd3e9671312205c424fe6017915f1ed81/transform/__pycache__/functional.cpython-38.pyc -------------------------------------------------------------------------------- /transform/__pycache__/transforms_group.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LHY1007/M3C2/ef66915bd3e9671312205c424fe6017915f1ed81/transform/__pycache__/transforms_group.cpython-38.pyc --------------------------------------------------------------------------------