├── LICENSE ├── Overview.png ├── README.md ├── __init__.py ├── __pycache__ ├── file_and_folder_operations.cpython-36.pyc ├── file_and_folder_operations.cpython-38.pyc ├── util.cpython-36.pyc └── util.cpython-38.pyc ├── base ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── base_dataloader.cpython-36.pyc │ ├── base_dataset.cpython-36.pyc │ └── base_model.cpython-36.pyc ├── base_dataloader.py ├── base_dataset.py ├── base_model.py └── base_trainer.py ├── config.yaml ├── config_mmwhs.yaml ├── configs ├── Config.py ├── Config_mmwhs.py ├── __init__.py └── __pycache__ │ ├── Config.cpython-38.pyc │ ├── Config_mmwhs.cpython-36.pyc │ ├── Config_mmwhs.cpython-38.pyc │ ├── __init__.cpython-36.pyc │ └── __init__.cpython-38.pyc ├── datasets ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-38.pyc │ ├── data_loader.cpython-36.pyc │ ├── data_loader.cpython-38.pyc │ ├── downsanpling_data.cpython-36.pyc │ └── downsanpling_data.cpython-38.pyc ├── data_loader.py ├── downsanpling_data.py ├── prepare_dataset │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── create_splits.cpython-36.pyc │ │ ├── create_splits.cpython-38.pyc │ │ ├── file_and_folder_operations.cpython-36.pyc │ │ ├── preprocessing.cpython-36.pyc │ │ ├── preprocessing.cpython-38.pyc │ │ ├── rearrange_dir.cpython-36.pyc │ │ └── rearrange_dir.cpython-38.pyc │ ├── create_splits.py │ ├── download_dataset.py │ ├── file_and_folder_operations.py │ ├── preprocessing.py │ └── rearrange_dir.py └── two_dim │ ├── NumpyDataLoader.py │ ├── __init__.py │ ├── __pycache__ │ ├── NumpyDataLoader.cpython-36.pyc │ ├── NumpyDataLoader.cpython-38.pyc │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-38.pyc │ ├── data_augmentation.cpython-36.pyc │ └── data_augmentation.cpython-38.pyc │ └── data_augmentation.py ├── experiments ├── MixExperiment.py ├── SegExperiment.py ├── __init__.py ├── __pycache__ │ ├── MixExperiment.cpython-36.pyc │ ├── MixExperiment.cpython-38.pyc │ ├── SegExperiment.cpython-36.pyc │ ├── SegExperiment.cpython-38.pyc │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-38.pyc │ ├── simclr_experiment.cpython-36.pyc │ └── simclr_experiment.cpython-38.pyc ├── simclr_experiment.py └── simclr_experiment_my.py ├── file_and_folder_operations.py ├── inference.py ├── loss_functions ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-38.pyc │ ├── dice_loss.cpython-36.pyc │ ├── dice_loss.cpython-38.pyc │ ├── metrics.cpython-36.pyc │ ├── metrics.cpython-38.pyc │ ├── nt_xent.cpython-36.pyc │ ├── nt_xent.cpython-38.pyc │ ├── supcon_loss.cpython-36.pyc │ └── supcon_loss.cpython-38.pyc ├── dice_loss.py ├── metrics.py ├── nt_xent.py └── supcon_loss.py ├── main_coseg.py ├── main_simclr.py ├── networks ├── RecursiveUNet.py ├── UNET.py ├── __pycache__ │ ├── RecursiveUNet.cpython-36.pyc │ ├── RecursiveUNet.cpython-38.pyc │ ├── unet_con.cpython-36.pyc │ └── unet_con.cpython-38.pyc └── unet_con.py ├── pallete.py ├── requirements.txt ├── result.txt ├── run_coseg.sh ├── run_mix_pipeline.py ├── run_seg.sh ├── run_seg_pipeline.py ├── run_simclr.sh ├── run_supcon.sh ├── supcon_loss.py ├── third-stage.png ├── util.py └── utilities ├── __init__.py ├── __pycache__ ├── __init__.cpython-36.pyc ├── __init__.cpython-38.pyc ├── file_and_folder_operations.cpython-36.pyc └── file_and_folder_operations.cpython-38.pyc └── file_and_folder_operations.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright {yyyy} {name of copyright owner} 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /Overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PerPerZXY/BHPC/1d0fd50c497a0f6cfe4c2263c6dc11132f96b08e/Overview.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Semi-supervised Medical Image Segmentation via Hard Positives oriented Contrastive Learning 2 | This is the pytorch implementation of paper "Semi-supervised Medical Image Segmentation via Hard Positives oriented Contrastive Learning". 3 | 4 | ![workflow of our methods](./Overview.png) 5 | 6 | Email addresses: tangcheng1@stu.scu.edu.cn & perperstudy@gmail.com (Joint First Authors) & wangyanscu@hotmail.com (Corresponding Author) 7 | 8 | ## Setup 9 | ### Environment 10 | ``` 11 | python=3.7.10 12 | torch==1.8.1 13 | torchvision=0.9.1 14 | ``` 15 | ### Dataset 16 | We will take the [Hippocampus dataset](https://drive.google.com/file/d/1RzPB1_bqzQhlWvU-YGvZzhx2omcDh38C/view?usp=sharing) as the example to illustrate how to do the preprocessing. Put the images .nii.gz files in ./data/Hippocampus/imgs folder and labels files in ./data/Hippocampus/labels. 17 | ``` 18 | cd dataset/prepare_dataset 19 | python preprcocessing.py 20 | python create_splits.py 21 | ``` 22 | 23 | Afterwards, the images and their respective labels will be combined and saved in a .npy file. The shape of the images will be normalized to match the target size. 24 | 25 | ## Run the codes 26 | To run the Stage I: Unsupervised Image-level Contrastive Learning, 27 | ``` 28 | bash run_simclr.sh 29 | ``` 30 | To run the Stage II: Supervised Pixel-level Contrastive Learning, 31 | ``` 32 | bash run_coseg.sh 33 | ``` 34 | To combine the above two pretraining, run run_simclr.sh first and the pretrained model will be saved at save/simclr/Hippocampus/ and set --pretrained_model_path ${the saved model path} in run_coseg.sh. Remember to load the saved pre-trained model. 35 | 36 | To run the Stage III: (Semi-supervised Segmentation), 37 | ``` 38 | bash run_seg.sh 39 | ``` 40 | 41 | ## Important Notes 42 | * In all of the aforementioned three files, "train_sample" denotes the percentage of labeled data to be utilized. 43 | 44 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PerPerZXY/BHPC/1d0fd50c497a0f6cfe4c2263c6dc11132f96b08e/__init__.py -------------------------------------------------------------------------------- /__pycache__/file_and_folder_operations.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PerPerZXY/BHPC/1d0fd50c497a0f6cfe4c2263c6dc11132f96b08e/__pycache__/file_and_folder_operations.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/file_and_folder_operations.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PerPerZXY/BHPC/1d0fd50c497a0f6cfe4c2263c6dc11132f96b08e/__pycache__/file_and_folder_operations.cpython-38.pyc -------------------------------------------------------------------------------- /__pycache__/util.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PerPerZXY/BHPC/1d0fd50c497a0f6cfe4c2263c6dc11132f96b08e/__pycache__/util.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PerPerZXY/BHPC/1d0fd50c497a0f6cfe4c2263c6dc11132f96b08e/__pycache__/util.cpython-38.pyc -------------------------------------------------------------------------------- /base/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_dataloader import * 2 | from .base_dataset import * 3 | from .base_model import * 4 | # from .base_trainer import * 5 | 6 | 7 | -------------------------------------------------------------------------------- /base/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PerPerZXY/BHPC/1d0fd50c497a0f6cfe4c2263c6dc11132f96b08e/base/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /base/__pycache__/base_dataloader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PerPerZXY/BHPC/1d0fd50c497a0f6cfe4c2263c6dc11132f96b08e/base/__pycache__/base_dataloader.cpython-36.pyc -------------------------------------------------------------------------------- /base/__pycache__/base_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PerPerZXY/BHPC/1d0fd50c497a0f6cfe4c2263c6dc11132f96b08e/base/__pycache__/base_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /base/__pycache__/base_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PerPerZXY/BHPC/1d0fd50c497a0f6cfe4c2263c6dc11132f96b08e/base/__pycache__/base_model.cpython-36.pyc -------------------------------------------------------------------------------- /base/base_dataloader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from copy import deepcopy 3 | import torch 4 | from torch.utils.data import DataLoader 5 | from torch.utils.data.sampler import SubsetRandomSampler 6 | 7 | class BaseDataLoader(DataLoader): 8 | def __init__(self, dataset, batch_size, shuffle, num_workers, val_split = 0.0, dist_sampler=None): 9 | self.shuffle = shuffle 10 | self.dataset = dataset 11 | self.nbr_examples = len(dataset) 12 | if val_split: 13 | self.train_sampler, self.val_sampler = self._split_sampler(val_split) 14 | else: 15 | self.train_sampler, self.val_sampler = None, None 16 | 17 | if dist_sampler: 18 | self.train_sampler = dist_sampler 19 | 20 | self.init_kwargs = { 21 | 'dataset': self.dataset, 22 | 'batch_size': batch_size, 23 | 'shuffle': self.shuffle, 24 | 'num_workers': num_workers, 25 | 'pin_memory': True, 26 | 'drop_last': True 27 | } 28 | 29 | super(BaseDataLoader, self).__init__(sampler=self.train_sampler, **self.init_kwargs) 30 | 31 | def _split_sampler(self, split): 32 | if split == 0.0: 33 | return None, None 34 | 35 | self.shuffle = False 36 | 37 | split_indx = int(self.nbr_examples * split) 38 | np.random.seed(0) 39 | 40 | indxs = np.arange(self.nbr_examples) 41 | np.random.shuffle(indxs) 42 | train_indxs = indxs[split_indx:] 43 | val_indxs = indxs[:split_indx] 44 | self.nbr_examples = len(train_indxs) 45 | 46 | train_sampler = SubsetRandomSampler(train_indxs) 47 | val_sampler = SubsetRandomSampler(val_indxs) 48 | return train_sampler, val_sampler 49 | 50 | def get_val_loader(self): 51 | if self.val_sampler is None: 52 | return None 53 | return DataLoader(sampler=self.val_sampler, **self.init_kwargs) 54 | -------------------------------------------------------------------------------- /base/base_dataset.py: -------------------------------------------------------------------------------- 1 | import random, math 2 | import numpy as np 3 | import cv2 4 | import torch 5 | import torch.nn.functional as F 6 | from torch.utils.data import Dataset 7 | from PIL import Image 8 | from torchvision import transforms 9 | from scipy import ndimage 10 | from math import ceil 11 | 12 | class BaseDataSet(Dataset): 13 | def __init__(self, data_dir, split, mean, std, ignore_index, base_size=None, augment=True, val=False, 14 | jitter=False, use_weak_lables=False, weak_labels_output=None, crop_size=None, scale=False, flip=False, rotate=False, 15 | blur=False, return_id=False, n_labeled_examples=None): 16 | 17 | self.root = data_dir 18 | self.split = split 19 | self.mean = mean 20 | self.std = std 21 | self.augment = augment 22 | self.crop_size = crop_size 23 | self.jitter = jitter 24 | self.image_padding = (np.array(mean)*255.).tolist() 25 | self.ignore_index = ignore_index 26 | self.return_id = return_id 27 | self.n_labeled_examples = n_labeled_examples 28 | self.val = val 29 | 30 | self.use_weak_lables = use_weak_lables 31 | self.weak_labels_output = weak_labels_output 32 | 33 | if self.augment: 34 | self.base_size = base_size 35 | self.scale = scale 36 | self.flip = flip 37 | self.rotate = rotate 38 | self.blur = blur 39 | 40 | # self.jitter_tf = transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1) 41 | self.jitter_tf = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1) 42 | self.to_tensor = transforms.ToTensor() 43 | self.normalize = transforms.Normalize(mean, std) 44 | 45 | self.files = [] 46 | self._set_files() 47 | 48 | cv2.setNumThreads(0) 49 | 50 | def _set_files(self): 51 | raise NotImplementedError 52 | 53 | def _load_data(self, index): 54 | raise NotImplementedError 55 | 56 | def _rotate(self, image, label): 57 | # Rotate the image with an angle between -10 and 10 58 | h, w, _ = image.shape 59 | angle = random.randint(-10, 10) 60 | center = (w / 2, h / 2) 61 | rot_matrix = cv2.getRotationMatrix2D(center, angle, 1.0) 62 | image = cv2.warpAffine(image, rot_matrix, (w, h), flags=cv2.INTER_CUBIC)#, borderMode=cv2.BORDER_REFLECT) 63 | label = cv2.warpAffine(label, rot_matrix, (w, h), flags=cv2.INTER_NEAREST)#, borderMode=cv2.BORDER_REFLECT) 64 | return image, label 65 | 66 | def _crop(self, image, label): 67 | # Padding to return the correct crop size 68 | if (isinstance(self.crop_size, list) or isinstance(self.crop_size, tuple)) and len(self.crop_size) == 2: 69 | crop_h, crop_w = self.crop_size 70 | elif isinstance(self.crop_size, int): 71 | crop_h, crop_w = self.crop_size, self.crop_size 72 | else: 73 | raise ValueError 74 | 75 | h, w, _ = image.shape 76 | pad_h = max(crop_h - h, 0) 77 | pad_w = max(crop_w - w, 0) 78 | pad_kwargs = { 79 | "top": 0, 80 | "bottom": pad_h, 81 | "left": 0, 82 | "right": pad_w, 83 | "borderType": cv2.BORDER_CONSTANT,} 84 | if pad_h > 0 or pad_w > 0: 85 | image = cv2.copyMakeBorder(image, value=self.image_padding, **pad_kwargs) 86 | label = cv2.copyMakeBorder(label, value=self.ignore_index, **pad_kwargs) 87 | 88 | # Cropping 89 | h, w, _ = image.shape 90 | start_h = random.randint(0, h - crop_h) 91 | start_w = random.randint(0, w - crop_w) 92 | end_h = start_h + crop_h 93 | end_w = start_w + crop_w 94 | image = image[start_h:end_h, start_w:end_w] 95 | label = label[start_h:end_h, start_w:end_w] 96 | return image, label 97 | 98 | def _blur(self, image, label): 99 | # Gaussian Blud (sigma between 0 and 1.5) 100 | sigma = random.random() * 1.5 101 | ksize = int(3.3 * sigma) 102 | ksize = ksize + 1 if ksize % 2 == 0 else ksize 103 | image = cv2.GaussianBlur(image, (ksize, ksize), sigmaX=sigma, sigmaY=sigma, borderType=cv2.BORDER_REFLECT_101) 104 | return image, label 105 | 106 | def _flip(self, image, label): 107 | # Random H flip 108 | if random.random() > 0.5: 109 | image = np.fliplr(image).copy() 110 | label = np.fliplr(label).copy() 111 | return image, label 112 | 113 | def _resize(self, image, label, bigger_side_to_base_size=True): 114 | if isinstance(self.base_size, int): 115 | h, w, _ = image.shape 116 | if self.scale: 117 | longside = random.randint(int(self.base_size*0.5), int(self.base_size*2.0)) 118 | #longside = random.randint(int(self.base_size*0.5), int(self.base_size*1)) 119 | else: 120 | longside = self.base_size 121 | 122 | if bigger_side_to_base_size: 123 | h, w = (longside, int(1.0 * longside * w / h + 0.5)) if h > w else (int(1.0 * longside * h / w + 0.5), longside) 124 | else: 125 | h, w = (longside, int(1.0 * longside * w / h + 0.5)) if h < w else (int(1.0 * longside * h / w + 0.5), longside) 126 | image = np.asarray(Image.fromarray(np.uint8(image)).resize((w, h), Image.BICUBIC)) 127 | label = cv2.resize(label, (w, h), interpolation=cv2.INTER_NEAREST) 128 | return image, label 129 | 130 | elif (isinstance(self.base_size, list) or isinstance(self.base_size, tuple)) and len(self.base_size) == 2: 131 | h, w, _ = image.shape 132 | if self.scale: 133 | scale = random.random() * 1.5 + 0.5 # Scaling between [0.5, 2] 134 | h, w = int(self.base_size[0] * scale), int(self.base_size[1] * scale) 135 | else: 136 | h, w = self.base_size 137 | image = np.asarray(Image.fromarray(np.uint8(image)).resize((w, h), Image.BICUBIC)) 138 | label = cv2.resize(label, (w, h), interpolation=cv2.INTER_NEAREST) 139 | return image, label 140 | 141 | else: 142 | raise ValueError 143 | 144 | def _val_augmentation(self, image, label): 145 | if self.base_size is not None: 146 | image, label = self._resize(image, label) 147 | image = self.normalize(self.to_tensor(Image.fromarray(np.uint8(image)))) 148 | return image, label 149 | 150 | image = self.normalize(self.to_tensor(Image.fromarray(np.uint8(image)))) 151 | 152 | return image, label 153 | 154 | def _augmentation(self, image, label): 155 | h, w, _ = image.shape 156 | 157 | if self.base_size is not None: 158 | image, label = self._resize(image, label) 159 | 160 | if self.crop_size is not None: 161 | image, label = self._crop(image, label) 162 | 163 | if self.flip: 164 | image, label = self._flip(image, label) 165 | 166 | image = Image.fromarray(np.uint8(image)) 167 | image = self.jitter_tf(image) if self.jitter else image 168 | 169 | return self.normalize(self.to_tensor(image)), label 170 | 171 | def __len__(self): 172 | return len(self.files) 173 | 174 | def __getitem__(self, index): 175 | 176 | image, label, image_id = self._load_data(index) 177 | if self.val: 178 | image, label = self._val_augmentation(image, label) 179 | elif self.augment: 180 | image, label = self._augmentation(image, label) 181 | 182 | label = torch.from_numpy(np.array(label, dtype=np.int32)).long() 183 | 184 | return image, label 185 | 186 | def __repr__(self): 187 | fmt_str = "Dataset: " + self.__class__.__name__ + "\n" 188 | fmt_str += " # data: {}\n".format(self.__len__()) 189 | fmt_str += " Split: {}\n".format(self.split) 190 | fmt_str += " Root: {}".format(self.root) 191 | return fmt_str 192 | 193 | -------------------------------------------------------------------------------- /base/base_model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | class BaseModel(nn.Module): 6 | def __init__(self): 7 | super(BaseModel, self).__init__() 8 | self.logger = logging.getLogger(self.__class__.__name__) 9 | 10 | def forward(self): 11 | raise NotImplementedError 12 | 13 | def summary(self): 14 | model_parameters = filter(lambda p: p.requires_grad, self.parameters()) 15 | nbr_params = sum([np.prod(p.size()) for p in model_parameters]) 16 | self.logger.info(f'Nbr of trainable parameters: {nbr_params}') 17 | 18 | def __str__(self): 19 | model_parameters = filter(lambda p: p.requires_grad, self.parameters()) 20 | nbr_params = int(sum([np.prod(p.size()) for p in model_parameters])) 21 | return f'\nNbr of trainable parameters: {nbr_params}' 22 | #return super(BaseModel, self).__str__() + f'\nNbr of trainable parameters: {nbr_params}' 23 | -------------------------------------------------------------------------------- /base/base_trainer.py: -------------------------------------------------------------------------------- 1 | import os, json, math, logging, sys, datetime, time 2 | import torch 3 | from torch.utils import tensorboard 4 | from utils import helpers 5 | from utils import logger 6 | import utils.lr_scheduler 7 | from utils.helpers import dir_exists 8 | 9 | def get_instance(module, name, config, *args): 10 | return getattr(module, config[name]['type'])(*args, **config[name]['args']) 11 | 12 | class BaseTrainer: 13 | def __init__(self, model, resume, config, iters_per_epoch, train_logger=None, gpu=None, test=False): 14 | self.model = model 15 | self.config = config 16 | 17 | if gpu == 0: 18 | self.train_logger = train_logger 19 | self.logger = logging.getLogger(self.__class__.__name__) 20 | self.logger.setLevel(logging.INFO) 21 | log_dir = os.path.join(config['trainer']['log_dir'], config['experim_name']) 22 | log_path = os.path.join(log_dir, '{}.log'.format(time.time())) 23 | dir_exists(log_dir) 24 | fh = logging.FileHandler(log_path) 25 | fh.setLevel(logging.INFO) 26 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 27 | fh.setFormatter(formatter) 28 | self.logger.addHandler(fh) 29 | self.logger.info("config: {}".format(self.config)) 30 | self.do_validation = self.config['trainer']['val'] 31 | self.start_epoch = 1 32 | self.improved = False 33 | self.gpu = gpu 34 | torch.cuda.set_device(self.gpu) 35 | 36 | self.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.model) 37 | 38 | trainable_params = [{'params': list(filter(lambda p:p.requires_grad, self.model.get_other_params()))}, 39 | {'params': list(filter(lambda p:p.requires_grad, self.model.get_backbone_params())), 40 | 'lr': config['optimizer']['args']['lr'] / 10}] 41 | 42 | self.model = torch.nn.parallel.DistributedDataParallel(self.model.cuda(), device_ids=[gpu], find_unused_parameters=True) 43 | 44 | # CONFIGS 45 | cfg_trainer = self.config['trainer'] 46 | self.epochs = cfg_trainer['epochs'] 47 | self.save_period = cfg_trainer['save_period'] 48 | 49 | # OPTIMIZER 50 | self.optimizer = get_instance(torch.optim, 'optimizer', config, trainable_params) # trainable_params should be obtained before wraping the model with DistributedDataParallel 51 | 52 | model_params = sum([i.shape.numel() for i in list(filter(lambda p: p.requires_grad, model.parameters()))]) 53 | opt_params = sum([i.shape.numel() for j in self.optimizer.param_groups for i in j['params']]) 54 | 55 | assert opt_params == model_params, 'some params are missing in the opt' 56 | 57 | self.lr_scheduler = getattr(utils.lr_scheduler, config['lr_scheduler'])(optimizer=self.optimizer, num_epochs=self.epochs, 58 | iters_per_epoch=iters_per_epoch) 59 | 60 | # MONITORING 61 | self.monitor = cfg_trainer.get('monitor', 'off') 62 | if self.monitor == 'off': 63 | self.mnt_mode = 'off' 64 | self.mnt_best = 0 65 | else: 66 | self.mnt_mode, self.mnt_metric = self.monitor.split() 67 | assert self.mnt_mode in ['min', 'max'] 68 | self.mnt_best = -math.inf if self.mnt_mode == 'max' else math.inf 69 | self.early_stoping = cfg_trainer.get('early_stop', math.inf) 70 | 71 | if self.gpu == 0: 72 | # CHECKPOINTS & TENSOBOARD 73 | date_time = datetime.datetime.now().strftime('%m-%d_%H-%M') 74 | run_name = config['experim_name'] 75 | self.checkpoint_dir = os.path.join(cfg_trainer['save_dir'], run_name) 76 | helpers.dir_exists(self.checkpoint_dir) 77 | config_save_path = os.path.join(self.checkpoint_dir, 'config.json') 78 | with open(config_save_path, 'w') as handle: 79 | json.dump(self.config, handle, indent=4, sort_keys=True) 80 | 81 | writer_dir = os.path.join(cfg_trainer['log_dir'], run_name) 82 | self.writer = tensorboard.SummaryWriter(writer_dir) 83 | 84 | self.test = test 85 | if resume: self._resume_checkpoint(resume) 86 | 87 | def _get_available_devices(self, n_gpu): 88 | sys_gpu = torch.cuda.device_count() 89 | if sys_gpu == 0: 90 | self.logger.warning('No GPUs detected, using the CPU') 91 | n_gpu = 0 92 | elif n_gpu > sys_gpu: 93 | self.logger.warning(f'Nbr of GPU requested is {n_gpu} but only {sys_gpu} are available') 94 | n_gpu = sys_gpu 95 | 96 | device = torch.device('cuda:0' if n_gpu > 0 else 'cpu') 97 | self.logger.info(f'Detected GPUs: {sys_gpu} Requested: {n_gpu}') 98 | available_gpus = list(range(n_gpu)) 99 | return device, available_gpus 100 | 101 | 102 | 103 | def train(self): 104 | if self.test: 105 | results = self._valid_epoch(0) 106 | if self.gpu == 0: 107 | self.logger.info('\n') 108 | for k, v in results.items(): 109 | self.logger.info(f' {str(k):15s}: {v}') 110 | return 111 | 112 | for epoch in range(self.start_epoch, self.epochs+1): 113 | self._train_epoch(epoch) 114 | if self.do_validation and epoch % self.config['trainer']['val_per_epochs'] == 0: 115 | results = self._valid_epoch(epoch) 116 | if self.gpu == 0: 117 | self.logger.info('\n\n Epoch {}:'.format(epoch)) 118 | for k, v in results.items(): 119 | self.logger.info(f' {str(k):15s}: {v}') 120 | 121 | log = {'epoch' : epoch, **results} 122 | if self.gpu == 0: 123 | if self.train_logger is not None: 124 | self.train_logger.add_entry(log) 125 | 126 | # CHECKING IF THIS IS THE BEST MODEL (ONLY FOR VAL) 127 | if self.mnt_mode != 'off' and epoch % self.config['trainer']['val_per_epochs'] == 0: 128 | try: 129 | if self.mnt_mode == 'min': self.improved = (log[self.mnt_metric] < self.mnt_best) 130 | else: self.improved = (log[self.mnt_metric] > self.mnt_best) 131 | except KeyError: 132 | self.logger.warning(f'The metrics being tracked ({self.mnt_metric}) has not been calculated. Training stops.') 133 | break 134 | 135 | if self.improved: 136 | self.mnt_best = log[self.mnt_metric] 137 | self.not_improved_count = 0 138 | else: 139 | self.not_improved_count += 1 140 | 141 | if self.not_improved_count > self.early_stoping: 142 | # if (not self.dist) or (self.dist and self.gpu == 0): 143 | if self.gpu == 0: 144 | self.logger.info(f'\nPerformance didn\'t improve for {self.early_stoping} epochs') 145 | self.logger.warning('Training Stoped') 146 | break 147 | 148 | if self.gpu == 0: 149 | # SAVE CHECKPOINT 150 | self._save_checkpoint(epoch, save_best=self.improved) 151 | if self.gpu == 0: 152 | self.logger.info(str(self.train_logger)) 153 | 154 | 155 | def _save_checkpoint(self, epoch, save_best=False): 156 | state = { 157 | 'arch': type(self.model).__name__, 158 | 'epoch': epoch, 159 | 'state_dict': self.model.state_dict(), 160 | 'monitor_best': self.mnt_best, 161 | 'config': self.config 162 | } 163 | 164 | filename = os.path.join(self.checkpoint_dir, f'checkpoint.pth') 165 | 166 | self.logger.info(f'\nSaving a checkpoint: {filename} ...') 167 | torch.save(state, filename) 168 | 169 | if save_best: 170 | filename = os.path.join(self.checkpoint_dir, f'best_model.pth') 171 | torch.save(state, filename) 172 | self.logger.info("Saving current best: best_model.pth {} at {} epoch".format(self.mnt_best, epoch)) 173 | 174 | def _resume_checkpoint(self, resume_path): 175 | 176 | if self.gpu == 0: 177 | self.logger.info(f'Loading checkpoint : {resume_path}') 178 | checkpoint = torch.load(resume_path) 179 | self.start_epoch = checkpoint['epoch'] + 1 180 | self.mnt_best = checkpoint['monitor_best'] 181 | self.not_improved_count = 0 182 | 183 | try: 184 | self.model.load_state_dict(checkpoint['state_dict']) 185 | except Exception as e: 186 | print(f'Error when loading: {e}') 187 | self.model.load_state_dict(checkpoint['state_dict'], strict=False) 188 | 189 | if self.gpu == 0: 190 | if "logger" in checkpoint.keys(): 191 | self.train_logger = checkpoint['logger'] 192 | self.logger.info(f'Checkpoint <{resume_path}> (epoch {self.start_epoch}) was loaded') 193 | 194 | def _train_epoch(self, epoch): 195 | raise NotImplementedError 196 | 197 | def _valid_epoch(self, epoch): 198 | raise NotImplementedError 199 | 200 | def _eval_metrics(self, output, target): 201 | raise NotImplementedError 202 | -------------------------------------------------------------------------------- /config.yaml: -------------------------------------------------------------------------------- 1 | batch_size: 120 2 | val_batch_size: 100 3 | epochs: 100 4 | base_dir: "data/Hippocampus/" 5 | save_dir: "save/simclr/Hippocampus" 6 | eval_every_n_epochs: 1 7 | fine_tune_from: None 8 | log_every_n_steps: 10 9 | weight_decay: 10e-6 10 | fp16_precision: False 11 | img_size: 64 12 | model: 13 | out_dim: 256 14 | embed_dim: 512 15 | base_model: "resnet18" 16 | 17 | dataset: 18 | s: 1 19 | input_shape: (112,112,3) 20 | num_workers: 0 21 | valid_size: 0.01 22 | 23 | loss: 24 | temperature: 0.5 25 | use_cosine_similarity: True 26 | # beta: 0.1 # use 0.1 0.3 0.5 0.7 0.9 have used 0.1 0.3 nobeta 27 | # stratedgy: 1 # choose the hard positive 28 | # estimator: "hard" 29 | # tau_plus: 0.1 30 | -------------------------------------------------------------------------------- /config_mmwhs.yaml: -------------------------------------------------------------------------------- 1 | atch_size: 160 2 | val_batch_size: 100 3 | epochs: 10 4 | base_dir: "data/mmwhs/" 5 | save_dir: "save/simclr/mmwhs" 6 | eval_every_n_epochs: 1 7 | fine_tune_from: None 8 | log_every_n_steps: 10 9 | weight_decay: 10e-6 10 | fp16_precision: False 11 | img_size: 160 12 | model: 13 | out_dim: 256 14 | embed_dim: 512 15 | base_model: "resnet18" 16 | 17 | dataset: 18 | s: 1 19 | input_shape: (112,112,3) 20 | num_workers: 0 21 | valid_size: 0.01 22 | 23 | loss: 24 | temperature: 0.5 25 | use_cosine_similarity: True 26 | # beta: 0.1 27 | -------------------------------------------------------------------------------- /configs/Config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from trixi.util import Config 4 | 5 | 6 | def get_config(): 7 | # Set your own path, if needed. 8 | data_root_dir = os.path.abspath('data') # The path where the downloaded dataset is stored. 9 | 10 | c = Config( 11 | stage = "mix_train", 12 | # stage = "train", 13 | update_from_argv=True, 14 | 15 | # Train parameters 16 | num_classes=3, 17 | in_channels=1, 18 | batch_size=8, 19 | patch_size=64, 20 | n_epochs=120, 21 | learning_rate=0.00001, 22 | fold=1, # The 'splits.pkl' may contain multiple folds. Here we choose which one we want to use. 23 | train_sample=0.4, # sample rate for training set 24 | 25 | device="cuda", # 'cuda' is the default CUDA device, you can use also 'cpu'. For more information, see https://pytorch.org/docs/stable/notes/cuda.html 26 | 27 | # Logging parameters 28 | name='Unet_hippo', 29 | plot_freq=10, # How often should stuff be shown in visdom 30 | append_rnd_string=False, 31 | start_visdom=True, 32 | 33 | do_instancenorm=True, # Defines whether or not the UNet does a instance normalization in the contracting path 34 | do_load_checkpoint=False, 35 | load_model=False, 36 | checkpoint_dir='', 37 | saved_model_path=None, 38 | freeze=False, 39 | 40 | # Adapt to your own path, if needed. 41 | google_drive_id='1RzPB1_bqzQhlWvU-YGvZzhx2omcDh38C', 42 | dataset_name='Hippocampus', 43 | img_size=64, 44 | base_dir=os.path.abspath('output_experiment'), # Where to log the output of the experiment. 45 | 46 | data_root_dir=data_root_dir, # The path where the downloaded dataset is stored. 47 | data_dir=os.path.join(data_root_dir, 'Hippocampus/preprocessed'), # This is where your training and validation data is stored 48 | data_test_dir=os.path.join(data_root_dir, 'Hippocampus/preprocessed'), # This is where your test data is stored 49 | 50 | split_dir=os.path.join(data_root_dir, 'Hippocampus'), # This is where the 'splits.pkl' file is located, that holds your splits. 51 | # test_result_dir=os.path.join(data_root_dir, 'mmwhs/test_result') 52 | 53 | ) 54 | 55 | return c 56 | -------------------------------------------------------------------------------- /configs/Config_mmwhs.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from trixi.util import Config 4 | 5 | 6 | def get_config(): 7 | # Set your own path, if needed. 8 | data_root_dir = os.path.abspath('data') # The path where the downloaded dataset is stored. 9 | 10 | c = Config( 11 | stage="mix_train", 12 | update_from_argv=True, 13 | 14 | # Train parameters 15 | num_classes=8, 16 | in_channels=1, 17 | batch_size=8, 18 | patch_size=64, 19 | n_epochs=0, 20 | learning_rate=0.0001, 21 | fold=1, # The 'splits.pkl' may contain multiple folds. Here we choose which one we want to use. 22 | train_sample=0.4, # sample rate for training set 23 | 24 | device="cuda", # 'cuda' is the default CUDA device, you can use also 'cpu'. For more information, see https://pytorch.org/docs/stable/notes/cuda.html 25 | 26 | # Logging parameters 27 | name='Unet_mmwhs', 28 | plot_freq=10, # How often should stuff be shown in visdom 29 | append_rnd_string=False, 30 | start_visdom=True, 31 | 32 | do_instancenorm=True, # Defines whether or not the UNet does a instance normalization in the contracting path 33 | do_load_checkpoint=False, 34 | load_model=False, 35 | checkpoint_dir='', 36 | saved_model_path=None, 37 | freeze=False, 38 | 39 | # Adapt to your own path, if needed. 40 | google_drive_id='1RzPB1_bqzQhlWvU-YGvZzhx2omcDh38C', 41 | dataset_name='mmwhs', 42 | img_size=160, 43 | base_dir=os.path.abspath('output_experiment'), # Where to log the output of the experiment. 44 | 45 | data_root_dir=data_root_dir, # The path where the downloaded dataset is stored. 46 | data_dir=os.path.join(data_root_dir, 'mmwhs/preprocessed'), # This is where your training and validation data is stored 47 | data_test_dir=os.path.join(data_root_dir, 'mmwhs/preprocessed'), # This is where your test data is stored 48 | 49 | split_dir=os.path.join(data_root_dir, 'mmwhs'), # This is where the 'splits.pkl' file is located, that holds your splits. 50 | # test_result_dir=os.path.join(data_root_dir, 'mmwhs/test_result') 51 | 52 | ) 53 | 54 | return c 55 | -------------------------------------------------------------------------------- /configs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PerPerZXY/BHPC/1d0fd50c497a0f6cfe4c2263c6dc11132f96b08e/configs/__init__.py -------------------------------------------------------------------------------- /configs/__pycache__/Config.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PerPerZXY/BHPC/1d0fd50c497a0f6cfe4c2263c6dc11132f96b08e/configs/__pycache__/Config.cpython-38.pyc -------------------------------------------------------------------------------- /configs/__pycache__/Config_mmwhs.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PerPerZXY/BHPC/1d0fd50c497a0f6cfe4c2263c6dc11132f96b08e/configs/__pycache__/Config_mmwhs.cpython-36.pyc -------------------------------------------------------------------------------- /configs/__pycache__/Config_mmwhs.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PerPerZXY/BHPC/1d0fd50c497a0f6cfe4c2263c6dc11132f96b08e/configs/__pycache__/Config_mmwhs.cpython-38.pyc -------------------------------------------------------------------------------- /configs/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PerPerZXY/BHPC/1d0fd50c497a0f6cfe4c2263c6dc11132f96b08e/configs/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /configs/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PerPerZXY/BHPC/1d0fd50c497a0f6cfe4c2263c6dc11132f96b08e/configs/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PerPerZXY/BHPC/1d0fd50c497a0f6cfe4c2263c6dc11132f96b08e/datasets/__init__.py -------------------------------------------------------------------------------- /datasets/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PerPerZXY/BHPC/1d0fd50c497a0f6cfe4c2263c6dc11132f96b08e/datasets/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PerPerZXY/BHPC/1d0fd50c497a0f6cfe4c2263c6dc11132f96b08e/datasets/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/data_loader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PerPerZXY/BHPC/1d0fd50c497a0f6cfe4c2263c6dc11132f96b08e/datasets/__pycache__/data_loader.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/data_loader.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PerPerZXY/BHPC/1d0fd50c497a0f6cfe4c2263c6dc11132f96b08e/datasets/__pycache__/data_loader.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/downsanpling_data.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PerPerZXY/BHPC/1d0fd50c497a0f6cfe4c2263c6dc11132f96b08e/datasets/__pycache__/downsanpling_data.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/downsanpling_data.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PerPerZXY/BHPC/1d0fd50c497a0f6cfe4c2263c6dc11132f96b08e/datasets/__pycache__/downsanpling_data.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/data_loader.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader, Dataset 2 | from trixi.util.pytorchutils import set_seed 3 | 4 | 5 | class WrappedDataset(Dataset): 6 | def __init__(self, dataset, transform): 7 | self.transform = transform 8 | self.dataset = dataset 9 | 10 | self.is_indexable = False 11 | if hasattr(self.dataset, "__getitem__") and not (hasattr(self.dataset, "use_next") and self.dataset.use_next is True): 12 | self.is_indexable = True 13 | 14 | def __getitem__(self, index): 15 | 16 | if not self.is_indexable: 17 | item = next(self.dataset) 18 | else: 19 | item = self.dataset[index] 20 | item = self.transform(**item) 21 | return item 22 | 23 | def __len__(self): 24 | return int(self.dataset.num_batches) 25 | 26 | 27 | class MultiThreadedDataLoader(object): 28 | def __init__(self, data_loader, transform, num_processes, **kwargs): 29 | 30 | self.cntr = 1 31 | self.ds_wrapper = WrappedDataset(data_loader, transform) 32 | 33 | self.generator = DataLoader(self.ds_wrapper, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, 34 | num_workers=0, pin_memory=True, drop_last=False, 35 | worker_init_fn=self.get_worker_init_fn()) 36 | 37 | self.num_processes = num_processes 38 | self.iter = None 39 | 40 | def get_worker_init_fn(self): 41 | def init_fn(worker_id): 42 | set_seed(worker_id + self.cntr) 43 | 44 | return init_fn 45 | 46 | def __iter__(self): 47 | self.kill_iterator() 48 | self.iter = iter(self.generator) 49 | return self.iter 50 | 51 | def __next__(self): 52 | if self.iter is None: 53 | self.iter = iter(self.generator) 54 | return next(self.iter) 55 | 56 | def renew(self): 57 | self.cntr += 1 58 | self.kill_iterator() 59 | self.generator.worker_init_fn = self.get_worker_init_fn() 60 | self.iter = iter(self.generator) 61 | 62 | def restart(self): 63 | pass 64 | # self.iter = iter(self.generator) 65 | 66 | def kill_iterator(self): 67 | try: 68 | if self.iter is not None: 69 | self.iter._shutdown_workers() 70 | for p in self.iter.workers: 71 | p.terminate() 72 | except: 73 | print("Could not kill Dataloader Iterator") 74 | -------------------------------------------------------------------------------- /datasets/downsanpling_data.py: -------------------------------------------------------------------------------- 1 | from utilities.file_and_folder_operations import subfiles 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | import os 6 | import matplotlib.pyplot as plt 7 | import math 8 | 9 | def reshape_array(numpy_array, axis=1): 10 | shape = numpy_array.shape[1] 11 | if axis == 1: 12 | slice_img = numpy_array[:, 0, :, :].reshape(1, 2, shape, shape) 13 | slice_len = np.shape(numpy_array)[1] 14 | for k in range(1, slice_len): 15 | slice_array = numpy_array[:, k, :, :].reshape(1, 2, shape, shape) 16 | slice_img = np.concatenate((slice_img, slice_array)) 17 | return slice_img 18 | elif axis == 2: 19 | slice_img = numpy_array[:, :, 0, :].reshape(1, 2, shape, shape) 20 | slice_len = np.shape(numpy_array)[2] 21 | for k in range(1, slice_len): 22 | slice_array = numpy_array[:, :, k, :].reshape(1, 2, shape, shape) 23 | slice_img = np.concatenate((slice_img, slice_array)) 24 | return slice_img 25 | elif axis == 3: 26 | slice_img = numpy_array[:, :, :, 0].reshape(1, 2, shape, shape) 27 | slice_len = np.shape(numpy_array)[3] 28 | for k in range(1, slice_len): 29 | slice_array = numpy_array[:, :, :, k].reshape(1, 2, shape, shape) 30 | slice_img = np.concatenate((slice_img, slice_array)) 31 | return slice_img 32 | 33 | 34 | def downsampling_image(data_dir, output_dir): 35 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 36 | if not os.path.exists(output_dir): 37 | os.makedirs(output_dir) 38 | print('Created' + output_dir + '...') 39 | 40 | npy_files = subfiles(data_dir, suffix=".npy", join=False) 41 | for file in npy_files: 42 | np_path = os.path.join(data_dir, file) 43 | save_path = os.path.join(output_dir, file.split('.')[0] + '.npy') 44 | 45 | if not os.path.exists(save_path): 46 | numpy_array = reshape_array(np.load(np_path), axis=3) 47 | shape = numpy_array.shape[3] 48 | num_of_pooling = math.ceil(math.log(shape, 2)) - 4 49 | 50 | ################ test num_of_pooling ############### 51 | num_of_pooling = num_of_pooling - 1 52 | 53 | slice_data = torch.from_numpy(numpy_array).to(device) 54 | 55 | for k in range(num_of_pooling): 56 | # pooling_data = F.max_pool2d(slice_data, kernel_size=2, stride=2) 57 | pooling_data = F.interpolate(slice_data, scale_factor=1/2, mode='bilinear') 58 | slice_data = pooling_data 59 | 60 | pooling_array = slice_data.cpu().numpy() 61 | np.save(os.path.join(output_dir, file.split('.')[0] + '.npy'), pooling_array) 62 | print(file) 63 | 64 | # else: 65 | # print("scaled image has already been created") 66 | 67 | # data_path = os.path.join(project_dir, "data/Task01_BrainTumour/preprocessed") 68 | # image_dir = os.path.join(c.data_dir) 69 | 70 | 71 | """" 72 | file_num = len(npy_files) 73 | for i in range(1, 50): 74 | np_path = os.path.join(data_path, npy_files[i]) 75 | numpy_array = np.load(np_path) 76 | slice_data = reshape_array(numpy_array) 77 | slice_img = np.concatenate((slice_img, slice_data)) 78 | 79 | print(np.shape(slice_img)) 80 | 81 | 82 | 83 | 84 | pooling_1_data = F.max_pool2d(batch_data, kernel_size=2, stride=2) 85 | pooling_2_data = F.max_pool2d(pooling_1_data, kernel_size=2, stride=2) 86 | 87 | batch_image = batch_data[150] 88 | plt.figure(1) 89 | plt.subplot(3, 2, 1) 90 | plt.imshow(batch_image[0], cmap='gray') 91 | plt.subplot(3, 2, 2) 92 | plt.imshow(batch_image[1], cmap='gray') 93 | 94 | 95 | pooling_image = pooling_1_data[150] 96 | plt.figure(1) 97 | plt.subplot(3, 2, 3) 98 | plt.imshow(pooling_image[0], cmap='gray') 99 | plt.subplot(3, 2, 4) 100 | plt.imshow(pooling_image[1], cmap='gray') 101 | 102 | pooling_image = pooling_2_data[150] 103 | plt.figure(1) 104 | plt.subplot(3, 2, 5) 105 | plt.imshow(pooling_image[0], cmap='gray') 106 | plt.subplot(3, 2, 6) 107 | plt.imshow(pooling_image[1], cmap='gray') 108 | 109 | plt.show() 110 | 111 | """ -------------------------------------------------------------------------------- /datasets/prepare_dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PerPerZXY/BHPC/1d0fd50c497a0f6cfe4c2263c6dc11132f96b08e/datasets/prepare_dataset/__init__.py -------------------------------------------------------------------------------- /datasets/prepare_dataset/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PerPerZXY/BHPC/1d0fd50c497a0f6cfe4c2263c6dc11132f96b08e/datasets/prepare_dataset/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/prepare_dataset/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PerPerZXY/BHPC/1d0fd50c497a0f6cfe4c2263c6dc11132f96b08e/datasets/prepare_dataset/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/prepare_dataset/__pycache__/create_splits.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PerPerZXY/BHPC/1d0fd50c497a0f6cfe4c2263c6dc11132f96b08e/datasets/prepare_dataset/__pycache__/create_splits.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/prepare_dataset/__pycache__/create_splits.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PerPerZXY/BHPC/1d0fd50c497a0f6cfe4c2263c6dc11132f96b08e/datasets/prepare_dataset/__pycache__/create_splits.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/prepare_dataset/__pycache__/file_and_folder_operations.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PerPerZXY/BHPC/1d0fd50c497a0f6cfe4c2263c6dc11132f96b08e/datasets/prepare_dataset/__pycache__/file_and_folder_operations.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/prepare_dataset/__pycache__/preprocessing.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PerPerZXY/BHPC/1d0fd50c497a0f6cfe4c2263c6dc11132f96b08e/datasets/prepare_dataset/__pycache__/preprocessing.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/prepare_dataset/__pycache__/preprocessing.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PerPerZXY/BHPC/1d0fd50c497a0f6cfe4c2263c6dc11132f96b08e/datasets/prepare_dataset/__pycache__/preprocessing.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/prepare_dataset/__pycache__/rearrange_dir.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PerPerZXY/BHPC/1d0fd50c497a0f6cfe4c2263c6dc11132f96b08e/datasets/prepare_dataset/__pycache__/rearrange_dir.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/prepare_dataset/__pycache__/rearrange_dir.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PerPerZXY/BHPC/1d0fd50c497a0f6cfe4c2263c6dc11132f96b08e/datasets/prepare_dataset/__pycache__/rearrange_dir.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/prepare_dataset/create_splits.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from file_and_folder_operations import subfiles 3 | import os 4 | import numpy as np 5 | 6 | 7 | 8 | def create_splits(output_dir, image_dir): 9 | npy_files = subfiles(image_dir, suffix=".npy", join=False) 10 | # 3:1:1 11 | trainset_size = len(npy_files)*60//100 12 | valset_size = len(npy_files)*20//100 13 | testset_size = len(npy_files)*20//100 14 | 15 | splits = [] 16 | for split in range(0, 5): 17 | image_list = npy_files.copy() 18 | trainset = [] 19 | valset = [] 20 | testset = [] 21 | for i in range(0, trainset_size): 22 | patient = np.random.choice(image_list) 23 | image_list.remove(patient) 24 | trainset.append(patient) 25 | for i in range(0, valset_size): 26 | patient = np.random.choice(image_list) 27 | image_list.remove(patient) 28 | valset.append(patient) 29 | for i in range(0, testset_size): 30 | patient = np.random.choice(image_list) 31 | image_list.remove(patient) 32 | testset.append(patient) 33 | split_dict = dict() 34 | split_dict['train'] = trainset 35 | split_dict['val'] = valset 36 | split_dict['test'] = testset 37 | 38 | splits.append(split_dict) 39 | 40 | with open(os.path.join(output_dir, 'splits.pkl'), 'wb') as f: 41 | pickle.dump(splits, f) 42 | 43 | 44 | # some dataset may include an independent test set 45 | def create_splits_1(output_dir, image_dir, test_dir): 46 | npy_files = subfiles(image_dir, suffix=".npy", join=False) 47 | test_files = subfiles(test_dir, suffix=".npy", join=False) 48 | 49 | trainset_size = len(npy_files) * 3 // 4 50 | valset_size = len(npy_files) - trainset_size 51 | 52 | splits = [] 53 | for split in range(0, 5): 54 | image_list = npy_files.copy() 55 | trainset = [] 56 | valset = [] 57 | for i in range(0, trainset_size): 58 | patient = np.random.choice(image_list) 59 | image_list.remove(patient) 60 | trainset.append(patient) 61 | for i in range(0, valset_size): 62 | patient = np.random.choice(image_list) 63 | image_list.remove(patient) 64 | valset.append(patient) 65 | split_dict = dict() 66 | split_dict['train'] = trainset 67 | split_dict['val'] = valset 68 | split_dict['test'] = test_files 69 | 70 | splits.append(split_dict) 71 | 72 | with open(os.path.join(output_dir, 'splits.pkl'), 'wb') as f: 73 | pickle.dump(splits, f) 74 | 75 | 76 | if __name__ == "__main__": 77 | # 78 | # root_dir = "../../data/mmwhs" 79 | # image_dir = "../../data/mmwhs/preprocessed" 80 | root_dir = "../../data/Hippocampus" 81 | image_dir = "../../data/Hippocampus/preprocessed" 82 | create_splits(root_dir, image_dir) 83 | 84 | 85 | -------------------------------------------------------------------------------- /datasets/prepare_dataset/download_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import exists 3 | import tarfile 4 | 5 | from google_drive_downloader import GoogleDriveDownloader as gdd 6 | 7 | dataset_id = { 8 | "Task01_BrainTumour": '1A2IU8Sgea1h3fYLpYtFb2v7NYdMjvEhU', 9 | "Task02_Heart": '1wEB2I6S6tQBVEPxir8cA5kFB8gTQadYY', 10 | "Task03_Liver": '1jyVGUGyxKBXV6_9ivuZapQS8eUJXCIpu', 11 | "Task04_Hippocampus": '1RzPB1_bqzQhlWvU-YGvZzhx2omcDh38C', 12 | "Task05_Prostate": "1Ff7c21UksxyT4JfETjaarmuKEjdqe1-a", 13 | "Task07_Pancreas": "1YZQFSonulXuagMIfbJkZeTFJ6qEUuUxL", 14 | "Task10_Colon": "1m7tMpE9qEcQGQjL_BdMD-Mvgmc44hG1Y" 15 | } 16 | 17 | def download_dataset(dest_path, dataset): 18 | tar_path = os.path.join(dest_path, dataset) + '.tar' 19 | id = dataset_id[dataset] 20 | gdd.download_file_from_google_drive(file_id=id, 21 | dest_path=tar_path, overwrite=False, 22 | unzip=False) 23 | 24 | if not exists(os.path.join(dest_path, dataset)): 25 | print('Extracting data [STARTED]') 26 | tar = tarfile.open(tar_path) 27 | tar.extractall(dest_path) 28 | print('Extracting data [DONE]') 29 | else: 30 | print('Data already downloaded. Files are not extracted again.') 31 | 32 | return 33 | 34 | 35 | if __name__ == "__main__": 36 | 37 | dest_path = "../../data" 38 | dataset_name = "Task07_Pancreas" 39 | 40 | download_dataset(dest_path, dataset=dataset_name) 41 | -------------------------------------------------------------------------------- /datasets/prepare_dataset/file_and_folder_operations.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | def subdirs(folder, join=True, prefix=None, suffix=None, sort=True): 5 | if join: 6 | l = os.path.join 7 | else: 8 | l = lambda x, y: y 9 | res = [l(folder, i) for i in os.listdir(folder) if os.path.isdir(os.path.join(folder, i)) 10 | and (prefix is None or i.startswith(prefix)) 11 | and (suffix is None or i.endswith(suffix))] 12 | if sort: 13 | res.sort() 14 | return res 15 | 16 | 17 | def subfiles(folder, join=True, prefix=None, suffix=None, sort=True): 18 | if join: 19 | l = os.path.join 20 | else: 21 | l = lambda x, y: y # lambda is another simplified way of defining a function 22 | res = [l(folder, i) for i in os.listdir(folder) if os.path.isfile(os.path.join(folder, i)) 23 | and (prefix is None or i.startswith(prefix)) 24 | and (suffix is None or i.endswith(suffix))] 25 | if sort: 26 | res.sort() 27 | return res 28 | 29 | 30 | def maybe_mkdir_p(directory): 31 | splits = directory.split("/")[1:] 32 | for i in range(0, len(splits)): 33 | if not os.path.isdir(os.path.join("/", *splits[:i+1])): 34 | os.mkdir(os.path.join("/", *splits[:i+1])) 35 | -------------------------------------------------------------------------------- /datasets/prepare_dataset/preprocessing.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | # from batchgenerators.augmentations.utils import resize_image_by_padding 3 | 4 | from medpy.io import load 5 | import os 6 | import numpy as np 7 | import shutil 8 | import torch 9 | import torch.nn.functional as F 10 | 11 | 12 | def preprocess_data(root_dir): 13 | image_dir = os.path.join(root_dir, 'imgs') 14 | label_dir = os.path.join(root_dir, 'labels') 15 | output_dir = os.path.join(root_dir, 'orig') 16 | 17 | if not os.path.exists(output_dir): 18 | os.makedirs(output_dir) 19 | print('Created' + output_dir + '...') 20 | 21 | class_stats = defaultdict(int) 22 | total = 0 23 | nii_files = subfiles(image_dir, suffix=".nii.gz", join=False) 24 | 25 | for f in nii_files: 26 | if f.startswith("."): 27 | os.remove(os.path.join(image_dir, f)) 28 | continue 29 | file_dir = os.path.join(output_dir, f.split('.')[0]+'.npy') 30 | if not os.path.exists(file_dir): 31 | image, _ = load(os.path.join(image_dir, f)) 32 | label, _ = load(os.path.join(label_dir, f.replace('image', 'label'))) 33 | 34 | 35 | # normalize images 36 | image = (image - image.min()) / (image.max() - image.min()) 37 | 38 | print(label.max()) 39 | print(label.min()) 40 | total += image.shape[2] 41 | 42 | # image = image[:, :, 0].transpose((0, 2, 1)) #TODO :修改了 43 | 44 | 45 | # # modify the label for MMWHS dataset 46 | # label[label == 500] = 1 47 | # label[label == 600] = 2 48 | # label[label == 420] = 3 49 | # label[label == 550] = 4 50 | # label[label == 205] = 5 51 | # label[label == 820] = 6 52 | # label[label == 850] = 7 53 | 54 | 55 | print(image.shape, label.shape) 56 | 57 | result = np.stack((image, label)).transpose((3, 0, 1, 2)) 58 | print(result.shape) 59 | 60 | np.save(os.path.join(output_dir, f.split('.')[0] + '.npy'), result) 61 | print(f) 62 | 63 | print(total) 64 | 65 | 66 | def reshape_2d_data(input_dir, output_dir, target_size=(160, 160)): 67 | if not os.path.exists(output_dir): 68 | os.makedirs(output_dir) 69 | print('Created' + output_dir + '...') 70 | 71 | files_list = os.listdir(input_dir) 72 | 73 | for f in files_list: 74 | target_dir = os.path.join(output_dir, f) 75 | if not os.path.exists(target_dir): 76 | data = np.load(os.path.join(input_dir, f)) 77 | 78 | image = data[:, 0] 79 | label = data[:, 1] 80 | 81 | image_tensor = torch.from_numpy(image) 82 | label_tensor = torch.from_numpy(label) 83 | 84 | new_image = F.interpolate(image_tensor[None], size=target_size, mode="bilinear") 85 | new_image = new_image.squeeze().cpu().numpy() 86 | 87 | new_label = F.interpolate(label_tensor[None], size=target_size, mode="bilinear") 88 | new_label = new_label.squeeze().cpu().numpy() 89 | 90 | new_data = np.concatenate((new_image[:, None], new_label[:, None]), axis=1) 91 | 92 | print(new_data.shape) 93 | np.save(target_dir, new_data) 94 | 95 | 96 | def reshape_three_dim_data(input_dir, output_dir): 97 | if not os.path.exists(output_dir): 98 | os.makedirs(output_dir) 99 | print('Created' + output_dir + '...') 100 | 101 | files_list = os.listdir(input_dir) 102 | 103 | for f in files_list: 104 | target_dir = os.path.join(output_dir, f) 105 | if not os.path.exists(target_dir): 106 | data = np.load(os.path.join(input_dir, f)) 107 | 108 | image = data[:, 0] 109 | label = data[:, 1] 110 | 111 | image_tensor = torch.from_numpy(image) 112 | label_tensor = torch.from_numpy(label) 113 | 114 | new_image = F.interpolate(image_tensor[None, None], size=(160, 160), mode="bilinear") 115 | new_image = new_image.squeeze().cpu().numpy() 116 | 117 | new_label = F.interpolate(label_tensor[None, None], size=(160, 160), mode="bilinear") 118 | new_label = new_label.squeeze().cpu().numpy() 119 | 120 | new_data = np.concatenate((new_image[None], new_label[None])) 121 | 122 | print(new_data.shape) 123 | np.save(target_dir, new_data) 124 | 125 | 126 | def subfiles(folder, join=True, prefix=None, suffix=None, sort=True): 127 | if join: 128 | l = os.path.join 129 | else: 130 | l = lambda x, y: y # lambda is another simplified way of defining a function 131 | res = [l(folder, i) for i in os.listdir(folder) if os.path.isfile(os.path.join(folder, i)) 132 | and (prefix is None or i.startswith(prefix)) 133 | and (suffix is None or i.endswith(suffix))] 134 | if sort: 135 | res.sort() 136 | return res 137 | 138 | 139 | if __name__ == "__main__": 140 | root_dir = "../../data/Hippocampus" 141 | input_dir = "../../data/Hippocampus/orig" 142 | target_dir = "../../data/Hippocampus/preprocessed" 143 | # root_dir = "../../data/mmwhs" 144 | # input_dir = "../../data/mmwhs/orig" 145 | # target_dir = "../../data/mmwhs/preprocessed" 146 | # root_dir = "../../data/mmwhs" 147 | # input_dir = "../../data/mmwhs/orig" 148 | # target_dir = "../../data/mmwhs/preprocessed" 149 | 150 | preprocess_data(root_dir) 151 | 152 | reshape_2d_data(input_dir, target_dir) 153 | 154 | 155 | 156 | 157 | -------------------------------------------------------------------------------- /datasets/prepare_dataset/rearrange_dir.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import os 4 | import shutil 5 | 6 | from utilities.file_and_folder_operations import subfiles 7 | 8 | 9 | def rearrange_dir(root_dir): 10 | image_dir = os.path.join(root_dir, 'images') 11 | label_dir = os.path.join(root_dir, 'labels') 12 | 13 | 14 | if not os.path.exists(image_dir): 15 | os.makedirs(image_dir) 16 | print('Created' + image_dir + '...') 17 | 18 | if not os.path.exists(label_dir): 19 | os.makedirs(label_dir) 20 | print('Created' + label_dir + '...') 21 | 22 | nii_files = subfiles(root_dir, suffix=".nii.gz", join=False) 23 | 24 | for i in range(0, len(nii_files)): 25 | src_dir = os.path.join(root_dir, nii_files[i]) 26 | if 'image' in nii_files[i]: 27 | shutil.move(src_dir, os.path.join(image_dir, nii_files[i])) 28 | elif 'label' in nii_files[i]: 29 | shutil.move(src_dir, os.path.join(label_dir, nii_files[i])) 30 | 31 | print('moving' + nii_files[i] + '...') 32 | 33 | files = subfiles(root_dir, suffix=".nii.gz", join=False) 34 | if files == []: 35 | print("rearrange directory finished") 36 | 37 | -------------------------------------------------------------------------------- /datasets/two_dim/NumpyDataLoader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import fnmatch 3 | import random 4 | 5 | import numpy as np 6 | 7 | from batchgenerators.dataloading import SlimDataLoaderBase 8 | from datasets.data_loader import MultiThreadedDataLoader 9 | from .data_augmentation import get_transforms 10 | 11 | 12 | # get three parameters file (directory of processed images), files_len, slcies_ax( list of tuples) 13 | def load_dataset(base_dir, pattern='*.npy', slice_offset=5, keys=None): 14 | fls = [] 15 | files_len = [] 16 | slices_ax = [] 17 | 18 | for root, dirs, files in os.walk(base_dir): 19 | i = 0 20 | for filename in sorted(fnmatch.filter(files, pattern)): 21 | 22 | if keys is not None and filename in keys: 23 | npy_file = os.path.join(root, filename) 24 | numpy_array = np.load(npy_file, mmap_mode="r+") # change "r" to "r+" 25 | 26 | fls.append(npy_file) 27 | files_len.append(numpy_array.shape[0]) # changed from 0 to 1 28 | 29 | slices_ax.extend([(i, j) for j in range(slice_offset, files_len[-1] - slice_offset)]) 30 | 31 | i += 1 32 | 33 | return fls, files_len, slices_ax, 34 | 35 | 36 | class NumpyDataSet(object): 37 | """ 38 | TODO 39 | """ 40 | 41 | def __init__(self, base_dir, mode="train", batch_size=16, num_batches=10000000, seed=None, num_processes=0, 42 | num_cached_per_queue=8 * 4, target_size=128, 43 | file_pattern='*.npy', label_slice=1, input_slice=(0,), do_reshuffle=True, keys=None): 44 | data_loader = NumpyDataLoader(base_dir=base_dir, mode=mode, batch_size=batch_size, num_batches=num_batches, 45 | seed=seed, file_pattern=file_pattern, 46 | input_slice=input_slice, label_slice=label_slice, keys=keys) 47 | 48 | self.data_loader = data_loader 49 | self.batch_size = batch_size 50 | self.do_reshuffle = do_reshuffle 51 | self.number_of_slices = 1 52 | 53 | self.transforms = get_transforms(mode=mode, target_size=target_size) 54 | self.augmenter = MultiThreadedDataLoader(data_loader, self.transforms, num_processes=num_processes, 55 | num_cached_per_queue=num_cached_per_queue, seeds=seed, 56 | shuffle=do_reshuffle) 57 | self.augmenter.restart() 58 | 59 | def __len__(self): 60 | return len(self.data_loader) 61 | 62 | def __iter__(self): 63 | if self.do_reshuffle: 64 | self.data_loader.reshuffle() 65 | self.augmenter.renew() 66 | return self.augmenter 67 | 68 | def __next__(self): 69 | return next(self.augmenter) 70 | 71 | 72 | class NumpyDataLoader(): 73 | def __init__(self, base_dir, mode="train", batch_size=16, num_batches=10000000, 74 | seed=None, file_pattern='*.npy', label_slice=1, input_slice=(0,), keys=None): 75 | 76 | self.files, self.file_len, self.slices = load_dataset(base_dir=base_dir, pattern=file_pattern, slice_offset=0, 77 | keys=keys, ) 78 | # super(NumpyDataLoader, self).__init__(self.slices, batch_size, num_batches) 79 | 80 | self.batch_size = batch_size 81 | # TODO: data agumentation 82 | self.train_transform = transforms.Compose([ 83 | transforms.ToPILImage(), 84 | RandomGaussianBlur(), 85 | transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8), 86 | transforms.RandomGrayscale(p=0.2), 87 | transforms.ToTensor(), 88 | # self.normalize, 89 | ]) 90 | self.use_next = False 91 | if mode == "train": 92 | self.use_next = False 93 | 94 | self.slice_idxs = list(range(0, len(self.slices))) # divide 3D images into slices 95 | 96 | self.data_len = len(self.slices) 97 | 98 | self.num_batches = min((self.data_len // self.batch_size) + 10, num_batches) 99 | 100 | if isinstance(label_slice, int): 101 | label_slice = (label_slice,) 102 | self.input_slice = input_slice 103 | self.label_slice = label_slice 104 | 105 | self.np_data = np.asarray(self.slices) 106 | 107 | def reshuffle(self): 108 | print("Reshuffle...") 109 | random.shuffle(self.slice_idxs) 110 | print("Initializing... this might take a while...") 111 | 112 | def generate_train_batch(self): 113 | open_arr = random.sample(self._data, self.batch_size) 114 | return self.get_data_from_array(open_arr) 115 | 116 | def __len__(self): 117 | n_items = min(self.data_len // self.batch_size, self.num_batches) 118 | return n_items 119 | 120 | def __getitem__(self, item): 121 | slice_idxs = self.slice_idxs 122 | data_len = len(self.slices) 123 | np_data = self.np_data 124 | 125 | if item > len(self): 126 | raise StopIteration() 127 | if (item * self.batch_size) == data_len: 128 | raise StopIteration() 129 | 130 | start_idx = (item * self.batch_size) % data_len 131 | stop_idx = ((item + 1) * self.batch_size) % data_len 132 | 133 | if ((item + 1) * self.batch_size) == data_len: 134 | stop_idx = data_len 135 | 136 | if stop_idx > start_idx: 137 | idxs = slice_idxs[start_idx:stop_idx] 138 | else: 139 | raise StopIteration() 140 | open_arr = np_data[idxs] # tuple (a,b) of images of this batch 141 | 142 | return self.get_data_from_array(open_arr) 143 | 144 | def get_data_from_array(self, open_array): 145 | data = [] 146 | fnames = [] 147 | slice_idxs = [] 148 | labels = [] 149 | 150 | for slice in open_array: 151 | # slice is a tuple (a,b), slice[0] indicating which image it's, 152 | # and slice[1] incicats which one in the 3d image it's. 153 | fn_name = self.files[slice[0]] 154 | 155 | numpy_array = np.load(fn_name, mmap_mode="r") # load data from .npy to numpy_arrary 156 | 157 | numpy_slice = numpy_array[slice[1]] # (2,64,64) 158 | data.append(numpy_slice[None, self.input_slice[0]]) # 'None' keeps the dimension (1,64,64) 159 | 160 | if self.label_slice is not None: 161 | labels.append(numpy_slice[None, self.label_slice[0]]) # 'None' keeps the dimension 162 | 163 | fnames.append(self.files[slice[0]]) 164 | slice_idxs.append(slice[1]) 165 | 166 | labels = np.asarray(labels) 167 | labels[labels > 7] = 0 168 | 169 | ret_dict = {'data': np.asarray(data), 'fnames': fnames, 170 | 'slice_idxs': slice_idxs} # data_shape (8,1,64,64) 'data': np.asarray(data), 171 | if self.label_slice is not None: 172 | ret_dict['seg'] = labels 173 | 174 | return ret_dict 175 | 176 | 177 | from torchvision import transforms 178 | import random 179 | import cv2 180 | 181 | 182 | class RandomGaussianBlur(object): 183 | def __init__(self, radius=5): 184 | self.radius = radius 185 | 186 | def __call__(self, image): 187 | image = np.asarray(image) 188 | if random.random() < 0.5: 189 | image = cv2.GaussianBlur(image, (self.radius, self.radius), 0) 190 | image = transforms.functional.to_pil_image(image) 191 | return image 192 | -------------------------------------------------------------------------------- /datasets/two_dim/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PerPerZXY/BHPC/1d0fd50c497a0f6cfe4c2263c6dc11132f96b08e/datasets/two_dim/__init__.py -------------------------------------------------------------------------------- /datasets/two_dim/__pycache__/NumpyDataLoader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PerPerZXY/BHPC/1d0fd50c497a0f6cfe4c2263c6dc11132f96b08e/datasets/two_dim/__pycache__/NumpyDataLoader.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/two_dim/__pycache__/NumpyDataLoader.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PerPerZXY/BHPC/1d0fd50c497a0f6cfe4c2263c6dc11132f96b08e/datasets/two_dim/__pycache__/NumpyDataLoader.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/two_dim/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PerPerZXY/BHPC/1d0fd50c497a0f6cfe4c2263c6dc11132f96b08e/datasets/two_dim/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/two_dim/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PerPerZXY/BHPC/1d0fd50c497a0f6cfe4c2263c6dc11132f96b08e/datasets/two_dim/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/two_dim/__pycache__/data_augmentation.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PerPerZXY/BHPC/1d0fd50c497a0f6cfe4c2263c6dc11132f96b08e/datasets/two_dim/__pycache__/data_augmentation.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/two_dim/__pycache__/data_augmentation.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PerPerZXY/BHPC/1d0fd50c497a0f6cfe4c2263c6dc11132f96b08e/datasets/two_dim/__pycache__/data_augmentation.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/two_dim/data_augmentation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from batchgenerators.transforms import MirrorTransform,Compose 3 | from batchgenerators.transforms.crop_and_pad_transforms import CenterCropTransform, RandomCropTransform 4 | from batchgenerators.transforms.spatial_transforms import ResizeTransform, SpatialTransform 5 | from batchgenerators.transforms.utility_transforms import NumpyToTensor 6 | from batchgenerators.transforms.color_transforms import BrightnessTransform, GammaTransform 7 | from batchgenerators.transforms.noise_transforms import GaussianNoiseTransform, GaussianBlurTransform 8 | 9 | 10 | def get_transforms(mode="train", target_size=128): 11 | tranform_list = [] 12 | 13 | if mode == "train": 14 | tranform_list = [# CenterCropTransform(crop_size=target_size), 15 | ResizeTransform(target_size=(target_size,target_size), order=1), # resize 16 | MirrorTransform(axes=(1,)), 17 | SpatialTransform(patch_size=(target_size, target_size), random_crop=False, 18 | patch_center_dist_from_border=target_size // 2, 19 | do_elastic_deform=True, alpha=(0., 1000.), sigma=(40., 60.), 20 | do_rotation=True, p_rot_per_sample=0.5, 21 | angle_x=(-0.1, 0.1), angle_y=(0, 1e-8), angle_z=(0, 1e-8), 22 | scale=(0.5, 1.9), p_scale_per_sample=0.5, 23 | border_mode_data="nearest", border_mode_seg="nearest"), 24 | ] 25 | 26 | elif mode == "val": 27 | tranform_list = [# CenterCropTransform(crop_size=target_size), 28 | ResizeTransform(target_size=target_size, order=1), 29 | ] 30 | 31 | elif mode == "test": 32 | tranform_list = [# CenterCropTransform(crop_size=target_size), 33 | ResizeTransform(target_size=target_size, order=1), 34 | ] 35 | 36 | elif mode == "supcon": 37 | tranform_list = [ 38 | BrightnessTransform(mu=1, sigma=1, p_per_sample=0.9), 39 | GammaTransform(p_per_sample=0.5), 40 | GaussianNoiseTransform(p_per_sample=0.5), 41 | # SpatialTransform(patch_size=(target_size, target_size), 42 | # do_elastic_deform=True, alpha=(0., 1000.), sigma=(40., 60.), 43 | # do_rotation=True, p_rot_per_sample=0.5, 44 | # angle_z=(0, 2 * np.pi), 45 | # scale=(0.7, 1.25), p_scale_per_sample=0.5, 46 | # border_mode_data="nearest", border_mode_seg="nearest"), 47 | GaussianBlurTransform(p_per_sample=0.9), 48 | ] 49 | 50 | tranform_list.append(NumpyToTensor()) 51 | 52 | return TwoCropTransform(Compose(tranform_list)) 53 | 54 | elif mode == "simclr": 55 | tranform_list = [ 56 | BrightnessTransform(mu=1, sigma=1, p_per_sample=0.5), 57 | GammaTransform(p_per_sample=0.5), 58 | GaussianNoiseTransform(p_per_sample=0.5), 59 | SpatialTransform(patch_size=(target_size, target_size), random_crop=True, 60 | do_elastic_deform=True, alpha=(0., 1000.), sigma=(40., 60.), 61 | do_rotation=True, p_rot_per_sample=0.5, 62 | angle_z=(0, 2 * np.pi), 63 | scale=(0.7, 1.25), p_scale_per_sample=0.5, 64 | border_mode_data="nearest", border_mode_seg="nearest"), 65 | NumpyToTensor(), 66 | ] 67 | 68 | return TwoCropTransform(Compose(tranform_list)) 69 | 70 | tranform_list.append(NumpyToTensor()) 71 | 72 | return Compose(tranform_list) 73 | 74 | 75 | class TwoCropTransform: 76 | """Create two crops of the same image""" 77 | def __init__(self, transform): 78 | self.transform = transform 79 | 80 | def __call__(self, **x): 81 | return [self.transform(**x), self.transform(**x)] 82 | -------------------------------------------------------------------------------- /experiments/MixExperiment.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | import numpy as np 5 | # import tensorboard_logger as tb_logger 6 | import torch 7 | import torch.optim as optim 8 | import torch.nn as nn 9 | from torch.utils.tensorboard import SummaryWriter 10 | from torch.optim.lr_scheduler import ReduceLROnPlateau 11 | import torch.nn.functional as F 12 | 13 | from datasets.two_dim.NumpyDataLoader import NumpyDataSet 14 | from trixi.experiment.pytorchexperiment import PytorchExperiment 15 | 16 | from networks.RecursiveUNet import UNet 17 | from networks.unet_con import SupConUnet 18 | 19 | from loss_functions.dice_loss import SoftDiceLoss 20 | 21 | from loss_functions.metrics import dice_pytorch, SegmentationMetric 22 | from skimage.io import imsave 23 | 24 | 25 | def save_images(segs, names, root, mode, iter): 26 | # b, w, h = segs.shape 27 | for seg, name in zip(segs, names): 28 | save_path = os.path.join(root, str(iter) + mode + name + '.png') 29 | 30 | # save_path.mkdir(parents=True, exist_ok=True) 31 | imsave(str(save_path), seg.cpu().numpy()) 32 | 33 | 34 | class MixExperiment(PytorchExperiment): 35 | """ 36 | The UnetExperiment is inherited from the PytorchExperiment. It implements the basic life cycle for a segmentation 37 | task with UNet(https://arxiv.org/abs/1505.04597). 38 | It is optimized to work with the provided NumpyDataLoader. 39 | 40 | The basic life cycle of a UnetExperiment is the same s PytorchExperiment: 41 | 42 | setup() 43 | (--> Automatically restore values if a previous checkpoint is given) 44 | prepare() 45 | 46 | for epoch in n_epochs: 47 | train() 48 | validate() 49 | (--> save current checkpoint) 50 | 51 | end() 52 | """ 53 | 54 | def set_loader(self, opt): 55 | # construct data loader 56 | pkl_dir = opt.split_dir 57 | with open(os.path.join(pkl_dir, "splits.pkl"), 'rb') as f: 58 | splits = pickle.load(f) 59 | 60 | if opt.train_sample == 1: 61 | tr_keys = splits[opt.fold]['train'] + splits[opt.fold]['val'] + splits[opt.fold]['test'] 62 | else: 63 | tr_keys = splits[opt.fold]['train'] 64 | tr_size = int(len(tr_keys) * opt.train_sample) 65 | tr_keys = tr_keys[0:tr_size] 66 | 67 | train_loader = NumpyDataSet(opt.data_dir, target_size=160, batch_size=opt.batch_size, 68 | keys=tr_keys, do_reshuffle=True, mode="supcon") 69 | 70 | return train_loader 71 | 72 | def setup(self): 73 | pkl_dir = self.config.split_dir 74 | with open(os.path.join(pkl_dir, "splits.pkl"), 'rb') as f: 75 | splits = pickle.load(f) 76 | 77 | tr_keys = splits[self.config.fold]['train'] 78 | tr_size = int(len(tr_keys) * self.config.train_sample) 79 | tr_keys = tr_keys[0:tr_size] 80 | val_keys = splits[self.config.fold]['val'] 81 | self.test_keys = splits[self.config.fold]['test'] 82 | test_keys = splits[self.config.fold]['test'] 83 | 84 | self.device = torch.device(self.config.device if torch.cuda.is_available() else 'cpu') # 85 | 86 | if self.config.stage == "train" or "test": 87 | self.train_data_loader = NumpyDataSet(self.config.data_dir, target_size=self.config.img_size, 88 | batch_size=self.config.batch_size, 89 | keys=tr_keys, do_reshuffle=True) 90 | elif self.config.stage == "mix_train": 91 | self.train_data_loader = self.set_loader(self.config) 92 | # self.train_data_loader = NumpyDataSet(self.config.data_dir, target_size=self.config.img_size, batch_size=self.config.batch_size, 93 | # keys=tr_keys, do_reshuffle=True, mode="supcon") 94 | 95 | self.val_data_loader = NumpyDataSet(self.config.data_dir, target_size=self.config.img_size, 96 | batch_size=self.config.batch_size, 97 | keys=val_keys, mode="val", do_reshuffle=True) 98 | self.test_data_loader = NumpyDataSet(self.config.data_dir, target_size=self.config.img_size, 99 | batch_size=self.config.batch_size, 100 | keys=test_keys, mode="test", do_reshuffle=False) 101 | # self.model = UNet(num_classes=self.config.num_classes, num_downs=4) 102 | # initial the alongside model 103 | self.model = SupConUnet(num_classes=self.config.num_classes) 104 | 105 | if torch.cuda.device_count() > 1: 106 | print("Let's use", torch.cuda.device_count(), "GPUs!") 107 | # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs 108 | # self.model.encoder = nn.DataParallel(self.model.encoder) 109 | self.model = nn.DataParallel(self.model) 110 | 111 | self.model.to(self.device) 112 | 113 | # We use a combination of DICE-loss and CE-Loss in this example. 114 | # This proved good in the medical segmentation decathlon. 115 | self.dice_loss = SoftDiceLoss(batch_dice=True, do_bg=False) # Softmax für DICE Loss! 116 | 117 | # weight = torch.tensor([1, 30, 30]).float().to(self.device) 118 | self.ce_loss = torch.nn.CrossEntropyLoss() # Kein Softmax für CE Loss -> ist in torch schon mit drin! 119 | # self.dice_pytorch = dice_pytorch(self.config.num_classes) 120 | 121 | # If directory for checkpoint is provided, we load it. 122 | if self.config.do_load_checkpoint: 123 | if self.config.checkpoint_dir == '': 124 | print('checkpoint_dir is empty, please provide directory to load checkpoint.') 125 | else: 126 | self.load_checkpoint(name=self.config.checkpoint_dir, save_types=("model")) 127 | 128 | if self.config.saved_model_path is not None: 129 | self.set_model() 130 | 131 | # freeze certain layer if required 132 | parameters = list(filter(lambda p: p.requires_grad, self.model.parameters())) 133 | # self.optimizer = optim.Adam(parameters, lr=self.config.learning_rate) 134 | self.optimizer = optim.SGD(self.model.parameters(), lr=self.config.learning_rate) 135 | self.scheduler = ReduceLROnPlateau(self.optimizer, 'min') 136 | 137 | self.save_checkpoint(name="checkpoint_start") 138 | self.writter = SummaryWriter(self.elog.work_dir) 139 | # self.writter = tb_logger.Logger(logdir=self.elog.work_dir, flush_secs=2) 140 | self.elog.print('Experiment set up.') 141 | self.elog.print(self.elog.work_dir) 142 | 143 | def train(self, epoch): 144 | if self.config.stage == "mix_train": 145 | print('=====MIX -- TRAIN=====') 146 | self.model.train() 147 | # data = None 148 | batch_counter = 0 149 | 150 | # self.train_data_loader = self.set_loader(self.config) 151 | for data_batch in self.train_data_loader: 152 | self.optimizer.zero_grad() 153 | data1 = data_batch['data'][0].float().to(self.device) 154 | target1 = data_batch['seg'][0].long().to(self.device) 155 | # for idx, data_batch in enumerate(self.train_data_loader): 156 | 157 | # data1 = data_batch[0]['data'][0].float().to(self.device) 158 | # target1 = data_batch[0]['seg'][0].long().to(self.device) 159 | # 160 | # data2 = data_batch[1]['data'][0].float().to(self.device) 161 | # 对data2 做数据扰动 : 162 | # target2 = data_batch[1]['seg'][0].long().to(self.device) 163 | inputs_1, target_a_1, target_b_1, lam_1 = self.mixup_data(data1, target1, 1.0) 164 | inputs_2, target_a_2, target_b_2, lam_2 = self.mixup_data(data1, target1, 1.0) 165 | 166 | feature_list1, feature_list2, pred_1, pred_2 = self.model(inputs_1, inputs_2, stage="mix_train") 167 | 168 | pred_softmax_1 = F.softmax(pred_1, dim=-1) # We calculate a softmax, because our SoftDiceLoss expects that as an input. The CE-Loss does the softmax internally. 169 | pred_softmax_2 = F.softmax(pred_2, dim=-1) 170 | 171 | pred_image_1 = torch.argmax(pred_softmax_1, dim=1) 172 | pred_image_2 = torch.argmax(pred_softmax_2, dim=1) 173 | # TODO: each decoder layer loss caclulation 174 | # loss_kl = 0.0 175 | loss_kl = F.kl_div(feature_list1[0].softmax(dim=-1).log(), feature_list2[0].softmax(dim=-1),reduction='sum') 176 | loss_seg = F.kl_div(pred_softmax_1.log(), pred_softmax_2,reduction='mean') 177 | # for index in range(len(feature_list1)): 178 | # loss_kl_1 = F.kl_div(feature_list1[index].softmax(dim=1).log(), feature_list2[index].softmax(dim=1), 179 | # reduction='sum') # why so large ???? how to slove it 180 | # loss_kl = loss_kl+loss_kl_1 181 | 182 | loss_seg_1 = self.mixup_criterian(pred_1, target_a_1, target_b_1, lam_1) 183 | loss_seg_2 = self.mixup_criterian(pred_2, target_a_2, target_b_2, lam_2) 184 | # loss = self.dice_loss(pred_softmax, target.squeeze()) 185 | loss = loss_seg_1 + loss_seg_2 186 | loss.backward() 187 | self.optimizer.step() 188 | 189 | # Some logging and plotting 190 | if (batch_counter % self.config.plot_freq) == 0: 191 | print('Train: [{0}][{1}/{2}]\t' 192 | 'loss {loss:.4f}'.format(epoch, batch_counter, len(self.train_data_loader), 193 | loss=loss.item())) 194 | self.writter.add_scalar("train_loss", loss.item(), 195 | epoch * len(self.train_data_loader) + batch_counter) 196 | 197 | batch_counter += 1 198 | 199 | # assert data is not None, 'data is None. Please check if your dataloader works properly' 200 | elif self.config.stage == "train": 201 | self.elog.print('=====TRAIN=====') 202 | self.model.train() 203 | 204 | data = None 205 | batch_counter = 0 206 | for data_batch in self.train_data_loader: 207 | 208 | self.optimizer.zero_grad() 209 | 210 | # Shape of data_batch = [1, b, c, w, h] 211 | # Desired shape = [b, c, w, h] 212 | # Move data and target to the GPU 213 | data = data_batch['data'][0].float().to(self.device) 214 | target = data_batch['seg'][0].long().to(self.device) 215 | max_value = target.max() 216 | min_value = target.min() 217 | 218 | inputs, target_a, target_b, lam = self.mixup_data(data, target, 1.0) 219 | # inputs, targets_a, targets_b = map(Variable, (inputs, targets_a, targets_b)) 220 | # print("inputs.shape:", inputs.shape) 221 | pred = self.model(inputs, stage="train") 222 | pred_softmax = F.softmax(pred,dim=1) # We calculate a softmax, because our SoftDiceLoss expects that as an input. The CE-Loss does the softmax internally. 223 | pred_image = torch.argmax(pred_softmax, dim=1) 224 | 225 | loss = self.mixup_criterian(pred, target_a, target_b, lam) 226 | # loss = self.dice_loss(pred_softmax, target.squeeze()) 227 | loss.backward() 228 | self.optimizer.step() 229 | 230 | # Some logging and plotting 231 | if (batch_counter % self.config.plot_freq) == 0: 232 | self.elog.print('Train: [{0}][{1}/{2}]\t' 233 | 'loss {loss:.4f}'.format(epoch, batch_counter, len(self.train_data_loader), 234 | loss=loss.item())) 235 | self.writter.add_scalar("train_loss", loss.item(), 236 | epoch * len(self.train_data_loader) + batch_counter) 237 | 238 | batch_counter += 1 239 | 240 | assert data is not None, 'data is None. Please check if your dataloader works properly' 241 | 242 | def test(self, epoch=120): 243 | metric_val = SegmentationMetric(self.config.num_classes) 244 | metric_val.reset() 245 | self.model.eval() 246 | 247 | num_of_parameters = sum(p.numel() for p in self.model.parameters() if p.requires_grad) 248 | print("number of parameters:", num_of_parameters) 249 | 250 | with torch.no_grad(): 251 | for i, data_batch in enumerate(self.test_data_loader): 252 | data = data_batch['data'][0].float().to(self.device) 253 | target = data_batch["seg"][0].long().to(self.device) 254 | 255 | output = self.model(data, stage="test") 256 | pred_softmax = F.softmax(output, dim=1) 257 | # pred_argmax = torch.argmax(pred_softmax, dim=1) 258 | # save_images(pred_argmax,str(i),r"/home/labuser2/tangcheng/semi_cotrast_seg-master/save/mixup", "seg", i) 259 | # save_images(target.squeeze(),str(i),r"/home/labuser2/tangcheng/semi_cotrast_seg-master/save/mixup", "groudtruth", i) 260 | # save_images(data.squeeze(),str(i),r"/home/labuser2/tangcheng/semi_cotrast_seg-master/save/mixup", "inputimage", i) 261 | metric_val.update(target.squeeze(dim=1), pred_softmax) 262 | pixAcc, mIoU, Dice = metric_val.get() 263 | if (i % self.config.plot_freq) == 0: 264 | self.elog.print("Index:%f, mean Dice:%.4f" % (i, Dice)) 265 | 266 | _, _, Dice = metric_val.get() 267 | print("Overall mean dice score is:", Dice) 268 | with open("result.txt", 'a') as f: 269 | f.write("epoch:" + str(epoch) + " " + "dice socre:" + str(Dice) + "\n") 270 | print("Finished test") 271 | 272 | def validate(self, epoch): 273 | self.elog.print('VALIDATE') 274 | self.model.eval() 275 | 276 | data = None 277 | loss_list = [] 278 | dice_list = [] 279 | 280 | with torch.no_grad(): 281 | for data_batch in self.val_data_loader: 282 | data = data_batch['data'][0].float().to(self.device) 283 | target = data_batch['seg'][0].long().to(self.device) 284 | 285 | pred = self.model(data, stage="test") 286 | pred_softmax = F.softmax( 287 | pred) # We calculate a softmax, because our SoftDiceLoss expects that as an input. The CE-Loss does the softmax internally. 288 | 289 | pred_image = torch.argmax(pred_softmax, dim=1) 290 | dice_result = dice_pytorch(outputs=pred_image, labels=target, N_class=self.config.num_classes) 291 | dice_list.append(dice_result) 292 | 293 | loss = self.dice_loss(pred_softmax, target.squeeze()) # self.ce_loss(pred, target.squeeze()) 294 | loss_list.append(loss.item()) 295 | 296 | assert data is not None, 'data is None. Please check if your dataloader works properly' 297 | 298 | # dice_list = np.asarray(dice_list) 299 | # dice_score = np.mean(dice_list, axis=0) 300 | # self.scheduler.step(np.mean(loss_list)) 301 | 302 | self.elog.print('Epoch: %2d Loss: %.4f' % (self._epoch_idx, np.mean(loss_list))) 303 | 304 | self.writter.add_scalar("val_loss", np.mean(loss_list), epoch) 305 | if epoch >= 90: 306 | self.test(epoch) 307 | 308 | def set_model(self): 309 | print("====> start loading model:", self.config.saved_model_path) 310 | checkpoint = torch.load(self.config.saved_model_path) 311 | if "model" not in checkpoint.keys(): 312 | state_dict = checkpoint 313 | else: 314 | state_dict = checkpoint["model"] 315 | for k in list(state_dict.keys()): 316 | if "head" in k: 317 | del state_dict[k] 318 | self.model.load_state_dict(state_dict, strict=False) 319 | print("checkpoint state dict:", state_dict.keys()) 320 | print("model state dict:", self.model.state_dict().keys()) 321 | if self.config.freeze: 322 | # state_dict = torch.load(self.config.saved_model_path)["model"] 323 | freeze_list = list(state_dict.keys()) 324 | for name, param in self.model.named_parameters(): 325 | if name in freeze_list: 326 | param.requires_grad = False 327 | 328 | def mixup_data(self, x, y, alpha=1.0, use_cuda=True): 329 | '''Returns mixed inputs, pairs of targets, and lambda''' 330 | if alpha > 0: 331 | lam = np.random.beta(alpha, alpha) 332 | else: 333 | lam = 1 334 | 335 | batch_size = x.size()[0] 336 | if use_cuda: 337 | index = torch.randperm(batch_size).cuda() 338 | else: 339 | index = torch.randperm(batch_size) 340 | 341 | mixed_x = lam * x + (1 - lam) * x[index, :] 342 | y_a, y_b = y, y[index] 343 | return mixed_x, y_a, y_b, lam 344 | 345 | def mixup_criterian(self, pred, target_a, target_b, lam): 346 | pred_softmax = F.softmax(pred) 347 | loss1 = self.ce_loss(pred, target_a.squeeze()) + self.dice_loss(pred_softmax, target_a.squeeze()) 348 | loss2 = self.ce_loss(pred, target_b.squeeze()) + self.dice_loss(pred_softmax, target_b.squeeze()) 349 | return lam * loss1 + (1 - lam) * loss2 350 | -------------------------------------------------------------------------------- /experiments/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PerPerZXY/BHPC/1d0fd50c497a0f6cfe4c2263c6dc11132f96b08e/experiments/__init__.py -------------------------------------------------------------------------------- /experiments/__pycache__/MixExperiment.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PerPerZXY/BHPC/1d0fd50c497a0f6cfe4c2263c6dc11132f96b08e/experiments/__pycache__/MixExperiment.cpython-36.pyc -------------------------------------------------------------------------------- /experiments/__pycache__/MixExperiment.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PerPerZXY/BHPC/1d0fd50c497a0f6cfe4c2263c6dc11132f96b08e/experiments/__pycache__/MixExperiment.cpython-38.pyc -------------------------------------------------------------------------------- /experiments/__pycache__/SegExperiment.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PerPerZXY/BHPC/1d0fd50c497a0f6cfe4c2263c6dc11132f96b08e/experiments/__pycache__/SegExperiment.cpython-36.pyc -------------------------------------------------------------------------------- /experiments/__pycache__/SegExperiment.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PerPerZXY/BHPC/1d0fd50c497a0f6cfe4c2263c6dc11132f96b08e/experiments/__pycache__/SegExperiment.cpython-38.pyc -------------------------------------------------------------------------------- /experiments/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PerPerZXY/BHPC/1d0fd50c497a0f6cfe4c2263c6dc11132f96b08e/experiments/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /experiments/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PerPerZXY/BHPC/1d0fd50c497a0f6cfe4c2263c6dc11132f96b08e/experiments/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /experiments/__pycache__/simclr_experiment.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PerPerZXY/BHPC/1d0fd50c497a0f6cfe4c2263c6dc11132f96b08e/experiments/__pycache__/simclr_experiment.cpython-36.pyc -------------------------------------------------------------------------------- /experiments/__pycache__/simclr_experiment.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PerPerZXY/BHPC/1d0fd50c497a0f6cfe4c2263c6dc11132f96b08e/experiments/__pycache__/simclr_experiment.cpython-38.pyc -------------------------------------------------------------------------------- /experiments/simclr_experiment.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.tensorboard import SummaryWriter 3 | import torch.nn.functional as F 4 | import torch.nn as nn 5 | import torch.distributed as dist 6 | from loss_functions.nt_xent import NTXentLoss 7 | import os 8 | import shutil 9 | import sys 10 | import pickle 11 | import torch.optim as optim 12 | 13 | from datasets.two_dim.NumpyDataLoader import NumpyDataSet 14 | from networks.unet_con import GlobalConUnet, MLP 15 | 16 | apex_support = False 17 | 18 | import numpy as np 19 | 20 | torch.manual_seed(0) 21 | def _save_config_file(model_checkpoints_folder): 22 | if not os.path.exists(model_checkpoints_folder): 23 | os.makedirs(model_checkpoints_folder) 24 | shutil.copy('./config.yaml', os.path.join(model_checkpoints_folder, 'config.yaml')) 25 | 26 | 27 | class SimCLR(object): 28 | 29 | def __init__(self, config): 30 | self.config = config 31 | self.device = self._get_device() 32 | self.writer = SummaryWriter(os.path.join(self.config['save_dir'], 'tensorboard')) 33 | self.nt_xent_criterion = NTXentLoss(self.device, **config['loss']) 34 | 35 | split_dir = os.path.join(self.config["base_dir"], "splits.pkl") 36 | data_dir = os.path.join(self.config["base_dir"], "preprocessed") 37 | print(data_dir) 38 | with open(split_dir, "rb") as f: 39 | splits = pickle.load(f) 40 | tr_keys = splits[0]['train'] + splits[0]['val'] + splits[0]['test'] 41 | val_keys = splits[0]['val'] 42 | self.train_loader = NumpyDataSet(data_dir, target_size=self.config["img_size"], batch_size=self.config["batch_size"], 43 | keys=tr_keys, do_reshuffle=True, mode='simclr') 44 | self.val_loader = NumpyDataSet(data_dir, target_size=self.config["img_size"], batch_size=self.config["val_batch_size"], 45 | keys=val_keys, do_reshuffle=True, mode='simclr') 46 | 47 | print(len(self.train_loader)) 48 | self.model = GlobalConUnet() 49 | self.head = MLP(num_class=256) 50 | 51 | self.nt_xent_criterion = NTXentLoss(self.device, **config['loss']) 52 | 53 | # dist.init_process_group(backend='nccl') 54 | if torch.cuda.device_count() > 1: 55 | print("Let's use %d GPUs" % torch.cuda.device_count()) 56 | self.model = nn.DataParallel(self.model) 57 | self.head = nn.DataParallel(self.head) 58 | 59 | self.model.to(self.device) 60 | self.head.to(self.device) 61 | 62 | self.model = self._load_pre_trained_weights(self.model) 63 | 64 | self.optimizer = torch.optim.Adam(self.model.parameters(), 3e-4, weight_decay=eval(self.config['weight_decay'])) 65 | 66 | def _get_device(self): 67 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 68 | print("Running on:", device) 69 | return device 70 | def _step(self, model, head, xis, xjs, n_iter): 71 | 72 | # get the representations and the projections 73 | ris = model(xis) # [N,C] 74 | zis = head(ris) 75 | # get the representations and the projections 76 | rjs = model(xjs) # [N,C] 77 | zjs = head(rjs) 78 | 79 | # normalize projection feature vectors 80 | zis = F.normalize(zis, dim=1) 81 | zjs = F.normalize(zjs, dim=1) 82 | 83 | # loss = self.nt_xent_criterion(zis, zjs) 84 | loss = self.nt_xent_criterion(zis,zjs) 85 | return loss 86 | def train(self): 87 | 88 | model_checkpoints_folder = os.path.join(self.writer.log_dir, 'checkpoints') 89 | 90 | # save config file 91 | _save_config_file(model_checkpoints_folder) 92 | 93 | n_iter = 0 94 | valid_n_iter = 0 95 | best_valid_loss = np.inf 96 | 97 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=len(self.train_loader), eta_min=0, 98 | last_epoch=-1) 99 | 100 | for epoch_counter in range(self.config['epochs']): 101 | print("=====Training Epoch: %d =====" % epoch_counter) 102 | for i, (xis, xjs) in enumerate(self.train_loader): 103 | self.optimizer.zero_grad() 104 | 105 | xis = xis['data'][0].float().to(self.device) 106 | xjs = xjs['data'][0].float().to(self.device) 107 | 108 | loss = self._step(self.model, self.head, xis, xjs, n_iter) 109 | 110 | if n_iter % self.config['log_every_n_steps'] == 0: 111 | self.writer.add_scalar('train_loss', loss, global_step=n_iter) 112 | print("Train:[{0}][{1}][{2}] loss: {loss:.4f}".format(epoch_counter, i, len(self.train_loader), 113 | loss=loss.item())) 114 | 115 | loss.backward() 116 | self.optimizer.step() 117 | n_iter += 1 118 | 119 | print("===== Validation =====") 120 | # validate the model if requested 121 | if epoch_counter % self.config['eval_every_n_epochs'] == 0: 122 | valid_loss = self._validate(self.val_loader) 123 | print("Val:[{0}] loss: {loss:.4f}".format(epoch_counter, loss=valid_loss)) 124 | if valid_loss < best_valid_loss: 125 | # save the model weights 126 | best_valid_loss = valid_loss 127 | torch.save(self.model.state_dict(), os.path.join(self.config['save_dir'], 128 | 'b_{}_model.pth'.format(self.config["batch_size"]))) 129 | 130 | self.writer.add_scalar('validation_loss', valid_loss, global_step=valid_n_iter) 131 | valid_n_iter += 1 132 | 133 | # warmup for the first 10 epochs 134 | if epoch_counter >= 10: 135 | scheduler.step() 136 | self.writer.add_scalar('cosine_lr_decay', scheduler.get_lr()[0], global_step=n_iter) 137 | 138 | def _load_pre_trained_weights(self, model): 139 | try: 140 | checkpoints_folder = os.path.join('./runs', self.config['fine_tune_from'], 'checkpoints') 141 | state_dict = torch.load(os.path.join(checkpoints_folder, 'model.pth')) 142 | model.load_state_dict(state_dict) 143 | print("Loaded pre-trained model with success.") 144 | except FileNotFoundError: 145 | print("Pre-trained weights not found. Training from scratch.") 146 | 147 | return model 148 | 149 | def _validate(self, valid_loader): 150 | 151 | # validation steps 152 | with torch.no_grad(): 153 | self.model.eval() 154 | 155 | valid_loss = 0.0 156 | counter = 0 157 | for (xis, xjs) in valid_loader: 158 | xis = xis['data'][0].float().to(self.device) 159 | xjs = xjs['data'][0].float().to(self.device) 160 | 161 | loss = self._step(self.model, self.head, xis, xjs, counter) 162 | valid_loss += loss.item() 163 | counter += 1 164 | valid_loss /= counter 165 | return valid_loss 166 | 167 | 168 | -------------------------------------------------------------------------------- /experiments/simclr_experiment_my.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.tensorboard import SummaryWriter 3 | import torch.nn.functional as F 4 | import torch.nn as nn 5 | import torch.distributed as dist 6 | from loss_functions.nt_xent import NTXentLoss 7 | import os 8 | import shutil 9 | import sys 10 | import pickle 11 | import torch.optim as optim 12 | 13 | from datasets.two_dim.NumpyDataLoader import NumpyDataSet 14 | from networks.unet_con import GlobalConUnet, MLP 15 | 16 | apex_support = False 17 | 18 | import numpy as np 19 | 20 | torch.manual_seed(0) 21 | 22 | os.environ['CUDA_VISIBLE_DEVICE'] ='0,1' 23 | def _save_config_file(model_checkpoints_folder): 24 | if not os.path.exists(model_checkpoints_folder): 25 | os.makedirs(model_checkpoints_folder) 26 | shutil.copy('./config.yaml', os.path.join(model_checkpoints_folder, 'config.yaml')) 27 | #define the prediction head 28 | class projection_MLP(nn.Module): 29 | def __init__(self, in_dim, out_dim=256): 30 | super().__init__() 31 | hidden_dim = in_dim 32 | self.layer1 = nn.Sequential( 33 | nn.Linear(in_dim, hidden_dim), 34 | nn.ReLU(inplace=True) 35 | ) 36 | self.layer2 = nn.Linear(hidden_dim, out_dim) 37 | def forward(self, x): 38 | x = self.layer1(x) 39 | x = self.layer2(x) 40 | return x 41 | 42 | class SimCLR(object): 43 | 44 | def __init__(self, config): 45 | self.config = config 46 | self.device = self._get_device() 47 | print("use the cpu:",self.device) 48 | 49 | self.writer = SummaryWriter(os.path.join(self.config['save_dir'], 'tensorboard')) 50 | self.nt_xent_criterion = NTXentLoss(self.device, **config['loss']) 51 | 52 | split_dir = os.path.join(self.config["base_dir"], "splits.pkl") 53 | data_dir = os.path.join(self.config["base_dir"], "preprocessed") 54 | print(data_dir) 55 | with open(split_dir, "rb") as f: 56 | splits = pickle.load(f) 57 | tr_keys = splits[0]['train'] + splits[0]['val'] + splits[0]['test'] 58 | val_keys = splits[0]['val'] 59 | self.train_loader = NumpyDataSet(data_dir, target_size=self.config["img_size"], batch_size=self.config["batch_size"], 60 | keys=tr_keys, do_reshuffle=True, mode='simclr') 61 | self.val_loader = NumpyDataSet(data_dir, target_size=self.config["img_size"], batch_size=self.config["val_batch_size"], 62 | keys=val_keys, do_reshuffle=True, mode='simclr') 63 | 64 | print(len(self.train_loader)) 65 | self.model = GlobalConUnet() 66 | self.head = MLP(num_class=256) 67 | 68 | self.nt_xent_criterion = NTXentLoss(self.device, **config['loss']) 69 | 70 | # dist.init_process_group(backend='nccl') 71 | if torch.cuda.device_count() > 1: 72 | print("Let's use %d GPUs" % torch.cuda.device_count()) 73 | # print("Let's use %d GPUs" % self.device) 74 | self.model = nn.DataParallel(self.model) 75 | self.head = nn.DataParallel(self.head) 76 | 77 | self.model.to(self.device) 78 | self.head.to(self.device) 79 | 80 | self.model = self._load_pre_trained_weights(self.model) 81 | 82 | self.optimizer = torch.optim.Adam(self.model.parameters(), 3e-4, weight_decay=eval(self.config['weight_decay'])) 83 | # self.optimizer = torch.optim.SGD(self.model.parameters(), 1e-4,momentum=0.9, weight_decay=eval(self.config['weight_decay'])) 84 | 85 | def set_optimizer(opt, model): 86 | if opt.optimizer == "sgd": 87 | optimizer = optim.SGD(model.parameters(), 88 | lr=opt.learning_rate, 89 | momentum=opt.momentum, 90 | weight_decay=opt.weight_decay 91 | ) 92 | elif opt.optimizer == "adam": 93 | optimizer = optim.Adam(model.parameters(), 94 | lr=opt.learning_rate, 95 | weight_decay=opt.weight_decay) 96 | else: 97 | raise NotImplementedError("The optimizer is not supported.") 98 | return optimizer 99 | def _get_device(self): 100 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 101 | 102 | print("Running on:", device) 103 | return device 104 | #TODO:simsam network implement 105 | def D(self,p,z,queue): #negative cosine similarity 106 | z = z.detach() # stop gradient 107 | return self.nt_xent_criterion(p, z,queue) 108 | 109 | 110 | def _step(self, model, head, xis, xjs, queue): 111 | model.zero_grad() 112 | # get the representations and the projections 113 | ris = model(xis) # [N,C] 114 | zis = head(ris) 115 | # get the representations and the projections 116 | rjs = model(xjs) # [N,C] 117 | zjs = head(rjs) 118 | #norm 119 | p = F.normalize(zis,dim=1) 120 | z = F.normalize(zjs, dim=1) 121 | # update the queue 122 | queue = torch.cat((queue,z ), 0) 123 | if queue.shape[0] > 20*160: #memory size 20 *batch_size 124 | queue = queue[160:, :] 125 | loss = self.D(p,z,queue)/2+self.D(p,z,queue)/2 126 | # loss = self.nt_xent_criterion(p,z,queue) 127 | return loss 128 | 129 | def train(self): 130 | model_checkpoints_folder = os.path.join(self.writer.log_dir, 'checkpoints') 131 | 132 | # save config file 133 | _save_config_file(model_checkpoints_folder) 134 | 135 | n_iter = 0 136 | valid_n_iter = 0 137 | best_valid_loss = np.inf 138 | 139 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=len(self.train_loader), eta_min=0, 140 | last_epoch=-1) 141 | flag = 0 142 | K = 20*160 # the most biggest number of queue 143 | queue = None 144 | if queue is None: 145 | while True: 146 | with torch.no_grad(): 147 | for i, (xis, xjs) in enumerate(self.train_loader): 148 | xjs = xjs['data'][0].float().to(self.device) 149 | self.model.zero_grad() 150 | # get the representations and the projections 151 | rjs = self.model(xjs) # [N,C] 152 | zjs = self.head(rjs) 153 | zjs.detach() 154 | k = torch.div(zjs,torch.norm(zjs,dim=1).reshape(-1,1)) 155 | if queue is None: 156 | queue = k 157 | else: 158 | if queue.shape[0]= 10: 206 | scheduler.step() 207 | self.writer.add_scalar('cosine_lr_decay', scheduler.get_lr()[0], global_step=n_iter) 208 | 209 | 210 | def _load_pre_trained_weights(self, model): 211 | try: 212 | checkpoints_folder = os.path.join('./runs', self.config['fine_tune_from'], 'checkpoints') 213 | state_dict = torch.load(os.path.join(checkpoints_folder, 'model.pth')) 214 | model.load_state_dict(state_dict) 215 | print("Loaded pre-trained model with success.") 216 | except FileNotFoundError: 217 | print("Pre-trained weights not found. Training from scratch.") 218 | 219 | return model 220 | 221 | def _validate(self, valid_loader,queue): 222 | 223 | # validation steps 224 | with torch.no_grad(): 225 | self.model.eval() 226 | 227 | valid_loss = 0.0 228 | counter = 0 229 | for (xis, xjs) in valid_loader: 230 | xis = xis['data'][0].float().to(self.device) 231 | xjs = xjs['data'][0].float().to(self.device) 232 | 233 | loss = self._step(self.model, self.head, xis, xjs, queue) 234 | valid_loss += loss.item() 235 | counter += 1 236 | valid_loss /= counter 237 | return valid_loss 238 | -------------------------------------------------------------------------------- /file_and_folder_operations.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | def subdirs(folder, join=True, prefix=None, suffix=None, sort=True): 5 | if join: 6 | l = os.path.join 7 | else: 8 | l = lambda x, y: y 9 | res = [l(folder, i) for i in os.listdir(folder) if os.path.isdir(os.path.join(folder, i)) 10 | and (prefix is None or i.startswith(prefix)) 11 | and (suffix is None or i.endswith(suffix))] 12 | if sort: 13 | res.sort() 14 | return res 15 | 16 | 17 | def subfiles(folder, join=True, prefix=None, suffix=None, sort=True): 18 | if join: 19 | l = os.path.join 20 | else: 21 | l = lambda x, y: y # lambda is another simplified way of defining a function 22 | res = [l(folder, i) for i in os.listdir(folder) if os.path.isfile(os.path.join(folder, i)) 23 | and (prefix is None or i.startswith(prefix)) 24 | and (suffix is None or i.endswith(suffix))] 25 | if sort: 26 | res.sort() 27 | return res 28 | 29 | 30 | def maybe_mkdir_p(directory): 31 | splits = directory.split("/")[1:] 32 | for i in range(0, len(splits)): 33 | if not os.path.isdir(os.path.join("/", *splits[:i+1])): 34 | os.mkdir(os.path.join("/", *splits[:i+1])) 35 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from datetime import datetime 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | # from configs.Config import get_config 11 | from configs.Config_mmwhs import get_config 12 | from datasets.two_dim.NumpyDataLoader import NumpyDataSet 13 | 14 | from networks.unet_con import SupConUnetInfer 15 | from loss_functions.supcon_loss import SupConSegLoss, LocalConLoss, BlockConLoss 16 | from loss_functions.metrics import SegmentationMetric 17 | from util import AverageMeter 18 | 19 | 20 | class InferenceExperiment(object): 21 | def __init__(self, config): 22 | self.config = config 23 | pkl_dir = self.config.split_dir 24 | with open(os.path.join(pkl_dir, "splits.pkl"), 'rb') as f: 25 | splits = pickle.load(f) 26 | 27 | self.train_keys = splits[self.config.fold]['train'][0:2] 28 | self.val_keys = splits[self.config.fold]['val'][0:2] 29 | 30 | self.test_data_loader = NumpyDataSet(self.config.data_dir, target_size=self.config.img_size, 31 | batch_size=2, keys=self.train_keys, do_reshuffle=False, mode="test") 32 | self.model = SupConUnetInfer(num_classes=self.config.num_classes) 33 | self.criterion = SupConSegLoss(temperature=0.7) 34 | self.criterion1 = LocalConLoss(temperature=0.7) 35 | self.criterion2 = BlockConLoss(temperature=0.7) 36 | 37 | if torch.cuda.device_count() > 1: 38 | print("Let's use", torch.cuda.device_count(), "GPUs!") 39 | # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs 40 | self.model = nn.DataParallel(self.model) 41 | 42 | self.device = torch.device(self.config.device if torch.cuda.is_available() else 'cpu') 43 | self.model.to(self.device) 44 | self.criterion.to(self.device) 45 | self.criterion1.to(self.device) 46 | self.criterion2.to(self.device) 47 | 48 | # self.load_checkpoint() 49 | 50 | self.save_folder = os.path.join(self.config.base_dir, "infer_" + self.config.name + str(datetime.now())[0:16]) 51 | if not os.path.exists(self.save_folder): 52 | os.mkdir(self.save_folder) 53 | 54 | def load_checkpoint(self): 55 | if self.config.saved_model_path is None: 56 | print('checkpoint_dir is empty, please provide directory to load checkpoint.') 57 | exit(0) 58 | else: 59 | state_dict = torch.load(self.config.saved_model_path)['model'] 60 | self.model.load_state_dict(state_dict, strict=False) 61 | # self.model.load_state_dict(state_dict) 62 | 63 | def binfer(self): 64 | self.model.eval() 65 | co_losses = AverageMeter() 66 | local_co_losses = AverageMeter() 67 | block_co_losses = AverageMeter() 68 | metric_val = SegmentationMetric(self.config.num_classes) 69 | metric_val.reset() 70 | bsz = 2 71 | 72 | with torch.no_grad(): 73 | for (i, data_batch) in enumerate(self.test_data_loader): 74 | """ 75 | data = data_batch['data'][0].float().to(self.device) 76 | labels = data_batch['seg'][0].long().to(self.device) 77 | fnames = data_batch['fnames'] 78 | slice_idx = data_batch['slice_idxs'] 79 | """ 80 | data1 = data_batch[0]['data'][0].float() 81 | target1 = data_batch[0]['seg'][0].long() 82 | 83 | data2 = data_batch[1]['data'][0].float() 84 | target2 = data_batch[1]['seg'][0].long() 85 | 86 | data = torch.cat([data1, data2], dim=0) 87 | labels = torch.cat([target1, target2], dim=0).squeeze(dim=1) # of shape [2B, 512, 512] 88 | 89 | features, output = self.model(data) 90 | output_softmax = F.softmax(output, dim=1) 91 | pred = torch.argmax(output_softmax, dim=1) 92 | metric_val.update(labels, output_softmax) 93 | # self.save_data(pred, fnames, slice_idx, 'seg') 94 | 95 | features = F.normalize(features, p=2, dim=1) 96 | # print(features.shape, labels.shape) 97 | f1, f2 = torch.split(features, [bsz, bsz], dim=0) 98 | features = torch.cat([f1.unsqueeze(1), f2.unsqueeze(1)], dim=1) # [bsz, n_view, c, img_size, img_size] 99 | l1, l2 = torch.split(labels, [bsz, bsz], dim=0) 100 | labels = torch.cat([l1.unsqueeze(1), l2.unsqueeze(1)], dim=1) 101 | labels = labels.cuda() 102 | # print(features.device, labels.device) 103 | co_loss = self.criterion(features, labels) 104 | local_co_loss = self.criterion1(features, labels) 105 | block_co_loss = self.criterion2(features, labels) 106 | if co_loss == 0: 107 | continue 108 | co_losses.update(co_loss, bsz) 109 | if local_co_loss == 0: 110 | continue 111 | local_co_losses.update(local_co_loss, bsz) 112 | if block_co_loss == 0: 113 | continue 114 | block_co_losses.update(block_co_loss, bsz) 115 | # self.save_data(features, fnames, slice_idx, 'features') 116 | 117 | if i % 10 == 0: 118 | _, _, Dice = metric_val.get() 119 | print("Index:%d, mean Dice:%.4f" % (i, Dice)) 120 | print("Index:%d, mean contrastive loss:%.4f" % (i, co_losses.avg)) 121 | 122 | print("=====Inference Finished=====") 123 | _, _, Dice = metric_val.get() 124 | print("mean Dice:", Dice) 125 | print("mean contrastive loss:", co_losses.avg.item()) 126 | print("mean local contrastive loss:", local_co_losses.avg.item()) 127 | print("mean block contrastive loss:", block_co_losses.avg.item()) 128 | 129 | def inference(self): 130 | self.model.eval() 131 | co_losses = AverageMeter() 132 | metric_val = SegmentationMetric(self.config.num_classes) 133 | metric_val.reset() 134 | bsz = 4 135 | 136 | with torch.no_grad(): 137 | for k in range(2): 138 | key = self.val_keys[k:k+1] 139 | data_loader = NumpyDataSet(self.config.data_dir, target_size=self.config.img_size, 140 | batch_size=bsz, keys=key, do_reshuffle=False, mode="test") 141 | feature_map = [] 142 | prediction = [] 143 | for (i, data_batch) in enumerate(data_loader): 144 | data = data_batch['data'][0].float().to(self.device) 145 | labels = data_batch['seg'][0].long().to(self.device) 146 | slice_idx = data_batch['slice_idxs'] 147 | 148 | features, output = self.model(data) 149 | # print(output.shape, labels.shape) 150 | output_softmax = F.softmax(output, dim=1) 151 | pred = torch.argmax(output_softmax, dim=1) 152 | metric_val.update(labels.squeeze(), output_softmax) 153 | # self.save_data(pred, fnames, slice_idx, 'seg') 154 | 155 | features = F.normalize(features, p=2, dim=1) 156 | for j in range(features.shape[0]): 157 | # feature_map.append(features[j].cpu().numpy()) 158 | prediction.append(pred[j].cpu().numpy()) 159 | # print(features.shape, labels.shape) 160 | 161 | """ 162 | if i == 30: 163 | print(slice_idx) 164 | self.save_data(features.cpu().numpy(), key[0], 'features') 165 | self.save_data(labels.cpu().numpy(), key[0], "labels") 166 | """ 167 | 168 | if i % 10 == 0: 169 | _, _, Dice = metric_val.get() 170 | print("Index:%d, mean Dice:%.4f" % (i, Dice)) 171 | 172 | # feature_map = np.stack(feature_map) 173 | prediction = np.stack(prediction) 174 | # self.save_data(feature_map, key, 'features') 175 | self.save_data(prediction, key[0], 'prediction') 176 | 177 | print("=====Inference Finished=====") 178 | _, _, Dice = metric_val.get() 179 | print("mean Dice:", Dice) 180 | 181 | def save_data(self, data, key, mode): 182 | 183 | if not os.path.exists(os.path.join(self.save_folder, mode)): 184 | os.mkdir(os.path.join(self.save_folder, mode)) 185 | 186 | save_path = os.path.join(self.save_folder, mode + '_' + key) 187 | np.save(save_path, data) 188 | 189 | """ 190 | for k in range(bsz): 191 | slice = slice_idx[k][0].numpy() 192 | file_name = fnames[k][0].split("preprocessed/")[1] 193 | save_path = os.path.join(self.save_folder, mode, str(slice) + '_' + file_name) 194 | np.save(save_path, data[k]) 195 | """ 196 | 197 | 198 | if __name__ == "__main__": 199 | c = get_config() 200 | c.saved_model_path = os.path.abspath("output_experiment") + "/20210227-065712_Unet_mmwhs/" \ 201 | + "checkpoint/" + "checkpoint_last.pth" 202 | # c.saved_model_path = os.path.abspath('save') + '/SupCon/mmwhs_models/' \ 203 | # + 'SupCon_mmwhs_adam_fold_0_lr_0.0001_decay_0.0001_bsz_4_temp_0.1_train_0.4_mlp_block_pretrained/' \ 204 | # + 'ckpt.pth' 205 | c.fold = 0 206 | print(c) 207 | exp = InferenceExperiment(config=c) 208 | exp.load_checkpoint() 209 | exp.inference() 210 | 211 | -------------------------------------------------------------------------------- /loss_functions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PerPerZXY/BHPC/1d0fd50c497a0f6cfe4c2263c6dc11132f96b08e/loss_functions/__init__.py -------------------------------------------------------------------------------- /loss_functions/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PerPerZXY/BHPC/1d0fd50c497a0f6cfe4c2263c6dc11132f96b08e/loss_functions/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /loss_functions/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PerPerZXY/BHPC/1d0fd50c497a0f6cfe4c2263c6dc11132f96b08e/loss_functions/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /loss_functions/__pycache__/dice_loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PerPerZXY/BHPC/1d0fd50c497a0f6cfe4c2263c6dc11132f96b08e/loss_functions/__pycache__/dice_loss.cpython-36.pyc -------------------------------------------------------------------------------- /loss_functions/__pycache__/dice_loss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PerPerZXY/BHPC/1d0fd50c497a0f6cfe4c2263c6dc11132f96b08e/loss_functions/__pycache__/dice_loss.cpython-38.pyc -------------------------------------------------------------------------------- /loss_functions/__pycache__/metrics.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PerPerZXY/BHPC/1d0fd50c497a0f6cfe4c2263c6dc11132f96b08e/loss_functions/__pycache__/metrics.cpython-36.pyc -------------------------------------------------------------------------------- /loss_functions/__pycache__/metrics.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PerPerZXY/BHPC/1d0fd50c497a0f6cfe4c2263c6dc11132f96b08e/loss_functions/__pycache__/metrics.cpython-38.pyc -------------------------------------------------------------------------------- /loss_functions/__pycache__/nt_xent.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PerPerZXY/BHPC/1d0fd50c497a0f6cfe4c2263c6dc11132f96b08e/loss_functions/__pycache__/nt_xent.cpython-36.pyc -------------------------------------------------------------------------------- /loss_functions/__pycache__/nt_xent.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PerPerZXY/BHPC/1d0fd50c497a0f6cfe4c2263c6dc11132f96b08e/loss_functions/__pycache__/nt_xent.cpython-38.pyc -------------------------------------------------------------------------------- /loss_functions/__pycache__/supcon_loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PerPerZXY/BHPC/1d0fd50c497a0f6cfe4c2263c6dc11132f96b08e/loss_functions/__pycache__/supcon_loss.cpython-36.pyc -------------------------------------------------------------------------------- /loss_functions/__pycache__/supcon_loss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PerPerZXY/BHPC/1d0fd50c497a0f6cfe4c2263c6dc11132f96b08e/loss_functions/__pycache__/supcon_loss.cpython-38.pyc -------------------------------------------------------------------------------- /loss_functions/dice_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch import nn 4 | 5 | 6 | def sum_tensor(input, axes, keepdim=False): 7 | axes = np.unique(axes).astype(int) 8 | if keepdim: 9 | for ax in axes: 10 | input = input.sum(int(ax), keepdim=True) 11 | else: 12 | for ax in sorted(axes, reverse=True): 13 | input = input.sum(int(ax)) 14 | return input 15 | 16 | 17 | def mean_tensor(input, axes, keepdim=False): 18 | axes = np.unique(axes).astype(int) 19 | if keepdim: 20 | for ax in axes: 21 | input = input.mean(int(ax), keepdim=True) 22 | else: 23 | for ax in sorted(axes, reverse=True): 24 | input = input.mean(int(ax)) 25 | return input 26 | 27 | 28 | class SoftDiceLoss(nn.Module): 29 | def __init__(self, smooth=1., apply_nonlin=None, batch_dice=False, do_bg=True, smooth_in_nom=True, 30 | background_weight=1, rebalance_weights=None): 31 | """ 32 | hahaa no documentation for you today 33 | :param smooth: 34 | :param apply_nonlin: 35 | :param batch_dice: 36 | :param do_bg: 37 | :param smooth_in_nom: 38 | :param background_weight: 39 | :param rebalance_weights: 40 | """ 41 | super(SoftDiceLoss, self).__init__() 42 | if not do_bg: 43 | assert background_weight == 1, "if there is no bg, then set background weight to 1 you dummy" 44 | self.rebalance_weights = rebalance_weights 45 | self.background_weight = background_weight 46 | self.smooth_in_nom = smooth_in_nom 47 | self.do_bg = do_bg 48 | self.batch_dice = batch_dice 49 | self.apply_nonlin = apply_nonlin 50 | self.smooth = smooth 51 | self.y_onehot = None 52 | if not smooth_in_nom: 53 | self.nom_smooth = 0 54 | else: 55 | self.nom_smooth = smooth 56 | 57 | def forward(self, x, y): 58 | with torch.no_grad(): 59 | y = y.long() 60 | shp_x = x.shape 61 | shp_y = y.shape 62 | if self.apply_nonlin is not None: 63 | x = self.apply_nonlin(x) 64 | if len(shp_x) != len(shp_y): 65 | y = y.view((shp_y[0], 1, *shp_y[1:])) 66 | # now x and y should have shape (B, C, X, Y(, Z))) and (B, 1, X, Y(, Z))), respectively 67 | y_max = torch.max(y) 68 | y_onehot = torch.zeros(shp_x) 69 | if x.device.type == "cuda": 70 | y_onehot = y_onehot.cuda(x.device.index) 71 | # this is really fancy 72 | y_onehot.scatter_(1, y, 1) 73 | if not self.do_bg: 74 | x = x[:, 1:] 75 | y_onehot = y_onehot[:, 1:] 76 | if not self.batch_dice: 77 | if self.background_weight != 1 or (self.rebalance_weights is not None): 78 | raise NotImplementedError("nah son") 79 | l = soft_dice(x, y_onehot, self.smooth, self.smooth_in_nom) 80 | else: 81 | l = soft_dice_per_batch_2(x, y_onehot, self.smooth, self.smooth_in_nom, 82 | background_weight=self.background_weight, 83 | rebalance_weights=self.rebalance_weights) 84 | return l 85 | 86 | 87 | def soft_dice_per_batch(net_output, gt, smooth=1., smooth_in_nom=1., background_weight=1): 88 | axes = tuple([0] + list(range(2, len(net_output.size())))) 89 | intersect = sum_tensor(net_output * gt, axes, keepdim=False) 90 | denom = sum_tensor(net_output + gt, axes, keepdim=False) 91 | weights = torch.ones(intersect.shape) 92 | weights[0] = background_weight 93 | if net_output.device.type == "cuda": 94 | weights = weights.cuda(net_output.device.index) 95 | result = (- ((2 * intersect + smooth_in_nom) / (denom + smooth)) * weights).mean() 96 | return result 97 | 98 | 99 | def soft_dice_per_batch_2(net_output, gt, smooth=1., smooth_in_nom=1., background_weight=1, rebalance_weights=None): 100 | if rebalance_weights is not None and len(rebalance_weights) != gt.shape[1]: 101 | rebalance_weights = rebalance_weights[1:] # this is the case when use_bg=False 102 | axes = tuple([0] + list(range(2, len(net_output.size())))) 103 | intersect = sum_tensor(net_output * gt, axes, keepdim=False) 104 | net_output_sqaure = sum_tensor(net_output*net_output, axes, keepdim=False) 105 | gt_square = sum_tensor(gt*gt, axes, keepdim=False) 106 | #fn = sum_tensor((1 - net_output) * gt, axes, keepdim=False) 107 | # fp = sum_tensor(net_output * (1 - gt), axes, keepdim=False) 108 | weights = torch.ones(intersect.shape) 109 | weights[0] = background_weight 110 | if net_output.device.type == "cuda": 111 | weights = weights.cuda(net_output.device.index) 112 | if rebalance_weights is not None: 113 | rebalance_weights = torch.from_numpy(rebalance_weights).float() 114 | if net_output.device.type == "cuda": 115 | rebalance_weights = rebalance_weights.cuda(net_output.device.index) 116 | intersect = intersect * rebalance_weights 117 | # fn = fn * rebalance_weights 118 | result = (1 - (2*intersect + smooth_in_nom)/(net_output_sqaure + gt_square + smooth) * weights) 119 | result = result[result > 0] # ensure that when there is no target class, the dice loss is not too large 120 | result = result.mean() 121 | return result 122 | 123 | 124 | def soft_dice(net_output, gt, smooth=1., smooth_in_nom=1.): 125 | axes = tuple(range(2, len(net_output.size()))) 126 | intersect = sum_tensor(net_output * gt, axes, keepdim=False) 127 | denom = sum_tensor(net_output + gt, axes, keepdim=False) 128 | result = (- ((2 * intersect + smooth_in_nom) / (denom + smooth))).mean() #TODO: Was ist weights and er Stelle? 129 | return result 130 | 131 | 132 | class MultipleOutputLoss(nn.Module): 133 | def __init__(self, loss, weight_factors=None): 134 | """ 135 | use this if you have several outputs that should predict the same y 136 | :param loss: 137 | :param weight_factors: 138 | """ 139 | super(MultipleOutputLoss, self).__init__() 140 | self.weight_factors = weight_factors 141 | self.loss = loss 142 | 143 | def forward(self, x, y): 144 | assert isinstance(x, (tuple, list)), "x must be either tuple or list" 145 | if self.weight_factors is None: 146 | weights = [1] * len(x) 147 | else: 148 | weights = self.weight_factors 149 | l = weights[0] * self.loss(x[0], y) 150 | for i in range(1, len(x)): 151 | l += weights[i] * self.loss(x[i], y) 152 | return l -------------------------------------------------------------------------------- /loss_functions/metrics.py: -------------------------------------------------------------------------------- 1 | import threading 2 | import torch 3 | import numpy as np 4 | 5 | # PyTroch version 6 | 7 | SMOOTH = 1e-5 8 | 9 | 10 | def dice_pytorch(outputs: torch.Tensor, labels: torch.Tensor, N_class): 11 | # You can comment out this line if you are passing tensors of equal shape 12 | # But if you are passing output from UNet or something it will most probably 13 | # be with the BATCH x 1 x H x W shape 14 | outputs = outputs.squeeze().float() 15 | labels = labels.squeeze().float() 16 | dice = torch.ones(N_class-1).float() 17 | # dice = torch.ones(N_class).float() 18 | ## for test 19 | #outputs = torch.tensor([[1,1],[3,3]]).float() 20 | #labels = torch.tensor([[0, 1], [2, 3]]).float() 21 | 22 | for iter in range(1, N_class): ## ignore the background 23 | # for iter in range(0, N_class): 24 | predict_temp = torch.eq(outputs, iter) 25 | label_temp = torch.eq(labels, iter) 26 | intersection = predict_temp & label_temp 27 | intersection = intersection.float().sum() 28 | union = (predict_temp.float().sum() + label_temp.float().sum()) 29 | 30 | if intersection>0 and union>0: 31 | dice_temp = (2*intersection)/(union) 32 | else: 33 | dice_temp = 0 34 | #print(dice_temp) 35 | dice[iter-1] = dice_temp #(intersection + SMOOTH) / (union + SMOOTH) 36 | # dice[iter] = dice_temp 37 | #print(dice) 38 | 39 | return dice # Or thresholded.mean() 40 | 41 | def iou_pytorch(outputs: torch.Tensor, labels: torch.Tensor): 42 | # You can comment out this line if you are passing tensors of equal shape 43 | # But if you are passing output from UNet or something it will most probably 44 | # be with the BATCH x 1 x H x W shape 45 | outputs = outputs.squeeze(1) # BATCH x 1 x H x W => BATCH x H x W 46 | 47 | intersection = (outputs & labels).float().sum((1, 2)) # Will be zero if Truth=0 or Prediction=0 48 | union = (outputs | labels).float().sum((1, 2)) # Will be zzero if both are 0 49 | 50 | iou = (intersection + SMOOTH) / (union + SMOOTH) # We smooth our devision to avoid 0/0 51 | 52 | thresholded = torch.clamp(20 * (iou - 0.5), 0, 10).ceil() / 10 # This is equal to comparing with thresolds 53 | 54 | return thresholded # Or thresholded.mean() if you are interested in average across the batch 55 | 56 | 57 | ################ Numpy version ################ 58 | # Well, it's the same function, so I'm going to omit the comments 59 | 60 | def iou_numpy(outputs: np.array, labels: np.array): 61 | outputs = outputs.squeeze() 62 | 63 | intersection = (outputs & labels).sum((1, 2)) 64 | union = (outputs | labels).sum((1, 2)) 65 | 66 | iou = (intersection + SMOOTH) / (union + SMOOTH) 67 | 68 | thresholded = np.ceil(np.clip(20 * (iou - 0.5), 0, 10)) / 10 69 | 70 | return thresholded # Or thresholded.mean() 71 | 72 | 73 | # Numpy version 74 | # Well, it's the same function, so I'm going to omit the comments 75 | 76 | def dice_numpy(outputs: np.array, labels: np.array): 77 | outputs = outputs.squeeze() 78 | 79 | intersection = (outputs & labels).sum((1, 2)) 80 | union = (outputs | labels).sum((1, 2)) 81 | 82 | dice = (intersection + SMOOTH) / (union + SMOOTH) 83 | 84 | return dice # Or thresholded.mean() 85 | 86 | 87 | class SegmentationMetric(object): 88 | """Computes pixAcc and mIoU metric scroes""" 89 | 90 | def __init__(self, nclass): 91 | self.nclass = nclass 92 | self.lock = threading.Lock() 93 | self.reset() 94 | 95 | def update(self, labels, preds): 96 | def evaluate_worker(self, label, pred): 97 | correct, labeled = batch_pix_accuracy( 98 | pred, label) 99 | inter, union = batch_intersection_union( 100 | pred, label, self.nclass) 101 | with self.lock: 102 | self.total_correct += correct 103 | self.total_label += labeled 104 | self.total_inter += inter 105 | self.total_union += union 106 | return 107 | 108 | if isinstance(preds, torch.Tensor): 109 | evaluate_worker(self, labels, preds) 110 | elif isinstance(preds, (list, tuple)): 111 | threads = [threading.Thread(target=evaluate_worker, 112 | args=(self, label, pred), 113 | ) 114 | for (label, pred) in zip(labels, preds)] 115 | for thread in threads: 116 | thread.start() 117 | for thread in threads: 118 | thread.join() 119 | else: 120 | raise NotImplemented 121 | 122 | def get(self, mode='mean'): 123 | pixAcc = 1.0 * self.total_correct / (np.spacing(1) + self.total_label) 124 | IoU = 1.0 * self.total_inter / (np.spacing(1) + self.total_union) 125 | Dice = 2.0 * self.total_inter / (np.spacing(1) + self.total_union + self.total_inter) 126 | if mode == 'mean': 127 | mIoU = IoU.mean() 128 | Dice = Dice.mean() 129 | return pixAcc, mIoU, Dice 130 | else: 131 | return pixAcc, IoU, Dice 132 | 133 | def reset(self): 134 | self.total_inter = 0 135 | self.total_union = 0 136 | self.total_correct = 0 137 | self.total_label = 0 138 | return 139 | 140 | def batch_pix_accuracy(output, target): 141 | """Batch Pixel Accuracy 142 | Args: 143 | predict: input 4D tensor 144 | target: label 3D tensor 145 | """ 146 | # predict = torch.max(output, 1)[1] 147 | predict = torch.argmax(output, dim=1) 148 | # predict = output 149 | 150 | # label: 0, 1, ..., nclass - 1 151 | # Note: 0 is background 152 | predict = predict.cpu().numpy().astype('int64') + 1 153 | target = target.cpu().numpy().astype('int64') + 1 154 | 155 | pixel_labeled = np.sum(target > 0) 156 | pixel_correct = np.sum((predict == target)*(target > 0)) 157 | assert pixel_correct <= pixel_labeled, \ 158 | "Correct area should be smaller than Labeled" 159 | return pixel_correct, pixel_labeled 160 | 161 | 162 | def batch_intersection_union(output, target, nclass): #只区分背景和器官: nclass = 2 163 | """Batch Intersection of Union 164 | Args: 165 | predict: input 4D tensor #model的输出 166 | target: label 3D Tensor #label 167 | nclass: number of categories (int) #只区分背景和器官: nclass = 2 168 | """ 169 | predict = torch.max(output, dim=1)[1] #获得了预测结果 170 | # predict = output 171 | mini = 1 172 | maxi = nclass-1 #nclass = 2, maxi=1 173 | nbins = nclass-1 #nclass = 2, nbins=1 174 | 175 | # label is: 0, 1, 2, ..., nclass-1 176 | # Note: 0 is background 177 | predict = predict.cpu().numpy().astype('int64') 178 | target = target.cpu().numpy().astype('int64') 179 | 180 | predict = predict * (target >= 0).astype(predict.dtype) 181 | intersection = predict * (predict == target) # 得到TP和TN 182 | 183 | # areas of intersection and union 184 | area_inter, _ = np.histogram(intersection, bins=nbins, range=(mini, maxi)) #统计(TP、TN)值为1的像素个数,获得TN 185 | area_pred, _ = np.histogram(predict, bins=nbins, range=(mini, maxi)) #统计predict中值为1的像素个数,获得TN+FN 186 | area_lab, _ = np.histogram(target, bins=nbins, range=(mini, maxi)) #统计target中值为1的像素个数,获得TN+FP 187 | area_union = area_pred + area_lab - area_inter #area_union:TN+FN+FP 188 | assert (area_inter <= area_union).all(), \ 189 | "Intersection area should be smaller than Union area" 190 | return area_inter, area_union -------------------------------------------------------------------------------- /loss_functions/nt_xent.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class NTXentLoss(torch.nn.Module): 6 | 7 | def __init__(self, device, temperature, use_cosine_similarity,beta=0.1): 8 | super(NTXentLoss, self).__init__() 9 | self.temperature = temperature 10 | self.device = device 11 | self.softmax = torch.nn.Softmax(dim=-1) 12 | self.similarity_function = self._get_similarity_function(use_cosine_similarity) 13 | self.criterion = torch.nn.CrossEntropyLoss(reduction="sum") 14 | self.beta = beta 15 | 16 | def _get_similarity_function(self, use_cosine_similarity): 17 | if use_cosine_similarity: 18 | self._cosine_similarity = torch.nn.CosineSimilarity(dim=-1) 19 | return self._cosine_simililarity 20 | else: 21 | return self._dot_simililarity 22 | 23 | def _get_correlated_mask(self, batch_size): 24 | diag = np.eye(2 * batch_size) 25 | l1 = np.eye((2 * batch_size), 2 * batch_size, k=-batch_size) 26 | l2 = np.eye((2 * batch_size), 2 * batch_size, k=batch_size) 27 | mask = torch.from_numpy((diag + l1 + l2)) 28 | mask = (1 - mask).type(torch.bool) 29 | return mask.to(self.device) 30 | 31 | @staticmethod 32 | def _dot_simililarity(x, y): 33 | v = torch.tensordot(x.unsqueeze(1), y.T.unsqueeze(0), dims=2) 34 | # x shape: (N, 1, C) 35 | # y shape: (1, C, 2N) 36 | # v shape: (N, 2N) 37 | return v 38 | 39 | def _cosine_simililarity(self, x, y): 40 | # x shape: (N, 1, C) 41 | # y shape: (1, 2N, C) 42 | # v shape: (N, 2N) 43 | v = self._cosine_similarity(x.unsqueeze(1), y.unsqueeze(0)) 44 | return v 45 | 46 | def forward(self, zis, zjs): 47 | representations = torch.cat([zjs, zis], dim=0) 48 | 49 | similarity_matrix = self.similarity_function(representations, representations) 50 | 51 | batch_size = zis.shape[0] 52 | mask_samples_from_same_repr = self._get_correlated_mask(batch_size).type(torch.bool) 53 | # filter out the scores from the positive samples 54 | l_pos = torch.diag(similarity_matrix, batch_size) 55 | r_pos = torch.diag(similarity_matrix, -batch_size) 56 | positives = torch.cat([l_pos, r_pos]).view(2 * batch_size, 1) 57 | 58 | negatives = similarity_matrix[mask_samples_from_same_repr].view(2 * batch_size, -1) 59 | # weight hard positive strategy 60 | weight = -self.beta * positives 61 | logits = torch.cat((weight+positives, negatives), dim=1) 62 | logits /= self.temperature 63 | 64 | labels = torch.zeros(2*batch_size).to(self.device).long() 65 | loss = self.criterion(logits, labels) 66 | 67 | return loss/ (2 * batch_size) 68 | -------------------------------------------------------------------------------- /main_coseg.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | import sys 5 | import argparse 6 | import time 7 | import math 8 | import pickle 9 | 10 | from torch.utils.tensorboard import SummaryWriter 11 | import torch 12 | import torch.backends.cudnn as cudnn 13 | import torch.nn.functional as F 14 | import torch.optim as optim 15 | from torchvision import transforms, datasets 16 | 17 | from util import TwoCropTransform, AverageMeter 18 | from util import adjust_learning_rate, warmup_learning_rate 19 | from util import save_model 20 | # from util import get_gpu_memory_map 21 | from networks.unet_con import SupConUnet, LocalConUnet2, LocalConUnet3 22 | from loss_functions.supcon_loss import SupConSegLoss, ContrastCELoss, BlockConLoss 23 | from datasets.two_dim.NumpyDataLoader import NumpyDataSet 24 | 25 | try: 26 | import apex 27 | from apex import amp, optimizers 28 | except ImportError: 29 | pass 30 | 31 | 32 | def parse_option(): 33 | parser = argparse.ArgumentParser('argument for training') 34 | 35 | parser.add_argument('--print_freq', type=int, default=10, 36 | help='print frequency') 37 | parser.add_argument('--save_freq', type=int, default=50, 38 | help='save frequency') 39 | parser.add_argument('--batch_size', type=int, default=16, 40 | help='batch_size') 41 | parser.add_argument('--num_workers', type=int, default=16, 42 | help='num of workers to use') # 1000 43 | parser.add_argument('--epochs', type=int, default=1000, 44 | help='number of training epochs') 45 | parser.add_argument('--pretrained_model_path', type=str, default=None, 46 | help='where to find the pretrained model') 47 | parser.add_argument('--head', type=str, default="cls", 48 | help='head mode, cls or mlp') 49 | parser.add_argument('--stride', type=int, default=4, 50 | help='number of stride when doing downsampling') 51 | parser.add_argument('--mode', type=str, default="block", 52 | help='how to downsample the feature maps, stride or block') 53 | 54 | # optimization 55 | parser.add_argument('--optimizer', type=str, default="adam", 56 | help='optimization method') 57 | parser.add_argument('--learning_rate', type=float, default=0.0001, 58 | help='learning rate') 59 | parser.add_argument('--lr_decay_epochs', type=str, default='700,800,900', 60 | help='where to decay lr, can be a list') 61 | parser.add_argument('--lr_decay_rate', type=float, default=0.1, 62 | help='decay rate for learning rate') 63 | parser.add_argument('--weight_decay', type=float, default=1e-4, 64 | help='weight decay') 65 | parser.add_argument('--momentum', type=float, default=0.1, 66 | help='momentum') 67 | 68 | # model dataset 69 | parser.add_argument('--dataset', type=str, default='mmwhs', 70 | help='dataset') 71 | parser.add_argument('--resume', type=str, default=None, 72 | help="path to the stored checkpoint") 73 | parser.add_argument('--mean', type=str, help='mean of dataset in path in form of str tuple') 74 | parser.add_argument('--std', type=str, help='std of dataset in path in form of str tuple') 75 | parser.add_argument('--data_folder', type=str, default=None, help='path to custom dataset') 76 | parser.add_argument('--size', type=int, default=32, help='parameter for RandomResizedCrop') 77 | parser.add_argument('--split_dir', type=str, default=None, help='path to split pickle file') 78 | parser.add_argument('--fold', type=int, default=0, help='parameter for splits') 79 | parser.add_argument('--train_sample', type=float, default=1.0, help='parameter for sampling rate of training set') 80 | 81 | # method 82 | parser.add_argument('--method', type=str, default='SupCon', 83 | choices=['SupCon', 'SimCLR'], help='choose method') 84 | 85 | # temperature 86 | parser.add_argument('--temp', type=float, default=0.07, 87 | help='temperature for loss function') 88 | 89 | parser.add_argument('--block_size', type=float, default=16, 90 | help='temperature for loss function') 91 | opt = parser.parse_args() 92 | opt.mode = 'block' 93 | # check if dataset is path that passed required arguments 94 | if opt.dataset == 'path': 95 | assert opt.data_folder is not None \ 96 | and opt.mean is not None \ 97 | and opt.std is not None 98 | 99 | # set the path according to the environment 100 | if opt.data_folder is None: 101 | opt.data_folder = 'data' 102 | else: 103 | opt.data_folder = os.path.join(opt.data_folder, opt.dataset, 'preprocessed') 104 | 105 | if opt.split_dir is None: 106 | opt.split_dir = os.path.join('./data', opt.dataset) 107 | opt.model_path = './save/SupCon/{}_models'.format(opt.dataset) 108 | opt.tb_path = './save/SupCon/{}_tensorboard'.format(opt.dataset) 109 | 110 | iterations = opt.lr_decay_epochs.split(',') 111 | opt.lr_decay_epochs = list([]) 112 | for it in iterations: 113 | opt.lr_decay_epochs.append(int(it)) 114 | 115 | opt.model_name = '{}_{}_{}_fold_{}_lr_{}_decay_{}_bsz_{}_temp_{}_train_{}_{}'. \ 116 | format(opt.method, opt.dataset, opt.optimizer, opt.fold, opt.learning_rate, 117 | opt.weight_decay, opt.batch_size, opt.temp, opt.train_sample, opt.mode) 118 | 119 | if opt.mode == "stride": 120 | opt.model_name = '{}_stride_{}'.format(opt.model_name, opt.stride) 121 | elif opt.mode == "block": 122 | opt.model_name = '{}_block_{}'.format(opt.model_name, opt.block_size) 123 | 124 | if opt.pretrained_model_path is not None: 125 | opt.model_name = '{}_pretrained'.format(opt.model_name) 126 | 127 | opt.tb_folder = os.path.join(opt.tb_path, opt.model_name) 128 | if not os.path.isdir(opt.tb_folder): 129 | os.makedirs(opt.tb_folder) 130 | 131 | opt.save_folder = os.path.join(opt.model_path, opt.model_name) 132 | if not os.path.isdir(opt.save_folder): 133 | os.makedirs(opt.save_folder) 134 | 135 | return opt 136 | 137 | 138 | def set_loader(opt): 139 | # construct data loader 140 | pkl_dir = opt.split_dir 141 | with open(os.path.join(pkl_dir, "splits.pkl"), 'rb') as f: 142 | splits = pickle.load(f) 143 | 144 | if opt.train_sample == 1: 145 | tr_keys = splits[opt.fold]['train'] + splits[opt.fold]['val'] + splits[opt.fold]['test'] 146 | else: 147 | tr_keys = splits[opt.fold]['train'] 148 | tr_size = int(len(tr_keys) * opt.train_sample) 149 | tr_keys = tr_keys[0:tr_size] 150 | 151 | train_loader = NumpyDataSet(opt.data_folder, target_size=64, batch_size=opt.batch_size, 152 | keys=tr_keys, do_reshuffle=True, mode="supcon") 153 | 154 | return train_loader 155 | 156 | 157 | def set_model(opt): 158 | model = SupConUnet(num_classes=3, mode=opt.head) 159 | if opt.mode == "block": 160 | criterion = BlockConLoss(temperature=opt.temp, block_size=opt.block_size) 161 | elif opt.mode == "stride": 162 | criterion = ContrastCELoss(temperature=opt.temp, stride=opt.stride) 163 | else: 164 | raise NotImplementedError("The feature downsampling mode is not supported yet!") 165 | 166 | if opt.resume is not None: 167 | if os.path.isfile(opt.resume): 168 | print("=> loading checkpoint '{}'".format(opt.resume)) 169 | ckpt = torch.load(opt.resume) 170 | model.load_state_dict(ckpt['model']) 171 | else: 172 | print("=> no checkpoint found at '{}'".format(opt.resume)) 173 | 174 | # enable synchronized Batch Normalization 175 | 176 | # if opt.syncBN: 177 | # model = apex.parallel.convert_syncbn_model(model) 178 | 179 | if torch.cuda.is_available(): 180 | if torch.cuda.device_count() > 1: 181 | model = torch.nn.DataParallel(model) 182 | criterion = torch.nn.DataParallel(criterion) 183 | model = model.cuda() 184 | criterion = criterion.cuda() 185 | cudnn.benchmark = True 186 | 187 | if opt.pretrained_model_path is not None: 188 | state_dict = torch.load(opt.pretrained_model_path) 189 | model.load_state_dict(state_dict, strict=False) 190 | print("checkpoint state dict:", state_dict.keys()) 191 | print("model state dict:", model.state_dict().keys()) 192 | print("loaded pretrained model:", opt.pretrained_model_path) 193 | 194 | return model, criterion 195 | 196 | 197 | def set_optimizer(opt, model): 198 | if opt.optimizer == "sgd": 199 | optimizer = optim.SGD(model.parameters(), 200 | lr=opt.learning_rate, 201 | momentum=opt.momentum, 202 | weight_decay=opt.weight_decay 203 | ) 204 | elif opt.optimizer == "adam": 205 | optimizer = optim.Adam(model.parameters(), 206 | lr=opt.learning_rate, 207 | weight_decay=opt.weight_decay) 208 | else: 209 | raise NotImplementedError("The optimizer is not supported.") 210 | return optimizer 211 | 212 | 213 | def train(train_loader, model, criterion, logger, optimizer, epoch, opt): 214 | # print("opt:",opt) 215 | """one epoch training""" 216 | model.train() 217 | 218 | batch_time = AverageMeter() 219 | data_time = AverageMeter() 220 | losses = AverageMeter() 221 | 222 | end = time.time() 223 | for idx, data_batch in enumerate(train_loader): 224 | data_time.update(time.time() - end) 225 | data1 = data_batch[0]['data'][0].float() 226 | target1 = data_batch[0]['seg'][0].long() 227 | 228 | data2 = data_batch[1]['data'][0].float() 229 | # 对data2 做数据扰动 : 230 | target2 = data_batch[1]['seg'][0].long() 231 | 232 | imgs = torch.cat([data1, data2], dim=0) 233 | labels = torch.cat([target1, target2], dim=0).squeeze(dim=1) # of shape [2B, 512, 512] 234 | 235 | if torch.cuda.is_available(): 236 | imgs = imgs.cuda(non_blocking=True) 237 | labels = labels.cuda(non_blocking=True) 238 | bsz = labels.shape[0] // 2 239 | img_size = imgs.shape[-1] 240 | 241 | # compute loss 242 | inf_time = time.time() 243 | # print(imgs.shape) 244 | features = model(imgs) # of shape [2b, c, 512, 512] 245 | 246 | features = F.normalize(features, p=2, dim=1) 247 | f1, f2 = torch.split(features, [bsz, bsz], dim=0) 248 | 249 | features = torch.cat([f1.unsqueeze(1), f2.unsqueeze(1)], dim=1) # [bsz, n_view, c, img_size, img_size] 250 | l1, l2 = torch.split(labels, [bsz, bsz], dim=0) 251 | labels = torch.cat([l1.unsqueeze(1), l2.unsqueeze(1)], dim=1) 252 | # gpu_map = get_gpu_memory_map() 253 | # print("model inference time:", time.time() - inf_time) 254 | # gpu_map = get_gpu_memory_map() 255 | loss_time = time.time() 256 | loss = criterion(features, labels) 257 | # print("loss time:", time.time() - loss_time) 258 | # # gpu_map = get_gpu_memory_map() 259 | # exit(0) 260 | 261 | if loss.mean() == 0: 262 | continue 263 | mask = (loss != 0) 264 | mask = mask.int().cuda() 265 | loss = (loss * mask).sum() / mask.sum() 266 | 267 | if torch.isinf(loss): 268 | print(data_batch[0]['fnames']) 269 | print(data_batch[0]['slice_idx']) 270 | print(imgs.max().item(), imgs.min().item()) 271 | losses.update(loss.item(), img_size) 272 | 273 | # SGD 274 | optimizer.zero_grad() 275 | loss.backward() 276 | optimizer.step() 277 | 278 | # measure elapsed time 279 | batch_time.update(time.time() - end) 280 | end = time.time() 281 | 282 | # print info 283 | if (idx + 1) % opt.print_freq == 0: 284 | num_iteration = idx + 1 + (epoch - 1) * len(train_loader) 285 | logger.add_scalar("train_loss", losses.avg, num_iteration) 286 | print('Train: [{0}][{1}/{2}]\t' 287 | 'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 288 | 'DT {data_time.val:.3f} ({data_time.avg:.3f})\t' 289 | 'loss {loss.val:.3f} ({loss.avg:.3f})'.format( 290 | epoch, idx + 1, len(train_loader), batch_time=batch_time, 291 | data_time=data_time, loss=losses)) 292 | sys.stdout.flush() 293 | 294 | return losses.avg 295 | 296 | 297 | def main(): 298 | opt = parse_option() 299 | print(opt) 300 | # build data loader 301 | train_loader = set_loader(opt) 302 | 303 | # build model and criterion 304 | model, criterion = set_model(opt) 305 | 306 | # build optimizer 307 | optimizer = set_optimizer(opt, model) 308 | 309 | # tensorboard 310 | logger = SummaryWriter(opt.tb_folder) 311 | 312 | # training routine 313 | for epoch in range(1, opt.epochs + 1): 314 | adjust_learning_rate(opt, optimizer, epoch) 315 | # train for one epoch 316 | time1 = time.time() 317 | loss = train(train_loader, model, criterion, logger, optimizer, epoch, opt) 318 | time2 = time.time() 319 | print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1)) 320 | 321 | # tensorboard logger 322 | logger.add_scalar('loss', loss, epoch) 323 | logger.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], epoch) 324 | 325 | if epoch % opt.save_freq == 0: 326 | save_file = os.path.join( 327 | opt.save_folder, 'ckpt.pth'.format(epoch=epoch)) 328 | save_model(model, optimizer, opt, epoch, save_file) 329 | 330 | # save the last model 331 | save_file = os.path.join( 332 | opt.save_folder, 'last.pth') 333 | save_model(model, optimizer, opt, opt.epochs, save_file) 334 | 335 | 336 | if __name__ == '__main__': 337 | print(torch.cuda.device_count()) 338 | main() 339 | -------------------------------------------------------------------------------- /main_simclr.py: -------------------------------------------------------------------------------- 1 | from experiments.simclr_experiment import SimCLR 2 | import yaml 3 | import argparse 4 | import os 5 | def parse_option(): 6 | parser = argparse.ArgumentParser("argument for run segmentation pipeline") 7 | 8 | parser.add_argument("--dataset", type=str, default="mmwhs") 9 | parser.add_argument("--batch_size", type=int, default=160) 10 | parser.add_argument("-e", "--epoch", type=int, default=100) 11 | parser.add_argument("-f", "--fold", type=int, default=1) 12 | 13 | args = parser.parse_args() 14 | return args 15 | 16 | 17 | if __name__ == "__main__": 18 | args = parse_option() 19 | if args.dataset == "mmwhs": 20 | with open("config_mmwhs.yaml", "r") as f: 21 | config = yaml.load(f, Loader=yaml.FullLoader) 22 | elif args.dataset == "hippo": 23 | with open("config.yaml", "r") as f: 24 | config = yaml.load(f, Loader=yaml.FullLoader) 25 | # elif args.dataset == "mmwhs": 26 | # with open("config_Prostate.yaml", "r") as f: 27 | # config = yaml.load(f, Loader=yaml.FullLoader) 28 | # 29 | config['batch_size'] = args.batch_size 30 | config['epochs'] = args.epoch 31 | print(config) 32 | 33 | simclr = SimCLR(config) 34 | simclr.train() 35 | -------------------------------------------------------------------------------- /networks/RecursiveUNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch import nn 4 | 5 | 6 | class UNet(nn.Module): 7 | def __init__(self, num_classes=3, in_channels=1, initial_filter_size=64, kernel_size=3, num_downs=4, norm_layer=nn.InstanceNorm2d): 8 | # norm_layer=nn.BatchNorm2d, use_dropout=False): 9 | super(UNet, self).__init__() 10 | 11 | print(initial_filter_size) 12 | # construct unet structure 13 | unet_block = UnetSkipConnectionBlock(in_channels=initial_filter_size * 2 ** (num_downs-1), out_channels=initial_filter_size * 2 ** num_downs, 14 | num_classes=num_classes, kernel_size=kernel_size, norm_layer=norm_layer, innermost=True) 15 | for i in range(1, num_downs): 16 | unet_block = UnetSkipConnectionBlock(in_channels=initial_filter_size * 2 ** (num_downs-(i+1)), 17 | out_channels=initial_filter_size * 2 ** (num_downs-i), 18 | num_classes=num_classes, kernel_size=kernel_size, submodule=unet_block, norm_layer=norm_layer) 19 | unet_block = UnetSkipConnectionBlock(in_channels=in_channels, out_channels=initial_filter_size, 20 | num_classes=num_classes, kernel_size=kernel_size, submodule=unet_block, norm_layer=norm_layer, 21 | outermost=True) 22 | 23 | self.model = unet_block 24 | 25 | def forward(self, x): 26 | return self.model(x) 27 | 28 | 29 | # Defines the submodule with skip connection. 30 | # X -------------------identity---------------------- X 31 | # |-- downsampling -- |submodule| -- upsampling --| 32 | class UnetSkipConnectionBlock(nn.Module): 33 | def __init__(self, in_channels=None, out_channels=None, num_classes=1, kernel_size=3, 34 | submodule=None, outermost=False, innermost=False, norm_layer=nn.InstanceNorm2d, use_dropout=False): 35 | super(UnetSkipConnectionBlock, self).__init__() 36 | self.outermost = outermost 37 | # downconv 38 | pool = nn.MaxPool2d(2, stride=2) 39 | conv1 = self.contract(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, norm_layer=norm_layer) 40 | conv2 = self.contract(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size, norm_layer=norm_layer) 41 | 42 | # upconv 43 | conv3 = self.expand(in_channels=out_channels*2, out_channels=out_channels, kernel_size=kernel_size) 44 | conv4 = self.expand(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size) 45 | 46 | if outermost: 47 | final = nn.Conv2d(out_channels, num_classes, kernel_size=1) 48 | down = [conv1, conv2] 49 | up = [conv3, conv4, final] 50 | model = down + [submodule] + up 51 | elif innermost: 52 | upconv = nn.ConvTranspose2d(in_channels*2, in_channels, 53 | kernel_size=2, stride=2) 54 | model = [pool, conv1, conv2, upconv] 55 | else: 56 | upconv = nn.ConvTranspose2d(in_channels*2, in_channels, kernel_size=2, stride=2) 57 | 58 | down = [pool, conv1, conv2] 59 | up = [conv3, conv4, upconv] 60 | 61 | if use_dropout: 62 | model = down + [submodule] + up + [nn.Dropout(0.5)] 63 | else: 64 | model = down + [submodule] + up 65 | 66 | self.model = nn.Sequential(*model) 67 | 68 | @staticmethod 69 | def contract(in_channels, out_channels, kernel_size=3, norm_layer=nn.InstanceNorm2d): 70 | layer = nn.Sequential( 71 | nn.Conv2d(in_channels, out_channels, kernel_size, padding=1), 72 | norm_layer(out_channels), 73 | nn.LeakyReLU(inplace=True)) 74 | return layer 75 | 76 | @staticmethod 77 | def expand(in_channels, out_channels, kernel_size=3): 78 | layer = nn.Sequential( 79 | nn.Conv2d(in_channels, out_channels, kernel_size, padding=1), 80 | nn.LeakyReLU(inplace=True), 81 | ) 82 | return layer 83 | 84 | @staticmethod 85 | def center_crop(layer, target_width, target_height): 86 | batch_size, n_channels, layer_width, layer_height = layer.size() 87 | xy1 = (layer_width - target_width) // 2 88 | xy2 = (layer_height - target_height) // 2 89 | return layer[:, :, xy1:(xy1 + target_width), xy2:(xy2 + target_height)] 90 | 91 | def forward(self, x): 92 | if self.outermost: 93 | return self.model(x) 94 | else: 95 | crop = self.center_crop(self.model(x), x.size()[2], x.size()[3]) 96 | return torch.cat([x, crop], 1) 97 | -------------------------------------------------------------------------------- /networks/UNET.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class UNet(nn.Module): 6 | 7 | def __init__(self, num_classes, in_channels=1, initial_filter_size=64, kernel_size=3, do_instancenorm=True): 8 | super().__init__() 9 | 10 | self.contr_1_1 = self.contract(in_channels, initial_filter_size, kernel_size, instancenorm=do_instancenorm) 11 | self.contr_1_2 = self.contract(initial_filter_size, initial_filter_size, kernel_size, instancenorm=do_instancenorm) 12 | self.pool = nn.MaxPool2d(2, stride=2) 13 | 14 | self.contr_2_1 = self.contract(initial_filter_size, initial_filter_size*2, kernel_size, instancenorm=do_instancenorm) 15 | self.contr_2_2 = self.contract(initial_filter_size*2, initial_filter_size*2, kernel_size, instancenorm=do_instancenorm) 16 | # self.pool2 = nn.MaxPool2d(2, stride=2) 17 | 18 | self.contr_3_1 = self.contract(initial_filter_size*2, initial_filter_size*2**2, kernel_size, instancenorm=do_instancenorm) 19 | self.contr_3_2 = self.contract(initial_filter_size*2**2, initial_filter_size*2**2, kernel_size, instancenorm=do_instancenorm) 20 | # self.pool3 = nn.MaxPool2d(2, stride=2) 21 | 22 | self.contr_4_1 = self.contract(initial_filter_size*2**2, initial_filter_size*2**3, kernel_size, instancenorm=do_instancenorm) 23 | self.contr_4_2 = self.contract(initial_filter_size*2**3, initial_filter_size*2**3, kernel_size, instancenorm=do_instancenorm) 24 | # self.pool4 = nn.MaxPool2d(2, stride=2) 25 | 26 | self.center = nn.Sequential( 27 | nn.Conv2d(initial_filter_size*2**3, initial_filter_size*2**4, 3, padding=1), 28 | nn.ReLU(inplace=True), 29 | nn.Conv2d(initial_filter_size*2**4, initial_filter_size*2**4, 3, padding=1), 30 | nn.ReLU(inplace=True), 31 | nn.ConvTranspose2d(initial_filter_size*2**4, initial_filter_size*2**3, 2, stride=2), 32 | nn.ReLU(inplace=True), 33 | ) 34 | 35 | self.expand_4_1 = self.expand(initial_filter_size*2**4, initial_filter_size*2**3) 36 | self.expand_4_2 = self.expand(initial_filter_size*2**3, initial_filter_size*2**3) 37 | self.upscale4 = nn.ConvTranspose2d(initial_filter_size*2**3, initial_filter_size*2**2, kernel_size=2, stride=2) 38 | 39 | self.expand_3_1 = self.expand(initial_filter_size*2**3, initial_filter_size*2**2) 40 | self.expand_3_2 = self.expand(initial_filter_size*2**2, initial_filter_size*2**2) 41 | self.upscale3 = nn.ConvTranspose2d(initial_filter_size*2**2, initial_filter_size*2, 2, stride=2) 42 | 43 | self.expand_2_1 = self.expand(initial_filter_size*2**2, initial_filter_size*2) 44 | self.expand_2_2 = self.expand(initial_filter_size*2, initial_filter_size*2) 45 | self.upscale2 = nn.ConvTranspose2d(initial_filter_size*2, initial_filter_size, 2, stride=2) 46 | 47 | self.expand_1_1 = self.expand(initial_filter_size*2, initial_filter_size) 48 | self.expand_1_2 = self.expand(initial_filter_size, initial_filter_size) 49 | # Output layer for segmentation 50 | self.final = nn.Conv2d(initial_filter_size, num_classes, kernel_size=1) # kernel size for final layer = 1, see paper 51 | 52 | self.softmax = torch.nn.Softmax2d() 53 | 54 | # Output layer for "autoencoder-mode" 55 | self.output_reconstruction_map = nn.Conv2d(initial_filter_size, out_channels=1, kernel_size=1) 56 | 57 | @staticmethod 58 | def contract(in_channels, out_channels, kernel_size=3, instancenorm=True): 59 | if instancenorm: 60 | layer = nn.Sequential( 61 | nn.Conv2d(in_channels, out_channels, kernel_size, padding=1), 62 | nn.InstanceNorm2d(out_channels), 63 | nn.LeakyReLU(inplace=True)) 64 | else: 65 | layer = nn.Sequential( 66 | nn.Conv2d(in_channels, out_channels, kernel_size, padding=1), 67 | nn.LeakyReLU(inplace=True)) 68 | return layer 69 | 70 | @staticmethod 71 | def expand(in_channels, out_channels, kernel_size=3): 72 | layer = nn.Sequential( 73 | nn.Conv2d(in_channels, out_channels, kernel_size, padding=1), 74 | nn.LeakyReLU(inplace=True), 75 | ) 76 | return layer 77 | 78 | @staticmethod 79 | def center_crop(layer, target_width, target_height): 80 | batch_size, n_channels, layer_width, layer_height = layer.size() 81 | xy1 = (layer_width - target_width) // 2 82 | xy2 = (layer_height - target_height) // 2 83 | return layer[:, :, xy1:(xy1 + target_width), xy2:(xy2 + target_height)] 84 | 85 | def forward(self, x, enable_concat=True, print_layer_shapes=False): 86 | concat_weight = 1 87 | if not enable_concat: 88 | concat_weight = 0 89 | 90 | contr_1 = self.contr_1_2(self.contr_1_1(x)) 91 | pool = self.pool(contr_1) 92 | 93 | contr_2 = self.contr_2_2(self.contr_2_1(pool)) 94 | pool = self.pool(contr_2) 95 | 96 | contr_3 = self.contr_3_2(self.contr_3_1(pool)) 97 | pool = self.pool(contr_3) 98 | 99 | contr_4 = self.contr_4_2(self.contr_4_1(pool)) 100 | pool = self.pool(contr_4) 101 | 102 | center = self.center(pool) 103 | 104 | crop = self.center_crop(contr_4, center.size()[2], center.size()[3]) 105 | concat = torch.cat([center, crop*concat_weight], 1) 106 | 107 | expand = self.expand_4_2(self.expand_4_1(concat)) 108 | upscale = self.upscale4(expand) 109 | 110 | crop = self.center_crop(contr_3, upscale.size()[2], upscale.size()[3]) 111 | concat = torch.cat([upscale, crop*concat_weight], 1) 112 | 113 | expand = self.expand_3_2(self.expand_3_1(concat)) 114 | upscale = self.upscale3(expand) 115 | 116 | crop = self.center_crop(contr_2, upscale.size()[2], upscale.size()[3]) 117 | concat = torch.cat([upscale, crop*concat_weight], 1) 118 | 119 | expand = self.expand_2_2(self.expand_2_1(concat)) 120 | upscale = self.upscale2(expand) 121 | 122 | crop = self.center_crop(contr_1, upscale.size()[2], upscale.size()[3]) 123 | concat = torch.cat([upscale, crop*concat_weight], 1) 124 | 125 | expand = self.expand_1_2(self.expand_1_1(concat)) 126 | 127 | if enable_concat: 128 | output = self.final(expand) 129 | if not enable_concat: 130 | output = self.output_reconstruction_map(expand) 131 | 132 | return output 133 | -------------------------------------------------------------------------------- /networks/__pycache__/RecursiveUNet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PerPerZXY/BHPC/1d0fd50c497a0f6cfe4c2263c6dc11132f96b08e/networks/__pycache__/RecursiveUNet.cpython-36.pyc -------------------------------------------------------------------------------- /networks/__pycache__/RecursiveUNet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PerPerZXY/BHPC/1d0fd50c497a0f6cfe4c2263c6dc11132f96b08e/networks/__pycache__/RecursiveUNet.cpython-38.pyc -------------------------------------------------------------------------------- /networks/__pycache__/unet_con.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PerPerZXY/BHPC/1d0fd50c497a0f6cfe4c2263c6dc11132f96b08e/networks/__pycache__/unet_con.cpython-36.pyc -------------------------------------------------------------------------------- /networks/__pycache__/unet_con.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PerPerZXY/BHPC/1d0fd50c497a0f6cfe4c2263c6dc11132f96b08e/networks/__pycache__/unet_con.cpython-38.pyc -------------------------------------------------------------------------------- /pallete.py: -------------------------------------------------------------------------------- 1 | def get_voc_pallete(num_classes): 2 | n = num_classes 3 | pallete = [0] * (n * 3) 4 | for j in range(0, n): 5 | lab = j 6 | pallete[j * 3 + 0] = 0 7 | pallete[j * 3 + 1] = 0 8 | pallete[j * 3 + 2] = 0 9 | i = 0 10 | while (lab > 0): 11 | pallete[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i)) 12 | pallete[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i)) 13 | pallete[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i)) 14 | i = i + 1 15 | lab >>= 3 16 | return pallete 17 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorboard 2 | torch 3 | torchvision 4 | trixi 5 | batchgenerators==0.18.1 6 | numpy -------------------------------------------------------------------------------- /result.txt: -------------------------------------------------------------------------------- 1 | 0.004583200243105888 2 | epoch:120 dice socre:0.004583200243105888 3 | epoch:90 dice socre:0.5681579867720501 4 | epoch:91 dice socre:0.5634988651782914 5 | epoch:92 dice socre:0.5438517154915933 6 | epoch:93 dice socre:0.18185657435492578 7 | epoch:94 dice socre:0.020556566622363468 8 | epoch:95 dice socre:0.1985564188573966 9 | epoch:96 dice socre:0.04547603745167831 10 | epoch:97 dice socre:0.0023100010562356723 11 | epoch:98 dice socre:0.0056058855592176215 12 | epoch:99 dice socre:0.04113621264744132 13 | epoch:100 dice socre:0.0001233829410770183 14 | epoch:101 dice socre:0.0430133957448039 15 | epoch:120 dice socre:0.004583200243105888 16 | epoch:90 dice socre:0.0 17 | epoch:91 dice socre:0.0 18 | epoch:92 dice socre:0.0 19 | epoch:93 dice socre:0.0 20 | epoch:94 dice socre:0.0 21 | epoch:95 dice socre:0.0 22 | epoch:96 dice socre:0.0 23 | epoch:97 dice socre:0.0 24 | epoch:98 dice socre:0.0 25 | epoch:99 dice socre:0.0 26 | epoch:100 dice socre:0.0 27 | epoch:101 dice socre:0.0 28 | epoch:102 dice socre:0.0 29 | epoch:103 dice socre:0.0 30 | epoch:104 dice socre:0.0 31 | epoch:105 dice socre:0.0 32 | epoch:106 dice socre:0.0 33 | epoch:107 dice socre:0.0 34 | epoch:108 dice socre:0.0 35 | epoch:109 dice socre:0.0 36 | epoch:110 dice socre:0.0 37 | epoch:111 dice socre:0.0 38 | epoch:112 dice socre:0.0 39 | epoch:113 dice socre:0.0 40 | epoch:114 dice socre:0.0 41 | epoch:115 dice socre:0.0 42 | epoch:116 dice socre:0.0 43 | epoch:117 dice socre:0.0 44 | epoch:118 dice socre:0.0 45 | epoch:119 dice socre:0.0 46 | epoch:120 dice socre:0.0 47 | epoch:90 dice socre:0.0 48 | epoch:91 dice socre:0.0 49 | epoch:92 dice socre:0.0 50 | epoch:90 dice socre:0.0 51 | epoch:93 dice socre:0.0 52 | epoch:91 dice socre:0.0 53 | epoch:94 dice socre:0.0 54 | epoch:95 dice socre:0.0 55 | epoch:92 dice socre:0.0 56 | epoch:96 dice socre:0.0 57 | epoch:93 dice socre:0.0 58 | epoch:97 dice socre:0.0 59 | epoch:94 dice socre:0.0 60 | epoch:98 dice socre:0.0 61 | epoch:95 dice socre:0.0 62 | epoch:99 dice socre:0.0 63 | epoch:96 dice socre:0.0 64 | epoch:100 dice socre:0.0 65 | epoch:97 dice socre:0.0 66 | epoch:101 dice socre:0.0 67 | epoch:98 dice socre:0.0 68 | epoch:102 dice socre:0.0 69 | epoch:99 dice socre:0.0 70 | epoch:103 dice socre:0.0 71 | epoch:100 dice socre:0.0 72 | epoch:104 dice socre:0.0 73 | epoch:101 dice socre:0.0 74 | epoch:105 dice socre:0.0 75 | epoch:102 dice socre:0.0 76 | epoch:106 dice socre:0.0 77 | epoch:103 dice socre:0.0 78 | epoch:107 dice socre:0.0 79 | epoch:104 dice socre:0.0 80 | epoch:108 dice socre:0.0 81 | epoch:105 dice socre:0.0 82 | epoch:109 dice socre:0.0 83 | epoch:106 dice socre:0.0 84 | epoch:110 dice socre:0.0 85 | epoch:107 dice socre:0.0 86 | epoch:111 dice socre:0.0 87 | epoch:108 dice socre:0.0 88 | epoch:112 dice socre:0.0 89 | epoch:109 dice socre:0.0 90 | epoch:113 dice socre:0.0 91 | epoch:110 dice socre:0.0 92 | epoch:114 dice socre:0.0 93 | epoch:111 dice socre:0.0 94 | epoch:115 dice socre:0.0 95 | epoch:112 dice socre:0.0 96 | epoch:116 dice socre:0.0 97 | epoch:117 dice socre:0.0 98 | epoch:113 dice socre:0.0 99 | epoch:118 dice socre:0.0 100 | epoch:114 dice socre:0.0 101 | epoch:119 dice socre:0.0 102 | epoch:115 dice socre:0.0 103 | epoch:120 dice socre:0.0 104 | epoch:116 dice socre:0.0 105 | epoch:117 dice socre:0.0 106 | epoch:118 dice socre:0.0 107 | epoch:119 dice socre:0.0 108 | epoch:120 dice socre:0.0 109 | epoch:90 dice socre:0.0 110 | epoch:91 dice socre:0.0 111 | epoch:92 dice socre:0.0 112 | epoch:93 dice socre:0.0 113 | epoch:94 dice socre:0.0 114 | epoch:95 dice socre:0.0 115 | epoch:96 dice socre:0.0 116 | epoch:97 dice socre:0.0 117 | epoch:98 dice socre:0.0 118 | epoch:99 dice socre:0.0 119 | epoch:100 dice socre:0.0 120 | epoch:101 dice socre:0.0 121 | epoch:102 dice socre:0.0 122 | epoch:103 dice socre:0.0 123 | epoch:104 dice socre:0.0 124 | epoch:105 dice socre:0.0 125 | epoch:106 dice socre:0.0 126 | epoch:107 dice socre:0.0 127 | epoch:108 dice socre:0.0 128 | epoch:109 dice socre:0.0 129 | epoch:110 dice socre:0.0 130 | epoch:111 dice socre:0.0 131 | epoch:112 dice socre:0.0 132 | epoch:113 dice socre:0.0 133 | epoch:114 dice socre:0.0 134 | epoch:115 dice socre:0.0 135 | epoch:116 dice socre:0.0 136 | epoch:117 dice socre:0.0 137 | epoch:118 dice socre:0.0 138 | epoch:119 dice socre:0.0 139 | epoch:120 dice socre:0.0 140 | -------------------------------------------------------------------------------- /run_coseg.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | source activate py37 5 | 6 | batch=4 7 | dataset=Hippocampus 8 | #dataset=mmwhs 9 | fold=3 10 | head=mlp 11 | mode=block 12 | temp=0.1 13 | 14 | train_sample=1 15 | 16 | python main_coseg.py --batch_size 4 --dataset Hippocampus \ 17 | --data_folder ./data \ 18 | --learning_rate 0.0001 \ 19 | --epochs 70 \ 20 | --head mlp \ 21 | --mode block\ 22 | --fold 3 \ 23 | --save_freq 1 \ 24 | --print_freq 10 \ 25 | --temp 0.1 \ 26 | --train_sample 1 \ 27 | --pretrained_model_path save/simclr/Hippocampus/b_120_model.pth\ 28 | # --pretrained_model_path save/simclr/Hippocampus/b_80_model.pth \ 29 | # --pretrained_model_path save/simclr/Hippocampus/b_80_model.pth \ 30 | 31 | -------------------------------------------------------------------------------- /run_mix_pipeline.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # witten by: Xinrong Hu 3 | 4 | import os 5 | import argparse 6 | import torch 7 | from os.path import exists 8 | 9 | from trixi.util import Config 10 | 11 | from configs.Config import get_config 12 | import configs.Config_mmwhs as config_mmwhs 13 | from datasets.prepare_dataset.preprocessing import preprocess_data 14 | from datasets.prepare_dataset.create_splits import create_splits 15 | from experiments.SegExperiment import SegExperiment 16 | from experiments.MixExperiment import MixExperiment 17 | from datasets.downsanpling_data import downsampling_image 18 | 19 | import datetime 20 | import time 21 | 22 | import matplotlib 23 | import matplotlib.pyplot as plt 24 | 25 | from datasets.prepare_dataset.rearrange_dir import rearrange_dir 26 | 27 | 28 | def parse_option(): 29 | parser = argparse.ArgumentParser("argument for run segmentation pipeline") 30 | 31 | parser.add_argument("--dataset", type=str, default="hippo") 32 | parser.add_argument("--train_sample", type=float, default=0.4) 33 | parser.add_argument("--batch_size", type=int, default=8) 34 | parser.add_argument("-f", "--fold", type=int, default=1) 35 | parser.add_argument("--saved_model_path", type=str, default=None) 36 | parser.add_argument("--freeze_model", action='store_true', 37 | help="whether load saved model from saved_model_path") 38 | parser.add_argument("--load_saved_model", action='store_true', 39 | help='whether freeze encoder of the segmenter') 40 | 41 | args = parser.parse_args() 42 | return args 43 | 44 | def training(config): 45 | 46 | if not os.path.exists(os.path.join(config.split_dir, "splits.pkl")): 47 | create_splits(output_dir=config.split_dir, image_dir=config.data_dir) 48 | 49 | if config.saved_model_path is not None: 50 | config.load_model = True 51 | 52 | # config.saved_model_path = os.path.abspath('save') + '/SupCon/Hippocampus_models/' \ 53 | # + 'SupCon_Hippocampus_resnet50_lr_0.0001_decay_0.0001_bsz_1_temp_0.7_trial_0_cosine/' \ 54 | # + 'last.pth' 55 | print(config) 56 | exp = MixExperiment(config=config, name=config.name, n_epochs=config.n_epochs, 57 | seed=42, append_rnd_to_name=config.append_rnd_string) # visdomlogger_kwargs={"auto_start": c.start_visdom} 58 | 59 | exp.run() 60 | exp.run_test(setup=False) 61 | 62 | 63 | def testing(config): 64 | 65 | c.do_load_checkpoint = True 66 | c.checkpoint_dir = c.base_dir + '/20210202-064334_Unet_mmwhs' + '/checkpoint/checkpoint_current' 67 | 68 | exp = SegExperiment(config=config, name='unet_test', n_epochs=config.n_epochs, 69 | seed=42, globs=globals()) 70 | exp.run_test(setup=True) 71 | 72 | 73 | if __name__ == "__main__": 74 | args = parse_option() 75 | if args.dataset == "mmwhs": 76 | c = config_mmwhs.get_config() 77 | elif args.dataset == "hippo" or args.dataset == "Hippocampus": 78 | c = get_config() 79 | else: 80 | exit('the dataset is not supoorted currently') 81 | c.fold = args.fold 82 | c.batch_size = args.batch_size 83 | c.train_sample = args.train_sample 84 | if args.load_saved_model: 85 | c.saved_model_path = os.path.abspath('save') + '/SupCon/mmwhs_models/' \ 86 | + 'SupCon_mmwhs_adam_fold_1_lr_0.0001_decay_0.0001_bsz_4_temp_0.1_train_0.4_block/' \ 87 | + 'ckpt.pth' 88 | 89 | c.saved_model_path = args.saved_model_path 90 | c.freeze = args.freeze_model 91 | training(config=c) 92 | 93 | -------------------------------------------------------------------------------- /run_seg.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #source activate py37 4 | #dataset=mmwhs 5 | dataset=hippo 6 | fold=1 7 | train_sample=0.4 8 | method=mix 9 | #model_path=SupCon_Hippocampus_adam_fold_2_lr_0.0001_decay_0.0001_bsz_4_temp_0.1_train_1.0_mlp_stride_pretrained 10 | # notice: when load saved models, remember to check whether true model is loaded 11 | python run_mix_pipeline.py --dataset hippo --train_sample 0.2 --fold 1 --batch_size 8 \ 12 | --load_saved_model 13 | #--saved_model_path data/checkpoint_last.pth.tar\ 14 | -------------------------------------------------------------------------------- /run_seg_pipeline.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | from os.path import exists 5 | 6 | from trixi.util import Config 7 | 8 | from configs.Config import get_config 9 | import configs.Config_mmwhs as config_mmwhs 10 | from datasets.prepare_dataset.preprocessing import preprocess_data 11 | from datasets.prepare_dataset.create_splits import create_splits 12 | from experiments.SegExperiment import SegExperiment 13 | from datasets.downsanpling_data import downsampling_image 14 | 15 | import datetime 16 | import time 17 | 18 | import matplotlib 19 | import matplotlib.pyplot as plt 20 | 21 | from datasets.prepare_dataset.rearrange_dir import rearrange_dir 22 | 23 | 24 | def parse_option(): 25 | parser = argparse.ArgumentParser("argument for run segmentation pipeline") 26 | 27 | parser.add_argument("--dataset", type=str, default="hippo") 28 | parser.add_argument("--train_sample", type=float, default=1) 29 | parser.add_argument("--batch_size", type=int, default=8) 30 | parser.add_argument("-f", "--fold", type=int, default=1) 31 | parser.add_argument("--saved_model_path", type=str, default=None) 32 | parser.add_argument("--freeze_model", action='store_true', 33 | help="whether load saved model from saved_model_path") 34 | parser.add_argument("--load_saved_model", action='store_true', 35 | help='whether freeze encoder of the segmenter') 36 | parser.add_argument("--learning_rate",type=float,default=0.000001*7) 37 | args = parser.parse_args() 38 | return args 39 | 40 | def training(config): 41 | 42 | if not os.path.exists(os.path.join(config.split_dir, "splits.pkl")): 43 | create_splits(output_dir=config.split_dir, image_dir=config.data_dir) 44 | 45 | if config.saved_model_path is not None: 46 | config.load_model = True 47 | # 48 | # config.saved_model_path = os.path.abspath('save') + '/SupCon/Hippocampus_models/' \ 49 | # + 'SupCon_Hippocampus_resnet50_lr_0.0001_decay_0.00001*5_bsz_1_temp_0.7_trial_0_cosine/' \ 50 | # + 'last.pth' 51 | print(config) 52 | exp = SegExperiment(config=config, name=config.name, n_epochs=config.n_epochs, 53 | seed=42, append_rnd_to_name=config.append_rnd_string) # visdomlogger_kwargs={"auto_start": c.start_visdom} 54 | 55 | exp.run() 56 | exp.run_test(setup=False) 57 | 58 | 59 | def testing(config): 60 | 61 | c.do_load_checkpoint = True 62 | c.checkpoint_dir = c.base_dir + '/20210202-064334_Unet_mmwhs' + '/checkpoint/checkpoint_current' 63 | 64 | exp = SegExperiment(config=config, name='unet_test', n_epochs=config.n_epochs, 65 | seed=42, globs=globals()) 66 | exp.run_test(setup=True) 67 | 68 | 69 | if __name__ == "__main__": 70 | args = parse_option() 71 | if args.dataset == "mmwhs": 72 | c = config_mmwhs.get_config() 73 | elif args.dataset == "hippo" or args.dataset == "Hippocampus": 74 | c = get_config() 75 | else: 76 | exit('the dataset is not supoorted currently') 77 | c.fold = args.fold 78 | c.batch_size = args.batch_size 79 | c.train_sample = args.train_sample 80 | if args.load_saved_model: 81 | c.saved_model_path = os.path.abspath('save') + '/SupCon/mmwhs_models/' \ 82 | + 'SupCon_mmwhs_adam_fold_1_lr_0.0001_decay_0.0001_bsz_4_temp_0.7_train_0.4_block/' \ 83 | + 'ckpt.pth' 84 | 85 | c.saved_model_path = args.saved_model_path 86 | c.freeze = args.freeze_model 87 | # c.learning_rate = 0.00001 88 | training(config=c) 89 | 90 | -------------------------------------------------------------------------------- /run_simclr.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source activate py37 4 | #dataset=mmwhs 5 | dataset=hippo 6 | 7 | 8 | # notice: when load saved models, remember to check whether true model is loaded 9 | python main_simclr.py --batch_size 120 --dataset hippo -e 100 \ 10 | # --load_saved_model \ 11 | 12 | #python main_simclr.py --batch_size 80 --dataset hippo -e 100 13 | # 14 | #python main_simclr.py --batch_size 40 --dataset ${dataset} -e 100 -------------------------------------------------------------------------------- /run_supcon.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source activate py37 4 | 5 | batch=1 6 | dataset=mmwhs 7 | 8 | CUDA_VISIBLE_DEVICES=1,3 python main_supcon.py --batch_size ${batch} --dataset ${dataset} \ 9 | --data_folder ./data \ 10 | --learning_rate 0.01 \ 11 | --epochs 10 \ 12 | --save_freq 5 \ 13 | --cosine \ 14 | -------------------------------------------------------------------------------- /supcon_loss.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | #from util import get_gpu_memory_map 8 | 9 | 10 | class SupConLoss(nn.Module): 11 | """modified supcon loss for segmentation application, the main difference is that the label for different view 12 | could be different if after spatial transformation""" 13 | def __init__(self, temperature=0.07, contrast_mode='all', 14 | base_temperature=0.07): 15 | super(SupConLoss, self).__init__() 16 | self.temperature = temperature 17 | self.contrast_mode = contrast_mode 18 | self.base_temperature = base_temperature 19 | 20 | def forward(self, features, labels=None): 21 | # input features shape: [bsz, v, c, w, h] 22 | # input labels shape: [bsz, v, w, h] 23 | device = (torch.device('cuda') 24 | if features.is_cuda 25 | else torch.device('cpu')) 26 | 27 | if len(features.shape) < 3: 28 | raise ValueError('`features` needs to be [bsz, n_views, ...],' 29 | 'at least 3 dimensions are required') 30 | 31 | contrast_count = features.shape[1] 32 | contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0) # of size (bsz*v, c, h, w) 33 | 34 | kernels = contrast_feature.permute(0, 2, 3, 1) 35 | kernels = kernels.reshape(-1, contrast_feature.shape[1], 1, 1) 36 | # kernels = kernels[non_background_idx] 37 | logits = torch.div(F.conv2d(contrast_feature, kernels), self.temperature) # of size (bsz*v, bsz*v*h*w, h, w) 38 | logits = logits.permute(1, 0, 2, 3) 39 | logits = logits.reshape(logits.shape[0], -1) 40 | 41 | if labels is not None: 42 | labels = torch.cat(torch.unbind(labels, dim=1), dim=0) 43 | labels = labels.contiguous().view(-1, 1) 44 | mask = torch.eq(labels, labels.T).float().to(device) 45 | 46 | bg_bool = torch.eq(labels.squeeze().cpu(), torch.zeros(labels.squeeze().shape)) 47 | non_bg_bool = ~ bg_bool 48 | non_bg_bool = non_bg_bool.int().to(device) 49 | else: 50 | mask = torch.eye(logits.shape[0]//contrast_count).float().to(device) 51 | mask = mask.repeat(contrast_count, contrast_count) 52 | # print(mask.shape) 53 | 54 | # mask-out self-contrast cases 55 | logits_mask = torch.scatter( 56 | torch.ones_like(mask), 57 | 1, 58 | torch.arange(mask.shape[0]).view(-1, 1).to(device), 59 | 0 60 | ) 61 | mask = mask * logits_mask 62 | 63 | # compute log_prob 64 | exp_logits = torch.exp(logits) * logits_mask 65 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) 66 | 67 | # compute mean of log-likelihood over positive 68 | mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1) 69 | 70 | # loss 71 | loss = - mean_log_prob_pos 72 | # loss = loss.view(anchor_count, batch_size).mean() 73 | if labels is not None: 74 | # only consider the contrastive loss for non-background pixel 75 | loss = (loss * non_bg_bool).sum() / (non_bg_bool.sum()) 76 | else: 77 | loss = loss.mean() 78 | return loss 79 | 80 | 81 | class SupConSegLoss(nn.Module): 82 | # TODO: only support batch size = 1 83 | def __init__(self, temperature=0.7): 84 | super(SupConSegLoss, self).__init__() 85 | self.temp = temperature 86 | self.device = (torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')) 87 | 88 | def forward(self, features, labels=None): 89 | # input features: [bsz, c, h ,w], h & w are the image size 90 | shape = features.shape 91 | img_size = shape[-1] 92 | if labels is not None: 93 | f1, f2 = torch.split(features, [1, 1], dim=1) 94 | features = torch.cat([f1.squeeze(1), f2.squeeze(1)], dim=0) 95 | l1, l2 = torch.split(labels, [1, 1], dim=1) 96 | labels = torch.cat([l1.squeeze(1), l2.squeeze(1)], dim=0) 97 | # features = features.squeeze(dim=1) 98 | # labels = labels.squeeze(dim=1) 99 | bsz = features.shape[0] 100 | loss = [] 101 | for b in range(bsz): 102 | # print("Iteration index:", idx, "Batch_size:", b) 103 | for i in range(img_size): 104 | # print("before ith iteration, the consumption memory is:", torch.cuda.memory_allocated() / 1024**2) 105 | for j in range(img_size): 106 | x = features[b:b + 1, :, i:i + 1, j:j + 1] # [1,c, 1, 1, 1] 107 | x_label = labels[b, i, j] + 1 # avoid cases when label=0 108 | if x_label == 1: # ignore background 109 | continue 110 | cos_dst = F.conv2d(features, x) # [2b, 1, 512, 512] 111 | cos_dst = torch.div(cos_dst.squeeze(dim=1), self.temp) 112 | # print("cos_dst:", cos_dst.max(), cos_dst.min()) 113 | self_contrast_dst = torch.div((x * x).sum(), self.temp) 114 | 115 | mask = labels + 1 116 | mask[mask != x_label] = 0 117 | # if mask.sum() < 5: 118 | # print("Not enough same label pixel") 119 | # continue 120 | mask = torch.div(mask, x_label) 121 | numerator = (mask * cos_dst).sum() - self_contrast_dst 122 | denominator = torch.exp(cos_dst).sum() - torch.exp(self_contrast_dst) 123 | # print("denominator:", denominator.item()) 124 | # print("numerator:", numerator.max(), numerator.min()) 125 | loss_tmp = torch.log(denominator) - numerator / (mask.sum() - 1) 126 | if loss_tmp != loss_tmp: 127 | print(numerator.item(), denominator.item()) 128 | 129 | loss.append(loss_tmp) 130 | if len(loss) == 0: 131 | loss = torch.tensor(0).float().to(self.device) 132 | return loss 133 | loss = torch.stack(loss).mean() 134 | return loss 135 | 136 | else: 137 | bsz = features.shape[0] 138 | loss = [] 139 | for b in range(bsz): 140 | # print("Iteration index:", idx, "Batch_size:", b) 141 | tmp_feature = features[b] 142 | for n in range(tmp_feature.shape[0]): 143 | for i in range(img_size): 144 | # print("before ith iteration, the consumption memory is:", torch.cuda.memory_allocated() / 1024**2) 145 | for j in range(img_size): 146 | x = tmp_feature[n:n+1, :, i:i + 1, j:j + 1] # [c, 1, 1, 1] 147 | cos_dst = F.conv2d(tmp_feature, x) # [2b, 1, 512, 512] 148 | cos_dst = torch.div(cos_dst.squeeze(dim=1), self.temp) 149 | # print("cos_dst:", cos_dst.max(), cos_dst.min()) 150 | self_contrast_dst = torch.div((x * x).sum(), self.temp) 151 | 152 | mask = torch.zeros((tmp_feature.shape[0], tmp_feature.shape[2], tmp_feature.shape[3]), 153 | device=self.device) 154 | mask[0:tmp_feature.shape[0], i, j] = 1 155 | numerator = (mask * cos_dst).sum() - self_contrast_dst 156 | denominator = torch.exp(cos_dst).sum() - torch.exp(self_contrast_dst) 157 | # print("numerator:", numerator.max(), numerator.min()) 158 | loss_tmp = torch.log(denominator) - numerator / (mask.sum() - 1) 159 | if loss_tmp != loss_tmp: 160 | print(numerator.item(), denominator.item()) 161 | 162 | loss.append(loss_tmp) 163 | 164 | loss = torch.stack(loss).mean() 165 | return loss 166 | 167 | 168 | class LocalConLoss(nn.Module): 169 | def __init__(self, temperature=0.7, stride=4): 170 | super(LocalConLoss, self).__init__() 171 | self.temp = temperature 172 | self.device = (torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')) 173 | self.supconloss = SupConLoss(temperature=self.temp) 174 | self.stride = stride 175 | 176 | def forward(self, features, labels=None): 177 | # input features: [bsz, num_view, c, h ,w], h & w are the image size 178 | features = features[:, :, :, ::self.stride, ::self.stride] # resample feature maps to reduce memory consumption and running time 179 | shape = features.shape 180 | img_size = shape[-1] 181 | if labels is not None: 182 | labels = labels[:, :, ::self.stride, ::self.stride] 183 | if labels.sum() == 0: 184 | loss = torch.tensor(0).float().to(self.device) 185 | return loss 186 | 187 | loss = self.supconloss(features, labels) 188 | """ 189 | f1, f2 = torch.split(features, [1, 1], dim=1) 190 | features = torch.cat([f1.squeeze(1), f2.squeeze(1)], dim=0) 191 | l1, l2 = torch.split(labels, [1, 1], dim=1) 192 | labels = torch.cat([l1.squeeze(1), l2.squeeze(1)], dim=0) 193 | bsz = features.shape[0] 194 | loss = [] 195 | for b in range(bsz): 196 | # print("Iteration index:", idx, "Batch_size:", b) 197 | for i in range(img_size): 198 | # print("before ith iteration, the consumption memory is:", torch.cuda.memory_allocated() / 1024**2) 199 | for j in range(img_size): 200 | x = features[b:b + 1, :, i:i + 1, j:j + 1] # [c, 1, 1, 1] 201 | x_label = labels[b, i, j] + 1 # avoid cases when label=0 202 | if x_label == 1: # ignore background 203 | continue 204 | cos_dst = F.conv2d(features, x) # [2b, 1, 512, 512] 205 | cos_dst = torch.div(cos_dst.squeeze(dim=1), self.temp) 206 | self_contrast_dst = torch.div((x * x).sum(), self.temp) 207 | 208 | mask = labels + 1 209 | mask[mask != x_label] = 0 210 | mask = torch.div(mask, x_label) 211 | numerator = (mask * cos_dst).sum() - self_contrast_dst 212 | denominator = torch.exp(cos_dst).sum() - torch.exp(self_contrast_dst) 213 | # print("denominator:", denominator.item()) 214 | # print("numerator:", numerator.max(), numerator.min()) 215 | loss_tmp = torch.log(denominator) - numerator / (mask.sum() - 1) 216 | if loss_tmp != loss_tmp: 217 | print(numerator.item(), denominator.item()) 218 | 219 | loss.append(loss_tmp) 220 | 221 | if len(loss) == 0: 222 | loss = torch.tensor(0).float().to(self.device) 223 | return loss 224 | loss = torch.stack(loss).mean() 225 | """ 226 | return loss 227 | else: 228 | bsz = features.shape[0] 229 | loss = self.supconloss(features) 230 | 231 | """ 232 | loss = [] 233 | for b in range(bsz): 234 | # print("Iteration index:", idx, "Batch_size:", b) 235 | tmp_feature = features[b] 236 | for n in range(tmp_feature.shape[0]): 237 | for i in range(img_size): 238 | # print("before ith iteration, the consumption memory is:", torch.cuda.memory_allocated() / 1024**2) 239 | for j in range(img_size): 240 | x = tmp_feature[n:n+1, :, i:i + 1, j:j + 1] # [c, 1, 1, 1] 241 | cos_dst = F.conv2d(tmp_feature, x) # [2b, 1, 512, 512] 242 | cos_dst = torch.div(cos_dst.squeeze(dim=1), self.temp) 243 | # print("cos_dst:", cos_dst.max(), cos_dst.min()) 244 | self_contrast_dst = torch.div((x * x).sum(), self.temp) 245 | 246 | mask = torch.zeros((tmp_feature.shape[0], tmp_feature.shape[2], tmp_feature.shape[3]), 247 | device=self.device) 248 | mask[0:tmp_feature.shape[0], i, j] = 1 249 | numerator = (mask * cos_dst).sum() - self_contrast_dst 250 | denominator = torch.exp(cos_dst).sum() - torch.exp(self_contrast_dst) 251 | # print("numerator:", numerator.max(), numerator.min()) 252 | loss_tmp = torch.log(denominator) - numerator / (mask.sum() - 1) 253 | if loss_tmp != loss_tmp: 254 | print(numerator.item(), denominator.item()) 255 | 256 | loss.append(loss_tmp) 257 | 258 | loss = torch.stack(loss).mean() 259 | """ 260 | return loss 261 | 262 | 263 | class BlockConLoss(nn.Module): 264 | def __init__(self, temperature=0.7, block_size=32): 265 | super(BlockConLoss, self).__init__() 266 | self.block_size = block_size 267 | self.device = (torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')) 268 | self.supconloss = SupConLoss(temperature=temperature) 269 | 270 | def forward(self, features, labels=None): 271 | # input features: [bsz, num_view, c, h ,w], h & w are the image size 272 | shape = features.shape 273 | img_size = shape[-1] 274 | div_num = img_size // self.block_size 275 | if labels is not None: 276 | loss = [] 277 | for i in range(div_num): 278 | # print("Iteration index:", idx, "Batch_size:", b) 279 | for j in range(div_num): 280 | # print("before ith iteration, the consumption memory is:", torch.cuda.memory_allocated() / 1024**2) 281 | block_features = features[:, :, :, i*self.block_size:(i+1)*self.block_size, 282 | j*self.block_size:(j+1)*self.block_size] 283 | block_labels = labels[:,:, i*self.block_size:(i+1)*self.block_size, 284 | j*self.block_size:(j+1)*self.block_size] 285 | 286 | if block_labels.sum() == 0: 287 | continue 288 | 289 | tmp_loss = self.supconloss(block_features, block_labels) 290 | 291 | loss.append(tmp_loss) 292 | 293 | if len(loss) == 0: 294 | loss = torch.tensor(0).float().to(self.device) 295 | return loss 296 | loss = torch.stack(loss).mean() 297 | return loss 298 | 299 | else: 300 | loss = [] 301 | for i in range(div_num): 302 | # print("Iteration index:", idx, "Batch_size:", b) 303 | for j in range(div_num): 304 | # print("before ith iteration, the consumption memory is:", torch.cuda.memory_allocated() / 1024**2) 305 | block_features = features[:, :, :, i * self.block_size:(i + 1) * self.block_size, 306 | j * self.block_size:(j + 1) * self.block_size] 307 | 308 | tmp_loss = self.supconloss(block_features) 309 | 310 | loss.append(tmp_loss) 311 | 312 | loss = torch.stack(loss).mean() 313 | return loss -------------------------------------------------------------------------------- /third-stage.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PerPerZXY/BHPC/1d0fd50c497a0f6cfe4c2263c6dc11132f96b08e/third-stage.png -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import math 4 | import os 5 | import numpy as np 6 | import torch 7 | import torch.optim as optim 8 | import subprocess 9 | 10 | 11 | class TwoCropTransform: 12 | """Create two crops of the same image""" 13 | def __init__(self, transform): 14 | self.transform = transform 15 | 16 | def __call__(self, x): 17 | return [self.transform(x), self.transform(x)] 18 | 19 | 20 | class AverageMeter(object): 21 | """Computes and stores the average and current value""" 22 | def __init__(self): 23 | self.reset() 24 | 25 | def reset(self): 26 | self.val = 0 27 | self.avg = 0 28 | self.sum = 0 29 | self.count = 0 30 | 31 | def update(self, val, n=1): 32 | self.val = val 33 | self.sum += val * n 34 | self.count += n 35 | self.avg = self.sum / self.count 36 | 37 | 38 | def accuracy(output, target, topk=(1,)): 39 | """Computes the accuracy over the k top predictions for the specified values of k""" 40 | with torch.no_grad(): 41 | maxk = max(topk) 42 | batch_size = target.size(0) 43 | 44 | _, pred = output.topk(maxk, 1, True, True) 45 | pred = pred.t() 46 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 47 | print(correct.shape) 48 | 49 | res = [] 50 | for k in topk: 51 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 52 | res.append(correct_k.mul_(100.0 / batch_size)) 53 | return res 54 | 55 | 56 | def adjust_learning_rate(args, optimizer, epoch): 57 | lr = args.learning_rate 58 | # if args.cosine: 59 | if True: 60 | eta_min = lr * (args.lr_decay_rate ** 3) 61 | lr = eta_min + (lr - eta_min) * ( 62 | 1 + math.cos(math.pi * epoch / args.epochs)) / 2 63 | # else: 64 | # steps = np.sum(epoch > np.asarray(args.lr_decay_epochs)) 65 | # if steps > 0: 66 | # lr = lr * (args.lr_decay_rate ** steps) 67 | 68 | for param_group in optimizer.param_groups: 69 | param_group['lr'] = lr 70 | 71 | 72 | def warmup_learning_rate(args, epoch, batch_id, total_batches, optimizer): 73 | if args.warm and epoch <= args.warm_epochs: 74 | p = (batch_id + (epoch - 1) * total_batches) / \ 75 | (args.warm_epochs * total_batches) 76 | lr = args.warmup_from + p * (args.warmup_to - args.warmup_from) 77 | 78 | for param_group in optimizer.param_groups: 79 | param_group['lr'] = lr 80 | 81 | 82 | def set_optimizer(opt, model): 83 | optimizer = optim.SGD(model.parameters(), 84 | lr=opt.learning_rate, 85 | momentum=opt.momentum, 86 | weight_decay=opt.weight_decay) 87 | return optimizer 88 | 89 | 90 | def save_model(model, optimizer, opt, epoch, save_file): 91 | print('==> Saving...') 92 | state = { 93 | 'opt': opt, 94 | 'model': model.state_dict(), 95 | 'optimizer': optimizer.state_dict(), 96 | 'epoch': epoch, 97 | } 98 | torch.save(state, save_file) 99 | del state 100 | 101 | 102 | def prepare_embbank(opt, data_loader, model): 103 | embedding = dict() 104 | for k in range(opt.n_cls): 105 | embedding[str(k)] = [] 106 | with torch.no_grad(): 107 | for images, labels in data_loader: 108 | images = images.float().cuda() 109 | labels = labels.numpy() 110 | 111 | batch_size = images.shape[0] 112 | embeds = model(images) 113 | for i in range(batch_size): 114 | embedding[str(labels[i])].append(embeds[None, i]) 115 | 116 | avg_emb = [] 117 | for key in embedding.keys(): 118 | avg = torch.cat(embedding[key]).mean(0, keepdim=True) 119 | avg_emb.append(avg) 120 | 121 | avg_emb = torch.cat(avg_emb) 122 | torch.save(avg_emb, os.path.join(opt.save_dir, 'embedding.pth')) 123 | 124 | return avg_emb 125 | 126 | 127 | # def get_gpu_memory_map(): 128 | # """Get the current gpu usage. 129 | # 130 | # Returns 131 | # ------- 132 | # usage: dict 133 | # Keys are device ids as integers. 134 | # Values are memory usage as integers in MB. 135 | # """ 136 | # result = subprocess.check_output( 137 | # [ 138 | # 'nvidia-smi', '--query-gpu=memory.used', 139 | # '--format=csv,nounits,noheader' 140 | # ], encoding='utf-8') 141 | # # Convert lines into a dictionary 142 | # gpu_memory = [int(x) for x in result.strip().split('\n')] 143 | # gpu_memory_map = dict(zip(range(len(gpu_memory)), gpu_memory)) 144 | # print(gpu_memory_map) 145 | # return gpu_memory_map -------------------------------------------------------------------------------- /utilities/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PerPerZXY/BHPC/1d0fd50c497a0f6cfe4c2263c6dc11132f96b08e/utilities/__init__.py -------------------------------------------------------------------------------- /utilities/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PerPerZXY/BHPC/1d0fd50c497a0f6cfe4c2263c6dc11132f96b08e/utilities/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /utilities/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PerPerZXY/BHPC/1d0fd50c497a0f6cfe4c2263c6dc11132f96b08e/utilities/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /utilities/__pycache__/file_and_folder_operations.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PerPerZXY/BHPC/1d0fd50c497a0f6cfe4c2263c6dc11132f96b08e/utilities/__pycache__/file_and_folder_operations.cpython-36.pyc -------------------------------------------------------------------------------- /utilities/__pycache__/file_and_folder_operations.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PerPerZXY/BHPC/1d0fd50c497a0f6cfe4c2263c6dc11132f96b08e/utilities/__pycache__/file_and_folder_operations.cpython-38.pyc -------------------------------------------------------------------------------- /utilities/file_and_folder_operations.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | def subdirs(folder, join=True, prefix=None, suffix=None, sort=True): 5 | if join: 6 | l = os.path.join 7 | else: 8 | l = lambda x, y: y 9 | res = [l(folder, i) for i in os.listdir(folder) if os.path.isdir(os.path.join(folder, i)) 10 | and (prefix is None or i.startswith(prefix)) 11 | and (suffix is None or i.endswith(suffix))] 12 | if sort: 13 | res.sort() 14 | return res 15 | 16 | 17 | def subfiles(folder, join=True, prefix=None, suffix=None, sort=True): 18 | if join: 19 | l = os.path.join 20 | else: 21 | l = lambda x, y: y # lambda is another simplified way of defining a function 22 | res = [l(folder, i) for i in os.listdir(folder) if os.path.isfile(os.path.join(folder, i)) 23 | and (prefix is None or i.startswith(prefix)) 24 | and (suffix is None or i.endswith(suffix))] 25 | if sort: 26 | res.sort() 27 | return res 28 | 29 | 30 | def maybe_mkdir_p(directory): 31 | splits = directory.split("/")[1:] 32 | for i in range(0, len(splits)): 33 | if not os.path.isdir(os.path.join("/", *splits[:i+1])): 34 | os.mkdir(os.path.join("/", *splits[:i+1])) 35 | --------------------------------------------------------------------------------