├── LICENSE
├── README.md
├── assets
└── teaser.png
├── ramp-submission
├── estimator.py
└── estimator_mse.py
└── src
├── data
├── __init__.py
├── masks
│ ├── cat12vbm_space-MNI152_desc-gm_TPM.nii.gz
│ └── quasiraw_space-MNI152_desc-brain_T1w.nii.gz
├── openbhb.py
└── transforms.py
├── exp
├── mae.yaml
├── supcon_adam_kernel.yaml
└── supcon_sgd_kernel.yaml
├── figures
├── ablation.csv
├── ablation.pdf
└── ablation.py
├── launcher.py
├── losses.py
├── main_infonce.py
├── main_mse.py
├── models
├── __init__.py
├── alexnet3d.py
├── densenet3d.py
├── estimators.py
└── resnet3d.py
└── util.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 EIDOSLAB
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Contrastive learning for regression in multi-site brain age prediction
2 |
3 | Carlo Alberto Barbano1,2, Benoit Dufumier1,3, Edouard Duchesnay3, Marco Grangetto2, Pietro Gori1 | [[pdf](https://arxiv.org/pdf/2211.08326.pdf)] [[poster](https://drive.google.com/file/d/1gr45EamhVVClPbMT5T5b1Gy9V50fgw3c/view)]
4 |
5 | 1LTCI, Télécom Paris, IP Paris
6 | 2University of Turin, Computer Science dept.
7 | 3NeuroSpin, CEA, Universite Paris-Saclay
8 |
9 |
10 | 
11 |
12 | Building accurate Deep Learning (DL) models for brain age prediction is a very relevant topic in neuroimaging, as it could help better understand neurodegenerative disorders and find new biomarkers. To estimate accurate and generalizable models, large datasets have been collected, which are often multi-site and multi-scanner. This large heterogeneity negatively affects the generalization performance of DL models since they are prone to overfit site-related noise. Recently, contrastive learning approaches have been shown to be more robust against noise in data or labels. For this reason, we propose a novel contrastive learning regression loss for robust brain age prediction using MRI scans. Our method achieves state-of-the-art performance on the OpenBHB challenge, yielding the best generalization capability and robustness to site-related noise.
13 |
14 |
15 | ## Running
16 |
17 | ### Training
18 |
19 | The code can be found in the src folder. For training there is a couple of different files:
20 |
21 | - `main_mse.py`: for training baseline MSE/MAE models
22 | - `main_infonce.py`: for training models with contrastive losses
23 |
24 | For easiness of use, the script `launcher.py` is provided with some predefined experiments which can be found in `src/exp` as YAML template. To launch:
25 |
26 | ```
27 | python3 launcher.py exp/mae.yaml
28 | ```
29 |
30 | ### Testing on the leaderboard
31 |
32 | To test on the official leaderboard of the OpenBHB challenge, first you need to create an account at [https://ramp.studio/](https://ramp.studio/). For the submission to the challenge ([https://ramp.studio/events/brain_age_with_site_removal_open_2022](https://ramp.studio/events/brain_age_with_site_removal_open_2022)), the source code for submission can be found in the `ramp-submission` folder (code for both supervised and contrastive models).
33 |
34 | ## Citing
35 |
36 | For citing our work, please use the following bibtex entry:
37 |
38 | ```bibtex
39 | @inproceedings{barbano2023contrastive,
40 | author = {Barbano, Carlo Alberto and Dufumier, Benoit and Duchesnay, Edouard and Grangetto, Marco and Gori, Pietro},
41 | journal = {International Symposium on Biomedical Imaging (ISBI)},
42 | title = {Contrastive learning for regression in multi-site brain age prediction},
43 | year = {2023}
44 | }
45 | ```
--------------------------------------------------------------------------------
/assets/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EIDOSLAB/contrastive-brain-age-prediction/2fe9e7b81dd53d8f43dfeb34e41250f5450c1094/assets/teaser.png
--------------------------------------------------------------------------------
/ramp-submission/estimator.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | ##########################################################################
3 | # Code version 6207dffcc20f461bdb742f5d8a2f6641483b9d83
4 | ##########################################################################
5 |
6 |
7 | """
8 | Each solution to be tested should be stored in its own directory within
9 | submissions/. The name of this new directory will serve as the ID for
10 | the submission. If you wish to launch a RAMP challenge you will need to
11 | provide an example solution within submissions/starting_kit/. Even if
12 | you are not launching a RAMP challenge on RAMP Studio, it is useful to
13 | have an example submission as it shows which files are required, how they
14 | need to be named and how each file should be structured.
15 | """
16 |
17 | # Filename: estimator.py
18 | # Run id:
19 | #
20 | import os
21 | ARCHITECTURE = os.environ.get("ARCHITECTURE", "resnet18")
22 |
23 | from collections import OrderedDict
24 | from abc import ABCMeta
25 | import progressbar
26 | import nibabel
27 | import numpy as np
28 | from nilearn.masking import unmask
29 | from sklearn.base import BaseEstimator
30 | from sklearn.base import TransformerMixin
31 | from sklearn.pipeline import Pipeline, make_pipeline
32 | import torch
33 | import torch.nn as nn
34 | import torch.nn.functional as F
35 | import torch.utils.checkpoint as cp
36 | from torchvision import transforms
37 | import math
38 |
39 |
40 | ############################################################################
41 | # Define here some selectors
42 | ############################################################################
43 |
44 | class FeatureExtractor(BaseEstimator, TransformerMixin):
45 | """ Select only the requested data associatedd features from the the
46 | input buffered data.
47 | """
48 | MODALITIES = OrderedDict([
49 | ("vbm", {
50 | "shape": (1, 121, 145, 121),
51 | "size": 519945}),
52 | ("quasiraw", {
53 | "shape": (1, 182, 218, 182),
54 | "size": 1827095}),
55 | ("xhemi", {
56 | "shape": (8, 163842),
57 | "size": 1310736}),
58 | ("vbm_roi", {
59 | "shape": (1, 284),
60 | "size": 284}),
61 | ("desikan_roi", {
62 | "shape": (7, 68),
63 | "size": 476}),
64 | ("destrieux_roi", {
65 | "shape": (7, 148),
66 | "size": 1036})
67 | ])
68 | MASKS = {
69 | "vbm": {
70 | "path": None,
71 | "thr": 0.05},
72 | "quasiraw": {
73 | "path": None,
74 | "thr": 0}
75 | }
76 |
77 | def __init__(self, dtype, mock=False):
78 | """ Init class.
79 | Parameters
80 | ----------
81 | dtype: str
82 | the requested data: 'vbm', 'quasiraw', 'vbm_roi', 'desikan_roi',
83 | 'destrieux_roi' or 'xhemi'.
84 | """
85 | if dtype not in self.MODALITIES:
86 | raise ValueError("Invalid input data type.")
87 | self.dtype = dtype
88 |
89 | data_types = list(self.MODALITIES.keys())
90 | index = data_types.index(dtype)
91 |
92 | cumsum = np.cumsum([item["size"] for item in self.MODALITIES.values()])
93 |
94 | if index > 0:
95 | self.start = cumsum[index - 1]
96 | else:
97 | self.start = 0
98 | self.stop = cumsum[index]
99 |
100 | self.masks = dict((key, val["path"]) for key, val in self.MASKS.items())
101 | self.masks["vbm"] = os.environ.get("VBM_MASK")
102 | self.masks["quasiraw"] = os.environ.get("QUASIRAW_MASK")
103 |
104 | self.mock = mock
105 | if mock:
106 | return
107 |
108 | for key in self.masks:
109 | if self.masks[key] is None or not os.path.isfile(self.masks[key]):
110 | raise ValueError("Impossible to find mask:", key, self.masks[key])
111 | arr = nibabel.load(self.masks[key]).get_fdata()
112 | thr = self.MASKS[key]["thr"]
113 | arr[arr <= thr] = 0
114 | arr[arr > thr] = 1
115 | self.masks[key] = nibabel.Nifti1Image(arr.astype(int), np.eye(4))
116 |
117 | def fit(self, X, y):
118 | return self
119 |
120 | def transform(self, X):
121 | if self.mock:
122 | #print("transforming", X.shape)
123 | data = X.reshape(self.MODALITIES[self.dtype]["shape"])
124 | #print("mock data:", data.shape)
125 | return data
126 |
127 | # print(X.shape)
128 | select_X = X[self.start:self.stop]
129 | if self.dtype in ("vbm", "quasiraw"):
130 | im = unmask(select_X, self.masks[self.dtype])
131 | select_X = im.get_fdata()
132 | select_X = select_X.transpose(2, 0, 1)
133 | select_X = select_X.reshape(self.MODALITIES[self.dtype]["shape"])
134 | return select_X
135 |
136 | class Crop(object):
137 | """ Crop the given n-dimensional array either at a random location or
138 | centered.
139 | """
140 | def __init__(self, shape, type="center", keep_dim=False):
141 | assert type in ["center", "random"]
142 | self.shape = shape
143 | self.copping_type = type
144 | self.keep_dim = keep_dim
145 |
146 | def __call__(self, X):
147 | img_shape = np.array(X.shape)
148 |
149 | if type(self.shape) == int:
150 | size = [self.shape for _ in range(len(self.shape))]
151 | else:
152 | size = np.copy(self.shape)
153 |
154 | # print('img_shape:', img_shape, 'size', size)
155 |
156 | indexes = []
157 | for ndim in range(len(img_shape)):
158 | if size[ndim] > img_shape[ndim] or size[ndim] < 0:
159 | size[ndim] = img_shape[ndim]
160 |
161 | if self.copping_type == "center":
162 | delta_before = int((img_shape[ndim] - size[ndim]) / 2.0)
163 |
164 | elif self.copping_type == "random":
165 | delta_before = np.random.randint(0, img_shape[ndim] - size[ndim] + 1)
166 |
167 | indexes.append(slice(delta_before, delta_before + size[ndim]))
168 |
169 | if self.keep_dim:
170 | mask = np.zeros(img_shape, dtype=np.bool)
171 | mask[tuple(indexes)] = True
172 | arr_copy = X.copy()
173 | arr_copy[~mask] = 0
174 | return arr_copy
175 |
176 | _X = X[tuple(indexes)]
177 | # print('cropped.shape', _X.shape)
178 | return _X
179 |
180 | class Pad(object):
181 | """ Pad the given n-dimensional array
182 | """
183 | def __init__(self, shape, **kwargs):
184 | self.shape = shape
185 | self.kwargs = kwargs
186 |
187 | def __call__(self, X):
188 | _X = self._apply_padding(X)
189 | return _X
190 |
191 | def _apply_padding(self, arr):
192 | orig_shape = arr.shape
193 | padding = []
194 | for orig_i, final_i in zip(orig_shape, self.shape):
195 | shape_i = final_i - orig_i
196 | half_shape_i = shape_i // 2
197 | if shape_i % 2 == 0:
198 | padding.append([half_shape_i, half_shape_i])
199 | else:
200 | padding.append([half_shape_i, half_shape_i + 1])
201 | for cnt in range(len(arr.shape) - len(padding)):
202 | padding.append([0, 0])
203 | fill_arr = np.pad(arr, padding, **self.kwargs)
204 | return fill_arr
205 |
206 | ############################################################################
207 | # Define here your dataset
208 | ############################################################################
209 |
210 | class Dataset(torch.utils.data.Dataset):
211 | def __init__(self, X, y=None, transforms=None, indices=None):
212 | self.T = transforms
213 | self.X = X
214 | self.y = y
215 | self.indices = indices
216 | if indices is None:
217 | self.indices = range(len(X))
218 |
219 | def __len__(self):
220 | return len(self.indices)
221 |
222 | def __getitem__(self, i):
223 | real_i = self.indices[i]
224 | x = self.X[real_i]
225 |
226 | if self.T is not None:
227 | x = self.T(x)
228 |
229 | if self.y is not None:
230 | y = self.y[real_i]
231 | return x, y
232 | else:
233 | return x
234 |
235 |
236 | ############################################################################
237 | # Define here your regression model
238 | ############################################################################
239 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
240 | """3x3 convolution with padding"""
241 | return nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=stride,
242 | padding=dilation, groups=groups, bias=False, dilation=dilation)
243 |
244 | def conv1x1(in_planes, out_planes, stride=1):
245 | """1x1 convolution"""
246 | return nn.Conv3d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
247 |
248 | class BasicBlock(nn.Module):
249 | expansion = 1
250 |
251 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
252 | base_width=64, dilation=1, norm_layer=None):
253 | super(BasicBlock, self).__init__()
254 | if norm_layer is None:
255 | norm_layer = nn.BatchNorm3d
256 | if groups != 1 or base_width != 64:
257 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
258 | if dilation > 1:
259 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
260 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
261 | self.conv1 = conv3x3(inplanes, planes, stride)
262 | self.bn1 = norm_layer(planes)
263 | self.relu = nn.ReLU(inplace=True)
264 | self.conv2 = conv3x3(planes, planes)
265 | self.bn2 = norm_layer(planes)
266 | self.downsample = downsample
267 | self.stride = stride
268 |
269 | def forward(self, x):
270 | identity = x
271 |
272 | out = self.conv1(x)
273 | out = self.bn1(out)
274 | out = self.relu(out)
275 | out = self.conv2(out)
276 | out = self.bn2(out)
277 |
278 | if self.downsample is not None:
279 | identity = self.downsample(x)
280 |
281 | out += identity
282 | out = self.relu(out)
283 |
284 | return out
285 |
286 |
287 | class Bottleneck(nn.Module):
288 | expansion = 4
289 |
290 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
291 | base_width=64, dilation=1, norm_layer=None):
292 | super(Bottleneck, self).__init__()
293 | if norm_layer is None:
294 | norm_layer = nn.BatchNorm3d
295 | width = int(planes * (base_width / 64.)) * groups
296 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
297 | self.conv1 = conv1x1(inplanes, width)
298 | self.bn1 = norm_layer(width)
299 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
300 | self.bn2 = norm_layer(width)
301 | self.conv3 = conv1x1(width, planes * self.expansion)
302 | self.bn3 = norm_layer(planes * self.expansion)
303 | self.relu = nn.ReLU(inplace=True)
304 | self.downsample = downsample
305 | self.stride = stride
306 |
307 | def forward(self, x):
308 | identity = x
309 |
310 | out = self.conv1(x)
311 | out = self.bn1(out)
312 | out = self.relu(out)
313 |
314 | out = self.conv2(out)
315 | out = self.bn2(out)
316 | out = self.relu(out)
317 |
318 | out = self.conv3(out)
319 | out = self.bn3(out)
320 |
321 | if self.downsample is not None:
322 | identity = self.downsample(x)
323 |
324 | out += identity
325 | out = self.relu(out)
326 |
327 | return out
328 |
329 | class ResNet(nn.Module):
330 | """
331 | Standard 3D-ResNet architecture with big initial 7x7x7 kernel.
332 | It can be turned in mode "classifier", outputting a vector of size or
333 | "encoder", outputting a latent vector of size 512 (independent of input size).
334 | Note: only a last FC layer is added on top of the "encoder" backbone.
335 | """
336 | def __init__(self, block, layers, in_channels=1,
337 | zero_init_residual=False, groups=1, width_per_group=64, replace_stride_with_dilation=None,
338 | norm_layer=None, initial_kernel_size=7):
339 | super(ResNet, self).__init__()
340 |
341 | if norm_layer is None:
342 | norm_layer = nn.BatchNorm3d
343 | self._norm_layer = norm_layer
344 |
345 | self.name = "resnet"
346 | self.inputs = None
347 | self.inplanes = 64
348 | self.dilation = 1
349 |
350 | if replace_stride_with_dilation is None:
351 | # each element in the tuple indicates if we should replace
352 | # the 2x2 stride with a dilated convolution instead
353 | replace_stride_with_dilation = [False, False, False]
354 | if len(replace_stride_with_dilation) != 3:
355 | raise ValueError("replace_stride_with_dilation should be None "
356 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
357 | self.groups = groups
358 | self.base_width = width_per_group
359 | initial_stride = 2 if initial_kernel_size==7 else 1
360 | padding = (initial_kernel_size-initial_stride+1)//2
361 | self.conv1 = nn.Conv3d(in_channels, self.inplanes, kernel_size=initial_kernel_size, stride=initial_stride, padding=padding, bias=False)
362 | self.bn1 = norm_layer(self.inplanes)
363 | self.relu = nn.ReLU(inplace=True)
364 | self.maxpool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1)
365 |
366 | channels = [64, 128, 256, 512]
367 |
368 | self.layer1 = self._make_layer(block, channels[0], layers[0])
369 | self.layer2 = self._make_layer(block, channels[1], layers[1], stride=2, dilate=replace_stride_with_dilation[0])
370 | self.layer3 = self._make_layer(block, channels[2], layers[2], stride=2, dilate=replace_stride_with_dilation[1])
371 | self.layer4 = self._make_layer(block, channels[3], layers[3], stride=2, dilate=replace_stride_with_dilation[2])
372 | self.avgpool = nn.AdaptiveAvgPool3d(1)
373 |
374 | for m in self.modules():
375 | if isinstance(m, nn.Conv3d):
376 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
377 | elif isinstance(m, (nn.BatchNorm3d, nn.GroupNorm)):
378 | nn.init.constant_(m.weight, 1)
379 | nn.init.constant_(m.bias, 0)
380 | elif isinstance(m, nn.Linear):
381 | nn.init.normal_(m.weight, 0, 0.01)
382 | if m.bias is not None:
383 | nn.init.constant_(m.bias, 0)
384 |
385 | # Zero-initialize the last BN in each residual branch,
386 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
387 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
388 | if zero_init_residual:
389 | for m in self.modules():
390 | if isinstance(m, Bottleneck):
391 | nn.init.constant_(m.bn3.weight, 0)
392 | elif isinstance(m, BasicBlock):
393 | nn.init.constant_(m.bn2.weight, 0)
394 |
395 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
396 | norm_layer = self._norm_layer
397 | downsample = None
398 | previous_dilation = self.dilation
399 | if dilate:
400 | self.dilation *= stride
401 | stride = 1
402 | if stride != 1 or self.inplanes != planes * block.expansion:
403 | downsample = nn.Sequential(
404 | conv1x1(self.inplanes, planes * block.expansion, stride),
405 | norm_layer(planes * block.expansion),
406 | )
407 |
408 | layers = []
409 | layers.append(block(self.inplanes, planes, stride=stride, downsample=downsample, groups=self.groups,
410 | base_width=self.base_width, dilation=previous_dilation, norm_layer=norm_layer))
411 | self.inplanes = planes * block.expansion
412 | for _ in range(1, blocks):
413 | layers.append(block(self.inplanes, planes, groups=self.groups,
414 | base_width=self.base_width, dilation=self.dilation,
415 | norm_layer=norm_layer))
416 |
417 | return nn.Sequential(*layers)
418 |
419 | def forward(self, x):
420 | x = self.conv1(x)
421 | x = self.bn1(x)
422 | x = self.relu(x)
423 | x = self.maxpool(x)
424 |
425 | x1 = self.layer1(x)
426 | x2 = self.layer2(x1)
427 | x3 = self.layer3(x2)
428 | x4 = self.layer4(x3)
429 |
430 | x5 = self.avgpool(x4)
431 | return torch.flatten(x5, 1)
432 |
433 | def resnet18(**kwargs):
434 | return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
435 |
436 | def resnet34(**kwargs):
437 | return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
438 |
439 | def resnet50(**kwargs):
440 | return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
441 |
442 | def resnet101(**kwargs):
443 | return ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
444 |
445 | model_dict = {
446 | 'resnet18': [resnet18, 512],
447 | 'resnet34': [resnet34, 512],
448 | 'resnet50': [resnet50, 2048],
449 | 'resnet101': [resnet101, 2048],
450 | }
451 |
452 | class SupConResNet(nn.Module):
453 | """backbone + projection head"""
454 | def __init__(self, name='resnet50', head='mlp', feat_dim=128):
455 | super().__init__()
456 | model_fun, dim_in = model_dict[name]
457 | self.encoder = model_fun()
458 | if head == 'linear':
459 | self.head = nn.Linear(dim_in, feat_dim)
460 | elif head == 'mlp':
461 | self.head = nn.Sequential(
462 | nn.Linear(dim_in, dim_in),
463 | nn.ReLU(inplace=True),
464 | nn.Linear(dim_in, feat_dim)
465 | )
466 | else:
467 | raise NotImplementedError(
468 | 'head not supported: {}'.format(head))
469 |
470 | def forward(self, x):
471 | feat = self.encoder(x)
472 | feat = F.normalize(self.head(feat), dim=1)
473 | return feat
474 |
475 |
476 | class AlexNet3D(nn.Module):
477 | def __init__(self):
478 | """
479 | :param num_classes: int, number of classes
480 | :param mode: "classifier" or "encoder" (returning 128-d vector)
481 | """
482 | super().__init__()
483 | self.features = nn.Sequential(
484 | nn.Conv3d(1, 64, kernel_size=5, stride=2, padding=0),
485 | nn.BatchNorm3d(64),
486 | nn.ReLU(inplace=True),
487 | nn.MaxPool3d(kernel_size=3, stride=3),
488 |
489 | nn.Conv3d(64, 128, kernel_size=3, stride=1, padding=0),
490 | nn.BatchNorm3d(128),
491 | nn.ReLU(inplace=True),
492 | nn.MaxPool3d(kernel_size=3, stride=3),
493 |
494 | nn.Conv3d(128, 192, kernel_size=3, padding=1),
495 | nn.BatchNorm3d(192),
496 | nn.ReLU(inplace=True),
497 |
498 | nn.Conv3d(192, 192, kernel_size=3, padding=1),
499 | nn.BatchNorm3d(192),
500 | nn.ReLU(inplace=True),
501 |
502 | nn.Conv3d(192, 128, kernel_size=3, padding=1),
503 | nn.BatchNorm3d(128),
504 | nn.ReLU(inplace=True),
505 | nn.AdaptiveMaxPool3d(1),
506 | )
507 |
508 |
509 | for m in self.modules():
510 | if isinstance(m, nn.Conv2d):
511 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
512 | m.weight.data.normal_(0, math.sqrt(2. / n))
513 | elif isinstance(m, nn.BatchNorm3d):
514 | m.weight.data.fill_(1)
515 | m.bias.data.zero_()
516 |
517 | def forward(self, x):
518 | xp = self.features(x)
519 | x = xp.view(xp.size(0), -1)
520 | return x
521 |
522 | class SupConAlexNet(nn.Module):
523 | """backbone + projection head"""
524 | def __init__(self, head='mlp', feat_dim=128):
525 | super().__init__()
526 | self.encoder = AlexNet3D()
527 | dim_in = 128
528 |
529 | if head == 'linear':
530 | self.head = nn.Linear(dim_in, feat_dim)
531 | elif head == 'mlp':
532 | self.head = nn.Sequential(
533 | nn.Linear(dim_in, dim_in),
534 | nn.ReLU(inplace=True),
535 | nn.Linear(dim_in, feat_dim)
536 | )
537 |
538 | else:
539 | raise NotImplementedError(
540 | 'head not supported: {}'.format(head))
541 |
542 | def forward(self, x):
543 | feat = self.encoder(x)
544 | feat = F.normalize(self.head(feat), dim=1)
545 | return feat
546 |
547 | def features(self, x):
548 | return self.forward(x)
549 |
550 | class DenseNet(nn.Module):
551 | """3D-Densenet-BC model class, based on
552 | `"Densely Connected Convolutional Networks" `_
553 | Args:
554 | growth_rate (int) - how many filters to add each layer (`k` in paper)
555 | block_config (list of 4 ints) - how many layers in each pooling block
556 | num_init_features (int) - the number of filters to learn in the first convolution layer
557 | mode (str) - "classifier" or "encoder" (all but last FC layer)
558 | bn_size (int) - multiplicative factor for number of bottle neck layers
559 | (i.e. bn_size * k features in the bottleneck layer)
560 | num_classes (int) - number of classification classes
561 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
562 | but slower. Default: *False*. See `"paper" `_
563 | """
564 |
565 | def __init__(self, growth_rate=32, block_config=(3, 12, 24, 16),
566 | num_init_features=64,
567 | bn_size=4, in_channels=1,
568 | memory_efficient=False):
569 | super(DenseNet, self).__init__()
570 | # First convolution
571 | self.features = nn.Sequential(OrderedDict([
572 | ('conv0', nn.Conv3d(in_channels, num_init_features,
573 | kernel_size=7, stride=2, padding=3, bias=False)),
574 | ('norm0', nn.BatchNorm3d(num_init_features)),
575 | ('relu0', nn.ReLU(inplace=True)),
576 | ('pool0', nn.MaxPool3d(kernel_size=3, stride=2, padding=1)),
577 | ]))
578 |
579 | # Each denseblock
580 | num_features = num_init_features
581 | for i, num_layers in enumerate(block_config):
582 | block = _DenseBlock(
583 | num_layers=num_layers,
584 | num_input_features=num_features,
585 | bn_size=bn_size,
586 | growth_rate=growth_rate,
587 | memory_efficient=memory_efficient
588 | )
589 | self.features.add_module('denseblock%d' % (i + 1), block)
590 | num_features = num_features + num_layers * growth_rate
591 | if i != len(block_config) - 1:
592 | trans = _Transition(num_input_features=num_features,
593 | num_output_features=num_features // 2)
594 | self.features.add_module('transition%d' % (i + 1), trans)
595 | num_features = num_features // 2
596 |
597 | self.num_features = num_features
598 |
599 |
600 | # Official init from torch repo.
601 | for m in self.modules():
602 | if isinstance(m, nn.Conv3d):
603 | nn.init.kaiming_normal_(m.weight)
604 | elif isinstance(m, nn.BatchNorm3d):
605 | nn.init.constant_(m.weight, 1)
606 | nn.init.constant_(m.bias, 0)
607 | elif isinstance(m, nn.Linear):
608 | nn.init.constant_(m.bias, 0)
609 |
610 | def forward(self, x):
611 | features = self.features(x)
612 | out = F.adaptive_avg_pool3d(features, 1)
613 | out = torch.flatten(out, 1)
614 | return out.squeeze(dim=1)
615 |
616 |
617 | def _bn_function_factory(norm, relu, conv):
618 | def bn_function(*inputs):
619 | concated_features = torch.cat(inputs, 1)
620 | bottleneck_output = conv(relu(norm(concated_features)))
621 | return bottleneck_output
622 |
623 | return bn_function
624 |
625 |
626 | class _DenseLayer(nn.Sequential):
627 | def __init__(self, num_input_features, growth_rate, bn_size, memory_efficient=False):
628 | super(_DenseLayer, self).__init__()
629 | self.add_module('norm1', nn.BatchNorm3d(num_input_features)),
630 | self.add_module('relu1', nn.ReLU(inplace=True)),
631 | self.add_module('conv1', nn.Conv3d(num_input_features, bn_size *
632 | growth_rate, kernel_size=1, stride=1,
633 | bias=False)),
634 | self.add_module('norm2', nn.BatchNorm3d(bn_size * growth_rate)),
635 | self.add_module('relu2', nn.ReLU(inplace=True)),
636 | self.add_module('conv2', nn.Conv3d(bn_size * growth_rate, growth_rate,
637 | kernel_size=3, stride=1, padding=1,
638 | bias=False)),
639 | self.memory_efficient = memory_efficient
640 |
641 | def forward(self, *prev_features):
642 | bn_function = _bn_function_factory(self.norm1, self.relu1, self.conv1)
643 | if self.memory_efficient and any(prev_feature.requires_grad for prev_feature in prev_features):
644 | bottleneck_output = cp.checkpoint(bn_function, *prev_features)
645 | else:
646 | bottleneck_output = bn_function(*prev_features)
647 |
648 | new_features = self.conv2(self.relu2(self.norm2(bottleneck_output)))
649 |
650 | return new_features
651 |
652 |
653 | class _DenseBlock(nn.Module):
654 | def __init__(self, num_layers, num_input_features, bn_size, growth_rate, memory_efficient=False):
655 | super(_DenseBlock, self).__init__()
656 | for i in range(num_layers):
657 | layer = _DenseLayer(
658 | num_input_features + i * growth_rate,
659 | growth_rate=growth_rate,
660 | bn_size=bn_size,
661 | memory_efficient=memory_efficient,
662 | )
663 | self.add_module('denselayer%d' % (i + 1), layer)
664 |
665 | def forward(self, init_features):
666 | features = [init_features]
667 | for name, layer in self.named_children():
668 | new_features = layer(*features)
669 | features.append(new_features)
670 | return torch.cat(features, 1)
671 |
672 |
673 | class _Transition(nn.Sequential):
674 | def __init__(self, num_input_features, num_output_features):
675 | super(_Transition, self).__init__()
676 | self.add_module('norm', nn.BatchNorm3d(num_input_features))
677 | self.add_module('relu', nn.ReLU(inplace=True))
678 | self.add_module('conv', nn.Conv3d(num_input_features, num_output_features,
679 | kernel_size=1, stride=1, bias=False))
680 | self.add_module('pool', nn.AvgPool3d(kernel_size=2, stride=2))
681 |
682 |
683 | def _densenet(arch, growth_rate, block_config, num_init_features, **kwargs):
684 | model = DenseNet(growth_rate, block_config, num_init_features, **kwargs)
685 | return model
686 |
687 |
688 | def densenet121(**kwargs):
689 | r"""Densenet-121 model from
690 | `"Densely Connected Convolutional Networks" `_
691 |
692 | Args:
693 | pretrained (bool): If True, returns a model pre-trained on ImageNet
694 | progress (bool): If True, displays a progress bar of the download to stderr
695 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
696 | but slower. Default: *False*. See `"paper" `_
697 | """
698 | return _densenet('densenet121', 32, (6, 12, 24, 16), 64, **kwargs)
699 |
700 | class SupConDenseNet(nn.Module):
701 | """backbone + projection head"""
702 | def __init__(self, head='mlp', feat_dim=128):
703 | super().__init__()
704 | self.encoder = densenet121()
705 | dim_in = self.encoder.num_features
706 |
707 | if head == 'linear':
708 | self.head = nn.Linear(dim_in, feat_dim)
709 | elif head == 'mlp':
710 | self.head = nn.Sequential(
711 | nn.Linear(dim_in, dim_in),
712 | nn.ReLU(inplace=True),
713 | nn.Linear(dim_in, feat_dim)
714 | )
715 |
716 | else:
717 | raise NotImplementedError(
718 | 'head not supported: {}'.format(head))
719 |
720 | def forward(self, x):
721 | feat = self.encoder(x)
722 | feat = F.normalize(self.head(feat), dim=1)
723 | return feat
724 |
725 | def features(self, x):
726 | return self.forward(x)
727 |
728 |
729 | class RegressionModel(metaclass=ABCMeta):
730 | __model_local_weights__ = os.path.join(os.path.dirname(__file__), os.environ.get("MODEL", "weights.pth"))
731 | __metadata_local_weights__ = os.path.join(os.path.dirname(__file__), "metadata.pkl")
732 |
733 | def __init__(self, model, batch_size=15, transforms=None):
734 | self.model = model
735 | self.batch_size = batch_size
736 | self.transforms = transforms
737 | self.indices = None
738 |
739 | def fit(self, X, y):
740 | """ Restore weights.
741 | """
742 | if not os.path.isfile(self.__model_local_weights__):
743 | raise ValueError("You must provide the model weigths in your submission folder.")
744 | state = torch.load(self.__model_local_weights__, map_location="cpu")
745 |
746 | if "model" not in state:
747 | raise ValueError("Model weigths are searched in the state dictionary at the 'model' key location.")
748 | self.model.load_state_dict(state["model"], strict=True)
749 |
750 | def predict(self, X: np.ndarray) -> np.ndarray:
751 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
752 | self.model.to(device)
753 |
754 | dataset = Dataset(X, transforms=self.transforms, indices=self.indices)
755 | testloader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, shuffle=False, num_workers=0)
756 |
757 | self.model.eval()
758 | outputs = []
759 |
760 | with progressbar.ProgressBar(max_value=len(testloader)) as bar:
761 | for cnt, inputs in enumerate(testloader):
762 | inputs = inputs.float().to(device)
763 | # print("Batch size", inputs.shape)
764 | with torch.no_grad():
765 | out = self.model(inputs)
766 | # out = torch.randn((inputs.shape[0], 128))
767 |
768 | outputs.append(out.detach())
769 | bar.update(cnt)
770 |
771 | outputs = torch.cat(outputs, dim=0)
772 | return outputs.detach().cpu().numpy()
773 |
774 |
775 | ############################################################################
776 | # Define here your estimator pipeline
777 | ############################################################################
778 |
779 | def get_estimator(mock=False) -> Pipeline:
780 | """ Build your estimator here.
781 | Notes
782 | -----
783 | In order to minimize the memory load the first steps of the pipeline
784 | are applied directly as transforms attached to the Torch Dataset.
785 | Notes
786 | -----
787 | It is recommended to create an instance of sklearn.pipeline.Pipeline.
788 | """
789 | print("InfoNCE")
790 | if "resnet" in ARCHITECTURE:
791 | net = SupConResNet(ARCHITECTURE)
792 | elif ARCHITECTURE == "alexnet":
793 | net = SupConAlexNet()
794 | elif "densenet" in ARCHITECTURE:
795 | net = SupConDenseNet()
796 |
797 | selector = FeatureExtractor("vbm", mock=mock)
798 | preproc = transforms.Compose([
799 | transforms.Lambda(lambda x: selector.transform(x)),
800 | # Crop((1, 121, 128, 121), type="center"),
801 | # Pad((1, 128, 128, 128)),
802 | transforms.Lambda(lambda x: torch.from_numpy(x).float()),
803 | transforms.Normalize(mean=0.0, std=1.0),
804 | ])
805 | estimator = make_pipeline(
806 | RegressionModel(net, transforms=preproc))
807 | return estimator
808 |
809 |
810 | if __name__ == '__main__':
811 | estimator = get_estimator(mock=True).fit(None)
812 | estimator.predict(np.random.random((32, 2122945)))
--------------------------------------------------------------------------------
/ramp-submission/estimator_mse.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | ##########################################################################
3 | # Code version 6207dffcc20f461bdb742f5d8a2f6641483b9d83
4 | ##########################################################################
5 |
6 |
7 | """
8 | Each solution to be tested should be stored in its own directory within
9 | submissions/. The name of this new directory will serve as the ID for
10 | the submission. If you wish to launch a RAMP challenge you will need to
11 | provide an example solution within submissions/starting_kit/. Even if
12 | you are not launching a RAMP challenge on RAMP Studio, it is useful to
13 | have an example submission as it shows which files are required, how they
14 | need to be named and how each file should be structured.
15 | """
16 |
17 | # Filename: estimator_mse.py
18 | # Run id:
19 | #
20 | import os
21 | ARCHITECTURE = os.environ.get("ARCHITECTURE", "resnet18")
22 |
23 |
24 | from collections import OrderedDict
25 | from abc import ABCMeta
26 | import progressbar
27 | import nibabel
28 | import numpy as np
29 | from nilearn.masking import unmask
30 | from sklearn.base import BaseEstimator
31 | from sklearn.base import TransformerMixin
32 | from sklearn.pipeline import Pipeline, make_pipeline
33 | import torch
34 | import torch.nn as nn
35 | import torch.nn.functional as F
36 | import torch.utils.checkpoint as cp
37 | from torchvision import transforms
38 | import math
39 |
40 | ############################################################################
41 | # Define here some selectors
42 | ############################################################################
43 |
44 | class FeatureExtractor(BaseEstimator, TransformerMixin):
45 | """ Select only the requested data associatedd features from the the
46 | input buffered data.
47 | """
48 | MODALITIES = OrderedDict([
49 | ("vbm", {
50 | "shape": (1, 121, 145, 121),
51 | "size": 519945}),
52 | ("quasiraw", {
53 | "shape": (1, 182, 218, 182),
54 | "size": 1827095}),
55 | ("xhemi", {
56 | "shape": (8, 163842),
57 | "size": 1310736}),
58 | ("vbm_roi", {
59 | "shape": (1, 284),
60 | "size": 284}),
61 | ("desikan_roi", {
62 | "shape": (7, 68),
63 | "size": 476}),
64 | ("destrieux_roi", {
65 | "shape": (7, 148),
66 | "size": 1036})
67 | ])
68 | MASKS = {
69 | "vbm": {
70 | "path": None,
71 | "thr": 0.05},
72 | "quasiraw": {
73 | "path": None,
74 | "thr": 0}
75 | }
76 |
77 | def __init__(self, dtype, mock=False):
78 | """ Init class.
79 | Parameters
80 | ----------
81 | dtype: str
82 | the requested data: 'vbm', 'quasiraw', 'vbm_roi', 'desikan_roi',
83 | 'destrieux_roi' or 'xhemi'.
84 | """
85 | if dtype not in self.MODALITIES:
86 | raise ValueError("Invalid input data type.")
87 | self.dtype = dtype
88 |
89 | data_types = list(self.MODALITIES.keys())
90 | index = data_types.index(dtype)
91 |
92 | cumsum = np.cumsum([item["size"] for item in self.MODALITIES.values()])
93 |
94 | if index > 0:
95 | self.start = cumsum[index - 1]
96 | else:
97 | self.start = 0
98 | self.stop = cumsum[index]
99 |
100 | self.masks = dict((key, val["path"]) for key, val in self.MASKS.items())
101 | self.masks["vbm"] = os.environ.get("VBM_MASK")
102 | self.masks["quasiraw"] = os.environ.get("QUASIRAW_MASK")
103 |
104 | self.mock = mock
105 | if mock:
106 | return
107 |
108 | for key in self.masks:
109 | if self.masks[key] is None or not os.path.isfile(self.masks[key]):
110 | raise ValueError("Impossible to find mask:", key, self.masks[key])
111 | arr = nibabel.load(self.masks[key]).get_fdata()
112 | thr = self.MASKS[key]["thr"]
113 | arr[arr <= thr] = 0
114 | arr[arr > thr] = 1
115 | self.masks[key] = nibabel.Nifti1Image(arr.astype(int), np.eye(4))
116 |
117 | def fit(self, X, y):
118 | return self
119 |
120 | def transform(self, X):
121 | if self.mock:
122 | #print("transforming", X.shape)
123 | data = X.reshape(self.MODALITIES[self.dtype]["shape"])
124 | #print("mock data:", data.shape)
125 | return data
126 |
127 | # print(X.shape)
128 | select_X = X[self.start:self.stop]
129 | if self.dtype in ("vbm", "quasiraw"):
130 | im = unmask(select_X, self.masks[self.dtype])
131 | select_X = im.get_fdata()
132 | select_X = select_X.transpose(2, 0, 1)
133 | select_X = select_X.reshape(self.MODALITIES[self.dtype]["shape"])
134 | return select_X
135 |
136 | class Crop(object):
137 | """ Crop the given n-dimensional array either at a random location or
138 | centered.
139 | """
140 | def __init__(self, shape, type="center", keep_dim=False):
141 | assert type in ["center", "random"]
142 | self.shape = shape
143 | self.copping_type = type
144 | self.keep_dim = keep_dim
145 |
146 | def __call__(self, X):
147 | img_shape = np.array(X.shape)
148 |
149 | if type(self.shape) == int:
150 | size = [self.shape for _ in range(len(self.shape))]
151 | else:
152 | size = np.copy(self.shape)
153 |
154 | # print('img_shape:', img_shape, 'size', size)
155 |
156 | indexes = []
157 | for ndim in range(len(img_shape)):
158 | if size[ndim] > img_shape[ndim] or size[ndim] < 0:
159 | size[ndim] = img_shape[ndim]
160 |
161 | if self.copping_type == "center":
162 | delta_before = int((img_shape[ndim] - size[ndim]) / 2.0)
163 |
164 | elif self.copping_type == "random":
165 | delta_before = np.random.randint(0, img_shape[ndim] - size[ndim] + 1)
166 |
167 | indexes.append(slice(delta_before, delta_before + size[ndim]))
168 |
169 | if self.keep_dim:
170 | mask = np.zeros(img_shape, dtype=np.bool)
171 | mask[tuple(indexes)] = True
172 | arr_copy = X.copy()
173 | arr_copy[~mask] = 0
174 | return arr_copy
175 |
176 | _X = X[tuple(indexes)]
177 | # print('cropped.shape', _X.shape)
178 | return _X
179 |
180 | class Pad(object):
181 | """ Pad the given n-dimensional array
182 | """
183 | def __init__(self, shape, **kwargs):
184 | self.shape = shape
185 | self.kwargs = kwargs
186 |
187 | def __call__(self, X):
188 | _X = self._apply_padding(X)
189 | return _X
190 |
191 | def _apply_padding(self, arr):
192 | orig_shape = arr.shape
193 | padding = []
194 | for orig_i, final_i in zip(orig_shape, self.shape):
195 | shape_i = final_i - orig_i
196 | half_shape_i = shape_i // 2
197 | if shape_i % 2 == 0:
198 | padding.append([half_shape_i, half_shape_i])
199 | else:
200 | padding.append([half_shape_i, half_shape_i + 1])
201 | for cnt in range(len(arr.shape) - len(padding)):
202 | padding.append([0, 0])
203 | fill_arr = np.pad(arr, padding, **self.kwargs)
204 | return fill_arr
205 |
206 | ############################################################################
207 | # Define here your dataset
208 | ############################################################################
209 |
210 | class Dataset(torch.utils.data.Dataset):
211 | def __init__(self, X, y=None, transforms=None, indices=None):
212 | self.T = transforms
213 | self.X = X
214 | self.y = y
215 | self.indices = indices
216 | if indices is None:
217 | self.indices = range(len(X))
218 |
219 | def __len__(self):
220 | return len(self.indices)
221 |
222 | def __getitem__(self, i):
223 | real_i = self.indices[i]
224 | x = self.X[real_i]
225 |
226 | if self.T is not None:
227 | x = self.T(x)
228 |
229 | if self.y is not None:
230 | y = self.y[real_i]
231 | return x, y
232 | else:
233 | return x
234 |
235 |
236 | ############################################################################
237 | # Define here your regression model
238 | ############################################################################
239 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
240 | """3x3 convolution with padding"""
241 | return nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=stride,
242 | padding=dilation, groups=groups, bias=False, dilation=dilation)
243 |
244 | def conv1x1(in_planes, out_planes, stride=1):
245 | """1x1 convolution"""
246 | return nn.Conv3d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
247 |
248 | class BasicBlock(nn.Module):
249 | expansion = 1
250 |
251 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
252 | base_width=64, dilation=1, norm_layer=None):
253 | super(BasicBlock, self).__init__()
254 | if norm_layer is None:
255 | norm_layer = nn.BatchNorm3d
256 | if groups != 1 or base_width != 64:
257 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
258 | if dilation > 1:
259 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
260 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
261 | self.conv1 = conv3x3(inplanes, planes, stride)
262 | self.bn1 = norm_layer(planes)
263 | self.relu = nn.ReLU(inplace=True)
264 | self.conv2 = conv3x3(planes, planes)
265 | self.bn2 = norm_layer(planes)
266 | self.downsample = downsample
267 | self.stride = stride
268 |
269 | def forward(self, x):
270 | identity = x
271 |
272 | out = self.conv1(x)
273 | out = self.bn1(out)
274 | out = self.relu(out)
275 | out = self.conv2(out)
276 | out = self.bn2(out)
277 |
278 | if self.downsample is not None:
279 | identity = self.downsample(x)
280 |
281 | out += identity
282 | out = self.relu(out)
283 |
284 | return out
285 |
286 |
287 | class Bottleneck(nn.Module):
288 | expansion = 4
289 |
290 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
291 | base_width=64, dilation=1, norm_layer=None):
292 | super(Bottleneck, self).__init__()
293 | if norm_layer is None:
294 | norm_layer = nn.BatchNorm3d
295 | width = int(planes * (base_width / 64.)) * groups
296 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
297 | self.conv1 = conv1x1(inplanes, width)
298 | self.bn1 = norm_layer(width)
299 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
300 | self.bn2 = norm_layer(width)
301 | self.conv3 = conv1x1(width, planes * self.expansion)
302 | self.bn3 = norm_layer(planes * self.expansion)
303 | self.relu = nn.ReLU(inplace=True)
304 | self.downsample = downsample
305 | self.stride = stride
306 |
307 | def forward(self, x):
308 | identity = x
309 |
310 | out = self.conv1(x)
311 | out = self.bn1(out)
312 | out = self.relu(out)
313 |
314 | out = self.conv2(out)
315 | out = self.bn2(out)
316 | out = self.relu(out)
317 |
318 | out = self.conv3(out)
319 | out = self.bn3(out)
320 |
321 | if self.downsample is not None:
322 | identity = self.downsample(x)
323 |
324 | out += identity
325 | out = self.relu(out)
326 |
327 | return out
328 |
329 | class ResNet(nn.Module):
330 | """
331 | Standard 3D-ResNet architecture with big initial 7x7x7 kernel.
332 | It can be turned in mode "classifier", outputting a vector of size or
333 | "encoder", outputting a latent vector of size 512 (independent of input size).
334 | Note: only a last FC layer is added on top of the "encoder" backbone.
335 | """
336 | def __init__(self, block, layers, in_channels=1,
337 | zero_init_residual=False, groups=1, width_per_group=64, replace_stride_with_dilation=None,
338 | norm_layer=None, initial_kernel_size=7):
339 | super(ResNet, self).__init__()
340 |
341 | if norm_layer is None:
342 | norm_layer = nn.BatchNorm3d
343 | self._norm_layer = norm_layer
344 |
345 | self.name = "resnet"
346 | self.inputs = None
347 | self.inplanes = 64
348 | self.dilation = 1
349 |
350 | if replace_stride_with_dilation is None:
351 | # each element in the tuple indicates if we should replace
352 | # the 2x2 stride with a dilated convolution instead
353 | replace_stride_with_dilation = [False, False, False]
354 | if len(replace_stride_with_dilation) != 3:
355 | raise ValueError("replace_stride_with_dilation should be None "
356 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
357 | self.groups = groups
358 | self.base_width = width_per_group
359 | initial_stride = 2 if initial_kernel_size==7 else 1
360 | padding = (initial_kernel_size-initial_stride+1)//2
361 | self.conv1 = nn.Conv3d(in_channels, self.inplanes, kernel_size=initial_kernel_size, stride=initial_stride, padding=padding, bias=False)
362 | self.bn1 = norm_layer(self.inplanes)
363 | self.relu = nn.ReLU(inplace=True)
364 | self.maxpool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1)
365 |
366 | channels = [64, 128, 256, 512]
367 |
368 | self.layer1 = self._make_layer(block, channels[0], layers[0])
369 | self.layer2 = self._make_layer(block, channels[1], layers[1], stride=2, dilate=replace_stride_with_dilation[0])
370 | self.layer3 = self._make_layer(block, channels[2], layers[2], stride=2, dilate=replace_stride_with_dilation[1])
371 | self.layer4 = self._make_layer(block, channels[3], layers[3], stride=2, dilate=replace_stride_with_dilation[2])
372 | self.avgpool = nn.AdaptiveAvgPool3d(1)
373 |
374 | for m in self.modules():
375 | if isinstance(m, nn.Conv3d):
376 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
377 | elif isinstance(m, (nn.BatchNorm3d, nn.GroupNorm)):
378 | nn.init.constant_(m.weight, 1)
379 | nn.init.constant_(m.bias, 0)
380 | elif isinstance(m, nn.Linear):
381 | nn.init.normal_(m.weight, 0, 0.01)
382 | if m.bias is not None:
383 | nn.init.constant_(m.bias, 0)
384 |
385 | # Zero-initialize the last BN in each residual branch,
386 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
387 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
388 | if zero_init_residual:
389 | for m in self.modules():
390 | if isinstance(m, Bottleneck):
391 | nn.init.constant_(m.bn3.weight, 0)
392 | elif isinstance(m, BasicBlock):
393 | nn.init.constant_(m.bn2.weight, 0)
394 |
395 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
396 | norm_layer = self._norm_layer
397 | downsample = None
398 | previous_dilation = self.dilation
399 | if dilate:
400 | self.dilation *= stride
401 | stride = 1
402 | if stride != 1 or self.inplanes != planes * block.expansion:
403 | downsample = nn.Sequential(
404 | conv1x1(self.inplanes, planes * block.expansion, stride),
405 | norm_layer(planes * block.expansion),
406 | )
407 |
408 | layers = []
409 | layers.append(block(self.inplanes, planes, stride=stride, downsample=downsample, groups=self.groups,
410 | base_width=self.base_width, dilation=previous_dilation, norm_layer=norm_layer))
411 | self.inplanes = planes * block.expansion
412 | for _ in range(1, blocks):
413 | layers.append(block(self.inplanes, planes, groups=self.groups,
414 | base_width=self.base_width, dilation=self.dilation,
415 | norm_layer=norm_layer))
416 |
417 | return nn.Sequential(*layers)
418 |
419 | def forward(self, x):
420 | x = self.conv1(x)
421 | x = self.bn1(x)
422 | x = self.relu(x)
423 | x = self.maxpool(x)
424 |
425 | x1 = self.layer1(x)
426 | x2 = self.layer2(x1)
427 | x3 = self.layer3(x2)
428 | x4 = self.layer4(x3)
429 |
430 | x5 = self.avgpool(x4)
431 | return torch.flatten(x5, 1)
432 |
433 | def resnet18(**kwargs):
434 | return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
435 |
436 | def resnet34(**kwargs):
437 | return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
438 |
439 | def resnet50(**kwargs):
440 | return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
441 |
442 | def resnet101(**kwargs):
443 | return ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
444 |
445 | model_dict = {
446 | 'resnet18': [resnet18, 512],
447 | 'resnet34': [resnet34, 512],
448 | 'resnet50': [resnet50, 2048],
449 | 'resnet101': [resnet101, 2048],
450 | }
451 |
452 | class SupRegResNet(nn.Module):
453 | """encoder + regressor"""
454 | def __init__(self, name='resnet50'):
455 | super().__init__()
456 | model_fun, dim_in = model_dict[name]
457 | self.encoder = model_fun()
458 | self.fc = nn.Linear(dim_in, 1)
459 |
460 | def forward(self, x):
461 | return self.encoder(x)
462 | # return self.fc(self.encoder(x))
463 |
464 | class AlexNet3D(nn.Module):
465 | def __init__(self):
466 | """
467 | :param num_classes: int, number of classes
468 | :param mode: "classifier" or "encoder" (returning 128-d vector)
469 | """
470 | super().__init__()
471 | self.features = nn.Sequential(
472 | nn.Conv3d(1, 64, kernel_size=5, stride=2, padding=0),
473 | nn.BatchNorm3d(64),
474 | nn.ReLU(inplace=True),
475 | nn.MaxPool3d(kernel_size=3, stride=3),
476 |
477 | nn.Conv3d(64, 128, kernel_size=3, stride=1, padding=0),
478 | nn.BatchNorm3d(128),
479 | nn.ReLU(inplace=True),
480 | nn.MaxPool3d(kernel_size=3, stride=3),
481 |
482 | nn.Conv3d(128, 192, kernel_size=3, padding=1),
483 | nn.BatchNorm3d(192),
484 | nn.ReLU(inplace=True),
485 |
486 | nn.Conv3d(192, 192, kernel_size=3, padding=1),
487 | nn.BatchNorm3d(192),
488 | nn.ReLU(inplace=True),
489 |
490 | nn.Conv3d(192, 128, kernel_size=3, padding=1),
491 | nn.BatchNorm3d(128),
492 | nn.ReLU(inplace=True),
493 | nn.AdaptiveMaxPool3d(1),
494 | )
495 |
496 |
497 | for m in self.modules():
498 | if isinstance(m, nn.Conv2d):
499 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
500 | m.weight.data.normal_(0, math.sqrt(2. / n))
501 | elif isinstance(m, nn.BatchNorm3d):
502 | m.weight.data.fill_(1)
503 | m.bias.data.zero_()
504 |
505 | def forward(self, x):
506 | xp = self.features(x)
507 | x = xp.view(xp.size(0), -1)
508 | return x
509 |
510 | class SupRegAlexNet(nn.Module):
511 | """encoder + regressor"""
512 | def __init__(self,):
513 | super().__init__()
514 | self.encoder = AlexNet3D()
515 | self.fc = nn.Linear(128, 1)
516 |
517 | def forward(self, x):
518 | feats = self.features(x)
519 | return feats
520 | # return self.fc(feats), feats
521 |
522 | def features(self, x):
523 | return self.encoder(x)
524 |
525 | class DenseNet(nn.Module):
526 | """3D-Densenet-BC model class, based on
527 | `"Densely Connected Convolutional Networks" `_
528 | Args:
529 | growth_rate (int) - how many filters to add each layer (`k` in paper)
530 | block_config (list of 4 ints) - how many layers in each pooling block
531 | num_init_features (int) - the number of filters to learn in the first convolution layer
532 | mode (str) - "classifier" or "encoder" (all but last FC layer)
533 | bn_size (int) - multiplicative factor for number of bottle neck layers
534 | (i.e. bn_size * k features in the bottleneck layer)
535 | num_classes (int) - number of classification classes
536 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
537 | but slower. Default: *False*. See `"paper" `_
538 | """
539 |
540 | def __init__(self, growth_rate=32, block_config=(3, 12, 24, 16),
541 | num_init_features=64,
542 | bn_size=4, in_channels=1,
543 | memory_efficient=False):
544 | super(DenseNet, self).__init__()
545 | # First convolution
546 | self.features = nn.Sequential(OrderedDict([
547 | ('conv0', nn.Conv3d(in_channels, num_init_features,
548 | kernel_size=7, stride=2, padding=3, bias=False)),
549 | ('norm0', nn.BatchNorm3d(num_init_features)),
550 | ('relu0', nn.ReLU(inplace=True)),
551 | ('pool0', nn.MaxPool3d(kernel_size=3, stride=2, padding=1)),
552 | ]))
553 |
554 | # Each denseblock
555 | num_features = num_init_features
556 | for i, num_layers in enumerate(block_config):
557 | block = _DenseBlock(
558 | num_layers=num_layers,
559 | num_input_features=num_features,
560 | bn_size=bn_size,
561 | growth_rate=growth_rate,
562 | memory_efficient=memory_efficient
563 | )
564 | self.features.add_module('denseblock%d' % (i + 1), block)
565 | num_features = num_features + num_layers * growth_rate
566 | if i != len(block_config) - 1:
567 | trans = _Transition(num_input_features=num_features,
568 | num_output_features=num_features // 2)
569 | self.features.add_module('transition%d' % (i + 1), trans)
570 | num_features = num_features // 2
571 |
572 | self.num_features = num_features
573 |
574 |
575 | # Official init from torch repo.
576 | for m in self.modules():
577 | if isinstance(m, nn.Conv3d):
578 | nn.init.kaiming_normal_(m.weight)
579 | elif isinstance(m, nn.BatchNorm3d):
580 | nn.init.constant_(m.weight, 1)
581 | nn.init.constant_(m.bias, 0)
582 | elif isinstance(m, nn.Linear):
583 | nn.init.constant_(m.bias, 0)
584 |
585 | def forward(self, x):
586 | features = self.features(x)
587 | out = F.adaptive_avg_pool3d(features, 1)
588 | out = torch.flatten(out, 1)
589 | return out.squeeze(dim=1)
590 |
591 |
592 | def _bn_function_factory(norm, relu, conv):
593 | def bn_function(*inputs):
594 | concated_features = torch.cat(inputs, 1)
595 | bottleneck_output = conv(relu(norm(concated_features)))
596 | return bottleneck_output
597 |
598 | return bn_function
599 |
600 |
601 | class _DenseLayer(nn.Sequential):
602 | def __init__(self, num_input_features, growth_rate, bn_size, memory_efficient=False):
603 | super(_DenseLayer, self).__init__()
604 | self.add_module('norm1', nn.BatchNorm3d(num_input_features)),
605 | self.add_module('relu1', nn.ReLU(inplace=True)),
606 | self.add_module('conv1', nn.Conv3d(num_input_features, bn_size *
607 | growth_rate, kernel_size=1, stride=1,
608 | bias=False)),
609 | self.add_module('norm2', nn.BatchNorm3d(bn_size * growth_rate)),
610 | self.add_module('relu2', nn.ReLU(inplace=True)),
611 | self.add_module('conv2', nn.Conv3d(bn_size * growth_rate, growth_rate,
612 | kernel_size=3, stride=1, padding=1,
613 | bias=False)),
614 | self.memory_efficient = memory_efficient
615 |
616 | def forward(self, *prev_features):
617 | bn_function = _bn_function_factory(self.norm1, self.relu1, self.conv1)
618 | if self.memory_efficient and any(prev_feature.requires_grad for prev_feature in prev_features):
619 | bottleneck_output = cp.checkpoint(bn_function, *prev_features)
620 | else:
621 | bottleneck_output = bn_function(*prev_features)
622 |
623 | new_features = self.conv2(self.relu2(self.norm2(bottleneck_output)))
624 |
625 | return new_features
626 |
627 |
628 | class _DenseBlock(nn.Module):
629 | def __init__(self, num_layers, num_input_features, bn_size, growth_rate, memory_efficient=False):
630 | super(_DenseBlock, self).__init__()
631 | for i in range(num_layers):
632 | layer = _DenseLayer(
633 | num_input_features + i * growth_rate,
634 | growth_rate=growth_rate,
635 | bn_size=bn_size,
636 | memory_efficient=memory_efficient,
637 | )
638 | self.add_module('denselayer%d' % (i + 1), layer)
639 |
640 | def forward(self, init_features):
641 | features = [init_features]
642 | for name, layer in self.named_children():
643 | new_features = layer(*features)
644 | features.append(new_features)
645 | return torch.cat(features, 1)
646 |
647 |
648 | class _Transition(nn.Sequential):
649 | def __init__(self, num_input_features, num_output_features):
650 | super(_Transition, self).__init__()
651 | self.add_module('norm', nn.BatchNorm3d(num_input_features))
652 | self.add_module('relu', nn.ReLU(inplace=True))
653 | self.add_module('conv', nn.Conv3d(num_input_features, num_output_features,
654 | kernel_size=1, stride=1, bias=False))
655 | self.add_module('pool', nn.AvgPool3d(kernel_size=2, stride=2))
656 |
657 |
658 | def _densenet(arch, growth_rate, block_config, num_init_features, **kwargs):
659 | model = DenseNet(growth_rate, block_config, num_init_features, **kwargs)
660 | return model
661 |
662 |
663 | def densenet121(**kwargs):
664 | r"""Densenet-121 model from
665 | `"Densely Connected Convolutional Networks" `_
666 |
667 | Args:
668 | pretrained (bool): If True, returns a model pre-trained on ImageNet
669 | progress (bool): If True, displays a progress bar of the download to stderr
670 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
671 | but slower. Default: *False*. See `"paper" `_
672 | """
673 | return _densenet('densenet121', 32, (6, 12, 24, 16), 64, **kwargs)
674 |
675 | class SupRegDenseNet(nn.Module):
676 | """encoder + regressor"""
677 | def __init__(self,):
678 | super().__init__()
679 | self.encoder = densenet121()
680 | self.fc = nn.Linear(self.encoder.num_features, 1)
681 |
682 | def forward(self, x):
683 | feats = self.features(x)
684 | return feats
685 | # return self.fc(feats), feats
686 |
687 | def features(self, x):
688 | return self.encoder(x)
689 |
690 | class RegressionModel(metaclass=ABCMeta):
691 | __model_local_weights__ = os.path.join(os.path.dirname(__file__), os.environ.get("MODEL", "weights.pth"))
692 | __metadata_local_weights__ = os.path.join(os.path.dirname(__file__), "metadata.pkl")
693 |
694 | def __init__(self, model, batch_size=15, transforms=None):
695 | self.model = model
696 | self.batch_size = batch_size
697 | self.transforms = transforms
698 | self.indices = None
699 |
700 | def fit(self, X, y):
701 | """ Restore weights.
702 | """
703 | if not os.path.isfile(self.__model_local_weights__):
704 | raise ValueError("You must provide the model weigths in your submission folder.")
705 | state = torch.load(self.__model_local_weights__, map_location="cpu")
706 |
707 | if "model" not in state:
708 | raise ValueError("Model weigths are searched in the state dictionary at the 'model' key location.")
709 | self.model.load_state_dict(state["model"], strict=True)
710 |
711 | def predict(self, X: np.ndarray) -> np.ndarray:
712 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
713 | self.model.to(device)
714 |
715 | dataset = Dataset(X, transforms=self.transforms, indices=self.indices)
716 | testloader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, shuffle=False, num_workers=0)
717 |
718 | self.model.eval()
719 | outputs = []
720 |
721 | with progressbar.ProgressBar(max_value=len(testloader)) as bar:
722 | for cnt, inputs in enumerate(testloader):
723 | inputs = inputs.float().to(device)
724 | # print("Batch size", inputs.shape)
725 | with torch.no_grad():
726 | out = self.model(inputs)
727 | # out = torch.randn((inputs.shape[0], 128))
728 |
729 | outputs.append(out.detach())
730 | bar.update(cnt)
731 |
732 | outputs = torch.cat(outputs, dim=0)
733 | return outputs.detach().cpu().numpy()
734 |
735 |
736 | ############################################################################
737 | # Define here your estimator pipeline
738 | ############################################################################
739 |
740 | def get_estimator(mock=False) -> Pipeline:
741 | """ Build your estimator here.
742 | Notes
743 | -----
744 | In order to minimize the memory load the first steps of the pipeline
745 | are applied directly as transforms attached to the Torch Dataset.
746 | Notes
747 | -----
748 | It is recommended to create an instance of sklearn.pipeline.Pipeline.
749 | """
750 | if "resnet" in ARCHITECTURE:
751 | net = SupRegResNet(ARCHITECTURE)
752 | elif ARCHITECTURE == "alexnet":
753 | net = SupRegAlexNet()
754 | elif "densenet" in ARCHITECTURE:
755 | net = SupRegDenseNet()
756 |
757 | selector = FeatureExtractor("vbm", mock=mock)
758 | preproc = transforms.Compose([
759 | transforms.Lambda(lambda x: selector.transform(x)),
760 | # Crop((1, 121, 128, 121), type="center"),
761 | # Pad((1, 128, 128, 128)),
762 | transforms.Lambda(lambda x: torch.from_numpy(x).float()),
763 | transforms.Normalize(mean=0.0, std=1.0),
764 | ])
765 | estimator = make_pipeline(
766 | RegressionModel(net, transforms=preproc))
767 | return estimator
768 |
769 |
770 | if __name__ == '__main__':
771 | estimator = get_estimator(mock=True).fit(None)
772 | estimator.predict(np.random.random((32, 2122945)))
773 |
--------------------------------------------------------------------------------
/src/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .openbhb import FeatureExtractor, OpenBHB, bin_age
--------------------------------------------------------------------------------
/src/data/masks/cat12vbm_space-MNI152_desc-gm_TPM.nii.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EIDOSLAB/contrastive-brain-age-prediction/2fe9e7b81dd53d8f43dfeb34e41250f5450c1094/src/data/masks/cat12vbm_space-MNI152_desc-gm_TPM.nii.gz
--------------------------------------------------------------------------------
/src/data/masks/quasiraw_space-MNI152_desc-brain_T1w.nii.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EIDOSLAB/contrastive-brain-age-prediction/2fe9e7b81dd53d8f43dfeb34e41250f5450c1094/src/data/masks/quasiraw_space-MNI152_desc-brain_T1w.nii.gz
--------------------------------------------------------------------------------
/src/data/openbhb.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import nibabel
4 | import torch
5 | import pandas as pd
6 | from sklearn.base import BaseEstimator
7 | from sklearn.base import TransformerMixin
8 | from collections import OrderedDict
9 | from nilearn.masking import unmask
10 |
11 | def bin_age(age_real: torch.Tensor):
12 | bins = [i for i in range(4, 92, 2)]
13 | age_binned = age_real.clone()
14 | for value in bins[::-1]:
15 | age_binned[age_real <= value] = value
16 | return age_binned.long()
17 |
18 | def read_data(path, dataset, fast):
19 | print(f"Read {dataset.upper()}")
20 | df = pd.read_csv(os.path.join(path, dataset + ".tsv"), sep="\t")
21 | df.loc[df["split"] == "external_test", "site"] = np.nan
22 |
23 | y_arr = df[["age", "site"]].values
24 |
25 | x_arr = np.zeros((10, 3659572))
26 | if not fast:
27 | x_arr = np.load(os.path.join(path, dataset + ".npy"), mmap_mode="r")
28 |
29 | print("- y size [original]:", y_arr.shape)
30 | print("- x size [original]:", x_arr.shape)
31 | return x_arr, y_arr
32 |
33 | class OpenBHB(torch.utils.data.Dataset):
34 | def __init__(self, root, train=True, internal=True, transform=None,
35 | label="cont", fast=False, load_feats=None):
36 | self.root = root
37 |
38 | if train and not internal:
39 | raise ValueError("Invalid configuration train=True and internal=False")
40 |
41 | self.train = train
42 | self.internal = internal
43 |
44 | dataset = "train"
45 | if not train:
46 | if internal:
47 | dataset = "internal_test"
48 | else:
49 | dataset = "external_test"
50 |
51 | self.X, self.y = read_data(root, dataset, fast)
52 | self.T = transform
53 | self.label = label
54 | self.fast = fast
55 |
56 | self.bias_feats = None
57 | if load_feats:
58 | print("Loading biased features", load_feats)
59 | self.bias_feats = torch.load(load_feats, map_location="cpu")
60 |
61 | print(f"Read {len(self.X)} records")
62 |
63 | def __len__(self):
64 | return len(self.y)
65 |
66 | def __getitem__(self, index):
67 | if not self.fast:
68 | x = self.X[index]
69 | else:
70 | x = self.X[0]
71 |
72 | y = self.y[index]
73 |
74 | if self.T is not None:
75 | x = self.T(x)
76 |
77 | # sample, age, site
78 | age, site = y[0], y[1]
79 | if self.label == "bin":
80 | age = bin_age(torch.tensor(age))
81 |
82 | if self.bias_feats is not None:
83 | return x, age, self.bias_feats[index]
84 | else:
85 | return x, age, site
86 |
87 | class FeatureExtractor(BaseEstimator, TransformerMixin):
88 | """ Select only the requested data associatedd features from the the
89 | input buffered data.
90 | """
91 | MODALITIES = OrderedDict([
92 | ("vbm", {
93 | "shape": (1, 121, 145, 121),
94 | "size": 519945}),
95 | ("quasiraw", {
96 | "shape": (1, 182, 218, 182),
97 | "size": 1827095}),
98 | ("xhemi", {
99 | "shape": (8, 163842),
100 | "size": 1310736}),
101 | ("vbm_roi", {
102 | "shape": (1, 284),
103 | "size": 284}),
104 | ("desikan_roi", {
105 | "shape": (7, 68),
106 | "size": 476}),
107 | ("destrieux_roi", {
108 | "shape": (7, 148),
109 | "size": 1036})
110 | ])
111 | MASKS = {
112 | "vbm": {
113 | "path": None,
114 | "thr": 0.05},
115 | "quasiraw": {
116 | "path": None,
117 | "thr": 0}
118 | }
119 |
120 | def __init__(self, dtype, mock=False):
121 | """ Init class.
122 | Parameters
123 | ----------
124 | dtype: str
125 | the requested data: 'vbm', 'quasiraw', 'vbm_roi', 'desikan_roi',
126 | 'destrieux_roi' or 'xhemi'.
127 | """
128 | if dtype not in self.MODALITIES:
129 | raise ValueError("Invalid input data type.")
130 | self.dtype = dtype
131 |
132 | data_types = list(self.MODALITIES.keys())
133 | index = data_types.index(dtype)
134 |
135 | cumsum = np.cumsum([item["size"] for item in self.MODALITIES.values()])
136 |
137 | if index > 0:
138 | self.start = cumsum[index - 1]
139 | else:
140 | self.start = 0
141 | self.stop = cumsum[index]
142 |
143 | self.masks = dict((key, val["path"]) for key, val in self.MASKS.items())
144 | self.masks["vbm"] = "./data/masks/cat12vbm_space-MNI152_desc-gm_TPM.nii.gz"
145 | self.masks["quasiraw"] = "./data/masks/quasiraw_space-MNI152_desc-brain_T1w.nii.gz"
146 |
147 | self.mock = mock
148 | if mock:
149 | return
150 |
151 | for key in self.masks:
152 | if self.masks[key] is None or not os.path.isfile(self.masks[key]):
153 | raise ValueError("Impossible to find mask:", key, self.masks[key])
154 | arr = nibabel.load(self.masks[key]).get_fdata()
155 | thr = self.MASKS[key]["thr"]
156 | arr[arr <= thr] = 0
157 | arr[arr > thr] = 1
158 | self.masks[key] = nibabel.Nifti1Image(arr.astype(int), np.eye(4))
159 |
160 | def fit(self, X, y):
161 | return self
162 |
163 | def transform(self, X):
164 | if self.mock:
165 | #print("transforming", X.shape)
166 | data = X.reshape(self.MODALITIES[self.dtype]["shape"])
167 | #print("mock data:", data.shape)
168 | return data
169 |
170 | # print(X.shape)
171 | select_X = X[self.start:self.stop]
172 | if self.dtype in ("vbm", "quasiraw"):
173 | im = unmask(select_X, self.masks[self.dtype])
174 | select_X = im.get_fdata()
175 | select_X = select_X.transpose(2, 0, 1)
176 | select_X = select_X.reshape(self.MODALITIES[self.dtype]["shape"])
177 | # print('transformed.shape', select_X.shape)
178 | return select_X
179 |
180 |
181 | if __name__ == '__main__':
182 | import sys
183 | from torchvision import transforms
184 | from .transforms import Crop, Pad
185 |
186 | selector = FeatureExtractor("vbm")
187 |
188 | T_pre = transforms.Lambda(lambda x: selector.transform(x))
189 | T_train = transforms.Compose([
190 | T_pre,
191 | Crop((1, 121, 128, 121), type="random"),
192 | Pad((1, 128, 128, 128)),
193 | transforms.Lambda(lambda x: torch.from_numpy(x)),
194 | transforms.Normalize(mean=0.0, std=1.0)
195 | ])
196 |
197 | train_loader = torch.utils.data.DataLoader(OpenBHB(sys.argv[1], train=True, internal=True, transform=T_train),
198 | batch_size=3, shuffle=True, num_workers=8,
199 | persistent_workers=True)
200 |
201 | x, y1, y2 = next(iter(train_loader))
202 | print(x.shape, y1, y2)
--------------------------------------------------------------------------------
/src/data/transforms.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import operator
3 | import random
4 | import torch
5 |
6 | class Crop(object):
7 | """ Crop the given n-dimensional array either at a random location or
8 | centered.
9 | """
10 | def __init__(self, shape, type="center", keep_dim=False):
11 | assert type in ["center", "random"]
12 | self.shape = shape
13 | self.cropping_type = type
14 | self.keep_dim = keep_dim
15 |
16 | def slow_crop(self, X):
17 | img_shape = np.array(X.shape)
18 |
19 | if type(self.shape) == int:
20 | size = [self.shape for _ in range(len(self.shape))]
21 | else:
22 | size = np.copy(self.shape)
23 |
24 | # print('img_shape:', img_shape, 'size', size)
25 |
26 | indexes = []
27 | for ndim in range(len(img_shape)):
28 | if size[ndim] > img_shape[ndim] or size[ndim] < 0:
29 | size[ndim] = img_shape[ndim]
30 |
31 | if self.cropping_type == "center":
32 | delta_before = int((img_shape[ndim] - size[ndim]) / 2.0)
33 |
34 | elif self.cropping_type == "random":
35 | delta_before = np.random.randint(0, img_shape[ndim] - size[ndim] + 1)
36 |
37 | indexes.append(slice(delta_before, delta_before + size[ndim]))
38 |
39 | if self.keep_dim:
40 | mask = np.zeros(img_shape, dtype=np.bool)
41 | mask[tuple(indexes)] = True
42 | arr_copy = X.copy()
43 | arr_copy[~mask] = 0
44 | return arr_copy
45 |
46 | _X = X[tuple(indexes)]
47 | # print('cropped.shape', _X.shape)
48 | return _X
49 |
50 | def fast_crop(self, X):
51 | # X is a single image (CxWxHxZ)
52 | shape = X.shape
53 |
54 | delta = [shape[1]-self.shape[1],
55 | shape[2]-self.shape[2],
56 | shape[3]-self.shape[3]]
57 |
58 | if self.cropping_type == "center":
59 | offset = list(map(operator.floordiv, delta, [2]*len(delta)))
60 | X = X[:, offset[0]:offset[0]+self.shape[1],
61 | offset[1]:offset[1]+self.shape[2],
62 | offset[2]:offset[2]+self.shape[3]]
63 |
64 | elif self.cropping_type == "random":
65 | offset = [
66 | int(random.random()*128) % (delta[0]+1),
67 | int(random.random()*128) % (delta[1]+1),
68 | int(random.random()*128) % (delta[2]+1)
69 | ]
70 | X = X[:, offset[0]:offset[0]+self.shape[1],
71 | offset[1]:offset[1]+self.shape[2],
72 | offset[2]:offset[2]+self.shape[3]]
73 | else:
74 | raise ValueError("Invalid cropping_type", self.cropping_type)
75 |
76 | return X
77 |
78 | def __call__(self, X):
79 | return self.fast_crop(X)
80 |
81 | class Cutout(object):
82 | """Apply a cutout on the images
83 | cf. Improved Regularization of Convolutional Neural Networks with Cutout, arXiv, 2017
84 | We assume that the square to be cut is inside the image.
85 | """
86 | def __init__(self, patch_size=None, value=0, random_size=False, inplace=False, localization=None, probability=0.5):
87 | self.patch_size = patch_size
88 | self.value = value
89 | self.random_size = random_size
90 | self.inplace = inplace
91 | self.localization = localization
92 | self.probability = probability
93 |
94 | def __call__(self, arr):
95 | if np.random.rand() >= self.probability:
96 | return arr
97 |
98 | img_shape = np.array(arr.shape)
99 | if type(self.patch_size) == int:
100 | size = [self.patch_size for _ in range(len(img_shape))]
101 | else:
102 | size = np.copy(self.patch_size)
103 | assert len(size) == len(img_shape), "Incorrect patch dimension."
104 | indexes = []
105 | for ndim in range(len(img_shape)):
106 | if size[ndim] > img_shape[ndim] or size[ndim] < 0:
107 | size[ndim] = img_shape[ndim]
108 | if self.random_size:
109 | size[ndim] = np.random.randint(0, size[ndim])
110 | if self.localization is not None:
111 | delta_before = max(self.localization[ndim] - size[ndim]//2, 0)
112 | else:
113 | delta_before = np.random.randint(0, img_shape[ndim] - size[ndim] + 1)
114 | indexes.append(slice(int(delta_before), int(delta_before + size[ndim])))
115 | if self.inplace:
116 | arr[tuple(indexes)] = self.value
117 | return arr
118 | else:
119 | arr_cut = np.copy(arr)
120 | arr_cut[tuple(indexes)] = self.value
121 | return arr_cut
122 |
123 | class Pad(object):
124 | """ Pad the given n-dimensional array
125 | """
126 | def __init__(self, shape, **kwargs):
127 | self.shape = shape
128 | self.kwargs = kwargs
129 |
130 | def __call__(self, X):
131 | _X = self._apply_padding(X)
132 | return _X
133 |
134 | def _apply_padding(self, arr):
135 | orig_shape = arr.shape
136 | padding = []
137 | for orig_i, final_i in zip(orig_shape, self.shape):
138 | shape_i = final_i - orig_i
139 | half_shape_i = shape_i // 2
140 | if shape_i % 2 == 0:
141 | padding.append([half_shape_i, half_shape_i])
142 | else:
143 | padding.append([half_shape_i, half_shape_i + 1])
144 | for cnt in range(len(arr.shape) - len(padding)):
145 | padding.append([0, 0])
146 | fill_arr = np.pad(arr, padding, **self.kwargs)
147 | return fill_arr
148 |
149 |
150 | if __name__ == '__main__':
151 | import timeit
152 | x = np.random.rand(1, 128, 128, 128)
153 |
154 | cut = Cutout((1, 10, 10, 10), probability=1.)
155 | print(cut(x).shape)
156 |
157 | crop = Crop((1, 121, 128, 121), type="center")
158 | print(crop(x).shape)
159 |
160 | crop = Crop((1, 121, 128, 121), type="random")
161 | print(crop(x).shape)
162 |
163 | print("slow crop:", timeit.timeit(lambda: crop.slow_crop(x), number=10000))
164 | print("fast crop:", timeit.timeit(lambda: crop.fast_crop(x), number=10000))
--------------------------------------------------------------------------------
/src/exp/mae.yaml:
--------------------------------------------------------------------------------
1 | program: main_mse.py
2 | data_dir: /scratch/data-registry/medical/openbhb
3 | save_dir: /scratch/output/brain-age-mri
4 | model: resnet18
5 | epochs: 300
6 | batch_size: 32
7 | lr: 1e-4
8 | lr_decay: step
9 | lr_decay_rate: 0.9
10 | lr_decay_step: 10
11 | optimizer: adam
12 | momentum: 0.9
13 | weight_decay: 5e-5
14 | train_all: 1
15 | trial: 0
16 | tf: none
--------------------------------------------------------------------------------
/src/exp/supcon_adam_kernel.yaml:
--------------------------------------------------------------------------------
1 | program: main_infonce.py
2 | data_dir: /scratch/data-registry/medical/openbhb
3 | save_dir: /scratch/output/brain-age-mri
4 | model: resnet18
5 | epochs: 300
6 | batch_size: 32
7 | lr: 1e-4
8 | lr_decay: step
9 | lr_decay_rate: 0.9
10 | lr_decay_step: 10
11 | optimizer: adam
12 | momentum: 0.9
13 | weight_decay: 5e-5
14 | train_all: 1
15 | method: yaware
16 | kernel: gaussian
17 | sigma: 1
18 | trial: 0
19 | tf: none
20 |
21 |
--------------------------------------------------------------------------------
/src/exp/supcon_sgd_kernel.yaml:
--------------------------------------------------------------------------------
1 | program: main_infonce.py
2 | data_dir: /scratch/data-registry/medical/openbhb
3 | save_dir: /scratch/output/brain-age-mri
4 | model: resnet18
5 | epochs: 300
6 | batch_size: 32
7 | lr: 0.1
8 | lr_decay: cosine
9 | optimizer: sgd
10 | momentum: 0.9
11 | weight_decay: 1e-4
12 | kernel: gaussian
13 | sigma: 1
14 | trial: 0
15 | tf: none
--------------------------------------------------------------------------------
/src/figures/ablation.csv:
--------------------------------------------------------------------------------
1 | "Name","ramp/score","ramp/bacc","Created","Runtime","End Time","sigma","Hostname","ID","Notes","Updated","Tags","kernel","method","clip_grad","ramp/bacc_std","ramp/ext_mae","ramp/ext_mae_std","ramp/int_mae","ramp/int_mae_std"
2 | "resnet18_threshold_reduction_sum_adam_tfnone_lr0.0001_step_step10_rate0.9_temp0.1_wd5e-05_bsz32_views2_trainall_True_kernel_cauchy_gamma2.0_alpha1.0_lambd0.0_fklliteral_True_fklkernel_True_label_trial0","3.025333","9.2667","2022-10-28T15:04:30.000Z","495350","2022-11-03T08:40:20.000Z","2","NBDOTTI61","3jm0a5ld","-","2022-11-03T08:40:20.000Z","tested","cauchy","threshold","","1.124","6.183","0.038105","3.477333","0.021385"
3 | "resnet18_expw_adam_tfnone_lr0.0001_step_step10_rate0.9_temp0.1_wd5e-05_bsz32_views2_trainall_True_kernel_cauchy_gamma2.0_alpha1.0_lambd0.0_fklliteral_True_fklkernel_True_label_trial0","1.848667","5","2022-10-28T15:04:30.000Z","517975","2022-11-03T14:57:25.000Z","2","NBDOTTI61","i3qmz1yh","-","2022-11-03T14:57:25.000Z","tested","cauchy","expw","","0.1","4.547","0.019157","2.666","0.002646"
4 | "resnet18_yaware_adam_tfnone_lr0.0001_step_step10_rate0.9_temp0.1_wd5e-05_bsz32_views2_trainall_True_kernel_rbf_sigma2.0_alpha1.0_lambd0.0_fklliteral_True_fklkernel_True_label_trial0","1.815333","6.6","2022-10-28T15:04:30.000Z","79443","2022-10-29T13:08:33.000Z","2","NBDOTTI61","ooz7z2yo","-","2022-10-29T13:08:33.000Z","tested","rbf","yaware","","0.1732","4.102","0.009539","2.664667","0.002082"
5 | "resnet18_yaware_adam_tfnone_lr0.0001_step_step10_rate0.9_temp0.1_wd5e-05_bsz32_views2_trainall_True_kernel_cauchy_gamma2.0_alpha1.0_lambd0.0_fklliteral_True_fklkernel_True_label_trial0","2.482333","8.1667","2022-10-28T15:04:28.000Z","461295","2022-11-02T23:12:43.000Z","2","NBDOTTI61","r07q08n4","-","2022-11-02T23:12:43.000Z","tested","cauchy","yaware","","0.6658","5.267333","0.004509","3.088333","0.019425"
6 | "resnet18_expw_adam_tfnone_lr0.0001_step_step10_rate0.9_temp0.1_wd5e-05_bsz32_views2_trainall_True_kernel_rbf_sigma2.0_alpha1.0_lambd0.0_fklliteral_True_fklkernel_True_label_trial0","1.539667","5.1","2022-10-29T06:27:10.000Z","453260","2022-11-03T12:21:30.000Z","2","NBDOTTI61","w0ci971l","-","2022-11-03T12:21:30.000Z","tested","rbf","expw","","0.1","3.761","0.005","2.552","0.002"
7 | "resnet18_threshold_reduction_sum_adam_tfnone_lr0.0001_step_step10_rate0.9_temp0.1_wd5e-05_bsz32_views2_trainall_True_kernel_rbf_sigma2.0_alpha1.0_lambd0.0_fklliteral_True_fklkernel_True_label_trial0","1.738333","5.7333","2022-10-29T06:24:22.000Z","440525","2022-11-03T08:46:27.000Z","2","NBDOTTI61","ykx8p7pc","-","2022-11-03T08:46:27.000Z","tested","rbf","threshold","","0.1528","4.098333","0.009504","2.947","0.004359"
8 | "resnet18_threshold_reduction_sum_adam_tfnone_lr0.0001_step_step10_rate0.9_temp0.1_wd5e-05_bsz32_views2_trainall_True_kernel_rbf_sigma1.0_alpha1.0_lambd0.0_fklliteral_True_fklkernel_True_label_trial0","2.627","8.5","2022-10-29T05:51:32.000Z","472252","2022-11-03T17:02:24.000Z","1","NBDOTTI61","alp80xna","-","2022-11-03T17:02:24.000Z","tested","rbf","threshold","","1.0149","5.508333","0.019858","3.042667","0.011676"
9 | "resnet18_expw_adam_tfnone_lr0.0001_step_step10_rate0.9_temp0.1_wd5e-05_bsz32_views2_trainall_True_kernel_cauchy_gamma1.0_alpha1.0_lambd0.0_fklliteral_True_fklkernel_True_label_trial0","1.816","4.8667","2022-10-28T15:04:30.000Z","536997","2022-11-03T20:14:27.000Z","1","NBDOTTI61","anck0dqb","-","2022-11-03T20:14:27.000Z","tested","cauchy","expw","","0.2082","4.502667","0.010408","2.731","0.006083"
10 | "resnet18_yaware_adam_tfnone_lr0.0001_step_step10_rate0.9_temp0.1_wd5e-05_bsz32_views2_trainall_True_kernel_rbf_sigma1.0_alpha1.0_lambd0.0_fklliteral_True_fklkernel_True_label_trial0","2.428","6.6333","2022-10-28T15:04:30.000Z","180849","2022-10-30T17:18:39.000Z","1","NBDOTTI61","apyqr2sz","-","2022-10-30T17:18:39.000Z","tested","rbf","yaware","","0.8505","5.492667","0.05208","2.850333","0.006807"
11 | "resnet18_threshold_reduction_sum_adam_tfnone_lr0.0001_step_step10_rate0.9_temp0.1_wd5e-05_bsz32_views2_trainall_True_kernel_cauchy_gamma1.0_alpha1.0_lambd0.0_fklliteral_True_fklkernel_True_label_trial0","2.282333","8.5333","2022-10-28T15:04:29.000Z","534546","2022-11-03T19:33:35.000Z","1","NBDOTTI61","bsecf64m","-","2022-11-03T19:33:35.000Z","tested","cauchy","threshold","","0.3512","4.775333","0.013204","2.779667","0.010599"
12 | "resnet18_expw_adam_tfnone_lr0.0001_step_step10_rate0.9_temp0.1_wd5e-05_bsz32_views2_trainall_True_kernel_rbf_sigma1.0_alpha1.0_lambd0.0_fklliteral_True_fklkernel_True_label_trial0","1.575667","4.9667","2022-10-29T06:25:27.000Z","470279","2022-11-03T17:03:26.000Z","1","NBDOTTI61","nl49poa9","-","2022-11-03T17:03:26.000Z","tested","rbf","expw","","0.4041","3.877667","0.004041","2.823","0.002646"
13 | "resnet18_yaware_adam_tfnone_lr0.0001_step_step10_rate0.9_temp0.1_wd5e-05_bsz32_views2_trainall_True_kernel_cauchy_gamma1.0_alpha1.0_lambd0.0_fklliteral_True_fklkernel_True_label_trial0","2.147333","7.7333","2022-10-28T15:04:28.000Z","450797","2022-11-02T20:17:45.000Z","1","NBDOTTI61","w85oskw1","-","2022-11-02T20:17:45.000Z","tested","cauchy","yaware","","0.6506","4.633667","0.028113","2.710667","0.005508"
--------------------------------------------------------------------------------
/src/figures/ablation.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EIDOSLAB/contrastive-brain-age-prediction/2fe9e7b81dd53d8f43dfeb34e41250f5450c1094/src/figures/ablation.pdf
--------------------------------------------------------------------------------
/src/figures/ablation.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import pandas as pd
3 | import numpy as np
4 | from matplotlib import rc
5 |
6 |
7 | if __name__ == '__main__':
8 | rc('axes', titlesize=25) # fontsize of the axes title
9 | rc('axes', labelsize=20) # fontsize of the x and y labels
10 | rc('xtick', labelsize=16) # fontsize of the tick labels
11 | rc('ytick', labelsize=16) # fontsize of the tick labels
12 | rc('legend', fontsize=14) # legend fontsize
13 | rc('figure', titlesize=28) # fontsize of the figure title
14 | rc('font', size=18)
15 | # rc('font', family='Times New Roman')
16 | rc('text', usetex=True)
17 |
18 | df = pd.read_csv('ablation.csv')
19 | print(df.head())
20 |
21 | fig = plt.figure(figsize=(20, 3))
22 | plt.rcParams['image.cmap'] = "Set2"
23 | plt.rcParams['axes.prop_cycle'] = plt.cycler(color=plt.cm.Set2.colors)
24 |
25 | i = 1
26 | for kernel in df.kernel.unique():
27 | for sigma in sorted(df.sigma.unique()):
28 | ax = fig.add_subplot(1, 4, i)
29 | data = df[(df.kernel == kernel) & (df.sigma == sigma)]
30 | data = data.sort_values(by='method')
31 | print(data)
32 |
33 | int_mae = data['ramp/int_mae'].values
34 | bacc = data['ramp/bacc'].values
35 | ext_mae = data['ramp/ext_mae'].values
36 | score = data['ramp/score'].values
37 | methods = data['method'].values
38 |
39 | width = 0.4
40 | x = np.arange(4)*2
41 | labels = ['Int. MAE', 'BAcc', 'Ext. MAE', 'Score']
42 |
43 | data = np.array([[int_mae[i], bacc[i], ext_mae[i], score[i]] for i in range(3)])
44 |
45 | if kernel == "rbf":
46 | ax.set_title(f"{kernel} ($\sigma$={sigma})")
47 | else:
48 | ax.set_title(f"{kernel} ($\gamma$={sigma})")
49 |
50 | alpha = 0.8
51 | ax.bar(x - width, data[0], width, label=methods[0], alpha=alpha)
52 | ax.bar(x, data[1], width, label=methods[1], alpha=alpha)
53 | ax.bar(x + width, data[2], width, label=methods[2], alpha=alpha)
54 | ax.set_ylim(0, 10)
55 | # ax.bar(x + width, data[3], width, label=methods[3])
56 |
57 | if i == 4:
58 | ax.legend()
59 | ax.set_xticks(x, labels, rotation=45)
60 |
61 | i += 1
62 | # fig.tight_layout()
63 | plt.savefig('ablation.pdf', dpi=200, bbox_inches='tight', pad_inches=0)
64 | plt.show()
--------------------------------------------------------------------------------
/src/launcher.py:
--------------------------------------------------------------------------------
1 | import yaml
2 | import sys
3 | import subprocess
4 | import os
5 | from yaml.loader import SafeLoader
6 |
7 |
8 | if __name__ == '__main__':
9 | if len(sys.argv) <= 1:
10 | print("Usage: ./launcher.py path/to/yaml")
11 | exit(1)
12 |
13 | with open(sys.argv[1]) as f:
14 | data = yaml.load(f, Loader=SafeLoader)
15 |
16 | program = data['program']
17 | del data['program']
18 |
19 | skip = False
20 | for idx, override in enumerate(sys.argv[2:]):
21 | if skip:
22 | skip = False
23 | continue
24 |
25 | if '=' in override:
26 | k, v = override.split('=')
27 | else:
28 | k = override.replace('--', '')
29 | v = sys.argv[2+idx+1]
30 | skip = True
31 | data[k] = v
32 |
33 | args = ["python3", os.path.join(os.getcwd(), program)]
34 | for k, v in data.items():
35 | args.extend(["--" + k, str(v)])
36 | print("Running:", ' '.join(args))
37 | subprocess.run(args)
38 |
--------------------------------------------------------------------------------
/src/losses.py:
--------------------------------------------------------------------------------
1 | from cmath import isinf
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | class KernelizedSupCon(nn.Module):
8 | """Supervised contrastive loss: https://arxiv.org/pdf/2004.11362.pdf.
9 | It also supports the unsupervised contrastive loss in SimCLR
10 | Based on: https://github.com/HobbitLong/SupContrast"""
11 | def __init__(self, method: str, temperature: float=0.07, contrast_mode: str='all',
12 | base_temperature: float=0.07, kernel: callable=None, delta_reduction: str='sum'):
13 | super().__init__()
14 | self.temperature = temperature
15 | self.contrast_mode = contrast_mode
16 | self.base_temperature = base_temperature
17 | self.method = method
18 | self.kernel = kernel
19 | self.delta_reduction = delta_reduction
20 |
21 | if kernel is not None and method == 'supcon':
22 | raise ValueError('Kernel must be none if method=supcon')
23 |
24 | if kernel is None and method != 'supcon':
25 | raise ValueError('Kernel must not be none if method != supcon')
26 |
27 | if delta_reduction not in ['mean', 'sum']:
28 | raise ValueError(f"Invalid reduction {delta_reduction}")
29 |
30 | def __repr__(self):
31 | return f'{self.__class__.__name__} ' \
32 | f'(t={self.temperature}, ' \
33 | f'method={self.method}, ' \
34 | f'kernel={self.kernel is not None}, ' \
35 | f'delta_reduction={self.delta_reduction})'
36 |
37 | def forward(self, features, labels=None):
38 | """Compute loss for model. If `labels` is None,
39 | it degenerates to SimCLR unsupervised loss:
40 | https://arxiv.org/pdf/2002.05709.pdf
41 |
42 | Args:
43 | features: hidden vector of shape [bsz, n_views, n_features].
44 | input has to be rearranged to [bsz, n_views, n_features] and labels [bsz],
45 | labels: ground truth of shape [bsz].
46 | Returns:
47 | A loss scalar.
48 | """
49 | device = features.device
50 |
51 | if len(features.shape) != 3:
52 | raise ValueError('`features` needs to be [bsz, n_views, n_feats],'
53 | '3 dimensions are required')
54 |
55 | batch_size = features.shape[0]
56 | n_views = features.shape[1]
57 |
58 | if labels is None:
59 | mask = torch.eye(batch_size, device=device)
60 |
61 | else:
62 | labels = labels.view(-1, 1)
63 | if labels.shape[0] != batch_size:
64 | raise ValueError('Num of labels does not match num of features')
65 |
66 | if self.kernel is None:
67 | mask = torch.eq(labels, labels.T)
68 | else:
69 | mask = self.kernel(labels)
70 |
71 | view_count = features.shape[1]
72 | features = torch.cat(torch.unbind(features, dim=1), dim=0)
73 | if self.contrast_mode == 'one':
74 | features = features[:, 0]
75 | anchor_count = 1
76 | elif self.contrast_mode == 'all':
77 | features = features
78 | anchor_count = view_count
79 | else:
80 | raise ValueError('Unknown mode: {}'.format(self.contrast_mode))
81 |
82 | # Tile mask
83 | mask = mask.repeat(anchor_count, view_count)
84 |
85 | # Inverse of torch-eye to remove self-contrast (diagonal)
86 | inv_diagonal = torch.scatter(
87 | torch.ones_like(mask),
88 | 1,
89 | torch.arange(batch_size*n_views, device=device).view(-1, 1),
90 | 0
91 | )
92 |
93 | # compute similarity
94 | anchor_dot_contrast = torch.div(
95 | torch.matmul(features, features.T),
96 | self.temperature
97 | )
98 |
99 | # for numerical stability
100 | logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
101 | logits = anchor_dot_contrast - logits_max.detach()
102 |
103 | alignment = logits
104 |
105 | # base case is:
106 | # - supcon if kernel = none
107 | # - y-aware is kernel != none
108 | uniformity = torch.exp(logits) * inv_diagonal
109 |
110 | if self.method == 'threshold':
111 | repeated = mask.unsqueeze(-1).repeat(1, 1, mask.shape[0]) # repeat kernel mask
112 |
113 | delta = (mask[:, None].T - repeated.T).transpose(1, 2) # compute the difference w_k - w_j for every k,j
114 | delta = (delta > 0.).float()
115 |
116 | # for each z_i, repel only samples j s.t. K(z_i, z_j) < K(z_i, z_k)
117 | uniformity = uniformity.unsqueeze(-1).repeat(1, 1, mask.shape[0])
118 |
119 | if self.delta_reduction == 'mean':
120 | uniformity = (uniformity * delta).mean(-1)
121 | else:
122 | uniformity = (uniformity * delta).sum(-1)
123 |
124 | elif self.method == 'expw':
125 | # exp weight e^(s_j(1-w_j))
126 | uniformity = torch.exp(logits * (1 - mask)) * inv_diagonal
127 |
128 | uniformity = torch.log(uniformity.sum(1, keepdim=True))
129 |
130 |
131 | # positive mask contains the anchor-positive pairs
132 | # excluding on the diagonal
133 | positive_mask = mask * inv_diagonal
134 |
135 | log_prob = alignment - uniformity # log(alignment/uniformity) = log(alignment) - log(uniformity)
136 | log_prob = (positive_mask * log_prob).sum(1) / positive_mask.sum(1) # compute mean of log-likelihood over positive
137 |
138 | # loss
139 | loss = - (self.temperature / self.base_temperature) * log_prob
140 | return loss.mean()
141 |
142 |
143 | if __name__ == '__main__':
144 | k_supcon = KernelizedSupCon(1.0)
145 |
146 | x = torch.nn.functional.normalize(torch.randn((256, 2, 64)), dim=1)
147 | labels = torch.randint(0, 4, (256,))
148 |
149 | l = k_supcon(x, labels)
150 | print(l)
--------------------------------------------------------------------------------
/src/main_infonce.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import math
3 | import os
4 | from random import gauss
5 | import numpy as np
6 | import torch
7 | import torch.nn.functional as F
8 | import torch.utils.data
9 | import torchvision
10 | import argparse
11 | import models
12 | import losses
13 | import time
14 | import wandb
15 | import torch.utils.tensorboard
16 |
17 | from torch import nn
18 | from torchvision import transforms
19 | from torchvision import datasets
20 | from util import AverageMeter, NViewTransform, ensure_dir, set_seed, arg2bool, save_model
21 | from util import warmup_learning_rate, adjust_learning_rate
22 | from util import compute_age_mae, compute_site_ba
23 | from data import FeatureExtractor, OpenBHB, bin_age
24 | from data.transforms import Crop, Pad, Cutout
25 | from main_mse import get_transforms
26 |
27 |
28 | def parse_arguments():
29 | parser = argparse.ArgumentParser(description="Weakly contrastive learning for brain age predictin",
30 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
31 |
32 | # Misc
33 | parser.add_argument('--device', type=str, help='torch device', default='cuda')
34 | parser.add_argument('--print_freq', type=int, help='print frequency', default=10)
35 | parser.add_argument('--trial', type=int, help='random seed / trial id', default=0)
36 | parser.add_argument('--save_dir', type=str, help='output dir', default='output')
37 | parser.add_argument('--save_freq', type=int, help='save frequency', default=50)
38 | parser.add_argument('--data_dir', type=str, help='path of data dir', default='/data')
39 | parser.add_argument('--amp', type=arg2bool, help='use amp', default=False)
40 | parser.add_argument('--clip_grad', type=arg2bool, help='clip gradient to prevent nan', default=False)
41 |
42 | # Model
43 | parser.add_argument('--model', type=str, help='model architecture', default='resnet18')
44 |
45 | # Optimizer
46 | parser.add_argument('--epochs', type=int, help='number of epochs', default=300)
47 | parser.add_argument('--batch_size', type=int, help='batch size', default=256)
48 | parser.add_argument('--lr', type=float, help='learning rate', default=1e-4)
49 | parser.add_argument('--lr_decay', type=str, help='type of decay', choices=['cosine', 'step'], default='step')
50 | parser.add_argument('--lr_decay_rate', type=float, default=0.9, help='decay rate for learning rate (for step)')
51 | parser.add_argument('--lr_decay_epochs', type=str, help='steps of lr decay (list)', default="700,800,900")
52 | parser.add_argument('--lr_decay_step', type=int, help='decay rate step (overwrites lr_decay_epochs', default=10)
53 | parser.add_argument('--warm', type=arg2bool, help='warmup lr', default=False)
54 | parser.add_argument('--optimizer', type=str, help="optimizer (adam or sgd)", choices=["adam", "sgd"], default="adam")
55 | parser.add_argument('--momentum', type=float, help='momentum', default=0.9)
56 | parser.add_argument('--weight_decay', type=float, help='weight decay', default=5e-5)
57 |
58 | # Data
59 | parser.add_argument('--train_all', type=arg2bool, help='train on all dataset including validation (int+ext)', default=True)
60 | parser.add_argument('--tf', type=str, help='data augmentation', choices=['none', 'crop', 'cutout', 'all'], default='none')
61 |
62 | # Loss
63 | parser.add_argument('--method', type=str, help='loss function', choices=['supcon', 'yaware', 'threshold', 'expw'], default='supcon')
64 | parser.add_argument('--kernel', type=str, help='Kernel function (not for supcon)', choices=['cauchy', 'gaussian', 'rbf'], default=None)
65 | parser.add_argument('--delta_reduction', type=str, help='use mean or sum to reduce 3d delta mask (only for method=threshold)', default='sum')
66 | parser.add_argument('--temp', type=float, help='loss temperature', default=0.1)
67 | parser.add_argument('--alpha', type=float, help='infonce weight', default=1.)
68 | parser.add_argument('--sigma', type=float, help='gaussian-rbf kernel sigma / cauchy gamma', default=1)
69 | parser.add_argument('--n_views', type=int, help='num. of multiviews', default=2)
70 |
71 | opts = parser.parse_args()
72 |
73 | if opts.batch_size > 256:
74 | print("Forcing warm")
75 | opts.warm = True
76 |
77 | if opts.lr_decay_step is not None:
78 | opts.lr_decay_epochs = list(range(opts.lr_decay_step, opts.epochs, opts.lr_decay_step))
79 | print(f"Computed decay epochs based on step ({opts.lr_decay_step}):", opts.lr_decay_epochs)
80 | else:
81 | iterations = opts.lr_decay_epochs.split(',')
82 | opts.lr_decay_epochs = list([])
83 | for it in iterations:
84 | opts.lr_decay_epochs.append(int(it))
85 |
86 | if opts.warm:
87 | opts.warmup_from = 0.01
88 | opts.warm_epochs = 10
89 | if opts.lr_decay == 'cosine':
90 | eta_min = opts.lr * (opts.lr_decay_rate ** 3)
91 | opts.warmup_to = eta_min + (opts.lr - eta_min) * (
92 | 1 + math.cos(math.pi * opts.warm_epochs / opts.epochs)) / 2
93 | else:
94 | opts.milestones = [int(s) for s in opts.lr_decay_epochs.split(',')]
95 | opts.warmup_to = opts.lr
96 |
97 | if opts.method == 'supcon':
98 | print('method == supcon, binning age')
99 | opts.label = 'bin'
100 | else:
101 | print('method != supcon, using real age value')
102 | opts.label = 'cont'
103 |
104 | if opts.method == 'supcon' and opts.kernel is not None:
105 | print('Invalid kernel for supcon')
106 | exit(0)
107 |
108 | if opts.method != 'supcon' and opts.kernel is None:
109 | print('Kernel cannot be None for method != supcon')
110 | exit(1)
111 |
112 | if opts.model == 'densenet121':
113 | opts.n_views = 1
114 |
115 | return opts
116 |
117 | def load_data(opts):
118 | T_train, T_test = get_transforms(opts)
119 | T_train = NViewTransform(T_train, opts.n_views)
120 |
121 | train_dataset = OpenBHB(opts.data_dir, train=True, internal=True, transform=T_train, label=opts.label,
122 | load_feats=opts.biased_features)
123 | if opts.train_all:
124 | valint_feats, valext_feats = None, None
125 | if opts.biased_features is not None:
126 | valint_feats = opts.biased_features.replace('.pth', '_valint.pth')
127 | valext_feats = opts.biased_features.replace('.pth', '_valext.pth')
128 |
129 | valint = OpenBHB(opts.data_dir, train=False, internal=True, transform=T_train,
130 | label=opts.label, load_feats=valint_feats)
131 | valext = OpenBHB(opts.data_dir, train=False, internal=False, transform=T_train,
132 | label=opts.label, load_feats=valext_feats)
133 | train_dataset = torch.utils.data.ConcatDataset([train_dataset, valint, valext])
134 | print("Total dataset length:", len(train_dataset))
135 |
136 |
137 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=opts.batch_size, shuffle=True, num_workers=8,
138 | persistent_workers=True)
139 | train_loader_score = torch.utils.data.DataLoader(OpenBHB(opts.data_dir, train=True, internal=True, transform=T_train, label=opts.label),
140 | batch_size=opts.batch_size, shuffle=True, num_workers=8,
141 | persistent_workers=True)
142 | test_internal = torch.utils.data.DataLoader(OpenBHB(opts.data_dir, train=False, internal=True, transform=T_test),
143 | batch_size=opts.batch_size, shuffle=False, num_workers=8,
144 | persistent_workers=True)
145 | test_external = torch.utils.data.DataLoader(OpenBHB(opts.data_dir, train=False, internal=False, transform=T_test),
146 | batch_size=opts.batch_size, shuffle=False, num_workers=8,
147 | persistent_workers=True)
148 | return train_loader, train_loader_score, test_internal, test_external
149 |
150 | def load_model(opts):
151 | if 'resnet' in opts.model:
152 | model = models.SupConResNet(opts.model, feat_dim=128)
153 | elif 'alexnet' in opts.model:
154 | model = models.SupConAlexNet(feat_dim=128)
155 | elif 'densenet121' in opts.model:
156 | model = models.SupConDenseNet(feat_dim=128)
157 |
158 | else:
159 | raise ValueError("Unknown model", opts.model)
160 |
161 | if opts.device == 'cuda' and torch.cuda.device_count() > 1:
162 | print(f"Using multiple CUDA devices ({torch.cuda.device_count()})")
163 | model = torch.nn.DataParallel(model)
164 | model = model.to(opts.device)
165 |
166 |
167 | def gaussian_kernel(x):
168 | x = x - x.T
169 | return torch.exp(-(x**2) / (2*(opts.sigma**2))) / (math.sqrt(2*torch.pi)*opts.sigma)
170 |
171 | def rbf(x):
172 | x = x - x.T
173 | return torch.exp(-(x**2)/(2*(opts.sigma**2)))
174 |
175 | def cauchy(x):
176 | x = x - x.T
177 | return 1. / (opts.sigma*(x**2) + 1)
178 |
179 | kernels = {
180 | 'none': None,
181 | 'cauchy': cauchy,
182 | 'gaussian': gaussian_kernel,
183 | 'rbf': rbf
184 | }
185 |
186 | infonce = losses.KernelizedSupCon(method=opts.method, temperature=opts.temp,
187 | kernel=kernels[opts.kernel], delta_reduction=opts.delta_reduction)
188 | infonce = infonce.to(opts.device)
189 |
190 |
191 | return model, infonce
192 |
193 | def load_optimizer(model, opts):
194 | if opts.optimizer == "sgd":
195 | optimizer = torch.optim.SGD(model.parameters(), lr=opts.lr,
196 | momentum=opts.momentum,
197 | weight_decay=opts.weight_decay)
198 | else:
199 | optimizer = torch.optim.Adam(model.parameters(), lr=opts.lr, weight_decay=opts.weight_decay)
200 |
201 | return optimizer
202 |
203 | def train(train_loader, model, infonce, optimizer, opts, epoch):
204 | loss = AverageMeter()
205 | batch_time = AverageMeter()
206 | data_time = AverageMeter()
207 |
208 | scaler = torch.cuda.amp.GradScaler() if opts.amp else None
209 | model.train()
210 |
211 | t1 = time.time()
212 | for idx, (images, labels, _) in enumerate(train_loader):
213 | data_time.update(time.time() - t1)
214 |
215 | images = torch.cat(images, dim=0).to(opts.device)
216 | bsz = labels.shape[0]
217 |
218 | warmup_learning_rate(opts, epoch, idx, len(train_loader), optimizer)
219 |
220 | with torch.cuda.amp.autocast(scaler is not None):
221 | projected = model(images)
222 | projected = torch.split(projected, [bsz]*opts.n_views, dim=0)
223 | projected = torch.cat([f.unsqueeze(1) for f in projected], dim=1)
224 | running_loss = infonce(projected, labels.to(opts.device))
225 |
226 | optimizer.zero_grad()
227 | if scaler is None:
228 | running_loss.backward()
229 | if opts.clip_grad:
230 | nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
231 | optimizer.step()
232 | else:
233 | scaler.scale(running_loss).backward()
234 | if opts.clip_grad:
235 | scaler.unscale_(optimizer)
236 | nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
237 | scaler.step(optimizer)
238 | scaler.update()
239 |
240 | loss.update(running_loss.item(), bsz)
241 | batch_time.update(time.time() - t1)
242 | t1 = time.time()
243 | eta = batch_time.avg * (len(train_loader) - idx)
244 |
245 | if (idx + 1) % opts.print_freq == 0:
246 | print(f"Train: [{epoch}][{idx + 1}/{len(train_loader)}]:\t"
247 | f"BT {batch_time.avg:.3f}\t"
248 | f"ETA {datetime.timedelta(seconds=eta)}\t"
249 | f"loss {loss.avg:.3f}\t")
250 |
251 | return loss.avg, batch_time.avg, data_time.avg
252 |
253 | if __name__ == '__main__':
254 | opts = parse_arguments()
255 |
256 | set_seed(opts.trial)
257 |
258 | train_loader, train_loader_score, test_loader_int, test_loader_ext = load_data(opts)
259 | model, infonce = load_model(opts)
260 | optimizer = load_optimizer(model, opts)
261 |
262 | model_name = opts.model
263 | if opts.warm:
264 | model_name = f"{model_name}_warm"
265 | if opts.amp:
266 | model_name = f"{model_name}_amp"
267 |
268 | method_name = opts.method
269 | if opts.method == 'threshold':
270 | method_name = f"{method_name}_reduction_{opts.delta_reduction}"
271 |
272 | optimizer_name = opts.optimizer
273 | if opts.clip_grad:
274 | optimizer_name = f"{optimizer_name}_clipgrad"
275 |
276 | kernel_name = opts.kernel
277 | if opts.kernel == "gaussian" or opts.kernel == 'rbf':
278 | kernel_name = f"{kernel_name}_sigma{opts.sigma}"
279 | elif opts.kernel == 'cauchy':
280 | kernel_name = f"{kernel_name}_gamma{opts.sigma}"
281 |
282 | run_name = (f"{model_name}_{method_name}_"
283 | f"{optimizer_name}_"
284 | f"tf{opts.tf}_"
285 | f"lr{opts.lr}_{opts.lr_decay}_step{opts.lr_decay_step}_rate{opts.lr_decay_rate}_"
286 | f"temp{opts.temp}_"
287 | f"wd{opts.weight_decay}_"
288 | f"bsz{opts.batch_size}_views{opts.n_views}_"
289 | f"trainall_{opts.train_all}_"
290 | f"kernel_{kernel_name}_"
291 | f"f{opts.alpha}_lambd{opts.lambd}_"
292 | f"trial{opts.trial}")
293 | tb_dir = os.path.join(opts.save_dir, "tensorboard", run_name)
294 | save_dir = os.path.join(opts.save_dir, f"openbhb_models", run_name)
295 | ensure_dir(tb_dir)
296 | ensure_dir(save_dir)
297 |
298 | opts.model_class = model.__class__.__name__
299 | opts.criterion = infonce.__class__.__name__
300 | opts.optimizer_class = optimizer.__class__.__name__
301 |
302 | wandb.init(project="brain-age-prediction", config=opts, name=run_name, sync_tensorboard=True,
303 | settings=wandb.Settings(code_dir="/src"), tags=['to test'])
304 | wandb.run.log_code(root="/src", include_fn=lambda path: path.endswith(".py"))
305 |
306 | print('Config:', opts)
307 | print('Model:', model.__class__.__name__)
308 | print('Criterion:', infonce)
309 | print('Optimizer:', optimizer)
310 | print('Scheduler:', opts.lr_decay)
311 |
312 | writer = torch.utils.tensorboard.writer.SummaryWriter(tb_dir)
313 | if opts.amp:
314 | print("Using AMP")
315 |
316 | start_time = time.time()
317 | best_acc = 0.
318 | for epoch in range(1, opts.epochs + 1):
319 | adjust_learning_rate(opts, optimizer, epoch)
320 |
321 | t1 = time.time()
322 | loss_train, batch_time, data_time = train(train_loader, model, infonce, optimizer, opts, epoch)
323 | t2 = time.time()
324 | writer.add_scalar("train/loss", loss_train, epoch)
325 |
326 | writer.add_scalar("lr", optimizer.param_groups[0]['lr'], epoch)
327 | writer.add_scalar("BT", batch_time, epoch)
328 | writer.add_scalar("DT", data_time, epoch)
329 | print(f"epoch {epoch}, total time {t2-start_time:.2f}, epoch time {t2-t1:.3f} loss {loss_train:.4f}")
330 |
331 | if epoch % opts.save_freq == 0:
332 | # save_file = os.path.join(save_dir, f"ckpt_epoch_{epoch}.pth")
333 | # save_model(model, optimizer, opts, epoch, save_file)
334 |
335 | mae_train, mae_int, mae_ext = compute_age_mae(model, train_loader_score, test_loader_int, test_loader_ext, opts)
336 | writer.add_scalar("train/mae", mae_train, epoch)
337 | writer.add_scalar("test/mae_int", mae_int, epoch)
338 | writer.add_scalar("test/mae_ext", mae_ext, epoch)
339 | print("Age MAE:", mae_train, mae_int, mae_ext)
340 |
341 | ba_train, ba_int, ba_ext = compute_site_ba(model, train_loader_score, test_loader_int, test_loader_ext, opts)
342 | writer.add_scalar("train/site_ba", ba_train, epoch)
343 | writer.add_scalar("test/ba_int", ba_int, epoch)
344 | writer.add_scalar("test/ba_ext", ba_ext, epoch)
345 | print("Site BA:", ba_train, ba_int, ba_ext)
346 |
347 | challenge_metric = ba_int**0.3 * mae_ext
348 | writer.add_scalar("test/score", challenge_metric, epoch)
349 | print("Challenge score", challenge_metric)
350 |
351 | save_file = os.path.join(save_dir, f"weights.pth")
352 | save_model(model, optimizer, opts, epoch, save_file)
353 |
354 | mae_train, mae_int, mae_ext = compute_age_mae(model, train_loader_score, test_loader_int, test_loader_ext, opts)
355 | writer.add_scalar("train/mae", mae_train, epoch)
356 | writer.add_scalar("test/mae_int", mae_int, epoch)
357 | writer.add_scalar("test/mae_ext", mae_ext, epoch)
358 | print("Age MAE:", mae_train, mae_int, mae_ext)
359 |
360 | ba_train, ba_int, ba_ext = compute_site_ba(model, train_loader_score, test_loader_int, test_loader_ext, opts)
361 | writer.add_scalar("train/site_ba", ba_train, epoch)
362 | writer.add_scalar("test/ba_int", ba_int, epoch)
363 | writer.add_scalar("test/ba_ext", ba_ext, epoch)
364 | print("Site BA:", ba_train, ba_int, ba_ext)
365 |
366 | challenge_metric = ba_int**0.3 * mae_ext
367 | writer.add_scalar("test/score", challenge_metric, epoch)
368 | print("Challenge score", challenge_metric)
--------------------------------------------------------------------------------
/src/main_mse.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import math
3 | import os
4 | import numpy as np
5 | import torch
6 | import torch.nn.functional as F
7 | import torch.utils.data
8 | import argparse
9 | import models
10 | import losses
11 | import time
12 | import wandb
13 | import torch.utils.tensorboard
14 |
15 | from torchvision import transforms
16 | from util import AverageMeter, MAE, ensure_dir, set_seed, arg2bool, save_model
17 | from util import warmup_learning_rate, adjust_learning_rate
18 | from util import compute_age_mae, compute_site_ba
19 | from data import FeatureExtractor, OpenBHB, bin_age
20 | from data.transforms import Crop, Pad, Cutout
21 |
22 | def parse_arguments():
23 | parser = argparse.ArgumentParser(description="Augmentation for multiview",
24 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
25 |
26 | parser.add_argument('--device', type=str, help='torch device', default='cuda')
27 | parser.add_argument('--print_freq', type=int, help='print frequency', default=10)
28 | parser.add_argument('--trial', type=int, help='random seed / trial id', default=0)
29 | parser.add_argument('--save_dir', type=str, help='output dir', default='output')
30 | parser.add_argument('--save_freq', type=int, help='save frequency', default=50)
31 |
32 | parser.add_argument('--data_dir', type=str, help='path of data dir', default='/data')
33 | parser.add_argument('--batch_size', type=int, help='batch size', default=256)
34 |
35 | parser.add_argument('--epochs', type=int, help='number of epochs', default=200)
36 | parser.add_argument('--lr', type=float, help='learning rate', default=0.1)
37 | parser.add_argument('--lr_decay', type=str, help='type of decay', choices=['cosine', 'step'], default='cosine')
38 | parser.add_argument('--lr_decay_rate', type=float, default=0.1, help='decay rate for learning rate (for step)')
39 | parser.add_argument('--lr_decay_epochs', type=str, help='steps of lr decay (list)', default="700,800,900")
40 | parser.add_argument('--lr_decay_step', type=int, help='decay rate step (overwrites lr_decay_epochs', default=None)
41 |
42 | parser.add_argument('--warm', type=arg2bool, help='warmup lr', default=False)
43 | parser.add_argument('--optimizer', type=str, help="optimizer (adam or sgd)", choices=["adam", "sgd"], default="sgd")
44 | parser.add_argument('--momentum', type=float, help='momentum', default=0.9)
45 | parser.add_argument('--weight_decay', type=float, help='weight decay', default=1e-4)
46 |
47 | parser.add_argument('--model', type=str, help='model architecture', default='resnet18')
48 |
49 | parser.add_argument('--method', type=str, help='loss function', choices=['mae', 'mse'], default='mae')
50 |
51 |
52 | parser.add_argument('--train_all', type=arg2bool, help='train on all dataset including validation (int+ext)', default=False)
53 | parser.add_argument('--tf', type=str, help='data augmentation', choices=['none', 'crop', 'cutout', 'all'], default='none')
54 |
55 | parser.add_argument('--amp', action='store_true', help='use amp')
56 |
57 | opts = parser.parse_args()
58 |
59 | if opts.batch_size > 256:
60 | print("Forcing warm")
61 | opts.warm = True
62 |
63 | if opts.lr_decay_step is not None:
64 | opts.lr_decay_epochs = list(range(opts.lr_decay_step, opts.epochs, opts.lr_decay_step))
65 | print(f"Computed decay epochs based on step ({opts.lr_decay_step}):", opts.lr_decay_epochs)
66 | else:
67 | iterations = opts.lr_decay_epochs.split(',')
68 | opts.lr_decay_epochs = list([])
69 | for it in iterations:
70 | opts.lr_decay_epochs.append(int(it))
71 |
72 | if opts.warm:
73 | opts.warmup_from = 0.01
74 | opts.warm_epochs = 10
75 | if opts.lr_decay == 'cosine':
76 | eta_min = opts.lr * (opts.lr_decay_rate ** 3)
77 | opts.warmup_to = eta_min + (opts.lr - eta_min) * (
78 | 1 + math.cos(math.pi * opts.warm_epochs / opts.epochs)) / 2
79 | else:
80 | opts.milestones = [int(s) for s in opts.lr_decay_epochs.split(',')]
81 | opts.warmup_to = opts.lr
82 |
83 | opts.fairkl_kernel = opts.kernel != 'none'
84 | return opts
85 |
86 | def get_transforms(opts):
87 | selector = FeatureExtractor("vbm")
88 |
89 | if opts.tf == 'none':
90 | aug = transforms.Lambda(lambda x: x)
91 |
92 | elif opts.tf == 'crop':
93 | aug = transforms.Compose([
94 | Crop((1, 121, 128, 121), type="random"),
95 | Pad((1, 128, 128, 128))
96 | ])
97 |
98 | elif opts.tf == 'cutout':
99 | aug = Cutout(patch_size=[1, 32, 32, 32], probability=0.5)
100 |
101 | elif opts.tf == 'all':
102 | aug = transforms.Compose([
103 | Cutout(patch_size=[1, 32, 32, 32], probability=0.5),
104 | Crop((1, 121, 128, 121), type="random"),
105 | Pad((1, 128, 128, 128))
106 | ])
107 |
108 | T_pre = transforms.Lambda(lambda x: selector.transform(x))
109 | T_train = transforms.Compose([
110 | T_pre,
111 | aug,
112 | transforms.Lambda(lambda x: torch.from_numpy(x).float()),
113 | transforms.Normalize(mean=0.0, std=1.0)
114 | ])
115 |
116 | T_test = transforms.Compose([
117 | T_pre,
118 | transforms.Lambda(lambda x: torch.from_numpy(x).float()),
119 | transforms.Normalize(mean=0.0, std=1.0)
120 | ])
121 |
122 | return T_train, T_test
123 |
124 |
125 | def load_data(opts):
126 | T_train, T_test = get_transforms(opts)
127 |
128 | train_dataset = OpenBHB(opts.data_dir, train=True, internal=True, transform=T_train,
129 | load_feats=opts.biased_features)
130 |
131 | if opts.train_all:
132 | valint_feats, valext_feats = None, None
133 | if opts.biased_features is not None:
134 | valint_feats = opts.biased_features.replace('.pth', '_valint.pth')
135 | valext_feats = opts.biased_features.replace('.pth', '_valext.pth')
136 |
137 | valint = OpenBHB(opts.data_dir, train=False, internal=True, transform=T_train,
138 | load_feats=valint_feats)
139 | valext = OpenBHB(opts.data_dir, train=False, internal=False, transform=T_train,
140 | load_feats=valext_feats)
141 | train_dataset = torch.utils.data.ConcatDataset([train_dataset, valint, valext])
142 | print("Total dataset lenght:", len(train_dataset))
143 |
144 |
145 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=opts.batch_size, shuffle=True,
146 | num_workers=8, persistent_workers=True)
147 |
148 | test_internal = torch.utils.data.DataLoader(OpenBHB(opts.data_dir, train=False, internal=True, transform=T_test),
149 | batch_size=opts.batch_size, shuffle=False, num_workers=8,
150 | persistent_workers=True)
151 | test_external = torch.utils.data.DataLoader(OpenBHB(opts.data_dir, train=False, internal=False, transform=T_test),
152 | batch_size=opts.batch_size, shuffle=False, num_workers=8,
153 | persistent_workers=True)
154 |
155 | return train_loader, test_internal, test_external
156 |
157 | def load_model(opts):
158 | if 'resnet' in opts.model:
159 | model = models.SupRegResNet(opts.model)
160 |
161 | elif 'alexnet' in opts.model:
162 | model = models.SupRegAlexNet()
163 |
164 | elif 'densenet121' in opts.model:
165 | model = models.SupRegDenseNet()
166 |
167 | else:
168 | raise ValueError("Unknown model", opts.model)
169 |
170 | if opts.device == 'cuda' and torch.cuda.device_count() > 1:
171 | print(f"Using multiple CUDA devices ({torch.cuda.device_count()})")
172 | model = torch.nn.DataParallel(model)
173 | model = model.to(opts.device)
174 |
175 | methods = {
176 | 'mae': F.l1_loss,
177 | 'mse': F.mse_loss
178 | }
179 | regression_loss = methods[opts.method]
180 |
181 | return model, regression_loss
182 |
183 | def load_optimizer(model, opts):
184 | if opts.optimizer == "sgd":
185 | optimizer = torch.optim.SGD(model.parameters(), lr=opts.lr,
186 | momentum=opts.momentum,
187 | weight_decay=opts.weight_decay)
188 | else:
189 | optimizer = torch.optim.Adam(model.parameters(), lr=opts.lr, weight_decay=opts.weight_decay)
190 |
191 | return optimizer
192 |
193 | def train(train_loader, model, criterion, optimizer, opts, epoch):
194 | loss = AverageMeter()
195 | mae = MAE()
196 |
197 | batch_time = AverageMeter()
198 | data_time = AverageMeter()
199 |
200 | scaler = torch.cuda.amp.GradScaler() if opts.amp else None
201 | model.train()
202 |
203 | t1 = time.time()
204 | for idx, (images, labels, _) in enumerate(train_loader):
205 | data_time.update(time.time() - t1)
206 |
207 | images, labels = images.to(opts.device), labels.to(opts.device)
208 | bsz = labels.shape[0]
209 |
210 | warmup_learning_rate(opts, epoch, idx, len(train_loader), optimizer)
211 |
212 | with torch.cuda.amp.autocast(scaler is not None):
213 | output, features = model(images)
214 | output = output.view(-1)
215 | running_loss = criterion(output, features, labels.float())
216 |
217 | optimizer.zero_grad()
218 | if scaler is None:
219 | running_loss.backward()
220 | optimizer.step()
221 | else:
222 | scaler.scale(running_loss).backward()
223 | scaler.step(optimizer)
224 | scaler.update()
225 |
226 | loss.update(running_loss.item(), bsz)
227 | mae.update(output, labels)
228 |
229 | batch_time.update(time.time() - t1)
230 | eta = batch_time.avg * (len(train_loader) - idx)
231 |
232 | if (idx + 1) % opts.print_freq == 0:
233 | print(f"Train: [{epoch}][{idx + 1}/{len(train_loader)}]:\t"
234 | f"BT {batch_time.avg:.3f}\t"
235 | f"ETA {datetime.timedelta(seconds=eta)}\t"
236 | f"loss {loss.avg:.3f}\t"
237 | f"MAE {mae.avg:.3f}")
238 |
239 | t1 = time.time()
240 |
241 | return loss.avg, mae.avg, batch_time.avg, data_time.avg
242 |
243 | @torch.no_grad()
244 | def test(test_loader, model, criterion, opts, epoch):
245 | loss = AverageMeter()
246 | mae = MAE()
247 | batch_time = AverageMeter()
248 |
249 | model.eval()
250 | t1 = time.time()
251 | for idx, (images, labels, _) in enumerate(test_loader):
252 | images, labels = images.to(opts.device), labels.to(opts.device)
253 | bsz = labels.shape[0]
254 |
255 | output, features = model(images)
256 | output = output.view(-1)
257 | running_loss = criterion(output, features, labels.float())
258 |
259 | loss.update(running_loss.item(), bsz)
260 | mae.update(output, labels)
261 |
262 | batch_time.update(time.time() - t1)
263 | eta = batch_time.avg * (len(train_loader) - idx)
264 |
265 | if (idx + 1) % opts.print_freq == 0:
266 | print(f"Test: [{epoch}][{idx + 1}/{len(train_loader)}]:\t"
267 | f"BT {batch_time.avg:.3f}\t"
268 | f"ETA {datetime.timedelta(seconds=eta)}\t"
269 | f"loss {loss.avg:.3f}\t"
270 | f"MAE {mae.avg:.3f}")
271 |
272 | t1 = time.time()
273 |
274 | return loss.avg, mae.avg
275 |
276 | if __name__ == '__main__':
277 | opts = parse_arguments()
278 |
279 | set_seed(opts.trial)
280 |
281 | train_loader, test_loader_int, test_loader_ext = load_data(opts)
282 | model, criterion = load_model(opts)
283 | optimizer = load_optimizer(model, opts)
284 |
285 | model_name = opts.model
286 | if opts.warm:
287 | model_name = f"{model_name}_warm"
288 |
289 |
290 | run_name = (f"{model_name}_{opts.method}_"
291 | f"{opts.optimizer}_"
292 | f"tf_{opts.tf}_"
293 | f"lr{opts.lr}_{opts.lr_decay}_step{opts.lr_decay_step}_rate{opts.lr_decay_rate}_"
294 | f"wd{opts.weight_decay}_"
295 | f"trainall_{opts.train_all}_"
296 | f"bsz{opts.batch_size}_"
297 | f"trial{opts.trial}")
298 | tb_dir = os.path.join(opts.save_dir, "tensorboard", run_name)
299 | save_dir = os.path.join(opts.save_dir, f"openbhb_models", run_name)
300 | ensure_dir(tb_dir)
301 | ensure_dir(save_dir)
302 |
303 | opts.model_class = model.__class__.__name__
304 | opts.criterion = opts.method
305 | opts.optimizer_class = optimizer.__class__.__name__
306 |
307 | wandb.init(project="brain-age-prediction", config=opts, name=run_name, sync_tensorboard=True, tags=['to test'])
308 | print('Config:', opts)
309 | print('Model:', model.__class__.__name__)
310 | print('Criterion:', opts.criterion)
311 | print('Optimizer:', optimizer)
312 | print('Scheduler:', opts.lr_decay)
313 |
314 | writer = torch.utils.tensorboard.writer.SummaryWriter(tb_dir)
315 | if opts.amp:
316 | print("Using AMP")
317 |
318 | start_time = time.time()
319 | best_acc = 0.
320 | for epoch in range(1, opts.epochs + 1):
321 | adjust_learning_rate(opts, optimizer, epoch)
322 |
323 | t1 = time.time()
324 | loss_train, mae_train, batch_time, data_time = train(train_loader, model, criterion, optimizer, opts, epoch)
325 | t2 = time.time()
326 | writer.add_scalar("train/loss", loss_train, epoch)
327 | # writer.add_scalar("train/mae", mae_train, epoch)
328 |
329 | loss_test, mae_int = test(test_loader_int, model, criterion, opts, epoch)
330 | writer.add_scalar("test/loss_int", loss_test, epoch)
331 | # writer.add_scalar("test/mae_int", mae_int, epoch)
332 |
333 | loss_test, mae_ext = test(test_loader_ext, model, criterion, opts, epoch)
334 | writer.add_scalar("test/loss_ext", loss_test, epoch)
335 | # writer.add_scalar("test/mae_ext", mae_ext, epoch)
336 |
337 | writer.add_scalar("lr", optimizer.param_groups[0]['lr'], epoch)
338 | writer.add_scalar("BT", batch_time, epoch)
339 | writer.add_scalar("DT", data_time, epoch)
340 | print(f"epoch {epoch}, total time {t2-start_time:.2f}, epoch time {t2-t1:.3f} loss {loss_test:.4f} "
341 | f"mae_int {mae_int:.3f} mae_ext {mae_ext:.3f}")
342 |
343 | if epoch % opts.save_freq == 0:
344 | # save_file = os.path.join(save_dir, f"ckpt_epoch_{epoch}.pth")
345 | # save_model(model, optimizer, opts, epoch, save_file)
346 | mae_train, mae_int, mae_ext = compute_age_mae(model, train_loader, test_loader_int, test_loader_ext, opts)
347 |
348 | writer.add_scalar("train/mae", mae_train, epoch)
349 | writer.add_scalar("test/mae_int", mae_int, epoch)
350 | writer.add_scalar("test/mae_ext", mae_ext, epoch)
351 | print("Age MAE:", mae_train, mae_int, mae_ext)
352 |
353 | ba_train, ba_int, ba_ext = compute_site_ba(model, train_loader, test_loader_int, test_loader_ext, opts)
354 | writer.add_scalar("train/site_ba", ba_train, epoch)
355 | writer.add_scalar("test/ba_int", ba_int, epoch)
356 | writer.add_scalar("test/ba_ext", ba_ext, epoch)
357 | print("Site BA:", ba_train, ba_int, ba_ext)
358 |
359 | challenge_metric = ba_int**0.3 * mae_ext
360 | writer.add_scalar("test/score", challenge_metric, epoch)
361 | print("Challenge score", challenge_metric)
362 |
363 | save_file = os.path.join(save_dir, f"weights.pth")
364 | save_model(model, optimizer, opts, epoch, save_file)
365 |
366 | mae_train, mae_int, mae_ext = compute_age_mae(model, train_loader, test_loader_int, test_loader_ext, opts)
367 | writer.add_scalar("train/mae", mae_train, epoch)
368 | writer.add_scalar("test/mae_int", mae_int, epoch)
369 | writer.add_scalar("test/mae_ext", mae_ext, epoch)
370 | print("Age MAE:", mae_train, mae_int, mae_ext)
371 |
372 | ba_train, ba_int, ba_ext = compute_site_ba(model, train_loader, test_loader_int, test_loader_ext, opts)
373 | writer.add_scalar("train/site_ba", ba_train, epoch)
374 | writer.add_scalar("test/ba_int", ba_int, epoch)
375 | writer.add_scalar("test/ba_ext", ba_ext, epoch)
376 | print("Site BA:", ba_train, ba_int, ba_ext)
377 |
378 | challenge_metric = ba_int**0.3 * mae_ext
379 | writer.add_scalar("test/score", challenge_metric, epoch)
380 | print("Challenge score", challenge_metric)
--------------------------------------------------------------------------------
/src/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .resnet3d import SupConResNet, SupRegResNet, SupCEResNet, LinearRegressor
2 | from .alexnet3d import SupConAlexNet, SupRegAlexNet
3 | from .densenet3d import SupConDenseNet, SupRegDenseNet
4 | from .estimators import AgeEstimator, SiteEstimator
--------------------------------------------------------------------------------
/src/models/alexnet3d.py:
--------------------------------------------------------------------------------
1 | """
2 | Model implemented in https://doi.org/10.5281/zenodo.4309677 by Abrol et al., 2021
3 | """
4 | from torch import nn
5 | import torch.nn.functional as F
6 | import math
7 |
8 | class AlexNet3D(nn.Module):
9 | def __init__(self):
10 | """
11 | :param num_classes: int, number of classes
12 | :param mode: "classifier" or "encoder" (returning 128-d vector)
13 | """
14 | super().__init__()
15 | self.features = nn.Sequential(
16 | nn.Conv3d(1, 64, kernel_size=5, stride=2, padding=0),
17 | nn.BatchNorm3d(64),
18 | nn.ReLU(inplace=True),
19 | nn.MaxPool3d(kernel_size=3, stride=3),
20 |
21 | nn.Conv3d(64, 128, kernel_size=3, stride=1, padding=0),
22 | nn.BatchNorm3d(128),
23 | nn.ReLU(inplace=True),
24 | nn.MaxPool3d(kernel_size=3, stride=3),
25 |
26 | nn.Conv3d(128, 192, kernel_size=3, padding=1),
27 | nn.BatchNorm3d(192),
28 | nn.ReLU(inplace=True),
29 |
30 | nn.Conv3d(192, 192, kernel_size=3, padding=1),
31 | nn.BatchNorm3d(192),
32 | nn.ReLU(inplace=True),
33 |
34 | nn.Conv3d(192, 128, kernel_size=3, padding=1),
35 | nn.BatchNorm3d(128),
36 | nn.ReLU(inplace=True),
37 | nn.AdaptiveMaxPool3d(1),
38 | )
39 |
40 |
41 | for m in self.modules():
42 | if isinstance(m, nn.Conv2d):
43 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
44 | m.weight.data.normal_(0, math.sqrt(2. / n))
45 | elif isinstance(m, nn.BatchNorm3d):
46 | m.weight.data.fill_(1)
47 | m.bias.data.zero_()
48 |
49 | def forward(self, x):
50 | xp = self.features(x)
51 | x = xp.view(xp.size(0), -1)
52 | return x
53 |
54 | class SupConAlexNet(nn.Module):
55 | """backbone + projection head"""
56 | def __init__(self, head='mlp', feat_dim=128):
57 | super().__init__()
58 | self.encoder = AlexNet3D()
59 | dim_in = 128
60 |
61 | if head == 'linear':
62 | self.head = nn.Linear(dim_in, feat_dim)
63 | elif head == 'mlp':
64 | self.head = nn.Sequential(
65 | nn.Linear(dim_in, dim_in),
66 | nn.ReLU(inplace=True),
67 | nn.Linear(dim_in, feat_dim)
68 | )
69 |
70 | else:
71 | raise NotImplementedError(
72 | 'head not supported: {}'.format(head))
73 |
74 | def forward(self, x):
75 | feat = self.encoder(x)
76 | feat = F.normalize(self.head(feat), dim=1)
77 | return feat
78 |
79 | def features(self, x):
80 | return self.forward(x)
81 |
82 |
83 | class SupRegAlexNet(nn.Module):
84 | """encoder + regressor"""
85 | def __init__(self,):
86 | super().__init__()
87 | self.encoder = AlexNet3D()
88 | self.fc = nn.Linear(128, 1)
89 |
90 | def forward(self, x):
91 | feats = self.features(x)
92 | return self.fc(feats), feats
93 |
94 | def features(self, x):
95 | return self.encoder(x)
--------------------------------------------------------------------------------
/src/models/densenet3d.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torch.utils.checkpoint as cp
5 | from collections import OrderedDict
6 |
7 | class DenseNet(nn.Module):
8 | """3D-Densenet-BC model class, based on
9 | `"Densely Connected Convolutional Networks" `_
10 | Args:
11 | growth_rate (int) - how many filters to add each layer (`k` in paper)
12 | block_config (list of 4 ints) - how many layers in each pooling block
13 | num_init_features (int) - the number of filters to learn in the first convolution layer
14 | mode (str) - "classifier" or "encoder" (all but last FC layer)
15 | bn_size (int) - multiplicative factor for number of bottle neck layers
16 | (i.e. bn_size * k features in the bottleneck layer)
17 | num_classes (int) - number of classification classes
18 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
19 | but slower. Default: *False*. See `"paper" `_
20 | """
21 |
22 | def __init__(self, growth_rate=32, block_config=(3, 12, 24, 16),
23 | num_init_features=64,
24 | bn_size=4, in_channels=1,
25 | memory_efficient=False):
26 | super(DenseNet, self).__init__()
27 | # First convolution
28 | self.features = nn.Sequential(OrderedDict([
29 | ('conv0', nn.Conv3d(in_channels, num_init_features,
30 | kernel_size=7, stride=2, padding=3, bias=False)),
31 | ('norm0', nn.BatchNorm3d(num_init_features)),
32 | ('relu0', nn.ReLU(inplace=True)),
33 | ('pool0', nn.MaxPool3d(kernel_size=3, stride=2, padding=1)),
34 | ]))
35 |
36 | # Each denseblock
37 | num_features = num_init_features
38 | for i, num_layers in enumerate(block_config):
39 | block = _DenseBlock(
40 | num_layers=num_layers,
41 | num_input_features=num_features,
42 | bn_size=bn_size,
43 | growth_rate=growth_rate,
44 | memory_efficient=memory_efficient
45 | )
46 | self.features.add_module('denseblock%d' % (i + 1), block)
47 | num_features = num_features + num_layers * growth_rate
48 | if i != len(block_config) - 1:
49 | trans = _Transition(num_input_features=num_features,
50 | num_output_features=num_features // 2)
51 | self.features.add_module('transition%d' % (i + 1), trans)
52 | num_features = num_features // 2
53 |
54 | self.num_features = num_features
55 |
56 |
57 | # Official init from torch repo.
58 | for m in self.modules():
59 | if isinstance(m, nn.Conv3d):
60 | nn.init.kaiming_normal_(m.weight)
61 | elif isinstance(m, nn.BatchNorm3d):
62 | nn.init.constant_(m.weight, 1)
63 | nn.init.constant_(m.bias, 0)
64 | elif isinstance(m, nn.Linear):
65 | nn.init.constant_(m.bias, 0)
66 |
67 | def forward(self, x):
68 | features = self.features(x)
69 | out = F.adaptive_avg_pool3d(features, 1)
70 | out = torch.flatten(out, 1)
71 | return out.squeeze(dim=1)
72 |
73 |
74 | def _bn_function_factory(norm, relu, conv):
75 | def bn_function(*inputs):
76 | concated_features = torch.cat(inputs, 1)
77 | bottleneck_output = conv(relu(norm(concated_features)))
78 | return bottleneck_output
79 |
80 | return bn_function
81 |
82 |
83 | class _DenseLayer(nn.Sequential):
84 | def __init__(self, num_input_features, growth_rate, bn_size, memory_efficient=False):
85 | super(_DenseLayer, self).__init__()
86 | self.add_module('norm1', nn.BatchNorm3d(num_input_features)),
87 | self.add_module('relu1', nn.ReLU(inplace=True)),
88 | self.add_module('conv1', nn.Conv3d(num_input_features, bn_size *
89 | growth_rate, kernel_size=1, stride=1,
90 | bias=False)),
91 | self.add_module('norm2', nn.BatchNorm3d(bn_size * growth_rate)),
92 | self.add_module('relu2', nn.ReLU(inplace=True)),
93 | self.add_module('conv2', nn.Conv3d(bn_size * growth_rate, growth_rate,
94 | kernel_size=3, stride=1, padding=1,
95 | bias=False)),
96 | self.memory_efficient = memory_efficient
97 |
98 | def forward(self, *prev_features):
99 | bn_function = _bn_function_factory(self.norm1, self.relu1, self.conv1)
100 | if self.memory_efficient and any(prev_feature.requires_grad for prev_feature in prev_features):
101 | bottleneck_output = cp.checkpoint(bn_function, *prev_features)
102 | else:
103 | bottleneck_output = bn_function(*prev_features)
104 |
105 | new_features = self.conv2(self.relu2(self.norm2(bottleneck_output)))
106 |
107 | return new_features
108 |
109 |
110 | class _DenseBlock(nn.Module):
111 | def __init__(self, num_layers, num_input_features, bn_size, growth_rate, memory_efficient=False):
112 | super(_DenseBlock, self).__init__()
113 | for i in range(num_layers):
114 | layer = _DenseLayer(
115 | num_input_features + i * growth_rate,
116 | growth_rate=growth_rate,
117 | bn_size=bn_size,
118 | memory_efficient=memory_efficient,
119 | )
120 | self.add_module('denselayer%d' % (i + 1), layer)
121 |
122 | def forward(self, init_features):
123 | features = [init_features]
124 | for name, layer in self.named_children():
125 | new_features = layer(*features)
126 | features.append(new_features)
127 | return torch.cat(features, 1)
128 |
129 |
130 | class _Transition(nn.Sequential):
131 | def __init__(self, num_input_features, num_output_features):
132 | super(_Transition, self).__init__()
133 | self.add_module('norm', nn.BatchNorm3d(num_input_features))
134 | self.add_module('relu', nn.ReLU(inplace=True))
135 | self.add_module('conv', nn.Conv3d(num_input_features, num_output_features,
136 | kernel_size=1, stride=1, bias=False))
137 | self.add_module('pool', nn.AvgPool3d(kernel_size=2, stride=2))
138 |
139 |
140 | def _densenet(arch, growth_rate, block_config, num_init_features, **kwargs):
141 | model = DenseNet(growth_rate, block_config, num_init_features, **kwargs)
142 | return model
143 |
144 |
145 | def densenet121(**kwargs):
146 | r"""Densenet-121 model from
147 | `"Densely Connected Convolutional Networks" `_
148 |
149 | Args:
150 | pretrained (bool): If True, returns a model pre-trained on ImageNet
151 | progress (bool): If True, displays a progress bar of the download to stderr
152 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
153 | but slower. Default: *False*. See `"paper" `_
154 | """
155 | return _densenet('densenet121', 32, (6, 12, 24, 16), 64, **kwargs)
156 |
157 | class SupConDenseNet(nn.Module):
158 | """backbone + projection head"""
159 | def __init__(self, head='mlp', feat_dim=128):
160 | super().__init__()
161 | self.encoder = densenet121()
162 | dim_in = self.encoder.num_features
163 |
164 | if head == 'linear':
165 | self.head = nn.Linear(dim_in, feat_dim)
166 | elif head == 'mlp':
167 | self.head = nn.Sequential(
168 | nn.Linear(dim_in, dim_in),
169 | nn.ReLU(inplace=True),
170 | nn.Linear(dim_in, feat_dim)
171 | )
172 |
173 | else:
174 | raise NotImplementedError(
175 | 'head not supported: {}'.format(head))
176 |
177 | def forward(self, x):
178 | feat = self.encoder(x)
179 | feat = F.normalize(self.head(feat), dim=1)
180 | return feat
181 |
182 | def features(self, x):
183 | return self.forward(x)
184 |
185 |
186 | class SupRegDenseNet(nn.Module):
187 | """encoder + regressor"""
188 | def __init__(self,):
189 | super().__init__()
190 | self.encoder = densenet121()
191 | self.fc = nn.Linear(self.encoder.num_features, 1)
192 |
193 | def forward(self, x):
194 | feats = self.features(x)
195 | return self.fc(feats), feats
196 |
197 | def features(self, x):
198 | return self.encoder(x)
--------------------------------------------------------------------------------
/src/models/estimators.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import multiprocessing
3 | from sklearn.base import BaseEstimator
4 | from sklearn.linear_model import LogisticRegression, Ridge
5 | from sklearn.model_selection import GridSearchCV
6 | from sklearn.metrics import mean_absolute_error
7 |
8 | class AgeEstimator(BaseEstimator):
9 | """ Define the age estimator on latent space network features.
10 | """
11 | def __init__(self):
12 | n_jobs = multiprocessing.cpu_count()
13 | self.age_estimator = GridSearchCV(
14 | Ridge(), param_grid={"alpha": 10.**np.arange(-2, 3)}, cv=5,
15 | scoring="r2", n_jobs=n_jobs)
16 |
17 | def fit(self, X, y):
18 | self.age_estimator.fit(X, y)
19 | return self.score(X, y)
20 |
21 | def predict(self, X):
22 | y_pred = self.age_estimator.predict(X)
23 | return y_pred
24 |
25 | def score(self, X, y):
26 | y_pred = self.age_estimator.predict(X)
27 | return mean_absolute_error(y, y_pred)
28 |
29 | class SiteEstimator(BaseEstimator):
30 | """ Define the site estimator on latent space network features.
31 | """
32 | def __init__(self):
33 | n_jobs = multiprocessing.cpu_count()
34 | self.site_estimator = GridSearchCV(
35 | LogisticRegression(solver="saga", max_iter=150), cv=5,
36 | param_grid={"C": 10.**np.arange(-2, 3)},
37 | scoring="balanced_accuracy", n_jobs=n_jobs)
38 |
39 | def fit(self, X, y):
40 | self.site_estimator.fit(X, y)
41 | return self.site_estimator.score(X, y)
42 |
43 | def predict(self, X):
44 | return self.site_estimator.predict(X)
45 |
46 | def score(self, X, y):
47 | return self.site_estimator.score(X, y)
--------------------------------------------------------------------------------
/src/models/resnet3d.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
7 | """3x3 convolution with padding"""
8 | return nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=stride,
9 | padding=dilation, groups=groups, bias=False, dilation=dilation)
10 |
11 | def conv1x1(in_planes, out_planes, stride=1):
12 | """1x1 convolution"""
13 | return nn.Conv3d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
14 |
15 | class BasicBlock(nn.Module):
16 | expansion = 1
17 |
18 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
19 | base_width=64, dilation=1, norm_layer=None):
20 | super(BasicBlock, self).__init__()
21 | if norm_layer is None:
22 | norm_layer = nn.BatchNorm3d
23 | if groups != 1 or base_width != 64:
24 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
25 | if dilation > 1:
26 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
27 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
28 | self.conv1 = conv3x3(inplanes, planes, stride)
29 | self.bn1 = norm_layer(planes)
30 | self.relu = nn.ReLU(inplace=True)
31 | self.conv2 = conv3x3(planes, planes)
32 | self.bn2 = norm_layer(planes)
33 | self.downsample = downsample
34 | self.stride = stride
35 |
36 | def forward(self, x):
37 | identity = x
38 |
39 | out = self.conv1(x)
40 | out = self.bn1(out)
41 | out = self.relu(out)
42 | out = self.conv2(out)
43 | out = self.bn2(out)
44 |
45 | if self.downsample is not None:
46 | identity = self.downsample(x)
47 |
48 | out += identity
49 | out = self.relu(out)
50 |
51 | return out
52 |
53 |
54 | class Bottleneck(nn.Module):
55 | expansion = 4
56 |
57 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
58 | base_width=64, dilation=1, norm_layer=None):
59 | super(Bottleneck, self).__init__()
60 | if norm_layer is None:
61 | norm_layer = nn.BatchNorm3d
62 | width = int(planes * (base_width / 64.)) * groups
63 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
64 | self.conv1 = conv1x1(inplanes, width)
65 | self.bn1 = norm_layer(width)
66 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
67 | self.bn2 = norm_layer(width)
68 | self.conv3 = conv1x1(width, planes * self.expansion)
69 | self.bn3 = norm_layer(planes * self.expansion)
70 | self.relu = nn.ReLU(inplace=True)
71 | self.downsample = downsample
72 | self.stride = stride
73 |
74 | def forward(self, x):
75 | identity = x
76 |
77 | out = self.conv1(x)
78 | out = self.bn1(out)
79 | out = self.relu(out)
80 |
81 | out = self.conv2(out)
82 | out = self.bn2(out)
83 | out = self.relu(out)
84 |
85 | out = self.conv3(out)
86 | out = self.bn3(out)
87 |
88 | if self.downsample is not None:
89 | identity = self.downsample(x)
90 |
91 | out += identity
92 | out = self.relu(out)
93 |
94 | return out
95 |
96 | class ResNet(nn.Module):
97 | """
98 | Standard 3D-ResNet architecture with big initial 7x7x7 kernel.
99 | It can be turned in mode "classifier", outputting a vector of size or
100 | "encoder", outputting a latent vector of size 512 (independent of input size).
101 | Note: only a last FC layer is added on top of the "encoder" backbone.
102 | """
103 | def __init__(self, block, layers, in_channels=1,
104 | zero_init_residual=False, groups=1, width_per_group=64, replace_stride_with_dilation=None,
105 | norm_layer=None, initial_kernel_size=7):
106 | super(ResNet, self).__init__()
107 |
108 | if norm_layer is None:
109 | norm_layer = nn.BatchNorm3d
110 | self._norm_layer = norm_layer
111 |
112 | self.name = "resnet"
113 | self.inputs = None
114 | self.inplanes = 64
115 | self.dilation = 1
116 |
117 | if replace_stride_with_dilation is None:
118 | # each element in the tuple indicates if we should replace
119 | # the 2x2 stride with a dilated convolution instead
120 | replace_stride_with_dilation = [False, False, False]
121 | if len(replace_stride_with_dilation) != 3:
122 | raise ValueError("replace_stride_with_dilation should be None "
123 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
124 | self.groups = groups
125 | self.base_width = width_per_group
126 | initial_stride = 2 if initial_kernel_size==7 else 1
127 | padding = (initial_kernel_size-initial_stride+1)//2
128 | self.conv1 = nn.Conv3d(in_channels, self.inplanes, kernel_size=initial_kernel_size, stride=initial_stride, padding=padding, bias=False)
129 | self.bn1 = norm_layer(self.inplanes)
130 | self.relu = nn.ReLU(inplace=True)
131 | self.maxpool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1)
132 |
133 | channels = [64, 128, 256, 512]
134 |
135 | self.layer1 = self._make_layer(block, channels[0], layers[0])
136 | self.layer2 = self._make_layer(block, channels[1], layers[1], stride=2, dilate=replace_stride_with_dilation[0])
137 | self.layer3 = self._make_layer(block, channels[2], layers[2], stride=2, dilate=replace_stride_with_dilation[1])
138 | self.layer4 = self._make_layer(block, channels[3], layers[3], stride=2, dilate=replace_stride_with_dilation[2])
139 | self.avgpool = nn.AdaptiveAvgPool3d(1)
140 |
141 | for m in self.modules():
142 | if isinstance(m, nn.Conv3d):
143 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
144 | elif isinstance(m, (nn.BatchNorm3d, nn.GroupNorm)):
145 | nn.init.constant_(m.weight, 1)
146 | nn.init.constant_(m.bias, 0)
147 | elif isinstance(m, nn.Linear):
148 | nn.init.normal_(m.weight, 0, 0.01)
149 | if m.bias is not None:
150 | nn.init.constant_(m.bias, 0)
151 |
152 | # Zero-initialize the last BN in each residual branch,
153 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
154 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
155 | if zero_init_residual:
156 | for m in self.modules():
157 | if isinstance(m, Bottleneck):
158 | nn.init.constant_(m.bn3.weight, 0)
159 | elif isinstance(m, BasicBlock):
160 | nn.init.constant_(m.bn2.weight, 0)
161 |
162 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
163 | norm_layer = self._norm_layer
164 | downsample = None
165 | previous_dilation = self.dilation
166 | if dilate:
167 | self.dilation *= stride
168 | stride = 1
169 | if stride != 1 or self.inplanes != planes * block.expansion:
170 | downsample = nn.Sequential(
171 | conv1x1(self.inplanes, planes * block.expansion, stride),
172 | norm_layer(planes * block.expansion),
173 | )
174 |
175 | layers = []
176 | layers.append(block(self.inplanes, planes, stride=stride, downsample=downsample, groups=self.groups,
177 | base_width=self.base_width, dilation=previous_dilation, norm_layer=norm_layer))
178 | self.inplanes = planes * block.expansion
179 | for _ in range(1, blocks):
180 | layers.append(block(self.inplanes, planes, groups=self.groups,
181 | base_width=self.base_width, dilation=self.dilation,
182 | norm_layer=norm_layer))
183 |
184 | return nn.Sequential(*layers)
185 |
186 | def forward(self, x):
187 | x = self.conv1(x)
188 | x = self.bn1(x)
189 | x = self.relu(x)
190 | x = self.maxpool(x)
191 |
192 | x1 = self.layer1(x)
193 | x2 = self.layer2(x1)
194 | x3 = self.layer3(x2)
195 | x4 = self.layer4(x3)
196 |
197 | x5 = self.avgpool(x4)
198 | return torch.flatten(x5, 1)
199 |
200 | def resnet18(**kwargs):
201 | return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
202 |
203 | def resnet34(**kwargs):
204 | return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
205 |
206 | def resnet50(**kwargs):
207 | return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
208 |
209 | def resnet101(**kwargs):
210 | return ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
211 |
212 | model_dict = {
213 | 'resnet18': [resnet18, 512],
214 | 'resnet34': [resnet34, 512],
215 | 'resnet50': [resnet50, 2048],
216 | 'resnet101': [resnet101, 2048],
217 | }
218 |
219 | class SupConResNet(nn.Module):
220 | """backbone + projection head"""
221 | def __init__(self, name='resnet50', head='mlp', feat_dim=128):
222 | super().__init__()
223 | model_fun, dim_in = model_dict[name]
224 | self.encoder = model_fun()
225 | if head == 'linear':
226 | self.head = nn.Linear(dim_in, feat_dim)
227 | elif head == 'mlp':
228 | self.head = nn.Sequential(
229 | nn.Linear(dim_in, dim_in),
230 | nn.ReLU(inplace=True),
231 | nn.Linear(dim_in, feat_dim)
232 | )
233 | else:
234 | raise NotImplementedError(
235 | 'head not supported: {}'.format(head))
236 |
237 | def forward(self, x):
238 | feat = self.encoder(x)
239 | feat = F.normalize(self.head(feat), dim=1)
240 | return feat
241 |
242 | def features(self, x):
243 | return self.forward(x)
244 |
245 |
246 | class SupRegResNet(nn.Module):
247 | """encoder + regressor"""
248 | def __init__(self, name='resnet50'):
249 | super().__init__()
250 | model_fun, dim_in = model_dict[name]
251 | self.encoder = model_fun()
252 | self.fc = nn.Linear(dim_in, 1)
253 |
254 | def forward(self, x):
255 | feats = self.features(x)
256 | return self.fc(feats), feats
257 |
258 | def features(self, x):
259 | return self.encoder(x)
260 |
261 | class SupCEResNet(nn.Module):
262 | """encoder + classifier"""
263 | def __init__(self, n_classes, name='resnet50'):
264 | super().__init__()
265 | model_fun, dim_in = model_dict[name]
266 | self.encoder = model_fun()
267 | self.fc = nn.Linear(dim_in, n_classes)
268 |
269 | def forward(self, x):
270 | return self.fc(self.encoder(x))
271 |
272 | def features(self, x):
273 | return self.encoder(x)
274 |
275 |
276 | class LinearRegressor(nn.Module):
277 | """Linear regressor"""
278 | def __init__(self, name='resnet50'):
279 | super().__init__()
280 | _, feat_dim = model_dict[name]
281 | self.fc = nn.Linear(feat_dim, 1)
282 |
283 | def forward(self, features):
284 | return self.fc(features)
--------------------------------------------------------------------------------
/src/util.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import random
4 | import numpy as np
5 | import os
6 | import wandb
7 | import torch.nn.functional as F
8 | import models
9 | from pathlib import Path
10 |
11 |
12 | class NViewTransform:
13 | """Create N augmented views of the same image"""
14 | def __init__(self, transform, N):
15 | self.transform = transform
16 | self.N = N
17 |
18 | def __call__(self, x):
19 | return [self.transform(x) for _ in range(self.N)]
20 |
21 | def arg2bool(val):
22 | if isinstance(val, bool):
23 | return val
24 |
25 | elif isinstance(val, str):
26 | if val == "true":
27 | return True
28 |
29 | if val == "false":
30 | return False
31 |
32 | val = int(val)
33 | assert val == 0 or val == 1
34 | return val == 1
35 |
36 | class AverageMeter(object):
37 | """Computes and stores the average and current value"""
38 | def __init__(self):
39 | self.reset()
40 |
41 | def reset(self):
42 | self.val = 0
43 | self.avg = 0
44 | self.sum = 0
45 | self.count = 0
46 |
47 | def update(self, val, n=1):
48 | self.val = val
49 | self.sum += val * n
50 | self.count += n
51 | self.avg = self.sum / self.count
52 |
53 | class MAE():
54 | def __init__(self):
55 | self.reset()
56 |
57 | def reset(self):
58 | self.outputs = []
59 | self.targets = []
60 | self.avg = np.inf
61 |
62 | def update(self, outputs, targets):
63 | self.outputs.append(outputs.detach())
64 | self.targets.append(targets.detach())
65 | self.avg = F.l1_loss(torch.cat(self.outputs, 0), torch.cat(self.targets, 0))
66 |
67 | class Accuracy():
68 | def __init__(self, topk=(1,)):
69 | self.reset()
70 | self.topk = topk
71 |
72 | def reset(self):
73 | self.outputs = []
74 | self.targets = []
75 | self.avg = np.inf
76 |
77 | def update(self, outputs, targets):
78 | self.outputs.append(outputs.detach())
79 | self.targets.append(targets.detach())
80 | self.avg = accuracy(torch.cat(self.outputs, 0), torch.cat(self.targets, 0), self.topk)
81 |
82 | def ensure_dir(dirname):
83 | dirname = Path(dirname)
84 | if not dirname.is_dir():
85 | dirname.mkdir(parents=True, exist_ok=True)
86 |
87 | def accuracy(output, target, topk=(1,)):
88 | """Computes the accuracy over the k top predictions for the specified values of k"""
89 | with torch.no_grad():
90 | maxk = max(topk)
91 | batch_size = target.size(0)
92 |
93 | _, pred = output.topk(maxk, 1, True, True)
94 | pred = pred.t()
95 | correct = pred.eq(target.reshape(1, -1).expand_as(pred))
96 |
97 | res = []
98 | for k in topk:
99 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
100 | res.append(correct_k.mul_(100.0 / batch_size).item())
101 | return res
102 |
103 | def set_seed(seed):
104 | random.seed(seed)
105 | os.environ["PYTHONHASHSEED"] = str(seed)
106 | np.random.seed(seed)
107 | torch.cuda.manual_seed(seed)
108 | torch.cuda.manual_seed_all(seed)
109 | torch.backends.cudnn.deterministic = False
110 | torch.backends.cudnn.benchmark = True
111 | torch.manual_seed(seed)
112 |
113 | def save_model(model, optimizer, opt, epoch, save_file):
114 | print('==> Saving...')
115 | state_dict = model.state_dict()
116 | if torch.cuda.device_count() > 1:
117 | state_dict = model.module.state_dict()
118 |
119 | state = {
120 | 'opts': opt,
121 | 'model': state_dict,
122 | 'optimizer': optimizer.state_dict(),
123 | 'epoch': epoch,
124 | 'run_id': wandb.run.id
125 | }
126 | torch.save(state, save_file)
127 | del state
128 |
129 | def adjust_learning_rate(args, optimizer, epoch):
130 | lr = args.lr
131 | if args.lr_decay == 'cosine':
132 | eta_min = lr * (args.lr_decay_rate ** 3)
133 | lr = eta_min + (lr - eta_min) * (
134 | 1 + math.cos(math.pi * epoch / args.epochs)) / 2
135 | else:
136 | steps = np.sum(epoch > np.asarray(args.lr_decay_epochs))
137 | if steps > 0:
138 | lr = lr * (args.lr_decay_rate ** steps)
139 |
140 | for param_group in optimizer.param_groups:
141 | param_group['lr'] = lr
142 |
143 |
144 | def warmup_learning_rate(args, epoch, batch_id, total_batches, optimizer):
145 | if args.warm and epoch <= args.warm_epochs:
146 | p = (batch_id + (epoch - 1) * total_batches) / \
147 | (args.warm_epochs * total_batches)
148 | lr = args.warmup_from + p * (args.warmup_to - args.warmup_from)
149 |
150 | for param_group in optimizer.param_groups:
151 | param_group['lr'] = lr
152 |
153 | @torch.no_grad()
154 | def gather_age_feats(model, dataloader, opts):
155 | features = []
156 | age_labels = []
157 |
158 | model.eval()
159 | for idx, (images, labels, _) in enumerate(dataloader):
160 | if isinstance(images, list):
161 | images = images[0]
162 | images = images.to(opts.device)
163 | features.append(model.features(images))
164 | age_labels.append(labels)
165 |
166 | return torch.cat(features, 0).cpu().numpy(), torch.cat(age_labels, 0).cpu().numpy()
167 |
168 | @torch.no_grad()
169 | def compute_age_mae(model, train_loader, test_int, test_ext, opts):
170 | site_estimator = models.AgeEstimator()
171 |
172 | print("Training age estimator")
173 | train_X, train_y = gather_age_feats(model, train_loader, opts)
174 | mae_train = site_estimator.fit(train_X, train_y)
175 |
176 | print("Computing BA")
177 | int_X, int_y = gather_age_feats(model, test_int, opts)
178 | ext_X, ext_y = gather_age_feats(model, test_ext, opts)
179 | mae_int = site_estimator.score(int_X, int_y)
180 | mae_ext = site_estimator.score(ext_X, ext_y)
181 |
182 | return mae_train, mae_int, mae_ext
183 |
184 | @torch.no_grad()
185 | def gather_site_feats(model, dataloader, opts):
186 | features = []
187 | site_labels = []
188 |
189 | model.eval()
190 | for idx, (images, _, sites) in enumerate(dataloader):
191 | if isinstance(images, list):
192 | images = images[0]
193 | images = images.to(opts.device)
194 | features.append(model.features(images))
195 | site_labels.append(sites)
196 |
197 | return torch.cat(features, 0).cpu().numpy(), torch.cat(site_labels, 0).cpu().numpy()
198 |
199 | @torch.no_grad()
200 | def compute_site_ba(model, train_loader, test_int, test_ext, opts):
201 | site_estimator = models.SiteEstimator()
202 |
203 | print("Training site estimator")
204 | train_X, train_y = gather_site_feats(model, train_loader, opts)
205 | ba_train = site_estimator.fit(train_X, train_y)
206 |
207 | print("Computing BA")
208 | int_X, int_y = gather_site_feats(model, test_int, opts)
209 | ext_X, ext_y = gather_site_feats(model, test_ext, opts)
210 | ba_int = site_estimator.score(int_X, int_y)
211 | ba_ext = site_estimator.score(ext_X, ext_y)
212 |
213 | return ba_train, ba_int, ba_ext
--------------------------------------------------------------------------------