├── README.md ├── condssl ├── __init__.py ├── builder.py └── loader.py ├── dataset └── dataloader.py ├── feature_extraction ├── extract_embeddings.py └── get_clusters.py ├── network └── inception_v4.py ├── plots ├── progression_plot.png └── umap.png ├── preprocessing ├── colorstandard.png ├── process_cptac.py ├── process_tcga.py └── utils.py ├── survival_models ├── cox.py └── utils.py └── train_ssl.py /README.md: -------------------------------------------------------------------------------- 1 | # Conditional Self-supervised for histopathology images 2 | 3 | This repository contains the code for the paper [Interpretable Prediction of Lung Squamous Cell Carcinoma Recurrence With Self-supervised Learning](https://arxiv.org/pdf/2203.12204.pdf). 4 | 5 | ## Introduction 6 | In this study, we explore the morphological features of LSCC recurrence and metastasis with novel SSL method, based on conditional SSL. We propose a sampling mechanism within contrastive SSL framework for histopathology images that avoids overfitting to batch effects. 7 | 8 | The 2D UMAP projection of tile representations, trained by different sampling 9 | in self-supervised learning. Tiles from 8 slides with mostly LSCC tumor content 10 | are highlighted with different colors. Left: model trained by MoCo contrastive 11 | learning with uniform sampling. It shows that tiles within each slide cluster 12 | together. Right: model trained with proposed conditional contrastive learning. 13 | The tiles from each slide are less clustered together. 14 |  15 | 16 | The Kaplan-Meier curves shows rates of recurrence-free patients over time in 17 | sub-cohorts of test set with different criterion. Two sub-cohorts stratified with the predicted 18 | recurrence risk by our Cox regression. The high risk cohort includes the top half 19 | patients of highest estimated risks; the low risk cohort includes the lower half. 20 | 21 |
22 |
23 |
77 |89 | -------------------------------------------------------------------------------- /condssl/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | -------------------------------------------------------------------------------- /condssl/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class MoCo(nn.Module): 7 | """ 8 | Build a MoCo model with: a query encoder, a key encoder, and a queue 9 | https://arxiv.org/abs/1911.05722 10 | """ 11 | def __init__(self, base_encoder, dim=128, K=65536, m=0.999, T=0.07, mlp=False, condition=True): 12 | """ 13 | dim: feature dimension (default: 128) 14 | K: queue size; number of negative keys (default: 65536) 15 | m: moco momentum of updating key encoder (default: 0.999) 16 | T: softmax temperature (default: 0.07) 17 | """ 18 | super(MoCo, self).__init__() 19 | 20 | self.condition = condition 21 | self.K = K 22 | self.m = m 23 | self.T = T 24 | 25 | # create the encoders 26 | # num_classes is the output fc dimension 27 | self.encoder_q = base_encoder(num_classes=dim) 28 | self.encoder_k = base_encoder(num_classes=dim) 29 | 30 | 31 | if mlp: # hack: brute-force replacement 32 | dim_mlp = self.encoder_q.fc.weight.shape[1] 33 | self.encoder_q.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_q.fc) 34 | self.encoder_k.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_k.fc) 35 | 36 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): 37 | param_k.data.copy_(param_q.data) # initialize 38 | param_k.requires_grad = False # not update by gradient 39 | 40 | # create the queue 41 | self.register_buffer("queue", torch.randn(dim, K)) 42 | self.queue = nn.functional.normalize(self.queue, dim=0) 43 | 44 | self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) 45 | 46 | @torch.no_grad() 47 | def _momentum_update_key_encoder(self): 48 | """ 49 | Momentum update of the key encoder 50 | """ 51 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): 52 | param_k.data = param_k.data * self.m + param_q.data * (1. - self.m) 53 | 54 | @torch.no_grad() 55 | def _dequeue_and_enqueue(self, keys): 56 | # gather keys before updating queue 57 | keys = concat_all_gather(keys) 58 | 59 | batch_size = keys.shape[0] 60 | 61 | ptr = int(self.queue_ptr) 62 | assert self.K % batch_size == 0 # for simplicity 63 | 64 | # replace the keys at ptr (dequeue and enqueue) 65 | # self.queue[:, ptr:ptr + batch_size] = keys.T 66 | self.queue[:, ptr:ptr + batch_size] = torch.t(keys) 67 | ptr = (ptr + batch_size) % self.K # move pointer 68 | 69 | self.queue_ptr[0] = ptr 70 | 71 | @torch.no_grad() 72 | def _batch_shuffle_ddp(self, x): 73 | """ 74 | Batch shuffle, for making use of BatchNorm. 75 | *** Only support DistributedDataParallel (DDP) model. *** 76 | """ 77 | # gather from all gpus 78 | batch_size_this = x.shape[0] 79 | x_gather = concat_all_gather(x) 80 | batch_size_all = x_gather.shape[0] 81 | 82 | num_gpus = batch_size_all // batch_size_this 83 | 84 | # random shuffle index 85 | idx_shuffle = torch.randperm(batch_size_all).cuda() 86 | 87 | # broadcast to all gpus 88 | torch.distributed.broadcast(idx_shuffle, src=0) 89 | 90 | # index for restoring 91 | idx_unshuffle = torch.argsort(idx_shuffle) 92 | 93 | # shuffled index for this gpu 94 | gpu_idx = torch.distributed.get_rank() 95 | idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx] 96 | return x_gather[idx_this], idx_unshuffle 97 | 98 | @torch.no_grad() 99 | def _batch_unshuffle_ddp(self, x, idx_unshuffle): 100 | """ 101 | Undo batch shuffle. 102 | *** Only support DistributedDataParallel (DDP) model. *** 103 | """ 104 | # gather from all gpus 105 | batch_size_this = x.shape[0] 106 | x_gather = concat_all_gather(x) 107 | batch_size_all = x_gather.shape[0] 108 | 109 | num_gpus = batch_size_all // batch_size_this 110 | 111 | # restored index for this gpu 112 | gpu_idx = torch.distributed.get_rank() 113 | idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx] 114 | 115 | return x_gather[idx_this] 116 | 117 | def forward(self, im_q, im_k): 118 | """ 119 | Input: 120 | im_q: a batch of query images 121 | im_k: a batch of key images 122 | Output: 123 | logits, targets 124 | """ 125 | 126 | # compute query features 127 | q = self.encoder_q(im_q) # queries: NxC 128 | q = nn.functional.normalize(q, dim=1) 129 | 130 | # compute key features 131 | with torch.no_grad(): # no gradient to keys 132 | self._momentum_update_key_encoder() # update the key encoder 133 | 134 | # shuffle for making use of BN 135 | im_k, idx_unshuffle = self._batch_shuffle_ddp(im_k) 136 | 137 | k = self.encoder_k(im_k) # keys: NxC 138 | k = nn.functional.normalize(k, dim=1) 139 | 140 | # undo shuffle 141 | k = self._batch_unshuffle_ddp(k, idx_unshuffle) 142 | 143 | 144 | if self.condition: 145 | # conditional ssl 146 | logits = torch.mm(q, k.T) / self.T 147 | labels = torch.arange(logits.shape[0], dtype=torch.long).cuda() 148 | return logits, labels 149 | 150 | else: 151 | # compute logits 152 | # Einstein sum is more intuitive 153 | # positive logits: Nx1 154 | l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1) 155 | 156 | # negative logits: NxK 157 | l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()]) 158 | 159 | # logits: Nx(1+K) 160 | logits = torch.cat([l_pos, l_neg], dim=1) 161 | 162 | # apply temperature 163 | logits /= self.T 164 | 165 | # labels: positive key indicators 166 | labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda() 167 | 168 | # dequeue and enqueue 169 | self._dequeue_and_enqueue(k) 170 | return logits, labels 171 | 172 | # utils 173 | @torch.no_grad() 174 | def concat_all_gather(tensor): 175 | """ 176 | Performs all_gather operation on the provided tensors. 177 | *** Warning ***: torch.distributed.all_gather has no gradient. 178 | """ 179 | tensors_gather = [torch.ones_like(tensor) 180 | for _ in range(torch.distributed.get_world_size())] 181 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False) 182 | 183 | output = torch.cat(tensors_gather, dim=0) 184 | return output 185 | -------------------------------------------------------------------------------- /condssl/loader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | from PIL import ImageFilter 3 | import random 4 | 5 | 6 | class TwoCropsTransform: 7 | """Take two random crops of one image as the query and key.""" 8 | 9 | def __init__(self, base_transform): 10 | self.base_transform = base_transform 11 | 12 | def __call__(self, x): 13 | q = self.base_transform(x) 14 | k = self.base_transform(x) 15 | return [q, k] 16 | 17 | 18 | class GaussianBlur(object): 19 | """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709""" 20 | 21 | def __init__(self, sigma=[.1, 2.]): 22 | self.sigma = sigma 23 | 24 | def __call__(self, x): 25 | sigma = random.uniform(self.sigma[0], self.sigma[1]) 26 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma)) 27 | return x 28 | -------------------------------------------------------------------------------- /dataset/dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import torch 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | from torchvision import transforms 7 | import cv2 8 | 9 | class TCGA_CPTAC_Dataset(Dataset): 10 | def __init__(self, cptac_dir, tcga_dir, split_dir, transform=None, mode='train', batch_slide_num=4, batch_size=128): 11 | self.cptac_dir = cptac_dir 12 | self.tcga_dir = tcga_dir 13 | 14 | slide_list = pickle.load(open(split_dir + '/case_split.pkl', 'rb'))[mode + "_id"] 15 | # slide_list = [s for s in slide_list if "TCGA" in s] 16 | self.slide2tiles = {} 17 | for slide_id in slide_list: 18 | if "TCGA" in slide_id: 19 | self.slide2tiles[slide_id] = os.listdir(os.path.join(self.tcga_dir, slide_id)) 20 | else: 21 | self.slide2tiles[slide_id] = os.listdir(os.path.join(self.cptac_dir, slide_id)) 22 | self.idx2tiles = [os.path.join(slide_id, tile_name) for slide_id, tile_list in self.slide2tiles.items() 23 | for tile_name in tile_list if 'jpg' in tile_name] 24 | self.tiles2idx = dict(zip(self.idx2tiles, range(len(self.idx2tiles)))) 25 | self.idx2slide = slide_list 26 | self.slide2idx = dict(zip(self.idx2slide, range(len(self.idx2slide)))) 27 | self.transform = transform 28 | self.batch_slide_num = batch_slide_num 29 | self.batch_size = batch_size 30 | 31 | def __getitem__(self, index): 32 | tile_names = [] 33 | slide_id = self.idx2slide[index] 34 | selected_tiles = [slide_id + '/' + t for t in np.random.choice(self.slide2tiles[slide_id], self.batch_size // self.batch_slide_num)] 35 | tile_names += selected_tiles 36 | for i in range(self.batch_slide_num - 1): 37 | slide_id = self.idx2slide[np.random.randint(len(self.idx2slide))] 38 | tile_names += [slide_id + '/' + t for t in np.random.choice(self.slide2tiles[slide_id], self.batch_size // self.batch_slide_num)] 39 | indices = [] 40 | imgs = [] 41 | for tile_name in tile_names: 42 | if "TCGA" in tile_name: 43 | image = cv2.imread(self.tcga_dir + tile_name) 44 | else: 45 | image = cv2.imread(self.cptac_dir + tile_name) 46 | image = Image.fromarray(image) 47 | image_tensor = self.transform(image) 48 | imgs.append(image_tensor) 49 | indices.append(index) 50 | return imgs, indices 51 | 52 | def __len__(self): 53 | return len(self.idx2slide) 54 | 55 | 56 | class TCGA_CPTAC_Bag_Dataset(Dataset): 57 | def __init__(self, data_dir, split_dir, mode='train'): 58 | self.data_dir = data_dir 59 | slide_list = pickle.load(open(os.path.join(split_dir, 'case_split_2yr.pkl'), 'rb'))[mode + '_id'] 60 | self.slide2tiles = {} 61 | for slide_id in slide_list: 62 | if "TCGA" in slide_id: 63 | tile_dir = self.data_dir + '/TCGA/tiles/' 64 | else: 65 | tile_dir = self.data_dir + '/CPTAC/tiles/' 66 | self.slide2tiles[slide_id] = os.listdir(os.path.join(tile_dir, slide_id)) 67 | self.idx2tiles = [os.path.join(slide_id, tile_name) for slide_id, tile_list in self.slide2tiles.items() 68 | for tile_name in tile_list if 'jpg' in tile_name] 69 | self.tiles2idx = dict(zip(self.idx2tiles, range(len(self.idx2tiles)))) 70 | self.idx2slide = slide_list 71 | self.slide2idx = dict(zip(self.idx2slide, range(len(self.idx2slide)))) 72 | self.transform = transforms.Compose([transforms.ToTensor()]) 73 | 74 | def __getitem__(self, index): 75 | tile_path = self.idx2tiles[index] 76 | slide_id, tile_name = tile_path.split('/') 77 | tile_idx = self.tiles2idx[tile_path] 78 | slide_idx = self.slide2idx[slide_id] 79 | if "TCGA" in tile_path: 80 | prefix = self.data_dir + '/TCGA/tiles/' 81 | else: 82 | prefix = self.data_dir + '/CPTAC/tiles/' 83 | image = cv2.imread(os.path.join(prefix, tile_path)) 84 | image = Image.fromarray(image) 85 | image_tensor = self.transform(image) 86 | return image_tensor, tile_idx, slide_idx 87 | 88 | def __len__(self): 89 | return len(self.idx2tiles) 90 | 91 | 92 | 93 | -------------------------------------------------------------------------------- /feature_extraction/extract_embeddings.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../') 3 | import argparse 4 | from inception_v4 import InceptionV4 5 | import numpy as np 6 | import pickle 7 | import torch 8 | import torch.nn as nn 9 | import os 10 | import cv2 11 | from torchvision import transforms 12 | from collections import defaultdict 13 | from tqdm import tqdm 14 | import pandas as pd 15 | from PIL import Image 16 | import shelve 17 | from dataset.dataloader import TCGA_CPTAC_Bag_Dataset 18 | 19 | if torch.cuda.is_available(): 20 | device = 'cuda' 21 | else: 22 | device = 'cpu' 23 | print(device) 24 | 25 | def get_embeddings_bagging(feature_extractor, subtype_model, data_set): 26 | embedding_dict = defaultdict(list) 27 | outcomes_dict = defaultdict(list) 28 | feature_extractor.eval() 29 | data_loader = torch.utils.data.DataLoader(data_set, batch_size=256, shuffle=False, num_workers=torch.cuda.device_count()) 30 | with torch.no_grad(): 31 | count = 0 32 | for batch in tqdm(data_loader, position=0, leave=True): 33 | count += 1 34 | img, _, bag_idx = batch 35 | feat = feature_extractor(img.to(device)).cpu() 36 | subtype_model.eval() 37 | subtype_prob = subtype_model(img) 38 | subtype_pred = torch.argmax(subtype_prob, dim=1) 39 | tumor_idx = (subtype_pred != 0) 40 | feat = feat[tumor_idx].numpy() 41 | bag_idx = bag_idx[tumor_idx] 42 | for i in range(len(bag_idx)): 43 | embedding_dict[bag_idx[i].item()].append(feat[i][np.newaxis,:]) 44 | slide_id = data_set.idx2slide[bag_idx[i].item()] 45 | if "TCGA" in slide_id: 46 | case_id = '-'.join(slide_id.split('-', 3)[:3]) 47 | else: 48 | case_id = slide_id.rsplit('-', 1)[0] 49 | outcomes_dict[bag_idx[i].item()] = annotations[case_id] 50 | for k in embedding_dict: 51 | embedding_dict[k] = np.concatenate(embedding_dict[k], axis=0) 52 | return embedding_dict, outcomes_dict 53 | 54 | def load_pretrained(net, model_dir): 55 | 56 | print(model_dir) 57 | checkpoint = torch.load(model_dir) 58 | model_state_dict = {k.replace("module.encoder_q.", ""): v for k, v in checkpoint['state_dict'].items() if 59 | "encoder_q" in k} 60 | net.load_state_dict(model_state_dict) 61 | net.last_linear = nn.Identity() 62 | 63 | parser = argparse.ArgumentParser(description='Extract embeddings ') 64 | 65 | parser.add_argument('--feature_extractor_dir', default='./pretrained/checkpoint.pth.tar', type=str) 66 | parser.add_argument('--subtype_model_dir', default='./subtype_cls/checkpoint.pth.tar', type=str) 67 | parser.add_argument('--root_dir', type=str) 68 | parser.add_argument('--split_dir', type=str) 69 | parser.add_argument('--out_dir', type=str) 70 | 71 | args = parser.parse_args() 72 | 73 | tcga_annotation = pickle.load(open('../TCGA/recurrence_annotation.pkl', 'rb')) 74 | cptac_annotation = pickle.load(open('../CPTAC/recurrence_annotation.pkl', 'rb')) 75 | annotations = {**tcga_annotation, **cptac_annotation} 76 | feature_extractor = InceptionV4(num_classes=256) 77 | load_pretrained(feature_extractor, args.feature_extractor_dir) 78 | feature_extractor.to('cuda') 79 | feature_extractor = nn.DataParallel(feature_extractor, device_ids=device_ids) 80 | 81 | subtype_model = InceptionV4(num_classes=2).to('cuda') 82 | subtype_model.load_state_dict(torch.load(args.subtype_model_dir)) 83 | subtype_model = nn.DataParallel(subtype_model, device_ids=device_ids) 84 | 85 | 86 | train_dataset = TCGA_CPTAC_Bag_Dataset(args.root_dir, args.split_dir, 'train') 87 | val_dataset = TCGA_CPTAC_Bag_Dataset(args.root_dir, args.split_dir, 'val') 88 | test_dataset = TCGA_CPTAC_Bag_Dataset(args.root_dir, args.split_dir, 'test') 89 | 90 | 91 | with torch.no_grad(): 92 | names = ['train', 'val', 'test'] 93 | for name, data_set in zip(names, [train_dataset, val_dataset, test_dataset]): 94 | print(name) 95 | embedding_dict, outcomes_dict = get_embeddings_bagging(feature_extractor, subtype_model, data_set) 96 | pickle.dump(embedding_dict, open("{}_{}_embedding.pkl".format(args.out_dir), 'wb'), protocol=4) 97 | pickle.dump((outcomes_dict), open("{}_{}_outcomes.pkl".format(args.out_dir), 'wb'), protocol=4) 98 | 99 | -------------------------------------------------------------------------------- /feature_extraction/get_clusters.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from sklearn.cluster import KMeans 3 | from sklearn.mixture import GaussianMixture 4 | import pickle 5 | import numpy as np 6 | import sys 7 | 8 | 9 | 10 | parser = argparse.ArgumentParser(description='Get cluster features') 11 | 12 | parser.add_argument('--data_dir', default='./', type=str) 13 | parser.add_argument('--cluster_type', default='gmm', type=str) 14 | parser.add_argument('--n_cluster', default=50, type=int) 15 | parser.add_argument('--out_dir', default='./', type=str) 16 | 17 | 18 | args = parser.parse_args() 19 | 20 | train_features = pickle.load(open(args.data_dir + 'train_embedding.pkl', 'rb')) 21 | train_features = np.concatenate(list(train_features.values()), axis=0) 22 | if args.cluster_type == "kmeans": 23 | print('kmeans') 24 | cluster = KMeans(n_clusters=n_cluster).fit(train_features) 25 | pickle.dump(cluster, open(data_dir + 'kmeans_{}.pkl'.format(args.n_cluster), 'wb')) 26 | else: 27 | print('gmm') 28 | cluster = GaussianMixture(n_components=n_cluster).fit(train_features) 29 | pickle.dump(cluster, open(data_dir + 'gmm_{}.pkl'.format(args.n_cluster), 'wb')) 30 | 31 | 32 | -------------------------------------------------------------------------------- /network/inception_v4.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.utils.model_zoo as model_zoo 5 | import os 6 | import sys 7 | 8 | 9 | class BasicConv2d(nn.Module): 10 | 11 | def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0): 12 | super(BasicConv2d, self).__init__() 13 | self.conv = nn.Conv2d(in_planes, out_planes, 14 | kernel_size=kernel_size, stride=stride, 15 | padding=padding, bias=False) # verify bias false 16 | self.bn = nn.BatchNorm2d(out_planes, 17 | eps=0.001, # value found in tensorflow 18 | momentum=0.1, # default pytorch value 19 | affine=True) 20 | self.relu = nn.ReLU(inplace=True) 21 | 22 | def forward(self, x): 23 | x = self.conv(x) 24 | x = self.bn(x) 25 | x = self.relu(x) 26 | return x 27 | 28 | 29 | class Mixed_3a(nn.Module): 30 | 31 | def __init__(self): 32 | super(Mixed_3a, self).__init__() 33 | self.maxpool = nn.MaxPool2d(3, stride=2) 34 | self.conv = BasicConv2d(64, 96, kernel_size=3, stride=2) 35 | 36 | def forward(self, x): 37 | x0 = self.maxpool(x) 38 | x1 = self.conv(x) 39 | out = torch.cat((x0, x1), 1) 40 | return out 41 | 42 | 43 | class Mixed_4a(nn.Module): 44 | 45 | def __init__(self): 46 | super(Mixed_4a, self).__init__() 47 | 48 | self.branch0 = nn.Sequential( 49 | BasicConv2d(160, 64, kernel_size=1, stride=1), 50 | BasicConv2d(64, 96, kernel_size=3, stride=1) 51 | ) 52 | 53 | self.branch1 = nn.Sequential( 54 | BasicConv2d(160, 64, kernel_size=1, stride=1), 55 | BasicConv2d(64, 64, kernel_size=(1,7), stride=1, padding=(0,3)), 56 | BasicConv2d(64, 64, kernel_size=(7,1), stride=1, padding=(3,0)), 57 | BasicConv2d(64, 96, kernel_size=(3,3), stride=1) 58 | ) 59 | 60 | def forward(self, x): 61 | x0 = self.branch0(x) 62 | x1 = self.branch1(x) 63 | out = torch.cat((x0, x1), 1) 64 | return out 65 | 66 | 67 | class Mixed_5a(nn.Module): 68 | 69 | def __init__(self): 70 | super(Mixed_5a, self).__init__() 71 | self.conv = BasicConv2d(192, 192, kernel_size=3, stride=2) 72 | self.maxpool = nn.MaxPool2d(3, stride=2) 73 | 74 | def forward(self, x): 75 | x0 = self.conv(x) 76 | x1 = self.maxpool(x) 77 | out = torch.cat((x0, x1), 1) 78 | return out 79 | 80 | 81 | class Inception_A(nn.Module): 82 | 83 | def __init__(self): 84 | super(Inception_A, self).__init__() 85 | self.branch0 = BasicConv2d(384, 96, kernel_size=1, stride=1) 86 | 87 | self.branch1 = nn.Sequential( 88 | BasicConv2d(384, 64, kernel_size=1, stride=1), 89 | BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1) 90 | ) 91 | 92 | self.branch2 = nn.Sequential( 93 | BasicConv2d(384, 64, kernel_size=1, stride=1), 94 | BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1), 95 | BasicConv2d(96, 96, kernel_size=3, stride=1, padding=1) 96 | ) 97 | 98 | self.branch3 = nn.Sequential( 99 | nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False), 100 | BasicConv2d(384, 96, kernel_size=1, stride=1) 101 | ) 102 | 103 | def forward(self, x): 104 | x0 = self.branch0(x) 105 | x1 = self.branch1(x) 106 | x2 = self.branch2(x) 107 | x3 = self.branch3(x) 108 | out = torch.cat((x0, x1, x2, x3), 1) 109 | return out 110 | 111 | 112 | class Reduction_A(nn.Module): 113 | 114 | def __init__(self): 115 | super(Reduction_A, self).__init__() 116 | self.branch0 = BasicConv2d(384, 384, kernel_size=3, stride=2) 117 | 118 | self.branch1 = nn.Sequential( 119 | BasicConv2d(384, 192, kernel_size=1, stride=1), 120 | BasicConv2d(192, 224, kernel_size=3, stride=1, padding=1), 121 | BasicConv2d(224, 256, kernel_size=3, stride=2) 122 | ) 123 | 124 | self.branch2 = nn.MaxPool2d(3, stride=2) 125 | 126 | def forward(self, x): 127 | x0 = self.branch0(x) 128 | x1 = self.branch1(x) 129 | x2 = self.branch2(x) 130 | out = torch.cat((x0, x1, x2), 1) 131 | return out 132 | 133 | 134 | class Inception_B(nn.Module): 135 | 136 | def __init__(self): 137 | super(Inception_B, self).__init__() 138 | self.branch0 = BasicConv2d(1024, 384, kernel_size=1, stride=1) 139 | 140 | self.branch1 = nn.Sequential( 141 | BasicConv2d(1024, 192, kernel_size=1, stride=1), 142 | BasicConv2d(192, 224, kernel_size=(1,7), stride=1, padding=(0,3)), 143 | BasicConv2d(224, 256, kernel_size=(7,1), stride=1, padding=(3,0)) 144 | ) 145 | 146 | self.branch2 = nn.Sequential( 147 | BasicConv2d(1024, 192, kernel_size=1, stride=1), 148 | BasicConv2d(192, 192, kernel_size=(7,1), stride=1, padding=(3,0)), 149 | BasicConv2d(192, 224, kernel_size=(1,7), stride=1, padding=(0,3)), 150 | BasicConv2d(224, 224, kernel_size=(7,1), stride=1, padding=(3,0)), 151 | BasicConv2d(224, 256, kernel_size=(1,7), stride=1, padding=(0,3)) 152 | ) 153 | 154 | self.branch3 = nn.Sequential( 155 | nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False), 156 | BasicConv2d(1024, 128, kernel_size=1, stride=1) 157 | ) 158 | 159 | def forward(self, x): 160 | x0 = self.branch0(x) 161 | x1 = self.branch1(x) 162 | x2 = self.branch2(x) 163 | x3 = self.branch3(x) 164 | out = torch.cat((x0, x1, x2, x3), 1) 165 | return out 166 | 167 | 168 | class Reduction_B(nn.Module): 169 | 170 | def __init__(self): 171 | super(Reduction_B, self).__init__() 172 | 173 | self.branch0 = nn.Sequential( 174 | BasicConv2d(1024, 192, kernel_size=1, stride=1), 175 | BasicConv2d(192, 192, kernel_size=3, stride=2) 176 | ) 177 | 178 | self.branch1 = nn.Sequential( 179 | BasicConv2d(1024, 256, kernel_size=1, stride=1), 180 | BasicConv2d(256, 256, kernel_size=(1,7), stride=1, padding=(0,3)), 181 | BasicConv2d(256, 320, kernel_size=(7,1), stride=1, padding=(3,0)), 182 | BasicConv2d(320, 320, kernel_size=3, stride=2) 183 | ) 184 | 185 | self.branch2 = nn.MaxPool2d(3, stride=2) 186 | 187 | def forward(self, x): 188 | x0 = self.branch0(x) 189 | x1 = self.branch1(x) 190 | x2 = self.branch2(x) 191 | out = torch.cat((x0, x1, x2), 1) 192 | return out 193 | 194 | 195 | class Inception_C(nn.Module): 196 | 197 | def __init__(self): 198 | super(Inception_C, self).__init__() 199 | 200 | self.branch0 = BasicConv2d(1536, 256, kernel_size=1, stride=1) 201 | 202 | self.branch1_0 = BasicConv2d(1536, 384, kernel_size=1, stride=1) 203 | self.branch1_1a = BasicConv2d(384, 256, kernel_size=(1,3), stride=1, padding=(0,1)) 204 | self.branch1_1b = BasicConv2d(384, 256, kernel_size=(3,1), stride=1, padding=(1,0)) 205 | 206 | self.branch2_0 = BasicConv2d(1536, 384, kernel_size=1, stride=1) 207 | self.branch2_1 = BasicConv2d(384, 448, kernel_size=(3,1), stride=1, padding=(1,0)) 208 | self.branch2_2 = BasicConv2d(448, 512, kernel_size=(1,3), stride=1, padding=(0,1)) 209 | self.branch2_3a = BasicConv2d(512, 256, kernel_size=(1,3), stride=1, padding=(0,1)) 210 | self.branch2_3b = BasicConv2d(512, 256, kernel_size=(3,1), stride=1, padding=(1,0)) 211 | 212 | self.branch3 = nn.Sequential( 213 | nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False), 214 | BasicConv2d(1536, 256, kernel_size=1, stride=1) 215 | ) 216 | 217 | def forward(self, x): 218 | x0 = self.branch0(x) 219 | 220 | x1_0 = self.branch1_0(x) 221 | x1_1a = self.branch1_1a(x1_0) 222 | x1_1b = self.branch1_1b(x1_0) 223 | x1 = torch.cat((x1_1a, x1_1b), 1) 224 | 225 | x2_0 = self.branch2_0(x) 226 | x2_1 = self.branch2_1(x2_0) 227 | x2_2 = self.branch2_2(x2_1) 228 | x2_3a = self.branch2_3a(x2_2) 229 | x2_3b = self.branch2_3b(x2_2) 230 | x2 = torch.cat((x2_3a, x2_3b), 1) 231 | 232 | x3 = self.branch3(x) 233 | 234 | out = torch.cat((x0, x1, x2, x3), 1) 235 | return out 236 | 237 | 238 | class InceptionV4(nn.Module): 239 | 240 | def __init__(self, num_classes=1001): 241 | super(InceptionV4, self).__init__() 242 | # Special attributs 243 | self.input_space = None 244 | self.input_size = (299, 299, 3) 245 | self.mean = None 246 | self.std = None 247 | # Modules 248 | self.features = nn.Sequential( 249 | BasicConv2d(3, 32, kernel_size=3, stride=2), 250 | BasicConv2d(32, 32, kernel_size=3, stride=1), 251 | BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1), 252 | Mixed_3a(), 253 | Mixed_4a(), 254 | Mixed_5a(), 255 | Inception_A(), 256 | Inception_A(), 257 | Inception_A(), 258 | Inception_A(), 259 | Reduction_A(), # Mixed_6a 260 | Inception_B(), 261 | Inception_B(), 262 | Inception_B(), 263 | Inception_B(), 264 | Inception_B(), 265 | Inception_B(), 266 | Inception_B(), 267 | Reduction_B(), # Mixed_7a 268 | Inception_C(), 269 | Inception_C(), 270 | Inception_C() 271 | ) 272 | self.last_linear = nn.Linear(1536, num_classes) 273 | 274 | def logits(self, features): 275 | #Allows image of any size to be processed 276 | adaptiveAvgPoolWidth = features.shape[2] 277 | x = F.avg_pool2d(features, kernel_size=adaptiveAvgPoolWidth) 278 | x = x.view(x.size(0), -1) 279 | x = self.last_linear(x) 280 | return x 281 | 282 | def forward(self, input): 283 | x = self.features(input) 284 | x = self.logits(x) 285 | return x 286 | 287 | 288 | def inceptionv4(num_classes=1000, pretrained='imagenet'): 289 | if pretrained: 290 | settings = pretrained_settings['inceptionv4'][pretrained] 291 | assert num_classes == settings['num_classes'], \ 292 | "num_classes should be {}, but is {}".format(settings['num_classes'], num_classes) 293 | 294 | # both 'imagenet'&'imagenet+background' are loaded from same parameters 295 | model = InceptionV4(num_classes=1001) 296 | model.load_state_dict(model_zoo.load_url(settings['url'])) 297 | 298 | if pretrained == 'imagenet': 299 | new_last_linear = nn.Linear(1536, 1000) 300 | new_last_linear.weight.data = model.last_linear.weight.data[1:] 301 | new_last_linear.bias.data = model.last_linear.bias.data[1:] 302 | model.last_linear = new_last_linear 303 | 304 | model.input_space = settings['input_space'] 305 | model.input_size = settings['input_size'] 306 | model.input_range = settings['input_range'] 307 | model.mean = settings['mean'] 308 | model.std = settings['std'] 309 | else: 310 | model = InceptionV4(num_classes=num_classes) 311 | return model -------------------------------------------------------------------------------- /plots/progression_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NYUMedML/conditional_ssl_hist/32c055130c9d4a40e66af50835b97364599792d6/plots/progression_plot.png -------------------------------------------------------------------------------- /plots/umap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NYUMedML/conditional_ssl_hist/32c055130c9d4a40e66af50835b97364599792d6/plots/umap.png -------------------------------------------------------------------------------- /preprocessing/colorstandard.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NYUMedML/conditional_ssl_hist/32c055130c9d4a40e66af50835b97364599792d6/preprocessing/colorstandard.png -------------------------------------------------------------------------------- /preprocessing/process_cptac.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import pandas as pd 4 | from utils import wsi_to_tiles 5 | import pickle 6 | 7 | parser = argparse.ArgumentParser(description='Process TCGA') 8 | 9 | parser.add_argument('--followup_path', default='./LSCC-clinicalTable.csv', type=str) 10 | parser.add_argument('--wsi_path', default='../CPTAC_WSI', type=str) 11 | parser.add_argument('--refer_img', default='./colorstandard.png', type=str) 12 | parser.add_argument('--s', default=0.9, type=float, help='The proportion of tissues') 13 | args = parser.parse_args() 14 | 15 | clinicalTable = pd.read_csv(args.followup_path).set_index('case_id') 16 | wsi_dir_dict = {} 17 | wsi_list = os.popen("find {} -name '*.svs'".format(args.wsi_path)).read().strip('\n').replace(wsi_path,'').split("\n") 18 | for slide_id in wsi_list: 19 | slide_id = slide_id.lstrip('/').rstrip('.svs') 20 | tile_path = os.path.join('../CPTAC/tiles', slide_id) 21 | if not os.path.exists(tile_path): 22 | os.mkdir(tile_path) 23 | 24 | for idx, wsi in enumerate(wsi_list): 25 | wsi_to_tiles(idx, wsi, args.refer_img, args.s) 26 | 27 | # Get annotation 28 | annotation = {} 29 | for case_id in clinicalTable.index: 30 | clinicalRow = clinicalTable.loc[case_id].to_dict() 31 | try: 32 | imageRow = imageTable.loc[case_id].to_dict(orient='list') 33 | slide_id = imageRow['Slide_ID'] 34 | except: 35 | slide_id = [] 36 | annotation[case_id] = {'recurrence': clinicalRow['Recurrence.status..1..yes..0..no.'], 37 | 'stage': stage_dict[clinicalRow['baseline.tumor_stage_pathological']], 38 | 'survival_days': clinicalRow['Overall.survival..days'], 39 | 'survival': clinicalRow['Survival.status..1..dead..0..alive.'], 40 | 'recurrence_free_days':clinicalRow['Recurrence.free.survival..days'], 41 | 'age':clinicalRow['consent.age'], 42 | 'gender':clinicalRow['consent.sex'], 43 | 'followup_days':clinicalRow['follow.up.number_of_days_from_date_of_initial_pathologic_diagnosis_to_date_of_last_contact'], 44 | 'slide_id': slide_id} 45 | pickle.dump(annotation, open('../CPTAC/recurrence_annotation.pkl', 'wb')) 46 | 47 | -------------------------------------------------------------------------------- /preprocessing/process_tcga.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import pandas as pd 4 | from utils import wsi_to_tiles 5 | import pickle 6 | 7 | parser = argparse.ArgumentParser(description='Process TCGA') 8 | 9 | parser.add_argument('--followup_path', default='./clinical_follow_up_v1.0_lusc.xlsx', type=str) 10 | parser.add_argument('--clinical_table_path', default='./clinical_follow_up_v1.0_lusc.xlsx', type=str) 11 | parser.add_argument('--wsi_path', default='../TCGA_WSI', type=str) 12 | parser.add_argument('--refer_img', default='./colorstandard.png', type=str) 13 | parser.add_argument('--s', default=0.9, type=float, help='The proportion of tissues') 14 | 15 | 16 | args = parser.parse_args() 17 | 18 | followupTable = pd.read_excel(args.followup_path, skiprows=[1,2], engine='openpyxl') 19 | followupTable = followupTable.loc[followupTable['new_tumor_event_dx_indicator'].isin({'YES', 'NO'})] 20 | followupTable['recurrence'] = ((followupTable['new_tumor_event_dx_indicator'] == 'YES') & 21 | (followupTable['new_tumor_event_type'] != 'New Primary Tumor')) 22 | followupTable = followupTable.sort_values(['bcr_patient_barcode', 'form_completion_date']).drop_duplicates('bcr_patient_barcode', keep='last') 23 | LUSC_patientids = set(followupTable['bcr_patient_barcode']) 24 | 25 | 26 | wsi_list = os.popen("find {} -name '*.svs'".format(wsi_path)).read().strip('\n').split('\n') 27 | wsi_list_LUSC = [] 28 | for idx in range(len(wsi_list)): 29 | slide_id = wsi_list[idx].rsplit('/', 1)[1].split('.')[0] 30 | patient_id = '-'.join(slide_id.split('-', 3)[:3]) 31 | tile_path = os.path.join('../TCGA/tiles', slide_id) 32 | if patient_id in LUSC_patientids: 33 | if not os.path.exists(tile_path): 34 | os.mkdir(tile_path) 35 | wsi_list_LUSC.append(wsi_list[idx]) 36 | 37 | for idx, wsi in enumerate(wsi_list_LUSC): 38 | wsi_to_tiles(idx, wsi, args.refer_img, args.s) 39 | 40 | # Get annotation 41 | clinicalTable = pd.read_csv(args.clinical_table_path).set_index('bcr_patient_barcode') 42 | annotation = defaultdict(lambda: {"recurrence": None, "slide_id": []}) 43 | slide_ids = os.listdir('./TCGA/tiles') 44 | included_slides = [s for s in slide_ids if s.rsplit('-',3)[0] in set(followupTable.index)] 45 | for slide_id in included_slides: 46 | case_id = '-'.join(slide_id.split('-', 3)[:3]) 47 | clinicalRow = followupTable.loc[case_id].to_dict() 48 | annotation[case_id]['recurrence'] = 1 if clinicalRow['recurrence'] else 0 49 | annotation[case_id]['slide_id'].append(slide_id) 50 | annotation[case_id]['stage'] = clinicalTable.loc[case_id]['ajcc_pathologic_tumor_stage'] 51 | annotation[case_id]['survival_days'] = clinicalTable.loc[case_id]['death_days_to'] 52 | annotation[case_id]['survival'] = clinicalTable.loc[case_id]['vital_status'] 53 | annotation[case_id]['recurrence_free_days'] = pd.to_numeric(followupTable.new_tumor_event_dx_days_to, errors='coerce').loc[case_id] 54 | annotation[case_id]['followup_days'] = pd.to_numeric(followupTable.last_contact_days_to, errors='coerce').loc[case_id] 55 | annotation[case_id]['gender'] = clinicalTable['gender'].loc[case_id] 56 | pickle.dump(annotation, open('../TCGA/recurrence_annotation.pkl', 'wb')) 57 | -------------------------------------------------------------------------------- /preprocessing/utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | from openslide import OpenSlide 3 | from PIL import Image 4 | import numpy as np 5 | import staintools 6 | 7 | 8 | def getGradientMagnitude(im): 9 | "Get magnitude of gradient for given image" 10 | ddepth = cv2.CV_32F 11 | dx = cv2.Sobel(im, ddepth, 1, 0) 12 | dy = cv2.Sobel(im, ddepth, 0, 1) 13 | dxabs = cv2.convertScaleAbs(dx) 14 | dyabs = cv2.convertScaleAbs(dy) 15 | mag = cv2.addWeighted(dxabs, 0.5, dyabs, 0.5, 0) 16 | return mag 17 | 18 | 19 | def wsi_to_tiles(idx, wsi, refer_img, s): 20 | normalizer = staintools.StainNormalizer(method='vahadane') 21 | refer_img = staintools.read_image(refer_img) 22 | normalizer.fit(refer_img) 23 | count = 0 24 | sys.stdout.write('Start task %d: %s \n' % (idx, wsi)) 25 | slide_id = wsi.rsplit('/', 1)[1].split('.')[0] 26 | tile_path = os.path.join('./tiles', slide_id) 27 | img = OpenSlide(os.path.join(wsi)) 28 | if str(img.properties.values.__self__.get('tiff.ImageDescription')).split("|")[1] == "AppMag = 40": 29 | sz = 2048 30 | seq = 1536 31 | else: 32 | sz = 1024 33 | seq = 768 34 | [w, h] = img.dimensions 35 | for x in range(1, w, seq): 36 | for y in range(1, h, seq): 37 | img_tmp = img.read_region(location=(x, y), level=0, size=(sz, sz)) \ 38 | .convert("RGB").resize((299, 299), Image.ANTIALIAS) 39 | grad = getGradientMagnitude(np.array(img_tmp)) 40 | unique, counts = np.unique(grad, return_counts=True) 41 | if counts[np.argwhere(unique <= 15)].sum() < 299 * 299 * s: 42 | img_tmp = normalizer.transform(np.array(img_tmp)) 43 | img_tmp = Image.fromarray(img_tmp) 44 | img_tmp.save(tile_path + "/" + str(x) + "_" + str(y) + '.jpg', 'JPEG', optimize=True, quality=94) 45 | count += 1 46 | sys.stdout.write('End task %d with %d tiles\n' % (idx, count)) 47 | 48 | -------------------------------------------------------------------------------- /survival_models/cox.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../') 3 | from lifelines import CoxPHFitter 4 | from lifelines.utils import concordance_index 5 | import pickle 6 | import utils 7 | import os 8 | import pandas as pd 9 | import numpy as np 10 | import matplotlib.pyplot as plt 11 | from sklearn.metrics import brier_score_loss 12 | import cox_utils 13 | from sksurv.linear_model import CoxPHSurvivalAnalysis 14 | 15 | 16 | 17 | parser = argparse.ArgumentParser(description='Cox-PH survival models') 18 | 19 | parser.add_argument('--data_dir', default='.', type=str) 20 | parser.add_argument('--cluster_name', default='gmm_50.pkl', type=str) 21 | parser.add_argument('--normalize', default='mean', type=str) 22 | 23 | 24 | args = parser.parse_args() 25 | 26 | cluster = pickle.load(open(os.path.join(args.data_dir, args.cluster_name), 'rb')) 27 | cluster_method = type(cluster).__name__ 28 | if cluster_method == 'GaussianMixture': 29 | n_clusters = len(cluster.weights_) 30 | else: 31 | n_clusters = cluster.n_clusters 32 | 33 | train_data, val_data, test_data = utils.load_data(args.data_dir, 34 | os.path.join(args.data_dir, args.cluster_name), normalize=args.normalize) 35 | 36 | train_df, val_df, test_df = utils.preprocess_data(train_data, val_data, test_data) 37 | if data_source == 'TCGA': 38 | test_df = test_df.loc[test_df['tcga_flag']==1.0] 39 | elif data_source =='CPTAC': 40 | test_df = test_df.loc[test_df['tcga_flag']==0.0] 41 | train_df, val_df, test_df = train_df.drop(columns=['tcga_flag']), val_df.drop(columns=['tcga_flag']), test_df.drop(columns=['tcga_flag']) 42 | y_train = np.array([tuple((bool(row[0]), row[1])) for row in zip(train_df['outcome'], train_df['day'])], 43 | dtype=[('outcome', 'bool'), ('day', '@misc{https://doi.org/10.48550/arxiv.2203.12204, 78 | doi = {10.48550/ARXIV.2203.12204}, 79 | url = {https://arxiv.org/abs/2203.12204}, 80 | author = {Zhu, Weicheng and Fernandez-Granda, Carlos and Razavian, Narges}, 81 | keywords = {Computer Vision and Pattern Recognition (cs.CV), FOS: Computer and information sciences, FOS: Computer and information sciences}, 82 | title = {Interpretable Prediction of Lung Squamous Cell Carcinoma Recurrence With Self-supervised Learning}, 83 | publisher = {arXiv}, 84 | year = {2022}, 85 | copyright = {arXiv.org perpetual, non-exclusive license} 86 | } 87 |
88 |