├── CPTAC.xlsx
├── IvYGAP.xlsx
├── README.md
├── ROC1.zip
├── ROC2.zip
├── TCGA.xlsx
├── __pycache__
├── dataset.cpython-38.pyc
├── dataset_mine.cpython-38.pyc
├── evaluation.cpython-38.pyc
├── model.cpython-38.pyc
├── net.cpython-38.pyc
├── saver.cpython-38.pyc
└── utils.cpython-38.pyc
├── basic_net
├── CDKN_main.py
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-38.pyc
│ ├── alexnet.cpython-38.pyc
│ ├── densenet.cpython-38.pyc
│ ├── inception.cpython-38.pyc
│ ├── mnasnet.cpython-38.pyc
│ ├── nystrom_atten.cpython-38.pyc
│ └── resnet.cpython-38.pyc
├── alexnet.py
├── densenet.py
├── inception.py
├── mnasnet.py
├── nystrom_atten.py
└── resnet.py
├── config
├── CDKN.yml
├── cifar_pretrain.yml
├── miccai.yml
└── mine.yml
├── data_process.py
├── dataset.py
├── dataset_mine copy 2.py
├── dataset_mine copy.py
├── dataset_mine.py
├── debug.py
├── docs
├── 1748968628246.png
└── framework图.png
├── evaluation.py
├── feature_generation.py
├── logs.py
├── main _noGrad_guide.py
├── main miccai.py
├── main.py
├── main_DCC-IDH.py
├── main_dis_loss1.py
├── main_dis_loss2.py
├── main_dis_noloss.py
├── main_noDCC.py
├── main_noGrad.py
├── main_noGrad_longth copy.py
├── main_noGrad_longth.py
├── main_noLCloss.py
├── main_no_both.py
├── main_no_his.py
├── main_no_maker.py
├── main_nograph pre.py
├── main_nograph.py
├── mainpre.py
├── merge_who.xlsx
├── model copy.py
├── model.py
├── net.py
├── post_processing.py
├── roc_plot mu.py
├── roc_plot.py
├── roc_util
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-38.pyc
│ ├── _demo.cpython-38.pyc
│ ├── _plot.cpython-38.pyc
│ ├── _roc.cpython-38.pyc
│ ├── _sampling.cpython-38.pyc
│ ├── _stats.cpython-38.pyc
│ └── _types.cpython-38.pyc
├── _demo.py
├── _plot.py
├── _roc.py
├── _sampling.py
├── _stats.py
└── _types.py
├── saver.py
├── test_forroc.py
├── test_forroc_new.py
├── test_main.py
├── test_main_new.py
├── transform
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-38.pyc
│ ├── functional.cpython-38.pyc
│ └── transforms_group.cpython-38.pyc
├── functional.py
└── transforms_group.py
├── utils copy.py
├── utils.py
└── utils_finetune.py
/CPTAC.xlsx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LHY1007/M3C2/ef66915bd3e9671312205c424fe6017915f1ed81/CPTAC.xlsx
--------------------------------------------------------------------------------
/IvYGAP.xlsx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LHY1007/M3C2/ef66915bd3e9671312205c424fe6017915f1ed81/IvYGAP.xlsx
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
M3C2
4 |
5 |
6 | Joint Modelling Histology and Molecular Markers for Glioma Classification
7 |
8 |
9 | Xiaofei Wang, a,1,
10 | Hanyu Liu,b,1,
11 | Yupei Zhanga,
12 | Boyang Zhaob,
13 | Hao Duanc,
14 | Wanming Hud,
15 | Yonggao Mouc,
16 | Stephen Pricea,
17 | Chao Lia,b,e,f
18 | a Department of Clinical Neurosciences, University of Cambridge, UK
19 | b School of Science and Engineering, University of Dundee, UK
20 | c Department of Neurosurgery, State Key Laboratory of Oncology in South China, Guangdong Provincial Clinical Research Center for Cancer, Sun Yat-sen
21 | University Cancer Center, China
22 | d Department of Pathology, State Key Laboratory of Oncology in South China, Guangdong Provincial Clinical Research Center for Cancer, Sun Yat-sen University
23 | Cancer Center, China
24 | e Department of Applied Mathematics and Theoretical Physics, University of Cambridge, UK
25 | f School of Medicine, University of Dundee, UK
26 |
27 |
28 |
29 |
30 |
34 |
35 |
36 | ## 📣 Latest Updates
37 |
38 | - **[2025-03-25]** 📊 *M3C2 Code have been released!*
39 | - **[2025-02-04]** 📝 *M3C2 paper preprint is now available on [Medical Image Analysis]([https://arxiv.org/abs/your-link](https://www.sciencedirect.com/science/article/pii/S1361841525000532)).*
40 | - **[2024-12-04]** 🎉 *M3C2 has been accepted to Medical Image Analysis!*
41 | - **[2024-08-02]** 📝 *M3C2 are now submitted to the Medical Image Analysis.*
42 |
43 | ## Key Takeaways
44 |
45 | - **M3C2** presents a groundbreaking framework for cancer classification by integrating histology and molecular markers.
46 | 🧠 **Key Innovation**: The framework employs **multi-scale disentangling modules** to extract both **high-magnification cellular-level** and **low-magnification tissue-level** features, which are then used to predict histology and molecular markers simultaneously.
47 |
48 | - The method introduces a **Co-occurrence Probability-based Label-Correlation Graph (CPLC-Graph)** to model the relationships between multiple molecular markers.
49 | This enhancement leads to better classification accuracy by capturing **intrinsic marker co-occurrences** and their impact on cancer classification.
50 |
51 | - **Cross-Modal Interaction** is key to the model’s success.
52 | 🔄 **Interaction Mechanism**: By using **dynamic confidence constraints** and a **cross-modal gradient modulation strategy**, M3C2 efficiently aligns the prediction tasks for histology and molecular markers, ensuring both tasks complement each other for more accurate results.
53 |
54 | - **Validation Across Diverse Datasets**: M3C2 outperforms existing state-of-the-art methods in **glioma classification** and **molecular marker prediction**, showcasing its robustness in **internal** and **external validation datasets**.
55 | 📊 **Performance**: The method achieves significant improvements, with **accuracy** and **AUC scores** surpassing previous models by as much as **5.6%** in certain tasks.
56 |
57 | - **Clinical Implications**: The ability to predict molecular markers directly from **whole-slide images (WSIs)**, combined with the model's capacity to understand the interactions between histology and molecular data, offers strong potential for **precision oncology**.
58 | 🏥 **Impact**: M3C2’s approach aligns with the latest **WHO glioma classification criteria**, making it a promising tool for clinical decision-making and personalized cancer treatment.
59 |
60 |
61 | 
62 |
63 | ## About this code
64 |
65 | he M3C2 codebase is written in Python and focuses on integrating histology features and molecular markers for cancer classification. It uses various deep learning techniques for analyzing whole-slide images (WSIs) and predicting cancer types, particularly gliomas. The core module structure is as follows:
66 |
67 | ```
68 | M3C2-main/
69 | ├── CPTAC.xlsx # Dataset containing clinical and molecular data for cancer classification (CPTAC).
70 | ├── IvYGAP.xlsx # Dataset containing data for glioma diagnosis and treatment (IvYGAP).
71 | ├── TCGA.xlsx # Dataset from The Cancer Genome Atlas (TCGA) for glioma classification.
72 | ├── merge_who.xlsx # Merged dataset based on the latest WHO glioma classification.
73 | ├── README.md # Overview of the project, its purpose, and instructions for usage.
74 | ├── data_process.py # Script to process and clean datasets for model training.
75 | ├── dataset.py # Defines the dataset structure and loading mechanism for WSIs.
76 | ├── dataset_mine.py # Alternative dataset processing script with additional feature extraction.
77 | ├── feature_generation.py # Generates features required for classification from raw data.
78 | ├── model.py # Defines the neural network architecture used for cancer classification.
79 | ├── net.py # Contains code for building the network layers of the model.
80 | ├── evaluation.py # Script to evaluate the model’s performance on various tasks.
81 | ├── post_processing.py # Post-processing of model predictions such as filtering or formatting.
82 | ├── main.py # Main script for running the model with specified parameters.
83 | ├── main_noGrad_guide.py # Variant of the main script with gradient updates disabled for specific tasks.
84 | ├── main_miccai.py # Main script variant used for experiments aligned with MICCAI conference.
85 | ├── main_noLCloss.py # Variant without the label correlation loss for model training.
86 | ├── model_copy.py # Another copy of the model code, likely with experimental variations.
87 | ├── roc_plot.py # Script for plotting ROC curves to evaluate model performance.
88 | ├── test_forroc.py # Testing script specifically for evaluating model performance via ROC.
89 | ├── transform # Folder containing data transformation functions for WSIs.
90 | │ ├── augmentations.py # Defines data augmentation methods (e.g., rotation, zoom).
91 | │ ├── normalize.py # Script for normalizing the input data (WSIs).
92 | ├── utils.py # Contains helper functions used throughout the project.
93 | ├── utils_finetune.py # Utility functions specifically for fine-tuning the model.
94 | ├── logs.py # Handles logging of training and testing results.
95 | ├── __pycache__ # Folder containing Python bytecode files for faster execution.
96 | ```
97 |
98 | ## How to apply the work
99 | ### 1. Environment
100 | - Python >= 3.7
101 | - Pytorch >= 1.12 is recommended
102 | - opencv-python
103 | - sklearn
104 | - matplotlib
105 |
106 |
107 | ### 2. Train
108 | Use the below command to train the model on our database.
109 | ```
110 | python ./main.py
111 | ```
112 |
113 | ### 3. Test
114 | Use the below command to test the model on our database.
115 | ```
116 | python ./test_main.py
117 | ```
118 |
119 | ### 4. Datasets
120 | ```
121 | https://www.kaggle.com/datasets/liuhanyu1007/m3c2-data
122 | ```
123 |
124 | ### 5. Model
125 | ```
126 | https://www.kaggle.com/models/liuhanyu1007/m3c2_model
127 | ```
128 |
129 | ## Contact
130 | - Xiaofei Wang: xw405@cam.ac.uk
131 | - Hanyu Liu: 2485644@dundee.ac.uk
132 |
133 |
134 | Please open an issue or submit a pull request for issues, or contributions.
135 |
136 | ## 💼 License
137 |
138 |
139 |
140 |
141 |
142 | ## Citation
143 |
144 | If you find our benchmark is helpful, please cite our paper:
145 |
146 | ```
147 | @article{wang2025joint,
148 | title={Joint Modelling Histology and Molecular Markers for Cancer Classification},
149 | author={Wang, Xiaofei and Liu, Hanyu and Zhang, Yupei and Zhao, Boyang and Duan, Hao and Hu, Wanming and Mou, Yonggao and Price, Stephen and Li, Chao},
150 | journal={arXiv preprint arXiv:2502.07979},
151 | year={2025}
152 | }
153 | ```
154 |
--------------------------------------------------------------------------------
/ROC1.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LHY1007/M3C2/ef66915bd3e9671312205c424fe6017915f1ed81/ROC1.zip
--------------------------------------------------------------------------------
/ROC2.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LHY1007/M3C2/ef66915bd3e9671312205c424fe6017915f1ed81/ROC2.zip
--------------------------------------------------------------------------------
/TCGA.xlsx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LHY1007/M3C2/ef66915bd3e9671312205c424fe6017915f1ed81/TCGA.xlsx
--------------------------------------------------------------------------------
/__pycache__/dataset.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LHY1007/M3C2/ef66915bd3e9671312205c424fe6017915f1ed81/__pycache__/dataset.cpython-38.pyc
--------------------------------------------------------------------------------
/__pycache__/dataset_mine.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LHY1007/M3C2/ef66915bd3e9671312205c424fe6017915f1ed81/__pycache__/dataset_mine.cpython-38.pyc
--------------------------------------------------------------------------------
/__pycache__/evaluation.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LHY1007/M3C2/ef66915bd3e9671312205c424fe6017915f1ed81/__pycache__/evaluation.cpython-38.pyc
--------------------------------------------------------------------------------
/__pycache__/model.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LHY1007/M3C2/ef66915bd3e9671312205c424fe6017915f1ed81/__pycache__/model.cpython-38.pyc
--------------------------------------------------------------------------------
/__pycache__/net.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LHY1007/M3C2/ef66915bd3e9671312205c424fe6017915f1ed81/__pycache__/net.cpython-38.pyc
--------------------------------------------------------------------------------
/__pycache__/saver.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LHY1007/M3C2/ef66915bd3e9671312205c424fe6017915f1ed81/__pycache__/saver.cpython-38.pyc
--------------------------------------------------------------------------------
/__pycache__/utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LHY1007/M3C2/ef66915bd3e9671312205c424fe6017915f1ed81/__pycache__/utils.cpython-38.pyc
--------------------------------------------------------------------------------
/basic_net/CDKN_main.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LHY1007/M3C2/ef66915bd3e9671312205c424fe6017915f1ed81/basic_net/CDKN_main.py
--------------------------------------------------------------------------------
/basic_net/__init__.py:
--------------------------------------------------------------------------------
1 | from .alexnet import *
2 | from .resnet import *
3 | from .inception import *
4 | from .densenet import *
5 | from .mnasnet import *
6 | from .nystrom_atten import *
7 |
8 |
9 |
--------------------------------------------------------------------------------
/basic_net/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LHY1007/M3C2/ef66915bd3e9671312205c424fe6017915f1ed81/basic_net/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/basic_net/__pycache__/alexnet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LHY1007/M3C2/ef66915bd3e9671312205c424fe6017915f1ed81/basic_net/__pycache__/alexnet.cpython-38.pyc
--------------------------------------------------------------------------------
/basic_net/__pycache__/densenet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LHY1007/M3C2/ef66915bd3e9671312205c424fe6017915f1ed81/basic_net/__pycache__/densenet.cpython-38.pyc
--------------------------------------------------------------------------------
/basic_net/__pycache__/inception.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LHY1007/M3C2/ef66915bd3e9671312205c424fe6017915f1ed81/basic_net/__pycache__/inception.cpython-38.pyc
--------------------------------------------------------------------------------
/basic_net/__pycache__/mnasnet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LHY1007/M3C2/ef66915bd3e9671312205c424fe6017915f1ed81/basic_net/__pycache__/mnasnet.cpython-38.pyc
--------------------------------------------------------------------------------
/basic_net/__pycache__/nystrom_atten.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LHY1007/M3C2/ef66915bd3e9671312205c424fe6017915f1ed81/basic_net/__pycache__/nystrom_atten.cpython-38.pyc
--------------------------------------------------------------------------------
/basic_net/__pycache__/resnet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LHY1007/M3C2/ef66915bd3e9671312205c424fe6017915f1ed81/basic_net/__pycache__/resnet.cpython-38.pyc
--------------------------------------------------------------------------------
/basic_net/alexnet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.utils.model_zoo as model_zoo
3 |
4 |
5 | __all__ = ['AlexNet', 'alexnet']
6 |
7 |
8 | model_urls = {
9 | 'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth',
10 | }
11 |
12 |
13 | class AlexNet(nn.Module):
14 |
15 | def __init__(self, num_classes=1000):
16 | super(AlexNet, self).__init__()
17 | self.features = nn.Sequential(
18 | nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
19 | nn.ReLU(inplace=True),
20 | nn.MaxPool2d(kernel_size=3, stride=2),
21 | nn.Conv2d(64, 192, kernel_size=5, padding=2),
22 | nn.ReLU(inplace=True),
23 | nn.MaxPool2d(kernel_size=3, stride=2),
24 | nn.Conv2d(192, 384, kernel_size=3, padding=1),
25 | nn.ReLU(inplace=True),
26 | nn.Conv2d(384, 256, kernel_size=3, padding=1),
27 | nn.ReLU(inplace=True),
28 | nn.Conv2d(256, 256, kernel_size=3, padding=1),
29 | nn.ReLU(inplace=True),
30 | nn.MaxPool2d(kernel_size=3, stride=2),
31 | )
32 | self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
33 | self.classifier = nn.Sequential(
34 | nn.Dropout(),
35 | nn.Linear(256 * 6 * 6, 4096),
36 | nn.ReLU(inplace=True),
37 | nn.Dropout(),
38 | nn.Linear(4096, 4096),
39 | nn.ReLU(inplace=True),
40 | nn.Linear(4096, num_classes),
41 | )
42 |
43 | def forward(self, x):
44 | x = self.features(x)
45 | x = self.avgpool(x)
46 | x = x.view(x.size(0), 256 * 6 * 6)
47 | x = self.classifier(x)
48 | return x
49 |
50 |
51 | def alexnet(pretrained=False, **kwargs):
52 | r"""AlexNet model architecture from the
53 | `"One weird trick..." `_ paper.
54 |
55 | Args:
56 | pretrained (bool): If True, returns a model pre-trained on ImageNet
57 | """
58 | model = AlexNet(**kwargs)
59 | if pretrained:
60 | model.load_state_dict(model_zoo.load_url(model_urls['alexnet']))
61 | return model
62 |
--------------------------------------------------------------------------------
/basic_net/densenet.py:
--------------------------------------------------------------------------------
1 | import re
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import torch.utils.model_zoo as model_zoo
6 | from collections import OrderedDict
7 |
8 | __all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161']
9 |
10 |
11 | model_urls = {
12 | 'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth',
13 | 'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth',
14 | 'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth',
15 | 'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth',
16 | }
17 |
18 |
19 | class _DenseLayer(nn.Sequential):
20 | def __init__(self, num_input_features, growth_rate, bn_size, drop_rate):
21 | super(_DenseLayer, self).__init__()
22 | self.add_module('norm1', nn.BatchNorm2d(num_input_features)),
23 | self.add_module('relu1', nn.ReLU(inplace=True)),
24 | self.add_module('conv1', nn.Conv2d(num_input_features, bn_size *
25 | growth_rate, kernel_size=1, stride=1, bias=False)),
26 | self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)),
27 | self.add_module('relu2', nn.ReLU(inplace=True)),
28 | self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate,
29 | kernel_size=3, stride=1, padding=1, bias=False)),
30 | self.drop_rate = drop_rate
31 |
32 | def forward(self, x):
33 | new_features = super(_DenseLayer, self).forward(x)
34 | if self.drop_rate > 0:
35 | new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)
36 | return torch.cat([x, new_features], 1)
37 |
38 |
39 | class _DenseBlock(nn.Sequential):
40 | def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate):
41 | super(_DenseBlock, self).__init__()
42 | for i in range(num_layers):
43 | layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate)
44 | self.add_module('denselayer%d' % (i + 1), layer)
45 |
46 |
47 | class _Transition(nn.Sequential):
48 | def __init__(self, num_input_features, num_output_features):
49 | super(_Transition, self).__init__()
50 | self.add_module('norm', nn.BatchNorm2d(num_input_features))
51 | self.add_module('relu', nn.ReLU(inplace=True))
52 | self.add_module('conv', nn.Conv2d(num_input_features, num_output_features,
53 | kernel_size=1, stride=1, bias=False))
54 | self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2))
55 |
56 |
57 | class DenseNet(nn.Module):
58 | r"""Densenet-BC model class, based on
59 | `"Densely Connected Convolutional Networks" `_
60 |
61 | Args:
62 | growth_rate (int) - how many filters to add each layer (`k` in paper)
63 | block_config (list of 4 ints) - how many layers in each pooling block
64 | num_init_features (int) - the number of filters to learn in the first convolution layer
65 | bn_size (int) - multiplicative factor for number of bottle neck layers
66 | (i.e. bn_size * k features in the bottleneck layer)
67 | drop_rate (float) - dropout rate after each dense layer
68 | num_classes (int) - number of classification classes
69 | """
70 |
71 | def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16),
72 | num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000):
73 |
74 | super(DenseNet, self).__init__()
75 |
76 | # First convolution
77 | self.features = nn.Sequential(OrderedDict([
78 | ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),
79 | ('norm0', nn.BatchNorm2d(num_init_features)),
80 | ('relu0', nn.ReLU(inplace=True)),
81 | ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
82 | ]))
83 |
84 | # Each denseblock
85 | num_features = num_init_features
86 | for i, num_layers in enumerate(block_config):
87 | block = _DenseBlock(num_layers=num_layers, num_input_features=num_features,
88 | bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate)
89 | self.features.add_module('denseblock%d' % (i + 1), block)
90 | num_features = num_features + num_layers * growth_rate
91 | if i != len(block_config) - 1:
92 | trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2)
93 | self.features.add_module('transition%d' % (i + 1), trans)
94 | num_features = num_features // 2
95 |
96 | # Final batch norm
97 | self.features.add_module('norm5', nn.BatchNorm2d(num_features))
98 |
99 | # Linear layer
100 | self.classifier = nn.Linear(num_features, num_classes)
101 |
102 | # Official init from torch repo.
103 | for m in self.modules():
104 | if isinstance(m, nn.Conv2d):
105 | nn.init.kaiming_normal_(m.weight)
106 | elif isinstance(m, nn.BatchNorm2d):
107 | nn.init.constant_(m.weight, 1)
108 | nn.init.constant_(m.bias, 0)
109 | elif isinstance(m, nn.Linear):
110 | nn.init.constant_(m.bias, 0)
111 |
112 | def forward(self, x):
113 | features = self.features(x)
114 | out = F.relu(features, inplace=True)
115 | out = F.adaptive_avg_pool2d(out, (1, 1)).view(features.size(0), -1)
116 | out = self.classifier(out)
117 | return out
118 |
119 |
120 | def densenet121(pretrained=False, **kwargs):
121 | r"""Densenet-121 model from
122 | `"Densely Connected Convolutional Networks" `_
123 |
124 | Args:
125 | pretrained (bool): If True, returns a model pre-trained on ImageNet
126 | """
127 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16),
128 | **kwargs)
129 | if pretrained:
130 | # '.'s are no longer allowed in module names, but pervious _DenseLayer
131 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
132 | # They are also in the checkpoints in model_urls. This pattern is used
133 | # to find such keys.
134 | pattern = re.compile(
135 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
136 | state_dict = model_zoo.load_url(model_urls['densenet121'])
137 | for key in list(state_dict.keys()):
138 | res = pattern.match(key)
139 | if res:
140 | new_key = res.group(1) + res.group(2)
141 | state_dict[new_key] = state_dict[key]
142 | del state_dict[key]
143 | # state_dict = {k: v for k, v in state_dict.items() if ('conv0' in k or'norm0' in k ) }
144 | model.load_state_dict(state_dict,strict=False)
145 | return model
146 |
147 |
148 | def densenet169(pretrained=False, **kwargs):
149 | r"""Densenet-169 model from
150 | `"Densely Connected Convolutional Networks" `_
151 |
152 | Args:
153 | pretrained (bool): If True, returns a model pre-trained on ImageNet
154 | """
155 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32),
156 | **kwargs)
157 | if pretrained:
158 | # '.'s are no longer allowed in module names, but pervious _DenseLayer
159 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
160 | # They are also in the checkpoints in model_urls. This pattern is used
161 | # to find such keys.
162 | pattern = re.compile(
163 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
164 | state_dict = model_zoo.load_url(model_urls['densenet169'])
165 | for key in list(state_dict.keys()):
166 | res = pattern.match(key)
167 | if res:
168 | new_key = res.group(1) + res.group(2)
169 | state_dict[new_key] = state_dict[key]
170 | del state_dict[key]
171 | model.load_state_dict(state_dict)
172 | return model
173 |
174 |
175 | def densenet201(pretrained=False, **kwargs):
176 | r"""Densenet-201 model from
177 | `"Densely Connected Convolutional Networks" `_
178 |
179 | Args:
180 | pretrained (bool): If True, returns a model pre-trained on ImageNet
181 | """
182 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32),
183 | **kwargs)
184 | if pretrained:
185 | # '.'s are no longer allowed in module names, but pervious _DenseLayer
186 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
187 | # They are also in the checkpoints in model_urls. This pattern is used
188 | # to find such keys.
189 | pattern = re.compile(
190 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
191 | state_dict = model_zoo.load_url(model_urls['densenet201'])
192 | for key in list(state_dict.keys()):
193 | res = pattern.match(key)
194 | if res:
195 | new_key = res.group(1) + res.group(2)
196 | state_dict[new_key] = state_dict[key]
197 | del state_dict[key]
198 | model.load_state_dict(state_dict)
199 | return model
200 |
201 |
202 | def densenet161(pretrained=False, **kwargs):
203 | r"""Densenet-161 model from
204 | `"Densely Connected Convolutional Networks" `_
205 |
206 | Args:
207 | pretrained (bool): If True, returns a model pre-trained on ImageNet
208 | """
209 | model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24),
210 | **kwargs)
211 | if pretrained:
212 | # '.'s are no longer allowed in module names, but pervious _DenseLayer
213 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
214 | # They are also in the checkpoints in model_urls. This pattern is used
215 | # to find such keys.
216 | pattern = re.compile(
217 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
218 | state_dict = model_zoo.load_url(model_urls['densenet161'])
219 | for key in list(state_dict.keys()):
220 | res = pattern.match(key)
221 | if res:
222 | new_key = res.group(1) + res.group(2)
223 | state_dict[new_key] = state_dict[key]
224 | del state_dict[key]
225 | model.load_state_dict(state_dict)
226 | return model
227 |
--------------------------------------------------------------------------------
/basic_net/inception.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torch.utils.model_zoo as model_zoo
5 |
6 |
7 | __all__ = ['Inception3', 'inception_v3']
8 |
9 |
10 | model_urls = {
11 | # Inception v3 ported from TensorFlow
12 | 'inception_v3_google': 'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth',
13 | }
14 |
15 |
16 | def inception_v3(pretrained=False, **kwargs):
17 | r"""Inception v3 model architecture from
18 | `"Rethinking the Inception Architecture for Computer Vision" `_.
19 |
20 | .. note::
21 | **Important**: In contrast to the other models the inception_v3 expects tensors with a size of
22 | N x 3 x 299 x 299, so ensure your images are sized accordingly.
23 |
24 | Args:
25 | pretrained (bool): If True, returns a model pre-trained on ImageNet
26 | """
27 | if pretrained:
28 | if 'transform_input' not in kwargs:
29 | kwargs['transform_input'] = True
30 | model = Inception3(**kwargs)
31 | model.load_state_dict(model_zoo.load_url(model_urls['inception_v3_google']))
32 | return model
33 |
34 | return Inception3(**kwargs)
35 |
36 |
37 | class Inception3(nn.Module):
38 |
39 | def __init__(self, num_classes=1000, aux_logits=True, transform_input=False):
40 | super(Inception3, self).__init__()
41 | self.aux_logits = aux_logits
42 | self.transform_input = transform_input
43 | self.Conv2d_1a_3x3 = BasicConv2d(3, 32, kernel_size=3, stride=2)
44 | self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3)
45 | self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1)
46 | self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1)
47 | self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3)
48 | self.Mixed_5b = InceptionA(192, pool_features=32)
49 | self.Mixed_5c = InceptionA(256, pool_features=64)
50 | self.Mixed_5d = InceptionA(288, pool_features=64)
51 | self.Mixed_6a = InceptionB(288)
52 | self.Mixed_6b = InceptionC(768, channels_7x7=128)
53 | self.Mixed_6c = InceptionC(768, channels_7x7=160)
54 | self.Mixed_6d = InceptionC(768, channels_7x7=160)
55 | self.Mixed_6e = InceptionC(768, channels_7x7=192)
56 | if aux_logits:
57 | self.AuxLogits = InceptionAux(768, num_classes)
58 | self.Mixed_7a = InceptionD(768)
59 | self.Mixed_7b = InceptionE(1280)
60 | self.Mixed_7c = InceptionE(2048)
61 | self.fc = nn.Linear(2048, num_classes)
62 |
63 | for m in self.modules():
64 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
65 | import scipy.stats as stats
66 | stddev = m.stddev if hasattr(m, 'stddev') else 0.1
67 | X = stats.truncnorm(-2, 2, scale=stddev)
68 | values = torch.Tensor(X.rvs(m.weight.numel()))
69 | values = values.view(m.weight.size())
70 | m.weight.data.copy_(values)
71 | elif isinstance(m, nn.BatchNorm2d):
72 | nn.init.constant_(m.weight, 1)
73 | nn.init.constant_(m.bias, 0)
74 |
75 | def forward(self, x):
76 | # if self.transform_input:
77 | # x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
78 | # x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
79 | # x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
80 | # x = torch.cat((x_ch0, x_ch1, x_ch2), 1)
81 | # N x 3 x 299 x 299
82 | x = self.Conv2d_1a_3x3(x)
83 | # N x 32 x 149 x 149
84 | x = self.Conv2d_2a_3x3(x)
85 | # N x 32 x 147 x 147
86 | x = self.Conv2d_2b_3x3(x)
87 | # N x 64 x 147 x 147
88 | x = F.max_pool2d(x, kernel_size=3, stride=2)
89 | # N x 64 x 73 x 73
90 | x = self.Conv2d_3b_1x1(x)
91 | # N x 80 x 73 x 73
92 | x = self.Conv2d_4a_3x3(x)
93 | # N x 192 x 71 x 71
94 | x = F.max_pool2d(x, kernel_size=3, stride=2)
95 | # N x 192 x 35 x 35
96 | x = self.Mixed_5b(x)
97 | # N x 256 x 35 x 35
98 | x = self.Mixed_5c(x)
99 | # N x 288 x 35 x 35
100 | x = self.Mixed_5d(x)
101 | # N x 288 x 35 x 35
102 | x = self.Mixed_6a(x)
103 | # N x 768 x 17 x 17
104 | x = self.Mixed_6b(x)
105 | # N x 768 x 17 x 17
106 | x = self.Mixed_6c(x)
107 | # N x 768 x 17 x 17
108 | x = self.Mixed_6d(x)
109 | # N x 768 x 17 x 17
110 | x = self.Mixed_6e(x)
111 | # N x 768 x 17 x 17
112 | if self.training and self.aux_logits:
113 | aux = self.AuxLogits(x)
114 | # N x 768 x 17 x 17
115 | x = self.Mixed_7a(x)
116 | # N x 1280 x 8 x 8
117 | x = self.Mixed_7b(x)
118 | # N x 2048 x 8 x 8
119 | x = self.Mixed_7c(x)
120 | # N x 2048 x 8 x 8
121 | # Adaptive average pooling
122 | x = F.adaptive_avg_pool2d(x, (1, 1))
123 | # N x 2048 x 1 x 1
124 | x = F.dropout(x, training=self.training)
125 | # N x 2048 x 1 x 1
126 | x = x.view(x.size(0), -1)
127 | # N x 2048
128 | x = self.fc(x)
129 | # N x 1000 (num_classes)
130 | if self.training and self.aux_logits:
131 | return x, aux
132 | return x
133 |
134 |
135 | class InceptionA(nn.Module):
136 |
137 | def __init__(self, in_channels, pool_features):
138 | super(InceptionA, self).__init__()
139 | self.branch1x1 = BasicConv2d(in_channels, 64, kernel_size=1)
140 |
141 | self.branch5x5_1 = BasicConv2d(in_channels, 48, kernel_size=1)
142 | self.branch5x5_2 = BasicConv2d(48, 64, kernel_size=5, padding=2)
143 |
144 | self.branch3x3dbl_1 = BasicConv2d(in_channels, 64, kernel_size=1)
145 | self.branch3x3dbl_2 = BasicConv2d(64, 96, kernel_size=3, padding=1)
146 | self.branch3x3dbl_3 = BasicConv2d(96, 96, kernel_size=3, padding=1)
147 |
148 | self.branch_pool = BasicConv2d(in_channels, pool_features, kernel_size=1)
149 |
150 | def forward(self, x):
151 | branch1x1 = self.branch1x1(x)
152 |
153 | branch5x5 = self.branch5x5_1(x)
154 | branch5x5 = self.branch5x5_2(branch5x5)
155 |
156 | branch3x3dbl = self.branch3x3dbl_1(x)
157 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
158 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
159 |
160 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
161 | branch_pool = self.branch_pool(branch_pool)
162 |
163 | outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
164 | return torch.cat(outputs, 1)
165 |
166 |
167 | class InceptionB(nn.Module):
168 |
169 | def __init__(self, in_channels):
170 | super(InceptionB, self).__init__()
171 | self.branch3x3 = BasicConv2d(in_channels, 384, kernel_size=3, stride=2)
172 |
173 | self.branch3x3dbl_1 = BasicConv2d(in_channels, 64, kernel_size=1)
174 | self.branch3x3dbl_2 = BasicConv2d(64, 96, kernel_size=3, padding=1)
175 | self.branch3x3dbl_3 = BasicConv2d(96, 96, kernel_size=3, stride=2)
176 |
177 | def forward(self, x):
178 | branch3x3 = self.branch3x3(x)
179 |
180 | branch3x3dbl = self.branch3x3dbl_1(x)
181 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
182 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
183 |
184 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)
185 |
186 | outputs = [branch3x3, branch3x3dbl, branch_pool]
187 | return torch.cat(outputs, 1)
188 |
189 |
190 | class InceptionC(nn.Module):
191 |
192 | def __init__(self, in_channels, channels_7x7):
193 | super(InceptionC, self).__init__()
194 | self.branch1x1 = BasicConv2d(in_channels, 192, kernel_size=1)
195 |
196 | c7 = channels_7x7
197 | self.branch7x7_1 = BasicConv2d(in_channels, c7, kernel_size=1)
198 | self.branch7x7_2 = BasicConv2d(c7, c7, kernel_size=(1, 7), padding=(0, 3))
199 | self.branch7x7_3 = BasicConv2d(c7, 192, kernel_size=(7, 1), padding=(3, 0))
200 |
201 | self.branch7x7dbl_1 = BasicConv2d(in_channels, c7, kernel_size=1)
202 | self.branch7x7dbl_2 = BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0))
203 | self.branch7x7dbl_3 = BasicConv2d(c7, c7, kernel_size=(1, 7), padding=(0, 3))
204 | self.branch7x7dbl_4 = BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0))
205 | self.branch7x7dbl_5 = BasicConv2d(c7, 192, kernel_size=(1, 7), padding=(0, 3))
206 |
207 | self.branch_pool = BasicConv2d(in_channels, 192, kernel_size=1)
208 |
209 | def forward(self, x):
210 | branch1x1 = self.branch1x1(x)
211 |
212 | branch7x7 = self.branch7x7_1(x)
213 | branch7x7 = self.branch7x7_2(branch7x7)
214 | branch7x7 = self.branch7x7_3(branch7x7)
215 |
216 | branch7x7dbl = self.branch7x7dbl_1(x)
217 | branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
218 | branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
219 | branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
220 | branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
221 |
222 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
223 | branch_pool = self.branch_pool(branch_pool)
224 |
225 | outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
226 | return torch.cat(outputs, 1)
227 |
228 |
229 | class InceptionD(nn.Module):
230 |
231 | def __init__(self, in_channels):
232 | super(InceptionD, self).__init__()
233 | self.branch3x3_1 = BasicConv2d(in_channels, 192, kernel_size=1)
234 | self.branch3x3_2 = BasicConv2d(192, 320, kernel_size=3, stride=2)
235 |
236 | self.branch7x7x3_1 = BasicConv2d(in_channels, 192, kernel_size=1)
237 | self.branch7x7x3_2 = BasicConv2d(192, 192, kernel_size=(1, 7), padding=(0, 3))
238 | self.branch7x7x3_3 = BasicConv2d(192, 192, kernel_size=(7, 1), padding=(3, 0))
239 | self.branch7x7x3_4 = BasicConv2d(192, 192, kernel_size=3, stride=2)
240 |
241 | def forward(self, x):
242 | branch3x3 = self.branch3x3_1(x)
243 | branch3x3 = self.branch3x3_2(branch3x3)
244 |
245 | branch7x7x3 = self.branch7x7x3_1(x)
246 | branch7x7x3 = self.branch7x7x3_2(branch7x7x3)
247 | branch7x7x3 = self.branch7x7x3_3(branch7x7x3)
248 | branch7x7x3 = self.branch7x7x3_4(branch7x7x3)
249 |
250 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)
251 | outputs = [branch3x3, branch7x7x3, branch_pool]
252 | return torch.cat(outputs, 1)
253 |
254 |
255 | class InceptionE(nn.Module):
256 |
257 | def __init__(self, in_channels):
258 | super(InceptionE, self).__init__()
259 | self.branch1x1 = BasicConv2d(in_channels, 320, kernel_size=1)
260 |
261 | self.branch3x3_1 = BasicConv2d(in_channels, 384, kernel_size=1)
262 | self.branch3x3_2a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1))
263 | self.branch3x3_2b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0))
264 |
265 | self.branch3x3dbl_1 = BasicConv2d(in_channels, 448, kernel_size=1)
266 | self.branch3x3dbl_2 = BasicConv2d(448, 384, kernel_size=3, padding=1)
267 | self.branch3x3dbl_3a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1))
268 | self.branch3x3dbl_3b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0))
269 |
270 | self.branch_pool = BasicConv2d(in_channels, 192, kernel_size=1)
271 |
272 | def forward(self, x):
273 | branch1x1 = self.branch1x1(x)
274 |
275 | branch3x3 = self.branch3x3_1(x)
276 | branch3x3 = [
277 | self.branch3x3_2a(branch3x3),
278 | self.branch3x3_2b(branch3x3),
279 | ]
280 | branch3x3 = torch.cat(branch3x3, 1)
281 |
282 | branch3x3dbl = self.branch3x3dbl_1(x)
283 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
284 | branch3x3dbl = [
285 | self.branch3x3dbl_3a(branch3x3dbl),
286 | self.branch3x3dbl_3b(branch3x3dbl),
287 | ]
288 | branch3x3dbl = torch.cat(branch3x3dbl, 1)
289 |
290 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
291 | branch_pool = self.branch_pool(branch_pool)
292 |
293 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
294 | return torch.cat(outputs, 1)
295 |
296 |
297 | class InceptionAux(nn.Module):
298 |
299 | def __init__(self, in_channels, num_classes):
300 | super(InceptionAux, self).__init__()
301 | self.conv0 = BasicConv2d(in_channels, 128, kernel_size=1)
302 | self.conv1 = BasicConv2d(128, 768, kernel_size=5)
303 | self.conv1.stddev = 0.01
304 | self.fc = nn.Linear(768, num_classes)
305 | self.fc.stddev = 0.001
306 |
307 | def forward(self, x):
308 | # N x 768 x 17 x 17
309 | x = F.avg_pool2d(x, kernel_size=5, stride=3)
310 | # N x 768 x 5 x 5
311 | x = self.conv0(x)
312 | # N x 128 x 5 x 5
313 | x = self.conv1(x)
314 | # N x 768 x 1 x 1
315 | # Adaptive average pooling
316 | x = F.adaptive_avg_pool2d(x, (1, 1))
317 | # N x 768 x 1 x 1
318 | x = x.view(x.size(0), -1)
319 | # N x 768
320 | x = self.fc(x)
321 | # N x 1000
322 | return x
323 |
324 |
325 | class BasicConv2d(nn.Module):
326 |
327 | def __init__(self, in_channels, out_channels, **kwargs):
328 | super(BasicConv2d, self).__init__()
329 | self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
330 | self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
331 |
332 | def forward(self, x):
333 | x = self.conv(x)
334 | x = self.bn(x)
335 | return F.relu(x, inplace=True)
336 |
--------------------------------------------------------------------------------
/basic_net/mnasnet.py:
--------------------------------------------------------------------------------
1 | import math
2 | import warnings
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.utils.model_zoo as model_zoo
7 |
8 | __all__ = ['MNASNet', 'mnasnet0_5', 'mnasnet0_75', 'mnasnet1_0', 'mnasnet1_3']
9 |
10 | _MODEL_URLS = {
11 | "mnasnet0_5":
12 | "https://download.pytorch.org/models/mnasnet0.5_top1_67.823-3ffadce67e.pth",
13 | "mnasnet0_75": None,
14 | "mnasnet1_0":
15 | "https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth",
16 | "mnasnet1_3": None
17 | }
18 |
19 | # Paper suggests 0.9997 momentum, for TensorFlow. Equivalent PyTorch momentum is
20 | # 1.0 - tensorflow.
21 | _BN_MOMENTUM = 1 - 0.9997
22 |
23 |
24 | class _InvertedResidual(nn.Module):
25 |
26 | def __init__(self, in_ch, out_ch, kernel_size, stride, expansion_factor,
27 | bn_momentum=0.1):
28 | super(_InvertedResidual, self).__init__()
29 | assert stride in [1, 2]
30 | assert kernel_size in [3, 5]
31 | mid_ch = in_ch * expansion_factor
32 | self.apply_residual = (in_ch == out_ch and stride == 1)
33 | self.layers = nn.Sequential(
34 | # Pointwise
35 | nn.Conv2d(in_ch, mid_ch, 1, bias=False),
36 | nn.BatchNorm2d(mid_ch, momentum=bn_momentum),
37 | nn.ReLU(inplace=True),
38 | # Depthwise
39 | nn.Conv2d(mid_ch, mid_ch, kernel_size, padding=kernel_size // 2,
40 | stride=stride, groups=mid_ch, bias=False),
41 | nn.BatchNorm2d(mid_ch, momentum=bn_momentum),
42 | nn.ReLU(inplace=True),
43 | # Linear pointwise. Note that there's no activation.
44 | nn.Conv2d(mid_ch, out_ch, 1, bias=False),
45 | nn.BatchNorm2d(out_ch, momentum=bn_momentum))
46 |
47 | def forward(self, input):
48 | if self.apply_residual:
49 | return self.layers(input) + input
50 | else:
51 | return self.layers(input)
52 |
53 |
54 | def _stack(in_ch, out_ch, kernel_size, stride, exp_factor, repeats,
55 | bn_momentum):
56 | """ Creates a stack of inverted residuals. """
57 | assert repeats >= 1
58 | # First one has no skip, because feature map size changes.
59 | first = _InvertedResidual(in_ch, out_ch, kernel_size, stride, exp_factor,
60 | bn_momentum=bn_momentum)
61 | remaining = []
62 | for _ in range(1, repeats):
63 | remaining.append(
64 | _InvertedResidual(out_ch, out_ch, kernel_size, 1, exp_factor,
65 | bn_momentum=bn_momentum))
66 | return nn.Sequential(first, *remaining)
67 |
68 |
69 | def _round_to_multiple_of(val, divisor, round_up_bias=0.9):
70 | """ Asymmetric rounding to make `val` divisible by `divisor`. With default
71 | bias, will round up, unless the number is no more than 10% greater than the
72 | smaller divisible value, i.e. (83, 8) -> 80, but (84, 8) -> 88. """
73 | assert 0.0 < round_up_bias < 1.0
74 | new_val = max(divisor, int(val + divisor / 2) // divisor * divisor)
75 | return new_val if new_val >= round_up_bias * val else new_val + divisor
76 |
77 |
78 | def _get_depths(alpha):
79 | """ Scales tensor depths as in reference MobileNet code, prefers rouding up
80 | rather than down. """
81 | depths = [32, 16, 24, 40, 80, 96, 192, 320]
82 | return [_round_to_multiple_of(depth * alpha, 8) for depth in depths]
83 |
84 |
85 | class MNASNet(torch.nn.Module):
86 | """ MNASNet, as described in https://arxiv.org/pdf/1807.11626.pdf. This
87 | implements the B1 variant of the model.
88 | >>> model = MNASNet(1000, 1.0)
89 | >>> x = torch.rand(1, 3, 224, 224)
90 | >>> y = model(x)
91 | >>> y.dim()
92 | 1
93 | >>> y.nelement()
94 | 1000
95 | """
96 | # Version 2 adds depth scaling in the initial stages of the network.
97 | _version = 2
98 |
99 | def __init__(self, alpha, num_classes=1000, dropout=0.2):
100 | super(MNASNet, self).__init__()
101 | assert alpha > 0.0
102 | self.alpha = alpha
103 | self.num_classes = num_classes
104 | depths = _get_depths(alpha)
105 | layers = [
106 | # First layer: regular conv.
107 | nn.Conv2d(3, depths[0], 3, padding=1, stride=2, bias=False),
108 | nn.BatchNorm2d(depths[0], momentum=_BN_MOMENTUM),
109 | nn.ReLU(inplace=True),
110 | # Depthwise separable, no skip.
111 | nn.Conv2d(depths[0], depths[0], 3, padding=1, stride=1,
112 | groups=depths[0], bias=False),
113 | nn.BatchNorm2d(depths[0], momentum=_BN_MOMENTUM),
114 | nn.ReLU(inplace=True),
115 | nn.Conv2d(depths[0], depths[1], 1, padding=0, stride=1, bias=False),
116 | nn.BatchNorm2d(depths[1], momentum=_BN_MOMENTUM),
117 | # MNASNet blocks: stacks of inverted residuals.
118 | _stack(depths[1], depths[2], 3, 2, 3, 3, _BN_MOMENTUM),
119 | _stack(depths[2], depths[3], 5, 2, 3, 3, _BN_MOMENTUM),
120 | _stack(depths[3], depths[4], 5, 2, 6, 3, _BN_MOMENTUM),
121 | _stack(depths[4], depths[5], 3, 1, 6, 2, _BN_MOMENTUM),
122 | _stack(depths[5], depths[6], 5, 2, 6, 4, _BN_MOMENTUM),
123 | _stack(depths[6], depths[7], 3, 1, 6, 1, _BN_MOMENTUM),
124 | # Final mapping to classifier input.
125 | nn.Conv2d(depths[7], 1280, 1, padding=0, stride=1, bias=False),
126 | nn.BatchNorm2d(1280, momentum=_BN_MOMENTUM),
127 | nn.ReLU(inplace=True),
128 | ]
129 | self.layers = nn.Sequential(*layers)
130 | self.classifier = nn.Sequential(nn.Dropout(p=dropout, inplace=True),
131 | nn.Linear(1280, num_classes))
132 | self._initialize_weights()
133 |
134 | def forward(self, x):
135 | x = self.layers(x)
136 | # Equivalent to global avgpool and removing H and W dimensions.
137 | x = x.mean([2, 3])
138 | return self.classifier(x)
139 |
140 | def _initialize_weights(self):
141 | for m in self.modules():
142 | if isinstance(m, nn.Conv2d):
143 | nn.init.kaiming_normal_(m.weight, mode="fan_out",
144 | nonlinearity="relu")
145 | if m.bias is not None:
146 | nn.init.zeros_(m.bias)
147 | elif isinstance(m, nn.BatchNorm2d):
148 | nn.init.ones_(m.weight)
149 | nn.init.zeros_(m.bias)
150 | elif isinstance(m, nn.Linear):
151 | nn.init.kaiming_uniform_(m.weight, mode="fan_out",
152 | nonlinearity="sigmoid")
153 | nn.init.zeros_(m.bias)
154 |
155 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
156 | missing_keys, unexpected_keys, error_msgs):
157 | version = local_metadata.get("version", None)
158 | assert version in [1, 2]
159 |
160 | if version == 1 and not self.alpha == 1.0:
161 | # In the initial version of the model (v1), stem was fixed-size.
162 | # All other layer configurations were the same. This will patch
163 | # the model so that it's identical to v1. Model with alpha 1.0 is
164 | # unaffected.
165 | depths = _get_depths(self.alpha)
166 | v1_stem = [
167 | nn.Conv2d(3, 32, 3, padding=1, stride=2, bias=False),
168 | nn.BatchNorm2d(32, momentum=_BN_MOMENTUM),
169 | nn.ReLU(inplace=True),
170 | nn.Conv2d(32, 32, 3, padding=1, stride=1, groups=32,
171 | bias=False),
172 | nn.BatchNorm2d(32, momentum=_BN_MOMENTUM),
173 | nn.ReLU(inplace=True),
174 | nn.Conv2d(32, 16, 1, padding=0, stride=1, bias=False),
175 | nn.BatchNorm2d(16, momentum=_BN_MOMENTUM),
176 | _stack(16, depths[2], 3, 2, 3, 3, _BN_MOMENTUM),
177 | ]
178 | for idx, layer in enumerate(v1_stem):
179 | self.layers[idx] = layer
180 |
181 | # The model is now identical to v1, and must be saved as such.
182 | self._version = 1
183 | warnings.warn(
184 | "A new version of MNASNet model has been implemented. "
185 | "Your checkpoint was saved using the previous version. "
186 | "This checkpoint will load and work as before, but "
187 | "you may want to upgrade by training a newer model or "
188 | "transfer learning from an updated ImageNet checkpoint.",
189 | UserWarning)
190 |
191 | super(MNASNet, self)._load_from_state_dict(
192 | state_dict, prefix, local_metadata, strict, missing_keys,
193 | unexpected_keys, error_msgs)
194 |
195 |
196 | def _load_pretrained(model_name, model, progress):
197 | if model_name not in _MODEL_URLS or _MODEL_URLS[model_name] is None:
198 | raise ValueError(
199 | "No checkpoint is available for model type {}".format(model_name))
200 | checkpoint_url = _MODEL_URLS[model_name]
201 | model.load_state_dict(model_zoo.load_url(url=checkpoint_url,model_dir='./'))
202 |
203 |
204 | def mnasnet0_5(pretrained=False, progress=True, **kwargs):
205 | """MNASNet with depth multiplier of 0.5 from
206 | `"MnasNet: Platform-Aware Neural Architecture Search for Mobile"
207 | `_.
208 | Args:
209 | pretrained (bool): If True, returns a model pre-trained on ImageNet
210 | progress (bool): If True, displays a progress bar of the download to stderr
211 | """
212 | model = MNASNet(0.5, **kwargs)
213 | if pretrained:
214 | _load_pretrained("mnasnet0_5", model, progress)
215 | return model
216 |
217 |
218 | def mnasnet0_75(pretrained=False, progress=True, **kwargs):
219 | """MNASNet with depth multiplier of 0.75 from
220 | `"MnasNet: Platform-Aware Neural Architecture Search for Mobile"
221 | `_.
222 | Args:
223 | pretrained (bool): If True, returns a model pre-trained on ImageNet
224 | progress (bool): If True, displays a progress bar of the download to stderr
225 | """
226 | model = MNASNet(0.75, **kwargs)
227 | if pretrained:
228 | _load_pretrained("mnasnet0_75", model, progress)
229 | return model
230 |
231 |
232 | def mnasnet1_0(pretrained=False, progress=True, **kwargs):
233 | """MNASNet with depth multiplier of 1.0 from
234 | `"MnasNet: Platform-Aware Neural Architecture Search for Mobile"
235 | `_.
236 | Args:
237 | pretrained (bool): If True, returns a model pre-trained on ImageNet
238 | progress (bool): If True, displays a progress bar of the download to stderr
239 | """
240 | model = MNASNet(1.0, **kwargs)
241 | if pretrained:
242 | _load_pretrained("mnasnet1_0", model, progress)
243 | return model
244 |
245 |
246 | def mnasnet1_3(pretrained=False, progress=True, **kwargs):
247 | """MNASNet with depth multiplier of 1.3 from
248 | `"MnasNet: Platform-Aware Neural Architecture Search for Mobile"
249 | `_.
250 | Args:
251 | pretrained (bool): If True, returns a model pre-trained on ImageNet
252 | progress (bool): If True, displays a progress bar of the download to stderr
253 | """
254 | model = MNASNet(1.3, **kwargs)
255 | if pretrained:
256 | _load_pretrained("mnasnet1_3", model, progress)
257 | return model
258 |
--------------------------------------------------------------------------------
/basic_net/nystrom_atten.py:
--------------------------------------------------------------------------------
1 | from math import ceil
2 | import torch
3 | from torch import nn, einsum
4 | import torch.nn.functional as F
5 |
6 | from einops import rearrange, reduce
7 |
8 | # helper functions
9 |
10 | def exists(val):
11 | return val is not None
12 |
13 | def moore_penrose_iter_pinv(x, iters = 6):
14 | device = x.device
15 |
16 | abs_x = torch.abs(x)
17 | col = abs_x.sum(dim = -1)
18 | row = abs_x.sum(dim = -2)
19 | z = rearrange(x, '... i j -> ... j i') / (torch.max(col) * torch.max(row))
20 |
21 | I = torch.eye(x.shape[-1], device = device)
22 | I = rearrange(I, 'i j -> () i j')
23 |
24 | for _ in range(iters):
25 | xz = x @ z
26 | z = 0.25 * z @ (13 * I - (xz @ (15 * I - (xz @ (7 * I - xz)))))
27 |
28 | return z
29 |
30 | # main attention class
31 |
32 | class NystromAttention(nn.Module):
33 | def __init__(
34 | self,
35 | dim,
36 | dim_head = 64,
37 | heads = 8,
38 | num_landmarks = 256,
39 | pinv_iterations = 6,
40 | residual = True,
41 | residual_conv_kernel = 33,
42 | eps = 1e-8,
43 | dropout = 0.
44 | ):
45 | super().__init__()
46 | self.eps = eps
47 | inner_dim = heads * dim_head
48 |
49 | self.num_landmarks = num_landmarks
50 | self.pinv_iterations = pinv_iterations
51 |
52 | self.heads = heads
53 | self.scale = dim_head ** -0.5
54 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
55 |
56 | self.to_out = nn.Sequential(
57 | nn.Linear(inner_dim, dim),
58 | nn.Dropout(dropout)
59 | )
60 |
61 | self.residual = residual
62 | if residual:
63 | kernel_size = residual_conv_kernel
64 | padding = residual_conv_kernel // 2
65 | self.res_conv = nn.Conv2d(heads, heads, (kernel_size, 1), padding = (padding, 0), groups = heads, bias = False)
66 |
67 | def forward(self, x, mask = None, return_attn = False):
68 | b, n, _, h, m, iters, eps = *x.shape, self.heads, self.num_landmarks, self.pinv_iterations, self.eps
69 |
70 | # pad so that sequence can be evenly divided into m landmarks
71 |
72 | remainder = n % m
73 | if remainder > 0:
74 | padding = m - (n % m)
75 | x = F.pad(x, (0, 0, padding, 0), value = 0)
76 |
77 | if exists(mask):
78 | mask = F.pad(mask, (padding, 0), value = False)
79 |
80 | # derive query, keys, values
81 |
82 | q, k, v = self.to_qkv(x).chunk(3, dim = -1)
83 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
84 |
85 | # set masked positions to 0 in queries, keys, values
86 |
87 | if exists(mask):
88 | mask = rearrange(mask, 'b n -> b () n')
89 | q, k, v = map(lambda t: t * mask[..., None], (q, k, v))
90 |
91 | q = q * self.scale
92 |
93 | # generate landmarks by sum reduction, and then calculate mean using the mask
94 |
95 | l = ceil(n / m)
96 | landmark_einops_eq = '... (n l) d -> ... n d'
97 | q_landmarks = reduce(q, landmark_einops_eq, 'sum', l = l)
98 | k_landmarks = reduce(k, landmark_einops_eq, 'sum', l = l)
99 |
100 | # calculate landmark mask, and also get sum of non-masked elements in preparation for masked mean
101 |
102 | divisor = l
103 | if exists(mask):
104 | mask_landmarks_sum = reduce(mask, '... (n l) -> ... n', 'sum', l = l)
105 | divisor = mask_landmarks_sum[..., None] + eps
106 | mask_landmarks = mask_landmarks_sum > 0
107 |
108 | # masked mean (if mask exists)
109 |
110 | q_landmarks /= divisor
111 | k_landmarks /= divisor
112 |
113 | # similarities
114 |
115 | einops_eq = '... i d, ... j d -> ... i j'
116 | sim1 = einsum(einops_eq, q.float(), k_landmarks.float())
117 | sim2 = einsum(einops_eq, q_landmarks.float(), k_landmarks.float())
118 | sim3 = einsum(einops_eq, q_landmarks.float(), k.float())
119 |
120 | # masking
121 |
122 | if exists(mask):
123 | mask_value = -torch.finfo(q.dtype).max
124 | sim1.masked_fill_(~(mask[..., None] * mask_landmarks[..., None, :]), mask_value)
125 | sim2.masked_fill_(~(mask_landmarks[..., None] * mask_landmarks[..., None, :]), mask_value)
126 | sim3.masked_fill_(~(mask_landmarks[..., None] * mask[..., None, :]), mask_value)
127 |
128 | # eq (15) in the paper and aggregate values
129 |
130 | attn1, attn2, attn3 = map(lambda t: t.softmax(dim = -1), (sim1, sim2, sim3))
131 | attn2_inv = moore_penrose_iter_pinv(attn2, iters)
132 |
133 | out = (attn1 @ attn2_inv) @ (attn3 @ v)
134 |
135 | # add depth-wise conv residual of values
136 |
137 | if self.residual:
138 | out += self.res_conv(v)
139 |
140 | # merge and combine heads
141 |
142 | out = rearrange(out, 'b h n d -> b n (h d)', h = h)
143 | out = self.to_out(out)
144 | out = out[:, -n:]
145 |
146 | if return_attn:
147 | attn = attn1 @ attn2_inv @ attn3
148 | return out, attn
149 |
150 | return out
151 |
152 | # transformer
153 |
154 | class PreNorm(nn.Module):
155 | def __init__(self, dim, fn):
156 | super().__init__()
157 | self.norm = nn.LayerNorm(dim)
158 | self.fn = fn
159 |
160 | def forward(self, x, **kwargs):
161 | x = self.norm(x)
162 | return self.fn(x, **kwargs)
163 |
164 | class FeedForward(nn.Module):
165 | def __init__(self, dim, mult = 4, dropout = 0.):
166 | super().__init__()
167 | self.net = nn.Sequential(
168 | nn.Linear(dim, dim * mult),
169 | nn.GELU(),
170 | nn.Dropout(dropout),
171 | nn.Linear(dim * mult, dim)
172 | )
173 |
174 | def forward(self, x):
175 | return self.net(x)
176 |
177 | class Nystromformer(nn.Module):
178 | def __init__(
179 | self,
180 | *,
181 | dim,
182 | depth,
183 | dim_head = 64,
184 | heads = 8,
185 | num_landmarks = 256,
186 | pinv_iterations = 6,
187 | attn_values_residual = True,
188 | attn_values_residual_conv_kernel = 33,
189 | attn_dropout = 0.,
190 | ff_dropout = 0.
191 | ):
192 | super().__init__()
193 |
194 | self.layers = nn.ModuleList([])
195 | for _ in range(depth):
196 | self.layers.append(nn.ModuleList([
197 | PreNorm(dim, NystromAttention(dim = dim, dim_head = dim_head, heads = heads, num_landmarks = num_landmarks, pinv_iterations = pinv_iterations, residual = attn_values_residual, residual_conv_kernel = attn_values_residual_conv_kernel, dropout = attn_dropout)),
198 | PreNorm(dim, FeedForward(dim = dim, dropout = ff_dropout))
199 | ]))
200 |
201 | def forward(self, x, mask = None):
202 | for attn, ff in self.layers:
203 | x = attn(x, mask = mask) + x
204 | x = ff(x) + x
205 | return x
206 |
--------------------------------------------------------------------------------
/basic_net/resnet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.utils.model_zoo as model_zoo
3 |
4 |
5 | __all__ = ['ResNet', 'resnet18','resnet18_stem', 'resnet34', 'resnet50', 'resnet101',
6 | 'resnet152']
7 |
8 |
9 | model_urls = {
10 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
11 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
12 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
13 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
14 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
15 | }
16 |
17 |
18 | def conv3x3(in_planes, out_planes, stride=1):
19 | """3x3 convolution with padding"""
20 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
21 | padding=1, bias=False)
22 |
23 |
24 | def conv1x1(in_planes, out_planes, stride=1):
25 | """1x1 convolution"""
26 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
27 |
28 |
29 | class BasicBlock(nn.Module):
30 | expansion = 1
31 |
32 | def __init__(self, inplanes, planes, stride=1, downsample=None):
33 | super(BasicBlock, self).__init__()
34 | self.conv1 = conv3x3(inplanes, planes, stride)
35 | self.bn1 = nn.BatchNorm2d(planes)
36 | self.relu = nn.ReLU(inplace=True)
37 | self.conv2 = conv3x3(planes, planes)
38 | self.bn2 = nn.BatchNorm2d(planes)
39 | self.downsample = downsample
40 | self.stride = stride
41 |
42 | def forward(self, x):
43 | identity = x
44 |
45 | out = self.conv1(x)
46 | out = self.bn1(out)
47 | out = self.relu(out)
48 |
49 | out = self.conv2(out)
50 | out = self.bn2(out)
51 |
52 | if self.downsample is not None:
53 | identity = self.downsample(x)
54 |
55 | out += identity
56 | out = self.relu(out)
57 |
58 | return out
59 |
60 |
61 | class Bottleneck(nn.Module):
62 | expansion = 4
63 |
64 | def __init__(self, inplanes, planes, stride=1, downsample=None):
65 | super(Bottleneck, self).__init__()
66 | self.conv1 = conv1x1(inplanes, planes)
67 | self.bn1 = nn.BatchNorm2d(planes)
68 | self.conv2 = conv3x3(planes, planes, stride)
69 | self.bn2 = nn.BatchNorm2d(planes)
70 | self.conv3 = conv1x1(planes, planes * self.expansion)
71 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
72 | self.relu = nn.ReLU(inplace=True)
73 | self.downsample = downsample
74 | self.stride = stride
75 |
76 | def forward(self, x):
77 | identity = x
78 |
79 | out = self.conv1(x)
80 | out = self.bn1(out)
81 | out = self.relu(out)
82 |
83 | out = self.conv2(out)
84 | out = self.bn2(out)
85 | out = self.relu(out)
86 |
87 | out = self.conv3(out)
88 | out = self.bn3(out)
89 |
90 | if self.downsample is not None:
91 | identity = self.downsample(x)
92 |
93 | out += identity
94 | out = self.relu(out)
95 |
96 | return out
97 |
98 |
99 | class ResNet(nn.Module):
100 |
101 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False):
102 | super(ResNet, self).__init__()
103 | self.inplanes = 64
104 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
105 | bias=False)
106 | self.bn1 = nn.BatchNorm2d(64)
107 | self.relu = nn.ReLU(inplace=True)
108 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
109 | self.layer1 = self._make_layer(block, 64, layers[0])
110 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
111 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
112 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
113 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
114 | self.fc = nn.Linear(512 * block.expansion, num_classes)
115 |
116 | for m in self.modules():
117 | if isinstance(m, nn.Conv2d):
118 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
119 | elif isinstance(m, nn.BatchNorm2d):
120 | nn.init.constant_(m.weight, 1)
121 | nn.init.constant_(m.bias, 0)
122 |
123 | # Zero-initialize the last BN in each residual branch,
124 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
125 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
126 | if zero_init_residual:
127 | for m in self.modules():
128 | if isinstance(m, Bottleneck):
129 | nn.init.constant_(m.bn3.weight, 0)
130 | elif isinstance(m, BasicBlock):
131 | nn.init.constant_(m.bn2.weight, 0)
132 |
133 | def _make_layer(self, block, planes, blocks, stride=1):
134 | downsample = None
135 | if stride != 1 or self.inplanes != planes * block.expansion:
136 | downsample = nn.Sequential(
137 | conv1x1(self.inplanes, planes * block.expansion, stride),
138 | nn.BatchNorm2d(planes * block.expansion),
139 | )
140 |
141 | layers = []
142 | layers.append(block(self.inplanes, planes, stride, downsample))
143 | self.inplanes = planes * block.expansion
144 | for _ in range(1, blocks):
145 | layers.append(block(self.inplanes, planes))
146 |
147 | return nn.Sequential(*layers)
148 |
149 | def forward(self, x):
150 | x = self.conv1(x)
151 | x = self.bn1(x)
152 | x = self.relu(x)
153 | x = self.maxpool(x)
154 |
155 | x = self.layer1(x)
156 | x = self.layer2(x)
157 | x = self.layer3(x)
158 | x = self.layer4(x)
159 |
160 | x = self.avgpool(x)
161 | x = x.view(x.size(0), -1)
162 | x = self.fc(x)
163 |
164 | return x
165 |
166 | class ResNet_stem(nn.Module):
167 |
168 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False):
169 | super(ResNet_stem, self).__init__()
170 | self.inplanes = 64
171 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
172 | bias=False)
173 | self.bn1 = nn.BatchNorm2d(64)
174 | self.relu = nn.ReLU(inplace=True)
175 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
176 |
177 |
178 | # Zero-initialize the last BN in each residual branch,
179 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
180 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
181 | if zero_init_residual:
182 | for m in self.modules():
183 | if isinstance(m, Bottleneck):
184 | nn.init.constant_(m.bn3.weight, 0)
185 | elif isinstance(m, BasicBlock):
186 | nn.init.constant_(m.bn2.weight, 0)
187 |
188 |
189 |
190 | def forward(self, x):
191 | x = self.conv1(x)
192 | x = self.bn1(x)
193 | x = self.relu(x)
194 | x = self.maxpool(x)
195 |
196 | return x
197 |
198 |
199 |
200 | def resnet18(pretrained=False, **kwargs):
201 | """Constructs a ResNet-18 model.
202 |
203 | Args:
204 | pretrained (bool): If True, returns a model pre-trained on ImageNet
205 | """
206 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
207 | if pretrained:
208 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18']),strict=False)
209 | return model
210 |
211 | def resnet18_stem(pretrained=False, **kwargs):
212 | """Constructs a ResNet-18 model.
213 |
214 | Args:
215 | pretrained (bool): If True, returns a model pre-trained on ImageNet
216 | """
217 | model = ResNet_stem(BasicBlock, [2, 2, 2, 2], **kwargs)
218 | if pretrained:
219 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18']),strict=False)
220 | return model
221 |
222 |
223 | def resnet34(pretrained=False, **kwargs):
224 | """Constructs a ResNet-34 model.
225 |
226 | Args:
227 | pretrained (bool): If True, returns a model pre-trained on ImageNet
228 | """
229 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
230 | if pretrained:
231 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
232 | return model
233 |
234 |
235 | def resnet50(pretrained=False, **kwargs):
236 | """Constructs a ResNet-50 model.
237 |
238 | Args:
239 | pretrained (bool): If True, returns a model pre-trained on ImageNet
240 | """
241 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
242 | if pretrained:
243 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
244 | return model
245 |
246 |
247 | def resnet101(pretrained=False, **kwargs):
248 | """Constructs a ResNet-101 model.
249 |
250 | Args:
251 | pretrained (bool): If True, returns a model pre-trained on ImageNet
252 | """
253 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
254 | if pretrained:
255 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
256 | return model
257 |
258 |
259 | def resnet152(pretrained=False, **kwargs):
260 | """Constructs a ResNet-152 model.
261 |
262 | Args:
263 | pretrained (bool): If True, returns a model pre-trained on ImageNet
264 | """
265 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
266 | if pretrained:
267 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
268 | return model
269 |
--------------------------------------------------------------------------------
/config/CDKN.yml:
--------------------------------------------------------------------------------
1 | name: TransMIL_fea # task-specific 1p19q CDKN Diag Grade His MLC fea img
2 | #
3 | command: Train # Test
4 | gpus: [2]
5 |
6 |
7 | dataDir: /mnt/disk10T/fuyibing/wxf_data/TCGA/brain/ #ali8k_3
8 | #dataDir: /mnt/disk10T/fyb/wxf_data/TCGA/brain/ #ali8k_1
9 |
10 |
11 | #### Network setting main
12 | Network:
13 | BasicMIL: # instance level information integration with simply; majority voting or standard bag label definition
14 | lr: 0.0003
15 | AMIL:
16 | lr: 0.0005
17 | TransMIL:
18 | lr: 0.0002
19 | CLAM:
20 | lr: 0.0002
21 | UACNN:
22 | lr: 0.0003
23 | PatchGCN:
24 | lr: 0.0002
25 | Mine:
26 | lr: 0.0003
27 |
28 |
29 | #### Training setting
30 |
31 | n_ep: 200
32 | n_ep_decay: 50
33 | decayType: step # step, linear, exp
34 | n_ep_save: 5
35 | resume_epoch: 0 # 100
36 | eva_cm: False
37 | batchSize: 1
38 | Val_batchSize: 1
39 | Test_batchSize: 1
40 |
41 | #### Directories
42 | logDir: ./logs
43 | saveDir: ./outs
44 | modelDir: ./models
45 | cm_saveDir: ./cm
46 | label_path: ./merge_who.xlsx
47 |
48 | #### Meta setting main
49 | dataset: TCGA
50 | nThreads: 16
51 | seed: 124
52 | imgSize: [224,224]
53 |
--------------------------------------------------------------------------------
/config/cifar_pretrain.yml:
--------------------------------------------------------------------------------
1 | name: CLAM # model
2 | gpus: [1]
3 | batchSize: 1
4 | Val_batchSize: 1
5 | Test_batchSize: 1
6 | #dataDir: /mnt/disk10T_2/fuyibing/wxf_data/TCGA/brain/ #ali8k_2
7 | dataDir: /mnt/disk10T/fuyibing/wxf_data/TCGA/brain/ #ali8k_3
8 | #dataDir: /mnt/disk10T/fyb/wxf_data/TCGA/brain/ #ali8k_1
9 | #dataDir: /mnt/disk10T/fuyibing/wxf_data/TCGA/brain/ #ali4k
10 |
11 |
12 | #### Network setting main
13 | Network:
14 | BasicMIL: # instance level information integration with simply; majority voting or standard bag label definition
15 | lr: 0.0003
16 | AMIL:
17 | lr: 0.0005
18 | TransMIL:
19 | lr: 0.0002
20 | CLAM:
21 | lr: 0.003
22 | UACNN:
23 | lr: 0.0003
24 | PatchGCN:
25 | lr: 0.0002
26 |
27 |
28 |
29 | #### Training setting
30 |
31 | n_ep: 200
32 | n_ep_decay: 50
33 | decayType: exp # step, linear, exp,cos
34 | n_ep_save: 5
35 | resume_epoch: 0 # 100
36 |
37 |
38 | #### Directories
39 | logDir: ./logs
40 | saveDir: ./outs
41 | modelDir: ./models
42 | cm_saveDir: ./cm
43 | label_path: ./merge_who.xlsx
44 |
45 | #### Meta setting main
46 | dataset: TCGA
47 | nThreads: 16
48 | seed: 124
49 | imgSize: [224,224]
50 |
--------------------------------------------------------------------------------
/config/miccai.yml:
--------------------------------------------------------------------------------
1 | name: CLAM_IDH_fea # mna incep dense alex res
2 | gpus: [0]
3 | batchSize: 1
4 | fixdim: 2500
5 | dataDir: /home/zeiler/WSI_proj/data/ #Adden_linux1
6 | #dataDir: /home/cbtil2/WSI_proj/data/ #Adden_linux2
7 |
8 | #### Network setting main
9 | Network:
10 | BasicMIL: # instance level information integration with simply; majority voting or standard bag label definition
11 | lr: 0.001
12 | AMIL:
13 | lr: 0.0002
14 | TransMIL:
15 | lr: 0.0002
16 | CLAM:
17 | lr: 0.0001
18 | UACNN:
19 | lr: 0.0003
20 | PatchGCN:
21 | lr: 0.0002
22 |
23 |
24 | #### Training setting
25 |
26 | n_ep: 70
27 | decay_cos_warmup_steps: 35
28 | n_ep_decay: 15
29 | decayType: exp # step, linear, exp,cos
30 | n_ep_save: 5
31 | resume_epoch: 70 # 100
32 | eva_cm: False
33 | dataLabels: ['G2_O', 'G3_O', 'G2_A', 'G3_A', 'G4_A' ,'GBM'] #['G2_O', 'G3_O', 'G2_A', 'G3_A','G2_OA','G3_OA' 'GBM']
34 |
35 | #### Directories
36 | logDir: ./logs
37 | saveDir: ./outs
38 | modelDir: ./models
39 | cm_saveDir: ./cm
40 | label_path: ./merge_who.xlsx
41 |
42 | #### Meta setting main
43 | dataset: TCGA
44 | nThreads: 16
45 | seed: 124
46 | Val_batchSize: 1
47 | Test_batchSize: 1
48 | imgSize: [224,224]
49 | command: Train # Test
50 |
--------------------------------------------------------------------------------
/config/mine.yml:
--------------------------------------------------------------------------------
1 | name: Mine # Mine_dim2500_seed124_pretrain_exp_45
2 | command: Train # Test
3 | gpus: [0,1]
4 | fixdim: 2500
5 | batchSize: 2
6 | decayType: cos # step, linear, exp cos
7 | #dataDir: /raid/qiaominglang/brain/ #hyy3
8 | # dataDir: /home/zeiler/WSI_proj/data/ #Adden_linux1
9 | #dataDir: /home/cbtil2/WSI_proj/data/ #Adden_linux2
10 | # dataDir: /home/cbtil3/WSI_proj/data/ #Adden_linux3
11 | # dataDir: /home/cbtil/ST_proj/LDH/data/
12 | dataDir: /home/hanyu/
13 |
14 | #### Network setting main
15 | Network:
16 |
17 | lr: 0.003
18 | dropout_rate: 0.1
19 | IDH_layers: 3
20 | 1p19q_layers: 2
21 | CDKN_layers: 2
22 | His_layers: 3
23 | Grade_layers: 1
24 | Trans_block: 'full' #'full' 'simple'
25 | graph_alpha: 0.1
26 | corre_loss_ratio: 0.1
27 |
28 |
29 | #### Training setting
30 | n_ep: 80
31 | decay_cos_warmup_steps: 25
32 |
33 |
34 |
35 | #### Directories
36 | logDir: ./writer/logs
37 | saveDir: ./writer/outs
38 | modelDir: ./writer/models
39 | cm_saveDir: ./writer/cm
40 | label_path: ./merge_who.xlsx
41 | TCGA_label_path: ./TCGA.xlsx
42 | IvYGAP_label_path: ./IvYGAP.xlsx
43 | CPTAC_label_path: ./CPTAC.xlsx
44 | #### Meta setting main
45 | dataset: TCGA
46 | nThreads: 16
47 | seed: 124
48 | imgSize: [224,224]
49 | eva_cm: False
50 | n_ep_save: 1
51 | fp16: False
52 |
53 | Val_batchSize: 1
54 | Test_batchSize: 1
55 | n_ep_decay: 15
56 | top_K_patch: 300
--------------------------------------------------------------------------------
/data_process.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import h5py
3 | import argparse, time
4 |
5 | import gc
6 | # root=r'/mnt/disk10T/fuyibing/wxf_data/TCGA/brain/npy/'
7 | root=r'/mnt/disk10T/fyb/wxf_data/TCGA/brain/npy/'
8 |
9 | import SimpleITK as sitk
10 |
11 | #####itk
12 | img=np.zeros(shape=(1200,3,224,224), dtype=np.uint8)
13 | out = sitk.GetImageFromArray(img)
14 | sitk.WriteImage(out, root+'simpleitk_save.nii.gz')
15 |
16 |
17 | #####npy
18 | img=np.zeros(shape=(1200,3,224,224), dtype=np.uint8)
19 | np.save(root+'img.npy', img)
20 |
21 | #####h5
22 | imgData=np.zeros(shape=(1200,3,224,224), dtype=np.uint8)
23 | with h5py.File(root+'test.h5','w') as f:
24 | f['data'] = imgData
25 |
26 | k=h5py.File(root+'test.h5')['data'][:]
27 | a=1
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | from matplotlib import pyplot as plt
4 | from torch.utils.data import DataLoader,Dataset
5 | import numpy as np
6 | import os
7 | from PIL import Image
8 | from skimage import io,transform
9 | import cv2
10 | import torch
11 | import platform
12 | import pandas as pd
13 | import argparse, time, random
14 | import yaml
15 | from yaml.loader import SafeLoader
16 | from tqdm import tqdm
17 | import h5py
18 | import gc
19 | import math
20 | import scipy.interpolate
21 | from PIL import Image
22 | import cv2
23 | from matplotlib import pyplot as plt
24 | from torchvision.transforms import Compose
25 | import transform.transforms_group as our_transform
26 | from torchvision.transforms import Compose, ToTensor, ToPILImage, CenterCrop, Resize
27 | def train_transform(degree=180):
28 | return Compose([
29 | our_transform.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.05),
30 | ])
31 | class Our_Dataset(Dataset):
32 | def __init__(self, phase,opt):
33 | super(Our_Dataset, self).__init__()
34 | self.opt = opt
35 | self.patc_bs=64
36 | self.phase=phase
37 | # self.test_mode=opt['test_mode'] # WSI Patient
38 | self.name = opt['name'].split('_')
39 |
40 | excel_label_wsi = pd.read_excel(opt['label_path'],sheet_name='wsi_level',header=0)
41 | excel_wsi =excel_label_wsi.values
42 | PATIENT_LIST=excel_wsi[:,0]
43 | np.random.seed(self.opt['seed'])
44 | random.seed(self.opt['seed'])
45 | PATIENT_LIST=list(PATIENT_LIST)
46 |
47 |
48 |
49 |
50 | PATIENT_LIST=np.unique(PATIENT_LIST)
51 | np.random.shuffle(PATIENT_LIST)
52 | NUM_PATIENT_ALL=len(PATIENT_LIST) # 952
53 | TRAIN_PATIENT_LIST=PATIENT_LIST[0:int(NUM_PATIENT_ALL* 0.8)]
54 | VAL_PATIENT_LIST = PATIENT_LIST[int(NUM_PATIENT_ALL * 0.8):int(NUM_PATIENT_ALL * 0.9)]
55 | TEST_PATIENT_LIST = PATIENT_LIST[int(NUM_PATIENT_ALL * 0.9):]
56 | self.TRAIN_LIST=[]
57 | self.VAL_LIST = []
58 | self.TEST_LIST = []
59 | self.My_transform=train_transform()
60 | for i in range(excel_wsi.shape[0]):# 2612
61 | if excel_wsi[:,0][i] in TRAIN_PATIENT_LIST:
62 | self.TRAIN_LIST.append(excel_wsi[i,:])
63 | elif excel_wsi[:,0][i] in VAL_PATIENT_LIST:
64 | self.VAL_LIST.append(excel_wsi[i,:])
65 | elif excel_wsi[:,0][i] in TEST_PATIENT_LIST:
66 | self.TEST_LIST.append(excel_wsi[i,:])
67 | self.LIST= np.asarray(self.TRAIN_LIST) if self.phase == 'Train' else (np.asarray(self.VAL_LIST) if self.phase == 'Val' else np.asarray(self.TEST_LIST))
68 | # df = pd.DataFrame(self.LIST, columns=list(excel_label_wsi))
69 | # df.to_excel("vis/Val.xlsx", index=False)
70 |
71 | self.train_iter_count=0
72 | self.Flat=0
73 | self.WSI_all=[]
74 |
75 | def __getitem__(self, index):
76 |
77 |
78 |
79 | if self.name[2]=='img':
80 | self.read_img(index)
81 | elif self.name[2]=='fea':
82 | patch_all,coor_all=self.read_feature(index)
83 | label=self.label_gene(index)
84 |
85 | return torch.from_numpy(np.array(patch_all)).float(),torch.from_numpy(np.array(label)).long(), self.LIST[index, 1],coor_all
86 |
87 | def read_feature(self, index):
88 | root = self.opt['dataDir']+'TCGA/Res50_feature_'+str(self.opt['fixdim'])+'_fixdim0/'
89 | patch_all=h5py.File(root+self.LIST[index, 1]+'.h5')['Res_feature'][:] #(1,N,1024)
90 | coor_all = h5py.File(root + self.LIST[index, 1] + '.h5')['patches_coor'][:]
91 | return patch_all ,coor_all
92 | def read_feature1(self, index,k):
93 | root = self.opt['dataDir']+'Res50_feature_1200_fixdim0_aug/aug_set'+str(k)+'/'
94 | patch_all=h5py.File(root+self.LIST[index, 1]+'.h5')['Res_feature'][:] #(1,N,1024)
95 | coor_all = h5py.File(root + self.LIST[index, 1] + '.h5')['patches_coor'][:]
96 | return patch_all ,coor_all
97 |
98 |
99 |
100 | def read_img(self,index):
101 | wsi_path = self.dataDir + self.LIST[index, 1]
102 | patch_all = []
103 | patch_all_ori=[]
104 | coor_all=[]
105 | coor_all_ori = []
106 | self.img_dir = os.listdir(wsi_path)
107 |
108 | read_details=np.load(self.opt['dataDir']+'read_details/'+self.LIST[index, 1]+'.npy',allow_pickle=True)[0]
109 | num_patches = read_details.shape[0]
110 | print(num_patches)
111 | max_num=2500
112 | Use_patch_num = num_patches if num_patches <= max_num else max_num
113 | if num_patches <= max_num:
114 | times=int(np.floor(max_num/num_patches))
115 | remaining=max_num % num_patches
116 | for i in range(Use_patch_num):
117 | img_temp=io.imread(wsi_path + '/' + str(read_details[i][0]) + '_' + str(read_details[i][1]) + '.jpg')
118 | img_temp = cv2.resize(img_temp, (224, 224))
119 | patch_all_ori.append(img_temp)
120 | coor_all_ori.append(read_details[i])
121 | patch_all=patch_all_ori
122 | coor_all = coor_all_ori
123 |
124 | ####### fixdim0
125 | if times>1:
126 | for k in range(times-1):
127 | patch_all=patch_all+patch_all_ori
128 | coor_all=coor_all+coor_all_ori
129 | if not remaining==0:
130 | patch_all = patch_all + patch_all_ori[0:remaining]
131 | coor_all = coor_all + coor_all_ori[0:remaining]
132 |
133 | else:
134 | for i in range(Use_patch_num):
135 | img_temp = io.imread(wsi_path + '/' + str(read_details[int(np.around(i*(num_patches/max_num)))][0])+'_'+str(read_details[int(np.around(i*(num_patches/max_num)))][1])+'.jpg')
136 | img_temp = cv2.resize(img_temp, (224, 224))
137 | patch_all.append(img_temp)
138 | coor_all.append(read_details[int(np.around(i*(num_patches/max_num)))])
139 |
140 | patch_all = np.asarray(patch_all)
141 |
142 | # data augmentation
143 | patch_all = patch_all.reshape(-1, 224, 3) # (num_patches*224,224,3)
144 | patch_all = patch_all.reshape(-1, 224, 224, 3) # (num_patches,224,224,3)
145 |
146 |
147 | patch_all = patch_all / 255.0
148 | patch_all = np.transpose(patch_all, (0, 3, 1, 2))
149 | patch_all = patch_all.astype(np.float32)
150 |
151 |
152 | return patch_all,coor_all
153 |
154 | def label_gene(self,index):
155 | his_label_map= {'glioblastoma'}
156 | grade_label_map = {'2':0,'3':1,'4':2}
157 | #grade 2021 = {0: 'G2', 1: 'G3_O', 2: 'G4'}
158 | # His 2021 = {0: 'A', 1: 'O', 2: 'GBM'}
159 | #label 2021={ 0:'G2_O', 1:'G3_O', 2:'G2_A', 3:'G3_A', 4:'G4_A', 5:'GBM'}
160 | if self.name[1]=='IDH':
161 | if self.LIST[index, 4]=='WT':
162 | label=0
163 | elif self.LIST[index, 4]=='Mutant':
164 | label=1
165 | elif self.name[1] == '1p19q':
166 | if self.LIST[index, 5] == 'non-codel':
167 | label = 0
168 | elif self.LIST[index, 5] == 'codel':
169 | label = 1
170 | elif self.name[1] == 'CDKN':
171 | if self.LIST[index, 6] == -2 or self.LIST[index, 6] == -1:
172 | label = 1
173 | else:
174 | label = 0
175 |
176 | elif self.name[1] == 'Diag':
177 | if self.LIST[index, 4] == 'WT':
178 | label = 0
179 | elif self.LIST[index, 5] == 'codel':
180 | label = 3
181 | else:
182 | if self.LIST[index, 6] == -2 or self.LIST[index, 6] == -1 or self.LIST[index, 3] == 'G4':
183 | label = 1
184 | else:
185 | label = 2
186 | elif self.name[1] == 'Grade':
187 | if self.LIST[index, 4] == 'WT':
188 | label = 2
189 | elif self.LIST[index, 5] == 'codel':
190 | label = 0 if self.LIST[index, 3] =='G2' else 1
191 | else:
192 | if self.LIST[index, 6] == -2 or self.LIST[index, 6] == -1 or self.LIST[index, 3] =='G4':
193 | label = 2
194 | else:
195 | label = 0 if self.LIST[index, 3] == 'G2' else 1
196 | elif self.name[1] == 'His':
197 | # if self.LIST[index, 2]=='astrocytoma':
198 | # label = 0
199 | # elif self.LIST[index, 2] == 'oligoastrocytoma':
200 | # label = 1
201 | # elif self.LIST[index, 2] == 'oligodendroglioma':
202 | # label = 2
203 | # elif self.LIST[index, 2] == 'glioblastoma':
204 | # label = 3
205 | if self.LIST[index, 2] == 'glioblastoma':
206 | label = 1
207 | else:
208 | label = 0
209 |
210 |
211 |
212 | return label
213 |
214 |
215 | def shuffle_list(self, seed):
216 | np.random.seed(seed)
217 | random.seed(seed)
218 | np.random.shuffle(self.LIST)
219 |
220 |
221 |
222 | def __len__(self):
223 | return self.LIST.shape[0]
224 |
225 |
226 |
227 | if __name__ == '__main__':
228 |
229 | parser = argparse.ArgumentParser()
230 | parser.add_argument('--opt', type=str, default='config/miccai.yml')
231 | args = parser.parse_args()
232 | with open(args.opt) as f:
233 | opt = yaml.load(f, Loader=SafeLoader)
234 | trainDataset = Our_Dataset(phase='Val', opt=opt)
235 | for i in range(100):
236 | trainDataset._getitem__(i)
--------------------------------------------------------------------------------
/dataset_mine copy 2.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | from matplotlib import pyplot as plt
4 | from torch.utils.data import DataLoader,Dataset
5 | import numpy as np
6 | import os
7 | from PIL import Image
8 | from skimage import io,transform
9 | import cv2
10 | import torch
11 | import platform
12 | import pandas as pd
13 | import argparse, time, random
14 | import yaml
15 | from yaml.loader import SafeLoader
16 | from tqdm import tqdm
17 | import h5py
18 | import gc
19 | import math
20 | import scipy.interpolate
21 | from PIL import Image
22 | import cv2
23 | from matplotlib import pyplot as plt
24 | from torchvision.transforms import Compose
25 | import transform.transforms_group as our_transform
26 |
27 | class Our_Dataset(Dataset):
28 | def __init__(self, phase,opt,if_end2end=False):
29 | super(Our_Dataset, self).__init__()
30 | self.opt = opt
31 | self.patc_bs=64
32 | self.phase=phase
33 | self.if_end2end=if_end2end
34 |
35 | CPTAC_label = pd.read_excel(opt['CPTAC_label_path'], header=0)
36 | IvYGAP_label = pd.read_excel(opt['IvYGAP_label_path'], sheet_name='Sheet1', header=0)
37 | TCGA_label = pd.read_excel(opt['TCGA_label_path'], sheet_name='Sheet1', header=0)
38 | combined_labels = pd.concat([TCGA_label, CPTAC_label], ignore_index=True)
39 | excel_wsi = combined_labels.values
40 |
41 | PATIENT_LIST=excel_wsi[:,0]
42 | np.random.seed(self.opt['seed'])
43 | random.seed(self.opt['seed'])
44 | PATIENT_LIST=list(PATIENT_LIST)
45 | # IvYGAP_label
46 | IvYGAP_label = IvYGAP_label.values
47 |
48 | PATIENT_LIST=np.unique(PATIENT_LIST)
49 | np.random.shuffle(PATIENT_LIST)
50 | NUM_PATIENT_ALL=len(PATIENT_LIST) # 952
51 | TRAIN_PATIENT_LIST=PATIENT_LIST[0:int(NUM_PATIENT_ALL * 0.8)]
52 | VAL_PATIENT_LIST = PATIENT_LIST[int(NUM_PATIENT_ALL * 0.9):]
53 | TEST_PATIENT_LIST = PATIENT_LIST[int(NUM_PATIENT_ALL * 0.80):int(NUM_PATIENT_ALL * 0.90)]
54 | self.TRAIN_LIST=[]
55 | self.VAL_LIST = []
56 | self.TEST_LIST = []
57 | self.I_TEST_LIST = []
58 | for i in range(excel_wsi.shape[0]):# 2612
59 | if excel_wsi[:,0][i] in TRAIN_PATIENT_LIST:
60 | self.TRAIN_LIST.append(excel_wsi[i,:])
61 | elif excel_wsi[:,0][i] in VAL_PATIENT_LIST:
62 | self.VAL_LIST.append(excel_wsi[i,:])
63 | elif excel_wsi[:,0][i] in TEST_PATIENT_LIST:
64 | self.TEST_LIST.append(excel_wsi[i,:])
65 |
66 | for i in range(IvYGAP_label.shape[0]):# 2612
67 | self.I_TEST_LIST.append(IvYGAP_label[i,:])
68 | self.LIST= np.asarray(self.TRAIN_LIST) if self.phase == 'Train' else np.asarray(self.VAL_LIST) if self.phase == 'Val' else np.asarray(self.TEST_LIST) if self.phase == 'Test' else np.asarray(self.I_TEST_LIST)
69 |
70 | self.train_iter_count=0
71 | self.Flat=0
72 | self.WSI_all=[]
73 |
74 | def __getitem__(self, index):
75 | feature_all_20,feature_all_10, = self.read_feature(index)
76 |
77 | label=self.label_gene(index)
78 |
79 | return torch.from_numpy(np.array(feature_all_20)).float(),torch.from_numpy(np.array(feature_all_10)).float(),\
80 | torch.from_numpy(label)
81 |
82 | def read_feature(self, index):
83 |
84 | root = '/Res50_feature_2500_fixdim0_norm'
85 |
86 | patient_id = self.LIST[index, 0]
87 |
88 |
89 | if patient_id[0].startswith('T'):
90 | base_path = self.opt['dataDir'] + 'TCGA'
91 | elif patient_id[0].startswith('W'):
92 | base_path = self.opt['dataDir'] + 'IvYGAP'
93 | elif patient_id[0].startswith('C'):
94 | base_path = self.opt['dataDir'] + 'CPTAC'
95 | else:
96 | raise ValueError("Unknown data source")
97 |
98 | patch_20 = h5py.File(base_path + root + '_20x/' + self.LIST[index, 1] + '.h5')['Res_feature'][:]
99 | patch_10 = h5py.File(base_path + root + '/' + self.LIST[index, 1] + '.h5')['Res_feature'][:]
100 | return patch_20[0], patch_10[0]#, patch_1_25[0]
101 |
102 |
103 | def label_gene(self,index):
104 |
105 |
106 | if self.LIST[index, 4]=='WT':
107 | label_IDH=0
108 | elif self.LIST[index, 4]=='Mutant':
109 | label_IDH=1
110 | if self.LIST[index, 5] == 'non-codel':
111 | label_1p19q = 0
112 | elif self.LIST[index, 5] == 'codel':
113 | label_1p19q = 1
114 | if self.LIST[index, 6] == -2 or self.LIST[index, 6] == -1:
115 | label_CDKN = 1
116 | else:
117 | label_CDKN = 0
118 |
119 | if self.LIST[index, 2]=='oligoastrocytoma':
120 | label_His = 0
121 | elif self.LIST[index, 2] == 'astrocytoma':
122 | label_His = 1
123 | elif self.LIST[index, 2] == 'oligodendroglioma':
124 | label_His = 2
125 | elif self.LIST[index, 2] == 'glioblastoma':
126 | label_His = 3
127 |
128 | if self.LIST[index, 2]=='glioblastoma':
129 | label_His_2class = 1
130 | else:
131 | label_His_2class = 0
132 |
133 | if self.LIST[index, 3]=='G2':
134 | label_Grade=0
135 | elif self.LIST[index, 3] == 'G3':
136 | label_Grade = 1
137 | else:
138 | label_Grade=2 #### Useless
139 |
140 |
141 | if self.LIST[index, 4]=='WT':
142 | label_Diag = 0
143 | elif self.LIST[index, 5] == 'codel':
144 | label_Diag = 3
145 | else:
146 | if self.LIST[index, 6] == -2 or self.LIST[index, 6] == -1 or self.LIST[index, 3] =='G4':
147 | label_Diag = 1
148 | else:
149 | label_Diag = 2
150 |
151 |
152 | label=np.asarray([label_IDH,label_1p19q,label_CDKN,label_His,label_Grade,label_Diag,label_His_2class])
153 |
154 | return label
155 |
156 |
157 | def shuffle_list(self, seed):
158 | np.random.seed(seed)
159 | random.seed(seed)
160 | np.random.shuffle(self.LIST)
161 |
162 |
163 |
164 | def __len__(self):
165 | return self.LIST.shape[0]
166 |
167 | class Our_Dataset_vis(Dataset):
168 | def __init__(self, phase,opt,if_end2end=False):
169 | super(Our_Dataset_vis, self).__init__()
170 | self.opt = opt
171 | self.patc_bs=64
172 | self.phase=phase
173 | self.if_end2end=if_end2end
174 | self.dataDir = (opt['dataDir']+'extract_224/') if opt['imgSize'][0]==224 else (opt['dataDir']+'extract_512/')
175 |
176 | excel_label_wsi = pd.read_excel(opt['label_path'],sheet_name='wsi_level',header=0)
177 | excel_wsi =excel_label_wsi.values
178 | PATIENT_LIST=excel_wsi[:,0]
179 | np.random.seed(self.opt['seed'])
180 | random.seed(self.opt['seed'])
181 | PATIENT_LIST=list(PATIENT_LIST)
182 |
183 |
184 | PATIENT_LIST=np.unique(PATIENT_LIST)
185 | np.random.shuffle(PATIENT_LIST)
186 | NUM_PATIENT_ALL=len(PATIENT_LIST) # 952
187 | TEST_PATIENT_LIST=PATIENT_LIST[0:int(NUM_PATIENT_ALL)]
188 | TEST_WSI_LIST=os.listdir(r'/home/zeiler/WSI_proj/miccai/vis_results/set0/')
189 | self.TRAIN_LIST=[]
190 | self.VAL_LIST = []
191 | self.TEST_LIST = []
192 | for i in range(excel_wsi.shape[0]):# 2612
193 | if excel_wsi[:,1][i]+'.h5' in TEST_WSI_LIST:
194 | self.TEST_LIST.append(excel_wsi[i,:])
195 | self.LIST= np.asarray(self.TEST_LIST)
196 |
197 |
198 | self.train_iter_count=0
199 | self.Flat=0
200 | self.WSI_all=[]
201 |
202 | def __getitem__(self, index):
203 | feature_all ,read_details,self.LIST[index, 1]= self.read_feature(index)
204 |
205 | label=self.label_gene(index)
206 |
207 | return torch.from_numpy(np.array(feature_all)).float(),torch.from_numpy(label),read_details,self.LIST[index, 1]
208 |
209 | def read_feature(self, index):
210 | read_details = np.load(self.opt['dataDir'] + 'read_details/' + self.LIST[index, 1] + '.npy', allow_pickle=True)[
211 | 0]
212 | num_patches = read_details.shape[0]
213 | root = self.opt['dataDir']+'Res50_feature_'+str(self.opt['fixdim'])+'_fixdim0/'
214 | patch_all = h5py.File(root + self.LIST[index, 1] + '.h5')['Res_feature'][:] # (1,N,1024)
215 | return patch_all[0],read_details,self.LIST[index, 1]
216 |
217 |
218 | def label_gene(self,index):
219 |
220 |
221 | if self.LIST[index, 4]=='WT':
222 | label_IDH=0
223 | elif self.LIST[index, 4]=='Mutant':
224 | label_IDH=1
225 | if self.LIST[index, 5] == 'non-codel':
226 | label_1p19q = 0
227 | elif self.LIST[index, 5] == 'codel':
228 | label_1p19q = 1
229 | if self.LIST[index, 6] == -2 or self.LIST[index, 6] == -1:
230 | label_CDKN = 1
231 | else:
232 | label_CDKN = 0
233 |
234 | if self.LIST[index, 2]=='oligoastrocytoma':
235 | label_His = 0
236 | elif self.LIST[index, 2] == 'astrocytoma':
237 | label_His = 1
238 | elif self.LIST[index, 2] == 'oligodendroglioma':
239 | label_His = 2
240 | elif self.LIST[index, 2] == 'glioblastoma':
241 | label_His = 3
242 |
243 | if self.LIST[index, 2]=='glioblastoma':
244 | label_His_2class = 1
245 | else:
246 | label_His_2class = 0
247 |
248 | if self.LIST[index, 3]=='G2':
249 | label_Grade=0
250 | elif self.LIST[index, 3] == 'G3':
251 | label_Grade = 1
252 | else:
253 | label_Grade=2 #### Useless
254 |
255 |
256 | if self.LIST[index, 4]=='WT':
257 | label_Diag = 0
258 | elif self.LIST[index, 5] == 'codel':
259 | label_Diag = 3
260 | else:
261 | if self.LIST[index, 6] == -2 or self.LIST[index, 6] == -1 or self.LIST[index, 3] =='G4':
262 | label_Diag = 1
263 | else:
264 | label_Diag = 2
265 |
266 |
267 | label=np.asarray([label_IDH,label_1p19q,label_CDKN,label_His,label_Grade,label_Diag,label_His_2class])
268 |
269 | return label
270 |
271 |
272 | def shuffle_list(self, seed):
273 | np.random.seed(seed)
274 | random.seed(seed)
275 | np.random.shuffle(self.LIST)
276 |
277 |
278 |
279 | def __len__(self):
280 | return self.LIST.shape[0]
281 |
282 | if __name__ == '__main__':
283 | # epoch_seed1 = np.arange(1000)
284 | # np.random.seed(100)
285 | # random.seed(100)
286 | # epoch_seed = np.arange(5)
287 | # np.random.shuffle(epoch_seed)
288 | # for i in range(5):
289 | # np.random.seed(epoch_seed[i])
290 | # random.seed(epoch_seed[i])
291 | # np.random.shuffle(epoch_seed1)
292 | # print(epoch_seed1[:20])
293 |
294 |
295 | parser = argparse.ArgumentParser()
296 | parser.add_argument('--opt', type=str, default='config/mine.yml')
297 | args = parser.parse_args()
298 | with open(args.opt) as f:
299 | opt = yaml.load(f, Loader=SafeLoader)
300 | trainDataset = Our_Dataset(phase='Train', opt=opt)
301 | for i in range(100):
302 | _,x_,y_=trainDataset._getitem__(index=2000-i)
--------------------------------------------------------------------------------
/dataset_mine copy.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | from matplotlib import pyplot as plt
4 | from torch.utils.data import DataLoader,Dataset
5 | import numpy as np
6 | import os
7 | from PIL import Image
8 | from skimage import io,transform
9 | import cv2
10 | import torch
11 | import platform
12 | import pandas as pd
13 | import argparse, time, random
14 | import yaml
15 | from yaml.loader import SafeLoader
16 | from tqdm import tqdm
17 | import h5py
18 | import gc
19 | import math
20 | import scipy.interpolate
21 | from PIL import Image
22 | import cv2
23 | from matplotlib import pyplot as plt
24 | from torchvision.transforms import Compose
25 | import transform.transforms_group as our_transform
26 |
27 | class Our_Dataset(Dataset):
28 | def __init__(self, phase,opt,if_end2end=False):
29 | super(Our_Dataset, self).__init__()
30 | self.opt = opt
31 | self.patc_bs=64
32 | self.phase=phase
33 | self.if_end2end=if_end2end
34 |
35 | CPTAC_label = pd.read_excel(opt['CPTAC_label_path'], header=0)
36 | IvYGAP_label = pd.read_excel(opt['IvYGAP_label_path'], sheet_name='Sheet1', header=0)
37 | TCGA_label = pd.read_excel(opt['TCGA_label_path'], sheet_name='wsi_level', header=0)
38 | combined_labels = pd.concat([TCGA_label, CPTAC_label], ignore_index=True)
39 | excel_wsi = combined_labels.values
40 |
41 | PATIENT_LIST=excel_wsi[:,0]
42 | np.random.seed(self.opt['seed'])
43 | random.seed(self.opt['seed'])
44 | PATIENT_LIST=list(PATIENT_LIST)
45 | # IvYGAP_label
46 | IvYGAP_label = IvYGAP_label.values
47 |
48 | PATIENT_LIST=np.unique(PATIENT_LIST)
49 | np.random.shuffle(PATIENT_LIST)
50 | NUM_PATIENT_ALL=len(PATIENT_LIST) # 952
51 | TRAIN_PATIENT_LIST=PATIENT_LIST[0:int(NUM_PATIENT_ALL * 0.8)]
52 | # VAL_PATIENT_LIST = PATIENT_LIST[int(NUM_PATIENT_ALL * 0.9):]
53 | TEST_PATIENT_LIST = PATIENT_LIST[int(NUM_PATIENT_ALL * 0.80):int(NUM_PATIENT_ALL * 0.90)]
54 | self.TRAIN_LIST=[]
55 | self.VAL_LIST = []
56 | self.TEST_LIST = []
57 | self.I_TEST_LIST = []
58 |
59 | for i in range(excel_wsi.shape[0]):# 2612
60 | if excel_wsi[:,0][i] in TRAIN_PATIENT_LIST:
61 | self.TRAIN_LIST.append(excel_wsi[i,:])
62 | # elif excel_wsi[:,0][i] in VAL_PATIENT_LIST:
63 | # self.VAL_LIST.append(excel_wsi[i,:])
64 | elif excel_wsi[:,0][i] in TEST_PATIENT_LIST:
65 | self.TEST_LIST.append(excel_wsi[i,:])
66 |
67 | for i in range(IvYGAP_label.shape[0]):# 2612
68 | self.I_TEST_LIST.append(IvYGAP_label[i,:])
69 | self.LIST= np.asarray(self.TRAIN_LIST) if self.phase == 'Train' else np.asarray(self.VAL_LIST) if self.phase == 'Val' else np.asarray(self.TEST_LIST) if self.phase == 'Test' else np.asarray(self.I_TEST_LIST)
70 |
71 |
72 | self.train_iter_count=0
73 | self.Flat=0
74 | self.WSI_all=[]
75 |
76 | def __getitem__(self, index):
77 | feature_all_20,feature_all_10, = self.read_feature(index)
78 |
79 | label=self.label_gene(index)
80 |
81 | return torch.from_numpy(np.array(feature_all_20)).float(),torch.from_numpy(np.array(feature_all_10)).float(),\
82 | torch.from_numpy(label)
83 |
84 | def read_feature(self, index):
85 |
86 | root = '/Res50_feature_2500_fixdim0_norm'
87 |
88 | patient_id = self.LIST[index, 0]
89 |
90 |
91 | if patient_id[0].startswith('T'):
92 | base_path = self.opt['dataDir'] + 'TCGA'
93 | elif patient_id[0].startswith('W'):
94 | base_path = self.opt['dataDir'] + 'IvYGAP'
95 | elif patient_id[0].startswith('C'):
96 | base_path = self.opt['dataDir'] + 'CPTAC'
97 | else:
98 | raise ValueError("Unknown data source")
99 |
100 | patch_20 = h5py.File(base_path + root + '_20x/' + self.LIST[index, 1] + '.h5')['Res_feature'][:]
101 | patch_10 = h5py.File(base_path + root + '/' + self.LIST[index, 1] + '.h5')['Res_feature'][:]
102 | return patch_20[0], patch_10[0]#, patch_1_25[0]
103 |
104 |
105 | def label_gene(self,index):
106 |
107 |
108 | if self.LIST[index, 4]=='WT':
109 | label_IDH=0
110 | elif self.LIST[index, 4]=='Mutant':
111 | label_IDH=1
112 | if self.LIST[index, 5] == 'non-codel':
113 | label_1p19q = 0
114 | elif self.LIST[index, 5] == 'codel':
115 | label_1p19q = 1
116 | if self.LIST[index, 6] == -2 or self.LIST[index, 6] == -1:
117 | label_CDKN = 1
118 | else:
119 | label_CDKN = 0
120 |
121 | if self.LIST[index, 2]=='oligoastrocytoma':
122 | label_His = 0
123 | elif self.LIST[index, 2] == 'astrocytoma':
124 | label_His = 1
125 | elif self.LIST[index, 2] == 'oligodendroglioma':
126 | label_His = 2
127 | elif self.LIST[index, 2] == 'glioblastoma':
128 | label_His = 3
129 |
130 | if self.LIST[index, 2]=='glioblastoma':
131 | label_His_2class = 1
132 | else:
133 | label_His_2class = 0
134 |
135 | if self.LIST[index, 3]=='G2':
136 | label_Grade=0
137 | elif self.LIST[index, 3] == 'G3':
138 | label_Grade = 1
139 | else:
140 | label_Grade=2 #### Useless
141 |
142 |
143 | if self.LIST[index, 4]=='WT':
144 | label_Diag = 0
145 | elif self.LIST[index, 5] == 'codel':
146 | label_Diag = 3
147 | else:
148 | if self.LIST[index, 6] == -2 or self.LIST[index, 6] == -1 or self.LIST[index, 3] =='G4':
149 | label_Diag = 1
150 | else:
151 | label_Diag = 2
152 |
153 |
154 | label=np.asarray([label_IDH,label_1p19q,label_CDKN,label_His,label_Grade,label_Diag,label_His_2class])
155 |
156 | return label
157 |
158 |
159 | def shuffle_list(self, seed):
160 | np.random.seed(seed)
161 | random.seed(seed)
162 | np.random.shuffle(self.LIST)
163 |
164 |
165 |
166 | def __len__(self):
167 | return self.LIST.shape[0]
168 |
169 | class Our_Dataset_vis(Dataset):
170 | def __init__(self, phase,opt,if_end2end=False):
171 | super(Our_Dataset_vis, self).__init__()
172 | self.opt = opt
173 | self.patc_bs=64
174 | self.phase=phase
175 | self.if_end2end=if_end2end
176 | self.dataDir = (opt['dataDir']+'extract_224/') if opt['imgSize'][0]==224 else (opt['dataDir']+'extract_512/')
177 |
178 | excel_label_wsi = pd.read_excel(opt['label_path'],sheet_name='wsi_level',header=0)
179 | excel_wsi =excel_label_wsi.values
180 | PATIENT_LIST=excel_wsi[:,0]
181 | np.random.seed(self.opt['seed'])
182 | random.seed(self.opt['seed'])
183 | PATIENT_LIST=list(PATIENT_LIST)
184 |
185 |
186 | PATIENT_LIST=np.unique(PATIENT_LIST)
187 | np.random.shuffle(PATIENT_LIST)
188 | NUM_PATIENT_ALL=len(PATIENT_LIST) # 952
189 | TEST_PATIENT_LIST=PATIENT_LIST[0:int(NUM_PATIENT_ALL)]
190 | TEST_WSI_LIST=os.listdir(r'/home/zeiler/WSI_proj/miccai/vis_results/set0/')
191 | self.TRAIN_LIST=[]
192 | self.VAL_LIST = []
193 | self.TEST_LIST = []
194 | for i in range(excel_wsi.shape[0]):# 2612
195 | if excel_wsi[:,1][i]+'.h5' in TEST_WSI_LIST:
196 | self.TEST_LIST.append(excel_wsi[i,:])
197 | self.LIST= np.asarray(self.TEST_LIST)
198 |
199 |
200 | self.train_iter_count=0
201 | self.Flat=0
202 | self.WSI_all=[]
203 |
204 | def __getitem__(self, index):
205 | feature_all ,read_details,self.LIST[index, 1]= self.read_feature(index)
206 |
207 | label=self.label_gene(index)
208 |
209 | return torch.from_numpy(np.array(feature_all)).float(),torch.from_numpy(label),read_details,self.LIST[index, 1]
210 |
211 | def read_feature(self, index):
212 | read_details = np.load(self.opt['dataDir'] + 'read_details/' + self.LIST[index, 1] + '.npy', allow_pickle=True)[
213 | 0]
214 | num_patches = read_details.shape[0]
215 | root = self.opt['dataDir']+'Res50_feature_'+str(self.opt['fixdim'])+'_fixdim0/'
216 | patch_all = h5py.File(root + self.LIST[index, 1] + '.h5')['Res_feature'][:] # (1,N,1024)
217 | return patch_all[0],read_details,self.LIST[index, 1]
218 |
219 |
220 | def label_gene(self,index):
221 |
222 |
223 | if self.LIST[index, 4]=='WT':
224 | label_IDH=0
225 | elif self.LIST[index, 4]=='Mutant':
226 | label_IDH=1
227 | if self.LIST[index, 5] == 'non-codel':
228 | label_1p19q = 0
229 | elif self.LIST[index, 5] == 'codel':
230 | label_1p19q = 1
231 | if self.LIST[index, 6] == -2 or self.LIST[index, 6] == -1:
232 | label_CDKN = 1
233 | else:
234 | label_CDKN = 0
235 |
236 | if self.LIST[index, 2]=='oligoastrocytoma':
237 | label_His = 0
238 | elif self.LIST[index, 2] == 'astrocytoma':
239 | label_His = 1
240 | elif self.LIST[index, 2] == 'oligodendroglioma':
241 | label_His = 2
242 | elif self.LIST[index, 2] == 'glioblastoma':
243 | label_His = 3
244 |
245 | if self.LIST[index, 2]=='glioblastoma':
246 | label_His_2class = 1
247 | else:
248 | label_His_2class = 0
249 |
250 | if self.LIST[index, 3]=='G2':
251 | label_Grade=0
252 | elif self.LIST[index, 3] == 'G3':
253 | label_Grade = 1
254 | else:
255 | label_Grade=2 #### Useless
256 |
257 |
258 | if self.LIST[index, 4]=='WT':
259 | label_Diag = 0
260 | elif self.LIST[index, 5] == 'codel':
261 | label_Diag = 3
262 | else:
263 | if self.LIST[index, 6] == -2 or self.LIST[index, 6] == -1 or self.LIST[index, 3] =='G4':
264 | label_Diag = 1
265 | else:
266 | label_Diag = 2
267 |
268 |
269 | label=np.asarray([label_IDH,label_1p19q,label_CDKN,label_His,label_Grade,label_Diag,label_His_2class])
270 |
271 | return label
272 |
273 |
274 | def shuffle_list(self, seed):
275 | np.random.seed(seed)
276 | random.seed(seed)
277 | np.random.shuffle(self.LIST)
278 |
279 |
280 |
281 | def __len__(self):
282 | return self.LIST.shape[0]
283 |
284 | if __name__ == '__main__':
285 | # epoch_seed1 = np.arange(1000)
286 | # np.random.seed(100)
287 | # random.seed(100)
288 | # epoch_seed = np.arange(5)
289 | # np.random.shuffle(epoch_seed)
290 | # for i in range(5):
291 | # np.random.seed(epoch_seed[i])
292 | # random.seed(epoch_seed[i])
293 | # np.random.shuffle(epoch_seed1)
294 | # print(epoch_seed1[:20])
295 |
296 |
297 | parser = argparse.ArgumentParser()
298 | parser.add_argument('--opt', type=str, default='config/mine.yml')
299 | args = parser.parse_args()
300 | with open(args.opt) as f:
301 | opt = yaml.load(f, Loader=SafeLoader)
302 | trainDataset = Our_Dataset(phase='Train', opt=opt)
303 | for i in range(100):
304 | _,x_,y_=trainDataset._getitem__(index=2000-i)
--------------------------------------------------------------------------------
/dataset_mine.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | from matplotlib import pyplot as plt
4 | from torch.utils.data import DataLoader,Dataset
5 | import numpy as np
6 | import os
7 | from PIL import Image
8 | from skimage import io,transform
9 | import cv2
10 | import torch
11 | import platform
12 | import pandas as pd
13 | import argparse, time, random
14 | import yaml
15 | from yaml.loader import SafeLoader
16 | from tqdm import tqdm
17 | import h5py
18 | import gc
19 | import math
20 | import scipy.interpolate
21 | from PIL import Image
22 | import cv2
23 | from matplotlib import pyplot as plt
24 | from torchvision.transforms import Compose
25 | import transform.transforms_group as our_transform
26 |
27 | class Our_Dataset(Dataset):
28 | def __init__(self, phase,opt,if_end2end=False):
29 | super(Our_Dataset, self).__init__()
30 | self.opt = opt
31 | self.patc_bs=64
32 | self.phase=phase
33 | self.if_end2end=if_end2end
34 |
35 | CPTAC_label = pd.read_excel(opt['CPTAC_label_path'], header=0)
36 | IvYGAP_label = pd.read_excel(opt['IvYGAP_label_path'], sheet_name='Sheet1', header=0)
37 | TCGA_label = pd.read_excel(opt['TCGA_label_path'], sheet_name='Sheet1', header=0)
38 | combined_labels = pd.concat([TCGA_label, CPTAC_label], ignore_index=True)
39 | excel_wsi = combined_labels.values
40 |
41 | PATIENT_LIST=excel_wsi[:,0]
42 | np.random.seed(self.opt['seed'])
43 | random.seed(self.opt['seed'])
44 | PATIENT_LIST=list(PATIENT_LIST)
45 | # IvYGAP_label
46 | IvYGAP_label = IvYGAP_label.values
47 |
48 | PATIENT_LIST=np.unique(PATIENT_LIST)
49 | np.random.shuffle(PATIENT_LIST)
50 | NUM_PATIENT_ALL=len(PATIENT_LIST) # 952
51 | TRAIN_PATIENT_LIST=PATIENT_LIST[0:int(NUM_PATIENT_ALL * 0.8)]
52 | VAL_PATIENT_LIST = PATIENT_LIST[int(NUM_PATIENT_ALL * 0.9):]
53 | TEST_PATIENT_LIST = PATIENT_LIST[int(NUM_PATIENT_ALL * 0.80):int(NUM_PATIENT_ALL * 0.90)]
54 | self.TRAIN_LIST=[]
55 | self.VAL_LIST = []
56 | self.TEST_LIST = []
57 | self.I_TEST_LIST = []
58 | for i in range(excel_wsi.shape[0]):# 2612
59 | if excel_wsi[:,0][i] in TRAIN_PATIENT_LIST:
60 | self.TRAIN_LIST.append(excel_wsi[i,:])
61 | # elif excel_wsi[:,0][i] in VAL_PATIENT_LIST:
62 | # self.VAL_LIST.append(excel_wsi[i,:])
63 | elif excel_wsi[:,0][i] in TEST_PATIENT_LIST:
64 | self.TEST_LIST.append(excel_wsi[i,:])
65 |
66 | for i in range(IvYGAP_label.shape[0]):# 2612
67 | self.I_TEST_LIST.append(IvYGAP_label[i,:])
68 | self.LIST= np.asarray(self.TRAIN_LIST) if self.phase == 'Train' else np.asarray(self.VAL_LIST) if self.phase == 'Val' else np.asarray(self.TEST_LIST) if self.phase == 'Test' else np.asarray(self.I_TEST_LIST)
69 |
70 | self.train_iter_count=0
71 | self.Flat=0
72 | self.WSI_all=[]
73 |
74 | def __getitem__(self, index):
75 | feature_all_20,feature_all_10, = self.read_feature(index)
76 |
77 | label=self.label_gene(index)
78 |
79 | return torch.from_numpy(np.array(feature_all_20)).float(),torch.from_numpy(np.array(feature_all_10)).float(),\
80 | torch.from_numpy(label)
81 |
82 | def read_feature(self, index):
83 |
84 | root = '/Res50_feature_2500_fixdim0_norm'
85 |
86 | patient_id = self.LIST[index, 0]
87 |
88 |
89 | if patient_id[0].startswith('T'):
90 | base_path = self.opt['dataDir'] + 'TCGA'
91 | elif patient_id[0].startswith('W'):
92 | base_path = self.opt['dataDir'] + 'IvYGAP'
93 | elif patient_id[0].startswith('C'):
94 | base_path = self.opt['dataDir'] + 'CPTAC'
95 | else:
96 | raise ValueError("Unknown data source")
97 |
98 | patch_20 = h5py.File(base_path + root + '_20x/' + self.LIST[index, 1] + '.h5')['Res_feature'][:]
99 | patch_10 = h5py.File(base_path + root + '/' + self.LIST[index, 1] + '.h5')['Res_feature'][:]
100 | return patch_20[0], patch_10[0]#, patch_1_25[0]
101 |
102 |
103 | def label_gene(self,index):
104 |
105 |
106 | if self.LIST[index, 4]=='WT':
107 | label_IDH=0
108 | elif self.LIST[index, 4]=='Mutant':
109 | label_IDH=1
110 | if self.LIST[index, 5] == 'non-codel':
111 | label_1p19q = 0
112 | elif self.LIST[index, 5] == 'codel':
113 | label_1p19q = 1
114 | if self.LIST[index, 6] == -2 or self.LIST[index, 6] == -1:
115 | label_CDKN = 1
116 | else:
117 | label_CDKN = 0
118 |
119 | if self.LIST[index, 2]=='oligoastrocytoma':
120 | label_His = 0
121 | elif self.LIST[index, 2] == 'astrocytoma':
122 | label_His = 1
123 | elif self.LIST[index, 2] == 'oligodendroglioma':
124 | label_His = 2
125 | elif self.LIST[index, 2] == 'glioblastoma':
126 | label_His = 3
127 |
128 | if self.LIST[index, 2]=='glioblastoma':
129 | label_His_2class = 1
130 | else:
131 | label_His_2class = 0
132 |
133 | if self.LIST[index, 3]=='G2':
134 | label_Grade=0
135 | elif self.LIST[index, 3] == 'G3':
136 | label_Grade = 1
137 | else:
138 | label_Grade=2 #### Useless
139 |
140 |
141 | if self.LIST[index, 4]=='WT':
142 | label_Diag = 0
143 | elif self.LIST[index, 5] == 'codel':
144 | label_Diag = 3
145 | else:
146 | if self.LIST[index, 6] == -2 or self.LIST[index, 6] == -1 or self.LIST[index, 3] =='G4':
147 | label_Diag = 1
148 | else:
149 | label_Diag = 2
150 |
151 |
152 | label=np.asarray([label_IDH,label_1p19q,label_CDKN,label_His,label_Grade,label_Diag,label_His_2class])
153 |
154 | return label
155 |
156 |
157 | def shuffle_list(self, seed):
158 | np.random.seed(seed)
159 | random.seed(seed)
160 | np.random.shuffle(self.LIST)
161 |
162 |
163 |
164 | def __len__(self):
165 | return self.LIST.shape[0]
166 |
167 | class Our_Dataset_vis(Dataset):
168 | def __init__(self, phase,opt,if_end2end=False):
169 | super(Our_Dataset_vis, self).__init__()
170 | self.opt = opt
171 | self.patc_bs=64
172 | self.phase=phase
173 | self.if_end2end=if_end2end
174 | self.dataDir = (opt['dataDir']+'extract_224/') if opt['imgSize'][0]==224 else (opt['dataDir']+'extract_512/')
175 |
176 |
177 | TCGA_label = pd.read_excel(opt['TCGA_label_path'], sheet_name='Sheet1', header=0)
178 | excel_wsi = TCGA_label.values
179 |
180 | PATIENT_LIST=excel_wsi[:,0]
181 | np.random.seed(opt['seed'])
182 | random.seed(opt['seed'])
183 | PATIENT_LIST=list(PATIENT_LIST)
184 |
185 | PATIENT_LIST=np.unique(PATIENT_LIST)
186 | np.random.shuffle(PATIENT_LIST)
187 | TEST_LIST = []
188 |
189 | for i in range(excel_wsi.shape[0]):# 2612
190 |
191 | if excel_wsi[:,0][i] in PATIENT_LIST:
192 | TEST_LIST.append(excel_wsi[i,:])
193 | TEST_LIST = np.asarray(TEST_LIST)
194 |
195 |
196 | self.train_iter_count=0
197 | self.Flat=0
198 | self.WSI_all=[]
199 |
200 | def __getitem__(self, index):
201 | feature_all_20,feature_all_10,read_details,self.LIST[index, 1] = self.read_feature(index)
202 |
203 | label=self.label_gene(index)
204 |
205 | return torch.from_numpy(np.array(feature_all_20)).float(),torch.from_numpy(np.array(feature_all_10)).float(),\
206 | torch.from_numpy(label),read_details,self.LIST[index, 1]
207 |
208 | def read_feature(self, index):
209 |
210 |
211 | root = '/Res50_feature_2500_fixdim0_norm'
212 |
213 | patient_id = self.LIST[index, 0]
214 |
215 |
216 | if patient_id[0].startswith('T'):
217 | base_path = self.opt['dataDir'] + 'TCGA'
218 | elif patient_id[0].startswith('W'):
219 | base_path = self.opt['dataDir'] + 'IvYGAP'
220 | elif patient_id[0].startswith('C'):
221 | base_path = self.opt['dataDir'] + 'CPTAC'
222 | else:
223 | raise ValueError("Unknown data source")
224 | read_details = np.load(self.opt['dataDir'] + 'read_details/' + self.LIST[index, 1] + '.npy', allow_pickle=True)[0]
225 | num_patches = read_details.shape[0]
226 | patch_20 = h5py.File(base_path + root + '_20x/' + self.LIST[index, 1] + '.h5')['Res_feature'][:]
227 | patch_10 = h5py.File(base_path + root + '/' + self.LIST[index, 1] + '.h5')['Res_feature'][:]
228 | return patch_20[0], patch_10[0],num_patches,self.LIST[index, 1]
229 |
230 |
231 | def label_gene(self,index):
232 |
233 |
234 | if self.LIST[index, 4]=='WT':
235 | label_IDH=0
236 | elif self.LIST[index, 4]=='Mutant':
237 | label_IDH=1
238 | if self.LIST[index, 5] == 'non-codel':
239 | label_1p19q = 0
240 | elif self.LIST[index, 5] == 'codel':
241 | label_1p19q = 1
242 | if self.LIST[index, 6] == -2 or self.LIST[index, 6] == -1:
243 | label_CDKN = 1
244 | else:
245 | label_CDKN = 0
246 |
247 | if self.LIST[index, 2]=='oligoastrocytoma':
248 | label_His = 0
249 | elif self.LIST[index, 2] == 'astrocytoma':
250 | label_His = 1
251 | elif self.LIST[index, 2] == 'oligodendroglioma':
252 | label_His = 2
253 | elif self.LIST[index, 2] == 'glioblastoma':
254 | label_His = 3
255 |
256 | if self.LIST[index, 2]=='glioblastoma':
257 | label_His_2class = 1
258 | else:
259 | label_His_2class = 0
260 |
261 | if self.LIST[index, 3]=='G2':
262 | label_Grade=0
263 | elif self.LIST[index, 3] == 'G3':
264 | label_Grade = 1
265 | else:
266 | label_Grade=2 #### Useless
267 |
268 |
269 | if self.LIST[index, 4]=='WT':
270 | label_Diag = 0
271 | elif self.LIST[index, 5] == 'codel':
272 | label_Diag = 3
273 | else:
274 | if self.LIST[index, 6] == -2 or self.LIST[index, 6] == -1 or self.LIST[index, 3] =='G4':
275 | label_Diag = 1
276 | else:
277 | label_Diag = 2
278 |
279 |
280 | label=np.asarray([label_IDH,label_1p19q,label_CDKN,label_His,label_Grade,label_Diag,label_His_2class])
281 |
282 | return label
283 |
284 |
285 | def shuffle_list(self, seed):
286 | np.random.seed(seed)
287 | random.seed(seed)
288 | np.random.shuffle(self.LIST)
289 |
290 |
291 |
292 | def __len__(self):
293 | return self.LIST.shape[0]
294 |
295 | if __name__ == '__main__':
296 | # epoch_seed1 = np.arange(1000)
297 | # np.random.seed(100)
298 | # random.seed(100)
299 | # epoch_seed = np.arange(5)
300 | # np.random.shuffle(epoch_seed)
301 | # for i in range(5):
302 | # np.random.seed(epoch_seed[i])
303 | # random.seed(epoch_seed[i])
304 | # np.random.shuffle(epoch_seed1)
305 | # print(epoch_seed1[:20])
306 |
307 |
308 | parser = argparse.ArgumentParser()
309 | parser.add_argument('--opt', type=str, default='config/mine.yml')
310 | args = parser.parse_args()
311 | with open(args.opt) as f:
312 | opt = yaml.load(f, Loader=SafeLoader)
313 | trainDataset = Our_Dataset(phase='Train', opt=opt)
314 | for i in range(100):
315 | _,x_,y_=trainDataset._getitem__(index=2000-i)
--------------------------------------------------------------------------------
/debug.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from skimage import io
4 | import math
5 | ####### find out backward propogation difference of img and fea in TransMIL
6 |
7 | # seed=100
8 | # torch.manual_seed(seed)
9 | # torch.cuda.manual_seed(seed)
10 | # torch.cuda.manual_seed_all(seed)
11 | # np.random.seed(seed)
12 | # random.seed(seed)
13 | # parser = argparse.ArgumentParser()
14 | # parser.add_argument('--opt', type=str, default='config/miccai.yml')
15 | # args = parser.parse_args()
16 | # with open(args.opt) as f:
17 | # opt = yaml.load(f, Loader=SafeLoader)
18 | # gpuID = opt['gpus']
19 | # TransMIL = model.TransMIL(opt)
20 | # model.init_weights(TransMIL, init_type='xavier', init_gain=1)
21 | # assert opt['name'].split('_')[0]=='TransMIL'
22 | # device = torch.device('cuda:{}'.format(gpuID[0])) if gpuID else torch.device('cpu')
23 | # if opt['name'].split('_')[2] == 'img':
24 | # Res_pretrain= net.Res50_pretrain()
25 | # Res_pretrain.to(device)
26 | #
27 | # TransMIL.to(device)
28 | # TransMIL_opt = torch.optim.Adam(filter(lambda p: p.requires_grad, TransMIL.parameters()), 0.01, weight_decay=0.00001)
29 | # trainDataset = dataset.Our_Dataset(phase='Train',opt=opt)
30 | #
31 | # img,label,img_path=trainDataset._getitem__(0)
32 | #
33 | # label = torch.from_numpy(np.asarray([label.detach().numpy()])).long()
34 | #
35 | # if torch.cuda.is_available():
36 | # img = img.cuda(gpuID[0])
37 | # label = label.cuda(gpuID[0])
38 | # TransMIL.train()
39 | # TransMIL.zero_grad()
40 | # print(img_path)
41 | # if opt['name'].split('_')[2] == 'img':
42 | # Res_pretrain.train()
43 | # img=Res_pretrain(img)
44 | #
45 | # results_dict = TransMIL(img)
46 | # pred=results_dict['logits']
47 | # loss_all = TransMIL.calculateLoss(pred,label)
48 | # loss_all.backward()
49 | # TransMIL_opt.step()
50 |
51 | # coords_all = h5py.File('temp/TCGA-DU-6404-01Z-00-DX1.93c15688-f5b2-40bc-85eb-ff2661e16d4e.h5')['coords'][:]
52 | """
53 | Use only one object per slide
54 | """
55 | excel_label_wsi = pd.read_excel('./merge_who.xlsx',sheet_name='wsi_level',header=0)
56 | excel_wsi =excel_label_wsi.values
57 | WSI_used_names=list(excel_wsi[:,1])
58 | np.random.seed(1)
59 | random.seed(1)
60 | random.shuffle(WSI_used_names)
61 |
62 | root_reading_list=r'/mnt/disk10T/fyb/wxf_data/TCGA/brain/reading_list_extract224/'
63 | files_reading_list=os.listdir(root_reading_list)
64 |
65 | for i in range(excel_wsi.shape[0]):
66 |
67 | # print(WSI_used_names[i])
68 | reading_list = np.load(root_reading_list + WSI_used_names[i]+'.npy')
69 | # reading_list = np.load(root_reading_list + 'TCGA-E1-A7Z4-01Z-00-DX2' + '.npy')
70 | Num_cluster = 0
71 | point_num = reading_list.shape[0]
72 | points_center_corrd = reading_list
73 |
74 | FLAT_w = int(points_center_corrd[0].split('_')[0])
75 | FLAT_h = int(points_center_corrd[0].split('_')[1])
76 | claster_corrds = {'0': []}
77 | for k in range(point_num):
78 | w_point = int(points_center_corrd[k].split('_')[0])
79 | h_point = int(points_center_corrd[k].split('_')[1])
80 | if w_point == FLAT_w:
81 | claster_corrds[str(Num_cluster)].append([w_point, h_point])
82 | continue
83 | # FLAT_w=w_point
84 | if w_point > FLAT_w and (w_point-FLAT_w)<=6:
85 |
86 | FLAT_w = w_point
87 | claster_corrds[str(Num_cluster)].append([w_point, h_point])
88 | continue
89 | if w_point < FLAT_w or (w_point-FLAT_w)>6:
90 | FLAT_w = w_point
91 | Num_cluster += 1
92 | claster_corrds[str(Num_cluster)] = []
93 | claster_corrds[str(Num_cluster)].append([w_point, h_point])
94 | del_value=[]
95 | for key,value in enumerate(claster_corrds):
96 | claster_corrds[value]=np.asarray(claster_corrds[value])
97 | if claster_corrds[value].shape[0] < 150:
98 | del_value.append(value)
99 | for nn in range(len(del_value)):
100 | del claster_corrds[del_value[nn]]
101 | patch_length=[]
102 | value_name=[]
103 | for key, value in enumerate(claster_corrds):
104 | patch_length.append(claster_corrds[value].shape[0])
105 | value_name.append(value)
106 | save_dict=[]
107 | sort_0=np.argsort(np.asarray(patch_length))
108 | for nn in range(len(value_name)):
109 | save_dict.append(claster_corrds[value_name[sort_0[len(value_name)-nn-1]]])
110 |
111 | np.save('/mnt/disk10T/fyb/wxf_data/TCGA/brain/read_details/' + WSI_used_names[i]+ '.npy', save_dict)
112 | print(i)
113 | a=1
114 | a=[[[1,2],[11,21]],[[12,23],[14,24]],[[15,25],[16,26]],[[17,27],[18,28]]]
115 | a=np.asarray(a)
116 | a=a.reshape(2,2,2,2)
117 | from PIL import Image
118 | import cv2
119 | from matplotlib import pyplot as plt
120 | from torchvision.transforms import Compose
121 | import transform.transforms_group as our_transform
122 | def train_transform(degree=180):
123 | return Compose([
124 | our_transform.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.05),
125 | ])
126 | My_transform=train_transform()
127 | root=r'D:\PhD\Project_WSI\data\a/'
128 | files=os.listdir(root)
129 | imgs=[]
130 | max_num=1200
131 | read_details=np.load(r'D:\PhD\Project_WSI\data\TCGA-HT-8104-01A-01-TS1.npy',allow_pickle=True)[0]
132 | for i in range(len(files)):
133 | imgs.append(io.imread(root + '/' + str(read_details[i][0]) + '_' + str(read_details[i][1]) + '.jpg'))
134 | imgs = np.asarray(imgs)#(num_patches,224,224,3)
135 |
136 |
137 |
138 | imgs=imgs.reshape(-1,224,3) #(num_patches*224,224,3)
139 | imgs = Image.fromarray(imgs.astype('uint8')).convert('RGB')
140 | imgs=My_transform(imgs)
141 | imgs=np.array(imgs)#(num_patches*224,224,3)
142 | imgs=imgs.reshape(-1,224,224,3)#(num_patches,224,224,3)
143 |
144 | N_adj=int(math.sqrt(len(files)))
145 | imgs=imgs[0:N_adj*N_adj]
146 | imgs=imgs.reshape(N_adj,N_adj,224,224,3) #(Na,Na,224,224,3)
147 | imgs=np.transpose(imgs,(0,2,1,3,4)) #(Na,224,Na,224,3)
148 | imgs=imgs.reshape(N_adj*224,N_adj*224,3)#(Na*224,Na*224,3)
149 |
150 | plt.imshow(imgs)
151 | plt.show()
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
--------------------------------------------------------------------------------
/docs/1748968628246.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LHY1007/M3C2/ef66915bd3e9671312205c424fe6017915f1ed81/docs/1748968628246.png
--------------------------------------------------------------------------------
/docs/framework图.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LHY1007/M3C2/ef66915bd3e9671312205c424fe6017915f1ed81/docs/framework图.png
--------------------------------------------------------------------------------
/evaluation.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ['QTQPAPLATFORM']='offscreen'
3 | import torch
4 | import torchvision
5 | # from tensorboardX import SummaryWriter
6 | import numpy as np
7 | from PIL import Image
8 | #import cv2
9 | import matplotlib
10 | matplotlib.use('agg')
11 | import matplotlib.pyplot as plt
12 |
13 |
14 | def plot_confusion_matrix(cm, savename, title='Confusion Matrix', classes=['G2_O', 'G3_O', 'G2_A', 'G3_A', 'G4_A', 'GBM']):
15 | plt.figure(figsize=(11, 11), dpi=100)
16 | np.set_printoptions(precision=2) # 输出小数点的个数0:'G2_O', 1:'G3_O', 2:'G2_A', 3:'G3_A', 4:'G4_A', 5:'GBM'
17 |
18 | # 在混淆矩阵中每格的概率值
19 | # classes = ['P_MN', 'S_MN', 'P_IgAN', 'S_IgAN', 'LN', 'DN', 'ANCA', 'MPGN']
20 | ind_array = np.arange(len(classes))
21 | x, y = np.meshgrid(ind_array, ind_array)
22 | thresh = cm.max() / 2.
23 |
24 |
25 | for x_val, y_val in zip(x.flatten(), y.flatten()):
26 | c = cm[y_val][x_val]
27 | if c > 0.001:
28 | plt.text(x_val, y_val, "%0.2f" % (c,), color='white' if cm[x_val, y_val] > thresh else 'black',
29 | fontsize=20, va='center', ha='center')
30 |
31 | # plt.imshow(cm, interpolation='nearest', cmap=plt.cm.binary)
32 | plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
33 | plt.title(title, fontsize=36, pad=20)
34 |
35 | # plt.matshow(cm, cmap=plt.cm.Blues) # 背景颜色
36 |
37 | plt.colorbar()
38 | xlocations = np.array(range(len(classes)))
39 | # plt.xticks(xlocations, classes, rotation=90)
40 | # plt.yticks(xlocations, classes)
41 | plt.xticks(xlocations, classes, size=16)
42 | plt.yticks(xlocations, classes, size=16)
43 | plt.ylabel('Actual label', fontsize=22, labelpad=12)
44 | plt.xlabel('Predict label', fontsize=22, labelpad=12)
45 |
46 | # offset the tick
47 | tick_marks = np.array(range(len(classes))) + 0.5
48 | plt.gca().set_xticks(tick_marks, minor=True)
49 | plt.gca().set_yticks(tick_marks, minor=True)
50 | plt.gca().xaxis.set_ticks_position('none')
51 | plt.gca().yaxis.set_ticks_position('none')
52 | plt.grid(True, which='minor', linestyle='-')
53 | plt.gcf().subplots_adjust(bottom=0.15)
54 |
55 | # show confusion matrix
56 | plt.savefig(savename, format='png')
57 | # plt.show()
58 | plt.close('all')
--------------------------------------------------------------------------------
/feature_generation.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 | # import curve
5 | import platform
6 | import dataset
7 | from torch.utils.data import DataLoader,Dataset
8 | import numpy as np
9 | import argparse, time, random
10 | import yaml
11 | from yaml.loader import SafeLoader
12 | from evaluation import *
13 | import glob
14 | import model
15 | import net
16 | from saver import Saver
17 | from tqdm import tqdm
18 | import matplotlib.pyplot as plt
19 | from matplotlib.ticker import MultipleLocator
20 | from sklearn.metrics import roc_curve, auc
21 | from tensorboardX import SummaryWriter
22 | from sklearn.metrics import cohen_kappa_score
23 | from sklearn.metrics import accuracy_score
24 | from sklearn.metrics import precision_score, recall_score, f1_score
25 | from sklearn.metrics import confusion_matrix
26 | import gc
27 | from sklearn import metrics
28 | import h5py
29 | from utils import *
30 | import pandas as pd
31 | from skimage import io
32 |
33 | from apex import amp
34 | """
35 | label 2016={ 0:'G2_O', 1:'G3_O', 2:'G2_A', 3:'G3_A', 4:'G2_OA', 5:'G3_OA', 6:'GBM'}
36 | label 2021={ 0:'G2_O', 1:'G3_O', 2:'G2_A', 3:'G3_A', 4:'G4_A', 5:'GBM'}
37 | """
38 |
39 |
40 | def train(opt):
41 |
42 | gpuID = opt['gpus']
43 |
44 | ############## Init #####################################
45 |
46 |
47 | device = torch.device('cuda:{}'.format(gpuID[0])) if gpuID else torch.device('cpu')
48 |
49 |
50 | Res_pretrain= net.Res50_pretrain().cuda(gpuID[0])
51 | # Res_pretrain.to(device)
52 | Res_pretrain = nn.DataParallel(Res_pretrain, device_ids=gpuID)
53 | Res_pretrain.eval()
54 |
55 | ############### Datasets #######################
56 |
57 |
58 |
59 | trainDataset = dataset.Our_Dataset(phase='Train',opt=opt)
60 | valDataset = dataset.Our_Dataset(phase='Val', opt=opt)
61 | testDataset = dataset.Our_Dataset(phase='Test',opt=opt)
62 | trainLoader = DataLoader(trainDataset, batch_size=opt['batchSize'],
63 | num_workers=opt['nThreads'], shuffle=True)
64 | valLoader = DataLoader(valDataset, batch_size=opt['Val_batchSize'],
65 | num_workers=opt['nThreads'], shuffle=True)
66 | testLoader = DataLoader(testDataset, batch_size=opt['Test_batchSize'],
67 | num_workers=opt['nThreads'], shuffle=True)
68 |
69 |
70 | train_bar = tqdm(trainLoader)
71 | for packs in train_bar:
72 | img = packs[0][0] #(N,3,224,224)
73 | imgPath = packs[2][0]
74 | patches_coor = packs[3][0] # list N,2
75 | if torch.cuda.is_available():
76 | img = img.cuda(gpuID[0])
77 | N=img.detach().cpu().numpy().shape[0]
78 |
79 |
80 | feature = Res_pretrain(img) # N 1024
81 | feature = torch.unsqueeze(feature, dim=0)
82 | feature_save = feature.detach().cpu().numpy()
83 | feature_save=np.float16(feature_save)
84 | print(feature_save.shape)
85 | if not os.path.exists(opt['dataDir'] +'Res50_feature_2000_fixdim0_512/'):
86 | os.makedirs(opt['dataDir'] +'Res50_feature_2000_fixdim0_512/')
87 | with h5py.File(opt['dataDir'] +'Res50_feature_2000_fixdim0_512/'+ imgPath+'.h5', 'w') as f:
88 | f['Res_feature'] = feature_save
89 | f['patches_coor'] = patches_coor
90 |
91 | train_bar = tqdm(valLoader)
92 | for packs in train_bar:
93 | img = packs[0][0] # (N,3,224,224)
94 | imgPath = packs[2][0]
95 | patches_coor = packs[3][0] # list N,2
96 | if torch.cuda.is_available():
97 | img = img.cuda(gpuID[0])
98 | N = img.detach().cpu().numpy().shape[0]
99 |
100 | feature = Res_pretrain(img) # N 1024
101 | feature = torch.unsqueeze(feature, dim=0)
102 | feature_save = feature.detach().cpu().numpy()
103 | feature_save = np.float16(feature_save)
104 | print(feature_save.shape)
105 | if not os.path.exists(opt['dataDir'] + 'Res50_feature_2000_fixdim0_512/'):
106 | os.makedirs(opt['dataDir'] + 'Res50_feature_2000_fixdim0_512/')
107 | with h5py.File(opt['dataDir'] + 'Res50_feature_2000_fixdim0_512/' + imgPath + '.h5', 'w') as f:
108 | f['Res_feature'] = feature_save
109 | f['patches_coor'] = patches_coor
110 | #
111 | train_bar = tqdm(testLoader)
112 | for packs in train_bar:
113 | img = packs[0][0] # (N,3,224,224)
114 | imgPath = packs[2][0]
115 | patches_coor = packs[3][0] # list N,2
116 | if torch.cuda.is_available():
117 | img = img.cuda(gpuID[0])
118 | N = img.detach().cpu().numpy().shape[0]
119 |
120 | feature = Res_pretrain(img) # N 1024
121 | feature = torch.unsqueeze(feature, dim=0)
122 | feature_save = feature.detach().cpu().numpy()
123 | feature_save = np.float16(feature_save)
124 | print(feature_save.shape)
125 | if not os.path.exists(opt['dataDir'] + 'Res50_feature_2000_fixdim0_512/'):
126 | os.makedirs(opt['dataDir'] + 'Res50_feature_2000_fixdim0_512/')
127 | with h5py.File(opt['dataDir'] + 'Res50_feature_2000_fixdim0_512/' + imgPath + '.h5', 'w') as f:
128 | f['Res_feature'] = feature_save
129 | f['patches_coor'] = patches_coor
130 | # features_WSI=[]
131 | # for i in range(N):
132 | # feature=Res_pretrain(torch.unsqueeze(img[i],dim=0))
133 | # feature=feature.detach().cpu().numpy()
134 | # features_WSI.append(feature)
135 | # features_WSI=np.asarray(features_WSI) #(N,1024)
136 | # feature_save = np.expand_dims(features_WSI, axis=0)
137 | # feature_save = np.float16(feature_save)
138 | # print(feature_save.shape)
139 | # if not os.path.exists(opt['dataDir'] + 'Res50_feature_1200_iso_512/'):
140 | # os.makedirs(opt['dataDir'] + 'Res50_feature_1200_iso_512/')
141 | # with h5py.File(opt['dataDir'] + 'Res50_feature_1200_iso_512/' + imgPath + '.h5', 'w') as f:
142 | # f['Res_feature'] = feature_save
143 | # f['patches_coor'] = patches_coor
144 |
145 |
146 | a=1
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 | if __name__ == '__main__':
157 |
158 |
159 |
160 |
161 |
162 | parser = argparse.ArgumentParser()
163 | parser.add_argument('--opt', type=str, default='config/miccai.yml')
164 | args = parser.parse_args()
165 | with open(args.opt) as f:
166 | opt = yaml.load(f, Loader=SafeLoader)
167 | #
168 | # k1 = h5py.File(opt['dataDir'] + 'Res50_feature_1/TCGA-DH-5143-01Z-00-DX1.h5')['Res50_feature'][:][0,100,:]
169 | # k2 = h5py.File( './TCGA-DH-5143-01Z-00-DX1.h5')['Res50_feature'][:][0,100,:]
170 |
171 | # k3 = h5py.File(opt['dataDir'] + 'Res50_feature_1/TCGA-06-0876-01Z-00-DX1.h5')['Res50_feature'][:][0,0,100,:]
172 | # k4 = h5py.File(opt['dataDir'] + 'Res50_feature_2/TCGA-06-0876-01Z-00-DX1.h5')['Res50_feature'][:][0,100,:]
173 | #
174 | # k5 = h5py.File(opt['dataDir'] + 'Res50_feature_1/TCGA-CS-6290-01A-01-TS1.h5')['Res50_feature'][:][0,0,100,:]
175 | # k6 = h5py.File(opt['dataDir'] + 'Res50_feature_2/TCGA-CS-6290-01A-01-TS1.h5')['Res50_feature'][:][0,100,:]
176 |
177 |
178 | a = 1
179 | train(opt)
180 |
181 |
182 |
183 | a=1
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 |
--------------------------------------------------------------------------------
/logs.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LHY1007/M3C2/ef66915bd3e9671312205c424fe6017915f1ed81/logs.py
--------------------------------------------------------------------------------
/mainpre.py:
--------------------------------------------------------------------------------
1 | # from apex import amp
2 | from utils import *
3 | import dataset_mine
4 | from net import init_weights,get_scheduler,WarmupCosineSchedule
5 | def setup_seed(seed):
6 | torch.manual_seed(seed)
7 | torch.cuda.manual_seed(seed)
8 | torch.cuda.manual_seed_all(seed)
9 | np.random.seed(seed)
10 | random.seed(seed)
11 | if seed == 0:
12 | torch.backends.cudnn.deterministic = True
13 | torch.backends.cudnn.benchmark = False
14 |
15 |
16 | def train(opt):
17 | opt['gpus'] = [5]
18 | gpuID = opt['gpus']
19 | opt['batchSize'] =1
20 |
21 | ############### Mine_model #######################
22 | Mine_model_init,Mine_model_IDH,Mine_model_1p19q,Mine_model_CDKN,Mine_model_Graph,Mine_model_His,Mine_model_Cls,Mine_model_Task\
23 | ,opt_init,opt_IDH,opt_1p19q,opt_CDKN,opt_Graph,opt_His,opt_Cls,opt_Task=get_model(opt)
24 |
25 | if opt['decayType']=='exp' or opt['decayType']=='step':
26 | Mine_model_sch_init = get_scheduler(opt_init, opt['n_ep'], opt['n_ep_decay'], opt['decayType'], -1)
27 | Mine_model_sch_IDH = get_scheduler(opt_IDH, opt['n_ep'], opt['n_ep_decay'], opt['decayType'], -1)
28 | Mine_model_sch_1p19q = get_scheduler(opt_1p19q, opt['n_ep'], opt['n_ep_decay'], opt['decayType'], -1)
29 | Mine_model_sch_CDKN = get_scheduler(opt_CDKN, opt['n_ep'], opt['n_ep_decay'], opt['decayType'], -1)
30 | Mine_model_sch_Graph = get_scheduler(opt_Graph, opt['n_ep'], opt['n_ep_decay'], opt['decayType'], -1)
31 | Mine_model_sch_His = get_scheduler(opt_His, opt['n_ep'], opt['n_ep_decay'], opt['decayType'], -1)
32 | Mine_model_sch_Cls = get_scheduler(opt_Cls, opt['n_ep'], opt['n_ep_decay'], opt['decayType'], -1)
33 | Mine_model_sch_Task = get_scheduler(opt_Task, opt['n_ep'], opt['n_ep_decay'], opt['decayType'], -1)
34 | elif opt['decayType']=='cos':
35 | Mine_model_sch_init = WarmupCosineSchedule(opt_init, warmup_steps=opt['decay_cos_warmup_steps'], t_total=opt['n_ep'])
36 | Mine_model_sch_IDH = WarmupCosineSchedule(opt_IDH, warmup_steps=opt['decay_cos_warmup_steps'], t_total=opt['n_ep'])
37 | Mine_model_sch_1p19q = WarmupCosineSchedule(opt_1p19q, warmup_steps=opt['decay_cos_warmup_steps'], t_total=opt['n_ep'])
38 | Mine_model_sch_CDKN = WarmupCosineSchedule(opt_CDKN, warmup_steps=opt['decay_cos_warmup_steps'], t_total=opt['n_ep'])
39 | Mine_model_sch_Graph = WarmupCosineSchedule(opt_Graph, warmup_steps=opt['decay_cos_warmup_steps'],t_total=opt['n_ep'])
40 | Mine_model_sch_His = WarmupCosineSchedule(opt_His, warmup_steps=opt['decay_cos_warmup_steps'], t_total=opt['n_ep'])
41 | Mine_model_sch_Cls = WarmupCosineSchedule(opt_Cls, warmup_steps=opt['decay_cos_warmup_steps'],t_total=opt['n_ep'])
42 | Mine_model_sch_Task = WarmupCosineSchedule(opt_Task, warmup_steps=opt['decay_cos_warmup_steps'],t_total=opt['n_ep'])
43 |
44 | print('%d GPUs are working with the id of %s' % (torch.cuda.device_count(), str(gpuID)))
45 |
46 |
47 |
48 | ############### Datasets #######################
49 | root_init =r'/home/hanyu/LHY/miccai7.22/best_model/Best3_0720-0755-0003.pth'
50 | ckptdir_init = os.path.join(root_init)
51 | checkpoint_init = torch.load(ckptdir_init)
52 | root_IDH =r'/home/hanyu/LHY/miccai7.22/best_model/Best3_0720-0755-0003.pth'
53 | ckptdir_IDH = os.path.join(root_IDH)
54 | checkpoint_IDH = torch.load(ckptdir_IDH)
55 | root_1p19q =r'/home/hanyu/LHY/miccai7.22/best_model/Best3_0720-0755-0003.pth'
56 | ckptdir_1p19q = os.path.join(root_1p19q)
57 | checkpoint_1p19q = torch.load(ckptdir_1p19q)
58 | root_CDKN =r'/home/hanyu/LHY/miccai7.22/best_model/Best3_0720-0755-0003.pth'
59 | ckptdir_CDKN = os.path.join(root_CDKN)
60 | checkpoint_CDKN = torch.load(ckptdir_CDKN)
61 | root_Task =r'/home/hanyu/LHY/miccai7.22/best_model/Best3_0720-0755-0003.pth'
62 | ckptdir_Task = os.path.join(root_Task)
63 | checkpoint_Task = torch.load(ckptdir_Task)
64 |
65 | related_params = {k: v for k, v in checkpoint_init['init'].items()}
66 | Mine_model_init.load_state_dict(related_params)
67 | related_params = {k: v for k, v in checkpoint_IDH['IDH'].items()}
68 | Mine_model_IDH.load_state_dict(related_params)
69 | related_params = {k: v for k, v in checkpoint_1p19q['1p19q'].items()}
70 | Mine_model_1p19q.load_state_dict(related_params)
71 | related_params = {k: v for k, v in checkpoint_CDKN['CDKN'].items()}
72 | Mine_model_CDKN.load_state_dict(related_params)
73 | related_params = {k: v for k, v in checkpoint_IDH['Graph'].items()}
74 | Mine_model_Graph.load_state_dict(related_params)
75 | related_params = {k: v for k, v in checkpoint_IDH['His'].items()}
76 | Mine_model_His.load_state_dict(related_params)
77 | related_params = {k: v for k, v in checkpoint_IDH['Cls'].items()}
78 | Mine_model_Cls.load_state_dict(related_params)
79 | related_params = {k: v for k, v in checkpoint_Task['Task'].items()}
80 | Mine_model_Task.load_state_dict(related_params)
81 |
82 | Mine_model_init.eval()
83 | Mine_model_IDH.eval()
84 | Mine_model_1p19q.eval()
85 | Mine_model_CDKN.eval()
86 | Mine_model_Graph.eval()
87 | Mine_model_His.eval()
88 | Mine_model_Cls.eval()
89 | Mine_model_Task.eval()
90 |
91 |
92 | for i in range(1):
93 | testDataset = dataset_mine.Our_Dataset(phase='Test',opt=opt)
94 | testLoader = DataLoader(testDataset, batch_size=opt['Test_batchSize'],
95 | num_workers=opt['nThreads'] if (sysstr == "Linux") else 1, shuffle=False)
96 | ItestDataset = dataset_mine.Our_Dataset(phase='ITest',opt=opt)
97 | ItestLoader = DataLoader(ItestDataset, batch_size=opt['Test_batchSize'],
98 | num_workers=opt['nThreads'] if (sysstr == "Linux") else 1, shuffle=False)
99 |
100 | last_ep = 0
101 | saver = Saver(opt)
102 | alleps = opt['n_ep'] - last_ep
103 | curep=0
104 | print('-------------------------------------Val and Test--------------------------------------')
105 | if (curep + 1) > (2):
106 | save_dir = os.path.join(opt['modelDir'], 'Mine_model-%04d.pth' % (curep + 1))
107 | state = {
108 | 'init': Mine_model_init.state_dict(),
109 | 'IDH': Mine_model_IDH.state_dict(),
110 | '1p19q': Mine_model_1p19q.state_dict(),
111 | 'CDKN': Mine_model_CDKN.state_dict(),
112 | 'Graph': Mine_model_Graph.state_dict(),
113 | 'His': Mine_model_His.state_dict(),
114 | 'Cls': Mine_model_Cls.state_dict(),
115 | 'Task': Mine_model_Task.state_dict(),
116 | }
117 | torch.save(state, save_dir)
118 |
119 | print("----------Test-------------")
120 | list_WSI_IDH,list_WSI_1p19q,list_WSI_CDKN,list_WSI_His_2class,list_WSI_Diag= \
121 | validation_All(opt, Mine_model_init, Mine_model_IDH, Mine_model_1p19q,Mine_model_CDKN,Mine_model_Graph,Mine_model_His,Mine_model_Cls,Mine_model_Task,testLoader, saver, curep + 1, opt['eva_cm'], gpuID, task = '')
122 | print('test in epoch: %d/%d, acc_IDH:%.3f,acc_1p19q:%.3f,acc_CDKN:%.3f,acc_His_2class:%.3f, acc_Diag:%.3f' % (
123 | curep + 1, alleps, list_WSI_IDH[0], list_WSI_1p19q[0], list_WSI_CDKN[0], list_WSI_His_2class[0], list_WSI_Diag[0]))
124 | test_dict = {'test/acc_IDH': list_WSI_IDH[0], 'test/sen_IDH': list_WSI_IDH[3], 'test/spec_IDH': list_WSI_IDH[4],
125 | 'test/auc_IDH': list_WSI_IDH[5],'test/f1_IDH': list_WSI_IDH[2], 'test/prec_IDH': list_WSI_IDH[6],
126 | 'test/acc_1p19q': list_WSI_1p19q[0], 'test/sen_1p19q': list_WSI_1p19q[3], 'test/spec_1p19q': list_WSI_1p19q[4],
127 | 'test/auc_1p19q': list_WSI_1p19q[5], 'test/f1_1p19q': list_WSI_1p19q[2], 'test/prec_1p19q': list_WSI_1p19q[6],
128 | 'test/acc_CDKN': list_WSI_CDKN[0], 'test/sen_CDKN': list_WSI_CDKN[3], 'test/spec_CDKN': list_WSI_CDKN[4],
129 | 'test/auc_CDKN': list_WSI_CDKN[5], 'test/f1_CDKN': list_WSI_CDKN[2], 'test/prec_CDKN': list_WSI_CDKN[6],
130 | 'test/acc_His_2class': list_WSI_His_2class[0], 'test/sen_His_2class': list_WSI_His_2class[3], 'test/spec_His_2class': list_WSI_His_2class[4],
131 | 'test/auc_His_2class': list_WSI_His_2class[5], 'test/f1_His_2class': list_WSI_His_2class[2], 'test/prec_His_2class': list_WSI_His_2class[6],
132 | 'test/acc_Diag': list_WSI_Diag[0], 'test/sen_Diag': list_WSI_Diag[3], 'test/spec_Diag': list_WSI_Diag[4],
133 | 'test/auc_Diag': list_WSI_Diag[5], 'test/f1_Diag': list_WSI_Diag[2], 'test/prec_Diag': list_WSI_Diag[6],
134 |
135 | }
136 | test_dict_IDH = {'test/acc_IDH': list_WSI_IDH[0], 'test/sen_IDH': list_WSI_IDH[3], 'test/spec_IDH': list_WSI_IDH[4],
137 | 'test/auc_IDH': list_WSI_IDH[5], 'test/f1_IDH': list_WSI_IDH[2], 'test/prec_IDH': list_WSI_IDH[6],}
138 | test_dict_1p19q = {'test/acc_1p19q': list_WSI_1p19q[0], 'test/sen_1p19q': list_WSI_1p19q[3],'test/spec_1p19q': list_WSI_1p19q[4],
139 | 'test/auc_1p19q': list_WSI_1p19q[5], 'test/f1_1p19q': list_WSI_1p19q[2],'test/prec_1p19q': list_WSI_1p19q[6],}
140 | test_dict_CDKN = {'test/acc_CDKN': list_WSI_CDKN[0], 'test/sen_CDKN': list_WSI_CDKN[3],'test/spec_CDKN': list_WSI_CDKN[4],
141 | 'test/auc_CDKN': list_WSI_CDKN[5], 'test/f1_CDKN': list_WSI_CDKN[2],'test/prec_CDKN': list_WSI_CDKN[6],}
142 | test_dict_His_2class = {'test/acc_His_2class': list_WSI_His_2class[0], 'test/sen_His_2class': list_WSI_His_2class[3],'test/spec_His_2class': list_WSI_His_2class[4],
143 | 'test/auc_His_2class': list_WSI_His_2class[5], 'test/f1_His_2class': list_WSI_His_2class[2],'test/prec_His_2class': list_WSI_His_2class[6],}
144 | test_dict_Diag = {'test/acc_Diag': list_WSI_Diag[0], 'test/sen_Diag': list_WSI_Diag[3],'test/spec_Diag': list_WSI_Diag[4],
145 | 'test/auc_Diag': list_WSI_Diag[5], 'test/f1_Diag': list_WSI_Diag[2],'test/prec_Diag': list_WSI_Diag[6], }
146 | saver.write_scalars(curep + 1, test_dict)
147 | saver.write_log(curep + 1, test_dict_IDH, 'test_IDH')
148 | saver.write_log(curep + 1, test_dict_1p19q, 'test_1p19q')
149 | saver.write_log(curep + 1, test_dict_CDKN, 'test_CDKN')
150 | saver.write_log(curep + 1, test_dict_His_2class, 'test_His_2class')
151 | saver.write_log(curep + 1, test_dict_Diag, 'test_Diag')
152 |
153 | # print("----------ITest-------------")
154 | # list_WSI_IDH,list_WSI_1p19q,list_WSI_CDKN,list_WSI_His_2class,list_WSI_Diag= \
155 | # validation_All(opt, Mine_model_init, Mine_model_IDH, Mine_model_1p19q,Mine_model_CDKN,Mine_model_Graph,Mine_model_His,Mine_model_Cls,Mine_model_Task,ItestLoader, saver, curep + 1, opt['eva_cm'], gpuID, task = 'Itest')
156 | # print('Itest in epoch: %d/%d, acc_IDH:%.3f,acc_1p19q:%.3f,acc_CDKN:%.3f,acc_His_2class:%.3f, acc_Diag:%.3f' % (
157 | # curep + 1, alleps, list_WSI_IDH[0], list_WSI_1p19q[0], list_WSI_CDKN[0], list_WSI_His_2class[0], list_WSI_Diag[0]))
158 | # Itest_dict = {'Itest/acc_IDH': list_WSI_IDH[0], 'Itest/sen_IDH': list_WSI_IDH[3], 'Itest/spec_IDH': list_WSI_IDH[4],
159 | # 'Itest/auc_IDH': list_WSI_IDH[5],'Itest/f1_IDH': list_WSI_IDH[2], 'Itest/prec_IDH': list_WSI_IDH[6],
160 | # 'Itest/acc_1p19q': list_WSI_1p19q[0], 'Itest/sen_1p19q': list_WSI_1p19q[3], 'Itest/spec_1p19q': list_WSI_1p19q[4],
161 | # 'Itest/auc_1p19q': list_WSI_1p19q[5], 'Itest/f1_1p19q': list_WSI_1p19q[2], 'Itest/prec_1p19q': list_WSI_1p19q[6],
162 | # 'Itest/acc_CDKN': list_WSI_CDKN[0], 'Itest/sen_CDKN': list_WSI_CDKN[3], 'Itest/spec_CDKN': list_WSI_CDKN[4],
163 | # 'Itest/auc_CDKN': list_WSI_CDKN[5], 'Itest/f1_CDKN': list_WSI_CDKN[2], 'Itest/prec_CDKN': list_WSI_CDKN[6],
164 | # 'Itest/acc_His_2class': list_WSI_His_2class[0], 'Itest/sen_His_2class': list_WSI_His_2class[3], 'Itest/spec_His_2class': list_WSI_His_2class[4],
165 | # 'Itest/auc_His_2class': list_WSI_His_2class[5], 'Itest/f1_His_2class': list_WSI_His_2class[2], 'Itest/prec_His_2class': list_WSI_His_2class[6],
166 | # 'Itest/acc_Diag': list_WSI_Diag[0], 'Itest/sen_Diag': list_WSI_Diag[3], 'Itest/spec_Diag': list_WSI_Diag[4],
167 | # 'Itest/auc_Diag': list_WSI_Diag[5], 'Itest/f1_Diag': list_WSI_Diag[2], 'Itest/prec_Diag': list_WSI_Diag[6],
168 | # }
169 | # Itest_dict_IDH = {'Itest/acc_IDH': list_WSI_IDH[0], 'Itest/sen_IDH': list_WSI_IDH[3], 'Itest/spec_IDH': list_WSI_IDH[4],
170 | # 'Itest/auc_IDH': list_WSI_IDH[5], 'Itest/f1_IDH': list_WSI_IDH[2], 'Itest/prec_IDH': list_WSI_IDH[6],}
171 | # Itest_dict_1p19q = {'Itest/acc_1p19q': list_WSI_1p19q[0], 'Itest/sen_1p19q': list_WSI_1p19q[3],'Itest/spec_1p19q': list_WSI_1p19q[4],
172 | # 'Itest/auc_1p19q': list_WSI_1p19q[5], 'Itest/f1_1p19q': list_WSI_1p19q[2],'Itest/prec_1p19q': list_WSI_1p19q[6],}
173 | # Itest_dict_CDKN = {'Itest/acc_CDKN': list_WSI_CDKN[0], 'Itest/sen_CDKN': list_WSI_CDKN[3],'Itest/spec_CDKN': list_WSI_CDKN[4],
174 | # 'Itest/auc_CDKN': list_WSI_CDKN[5], 'Itest/f1_CDKN': list_WSI_CDKN[2],'Itest/prec_CDKN': list_WSI_CDKN[6],}
175 | # Itest_dict_His_2class = {'Itest/acc_His_2class': list_WSI_His_2class[0], 'Itest/sen_His_2class': list_WSI_His_2class[3],'Itest/spec_His_2class': list_WSI_His_2class[4],
176 | # 'Itest/auc_His_2class': list_WSI_His_2class[5], 'Itest/f1_His_2class': list_WSI_His_2class[2],'Itest/prec_His_2class': list_WSI_His_2class[6],}
177 | # Itest_dict_Diag = {'Itest/acc_Diag': list_WSI_Diag[0], 'Itest/sen_Diag': list_WSI_Diag[3],'Itest/spec_Diag': list_WSI_Diag[4],
178 | # 'Itest/auc_Diag': list_WSI_Diag[5], 'Itest/f1_Diag': list_WSI_Diag[2],'Itest/prec_Diag': list_WSI_Diag[6], }
179 | # saver.write_scalars(curep + 1, Itest_dict)
180 | # saver.write_log(curep + 1, Itest_dict_IDH, 'Itest_IDH')
181 | # saver.write_log(curep + 1, Itest_dict_1p19q, 'Itest_1p19q')
182 | # saver.write_log(curep + 1, Itest_dict_CDKN, 'Itest_CDKN')
183 | # saver.write_log(curep + 1, Itest_dict_His_2class, 'Itest_His_2class')
184 | # saver.write_log(curep + 1, Itest_dict_Diag, 'Itest_Diag')
185 |
186 | def remove_all_file(path):
187 | if os.path.isdir(path):
188 | for i in os.listdir(path):
189 | path_file = os.path.join(path, i)
190 | os.remove(path_file)
191 | def remove_all_dir(path):
192 | if os.path.isdir(path):
193 | for i in os.listdir(path):
194 | path_file = os.path.join(path, i)
195 | for j in os.listdir(path_file):
196 | path_file1 = os.path.join(path_file, j)
197 | os.remove(path_file1)
198 | os.rmdir(path_file)
199 |
200 |
201 |
202 |
203 |
204 | if __name__ == '__main__':
205 | parser = argparse.ArgumentParser()
206 | parser.add_argument('--opt', type=str, default='config/mine.yml')
207 | args = parser.parse_args()
208 | with open(args.opt) as f:
209 | opt = yaml.load(f, Loader=SafeLoader)
210 |
211 |
212 | sysstr = platform.system()
213 |
214 |
215 | setup_seed(opt['seed'])
216 | if opt['command']=='Train':
217 | cur_time = time.strftime('%m%d-%H%M', time.localtime())
218 |
219 |
220 | opt['name'] = 'Pretrain'+'_{}'.format(cur_time)
221 | opt['logDir'] = os.path.join(opt['logDir'], opt['name'])
222 | opt['modelDir'] = os.path.join(opt['modelDir'], opt['name'])
223 | opt['saveDir'] = os.path.join(opt['saveDir'], opt['name'])
224 | opt['cm_saveDir'] = os.path.join(opt['cm_saveDir'], opt['name'])
225 | if not os.path.exists(opt['logDir']):
226 | os.makedirs(opt['logDir'])
227 | if not os.path.exists(opt['modelDir']):
228 | os.makedirs(opt['modelDir'])
229 | if not os.path.exists(opt['saveDir']):
230 | os.makedirs(opt['saveDir'])
231 | if not os.path.exists(opt['cm_saveDir']):
232 | os.makedirs(opt['cm_saveDir'])
233 |
234 | para_log = os.path.join(opt['modelDir'], 'params.yml')
235 | if os.path.exists(para_log):
236 | os.remove(para_log)
237 | with open(para_log, 'w') as f:
238 | data = yaml.dump(opt, f, sort_keys=False, default_flow_style=False)
239 |
240 | print("\n\n============> begin training <=======")
241 | train(opt)
242 |
243 |
244 |
245 |
246 | a=1
247 |
248 |
249 |
250 |
251 |
252 |
253 |
254 |
255 |
256 |
257 |
258 |
259 |
260 |
261 |
262 |
263 |
264 |
265 |
266 |
267 |
268 |
269 |
270 |
271 |
272 |
273 |
274 |
275 |
276 |
--------------------------------------------------------------------------------
/merge_who.xlsx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LHY1007/M3C2/ef66915bd3e9671312205c424fe6017915f1ed81/merge_who.xlsx
--------------------------------------------------------------------------------
/post_processing.py:
--------------------------------------------------------------------------------
1 |
2 | # root=r'D:\PhD\Project_WSI\Others_code\GBM_WSSM-master\OneDrive-2023-03-15/'
3 | # dirs=os.listdir(root)
4 | # save_root=r'D:\PhD\Project_WSI\Others_code\GBM_WSSM-master/GBM_TCGA_SEG/'
5 | # for i in range(len(dirs)):
6 | # files=os.listdir(root+dirs[i]+'/')
7 | # for j in range(len(files)):
8 | # shutil.copy(root+dirs[i]+'/'+files[j],save_root+files[j])
9 | import os
10 | # from openslide import OpenSlide
11 | import h5py
12 | import numpy as np
13 | import cv2
14 | from skimage import io,transform
15 | root=r'D:\PhD\Project_WSI\Others_code\GBM_WSSM-master\GBM_Test_Images/'
16 | imgs=os.listdir(root)
17 | for i in range(len(imgs)):
18 | img_temp = io.imread(root+imgs[i])
19 | img_temp = cv2.resize(img_temp, (1024, 1024))
20 | io.imsave(root+imgs[i][0:-4]+'resize.jpg',img_temp)
21 |
22 |
23 | # def get_filename(root_path,file_path):
24 | # return_file = []
25 | # files=os.listdir(root_path+file_path)
26 | # for i in range(len(files)):
27 | # get_path = os.path.join(root_path, file_path,files[i])
28 | # if get_path.endswith('.svs') or get_path.endswith('.partial'):
29 | # return_file.append(get_path)
30 | # return return_file
31 | #
32 | # root=r'/mnt/disk10T/fuyibing/wxf_data/TCGA/brain/ori_wsi/ffpe_GBM/'
33 | # slide_path = []
34 | # WSI_path_list = os.listdir(root)
35 | # for i in range(len(WSI_path_list)):
36 | # get_file = get_filename(root, WSI_path_list[i])
37 | # slide_path.append(get_file[0])
38 | #
39 | #
40 | # for i in range(len(slide_path)):
41 | # wsi_obj = OpenSlide(slide_path[i])
42 | #
43 | # pro = dict(wsi_obj.properties)
44 | # MPP=np.float(pro['aperio.MPP'])
45 | # wsi_w=wsi_obj.dimensions[0]
46 | # wsi_h=wsi_obj.dimensions[1]
47 | #
48 | #
49 | #
50 | # with h5py.File('vis_results/set0/'+slide_path[i].split('/')[-1].split('.')[0]+'.h5', 'w') as f:
51 | # f['wsi_w'] = wsi_w
52 | # f['wsi_h'] = wsi_h
53 | # f['MPP'] = MPP
54 | # print(i)
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
--------------------------------------------------------------------------------
/roc_plot mu.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import matplotlib.pyplot as plt
4 | from sklearn.metrics import roc_curve, auc, roc_auc_score
5 | import torch.nn.functional as F
6 | from sklearn.metrics import roc_curve, auc, roc_auc_score
7 | import torch
8 | import ast
9 | import re
10 | import numpy as np
11 | import pandas as pd
12 | import re
13 | from scipy.special import softmax
14 | from sklearn.metrics import roc_curve, auc, roc_auc_score
15 | from scipy.special import expit # Sigmoid function
16 | def ROC(df):
17 | n_classes = 4
18 | fpr = dict()
19 | tpr = dict()
20 | roc_auc = dict()
21 | l = np.zeros((len(df), n_classes))
22 | p = np.zeros((len(df), n_classes))
23 | for i in range(len(df)):
24 | l[i, df.iloc[i, 0]] = 1
25 | label = df['label'].tolist()
26 | scores = df['score'].tolist()
27 | converted_data = [list(map(float, re.findall(r'-?\d+\.\d+', item))) for item in scores]
28 | score = np.array(converted_data)
29 |
30 | for i in range(n_classes):
31 |
32 | p[:, i] = score[:, i]
33 | fpr[i], tpr[i], _ = roc_curve(label, score[:, i], pos_label=i)
34 | roc_auc[i] = auc(fpr[i], tpr[i])
35 |
36 | all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)]))
37 | mean_tpr = np.zeros_like(all_fpr)
38 | for i in range(n_classes):
39 | mean_tpr += np.interp(all_fpr, fpr[i], tpr[i])
40 | mean_tpr /= n_classes
41 | fpr["macro"] = all_fpr
42 | tpr["macro"] = mean_tpr
43 | weighted_auc = roc_auc_score(l, p, average='macro')
44 |
45 | return fpr['macro'],tpr['macro'], weighted_auc
46 |
47 | if __name__ == "__main__":
48 | name='IN_Diag'
49 | # print('\033[1;35;0m字体变色,但无背景色 \033[0m')
50 | np.random.seed(2)
51 | plt.rcParams["font.family"] = "ARIAL"
52 | plt.rcParams["font.weight"] = "bold"
53 | plt.rcParams["axes.labelweight"] = "bold"
54 |
55 | COLOR_LIST = [[139 / 255, 20 / 255, 8 / 255],
56 | [188 / 255, 189 / 255, 34 / 255],
57 | [52 / 255, 193 / 255, 52 / 255],
58 | [150 / 255, 150 / 255, 190 / 255],
59 | [139 / 255, 101 / 255, 8 / 255],
60 | [68 / 255, 114 / 255, 236 / 255],
61 | [100 / 255, 114 / 255, 196 / 255],
62 | [214 / 255 + 0.1, 39 / 255 + 0.2, 40 / 255 + 0.2],
63 | [52 / 255, 163 / 255, 152 / 255],
64 | [139 / 255 * 1.1, 20 / 255 * 1.1, 8 / 255 * 1.1],
65 | [188 / 255 * 0.9, 189 / 255 * 0.9, 34 / 255 * 0.9],
66 | [52 / 255 * 1.1, 193 / 255 * 1.1, 52 / 255 * 1.1],
67 | [150 / 255 * 0.9, 150 / 255 * 0.9, 190 / 255 * 0.9],
68 | [139 / 255 * 1.1, 101 / 255 * 1.1, 8 / 255 * 1.1]]
69 |
70 | LINE_WIDTH_LIST = [3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3]
71 |
72 | i = 0
73 | plt.figure(figsize=[10.5, 10])
74 |
75 | LABEL_LIST =['Ours', 'Wang et al.','ABMIL', 'TransMIL', 'CLAM',
76 | 'Charm', 'Deepglioma', 'MCAT', 'CMTA',
77 | 'AlexNet','DenseNet','InceptionNet','ResNet-50','VGG-18']
78 | EXCEL_LIST = ['plot/Mine_'+name+'.xlsx', 'plot/MICCAI_'+name+'.xlsx','plot/ABMIL_'+name+'.xlsx','plot/TransMIL_'+name+'.xlsx',\
79 | 'plot/Charm_'+name+'.xlsx','plot/Deepglioma_'+name+'.xlsx','plot/MCAT_'+name+'.xlsx','plot/CMTA_'+name+'.xlsx',\
80 | 'plot/CLAM_'+name+'.xlsx','plot/AlexNet_'+name+'.xlsx','plot/DenseNet_'+name+'.xlsx','plot/InceptionNet_'+name+'.xlsx',\
81 | 'plot/ResNet-50_'+name+'.xlsx','plot/VGG-18'+name+'.xlsx']
82 |
83 | LABEL_LIST = ['Ours']
84 | EXCEL_LIST = ['plot/Mine_'+name+'.xlsx']
85 |
86 |
87 | fpr = dict()
88 | tpr = dict()
89 | roc_auc = dict()
90 | num = len(LABEL_LIST)
91 |
92 |
93 | for i in range(num):
94 | df = pd.read_excel(EXCEL_LIST[i])
95 | label = df['label'].tolist()
96 | score = df['score'].tolist()
97 |
98 | fpr[i], tpr[i], _ = ROC(df)
99 | plt.plot(fpr[i], tpr[i],
100 | label=LABEL_LIST[i], # 添加 label 参数
101 | linewidth= LINE_WIDTH_LIST[i] , color=np.array(COLOR_LIST[i]))
102 | plt.plot(1-df['spec'].tolist()[0], df['sen'].tolist()[0], marker="o", markersize=15, markerfacecolor=np.array(COLOR_LIST[i]), markeredgecolor=np.array(COLOR_LIST[i]))
103 | print(df['sen'].tolist()[0])
104 | print(df['spec'].tolist()[0])
105 |
106 | plt.plot([0, 1], [0, 1], 'k--', lw=2)
107 | plt.grid(color=[0.85, 0.85, 0.85])
108 |
109 | plt.xticks([0, 0.2, 0.4, 0.6, 0.8, 1], fontsize=24, weight='semibold')
110 | plt.yticks([0, 0.2, 0.4, 0.6, 0.8, 1], fontsize=24, weight='semibold')
111 |
112 | font_axis_name = {'fontsize': 34, 'weight': 'bold'}
113 | plt.xlabel('1-Specificity', font_axis_name)
114 | plt.ylabel('Sensitivity', font_axis_name)
115 | plt.xlim((0, 1))
116 | plt.ylim((0, 1))
117 | plt.legend(framealpha=1, fontsize=30, loc='lower right')
118 | plt.tight_layout()
119 |
120 | plt.savefig("plot/"+name+".tiff")
121 | plt.show()
122 |
--------------------------------------------------------------------------------
/roc_plot.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | import platform
6 | import numpy as np
7 | import matplotlib.pyplot as plt
8 | from matplotlib.patches import Polygon
9 | from itertools import cycle
10 | from sklearn.metrics import precision_score, recall_score, f1_score
11 | from sklearn import svm, datasets
12 | from sklearn.metrics import roc_curve, auc
13 | from sklearn.model_selection import train_test_split
14 | from sklearn.preprocessing import label_binarize
15 | from sklearn.multiclass import OneVsRestClassifier
16 | from sklearn.metrics import roc_curve, auc
17 | import scipy
18 | import os
19 | from sklearn.metrics import roc_auc_score
20 | import pandas as pd
21 |
22 | plt.rcParams["font.family"] = "Times New Roman"
23 | plt.rcParams["font.weight"] = "bold"
24 | plt.rcParams["axes.labelweight"] = "bold"
25 |
26 | def ROC(df):
27 | n_classes = 9
28 | fpr = dict()
29 | tpr = dict()
30 | roc_auc = dict()
31 | l = np.zeros((len(df), n_classes))
32 | p = np.zeros((len(df), n_classes))
33 | for i in range(len(df)):
34 | l[i, df.iloc[i, 0]] = 1
35 |
36 | for i in range(n_classes):
37 | label = df['label'].tolist()
38 | score = df['score' + str(i)].tolist()
39 | p[:, i] = score
40 | fpr[i], tpr[i], _ = roc_curve(label, score, pos_label=i)
41 | roc_auc[i] = auc(fpr[i], tpr[i])
42 |
43 | all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)]))
44 | mean_tpr = np.zeros_like(all_fpr)
45 | for i in range(n_classes):
46 | mean_tpr += np.interp(all_fpr, fpr[i], tpr[i])
47 | mean_tpr /= n_classes
48 | fpr["macro"] = all_fpr
49 | tpr["macro"] = mean_tpr
50 | weighted_auc = roc_auc_score(l, p, average='macro')
51 |
52 | return fpr['macro'],tpr['macro'], weighted_auc
53 |
54 |
55 | if __name__ == "__main__":
56 |
57 | name='1p19q'
58 | print('\033[1;35;0m字体变色,但无背景色 \033[0m')
59 | np.random.seed(2)
60 | plt.rcParams["font.family"] = "Arial"
61 | plt.rcParams["font.weight"] = "bold"
62 | plt.rcParams["axes.labelweight"] = "bold"
63 |
64 | COLOR_LIST = [ [139 / 255, 20 / 255, 8 / 255],[188 / 255, 189 / 255, 34 / 255],
65 | [52 / 255, 193 / 255, 52 / 255],
66 | [150 / 255, 150 / 255, 190 / 255], [139 / 255, 101 / 255, 8 / 255],
67 | [68 / 255, 114 / 255, 236 / 255],
68 | [100 / 255, 114 / 255, 196 / 255], [214 / 255 + 0.1, 39 / 255 + 0.2, 40 / 255 + 0.2],
69 | [52 / 255, 163 / 255, 152 / 255]]
70 |
71 | LINE_WIDTH_LIST = [3, 3, 3, 3, 3, 3, 3,3,3]
72 |
73 |
74 | i = 0
75 | plt.figure(figsize=[10.5, 10])
76 |
77 |
78 | LABEL_LIST =['Ours', 'CLAM', 'TransMIL', 'ResNet-18',
79 | 'DenseNet-121', 'VGG-16', 'W/O Graph', 'W/O LC loss',
80 | 'W/O DCC']
81 | EXCEL_LIST = ['plot/Mine_'+name+'.xlsx', 'plot/CLAM_'+name+'_fea.xlsx', 'plot/TransMIL_'+name+'_fea.xlsx',
82 | 'plot/Basic_'+name+'_img_res.xlsx', 'plot/Basic_'+name+'_img_dense.xlsx', 'plot/Basic_'+name+'_img_VGG.xlsx',
83 | 'plot/Mine_Graph_'+name+'.xlsx', 'plot/Mine_Graphloss_'+name+'.xlsx', 'plot/Mine_DCC_'+name+'.xlsx', ]
84 |
85 | fpr = dict()
86 | tpr = dict()
87 | roc_auc = dict()
88 |
89 |
90 | for i in range(9):
91 | df = pd.read_excel(EXCEL_LIST[i])
92 | label = df['label'].tolist()
93 | score = df['score'].tolist()
94 | fpr[i], tpr[i], _ = roc_curve(label, score)
95 | # roc_auc[i] = auc(fpr[i], tpr[i])
96 | plt.plot(fpr[i], tpr[i],
97 | # label=LABEL_LIST[i],
98 | linewidth= LINE_WIDTH_LIST[i] , color=np.array(COLOR_LIST[i]))
99 | plt.plot(1-df['spec'].tolist()[0], df['sen'].tolist()[0], marker="o", markersize=15, markerfacecolor=np.array(COLOR_LIST[i]), markeredgecolor=np.array(COLOR_LIST[i]))
100 | print(df['sen'].tolist()[0])
101 | print(df['spec'].tolist()[0])
102 | plt.plot([0, 1], [0, 1], 'k--', lw=2)
103 | plt.grid(color=[0.85, 0.85, 0.85])
104 |
105 | plt.xticks([0, 0.2, 0.4, 0.6, 0.8, 1], fontsize=24, weight='semibold')
106 | plt.yticks([0, 0.2, 0.4, 0.6, 0.8, 1], fontsize=24, weight='semibold')
107 |
108 | font_axis_name = {'fontsize': 34, 'weight': 'bold'}
109 | plt.xlabel('1-Specificity',font_axis_name)
110 | plt.ylabel('Sensitivity',font_axis_name)
111 | plt.xlim((0, 0.5))
112 | plt.ylim((0.5, 1))
113 | plt.legend(framealpha=1, fontsize=30, loc='lower right')
114 | plt.tight_layout()
115 |
116 | plt.savefig("plot/"+name+".tiff")
117 |
118 | plt.show()
119 |
--------------------------------------------------------------------------------
/roc_util/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "0.2.2"
2 | __author__ = "Norman Juchler"
3 |
4 | from ._roc import (get_objective,
5 | compute_roc,
6 | compute_mean_roc,
7 | compute_roc_bootstrap)
8 | from ._plot import (plot_roc,
9 | plot_mean_roc,
10 | plot_roc_simple,
11 | plot_roc_bootstrap)
12 | from ._demo import (demo_basic,
13 | demo_bootstrap,
14 | demo_sample_data)
15 |
--------------------------------------------------------------------------------
/roc_util/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LHY1007/M3C2/ef66915bd3e9671312205c424fe6017915f1ed81/roc_util/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/roc_util/__pycache__/_demo.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LHY1007/M3C2/ef66915bd3e9671312205c424fe6017915f1ed81/roc_util/__pycache__/_demo.cpython-38.pyc
--------------------------------------------------------------------------------
/roc_util/__pycache__/_plot.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LHY1007/M3C2/ef66915bd3e9671312205c424fe6017915f1ed81/roc_util/__pycache__/_plot.cpython-38.pyc
--------------------------------------------------------------------------------
/roc_util/__pycache__/_roc.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LHY1007/M3C2/ef66915bd3e9671312205c424fe6017915f1ed81/roc_util/__pycache__/_roc.cpython-38.pyc
--------------------------------------------------------------------------------
/roc_util/__pycache__/_sampling.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LHY1007/M3C2/ef66915bd3e9671312205c424fe6017915f1ed81/roc_util/__pycache__/_sampling.cpython-38.pyc
--------------------------------------------------------------------------------
/roc_util/__pycache__/_stats.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LHY1007/M3C2/ef66915bd3e9671312205c424fe6017915f1ed81/roc_util/__pycache__/_stats.cpython-38.pyc
--------------------------------------------------------------------------------
/roc_util/__pycache__/_types.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LHY1007/M3C2/ef66915bd3e9671312205c424fe6017915f1ed81/roc_util/__pycache__/_types.cpython-38.pyc
--------------------------------------------------------------------------------
/roc_util/_demo.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from ._plot import plot_roc_bootstrap, plot_roc
3 | from ._roc import compute_roc
4 |
5 |
6 | def demo_sample_data(n1, mu1, std1, n2, mu2, std2, seed=42):
7 | """
8 | Construct binary classification problem with n1 and n2
9 | samples per class, respectively.
10 |
11 | Returns two np.ndarrays x and y of length (n1+n2).
12 | x represents the predictor, y the binary response.
13 | """
14 | rng = np.random.RandomState(seed)
15 | x1 = rng.normal(mu1, std1, n1)
16 | x2 = rng.normal(mu2, std2, n2)
17 | y1 = np.zeros(n1, dtype=bool)
18 | y2 = np.ones(n2, dtype=bool)
19 | x = np.concatenate([x1, x2])
20 | y = np.concatenate([y1, y2])
21 | return x, y
22 |
23 |
24 | def demo_basic(n_samples=600, seed=42):
25 | """
26 | Demonstrate basic usage of compute_roc() and plot_roc().
27 | """
28 | import matplotlib.pyplot as plt
29 | pos_label = True
30 | x, y = demo_sample_data(n1=n_samples//2, mu1=0.0, std1=0.5,
31 | n2=n_samples//2, mu2=1.0, std2=0.7,
32 | seed=seed)
33 | roc = compute_roc(X=x, y=y, pos_label=pos_label)
34 | plot_roc(roc, label="Dataset", color="red")
35 | plt.title("Basic demo")
36 | plt.show()
37 |
38 |
39 | def demo_bootstrap(n_samples=600, n_bootstrap=50, seed=42):
40 | """
41 | Demonstrate a ROC analysis for a bootstrapped dataset.
42 | """
43 | import matplotlib.pyplot as plt
44 | assert(n_samples > 2)
45 | pos_label = True
46 | x, y = demo_sample_data(n1=n_samples//2, mu1=0.0, std1=0.5,
47 | n2=n_samples//2, mu2=1.0, std2=0.7,
48 | seed=seed)
49 | plot_roc_bootstrap(X=x, y=y, pos_label=pos_label,
50 | n_bootstrap=n_bootstrap,
51 | random_state=seed+1,
52 | show_boots=False,
53 | title="Bootstrap demo")
54 | plt.show()
55 |
--------------------------------------------------------------------------------
/roc_util/_plot.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from ._roc import (compute_roc,
3 | compute_mean_roc,
4 | compute_roc_bootstrap,
5 | _DEFAULT_OBJECTIVE)
6 |
7 |
8 | def plot_roc(roc,
9 | color="red",
10 | label=None,
11 | show_opt=False,
12 | show_details=False,
13 | format_axes=True,
14 | ax=None,
15 | **kwargs):
16 | """
17 | Plot the ROC curve given the output of compute_roc.
18 |
19 | Arguments:
20 | roc: Output of compute_roc() with the following keys:
21 | - fpr: false positive rates fpr(thr)
22 | - tpr: true positive rates tpr(thr)
23 | - opd: optimal point(s).
24 | - inv: true if predictor is inverted (predicts ~y)
25 | label: Label used for legend.
26 | show_opt: Show optimal point.
27 | show_details: Show additional information.
28 | format_axes: Apply axes settings, show legend, etc.
29 | kwargs: A dictionary with detail settings not exposed
30 | explicitly in the function signature. The following
31 | options are available:
32 | - zorder:
33 | - legend_out: Place legend outside (default: False)
34 | - legend_label_inv: Use 1-AUC if roc.inv=True (True)
35 | Additional kwargs are forwarded to ax.plot().
36 | """
37 | import matplotlib.pyplot as plt
38 | import matplotlib.colors as mplc
39 |
40 | def _format_axes(loc, margin):
41 | ax.axis("square")
42 | ax.set_xlim([0 - margin, 1. + margin])
43 | ax.set_ylim([0 - margin, 1. + margin])
44 | ax.set_xlabel("FPR (false positive rate)")
45 | ax.set_ylabel("TPR (true positive rate)")
46 | ax.grid(True)
47 | if legend_out:
48 | ax.legend(loc=loc,
49 | bbox_to_anchor=(1.05, 1),
50 | borderaxespad=0.)
51 | else:
52 | ax.legend(loc=loc)
53 |
54 | def _plot_opt_point(key, opt, color, marker, zorder, label, ax):
55 | # Some objectives can be visualized.
56 | # Plot these optional things first.
57 | if key == "minopt":
58 | ax.plot([0, opt.opp[0]], [1, opt.opp[1]], ":ok",
59 | alpha=0.3,
60 | zorder=zorder + 1)
61 | if key == "minoptsym":
62 | d2_ul = (opt.opp[0]-0)**2+(opt.opp[1]-1)**2
63 | d2_ll = (opt.opp[0]-1)**2+(opt.opp[1]-0)**2
64 | ref = (0, 1) if (d2_ul < d2_ll) else (1, 0)
65 | ax.plot([ref[0], opt.opp[0]], [ref[1], opt.opp[1]], ":ok",
66 | alpha=0.3,
67 | zorder=zorder + 1)
68 | if key == "youden":
69 | # Vertical line between optimal point and diagonal.
70 | ax.plot([opt.opp[0], opt.opp[0]],
71 | [opt.opp[0], opt.opp[1]],
72 | color=color,
73 | zorder=zorder + 1)
74 | if key == "cost":
75 | # Line parallel to diagonal (shrunk by m if m≠1).
76 | ax.plot(opt.opq[0], opt.opq[1], ":k",
77 | alpha=0.3,
78 | zorder=zorder + 1)
79 | if key == "concordance":
80 | # Rectangle illustrating the area tpr*(1-fpr)
81 | from matplotlib import patches
82 | ll = [opt.opp[0], 0]
83 | w = 1 - opt.opp[0]
84 | h = opt.opp[1]
85 | rect = patches.Rectangle(ll, w, h,
86 | facecolor=color,
87 | alpha=0.2,
88 | zorder=zorder + 1)
89 | ax.add_patch(rect)
90 |
91 | face_color = mplc.to_rgba(color, alpha=0.3)
92 | ax.plot(opt.opp[0], opt.opp[1],
93 | linestyle="None",
94 | marker=marker,
95 | markerfacecolor=face_color,
96 | markeredgecolor=color,
97 | label=label,
98 | zorder=zorder + 3)
99 |
100 | if ax is None:
101 | ax = plt.gca()
102 |
103 | # Copy the kwargs (a shallow copy should be sufficient).
104 | # Set some defaults.
105 | label = label if label else "Feature"
106 | zorder = kwargs.pop("zorder", 1)
107 | legend_out = kwargs.pop("legend_out", False)
108 | legend_label_inv = kwargs.pop("legend_label_inv", True)
109 | if legend_label_inv:
110 | auc_disp, auc_val = "1-AUC" if roc.inv else "AUC", roc.auc
111 | else:
112 | auc_disp, auc_val = "AUC", roc.auc
113 | label = "%s (%s=%.3f)" % (label, auc_disp, auc_val)
114 |
115 | # Plot the ROC curve.
116 | ax.plot(
117 | roc.fpr,
118 | roc.tpr,
119 | color=color,
120 | zorder=zorder + 2,
121 | # label=label,
122 | **kwargs)
123 |
124 | # Plot the no-discrimination line.
125 | label_diag = "No discrimination" if show_details else None
126 | ax.plot([0, 1], [0, 1], ":k", label=label_diag,
127 | zorder=zorder, linewidth=1)
128 |
129 | # Visualize the optimal point.
130 | if show_opt:
131 | from itertools import cycle
132 | markers = cycle(["o", "*", "^", "s", "P", "D"])
133 | for key, opt in roc.opd.items():
134 | pa_str = (", PA=%.3f" % opt.opa) if opt.opa else ""
135 | if show_details:
136 | legend_entry_opt = ("Optimal point (%s, thr=%.3g%s)"
137 | % (key, opt.opt, pa_str))
138 | else:
139 | legend_entry_opt = "Optimal point (thr=%.3g)" % opt.opt
140 | _plot_opt_point(key=key, opt=opt, color=color,
141 | marker=next(markers), zorder=zorder,
142 | label=legend_entry_opt, ax=ax)
143 |
144 | if format_axes:
145 | margin = 0.02
146 | loc = "upper left" if (roc.inv or legend_out) else "lower right"
147 | _format_axes(loc=loc, margin=margin)
148 |
149 |
150 | def plot_mean_roc(rocs, auto_flip=True, show_all=False, ax=None, **kwargs):
151 | """
152 | Compute and plot the mean ROC curve for a sequence of ROC containers.
153 |
154 | rocs: List of ROC containers created by compute_roc().
155 | auto_flip: See compute_roc(), applies only to mean ROC curve.
156 | show_all: If True, show the single ROC curves.
157 | If an integer, show the rocs[:show_all] roc curves.
158 | show_ci: Show confidence interval
159 | show_ti: Show tolerance interval
160 | kwargs: Forwarded to plot_roc(), applies only to mean ROC curve.
161 |
162 | Optional kwargs argument show_opt can be either False, True or a string
163 | specifying the particular objective function to be used to plot the
164 | optimal point. See get_objective() for details. Default choice is the
165 | "minopt" objective.
166 | """
167 | import matplotlib.pyplot as plt
168 | if ax is None:
169 | ax = plt.gca()
170 |
171 | n_samples = len(rocs)
172 |
173 | # Some default values.
174 | zorder = kwargs.get("zorder", 1)
175 | label = kwargs.pop("label", "Mean ROC curve")
176 | # kwargs for plot_roc()...
177 | show_details = kwargs.get("show_details", False)
178 | show_opt = kwargs.pop("show_opt", False)
179 | show_ti = kwargs.pop("show_ti", True)
180 | show_ci = kwargs.pop("show_ci", True)
181 | color = kwargs.pop("color", "red")
182 | is_opt_str = isinstance(show_opt, (str, list, tuple))
183 | # Defaults for mean-ROC.
184 | resolution = kwargs.pop("resolution", 101)
185 | objective = show_opt if is_opt_str else _DEFAULT_OBJECTIVE
186 |
187 | # Compute average ROC.
188 | ret_mean = compute_mean_roc(rocs=rocs,
189 | resolution=resolution,
190 | auto_flip=auto_flip,
191 | objective=objective)
192 |
193 | # Plot ROC curve for single bootstrap samples.
194 | if show_all:
195 | def isint(x):
196 | return isinstance(x, int) and not isinstance(x, bool)
197 | n_loops = show_all if isint(show_all) else np.inf
198 | n_loops = min(n_loops, len(rocs))
199 | for ret in rocs[:n_loops]:
200 | ax.plot(ret.fpr, ret.tpr,
201 | color="gray",
202 | alpha=0.2,
203 | zorder=zorder + 2)
204 | if show_ti:
205 | # 95% interval
206 | tpr_sort = np.sort(ret_mean.tpr_all, axis=0)
207 | tpr_lower = tpr_sort[int(0.025 * n_samples), :]
208 | tpr_upper = tpr_sort[int(0.975 * n_samples), :]
209 | label_int = "95% of all samples" if show_details else None
210 | ax.fill_between(ret_mean.fpr, tpr_lower, tpr_upper,
211 | color="gray", alpha=.2,
212 | label=label_int,
213 | zorder=zorder + 1)
214 | if show_ci:
215 | # 95% confidence interval
216 | tpr_std = np.std(ret_mean.tpr_all, axis=0, ddof=1)
217 | tpr_lower = ret_mean.tpr - 1.96 * tpr_std / np.sqrt(n_samples)
218 | tpr_upper = ret_mean.tpr + 1.96 * tpr_std / np.sqrt(n_samples)
219 | label_ci = "95% CI of mean curve" if show_details else None
220 | ax.fill_between(ret_mean.fpr, tpr_lower, tpr_upper,
221 | color=color, alpha=.3,
222 | label=label_ci,
223 | zorder=zorder)
224 |
225 | # Last but not least, plot the average ROC curve on top of everything.
226 | # plot_roc(roc=ret_mean, label=label, show_opt=show_opt,
227 | # color=color, ax=ax, zorder=zorder + 3, **kwargs)
228 | return ret_mean
229 |
230 |
231 | def plot_roc_simple(X, y, pos_label, auto_flip=True,
232 | title=None, ax=None, **kwargs):
233 | """
234 | Compute and plot the receiver-operator characteristic curve for X and y.
235 | kwargs are forwarded to plot_roc(), see there for details.
236 |
237 | Optional kwargs argument show_opt can be either False, True or a string
238 | specifying the particular objective function to be used to plot the
239 | optimal point. See get_objective() for details. Default choice is the
240 | "minopt" objective.
241 | """
242 | import matplotlib.pyplot as plt
243 | if ax is None:
244 | ax = plt.gca()
245 | show_opt = kwargs.pop("show_opt", False)
246 | is_opt_str = isinstance(show_opt, (str, list, tuple))
247 | objective = show_opt if is_opt_str else _DEFAULT_OBJECTIVE
248 | ret = compute_roc(X=X, y=y, pos_label=pos_label,
249 | objective=objective,
250 | auto_flip=auto_flip)
251 | plot_roc(roc=ret, show_opt=show_opt, ax=ax, **kwargs)
252 | title = "ROC curve" if title is None else title
253 | ax.get_figure().suptitle(title)
254 | return ret
255 |
256 |
257 | def plot_roc_bootstrap(X, y, pos_label,
258 | objective=_DEFAULT_OBJECTIVE,
259 | auto_flip=True,
260 | n_bootstrap=100,
261 | random_state=None,
262 | stratified=False,
263 | show_boots=False,
264 | title=None,
265 | ax=None,
266 | **kwargs):
267 | """
268 | Similar as plot_roc_simple(), but estimate an average ROC curve from
269 | multiple bootstrap samples.
270 |
271 | See compute_roc_bootstrap() for the meaning of the arguments.
272 |
273 | Optional kwargs argument show_opt can be either False, True or a string
274 | specifying the particular objective function to be used to plot the
275 | optimal point. See get_objective() for details. Default choice is the
276 | "minopt" objective.
277 | """
278 | import matplotlib.pyplot as plt
279 | if ax is None:
280 | ax = plt.gca()
281 |
282 | # 1) Collect the data.
283 | rocs = compute_roc_bootstrap(X=X, y=y,
284 | pos_label=pos_label,
285 | objective=objective,
286 | auto_flip=auto_flip,
287 | n_bootstrap=n_bootstrap,
288 | random_state=random_state,
289 | stratified=stratified,
290 | return_mean=False)
291 | # 2) Plot the average ROC curve.
292 | ret_mean = plot_mean_roc(rocs=rocs, auto_flip=auto_flip,
293 | show_all=show_boots, ax=ax, **kwargs)
294 |
295 | title = "ROC curve" if title is None else title
296 | ax.get_figure().suptitle(title)
297 | ax.set_title("Bootstrap reps: %d, sample size: %d" %
298 | (n_bootstrap, len(y)), fontsize=10)
299 | return ret_mean
300 |
--------------------------------------------------------------------------------
/roc_util/_sampling.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def resample_data(*arrays, **kwargs):
5 | """
6 | Similar to sklearn's resample function, with a few more extras.
7 |
8 | arrays: Arrays with consistent first dimension.
9 | kwargs:
10 | replace: Sample with replacement. Default: True
11 | n_samples: Number of samples. Default: len(arrays[0])
12 | frac: Compute the number of samples as a fraction of the
13 | array length: n_samples=frac*len(arrays[0])
14 | Overrides the value for n_samples if provided.
15 | random_state: Determines the random number generation. Can be None,
16 | an int or np.random.RandomState. Default: None
17 | stratify: An iterable containing the class labels by which the
18 | the arrays should be stratified. Default: None
19 | axis: Sampling axis. Note: axis!=0 is slow! Also, stratify
20 | is currently not supported if axis!=0. Default: axis=0
21 | squeeze: Flatten the output array if only one array is provided.
22 | Default: Trues
23 | """
24 | def _resample(*arrays, replace, n_samples, stratify, rng, axis=0):
25 | lens = [x.shape[axis] for x in arrays]
26 | equal_length = (lens.count(lens[0]) == len(lens))
27 | if not equal_length:
28 | msg = "Input arrays don't have equal length: %s"
29 | raise ValueError(msg % lens)
30 | if stratify is not None:
31 | msg = "Stratification is not supported yet."
32 | raise ValueError(msg)
33 | if not isinstance(rng, np.random.RandomState):
34 | rng = np.random.RandomState(rng)
35 |
36 | n = lens[0]
37 | if replace:
38 | indices = rng.randint(0, n, n_samples)
39 | else:
40 | indices = rng.choice(np.arange(0, n), n_samples)
41 | # Sampling along an axis!=0 is not very clever.
42 | arrays = [x.take(indices, axis=axis) for x in arrays]
43 | # Flatten the output if only one input array was provided.
44 | return arrays if len(arrays) > 1 else arrays[0]
45 | try:
46 | from sklearn.utils import resample
47 | has_sklearn = True
48 | except ModuleNotFoundError:
49 | has_sklearn = False
50 |
51 | replace = kwargs.pop("replace", True)
52 | n_samples = kwargs.pop("n_samples", None)
53 | frac = kwargs.pop("frac", None)
54 | rng = kwargs.pop("random_state", None)
55 | stratify = kwargs.pop("stratify", None)
56 | squeeze = kwargs.pop("squeeze", True)
57 | axis = kwargs.pop("axis", 0)
58 | if kwargs:
59 | msg = "Received unexpected argument(s): %s" % kwargs
60 | raise ValueError(msg)
61 |
62 | arrays = [np.asarray(x) if not hasattr(x, "shape") else x for x in arrays]
63 | lens = [x.shape[axis] for x in arrays]
64 | if frac:
65 | n_samples = int(np.round(frac * lens[0]))
66 | if n_samples is None:
67 | n_samples = lens[0]
68 | if axis > 0 or not has_sklearn:
69 | ret = _resample(*arrays, replace=replace, n_samples=n_samples,
70 | stratify=stratify, rng=rng, axis=axis)
71 | else:
72 | ret = resample(*arrays,
73 | replace=replace,
74 | n_samples=n_samples,
75 | stratify=stratify,
76 | random_state=rng,)
77 | # Undo the squeezing, which is done by resample (and _resample).
78 | if not squeeze and len(arrays) == 1:
79 | ret = [ret]
80 | return ret
81 |
--------------------------------------------------------------------------------
/roc_util/_stats.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import scipy.stats as st
3 |
4 |
5 | def mean_intervals(data, confidence=0.95, axis=None):
6 | """
7 | Compute the mean, the confidence interval of the mean, and the tolerance
8 | interval. Note that the confidence interval is often misinterpreted [3].
9 |
10 | References:
11 | [1] https://en.wikipedia.org/wiki/Confidence_interval
12 | [2| https://en.wikipedia.org/wiki/Tolerance_interval
13 | [3] https://en.wikipedia.org/wiki/Confidence_interval#Meaning_and_interpretation
14 | """
15 | confidence = confidence / 100.0 if confidence > 1.0 else confidence
16 | assert(0 < confidence < 1)
17 | a = 1.0 * np.array(data)
18 | n = len(a)
19 | # Both s=std() and se=sem() use unbiased estimators (ddof=1).
20 | m = np.mean(a, axis=axis)
21 | s = np.std(a, ddof=1, axis=axis)
22 | se = st.sem(a, axis=axis)
23 | t = st.t.ppf((1 + confidence) / 2., n - 1)
24 | ci = np.c_[m - se * t, m + se * t]
25 | ti = np.c_[m - s * t, m + s * t]
26 | assert(ci.shape[1] == 2 and ci.shape[0] ==
27 | np.size(m, axis=None if axis is None else 0))
28 | assert(ti.shape[1] == 2 and ti.shape[0] ==
29 | np.size(m, axis=None if axis is None else 0))
30 | return m, ci, ti
31 |
32 |
33 | def mean_confidence_interval(data, confidence=0.95, axis=None):
34 | """
35 | Compute the mean and the confidence interval of the mean.
36 | """
37 | m, ci, _ = mean_intervals(data, confidence, axis=axis)
38 | return m, ci
39 |
40 |
41 | def mean_tolerance_interval(data, confidence=0.95, axis=None):
42 | """
43 | Compute the tolerance interval for the data.
44 | """
45 | m, _, ti = mean_intervals(data, confidence, axis=axis)
46 | return m, ti
47 |
--------------------------------------------------------------------------------
/roc_util/_types.py:
--------------------------------------------------------------------------------
1 | class StructContainer():
2 | """
3 | Build a type that behaves similar to a struct.
4 |
5 | Usage:
6 | # Construction from named arguments.
7 | settings = StructContainer(option1 = False,
8 | option2 = True)
9 | # Construction from dictionary.
10 | settings = StructContainer({"option1": False,
11 | "option2": True})
12 | print(settings.option1)
13 | settings.option2 = False
14 | for k,v in settings.items():
15 | print(k,v)
16 | """
17 |
18 | def __init__(self, dictionary=None, **kwargs):
19 | if dictionary is not None:
20 | assert(isinstance(dictionary, (dict, StructContainer)))
21 | self.__dict__.update(dictionary)
22 | self.__dict__.update(kwargs)
23 |
24 | def __iter__(self):
25 | for i in self.__dict__:
26 | yield i
27 |
28 | def __getitem__(self, key):
29 | return self.__dict__[key]
30 |
31 | def __setitem__(self, key, value):
32 | self.__dict__[key] = value
33 |
34 | def __len__(self):
35 | return sum(1 for k in self.keys())
36 |
37 | def __repr__(self):
38 | return "struct(**%s)" % str(self.__dict__)
39 |
40 | def __str__(self):
41 | return str(self.__dict__)
42 |
43 | def items(self):
44 | for k, v in self.__dict__.items():
45 | if not k.startswith("_"):
46 | yield (k, v)
47 |
48 | def keys(self):
49 | for k in self.__dict__:
50 | if not k.startswith("_"):
51 | yield k
52 |
53 | def values(self):
54 | for k, v in self.__dict__.items():
55 | if not k.startswith("_"):
56 | yield v
57 |
58 | def update(self, data):
59 | self.__dict__.update(data)
60 |
61 | def asdict(self):
62 | return dict(self.items())
63 |
64 | def first(self):
65 | # Assumption: __dict__ is ordered (python>=3.6).
66 | key, value = next(self.items())
67 | return key, value
68 |
69 | def last(self):
70 | # Assumption: __dict__ is ordered (python>=3.6).
71 | # See also: https://stackoverflow.com/questions/58413076
72 | key = list(self.keys())[-1]
73 | return key, self[key]
74 |
75 | def get(self, key, default=None):
76 | return self.__dict__.get(key, default)
77 |
78 | def setdefault(self, key, default=None):
79 | return self.__dict__.setdefault(key, default)
80 |
--------------------------------------------------------------------------------
/saver.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torchvision
4 | from tensorboardX import SummaryWriter
5 | import numpy as np
6 | from PIL import Image
7 | from evaluation import *
8 | import cv2
9 |
10 |
11 |
12 | class Saver():
13 | def __init__(self, opt):
14 | self.logDir = opt['logDir']
15 | self.n_ep_save = opt['n_ep_save']
16 | self.writer = SummaryWriter(logdir=self.logDir)
17 |
18 | def write_scalars(self, ep, lossdict):
19 | # Todo Save images
20 | for loss_key, loss_value in lossdict.items():
21 | self.writer.add_scalar(loss_key, loss_value, ep)
22 |
23 |
24 | def write_maps(self, ep, map_dict):
25 | for name,map in map_dict.items():
26 | if len(map.shape)==2:
27 | map = map[np.newaxis,...]
28 | if map.shape[0] == 1:
29 | map = np.concatenate((map, map, map), axis=0)
30 | self.writer.add_image('map/'+name, map, ep)
31 |
32 |
33 | def write_log(self, ep, lossdict, Name):
34 | logpath = os.path.join(self.logDir, Name + '.log')
35 | title = 'epochs,'
36 | vals = '%d,'%(ep)
37 | for loss_key, loss_value in lossdict.items():
38 | title = title + loss_key + ','
39 | vals = vals + '%4f,'% (loss_value)
40 | title = title[:-1] + '\n'
41 | vals = vals[:-1] + '\n'
42 | if ep==self.n_ep_save-1:
43 | saveFile = open(logpath, "w")
44 | saveFile.write(title)
45 | saveFile.write(vals)
46 | else:
47 | saveFile = open(logpath, "a")
48 | saveFile.write(vals)
49 | saveFile.close()
50 |
51 |
52 |
53 |
54 | def write_imagegroup(self, ep, images, basename, key):
55 | # images: tensor Bx3xHxW or Bx1xHxW or BxHxW
56 | if len(images.shape) == 3:
57 | images = torch.unsqueeze(images, 1)
58 | images = torch.cat([images, images, images], 1)
59 | elif images.shape[1] == 1:
60 | images = torch.cat([images, images, images], 1)
61 | image_dis = torchvision.utils.make_grid(images, nrow=7)
62 | self.writer.add_image('map/' + key, image_dis, ep)
63 | image_dis2 = image_dis
64 |
65 | ndarr = image_dis2.mul_(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
66 | savename = os.path.join(self.logDir, key + '_' + basename + '_' +str(ep) + '.png')
67 | if key == 'SAmap':
68 | ndarr = cv2.applyColorMap(ndarr, cv2.COLORMAP_JET)
69 | cv2.imwrite(savename, ndarr)
70 |
71 | return ndarr
72 |
73 | def write_cm_maps(self, ep, cm, class_list, savename='cm.png'):
74 | savename = os.path.join(self.logDir, savename)
75 | plot_confusion_matrix(cm, savename, title='Confusion Matrix',
76 | classes=class_list)
77 | cmimg = cv2.imread(savename)
78 | cmimg = np.transpose(cmimg, (2, 0, 1))
79 | self.writer.add_image('map/cm', cmimg, ep)
80 |
81 |
--------------------------------------------------------------------------------
/test_main_new.py:
--------------------------------------------------------------------------------
1 | # from apex import amp
2 | import numpy as np
3 |
4 | from utils import *
5 | import dataset_mine
6 | from net import init_weights,get_scheduler,WarmupCosineSchedule
7 | from matplotlib import pyplot as plt
8 | import matplotlib as mpl
9 | import cmaps
10 | import h5py
11 | def train(opt):
12 | root=r'/home/hanyu/LHY/miccai7.22/best_model/Best3_0720-0755-0003.pth'
13 |
14 | opt['gpus']=[6,7]
15 | gpuID = opt['gpus']
16 | valDataset = dataset_mine.Our_Dataset_vis(phase='Test', opt=opt)
17 | valLoader = DataLoader(valDataset, batch_size=opt['Val_batchSize'],
18 | num_workers=opt['nThreads'] if (sysstr == "Linux") else 1, shuffle=False)
19 |
20 | ############## initialize #######################
21 |
22 | last_ep = 0
23 | total_it = 0
24 | saver = Saver(opt)
25 | print('%d epochs and %d iterations has been trained' % (last_ep, total_it))
26 | alleps = opt['n_ep'] - last_ep
27 | curep=0
28 | if 1:
29 | ## IDH
30 | model_init = Mine_init(opt).cuda(gpuID[0])
31 | model_IDH = Mine_IDH(opt).cuda(gpuID[0])
32 | model_1p19q = Mine_1p19q(opt).cuda(gpuID[0])
33 | model_CDKN = Mine_CDKN(opt).cuda(gpuID[0])
34 | model_Graph = Label_correlation_Graph(opt).cuda(gpuID[0])
35 | model_His = Mine_His(opt).cuda(gpuID[0])
36 | model_Cls = Cls_His_Grade(opt).cuda(gpuID[0])
37 | model_Task = Mine_Task(opt).cuda(gpuID[0])
38 |
39 | model_init = torch.nn.DataParallel(model_init, device_ids=gpuID)
40 | model_IDH = torch.nn.DataParallel(model_IDH, device_ids=gpuID)
41 | model_1p19q = torch.nn.DataParallel(model_1p19q, device_ids=gpuID)
42 | model_CDKN = torch.nn.DataParallel(model_CDKN, device_ids=gpuID)
43 | model_Graph = torch.nn.DataParallel(model_Graph, device_ids=gpuID)
44 | model_His = torch.nn.DataParallel(model_His, device_ids=gpuID)
45 | model_Cls = torch.nn.DataParallel(model_Cls, device_ids=gpuID)
46 | model_Task = torch.nn.DataParallel(model_Task, device_ids=gpuID)
47 |
48 | ckptdir = os.path.join(root)
49 | checkpoint = torch.load(ckptdir)
50 | related_params = {k: v for k, v in checkpoint['init'].items()}
51 | model_init.load_state_dict(related_params)
52 | related_params = {k: v for k, v in checkpoint['IDH'].items()}
53 | model_IDH.load_state_dict(related_params)
54 | related_params = {k: v for k, v in checkpoint['1p19q'].items()}
55 | model_1p19q.load_state_dict(related_params)
56 | related_params = {k: v for k, v in checkpoint['CDKN'].items()}
57 | model_CDKN.load_state_dict(related_params)
58 | related_params = {k: v for k, v in checkpoint['Graph'].items()}
59 | model_Graph.load_state_dict(related_params)
60 | related_params = {k: v for k, v in checkpoint['His'].items()}
61 | model_His.load_state_dict(related_params)
62 | related_params = {k: v for k, v in checkpoint['Cls'].items()}
63 | model_Cls.load_state_dict(related_params,strict=False)
64 | related_params = {k: v for k, v in checkpoint['Task'].items()}
65 | model_Task.load_state_dict(related_params,strict=False)
66 |
67 | model_init.eval()
68 | model_IDH.eval()
69 | model_1p19q.eval()
70 | model_CDKN.eval()
71 | model_Graph.eval()
72 | model_His.eval()
73 | model_Cls.eval()
74 | model_Task.eval()
75 | model = [model_init, model_IDH, model_1p19q, model_CDKN, model_Graph, model_His, model_Cls,model_Task]
76 |
77 |
78 |
79 | print("----------Val-------------")
80 | # validation_test_vis(opt, model,valLoader, gpuID)
81 | vis_reconstruct(opt, model,valLoader, gpuID)
82 |
83 |
84 |
85 | def vis_reconstruct(opt, model,dataloader, gpuID):
86 | CPTAC_label = pd.read_excel(opt['CPTAC_label_path'], header=0)
87 | IvYGAP_label = pd.read_excel(opt['IvYGAP_label_path'], sheet_name='Sheet1', header=0)
88 | TCGA_label = pd.read_excel(opt['TCGA_label_path'], sheet_name='Sheet1', header=0)
89 | combined_labels = pd.concat([TCGA_label, CPTAC_label], ignore_index=True)
90 | excel_wsi = combined_labels.values
91 |
92 | PATIENT_LIST=excel_wsi[:,0]
93 | np.random.seed(opt['seed'])
94 | random.seed(opt['seed'])
95 | PATIENT_LIST=list(PATIENT_LIST)
96 | # IvYGAP_label
97 | IvYGAP_label = IvYGAP_label.values
98 |
99 | PATIENT_LIST=np.unique(PATIENT_LIST)
100 | np.random.shuffle(PATIENT_LIST)
101 | NUM_PATIENT_ALL=len(PATIENT_LIST) # 952
102 | TRAIN_PATIENT_LIST=PATIENT_LIST[0:int(NUM_PATIENT_ALL * 0.8)]
103 | VAL_PATIENT_LIST = PATIENT_LIST[int(NUM_PATIENT_ALL * 0.9):]
104 | TEST_PATIENT_LIST = PATIENT_LIST[int(NUM_PATIENT_ALL * 0.80):int(NUM_PATIENT_ALL * 0.90)]
105 | TEST_LIST = []
106 | I_TEST_LIST = []
107 |
108 | for i in range(excel_wsi.shape[0]):# 2612
109 |
110 | if excel_wsi[:,0][i] in PATIENT_LIST:
111 | TEST_LIST.append(excel_wsi[i,:])
112 | TEST_LIST = np.asarray(TEST_LIST)
113 |
114 | for i in range(TEST_LIST.shape[0]):
115 | root = '/Res50_feature_2500_fixdim0_norm'
116 |
117 | patient_id = TEST_LIST[i, 0]
118 |
119 |
120 | if patient_id[0].startswith('T'):
121 | base_path = opt['dataDir'] + 'TCGA'
122 | elif patient_id[0].startswith('W'):
123 | base_path = opt['dataDir'] + 'IvYGAP'
124 | elif patient_id[0].startswith('C'):
125 | base_path = opt['dataDir'] + 'CPTAC'
126 | else:
127 | raise ValueError("Unknown data source")
128 |
129 | patch_20 = h5py.File(base_path + root + '_20x/' + TEST_LIST[i, 1] + '.h5')['Res_feature'][:]
130 | patch_10 = h5py.File(base_path + root + '/' + TEST_LIST[i, 1] + '.h5')['Res_feature'][:]
131 | img20 = torch.from_numpy(np.array(patch_20[0])).float()
132 | img10 = torch.from_numpy(np.array(patch_10[0])).float()
133 |
134 | read_details = np.load(opt['dataDir'] + 'read_details/' + TEST_LIST[i, 1] + '.npy', allow_pickle=True)[0]
135 | WSI_name=TEST_LIST[i, 1]
136 |
137 |
138 | if torch.cuda.is_available():
139 | img20 = img20.cuda(gpuID[0])
140 | img10 = img10.cuda(gpuID[0])
141 | # label = label.cuda(gpuID[0])
142 | # # # # IDH
143 |
144 | init_feature_his, init_feature_mark, _, _, _, _ = model[0](img20,img10) # (BS,2500,1024)
145 | hidden_states, encoded_IDH = model[1](init_feature_mark)
146 | hidden_states, encoded_1p19q = model[2](hidden_states)
147 | encoded_CDKN = model[3](hidden_states)
148 |
149 | results_dict,weight_IDH_wt,weight_IDH_mut,weight_1p19q_codel,weight_CDKN_HOMDEL,encoded_IDH0,encoded_1p19q0,encoded_CDKN0,Mark_output = \
150 | model[4](encoded_IDH, encoded_1p19q, encoded_CDKN)
151 | pred_IDH = results_dict['logits_IDH']
152 | pred_1p19q = results_dict['logits_1p19q']
153 | pred_CDKN = results_dict['logits_CDKN']
154 |
155 | hidden_states, encoded_His = model[5](init_feature_his)
156 | results_dict, weight_His_GBM, weight_His_GBM_Cls2,weight_His_O,His_output = model[6](encoded_His)
157 | pred_His_2class = results_dict['logits_His_2class']
158 | pred_His = results_dict['logits_His']
159 |
160 |
161 | # wsi_w = h5py.File('vis_results/set0/' + WSI_name + '.h5')['wsi_w'][()]
162 | # wsi_h = h5py.File('vis_results/set0/' + WSI_name + '.h5')['wsi_h'][()]
163 | # MPP = h5py.File('vis_results/set0/' + WSI_name + '.h5')['MPP'][()]
164 |
165 | weight_IDH_wt = norm(np.array(weight_IDH_wt.tolist()),read_details.shape[0])
166 | weight_His_GBM = norm(np.array(weight_His_GBM.tolist()),read_details.shape[0])
167 | weight_IDH_mut = norm(np.array(weight_IDH_mut.tolist()),read_details.shape[0])
168 | weight_1p19q_codel = norm(np.array(weight_1p19q_codel.tolist()),read_details.shape[0])
169 | weight_CDKN_HOMDEL = norm(np.array(weight_CDKN_HOMDEL.tolist()),read_details.shape[0])
170 |
171 |
172 |
173 | ################################################################
174 |
175 |
176 |
177 | #
178 |
179 | # relative_MPP = MPP / 0.5
180 | # PATCH_SIZE_revise = np.int(512 / relative_MPP)
181 |
182 | # wsi_w = np.int(wsi_w * (224 / PATCH_SIZE_revise)) + 1
183 | # wsi_h = np.int(wsi_h * (224 / PATCH_SIZE_revise)) + 1
184 | wsi_h=224
185 | wsi_w=224
186 | wsi_reconstruct = np.ones(shape=(wsi_h, wsi_w, 3), dtype=np.uint8) * 255
187 | wsi_reconstruct_nmp = np.ones(shape=(wsi_h, wsi_w, 3), dtype=np.uint8) * 255
188 | wsi_reconstruct_mut = np.ones(shape=(wsi_h, wsi_w, 3), dtype=np.uint8) * 255
189 | wsi_reconstruct_pq = np.ones(shape=(wsi_h, wsi_w, 3), dtype=np.uint8) * 255
190 | wsi_reconstruct_cdkn = np.ones(shape=(wsi_h, wsi_w, 3), dtype=np.uint8) * 255
191 |
192 | for j in range(len(weight_IDH_wt)):
193 | width_index = np.int(read_details[j][0])
194 | height_index = np.int(read_details[j][1])
195 | wsi_reconstruct[height_index * 224:(height_index + 1) * 224, width_index * 224:(width_index + 1) * 224,
196 | :] = weight_IDH_wt[j]
197 | wsi_reconstruct_nmp[height_index * 224:(height_index + 1) * 224, width_index * 224:(width_index + 1) * 224,
198 | :] = weight_His_GBM[j]
199 | wsi_reconstruct_mut[height_index * 224:(height_index + 1) * 224, width_index * 224:(width_index + 1) * 224,
200 | :] = weight_IDH_mut[j]
201 | wsi_reconstruct_pq[height_index * 224:(height_index + 1) * 224, width_index * 224:(width_index + 1) * 224,
202 | :] = weight_1p19q_codel[j]
203 | wsi_reconstruct_cdkn[height_index * 224:(height_index + 1) * 224, width_index * 224:(width_index + 1) * 224,
204 | :] = weight_CDKN_HOMDEL[j]
205 |
206 | #
207 |
208 | wsi_reconstruct = Image.fromarray(wsi_reconstruct)
209 | wsi_reconstruct = wsi_reconstruct.resize((2000, int(wsi_h / wsi_w * 2000)))
210 | wsi_reconstruct.save('vis_results/set2/' + WSI_name + '_vis_IDHwt.jpg')
211 |
212 | wsi_reconstruct_nmp = Image.fromarray(wsi_reconstruct_nmp)
213 | wsi_reconstruct_nmp = wsi_reconstruct_nmp.resize((2000, int(wsi_h / wsi_w * 2000)))
214 | wsi_reconstruct_nmp.save('vis_results/set2/' + WSI_name + '_vis_nmp.jpg')
215 |
216 | wsi_reconstruct_mut = Image.fromarray(wsi_reconstruct_mut)
217 | wsi_reconstruct_mut = wsi_reconstruct_mut.resize((2000, int(wsi_h / wsi_w * 2000)))
218 | wsi_reconstruct_mut.save('vis_results/set2/' + WSI_name + '_vis_IDHmut.jpg')
219 |
220 | wsi_reconstruct_pq = Image.fromarray(wsi_reconstruct_pq)
221 | wsi_reconstruct_pq = wsi_reconstruct_pq.resize((2000, int(wsi_h / wsi_w * 2000)))
222 | wsi_reconstruct_pq.save('vis_results/set2/' + WSI_name + '_vis_pq.jpg')
223 |
224 | wsi_reconstruct_cdkn = Image.fromarray(wsi_reconstruct_cdkn)
225 | wsi_reconstruct_cdkn = wsi_reconstruct_cdkn.resize((2000, int(wsi_h / wsi_w * 2000)))
226 | wsi_reconstruct_cdkn.save('vis_results/set2/' + WSI_name + '_vis_cdkn.jpg')
227 |
228 |
229 | count += 1
230 | print(i)
231 |
232 |
233 |
234 |
235 |
236 | def validation_test_vis(opt,model, dataloader,gpuID):
237 |
238 | CPTAC_label = pd.read_excel(opt['CPTAC_label_path'], header=0)
239 | IvYGAP_label = pd.read_excel(opt['IvYGAP_label_path'], sheet_name='Sheet1', header=0)
240 | TCGA_label = pd.read_excel(opt['TCGA_label_path'], sheet_name='Sheet1', header=0)
241 | combined_labels = pd.concat([TCGA_label, CPTAC_label], ignore_index=True)
242 | excel_label_wsi = combined_labels.values
243 |
244 | excel_wsi = list(excel_label_wsi.values)
245 | excel_wsi_new=[]
246 | test_bar = tqdm(dataloader)
247 | bs = 1
248 | count = 0
249 | for packs in test_bar:
250 | img20,img10, label= packs[0]
251 |
252 |
253 |
254 | if torch.cuda.is_available():
255 | img20 = img20.cuda(gpuID[0])
256 | img10 = img10.cuda(gpuID[0])
257 | # label = label.cuda(gpuID[0])
258 | # # # # IDH
259 |
260 | init_feature_his, init_feature_mark, _, _, _, _ = model[0](img20,img10) # (BS,2500,1024)
261 | hidden_states, encoded_IDH = model[1](init_feature_mark)
262 | hidden_states, encoded_1p19q = model[2](hidden_states)
263 | encoded_CDKN = model[3](hidden_states)
264 |
265 | results_dict,weight_IDH_wt,weight_IDH_mut,weight_1p19q_codel,weight_CDKN_HOMDEL,encoded_IDH0,encoded_1p19q0,encoded_CDKN0,Mark_output = \
266 | model[4](encoded_IDH, encoded_1p19q, encoded_CDKN)
267 | pred_IDH = results_dict['logits_IDH']
268 | pred_1p19q = results_dict['logits_1p19q']
269 | pred_CDKN = results_dict['logits_CDKN']
270 |
271 | hidden_states, encoded_His = model[5](init_feature_his)
272 | results_dict, weight_His_GBM, weight_His_GBM_Cls2,weight_His_O,His_output = model[6](encoded_His)
273 | pred_His_2class = results_dict['logits_His_2class']
274 | pred_His = results_dict['logits_His']
275 | his_mark = model[7](His_output.float(), Mark_output.float())
276 | weight_IDH_wt=weight_IDH_wt.tolist()
277 |
278 | _, pred_His = torch.max(pred_His.data, 1)
279 | pred_His = pred_His.tolist()
280 | _, pred_IDH = torch.max(pred_IDH.data, 1)
281 | pred_IDH = pred_IDH.tolist()
282 | _, pred_1p19q = torch.max(pred_1p19q.data, 1)
283 | pred_1p19q = pred_1p19q.tolist()
284 | _, pred_CDKN = torch.max(pred_CDKN.data, 1)
285 | pred_CDKN = pred_CDKN.tolist()
286 | _, pred_His_2class = torch.max(pred_His_2class.data, 1)
287 | pred_His_2class = pred_His_2class.tolist()
288 | _, pred_Task = torch.max(his_mark.data, 1)
289 | pred_Task = pred_Task.tolist()
290 |
291 |
292 | if pred_His[0]==3 and pred_IDH[0]==0 and pred_1p19q[0]==0 and (pred_CDKN[0]==label[:, 2].tolist()[0]):
293 | excel_wsi_new.append(excel_wsi[count])
294 | count += 1
295 | excel_wsi_new=np.array(excel_wsi_new)
296 | df = pd.DataFrame(excel_wsi_new, columns=list(excel_label_wsi))
297 | df.to_excel("vis/Test_TPall.xlsx",index=False)
298 |
299 |
300 | def norm(weight,num_patch):
301 | N_biorepet = int(2500 / num_patch)
302 | weight_0=weight[0:num_patch]
303 | weight_color=[]
304 | if N_biorepet>1:
305 | for j in range(N_biorepet-1):
306 | weight_0+=weight[(j+1)*num_patch:(j+2)*num_patch]
307 |
308 | min_w=np.min(weight_0)
309 | max_w = np.max(weight_0)
310 | weight_0=(weight_0-min_w)/(max_w-min_w)
311 |
312 | cmap = cmaps.MPL_viridis # 引用NCL的colormap
313 | newcolors = cmap(np.linspace(0, 1, 256))*255
314 | newcolors = np.trunc(newcolors)
315 | newcolors = newcolors.astype(int)
316 | ref_array=np.zeros(shape=[256])
317 |
318 | for k in range(256):
319 | ref_array[k]=k/256
320 |
321 | for k in range(weight_0.shape[0]):
322 | delta=np.abs(ref_array-weight_0[k]*weight_0[k])
323 | delta=list(delta)
324 | min_del=min(delta)
325 | min_del_index=delta.index(min_del)
326 | weight_color.append(newcolors[min_del_index][0:3])
327 |
328 |
329 |
330 | return weight_color
331 |
332 |
333 |
334 | def setup_seed(seed):
335 | torch.manual_seed(seed)
336 | torch.cuda.manual_seed(seed)
337 | torch.cuda.manual_seed_all(seed)
338 | np.random.seed(seed)
339 | random.seed(seed)
340 | if seed == 0:
341 | torch.backends.cudnn.deterministic = True
342 | torch.backends.cudnn.benchmark = False
343 |
344 | if __name__ == '__main__':
345 | parser = argparse.ArgumentParser()
346 | parser.add_argument('--opt', type=str, default='config/mine.yml')
347 | args = parser.parse_args()
348 | with open(args.opt) as f:
349 | opt = yaml.load(f, Loader=SafeLoader)
350 |
351 | setup_seed(opt['seed'])
352 | sysstr = platform.system()
353 | opt['logDir'] = os.path.join(opt['logDir'], 'Mine')
354 | if not os.path.exists(opt['logDir']):
355 | os.makedirs(opt['logDir'])
356 | train(opt)
357 |
358 |
359 |
360 |
361 | a=1
362 |
363 |
364 |
365 |
366 |
367 |
368 |
369 |
370 |
371 |
372 |
373 |
374 |
375 |
376 |
377 |
378 |
379 |
380 |
381 |
382 |
383 |
384 |
385 |
386 |
387 |
388 |
389 |
390 |
391 |
--------------------------------------------------------------------------------
/transform/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # **************************************
4 | # @Author : Qiqi Xiao
5 | # @Email : xiaoqiqi177@gmail.com
6 | # @File : __init__.py
7 | # **************************************
8 | from .transforms_group import *
9 |
--------------------------------------------------------------------------------
/transform/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LHY1007/M3C2/ef66915bd3e9671312205c424fe6017915f1ed81/transform/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/transform/__pycache__/functional.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LHY1007/M3C2/ef66915bd3e9671312205c424fe6017915f1ed81/transform/__pycache__/functional.cpython-38.pyc
--------------------------------------------------------------------------------
/transform/__pycache__/transforms_group.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LHY1007/M3C2/ef66915bd3e9671312205c424fe6017915f1ed81/transform/__pycache__/transforms_group.cpython-38.pyc
--------------------------------------------------------------------------------