├── LICENSE ├── README.md ├── assets └── teaser.png ├── ramp-submission ├── estimator.py └── estimator_mse.py └── src ├── data ├── __init__.py ├── masks │ ├── cat12vbm_space-MNI152_desc-gm_TPM.nii.gz │ └── quasiraw_space-MNI152_desc-brain_T1w.nii.gz ├── openbhb.py └── transforms.py ├── exp ├── mae.yaml ├── supcon_adam_kernel.yaml └── supcon_sgd_kernel.yaml ├── figures ├── ablation.csv ├── ablation.pdf └── ablation.py ├── launcher.py ├── losses.py ├── main_infonce.py ├── main_mse.py ├── models ├── __init__.py ├── alexnet3d.py ├── densenet3d.py ├── estimators.py └── resnet3d.py └── util.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 EIDOSLAB 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Contrastive learning for regression in multi-site brain age prediction 2 | 3 | Carlo Alberto Barbano1,2, Benoit Dufumier1,3, Edouard Duchesnay3, Marco Grangetto2, Pietro Gori1 | [[pdf](https://arxiv.org/pdf/2211.08326.pdf)] [[poster](https://drive.google.com/file/d/1gr45EamhVVClPbMT5T5b1Gy9V50fgw3c/view)] 4 | 5 | 1LTCI, Télécom Paris, IP Paris
6 | 2University of Turin, Computer Science dept.
7 | 3NeuroSpin, CEA, Universite Paris-Saclay 8 |

9 | 10 | ![asd](assets/teaser.png) 11 | 12 | Building accurate Deep Learning (DL) models for brain age prediction is a very relevant topic in neuroimaging, as it could help better understand neurodegenerative disorders and find new biomarkers. To estimate accurate and generalizable models, large datasets have been collected, which are often multi-site and multi-scanner. This large heterogeneity negatively affects the generalization performance of DL models since they are prone to overfit site-related noise. Recently, contrastive learning approaches have been shown to be more robust against noise in data or labels. For this reason, we propose a novel contrastive learning regression loss for robust brain age prediction using MRI scans. Our method achieves state-of-the-art performance on the OpenBHB challenge, yielding the best generalization capability and robustness to site-related noise. 13 | 14 | 15 | ## Running 16 | 17 | ### Training 18 | 19 | The code can be found in the src folder. For training there is a couple of different files: 20 | 21 | - `main_mse.py`: for training baseline MSE/MAE models 22 | - `main_infonce.py`: for training models with contrastive losses 23 | 24 | For easiness of use, the script `launcher.py` is provided with some predefined experiments which can be found in `src/exp` as YAML template. To launch: 25 | 26 | ``` 27 | python3 launcher.py exp/mae.yaml 28 | ``` 29 | 30 | ### Testing on the leaderboard 31 | 32 | To test on the official leaderboard of the OpenBHB challenge, first you need to create an account at [https://ramp.studio/](https://ramp.studio/). For the submission to the challenge ([https://ramp.studio/events/brain_age_with_site_removal_open_2022](https://ramp.studio/events/brain_age_with_site_removal_open_2022)), the source code for submission can be found in the `ramp-submission` folder (code for both supervised and contrastive models). 33 | 34 | ## Citing 35 | 36 | For citing our work, please use the following bibtex entry: 37 | 38 | ```bibtex 39 | @inproceedings{barbano2023contrastive, 40 | author = {Barbano, Carlo Alberto and Dufumier, Benoit and Duchesnay, Edouard and Grangetto, Marco and Gori, Pietro}, 41 | journal = {International Symposium on Biomedical Imaging (ISBI)}, 42 | title = {Contrastive learning for regression in multi-site brain age prediction}, 43 | year = {2023} 44 | } 45 | ``` -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EIDOSLAB/contrastive-brain-age-prediction/2fe9e7b81dd53d8f43dfeb34e41250f5450c1094/assets/teaser.png -------------------------------------------------------------------------------- /ramp-submission/estimator.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | ########################################################################## 3 | # Code version 6207dffcc20f461bdb742f5d8a2f6641483b9d83 4 | ########################################################################## 5 | 6 | 7 | """ 8 | Each solution to be tested should be stored in its own directory within 9 | submissions/. The name of this new directory will serve as the ID for 10 | the submission. If you wish to launch a RAMP challenge you will need to 11 | provide an example solution within submissions/starting_kit/. Even if 12 | you are not launching a RAMP challenge on RAMP Studio, it is useful to 13 | have an example submission as it shows which files are required, how they 14 | need to be named and how each file should be structured. 15 | """ 16 | 17 | # Filename: estimator.py 18 | # Run id: 19 | # 20 | import os 21 | ARCHITECTURE = os.environ.get("ARCHITECTURE", "resnet18") 22 | 23 | from collections import OrderedDict 24 | from abc import ABCMeta 25 | import progressbar 26 | import nibabel 27 | import numpy as np 28 | from nilearn.masking import unmask 29 | from sklearn.base import BaseEstimator 30 | from sklearn.base import TransformerMixin 31 | from sklearn.pipeline import Pipeline, make_pipeline 32 | import torch 33 | import torch.nn as nn 34 | import torch.nn.functional as F 35 | import torch.utils.checkpoint as cp 36 | from torchvision import transforms 37 | import math 38 | 39 | 40 | ############################################################################ 41 | # Define here some selectors 42 | ############################################################################ 43 | 44 | class FeatureExtractor(BaseEstimator, TransformerMixin): 45 | """ Select only the requested data associatedd features from the the 46 | input buffered data. 47 | """ 48 | MODALITIES = OrderedDict([ 49 | ("vbm", { 50 | "shape": (1, 121, 145, 121), 51 | "size": 519945}), 52 | ("quasiraw", { 53 | "shape": (1, 182, 218, 182), 54 | "size": 1827095}), 55 | ("xhemi", { 56 | "shape": (8, 163842), 57 | "size": 1310736}), 58 | ("vbm_roi", { 59 | "shape": (1, 284), 60 | "size": 284}), 61 | ("desikan_roi", { 62 | "shape": (7, 68), 63 | "size": 476}), 64 | ("destrieux_roi", { 65 | "shape": (7, 148), 66 | "size": 1036}) 67 | ]) 68 | MASKS = { 69 | "vbm": { 70 | "path": None, 71 | "thr": 0.05}, 72 | "quasiraw": { 73 | "path": None, 74 | "thr": 0} 75 | } 76 | 77 | def __init__(self, dtype, mock=False): 78 | """ Init class. 79 | Parameters 80 | ---------- 81 | dtype: str 82 | the requested data: 'vbm', 'quasiraw', 'vbm_roi', 'desikan_roi', 83 | 'destrieux_roi' or 'xhemi'. 84 | """ 85 | if dtype not in self.MODALITIES: 86 | raise ValueError("Invalid input data type.") 87 | self.dtype = dtype 88 | 89 | data_types = list(self.MODALITIES.keys()) 90 | index = data_types.index(dtype) 91 | 92 | cumsum = np.cumsum([item["size"] for item in self.MODALITIES.values()]) 93 | 94 | if index > 0: 95 | self.start = cumsum[index - 1] 96 | else: 97 | self.start = 0 98 | self.stop = cumsum[index] 99 | 100 | self.masks = dict((key, val["path"]) for key, val in self.MASKS.items()) 101 | self.masks["vbm"] = os.environ.get("VBM_MASK") 102 | self.masks["quasiraw"] = os.environ.get("QUASIRAW_MASK") 103 | 104 | self.mock = mock 105 | if mock: 106 | return 107 | 108 | for key in self.masks: 109 | if self.masks[key] is None or not os.path.isfile(self.masks[key]): 110 | raise ValueError("Impossible to find mask:", key, self.masks[key]) 111 | arr = nibabel.load(self.masks[key]).get_fdata() 112 | thr = self.MASKS[key]["thr"] 113 | arr[arr <= thr] = 0 114 | arr[arr > thr] = 1 115 | self.masks[key] = nibabel.Nifti1Image(arr.astype(int), np.eye(4)) 116 | 117 | def fit(self, X, y): 118 | return self 119 | 120 | def transform(self, X): 121 | if self.mock: 122 | #print("transforming", X.shape) 123 | data = X.reshape(self.MODALITIES[self.dtype]["shape"]) 124 | #print("mock data:", data.shape) 125 | return data 126 | 127 | # print(X.shape) 128 | select_X = X[self.start:self.stop] 129 | if self.dtype in ("vbm", "quasiraw"): 130 | im = unmask(select_X, self.masks[self.dtype]) 131 | select_X = im.get_fdata() 132 | select_X = select_X.transpose(2, 0, 1) 133 | select_X = select_X.reshape(self.MODALITIES[self.dtype]["shape"]) 134 | return select_X 135 | 136 | class Crop(object): 137 | """ Crop the given n-dimensional array either at a random location or 138 | centered. 139 | """ 140 | def __init__(self, shape, type="center", keep_dim=False): 141 | assert type in ["center", "random"] 142 | self.shape = shape 143 | self.copping_type = type 144 | self.keep_dim = keep_dim 145 | 146 | def __call__(self, X): 147 | img_shape = np.array(X.shape) 148 | 149 | if type(self.shape) == int: 150 | size = [self.shape for _ in range(len(self.shape))] 151 | else: 152 | size = np.copy(self.shape) 153 | 154 | # print('img_shape:', img_shape, 'size', size) 155 | 156 | indexes = [] 157 | for ndim in range(len(img_shape)): 158 | if size[ndim] > img_shape[ndim] or size[ndim] < 0: 159 | size[ndim] = img_shape[ndim] 160 | 161 | if self.copping_type == "center": 162 | delta_before = int((img_shape[ndim] - size[ndim]) / 2.0) 163 | 164 | elif self.copping_type == "random": 165 | delta_before = np.random.randint(0, img_shape[ndim] - size[ndim] + 1) 166 | 167 | indexes.append(slice(delta_before, delta_before + size[ndim])) 168 | 169 | if self.keep_dim: 170 | mask = np.zeros(img_shape, dtype=np.bool) 171 | mask[tuple(indexes)] = True 172 | arr_copy = X.copy() 173 | arr_copy[~mask] = 0 174 | return arr_copy 175 | 176 | _X = X[tuple(indexes)] 177 | # print('cropped.shape', _X.shape) 178 | return _X 179 | 180 | class Pad(object): 181 | """ Pad the given n-dimensional array 182 | """ 183 | def __init__(self, shape, **kwargs): 184 | self.shape = shape 185 | self.kwargs = kwargs 186 | 187 | def __call__(self, X): 188 | _X = self._apply_padding(X) 189 | return _X 190 | 191 | def _apply_padding(self, arr): 192 | orig_shape = arr.shape 193 | padding = [] 194 | for orig_i, final_i in zip(orig_shape, self.shape): 195 | shape_i = final_i - orig_i 196 | half_shape_i = shape_i // 2 197 | if shape_i % 2 == 0: 198 | padding.append([half_shape_i, half_shape_i]) 199 | else: 200 | padding.append([half_shape_i, half_shape_i + 1]) 201 | for cnt in range(len(arr.shape) - len(padding)): 202 | padding.append([0, 0]) 203 | fill_arr = np.pad(arr, padding, **self.kwargs) 204 | return fill_arr 205 | 206 | ############################################################################ 207 | # Define here your dataset 208 | ############################################################################ 209 | 210 | class Dataset(torch.utils.data.Dataset): 211 | def __init__(self, X, y=None, transforms=None, indices=None): 212 | self.T = transforms 213 | self.X = X 214 | self.y = y 215 | self.indices = indices 216 | if indices is None: 217 | self.indices = range(len(X)) 218 | 219 | def __len__(self): 220 | return len(self.indices) 221 | 222 | def __getitem__(self, i): 223 | real_i = self.indices[i] 224 | x = self.X[real_i] 225 | 226 | if self.T is not None: 227 | x = self.T(x) 228 | 229 | if self.y is not None: 230 | y = self.y[real_i] 231 | return x, y 232 | else: 233 | return x 234 | 235 | 236 | ############################################################################ 237 | # Define here your regression model 238 | ############################################################################ 239 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 240 | """3x3 convolution with padding""" 241 | return nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=stride, 242 | padding=dilation, groups=groups, bias=False, dilation=dilation) 243 | 244 | def conv1x1(in_planes, out_planes, stride=1): 245 | """1x1 convolution""" 246 | return nn.Conv3d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 247 | 248 | class BasicBlock(nn.Module): 249 | expansion = 1 250 | 251 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 252 | base_width=64, dilation=1, norm_layer=None): 253 | super(BasicBlock, self).__init__() 254 | if norm_layer is None: 255 | norm_layer = nn.BatchNorm3d 256 | if groups != 1 or base_width != 64: 257 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 258 | if dilation > 1: 259 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 260 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 261 | self.conv1 = conv3x3(inplanes, planes, stride) 262 | self.bn1 = norm_layer(planes) 263 | self.relu = nn.ReLU(inplace=True) 264 | self.conv2 = conv3x3(planes, planes) 265 | self.bn2 = norm_layer(planes) 266 | self.downsample = downsample 267 | self.stride = stride 268 | 269 | def forward(self, x): 270 | identity = x 271 | 272 | out = self.conv1(x) 273 | out = self.bn1(out) 274 | out = self.relu(out) 275 | out = self.conv2(out) 276 | out = self.bn2(out) 277 | 278 | if self.downsample is not None: 279 | identity = self.downsample(x) 280 | 281 | out += identity 282 | out = self.relu(out) 283 | 284 | return out 285 | 286 | 287 | class Bottleneck(nn.Module): 288 | expansion = 4 289 | 290 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 291 | base_width=64, dilation=1, norm_layer=None): 292 | super(Bottleneck, self).__init__() 293 | if norm_layer is None: 294 | norm_layer = nn.BatchNorm3d 295 | width = int(planes * (base_width / 64.)) * groups 296 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 297 | self.conv1 = conv1x1(inplanes, width) 298 | self.bn1 = norm_layer(width) 299 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 300 | self.bn2 = norm_layer(width) 301 | self.conv3 = conv1x1(width, planes * self.expansion) 302 | self.bn3 = norm_layer(planes * self.expansion) 303 | self.relu = nn.ReLU(inplace=True) 304 | self.downsample = downsample 305 | self.stride = stride 306 | 307 | def forward(self, x): 308 | identity = x 309 | 310 | out = self.conv1(x) 311 | out = self.bn1(out) 312 | out = self.relu(out) 313 | 314 | out = self.conv2(out) 315 | out = self.bn2(out) 316 | out = self.relu(out) 317 | 318 | out = self.conv3(out) 319 | out = self.bn3(out) 320 | 321 | if self.downsample is not None: 322 | identity = self.downsample(x) 323 | 324 | out += identity 325 | out = self.relu(out) 326 | 327 | return out 328 | 329 | class ResNet(nn.Module): 330 | """ 331 | Standard 3D-ResNet architecture with big initial 7x7x7 kernel. 332 | It can be turned in mode "classifier", outputting a vector of size or 333 | "encoder", outputting a latent vector of size 512 (independent of input size). 334 | Note: only a last FC layer is added on top of the "encoder" backbone. 335 | """ 336 | def __init__(self, block, layers, in_channels=1, 337 | zero_init_residual=False, groups=1, width_per_group=64, replace_stride_with_dilation=None, 338 | norm_layer=None, initial_kernel_size=7): 339 | super(ResNet, self).__init__() 340 | 341 | if norm_layer is None: 342 | norm_layer = nn.BatchNorm3d 343 | self._norm_layer = norm_layer 344 | 345 | self.name = "resnet" 346 | self.inputs = None 347 | self.inplanes = 64 348 | self.dilation = 1 349 | 350 | if replace_stride_with_dilation is None: 351 | # each element in the tuple indicates if we should replace 352 | # the 2x2 stride with a dilated convolution instead 353 | replace_stride_with_dilation = [False, False, False] 354 | if len(replace_stride_with_dilation) != 3: 355 | raise ValueError("replace_stride_with_dilation should be None " 356 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 357 | self.groups = groups 358 | self.base_width = width_per_group 359 | initial_stride = 2 if initial_kernel_size==7 else 1 360 | padding = (initial_kernel_size-initial_stride+1)//2 361 | self.conv1 = nn.Conv3d(in_channels, self.inplanes, kernel_size=initial_kernel_size, stride=initial_stride, padding=padding, bias=False) 362 | self.bn1 = norm_layer(self.inplanes) 363 | self.relu = nn.ReLU(inplace=True) 364 | self.maxpool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1) 365 | 366 | channels = [64, 128, 256, 512] 367 | 368 | self.layer1 = self._make_layer(block, channels[0], layers[0]) 369 | self.layer2 = self._make_layer(block, channels[1], layers[1], stride=2, dilate=replace_stride_with_dilation[0]) 370 | self.layer3 = self._make_layer(block, channels[2], layers[2], stride=2, dilate=replace_stride_with_dilation[1]) 371 | self.layer4 = self._make_layer(block, channels[3], layers[3], stride=2, dilate=replace_stride_with_dilation[2]) 372 | self.avgpool = nn.AdaptiveAvgPool3d(1) 373 | 374 | for m in self.modules(): 375 | if isinstance(m, nn.Conv3d): 376 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 377 | elif isinstance(m, (nn.BatchNorm3d, nn.GroupNorm)): 378 | nn.init.constant_(m.weight, 1) 379 | nn.init.constant_(m.bias, 0) 380 | elif isinstance(m, nn.Linear): 381 | nn.init.normal_(m.weight, 0, 0.01) 382 | if m.bias is not None: 383 | nn.init.constant_(m.bias, 0) 384 | 385 | # Zero-initialize the last BN in each residual branch, 386 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 387 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 388 | if zero_init_residual: 389 | for m in self.modules(): 390 | if isinstance(m, Bottleneck): 391 | nn.init.constant_(m.bn3.weight, 0) 392 | elif isinstance(m, BasicBlock): 393 | nn.init.constant_(m.bn2.weight, 0) 394 | 395 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 396 | norm_layer = self._norm_layer 397 | downsample = None 398 | previous_dilation = self.dilation 399 | if dilate: 400 | self.dilation *= stride 401 | stride = 1 402 | if stride != 1 or self.inplanes != planes * block.expansion: 403 | downsample = nn.Sequential( 404 | conv1x1(self.inplanes, planes * block.expansion, stride), 405 | norm_layer(planes * block.expansion), 406 | ) 407 | 408 | layers = [] 409 | layers.append(block(self.inplanes, planes, stride=stride, downsample=downsample, groups=self.groups, 410 | base_width=self.base_width, dilation=previous_dilation, norm_layer=norm_layer)) 411 | self.inplanes = planes * block.expansion 412 | for _ in range(1, blocks): 413 | layers.append(block(self.inplanes, planes, groups=self.groups, 414 | base_width=self.base_width, dilation=self.dilation, 415 | norm_layer=norm_layer)) 416 | 417 | return nn.Sequential(*layers) 418 | 419 | def forward(self, x): 420 | x = self.conv1(x) 421 | x = self.bn1(x) 422 | x = self.relu(x) 423 | x = self.maxpool(x) 424 | 425 | x1 = self.layer1(x) 426 | x2 = self.layer2(x1) 427 | x3 = self.layer3(x2) 428 | x4 = self.layer4(x3) 429 | 430 | x5 = self.avgpool(x4) 431 | return torch.flatten(x5, 1) 432 | 433 | def resnet18(**kwargs): 434 | return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 435 | 436 | def resnet34(**kwargs): 437 | return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 438 | 439 | def resnet50(**kwargs): 440 | return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 441 | 442 | def resnet101(**kwargs): 443 | return ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 444 | 445 | model_dict = { 446 | 'resnet18': [resnet18, 512], 447 | 'resnet34': [resnet34, 512], 448 | 'resnet50': [resnet50, 2048], 449 | 'resnet101': [resnet101, 2048], 450 | } 451 | 452 | class SupConResNet(nn.Module): 453 | """backbone + projection head""" 454 | def __init__(self, name='resnet50', head='mlp', feat_dim=128): 455 | super().__init__() 456 | model_fun, dim_in = model_dict[name] 457 | self.encoder = model_fun() 458 | if head == 'linear': 459 | self.head = nn.Linear(dim_in, feat_dim) 460 | elif head == 'mlp': 461 | self.head = nn.Sequential( 462 | nn.Linear(dim_in, dim_in), 463 | nn.ReLU(inplace=True), 464 | nn.Linear(dim_in, feat_dim) 465 | ) 466 | else: 467 | raise NotImplementedError( 468 | 'head not supported: {}'.format(head)) 469 | 470 | def forward(self, x): 471 | feat = self.encoder(x) 472 | feat = F.normalize(self.head(feat), dim=1) 473 | return feat 474 | 475 | 476 | class AlexNet3D(nn.Module): 477 | def __init__(self): 478 | """ 479 | :param num_classes: int, number of classes 480 | :param mode: "classifier" or "encoder" (returning 128-d vector) 481 | """ 482 | super().__init__() 483 | self.features = nn.Sequential( 484 | nn.Conv3d(1, 64, kernel_size=5, stride=2, padding=0), 485 | nn.BatchNorm3d(64), 486 | nn.ReLU(inplace=True), 487 | nn.MaxPool3d(kernel_size=3, stride=3), 488 | 489 | nn.Conv3d(64, 128, kernel_size=3, stride=1, padding=0), 490 | nn.BatchNorm3d(128), 491 | nn.ReLU(inplace=True), 492 | nn.MaxPool3d(kernel_size=3, stride=3), 493 | 494 | nn.Conv3d(128, 192, kernel_size=3, padding=1), 495 | nn.BatchNorm3d(192), 496 | nn.ReLU(inplace=True), 497 | 498 | nn.Conv3d(192, 192, kernel_size=3, padding=1), 499 | nn.BatchNorm3d(192), 500 | nn.ReLU(inplace=True), 501 | 502 | nn.Conv3d(192, 128, kernel_size=3, padding=1), 503 | nn.BatchNorm3d(128), 504 | nn.ReLU(inplace=True), 505 | nn.AdaptiveMaxPool3d(1), 506 | ) 507 | 508 | 509 | for m in self.modules(): 510 | if isinstance(m, nn.Conv2d): 511 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 512 | m.weight.data.normal_(0, math.sqrt(2. / n)) 513 | elif isinstance(m, nn.BatchNorm3d): 514 | m.weight.data.fill_(1) 515 | m.bias.data.zero_() 516 | 517 | def forward(self, x): 518 | xp = self.features(x) 519 | x = xp.view(xp.size(0), -1) 520 | return x 521 | 522 | class SupConAlexNet(nn.Module): 523 | """backbone + projection head""" 524 | def __init__(self, head='mlp', feat_dim=128): 525 | super().__init__() 526 | self.encoder = AlexNet3D() 527 | dim_in = 128 528 | 529 | if head == 'linear': 530 | self.head = nn.Linear(dim_in, feat_dim) 531 | elif head == 'mlp': 532 | self.head = nn.Sequential( 533 | nn.Linear(dim_in, dim_in), 534 | nn.ReLU(inplace=True), 535 | nn.Linear(dim_in, feat_dim) 536 | ) 537 | 538 | else: 539 | raise NotImplementedError( 540 | 'head not supported: {}'.format(head)) 541 | 542 | def forward(self, x): 543 | feat = self.encoder(x) 544 | feat = F.normalize(self.head(feat), dim=1) 545 | return feat 546 | 547 | def features(self, x): 548 | return self.forward(x) 549 | 550 | class DenseNet(nn.Module): 551 | """3D-Densenet-BC model class, based on 552 | `"Densely Connected Convolutional Networks" `_ 553 | Args: 554 | growth_rate (int) - how many filters to add each layer (`k` in paper) 555 | block_config (list of 4 ints) - how many layers in each pooling block 556 | num_init_features (int) - the number of filters to learn in the first convolution layer 557 | mode (str) - "classifier" or "encoder" (all but last FC layer) 558 | bn_size (int) - multiplicative factor for number of bottle neck layers 559 | (i.e. bn_size * k features in the bottleneck layer) 560 | num_classes (int) - number of classification classes 561 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, 562 | but slower. Default: *False*. See `"paper" `_ 563 | """ 564 | 565 | def __init__(self, growth_rate=32, block_config=(3, 12, 24, 16), 566 | num_init_features=64, 567 | bn_size=4, in_channels=1, 568 | memory_efficient=False): 569 | super(DenseNet, self).__init__() 570 | # First convolution 571 | self.features = nn.Sequential(OrderedDict([ 572 | ('conv0', nn.Conv3d(in_channels, num_init_features, 573 | kernel_size=7, stride=2, padding=3, bias=False)), 574 | ('norm0', nn.BatchNorm3d(num_init_features)), 575 | ('relu0', nn.ReLU(inplace=True)), 576 | ('pool0', nn.MaxPool3d(kernel_size=3, stride=2, padding=1)), 577 | ])) 578 | 579 | # Each denseblock 580 | num_features = num_init_features 581 | for i, num_layers in enumerate(block_config): 582 | block = _DenseBlock( 583 | num_layers=num_layers, 584 | num_input_features=num_features, 585 | bn_size=bn_size, 586 | growth_rate=growth_rate, 587 | memory_efficient=memory_efficient 588 | ) 589 | self.features.add_module('denseblock%d' % (i + 1), block) 590 | num_features = num_features + num_layers * growth_rate 591 | if i != len(block_config) - 1: 592 | trans = _Transition(num_input_features=num_features, 593 | num_output_features=num_features // 2) 594 | self.features.add_module('transition%d' % (i + 1), trans) 595 | num_features = num_features // 2 596 | 597 | self.num_features = num_features 598 | 599 | 600 | # Official init from torch repo. 601 | for m in self.modules(): 602 | if isinstance(m, nn.Conv3d): 603 | nn.init.kaiming_normal_(m.weight) 604 | elif isinstance(m, nn.BatchNorm3d): 605 | nn.init.constant_(m.weight, 1) 606 | nn.init.constant_(m.bias, 0) 607 | elif isinstance(m, nn.Linear): 608 | nn.init.constant_(m.bias, 0) 609 | 610 | def forward(self, x): 611 | features = self.features(x) 612 | out = F.adaptive_avg_pool3d(features, 1) 613 | out = torch.flatten(out, 1) 614 | return out.squeeze(dim=1) 615 | 616 | 617 | def _bn_function_factory(norm, relu, conv): 618 | def bn_function(*inputs): 619 | concated_features = torch.cat(inputs, 1) 620 | bottleneck_output = conv(relu(norm(concated_features))) 621 | return bottleneck_output 622 | 623 | return bn_function 624 | 625 | 626 | class _DenseLayer(nn.Sequential): 627 | def __init__(self, num_input_features, growth_rate, bn_size, memory_efficient=False): 628 | super(_DenseLayer, self).__init__() 629 | self.add_module('norm1', nn.BatchNorm3d(num_input_features)), 630 | self.add_module('relu1', nn.ReLU(inplace=True)), 631 | self.add_module('conv1', nn.Conv3d(num_input_features, bn_size * 632 | growth_rate, kernel_size=1, stride=1, 633 | bias=False)), 634 | self.add_module('norm2', nn.BatchNorm3d(bn_size * growth_rate)), 635 | self.add_module('relu2', nn.ReLU(inplace=True)), 636 | self.add_module('conv2', nn.Conv3d(bn_size * growth_rate, growth_rate, 637 | kernel_size=3, stride=1, padding=1, 638 | bias=False)), 639 | self.memory_efficient = memory_efficient 640 | 641 | def forward(self, *prev_features): 642 | bn_function = _bn_function_factory(self.norm1, self.relu1, self.conv1) 643 | if self.memory_efficient and any(prev_feature.requires_grad for prev_feature in prev_features): 644 | bottleneck_output = cp.checkpoint(bn_function, *prev_features) 645 | else: 646 | bottleneck_output = bn_function(*prev_features) 647 | 648 | new_features = self.conv2(self.relu2(self.norm2(bottleneck_output))) 649 | 650 | return new_features 651 | 652 | 653 | class _DenseBlock(nn.Module): 654 | def __init__(self, num_layers, num_input_features, bn_size, growth_rate, memory_efficient=False): 655 | super(_DenseBlock, self).__init__() 656 | for i in range(num_layers): 657 | layer = _DenseLayer( 658 | num_input_features + i * growth_rate, 659 | growth_rate=growth_rate, 660 | bn_size=bn_size, 661 | memory_efficient=memory_efficient, 662 | ) 663 | self.add_module('denselayer%d' % (i + 1), layer) 664 | 665 | def forward(self, init_features): 666 | features = [init_features] 667 | for name, layer in self.named_children(): 668 | new_features = layer(*features) 669 | features.append(new_features) 670 | return torch.cat(features, 1) 671 | 672 | 673 | class _Transition(nn.Sequential): 674 | def __init__(self, num_input_features, num_output_features): 675 | super(_Transition, self).__init__() 676 | self.add_module('norm', nn.BatchNorm3d(num_input_features)) 677 | self.add_module('relu', nn.ReLU(inplace=True)) 678 | self.add_module('conv', nn.Conv3d(num_input_features, num_output_features, 679 | kernel_size=1, stride=1, bias=False)) 680 | self.add_module('pool', nn.AvgPool3d(kernel_size=2, stride=2)) 681 | 682 | 683 | def _densenet(arch, growth_rate, block_config, num_init_features, **kwargs): 684 | model = DenseNet(growth_rate, block_config, num_init_features, **kwargs) 685 | return model 686 | 687 | 688 | def densenet121(**kwargs): 689 | r"""Densenet-121 model from 690 | `"Densely Connected Convolutional Networks" `_ 691 | 692 | Args: 693 | pretrained (bool): If True, returns a model pre-trained on ImageNet 694 | progress (bool): If True, displays a progress bar of the download to stderr 695 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, 696 | but slower. Default: *False*. See `"paper" `_ 697 | """ 698 | return _densenet('densenet121', 32, (6, 12, 24, 16), 64, **kwargs) 699 | 700 | class SupConDenseNet(nn.Module): 701 | """backbone + projection head""" 702 | def __init__(self, head='mlp', feat_dim=128): 703 | super().__init__() 704 | self.encoder = densenet121() 705 | dim_in = self.encoder.num_features 706 | 707 | if head == 'linear': 708 | self.head = nn.Linear(dim_in, feat_dim) 709 | elif head == 'mlp': 710 | self.head = nn.Sequential( 711 | nn.Linear(dim_in, dim_in), 712 | nn.ReLU(inplace=True), 713 | nn.Linear(dim_in, feat_dim) 714 | ) 715 | 716 | else: 717 | raise NotImplementedError( 718 | 'head not supported: {}'.format(head)) 719 | 720 | def forward(self, x): 721 | feat = self.encoder(x) 722 | feat = F.normalize(self.head(feat), dim=1) 723 | return feat 724 | 725 | def features(self, x): 726 | return self.forward(x) 727 | 728 | 729 | class RegressionModel(metaclass=ABCMeta): 730 | __model_local_weights__ = os.path.join(os.path.dirname(__file__), os.environ.get("MODEL", "weights.pth")) 731 | __metadata_local_weights__ = os.path.join(os.path.dirname(__file__), "metadata.pkl") 732 | 733 | def __init__(self, model, batch_size=15, transforms=None): 734 | self.model = model 735 | self.batch_size = batch_size 736 | self.transforms = transforms 737 | self.indices = None 738 | 739 | def fit(self, X, y): 740 | """ Restore weights. 741 | """ 742 | if not os.path.isfile(self.__model_local_weights__): 743 | raise ValueError("You must provide the model weigths in your submission folder.") 744 | state = torch.load(self.__model_local_weights__, map_location="cpu") 745 | 746 | if "model" not in state: 747 | raise ValueError("Model weigths are searched in the state dictionary at the 'model' key location.") 748 | self.model.load_state_dict(state["model"], strict=True) 749 | 750 | def predict(self, X: np.ndarray) -> np.ndarray: 751 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 752 | self.model.to(device) 753 | 754 | dataset = Dataset(X, transforms=self.transforms, indices=self.indices) 755 | testloader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, shuffle=False, num_workers=0) 756 | 757 | self.model.eval() 758 | outputs = [] 759 | 760 | with progressbar.ProgressBar(max_value=len(testloader)) as bar: 761 | for cnt, inputs in enumerate(testloader): 762 | inputs = inputs.float().to(device) 763 | # print("Batch size", inputs.shape) 764 | with torch.no_grad(): 765 | out = self.model(inputs) 766 | # out = torch.randn((inputs.shape[0], 128)) 767 | 768 | outputs.append(out.detach()) 769 | bar.update(cnt) 770 | 771 | outputs = torch.cat(outputs, dim=0) 772 | return outputs.detach().cpu().numpy() 773 | 774 | 775 | ############################################################################ 776 | # Define here your estimator pipeline 777 | ############################################################################ 778 | 779 | def get_estimator(mock=False) -> Pipeline: 780 | """ Build your estimator here. 781 | Notes 782 | ----- 783 | In order to minimize the memory load the first steps of the pipeline 784 | are applied directly as transforms attached to the Torch Dataset. 785 | Notes 786 | ----- 787 | It is recommended to create an instance of sklearn.pipeline.Pipeline. 788 | """ 789 | print("InfoNCE") 790 | if "resnet" in ARCHITECTURE: 791 | net = SupConResNet(ARCHITECTURE) 792 | elif ARCHITECTURE == "alexnet": 793 | net = SupConAlexNet() 794 | elif "densenet" in ARCHITECTURE: 795 | net = SupConDenseNet() 796 | 797 | selector = FeatureExtractor("vbm", mock=mock) 798 | preproc = transforms.Compose([ 799 | transforms.Lambda(lambda x: selector.transform(x)), 800 | # Crop((1, 121, 128, 121), type="center"), 801 | # Pad((1, 128, 128, 128)), 802 | transforms.Lambda(lambda x: torch.from_numpy(x).float()), 803 | transforms.Normalize(mean=0.0, std=1.0), 804 | ]) 805 | estimator = make_pipeline( 806 | RegressionModel(net, transforms=preproc)) 807 | return estimator 808 | 809 | 810 | if __name__ == '__main__': 811 | estimator = get_estimator(mock=True).fit(None) 812 | estimator.predict(np.random.random((32, 2122945))) -------------------------------------------------------------------------------- /ramp-submission/estimator_mse.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | ########################################################################## 3 | # Code version 6207dffcc20f461bdb742f5d8a2f6641483b9d83 4 | ########################################################################## 5 | 6 | 7 | """ 8 | Each solution to be tested should be stored in its own directory within 9 | submissions/. The name of this new directory will serve as the ID for 10 | the submission. If you wish to launch a RAMP challenge you will need to 11 | provide an example solution within submissions/starting_kit/. Even if 12 | you are not launching a RAMP challenge on RAMP Studio, it is useful to 13 | have an example submission as it shows which files are required, how they 14 | need to be named and how each file should be structured. 15 | """ 16 | 17 | # Filename: estimator_mse.py 18 | # Run id: 19 | # 20 | import os 21 | ARCHITECTURE = os.environ.get("ARCHITECTURE", "resnet18") 22 | 23 | 24 | from collections import OrderedDict 25 | from abc import ABCMeta 26 | import progressbar 27 | import nibabel 28 | import numpy as np 29 | from nilearn.masking import unmask 30 | from sklearn.base import BaseEstimator 31 | from sklearn.base import TransformerMixin 32 | from sklearn.pipeline import Pipeline, make_pipeline 33 | import torch 34 | import torch.nn as nn 35 | import torch.nn.functional as F 36 | import torch.utils.checkpoint as cp 37 | from torchvision import transforms 38 | import math 39 | 40 | ############################################################################ 41 | # Define here some selectors 42 | ############################################################################ 43 | 44 | class FeatureExtractor(BaseEstimator, TransformerMixin): 45 | """ Select only the requested data associatedd features from the the 46 | input buffered data. 47 | """ 48 | MODALITIES = OrderedDict([ 49 | ("vbm", { 50 | "shape": (1, 121, 145, 121), 51 | "size": 519945}), 52 | ("quasiraw", { 53 | "shape": (1, 182, 218, 182), 54 | "size": 1827095}), 55 | ("xhemi", { 56 | "shape": (8, 163842), 57 | "size": 1310736}), 58 | ("vbm_roi", { 59 | "shape": (1, 284), 60 | "size": 284}), 61 | ("desikan_roi", { 62 | "shape": (7, 68), 63 | "size": 476}), 64 | ("destrieux_roi", { 65 | "shape": (7, 148), 66 | "size": 1036}) 67 | ]) 68 | MASKS = { 69 | "vbm": { 70 | "path": None, 71 | "thr": 0.05}, 72 | "quasiraw": { 73 | "path": None, 74 | "thr": 0} 75 | } 76 | 77 | def __init__(self, dtype, mock=False): 78 | """ Init class. 79 | Parameters 80 | ---------- 81 | dtype: str 82 | the requested data: 'vbm', 'quasiraw', 'vbm_roi', 'desikan_roi', 83 | 'destrieux_roi' or 'xhemi'. 84 | """ 85 | if dtype not in self.MODALITIES: 86 | raise ValueError("Invalid input data type.") 87 | self.dtype = dtype 88 | 89 | data_types = list(self.MODALITIES.keys()) 90 | index = data_types.index(dtype) 91 | 92 | cumsum = np.cumsum([item["size"] for item in self.MODALITIES.values()]) 93 | 94 | if index > 0: 95 | self.start = cumsum[index - 1] 96 | else: 97 | self.start = 0 98 | self.stop = cumsum[index] 99 | 100 | self.masks = dict((key, val["path"]) for key, val in self.MASKS.items()) 101 | self.masks["vbm"] = os.environ.get("VBM_MASK") 102 | self.masks["quasiraw"] = os.environ.get("QUASIRAW_MASK") 103 | 104 | self.mock = mock 105 | if mock: 106 | return 107 | 108 | for key in self.masks: 109 | if self.masks[key] is None or not os.path.isfile(self.masks[key]): 110 | raise ValueError("Impossible to find mask:", key, self.masks[key]) 111 | arr = nibabel.load(self.masks[key]).get_fdata() 112 | thr = self.MASKS[key]["thr"] 113 | arr[arr <= thr] = 0 114 | arr[arr > thr] = 1 115 | self.masks[key] = nibabel.Nifti1Image(arr.astype(int), np.eye(4)) 116 | 117 | def fit(self, X, y): 118 | return self 119 | 120 | def transform(self, X): 121 | if self.mock: 122 | #print("transforming", X.shape) 123 | data = X.reshape(self.MODALITIES[self.dtype]["shape"]) 124 | #print("mock data:", data.shape) 125 | return data 126 | 127 | # print(X.shape) 128 | select_X = X[self.start:self.stop] 129 | if self.dtype in ("vbm", "quasiraw"): 130 | im = unmask(select_X, self.masks[self.dtype]) 131 | select_X = im.get_fdata() 132 | select_X = select_X.transpose(2, 0, 1) 133 | select_X = select_X.reshape(self.MODALITIES[self.dtype]["shape"]) 134 | return select_X 135 | 136 | class Crop(object): 137 | """ Crop the given n-dimensional array either at a random location or 138 | centered. 139 | """ 140 | def __init__(self, shape, type="center", keep_dim=False): 141 | assert type in ["center", "random"] 142 | self.shape = shape 143 | self.copping_type = type 144 | self.keep_dim = keep_dim 145 | 146 | def __call__(self, X): 147 | img_shape = np.array(X.shape) 148 | 149 | if type(self.shape) == int: 150 | size = [self.shape for _ in range(len(self.shape))] 151 | else: 152 | size = np.copy(self.shape) 153 | 154 | # print('img_shape:', img_shape, 'size', size) 155 | 156 | indexes = [] 157 | for ndim in range(len(img_shape)): 158 | if size[ndim] > img_shape[ndim] or size[ndim] < 0: 159 | size[ndim] = img_shape[ndim] 160 | 161 | if self.copping_type == "center": 162 | delta_before = int((img_shape[ndim] - size[ndim]) / 2.0) 163 | 164 | elif self.copping_type == "random": 165 | delta_before = np.random.randint(0, img_shape[ndim] - size[ndim] + 1) 166 | 167 | indexes.append(slice(delta_before, delta_before + size[ndim])) 168 | 169 | if self.keep_dim: 170 | mask = np.zeros(img_shape, dtype=np.bool) 171 | mask[tuple(indexes)] = True 172 | arr_copy = X.copy() 173 | arr_copy[~mask] = 0 174 | return arr_copy 175 | 176 | _X = X[tuple(indexes)] 177 | # print('cropped.shape', _X.shape) 178 | return _X 179 | 180 | class Pad(object): 181 | """ Pad the given n-dimensional array 182 | """ 183 | def __init__(self, shape, **kwargs): 184 | self.shape = shape 185 | self.kwargs = kwargs 186 | 187 | def __call__(self, X): 188 | _X = self._apply_padding(X) 189 | return _X 190 | 191 | def _apply_padding(self, arr): 192 | orig_shape = arr.shape 193 | padding = [] 194 | for orig_i, final_i in zip(orig_shape, self.shape): 195 | shape_i = final_i - orig_i 196 | half_shape_i = shape_i // 2 197 | if shape_i % 2 == 0: 198 | padding.append([half_shape_i, half_shape_i]) 199 | else: 200 | padding.append([half_shape_i, half_shape_i + 1]) 201 | for cnt in range(len(arr.shape) - len(padding)): 202 | padding.append([0, 0]) 203 | fill_arr = np.pad(arr, padding, **self.kwargs) 204 | return fill_arr 205 | 206 | ############################################################################ 207 | # Define here your dataset 208 | ############################################################################ 209 | 210 | class Dataset(torch.utils.data.Dataset): 211 | def __init__(self, X, y=None, transforms=None, indices=None): 212 | self.T = transforms 213 | self.X = X 214 | self.y = y 215 | self.indices = indices 216 | if indices is None: 217 | self.indices = range(len(X)) 218 | 219 | def __len__(self): 220 | return len(self.indices) 221 | 222 | def __getitem__(self, i): 223 | real_i = self.indices[i] 224 | x = self.X[real_i] 225 | 226 | if self.T is not None: 227 | x = self.T(x) 228 | 229 | if self.y is not None: 230 | y = self.y[real_i] 231 | return x, y 232 | else: 233 | return x 234 | 235 | 236 | ############################################################################ 237 | # Define here your regression model 238 | ############################################################################ 239 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 240 | """3x3 convolution with padding""" 241 | return nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=stride, 242 | padding=dilation, groups=groups, bias=False, dilation=dilation) 243 | 244 | def conv1x1(in_planes, out_planes, stride=1): 245 | """1x1 convolution""" 246 | return nn.Conv3d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 247 | 248 | class BasicBlock(nn.Module): 249 | expansion = 1 250 | 251 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 252 | base_width=64, dilation=1, norm_layer=None): 253 | super(BasicBlock, self).__init__() 254 | if norm_layer is None: 255 | norm_layer = nn.BatchNorm3d 256 | if groups != 1 or base_width != 64: 257 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 258 | if dilation > 1: 259 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 260 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 261 | self.conv1 = conv3x3(inplanes, planes, stride) 262 | self.bn1 = norm_layer(planes) 263 | self.relu = nn.ReLU(inplace=True) 264 | self.conv2 = conv3x3(planes, planes) 265 | self.bn2 = norm_layer(planes) 266 | self.downsample = downsample 267 | self.stride = stride 268 | 269 | def forward(self, x): 270 | identity = x 271 | 272 | out = self.conv1(x) 273 | out = self.bn1(out) 274 | out = self.relu(out) 275 | out = self.conv2(out) 276 | out = self.bn2(out) 277 | 278 | if self.downsample is not None: 279 | identity = self.downsample(x) 280 | 281 | out += identity 282 | out = self.relu(out) 283 | 284 | return out 285 | 286 | 287 | class Bottleneck(nn.Module): 288 | expansion = 4 289 | 290 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 291 | base_width=64, dilation=1, norm_layer=None): 292 | super(Bottleneck, self).__init__() 293 | if norm_layer is None: 294 | norm_layer = nn.BatchNorm3d 295 | width = int(planes * (base_width / 64.)) * groups 296 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 297 | self.conv1 = conv1x1(inplanes, width) 298 | self.bn1 = norm_layer(width) 299 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 300 | self.bn2 = norm_layer(width) 301 | self.conv3 = conv1x1(width, planes * self.expansion) 302 | self.bn3 = norm_layer(planes * self.expansion) 303 | self.relu = nn.ReLU(inplace=True) 304 | self.downsample = downsample 305 | self.stride = stride 306 | 307 | def forward(self, x): 308 | identity = x 309 | 310 | out = self.conv1(x) 311 | out = self.bn1(out) 312 | out = self.relu(out) 313 | 314 | out = self.conv2(out) 315 | out = self.bn2(out) 316 | out = self.relu(out) 317 | 318 | out = self.conv3(out) 319 | out = self.bn3(out) 320 | 321 | if self.downsample is not None: 322 | identity = self.downsample(x) 323 | 324 | out += identity 325 | out = self.relu(out) 326 | 327 | return out 328 | 329 | class ResNet(nn.Module): 330 | """ 331 | Standard 3D-ResNet architecture with big initial 7x7x7 kernel. 332 | It can be turned in mode "classifier", outputting a vector of size or 333 | "encoder", outputting a latent vector of size 512 (independent of input size). 334 | Note: only a last FC layer is added on top of the "encoder" backbone. 335 | """ 336 | def __init__(self, block, layers, in_channels=1, 337 | zero_init_residual=False, groups=1, width_per_group=64, replace_stride_with_dilation=None, 338 | norm_layer=None, initial_kernel_size=7): 339 | super(ResNet, self).__init__() 340 | 341 | if norm_layer is None: 342 | norm_layer = nn.BatchNorm3d 343 | self._norm_layer = norm_layer 344 | 345 | self.name = "resnet" 346 | self.inputs = None 347 | self.inplanes = 64 348 | self.dilation = 1 349 | 350 | if replace_stride_with_dilation is None: 351 | # each element in the tuple indicates if we should replace 352 | # the 2x2 stride with a dilated convolution instead 353 | replace_stride_with_dilation = [False, False, False] 354 | if len(replace_stride_with_dilation) != 3: 355 | raise ValueError("replace_stride_with_dilation should be None " 356 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 357 | self.groups = groups 358 | self.base_width = width_per_group 359 | initial_stride = 2 if initial_kernel_size==7 else 1 360 | padding = (initial_kernel_size-initial_stride+1)//2 361 | self.conv1 = nn.Conv3d(in_channels, self.inplanes, kernel_size=initial_kernel_size, stride=initial_stride, padding=padding, bias=False) 362 | self.bn1 = norm_layer(self.inplanes) 363 | self.relu = nn.ReLU(inplace=True) 364 | self.maxpool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1) 365 | 366 | channels = [64, 128, 256, 512] 367 | 368 | self.layer1 = self._make_layer(block, channels[0], layers[0]) 369 | self.layer2 = self._make_layer(block, channels[1], layers[1], stride=2, dilate=replace_stride_with_dilation[0]) 370 | self.layer3 = self._make_layer(block, channels[2], layers[2], stride=2, dilate=replace_stride_with_dilation[1]) 371 | self.layer4 = self._make_layer(block, channels[3], layers[3], stride=2, dilate=replace_stride_with_dilation[2]) 372 | self.avgpool = nn.AdaptiveAvgPool3d(1) 373 | 374 | for m in self.modules(): 375 | if isinstance(m, nn.Conv3d): 376 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 377 | elif isinstance(m, (nn.BatchNorm3d, nn.GroupNorm)): 378 | nn.init.constant_(m.weight, 1) 379 | nn.init.constant_(m.bias, 0) 380 | elif isinstance(m, nn.Linear): 381 | nn.init.normal_(m.weight, 0, 0.01) 382 | if m.bias is not None: 383 | nn.init.constant_(m.bias, 0) 384 | 385 | # Zero-initialize the last BN in each residual branch, 386 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 387 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 388 | if zero_init_residual: 389 | for m in self.modules(): 390 | if isinstance(m, Bottleneck): 391 | nn.init.constant_(m.bn3.weight, 0) 392 | elif isinstance(m, BasicBlock): 393 | nn.init.constant_(m.bn2.weight, 0) 394 | 395 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 396 | norm_layer = self._norm_layer 397 | downsample = None 398 | previous_dilation = self.dilation 399 | if dilate: 400 | self.dilation *= stride 401 | stride = 1 402 | if stride != 1 or self.inplanes != planes * block.expansion: 403 | downsample = nn.Sequential( 404 | conv1x1(self.inplanes, planes * block.expansion, stride), 405 | norm_layer(planes * block.expansion), 406 | ) 407 | 408 | layers = [] 409 | layers.append(block(self.inplanes, planes, stride=stride, downsample=downsample, groups=self.groups, 410 | base_width=self.base_width, dilation=previous_dilation, norm_layer=norm_layer)) 411 | self.inplanes = planes * block.expansion 412 | for _ in range(1, blocks): 413 | layers.append(block(self.inplanes, planes, groups=self.groups, 414 | base_width=self.base_width, dilation=self.dilation, 415 | norm_layer=norm_layer)) 416 | 417 | return nn.Sequential(*layers) 418 | 419 | def forward(self, x): 420 | x = self.conv1(x) 421 | x = self.bn1(x) 422 | x = self.relu(x) 423 | x = self.maxpool(x) 424 | 425 | x1 = self.layer1(x) 426 | x2 = self.layer2(x1) 427 | x3 = self.layer3(x2) 428 | x4 = self.layer4(x3) 429 | 430 | x5 = self.avgpool(x4) 431 | return torch.flatten(x5, 1) 432 | 433 | def resnet18(**kwargs): 434 | return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 435 | 436 | def resnet34(**kwargs): 437 | return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 438 | 439 | def resnet50(**kwargs): 440 | return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 441 | 442 | def resnet101(**kwargs): 443 | return ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 444 | 445 | model_dict = { 446 | 'resnet18': [resnet18, 512], 447 | 'resnet34': [resnet34, 512], 448 | 'resnet50': [resnet50, 2048], 449 | 'resnet101': [resnet101, 2048], 450 | } 451 | 452 | class SupRegResNet(nn.Module): 453 | """encoder + regressor""" 454 | def __init__(self, name='resnet50'): 455 | super().__init__() 456 | model_fun, dim_in = model_dict[name] 457 | self.encoder = model_fun() 458 | self.fc = nn.Linear(dim_in, 1) 459 | 460 | def forward(self, x): 461 | return self.encoder(x) 462 | # return self.fc(self.encoder(x)) 463 | 464 | class AlexNet3D(nn.Module): 465 | def __init__(self): 466 | """ 467 | :param num_classes: int, number of classes 468 | :param mode: "classifier" or "encoder" (returning 128-d vector) 469 | """ 470 | super().__init__() 471 | self.features = nn.Sequential( 472 | nn.Conv3d(1, 64, kernel_size=5, stride=2, padding=0), 473 | nn.BatchNorm3d(64), 474 | nn.ReLU(inplace=True), 475 | nn.MaxPool3d(kernel_size=3, stride=3), 476 | 477 | nn.Conv3d(64, 128, kernel_size=3, stride=1, padding=0), 478 | nn.BatchNorm3d(128), 479 | nn.ReLU(inplace=True), 480 | nn.MaxPool3d(kernel_size=3, stride=3), 481 | 482 | nn.Conv3d(128, 192, kernel_size=3, padding=1), 483 | nn.BatchNorm3d(192), 484 | nn.ReLU(inplace=True), 485 | 486 | nn.Conv3d(192, 192, kernel_size=3, padding=1), 487 | nn.BatchNorm3d(192), 488 | nn.ReLU(inplace=True), 489 | 490 | nn.Conv3d(192, 128, kernel_size=3, padding=1), 491 | nn.BatchNorm3d(128), 492 | nn.ReLU(inplace=True), 493 | nn.AdaptiveMaxPool3d(1), 494 | ) 495 | 496 | 497 | for m in self.modules(): 498 | if isinstance(m, nn.Conv2d): 499 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 500 | m.weight.data.normal_(0, math.sqrt(2. / n)) 501 | elif isinstance(m, nn.BatchNorm3d): 502 | m.weight.data.fill_(1) 503 | m.bias.data.zero_() 504 | 505 | def forward(self, x): 506 | xp = self.features(x) 507 | x = xp.view(xp.size(0), -1) 508 | return x 509 | 510 | class SupRegAlexNet(nn.Module): 511 | """encoder + regressor""" 512 | def __init__(self,): 513 | super().__init__() 514 | self.encoder = AlexNet3D() 515 | self.fc = nn.Linear(128, 1) 516 | 517 | def forward(self, x): 518 | feats = self.features(x) 519 | return feats 520 | # return self.fc(feats), feats 521 | 522 | def features(self, x): 523 | return self.encoder(x) 524 | 525 | class DenseNet(nn.Module): 526 | """3D-Densenet-BC model class, based on 527 | `"Densely Connected Convolutional Networks" `_ 528 | Args: 529 | growth_rate (int) - how many filters to add each layer (`k` in paper) 530 | block_config (list of 4 ints) - how many layers in each pooling block 531 | num_init_features (int) - the number of filters to learn in the first convolution layer 532 | mode (str) - "classifier" or "encoder" (all but last FC layer) 533 | bn_size (int) - multiplicative factor for number of bottle neck layers 534 | (i.e. bn_size * k features in the bottleneck layer) 535 | num_classes (int) - number of classification classes 536 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, 537 | but slower. Default: *False*. See `"paper" `_ 538 | """ 539 | 540 | def __init__(self, growth_rate=32, block_config=(3, 12, 24, 16), 541 | num_init_features=64, 542 | bn_size=4, in_channels=1, 543 | memory_efficient=False): 544 | super(DenseNet, self).__init__() 545 | # First convolution 546 | self.features = nn.Sequential(OrderedDict([ 547 | ('conv0', nn.Conv3d(in_channels, num_init_features, 548 | kernel_size=7, stride=2, padding=3, bias=False)), 549 | ('norm0', nn.BatchNorm3d(num_init_features)), 550 | ('relu0', nn.ReLU(inplace=True)), 551 | ('pool0', nn.MaxPool3d(kernel_size=3, stride=2, padding=1)), 552 | ])) 553 | 554 | # Each denseblock 555 | num_features = num_init_features 556 | for i, num_layers in enumerate(block_config): 557 | block = _DenseBlock( 558 | num_layers=num_layers, 559 | num_input_features=num_features, 560 | bn_size=bn_size, 561 | growth_rate=growth_rate, 562 | memory_efficient=memory_efficient 563 | ) 564 | self.features.add_module('denseblock%d' % (i + 1), block) 565 | num_features = num_features + num_layers * growth_rate 566 | if i != len(block_config) - 1: 567 | trans = _Transition(num_input_features=num_features, 568 | num_output_features=num_features // 2) 569 | self.features.add_module('transition%d' % (i + 1), trans) 570 | num_features = num_features // 2 571 | 572 | self.num_features = num_features 573 | 574 | 575 | # Official init from torch repo. 576 | for m in self.modules(): 577 | if isinstance(m, nn.Conv3d): 578 | nn.init.kaiming_normal_(m.weight) 579 | elif isinstance(m, nn.BatchNorm3d): 580 | nn.init.constant_(m.weight, 1) 581 | nn.init.constant_(m.bias, 0) 582 | elif isinstance(m, nn.Linear): 583 | nn.init.constant_(m.bias, 0) 584 | 585 | def forward(self, x): 586 | features = self.features(x) 587 | out = F.adaptive_avg_pool3d(features, 1) 588 | out = torch.flatten(out, 1) 589 | return out.squeeze(dim=1) 590 | 591 | 592 | def _bn_function_factory(norm, relu, conv): 593 | def bn_function(*inputs): 594 | concated_features = torch.cat(inputs, 1) 595 | bottleneck_output = conv(relu(norm(concated_features))) 596 | return bottleneck_output 597 | 598 | return bn_function 599 | 600 | 601 | class _DenseLayer(nn.Sequential): 602 | def __init__(self, num_input_features, growth_rate, bn_size, memory_efficient=False): 603 | super(_DenseLayer, self).__init__() 604 | self.add_module('norm1', nn.BatchNorm3d(num_input_features)), 605 | self.add_module('relu1', nn.ReLU(inplace=True)), 606 | self.add_module('conv1', nn.Conv3d(num_input_features, bn_size * 607 | growth_rate, kernel_size=1, stride=1, 608 | bias=False)), 609 | self.add_module('norm2', nn.BatchNorm3d(bn_size * growth_rate)), 610 | self.add_module('relu2', nn.ReLU(inplace=True)), 611 | self.add_module('conv2', nn.Conv3d(bn_size * growth_rate, growth_rate, 612 | kernel_size=3, stride=1, padding=1, 613 | bias=False)), 614 | self.memory_efficient = memory_efficient 615 | 616 | def forward(self, *prev_features): 617 | bn_function = _bn_function_factory(self.norm1, self.relu1, self.conv1) 618 | if self.memory_efficient and any(prev_feature.requires_grad for prev_feature in prev_features): 619 | bottleneck_output = cp.checkpoint(bn_function, *prev_features) 620 | else: 621 | bottleneck_output = bn_function(*prev_features) 622 | 623 | new_features = self.conv2(self.relu2(self.norm2(bottleneck_output))) 624 | 625 | return new_features 626 | 627 | 628 | class _DenseBlock(nn.Module): 629 | def __init__(self, num_layers, num_input_features, bn_size, growth_rate, memory_efficient=False): 630 | super(_DenseBlock, self).__init__() 631 | for i in range(num_layers): 632 | layer = _DenseLayer( 633 | num_input_features + i * growth_rate, 634 | growth_rate=growth_rate, 635 | bn_size=bn_size, 636 | memory_efficient=memory_efficient, 637 | ) 638 | self.add_module('denselayer%d' % (i + 1), layer) 639 | 640 | def forward(self, init_features): 641 | features = [init_features] 642 | for name, layer in self.named_children(): 643 | new_features = layer(*features) 644 | features.append(new_features) 645 | return torch.cat(features, 1) 646 | 647 | 648 | class _Transition(nn.Sequential): 649 | def __init__(self, num_input_features, num_output_features): 650 | super(_Transition, self).__init__() 651 | self.add_module('norm', nn.BatchNorm3d(num_input_features)) 652 | self.add_module('relu', nn.ReLU(inplace=True)) 653 | self.add_module('conv', nn.Conv3d(num_input_features, num_output_features, 654 | kernel_size=1, stride=1, bias=False)) 655 | self.add_module('pool', nn.AvgPool3d(kernel_size=2, stride=2)) 656 | 657 | 658 | def _densenet(arch, growth_rate, block_config, num_init_features, **kwargs): 659 | model = DenseNet(growth_rate, block_config, num_init_features, **kwargs) 660 | return model 661 | 662 | 663 | def densenet121(**kwargs): 664 | r"""Densenet-121 model from 665 | `"Densely Connected Convolutional Networks" `_ 666 | 667 | Args: 668 | pretrained (bool): If True, returns a model pre-trained on ImageNet 669 | progress (bool): If True, displays a progress bar of the download to stderr 670 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, 671 | but slower. Default: *False*. See `"paper" `_ 672 | """ 673 | return _densenet('densenet121', 32, (6, 12, 24, 16), 64, **kwargs) 674 | 675 | class SupRegDenseNet(nn.Module): 676 | """encoder + regressor""" 677 | def __init__(self,): 678 | super().__init__() 679 | self.encoder = densenet121() 680 | self.fc = nn.Linear(self.encoder.num_features, 1) 681 | 682 | def forward(self, x): 683 | feats = self.features(x) 684 | return feats 685 | # return self.fc(feats), feats 686 | 687 | def features(self, x): 688 | return self.encoder(x) 689 | 690 | class RegressionModel(metaclass=ABCMeta): 691 | __model_local_weights__ = os.path.join(os.path.dirname(__file__), os.environ.get("MODEL", "weights.pth")) 692 | __metadata_local_weights__ = os.path.join(os.path.dirname(__file__), "metadata.pkl") 693 | 694 | def __init__(self, model, batch_size=15, transforms=None): 695 | self.model = model 696 | self.batch_size = batch_size 697 | self.transforms = transforms 698 | self.indices = None 699 | 700 | def fit(self, X, y): 701 | """ Restore weights. 702 | """ 703 | if not os.path.isfile(self.__model_local_weights__): 704 | raise ValueError("You must provide the model weigths in your submission folder.") 705 | state = torch.load(self.__model_local_weights__, map_location="cpu") 706 | 707 | if "model" not in state: 708 | raise ValueError("Model weigths are searched in the state dictionary at the 'model' key location.") 709 | self.model.load_state_dict(state["model"], strict=True) 710 | 711 | def predict(self, X: np.ndarray) -> np.ndarray: 712 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 713 | self.model.to(device) 714 | 715 | dataset = Dataset(X, transforms=self.transforms, indices=self.indices) 716 | testloader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, shuffle=False, num_workers=0) 717 | 718 | self.model.eval() 719 | outputs = [] 720 | 721 | with progressbar.ProgressBar(max_value=len(testloader)) as bar: 722 | for cnt, inputs in enumerate(testloader): 723 | inputs = inputs.float().to(device) 724 | # print("Batch size", inputs.shape) 725 | with torch.no_grad(): 726 | out = self.model(inputs) 727 | # out = torch.randn((inputs.shape[0], 128)) 728 | 729 | outputs.append(out.detach()) 730 | bar.update(cnt) 731 | 732 | outputs = torch.cat(outputs, dim=0) 733 | return outputs.detach().cpu().numpy() 734 | 735 | 736 | ############################################################################ 737 | # Define here your estimator pipeline 738 | ############################################################################ 739 | 740 | def get_estimator(mock=False) -> Pipeline: 741 | """ Build your estimator here. 742 | Notes 743 | ----- 744 | In order to minimize the memory load the first steps of the pipeline 745 | are applied directly as transforms attached to the Torch Dataset. 746 | Notes 747 | ----- 748 | It is recommended to create an instance of sklearn.pipeline.Pipeline. 749 | """ 750 | if "resnet" in ARCHITECTURE: 751 | net = SupRegResNet(ARCHITECTURE) 752 | elif ARCHITECTURE == "alexnet": 753 | net = SupRegAlexNet() 754 | elif "densenet" in ARCHITECTURE: 755 | net = SupRegDenseNet() 756 | 757 | selector = FeatureExtractor("vbm", mock=mock) 758 | preproc = transforms.Compose([ 759 | transforms.Lambda(lambda x: selector.transform(x)), 760 | # Crop((1, 121, 128, 121), type="center"), 761 | # Pad((1, 128, 128, 128)), 762 | transforms.Lambda(lambda x: torch.from_numpy(x).float()), 763 | transforms.Normalize(mean=0.0, std=1.0), 764 | ]) 765 | estimator = make_pipeline( 766 | RegressionModel(net, transforms=preproc)) 767 | return estimator 768 | 769 | 770 | if __name__ == '__main__': 771 | estimator = get_estimator(mock=True).fit(None) 772 | estimator.predict(np.random.random((32, 2122945))) 773 | -------------------------------------------------------------------------------- /src/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .openbhb import FeatureExtractor, OpenBHB, bin_age -------------------------------------------------------------------------------- /src/data/masks/cat12vbm_space-MNI152_desc-gm_TPM.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EIDOSLAB/contrastive-brain-age-prediction/2fe9e7b81dd53d8f43dfeb34e41250f5450c1094/src/data/masks/cat12vbm_space-MNI152_desc-gm_TPM.nii.gz -------------------------------------------------------------------------------- /src/data/masks/quasiraw_space-MNI152_desc-brain_T1w.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EIDOSLAB/contrastive-brain-age-prediction/2fe9e7b81dd53d8f43dfeb34e41250f5450c1094/src/data/masks/quasiraw_space-MNI152_desc-brain_T1w.nii.gz -------------------------------------------------------------------------------- /src/data/openbhb.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import nibabel 4 | import torch 5 | import pandas as pd 6 | from sklearn.base import BaseEstimator 7 | from sklearn.base import TransformerMixin 8 | from collections import OrderedDict 9 | from nilearn.masking import unmask 10 | 11 | def bin_age(age_real: torch.Tensor): 12 | bins = [i for i in range(4, 92, 2)] 13 | age_binned = age_real.clone() 14 | for value in bins[::-1]: 15 | age_binned[age_real <= value] = value 16 | return age_binned.long() 17 | 18 | def read_data(path, dataset, fast): 19 | print(f"Read {dataset.upper()}") 20 | df = pd.read_csv(os.path.join(path, dataset + ".tsv"), sep="\t") 21 | df.loc[df["split"] == "external_test", "site"] = np.nan 22 | 23 | y_arr = df[["age", "site"]].values 24 | 25 | x_arr = np.zeros((10, 3659572)) 26 | if not fast: 27 | x_arr = np.load(os.path.join(path, dataset + ".npy"), mmap_mode="r") 28 | 29 | print("- y size [original]:", y_arr.shape) 30 | print("- x size [original]:", x_arr.shape) 31 | return x_arr, y_arr 32 | 33 | class OpenBHB(torch.utils.data.Dataset): 34 | def __init__(self, root, train=True, internal=True, transform=None, 35 | label="cont", fast=False, load_feats=None): 36 | self.root = root 37 | 38 | if train and not internal: 39 | raise ValueError("Invalid configuration train=True and internal=False") 40 | 41 | self.train = train 42 | self.internal = internal 43 | 44 | dataset = "train" 45 | if not train: 46 | if internal: 47 | dataset = "internal_test" 48 | else: 49 | dataset = "external_test" 50 | 51 | self.X, self.y = read_data(root, dataset, fast) 52 | self.T = transform 53 | self.label = label 54 | self.fast = fast 55 | 56 | self.bias_feats = None 57 | if load_feats: 58 | print("Loading biased features", load_feats) 59 | self.bias_feats = torch.load(load_feats, map_location="cpu") 60 | 61 | print(f"Read {len(self.X)} records") 62 | 63 | def __len__(self): 64 | return len(self.y) 65 | 66 | def __getitem__(self, index): 67 | if not self.fast: 68 | x = self.X[index] 69 | else: 70 | x = self.X[0] 71 | 72 | y = self.y[index] 73 | 74 | if self.T is not None: 75 | x = self.T(x) 76 | 77 | # sample, age, site 78 | age, site = y[0], y[1] 79 | if self.label == "bin": 80 | age = bin_age(torch.tensor(age)) 81 | 82 | if self.bias_feats is not None: 83 | return x, age, self.bias_feats[index] 84 | else: 85 | return x, age, site 86 | 87 | class FeatureExtractor(BaseEstimator, TransformerMixin): 88 | """ Select only the requested data associatedd features from the the 89 | input buffered data. 90 | """ 91 | MODALITIES = OrderedDict([ 92 | ("vbm", { 93 | "shape": (1, 121, 145, 121), 94 | "size": 519945}), 95 | ("quasiraw", { 96 | "shape": (1, 182, 218, 182), 97 | "size": 1827095}), 98 | ("xhemi", { 99 | "shape": (8, 163842), 100 | "size": 1310736}), 101 | ("vbm_roi", { 102 | "shape": (1, 284), 103 | "size": 284}), 104 | ("desikan_roi", { 105 | "shape": (7, 68), 106 | "size": 476}), 107 | ("destrieux_roi", { 108 | "shape": (7, 148), 109 | "size": 1036}) 110 | ]) 111 | MASKS = { 112 | "vbm": { 113 | "path": None, 114 | "thr": 0.05}, 115 | "quasiraw": { 116 | "path": None, 117 | "thr": 0} 118 | } 119 | 120 | def __init__(self, dtype, mock=False): 121 | """ Init class. 122 | Parameters 123 | ---------- 124 | dtype: str 125 | the requested data: 'vbm', 'quasiraw', 'vbm_roi', 'desikan_roi', 126 | 'destrieux_roi' or 'xhemi'. 127 | """ 128 | if dtype not in self.MODALITIES: 129 | raise ValueError("Invalid input data type.") 130 | self.dtype = dtype 131 | 132 | data_types = list(self.MODALITIES.keys()) 133 | index = data_types.index(dtype) 134 | 135 | cumsum = np.cumsum([item["size"] for item in self.MODALITIES.values()]) 136 | 137 | if index > 0: 138 | self.start = cumsum[index - 1] 139 | else: 140 | self.start = 0 141 | self.stop = cumsum[index] 142 | 143 | self.masks = dict((key, val["path"]) for key, val in self.MASKS.items()) 144 | self.masks["vbm"] = "./data/masks/cat12vbm_space-MNI152_desc-gm_TPM.nii.gz" 145 | self.masks["quasiraw"] = "./data/masks/quasiraw_space-MNI152_desc-brain_T1w.nii.gz" 146 | 147 | self.mock = mock 148 | if mock: 149 | return 150 | 151 | for key in self.masks: 152 | if self.masks[key] is None or not os.path.isfile(self.masks[key]): 153 | raise ValueError("Impossible to find mask:", key, self.masks[key]) 154 | arr = nibabel.load(self.masks[key]).get_fdata() 155 | thr = self.MASKS[key]["thr"] 156 | arr[arr <= thr] = 0 157 | arr[arr > thr] = 1 158 | self.masks[key] = nibabel.Nifti1Image(arr.astype(int), np.eye(4)) 159 | 160 | def fit(self, X, y): 161 | return self 162 | 163 | def transform(self, X): 164 | if self.mock: 165 | #print("transforming", X.shape) 166 | data = X.reshape(self.MODALITIES[self.dtype]["shape"]) 167 | #print("mock data:", data.shape) 168 | return data 169 | 170 | # print(X.shape) 171 | select_X = X[self.start:self.stop] 172 | if self.dtype in ("vbm", "quasiraw"): 173 | im = unmask(select_X, self.masks[self.dtype]) 174 | select_X = im.get_fdata() 175 | select_X = select_X.transpose(2, 0, 1) 176 | select_X = select_X.reshape(self.MODALITIES[self.dtype]["shape"]) 177 | # print('transformed.shape', select_X.shape) 178 | return select_X 179 | 180 | 181 | if __name__ == '__main__': 182 | import sys 183 | from torchvision import transforms 184 | from .transforms import Crop, Pad 185 | 186 | selector = FeatureExtractor("vbm") 187 | 188 | T_pre = transforms.Lambda(lambda x: selector.transform(x)) 189 | T_train = transforms.Compose([ 190 | T_pre, 191 | Crop((1, 121, 128, 121), type="random"), 192 | Pad((1, 128, 128, 128)), 193 | transforms.Lambda(lambda x: torch.from_numpy(x)), 194 | transforms.Normalize(mean=0.0, std=1.0) 195 | ]) 196 | 197 | train_loader = torch.utils.data.DataLoader(OpenBHB(sys.argv[1], train=True, internal=True, transform=T_train), 198 | batch_size=3, shuffle=True, num_workers=8, 199 | persistent_workers=True) 200 | 201 | x, y1, y2 = next(iter(train_loader)) 202 | print(x.shape, y1, y2) -------------------------------------------------------------------------------- /src/data/transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import operator 3 | import random 4 | import torch 5 | 6 | class Crop(object): 7 | """ Crop the given n-dimensional array either at a random location or 8 | centered. 9 | """ 10 | def __init__(self, shape, type="center", keep_dim=False): 11 | assert type in ["center", "random"] 12 | self.shape = shape 13 | self.cropping_type = type 14 | self.keep_dim = keep_dim 15 | 16 | def slow_crop(self, X): 17 | img_shape = np.array(X.shape) 18 | 19 | if type(self.shape) == int: 20 | size = [self.shape for _ in range(len(self.shape))] 21 | else: 22 | size = np.copy(self.shape) 23 | 24 | # print('img_shape:', img_shape, 'size', size) 25 | 26 | indexes = [] 27 | for ndim in range(len(img_shape)): 28 | if size[ndim] > img_shape[ndim] or size[ndim] < 0: 29 | size[ndim] = img_shape[ndim] 30 | 31 | if self.cropping_type == "center": 32 | delta_before = int((img_shape[ndim] - size[ndim]) / 2.0) 33 | 34 | elif self.cropping_type == "random": 35 | delta_before = np.random.randint(0, img_shape[ndim] - size[ndim] + 1) 36 | 37 | indexes.append(slice(delta_before, delta_before + size[ndim])) 38 | 39 | if self.keep_dim: 40 | mask = np.zeros(img_shape, dtype=np.bool) 41 | mask[tuple(indexes)] = True 42 | arr_copy = X.copy() 43 | arr_copy[~mask] = 0 44 | return arr_copy 45 | 46 | _X = X[tuple(indexes)] 47 | # print('cropped.shape', _X.shape) 48 | return _X 49 | 50 | def fast_crop(self, X): 51 | # X is a single image (CxWxHxZ) 52 | shape = X.shape 53 | 54 | delta = [shape[1]-self.shape[1], 55 | shape[2]-self.shape[2], 56 | shape[3]-self.shape[3]] 57 | 58 | if self.cropping_type == "center": 59 | offset = list(map(operator.floordiv, delta, [2]*len(delta))) 60 | X = X[:, offset[0]:offset[0]+self.shape[1], 61 | offset[1]:offset[1]+self.shape[2], 62 | offset[2]:offset[2]+self.shape[3]] 63 | 64 | elif self.cropping_type == "random": 65 | offset = [ 66 | int(random.random()*128) % (delta[0]+1), 67 | int(random.random()*128) % (delta[1]+1), 68 | int(random.random()*128) % (delta[2]+1) 69 | ] 70 | X = X[:, offset[0]:offset[0]+self.shape[1], 71 | offset[1]:offset[1]+self.shape[2], 72 | offset[2]:offset[2]+self.shape[3]] 73 | else: 74 | raise ValueError("Invalid cropping_type", self.cropping_type) 75 | 76 | return X 77 | 78 | def __call__(self, X): 79 | return self.fast_crop(X) 80 | 81 | class Cutout(object): 82 | """Apply a cutout on the images 83 | cf. Improved Regularization of Convolutional Neural Networks with Cutout, arXiv, 2017 84 | We assume that the square to be cut is inside the image. 85 | """ 86 | def __init__(self, patch_size=None, value=0, random_size=False, inplace=False, localization=None, probability=0.5): 87 | self.patch_size = patch_size 88 | self.value = value 89 | self.random_size = random_size 90 | self.inplace = inplace 91 | self.localization = localization 92 | self.probability = probability 93 | 94 | def __call__(self, arr): 95 | if np.random.rand() >= self.probability: 96 | return arr 97 | 98 | img_shape = np.array(arr.shape) 99 | if type(self.patch_size) == int: 100 | size = [self.patch_size for _ in range(len(img_shape))] 101 | else: 102 | size = np.copy(self.patch_size) 103 | assert len(size) == len(img_shape), "Incorrect patch dimension." 104 | indexes = [] 105 | for ndim in range(len(img_shape)): 106 | if size[ndim] > img_shape[ndim] or size[ndim] < 0: 107 | size[ndim] = img_shape[ndim] 108 | if self.random_size: 109 | size[ndim] = np.random.randint(0, size[ndim]) 110 | if self.localization is not None: 111 | delta_before = max(self.localization[ndim] - size[ndim]//2, 0) 112 | else: 113 | delta_before = np.random.randint(0, img_shape[ndim] - size[ndim] + 1) 114 | indexes.append(slice(int(delta_before), int(delta_before + size[ndim]))) 115 | if self.inplace: 116 | arr[tuple(indexes)] = self.value 117 | return arr 118 | else: 119 | arr_cut = np.copy(arr) 120 | arr_cut[tuple(indexes)] = self.value 121 | return arr_cut 122 | 123 | class Pad(object): 124 | """ Pad the given n-dimensional array 125 | """ 126 | def __init__(self, shape, **kwargs): 127 | self.shape = shape 128 | self.kwargs = kwargs 129 | 130 | def __call__(self, X): 131 | _X = self._apply_padding(X) 132 | return _X 133 | 134 | def _apply_padding(self, arr): 135 | orig_shape = arr.shape 136 | padding = [] 137 | for orig_i, final_i in zip(orig_shape, self.shape): 138 | shape_i = final_i - orig_i 139 | half_shape_i = shape_i // 2 140 | if shape_i % 2 == 0: 141 | padding.append([half_shape_i, half_shape_i]) 142 | else: 143 | padding.append([half_shape_i, half_shape_i + 1]) 144 | for cnt in range(len(arr.shape) - len(padding)): 145 | padding.append([0, 0]) 146 | fill_arr = np.pad(arr, padding, **self.kwargs) 147 | return fill_arr 148 | 149 | 150 | if __name__ == '__main__': 151 | import timeit 152 | x = np.random.rand(1, 128, 128, 128) 153 | 154 | cut = Cutout((1, 10, 10, 10), probability=1.) 155 | print(cut(x).shape) 156 | 157 | crop = Crop((1, 121, 128, 121), type="center") 158 | print(crop(x).shape) 159 | 160 | crop = Crop((1, 121, 128, 121), type="random") 161 | print(crop(x).shape) 162 | 163 | print("slow crop:", timeit.timeit(lambda: crop.slow_crop(x), number=10000)) 164 | print("fast crop:", timeit.timeit(lambda: crop.fast_crop(x), number=10000)) -------------------------------------------------------------------------------- /src/exp/mae.yaml: -------------------------------------------------------------------------------- 1 | program: main_mse.py 2 | data_dir: /scratch/data-registry/medical/openbhb 3 | save_dir: /scratch/output/brain-age-mri 4 | model: resnet18 5 | epochs: 300 6 | batch_size: 32 7 | lr: 1e-4 8 | lr_decay: step 9 | lr_decay_rate: 0.9 10 | lr_decay_step: 10 11 | optimizer: adam 12 | momentum: 0.9 13 | weight_decay: 5e-5 14 | train_all: 1 15 | trial: 0 16 | tf: none -------------------------------------------------------------------------------- /src/exp/supcon_adam_kernel.yaml: -------------------------------------------------------------------------------- 1 | program: main_infonce.py 2 | data_dir: /scratch/data-registry/medical/openbhb 3 | save_dir: /scratch/output/brain-age-mri 4 | model: resnet18 5 | epochs: 300 6 | batch_size: 32 7 | lr: 1e-4 8 | lr_decay: step 9 | lr_decay_rate: 0.9 10 | lr_decay_step: 10 11 | optimizer: adam 12 | momentum: 0.9 13 | weight_decay: 5e-5 14 | train_all: 1 15 | method: yaware 16 | kernel: gaussian 17 | sigma: 1 18 | trial: 0 19 | tf: none 20 | 21 | -------------------------------------------------------------------------------- /src/exp/supcon_sgd_kernel.yaml: -------------------------------------------------------------------------------- 1 | program: main_infonce.py 2 | data_dir: /scratch/data-registry/medical/openbhb 3 | save_dir: /scratch/output/brain-age-mri 4 | model: resnet18 5 | epochs: 300 6 | batch_size: 32 7 | lr: 0.1 8 | lr_decay: cosine 9 | optimizer: sgd 10 | momentum: 0.9 11 | weight_decay: 1e-4 12 | kernel: gaussian 13 | sigma: 1 14 | trial: 0 15 | tf: none -------------------------------------------------------------------------------- /src/figures/ablation.csv: -------------------------------------------------------------------------------- 1 | "Name","ramp/score","ramp/bacc","Created","Runtime","End Time","sigma","Hostname","ID","Notes","Updated","Tags","kernel","method","clip_grad","ramp/bacc_std","ramp/ext_mae","ramp/ext_mae_std","ramp/int_mae","ramp/int_mae_std" 2 | "resnet18_threshold_reduction_sum_adam_tfnone_lr0.0001_step_step10_rate0.9_temp0.1_wd5e-05_bsz32_views2_trainall_True_kernel_cauchy_gamma2.0_alpha1.0_lambd0.0_fklliteral_True_fklkernel_True_label_trial0","3.025333","9.2667","2022-10-28T15:04:30.000Z","495350","2022-11-03T08:40:20.000Z","2","NBDOTTI61","3jm0a5ld","-","2022-11-03T08:40:20.000Z","tested","cauchy","threshold","","1.124","6.183","0.038105","3.477333","0.021385" 3 | "resnet18_expw_adam_tfnone_lr0.0001_step_step10_rate0.9_temp0.1_wd5e-05_bsz32_views2_trainall_True_kernel_cauchy_gamma2.0_alpha1.0_lambd0.0_fklliteral_True_fklkernel_True_label_trial0","1.848667","5","2022-10-28T15:04:30.000Z","517975","2022-11-03T14:57:25.000Z","2","NBDOTTI61","i3qmz1yh","-","2022-11-03T14:57:25.000Z","tested","cauchy","expw","","0.1","4.547","0.019157","2.666","0.002646" 4 | "resnet18_yaware_adam_tfnone_lr0.0001_step_step10_rate0.9_temp0.1_wd5e-05_bsz32_views2_trainall_True_kernel_rbf_sigma2.0_alpha1.0_lambd0.0_fklliteral_True_fklkernel_True_label_trial0","1.815333","6.6","2022-10-28T15:04:30.000Z","79443","2022-10-29T13:08:33.000Z","2","NBDOTTI61","ooz7z2yo","-","2022-10-29T13:08:33.000Z","tested","rbf","yaware","","0.1732","4.102","0.009539","2.664667","0.002082" 5 | "resnet18_yaware_adam_tfnone_lr0.0001_step_step10_rate0.9_temp0.1_wd5e-05_bsz32_views2_trainall_True_kernel_cauchy_gamma2.0_alpha1.0_lambd0.0_fklliteral_True_fklkernel_True_label_trial0","2.482333","8.1667","2022-10-28T15:04:28.000Z","461295","2022-11-02T23:12:43.000Z","2","NBDOTTI61","r07q08n4","-","2022-11-02T23:12:43.000Z","tested","cauchy","yaware","","0.6658","5.267333","0.004509","3.088333","0.019425" 6 | "resnet18_expw_adam_tfnone_lr0.0001_step_step10_rate0.9_temp0.1_wd5e-05_bsz32_views2_trainall_True_kernel_rbf_sigma2.0_alpha1.0_lambd0.0_fklliteral_True_fklkernel_True_label_trial0","1.539667","5.1","2022-10-29T06:27:10.000Z","453260","2022-11-03T12:21:30.000Z","2","NBDOTTI61","w0ci971l","-","2022-11-03T12:21:30.000Z","tested","rbf","expw","","0.1","3.761","0.005","2.552","0.002" 7 | "resnet18_threshold_reduction_sum_adam_tfnone_lr0.0001_step_step10_rate0.9_temp0.1_wd5e-05_bsz32_views2_trainall_True_kernel_rbf_sigma2.0_alpha1.0_lambd0.0_fklliteral_True_fklkernel_True_label_trial0","1.738333","5.7333","2022-10-29T06:24:22.000Z","440525","2022-11-03T08:46:27.000Z","2","NBDOTTI61","ykx8p7pc","-","2022-11-03T08:46:27.000Z","tested","rbf","threshold","","0.1528","4.098333","0.009504","2.947","0.004359" 8 | "resnet18_threshold_reduction_sum_adam_tfnone_lr0.0001_step_step10_rate0.9_temp0.1_wd5e-05_bsz32_views2_trainall_True_kernel_rbf_sigma1.0_alpha1.0_lambd0.0_fklliteral_True_fklkernel_True_label_trial0","2.627","8.5","2022-10-29T05:51:32.000Z","472252","2022-11-03T17:02:24.000Z","1","NBDOTTI61","alp80xna","-","2022-11-03T17:02:24.000Z","tested","rbf","threshold","","1.0149","5.508333","0.019858","3.042667","0.011676" 9 | "resnet18_expw_adam_tfnone_lr0.0001_step_step10_rate0.9_temp0.1_wd5e-05_bsz32_views2_trainall_True_kernel_cauchy_gamma1.0_alpha1.0_lambd0.0_fklliteral_True_fklkernel_True_label_trial0","1.816","4.8667","2022-10-28T15:04:30.000Z","536997","2022-11-03T20:14:27.000Z","1","NBDOTTI61","anck0dqb","-","2022-11-03T20:14:27.000Z","tested","cauchy","expw","","0.2082","4.502667","0.010408","2.731","0.006083" 10 | "resnet18_yaware_adam_tfnone_lr0.0001_step_step10_rate0.9_temp0.1_wd5e-05_bsz32_views2_trainall_True_kernel_rbf_sigma1.0_alpha1.0_lambd0.0_fklliteral_True_fklkernel_True_label_trial0","2.428","6.6333","2022-10-28T15:04:30.000Z","180849","2022-10-30T17:18:39.000Z","1","NBDOTTI61","apyqr2sz","-","2022-10-30T17:18:39.000Z","tested","rbf","yaware","","0.8505","5.492667","0.05208","2.850333","0.006807" 11 | "resnet18_threshold_reduction_sum_adam_tfnone_lr0.0001_step_step10_rate0.9_temp0.1_wd5e-05_bsz32_views2_trainall_True_kernel_cauchy_gamma1.0_alpha1.0_lambd0.0_fklliteral_True_fklkernel_True_label_trial0","2.282333","8.5333","2022-10-28T15:04:29.000Z","534546","2022-11-03T19:33:35.000Z","1","NBDOTTI61","bsecf64m","-","2022-11-03T19:33:35.000Z","tested","cauchy","threshold","","0.3512","4.775333","0.013204","2.779667","0.010599" 12 | "resnet18_expw_adam_tfnone_lr0.0001_step_step10_rate0.9_temp0.1_wd5e-05_bsz32_views2_trainall_True_kernel_rbf_sigma1.0_alpha1.0_lambd0.0_fklliteral_True_fklkernel_True_label_trial0","1.575667","4.9667","2022-10-29T06:25:27.000Z","470279","2022-11-03T17:03:26.000Z","1","NBDOTTI61","nl49poa9","-","2022-11-03T17:03:26.000Z","tested","rbf","expw","","0.4041","3.877667","0.004041","2.823","0.002646" 13 | "resnet18_yaware_adam_tfnone_lr0.0001_step_step10_rate0.9_temp0.1_wd5e-05_bsz32_views2_trainall_True_kernel_cauchy_gamma1.0_alpha1.0_lambd0.0_fklliteral_True_fklkernel_True_label_trial0","2.147333","7.7333","2022-10-28T15:04:28.000Z","450797","2022-11-02T20:17:45.000Z","1","NBDOTTI61","w85oskw1","-","2022-11-02T20:17:45.000Z","tested","cauchy","yaware","","0.6506","4.633667","0.028113","2.710667","0.005508" -------------------------------------------------------------------------------- /src/figures/ablation.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EIDOSLAB/contrastive-brain-age-prediction/2fe9e7b81dd53d8f43dfeb34e41250f5450c1094/src/figures/ablation.pdf -------------------------------------------------------------------------------- /src/figures/ablation.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import pandas as pd 3 | import numpy as np 4 | from matplotlib import rc 5 | 6 | 7 | if __name__ == '__main__': 8 | rc('axes', titlesize=25) # fontsize of the axes title 9 | rc('axes', labelsize=20) # fontsize of the x and y labels 10 | rc('xtick', labelsize=16) # fontsize of the tick labels 11 | rc('ytick', labelsize=16) # fontsize of the tick labels 12 | rc('legend', fontsize=14) # legend fontsize 13 | rc('figure', titlesize=28) # fontsize of the figure title 14 | rc('font', size=18) 15 | # rc('font', family='Times New Roman') 16 | rc('text', usetex=True) 17 | 18 | df = pd.read_csv('ablation.csv') 19 | print(df.head()) 20 | 21 | fig = plt.figure(figsize=(20, 3)) 22 | plt.rcParams['image.cmap'] = "Set2" 23 | plt.rcParams['axes.prop_cycle'] = plt.cycler(color=plt.cm.Set2.colors) 24 | 25 | i = 1 26 | for kernel in df.kernel.unique(): 27 | for sigma in sorted(df.sigma.unique()): 28 | ax = fig.add_subplot(1, 4, i) 29 | data = df[(df.kernel == kernel) & (df.sigma == sigma)] 30 | data = data.sort_values(by='method') 31 | print(data) 32 | 33 | int_mae = data['ramp/int_mae'].values 34 | bacc = data['ramp/bacc'].values 35 | ext_mae = data['ramp/ext_mae'].values 36 | score = data['ramp/score'].values 37 | methods = data['method'].values 38 | 39 | width = 0.4 40 | x = np.arange(4)*2 41 | labels = ['Int. MAE', 'BAcc', 'Ext. MAE', 'Score'] 42 | 43 | data = np.array([[int_mae[i], bacc[i], ext_mae[i], score[i]] for i in range(3)]) 44 | 45 | if kernel == "rbf": 46 | ax.set_title(f"{kernel} ($\sigma$={sigma})") 47 | else: 48 | ax.set_title(f"{kernel} ($\gamma$={sigma})") 49 | 50 | alpha = 0.8 51 | ax.bar(x - width, data[0], width, label=methods[0], alpha=alpha) 52 | ax.bar(x, data[1], width, label=methods[1], alpha=alpha) 53 | ax.bar(x + width, data[2], width, label=methods[2], alpha=alpha) 54 | ax.set_ylim(0, 10) 55 | # ax.bar(x + width, data[3], width, label=methods[3]) 56 | 57 | if i == 4: 58 | ax.legend() 59 | ax.set_xticks(x, labels, rotation=45) 60 | 61 | i += 1 62 | # fig.tight_layout() 63 | plt.savefig('ablation.pdf', dpi=200, bbox_inches='tight', pad_inches=0) 64 | plt.show() -------------------------------------------------------------------------------- /src/launcher.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import sys 3 | import subprocess 4 | import os 5 | from yaml.loader import SafeLoader 6 | 7 | 8 | if __name__ == '__main__': 9 | if len(sys.argv) <= 1: 10 | print("Usage: ./launcher.py path/to/yaml") 11 | exit(1) 12 | 13 | with open(sys.argv[1]) as f: 14 | data = yaml.load(f, Loader=SafeLoader) 15 | 16 | program = data['program'] 17 | del data['program'] 18 | 19 | skip = False 20 | for idx, override in enumerate(sys.argv[2:]): 21 | if skip: 22 | skip = False 23 | continue 24 | 25 | if '=' in override: 26 | k, v = override.split('=') 27 | else: 28 | k = override.replace('--', '') 29 | v = sys.argv[2+idx+1] 30 | skip = True 31 | data[k] = v 32 | 33 | args = ["python3", os.path.join(os.getcwd(), program)] 34 | for k, v in data.items(): 35 | args.extend(["--" + k, str(v)]) 36 | print("Running:", ' '.join(args)) 37 | subprocess.run(args) 38 | -------------------------------------------------------------------------------- /src/losses.py: -------------------------------------------------------------------------------- 1 | from cmath import isinf 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class KernelizedSupCon(nn.Module): 8 | """Supervised contrastive loss: https://arxiv.org/pdf/2004.11362.pdf. 9 | It also supports the unsupervised contrastive loss in SimCLR 10 | Based on: https://github.com/HobbitLong/SupContrast""" 11 | def __init__(self, method: str, temperature: float=0.07, contrast_mode: str='all', 12 | base_temperature: float=0.07, kernel: callable=None, delta_reduction: str='sum'): 13 | super().__init__() 14 | self.temperature = temperature 15 | self.contrast_mode = contrast_mode 16 | self.base_temperature = base_temperature 17 | self.method = method 18 | self.kernel = kernel 19 | self.delta_reduction = delta_reduction 20 | 21 | if kernel is not None and method == 'supcon': 22 | raise ValueError('Kernel must be none if method=supcon') 23 | 24 | if kernel is None and method != 'supcon': 25 | raise ValueError('Kernel must not be none if method != supcon') 26 | 27 | if delta_reduction not in ['mean', 'sum']: 28 | raise ValueError(f"Invalid reduction {delta_reduction}") 29 | 30 | def __repr__(self): 31 | return f'{self.__class__.__name__} ' \ 32 | f'(t={self.temperature}, ' \ 33 | f'method={self.method}, ' \ 34 | f'kernel={self.kernel is not None}, ' \ 35 | f'delta_reduction={self.delta_reduction})' 36 | 37 | def forward(self, features, labels=None): 38 | """Compute loss for model. If `labels` is None, 39 | it degenerates to SimCLR unsupervised loss: 40 | https://arxiv.org/pdf/2002.05709.pdf 41 | 42 | Args: 43 | features: hidden vector of shape [bsz, n_views, n_features]. 44 | input has to be rearranged to [bsz, n_views, n_features] and labels [bsz], 45 | labels: ground truth of shape [bsz]. 46 | Returns: 47 | A loss scalar. 48 | """ 49 | device = features.device 50 | 51 | if len(features.shape) != 3: 52 | raise ValueError('`features` needs to be [bsz, n_views, n_feats],' 53 | '3 dimensions are required') 54 | 55 | batch_size = features.shape[0] 56 | n_views = features.shape[1] 57 | 58 | if labels is None: 59 | mask = torch.eye(batch_size, device=device) 60 | 61 | else: 62 | labels = labels.view(-1, 1) 63 | if labels.shape[0] != batch_size: 64 | raise ValueError('Num of labels does not match num of features') 65 | 66 | if self.kernel is None: 67 | mask = torch.eq(labels, labels.T) 68 | else: 69 | mask = self.kernel(labels) 70 | 71 | view_count = features.shape[1] 72 | features = torch.cat(torch.unbind(features, dim=1), dim=0) 73 | if self.contrast_mode == 'one': 74 | features = features[:, 0] 75 | anchor_count = 1 76 | elif self.contrast_mode == 'all': 77 | features = features 78 | anchor_count = view_count 79 | else: 80 | raise ValueError('Unknown mode: {}'.format(self.contrast_mode)) 81 | 82 | # Tile mask 83 | mask = mask.repeat(anchor_count, view_count) 84 | 85 | # Inverse of torch-eye to remove self-contrast (diagonal) 86 | inv_diagonal = torch.scatter( 87 | torch.ones_like(mask), 88 | 1, 89 | torch.arange(batch_size*n_views, device=device).view(-1, 1), 90 | 0 91 | ) 92 | 93 | # compute similarity 94 | anchor_dot_contrast = torch.div( 95 | torch.matmul(features, features.T), 96 | self.temperature 97 | ) 98 | 99 | # for numerical stability 100 | logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) 101 | logits = anchor_dot_contrast - logits_max.detach() 102 | 103 | alignment = logits 104 | 105 | # base case is: 106 | # - supcon if kernel = none 107 | # - y-aware is kernel != none 108 | uniformity = torch.exp(logits) * inv_diagonal 109 | 110 | if self.method == 'threshold': 111 | repeated = mask.unsqueeze(-1).repeat(1, 1, mask.shape[0]) # repeat kernel mask 112 | 113 | delta = (mask[:, None].T - repeated.T).transpose(1, 2) # compute the difference w_k - w_j for every k,j 114 | delta = (delta > 0.).float() 115 | 116 | # for each z_i, repel only samples j s.t. K(z_i, z_j) < K(z_i, z_k) 117 | uniformity = uniformity.unsqueeze(-1).repeat(1, 1, mask.shape[0]) 118 | 119 | if self.delta_reduction == 'mean': 120 | uniformity = (uniformity * delta).mean(-1) 121 | else: 122 | uniformity = (uniformity * delta).sum(-1) 123 | 124 | elif self.method == 'expw': 125 | # exp weight e^(s_j(1-w_j)) 126 | uniformity = torch.exp(logits * (1 - mask)) * inv_diagonal 127 | 128 | uniformity = torch.log(uniformity.sum(1, keepdim=True)) 129 | 130 | 131 | # positive mask contains the anchor-positive pairs 132 | # excluding on the diagonal 133 | positive_mask = mask * inv_diagonal 134 | 135 | log_prob = alignment - uniformity # log(alignment/uniformity) = log(alignment) - log(uniformity) 136 | log_prob = (positive_mask * log_prob).sum(1) / positive_mask.sum(1) # compute mean of log-likelihood over positive 137 | 138 | # loss 139 | loss = - (self.temperature / self.base_temperature) * log_prob 140 | return loss.mean() 141 | 142 | 143 | if __name__ == '__main__': 144 | k_supcon = KernelizedSupCon(1.0) 145 | 146 | x = torch.nn.functional.normalize(torch.randn((256, 2, 64)), dim=1) 147 | labels = torch.randint(0, 4, (256,)) 148 | 149 | l = k_supcon(x, labels) 150 | print(l) -------------------------------------------------------------------------------- /src/main_infonce.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import math 3 | import os 4 | from random import gauss 5 | import numpy as np 6 | import torch 7 | import torch.nn.functional as F 8 | import torch.utils.data 9 | import torchvision 10 | import argparse 11 | import models 12 | import losses 13 | import time 14 | import wandb 15 | import torch.utils.tensorboard 16 | 17 | from torch import nn 18 | from torchvision import transforms 19 | from torchvision import datasets 20 | from util import AverageMeter, NViewTransform, ensure_dir, set_seed, arg2bool, save_model 21 | from util import warmup_learning_rate, adjust_learning_rate 22 | from util import compute_age_mae, compute_site_ba 23 | from data import FeatureExtractor, OpenBHB, bin_age 24 | from data.transforms import Crop, Pad, Cutout 25 | from main_mse import get_transforms 26 | 27 | 28 | def parse_arguments(): 29 | parser = argparse.ArgumentParser(description="Weakly contrastive learning for brain age predictin", 30 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 31 | 32 | # Misc 33 | parser.add_argument('--device', type=str, help='torch device', default='cuda') 34 | parser.add_argument('--print_freq', type=int, help='print frequency', default=10) 35 | parser.add_argument('--trial', type=int, help='random seed / trial id', default=0) 36 | parser.add_argument('--save_dir', type=str, help='output dir', default='output') 37 | parser.add_argument('--save_freq', type=int, help='save frequency', default=50) 38 | parser.add_argument('--data_dir', type=str, help='path of data dir', default='/data') 39 | parser.add_argument('--amp', type=arg2bool, help='use amp', default=False) 40 | parser.add_argument('--clip_grad', type=arg2bool, help='clip gradient to prevent nan', default=False) 41 | 42 | # Model 43 | parser.add_argument('--model', type=str, help='model architecture', default='resnet18') 44 | 45 | # Optimizer 46 | parser.add_argument('--epochs', type=int, help='number of epochs', default=300) 47 | parser.add_argument('--batch_size', type=int, help='batch size', default=256) 48 | parser.add_argument('--lr', type=float, help='learning rate', default=1e-4) 49 | parser.add_argument('--lr_decay', type=str, help='type of decay', choices=['cosine', 'step'], default='step') 50 | parser.add_argument('--lr_decay_rate', type=float, default=0.9, help='decay rate for learning rate (for step)') 51 | parser.add_argument('--lr_decay_epochs', type=str, help='steps of lr decay (list)', default="700,800,900") 52 | parser.add_argument('--lr_decay_step', type=int, help='decay rate step (overwrites lr_decay_epochs', default=10) 53 | parser.add_argument('--warm', type=arg2bool, help='warmup lr', default=False) 54 | parser.add_argument('--optimizer', type=str, help="optimizer (adam or sgd)", choices=["adam", "sgd"], default="adam") 55 | parser.add_argument('--momentum', type=float, help='momentum', default=0.9) 56 | parser.add_argument('--weight_decay', type=float, help='weight decay', default=5e-5) 57 | 58 | # Data 59 | parser.add_argument('--train_all', type=arg2bool, help='train on all dataset including validation (int+ext)', default=True) 60 | parser.add_argument('--tf', type=str, help='data augmentation', choices=['none', 'crop', 'cutout', 'all'], default='none') 61 | 62 | # Loss 63 | parser.add_argument('--method', type=str, help='loss function', choices=['supcon', 'yaware', 'threshold', 'expw'], default='supcon') 64 | parser.add_argument('--kernel', type=str, help='Kernel function (not for supcon)', choices=['cauchy', 'gaussian', 'rbf'], default=None) 65 | parser.add_argument('--delta_reduction', type=str, help='use mean or sum to reduce 3d delta mask (only for method=threshold)', default='sum') 66 | parser.add_argument('--temp', type=float, help='loss temperature', default=0.1) 67 | parser.add_argument('--alpha', type=float, help='infonce weight', default=1.) 68 | parser.add_argument('--sigma', type=float, help='gaussian-rbf kernel sigma / cauchy gamma', default=1) 69 | parser.add_argument('--n_views', type=int, help='num. of multiviews', default=2) 70 | 71 | opts = parser.parse_args() 72 | 73 | if opts.batch_size > 256: 74 | print("Forcing warm") 75 | opts.warm = True 76 | 77 | if opts.lr_decay_step is not None: 78 | opts.lr_decay_epochs = list(range(opts.lr_decay_step, opts.epochs, opts.lr_decay_step)) 79 | print(f"Computed decay epochs based on step ({opts.lr_decay_step}):", opts.lr_decay_epochs) 80 | else: 81 | iterations = opts.lr_decay_epochs.split(',') 82 | opts.lr_decay_epochs = list([]) 83 | for it in iterations: 84 | opts.lr_decay_epochs.append(int(it)) 85 | 86 | if opts.warm: 87 | opts.warmup_from = 0.01 88 | opts.warm_epochs = 10 89 | if opts.lr_decay == 'cosine': 90 | eta_min = opts.lr * (opts.lr_decay_rate ** 3) 91 | opts.warmup_to = eta_min + (opts.lr - eta_min) * ( 92 | 1 + math.cos(math.pi * opts.warm_epochs / opts.epochs)) / 2 93 | else: 94 | opts.milestones = [int(s) for s in opts.lr_decay_epochs.split(',')] 95 | opts.warmup_to = opts.lr 96 | 97 | if opts.method == 'supcon': 98 | print('method == supcon, binning age') 99 | opts.label = 'bin' 100 | else: 101 | print('method != supcon, using real age value') 102 | opts.label = 'cont' 103 | 104 | if opts.method == 'supcon' and opts.kernel is not None: 105 | print('Invalid kernel for supcon') 106 | exit(0) 107 | 108 | if opts.method != 'supcon' and opts.kernel is None: 109 | print('Kernel cannot be None for method != supcon') 110 | exit(1) 111 | 112 | if opts.model == 'densenet121': 113 | opts.n_views = 1 114 | 115 | return opts 116 | 117 | def load_data(opts): 118 | T_train, T_test = get_transforms(opts) 119 | T_train = NViewTransform(T_train, opts.n_views) 120 | 121 | train_dataset = OpenBHB(opts.data_dir, train=True, internal=True, transform=T_train, label=opts.label, 122 | load_feats=opts.biased_features) 123 | if opts.train_all: 124 | valint_feats, valext_feats = None, None 125 | if opts.biased_features is not None: 126 | valint_feats = opts.biased_features.replace('.pth', '_valint.pth') 127 | valext_feats = opts.biased_features.replace('.pth', '_valext.pth') 128 | 129 | valint = OpenBHB(opts.data_dir, train=False, internal=True, transform=T_train, 130 | label=opts.label, load_feats=valint_feats) 131 | valext = OpenBHB(opts.data_dir, train=False, internal=False, transform=T_train, 132 | label=opts.label, load_feats=valext_feats) 133 | train_dataset = torch.utils.data.ConcatDataset([train_dataset, valint, valext]) 134 | print("Total dataset length:", len(train_dataset)) 135 | 136 | 137 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=opts.batch_size, shuffle=True, num_workers=8, 138 | persistent_workers=True) 139 | train_loader_score = torch.utils.data.DataLoader(OpenBHB(opts.data_dir, train=True, internal=True, transform=T_train, label=opts.label), 140 | batch_size=opts.batch_size, shuffle=True, num_workers=8, 141 | persistent_workers=True) 142 | test_internal = torch.utils.data.DataLoader(OpenBHB(opts.data_dir, train=False, internal=True, transform=T_test), 143 | batch_size=opts.batch_size, shuffle=False, num_workers=8, 144 | persistent_workers=True) 145 | test_external = torch.utils.data.DataLoader(OpenBHB(opts.data_dir, train=False, internal=False, transform=T_test), 146 | batch_size=opts.batch_size, shuffle=False, num_workers=8, 147 | persistent_workers=True) 148 | return train_loader, train_loader_score, test_internal, test_external 149 | 150 | def load_model(opts): 151 | if 'resnet' in opts.model: 152 | model = models.SupConResNet(opts.model, feat_dim=128) 153 | elif 'alexnet' in opts.model: 154 | model = models.SupConAlexNet(feat_dim=128) 155 | elif 'densenet121' in opts.model: 156 | model = models.SupConDenseNet(feat_dim=128) 157 | 158 | else: 159 | raise ValueError("Unknown model", opts.model) 160 | 161 | if opts.device == 'cuda' and torch.cuda.device_count() > 1: 162 | print(f"Using multiple CUDA devices ({torch.cuda.device_count()})") 163 | model = torch.nn.DataParallel(model) 164 | model = model.to(opts.device) 165 | 166 | 167 | def gaussian_kernel(x): 168 | x = x - x.T 169 | return torch.exp(-(x**2) / (2*(opts.sigma**2))) / (math.sqrt(2*torch.pi)*opts.sigma) 170 | 171 | def rbf(x): 172 | x = x - x.T 173 | return torch.exp(-(x**2)/(2*(opts.sigma**2))) 174 | 175 | def cauchy(x): 176 | x = x - x.T 177 | return 1. / (opts.sigma*(x**2) + 1) 178 | 179 | kernels = { 180 | 'none': None, 181 | 'cauchy': cauchy, 182 | 'gaussian': gaussian_kernel, 183 | 'rbf': rbf 184 | } 185 | 186 | infonce = losses.KernelizedSupCon(method=opts.method, temperature=opts.temp, 187 | kernel=kernels[opts.kernel], delta_reduction=opts.delta_reduction) 188 | infonce = infonce.to(opts.device) 189 | 190 | 191 | return model, infonce 192 | 193 | def load_optimizer(model, opts): 194 | if opts.optimizer == "sgd": 195 | optimizer = torch.optim.SGD(model.parameters(), lr=opts.lr, 196 | momentum=opts.momentum, 197 | weight_decay=opts.weight_decay) 198 | else: 199 | optimizer = torch.optim.Adam(model.parameters(), lr=opts.lr, weight_decay=opts.weight_decay) 200 | 201 | return optimizer 202 | 203 | def train(train_loader, model, infonce, optimizer, opts, epoch): 204 | loss = AverageMeter() 205 | batch_time = AverageMeter() 206 | data_time = AverageMeter() 207 | 208 | scaler = torch.cuda.amp.GradScaler() if opts.amp else None 209 | model.train() 210 | 211 | t1 = time.time() 212 | for idx, (images, labels, _) in enumerate(train_loader): 213 | data_time.update(time.time() - t1) 214 | 215 | images = torch.cat(images, dim=0).to(opts.device) 216 | bsz = labels.shape[0] 217 | 218 | warmup_learning_rate(opts, epoch, idx, len(train_loader), optimizer) 219 | 220 | with torch.cuda.amp.autocast(scaler is not None): 221 | projected = model(images) 222 | projected = torch.split(projected, [bsz]*opts.n_views, dim=0) 223 | projected = torch.cat([f.unsqueeze(1) for f in projected], dim=1) 224 | running_loss = infonce(projected, labels.to(opts.device)) 225 | 226 | optimizer.zero_grad() 227 | if scaler is None: 228 | running_loss.backward() 229 | if opts.clip_grad: 230 | nn.utils.clip_grad_norm_(model.parameters(), max_norm=1) 231 | optimizer.step() 232 | else: 233 | scaler.scale(running_loss).backward() 234 | if opts.clip_grad: 235 | scaler.unscale_(optimizer) 236 | nn.utils.clip_grad_norm_(model.parameters(), max_norm=1) 237 | scaler.step(optimizer) 238 | scaler.update() 239 | 240 | loss.update(running_loss.item(), bsz) 241 | batch_time.update(time.time() - t1) 242 | t1 = time.time() 243 | eta = batch_time.avg * (len(train_loader) - idx) 244 | 245 | if (idx + 1) % opts.print_freq == 0: 246 | print(f"Train: [{epoch}][{idx + 1}/{len(train_loader)}]:\t" 247 | f"BT {batch_time.avg:.3f}\t" 248 | f"ETA {datetime.timedelta(seconds=eta)}\t" 249 | f"loss {loss.avg:.3f}\t") 250 | 251 | return loss.avg, batch_time.avg, data_time.avg 252 | 253 | if __name__ == '__main__': 254 | opts = parse_arguments() 255 | 256 | set_seed(opts.trial) 257 | 258 | train_loader, train_loader_score, test_loader_int, test_loader_ext = load_data(opts) 259 | model, infonce = load_model(opts) 260 | optimizer = load_optimizer(model, opts) 261 | 262 | model_name = opts.model 263 | if opts.warm: 264 | model_name = f"{model_name}_warm" 265 | if opts.amp: 266 | model_name = f"{model_name}_amp" 267 | 268 | method_name = opts.method 269 | if opts.method == 'threshold': 270 | method_name = f"{method_name}_reduction_{opts.delta_reduction}" 271 | 272 | optimizer_name = opts.optimizer 273 | if opts.clip_grad: 274 | optimizer_name = f"{optimizer_name}_clipgrad" 275 | 276 | kernel_name = opts.kernel 277 | if opts.kernel == "gaussian" or opts.kernel == 'rbf': 278 | kernel_name = f"{kernel_name}_sigma{opts.sigma}" 279 | elif opts.kernel == 'cauchy': 280 | kernel_name = f"{kernel_name}_gamma{opts.sigma}" 281 | 282 | run_name = (f"{model_name}_{method_name}_" 283 | f"{optimizer_name}_" 284 | f"tf{opts.tf}_" 285 | f"lr{opts.lr}_{opts.lr_decay}_step{opts.lr_decay_step}_rate{opts.lr_decay_rate}_" 286 | f"temp{opts.temp}_" 287 | f"wd{opts.weight_decay}_" 288 | f"bsz{opts.batch_size}_views{opts.n_views}_" 289 | f"trainall_{opts.train_all}_" 290 | f"kernel_{kernel_name}_" 291 | f"f{opts.alpha}_lambd{opts.lambd}_" 292 | f"trial{opts.trial}") 293 | tb_dir = os.path.join(opts.save_dir, "tensorboard", run_name) 294 | save_dir = os.path.join(opts.save_dir, f"openbhb_models", run_name) 295 | ensure_dir(tb_dir) 296 | ensure_dir(save_dir) 297 | 298 | opts.model_class = model.__class__.__name__ 299 | opts.criterion = infonce.__class__.__name__ 300 | opts.optimizer_class = optimizer.__class__.__name__ 301 | 302 | wandb.init(project="brain-age-prediction", config=opts, name=run_name, sync_tensorboard=True, 303 | settings=wandb.Settings(code_dir="/src"), tags=['to test']) 304 | wandb.run.log_code(root="/src", include_fn=lambda path: path.endswith(".py")) 305 | 306 | print('Config:', opts) 307 | print('Model:', model.__class__.__name__) 308 | print('Criterion:', infonce) 309 | print('Optimizer:', optimizer) 310 | print('Scheduler:', opts.lr_decay) 311 | 312 | writer = torch.utils.tensorboard.writer.SummaryWriter(tb_dir) 313 | if opts.amp: 314 | print("Using AMP") 315 | 316 | start_time = time.time() 317 | best_acc = 0. 318 | for epoch in range(1, opts.epochs + 1): 319 | adjust_learning_rate(opts, optimizer, epoch) 320 | 321 | t1 = time.time() 322 | loss_train, batch_time, data_time = train(train_loader, model, infonce, optimizer, opts, epoch) 323 | t2 = time.time() 324 | writer.add_scalar("train/loss", loss_train, epoch) 325 | 326 | writer.add_scalar("lr", optimizer.param_groups[0]['lr'], epoch) 327 | writer.add_scalar("BT", batch_time, epoch) 328 | writer.add_scalar("DT", data_time, epoch) 329 | print(f"epoch {epoch}, total time {t2-start_time:.2f}, epoch time {t2-t1:.3f} loss {loss_train:.4f}") 330 | 331 | if epoch % opts.save_freq == 0: 332 | # save_file = os.path.join(save_dir, f"ckpt_epoch_{epoch}.pth") 333 | # save_model(model, optimizer, opts, epoch, save_file) 334 | 335 | mae_train, mae_int, mae_ext = compute_age_mae(model, train_loader_score, test_loader_int, test_loader_ext, opts) 336 | writer.add_scalar("train/mae", mae_train, epoch) 337 | writer.add_scalar("test/mae_int", mae_int, epoch) 338 | writer.add_scalar("test/mae_ext", mae_ext, epoch) 339 | print("Age MAE:", mae_train, mae_int, mae_ext) 340 | 341 | ba_train, ba_int, ba_ext = compute_site_ba(model, train_loader_score, test_loader_int, test_loader_ext, opts) 342 | writer.add_scalar("train/site_ba", ba_train, epoch) 343 | writer.add_scalar("test/ba_int", ba_int, epoch) 344 | writer.add_scalar("test/ba_ext", ba_ext, epoch) 345 | print("Site BA:", ba_train, ba_int, ba_ext) 346 | 347 | challenge_metric = ba_int**0.3 * mae_ext 348 | writer.add_scalar("test/score", challenge_metric, epoch) 349 | print("Challenge score", challenge_metric) 350 | 351 | save_file = os.path.join(save_dir, f"weights.pth") 352 | save_model(model, optimizer, opts, epoch, save_file) 353 | 354 | mae_train, mae_int, mae_ext = compute_age_mae(model, train_loader_score, test_loader_int, test_loader_ext, opts) 355 | writer.add_scalar("train/mae", mae_train, epoch) 356 | writer.add_scalar("test/mae_int", mae_int, epoch) 357 | writer.add_scalar("test/mae_ext", mae_ext, epoch) 358 | print("Age MAE:", mae_train, mae_int, mae_ext) 359 | 360 | ba_train, ba_int, ba_ext = compute_site_ba(model, train_loader_score, test_loader_int, test_loader_ext, opts) 361 | writer.add_scalar("train/site_ba", ba_train, epoch) 362 | writer.add_scalar("test/ba_int", ba_int, epoch) 363 | writer.add_scalar("test/ba_ext", ba_ext, epoch) 364 | print("Site BA:", ba_train, ba_int, ba_ext) 365 | 366 | challenge_metric = ba_int**0.3 * mae_ext 367 | writer.add_scalar("test/score", challenge_metric, epoch) 368 | print("Challenge score", challenge_metric) -------------------------------------------------------------------------------- /src/main_mse.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import math 3 | import os 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | import torch.utils.data 8 | import argparse 9 | import models 10 | import losses 11 | import time 12 | import wandb 13 | import torch.utils.tensorboard 14 | 15 | from torchvision import transforms 16 | from util import AverageMeter, MAE, ensure_dir, set_seed, arg2bool, save_model 17 | from util import warmup_learning_rate, adjust_learning_rate 18 | from util import compute_age_mae, compute_site_ba 19 | from data import FeatureExtractor, OpenBHB, bin_age 20 | from data.transforms import Crop, Pad, Cutout 21 | 22 | def parse_arguments(): 23 | parser = argparse.ArgumentParser(description="Augmentation for multiview", 24 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 25 | 26 | parser.add_argument('--device', type=str, help='torch device', default='cuda') 27 | parser.add_argument('--print_freq', type=int, help='print frequency', default=10) 28 | parser.add_argument('--trial', type=int, help='random seed / trial id', default=0) 29 | parser.add_argument('--save_dir', type=str, help='output dir', default='output') 30 | parser.add_argument('--save_freq', type=int, help='save frequency', default=50) 31 | 32 | parser.add_argument('--data_dir', type=str, help='path of data dir', default='/data') 33 | parser.add_argument('--batch_size', type=int, help='batch size', default=256) 34 | 35 | parser.add_argument('--epochs', type=int, help='number of epochs', default=200) 36 | parser.add_argument('--lr', type=float, help='learning rate', default=0.1) 37 | parser.add_argument('--lr_decay', type=str, help='type of decay', choices=['cosine', 'step'], default='cosine') 38 | parser.add_argument('--lr_decay_rate', type=float, default=0.1, help='decay rate for learning rate (for step)') 39 | parser.add_argument('--lr_decay_epochs', type=str, help='steps of lr decay (list)', default="700,800,900") 40 | parser.add_argument('--lr_decay_step', type=int, help='decay rate step (overwrites lr_decay_epochs', default=None) 41 | 42 | parser.add_argument('--warm', type=arg2bool, help='warmup lr', default=False) 43 | parser.add_argument('--optimizer', type=str, help="optimizer (adam or sgd)", choices=["adam", "sgd"], default="sgd") 44 | parser.add_argument('--momentum', type=float, help='momentum', default=0.9) 45 | parser.add_argument('--weight_decay', type=float, help='weight decay', default=1e-4) 46 | 47 | parser.add_argument('--model', type=str, help='model architecture', default='resnet18') 48 | 49 | parser.add_argument('--method', type=str, help='loss function', choices=['mae', 'mse'], default='mae') 50 | 51 | 52 | parser.add_argument('--train_all', type=arg2bool, help='train on all dataset including validation (int+ext)', default=False) 53 | parser.add_argument('--tf', type=str, help='data augmentation', choices=['none', 'crop', 'cutout', 'all'], default='none') 54 | 55 | parser.add_argument('--amp', action='store_true', help='use amp') 56 | 57 | opts = parser.parse_args() 58 | 59 | if opts.batch_size > 256: 60 | print("Forcing warm") 61 | opts.warm = True 62 | 63 | if opts.lr_decay_step is not None: 64 | opts.lr_decay_epochs = list(range(opts.lr_decay_step, opts.epochs, opts.lr_decay_step)) 65 | print(f"Computed decay epochs based on step ({opts.lr_decay_step}):", opts.lr_decay_epochs) 66 | else: 67 | iterations = opts.lr_decay_epochs.split(',') 68 | opts.lr_decay_epochs = list([]) 69 | for it in iterations: 70 | opts.lr_decay_epochs.append(int(it)) 71 | 72 | if opts.warm: 73 | opts.warmup_from = 0.01 74 | opts.warm_epochs = 10 75 | if opts.lr_decay == 'cosine': 76 | eta_min = opts.lr * (opts.lr_decay_rate ** 3) 77 | opts.warmup_to = eta_min + (opts.lr - eta_min) * ( 78 | 1 + math.cos(math.pi * opts.warm_epochs / opts.epochs)) / 2 79 | else: 80 | opts.milestones = [int(s) for s in opts.lr_decay_epochs.split(',')] 81 | opts.warmup_to = opts.lr 82 | 83 | opts.fairkl_kernel = opts.kernel != 'none' 84 | return opts 85 | 86 | def get_transforms(opts): 87 | selector = FeatureExtractor("vbm") 88 | 89 | if opts.tf == 'none': 90 | aug = transforms.Lambda(lambda x: x) 91 | 92 | elif opts.tf == 'crop': 93 | aug = transforms.Compose([ 94 | Crop((1, 121, 128, 121), type="random"), 95 | Pad((1, 128, 128, 128)) 96 | ]) 97 | 98 | elif opts.tf == 'cutout': 99 | aug = Cutout(patch_size=[1, 32, 32, 32], probability=0.5) 100 | 101 | elif opts.tf == 'all': 102 | aug = transforms.Compose([ 103 | Cutout(patch_size=[1, 32, 32, 32], probability=0.5), 104 | Crop((1, 121, 128, 121), type="random"), 105 | Pad((1, 128, 128, 128)) 106 | ]) 107 | 108 | T_pre = transforms.Lambda(lambda x: selector.transform(x)) 109 | T_train = transforms.Compose([ 110 | T_pre, 111 | aug, 112 | transforms.Lambda(lambda x: torch.from_numpy(x).float()), 113 | transforms.Normalize(mean=0.0, std=1.0) 114 | ]) 115 | 116 | T_test = transforms.Compose([ 117 | T_pre, 118 | transforms.Lambda(lambda x: torch.from_numpy(x).float()), 119 | transforms.Normalize(mean=0.0, std=1.0) 120 | ]) 121 | 122 | return T_train, T_test 123 | 124 | 125 | def load_data(opts): 126 | T_train, T_test = get_transforms(opts) 127 | 128 | train_dataset = OpenBHB(opts.data_dir, train=True, internal=True, transform=T_train, 129 | load_feats=opts.biased_features) 130 | 131 | if opts.train_all: 132 | valint_feats, valext_feats = None, None 133 | if opts.biased_features is not None: 134 | valint_feats = opts.biased_features.replace('.pth', '_valint.pth') 135 | valext_feats = opts.biased_features.replace('.pth', '_valext.pth') 136 | 137 | valint = OpenBHB(opts.data_dir, train=False, internal=True, transform=T_train, 138 | load_feats=valint_feats) 139 | valext = OpenBHB(opts.data_dir, train=False, internal=False, transform=T_train, 140 | load_feats=valext_feats) 141 | train_dataset = torch.utils.data.ConcatDataset([train_dataset, valint, valext]) 142 | print("Total dataset lenght:", len(train_dataset)) 143 | 144 | 145 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=opts.batch_size, shuffle=True, 146 | num_workers=8, persistent_workers=True) 147 | 148 | test_internal = torch.utils.data.DataLoader(OpenBHB(opts.data_dir, train=False, internal=True, transform=T_test), 149 | batch_size=opts.batch_size, shuffle=False, num_workers=8, 150 | persistent_workers=True) 151 | test_external = torch.utils.data.DataLoader(OpenBHB(opts.data_dir, train=False, internal=False, transform=T_test), 152 | batch_size=opts.batch_size, shuffle=False, num_workers=8, 153 | persistent_workers=True) 154 | 155 | return train_loader, test_internal, test_external 156 | 157 | def load_model(opts): 158 | if 'resnet' in opts.model: 159 | model = models.SupRegResNet(opts.model) 160 | 161 | elif 'alexnet' in opts.model: 162 | model = models.SupRegAlexNet() 163 | 164 | elif 'densenet121' in opts.model: 165 | model = models.SupRegDenseNet() 166 | 167 | else: 168 | raise ValueError("Unknown model", opts.model) 169 | 170 | if opts.device == 'cuda' and torch.cuda.device_count() > 1: 171 | print(f"Using multiple CUDA devices ({torch.cuda.device_count()})") 172 | model = torch.nn.DataParallel(model) 173 | model = model.to(opts.device) 174 | 175 | methods = { 176 | 'mae': F.l1_loss, 177 | 'mse': F.mse_loss 178 | } 179 | regression_loss = methods[opts.method] 180 | 181 | return model, regression_loss 182 | 183 | def load_optimizer(model, opts): 184 | if opts.optimizer == "sgd": 185 | optimizer = torch.optim.SGD(model.parameters(), lr=opts.lr, 186 | momentum=opts.momentum, 187 | weight_decay=opts.weight_decay) 188 | else: 189 | optimizer = torch.optim.Adam(model.parameters(), lr=opts.lr, weight_decay=opts.weight_decay) 190 | 191 | return optimizer 192 | 193 | def train(train_loader, model, criterion, optimizer, opts, epoch): 194 | loss = AverageMeter() 195 | mae = MAE() 196 | 197 | batch_time = AverageMeter() 198 | data_time = AverageMeter() 199 | 200 | scaler = torch.cuda.amp.GradScaler() if opts.amp else None 201 | model.train() 202 | 203 | t1 = time.time() 204 | for idx, (images, labels, _) in enumerate(train_loader): 205 | data_time.update(time.time() - t1) 206 | 207 | images, labels = images.to(opts.device), labels.to(opts.device) 208 | bsz = labels.shape[0] 209 | 210 | warmup_learning_rate(opts, epoch, idx, len(train_loader), optimizer) 211 | 212 | with torch.cuda.amp.autocast(scaler is not None): 213 | output, features = model(images) 214 | output = output.view(-1) 215 | running_loss = criterion(output, features, labels.float()) 216 | 217 | optimizer.zero_grad() 218 | if scaler is None: 219 | running_loss.backward() 220 | optimizer.step() 221 | else: 222 | scaler.scale(running_loss).backward() 223 | scaler.step(optimizer) 224 | scaler.update() 225 | 226 | loss.update(running_loss.item(), bsz) 227 | mae.update(output, labels) 228 | 229 | batch_time.update(time.time() - t1) 230 | eta = batch_time.avg * (len(train_loader) - idx) 231 | 232 | if (idx + 1) % opts.print_freq == 0: 233 | print(f"Train: [{epoch}][{idx + 1}/{len(train_loader)}]:\t" 234 | f"BT {batch_time.avg:.3f}\t" 235 | f"ETA {datetime.timedelta(seconds=eta)}\t" 236 | f"loss {loss.avg:.3f}\t" 237 | f"MAE {mae.avg:.3f}") 238 | 239 | t1 = time.time() 240 | 241 | return loss.avg, mae.avg, batch_time.avg, data_time.avg 242 | 243 | @torch.no_grad() 244 | def test(test_loader, model, criterion, opts, epoch): 245 | loss = AverageMeter() 246 | mae = MAE() 247 | batch_time = AverageMeter() 248 | 249 | model.eval() 250 | t1 = time.time() 251 | for idx, (images, labels, _) in enumerate(test_loader): 252 | images, labels = images.to(opts.device), labels.to(opts.device) 253 | bsz = labels.shape[0] 254 | 255 | output, features = model(images) 256 | output = output.view(-1) 257 | running_loss = criterion(output, features, labels.float()) 258 | 259 | loss.update(running_loss.item(), bsz) 260 | mae.update(output, labels) 261 | 262 | batch_time.update(time.time() - t1) 263 | eta = batch_time.avg * (len(train_loader) - idx) 264 | 265 | if (idx + 1) % opts.print_freq == 0: 266 | print(f"Test: [{epoch}][{idx + 1}/{len(train_loader)}]:\t" 267 | f"BT {batch_time.avg:.3f}\t" 268 | f"ETA {datetime.timedelta(seconds=eta)}\t" 269 | f"loss {loss.avg:.3f}\t" 270 | f"MAE {mae.avg:.3f}") 271 | 272 | t1 = time.time() 273 | 274 | return loss.avg, mae.avg 275 | 276 | if __name__ == '__main__': 277 | opts = parse_arguments() 278 | 279 | set_seed(opts.trial) 280 | 281 | train_loader, test_loader_int, test_loader_ext = load_data(opts) 282 | model, criterion = load_model(opts) 283 | optimizer = load_optimizer(model, opts) 284 | 285 | model_name = opts.model 286 | if opts.warm: 287 | model_name = f"{model_name}_warm" 288 | 289 | 290 | run_name = (f"{model_name}_{opts.method}_" 291 | f"{opts.optimizer}_" 292 | f"tf_{opts.tf}_" 293 | f"lr{opts.lr}_{opts.lr_decay}_step{opts.lr_decay_step}_rate{opts.lr_decay_rate}_" 294 | f"wd{opts.weight_decay}_" 295 | f"trainall_{opts.train_all}_" 296 | f"bsz{opts.batch_size}_" 297 | f"trial{opts.trial}") 298 | tb_dir = os.path.join(opts.save_dir, "tensorboard", run_name) 299 | save_dir = os.path.join(opts.save_dir, f"openbhb_models", run_name) 300 | ensure_dir(tb_dir) 301 | ensure_dir(save_dir) 302 | 303 | opts.model_class = model.__class__.__name__ 304 | opts.criterion = opts.method 305 | opts.optimizer_class = optimizer.__class__.__name__ 306 | 307 | wandb.init(project="brain-age-prediction", config=opts, name=run_name, sync_tensorboard=True, tags=['to test']) 308 | print('Config:', opts) 309 | print('Model:', model.__class__.__name__) 310 | print('Criterion:', opts.criterion) 311 | print('Optimizer:', optimizer) 312 | print('Scheduler:', opts.lr_decay) 313 | 314 | writer = torch.utils.tensorboard.writer.SummaryWriter(tb_dir) 315 | if opts.amp: 316 | print("Using AMP") 317 | 318 | start_time = time.time() 319 | best_acc = 0. 320 | for epoch in range(1, opts.epochs + 1): 321 | adjust_learning_rate(opts, optimizer, epoch) 322 | 323 | t1 = time.time() 324 | loss_train, mae_train, batch_time, data_time = train(train_loader, model, criterion, optimizer, opts, epoch) 325 | t2 = time.time() 326 | writer.add_scalar("train/loss", loss_train, epoch) 327 | # writer.add_scalar("train/mae", mae_train, epoch) 328 | 329 | loss_test, mae_int = test(test_loader_int, model, criterion, opts, epoch) 330 | writer.add_scalar("test/loss_int", loss_test, epoch) 331 | # writer.add_scalar("test/mae_int", mae_int, epoch) 332 | 333 | loss_test, mae_ext = test(test_loader_ext, model, criterion, opts, epoch) 334 | writer.add_scalar("test/loss_ext", loss_test, epoch) 335 | # writer.add_scalar("test/mae_ext", mae_ext, epoch) 336 | 337 | writer.add_scalar("lr", optimizer.param_groups[0]['lr'], epoch) 338 | writer.add_scalar("BT", batch_time, epoch) 339 | writer.add_scalar("DT", data_time, epoch) 340 | print(f"epoch {epoch}, total time {t2-start_time:.2f}, epoch time {t2-t1:.3f} loss {loss_test:.4f} " 341 | f"mae_int {mae_int:.3f} mae_ext {mae_ext:.3f}") 342 | 343 | if epoch % opts.save_freq == 0: 344 | # save_file = os.path.join(save_dir, f"ckpt_epoch_{epoch}.pth") 345 | # save_model(model, optimizer, opts, epoch, save_file) 346 | mae_train, mae_int, mae_ext = compute_age_mae(model, train_loader, test_loader_int, test_loader_ext, opts) 347 | 348 | writer.add_scalar("train/mae", mae_train, epoch) 349 | writer.add_scalar("test/mae_int", mae_int, epoch) 350 | writer.add_scalar("test/mae_ext", mae_ext, epoch) 351 | print("Age MAE:", mae_train, mae_int, mae_ext) 352 | 353 | ba_train, ba_int, ba_ext = compute_site_ba(model, train_loader, test_loader_int, test_loader_ext, opts) 354 | writer.add_scalar("train/site_ba", ba_train, epoch) 355 | writer.add_scalar("test/ba_int", ba_int, epoch) 356 | writer.add_scalar("test/ba_ext", ba_ext, epoch) 357 | print("Site BA:", ba_train, ba_int, ba_ext) 358 | 359 | challenge_metric = ba_int**0.3 * mae_ext 360 | writer.add_scalar("test/score", challenge_metric, epoch) 361 | print("Challenge score", challenge_metric) 362 | 363 | save_file = os.path.join(save_dir, f"weights.pth") 364 | save_model(model, optimizer, opts, epoch, save_file) 365 | 366 | mae_train, mae_int, mae_ext = compute_age_mae(model, train_loader, test_loader_int, test_loader_ext, opts) 367 | writer.add_scalar("train/mae", mae_train, epoch) 368 | writer.add_scalar("test/mae_int", mae_int, epoch) 369 | writer.add_scalar("test/mae_ext", mae_ext, epoch) 370 | print("Age MAE:", mae_train, mae_int, mae_ext) 371 | 372 | ba_train, ba_int, ba_ext = compute_site_ba(model, train_loader, test_loader_int, test_loader_ext, opts) 373 | writer.add_scalar("train/site_ba", ba_train, epoch) 374 | writer.add_scalar("test/ba_int", ba_int, epoch) 375 | writer.add_scalar("test/ba_ext", ba_ext, epoch) 376 | print("Site BA:", ba_train, ba_int, ba_ext) 377 | 378 | challenge_metric = ba_int**0.3 * mae_ext 379 | writer.add_scalar("test/score", challenge_metric, epoch) 380 | print("Challenge score", challenge_metric) -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet3d import SupConResNet, SupRegResNet, SupCEResNet, LinearRegressor 2 | from .alexnet3d import SupConAlexNet, SupRegAlexNet 3 | from .densenet3d import SupConDenseNet, SupRegDenseNet 4 | from .estimators import AgeEstimator, SiteEstimator -------------------------------------------------------------------------------- /src/models/alexnet3d.py: -------------------------------------------------------------------------------- 1 | """ 2 | Model implemented in https://doi.org/10.5281/zenodo.4309677 by Abrol et al., 2021 3 | """ 4 | from torch import nn 5 | import torch.nn.functional as F 6 | import math 7 | 8 | class AlexNet3D(nn.Module): 9 | def __init__(self): 10 | """ 11 | :param num_classes: int, number of classes 12 | :param mode: "classifier" or "encoder" (returning 128-d vector) 13 | """ 14 | super().__init__() 15 | self.features = nn.Sequential( 16 | nn.Conv3d(1, 64, kernel_size=5, stride=2, padding=0), 17 | nn.BatchNorm3d(64), 18 | nn.ReLU(inplace=True), 19 | nn.MaxPool3d(kernel_size=3, stride=3), 20 | 21 | nn.Conv3d(64, 128, kernel_size=3, stride=1, padding=0), 22 | nn.BatchNorm3d(128), 23 | nn.ReLU(inplace=True), 24 | nn.MaxPool3d(kernel_size=3, stride=3), 25 | 26 | nn.Conv3d(128, 192, kernel_size=3, padding=1), 27 | nn.BatchNorm3d(192), 28 | nn.ReLU(inplace=True), 29 | 30 | nn.Conv3d(192, 192, kernel_size=3, padding=1), 31 | nn.BatchNorm3d(192), 32 | nn.ReLU(inplace=True), 33 | 34 | nn.Conv3d(192, 128, kernel_size=3, padding=1), 35 | nn.BatchNorm3d(128), 36 | nn.ReLU(inplace=True), 37 | nn.AdaptiveMaxPool3d(1), 38 | ) 39 | 40 | 41 | for m in self.modules(): 42 | if isinstance(m, nn.Conv2d): 43 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 44 | m.weight.data.normal_(0, math.sqrt(2. / n)) 45 | elif isinstance(m, nn.BatchNorm3d): 46 | m.weight.data.fill_(1) 47 | m.bias.data.zero_() 48 | 49 | def forward(self, x): 50 | xp = self.features(x) 51 | x = xp.view(xp.size(0), -1) 52 | return x 53 | 54 | class SupConAlexNet(nn.Module): 55 | """backbone + projection head""" 56 | def __init__(self, head='mlp', feat_dim=128): 57 | super().__init__() 58 | self.encoder = AlexNet3D() 59 | dim_in = 128 60 | 61 | if head == 'linear': 62 | self.head = nn.Linear(dim_in, feat_dim) 63 | elif head == 'mlp': 64 | self.head = nn.Sequential( 65 | nn.Linear(dim_in, dim_in), 66 | nn.ReLU(inplace=True), 67 | nn.Linear(dim_in, feat_dim) 68 | ) 69 | 70 | else: 71 | raise NotImplementedError( 72 | 'head not supported: {}'.format(head)) 73 | 74 | def forward(self, x): 75 | feat = self.encoder(x) 76 | feat = F.normalize(self.head(feat), dim=1) 77 | return feat 78 | 79 | def features(self, x): 80 | return self.forward(x) 81 | 82 | 83 | class SupRegAlexNet(nn.Module): 84 | """encoder + regressor""" 85 | def __init__(self,): 86 | super().__init__() 87 | self.encoder = AlexNet3D() 88 | self.fc = nn.Linear(128, 1) 89 | 90 | def forward(self, x): 91 | feats = self.features(x) 92 | return self.fc(feats), feats 93 | 94 | def features(self, x): 95 | return self.encoder(x) -------------------------------------------------------------------------------- /src/models/densenet3d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.utils.checkpoint as cp 5 | from collections import OrderedDict 6 | 7 | class DenseNet(nn.Module): 8 | """3D-Densenet-BC model class, based on 9 | `"Densely Connected Convolutional Networks" `_ 10 | Args: 11 | growth_rate (int) - how many filters to add each layer (`k` in paper) 12 | block_config (list of 4 ints) - how many layers in each pooling block 13 | num_init_features (int) - the number of filters to learn in the first convolution layer 14 | mode (str) - "classifier" or "encoder" (all but last FC layer) 15 | bn_size (int) - multiplicative factor for number of bottle neck layers 16 | (i.e. bn_size * k features in the bottleneck layer) 17 | num_classes (int) - number of classification classes 18 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, 19 | but slower. Default: *False*. See `"paper" `_ 20 | """ 21 | 22 | def __init__(self, growth_rate=32, block_config=(3, 12, 24, 16), 23 | num_init_features=64, 24 | bn_size=4, in_channels=1, 25 | memory_efficient=False): 26 | super(DenseNet, self).__init__() 27 | # First convolution 28 | self.features = nn.Sequential(OrderedDict([ 29 | ('conv0', nn.Conv3d(in_channels, num_init_features, 30 | kernel_size=7, stride=2, padding=3, bias=False)), 31 | ('norm0', nn.BatchNorm3d(num_init_features)), 32 | ('relu0', nn.ReLU(inplace=True)), 33 | ('pool0', nn.MaxPool3d(kernel_size=3, stride=2, padding=1)), 34 | ])) 35 | 36 | # Each denseblock 37 | num_features = num_init_features 38 | for i, num_layers in enumerate(block_config): 39 | block = _DenseBlock( 40 | num_layers=num_layers, 41 | num_input_features=num_features, 42 | bn_size=bn_size, 43 | growth_rate=growth_rate, 44 | memory_efficient=memory_efficient 45 | ) 46 | self.features.add_module('denseblock%d' % (i + 1), block) 47 | num_features = num_features + num_layers * growth_rate 48 | if i != len(block_config) - 1: 49 | trans = _Transition(num_input_features=num_features, 50 | num_output_features=num_features // 2) 51 | self.features.add_module('transition%d' % (i + 1), trans) 52 | num_features = num_features // 2 53 | 54 | self.num_features = num_features 55 | 56 | 57 | # Official init from torch repo. 58 | for m in self.modules(): 59 | if isinstance(m, nn.Conv3d): 60 | nn.init.kaiming_normal_(m.weight) 61 | elif isinstance(m, nn.BatchNorm3d): 62 | nn.init.constant_(m.weight, 1) 63 | nn.init.constant_(m.bias, 0) 64 | elif isinstance(m, nn.Linear): 65 | nn.init.constant_(m.bias, 0) 66 | 67 | def forward(self, x): 68 | features = self.features(x) 69 | out = F.adaptive_avg_pool3d(features, 1) 70 | out = torch.flatten(out, 1) 71 | return out.squeeze(dim=1) 72 | 73 | 74 | def _bn_function_factory(norm, relu, conv): 75 | def bn_function(*inputs): 76 | concated_features = torch.cat(inputs, 1) 77 | bottleneck_output = conv(relu(norm(concated_features))) 78 | return bottleneck_output 79 | 80 | return bn_function 81 | 82 | 83 | class _DenseLayer(nn.Sequential): 84 | def __init__(self, num_input_features, growth_rate, bn_size, memory_efficient=False): 85 | super(_DenseLayer, self).__init__() 86 | self.add_module('norm1', nn.BatchNorm3d(num_input_features)), 87 | self.add_module('relu1', nn.ReLU(inplace=True)), 88 | self.add_module('conv1', nn.Conv3d(num_input_features, bn_size * 89 | growth_rate, kernel_size=1, stride=1, 90 | bias=False)), 91 | self.add_module('norm2', nn.BatchNorm3d(bn_size * growth_rate)), 92 | self.add_module('relu2', nn.ReLU(inplace=True)), 93 | self.add_module('conv2', nn.Conv3d(bn_size * growth_rate, growth_rate, 94 | kernel_size=3, stride=1, padding=1, 95 | bias=False)), 96 | self.memory_efficient = memory_efficient 97 | 98 | def forward(self, *prev_features): 99 | bn_function = _bn_function_factory(self.norm1, self.relu1, self.conv1) 100 | if self.memory_efficient and any(prev_feature.requires_grad for prev_feature in prev_features): 101 | bottleneck_output = cp.checkpoint(bn_function, *prev_features) 102 | else: 103 | bottleneck_output = bn_function(*prev_features) 104 | 105 | new_features = self.conv2(self.relu2(self.norm2(bottleneck_output))) 106 | 107 | return new_features 108 | 109 | 110 | class _DenseBlock(nn.Module): 111 | def __init__(self, num_layers, num_input_features, bn_size, growth_rate, memory_efficient=False): 112 | super(_DenseBlock, self).__init__() 113 | for i in range(num_layers): 114 | layer = _DenseLayer( 115 | num_input_features + i * growth_rate, 116 | growth_rate=growth_rate, 117 | bn_size=bn_size, 118 | memory_efficient=memory_efficient, 119 | ) 120 | self.add_module('denselayer%d' % (i + 1), layer) 121 | 122 | def forward(self, init_features): 123 | features = [init_features] 124 | for name, layer in self.named_children(): 125 | new_features = layer(*features) 126 | features.append(new_features) 127 | return torch.cat(features, 1) 128 | 129 | 130 | class _Transition(nn.Sequential): 131 | def __init__(self, num_input_features, num_output_features): 132 | super(_Transition, self).__init__() 133 | self.add_module('norm', nn.BatchNorm3d(num_input_features)) 134 | self.add_module('relu', nn.ReLU(inplace=True)) 135 | self.add_module('conv', nn.Conv3d(num_input_features, num_output_features, 136 | kernel_size=1, stride=1, bias=False)) 137 | self.add_module('pool', nn.AvgPool3d(kernel_size=2, stride=2)) 138 | 139 | 140 | def _densenet(arch, growth_rate, block_config, num_init_features, **kwargs): 141 | model = DenseNet(growth_rate, block_config, num_init_features, **kwargs) 142 | return model 143 | 144 | 145 | def densenet121(**kwargs): 146 | r"""Densenet-121 model from 147 | `"Densely Connected Convolutional Networks" `_ 148 | 149 | Args: 150 | pretrained (bool): If True, returns a model pre-trained on ImageNet 151 | progress (bool): If True, displays a progress bar of the download to stderr 152 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, 153 | but slower. Default: *False*. See `"paper" `_ 154 | """ 155 | return _densenet('densenet121', 32, (6, 12, 24, 16), 64, **kwargs) 156 | 157 | class SupConDenseNet(nn.Module): 158 | """backbone + projection head""" 159 | def __init__(self, head='mlp', feat_dim=128): 160 | super().__init__() 161 | self.encoder = densenet121() 162 | dim_in = self.encoder.num_features 163 | 164 | if head == 'linear': 165 | self.head = nn.Linear(dim_in, feat_dim) 166 | elif head == 'mlp': 167 | self.head = nn.Sequential( 168 | nn.Linear(dim_in, dim_in), 169 | nn.ReLU(inplace=True), 170 | nn.Linear(dim_in, feat_dim) 171 | ) 172 | 173 | else: 174 | raise NotImplementedError( 175 | 'head not supported: {}'.format(head)) 176 | 177 | def forward(self, x): 178 | feat = self.encoder(x) 179 | feat = F.normalize(self.head(feat), dim=1) 180 | return feat 181 | 182 | def features(self, x): 183 | return self.forward(x) 184 | 185 | 186 | class SupRegDenseNet(nn.Module): 187 | """encoder + regressor""" 188 | def __init__(self,): 189 | super().__init__() 190 | self.encoder = densenet121() 191 | self.fc = nn.Linear(self.encoder.num_features, 1) 192 | 193 | def forward(self, x): 194 | feats = self.features(x) 195 | return self.fc(feats), feats 196 | 197 | def features(self, x): 198 | return self.encoder(x) -------------------------------------------------------------------------------- /src/models/estimators.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import multiprocessing 3 | from sklearn.base import BaseEstimator 4 | from sklearn.linear_model import LogisticRegression, Ridge 5 | from sklearn.model_selection import GridSearchCV 6 | from sklearn.metrics import mean_absolute_error 7 | 8 | class AgeEstimator(BaseEstimator): 9 | """ Define the age estimator on latent space network features. 10 | """ 11 | def __init__(self): 12 | n_jobs = multiprocessing.cpu_count() 13 | self.age_estimator = GridSearchCV( 14 | Ridge(), param_grid={"alpha": 10.**np.arange(-2, 3)}, cv=5, 15 | scoring="r2", n_jobs=n_jobs) 16 | 17 | def fit(self, X, y): 18 | self.age_estimator.fit(X, y) 19 | return self.score(X, y) 20 | 21 | def predict(self, X): 22 | y_pred = self.age_estimator.predict(X) 23 | return y_pred 24 | 25 | def score(self, X, y): 26 | y_pred = self.age_estimator.predict(X) 27 | return mean_absolute_error(y, y_pred) 28 | 29 | class SiteEstimator(BaseEstimator): 30 | """ Define the site estimator on latent space network features. 31 | """ 32 | def __init__(self): 33 | n_jobs = multiprocessing.cpu_count() 34 | self.site_estimator = GridSearchCV( 35 | LogisticRegression(solver="saga", max_iter=150), cv=5, 36 | param_grid={"C": 10.**np.arange(-2, 3)}, 37 | scoring="balanced_accuracy", n_jobs=n_jobs) 38 | 39 | def fit(self, X, y): 40 | self.site_estimator.fit(X, y) 41 | return self.site_estimator.score(X, y) 42 | 43 | def predict(self, X): 44 | return self.site_estimator.predict(X) 45 | 46 | def score(self, X, y): 47 | return self.site_estimator.score(X, y) -------------------------------------------------------------------------------- /src/models/resnet3d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 7 | """3x3 convolution with padding""" 8 | return nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=stride, 9 | padding=dilation, groups=groups, bias=False, dilation=dilation) 10 | 11 | def conv1x1(in_planes, out_planes, stride=1): 12 | """1x1 convolution""" 13 | return nn.Conv3d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 14 | 15 | class BasicBlock(nn.Module): 16 | expansion = 1 17 | 18 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 19 | base_width=64, dilation=1, norm_layer=None): 20 | super(BasicBlock, self).__init__() 21 | if norm_layer is None: 22 | norm_layer = nn.BatchNorm3d 23 | if groups != 1 or base_width != 64: 24 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 25 | if dilation > 1: 26 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 27 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 28 | self.conv1 = conv3x3(inplanes, planes, stride) 29 | self.bn1 = norm_layer(planes) 30 | self.relu = nn.ReLU(inplace=True) 31 | self.conv2 = conv3x3(planes, planes) 32 | self.bn2 = norm_layer(planes) 33 | self.downsample = downsample 34 | self.stride = stride 35 | 36 | def forward(self, x): 37 | identity = x 38 | 39 | out = self.conv1(x) 40 | out = self.bn1(out) 41 | out = self.relu(out) 42 | out = self.conv2(out) 43 | out = self.bn2(out) 44 | 45 | if self.downsample is not None: 46 | identity = self.downsample(x) 47 | 48 | out += identity 49 | out = self.relu(out) 50 | 51 | return out 52 | 53 | 54 | class Bottleneck(nn.Module): 55 | expansion = 4 56 | 57 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 58 | base_width=64, dilation=1, norm_layer=None): 59 | super(Bottleneck, self).__init__() 60 | if norm_layer is None: 61 | norm_layer = nn.BatchNorm3d 62 | width = int(planes * (base_width / 64.)) * groups 63 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 64 | self.conv1 = conv1x1(inplanes, width) 65 | self.bn1 = norm_layer(width) 66 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 67 | self.bn2 = norm_layer(width) 68 | self.conv3 = conv1x1(width, planes * self.expansion) 69 | self.bn3 = norm_layer(planes * self.expansion) 70 | self.relu = nn.ReLU(inplace=True) 71 | self.downsample = downsample 72 | self.stride = stride 73 | 74 | def forward(self, x): 75 | identity = x 76 | 77 | out = self.conv1(x) 78 | out = self.bn1(out) 79 | out = self.relu(out) 80 | 81 | out = self.conv2(out) 82 | out = self.bn2(out) 83 | out = self.relu(out) 84 | 85 | out = self.conv3(out) 86 | out = self.bn3(out) 87 | 88 | if self.downsample is not None: 89 | identity = self.downsample(x) 90 | 91 | out += identity 92 | out = self.relu(out) 93 | 94 | return out 95 | 96 | class ResNet(nn.Module): 97 | """ 98 | Standard 3D-ResNet architecture with big initial 7x7x7 kernel. 99 | It can be turned in mode "classifier", outputting a vector of size or 100 | "encoder", outputting a latent vector of size 512 (independent of input size). 101 | Note: only a last FC layer is added on top of the "encoder" backbone. 102 | """ 103 | def __init__(self, block, layers, in_channels=1, 104 | zero_init_residual=False, groups=1, width_per_group=64, replace_stride_with_dilation=None, 105 | norm_layer=None, initial_kernel_size=7): 106 | super(ResNet, self).__init__() 107 | 108 | if norm_layer is None: 109 | norm_layer = nn.BatchNorm3d 110 | self._norm_layer = norm_layer 111 | 112 | self.name = "resnet" 113 | self.inputs = None 114 | self.inplanes = 64 115 | self.dilation = 1 116 | 117 | if replace_stride_with_dilation is None: 118 | # each element in the tuple indicates if we should replace 119 | # the 2x2 stride with a dilated convolution instead 120 | replace_stride_with_dilation = [False, False, False] 121 | if len(replace_stride_with_dilation) != 3: 122 | raise ValueError("replace_stride_with_dilation should be None " 123 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 124 | self.groups = groups 125 | self.base_width = width_per_group 126 | initial_stride = 2 if initial_kernel_size==7 else 1 127 | padding = (initial_kernel_size-initial_stride+1)//2 128 | self.conv1 = nn.Conv3d(in_channels, self.inplanes, kernel_size=initial_kernel_size, stride=initial_stride, padding=padding, bias=False) 129 | self.bn1 = norm_layer(self.inplanes) 130 | self.relu = nn.ReLU(inplace=True) 131 | self.maxpool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1) 132 | 133 | channels = [64, 128, 256, 512] 134 | 135 | self.layer1 = self._make_layer(block, channels[0], layers[0]) 136 | self.layer2 = self._make_layer(block, channels[1], layers[1], stride=2, dilate=replace_stride_with_dilation[0]) 137 | self.layer3 = self._make_layer(block, channels[2], layers[2], stride=2, dilate=replace_stride_with_dilation[1]) 138 | self.layer4 = self._make_layer(block, channels[3], layers[3], stride=2, dilate=replace_stride_with_dilation[2]) 139 | self.avgpool = nn.AdaptiveAvgPool3d(1) 140 | 141 | for m in self.modules(): 142 | if isinstance(m, nn.Conv3d): 143 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 144 | elif isinstance(m, (nn.BatchNorm3d, nn.GroupNorm)): 145 | nn.init.constant_(m.weight, 1) 146 | nn.init.constant_(m.bias, 0) 147 | elif isinstance(m, nn.Linear): 148 | nn.init.normal_(m.weight, 0, 0.01) 149 | if m.bias is not None: 150 | nn.init.constant_(m.bias, 0) 151 | 152 | # Zero-initialize the last BN in each residual branch, 153 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 154 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 155 | if zero_init_residual: 156 | for m in self.modules(): 157 | if isinstance(m, Bottleneck): 158 | nn.init.constant_(m.bn3.weight, 0) 159 | elif isinstance(m, BasicBlock): 160 | nn.init.constant_(m.bn2.weight, 0) 161 | 162 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 163 | norm_layer = self._norm_layer 164 | downsample = None 165 | previous_dilation = self.dilation 166 | if dilate: 167 | self.dilation *= stride 168 | stride = 1 169 | if stride != 1 or self.inplanes != planes * block.expansion: 170 | downsample = nn.Sequential( 171 | conv1x1(self.inplanes, planes * block.expansion, stride), 172 | norm_layer(planes * block.expansion), 173 | ) 174 | 175 | layers = [] 176 | layers.append(block(self.inplanes, planes, stride=stride, downsample=downsample, groups=self.groups, 177 | base_width=self.base_width, dilation=previous_dilation, norm_layer=norm_layer)) 178 | self.inplanes = planes * block.expansion 179 | for _ in range(1, blocks): 180 | layers.append(block(self.inplanes, planes, groups=self.groups, 181 | base_width=self.base_width, dilation=self.dilation, 182 | norm_layer=norm_layer)) 183 | 184 | return nn.Sequential(*layers) 185 | 186 | def forward(self, x): 187 | x = self.conv1(x) 188 | x = self.bn1(x) 189 | x = self.relu(x) 190 | x = self.maxpool(x) 191 | 192 | x1 = self.layer1(x) 193 | x2 = self.layer2(x1) 194 | x3 = self.layer3(x2) 195 | x4 = self.layer4(x3) 196 | 197 | x5 = self.avgpool(x4) 198 | return torch.flatten(x5, 1) 199 | 200 | def resnet18(**kwargs): 201 | return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 202 | 203 | def resnet34(**kwargs): 204 | return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 205 | 206 | def resnet50(**kwargs): 207 | return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 208 | 209 | def resnet101(**kwargs): 210 | return ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 211 | 212 | model_dict = { 213 | 'resnet18': [resnet18, 512], 214 | 'resnet34': [resnet34, 512], 215 | 'resnet50': [resnet50, 2048], 216 | 'resnet101': [resnet101, 2048], 217 | } 218 | 219 | class SupConResNet(nn.Module): 220 | """backbone + projection head""" 221 | def __init__(self, name='resnet50', head='mlp', feat_dim=128): 222 | super().__init__() 223 | model_fun, dim_in = model_dict[name] 224 | self.encoder = model_fun() 225 | if head == 'linear': 226 | self.head = nn.Linear(dim_in, feat_dim) 227 | elif head == 'mlp': 228 | self.head = nn.Sequential( 229 | nn.Linear(dim_in, dim_in), 230 | nn.ReLU(inplace=True), 231 | nn.Linear(dim_in, feat_dim) 232 | ) 233 | else: 234 | raise NotImplementedError( 235 | 'head not supported: {}'.format(head)) 236 | 237 | def forward(self, x): 238 | feat = self.encoder(x) 239 | feat = F.normalize(self.head(feat), dim=1) 240 | return feat 241 | 242 | def features(self, x): 243 | return self.forward(x) 244 | 245 | 246 | class SupRegResNet(nn.Module): 247 | """encoder + regressor""" 248 | def __init__(self, name='resnet50'): 249 | super().__init__() 250 | model_fun, dim_in = model_dict[name] 251 | self.encoder = model_fun() 252 | self.fc = nn.Linear(dim_in, 1) 253 | 254 | def forward(self, x): 255 | feats = self.features(x) 256 | return self.fc(feats), feats 257 | 258 | def features(self, x): 259 | return self.encoder(x) 260 | 261 | class SupCEResNet(nn.Module): 262 | """encoder + classifier""" 263 | def __init__(self, n_classes, name='resnet50'): 264 | super().__init__() 265 | model_fun, dim_in = model_dict[name] 266 | self.encoder = model_fun() 267 | self.fc = nn.Linear(dim_in, n_classes) 268 | 269 | def forward(self, x): 270 | return self.fc(self.encoder(x)) 271 | 272 | def features(self, x): 273 | return self.encoder(x) 274 | 275 | 276 | class LinearRegressor(nn.Module): 277 | """Linear regressor""" 278 | def __init__(self, name='resnet50'): 279 | super().__init__() 280 | _, feat_dim = model_dict[name] 281 | self.fc = nn.Linear(feat_dim, 1) 282 | 283 | def forward(self, features): 284 | return self.fc(features) -------------------------------------------------------------------------------- /src/util.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import random 4 | import numpy as np 5 | import os 6 | import wandb 7 | import torch.nn.functional as F 8 | import models 9 | from pathlib import Path 10 | 11 | 12 | class NViewTransform: 13 | """Create N augmented views of the same image""" 14 | def __init__(self, transform, N): 15 | self.transform = transform 16 | self.N = N 17 | 18 | def __call__(self, x): 19 | return [self.transform(x) for _ in range(self.N)] 20 | 21 | def arg2bool(val): 22 | if isinstance(val, bool): 23 | return val 24 | 25 | elif isinstance(val, str): 26 | if val == "true": 27 | return True 28 | 29 | if val == "false": 30 | return False 31 | 32 | val = int(val) 33 | assert val == 0 or val == 1 34 | return val == 1 35 | 36 | class AverageMeter(object): 37 | """Computes and stores the average and current value""" 38 | def __init__(self): 39 | self.reset() 40 | 41 | def reset(self): 42 | self.val = 0 43 | self.avg = 0 44 | self.sum = 0 45 | self.count = 0 46 | 47 | def update(self, val, n=1): 48 | self.val = val 49 | self.sum += val * n 50 | self.count += n 51 | self.avg = self.sum / self.count 52 | 53 | class MAE(): 54 | def __init__(self): 55 | self.reset() 56 | 57 | def reset(self): 58 | self.outputs = [] 59 | self.targets = [] 60 | self.avg = np.inf 61 | 62 | def update(self, outputs, targets): 63 | self.outputs.append(outputs.detach()) 64 | self.targets.append(targets.detach()) 65 | self.avg = F.l1_loss(torch.cat(self.outputs, 0), torch.cat(self.targets, 0)) 66 | 67 | class Accuracy(): 68 | def __init__(self, topk=(1,)): 69 | self.reset() 70 | self.topk = topk 71 | 72 | def reset(self): 73 | self.outputs = [] 74 | self.targets = [] 75 | self.avg = np.inf 76 | 77 | def update(self, outputs, targets): 78 | self.outputs.append(outputs.detach()) 79 | self.targets.append(targets.detach()) 80 | self.avg = accuracy(torch.cat(self.outputs, 0), torch.cat(self.targets, 0), self.topk) 81 | 82 | def ensure_dir(dirname): 83 | dirname = Path(dirname) 84 | if not dirname.is_dir(): 85 | dirname.mkdir(parents=True, exist_ok=True) 86 | 87 | def accuracy(output, target, topk=(1,)): 88 | """Computes the accuracy over the k top predictions for the specified values of k""" 89 | with torch.no_grad(): 90 | maxk = max(topk) 91 | batch_size = target.size(0) 92 | 93 | _, pred = output.topk(maxk, 1, True, True) 94 | pred = pred.t() 95 | correct = pred.eq(target.reshape(1, -1).expand_as(pred)) 96 | 97 | res = [] 98 | for k in topk: 99 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 100 | res.append(correct_k.mul_(100.0 / batch_size).item()) 101 | return res 102 | 103 | def set_seed(seed): 104 | random.seed(seed) 105 | os.environ["PYTHONHASHSEED"] = str(seed) 106 | np.random.seed(seed) 107 | torch.cuda.manual_seed(seed) 108 | torch.cuda.manual_seed_all(seed) 109 | torch.backends.cudnn.deterministic = False 110 | torch.backends.cudnn.benchmark = True 111 | torch.manual_seed(seed) 112 | 113 | def save_model(model, optimizer, opt, epoch, save_file): 114 | print('==> Saving...') 115 | state_dict = model.state_dict() 116 | if torch.cuda.device_count() > 1: 117 | state_dict = model.module.state_dict() 118 | 119 | state = { 120 | 'opts': opt, 121 | 'model': state_dict, 122 | 'optimizer': optimizer.state_dict(), 123 | 'epoch': epoch, 124 | 'run_id': wandb.run.id 125 | } 126 | torch.save(state, save_file) 127 | del state 128 | 129 | def adjust_learning_rate(args, optimizer, epoch): 130 | lr = args.lr 131 | if args.lr_decay == 'cosine': 132 | eta_min = lr * (args.lr_decay_rate ** 3) 133 | lr = eta_min + (lr - eta_min) * ( 134 | 1 + math.cos(math.pi * epoch / args.epochs)) / 2 135 | else: 136 | steps = np.sum(epoch > np.asarray(args.lr_decay_epochs)) 137 | if steps > 0: 138 | lr = lr * (args.lr_decay_rate ** steps) 139 | 140 | for param_group in optimizer.param_groups: 141 | param_group['lr'] = lr 142 | 143 | 144 | def warmup_learning_rate(args, epoch, batch_id, total_batches, optimizer): 145 | if args.warm and epoch <= args.warm_epochs: 146 | p = (batch_id + (epoch - 1) * total_batches) / \ 147 | (args.warm_epochs * total_batches) 148 | lr = args.warmup_from + p * (args.warmup_to - args.warmup_from) 149 | 150 | for param_group in optimizer.param_groups: 151 | param_group['lr'] = lr 152 | 153 | @torch.no_grad() 154 | def gather_age_feats(model, dataloader, opts): 155 | features = [] 156 | age_labels = [] 157 | 158 | model.eval() 159 | for idx, (images, labels, _) in enumerate(dataloader): 160 | if isinstance(images, list): 161 | images = images[0] 162 | images = images.to(opts.device) 163 | features.append(model.features(images)) 164 | age_labels.append(labels) 165 | 166 | return torch.cat(features, 0).cpu().numpy(), torch.cat(age_labels, 0).cpu().numpy() 167 | 168 | @torch.no_grad() 169 | def compute_age_mae(model, train_loader, test_int, test_ext, opts): 170 | site_estimator = models.AgeEstimator() 171 | 172 | print("Training age estimator") 173 | train_X, train_y = gather_age_feats(model, train_loader, opts) 174 | mae_train = site_estimator.fit(train_X, train_y) 175 | 176 | print("Computing BA") 177 | int_X, int_y = gather_age_feats(model, test_int, opts) 178 | ext_X, ext_y = gather_age_feats(model, test_ext, opts) 179 | mae_int = site_estimator.score(int_X, int_y) 180 | mae_ext = site_estimator.score(ext_X, ext_y) 181 | 182 | return mae_train, mae_int, mae_ext 183 | 184 | @torch.no_grad() 185 | def gather_site_feats(model, dataloader, opts): 186 | features = [] 187 | site_labels = [] 188 | 189 | model.eval() 190 | for idx, (images, _, sites) in enumerate(dataloader): 191 | if isinstance(images, list): 192 | images = images[0] 193 | images = images.to(opts.device) 194 | features.append(model.features(images)) 195 | site_labels.append(sites) 196 | 197 | return torch.cat(features, 0).cpu().numpy(), torch.cat(site_labels, 0).cpu().numpy() 198 | 199 | @torch.no_grad() 200 | def compute_site_ba(model, train_loader, test_int, test_ext, opts): 201 | site_estimator = models.SiteEstimator() 202 | 203 | print("Training site estimator") 204 | train_X, train_y = gather_site_feats(model, train_loader, opts) 205 | ba_train = site_estimator.fit(train_X, train_y) 206 | 207 | print("Computing BA") 208 | int_X, int_y = gather_site_feats(model, test_int, opts) 209 | ext_X, ext_y = gather_site_feats(model, test_ext, opts) 210 | ba_int = site_estimator.score(int_X, int_y) 211 | ba_ext = site_estimator.score(ext_X, ext_y) 212 | 213 | return ba_train, ba_int, ba_ext --------------------------------------------------------------------------------