├── data ├── __init__.py ├── data_rfmid.py ├── data_palm.py ├── data_aptos.py └── data_ukb.py ├── models ├── __init__.py ├── genetics_model.py └── resnet_unet.py ├── util ├── __init__.py ├── lr_sched.py ├── crop.py ├── lars.py ├── datasets.py ├── lr_decay.py ├── pos_embed.py └── misc.py ├── image_preprocessing ├── __init__.py ├── filtering_images.py └── resize.py ├── README.md ├── LICENSE ├── main_pretrain.py └── models_mrm.py /data/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /image_preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MRM 2 | [ICCV' 23] MRM: Masked Relation Modeling for Medical Image Pre-Training with Genetics 3 | -------------------------------------------------------------------------------- /util/lr_sched.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | 9 | def adjust_learning_rate(optimizer, epoch, args): 10 | """Decay the learning rate with half-cycle cosine after warmup""" 11 | if epoch < args.warmup_epochs: 12 | lr = args.lr * epoch / args.warmup_epochs 13 | else: 14 | lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \ 15 | (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs))) 16 | for param_group in optimizer.param_groups: 17 | if "lr_scale" in param_group: 18 | param_group["lr"] = lr * param_group["lr_scale"] 19 | else: 20 | param_group["lr"] = lr 21 | return lr 22 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 CityU-AIM-Group 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /models/genetics_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class DietNetworkBasic(nn.Module): 6 | def __init__( 7 | self, 8 | n_feats, # this can change depending on what kind of feature dimension reduction was done 9 | n_hidden1_u, 10 | n_hidden2_u=None, 11 | eps=1e-05, 12 | ): 13 | super(DietNetworkBasic, self).__init__() 14 | 15 | # 1st hidden layer 16 | self.hidden_1 = nn.Linear(n_feats, n_hidden1_u) 17 | self.bn1 = nn.BatchNorm1d(num_features=n_hidden1_u, eps=eps) 18 | 19 | # 2nd hidden layer 20 | self.hidden_2 = None 21 | if n_hidden2_u is not None: 22 | self.hidden_2 = nn.Linear(n_hidden1_u, n_hidden2_u) 23 | self.bn2 = nn.BatchNorm1d(num_features=n_hidden2_u, eps=eps) 24 | 25 | def forward(self, x): 26 | z1 = self.hidden_1(x) 27 | a1 = torch.relu(z1) 28 | a1 = self.bn1(a1) 29 | out = a1 30 | 31 | if self.hidden_2 is not None: 32 | z2 = self.hidden_2(a1) 33 | a2 = torch.relu(z2) 34 | a2 = self.bn2(a2) 35 | out = a2 36 | 37 | return out 38 | -------------------------------------------------------------------------------- /image_preprocessing/filtering_images.py: -------------------------------------------------------------------------------- 1 | """basic quality control for retinal fundus images 2 | 3 | `drop_qc_paths` filters out the top_p*100% brightest and bot_p*100% darkest images 4 | """ 5 | from glob import glob 6 | from os.path import join 7 | 8 | import numpy as np 9 | import pandas as pd 10 | from PIL import Image 11 | from tqdm import tqdm 12 | 13 | # TODO: fill in your paths 14 | BASE_IMG = "PATH/TO/IMAGES/" 15 | LEFT = join(BASE_IMG, "left/512_left/processed") 16 | RIGHT = join(BASE_IMG, "right/512_right/processed") 17 | IMG_EXT = ".jpg" 18 | 19 | 20 | def drop_qc_paths( 21 | fn=join(BASE_IMG, "{eye}", "qc_paths_{eye}.txt"), 22 | top_p=0.005, 23 | bot_p=0.005, 24 | ): 25 | for base, eye in [(LEFT, "left"), (RIGHT, "right")]: 26 | paths = get_paths(base, subset=None) 27 | B = compute_brightness(paths) 28 | ind = np.argsort(B) 29 | N = len(B) 30 | B = B[ind] 31 | paths = np.array(paths)[ind] 32 | paths = paths[int(bot_p * N) : int((1 - top_p) * N)] 33 | pd.DataFrame([p.split("/")[-1] for p in paths]).to_csv( 34 | fn.format(eye=eye), index=None, header=None 35 | ) 36 | 37 | 38 | def get_paths(base, subset=None): 39 | return glob(join(base, "*.jpg"))[:subset] 40 | 41 | 42 | def compute_brightness(paths): 43 | brightnesses = [] 44 | for p in tqdm(paths): 45 | brightness = np.array(Image.open(p)).mean() 46 | brightnesses.append(brightness) 47 | return np.array(brightnesses) 48 | -------------------------------------------------------------------------------- /util/crop.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | 9 | import torch 10 | 11 | from torchvision import transforms 12 | from torchvision.transforms import functional as F 13 | 14 | 15 | class RandomResizedCrop(transforms.RandomResizedCrop): 16 | """ 17 | RandomResizedCrop for matching TF/TPU implementation: no for-loop is used. 18 | This may lead to results different with torchvision's version. 19 | Following BYOL's TF code: 20 | https://github.com/deepmind/deepmind-research/blob/master/byol/utils/dataset.py#L206 21 | """ 22 | @staticmethod 23 | def get_params(img, scale, ratio): 24 | width, height = F._get_image_size(img) 25 | area = height * width 26 | 27 | target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item() 28 | log_ratio = torch.log(torch.tensor(ratio)) 29 | aspect_ratio = torch.exp( 30 | torch.empty(1).uniform_(log_ratio[0], log_ratio[1]) 31 | ).item() 32 | 33 | w = int(round(math.sqrt(target_area * aspect_ratio))) 34 | h = int(round(math.sqrt(target_area / aspect_ratio))) 35 | 36 | w = min(w, width) 37 | h = min(h, height) 38 | 39 | i = torch.randint(0, height - h + 1, size=(1,)).item() 40 | j = torch.randint(0, width - w + 1, size=(1,)).item() 41 | 42 | return i, j, h, w -------------------------------------------------------------------------------- /util/lars.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # LARS optimizer, implementation from MoCo v3: 8 | # https://github.com/facebookresearch/moco-v3 9 | # -------------------------------------------------------- 10 | 11 | import torch 12 | 13 | 14 | class LARS(torch.optim.Optimizer): 15 | """ 16 | LARS optimizer, no rate scaling or weight decay for parameters <= 1D. 17 | """ 18 | def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001): 19 | defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, trust_coefficient=trust_coefficient) 20 | super().__init__(params, defaults) 21 | 22 | @torch.no_grad() 23 | def step(self): 24 | for g in self.param_groups: 25 | for p in g['params']: 26 | dp = p.grad 27 | 28 | if dp is None: 29 | continue 30 | 31 | if p.ndim > 1: # if not normalization gamma/beta or bias 32 | dp = dp.add(p, alpha=g['weight_decay']) 33 | param_norm = torch.norm(p) 34 | update_norm = torch.norm(dp) 35 | one = torch.ones_like(param_norm) 36 | q = torch.where(param_norm > 0., 37 | torch.where(update_norm > 0, 38 | (g['trust_coefficient'] * param_norm / update_norm), one), 39 | one) 40 | dp = dp.mul(q) 41 | 42 | param_state = self.state[p] 43 | if 'mu' not in param_state: 44 | param_state['mu'] = torch.zeros_like(p) 45 | mu = param_state['mu'] 46 | mu.mul_(g['momentum']).add_(dp) 47 | p.add_(mu, alpha=-g['lr']) -------------------------------------------------------------------------------- /util/datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # -------------------------------------------------------- 10 | 11 | import os 12 | import PIL 13 | 14 | from torchvision import datasets, transforms 15 | 16 | from timm.data import create_transform 17 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 18 | 19 | 20 | def build_dataset(is_train, args): 21 | transform = build_transform(is_train, args) 22 | 23 | root = os.path.join(args.data_path, 'train' if is_train else 'val') 24 | dataset = datasets.ImageFolder(root, transform=transform) 25 | 26 | print(dataset) 27 | 28 | return dataset 29 | 30 | 31 | def build_transform(is_train, args): 32 | mean = IMAGENET_DEFAULT_MEAN 33 | std = IMAGENET_DEFAULT_STD 34 | # train transform 35 | if is_train: 36 | # this should always dispatch to transforms_imagenet_train 37 | transform = create_transform( 38 | input_size=args.input_size, 39 | is_training=True, 40 | color_jitter=args.color_jitter, 41 | auto_augment=args.aa, 42 | interpolation='bicubic', 43 | re_prob=args.reprob, 44 | re_mode=args.remode, 45 | re_count=args.recount, 46 | mean=mean, 47 | std=std, 48 | ) 49 | return transform 50 | 51 | # eval transform 52 | t = [] 53 | if args.input_size <= 224: 54 | crop_pct = 224 / 256 55 | else: 56 | crop_pct = 1.0 57 | size = int(args.input_size / crop_pct) 58 | t.append( 59 | transforms.Resize(size, interpolation=PIL.Image.BICUBIC), # to maintain same ratio w.r.t. 224 images 60 | ) 61 | t.append(transforms.CenterCrop(args.input_size)) 62 | 63 | t.append(transforms.ToTensor()) 64 | t.append(transforms.Normalize(mean, std)) 65 | return transforms.Compose(t) 66 | -------------------------------------------------------------------------------- /util/lr_decay.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # ELECTRA https://github.com/google-research/electra 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import json 13 | 14 | 15 | def param_groups_lrd(model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75): 16 | """ 17 | Parameter groups for layer-wise lr decay 18 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58 19 | """ 20 | param_group_names = {} 21 | param_groups = {} 22 | 23 | num_layers = len(model.blocks) + 1 24 | 25 | layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1)) 26 | 27 | for n, p in model.named_parameters(): 28 | if not p.requires_grad: 29 | continue 30 | 31 | # no decay: all 1D parameters and model specific ones 32 | if p.ndim == 1 or n in no_weight_decay_list: 33 | g_decay = "no_decay" 34 | this_decay = 0. 35 | else: 36 | g_decay = "decay" 37 | this_decay = weight_decay 38 | 39 | layer_id = get_layer_id_for_vit(n, num_layers) 40 | group_name = "layer_%d_%s" % (layer_id, g_decay) 41 | 42 | if group_name not in param_group_names: 43 | this_scale = layer_scales[layer_id] 44 | 45 | param_group_names[group_name] = { 46 | "lr_scale": this_scale, 47 | "weight_decay": this_decay, 48 | "params": [], 49 | } 50 | param_groups[group_name] = { 51 | "lr_scale": this_scale, 52 | "weight_decay": this_decay, 53 | "params": [], 54 | } 55 | 56 | param_group_names[group_name]["params"].append(n) 57 | param_groups[group_name]["params"].append(p) 58 | 59 | # print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2)) 60 | 61 | return list(param_groups.values()) 62 | 63 | 64 | def get_layer_id_for_vit(name, num_layers): 65 | """ 66 | Assign a parameter with its layer id 67 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33 68 | """ 69 | if name in ['cls_token', 'pos_embed']: 70 | return 0 71 | elif name.startswith('patch_embed'): 72 | return 0 73 | elif name.startswith('blocks'): 74 | return int(name.split('.')[1]) + 1 75 | else: 76 | return num_layers -------------------------------------------------------------------------------- /image_preprocessing/resize.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import warnings 4 | 5 | os.environ["OPENBLAS_NUM_THREADS"] = "1" 6 | os.environ["MKL_NUM_THREADS"] = "1" 7 | os.environ["NUMEXPR_NUM_THREADS"] = "1" 8 | from os.path import join 9 | from glob import glob 10 | from joblib import Parallel, delayed 11 | 12 | from tqdm import tqdm 13 | 14 | import PIL.Image 15 | import cv2 16 | 17 | cv2.setNumThreads(0) 18 | from skimage.draw import circle_perimeter 19 | 20 | 21 | warnings.filterwarnings("ignore", category=UserWarning) 22 | 23 | 24 | def main(): 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument("inp", type=str, help="path to input directory") 27 | parser.add_argument("out", type=str, help="path for the resized images") 28 | parser.add_argument("--ext", type=str, default=".jpg", help="output extension") 29 | parser.add_argument( 30 | "--num_workers", type=int, default=10, help="number of parallel workers" 31 | ) 32 | parser.add_argument("--size", type=int, default=672, help="resize to which size") 33 | args = parser.parse_args() 34 | 35 | resize_all_imgs( 36 | args.inp, args.out, args.size, extension=args.ext, n_jobs=args.num_workers 37 | ) 38 | 39 | 40 | def detect_circle(p, buf=5): 41 | img = cv2.imread(p) 42 | h, w = img.shape[:2] 43 | img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 44 | img = cv2.medianBlur(img, 25) 45 | circles = cv2.HoughCircles( 46 | img, 47 | cv2.HOUGH_GRADIENT, 48 | 1, 49 | minDist=50, 50 | minRadius=min(h, w) // 4, 51 | ) 52 | if circles is None: 53 | return 54 | else: 55 | C = circles[0, 0].round().astype(int) 56 | cc, rr = circle_perimeter(C[0], C[1], C[2]) 57 | return [cc.min() - buf, rr.min() - buf, cc.max() + buf, rr.max() + buf] 58 | 59 | 60 | def resize_all_imgs(input_dir, output_dir, max_size, extension=".png", n_jobs=10): 61 | img_files = glob(join(input_dir, "*")) 62 | if not os.path.isdir(output_dir): 63 | os.makedirs(output_dir, exist_ok=True) 64 | os.makedirs(join(output_dir, "failed"), exist_ok=True) 65 | os.makedirs(join(output_dir, "processed"), exist_ok=True) 66 | Parallel(n_jobs=n_jobs)( 67 | delayed(resize_one)(fn, max_size, output_dir, extension=extension) 68 | for fn in tqdm(img_files) 69 | ) 70 | 71 | 72 | def resize_one(fn, max_size, output_dir, buffer=5, extension=".jpg"): 73 | try: 74 | img = PIL.Image.open(fn) 75 | box = detect_circle(fn, buf=buffer) 76 | fnn = fn.split("/")[-1].split(".")[0] + extension 77 | if box is None: 78 | img.save(join(output_dir, "failed", fnn)) 79 | else: 80 | img.crop(box).resize((max_size, max_size)).save( 81 | join(output_dir, "processed", fnn) 82 | ) 83 | except Exception as e: 84 | print(e) 85 | print("cannot handle file", fn) 86 | 87 | 88 | if __name__ == "__main__": 89 | main() 90 | -------------------------------------------------------------------------------- /util/pos_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # Position embedding utils 8 | # -------------------------------------------------------- 9 | 10 | import numpy as np 11 | 12 | import torch 13 | 14 | # -------------------------------------------------------- 15 | # 2D sine-cosine position embedding 16 | # References: 17 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py 18 | # MoCo v3: https://github.com/facebookresearch/moco-v3 19 | # -------------------------------------------------------- 20 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 21 | """ 22 | grid_size: int of the grid height and width 23 | return: 24 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 25 | """ 26 | grid_h = np.arange(grid_size, dtype=np.float32) 27 | grid_w = np.arange(grid_size, dtype=np.float32) 28 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 29 | grid = np.stack(grid, axis=0) 30 | 31 | grid = grid.reshape([2, 1, grid_size, grid_size]) 32 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 33 | if cls_token: 34 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 35 | return pos_embed 36 | 37 | 38 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 39 | assert embed_dim % 2 == 0 40 | 41 | # use half of dimensions to encode grid_h 42 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 43 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 44 | 45 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 46 | return emb 47 | 48 | 49 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 50 | """ 51 | embed_dim: output dimension for each position 52 | pos: a list of positions to be encoded: size (M,) 53 | out: (M, D) 54 | """ 55 | assert embed_dim % 2 == 0 56 | omega = np.arange(embed_dim // 2, dtype=np.float) 57 | omega /= embed_dim / 2. 58 | omega = 1. / 10000**omega # (D/2,) 59 | 60 | pos = pos.reshape(-1) # (M,) 61 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 62 | 63 | emb_sin = np.sin(out) # (M, D/2) 64 | emb_cos = np.cos(out) # (M, D/2) 65 | 66 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 67 | return emb 68 | 69 | 70 | # -------------------------------------------------------- 71 | # Interpolate position embeddings for high-resolution 72 | # References: 73 | # DeiT: https://github.com/facebookresearch/deit 74 | # -------------------------------------------------------- 75 | def interpolate_pos_embed(model, checkpoint_model): 76 | if 'pos_embed' in checkpoint_model: 77 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 78 | embedding_size = pos_embed_checkpoint.shape[-1] 79 | num_patches = model.patch_embed.num_patches 80 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 81 | # height (== width) for the checkpoint position embedding 82 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 83 | # height (== width) for the new position embedding 84 | new_size = int(num_patches ** 0.5) 85 | # class_token and dist_token are kept unchanged 86 | if orig_size != new_size: 87 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 88 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 89 | # only the position tokens are interpolated 90 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 91 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 92 | pos_tokens = torch.nn.functional.interpolate( 93 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 94 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 95 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 96 | checkpoint_model['pos_embed'] = new_pos_embed 97 | -------------------------------------------------------------------------------- /data/data_rfmid.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ["OPENBLAS_NUM_THREADS"] = "1" 4 | os.environ["MKL_NUM_THREADS"] = "1" 5 | os.environ["NUMEXPR_NUM_THREADS"] = "1" 6 | 7 | from os.path import join 8 | from functools import partial 9 | from joblib import Parallel, delayed 10 | 11 | import numpy as np 12 | import pandas as pd 13 | import torch 14 | from PIL import Image 15 | from torch.utils.data import DataLoader 16 | from torch.utils.data.dataset import Dataset 17 | from torchvision import transforms 18 | from tqdm import tqdm 19 | import toml 20 | 21 | 22 | torch.multiprocessing.set_sharing_strategy("file_system") 23 | 24 | BASE = toml.load(join(os.path.dirname(os.path.realpath(__file__)), "../paths.toml"))[ 25 | "RFMID_PATH" 26 | ] 27 | 28 | 29 | def get_tfms(size=256, augmentation=False): 30 | mean = [0.485, 0.456, 0.406] 31 | std = [0.229, 0.224, 0.225] 32 | if augmentation: 33 | tfms = transforms.Compose( 34 | [ 35 | transforms.Resize(size=size), 36 | transforms.RandomHorizontalFlip(p=0.5), 37 | transforms.ToTensor(), 38 | transforms.Normalize(mean=mean, std=std), 39 | ] 40 | ) 41 | else: 42 | tfms = transforms.Compose( 43 | [ 44 | transforms.Resize(size=size), 45 | transforms.ToTensor(), 46 | transforms.Normalize(mean=mean, std=std), 47 | ] 48 | ) 49 | return tfms 50 | 51 | 52 | def get_rfmid_loaders( 53 | size, 54 | batch_size=64, 55 | num_workers=8, 56 | ): 57 | """get dataloaders for RFMiD dataset, and also return number of labels""" 58 | loaders = [] 59 | for split in ["train", "valid", "test"]: 60 | tfms = get_tfms(size=size, augmentation=split == "train") 61 | D = RFMiD( 62 | split=split, tfms=tfms, drop_disease_risk_hr_odpm=True, use_cropped=True 63 | ) 64 | loader = DataLoader( 65 | D, 66 | batch_size=batch_size, 67 | shuffle=split == "train", 68 | num_workers=num_workers, 69 | pin_memory=True, 70 | ) 71 | loaders.append(loader) 72 | 73 | return loaders 74 | 75 | 76 | class RFMiD(Dataset): 77 | def __init__( 78 | self, split="train", tfms=None, drop_disease_risk_hr_odpm=True, use_cropped=True 79 | ): 80 | if split == "train": 81 | subdir = "Training_Set" 82 | label_fn = "RFMiD_Training_Labels.csv" 83 | img_subdir = "Training" 84 | elif split == "val" or split == "valid": 85 | subdir = "Evaluation_Set" 86 | label_fn = "RFMiD_Validation_Labels.csv" 87 | img_subdir = "Validation" 88 | elif split == "test": 89 | subdir = "Test_Set" 90 | label_fn = "RFMiD_Testing_Labels.csv" 91 | img_subdir = "Test" 92 | else: 93 | raise ValueError(f"split {split} not valid") 94 | if use_cropped: 95 | img_subdir = img_subdir + "_cropped" 96 | label_pth = join(BASE, subdir, subdir, label_fn) 97 | self.labels = pd.read_csv(label_pth, index_col=0) 98 | self.ext = ".png" 99 | if drop_disease_risk_hr_odpm: 100 | self.labels = self.labels.drop("Disease_Risk", 1) 101 | self.labels = self.labels.drop("HR", 1) 102 | self.labels = self.labels.drop("ODPM", 1) 103 | self.ext = ".jpg" 104 | self.img_dir = join(BASE, subdir, subdir, img_subdir) 105 | self.tfms = tfms 106 | 107 | def __len__(self): 108 | return len(self.labels) 109 | 110 | def __getitem__(self, idx): 111 | if isinstance(idx, torch.Tensor): 112 | idx = idx.item() 113 | inst = self.labels.iloc[idx] 114 | labels = inst.values 115 | id = inst.name 116 | p = join(self.img_dir, str(id) + self.ext) 117 | img = Image.open(p) 118 | if self.tfms: 119 | img = self.tfms(img) 120 | return img, np.array(labels, dtype=np.float) 121 | 122 | 123 | #### data standardization utils 124 | def crop_resize_all(split, dst, size=512, buffer=10, n_jobs=10): 125 | """prepare RFMiD images by center-crop-padding""" 126 | os.makedirs(dst, exist_ok=True) 127 | tfms = transforms.Compose( 128 | [ 129 | partial(center_crop_pad, buffer=buffer), 130 | transforms.Resize(size), 131 | ] 132 | ) 133 | D = RFMiD(split=split, tfms=tfms) 134 | Parallel(n_jobs=n_jobs)( 135 | delayed(lambda i: D[i][0].save(join(dst, f"{i+1}.jpg")))(i) 136 | for i in tqdm(range(len(D))) 137 | ) 138 | 139 | 140 | def center_crop_pad(img, buffer=0, min_mean=10): 141 | """dynamically center crop image, cropping away black space left and right""" 142 | g = np.array(img).mean(-1) 143 | h, w = g.shape 144 | zeros = g.mean(0) 145 | zero_inds = np.where(zeros < min_mean)[0] 146 | lo, hi = zero_inds[zero_inds < w // 2].max(), zero_inds[zero_inds > w // 2].min() 147 | return expand2square(img.crop((lo - buffer, 0, hi + buffer, h))) 148 | 149 | 150 | def expand2square(pil_img, background_color=0): 151 | """from https://note.nkmk.me/en/python-pillow-add-margin-expand-canvas/""" 152 | width, height = pil_img.size 153 | if width == height: 154 | return pil_img 155 | elif width > height: 156 | result = Image.new(pil_img.mode, (width, width), background_color) 157 | result.paste(pil_img, (0, (width - height) // 2)) 158 | return result 159 | else: 160 | result = Image.new(pil_img.mode, (height, height), background_color) 161 | result.paste(pil_img, ((height - width) // 2, 0)) 162 | return result 163 | -------------------------------------------------------------------------------- /data/data_palm.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ["OPENBLAS_NUM_THREADS"] = "1" 4 | os.environ["MKL_NUM_THREADS"] = "1" 5 | os.environ["NUMEXPR_NUM_THREADS"] = "1" 6 | 7 | from os.path import join 8 | from glob import glob 9 | 10 | import pandas as pd 11 | from sklearn.model_selection import StratifiedShuffleSplit 12 | 13 | import torch 14 | from PIL import Image 15 | from torch.utils.data import DataLoader 16 | from torch.utils.data.dataset import Dataset 17 | from torchvision import transforms 18 | import toml 19 | 20 | 21 | torch.multiprocessing.set_sharing_strategy("file_system") 22 | 23 | BASE = toml.load(join(os.path.dirname(os.path.realpath(__file__)), "../paths.toml"))[ 24 | "PALM_PATH" 25 | ] 26 | 27 | 28 | def get_tfms(size=256): 29 | mean = [0.485, 0.456, 0.406] 30 | std = [0.229, 0.224, 0.225] 31 | tfms = transforms.Compose( 32 | [ 33 | transforms.Resize(size=(size, size)), 34 | transforms.ToTensor(), 35 | transforms.Normalize(mean=mean, std=std), 36 | ] 37 | ) 38 | mask_tfms = transforms.Compose( 39 | [ 40 | transforms.Resize( 41 | size=(size, size), 42 | interpolation=transforms.InterpolationMode.NEAREST, 43 | ), 44 | transforms.ToTensor(), 45 | ] 46 | ) 47 | 48 | return tfms, mask_tfms 49 | 50 | 51 | def get_palm_loaders( 52 | size, batch_size=64, num_workers=8, joint_mask=True, train_pct=0.6, val_pct=0.2 53 | ): 54 | """get dataloaders for APTOS dataset, and also return number of labels""" 55 | loaders = [] 56 | tfms, mask_tfms = get_tfms(size=size) 57 | for split in ["train", "valid", "test"]: 58 | D = PALM( 59 | split=split, 60 | tfms=tfms, 61 | mask_tfms=mask_tfms, 62 | joint_mask=joint_mask, 63 | train_pct=train_pct, 64 | val_pct=val_pct, 65 | ) 66 | loader = DataLoader( 67 | D, 68 | batch_size=batch_size, 69 | shuffle=split == "train", 70 | num_workers=num_workers, 71 | pin_memory=True, 72 | ) 73 | loaders.append(loader) 74 | 75 | return loaders 76 | 77 | 78 | class PALM(Dataset): 79 | def __init__( 80 | self, 81 | split="train", 82 | tfms=None, 83 | mask_tfms=None, 84 | joint_mask=True, 85 | split_seed=42, 86 | train_pct=0.6, 87 | val_pct=0.2, 88 | ): 89 | img_subdir = "PALM-Training400" 90 | disk_subdir = "Disc_Masks" 91 | atrophy_subdir = "Lesion_Masks/Atrophy/" 92 | detachment_subdir = "Lesion_Masks/Detachment/" 93 | 94 | img_paths = glob(join(BASE, img_subdir, "*.jpg")) 95 | ids = [x.split("/")[-1].split(".")[0] for x in img_paths] 96 | 97 | disks = [ 98 | x if os.path.isfile(x) else "" 99 | for x in [join(BASE, disk_subdir, f"{id}.bmp") for id in ids] 100 | ] 101 | atrophies = [ 102 | x if os.path.isfile(x) else "" 103 | for x in [join(BASE, atrophy_subdir, f"{id}.bmp") for id in ids] 104 | ] 105 | detachments = [ 106 | x if os.path.isfile(x) else "" 107 | for x in [join(BASE, detachment_subdir, f"{id}.bmp") for id in ids] 108 | ] 109 | classes = [1 * (x[0] == "H") for x in ids] 110 | 111 | df = pd.DataFrame( 112 | { 113 | "img_path": img_paths, 114 | "disk": disks, 115 | "atrophy": atrophies, 116 | "detachment": detachments, 117 | "classes": classes, 118 | } 119 | ) 120 | 121 | sss = StratifiedShuffleSplit( 122 | n_splits=1, test_size=1 - train_pct, random_state=split_seed 123 | ) 124 | [(train_inds, val_test_inds)] = sss.split(df.index, df.classes) 125 | sss = StratifiedShuffleSplit( 126 | n_splits=1, 127 | test_size=1 - val_pct / (1 - train_pct), 128 | random_state=split_seed + 1, 129 | ) 130 | [(val_inds, test_inds)] = sss.split( 131 | df.loc[val_test_inds], df.loc[val_test_inds].classes 132 | ) 133 | val_inds = val_test_inds[val_inds] 134 | test_inds = val_test_inds[test_inds] 135 | 136 | if split == "train": 137 | inds = train_inds 138 | elif split in ["val", "valid"]: 139 | inds = val_inds 140 | elif split == "test": 141 | inds = test_inds 142 | else: 143 | raise ValueError(split) 144 | 145 | self.df = df.loc[inds] 146 | 147 | self.tfms = tfms 148 | self.joint_mask = joint_mask 149 | self.mask_tfms = mask_tfms 150 | 151 | def __len__(self): 152 | return len(self.df) 153 | 154 | def load_default(self, p, def_size=(256, 256)): 155 | if os.path.isfile(p): 156 | img = Image.open(p) 157 | else: 158 | img = Image.new("L", size=def_size, color=255) 159 | return img 160 | 161 | def __getitem__(self, idx): 162 | if isinstance(idx, torch.Tensor): 163 | idx = idx.item() 164 | inst = self.df.iloc[idx] 165 | img = Image.open(inst.img_path) 166 | disk = Image.open(inst.disk) 167 | atrophy = self.load_default(inst.atrophy, def_size=img.size) 168 | if self.tfms: 169 | img = self.tfms(img) 170 | if self.mask_tfms: 171 | disk = 1 - self.mask_tfms(disk) 172 | atrophy = 1 - self.mask_tfms(atrophy) 173 | if self.joint_mask: 174 | mask = torch.cat([disk, atrophy]) 175 | return img, mask 176 | else: 177 | return img, disk, atrophy 178 | -------------------------------------------------------------------------------- /models/resnet_unet.py: -------------------------------------------------------------------------------- 1 | """ 2 | https://github.com/kevinlu1211/pytorch-unet-resnet-50-encoder 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | # resnet = torchvision.models.resnet.resnet50(pretrained=True) 9 | 10 | 11 | class ConvBlock(nn.Module): 12 | """ 13 | Helper module that consists of a Conv -> BN -> ReLU 14 | """ 15 | 16 | def __init__( 17 | self, 18 | in_channels, 19 | out_channels, 20 | padding=1, 21 | kernel_size=3, 22 | stride=1, 23 | with_nonlinearity=True, 24 | ): 25 | super().__init__() 26 | self.conv = nn.Conv2d( 27 | in_channels, 28 | out_channels, 29 | padding=padding, 30 | kernel_size=kernel_size, 31 | stride=stride, 32 | ) 33 | self.bn = nn.BatchNorm2d(out_channels) 34 | self.relu = nn.ReLU() 35 | self.with_nonlinearity = with_nonlinearity 36 | 37 | def forward(self, x): 38 | x = self.conv(x) 39 | x = self.bn(x) 40 | if self.with_nonlinearity: 41 | x = self.relu(x) 42 | return x 43 | 44 | 45 | class Bridge(nn.Module): 46 | """ 47 | This is the middle layer of the UNet which just consists of some 48 | """ 49 | 50 | def __init__(self, in_channels, out_channels): 51 | super().__init__() 52 | self.bridge = nn.Sequential( 53 | ConvBlock(in_channels, out_channels), ConvBlock(out_channels, out_channels) 54 | ) 55 | 56 | def forward(self, x): 57 | return self.bridge(x) 58 | 59 | 60 | class UpBlockForUNetWithResNet50(nn.Module): 61 | """ 62 | Up block that encapsulates one up-sampling step which consists of Upsample -> ConvBlock -> ConvBlock 63 | """ 64 | 65 | def __init__( 66 | self, 67 | in_channels, 68 | out_channels, 69 | up_conv_in_channels=None, 70 | up_conv_out_channels=None, 71 | upsampling_method="conv_transpose", 72 | ): 73 | super().__init__() 74 | 75 | if up_conv_in_channels == None: 76 | up_conv_in_channels = in_channels 77 | if up_conv_out_channels == None: 78 | up_conv_out_channels = out_channels 79 | 80 | if upsampling_method == "conv_transpose": 81 | self.upsample = nn.ConvTranspose2d( 82 | up_conv_in_channels, up_conv_out_channels, kernel_size=2, stride=2 83 | ) 84 | elif upsampling_method == "bilinear": 85 | self.upsample = nn.Sequential( 86 | nn.Upsample(mode="bilinear", scale_factor=2), 87 | nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1), 88 | ) 89 | self.conv_block_1 = ConvBlock(in_channels, out_channels) 90 | self.conv_block_2 = ConvBlock(out_channels, out_channels) 91 | 92 | def forward(self, up_x, down_x): 93 | """ 94 | 95 | :param up_x: this is the output from the previous up block 96 | :param down_x: this is the output from the down block 97 | :return: upsampled feature map 98 | """ 99 | x = self.upsample(up_x) 100 | x = torch.cat([x, down_x], 1) 101 | x = self.conv_block_1(x) 102 | x = self.conv_block_2(x) 103 | return x 104 | 105 | 106 | class UNetWithResnet50Encoder(nn.Module): 107 | DEPTH = 6 108 | 109 | def __init__( 110 | self, 111 | resnet, 112 | n_classes=2, 113 | ): 114 | super().__init__() 115 | down_blocks = [] 116 | up_blocks = [] 117 | self.input_block = nn.Sequential(*list(resnet.children()))[:3] 118 | self.input_pool = list(resnet.children())[3] 119 | for bottleneck in list(resnet.children()): 120 | if isinstance(bottleneck, nn.Sequential): 121 | down_blocks.append(bottleneck) 122 | self.down_blocks = nn.ModuleList(down_blocks) 123 | self.bridge = Bridge(2048, 2048) 124 | up_blocks.append(UpBlockForUNetWithResNet50(2048, 1024)) 125 | up_blocks.append(UpBlockForUNetWithResNet50(1024, 512)) 126 | up_blocks.append(UpBlockForUNetWithResNet50(512, 256)) 127 | up_blocks.append( 128 | UpBlockForUNetWithResNet50( 129 | in_channels=128 + 64, 130 | out_channels=128, 131 | up_conv_in_channels=256, 132 | up_conv_out_channels=128, 133 | ) 134 | ) 135 | up_blocks.append( 136 | UpBlockForUNetWithResNet50( 137 | in_channels=64 + 3, 138 | out_channels=64, 139 | up_conv_in_channels=128, 140 | up_conv_out_channels=64, 141 | ) 142 | ) 143 | 144 | self.up_blocks = nn.ModuleList(up_blocks) 145 | 146 | self.out = nn.Conv2d(64, n_classes, kernel_size=1, stride=1) 147 | 148 | def forward(self, x, with_output_feature_map=False): 149 | pre_pools = dict() 150 | pre_pools[f"layer_0"] = x 151 | x = self.input_block(x) 152 | pre_pools[f"layer_1"] = x 153 | x = self.input_pool(x) 154 | 155 | for i, block in enumerate(self.down_blocks, 2): 156 | x = block(x) 157 | if i == (UNetWithResnet50Encoder.DEPTH - 1): 158 | continue 159 | pre_pools[f"layer_{i}"] = x 160 | 161 | x = self.bridge(x) 162 | 163 | for i, block in enumerate(self.up_blocks, 1): 164 | key = f"layer_{UNetWithResnet50Encoder.DEPTH - 1 - i}" 165 | x = block(x, pre_pools[key]) 166 | output_feature_map = x 167 | x = self.out(x) 168 | del pre_pools 169 | if with_output_feature_map: 170 | return x, output_feature_map 171 | else: 172 | return x 173 | 174 | 175 | # model = UNetWithResnet50Encoder().cuda() 176 | # inp = torch.rand((2, 3, 512, 512)).cuda() 177 | # out = model(inp) 178 | -------------------------------------------------------------------------------- /data/data_aptos.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ["OPENBLAS_NUM_THREADS"] = "1" 4 | os.environ["MKL_NUM_THREADS"] = "1" 5 | os.environ["NUMEXPR_NUM_THREADS"] = "1" 6 | 7 | from os.path import join 8 | from functools import partial 9 | 10 | import numpy as np 11 | import pandas as pd 12 | import torch 13 | from PIL import Image 14 | from torch.utils.data import DataLoader 15 | from torch.utils.data.dataset import Dataset 16 | from torchvision import transforms 17 | from tqdm import tqdm 18 | import toml 19 | 20 | torch.multiprocessing.set_sharing_strategy("file_system") 21 | 22 | BASE = toml.load(join(os.path.dirname(os.path.realpath(__file__)), "../paths.toml"))[ 23 | "APTOS_PATH" 24 | ] 25 | 26 | 27 | def get_tfms(size=256, augmentation=False): 28 | mean = [0.485, 0.456, 0.406] 29 | std = [0.229, 0.224, 0.225] 30 | if augmentation: 31 | tfms = transforms.Compose( 32 | [ 33 | transforms.Resize(size=size), 34 | transforms.RandomHorizontalFlip(p=0.5), 35 | transforms.ToTensor(), 36 | transforms.Normalize(mean=mean, std=std), 37 | ] 38 | ) 39 | else: 40 | tfms = transforms.Compose( 41 | [ 42 | transforms.Resize(size=size), 43 | transforms.ToTensor(), 44 | transforms.Normalize(mean=mean, std=std), 45 | ] 46 | ) 47 | return tfms 48 | 49 | 50 | def get_aptos_loaders( 51 | size, 52 | batch_size=64, 53 | num_workers=8, 54 | multilabel=True, 55 | train_pct=0.6, 56 | val_pct=0.2, 57 | ): 58 | """get dataloaders for APTOS dataset, and also return number of labels""" 59 | loaders = [] 60 | for split in ["train", "valid", "test"]: 61 | tfms = get_tfms(size=size, augmentation=split == "train") 62 | D = APTOS( 63 | split=split, 64 | tfms=tfms, 65 | multilabel=multilabel, 66 | train_pct=train_pct, 67 | val_pct=val_pct, 68 | ) 69 | loader = DataLoader( 70 | D, 71 | batch_size=batch_size, 72 | shuffle=split == "train", 73 | num_workers=num_workers, 74 | pin_memory=True, 75 | ) 76 | loaders.append(loader) 77 | 78 | return loaders 79 | 80 | 81 | class APTOS(Dataset): 82 | def __init__( 83 | self, 84 | split="train", 85 | tfms=None, 86 | split_seed=42, 87 | train_pct=0.6, 88 | val_pct=0.2, 89 | use_cropped=True, 90 | multilabel=True, 91 | ): 92 | label_fn = "train.csv" 93 | img_subdir = "train_images" 94 | if use_cropped: 95 | img_subdir = img_subdir + "_cropped" 96 | 97 | label_pth = join(BASE, label_fn) 98 | self.labels = pd.read_csv(label_pth) 99 | 100 | # data split 101 | rng = np.random.RandomState(split_seed) 102 | N = len(self.labels) 103 | perm = rng.permutation(N) 104 | m = int(N * train_pct) 105 | mv = int(N * (train_pct + val_pct)) 106 | if split == "train": 107 | self.labels = self.labels.iloc[perm[:m]] 108 | elif split in ["val", "valid"]: 109 | self.labels = self.labels.iloc[perm[m:mv]] 110 | elif split == "test": 111 | self.labels = self.labels.iloc[perm[mv:]] 112 | else: 113 | raise ValueError(f"split {split} not a valid option") 114 | self.ext = ".png" 115 | self.img_dir = join(BASE, img_subdir) 116 | self.tfms = tfms 117 | self.multilabel = multilabel 118 | 119 | def __len__(self): 120 | return len(self.labels) 121 | 122 | def __getitem__(self, idx): 123 | if isinstance(idx, torch.Tensor): 124 | idx = idx.item() 125 | inst = self.labels.iloc[idx] 126 | label = inst.diagnosis 127 | if self.multilabel: 128 | if label == 0: 129 | label = [1, 0, 0, 0, 0] 130 | elif label == 1: 131 | label = [1, 1, 0, 0, 0] 132 | elif label == 2: 133 | label = [1, 1, 1, 0, 0] 134 | elif label == 3: 135 | label = [1, 1, 1, 1, 0] 136 | elif label == 4: 137 | label = [1, 1, 1, 1, 1] 138 | label = np.array(label) 139 | id = inst.id_code 140 | p = join(self.img_dir, str(id) + self.ext) 141 | img = Image.open(p) 142 | if self.tfms: 143 | img = self.tfms(img) 144 | return img, label.astype(np.float) 145 | 146 | 147 | #### data standardization utils 148 | def crop_resize_all(split, dst, size=512, buffer=10, n_jobs=10): 149 | """prepare APTOS images by disc-cropping or center-crop-padding""" 150 | os.makedirs(dst, exist_ok=True) 151 | tfms = transforms.Compose( 152 | [ 153 | partial(center_crop_pad, buffer=buffer), 154 | transforms.Resize(size), 155 | ] 156 | ) 157 | label_pth = join(BASE, "test.csv" if split == "test" else "train.csv") 158 | img_subdir = "test_images" if split == "test" else "train_images" 159 | paths = [ 160 | join(BASE, img_subdir, f"{id}.png") for id in pd.read_csv(label_pth).id_code 161 | ] 162 | out_paths = [join(dst, p.split("/")[-1]) for p in paths] 163 | 164 | def process(p): 165 | box = detect_circle(p) 166 | img = Image.open(p) 167 | if box is not None: 168 | img = img.crop(box) 169 | else: 170 | img = center_crop_pad(img) 171 | return img.resize((size, size)) 172 | 173 | for i in tqdm(range(len(paths))): 174 | img = process(paths[i]) 175 | img.save(out_paths[i]) 176 | 177 | 178 | def center_crop_pad(img, buffer=0, min_mean=10): 179 | """dynamically center crop image, cropping away black space left and right""" 180 | g = np.array(img).mean(-1) 181 | h, w = g.shape 182 | zeros = g.mean(0) 183 | zero_inds = np.where(zeros < min_mean)[0] 184 | if len(zero_inds) == 0 or zero_inds.min() > w // 2 or zero_inds.max() < w // 2: 185 | return expand2square(img) 186 | lo, hi = zero_inds[zero_inds < w // 2].max(), zero_inds[zero_inds > w // 2].min() 187 | return expand2square(img.crop((lo - buffer, 0, hi + buffer, h))) 188 | 189 | 190 | def detect_circle(p, buf=5): 191 | # only for preprocessing, so no need to install otherwise 192 | import cv2 193 | from skimage.draw import circle_perimeter 194 | 195 | img = cv2.imread(p) 196 | h, w = img.shape[:2] 197 | img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 198 | img = cv2.medianBlur(img, 25) 199 | circles = cv2.HoughCircles( 200 | img, 201 | cv2.HOUGH_GRADIENT, 202 | 1, 203 | minDist=50, 204 | minRadius=min(h, w) // 4, 205 | ) 206 | if circles is None: 207 | return 208 | else: 209 | C = circles[0, 0].round().astype(int) 210 | cc, rr = circle_perimeter(C[0], C[1], C[2]) 211 | return [cc.min() - buf, rr.min() - buf, cc.max() + buf, rr.max() + buf] 212 | 213 | 214 | def expand2square(pil_img, background_color=0): 215 | """from https://note.nkmk.me/en/python-pillow-add-margin-expand-canvas/""" 216 | width, height = pil_img.size 217 | if width == height: 218 | return pil_img 219 | elif width > height: 220 | result = Image.new(pil_img.mode, (width, width), background_color) 221 | result.paste(pil_img, (0, (width - height) // 2)) 222 | return result 223 | else: 224 | result = Image.new(pil_img.mode, (height, height), background_color) 225 | result.paste(pil_img, ((height - width) // 2, 0)) 226 | return result 227 | -------------------------------------------------------------------------------- /main_pretrain.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import json 4 | import numpy as np 5 | import os 6 | import time 7 | from pathlib import Path 8 | 9 | import torch 10 | import torch.backends.cudnn as cudnn 11 | from torch.utils.tensorboard import SummaryWriter 12 | import torchvision.transforms as transforms 13 | import torchvision.datasets as datasets 14 | 15 | import timm 16 | 17 | assert timm.__version__ == "0.3.2" # version check 18 | import timm.optim.optim_factory as optim_factory 19 | 20 | import util.misc as misc 21 | from util.misc import NativeScalerWithGradNormCount as NativeScaler 22 | 23 | import models_mae 24 | 25 | from engine_pretrain import train_one_epoch 26 | 27 | 28 | def get_args_parser(): 29 | parser = argparse.ArgumentParser('MAE pre-training', add_help=False) 30 | parser.add_argument('--batch_size', default=64, type=int, 31 | help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus') 32 | parser.add_argument('--epochs', default=400, type=int) 33 | parser.add_argument('--accum_iter', default=1, type=int, 34 | help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)') 35 | 36 | # Model parameters 37 | parser.add_argument('--model', default='mae_vit_large_patch16', type=str, metavar='MODEL', 38 | help='Name of model to train') 39 | 40 | parser.add_argument('--input_size', default=224, type=int, 41 | help='images input size') 42 | 43 | parser.add_argument('--mask_ratio', default=0.75, type=float, 44 | help='Masking ratio (percentage of removed patches).') 45 | 46 | parser.add_argument('--norm_pix_loss', action='store_true', 47 | help='Use (per-patch) normalized pixels as targets for computing loss') 48 | parser.set_defaults(norm_pix_loss=False) 49 | 50 | # Optimizer parameters 51 | parser.add_argument('--weight_decay', type=float, default=0.05, 52 | help='weight decay (default: 0.05)') 53 | 54 | parser.add_argument('--lr', type=float, default=None, metavar='LR', 55 | help='learning rate (absolute lr)') 56 | parser.add_argument('--blr', type=float, default=1e-3, metavar='LR', 57 | help='base learning rate: absolute_lr = base_lr * total_batch_size / 256') 58 | parser.add_argument('--min_lr', type=float, default=0., metavar='LR', 59 | help='lower lr bound for cyclic schedulers that hit 0') 60 | 61 | parser.add_argument('--warmup_epochs', type=int, default=40, metavar='N', 62 | help='epochs to warmup LR') 63 | 64 | # Dataset parameters 65 | parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str, 66 | help='dataset path') 67 | 68 | parser.add_argument('--output_dir', default='./output_dir', 69 | help='path where to save, empty for no saving') 70 | parser.add_argument('--log_dir', default='./output_dir', 71 | help='path where to tensorboard log') 72 | parser.add_argument('--device', default='cuda', 73 | help='device to use for training / testing') 74 | parser.add_argument('--seed', default=0, type=int) 75 | parser.add_argument('--resume', default='', 76 | help='resume from checkpoint') 77 | 78 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 79 | help='start epoch') 80 | parser.add_argument('--num_workers', default=10, type=int) 81 | parser.add_argument('--pin_mem', action='store_true', 82 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 83 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem') 84 | parser.set_defaults(pin_mem=True) 85 | 86 | # distributed training parameters 87 | parser.add_argument('--world_size', default=1, type=int, 88 | help='number of distributed processes') 89 | parser.add_argument('--local_rank', default=-1, type=int) 90 | parser.add_argument('--dist_on_itp', action='store_true') 91 | parser.add_argument('--dist_url', default='env://', 92 | help='url used to set up distributed training') 93 | 94 | return parser 95 | 96 | 97 | def main(args): 98 | misc.init_distributed_mode(args) 99 | 100 | print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) 101 | print("{}".format(args).replace(', ', ',\n')) 102 | 103 | device = torch.device(args.device) 104 | 105 | # fix the seed for reproducibility 106 | seed = args.seed + misc.get_rank() 107 | torch.manual_seed(seed) 108 | np.random.seed(seed) 109 | 110 | cudnn.benchmark = True 111 | 112 | # simple augmentation 113 | transform_train = transforms.Compose([ 114 | transforms.RandomResizedCrop(args.input_size, scale=(0.2, 1.0), interpolation=3), # 3 is bicubic 115 | transforms.RandomHorizontalFlip(), 116 | transforms.ToTensor(), 117 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 118 | dataset_train = datasets.ImageFolder(os.path.join(args.data_path, 'train'), transform=transform_train) 119 | print(dataset_train) 120 | 121 | if True: # args.distributed: 122 | num_tasks = misc.get_world_size() 123 | global_rank = misc.get_rank() 124 | sampler_train = torch.utils.data.DistributedSampler( 125 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 126 | ) 127 | print("Sampler_train = %s" % str(sampler_train)) 128 | else: 129 | sampler_train = torch.utils.data.RandomSampler(dataset_train) 130 | 131 | if global_rank == 0 and args.log_dir is not None: 132 | os.makedirs(args.log_dir, exist_ok=True) 133 | log_writer = SummaryWriter(log_dir=args.log_dir) 134 | else: 135 | log_writer = None 136 | 137 | data_loader_train = torch.utils.data.DataLoader( 138 | dataset_train, sampler=sampler_train, 139 | batch_size=args.batch_size, 140 | num_workers=args.num_workers, 141 | pin_memory=args.pin_mem, 142 | drop_last=True, 143 | ) 144 | 145 | # define the model 146 | model = models_mae.__dict__[args.model](norm_pix_loss=args.norm_pix_loss) 147 | 148 | model.to(device) 149 | 150 | model_without_ddp = model 151 | print("Model = %s" % str(model_without_ddp)) 152 | 153 | eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() 154 | 155 | if args.lr is None: # only base_lr is specified 156 | args.lr = args.blr * eff_batch_size / 256 157 | 158 | print("base lr: %.2e" % (args.lr * 256 / eff_batch_size)) 159 | print("actual lr: %.2e" % args.lr) 160 | 161 | print("accumulate grad iterations: %d" % args.accum_iter) 162 | print("effective batch size: %d" % eff_batch_size) 163 | 164 | if args.distributed: 165 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True) 166 | model_without_ddp = model.module 167 | 168 | # following timm: set wd as 0 for bias and norm layers 169 | param_groups = optim_factory.add_weight_decay(model_without_ddp, args.weight_decay) 170 | optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95)) 171 | print(optimizer) 172 | loss_scaler = NativeScaler() 173 | 174 | misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler) 175 | 176 | print(f"Start training for {args.epochs} epochs") 177 | start_time = time.time() 178 | for epoch in range(args.start_epoch, args.epochs): 179 | if args.distributed: 180 | data_loader_train.sampler.set_epoch(epoch) 181 | train_stats = train_one_epoch( 182 | model, data_loader_train, 183 | optimizer, device, epoch, loss_scaler, 184 | log_writer=log_writer, 185 | args=args 186 | ) 187 | if args.output_dir and (epoch % 20 == 0 or epoch + 1 == args.epochs): 188 | misc.save_model( 189 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 190 | loss_scaler=loss_scaler, epoch=epoch) 191 | 192 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 193 | 'epoch': epoch,} 194 | 195 | if args.output_dir and misc.is_main_process(): 196 | if log_writer is not None: 197 | log_writer.flush() 198 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: 199 | f.write(json.dumps(log_stats) + "\n") 200 | 201 | total_time = time.time() - start_time 202 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 203 | print('Training time {}'.format(total_time_str)) 204 | 205 | 206 | if __name__ == '__main__': 207 | args = get_args_parser() 208 | args = args.parse_args() 209 | if args.output_dir: 210 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 211 | main(args) 212 | -------------------------------------------------------------------------------- /models_mrm.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from timm.models.vision_transformer import PatchEmbed, Block 7 | 8 | from util.pos_embed import get_2d_sincos_pos_embed 9 | 10 | 11 | class MaskedAutoencoderViT(nn.Module): 12 | """ Masked Autoencoder with VisionTransformer backbone 13 | """ 14 | def __init__(self, img_size=224, patch_size=16, in_chans=3, 15 | embed_dim=1024, depth=24, num_heads=16, 16 | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, 17 | mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False): 18 | super().__init__() 19 | 20 | # -------------------------------------------------------------------------- 21 | # MAE encoder specifics 22 | self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim) 23 | num_patches = self.patch_embed.num_patches 24 | 25 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 26 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding 27 | 28 | self.blocks = nn.ModuleList([ 29 | Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer) 30 | for i in range(depth)]) 31 | self.norm = norm_layer(embed_dim) 32 | # -------------------------------------------------------------------------- 33 | 34 | # -------------------------------------------------------------------------- 35 | # MAE decoder specifics 36 | self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True) 37 | 38 | self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) 39 | 40 | self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False) # fixed sin-cos embedding 41 | 42 | self.decoder_blocks = nn.ModuleList([ 43 | Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer) 44 | for i in range(decoder_depth)]) 45 | 46 | self.decoder_norm = norm_layer(decoder_embed_dim) 47 | self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # decoder to patch 48 | # -------------------------------------------------------------------------- 49 | 50 | self.norm_pix_loss = norm_pix_loss 51 | 52 | self.initialize_weights() 53 | 54 | def initialize_weights(self): 55 | # initialization 56 | # initialize (and freeze) pos_embed by sin-cos embedding 57 | pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True) 58 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) 59 | 60 | decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True) 61 | self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) 62 | 63 | # initialize patch_embed like nn.Linear (instead of nn.Conv2d) 64 | w = self.patch_embed.proj.weight.data 65 | torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) 66 | 67 | # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) 68 | torch.nn.init.normal_(self.cls_token, std=.02) 69 | torch.nn.init.normal_(self.mask_token, std=.02) 70 | 71 | # initialize nn.Linear and nn.LayerNorm 72 | self.apply(self._init_weights) 73 | 74 | def _init_weights(self, m): 75 | if isinstance(m, nn.Linear): 76 | # we use xavier_uniform following official JAX ViT: 77 | torch.nn.init.xavier_uniform_(m.weight) 78 | if isinstance(m, nn.Linear) and m.bias is not None: 79 | nn.init.constant_(m.bias, 0) 80 | elif isinstance(m, nn.LayerNorm): 81 | nn.init.constant_(m.bias, 0) 82 | nn.init.constant_(m.weight, 1.0) 83 | 84 | def patchify(self, imgs): 85 | """ 86 | imgs: (N, 3, H, W) 87 | x: (N, L, patch_size**2 *3) 88 | """ 89 | p = self.patch_embed.patch_size[0] 90 | assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 91 | 92 | h = w = imgs.shape[2] // p 93 | x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) 94 | x = torch.einsum('nchpwq->nhwpqc', x) 95 | x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3)) 96 | return x 97 | 98 | def unpatchify(self, x): 99 | """ 100 | x: (N, L, patch_size**2 *3) 101 | imgs: (N, 3, H, W) 102 | """ 103 | p = self.patch_embed.patch_size[0] 104 | h = w = int(x.shape[1]**.5) 105 | assert h * w == x.shape[1] 106 | 107 | x = x.reshape(shape=(x.shape[0], h, w, p, p, 3)) 108 | x = torch.einsum('nhwpqc->nchpwq', x) 109 | imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) 110 | return imgs 111 | 112 | def random_masking(self, x, mask_ratio): 113 | """ 114 | Perform per-sample random masking by per-sample shuffling. 115 | Per-sample shuffling is done by argsort random noise. 116 | x: [N, L, D], sequence 117 | """ 118 | N, L, D = x.shape # batch, length, dim 119 | len_keep = int(L * (1 - mask_ratio)) 120 | 121 | noise = torch.rand(N, L, device=x.device) # noise in [0, 1] 122 | 123 | # sort noise for each sample 124 | ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove 125 | ids_restore = torch.argsort(ids_shuffle, dim=1) 126 | 127 | # keep the first subset 128 | ids_keep = ids_shuffle[:, :len_keep] 129 | x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) 130 | 131 | # generate the binary mask: 0 is keep, 1 is remove 132 | mask = torch.ones([N, L], device=x.device) 133 | mask[:, :len_keep] = 0 134 | # unshuffle to get the binary mask 135 | mask = torch.gather(mask, dim=1, index=ids_restore) 136 | 137 | return x_masked, mask, ids_restore 138 | 139 | def forward_encoder(self, x, mask_ratio): 140 | # embed patches 141 | x = self.patch_embed(x) 142 | 143 | # add pos embed w/o cls token 144 | x = x + self.pos_embed[:, 1:, :] 145 | 146 | # masking: length -> length * mask_ratio 147 | x, mask, ids_restore = self.random_masking(x, mask_ratio) 148 | 149 | # append cls token 150 | cls_token = self.cls_token + self.pos_embed[:, :1, :] 151 | cls_tokens = cls_token.expand(x.shape[0], -1, -1) 152 | x = torch.cat((cls_tokens, x), dim=1) 153 | 154 | # apply Transformer blocks 155 | for blk in self.blocks: 156 | x = blk(x) 157 | x = self.norm(x) 158 | 159 | return x, mask, ids_restore 160 | 161 | def forward_decoder(self, x, ids_restore): 162 | # embed tokens 163 | x = self.decoder_embed(x) 164 | 165 | # append mask tokens to sequence 166 | mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) 167 | x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token 168 | x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle 169 | x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token 170 | 171 | # add pos embed 172 | x = x + self.decoder_pos_embed 173 | 174 | # apply Transformer blocks 175 | for blk in self.decoder_blocks: 176 | x = blk(x) 177 | x = self.decoder_norm(x) 178 | 179 | # predictor projection 180 | x = self.decoder_pred(x) 181 | 182 | # remove cls token 183 | x = x[:, 1:, :] 184 | 185 | return x 186 | 187 | def forward_loss(self, imgs, pred, mask): 188 | """ 189 | imgs: [N, 3, H, W] 190 | pred: [N, L, p*p*3] 191 | mask: [N, L], 0 is keep, 1 is remove, 192 | """ 193 | target = self.patchify(imgs) 194 | if self.norm_pix_loss: 195 | mean = target.mean(dim=-1, keepdim=True) 196 | var = target.var(dim=-1, keepdim=True) 197 | target = (target - mean) / (var + 1.e-6)**.5 198 | 199 | loss = (pred - target) ** 2 200 | loss = loss.mean(dim=-1) # [N, L], mean loss per patch 201 | 202 | loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches 203 | return loss 204 | 205 | def forward(self, imgs, mask_ratio=0.75): 206 | latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio) 207 | pred = self.forward_decoder(latent, ids_restore) # [N, L, p*p*3] 208 | loss = self.forward_loss(imgs, pred, mask) 209 | return loss, pred, mask 210 | 211 | 212 | def mae_vit_base_patch16_dec512d8b(**kwargs): 213 | model = MaskedAutoencoderViT( 214 | patch_size=16, embed_dim=768, depth=12, num_heads=12, 215 | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, 216 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 217 | return model 218 | 219 | 220 | def mae_vit_large_patch16_dec512d8b(**kwargs): 221 | model = MaskedAutoencoderViT( 222 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, 223 | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, 224 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 225 | return model 226 | 227 | 228 | def mae_vit_huge_patch14_dec512d8b(**kwargs): 229 | model = MaskedAutoencoderViT( 230 | patch_size=14, embed_dim=1280, depth=32, num_heads=16, 231 | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, 232 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 233 | return model 234 | 235 | 236 | # set recommended archs 237 | mae_vit_base_patch16 = mae_vit_base_patch16_dec512d8b # decoder: 512 dim, 8 blocks 238 | mae_vit_large_patch16 = mae_vit_large_patch16_dec512d8b # decoder: 512 dim, 8 blocks 239 | mae_vit_huge_patch14 = mae_vit_huge_patch14_dec512d8b # decoder: 512 dim, 8 blocks 240 | -------------------------------------------------------------------------------- /util/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import builtins 13 | import datetime 14 | import os 15 | import time 16 | from collections import defaultdict, deque 17 | from pathlib import Path 18 | 19 | import torch 20 | import torch.distributed as dist 21 | from torch._six import inf 22 | 23 | 24 | class SmoothedValue(object): 25 | """Track a series of values and provide access to smoothed values over a 26 | window or the global series average. 27 | """ 28 | 29 | def __init__(self, window_size=20, fmt=None): 30 | if fmt is None: 31 | fmt = "{median:.4f} ({global_avg:.4f})" 32 | self.deque = deque(maxlen=window_size) 33 | self.total = 0.0 34 | self.count = 0 35 | self.fmt = fmt 36 | 37 | def update(self, value, n=1): 38 | self.deque.append(value) 39 | self.count += n 40 | self.total += value * n 41 | 42 | def synchronize_between_processes(self): 43 | """ 44 | Warning: does not synchronize the deque! 45 | """ 46 | if not is_dist_avail_and_initialized(): 47 | return 48 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 49 | dist.barrier() 50 | dist.all_reduce(t) 51 | t = t.tolist() 52 | self.count = int(t[0]) 53 | self.total = t[1] 54 | 55 | @property 56 | def median(self): 57 | d = torch.tensor(list(self.deque)) 58 | return d.median().item() 59 | 60 | @property 61 | def avg(self): 62 | d = torch.tensor(list(self.deque), dtype=torch.float32) 63 | return d.mean().item() 64 | 65 | @property 66 | def global_avg(self): 67 | return self.total / self.count 68 | 69 | @property 70 | def max(self): 71 | return max(self.deque) 72 | 73 | @property 74 | def value(self): 75 | return self.deque[-1] 76 | 77 | def __str__(self): 78 | return self.fmt.format( 79 | median=self.median, 80 | avg=self.avg, 81 | global_avg=self.global_avg, 82 | max=self.max, 83 | value=self.value) 84 | 85 | 86 | class MetricLogger(object): 87 | def __init__(self, delimiter="\t"): 88 | self.meters = defaultdict(SmoothedValue) 89 | self.delimiter = delimiter 90 | 91 | def update(self, **kwargs): 92 | for k, v in kwargs.items(): 93 | if v is None: 94 | continue 95 | if isinstance(v, torch.Tensor): 96 | v = v.item() 97 | assert isinstance(v, (float, int)) 98 | self.meters[k].update(v) 99 | 100 | def __getattr__(self, attr): 101 | if attr in self.meters: 102 | return self.meters[attr] 103 | if attr in self.__dict__: 104 | return self.__dict__[attr] 105 | raise AttributeError("'{}' object has no attribute '{}'".format( 106 | type(self).__name__, attr)) 107 | 108 | def __str__(self): 109 | loss_str = [] 110 | for name, meter in self.meters.items(): 111 | loss_str.append( 112 | "{}: {}".format(name, str(meter)) 113 | ) 114 | return self.delimiter.join(loss_str) 115 | 116 | def synchronize_between_processes(self): 117 | for meter in self.meters.values(): 118 | meter.synchronize_between_processes() 119 | 120 | def add_meter(self, name, meter): 121 | self.meters[name] = meter 122 | 123 | def log_every(self, iterable, print_freq, header=None): 124 | i = 0 125 | if not header: 126 | header = '' 127 | start_time = time.time() 128 | end = time.time() 129 | iter_time = SmoothedValue(fmt='{avg:.4f}') 130 | data_time = SmoothedValue(fmt='{avg:.4f}') 131 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 132 | log_msg = [ 133 | header, 134 | '[{0' + space_fmt + '}/{1}]', 135 | 'eta: {eta}', 136 | '{meters}', 137 | 'time: {time}', 138 | 'data: {data}' 139 | ] 140 | if torch.cuda.is_available(): 141 | log_msg.append('max mem: {memory:.0f}') 142 | log_msg = self.delimiter.join(log_msg) 143 | MB = 1024.0 * 1024.0 144 | for obj in iterable: 145 | data_time.update(time.time() - end) 146 | yield obj 147 | iter_time.update(time.time() - end) 148 | if i % print_freq == 0 or i == len(iterable) - 1: 149 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 150 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 151 | if torch.cuda.is_available(): 152 | print(log_msg.format( 153 | i, len(iterable), eta=eta_string, 154 | meters=str(self), 155 | time=str(iter_time), data=str(data_time), 156 | memory=torch.cuda.max_memory_allocated() / MB)) 157 | else: 158 | print(log_msg.format( 159 | i, len(iterable), eta=eta_string, 160 | meters=str(self), 161 | time=str(iter_time), data=str(data_time))) 162 | i += 1 163 | end = time.time() 164 | total_time = time.time() - start_time 165 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 166 | print('{} Total time: {} ({:.4f} s / it)'.format( 167 | header, total_time_str, total_time / len(iterable))) 168 | 169 | 170 | def setup_for_distributed(is_master): 171 | """ 172 | This function disables printing when not in master process 173 | """ 174 | builtin_print = builtins.print 175 | 176 | def print(*args, **kwargs): 177 | force = kwargs.pop('force', False) 178 | force = force or (get_world_size() > 8) 179 | if is_master or force: 180 | now = datetime.datetime.now().time() 181 | builtin_print('[{}] '.format(now), end='') # print with time stamp 182 | builtin_print(*args, **kwargs) 183 | 184 | builtins.print = print 185 | 186 | 187 | def is_dist_avail_and_initialized(): 188 | if not dist.is_available(): 189 | return False 190 | if not dist.is_initialized(): 191 | return False 192 | return True 193 | 194 | 195 | def get_world_size(): 196 | if not is_dist_avail_and_initialized(): 197 | return 1 198 | return dist.get_world_size() 199 | 200 | 201 | def get_rank(): 202 | if not is_dist_avail_and_initialized(): 203 | return 0 204 | return dist.get_rank() 205 | 206 | 207 | def is_main_process(): 208 | return get_rank() == 0 209 | 210 | 211 | def save_on_master(*args, **kwargs): 212 | if is_main_process(): 213 | torch.save(*args, **kwargs) 214 | 215 | 216 | def init_distributed_mode(args): 217 | if args.dist_on_itp: 218 | args.rank = int(os.environ['OMPI_COMM_WORLD_RANK']) 219 | args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) 220 | args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) 221 | args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT']) 222 | os.environ['LOCAL_RANK'] = str(args.gpu) 223 | os.environ['RANK'] = str(args.rank) 224 | os.environ['WORLD_SIZE'] = str(args.world_size) 225 | # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] 226 | elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 227 | args.rank = int(os.environ["RANK"]) 228 | args.world_size = int(os.environ['WORLD_SIZE']) 229 | args.gpu = int(os.environ['LOCAL_RANK']) 230 | elif 'SLURM_PROCID' in os.environ: 231 | args.rank = int(os.environ['SLURM_PROCID']) 232 | args.gpu = args.rank % torch.cuda.device_count() 233 | else: 234 | print('Not using distributed mode') 235 | setup_for_distributed(is_master=True) # hack 236 | args.distributed = False 237 | return 238 | 239 | args.distributed = True 240 | 241 | torch.cuda.set_device(args.gpu) 242 | args.dist_backend = 'nccl' 243 | print('| distributed init (rank {}): {}, gpu {}'.format( 244 | args.rank, args.dist_url, args.gpu), flush=True) 245 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 246 | world_size=args.world_size, rank=args.rank) 247 | torch.distributed.barrier() 248 | setup_for_distributed(args.rank == 0) 249 | 250 | 251 | class NativeScalerWithGradNormCount: 252 | state_dict_key = "amp_scaler" 253 | 254 | def __init__(self): 255 | self._scaler = torch.cuda.amp.GradScaler() 256 | 257 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): 258 | self._scaler.scale(loss).backward(create_graph=create_graph) 259 | if update_grad: 260 | if clip_grad is not None: 261 | assert parameters is not None 262 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 263 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) 264 | else: 265 | self._scaler.unscale_(optimizer) 266 | norm = get_grad_norm_(parameters) 267 | self._scaler.step(optimizer) 268 | self._scaler.update() 269 | else: 270 | norm = None 271 | return norm 272 | 273 | def state_dict(self): 274 | return self._scaler.state_dict() 275 | 276 | def load_state_dict(self, state_dict): 277 | self._scaler.load_state_dict(state_dict) 278 | 279 | 280 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: 281 | if isinstance(parameters, torch.Tensor): 282 | parameters = [parameters] 283 | parameters = [p for p in parameters if p.grad is not None] 284 | norm_type = float(norm_type) 285 | if len(parameters) == 0: 286 | return torch.tensor(0.) 287 | device = parameters[0].grad.device 288 | if norm_type == inf: 289 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) 290 | else: 291 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) 292 | return total_norm 293 | 294 | 295 | def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler): 296 | output_dir = Path(args.output_dir) 297 | epoch_name = str(epoch) 298 | if loss_scaler is not None: 299 | checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)] 300 | for checkpoint_path in checkpoint_paths: 301 | to_save = { 302 | 'model': model_without_ddp.state_dict(), 303 | 'optimizer': optimizer.state_dict(), 304 | 'epoch': epoch, 305 | 'scaler': loss_scaler.state_dict(), 306 | 'args': args, 307 | } 308 | 309 | save_on_master(to_save, checkpoint_path) 310 | else: 311 | client_state = {'epoch': epoch} 312 | model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state) 313 | 314 | 315 | def load_model(args, model_without_ddp, optimizer, loss_scaler): 316 | if args.resume: 317 | if args.resume.startswith('https'): 318 | checkpoint = torch.hub.load_state_dict_from_url( 319 | args.resume, map_location='cpu', check_hash=True) 320 | else: 321 | checkpoint = torch.load(args.resume, map_location='cpu') 322 | model_without_ddp.load_state_dict(checkpoint['model']) 323 | print("Resume checkpoint %s" % args.resume) 324 | if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval): 325 | optimizer.load_state_dict(checkpoint['optimizer']) 326 | args.start_epoch = checkpoint['epoch'] + 1 327 | if 'scaler' in checkpoint: 328 | loss_scaler.load_state_dict(checkpoint['scaler']) 329 | print("With optim & sched!") 330 | 331 | 332 | def all_reduce_mean(x): 333 | world_size = get_world_size() 334 | if world_size > 1: 335 | x_reduce = torch.tensor(x).cuda() 336 | dist.all_reduce(x_reduce) 337 | x_reduce /= world_size 338 | return x_reduce.item() 339 | else: 340 | return x -------------------------------------------------------------------------------- /data/data_ukb.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from lightly.transforms import GaussianBlur 4 | 5 | os.environ["OPENBLAS_NUM_THREADS"] = "1" 6 | os.environ["MKL_NUM_THREADS"] = "1" 7 | os.environ["NUMEXPR_NUM_THREADS"] = "1" 8 | from collections import defaultdict 9 | from glob import glob 10 | from os.path import join 11 | from typing import List 12 | 13 | import h5py 14 | import lightly.data as ldata 15 | import numpy as np 16 | import pandas as pd 17 | import torch 18 | import torchvision 19 | from PIL import Image 20 | from pysnptools.snpreader import Bed 21 | from sklearn.preprocessing import StandardScaler 22 | from torch.utils.data import DataLoader 23 | from torch.utils.data.dataset import Dataset 24 | from torchvision import transforms 25 | from tqdm import tqdm 26 | import toml 27 | 28 | torch.multiprocessing.set_sharing_strategy("file_system") 29 | 30 | DEBUG = False 31 | 32 | # Consts --------------------------- 33 | config = toml.load(join(os.path.dirname(os.path.realpath(__file__)), "../paths.toml")) 34 | BASE_IMG = config["BASE_IMG"] 35 | LEFT = join(BASE_IMG, config["LEFT_SUBDIR"]) 36 | RIGHT = join(BASE_IMG, config["RIGHT_SUBDIR"]) 37 | IMG_EXT = ".jpg" 38 | 39 | PHENO = config["UKB_PHENO_FILE"] 40 | 41 | PATH_TO_COV = config["PATH_TO_COV"] 42 | 43 | BASE_GEN = config["BASE_GEN"] 44 | 45 | BASE_BURDEN = config["BASE_BURDEN"] 46 | 47 | BASE_PGS = config["BASE_PGS"] 48 | BLOOD_BIOMARKERS = config["BLOOD_BIOMARKERS"] 49 | ANCESTRY = config["ANCESTRY"] 50 | 51 | COVAR = [ 52 | "eid", 53 | "31-0.0", # sex 54 | "21022-0.0", # age 55 | "4079-0.0", 56 | "4079-0.1", # DBP 57 | "4080-0.0", 58 | "4080-0.1", # SBP 59 | "20116-0.0", # smoking; 2 == current; -3==prefer not to answer 60 | "21001-0.0", # bmi 61 | ] 62 | GENET_PCS = [f"22009-0.{i}" for i in range(1, 41)] 63 | 64 | COVAR_NAMES = [ 65 | "sex", 66 | "age", 67 | "BMI", 68 | "smoking", 69 | "SBP", 70 | "DBP", 71 | ] + [f"genet_pc_{i}" for i in range(1, 41)] 72 | 73 | # Classes -------------------------------- 74 | 75 | 76 | class UKBRetina(Dataset): 77 | def __init__( 78 | self, 79 | eye="left", 80 | iid_selection=None, 81 | tfms=None, 82 | subset=None, 83 | return_iid=False, 84 | normalize_features=True, 85 | img_extension=IMG_EXT, 86 | cov_fillna="mean", 87 | include_biomarkers=False, 88 | biomarkers_filter_nan_cols=0.2, 89 | biomarkers_filter_nan_rows=1.0, 90 | ): 91 | self.return_iid = return_iid 92 | self.img_extension = img_extension 93 | self.tfms = tfms 94 | self.eye = eye 95 | if eye == "left": 96 | self.path = LEFT 97 | elif eye == "right": 98 | self.path = RIGHT 99 | else: 100 | raise ValueError() 101 | if iid_selection is None: 102 | iid_selection = get_indiv() 103 | self._process_ids(iid_select=iid_selection) 104 | self._load_covs( 105 | normalize_features, 106 | cov_fillna=cov_fillna, 107 | include_biomarkers=include_biomarkers, 108 | biomarkers_filter_nan_cols=biomarkers_filter_nan_cols, 109 | biomarkers_filter_nan_rows=biomarkers_filter_nan_rows, 110 | ) 111 | self.subset = subset 112 | if subset: 113 | self.paths = self.paths[:subset] 114 | self.iids = self.iids[:subset] 115 | 116 | def __len__(self): 117 | return len(self.iids) 118 | 119 | def __getitem__(self, idx): 120 | img = self._load_img_item(idx) 121 | iid = self.iids[idx] 122 | cov = torch.from_numpy(self.cov_np[idx]).float() 123 | if self.return_iid: 124 | return img, cov, iid 125 | else: 126 | return img, cov 127 | 128 | def _load_img_item(self, idx): 129 | if isinstance(idx, torch.Tensor): 130 | idx = idx.item() 131 | p = self.paths[idx] 132 | img = Image.open(p) 133 | if self.tfms: 134 | img = self.tfms(img) 135 | return img 136 | 137 | def _load_covs( 138 | self, 139 | normalize_features, 140 | include_biomarkers, 141 | biomarkers_filter_nan_cols, 142 | biomarkers_filter_nan_rows, 143 | cov_fillna, 144 | ): 145 | cov = pd.read_csv(PATH_TO_COV, index_col=0) 146 | self.cov = cov.loc[self.iids] 147 | cols = self.cov.columns.tolist() 148 | df_sex_smoking = self.cov[["sex", "smoking"]].copy() 149 | self.cov = self.cov[self.cov.columns.difference(["sex", "smoking"])] 150 | if include_biomarkers: 151 | biomarkers = get_biomarker_data( 152 | filter_nan_cols=biomarkers_filter_nan_cols, 153 | filter_nan_rows=biomarkers_filter_nan_rows, 154 | iids=self.iids, 155 | ) 156 | self.cov = pd.merge( 157 | self.cov, biomarkers, left_index=True, right_index=True, how="outer" 158 | ) 159 | if cov_fillna == "mean" or cov_fillna is True: 160 | self.cov = self.cov.fillna(self.cov.mean()) 161 | df_sex_smoking = df_sex_smoking.fillna(df_sex_smoking.median()) 162 | elif cov_fillna == "median": 163 | self.cov = self.cov.fillna(self.cov.median()) 164 | df_sex_smoking = df_sex_smoking.fillna(df_sex_smoking.median()) 165 | elif (cov_fillna != False) and (cov_fillna is not None): 166 | raise NotImplementedError( 167 | f"covariate/biomarker fillna method {cov_fillna} is not yet implemented" 168 | ) 169 | 170 | if normalize_features: 171 | scaler = StandardScaler() 172 | self.cov[:] = scaler.fit_transform(self.cov.values) 173 | self.cov = pd.concat([self.cov, df_sex_smoking], axis=1) 174 | self.cov = self.cov[cols] 175 | self.cov_np = self.cov.to_numpy() 176 | self.cov_columns = list(self.cov.columns) 177 | 178 | def _process_ids(self, iid_select=None): 179 | qc_fns = pd.read_csv( 180 | join(BASE_IMG, self.eye, f"qc_paths_{self.eye}.txt"), header=None 181 | ).values.flatten() 182 | qc_paths = np.array([join(self.path, fn) for fn in qc_fns]) 183 | iids = np.array([int(fn.split("_")[0]) for fn in qc_fns]) 184 | 185 | if iid_select is not None: 186 | iid_select = set(iid_select) 187 | self.paths = np.array( 188 | [p for iid, p in zip(iids, qc_paths) if iid in iid_select] 189 | ) 190 | self.iids = np.array([iid for iid in iids if iid in iid_select]) 191 | else: 192 | self.iids = iids 193 | self.paths = qc_paths 194 | 195 | 196 | class UKBRetinaGen(UKBRetina): 197 | gen = None 198 | gen_lookup = None 199 | feature_names = None 200 | 201 | def __init__( 202 | self, 203 | chromos, 204 | rsids=None, 205 | sid_slice=slice(0, None, 100), 206 | iid_selection=None, 207 | eye="left", 208 | tfms=None, 209 | subset=None, 210 | fillna=True, 211 | return_iid=False, 212 | normalize_features=True, 213 | cov_fillna="mean", 214 | include_biomarkers=False, 215 | biomarkers_filter_nan_cols=0.2, 216 | biomarkers_filter_nan_rows=1.0, 217 | ): 218 | super().__init__( 219 | eye=eye, 220 | iid_selection=iid_selection, 221 | tfms=tfms, 222 | subset=subset, 223 | return_iid=return_iid, 224 | normalize_features=normalize_features, 225 | cov_fillna=cov_fillna, 226 | include_biomarkers=include_biomarkers, 227 | biomarkers_filter_nan_cols=biomarkers_filter_nan_cols, 228 | biomarkers_filter_nan_rows=biomarkers_filter_nan_rows, 229 | ) 230 | if UKBRetinaGen.gen is None: 231 | gen = get_gen_data( 232 | chromos=chromos, 233 | rsids=rsids, 234 | sid_slice=sid_slice, 235 | ) 236 | 237 | # imputing the missing SNP values (NaNs) with column-wise mode 238 | if fillna: 239 | for column in gen: 240 | if gen[column].isnull().any(): 241 | gen[column].fillna(gen[column].mode()[0], inplace=True) 242 | UKBRetinaGen.gen = gen.to_numpy() 243 | UKBRetinaGen.gen_lookup = dict( 244 | (iid, idx) for idx, iid in enumerate(gen.index) 245 | ) 246 | UKBRetinaGen.feature_names = list(gen.columns) 247 | 248 | gen_iids = np.array(list(UKBRetinaGen.gen_lookup.keys())) 249 | inter_iids = set(np.intersect1d(self.iids, gen_iids)) 250 | 251 | inds = np.array([i for i, iid in enumerate(self.iids) if iid in inter_iids]) 252 | self.paths = self.paths[inds] 253 | self.iids = self.iids[inds] 254 | self.cov = self.cov.iloc[inds] 255 | self.cov_np = self.cov_np[inds] 256 | 257 | def __getitem__(self, idx): 258 | img = self._load_img_item(idx) 259 | iid = self.iids[idx] 260 | cov = torch.from_numpy(self.cov_np[idx]).float() 261 | gen_idx = self.gen_lookup[iid] 262 | gen = torch.from_numpy(UKBRetinaGen.gen[gen_idx]).float() 263 | if self.return_iid: 264 | return img, cov, gen, iid 265 | else: 266 | return img, cov, gen 267 | 268 | 269 | class UKBRetinaBurden(UKBRetina): 270 | burdens = None 271 | iid_lookup = None 272 | feature_names = None 273 | 274 | def __init__( 275 | self, 276 | filter_zeros=0.01, 277 | eye="left", 278 | iid_selection=None, 279 | tfms=None, 280 | subset=None, 281 | return_iid=False, 282 | normalize_features=True, 283 | cov_fillna="mean", 284 | include_biomarkers=False, 285 | biomarkers_filter_nan_cols=0.2, 286 | biomarkers_filter_nan_rows=1.0, 287 | ): 288 | super().__init__( 289 | eye=eye, 290 | iid_selection=iid_selection, 291 | tfms=tfms, 292 | subset=subset, 293 | return_iid=return_iid, 294 | normalize_features=normalize_features, 295 | cov_fillna=cov_fillna, 296 | include_biomarkers=include_biomarkers, 297 | biomarkers_filter_nan_cols=biomarkers_filter_nan_cols, 298 | biomarkers_filter_nan_rows=biomarkers_filter_nan_rows, 299 | ) 300 | if UKBRetinaBurden.burdens is None: 301 | burdens = get_burden_data(filter_zeros=filter_zeros) 302 | burden_iids = burdens.index.to_numpy() 303 | UKBRetinaBurden.burdens = burdens.to_numpy() 304 | UKBRetinaBurden.iid_lookup = dict( 305 | (iid, idx) for idx, iid in enumerate(burden_iids) 306 | ) 307 | UKBRetinaBurden.feature_names = list(burdens.columns) 308 | 309 | burden_iids = np.array(list(UKBRetinaBurden.iid_lookup.keys())) 310 | inter_iids = set(np.intersect1d(self.iids, burden_iids)) 311 | 312 | inds = np.array([i for i, iid in enumerate(self.iids) if iid in inter_iids]) 313 | self.paths = self.paths[inds] 314 | self.iids = self.iids[inds] 315 | self.cov = self.cov.iloc[inds] 316 | self.cov_np = self.cov_np[inds] 317 | 318 | def __getitem__(self, idx): 319 | img = self._load_img_item(idx) 320 | cov = torch.from_numpy(self.cov_np[idx]).float() 321 | iid = self.iids[idx] 322 | burden_idx = self.iid_lookup[iid] 323 | burdens = torch.from_numpy(UKBRetinaBurden.burdens[burden_idx]).float() 324 | if self.return_iid: 325 | return img, cov, burdens, iid 326 | else: 327 | return img, cov, burdens 328 | 329 | 330 | class UKBRetinaPGS(UKBRetina): 331 | pgs = None 332 | iid_lookup = None 333 | 334 | def __init__( 335 | self, 336 | normalize_pgs=True, 337 | eye="left", 338 | iid_selection=None, 339 | tfms=None, 340 | subset=None, 341 | return_iid=False, 342 | normalize_features=True, 343 | cov_fillna="mean", 344 | include_biomarkers=False, 345 | biomarkers_filter_nan_cols=0.2, 346 | biomarkers_filter_nan_rows=1.0, 347 | ): 348 | super().__init__( 349 | eye=eye, 350 | iid_selection=iid_selection, 351 | tfms=tfms, 352 | subset=subset, 353 | return_iid=return_iid, 354 | normalize_features=normalize_features, 355 | cov_fillna=cov_fillna, 356 | include_biomarkers=include_biomarkers, 357 | biomarkers_filter_nan_cols=biomarkers_filter_nan_cols, 358 | biomarkers_filter_nan_rows=biomarkers_filter_nan_rows, 359 | ) 360 | if UKBRetinaPGS.pgs is None: 361 | pgs = get_pgs_data(normalize=normalize_pgs) 362 | pgs_iids = pgs.index.to_numpy() 363 | UKBRetinaPGS.pgs = pgs.to_numpy() 364 | UKBRetinaPGS.iid_lookup = dict( 365 | (iid, idx) for idx, iid in enumerate(pgs_iids) 366 | ) 367 | 368 | pgs_iids = np.array(list(UKBRetinaPGS.iid_lookup.keys())) 369 | inter_iids = set(np.intersect1d(self.iids, pgs_iids)) 370 | 371 | inds = np.array([i for i, iid in enumerate(self.iids) if iid in inter_iids]) 372 | self.paths = self.paths[inds] 373 | self.iids = self.iids[inds] 374 | self.cov = self.cov.iloc[inds] 375 | self.cov_np = self.cov_np[inds] 376 | 377 | def __getitem__(self, idx): 378 | img = self._load_img_item(idx) 379 | cov = torch.from_numpy(self.cov_np[idx]).float() 380 | iid = self.iids[idx] 381 | pgs_idx = self.iid_lookup[iid] 382 | pgs = torch.from_numpy(UKBRetinaPGS.pgs[pgs_idx]).float() 383 | if self.return_iid: 384 | return img, cov, pgs, iid 385 | else: 386 | return img, cov, pgs 387 | 388 | 389 | class UKBRetinaMultimodal(UKBRetina): 390 | gen = None 391 | gen_lookup = None 392 | gen_feature_names = None 393 | 394 | pgs = None 395 | pgs_lookup = None 396 | 397 | burdens = None 398 | burdens_lookup = None 399 | burdens_feature_names = None 400 | 401 | def __init__( 402 | self, 403 | # gen (raw SNPs): 404 | gen_chromos=[i for i in range(1, 23)], 405 | gen_rsids=None, 406 | gen_sid_slice=slice(0, None, 100), 407 | gen_fillna=True, 408 | # inner (=intersection, no missings) or outer (=union, with missings) 409 | aggregate_modalities="inner", 410 | modalities=["raw_snps", "risk_scores", "burden_scores"], 411 | # pgs: 412 | normalize_pgs=True, 413 | # burdens; 414 | filter_burdens=0.01, 415 | # general: 416 | eye="left", 417 | iid_selection=None, 418 | tfms=None, 419 | subset=None, 420 | return_iid=False, 421 | normalize_features=True, 422 | cov_fillna="mean", 423 | include_biomarkers=False, 424 | biomarkers_filter_nan_cols=0.2, 425 | biomarkers_filter_nan_rows=1.0, 426 | ): 427 | super().__init__( 428 | eye=eye, 429 | iid_selection=iid_selection, 430 | tfms=tfms, 431 | subset=subset, 432 | return_iid=return_iid, 433 | normalize_features=normalize_features, 434 | cov_fillna=cov_fillna, 435 | include_biomarkers=include_biomarkers, 436 | biomarkers_filter_nan_cols=biomarkers_filter_nan_cols, 437 | biomarkers_filter_nan_rows=biomarkers_filter_nan_rows, 438 | ) 439 | self.modalities = modalities 440 | gen_iids, burden_iids, pgs_iids = np.array([]), np.array([]), np.array([]) 441 | if "raw_snps" in modalities and UKBRetinaMultimodal.gen is None: 442 | print("loading raw genetic data...") 443 | gen = get_gen_data( 444 | chromos=gen_chromos, 445 | rsids=gen_rsids, 446 | sid_slice=gen_sid_slice, 447 | ) 448 | # imputing the missing SNP values (NaNs) with column-wise mode 449 | if gen_fillna: 450 | for column in gen: 451 | if gen[column].isnull().any(): 452 | gen[column].fillna(gen[column].mode()[0], inplace=True) 453 | UKBRetinaMultimodal.gen = gen.to_numpy() 454 | UKBRetinaMultimodal.gen_lookup = defaultdict( 455 | lambda: None, [(iid, idx) for idx, iid in enumerate(gen.index)] 456 | ) 457 | UKBRetinaMultimodal.gen_feature_names = list(gen.columns) 458 | if UKBRetinaMultimodal.gen_lookup is not None: 459 | gen_iids = np.array(list(UKBRetinaMultimodal.gen_lookup.keys())) 460 | if "risk_scores" in modalities and UKBRetinaMultimodal.pgs is None: 461 | print("loading polygenic risk score data...") 462 | pgs = get_pgs_data(normalize=normalize_pgs) 463 | pgs_iids = pgs.index.to_numpy() 464 | UKBRetinaMultimodal.pgs = pgs.to_numpy() 465 | UKBRetinaMultimodal.pgs_lookup = defaultdict( 466 | lambda: None, [(iid, idx) for idx, iid in enumerate(pgs_iids)] 467 | ) 468 | if UKBRetinaMultimodal.pgs_lookup is not None: 469 | pgs_iids = np.array(list(UKBRetinaMultimodal.pgs_lookup.keys())) 470 | if "burden_scores" in modalities and UKBRetinaMultimodal.burdens is None: 471 | print("loading burden score data...") 472 | burdens = get_burden_data(filter_zeros=filter_burdens) 473 | burden_iids = burdens.index.to_numpy() 474 | UKBRetinaMultimodal.burdens = burdens.to_numpy() 475 | UKBRetinaMultimodal.burdens_lookup = defaultdict( 476 | lambda: None, [(iid, idx) for idx, iid in enumerate(burden_iids)] 477 | ) 478 | UKBRetinaMultimodal.burdens_feature_names = list(burdens.columns) 479 | if UKBRetinaMultimodal.burdens_lookup is not None: 480 | burden_iids = np.array(list(UKBRetinaMultimodal.burdens_lookup.keys())) 481 | 482 | if ( 483 | aggregate_modalities == "inner" 484 | ): # make sure that all genetic modalities are available 485 | if gen_iids.size == 0 and pgs_iids.size != 0 and burden_iids.size != 0: 486 | selected_iids = set( 487 | np.intersect1d( 488 | self.iids, 489 | np.intersect1d(burden_iids, pgs_iids), 490 | ) 491 | ) 492 | elif gen_iids.size != 0 and pgs_iids.size == 0 and burden_iids.size != 0: 493 | selected_iids = set( 494 | np.intersect1d( 495 | self.iids, 496 | np.intersect1d(burden_iids, gen_iids), 497 | ) 498 | ) 499 | elif gen_iids.size != 0 and pgs_iids.size != 0 and burden_iids.size == 0: 500 | selected_iids = set( 501 | np.intersect1d( 502 | self.iids, 503 | np.intersect1d(pgs_iids, gen_iids), 504 | ) 505 | ) 506 | else: 507 | selected_iids = set( 508 | np.intersect1d( 509 | self.iids, 510 | np.intersect1d(gen_iids, np.intersect1d(burden_iids, pgs_iids)), 511 | ) 512 | ) 513 | elif aggregate_modalities == "outer": 514 | selected_iids = set( 515 | np.union1d( 516 | self.iids, 517 | np.union1d(gen_iids, np.union1d(burden_iids, pgs_iids)), 518 | ) 519 | ) 520 | else: 521 | raise ValueError(f"aggregation {aggregate_modalities} not known") 522 | inds = np.array([i for i, iid in enumerate(self.iids) if iid in selected_iids]) 523 | self.paths = self.paths[inds] 524 | self.iids = self.iids[inds] 525 | self.cov = self.cov.iloc[inds] 526 | self.cov_np = self.cov_np[inds] 527 | 528 | def __getitem__(self, idx): 529 | img = self._load_img_item(idx) 530 | cov = torch.from_numpy(self.cov_np[idx]).float() 531 | iid = self.iids[idx] 532 | 533 | gen, pgs, burdens = torch.empty(1), torch.empty(1), torch.empty(1) 534 | gen_idx, pgs_idx, burdens_idx = None, None, None 535 | if "raw_snps" in self.modalities: 536 | gen_idx = self.gen_lookup[iid] 537 | gen = torch.from_numpy( 538 | np.full(self.gen.shape[1], np.nan) 539 | if gen_idx is None 540 | else self.gen[gen_idx] 541 | ).float() 542 | if "risk_scores" in self.modalities: 543 | pgs_idx = self.pgs_lookup[iid] 544 | pgs = torch.from_numpy( 545 | np.full(self.pgs.shape[1], np.nan) 546 | if pgs_idx is None 547 | else self.pgs[pgs_idx] 548 | ).float() 549 | if "burden_scores" in self.modalities: 550 | burdens_idx = self.burdens_lookup[iid] 551 | burdens = torch.from_numpy( 552 | np.full(self.burdens.shape[1], np.nan) 553 | if burdens_idx is None 554 | else self.burdens[burdens_idx] 555 | ).float() 556 | 557 | missing = torch.tensor( 558 | [gen_idx is None, pgs_idx is None, burdens_idx is None], 559 | dtype=torch.long, 560 | ) 561 | 562 | if self.return_iid: 563 | return { 564 | "iid": iid, 565 | "img": img, 566 | "cov": cov, 567 | "gen": gen, 568 | "pgs": pgs, 569 | "burdens": burdens, 570 | "missing": missing, 571 | } 572 | else: 573 | return { 574 | "img": img, 575 | "cov": cov, 576 | "gen": gen, 577 | "pgs": pgs, 578 | "burdens": burdens, 579 | "missing": missing, 580 | } 581 | 582 | 583 | # Loaders ------------------------------- 584 | 585 | 586 | def get_multimodal_pretraining_data( 587 | # inner (=intersection, no missings) or outer (=union, with missings) 588 | aggregate_modalities="inner", 589 | modalities=["raw_snps", "risk_scores", "burden_scores"], 590 | # raw genetics 591 | gen_chromos=[i for i in range(1, 23)], 592 | gen_sid_slice=slice(0, None, 100), 593 | # pgs 594 | normalize_pgs=True, 595 | # burdens 596 | burdens_zeros=0.1, # filter burden scores by numbers of non-zeros (percentage or absolute) 597 | # general 598 | seed=42, 599 | num_workers=8, 600 | size=256, 601 | batch_size=32, 602 | train_pct=0.6, 603 | val_pct=0.2, 604 | subset=None, 605 | normalize_features=True, 606 | return_iid=False, 607 | tfms_settings="default", 608 | cov_fillna="mean", 609 | include_biomarkers=False, 610 | biomarkers_filter_nan_cols=0.2, 611 | biomarkers_filter_nan_rows=1.0, 612 | ): 613 | t_iids, v_iids, tt_iids = get_indiv_split( 614 | train_pct=train_pct, val_pct=val_pct, seed=seed 615 | ) 616 | loaders = [] 617 | for iids, mode in [(t_iids, "train"), (v_iids, "valid"), (tt_iids, "test")]: 618 | tfms = get_tfms(size=size, augmentation=mode == "train", setting=tfms_settings) 619 | dsets = [ 620 | UKBRetinaMultimodal( 621 | aggregate_modalities=aggregate_modalities, 622 | modalities=modalities, 623 | # raw SNPs 624 | gen_chromos=gen_chromos, 625 | gen_rsids=None, 626 | gen_sid_slice=gen_sid_slice, 627 | gen_fillna=True, 628 | # pgs: 629 | normalize_pgs=normalize_pgs, 630 | # burdens; 631 | filter_burdens=burdens_zeros, 632 | # general: 633 | iid_selection=iids, 634 | return_iid=return_iid, 635 | eye=eye, 636 | tfms=tfms, 637 | subset=subset, 638 | normalize_features=normalize_features, 639 | cov_fillna=cov_fillna, 640 | include_biomarkers=include_biomarkers, 641 | biomarkers_filter_nan_cols=biomarkers_filter_nan_cols, 642 | biomarkers_filter_nan_rows=biomarkers_filter_nan_rows, 643 | ) 644 | for eye in ["left", "right"] 645 | ] 646 | dataset = torch.utils.data.ConcatDataset(dsets) 647 | loader = DataLoader( 648 | dataset, 649 | batch_size=batch_size, 650 | shuffle=mode == "train", 651 | num_workers=num_workers, 652 | pin_memory=True, 653 | ) 654 | loaders.append(loader) 655 | num_features = { 656 | "gen": dsets[0].gen.shape[1] if dsets[0].gen is not None else 0, 657 | "pgs": dsets[0].pgs.shape[1] if dsets[0].pgs is not None else 0, 658 | "burdens": dsets[0].burdens.shape[1] if dsets[0].burdens is not None else 0, 659 | "cov": dsets[0].cov.shape[1], 660 | } 661 | 662 | return loaders, num_features 663 | 664 | 665 | def get_genetics_imaging_data( 666 | rsids=[], 667 | chromos=[i for i in range(1, 23)], 668 | sid_slice=slice(0, None, 100), 669 | burdens_zeros=None, # filter burden scores by numbers of non-zeros (percentage or absolute) 670 | seed=42, 671 | num_workers=4, 672 | size=256, 673 | normalize_features=True, 674 | batch_size=32, 675 | train_pct=0.6, 676 | val_pct=0.2, 677 | subset=None, 678 | return_iid=False, 679 | tfms_settings="default", 680 | cov_fillna="mean", 681 | include_biomarkers=False, 682 | biomarkers_filter_nan_cols=0.2, 683 | biomarkers_filter_nan_rows=1.0, 684 | ): 685 | """load imaging with either raw genetic or with burden data 686 | 687 | for raw SNPs: 688 | get_genetics_imaging_data(rsids=['rs123', ...], chromos=[1, ...], sid_slice=None, ...) 689 | or 690 | get_genetics_imaging_data(rsids=None, chromos=[1, ...], sid_slice=slice(0, None, 100), ...) 691 | for burden data, use 692 | get_genetics_imaging_data(rsids=None, chromos=None, sid_slice=None, burdens_zeros=0.01, ...) 693 | 694 | """ 695 | assert ( 696 | rsids is None or sid_slice is None 697 | ), "specified both rsids and sid_slice; need to choose one or the other" 698 | assert ( 699 | burdens_zeros is None or rsids is None 700 | ), "specify either burdens or snps, not both" 701 | 702 | t_iids, v_iids, tt_iids = get_indiv_split( 703 | train_pct=train_pct, val_pct=val_pct, seed=seed 704 | ) 705 | loaders = [] 706 | for iids, mode in [(t_iids, "train"), (v_iids, "valid"), (tt_iids, "test")]: 707 | tfms = get_tfms(size=size, augmentation=mode == "train", setting=tfms_settings) 708 | if burdens_zeros is None: 709 | dsets = [ 710 | UKBRetinaGen( 711 | iid_selection=iids, 712 | chromos=chromos, 713 | rsids=rsids, 714 | sid_slice=sid_slice, 715 | eye=eye, 716 | tfms=tfms, 717 | normalize_features=normalize_features, 718 | subset=subset, 719 | return_iid=return_iid, 720 | cov_fillna=cov_fillna, 721 | include_biomarkers=include_biomarkers, 722 | biomarkers_filter_nan_cols=biomarkers_filter_nan_cols, 723 | biomarkers_filter_nan_rows=biomarkers_filter_nan_rows, 724 | ) 725 | for eye in ["left", "right"] 726 | ] 727 | gen_num_features = dsets[0].gen.shape[-1] 728 | else: 729 | dsets = [ 730 | UKBRetinaBurden( 731 | filter_zeros=burdens_zeros, 732 | iid_selection=iids, 733 | eye=eye, 734 | tfms=tfms, 735 | normalize_features=normalize_features, 736 | subset=subset, 737 | return_iid=return_iid, 738 | cov_fillna=cov_fillna, 739 | include_biomarkers=include_biomarkers, 740 | biomarkers_filter_nan_cols=biomarkers_filter_nan_cols, 741 | biomarkers_filter_nan_rows=biomarkers_filter_nan_rows, 742 | ) 743 | for eye in ["left", "right"] 744 | ] 745 | gen_num_features = dsets[0].burdens.shape[-1] 746 | 747 | dataset = torch.utils.data.ConcatDataset(dsets) 748 | loader = DataLoader( 749 | dataset, 750 | batch_size=batch_size, 751 | shuffle=mode == "train", 752 | num_workers=num_workers, 753 | pin_memory=True, 754 | ) 755 | loaders.append(loader) 756 | return loaders, gen_num_features 757 | 758 | 759 | def get_imaging_pretraining_data( 760 | seed=42, 761 | num_workers=4, 762 | size=256, 763 | batch_size=50, 764 | train_pct=0.6, 765 | val_pct=0.2, 766 | tfms_settings="default", 767 | ): 768 | t_iids, v_iids, tt_iids = get_indiv_split( 769 | train_pct=train_pct, val_pct=val_pct, seed=seed 770 | ) 771 | loaders = [] 772 | 773 | def get_path_by_index(dataset, index): 774 | # filename is the path of the image relative to the dataset root 775 | return dataset.paths[index] 776 | 777 | class UKBLightlyCollateFunction(ldata.BaseCollateFunction): 778 | def __init__(self, transform: torchvision.transforms.Compose): 779 | super().__init__(transform) 780 | 781 | def forward(self, batch: List[tuple]): 782 | batch_size = len(batch) 783 | 784 | # list of transformed images 785 | transforms = [ 786 | self.transform(batch[i % batch_size][0]).unsqueeze_(0) 787 | for i in range(2 * batch_size) 788 | ] 789 | # list of labels 790 | labels = torch.LongTensor([0 for _ in batch]) 791 | # list of filenames 792 | fnames = [item[2] for item in batch] 793 | 794 | # tuple of transforms 795 | transforms = ( 796 | torch.cat(transforms[:batch_size], 0), 797 | torch.cat(transforms[batch_size:], 0), 798 | ) 799 | return transforms, labels, fnames 800 | 801 | for iids, mode in [(t_iids, "train"), (v_iids, "valid"), (tt_iids, "test")]: 802 | our_dsets = [ 803 | UKBRetina( 804 | eye=eye, 805 | iid_selection=iids, 806 | ) 807 | for eye in ["left", "right"] 808 | ] 809 | lightly_dsets = [ 810 | ldata.LightlyDataset.from_torch_dataset( 811 | dset, index_to_filename=get_path_by_index 812 | ) 813 | for dset in our_dsets 814 | ] 815 | dataset = torch.utils.data.ConcatDataset(lightly_dsets) 816 | loader = DataLoader( 817 | dataset, 818 | batch_size=batch_size, 819 | shuffle=mode == "train", 820 | num_workers=num_workers, 821 | pin_memory=True, 822 | collate_fn=UKBLightlyCollateFunction( 823 | transform=get_tfms( 824 | size=size, augmentation=mode == "train", setting=tfms_settings 825 | ) 826 | ), 827 | ) 828 | loaders.append(loader) 829 | return loaders 830 | 831 | 832 | def get_imaging_card_data( 833 | seed=42, 834 | num_workers=8, 835 | size=256, 836 | normalize_features=True, 837 | batch_size=50, 838 | train_pct=0.6, 839 | val_pct=0.2, 840 | subset=None, 841 | tfms_settings="default", 842 | cov_fillna="mean", 843 | return_iid=False, 844 | include_biomarkers=False, 845 | biomarkers_filter_nan_cols=0.2, 846 | biomarkers_filter_nan_rows=1.0, 847 | ): 848 | t_iids, v_iids, tt_iids = get_indiv_split( 849 | train_pct=train_pct, val_pct=val_pct, seed=seed 850 | ) 851 | loaders = [] 852 | for iids, mode in [(t_iids, "train"), (v_iids, "valid"), (tt_iids, "test")]: 853 | tfms = get_tfms(size=size, augmentation=mode == "train", setting=tfms_settings) 854 | dsets = [ 855 | UKBRetina( 856 | eye=eye, 857 | iid_selection=iids, 858 | tfms=tfms, 859 | normalize_features=normalize_features, 860 | subset=subset, 861 | cov_fillna=cov_fillna, 862 | return_iid=return_iid, 863 | include_biomarkers=include_biomarkers, 864 | biomarkers_filter_nan_cols=biomarkers_filter_nan_cols, 865 | biomarkers_filter_nan_rows=biomarkers_filter_nan_rows, 866 | ) 867 | for eye in ["left", "right"] 868 | ] 869 | dataset = torch.utils.data.ConcatDataset(dsets) 870 | cov_num_features = dsets[0].cov_np.shape[-1] 871 | 872 | loader = DataLoader( 873 | dataset, 874 | batch_size=batch_size, 875 | shuffle=mode == "train", 876 | num_workers=num_workers, 877 | pin_memory=True, 878 | ) 879 | loaders.append(loader) 880 | return loaders, cov_num_features 881 | 882 | 883 | def get_pgs_imaging_data( 884 | normalize_pgs=True, 885 | seed=42, 886 | num_workers=4, 887 | size=256, 888 | normalize_features=True, 889 | batch_size=32, 890 | train_pct=0.6, 891 | val_pct=0.2, 892 | subset=None, 893 | return_iid=False, 894 | tfms_settings="default", 895 | cov_fillna="mean", 896 | include_biomarkers=False, 897 | biomarkers_filter_nan_cols=0.2, 898 | biomarkers_filter_nan_rows=1.0, 899 | ): 900 | t_iids, v_iids, tt_iids = get_indiv_split( 901 | train_pct=train_pct, val_pct=val_pct, seed=seed 902 | ) 903 | loaders = [] 904 | for iids, mode in [(t_iids, "train"), (v_iids, "valid"), (tt_iids, "test")]: 905 | tfms = get_tfms(size=size, augmentation=mode == "train", setting=tfms_settings) 906 | dsets = [ 907 | UKBRetinaPGS( 908 | eye=eye, 909 | iid_selection=iids, 910 | tfms=tfms, 911 | normalize_features=normalize_features, 912 | normalize_pgs=normalize_pgs, 913 | subset=subset, 914 | cov_fillna=cov_fillna, 915 | return_iid=return_iid, 916 | include_biomarkers=include_biomarkers, 917 | biomarkers_filter_nan_cols=biomarkers_filter_nan_cols, 918 | biomarkers_filter_nan_rows=biomarkers_filter_nan_rows, 919 | ) 920 | for eye in ["left", "right"] 921 | ] 922 | dataset = torch.utils.data.ConcatDataset(dsets) 923 | 924 | loader = DataLoader( 925 | dataset, 926 | batch_size=batch_size, 927 | shuffle=mode == "train", 928 | num_workers=num_workers, 929 | pin_memory=True, 930 | ) 931 | loaders.append(loader) 932 | return loaders, dsets[0].pgs.shape[1] 933 | 934 | 935 | def get_burden_data(filter_zeros=0): 936 | """load burden data and filter columns with low numbers of non-zeros 937 | 938 | if filter_zeros >= 1: minimum number of non-zero individuals 939 | if filter_zeros in (0, 1): minimum percentage of of non-zero individuals 940 | """ 941 | cols = ( 942 | pd.read_csv(join(BASE_BURDEN, "combined_burdens_colnames.txt"), header=None) 943 | .to_numpy() 944 | .flatten() 945 | ) 946 | 947 | main_iids = get_indiv() 948 | burden_iids = ( 949 | pd.read_csv(join(BASE_BURDEN, "combined_burdens_iid.txt"), header=None) 950 | .to_numpy() 951 | .flatten() 952 | ) 953 | inds = fast_index_lookup(burden_iids, main_iids) 954 | inds.sort() 955 | iids = burden_iids[inds] 956 | if DEBUG: 957 | data = pd.DataFrame( 958 | data=np.random.randint(0, 10, size=(len(iids), len(cols))), 959 | index=iids, 960 | columns=cols, 961 | ) 962 | return data.sort_index() 963 | 964 | print("loading burden data...") 965 | G = h5py.File(join(BASE_BURDEN, "combined_burdens.h5"))["G"][inds] 966 | 967 | if filter_zeros > 0: 968 | if filter_zeros < 1: 969 | filter_zeros = len(iids) * filter_zeros 970 | g0 = (G > 0).sum(0) 971 | col_ind = g0 >= filter_zeros 972 | print(f"selecting {col_ind.sum()} with minimum of {filter_zeros} non-zeros") 973 | G = G[:, col_ind] 974 | cols = cols[col_ind] 975 | 976 | data = pd.DataFrame(data=G, index=iids, columns=cols) 977 | 978 | return data.sort_index() 979 | 980 | 981 | def get_biomarker_data(filter_nan_cols=0.2, filter_nan_rows=1.0, iids=None): 982 | """ 983 | filter biomarkers: 984 | first throw out all individuals with more than 100*filter_nan_rows% NaNs 985 | second throw out all biomarkers with more than 100*filter_nan_cols% NaNs in the remaining data 986 | 987 | """ 988 | if iids is None: 989 | iids = get_indiv() 990 | df = pd.read_csv(BLOOD_BIOMARKERS, sep="\t", index_col=0) 991 | inter_iids = np.intersect1d(iids, df.index) 992 | df = df.loc[inter_iids] 993 | nan_means_row = df.isna().mean(1) 994 | df = df.loc[nan_means_row <= filter_nan_rows] 995 | nan_means_col = df.isna().mean() 996 | df = df.loc[:, nan_means_col <= filter_nan_cols] 997 | return df 998 | 999 | 1000 | def get_pgs_data(normalize=True): 1001 | available_pgs = sorted(glob(join(BASE_PGS, "*.sscore"))) 1002 | iids = get_indiv() 1003 | for pgs_p in tqdm(available_pgs): 1004 | pgs = pgs_p.split("/")[-1].split(".")[0] 1005 | df = pd.read_csv(pgs_p, usecols=["IID", "SCORE1_AVG"], sep="\t", index_col=0) 1006 | if pgs_p == available_pgs[0]: 1007 | iids = np.intersect1d(iids, df.index) 1008 | full_df = df.loc[iids] 1009 | full_df.columns = [pgs] 1010 | else: 1011 | full_df[pgs] = df.loc[iids] 1012 | if normalize: 1013 | full_df = (full_df - full_df.mean()) / full_df.std() 1014 | return full_df 1015 | 1016 | 1017 | def get_gen_data(chromos=[15, 19], rsids=[], sid_slice=None): 1018 | """ 1019 | load all SNPS that have to be on one of the provided chromos 1020 | snps that can not be found (eg if maf<0.001 or not on the microarray) will be ignored 1021 | # ld = 0.8 throw some snps that are close to each other 1022 | """ 1023 | ids = get_indiv() 1024 | for chromo in tqdm(chromos): 1025 | path_to_genetic = join(BASE_GEN, f"ukb_chr{chromo}_v2") 1026 | 1027 | bed = Bed(path_to_genetic, count_A1=False) 1028 | ind = bed.iid_to_index([[str(i), str(i)] for i in ids]) 1029 | if sid_slice is not None: 1030 | sid_ind = sid_slice 1031 | else: 1032 | sid_ind = bed.sid_to_index(np.intersect1d(rsids, bed.sid)) 1033 | labels = bed[ind, sid_ind].read().val 1034 | df = pd.DataFrame( 1035 | index=ids, data=labels, columns=bed.sid[sid_ind], dtype=np.float32 1036 | ) 1037 | if chromo == chromos[0]: 1038 | full_df = df 1039 | else: 1040 | full_df = pd.merge(full_df, df, left_index=True, right_index=True) 1041 | return full_df 1042 | 1043 | 1044 | # Helpers ------------------------- 1045 | 1046 | 1047 | def test_train_valid_leak(tl, vl, ttl): 1048 | """short utility to ensure there's no data leakage""" 1049 | D0 = tl.dataset.datasets[0] 1050 | D1 = tl.dataset.datasets[1] 1051 | N = len(D0) 1052 | # tI = tl.sampler.indices 1053 | tI = tl.sampler if isinstance(tl.sampler, torch.Tensor) else tl.sampler.indices 1054 | vI = vl.sampler if isinstance(vl.sampler, torch.Tensor) else vl.sampler.indices 1055 | ttI = ttl.sampler if isinstance(ttl.sampler, torch.Tensor) else ttl.sampler.indices 1056 | train_iids = [D0.iids[i] if i < N else D1.iids[i - N] for i in tI] 1057 | valid_iids = [D0.iids[i] if i < N else D1.iids[i - N] for i in vI] 1058 | test_iids = [D0.iids[i] if i < N else D1.iids[i - N] for i in ttI] 1059 | inter_valid = np.intersect1d(train_iids, valid_iids) 1060 | inter_test = np.intersect1d(train_iids, test_iids) 1061 | print( 1062 | f"intersection: {len(inter_valid)} of {len(train_iids)}(train) and {len(valid_iids)}(valid)" 1063 | ) 1064 | print( 1065 | f"intersection: {len(inter_test)} of {len(train_iids)}(train) and {len(test_iids)}(test)" 1066 | ) 1067 | 1068 | 1069 | def get_indiv_split(train_pct=0.6, val_pct=0.2, seed=42): 1070 | """train/val/test split, stratified by individuals (no data leakage)""" 1071 | rng = np.random.RandomState(seed) 1072 | iids = get_indiv() 1073 | iids = rng.permutation(iids) 1074 | 1075 | m = len(iids) 1076 | t_cut = int(train_pct * m) 1077 | v_cut = int(val_pct * m) + t_cut 1078 | train_iids = iids[:t_cut] 1079 | valid_iids = iids[t_cut:v_cut] 1080 | test_iids = iids[v_cut:] 1081 | return train_iids, valid_iids, test_iids 1082 | 1083 | 1084 | def export_card(): 1085 | # encoding errors when pandas.__version__ == '1.3.x'? but works in 1.2.5 1086 | df = pd.read_csv(PHENO, usecols=COVAR + GENET_PCS) 1087 | iids = get_indiv() 1088 | df["iid"] = df["eid"] 1089 | df.index = df.iid 1090 | df = df.loc[iids] 1091 | 1092 | df["SBP"] = df[["4080-0.0", "4080-0.1"]].mean(1) 1093 | df["DBP"] = df[["4079-0.0", "4079-0.1"]].mean(1) 1094 | 1095 | df["age"] = df["21022-0.0"] 1096 | df["sex"] = df["31-0.0"] 1097 | df["smoking"] = 1 * (df["20116-0.0"] == 2) 1098 | df["BMI"] = df["21001-0.0"] 1099 | 1100 | for i, col in enumerate(GENET_PCS): 1101 | df[f"genet_pc_{i + 1}"] = df[col] 1102 | 1103 | df = df[ 1104 | ["sex", "age", "BMI", "smoking", "SBP", "DBP"] 1105 | + [f"genet_pc_{i + 1}" for i in range(40)] 1106 | ] 1107 | df.to_csv(PATH_TO_COV, index=True) 1108 | return df 1109 | 1110 | 1111 | def get_indiv(ancestry_threshold=0.99): 1112 | all_iids = [] 1113 | for eye in ["left", "right"]: 1114 | iids = list( 1115 | pd.read_csv( 1116 | join(BASE_IMG, eye, f"qc_paths_{eye}.txt"), 1117 | header=None, 1118 | sep="_", 1119 | usecols=[0], 1120 | )[0] 1121 | ) 1122 | all_iids += iids 1123 | iids = np.unique(all_iids) 1124 | 1125 | ancestry = pd.read_csv(ANCESTRY, sep="\t", index_col=0, usecols=["IID", "EUR"]) 1126 | anc_iids = ancestry[ancestry.EUR >= ancestry_threshold].index 1127 | iids = np.intersect1d(iids, anc_iids) 1128 | 1129 | return iids 1130 | 1131 | 1132 | def get_augmented_tfms(size=224, setting="default"): 1133 | mean = [0.485, 0.456, 0.406] 1134 | std = [0.229, 0.224, 0.225] 1135 | if setting == "none": 1136 | tfms = transforms.Compose( 1137 | [ 1138 | transforms.Resize(size=size), 1139 | transforms.ToTensor(), 1140 | transforms.Normalize(mean=mean, std=std), 1141 | ] 1142 | ) 1143 | elif setting == "default": 1144 | tfms = transforms.Compose( 1145 | [ 1146 | transforms.Resize(size=size), 1147 | transforms.RandomRotation(degrees=20), 1148 | transforms.RandomHorizontalFlip(p=0.5), 1149 | transforms.ToTensor(), 1150 | transforms.Normalize(mean=mean, std=std), 1151 | ] 1152 | ) 1153 | elif isinstance(setting, dict): 1154 | tfms_list = [transforms.Resize(size=size)] 1155 | if setting["rrc"] is not None: 1156 | tfms_list.append(transforms.RandomResizedCrop(**setting["rrc"])) 1157 | if setting["rotate"] is not None: 1158 | tfms_list.append(transforms.RandomRotation(**setting["rotate"])) 1159 | if setting["flip"]: 1160 | tfms_list.append(transforms.RandomHorizontalFlip(p=0.5)) 1161 | # TODO: this should probably be adaptive? 1162 | if setting["blur"]: 1163 | tfms_list.append(transforms.RandomAdjustSharpness(sharpness_factor=0.5)) 1164 | if setting["autocontrast"]: 1165 | tfms_list.append(transforms.RandomAutocontrast(p=0.5)) 1166 | if setting["jitter"] is not None: 1167 | tfms_list.append(transforms.ColorJitter(**setting["jitter"])) 1168 | 1169 | tfms_list += [ 1170 | transforms.ToTensor(), 1171 | transforms.Normalize(mean=mean, std=std), 1172 | ] 1173 | tfms = transforms.Compose(tfms_list) 1174 | elif setting == "simclr": 1175 | cj_prob = 0.8 1176 | cj_bright = 0.7 1177 | cj_contrast = 0.7 1178 | cj_sat = 0.7 1179 | cj_hue = 0.2 1180 | min_scale = 0.08 1181 | random_gray_scale = 0.2 1182 | gaussian_blur = 0.5 1183 | kernel_size = 0.1 1184 | hf_prob = 0.5 1185 | color_jitter = transforms.ColorJitter(cj_bright, cj_contrast, cj_sat, cj_hue) 1186 | tfms = transforms.Compose( 1187 | [ 1188 | transforms.RandomResizedCrop(size=size, scale=(min_scale, 1.0)), 1189 | transforms.RandomHorizontalFlip(p=hf_prob), 1190 | transforms.RandomApply([color_jitter], p=cj_prob), 1191 | transforms.RandomGrayscale(p=random_gray_scale), 1192 | GaussianBlur(kernel_size=kernel_size * size, prob=gaussian_blur), 1193 | transforms.ToTensor(), 1194 | transforms.Normalize(mean=mean, std=std), 1195 | ] 1196 | ) 1197 | 1198 | return tfms 1199 | 1200 | 1201 | def get_tfms(size=224, augmentation=False, setting="default"): 1202 | mean = [0.485, 0.456, 0.406] 1203 | std = [0.229, 0.224, 0.225] 1204 | if augmentation: 1205 | return get_augmented_tfms(size=size, setting=setting) 1206 | else: 1207 | tfms = transforms.Compose( 1208 | [ 1209 | transforms.Resize(size=size), 1210 | transforms.ToTensor(), 1211 | transforms.Normalize(mean=mean, std=std), 1212 | ] 1213 | ) 1214 | return tfms 1215 | 1216 | 1217 | def fast_index_lookup(a1, a2): 1218 | """more efficient implementation of np.where(a1 == a2[:, None])[1] 1219 | 1220 | assumes no duplicates in a2, otherwise it will only consider the last occurence of the duplicate instance 1221 | """ 1222 | ll = defaultdict(lambda: None, [(v, i) for i, v in enumerate(a1)]) 1223 | return np.array([ll[v] for v in a2 if not ll[v] is None]) 1224 | --------------------------------------------------------------------------------