├── .gitignore ├── LICENSE ├── README.md ├── images └── augmentations.png ├── requirements.txt └── simclr ├── __init__.py ├── augmentations.py ├── datasets.py ├── get_dataloader.py ├── linear.py ├── main.py ├── model.py ├── optimisers.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | /training 2 | /.neptune 3 | /.vscode 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Learning Representations with Contrastive Self-supervised Learning for Histopathology Application 2 | This is the official repo of the paper [Learning Representations with Contrastive Self-supervised Learning for Histopathology Application](https://arxiv.org/abs/2112.05760) by Karin Stacke, Jonas Unger, Claes Lundström and Gabriel Eilertsen, 2021, Journal of Machine Learning for Biomedical Imaging (MELBA) [https://www.melba-journal.org/papers/2022:023.html](https://www.melba-journal.org/papers/2022:023.html). 3 | 4 | The code builds upon the SimCLR implementation from: https://github.com/google-research/simclr. 5 | 6 | 7 | ## Requirements 8 | - Python >= 3.6 9 | - [PyTorch](https://pytorch.org) (version 1.7.0 with cuda 11.0 was used for the experiments in the paper) 10 | - Additional packages listed in `requirements.txt` 11 | 12 | 13 | ## Dataset 14 | This code assumes that images are stored in the `image_dir` folder, and that the there is a comma-separated .csv-file associated with it. 15 | The .csv file should have the following columns: 16 | 17 | * `filename` - Image will be loaded as `image_dir/filename` 18 | * `label` - string label matching the label enum associated with the dataset in question, for example TUMOR or NONTUMOR. Will be read but ignored during unsupervised training. 19 | * `slide_id` - Id of slide, used during inference 20 | * `patch_id` - Optional unique id of image patch. If missing, filename will be used instead. Used during inference. 21 | 22 | Dataset csv-files used to generate the results in the paper can be found [here](https://computergraphics.on.liu.se/ssl-pathology/datasets/). These also include patch-coordinates, for re-sampling of the dataset patches. 23 | 24 | ## Usage 25 | 26 | Please see `./simclr/main.py` or `./simclr/linear.py` for a full list of arguments/parameters available. 27 | 28 | Make sure the repo folder is in the python path. 29 | 30 | ### Train SimCLR with pathology data 31 | 32 | ``` 33 | python main.py 34 | required arguments: 35 | --training_data_csv 36 | --test_data_csv 37 | --validation_data_csv 38 | --data_input_dir 39 | --save_dir 40 | 41 | optional arguments: 42 | --feature_dim Feature dim for latent vector [default value is 128] 43 | --temperature Temperature used in softmax [default value is 0.5] 44 | --batch_size Number of images per GPU in each mini-batch [default value is 512] 45 | --epochs Number of sweeps over the dataset to train [default value is 500] 46 | --dataset Dataset to use [default value is cam] 47 | --lr Starting learning rate [default value is 1e-3] 48 | --use_album Bool to use Albumentations instead of Torchvision as augmentation library [default value is false] 49 | --image_size Image size during training [default value is 224] 50 | 51 | ``` 52 | 53 | ### Linear Evaluation 54 | 55 | Discards multi-linear head and adds a single linear layer. Default mode is to keep pre-trained weights frozen, and only train the added linear layer. Please use the flag `--finetune` to train all weights. 56 | 57 | ``` 58 | python simclr/linear.py 59 | required arguments: 60 | --model_path Path to pre-trained SimCLR model 61 | OR --pretrained Use ImageNet as pre-trained model 62 | OR --random No pre-training, random intialization 63 | --training_data_csv 64 | --test_data_csv 65 | --validation_data_csv 66 | --data_input_dir 67 | --save_dir 68 | 69 | optional arguments: 70 | --finetune 71 | --batch_size Number of images per GPU in each mini-batch [default value is 512] 72 | --epochs Number of sweeps over the dataset to train [default value is 100] 73 | --dataset Dataset to use [default value is cam] 74 | --lr Starting learning rate [default value is 1e-3] 75 | --use_album Bool to use Albumentations instead of Torchvision as augmentation library [default value is false] 76 | --image_size Image size during training [default value is 224] 77 | ``` 78 | ### Augmentations 79 | 80 | drawing 81 | 82 | Use either [Torchvision](https://pytorch.org/vision/stable/transforms.html) or [Albumentations](https://albumentations.ai/) as transformation library by using `--use_album` flag. 83 | 84 | From parameters, the following augmentations can be controlled (see example image above): 85 | 86 | ``` 87 | --image_size A smaller size than input image will during training mean random crop 88 | --scale 0.95 1.0 Random resize crop with scale interval 89 | --rgb_gaussian_blur_p Probability of applying gaussian blur 90 | --rgb_jitter_d Color jitter parameter 91 | --rgb_jitter_p Probability of color jitter parameter 92 | --rgb_contrast Contrast 93 | --rgb_contrast_p Probability of contrast 94 | --rgb_grid_distort_p Probability of applying grid distrort (Albumentation only) 95 | --rgb_grid_shuffle_p Probability of applying random grid shuffle (Albumentation only) 96 | ``` 97 | 98 | 99 | 100 | ## Trained models 101 | 102 | SimCLR models trained for this paper are available for download. Please see metadata file for exact parameters used for training (including augmentation and hyper parameters). 103 | 104 | Below are the SimCLR models given used as pre-trained weights in *Section 5.2 Downstream Performance*. 105 | 106 | | Method | Dataset | Batch Size | Epochs | Augmentations | Download | 107 | | ------ | ---------- | ---------- | ------ | ------------------------------ | ---------------- | 108 | | SimCLR | Camelyon16 | 4x256=1024 | 200 | Original | [model](https://computergraphics.on.liu.se/ssl-pathology/5_2_models/20210711_1302_simclr_org_1.2/128_0.5_200_256_200_model_200.pth) \| [metadata](https://computergraphics.on.liu.se/ssl-pathology/5_2_models/20210711_1302_simclr_org_1.2/metadata_train_cleaned.txt) | 109 | | SimCLR | Camelyon16 | 4x256=1024 | 200 | Base + Scale | [model](https://computergraphics.on.liu.se/ssl-pathology/5_2_models/20210712_1851_simclr_org_1.2_noblur/128_0.5_200_256_200_model_200.pth) \| [metadata](https://computergraphics.on.liu.se/ssl-pathology/5_2_models/20210712_1851_simclr_org_1.2_noblur/metadata_train_cleaned.txt) | 110 | | SimCLR | AIDA-LNSK | 4x256=1024 | 200 | Original | [model](https://computergraphics.on.liu.se/ssl-pathology/5_2_models/20210915_1110_skin_1.2_org/128_0.5_200_256_200_model_200.pth) \| [metadata](https://computergraphics.on.liu.se/ssl-pathology/5_2_models/20210915_1110_skin_1.2_org/metadata_train_cleaned.txt) | 111 | | SimCLR | AIDA-LNSK | 4x256=1024 | 200 | Base + GridDist. + GridShuffle | [model](https://computergraphics.on.liu.se/ssl-pathology/5_2_models/20211022_1528_skin_1.2_dist_shuffle/128_0.5_200_256_200_model_200.pth) \| [metadata](https://computergraphics.on.liu.se/ssl-pathology/5_2_models/20211022_1528_skin_1.2_dist_shuffle/metadata_train_cleaned.txt) | 112 | | SimCLR | Multidata | 4x256=1024 | 200 | Original | [model](https://computergraphics.on.liu.se/ssl-pathology/5_2_models/20210714_0910_simclr_multidata_org/128_0.5_200_256_200_model_200.pth) \| [metadata](https://computergraphics.on.liu.se/ssl-pathology/5_2_models/20210714_0910_simclr_multidata_org/metadata_train_cleaned.txt) | 113 | | SimCLR | Multidata | 4x256=1024 | 200 | Base + Scale | [model](https://computergraphics.on.liu.se/ssl-pathology/5_2_models/20210714_0911_simclr_multidata_org_noblur/128_0.5_200_256_200_model_200.pth) \| [metadata](https://computergraphics.on.liu.se/ssl-pathology/5_2_models/20210714_0911_simclr_multidata_org_noblur/metadata_train_cleaned.txt) | 114 | | SimCLR | Multidata | 4x256=1024 | 200 | Base + GridDist. + GridShuffle | [model](https://computergraphics.on.liu.se/ssl-pathology/5_2_models/20211022_1529_simclr_multidata_dist_shuffle/128_0.5_200_256_200_model_200.pth) \| [metadata](https://computergraphics.on.liu.se/ssl-pathology/5_2_models/20211022_1529_simclr_multidata_dist_shuffle/metadata_train_cleaned.txt) | 115 | 116 | ### 117 | -------------------------------------------------------------------------------- /images/augmentations.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/k-stacke/ssl-pathology/f09b9475ccbc3f4b53491bd08f899fc49c339400/images/augmentations.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | pandas 3 | matplotlib 4 | pillow 5 | tqdm 6 | lmdb 7 | albumentations -------------------------------------------------------------------------------- /simclr/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/k-stacke/ssl-pathology/f09b9475ccbc3f4b53491bd08f899fc49c339400/simclr/__init__.py -------------------------------------------------------------------------------- /simclr/augmentations.py: -------------------------------------------------------------------------------- 1 | 2 | import random 3 | import numpy as np 4 | from PIL import Image, ImageFilter 5 | import torch 6 | import torch.nn as nn 7 | from torchvision import transforms 8 | import torchvision.transforms.functional as F 9 | 10 | from skimage.color import rgb2hed 11 | import albumentations as A 12 | from albumentations.pytorch import ToTensorV2 13 | 14 | 15 | class GaussianBlur(object): 16 | """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709""" 17 | """Borrowed from MoCo implementation""" 18 | 19 | def __init__(self, sigma=[.1, 2.]): 20 | self.sigma = sigma 21 | 22 | def __call__(self, x): 23 | sigma = random.uniform(self.sigma[0], self.sigma[1]) 24 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma)) 25 | return x 26 | 27 | class FixedRandomRotation: 28 | """Rotate by one of the given angles.""" 29 | def __init__(self, angles): 30 | self.angles = angles 31 | 32 | def __call__(self, x): 33 | angle = random.choice(self.angles) 34 | return transforms.functional.rotate(x, angle) 35 | 36 | class Denormalize(object): 37 | def __init__(self, mean, std, inplace=False): 38 | self.mean = mean 39 | self.demean = [-m/s for m, s in zip(mean, std)] 40 | self.std = std 41 | self.destd = [1/s for s in std] 42 | self.inplace = inplace 43 | 44 | def __call__(self, tensor): 45 | tensor = F.normalize(tensor, self.demean, self.destd, self.inplace) 46 | # clamp to get rid of numerical errors 47 | return torch.clamp(tensor, 0.0, 1.0) 48 | 49 | 50 | class AlbumentationsTransform: 51 | """Wrapper for Albumnetation transforms""" 52 | def __init__(self, aug): 53 | self.aug = aug 54 | 55 | def __call__(self, img): 56 | aug_img = self.aug(image=np.array(img))['image'] 57 | return aug_img 58 | 59 | 60 | 61 | def torchvision_transforms(eval=False, aug=None): 62 | 63 | trans = [] 64 | 65 | if aug["resize"]: 66 | trans.append(transforms.Resize(aug["resize"])) 67 | 68 | if aug["randcrop"] and aug["scale"] and not eval: 69 | trans.append(transforms.RandomResizedCrop(aug["randcrop"], scale=aug["scale"])) 70 | 71 | if aug["randcrop"] and eval: 72 | trans.append(transforms.CenterCrop(aug["randcrop"])) 73 | 74 | if aug["flip"] and not eval: 75 | trans.append(transforms.RandomHorizontalFlip(p=0.5)) 76 | trans.append(transforms.RandomVerticalFlip(p=0.5)) 77 | 78 | if aug["jitter_d"] and not eval: 79 | trans.append(transforms.RandomApply( 80 | [transforms.ColorJitter(0.8*aug["jitter_d"], 0.8*aug["jitter_d"], 0.8*aug["jitter_d"], 0.2*aug["jitter_d"])], 81 | p=aug["jitter_p"])) 82 | 83 | if aug["gaussian_blur"] and not eval: 84 | trans.append(transforms.RandomApply([GaussianBlur([.1, 2.])], p=aug["gaussian_blur"])) 85 | 86 | if aug["rotation"] and not eval: 87 | # rotation_transform = FixedRandomRotation(angles=[0, 90, 180, 270]) 88 | trans.append(FixedRandomRotation(angles=[0, 90, 180, 270])) 89 | 90 | if aug["grayscale"]: 91 | trans.append(transforms.Grayscale()) 92 | trans.append(transforms.ToTensor()) 93 | trans.append(transforms.Normalize(mean=aug["bw_mean"], std=aug["bw_std"])) 94 | elif aug["mean"]: 95 | trans.append(transforms.ToTensor()) 96 | trans.append(transforms.Normalize(mean=aug["mean"], std=aug["std"])) 97 | else: 98 | trans.append(transforms.ToTensor()) 99 | 100 | # trans = transforms.Compose(trans) 101 | return trans 102 | 103 | def album_transforms(eval=False, aug=None): 104 | trans = [] 105 | 106 | if aug["resize"]: 107 | trans.append(A.Resize(aug["resize"], aug["resize"], always_apply=True)) 108 | 109 | if aug["randcrop"] and not eval: 110 | #trans.append(A.PadIfNeeded(min_height=aug["randcrop"], min_width=aug["randcrop"])) 111 | trans.append(A.RandomResizedCrop(width=aug["randcrop"], height=aug["randcrop"], scale=aug["scale"])) 112 | 113 | if aug["randcrop"] and eval: 114 | #trans.append(A.PadIfNeeded(min_height=aug["randcrop"], min_width=aug["randcrop"])) 115 | trans.append(A.CenterCrop(width=aug["randcrop"], height=aug["randcrop"])) 116 | 117 | if aug["flip"] and not eval: 118 | trans.append(A.Flip(p=0.5)) 119 | #trans.append(A.HorizontalFlip(p=0.5)) 120 | 121 | if aug["jitter_d"] and not eval: 122 | trans.append(A.ColorJitter(0.8*aug["jitter_d"], 0.8*aug["jitter_d"], 0.8*aug["jitter_d"], 0.2*aug["jitter_d"], 123 | p=aug["jitter_p"])) 124 | 125 | if aug["gaussian_blur"] and not eval: 126 | trans.append(A.GaussianBlur(blur_limit=(3,7), sigma_limit=(0.1, 2), p=aug["gaussian_blur"])) 127 | 128 | if aug["rotation"] and not eval: 129 | trans.append(A.RandomRotate90(p=0.5)) 130 | 131 | if aug["mean"]: 132 | trans.append(A.Normalize(mean=aug["mean"], std=aug["std"])) 133 | 134 | # Pathology specific augmentation 135 | if aug["grid_distort"] and not eval: 136 | trans.append(A.GridDistortion(num_steps=9, distort_limit=0.2, interpolation=1, border_mode=2, p=aug["grid_distort"])) 137 | if aug["contrast"] and not eval: 138 | trans.append(A.RandomContrast(limit=aug["contrast"], p=aug["contrast_p"])) 139 | if aug["grid_shuffle"] and not eval: 140 | trans.append(A.RandomGridShuffle(grid=(3, 3), p=aug["grid_shuffle"])) 141 | 142 | trans.append(ToTensorV2()) 143 | 144 | return trans 145 | 146 | def get_rgb_transforms(opt, eval=False): 147 | aug = { 148 | "resize": None, 149 | "randcrop": opt.image_size, 150 | "scale": opt.scale, 151 | "flip": True, 152 | "jitter_d": opt.rgb_jitter_d, 153 | "jitter_p": opt.rgb_jitter_p, 154 | "grayscale": False, 155 | "gaussian_blur": opt.rgb_gaussian_blur_p, 156 | "rotation": True, 157 | "contrast": opt.rgb_contrast, 158 | "contrast_p": opt.rgb_contrast_p, 159 | "grid_distort": opt.rgb_grid_distort_p, 160 | "grid_shuffle": opt.rgb_grid_shuffle_p, 161 | "mean": [0.4914, 0.4822, 0.4465], # values for train+unsupervised combined 162 | "std": [0.2023, 0.1994, 0.2010], 163 | "bw_mean": [0.4120], # values for train+unsupervised combined 164 | "bw_std": [0.2570], 165 | } 166 | if opt.use_album: 167 | return transforms.Compose([AlbumentationsTransform(A.Compose(album_transforms(eval=eval, aug=aug)))]) 168 | 169 | return transforms.Compose(torchvision_transforms(eval=eval, aug=aug)) 170 | 171 | 172 | def get_transforms(opt, eval=False): 173 | return get_rgb_transforms(opt, eval) 174 | 175 | 176 | -------------------------------------------------------------------------------- /simclr/datasets.py: -------------------------------------------------------------------------------- 1 | import random 2 | from PIL import Image 3 | import lmdb 4 | import h5py 5 | import numpy as np 6 | 7 | import torch 8 | from torch.utils.data import Dataset 9 | import torchvision 10 | from torchvision.transforms import transforms 11 | 12 | class ImagePatchesDataset(Dataset): 13 | def __init__(self, opt, dataframe, image_dir, transform=None, label_enum=None): 14 | self.opt = opt 15 | self.dataframe = dataframe 16 | self.image_dir = image_dir 17 | self.transform = transform 18 | self.image_size = opt.image_size 19 | 20 | self.label_enum = {'TUMOR': 1, 'NONTUMOR': 0} if label_enum is None else label_enum 21 | print(self.label_enum) 22 | 23 | def __len__(self): 24 | return len(self.dataframe.index) 25 | 26 | 27 | def get_views(self, image): 28 | pos_1 = self.transform(image) 29 | pos_2 = torch.zeros_like(pos_1) if self.opt.train_supervised else self.transform(image) 30 | return pos_1, pos_2 31 | 32 | 33 | def __getitem__(self, index): 34 | row = self.dataframe.iloc[index] 35 | path = f"{self.image_dir}/{row.filename}" 36 | try: 37 | image = Image.open(path) # pil image 38 | except IOError: 39 | print(f"could not open {path}") 40 | return None 41 | 42 | pos_1, pos_2 = self.get_views(image) 43 | 44 | label = self.label_enum[row.label] 45 | 46 | try: 47 | id_ = row.patch_id 48 | except AttributeError: 49 | id_ = row.filename 50 | return pos_1, pos_2, label, id_, row.slide_id 51 | 52 | 53 | class LmdbDataset(torch.utils.data.Dataset): 54 | def __init__(self, lmdb_path, transform): 55 | self.cursor_access = False 56 | self.lmdb_path = lmdb_path 57 | self.image_dimensions = (224, 224, 3) # size of files in lmdb 58 | self.transform = transform 59 | 60 | self._init_db() 61 | 62 | def __len__(self): 63 | return self.length 64 | 65 | def _init_db(self): 66 | num_readers = 999 67 | 68 | self.env = lmdb.open(self.lmdb_path, 69 | max_readers=num_readers, 70 | readonly=1, 71 | lock=0, 72 | readahead=0, 73 | meminit=0) 74 | 75 | self.txn = self.env.begin(write=False) 76 | self.cursor = self.txn.cursor() 77 | 78 | self.length = self.txn.stat()['entries'] 79 | print('Generating keys to lmdb dataset, this takes a while...') 80 | self.keys = [key for key, _ in self.txn.cursor()] # not so fast... 81 | 82 | def close(self): 83 | self.env.close() 84 | 85 | def __getitem__(self, index): 86 | ' cursor in lmdb is much faster than random access ' 87 | if self.cursor_access: 88 | if not self.cursor.next(): 89 | self.cursor.first() 90 | image = self.cursor.value() 91 | else: 92 | image = self.txn.get(self.keys[index]) 93 | 94 | image = np.frombuffer(image, dtype=np.uint8) 95 | image = image.reshape(self.image_dimensions) 96 | image = Image.fromarray(image) 97 | 98 | pos_1 = self.transform(image) 99 | pos_2 = self.transform(image) 100 | 101 | return pos_1, pos_2 102 | 103 | -------------------------------------------------------------------------------- /simclr/get_dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | import pandas as pd 5 | 6 | import torch 7 | import torchvision 8 | from torchvision.transforms import transforms 9 | 10 | from torch.utils.data import WeightedRandomSampler 11 | 12 | from augmentations import get_transforms, torchvision_transforms, album_transforms, AlbumentationsTransform, get_rgb_transforms 13 | from datasets import ImagePatchesDataset, LmdbDataset 14 | import albumentations as A 15 | 16 | 17 | def get_dataloader(opt): 18 | if opt.dataset == 'cam' or opt.dataset == 'skin': 19 | train_loader, train_dataset, val_loader, val_dataset, test_loader, test_dataset = get_camelyon_dataloader( 20 | opt 21 | ) 22 | elif opt.dataset == 'multidata': 23 | train_loader, train_dataset, val_loader, val_dataset, test_loader, test_dataset = get_multidata_dataloader(opt) 24 | raise Exception("Invalid option") 25 | 26 | return ( 27 | train_loader, 28 | train_dataset, 29 | val_loader, 30 | val_dataset, 31 | test_loader, 32 | test_dataset, 33 | ) 34 | 35 | 36 | def get_weighted_sampler(dataset, num_samples): 37 | df = dataset.dataframe 38 | # Get number of sampler per label. Weight = 1/num sampels 39 | class_weights = { row.label: 1/row[0] for _, row in df.groupby(['label']).size().reset_index().iterrows()} 40 | print(class_weights) 41 | # Set weights per sample in dataset 42 | weights = [class_weights[row.label] for _, row in df.iterrows()] 43 | return WeightedRandomSampler(weights=weights, num_samples=num_samples) 44 | 45 | def clean_data(img_dir, dataframe): 46 | """ Clean the data """ 47 | for idx, row in dataframe.iterrows(): 48 | if not os.path.isfile(f'{img_dir}/{row.filename}') or (os.stat(f'{img_dir}/{row.filename}').st_size == 0): 49 | print(f"Removing non-existing file from dataset: {img_dir}/{row.filename}") 50 | dataframe = dataframe.drop(idx) 51 | return dataframe 52 | 53 | 54 | def get_dataframes(opt): 55 | if os.path.isfile(opt.training_data_csv): 56 | print("reading csv file: ", opt.training_data_csv) 57 | train_df = pd.read_csv(opt.training_data_csv) 58 | else: 59 | raise Exception(f'Cannot find file: {opt.training_data_csv}') 60 | 61 | if os.path.isfile(opt.test_data_csv): 62 | print("reading csv file: ", opt.test_data_csv) 63 | test_df = pd.read_csv(opt.test_data_csv) 64 | else: 65 | raise Exception(f'Cannot find file: {opt.test_data_csv}') 66 | 67 | if opt.trainingset_split: 68 | # Split train_df into train and val 69 | slide_ids = train_df.slide_id.unique() 70 | random.shuffle(slide_ids) 71 | train_req_ids = [] 72 | valid_req_ids = [] 73 | # Take same number of slides from each site 74 | training_size = int(len(slide_ids)*opt.trainingset_split) 75 | validation_size = len(slide_ids) - training_size 76 | train_req_ids.extend([slide_id for slide_id in slide_ids[:training_size]]) # take first 77 | valid_req_ids.extend([ 78 | slide_id for slide_id in slide_ids[training_size:training_size+validation_size]]) # take last 79 | 80 | print("train / valid / total") 81 | print(f"{len(train_req_ids)} / {len(valid_req_ids)} / {len(slide_ids)}") 82 | 83 | val_df = train_df[train_df.slide_id.isin(valid_req_ids)] # First, take the slides for validation 84 | train_df = train_df[train_df.slide_id.isin(train_req_ids)] # Update train_df 85 | 86 | else: 87 | if os.path.isfile(opt.validation_data_csv): 88 | print("reading csv file: ", opt.validation_data_csv) 89 | val_df = pd.read_csv(opt.validation_data_csv) 90 | else: 91 | raise Exception(f'Cannot find file: {opt.test_data_csv}') 92 | 93 | if opt.balanced_training_set: 94 | print('Use uniform training set') 95 | samples_to_take = train_df.groupby('label').size().min() 96 | train_df = pd.concat([train_df[train_df.label == label].sample(samples_to_take) for label in train_df.label.unique()]) 97 | 98 | if opt.balanced_validation_set: 99 | print('Use uniform validation set') 100 | samples_to_take = val_df.groupby('label').size().min() 101 | val_df = pd.concat([val_df[val_df.label == label].sample(samples_to_take) for label in val_df.label.unique()]) 102 | 103 | print('Use uniform test set') 104 | samples_to_take = test_df.groupby('label').size().min() 105 | test_df = pd.concat([test_df[test_df.label == label].sample(samples_to_take) for label in test_df.label.unique()]) 106 | 107 | if not opt.train_supervised: 108 | val_df = val_df.sample(1000) 109 | test_df = test_df.sample(1000) 110 | 111 | if not opt.dataset == 'patchcam': 112 | train_df = clean_data(opt.data_input_dir, train_df) 113 | val_df = clean_data(opt.data_input_dir, val_df) 114 | test_df = clean_data(opt.data_input_dir, test_df) 115 | 116 | return train_df, val_df, test_df 117 | 118 | 119 | 120 | def get_camelyon_dataloader(opt): 121 | base_folder = opt.data_input_dir 122 | print('opt.data_input_dir: ', opt.data_input_dir) 123 | print('opt.data_input_dir_test: ', opt.data_input_dir_test) 124 | 125 | train_df, val_df, test_df = get_dataframes(opt) 126 | 127 | print("training patches: ", train_df.groupby('label').size()) 128 | print("Validation patches: ", val_df.groupby('label').size()) 129 | print("Test patches: ", test_df.groupby('label').size()) 130 | 131 | print("Saving training/val set to file") 132 | train_df.to_csv(f'{opt.log_path}/training_patches.csv', index=False) 133 | val_df.to_csv(f'{opt.log_path}/val_patches.csv', index=False) 134 | 135 | transform_train = get_transforms(opt, eval=False) 136 | transform_valid = get_transforms(opt, eval=True if opt.train_supervised else False) # we want augm in SSL training 137 | transform_test = get_transforms(opt, eval=True) 138 | 139 | if opt.dataset == 'cam': 140 | train_dataset = ImagePatchesDataset(opt, train_df, image_dir=base_folder, transform=transform_train) 141 | val_dataset = ImagePatchesDataset(opt, val_df, image_dir=base_folder, transform=transform_valid) 142 | test_dataset = ImagePatchesDataset(opt, test_df, image_dir=opt.data_input_dir_test, transform=transform_test) 143 | elif opt.dataset == 'skin': 144 | label_enum = { 145 | 'normal_dermis': 0, 146 | 'normal_epidermis': 1, 147 | 'normal_skinapp': 2, 148 | 'normal_subcut': 3, 149 | 'abnormal': 4, 150 | } 151 | train_dataset = ImagePatchesDataset(opt, train_df, image_dir=base_folder, transform=transform_train, label_enum=label_enum) 152 | val_dataset = ImagePatchesDataset(opt, val_df, image_dir=base_folder, transform=transform_valid, label_enum=label_enum) 153 | test_dataset = ImagePatchesDataset(opt, test_df, image_dir=opt.data_input_dir_test, transform=transform_test, label_enum=label_enum) 154 | 155 | # Weighted sampler to handle class imbalance 156 | print('Weighted validation sampler') 157 | val_sampler = get_weighted_sampler(val_dataset, num_samples=len(val_dataset)) 158 | 159 | if opt.train_supervised: 160 | print('Weighted training sampler') 161 | train_sampler = get_weighted_sampler(train_dataset, num_samples=len(train_dataset)) 162 | 163 | # default dataset loaders 164 | train_loader = torch.utils.data.DataLoader( 165 | train_dataset, 166 | batch_size=opt.batch_size_multiGPU, 167 | shuffle=True, 168 | sampler=None, 169 | num_workers=opt.num_workers, 170 | drop_last=True, 171 | ) 172 | 173 | if opt.train_supervised: 174 | train_loader = torch.utils.data.DataLoader( 175 | train_dataset, 176 | batch_size=opt.batch_size_multiGPU, 177 | sampler=train_sampler, 178 | shuffle=True if train_sampler is None else False, 179 | num_workers=opt.num_workers, 180 | drop_last=True, 181 | worker_init_fn=lambda id: np.random.seed(torch.initial_seed() // 2**32 + id) 182 | ) 183 | 184 | val_loader = torch.utils.data.DataLoader( 185 | val_dataset, 186 | batch_size=opt.batch_size_multiGPU//2, 187 | sampler=val_sampler, 188 | num_workers=opt.num_workers, 189 | drop_last=True, 190 | ) 191 | 192 | test_loader = torch.utils.data.DataLoader( 193 | test_dataset, 194 | batch_size=opt.batch_size_multiGPU//2, 195 | shuffle=False, 196 | num_workers=opt.num_workers, 197 | drop_last=True, 198 | ) 199 | 200 | return ( 201 | train_loader, 202 | train_dataset, 203 | val_loader, 204 | val_dataset, 205 | test_loader, 206 | test_dataset, 207 | ) 208 | 209 | def get_multidata_dataloader(opt): 210 | ''' 211 | Loads pathologydataset from lmdb files, performing augmentaions. 212 | Only supporting pretraining, as no labels are available 213 | args: 214 | opt (dict): Program/commandline arguments. 215 | Returns: 216 | dataloaders (): pretraindataloaders. 217 | ''' 218 | 219 | # Base train and test augmentaions 220 | aug = { 221 | "multidata": { 222 | "resize": None, 223 | "randcrop": None, 224 | "scale": opt.scale, 225 | "flip": True, 226 | "jitter_d": opt.rgb_jitter_d, 227 | "jitter_p": opt.rgb_jitter_p, 228 | "grayscale": opt.grayscale, 229 | "gaussian_blur": opt.rgb_gaussian_blur_p, 230 | "contrast": opt.rgb_contrast, 231 | "contrast_p": opt.rgb_contrast_p, 232 | "grid_distort": opt.rgb_grid_distort_p, 233 | "grid_shuffle": opt.rgb_grid_shuffle_p, 234 | "rotation": True, 235 | "mean": [0.4914, 0.4822, 0.4465], # values for train+unsupervised combined 236 | "std": [0.2023, 0.1994, 0.2010], 237 | "bw_mean": [0.4120], # values for train+unsupervised combined 238 | "bw_std": [0.2570], 239 | }, 240 | } 241 | 242 | #transform_train = transforms.Compose(torchvision_transforms(eval=False, aug=aug['multidata'])) 243 | transform_train = get_transforms(opt, eval=False) 244 | 245 | train_dataset = LmdbDataset(lmdb_path=opt.data_input_dir, 246 | transform=transform_train) 247 | 248 | train_loader = torch.utils.data.DataLoader(train_dataset, num_workers=opt.num_workers, 249 | pin_memory=True, drop_last=True, 250 | shuffle=True, 251 | batch_size=opt.batch_size_multiGPU) 252 | 253 | return ( 254 | train_loader, 255 | train_dataset, 256 | None, 257 | None, 258 | None, 259 | None, 260 | ) 261 | -------------------------------------------------------------------------------- /simclr/linear.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import json 4 | import random 5 | import numpy as np 6 | 7 | import pandas as pd 8 | import torch 9 | import torch.nn as nn 10 | import torch.optim as optim 11 | import torch.nn.functional as F 12 | from torch.nn.modules.loss import _WeightedLoss 13 | from torch.utils.data import DataLoader 14 | from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingLR 15 | from tqdm import tqdm 16 | 17 | from simclr.utils import distribute_over_GPUs, validate_arguments 18 | from simclr.model import Model, Identity 19 | from simclr.get_dataloader import get_dataloader 20 | 21 | torch.backends.cudnn.benchmark=True 22 | class Net(nn.Module): 23 | def __init__(self, opt): 24 | super(Net, self).__init__() 25 | 26 | # Load pre-trained model 27 | base_model = Model(pretrained=opt.pretrained) 28 | if (not opt.random) and (not opt.pretrained): 29 | print('Loading model from ', opt.model_path) 30 | base_model.load_state_dict(torch.load(opt.model_path, map_location=opt.device.type)['model'], strict=True) 31 | 32 | self.f = base_model.f 33 | 34 | # classifier 35 | self.fc = nn.Linear(opt.output_dims, opt.num_classes, bias=True) 36 | 37 | def forward(self, x): 38 | x = self.f(x) 39 | feature = torch.flatten(x, start_dim=1) 40 | out = self.fc(feature) 41 | return out 42 | 43 | 44 | # train or test for one epoch 45 | def train_val(net, data_loader, train_optimizer): 46 | is_train = train_optimizer is not None 47 | net.eval() # train only the last layers. 48 | #net.train() if is_train else net.eval() 49 | 50 | total_loss, total_correct, total_num, data_bar = 0.0, 0.0, 0, tqdm(data_loader) 51 | 52 | all_preds, all_labels, all_slides, all_outputs0, all_outputs1, all_patches = [], [], [], [], [], [] 53 | 54 | with (torch.enable_grad() if is_train else torch.no_grad()): 55 | for data, _, target, patch_id, slide_id in data_bar: 56 | data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True) 57 | out = net(data) 58 | loss = loss_criterion(out, target) 59 | 60 | if is_train: 61 | train_optimizer.zero_grad() 62 | loss.backward() 63 | train_optimizer.step() 64 | 65 | _, preds = torch.max(out.data, 1) 66 | 67 | all_preds.extend(preds.cpu().numpy()) 68 | all_labels.extend(target.cpu().data.numpy()) 69 | all_patches.extend(patch_id) 70 | all_slides.extend(slide_id) 71 | 72 | probs = torch.nn.functional.softmax(out.data, dim=1).cpu().numpy() 73 | all_outputs0.extend(probs[:, 0]) 74 | all_outputs1.extend(probs[:, 1]) 75 | 76 | total_num += data.size(0) 77 | total_loss += loss.item() * data.size(0) 78 | prediction = torch.argsort(out, dim=-1, descending=True) 79 | total_correct += torch.sum((prediction[:, 0:1] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item() 80 | 81 | data_bar.set_description(f'{"Train" if is_train else "Test"} Epoch: [{epoch}/{epochs}] Loss: {total_loss / total_num:.4f} ACC: {total_correct / total_num * 100:.2f}% ') 82 | 83 | 84 | df = pd.DataFrame({ 85 | 'label': all_labels, 86 | 'prediction': all_preds, 87 | 'slide_id': all_slides, 88 | 'patch_id': all_patches, 89 | 'probabilities_0': all_outputs0, 90 | 'probabilities_1': all_outputs1, 91 | }) 92 | 93 | return total_loss / total_num, total_correct / total_num * 100, df 94 | 95 | 96 | if __name__ == '__main__': 97 | parser = argparse.ArgumentParser(description='Linear Evaluation') 98 | group_modelset = parser.add_mutually_exclusive_group(required=True) 99 | group_modelset.add_argument('--model_path', type=str, default='results/128_0.5_512_500_model.pth', 100 | help='The pretrained model path') 101 | group_modelset.add_argument("--random", action="store_true", default=False, help="No pre-training, use random weights") 102 | group_modelset.add_argument("--pretrained", action="store_true", default=False, help="Use Imagenet pretrained Resnet") 103 | 104 | parser.add_argument('--batch_size', type=int, default=512, help='Number of images in each mini-batch') 105 | parser.add_argument('--epochs', type=int, default=100, help='Number of sweeps over the dataset to train') 106 | 107 | parser.add_argument('--training_data_csv', required=True, type=str, help='Path to file to use to read training data') 108 | parser.add_argument('--test_data_csv', required=True, type=str, help='Path to file to use to read test data') 109 | # For validation set, need to specify either csv or train/val split ratio 110 | group_validationset = parser.add_mutually_exclusive_group(required=True) 111 | group_validationset.add_argument('--validation_data_csv', type=str, help='Path to file to use to read validation data') 112 | group_validationset.add_argument('--trainingset_split', type=float, help='If not none, training csv with be split in train/val. Value between 0-1') 113 | parser.add_argument("--num_classes", type=int, default=2, help="Number of classes") 114 | 115 | parser.add_argument('--dataset', choices=['cam', 'patchcam', 'cam_rgb_hed', 'ovary', 'skin'], default='cam', type=str, help='Dataset') 116 | parser.add_argument('--data_input_dir', type=str, help='Base folder for images') 117 | parser.add_argument('--data_input_dir_test', type=str, required=False, help='Base folder for images') 118 | parser.add_argument('--save_dir', type=str, help='Path to save log') 119 | parser.add_argument('--save_after', type=int, default=1, help='Save model after every Nth epoch, default every epoch') 120 | parser.add_argument("--balanced_validation_set", action="store_true", default=False, help="Equal size of classes in validation AND test set",) 121 | 122 | parser.add_argument("--finetune", action="store_true", default=False, help="If true, pre-trained model weights will not be frozen.") 123 | parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate') 124 | parser.add_argument('--weight_decay', type=float, default=1e-6, help='Weight decay (l2 reg)') 125 | parser.add_argument("--model_to_save", choices=['best', 'latest'], default='latest', type=str, help='Save latest or best (based on val acc)') 126 | parser.add_argument('--seed', type=int, default=44, help='seed') 127 | 128 | 129 | parser.add_argument("--use_album", action="store_true", default=False, help="use Albumentations as augmentation lib",) 130 | parser.add_argument("--balanced_training_set", action="store_true", default=False, help="Equal size of classes in train - SUPERVISED!") 131 | 132 | 133 | parser.add_argument("--optimizer", type=str, default='adam', choices=['adam', 'sgd'], help="Choice of optimizer") 134 | 135 | # Common augmentations 136 | parser.add_argument("--image_size", type=int, default=224) 137 | parser.add_argument("--scale", nargs=2, type=float, default=[0.2, 1.0]) 138 | # RGB augmentations 139 | parser.add_argument("--rgb_gaussian_blur_p", type=float, default=0, help="probability of using gaussian blur (only on rgb)" ) 140 | parser.add_argument("--rgb_jitter_d", type=float, default=1, help="color jitter 0.8*d, val 0.2*d (only on rgb)" ) 141 | parser.add_argument("--rgb_jitter_p", type=float, default=0.8, help="probability of using color jitter(only on rgb)" ) 142 | parser.add_argument("--rgb_contrast", type=float, default=0.2, help="value of contrast (rgb only)") 143 | parser.add_argument("--rgb_contrast_p", type=float, default=0, help="prob of using contrast (rgb only)") 144 | parser.add_argument("--rgb_grid_distort_p", type=float, default=0, help="probability of using grid distort (only on rgb)" ) 145 | parser.add_argument("--rgb_grid_shuffle_p", type=float, default=0, help="probability of using grid shuffle (only on rgb)" ) 146 | 147 | 148 | opt = validate_arguments(parser.parse_args()) 149 | 150 | 151 | opt.output_dims = 2048 152 | 153 | 154 | is_windows = True if os.name == 'nt' else False 155 | opt.num_workers = 0 if is_windows else 16 156 | 157 | opt.train_supervised = True 158 | opt.grayscale = False 159 | 160 | if not os.path.exists(opt.save_dir): 161 | os.makedirs(opt.save_dir, exist_ok=True) 162 | opt.log_path = opt.save_dir 163 | 164 | # Write the parameters used to run experiment to file 165 | with open(f'{opt.log_path}/metadata_train.txt', 'w') as metadata_file: 166 | metadata_file.write(json.dumps(vars(opt))) 167 | 168 | opt.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 169 | print('Device:', opt.device) 170 | 171 | seed = opt.seed 172 | random.seed(seed) 173 | np.random.seed(seed) 174 | torch.manual_seed(seed) 175 | torch.cuda.manual_seed(seed) 176 | torch.cuda.manual_seed_all(seed) 177 | 178 | model_path, batch_size, epochs = opt.model_path, opt.batch_size, opt.epochs 179 | 180 | model = Net(opt) 181 | model, num_GPU = distribute_over_GPUs(opt, model) 182 | 183 | train_loader, train_data, val_loader, val_data, test_loader, test_data = get_dataloader(opt) 184 | 185 | if not opt.finetune: 186 | for param in model.module.f.parameters(): 187 | param.requires_grad = False 188 | 189 | if opt.optimizer == 'adam': 190 | optimizer = optim.Adam(model.parameters(), lr=opt.lr, weight_decay=opt.weight_decay) 191 | elif opt.optimizer == 'sgd': 192 | optimizer = optim.SGD(model.parameters(), lr=opt.lr, weight_decay=opt.weight_decay, 193 | momentum=0.9, nesterov=True) 194 | 195 | scheduler = CosineAnnealingLR(optimizer, opt.epochs) 196 | 197 | loss_criterion = nn.CrossEntropyLoss() 198 | results = {'train_loss': [], 'train_acc': [], 199 | 'val_loss': [], 'val_acc': []} 200 | 201 | best_acc = 0.0 202 | for epoch in range(1, epochs + 1): 203 | train_loss, train_acc, _ = train_val(model, train_loader, optimizer) 204 | results['train_loss'].append(train_loss) 205 | results['train_acc'].append(train_acc) 206 | val_loss, val_acc, _ = train_val(model, val_loader, None) 207 | results['val_loss'].append(val_loss) 208 | results['val_acc'].append(val_acc) 209 | 210 | scheduler.step() 211 | 212 | # save statistics 213 | data_frame = pd.DataFrame(data=results, index=range(1, epoch + 1)) 214 | data_frame.to_csv(f'{opt.log_path}/linear_statistics.csv', index_label='epoch') 215 | 216 | if opt.model_to_save == 'best' and val_acc > best_acc: 217 | # Save only the if the accuracy exceeds previous accuracy 218 | best_acc = val_acc 219 | torch.save(model.state_dict(), f'{opt.log_path}/linear_model.pth') 220 | elif opt.model_to_save == 'latest': 221 | # Save latest model 222 | best_acc = val_acc 223 | torch.save(model.state_dict(), f'{opt.log_path}/linear_model.pth') 224 | 225 | # trainig finished, run test 226 | print('Training finished, testing started...') 227 | # Load saved model 228 | model.load_state_dict(torch.load(f'{opt.log_path}/linear_model.pth')) 229 | model.eval() 230 | test_loss, test_acc, df = train_val(model, test_loader, None) 231 | 232 | df.to_csv( 233 | f"{opt.log_path}/inference_result_model.csv") 234 | 235 | 236 | 237 | 238 | 239 | -------------------------------------------------------------------------------- /simclr/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import json 4 | import numpy as np 5 | import pandas as pd 6 | import torch 7 | import torch.optim as optim 8 | import torch.nn as nn 9 | from torch.cuda import amp 10 | 11 | from torchvision import transforms 12 | from torch.utils.data import DataLoader 13 | from tqdm import tqdm 14 | 15 | from simclr.utils import distribute_over_GPUs, reload_weights, validate_arguments 16 | from simclr.get_dataloader import get_dataloader 17 | from simclr.model import Model 18 | from simclr.optimisers import LARS 19 | from simclr.augmentations import Denormalize 20 | 21 | torch.backends.cudnn.benchmark=True 22 | 23 | # train for one epoch to learn unique features 24 | def train(net, data_loader, train_optimizer, scaler, opt, epoch): 25 | denom_transform = Denormalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]) 26 | net.train() 27 | total_loss, total_num, train_bar = 0.0, 0, tqdm(data_loader) 28 | 29 | neg_cosine_sim = {0: [], 1:[]} 30 | pos_cosine_sim = {0: [], 1:[]} 31 | 32 | 33 | for step, data in enumerate(train_bar): 34 | pos_1, pos_2 = data[0].cuda(non_blocking=True), data[1].cuda(non_blocking=True) 35 | B, C, W, H = pos_1.shape 36 | 37 | with amp.autocast(): 38 | feature_1, out_1 = net(pos_1) 39 | feature_2, out_2 = net(pos_2) 40 | 41 | # [2*B, D] 42 | out = torch.cat([out_1, out_2], dim=0) 43 | # [2*B, 2*B] 44 | cosine_sim_neg = torch.mm(out, out.t().contiguous()) 45 | sim_matrix = torch.exp(cosine_sim_neg / temperature) 46 | 47 | mask_negatives = (torch.ones_like(sim_matrix) - torch.eye(opt.batch_size_multiGPU, device=sim_matrix.device).repeat(2,2)).bool() # for logging 48 | # [2*B, 2*B-2] 49 | sim_matrix_neg = sim_matrix.masked_select(mask_negatives).view(2 * opt.batch_size_multiGPU, -1) 50 | 51 | # log negative cosine dist 52 | cosine_mask = cosine_sim_neg.masked_select(torch.triu(mask_negatives)) 53 | negatives = cosine_mask[torch.randint(2 * B, (B,))].detach().cpu().numpy().ravel() 54 | neg_cosine_sim[0].extend(negatives) # randomly take B samples 55 | 56 | # compute loss 57 | cosine_sim = torch.sum(out_1 * out_2, dim=-1) 58 | pos_sim = torch.exp(cosine_sim / temperature) 59 | # log positive cosine dist 60 | pos_cosine_sim[0].extend(cosine_sim.detach().cpu().numpy()) 61 | 62 | # [2*B] 63 | pos_sim = torch.cat([pos_sim, pos_sim], dim=0) 64 | # loss = (- torch.log(pos_sim / sim_matrix.sum(dim=-1))).mean() 65 | loss = (- torch.log(pos_sim / (pos_sim + sim_matrix_neg.sum(dim=-1)))).mean() 66 | 67 | train_optimizer.zero_grad() 68 | scaler.scale(loss).backward() 69 | scaler.step(train_optimizer) 70 | scaler.update() 71 | 72 | total_num += opt.batch_size_multiGPU 73 | total_loss += loss.item() * opt.batch_size_multiGPU 74 | train_bar.set_description('Train Epoch: [{}/{}] Loss: {:.4f}'.format(epoch, epochs, total_loss / total_num)) 75 | 76 | return total_loss / total_num 77 | 78 | 79 | def test(net, test_data_loader, opt, epoch): 80 | denom_transform = Denormalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]) 81 | 82 | neg_cosine_sim = {0: [], 1:[]} 83 | pos_cosine_sim = {0: [], 1:[]} 84 | 85 | net.eval() 86 | with torch.no_grad(): 87 | test_bar = tqdm(test_data_loader) 88 | 89 | for step, data in enumerate(test_bar): 90 | pos_1 = data[0].cuda(non_blocking=True) 91 | pos_2 = data[1].cuda(non_blocking=True) 92 | 93 | B, C, W, H = pos_1.shape 94 | 95 | feat_1, out_1 = net(pos_1) 96 | feat_2, out_2 = net(pos_2) 97 | 98 | # log negative cosine dist 99 | out = torch.cat([out_1, out_2], dim=0) 100 | cosine_sim_neg = torch.mm(out, out.t().contiguous()) 101 | mask_negatives = (torch.ones_like(cosine_sim_neg) - torch.eye(B, device=cosine_sim_neg.device).repeat(2,2)).bool() # for logging 102 | 103 | cosine_mask = cosine_sim_neg.masked_select(torch.triu(mask_negatives)) 104 | negatives = cosine_mask[torch.randint(2 * B, (B,))].detach().cpu().numpy().ravel() 105 | neg_cosine_sim[0].extend(negatives) # randomly take B samples 106 | 107 | # log positive cosine dist 108 | cosine_sim = torch.sum(out_1 * out_2, dim=-1) 109 | pos_sim = torch.exp(cosine_sim / opt.temperature) 110 | pos_cosine_sim[0].extend(cosine_sim.detach().cpu().numpy()) 111 | 112 | 113 | 114 | if __name__ == '__main__': 115 | parser = argparse.ArgumentParser(description='Train SimCLR') 116 | parser.add_argument('--feature_dim', default=128, type=int, help='Feature dim for latent vector') 117 | parser.add_argument('--temperature', default=0.5, type=float, help='Temperature used in softmax') 118 | parser.add_argument('--batch_size', default=512, type=int, help='Number of images in each mini-batch') 119 | parser.add_argument('--epochs', default=500, type=int, help='Number of sweeps over the dataset to train') 120 | parser.add_argument('--lr', default=1e-3, type=float, help='Learning rate') 121 | 122 | parser.add_argument('--load_checkpoint_dir', default=None, 123 | help='Path to Load Pre-trained Model From.') 124 | parser.add_argument('--start_epoch', default=0, type=int, 125 | help='Epoch to start from when cont. training (affects optimizer)') 126 | 127 | parser.add_argument('--training_data_csv', required=True, type=str, help='Path to file to use to read training data') 128 | parser.add_argument('--test_data_csv', required=True, type=str, help='Path to file to use to read test data') 129 | # For validation set, need to specify either csv or train/val split ratio 130 | group_validationset = parser.add_mutually_exclusive_group(required=True) 131 | group_validationset.add_argument('--validation_data_csv', type=str, help='Path to file to use to read validation data') 132 | group_validationset.add_argument('--trainingset_split', type=float, help='If not none, training csv with be split in train/val. Value between 0-1') 133 | 134 | parser.add_argument('--dataset', choices=['cam', 'multidata', 'skin'], default='cam', type=str, help='Dataset') 135 | parser.add_argument('--data_input_dir', type=str, help='Base folder for images') 136 | parser.add_argument('--data_input_dir_test', type=str, required=False, help='Base folder for images') 137 | parser.add_argument('--save_dir', type=str, help='Path to save log') 138 | parser.add_argument('--save_after', type=int, default=1, help='Save model after every Nth epoch, default every epoch') 139 | parser.add_argument("--balanced_validation_set", action="store_true", default=False, help="Equal size of classes in validation AND test set",) 140 | 141 | parser.add_argument("--optimizer", choices=['adam', 'lars'], default='adam', help="Optimizer to use",) 142 | 143 | parser.add_argument("--use_album", action="store_true", default=False, help="use Albumentations as augmentation lib",) 144 | parser.add_argument("--balanced_training_set", action="store_true", default=False, help="Equal size of classes in train - SUPERVISED!") 145 | parser.add_argument("--pretrained", action="store_true", default=False, help="If true, use Imagenet pretrained resnet backbone") 146 | 147 | # Common augmentations 148 | parser.add_argument("--image_size", type=int, default=224) 149 | parser.add_argument("--scale", nargs=2, type=float, default=[0.2, 1.0]) 150 | 151 | # RGB augmentations 152 | parser.add_argument("--rgb_gaussian_blur_p", type=float, default=0, help="probability of using gaussian blur (only on rgb)" ) 153 | parser.add_argument("--rgb_jitter_d", type=float, default=1, help="color jitter 0.8*d, val 0.2*d (only on rgb)" ) 154 | parser.add_argument("--rgb_jitter_p", type=float, default=0.8, help="probability of using color jitter(only on rgb)" ) 155 | parser.add_argument("--rgb_contrast", type=float, default=0.2, help="value of contrast (rgb only)") 156 | parser.add_argument("--rgb_contrast_p", type=float, default=0, help="prob of using contrast (rgb only)") 157 | parser.add_argument("--rgb_grid_distort_p", type=float, default=0, help="probability of using grid distort (only on rgb)" ) 158 | parser.add_argument("--rgb_grid_shuffle_p", type=float, default=0, help="probability of using grid shuffle (only on rgb)" ) 159 | 160 | # args parse 161 | opt = validate_arguments(parser.parse_args()) 162 | 163 | feature_dim, temperature = opt.feature_dim, opt.temperature 164 | batch_size, epochs = opt.batch_size, opt.epochs - opt.start_epoch 165 | 166 | is_windows = True if os.name == 'nt' else False 167 | opt.num_workers = 0 if is_windows else 16 168 | 169 | opt.train_supervised = False 170 | opt.grayscale = False 171 | 172 | if not os.path.exists(opt.save_dir): 173 | os.makedirs(opt.save_dir, exist_ok=True) 174 | opt.log_path = opt.save_dir 175 | 176 | # Write the parameters used to run experiment to file 177 | with open(f'{opt.log_path}/metadata_train.txt', 'w') as metadata_file: 178 | metadata_file.write(json.dumps(vars(opt))) 179 | 180 | opt.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 181 | print('Device:', opt.device) 182 | 183 | # model setup and optimizer config 184 | scaler = amp.GradScaler() 185 | 186 | model = Model(feature_dim, pretrained=opt.pretrained) 187 | if opt.optimizer == 'adam': 188 | optimizer = optim.Adam(model.parameters(), lr=opt.lr, weight_decay=1e-6) 189 | elif opt.optimizer == 'lars': 190 | params_models = [] 191 | reduced_params = [] 192 | removed_params = [] 193 | 194 | skip_lists = ['bn', 'bias'] 195 | 196 | m_skip = [] 197 | m_noskip = [] 198 | params_models += list(model.parameters()) 199 | 200 | for name, param in model.named_parameters(): 201 | if (any(skip_name in name for skip_name in skip_lists)): 202 | m_skip.append(param) 203 | else: 204 | m_noskip.append(param) 205 | reduced_params += list(m_noskip) 206 | removed_params += list(m_skip) 207 | print("reduced_params len: {}".format(len(reduced_params))) 208 | print("removed_params len: {}".format(len(removed_params))) 209 | optimizer = LARS(reduced_params+removed_params, lr=opt.lr, 210 | weight_decay=1e-6, eta=0.001, use_nesterov=False, len_reduced=len(reduced_params)) 211 | 212 | model.to(opt.device) 213 | 214 | if opt.load_checkpoint_dir: 215 | print('Loading model from: ', opt.load_checkpoint_dir) 216 | model, optimizer = reload_weights( 217 | opt, model, optimizer 218 | ) 219 | 220 | model, num_GPU = distribute_over_GPUs(opt, model) 221 | 222 | train_loader, train_dataset, val_loader, val_dataset, _, _ = get_dataloader(opt) 223 | 224 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(train_loader), eta_min=0, last_epoch=-1) 225 | 226 | if opt.start_epoch > 0: 227 | print('Moving scheduler ahead') 228 | for _ in range(opt.start_epoch): 229 | scheduler.step() 230 | 231 | # training loop 232 | results = {'train_loss': []} 233 | save_name_pre = f'{feature_dim}_{temperature}_{batch_size}_{epochs}' 234 | best_acc = 0.0 235 | for epoch in range(1, epochs + 1): 236 | 237 | train_loss = train(model, train_loader, optimizer, scaler, opt, epoch) 238 | scheduler.step() 239 | results['train_loss'].append(train_loss) 240 | 241 | if val_loader is not None: 242 | test(model, val_loader, opt, epoch) 243 | # save statistics 244 | data_frame = pd.DataFrame(data=results, index=range(1, epoch + 1)) 245 | data_frame.to_csv(f'{opt.save_dir}/{save_name_pre}_statistics.csv', index_label='epoch') 246 | 247 | ## Save model 248 | if epoch % opt.save_after == 0: 249 | state = { 250 | #'args': args, 251 | 'model': model.module.state_dict(), 252 | 'optimiser': optimizer.state_dict(), 253 | 'epoch': epoch, 254 | } 255 | torch.save(state, f'{opt.log_path}/{save_name_pre}_model_{epoch}.pth') 256 | # Delete old ones, save latest, keep every 10th 257 | if (epoch - 1) % 10 != 0: 258 | try: 259 | os.remove(f'{opt.log_path}/{save_name_pre}_model_{epoch - 1}.pth') 260 | except: 261 | print("not enough models there yet, nothing to delete") 262 | -------------------------------------------------------------------------------- /simclr/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.cuda import amp 5 | from torchvision.models import resnet50 6 | 7 | class Identity(nn.Module): 8 | def __init__(self): 9 | super(Identity, self).__init__() 10 | 11 | def forward(self, x): 12 | return x 13 | 14 | class Model(nn.Module): 15 | def __init__(self, feature_dim=128, pretrained=False): 16 | super(Model, self).__init__() 17 | 18 | self.f = resnet50(pretrained=pretrained) 19 | self.f.fc = Identity() 20 | # projection head 21 | self.g = nn.Sequential(nn.Linear(2048, 512, bias=False), 22 | nn.BatchNorm1d(512), 23 | nn.ReLU(inplace=True), 24 | nn.Linear(512, feature_dim, bias=True)) 25 | 26 | @amp.autocast() 27 | def forward(self, x): 28 | x = self.f(x) 29 | feature = torch.flatten(x, start_dim=1) 30 | out = self.g(feature) 31 | return F.normalize(feature, dim=-1), F.normalize(out, dim=-1) 32 | -------------------------------------------------------------------------------- /simclr/optimisers.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import torch.optim as optim 4 | from torch.optim.optimizer import Optimizer, required 5 | 6 | 7 | class LARS(Optimizer): 8 | """ 9 | Layer-wise adaptive rate scaling 10 | 11 | - Converted from Tensorflow to Pytorch from: 12 | 13 | https://github.com/google-research/simclr/blob/master/lars_optimizer.py 14 | 15 | - Based on: 16 | 17 | https://github.com/noahgolmant/pytorch-lars 18 | 19 | params (iterable): iterable of parameters to optimize or dicts defining 20 | parameter groups 21 | lr (float): base learning rate (\gamma_0) 22 | 23 | lr (int): Length / Number of layers we want to apply weight decay, else do not compute 24 | 25 | momentum (float, optional): momentum factor (default: 0.9) 26 | 27 | use_nesterov (bool, optional): flag to use nesterov momentum (default: False) 28 | 29 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0.0) 30 | ("\beta") 31 | 32 | eta (float, optional): LARS coefficient (default: 0.001) 33 | 34 | - Based on Algorithm 1 of the following paper by You, Gitman, and Ginsburg. 35 | 36 | - Large Batch Training of Convolutional Networks: 37 | https://arxiv.org/abs/1708.03888 38 | 39 | """ 40 | 41 | def __init__(self, params, lr, len_reduced, momentum=0.9, use_nesterov=False, weight_decay=0.0, classic_momentum=True, eta=0.001): 42 | 43 | self.epoch = 0 44 | defaults = dict( 45 | lr=lr, 46 | momentum=momentum, 47 | use_nesterov=use_nesterov, 48 | weight_decay=weight_decay, 49 | classic_momentum=classic_momentum, 50 | eta=eta, 51 | len_reduced=len_reduced 52 | ) 53 | 54 | super(LARS, self).__init__(params, defaults) 55 | self.lr = lr 56 | self.momentum = momentum 57 | self.weight_decay = weight_decay 58 | self.use_nesterov = use_nesterov 59 | self.classic_momentum = classic_momentum 60 | self.eta = eta 61 | self.len_reduced = len_reduced 62 | 63 | def step(self, epoch=None, closure=None): 64 | 65 | loss = None 66 | 67 | if closure is not None: 68 | loss = closure() 69 | 70 | if epoch is None: 71 | epoch = self.epoch 72 | self.epoch += 1 73 | 74 | for group in self.param_groups: 75 | weight_decay = group['weight_decay'] 76 | momentum = group['momentum'] 77 | eta = group['eta'] 78 | learning_rate = group['lr'] 79 | 80 | # TODO: Hacky 81 | counter = 0 82 | for p in group['params']: 83 | if p.grad is None: 84 | continue 85 | 86 | param = p.data 87 | grad = p.grad.data 88 | 89 | param_state = self.state[p] 90 | 91 | # TODO: This really hacky way needs to be improved. 92 | 93 | # Note Excluded are passed at the end of the list to are ignored 94 | if counter < self.len_reduced: 95 | grad += self.weight_decay * param 96 | 97 | # Create parameter for the momentum 98 | if "momentum_var" not in param_state: 99 | next_v = param_state["momentum_var"] = torch.zeros_like( 100 | p.data 101 | ) 102 | else: 103 | next_v = param_state["momentum_var"] 104 | 105 | if self.classic_momentum: 106 | trust_ratio = 1.0 107 | 108 | # TODO: implementation of layer adaptation 109 | w_norm = torch.norm(param) 110 | g_norm = torch.norm(grad) 111 | 112 | device = g_norm.get_device() 113 | 114 | trust_ratio = torch.where(w_norm.ge(0), torch.where( 115 | g_norm.ge(0), (self.eta * w_norm / g_norm), torch.Tensor([1.0]).to(device)), torch.Tensor([1.0]).to(device)).item() 116 | 117 | scaled_lr = learning_rate * trust_ratio 118 | 119 | next_v.mul_(momentum).add_(scaled_lr, grad) 120 | 121 | if self.use_nesterov: 122 | update = (self.momentum * next_v) + (scaled_lr * grad) 123 | else: 124 | update = next_v 125 | 126 | p.data.add_(-update) 127 | 128 | # Not classic_momentum 129 | else: 130 | 131 | next_v.mul_(momentum).add_(grad) 132 | 133 | if self.use_nesterov: 134 | update = (self.momentum * next_v) + (grad) 135 | 136 | else: 137 | update = next_v 138 | 139 | trust_ratio = 1.0 140 | 141 | # TODO: implementation of layer adaptation 142 | w_norm = torch.norm(param) 143 | v_norm = torch.norm(update) 144 | 145 | device = v_norm.get_device() 146 | 147 | trust_ratio = torch.where(w_norm.ge(0), torch.where( 148 | v_norm.ge(0), (self.eta * w_norm / v_norm), torch.Tensor([1.0]).to(device)), torch.Tensor([1.0]).to(device)).item() 149 | 150 | scaled_lr = learning_rate * trust_ratio 151 | 152 | p.data.add_(-scaled_lr * update) 153 | 154 | counter += 1 155 | 156 | return loss 157 | -------------------------------------------------------------------------------- /simclr/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | from PIL import Image, ImageFilter 3 | import torch 4 | import torch.nn as nn 5 | from torchvision import transforms 6 | import torchvision.transforms.functional as F 7 | 8 | 9 | def reload_weights(args, model, optimizer): 10 | # Load the pretrained model 11 | print(args.device.type) 12 | checkpoint = torch.load(args.load_checkpoint_dir, map_location=args.device.type) 13 | 14 | ## reload weights for training of the linear classifier 15 | if 'model' in checkpoint.keys(): 16 | model.load_state_dict(checkpoint['model']) 17 | else: 18 | model.load_state_dict(checkpoint) 19 | ## reload weights and optimizers for continuing training 20 | if args.start_epoch > 0: 21 | print("Continuing training from epoch ", args.start_epoch) 22 | 23 | try: 24 | optimizer.load_state_dict(checkpoint['optimiser']) 25 | except KeyError: 26 | raise KeyError('Sry, no optimizer saved. Set start_epoch=0 to start from pretrained weights') 27 | 28 | return model, optimizer 29 | 30 | 31 | def distribute_over_GPUs(opt, model): 32 | ## distribute over GPUs 33 | if opt.device.type != "cpu": 34 | model = nn.DataParallel(model) 35 | num_GPU = torch.cuda.device_count() 36 | opt.batch_size_multiGPU = opt.batch_size * num_GPU 37 | else: 38 | model = nn.DataParallel(model) 39 | opt.batch_size_multiGPU = opt.batch_size 40 | 41 | model = model.to(opt.device) 42 | print("Let's use", num_GPU, "GPUs!") 43 | 44 | return model, num_GPU 45 | 46 | def validate_arguments(opt): 47 | if not opt.use_album: 48 | # Albumnetations are needed if these augmentations are to be used 49 | if (opt.rgb_grid_distort_p > 0): 50 | raise ValueError('Grid distort needs use_album to be true') 51 | if (opt.rgb_grid_shuffle_p > 0): 52 | raise ValueError('Grid shuffle needs use_album to be true') 53 | if (opt.rgb_contrast_p > 0): 54 | raise ValueError('Contrast needs use_album to be true') 55 | 56 | if not opt.data_input_dir_test: 57 | opt.data_input_dir_test = opt.data_input_dir 58 | 59 | return opt 60 | --------------------------------------------------------------------------------