├── README.md ├── condssl ├── __init__.py ├── builder.py └── loader.py ├── dataset └── dataloader.py ├── feature_extraction ├── extract_embeddings.py └── get_clusters.py ├── network └── inception_v4.py ├── plots ├── progression_plot.png └── umap.png ├── preprocessing ├── colorstandard.png ├── process_cptac.py ├── process_tcga.py └── utils.py ├── survival_models ├── cox.py └── utils.py └── train_ssl.py /README.md: -------------------------------------------------------------------------------- 1 | # Conditional Self-supervised for histopathology images 2 | 3 | This repository contains the code for the paper [Interpretable Prediction of Lung Squamous Cell Carcinoma Recurrence With Self-supervised Learning](https://arxiv.org/pdf/2203.12204.pdf). 4 | 5 | ## Introduction 6 | In this study, we explore the morphological features of LSCC recurrence and metastasis with novel SSL method, based on conditional SSL. We propose a sampling mechanism within contrastive SSL framework for histopathology images that avoids overfitting to batch effects. 7 | 8 | The 2D UMAP projection of tile representations, trained by different sampling 9 | in self-supervised learning. Tiles from 8 slides with mostly LSCC tumor content 10 | are highlighted with different colors. Left: model trained by MoCo contrastive 11 | learning with uniform sampling. It shows that tiles within each slide cluster 12 | together. Right: model trained with proposed conditional contrastive learning. 13 | The tiles from each slide are less clustered together. 14 | ![UMAP](./plots/umap.png) 15 | 16 | The Kaplan-Meier curves shows rates of recurrence-free patients over time in 17 | sub-cohorts of test set with different criterion. Two sub-cohorts stratified with the predicted 18 | recurrence risk by our Cox regression. The high risk cohort includes the top half 19 | patients of highest estimated risks; the low risk cohort includes the lower half. 20 | 21 |

22 | 23 |

24 | 25 | ## Data 26 | 27 | ### TCGA-LUSC 28 | Download the TCGA-LUSC whole slide image from this [filter](https://portal.gdc.cancer.gov/repository?facetTab=files&filters=%7B%22op%22%3A%22and%22%2C%22content%22%3A%5B%7B%22op%22%3A%22in%22%2C%22content%22%3A%7B%22field%22%3A%22cases.project.project_id%22%2C%22value%22%3A%5B%22TCGA-LUSC%22%5D%7D%7D%2C%7B%22op%22%3A%22in%22%2C%22content%22%3A%7B%22field%22%3A%22files.data_format%22%2C%22value%22%3A%5B%22svs%22%5D%7D%7D%2C%7B%22op%22%3A%22in%22%2C%22content%22%3A%7B%22field%22%3A%22files.experimental_strategy%22%2C%22value%22%3A%5B%22Tissue%20Slide%22%5D%7D%7D%5D%7D). 29 | 30 | ### CPTAC-LSCC 31 | Download the TCGA-LSCC whole slide image from [here](https://wiki.cancerimagingarchive.net/display/Public/CPTAC-LSCC). 32 | 33 | ## Preprocessing 34 | 35 | To preprocess the WSIs, run the code in preprocessing folder. 36 | 37 | `python process_tcga.py --followup_path {followup_table} --wsi_path {directory_of_WSIs} --refer_img {color_norm_img} --s {proportion_of_tissue}` 38 | 39 | `python process_cptac.py --followup_path {followup_table} --wsi_path {directory_of_WSIs} --refer_img {color_norm_img} --s {proportion_of_tissue}` 40 | 41 | ## Self-supervised learning 42 | 43 | Run the command to train the Inception V4 with conditional SSL on two-layer sampling. 44 | 45 | `torchrun train.py --data_dir {data_dir} --split_dir {annotation_dir} --batch_slide_num {number of slides in batch} --cos --out_dir {output_dir}` 46 | 47 | Pretrained weight can be downloaded [here](https://drive.google.com/drive/folders/1Uc7JZZRkBNxoKkDmy-fcLsy9cUz_ixcr?usp=sharing). 48 | 49 | ## Extract features 50 | 51 | To extract features, we first extract the tile representations with SSL pretrained Inception V4. 52 | 53 | `python extract_embeddings.py --feature_extractor_dir {checkpoint of pretrained feature extractor} --subtype_model_dir {subtype model} --root_dir {tiles directory} --split_dir {annotation files} --out_dir {output directory}` 54 | 55 | Then we fit the clusters of the tile reprenstations in the training data, and assign the clusters to tiles in the validation and test set. After clustering each tile, we aggregate tile 56 | probabilities with average pooling on clusters to generate the slide-level features. Run the following commends: 57 | 58 | `python get_clusters.py --data_dir {data_dir} --cluster_type {method_of_clustering} --n_cluster {number_of_clusters} --out_dir {out_dir}` 59 | 60 | ## Survival model (Cox-PH) 61 | 62 | We run the Cox-PH regression on the extracted slide-level features and the time and status of recurrence. 63 | 64 | The triplet of features and 65 | slide labels $\{(v_j , y_j , t_j)\}^N_{j=1}$ will be used, where $v_j$ is the vector of cluster features, $y_j$ is the 66 | binary label indicating LSCC recurrence, and $t_j$ encodes the recurrence-free followup times 67 | for the patient. i.e. If a patient was not observed to have recurrence during the followup 68 | period, we use the length of followup time $t_j$ as the time of censoring. Each $t_j$ is computed 69 | with a granularity of 6 months. We fit a Cox regression model with L2-norm regularization 70 | using $\{(v_j , y_j , t_j)\}^N_{j=1}$ to compute the proportional hazard function of recurrence $\lambda (t|v)$. 71 | 72 | `python cox.py --data_dir {data_dir} --cluster_name {cluster_model_checkpoint} --noramlize {pooling_method_over_slides}` 73 | 74 | ## Reference 75 | 76 |
77 |

@misc{https://doi.org/10.48550/arxiv.2203.12204, 78 | doi = {10.48550/ARXIV.2203.12204}, 79 | url = {https://arxiv.org/abs/2203.12204}, 80 | author = {Zhu, Weicheng and Fernandez-Granda, Carlos and Razavian, Narges}, 81 | keywords = {Computer Vision and Pattern Recognition (cs.CV), FOS: Computer and information sciences, FOS: Computer and information sciences}, 82 | title = {Interpretable Prediction of Lung Squamous Cell Carcinoma Recurrence With Self-supervised Learning}, 83 | publisher = {arXiv}, 84 | year = {2022}, 85 | copyright = {arXiv.org perpetual, non-exclusive license} 86 | } 87 |

88 |
89 | -------------------------------------------------------------------------------- /condssl/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | -------------------------------------------------------------------------------- /condssl/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class MoCo(nn.Module): 7 | """ 8 | Build a MoCo model with: a query encoder, a key encoder, and a queue 9 | https://arxiv.org/abs/1911.05722 10 | """ 11 | def __init__(self, base_encoder, dim=128, K=65536, m=0.999, T=0.07, mlp=False, condition=True): 12 | """ 13 | dim: feature dimension (default: 128) 14 | K: queue size; number of negative keys (default: 65536) 15 | m: moco momentum of updating key encoder (default: 0.999) 16 | T: softmax temperature (default: 0.07) 17 | """ 18 | super(MoCo, self).__init__() 19 | 20 | self.condition = condition 21 | self.K = K 22 | self.m = m 23 | self.T = T 24 | 25 | # create the encoders 26 | # num_classes is the output fc dimension 27 | self.encoder_q = base_encoder(num_classes=dim) 28 | self.encoder_k = base_encoder(num_classes=dim) 29 | 30 | 31 | if mlp: # hack: brute-force replacement 32 | dim_mlp = self.encoder_q.fc.weight.shape[1] 33 | self.encoder_q.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_q.fc) 34 | self.encoder_k.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_k.fc) 35 | 36 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): 37 | param_k.data.copy_(param_q.data) # initialize 38 | param_k.requires_grad = False # not update by gradient 39 | 40 | # create the queue 41 | self.register_buffer("queue", torch.randn(dim, K)) 42 | self.queue = nn.functional.normalize(self.queue, dim=0) 43 | 44 | self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) 45 | 46 | @torch.no_grad() 47 | def _momentum_update_key_encoder(self): 48 | """ 49 | Momentum update of the key encoder 50 | """ 51 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): 52 | param_k.data = param_k.data * self.m + param_q.data * (1. - self.m) 53 | 54 | @torch.no_grad() 55 | def _dequeue_and_enqueue(self, keys): 56 | # gather keys before updating queue 57 | keys = concat_all_gather(keys) 58 | 59 | batch_size = keys.shape[0] 60 | 61 | ptr = int(self.queue_ptr) 62 | assert self.K % batch_size == 0 # for simplicity 63 | 64 | # replace the keys at ptr (dequeue and enqueue) 65 | # self.queue[:, ptr:ptr + batch_size] = keys.T 66 | self.queue[:, ptr:ptr + batch_size] = torch.t(keys) 67 | ptr = (ptr + batch_size) % self.K # move pointer 68 | 69 | self.queue_ptr[0] = ptr 70 | 71 | @torch.no_grad() 72 | def _batch_shuffle_ddp(self, x): 73 | """ 74 | Batch shuffle, for making use of BatchNorm. 75 | *** Only support DistributedDataParallel (DDP) model. *** 76 | """ 77 | # gather from all gpus 78 | batch_size_this = x.shape[0] 79 | x_gather = concat_all_gather(x) 80 | batch_size_all = x_gather.shape[0] 81 | 82 | num_gpus = batch_size_all // batch_size_this 83 | 84 | # random shuffle index 85 | idx_shuffle = torch.randperm(batch_size_all).cuda() 86 | 87 | # broadcast to all gpus 88 | torch.distributed.broadcast(idx_shuffle, src=0) 89 | 90 | # index for restoring 91 | idx_unshuffle = torch.argsort(idx_shuffle) 92 | 93 | # shuffled index for this gpu 94 | gpu_idx = torch.distributed.get_rank() 95 | idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx] 96 | return x_gather[idx_this], idx_unshuffle 97 | 98 | @torch.no_grad() 99 | def _batch_unshuffle_ddp(self, x, idx_unshuffle): 100 | """ 101 | Undo batch shuffle. 102 | *** Only support DistributedDataParallel (DDP) model. *** 103 | """ 104 | # gather from all gpus 105 | batch_size_this = x.shape[0] 106 | x_gather = concat_all_gather(x) 107 | batch_size_all = x_gather.shape[0] 108 | 109 | num_gpus = batch_size_all // batch_size_this 110 | 111 | # restored index for this gpu 112 | gpu_idx = torch.distributed.get_rank() 113 | idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx] 114 | 115 | return x_gather[idx_this] 116 | 117 | def forward(self, im_q, im_k): 118 | """ 119 | Input: 120 | im_q: a batch of query images 121 | im_k: a batch of key images 122 | Output: 123 | logits, targets 124 | """ 125 | 126 | # compute query features 127 | q = self.encoder_q(im_q) # queries: NxC 128 | q = nn.functional.normalize(q, dim=1) 129 | 130 | # compute key features 131 | with torch.no_grad(): # no gradient to keys 132 | self._momentum_update_key_encoder() # update the key encoder 133 | 134 | # shuffle for making use of BN 135 | im_k, idx_unshuffle = self._batch_shuffle_ddp(im_k) 136 | 137 | k = self.encoder_k(im_k) # keys: NxC 138 | k = nn.functional.normalize(k, dim=1) 139 | 140 | # undo shuffle 141 | k = self._batch_unshuffle_ddp(k, idx_unshuffle) 142 | 143 | 144 | if self.condition: 145 | # conditional ssl 146 | logits = torch.mm(q, k.T) / self.T 147 | labels = torch.arange(logits.shape[0], dtype=torch.long).cuda() 148 | return logits, labels 149 | 150 | else: 151 | # compute logits 152 | # Einstein sum is more intuitive 153 | # positive logits: Nx1 154 | l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1) 155 | 156 | # negative logits: NxK 157 | l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()]) 158 | 159 | # logits: Nx(1+K) 160 | logits = torch.cat([l_pos, l_neg], dim=1) 161 | 162 | # apply temperature 163 | logits /= self.T 164 | 165 | # labels: positive key indicators 166 | labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda() 167 | 168 | # dequeue and enqueue 169 | self._dequeue_and_enqueue(k) 170 | return logits, labels 171 | 172 | # utils 173 | @torch.no_grad() 174 | def concat_all_gather(tensor): 175 | """ 176 | Performs all_gather operation on the provided tensors. 177 | *** Warning ***: torch.distributed.all_gather has no gradient. 178 | """ 179 | tensors_gather = [torch.ones_like(tensor) 180 | for _ in range(torch.distributed.get_world_size())] 181 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False) 182 | 183 | output = torch.cat(tensors_gather, dim=0) 184 | return output 185 | -------------------------------------------------------------------------------- /condssl/loader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | from PIL import ImageFilter 3 | import random 4 | 5 | 6 | class TwoCropsTransform: 7 | """Take two random crops of one image as the query and key.""" 8 | 9 | def __init__(self, base_transform): 10 | self.base_transform = base_transform 11 | 12 | def __call__(self, x): 13 | q = self.base_transform(x) 14 | k = self.base_transform(x) 15 | return [q, k] 16 | 17 | 18 | class GaussianBlur(object): 19 | """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709""" 20 | 21 | def __init__(self, sigma=[.1, 2.]): 22 | self.sigma = sigma 23 | 24 | def __call__(self, x): 25 | sigma = random.uniform(self.sigma[0], self.sigma[1]) 26 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma)) 27 | return x 28 | -------------------------------------------------------------------------------- /dataset/dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import torch 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | from torchvision import transforms 7 | import cv2 8 | 9 | class TCGA_CPTAC_Dataset(Dataset): 10 | def __init__(self, cptac_dir, tcga_dir, split_dir, transform=None, mode='train', batch_slide_num=4, batch_size=128): 11 | self.cptac_dir = cptac_dir 12 | self.tcga_dir = tcga_dir 13 | 14 | slide_list = pickle.load(open(split_dir + '/case_split.pkl', 'rb'))[mode + "_id"] 15 | # slide_list = [s for s in slide_list if "TCGA" in s] 16 | self.slide2tiles = {} 17 | for slide_id in slide_list: 18 | if "TCGA" in slide_id: 19 | self.slide2tiles[slide_id] = os.listdir(os.path.join(self.tcga_dir, slide_id)) 20 | else: 21 | self.slide2tiles[slide_id] = os.listdir(os.path.join(self.cptac_dir, slide_id)) 22 | self.idx2tiles = [os.path.join(slide_id, tile_name) for slide_id, tile_list in self.slide2tiles.items() 23 | for tile_name in tile_list if 'jpg' in tile_name] 24 | self.tiles2idx = dict(zip(self.idx2tiles, range(len(self.idx2tiles)))) 25 | self.idx2slide = slide_list 26 | self.slide2idx = dict(zip(self.idx2slide, range(len(self.idx2slide)))) 27 | self.transform = transform 28 | self.batch_slide_num = batch_slide_num 29 | self.batch_size = batch_size 30 | 31 | def __getitem__(self, index): 32 | tile_names = [] 33 | slide_id = self.idx2slide[index] 34 | selected_tiles = [slide_id + '/' + t for t in np.random.choice(self.slide2tiles[slide_id], self.batch_size // self.batch_slide_num)] 35 | tile_names += selected_tiles 36 | for i in range(self.batch_slide_num - 1): 37 | slide_id = self.idx2slide[np.random.randint(len(self.idx2slide))] 38 | tile_names += [slide_id + '/' + t for t in np.random.choice(self.slide2tiles[slide_id], self.batch_size // self.batch_slide_num)] 39 | indices = [] 40 | imgs = [] 41 | for tile_name in tile_names: 42 | if "TCGA" in tile_name: 43 | image = cv2.imread(self.tcga_dir + tile_name) 44 | else: 45 | image = cv2.imread(self.cptac_dir + tile_name) 46 | image = Image.fromarray(image) 47 | image_tensor = self.transform(image) 48 | imgs.append(image_tensor) 49 | indices.append(index) 50 | return imgs, indices 51 | 52 | def __len__(self): 53 | return len(self.idx2slide) 54 | 55 | 56 | class TCGA_CPTAC_Bag_Dataset(Dataset): 57 | def __init__(self, data_dir, split_dir, mode='train'): 58 | self.data_dir = data_dir 59 | slide_list = pickle.load(open(os.path.join(split_dir, 'case_split_2yr.pkl'), 'rb'))[mode + '_id'] 60 | self.slide2tiles = {} 61 | for slide_id in slide_list: 62 | if "TCGA" in slide_id: 63 | tile_dir = self.data_dir + '/TCGA/tiles/' 64 | else: 65 | tile_dir = self.data_dir + '/CPTAC/tiles/' 66 | self.slide2tiles[slide_id] = os.listdir(os.path.join(tile_dir, slide_id)) 67 | self.idx2tiles = [os.path.join(slide_id, tile_name) for slide_id, tile_list in self.slide2tiles.items() 68 | for tile_name in tile_list if 'jpg' in tile_name] 69 | self.tiles2idx = dict(zip(self.idx2tiles, range(len(self.idx2tiles)))) 70 | self.idx2slide = slide_list 71 | self.slide2idx = dict(zip(self.idx2slide, range(len(self.idx2slide)))) 72 | self.transform = transforms.Compose([transforms.ToTensor()]) 73 | 74 | def __getitem__(self, index): 75 | tile_path = self.idx2tiles[index] 76 | slide_id, tile_name = tile_path.split('/') 77 | tile_idx = self.tiles2idx[tile_path] 78 | slide_idx = self.slide2idx[slide_id] 79 | if "TCGA" in tile_path: 80 | prefix = self.data_dir + '/TCGA/tiles/' 81 | else: 82 | prefix = self.data_dir + '/CPTAC/tiles/' 83 | image = cv2.imread(os.path.join(prefix, tile_path)) 84 | image = Image.fromarray(image) 85 | image_tensor = self.transform(image) 86 | return image_tensor, tile_idx, slide_idx 87 | 88 | def __len__(self): 89 | return len(self.idx2tiles) 90 | 91 | 92 | 93 | -------------------------------------------------------------------------------- /feature_extraction/extract_embeddings.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../') 3 | import argparse 4 | from inception_v4 import InceptionV4 5 | import numpy as np 6 | import pickle 7 | import torch 8 | import torch.nn as nn 9 | import os 10 | import cv2 11 | from torchvision import transforms 12 | from collections import defaultdict 13 | from tqdm import tqdm 14 | import pandas as pd 15 | from PIL import Image 16 | import shelve 17 | from dataset.dataloader import TCGA_CPTAC_Bag_Dataset 18 | 19 | if torch.cuda.is_available(): 20 | device = 'cuda' 21 | else: 22 | device = 'cpu' 23 | print(device) 24 | 25 | def get_embeddings_bagging(feature_extractor, subtype_model, data_set): 26 | embedding_dict = defaultdict(list) 27 | outcomes_dict = defaultdict(list) 28 | feature_extractor.eval() 29 | data_loader = torch.utils.data.DataLoader(data_set, batch_size=256, shuffle=False, num_workers=torch.cuda.device_count()) 30 | with torch.no_grad(): 31 | count = 0 32 | for batch in tqdm(data_loader, position=0, leave=True): 33 | count += 1 34 | img, _, bag_idx = batch 35 | feat = feature_extractor(img.to(device)).cpu() 36 | subtype_model.eval() 37 | subtype_prob = subtype_model(img) 38 | subtype_pred = torch.argmax(subtype_prob, dim=1) 39 | tumor_idx = (subtype_pred != 0) 40 | feat = feat[tumor_idx].numpy() 41 | bag_idx = bag_idx[tumor_idx] 42 | for i in range(len(bag_idx)): 43 | embedding_dict[bag_idx[i].item()].append(feat[i][np.newaxis,:]) 44 | slide_id = data_set.idx2slide[bag_idx[i].item()] 45 | if "TCGA" in slide_id: 46 | case_id = '-'.join(slide_id.split('-', 3)[:3]) 47 | else: 48 | case_id = slide_id.rsplit('-', 1)[0] 49 | outcomes_dict[bag_idx[i].item()] = annotations[case_id] 50 | for k in embedding_dict: 51 | embedding_dict[k] = np.concatenate(embedding_dict[k], axis=0) 52 | return embedding_dict, outcomes_dict 53 | 54 | def load_pretrained(net, model_dir): 55 | 56 | print(model_dir) 57 | checkpoint = torch.load(model_dir) 58 | model_state_dict = {k.replace("module.encoder_q.", ""): v for k, v in checkpoint['state_dict'].items() if 59 | "encoder_q" in k} 60 | net.load_state_dict(model_state_dict) 61 | net.last_linear = nn.Identity() 62 | 63 | parser = argparse.ArgumentParser(description='Extract embeddings ') 64 | 65 | parser.add_argument('--feature_extractor_dir', default='./pretrained/checkpoint.pth.tar', type=str) 66 | parser.add_argument('--subtype_model_dir', default='./subtype_cls/checkpoint.pth.tar', type=str) 67 | parser.add_argument('--root_dir', type=str) 68 | parser.add_argument('--split_dir', type=str) 69 | parser.add_argument('--out_dir', type=str) 70 | 71 | args = parser.parse_args() 72 | 73 | tcga_annotation = pickle.load(open('../TCGA/recurrence_annotation.pkl', 'rb')) 74 | cptac_annotation = pickle.load(open('../CPTAC/recurrence_annotation.pkl', 'rb')) 75 | annotations = {**tcga_annotation, **cptac_annotation} 76 | feature_extractor = InceptionV4(num_classes=256) 77 | load_pretrained(feature_extractor, args.feature_extractor_dir) 78 | feature_extractor.to('cuda') 79 | feature_extractor = nn.DataParallel(feature_extractor, device_ids=device_ids) 80 | 81 | subtype_model = InceptionV4(num_classes=2).to('cuda') 82 | subtype_model.load_state_dict(torch.load(args.subtype_model_dir)) 83 | subtype_model = nn.DataParallel(subtype_model, device_ids=device_ids) 84 | 85 | 86 | train_dataset = TCGA_CPTAC_Bag_Dataset(args.root_dir, args.split_dir, 'train') 87 | val_dataset = TCGA_CPTAC_Bag_Dataset(args.root_dir, args.split_dir, 'val') 88 | test_dataset = TCGA_CPTAC_Bag_Dataset(args.root_dir, args.split_dir, 'test') 89 | 90 | 91 | with torch.no_grad(): 92 | names = ['train', 'val', 'test'] 93 | for name, data_set in zip(names, [train_dataset, val_dataset, test_dataset]): 94 | print(name) 95 | embedding_dict, outcomes_dict = get_embeddings_bagging(feature_extractor, subtype_model, data_set) 96 | pickle.dump(embedding_dict, open("{}_{}_embedding.pkl".format(args.out_dir), 'wb'), protocol=4) 97 | pickle.dump((outcomes_dict), open("{}_{}_outcomes.pkl".format(args.out_dir), 'wb'), protocol=4) 98 | 99 | -------------------------------------------------------------------------------- /feature_extraction/get_clusters.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from sklearn.cluster import KMeans 3 | from sklearn.mixture import GaussianMixture 4 | import pickle 5 | import numpy as np 6 | import sys 7 | 8 | 9 | 10 | parser = argparse.ArgumentParser(description='Get cluster features') 11 | 12 | parser.add_argument('--data_dir', default='./', type=str) 13 | parser.add_argument('--cluster_type', default='gmm', type=str) 14 | parser.add_argument('--n_cluster', default=50, type=int) 15 | parser.add_argument('--out_dir', default='./', type=str) 16 | 17 | 18 | args = parser.parse_args() 19 | 20 | train_features = pickle.load(open(args.data_dir + 'train_embedding.pkl', 'rb')) 21 | train_features = np.concatenate(list(train_features.values()), axis=0) 22 | if args.cluster_type == "kmeans": 23 | print('kmeans') 24 | cluster = KMeans(n_clusters=n_cluster).fit(train_features) 25 | pickle.dump(cluster, open(data_dir + 'kmeans_{}.pkl'.format(args.n_cluster), 'wb')) 26 | else: 27 | print('gmm') 28 | cluster = GaussianMixture(n_components=n_cluster).fit(train_features) 29 | pickle.dump(cluster, open(data_dir + 'gmm_{}.pkl'.format(args.n_cluster), 'wb')) 30 | 31 | 32 | -------------------------------------------------------------------------------- /network/inception_v4.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.utils.model_zoo as model_zoo 5 | import os 6 | import sys 7 | 8 | 9 | class BasicConv2d(nn.Module): 10 | 11 | def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0): 12 | super(BasicConv2d, self).__init__() 13 | self.conv = nn.Conv2d(in_planes, out_planes, 14 | kernel_size=kernel_size, stride=stride, 15 | padding=padding, bias=False) # verify bias false 16 | self.bn = nn.BatchNorm2d(out_planes, 17 | eps=0.001, # value found in tensorflow 18 | momentum=0.1, # default pytorch value 19 | affine=True) 20 | self.relu = nn.ReLU(inplace=True) 21 | 22 | def forward(self, x): 23 | x = self.conv(x) 24 | x = self.bn(x) 25 | x = self.relu(x) 26 | return x 27 | 28 | 29 | class Mixed_3a(nn.Module): 30 | 31 | def __init__(self): 32 | super(Mixed_3a, self).__init__() 33 | self.maxpool = nn.MaxPool2d(3, stride=2) 34 | self.conv = BasicConv2d(64, 96, kernel_size=3, stride=2) 35 | 36 | def forward(self, x): 37 | x0 = self.maxpool(x) 38 | x1 = self.conv(x) 39 | out = torch.cat((x0, x1), 1) 40 | return out 41 | 42 | 43 | class Mixed_4a(nn.Module): 44 | 45 | def __init__(self): 46 | super(Mixed_4a, self).__init__() 47 | 48 | self.branch0 = nn.Sequential( 49 | BasicConv2d(160, 64, kernel_size=1, stride=1), 50 | BasicConv2d(64, 96, kernel_size=3, stride=1) 51 | ) 52 | 53 | self.branch1 = nn.Sequential( 54 | BasicConv2d(160, 64, kernel_size=1, stride=1), 55 | BasicConv2d(64, 64, kernel_size=(1,7), stride=1, padding=(0,3)), 56 | BasicConv2d(64, 64, kernel_size=(7,1), stride=1, padding=(3,0)), 57 | BasicConv2d(64, 96, kernel_size=(3,3), stride=1) 58 | ) 59 | 60 | def forward(self, x): 61 | x0 = self.branch0(x) 62 | x1 = self.branch1(x) 63 | out = torch.cat((x0, x1), 1) 64 | return out 65 | 66 | 67 | class Mixed_5a(nn.Module): 68 | 69 | def __init__(self): 70 | super(Mixed_5a, self).__init__() 71 | self.conv = BasicConv2d(192, 192, kernel_size=3, stride=2) 72 | self.maxpool = nn.MaxPool2d(3, stride=2) 73 | 74 | def forward(self, x): 75 | x0 = self.conv(x) 76 | x1 = self.maxpool(x) 77 | out = torch.cat((x0, x1), 1) 78 | return out 79 | 80 | 81 | class Inception_A(nn.Module): 82 | 83 | def __init__(self): 84 | super(Inception_A, self).__init__() 85 | self.branch0 = BasicConv2d(384, 96, kernel_size=1, stride=1) 86 | 87 | self.branch1 = nn.Sequential( 88 | BasicConv2d(384, 64, kernel_size=1, stride=1), 89 | BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1) 90 | ) 91 | 92 | self.branch2 = nn.Sequential( 93 | BasicConv2d(384, 64, kernel_size=1, stride=1), 94 | BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1), 95 | BasicConv2d(96, 96, kernel_size=3, stride=1, padding=1) 96 | ) 97 | 98 | self.branch3 = nn.Sequential( 99 | nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False), 100 | BasicConv2d(384, 96, kernel_size=1, stride=1) 101 | ) 102 | 103 | def forward(self, x): 104 | x0 = self.branch0(x) 105 | x1 = self.branch1(x) 106 | x2 = self.branch2(x) 107 | x3 = self.branch3(x) 108 | out = torch.cat((x0, x1, x2, x3), 1) 109 | return out 110 | 111 | 112 | class Reduction_A(nn.Module): 113 | 114 | def __init__(self): 115 | super(Reduction_A, self).__init__() 116 | self.branch0 = BasicConv2d(384, 384, kernel_size=3, stride=2) 117 | 118 | self.branch1 = nn.Sequential( 119 | BasicConv2d(384, 192, kernel_size=1, stride=1), 120 | BasicConv2d(192, 224, kernel_size=3, stride=1, padding=1), 121 | BasicConv2d(224, 256, kernel_size=3, stride=2) 122 | ) 123 | 124 | self.branch2 = nn.MaxPool2d(3, stride=2) 125 | 126 | def forward(self, x): 127 | x0 = self.branch0(x) 128 | x1 = self.branch1(x) 129 | x2 = self.branch2(x) 130 | out = torch.cat((x0, x1, x2), 1) 131 | return out 132 | 133 | 134 | class Inception_B(nn.Module): 135 | 136 | def __init__(self): 137 | super(Inception_B, self).__init__() 138 | self.branch0 = BasicConv2d(1024, 384, kernel_size=1, stride=1) 139 | 140 | self.branch1 = nn.Sequential( 141 | BasicConv2d(1024, 192, kernel_size=1, stride=1), 142 | BasicConv2d(192, 224, kernel_size=(1,7), stride=1, padding=(0,3)), 143 | BasicConv2d(224, 256, kernel_size=(7,1), stride=1, padding=(3,0)) 144 | ) 145 | 146 | self.branch2 = nn.Sequential( 147 | BasicConv2d(1024, 192, kernel_size=1, stride=1), 148 | BasicConv2d(192, 192, kernel_size=(7,1), stride=1, padding=(3,0)), 149 | BasicConv2d(192, 224, kernel_size=(1,7), stride=1, padding=(0,3)), 150 | BasicConv2d(224, 224, kernel_size=(7,1), stride=1, padding=(3,0)), 151 | BasicConv2d(224, 256, kernel_size=(1,7), stride=1, padding=(0,3)) 152 | ) 153 | 154 | self.branch3 = nn.Sequential( 155 | nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False), 156 | BasicConv2d(1024, 128, kernel_size=1, stride=1) 157 | ) 158 | 159 | def forward(self, x): 160 | x0 = self.branch0(x) 161 | x1 = self.branch1(x) 162 | x2 = self.branch2(x) 163 | x3 = self.branch3(x) 164 | out = torch.cat((x0, x1, x2, x3), 1) 165 | return out 166 | 167 | 168 | class Reduction_B(nn.Module): 169 | 170 | def __init__(self): 171 | super(Reduction_B, self).__init__() 172 | 173 | self.branch0 = nn.Sequential( 174 | BasicConv2d(1024, 192, kernel_size=1, stride=1), 175 | BasicConv2d(192, 192, kernel_size=3, stride=2) 176 | ) 177 | 178 | self.branch1 = nn.Sequential( 179 | BasicConv2d(1024, 256, kernel_size=1, stride=1), 180 | BasicConv2d(256, 256, kernel_size=(1,7), stride=1, padding=(0,3)), 181 | BasicConv2d(256, 320, kernel_size=(7,1), stride=1, padding=(3,0)), 182 | BasicConv2d(320, 320, kernel_size=3, stride=2) 183 | ) 184 | 185 | self.branch2 = nn.MaxPool2d(3, stride=2) 186 | 187 | def forward(self, x): 188 | x0 = self.branch0(x) 189 | x1 = self.branch1(x) 190 | x2 = self.branch2(x) 191 | out = torch.cat((x0, x1, x2), 1) 192 | return out 193 | 194 | 195 | class Inception_C(nn.Module): 196 | 197 | def __init__(self): 198 | super(Inception_C, self).__init__() 199 | 200 | self.branch0 = BasicConv2d(1536, 256, kernel_size=1, stride=1) 201 | 202 | self.branch1_0 = BasicConv2d(1536, 384, kernel_size=1, stride=1) 203 | self.branch1_1a = BasicConv2d(384, 256, kernel_size=(1,3), stride=1, padding=(0,1)) 204 | self.branch1_1b = BasicConv2d(384, 256, kernel_size=(3,1), stride=1, padding=(1,0)) 205 | 206 | self.branch2_0 = BasicConv2d(1536, 384, kernel_size=1, stride=1) 207 | self.branch2_1 = BasicConv2d(384, 448, kernel_size=(3,1), stride=1, padding=(1,0)) 208 | self.branch2_2 = BasicConv2d(448, 512, kernel_size=(1,3), stride=1, padding=(0,1)) 209 | self.branch2_3a = BasicConv2d(512, 256, kernel_size=(1,3), stride=1, padding=(0,1)) 210 | self.branch2_3b = BasicConv2d(512, 256, kernel_size=(3,1), stride=1, padding=(1,0)) 211 | 212 | self.branch3 = nn.Sequential( 213 | nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False), 214 | BasicConv2d(1536, 256, kernel_size=1, stride=1) 215 | ) 216 | 217 | def forward(self, x): 218 | x0 = self.branch0(x) 219 | 220 | x1_0 = self.branch1_0(x) 221 | x1_1a = self.branch1_1a(x1_0) 222 | x1_1b = self.branch1_1b(x1_0) 223 | x1 = torch.cat((x1_1a, x1_1b), 1) 224 | 225 | x2_0 = self.branch2_0(x) 226 | x2_1 = self.branch2_1(x2_0) 227 | x2_2 = self.branch2_2(x2_1) 228 | x2_3a = self.branch2_3a(x2_2) 229 | x2_3b = self.branch2_3b(x2_2) 230 | x2 = torch.cat((x2_3a, x2_3b), 1) 231 | 232 | x3 = self.branch3(x) 233 | 234 | out = torch.cat((x0, x1, x2, x3), 1) 235 | return out 236 | 237 | 238 | class InceptionV4(nn.Module): 239 | 240 | def __init__(self, num_classes=1001): 241 | super(InceptionV4, self).__init__() 242 | # Special attributs 243 | self.input_space = None 244 | self.input_size = (299, 299, 3) 245 | self.mean = None 246 | self.std = None 247 | # Modules 248 | self.features = nn.Sequential( 249 | BasicConv2d(3, 32, kernel_size=3, stride=2), 250 | BasicConv2d(32, 32, kernel_size=3, stride=1), 251 | BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1), 252 | Mixed_3a(), 253 | Mixed_4a(), 254 | Mixed_5a(), 255 | Inception_A(), 256 | Inception_A(), 257 | Inception_A(), 258 | Inception_A(), 259 | Reduction_A(), # Mixed_6a 260 | Inception_B(), 261 | Inception_B(), 262 | Inception_B(), 263 | Inception_B(), 264 | Inception_B(), 265 | Inception_B(), 266 | Inception_B(), 267 | Reduction_B(), # Mixed_7a 268 | Inception_C(), 269 | Inception_C(), 270 | Inception_C() 271 | ) 272 | self.last_linear = nn.Linear(1536, num_classes) 273 | 274 | def logits(self, features): 275 | #Allows image of any size to be processed 276 | adaptiveAvgPoolWidth = features.shape[2] 277 | x = F.avg_pool2d(features, kernel_size=adaptiveAvgPoolWidth) 278 | x = x.view(x.size(0), -1) 279 | x = self.last_linear(x) 280 | return x 281 | 282 | def forward(self, input): 283 | x = self.features(input) 284 | x = self.logits(x) 285 | return x 286 | 287 | 288 | def inceptionv4(num_classes=1000, pretrained='imagenet'): 289 | if pretrained: 290 | settings = pretrained_settings['inceptionv4'][pretrained] 291 | assert num_classes == settings['num_classes'], \ 292 | "num_classes should be {}, but is {}".format(settings['num_classes'], num_classes) 293 | 294 | # both 'imagenet'&'imagenet+background' are loaded from same parameters 295 | model = InceptionV4(num_classes=1001) 296 | model.load_state_dict(model_zoo.load_url(settings['url'])) 297 | 298 | if pretrained == 'imagenet': 299 | new_last_linear = nn.Linear(1536, 1000) 300 | new_last_linear.weight.data = model.last_linear.weight.data[1:] 301 | new_last_linear.bias.data = model.last_linear.bias.data[1:] 302 | model.last_linear = new_last_linear 303 | 304 | model.input_space = settings['input_space'] 305 | model.input_size = settings['input_size'] 306 | model.input_range = settings['input_range'] 307 | model.mean = settings['mean'] 308 | model.std = settings['std'] 309 | else: 310 | model = InceptionV4(num_classes=num_classes) 311 | return model -------------------------------------------------------------------------------- /plots/progression_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NYUMedML/conditional_ssl_hist/32c055130c9d4a40e66af50835b97364599792d6/plots/progression_plot.png -------------------------------------------------------------------------------- /plots/umap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NYUMedML/conditional_ssl_hist/32c055130c9d4a40e66af50835b97364599792d6/plots/umap.png -------------------------------------------------------------------------------- /preprocessing/colorstandard.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NYUMedML/conditional_ssl_hist/32c055130c9d4a40e66af50835b97364599792d6/preprocessing/colorstandard.png -------------------------------------------------------------------------------- /preprocessing/process_cptac.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import pandas as pd 4 | from utils import wsi_to_tiles 5 | import pickle 6 | 7 | parser = argparse.ArgumentParser(description='Process TCGA') 8 | 9 | parser.add_argument('--followup_path', default='./LSCC-clinicalTable.csv', type=str) 10 | parser.add_argument('--wsi_path', default='../CPTAC_WSI', type=str) 11 | parser.add_argument('--refer_img', default='./colorstandard.png', type=str) 12 | parser.add_argument('--s', default=0.9, type=float, help='The proportion of tissues') 13 | args = parser.parse_args() 14 | 15 | clinicalTable = pd.read_csv(args.followup_path).set_index('case_id') 16 | wsi_dir_dict = {} 17 | wsi_list = os.popen("find {} -name '*.svs'".format(args.wsi_path)).read().strip('\n').replace(wsi_path,'').split("\n") 18 | for slide_id in wsi_list: 19 | slide_id = slide_id.lstrip('/').rstrip('.svs') 20 | tile_path = os.path.join('../CPTAC/tiles', slide_id) 21 | if not os.path.exists(tile_path): 22 | os.mkdir(tile_path) 23 | 24 | for idx, wsi in enumerate(wsi_list): 25 | wsi_to_tiles(idx, wsi, args.refer_img, args.s) 26 | 27 | # Get annotation 28 | annotation = {} 29 | for case_id in clinicalTable.index: 30 | clinicalRow = clinicalTable.loc[case_id].to_dict() 31 | try: 32 | imageRow = imageTable.loc[case_id].to_dict(orient='list') 33 | slide_id = imageRow['Slide_ID'] 34 | except: 35 | slide_id = [] 36 | annotation[case_id] = {'recurrence': clinicalRow['Recurrence.status..1..yes..0..no.'], 37 | 'stage': stage_dict[clinicalRow['baseline.tumor_stage_pathological']], 38 | 'survival_days': clinicalRow['Overall.survival..days'], 39 | 'survival': clinicalRow['Survival.status..1..dead..0..alive.'], 40 | 'recurrence_free_days':clinicalRow['Recurrence.free.survival..days'], 41 | 'age':clinicalRow['consent.age'], 42 | 'gender':clinicalRow['consent.sex'], 43 | 'followup_days':clinicalRow['follow.up.number_of_days_from_date_of_initial_pathologic_diagnosis_to_date_of_last_contact'], 44 | 'slide_id': slide_id} 45 | pickle.dump(annotation, open('../CPTAC/recurrence_annotation.pkl', 'wb')) 46 | 47 | -------------------------------------------------------------------------------- /preprocessing/process_tcga.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import pandas as pd 4 | from utils import wsi_to_tiles 5 | import pickle 6 | 7 | parser = argparse.ArgumentParser(description='Process TCGA') 8 | 9 | parser.add_argument('--followup_path', default='./clinical_follow_up_v1.0_lusc.xlsx', type=str) 10 | parser.add_argument('--clinical_table_path', default='./clinical_follow_up_v1.0_lusc.xlsx', type=str) 11 | parser.add_argument('--wsi_path', default='../TCGA_WSI', type=str) 12 | parser.add_argument('--refer_img', default='./colorstandard.png', type=str) 13 | parser.add_argument('--s', default=0.9, type=float, help='The proportion of tissues') 14 | 15 | 16 | args = parser.parse_args() 17 | 18 | followupTable = pd.read_excel(args.followup_path, skiprows=[1,2], engine='openpyxl') 19 | followupTable = followupTable.loc[followupTable['new_tumor_event_dx_indicator'].isin({'YES', 'NO'})] 20 | followupTable['recurrence'] = ((followupTable['new_tumor_event_dx_indicator'] == 'YES') & 21 | (followupTable['new_tumor_event_type'] != 'New Primary Tumor')) 22 | followupTable = followupTable.sort_values(['bcr_patient_barcode', 'form_completion_date']).drop_duplicates('bcr_patient_barcode', keep='last') 23 | LUSC_patientids = set(followupTable['bcr_patient_barcode']) 24 | 25 | 26 | wsi_list = os.popen("find {} -name '*.svs'".format(wsi_path)).read().strip('\n').split('\n') 27 | wsi_list_LUSC = [] 28 | for idx in range(len(wsi_list)): 29 | slide_id = wsi_list[idx].rsplit('/', 1)[1].split('.')[0] 30 | patient_id = '-'.join(slide_id.split('-', 3)[:3]) 31 | tile_path = os.path.join('../TCGA/tiles', slide_id) 32 | if patient_id in LUSC_patientids: 33 | if not os.path.exists(tile_path): 34 | os.mkdir(tile_path) 35 | wsi_list_LUSC.append(wsi_list[idx]) 36 | 37 | for idx, wsi in enumerate(wsi_list_LUSC): 38 | wsi_to_tiles(idx, wsi, args.refer_img, args.s) 39 | 40 | # Get annotation 41 | clinicalTable = pd.read_csv(args.clinical_table_path).set_index('bcr_patient_barcode') 42 | annotation = defaultdict(lambda: {"recurrence": None, "slide_id": []}) 43 | slide_ids = os.listdir('./TCGA/tiles') 44 | included_slides = [s for s in slide_ids if s.rsplit('-',3)[0] in set(followupTable.index)] 45 | for slide_id in included_slides: 46 | case_id = '-'.join(slide_id.split('-', 3)[:3]) 47 | clinicalRow = followupTable.loc[case_id].to_dict() 48 | annotation[case_id]['recurrence'] = 1 if clinicalRow['recurrence'] else 0 49 | annotation[case_id]['slide_id'].append(slide_id) 50 | annotation[case_id]['stage'] = clinicalTable.loc[case_id]['ajcc_pathologic_tumor_stage'] 51 | annotation[case_id]['survival_days'] = clinicalTable.loc[case_id]['death_days_to'] 52 | annotation[case_id]['survival'] = clinicalTable.loc[case_id]['vital_status'] 53 | annotation[case_id]['recurrence_free_days'] = pd.to_numeric(followupTable.new_tumor_event_dx_days_to, errors='coerce').loc[case_id] 54 | annotation[case_id]['followup_days'] = pd.to_numeric(followupTable.last_contact_days_to, errors='coerce').loc[case_id] 55 | annotation[case_id]['gender'] = clinicalTable['gender'].loc[case_id] 56 | pickle.dump(annotation, open('../TCGA/recurrence_annotation.pkl', 'wb')) 57 | -------------------------------------------------------------------------------- /preprocessing/utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | from openslide import OpenSlide 3 | from PIL import Image 4 | import numpy as np 5 | import staintools 6 | 7 | 8 | def getGradientMagnitude(im): 9 | "Get magnitude of gradient for given image" 10 | ddepth = cv2.CV_32F 11 | dx = cv2.Sobel(im, ddepth, 1, 0) 12 | dy = cv2.Sobel(im, ddepth, 0, 1) 13 | dxabs = cv2.convertScaleAbs(dx) 14 | dyabs = cv2.convertScaleAbs(dy) 15 | mag = cv2.addWeighted(dxabs, 0.5, dyabs, 0.5, 0) 16 | return mag 17 | 18 | 19 | def wsi_to_tiles(idx, wsi, refer_img, s): 20 | normalizer = staintools.StainNormalizer(method='vahadane') 21 | refer_img = staintools.read_image(refer_img) 22 | normalizer.fit(refer_img) 23 | count = 0 24 | sys.stdout.write('Start task %d: %s \n' % (idx, wsi)) 25 | slide_id = wsi.rsplit('/', 1)[1].split('.')[0] 26 | tile_path = os.path.join('./tiles', slide_id) 27 | img = OpenSlide(os.path.join(wsi)) 28 | if str(img.properties.values.__self__.get('tiff.ImageDescription')).split("|")[1] == "AppMag = 40": 29 | sz = 2048 30 | seq = 1536 31 | else: 32 | sz = 1024 33 | seq = 768 34 | [w, h] = img.dimensions 35 | for x in range(1, w, seq): 36 | for y in range(1, h, seq): 37 | img_tmp = img.read_region(location=(x, y), level=0, size=(sz, sz)) \ 38 | .convert("RGB").resize((299, 299), Image.ANTIALIAS) 39 | grad = getGradientMagnitude(np.array(img_tmp)) 40 | unique, counts = np.unique(grad, return_counts=True) 41 | if counts[np.argwhere(unique <= 15)].sum() < 299 * 299 * s: 42 | img_tmp = normalizer.transform(np.array(img_tmp)) 43 | img_tmp = Image.fromarray(img_tmp) 44 | img_tmp.save(tile_path + "/" + str(x) + "_" + str(y) + '.jpg', 'JPEG', optimize=True, quality=94) 45 | count += 1 46 | sys.stdout.write('End task %d with %d tiles\n' % (idx, count)) 47 | 48 | -------------------------------------------------------------------------------- /survival_models/cox.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../') 3 | from lifelines import CoxPHFitter 4 | from lifelines.utils import concordance_index 5 | import pickle 6 | import utils 7 | import os 8 | import pandas as pd 9 | import numpy as np 10 | import matplotlib.pyplot as plt 11 | from sklearn.metrics import brier_score_loss 12 | import cox_utils 13 | from sksurv.linear_model import CoxPHSurvivalAnalysis 14 | 15 | 16 | 17 | parser = argparse.ArgumentParser(description='Cox-PH survival models') 18 | 19 | parser.add_argument('--data_dir', default='.', type=str) 20 | parser.add_argument('--cluster_name', default='gmm_50.pkl', type=str) 21 | parser.add_argument('--normalize', default='mean', type=str) 22 | 23 | 24 | args = parser.parse_args() 25 | 26 | cluster = pickle.load(open(os.path.join(args.data_dir, args.cluster_name), 'rb')) 27 | cluster_method = type(cluster).__name__ 28 | if cluster_method == 'GaussianMixture': 29 | n_clusters = len(cluster.weights_) 30 | else: 31 | n_clusters = cluster.n_clusters 32 | 33 | train_data, val_data, test_data = utils.load_data(args.data_dir, 34 | os.path.join(args.data_dir, args.cluster_name), normalize=args.normalize) 35 | 36 | train_df, val_df, test_df = utils.preprocess_data(train_data, val_data, test_data) 37 | if data_source == 'TCGA': 38 | test_df = test_df.loc[test_df['tcga_flag']==1.0] 39 | elif data_source =='CPTAC': 40 | test_df = test_df.loc[test_df['tcga_flag']==0.0] 41 | train_df, val_df, test_df = train_df.drop(columns=['tcga_flag']), val_df.drop(columns=['tcga_flag']), test_df.drop(columns=['tcga_flag']) 42 | y_train = np.array([tuple((bool(row[0]), row[1])) for row in zip(train_df['outcome'], train_df['day'])], 43 | dtype=[('outcome', 'bool'), ('day', ' 0] 42 | train_df['day'] = train_df['day'] // 180 + 1 43 | 44 | if 'tcga' not in val_data: 45 | val_df = pd.DataFrame(np.concatenate([val_data['cluster'], np.stack([val_data['recur_day'], val_data['followup_day'], val_data['outcome']]).T], axis=1), 46 | columns=['c_{}'.format(i) for i in range(50)] + ['recur', 'followup', 'outcome']) 47 | else: 48 | val_df = pd.DataFrame(np.concatenate([val_data['cluster'], np.stack([val_data['recur_day'], val_data['followup_day'], val_data['outcome']]).T], axis=1),columns=['c_{}'.format(i) for i in range(50)] + ['recur', 'followup', 'outcome']) 49 | val_df = val_df[(val_df['recur'].notna() | val_df['followup'].notna())] 50 | val_df['day'] = val_df['recur'] 51 | val_df.loc[val_df['recur'].isna(), 'day'] = val_df.loc[val_df['recur'].isna(), 'followup'] 52 | val_df = val_df.drop(columns=['recur', 'followup']) 53 | val_df = val_df[val_df['day'] > 0] 54 | val_df['day'] = val_df['day'] // 180 + 1 55 | 56 | 57 | if 'tcga' not in test_data: 58 | test_df = pd.DataFrame(np.concatenate([test_data['cluster'], np.stack([test_data['recur_day'], test_data['followup_day'], test_data['outcome']]).T], axis=1), 59 | columns=['c_{}'.format(i) for i in range(50)] + ['recur', 'followup', 'outcome']) 60 | else: 61 | test_df = pd.DataFrame(np.concatenate([test_data['cluster'], np.stack([test_data['recur_day'], test_data['followup_day'], test_data['outcome']]).T], axis=1), 62 | columns=['c_{}'.format(i) for i in range(50)] + ['recur', 'followup', 'outcome']) 63 | 64 | test_df['day'] = test_df['recur'] 65 | test_df = test_df[(test_df['recur'].notna() | test_df['followup'].notna())] 66 | 67 | test_df.loc[test_df['recur'].isna(), 'day'] = test_df.loc[test_df['recur'].isna(), 'followup'] 68 | test_df = test_df.drop(columns=['recur', 'followup']) 69 | test_df = test_df[test_df['day'] > 0] 70 | test_df['day'] = test_df['day'] // 180 + 1 71 | return train_df, val_df, test_df 72 | 73 | 74 | def label_cluster(feature, cluster): 75 | clusters = defaultdict() 76 | cluster_method = type(cluster).__name__ 77 | for k in tqdm(list(feature.keys())): 78 | if cluster_method == 'GaussianMixture': 79 | clusters[k] = cluster.predict_proba(feature[int(k)]) 80 | else: 81 | clusters[k] = cluster.predict(feature[int(k)]) 82 | return clusters 83 | 84 | 85 | def load_data(data_dir, cluster_dir, normalize='mean', cls=1): 86 | split_dir = data_dir.rsplit('/', 1)[0] + '/' 87 | cluster = pickle.load(open(cluster_dir, 'rb')) 88 | cluster_method = type(cluster).__name__ 89 | if cluster_method == 'GaussianMixture': 90 | n_clusters = len(cluster.weights_) 91 | else: 92 | n_clusters = cluster.n_clusters 93 | train_features = pickle.load(open(data_dir + 'train_embedding.pkl', 'rb')) 94 | train_outcomes = pickle.load(open(data_dir + 'train_outcomes.pkl', 'rb')) 95 | val_features = pickle.load(open(data_dir + 'val_embedding.pkl', 'rb')) 96 | val_outcomes = pickle.load(open(data_dir + 'val_outcomes.pkl', 'rb')) 97 | test_features = pickle.load(open(data_dir + 'test_embedding.pkl', 'rb')) 98 | test_outcomes = pickle.load(open(data_dir + 'test_outcomes.pkl', 'rb')) 99 | 100 | val_tcga_flag = np.array([]) 101 | test_tcga_flag = np.array([]) 102 | 103 | 104 | train_cluster = label_cluster(train_features, cluster) 105 | val_cluster = label_cluster(val_features, cluster) 106 | test_cluster = label_cluster(test_features, cluster) 107 | 108 | train_data = transform(train_features, train_cluster, train_outcomes, train_tcga_flag, n_clusters, normalize, demo=False) 109 | val_data = transform(val_features, val_cluster, val_outcomes, val_tcga_flag, n_clusters, normalize, demo=False) 110 | test_data = transform(test_features, test_cluster, test_outcomes, test_tcga_flag, n_clusters, normalize, demo=False) 111 | 112 | return train_data, val_data, test_data 113 | 114 | 115 | def counter(arr, n): 116 | count = defaultdict(lambda: 0) 117 | for k, v in Counter(arr).items(): 118 | count[k] = v 119 | return [count[i] for i in range(n)] 120 | 121 | 122 | def transform(features, cluster, outcomes, tcga_flag, n_clusters, normalize='count', cls=1, weight=None, demo=None): 123 | count_list = [] 124 | outcome_list = [] 125 | recur_day_list = [] 126 | followup_day_list = [] 127 | raw_featuer_list = [] 128 | tile_outcome = [] 129 | demo_list = [] 130 | tcga_flag_list = [] 131 | cluster_method = type(cluster).__name__ 132 | for k in cluster: 133 | for v in features[int(k)]: 134 | raw_featuer_list.append(v) 135 | tile_outcome.append(outcomes[int(k)]) 136 | if normalize == 'mean': 137 | count_list.append(cluster[int(k)].mean(axis=0)) 138 | else: 139 | count_list.append(counter(cluster[int(k)], n_clusters)) 140 | 141 | outcome_list.append(outcomes[int(k)]['recurrence']) 142 | recur_day_list.append(outcomes[int(k)]['recurrence_free_days']) 143 | followup_day_list.append(outcomes[int(k)]['followup_days']) 144 | count_list, outcome_list = np.array(count_list), np.array(outcome_list) 145 | count_list = count_list + 1e-10 146 | 147 | if normalize == 'mean': 148 | cluster_features = (count_list.T/count_list.sum(axis=1)).T 149 | elif normalize == 'count': 150 | cluster_features = (count_list.T/count_list.sum(axis=1)).T 151 | elif normalize == 'onehot': 152 | cluster_features = (count_list > 1e-5) 153 | elif normalize == 'avg': 154 | cluster_features = count_list 155 | elif normalize == 'weight': 156 | cluster_features = count_list * weight 157 | elif normalize == 'sum': 158 | cluster_features = count_list 159 | return {'cluster': cluster_features, 160 | 'tile_feat': np.array(raw_featuer_list), 161 | 'tile_outcome': np.array(tile_outcome), 162 | 'recur_day': np.array(recur_day_list), 163 | 'followup_day': np.array(followup_day_list), 164 | 'outcome': outcome_list, 165 | 'demo': demo_list, 166 | 'tcga': np.array(tcga_flag_list), 167 | } -------------------------------------------------------------------------------- /train_ssl.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import builtins 3 | import math 4 | import os 5 | import random 6 | import shutil 7 | import time 8 | import warnings 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.parallel 13 | import torch.backends.cudnn as cudnn 14 | import torch.distributed as dist 15 | import torch.optim 16 | import torch.multiprocessing as mp 17 | import torch.utils.data 18 | import torch.utils.data.distributed 19 | import torchvision.transforms as transforms 20 | import torchvision.datasets as datasets 21 | import torchvision.models as models 22 | 23 | import condssl.loader 24 | import condssl.builder 25 | 26 | import sys 27 | from network.inception_v4 import InceptionV4 28 | from dataset.dataloader import TCGA_CPTAC_Dataset 29 | 30 | 31 | class TwoCropsTransform: 32 | """Take two random crops of one image as the query and key.""" 33 | 34 | def __init__(self, base_transform): 35 | self.base_transform = base_transform 36 | 37 | def __call__(self, x): 38 | q = self.base_transform(x) 39 | k = self.base_transform(x) 40 | return [q, k] 41 | 42 | 43 | def collate_fn_moco(batch): 44 | q_list = [] 45 | k_list = [] 46 | for imgs, indices in batch: 47 | for img in imgs: 48 | q_list.append(img[0].unsqueeze(0)) 49 | k_list.append(img[1].unsqueeze(0)) 50 | return torch.cat(q_list, dim=0), torch.cat(k_list, dim=0) 51 | 52 | 53 | def train(train_loader, model, criterion, optimizer, epoch, args): 54 | batch_time = AverageMeter('Time', ':6.3f') 55 | data_time = AverageMeter('Data', ':6.3f') 56 | losses = AverageMeter('Loss', ':.4e') 57 | top1 = AverageMeter('Acc@1', ':6.2f') 58 | top5 = AverageMeter('Acc@5', ':6.2f') 59 | progress = ProgressMeter( 60 | len(train_loader), 61 | [batch_time, data_time, losses, top1, top5], 62 | prefix="Epoch: [{}]".format(epoch)) 63 | 64 | # switch to train mode 65 | model.train() 66 | 67 | end = time.time() 68 | for i, (images) in enumerate(train_loader): 69 | # measure data loading time 70 | data_time.update(time.time() - end) 71 | if args.gpu is not None: 72 | images[0] = images[0].cuda(args.gpu, non_blocking=True) 73 | images[1] = images[1].cuda(args.gpu, non_blocking=True) 74 | # compute output 75 | output, target = model(im_q=images[0].cuda(), im_k=images[1].cuda()) 76 | loss = criterion(output, target) 77 | 78 | # acc1/acc5 are (K+1)-way contrast classifier accuracy 79 | # measure accuracy and record loss 80 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 81 | losses.update(loss.item(), images[0].size(0)) 82 | top1.update(acc1[0], images[0].size(0)) 83 | top5.update(acc5[0], images[0].size(0)) 84 | 85 | # compute gradient and do SGD step 86 | optimizer.zero_grad() 87 | loss.backward() 88 | optimizer.step() 89 | 90 | # measure elapsed time 91 | batch_time.update(time.time() - end) 92 | end = time.time() 93 | 94 | if i % args.print_freq == 0: 95 | progress.display(i) 96 | 97 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 98 | torch.save(state, filename) 99 | if is_best: 100 | shutil.copyfile(filename, 'model_best.pth.tar') 101 | 102 | 103 | class AverageMeter(object): 104 | """Computes and stores the average and current value""" 105 | def __init__(self, name, fmt=':f'): 106 | self.name = name 107 | self.fmt = fmt 108 | self.reset() 109 | 110 | def reset(self): 111 | self.val = 0 112 | self.avg = 0 113 | self.sum = 0 114 | self.count = 0 115 | 116 | def update(self, val, n=1): 117 | self.val = val 118 | self.sum += val * n 119 | self.count += n 120 | self.avg = self.sum / self.count 121 | 122 | def __str__(self): 123 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 124 | return fmtstr.format(**self.__dict__) 125 | 126 | 127 | class ProgressMeter(object): 128 | def __init__(self, num_batches, meters, prefix=""): 129 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 130 | self.meters = meters 131 | self.prefix = prefix 132 | 133 | def display(self, batch): 134 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 135 | entries += [str(meter) for meter in self.meters] 136 | print('\t'.join(entries)) 137 | 138 | def _get_batch_fmtstr(self, num_batches): 139 | num_digits = len(str(num_batches // 1)) 140 | fmt = '{:' + str(num_digits) + 'd}' 141 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 142 | 143 | 144 | def adjust_learning_rate(optimizer, epoch, args): 145 | """Decay the learning rate based on schedule""" 146 | lr = args.lr 147 | if args.cos: # cosine lr schedule 148 | lr *= 0.5 * (1. + math.cos(math.pi * epoch / args.epochs)) 149 | else: # stepwise lr schedule 150 | for milestone in args.schedule: 151 | lr *= 0.1 if epoch >= milestone else 1. 152 | for param_group in optimizer.param_groups: 153 | param_group['lr'] = lr 154 | 155 | 156 | def accuracy(output, target, topk=(1,)): 157 | """Computes the accuracy over the k top predictions for the specified values of k""" 158 | with torch.no_grad(): 159 | maxk = max(topk) 160 | batch_size = target.size(0) 161 | 162 | _, pred = output.topk(maxk, 1, True, True) 163 | pred = pred.t() 164 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 165 | 166 | res = [] 167 | for k in topk: 168 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 169 | res.append(correct_k.mul_(100.0 / batch_size)) 170 | return res 171 | 172 | 173 | 174 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 175 | # parser.add_argument('data', metavar='DIR', 176 | # help='path to dataset') 177 | 178 | parser.add_argument('-j', '--workers', default=16, type=int, metavar='N', 179 | help='number of data loading workers (default: 32)') 180 | parser.add_argument('--epochs', default=200, type=int, metavar='N', 181 | help='number of total epochs to run') 182 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 183 | help='manual epoch number (useful on restarts)') 184 | parser.add_argument('-b', '--batch-size', default=1, type=int, 185 | metavar='N', 186 | help='mini-batch size (default: 256), this is the total ' 187 | 'batch size of all GPUs on the current node when ' 188 | 'using Data Parallel or Distributed Data Parallel') 189 | parser.add_argument('--lr', '--learning-rate', default=0.03, type=float, 190 | metavar='LR', help='initial learning rate', dest='lr') 191 | parser.add_argument('--schedule', default=[120, 160], nargs='*', type=int, 192 | help='learning rate schedule (when to drop lr by 10x)') 193 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 194 | help='momentum of SGD solver') 195 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, 196 | metavar='W', help='weight decay (default: 1e-4)', 197 | dest='weight_decay') 198 | parser.add_argument('-p', '--print-freq', default=10, type=int, 199 | metavar='N', help='print frequency (default: 10)') 200 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 201 | help='path to latest checkpoint (default: none)') 202 | parser.add_argument('--world-size', default=-1, type=int, 203 | help='number of nodes for distributed training') 204 | parser.add_argument('--rank', default=-1, type=int, 205 | help='node rank for distributed training') 206 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, 207 | help='url used to set up distributed training') 208 | parser.add_argument('--dist-backend', default='nccl', type=str, 209 | help='distributed backend') 210 | parser.add_argument('--seed', default=None, type=int, 211 | help='seed for initializing training. ') 212 | parser.add_argument('--gpu', default=None, type=int, 213 | help='GPU id to use.') 214 | parser.add_argument('--multiprocessing-distributed', action='store_true', 215 | help='Use multi-processing distributed training to launch ' 216 | 'N processes per node, which has N GPUs. This is the ' 217 | 'fastest way to use PyTorch for either single node or ' 218 | 'multi node data parallel training') 219 | parser.add_argument("--local_rank", type=int, default=0) 220 | 221 | # moco specific configs: 222 | parser.add_argument('--moco-dim', default=128, type=int, 223 | help='feature dimension (default: 128)') 224 | parser.add_argument('--moco-k', default=65536, type=int, 225 | help='queue size; number of negative keys (default: 65536)') 226 | parser.add_argument('--moco-m', default=0.999, type=float, 227 | help='moco momentum of updating key encoder (default: 0.999)') 228 | parser.add_argument('--moco-t', default=0.07, type=float, 229 | help='softmax temperature (default: 0.07)') 230 | 231 | # options for moco v2 232 | parser.add_argument('--mlp', action='store_true', 233 | help='use mlp head') 234 | parser.add_argument('--aug-plus', action='store_true', 235 | help='use moco v2 data augmentation') 236 | parser.add_argument('--cos', action='store_true', 237 | help='use cosine lr schedule') 238 | 239 | parser.add_argument('--partition_name', default='train_Lung', type=str) 240 | parser.add_argument('--data_dir', default='./data/', type=str, 241 | help='path to output directory') 242 | parser.add_argument('--split_dir', default='./split/', type=str, 243 | help='path to output directory') 244 | parser.add_argument('--out_dir', default='./models/', type=str, 245 | help='path to output directory') 246 | parser.add_argument('--batch_slide_num', default=4, type=int) 247 | parser.add_argument('--condition', default=True, type=bool) 248 | 249 | args = parser.parse_args() 250 | 251 | print(args.out_dir) 252 | 253 | if args.seed is not None: 254 | random.seed(args.seed) 255 | torch.manual_seed(args.seed) 256 | cudnn.deterministic = True 257 | warnings.warn('You have chosen to seed training. ' 258 | 'This will turn on the CUDNN deterministic setting, ' 259 | 'which can slow down your training considerably! ' 260 | 'You may see unexpected behavior when restarting ' 261 | 'from checkpoints.') 262 | 263 | if args.gpu is not None: 264 | warnings.warn('You have chosen a specific GPU. This will completely ' 265 | 'disable data parallelism.') 266 | 267 | if args.dist_url == "env://" and args.world_size == -1: 268 | args.world_size = int(os.environ["WORLD_SIZE"]) 269 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 270 | ngpus_per_node = torch.cuda.device_count() 271 | 272 | 273 | print("=> creating model '{}'".format(args.arch)) 274 | 275 | encoder = InceptionV4 276 | 277 | model = condssl.builder.MoCo( 278 | encoder, args.moco_dim, args.moco_k, args.moco_m, args.moco_t, args.mlp, condition=args.condition) 279 | 280 | model = model.cuda() 281 | torch.distributed.init_process_group('nccl') 282 | 283 | model = torch.nn.parallel.DistributedDataParallel(model) 284 | 285 | print('Model builder Done.') 286 | criterion = nn.CrossEntropyLoss().cuda(args.gpu) 287 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 288 | momentum=args.momentum, 289 | weight_decay=args.weight_decay) 290 | 291 | augmentation = [ 292 | transforms.RandomResizedCrop(299, scale=(0.2, 1.)), 293 | transforms.RandomApply([ 294 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) # not strengthened 295 | ], p=0.8), 296 | transforms.RandomGrayscale(p=0.2), 297 | transforms.RandomApply([moco.loader.GaussianBlur([.1, 2.])], p=0.5), 298 | transforms.RandomHorizontalFlip(), 299 | transforms.RandomVerticalFlip(), 300 | transforms.ToTensor() 301 | # normalize 302 | ] 303 | 304 | print('Create dataset') 305 | 306 | 307 | train_dataset = TCGA_CPTAC_Dataset(cptac_dir=args.data_dir + "/CPTAC/tiles/", 308 | tcga_dir=args.data_dir + "/TCGA/tiles/", 309 | split_dir=arg.split_dir, 310 | transform=TwoCropsTransform(transforms.Compose(augmentation)), 311 | batch_slide_num=args.batch_slide_num) 312 | 313 | 314 | print("Dataset Created ...") 315 | print(args.batch_slide_num) 316 | 317 | if args.distributed: 318 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 319 | else: 320 | train_sampler = None 321 | train_loader = torch.utils.data.DataLoader( 322 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 323 | num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True, collate_fn=collate_fn_moco) 324 | 325 | 326 | for epoch in range(args.start_epoch, args.epochs): 327 | if args.distributed: 328 | train_sampler.set_epoch(epoch) 329 | adjust_learning_rate(optimizer, epoch, args) 330 | 331 | # train for one epoch 332 | train(train_loader, model, criterion, optimizer, epoch, args) 333 | 334 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed 335 | and args.rank % ngpus_per_node == 0): 336 | if (epoch + 1) % 50 == 0: 337 | save_checkpoint({ 338 | 'epoch': epoch + 1, 339 | 'arch': args.arch, 340 | 'state_dict': model.state_dict(), 341 | 'optimizer' : optimizer.state_dict(), 342 | }, is_best=False, filename='{}/checkpoint_{:04d}.pth.tar'.format(args.out_dir, epoch)) 343 | --------------------------------------------------------------------------------