GBM360 is a software that harnesses the power of machine learning to investigate the cellular heterogeneity and spatial architecture of glioblastoma
', unsafe_allow_html=True) 14 | st.image('pictures/demo.png', width=1000) 15 | 16 | with st.expander("Citation"): 17 | st.markdown("""Zheng, Y., Carrillo-Perez, F., Pizurica, M. et al. Spatial cellular architecture predicts prognosis in glioblastoma. Nat Commun 14, 4122 (2023). https://doi.org/10.1038/s41467-023-39933-0""") 18 | 19 | with st.expander("Disclaimer"): 20 | st.markdown("""GBM360 is an academic research project and should **not** be considered a medical device approved by any federal authorities.""") 21 | st.markdown("Please remove Personal Health Information (PHI) from all uploaded files, as we are not responsible for data compliance issues", unsafe_allow_html=True) 22 | 23 | with st.expander("Contact"): 24 | paragraph = "- Dr. Yuanning Zheng is a postdoctoral scholar at Stanford University. He obtained his PhD degree in Medical Sciences from Texas A&M University and a Master in Computer Science from Georgia Institute of Technology.\n" \ 25 | "Dr. Zheng's research focuses on developing innovative machine learning and bioinformatics methods to unravel the heterogeneity and improve personalized diagnosis of cancers and other complex diseases. Email: eric2021@stanford.edu\n\n" \ 26 | "- Dr. Olivier Gevaert is an associate professor at Stanford University focusing on developing machine-learning methods for biomedical decision support from multi-scale data. Email: ogevaert@stanford.edu\n\n" \ 27 | "- Other contributors: Francisco Carrillo-Perez \n\n" \ 28 | "- Visit us at: Dr. Gevaert lab\n\n" \ 29 | "- For bug reporting, please visit:Please log in first.
', unsafe_allow_html=True) 90 | # else: -------------------------------------------------------------------------------- /src/pathology_models.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility functions of pathology models 3 | 4 | """ 5 | 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | #from onmt.encoders.encoder import EncoderBase 12 | #from onmt.modules import MultiHeadedAttention 13 | #from onmt.modules.position_ffn import PositionwiseFeedForward 14 | 15 | 16 | class Identity(nn.Module): 17 | def __init__(self): 18 | super(Identity, self).__init__() 19 | 20 | def forward(self, x): 21 | out = x 22 | attention_weights = torch.ones(x.shape[0], x.shape[1], device=x.device) 23 | return out, attention_weights 24 | 25 | 26 | class TanhAttention(nn.Module): 27 | def __init__(self, dim=2048): 28 | super(TanhAttention, self).__init__() 29 | self.dim = dim 30 | self.vector = torch.nn.Parameter(torch.zeros(dim)) 31 | self.linear = nn.Linear(dim, dim, bias=False) 32 | 33 | def forward(self, x): 34 | logits = torch.tanh(self.linear(x)).matmul(self.vector.unsqueeze(-1)) 35 | attention_weights = torch.nn.functional.softmax(logits, dim=1) 36 | out = x * attention_weights * x.shape[1] 37 | return out,attention_weights 38 | 39 | 40 | class AggregationModel(nn.Module): 41 | def __init__(self, resnet, aggregator, aggregator_dim, resnet_dim=2048, out_features=1, task = "classification"): 42 | super(AggregationModel, self).__init__() 43 | self.task = task 44 | self.resnet = resnet 45 | self.aggregator = aggregator 46 | self.fc = nn.Linear(aggregator_dim, out_features) 47 | self.aggregator_dim = aggregator_dim 48 | self.resnet_dim = resnet_dim 49 | self.softmax = nn.Softmax(dim=1) 50 | 51 | def forward(self, x): 52 | features,attention_weights = self.extract(x) 53 | out = self.fc(features) 54 | if self.task == 'prob': 55 | out = self.softmax(out) 56 | return out, attention_weights 57 | 58 | def extract(self,x): 59 | (batch_size, c, h, w) = x.shape 60 | x = x.reshape(-1, c, h, w) 61 | features = self.resnet.forward_extract(x) 62 | features = features.view(batch_size, self.resnet_dim) # bsize, resnet_dim 63 | features, attention_weights = self.aggregator(features) # bsize, aggregator_dim 64 | return features,attention_weights 65 | 66 | class AggregationProjectModel(nn.Module): 67 | def __init__(self, resnet, aggregator, aggregator_dim, resnet_dim=2048, out_features=1,hdim=200,dropout=.3): 68 | super(AggregationProjectModel, self).__init__() 69 | self.resnet = resnet 70 | self.aggregator = aggregator 71 | self.aggregator_dim = aggregator_dim 72 | self.resnet_dim = resnet_dim 73 | self.hdim = hdim 74 | self.dropout = nn.Dropout(p=dropout) 75 | self.project = nn.Linear(aggregator_dim, hdim) 76 | self.fc = nn.Linear(hdim, out_features) 77 | 78 | def forward(self, x): 79 | features,attention_weights = self.extract(x) 80 | out = self.fc(features) 81 | return out, attention_weights 82 | 83 | def extract(self,x): 84 | (batch_size, bag_size, c, h, w) = x.shape 85 | x = x.reshape(-1, c, h, w) 86 | features = self.resnet.forward_extract(x) 87 | features = features.view(batch_size, bag_size, self.resnet_dim) # bsize, bagsize, resnet_dim 88 | 89 | features, attention_weights = self.aggregator(features) # bsize, bagsize, aggregator_dim 90 | features = features.mean(dim=1) # batch_size,aggregator_dim 91 | features = self.project(features) 92 | features = F.tanh(features) 93 | features = self.dropout(features) 94 | 95 | return features,attention_weights 96 | 97 | def cox_loss(cox_scores, times, status): 98 | ''' 99 | :param cox_scores: cox scores, size (batch_size) 100 | :param times: event times (either death or censor), size batch_size 101 | :param status: event status (1 for death, 0 for censor), size batch_size 102 | :return: loss of size 1, the sum of cox losses for the batch 103 | ''' 104 | 105 | times, sorted_indices = torch.sort(-times) 106 | cox_scores = cox_scores[sorted_indices] 107 | status = status[sorted_indices] 108 | cox_scores = cox_scores -torch.max(cox_scores) 109 | exp_scores = torch.exp(cox_scores) 110 | loss = cox_scores - torch.log(torch.cumsum(exp_scores, dim=0)+1e-5) 111 | loss = - loss * status 112 | # TODO maybe divide by status.sum() 113 | 114 | if (loss != loss).any(): 115 | import pdb; 116 | pdb.set_trace() 117 | 118 | return loss.mean() 119 | 120 | class CoxLoss(nn.Module): 121 | def __init__(self): 122 | super(CoxLoss,self).__init__() 123 | 124 | def forward(self,cox_scores,times,status): 125 | return cox_loss(cox_scores,times,status) -------------------------------------------------------------------------------- /src/resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | ResNet models 3 | 4 | """ 5 | 6 | import torch.nn as nn 7 | import math 8 | import torch.utils.model_zoo as model_zoo 9 | import torch.nn.functional as F 10 | import torch 11 | 12 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 13 | 'resnet152'] 14 | 15 | model_urls = { 16 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 17 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 18 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 19 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 20 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 21 | } 22 | 23 | 24 | def conv3x3(in_planes, out_planes, stride=1): 25 | """3x3 convolution with padding""" 26 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 27 | padding=1, bias=False) 28 | 29 | 30 | class BasicBlock(nn.Module): 31 | expansion = 1 32 | 33 | def __init__(self, inplanes, planes, stride=1, downsample=None): 34 | super(BasicBlock, self).__init__() 35 | self.conv1 = conv3x3(inplanes, planes, stride) 36 | self.bn1 = nn.BatchNorm2d(planes) 37 | self.relu = nn.ReLU(inplace=True) 38 | self.conv2 = conv3x3(planes, planes) 39 | self.bn2 = nn.BatchNorm2d(planes) 40 | self.downsample = downsample 41 | self.stride = stride 42 | 43 | def forward(self, x): 44 | residual = x 45 | 46 | out = self.conv1(x) 47 | out = self.bn1(out) 48 | out = self.relu(out) 49 | 50 | out = self.conv2(out) 51 | out = self.bn2(out) 52 | 53 | if self.downsample is not None: 54 | residual = self.downsample(x) 55 | 56 | out += residual 57 | out = self.relu(out) 58 | 59 | return out 60 | 61 | 62 | class Bottleneck(nn.Module): 63 | expansion = 4 64 | 65 | def __init__(self, inplanes, planes, stride=1, downsample=None): 66 | super(Bottleneck, self).__init__() 67 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 68 | self.bn1 = nn.BatchNorm2d(planes) 69 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 70 | padding=1, bias=False) 71 | self.bn2 = nn.BatchNorm2d(planes) 72 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 73 | self.bn3 = nn.BatchNorm2d(planes * 4) 74 | self.relu = nn.ReLU(inplace=True) 75 | self.downsample = downsample 76 | self.stride = stride 77 | 78 | def forward(self, x): 79 | residual = x 80 | 81 | out = self.conv1(x) 82 | out = self.bn1(out) 83 | out = self.relu(out) 84 | 85 | out = self.conv2(out) 86 | out = self.bn2(out) 87 | out = self.relu(out) 88 | 89 | out = self.conv3(out) 90 | out = self.bn3(out) 91 | 92 | if self.downsample is not None: 93 | residual = self.downsample(x) 94 | 95 | out += residual 96 | out = self.relu(out) 97 | 98 | return out 99 | 100 | 101 | class ResNet(nn.Module): 102 | 103 | def __init__(self, block, layers, num_classes=1000): 104 | self.inplanes = 64 105 | super(ResNet, self).__init__() 106 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 107 | bias=False) 108 | self.bn1 = nn.BatchNorm2d(64) 109 | self.relu = nn.ReLU(inplace=True) 110 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 111 | self.layer1 = self._make_layer(block, 64, layers[0]) 112 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 113 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 114 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 115 | self.avgpool = nn.AvgPool2d(7, stride=1) 116 | self.fc = nn.Linear(512 * block.expansion, num_classes) 117 | 118 | for m in self.modules(): 119 | if isinstance(m, nn.Conv2d): 120 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 121 | m.weight.data.normal_(0, math.sqrt(2. / n)) 122 | elif isinstance(m, nn.BatchNorm2d): 123 | m.weight.data.fill_(1) 124 | m.bias.data.zero_() 125 | 126 | def _make_layer(self, block, planes, blocks, stride=1): 127 | downsample = None 128 | if stride != 1 or self.inplanes != planes * block.expansion: 129 | downsample = nn.Sequential( 130 | nn.Conv2d(self.inplanes, planes * block.expansion, 131 | kernel_size=1, stride=stride, bias=False), 132 | nn.BatchNorm2d(planes * block.expansion), 133 | ) 134 | 135 | layers = [] 136 | layers.append(block(self.inplanes, planes, stride, downsample)) 137 | self.inplanes = planes * block.expansion 138 | for i in range(1, blocks): 139 | layers.append(block(self.inplanes, planes)) 140 | 141 | return nn.Sequential(*layers) 142 | 143 | def forward(self, x): 144 | x = self.conv1(x) 145 | x = self.bn1(x) 146 | x = self.relu(x) 147 | x = self.maxpool(x) 148 | 149 | x = self.layer1(x) 150 | x = self.layer2(x) 151 | x = self.layer3(x) 152 | x = self.layer4(x) 153 | 154 | x = self.avgpool(x) 155 | x = x.view(x.size(0), -1) 156 | x = self.fc(x) 157 | 158 | return x 159 | 160 | def forward_extract(self, x): 161 | x = self.conv1(x) 162 | x = self.bn1(x) 163 | x = self.relu(x) 164 | x = self.maxpool(x) 165 | 166 | x = self.layer1(x) 167 | x = self.layer2(x) 168 | x = self.layer3(x) 169 | x = self.layer4(x) 170 | 171 | x = self.avgpool(x) 172 | x = x.view(x.size(0), -1) 173 | 174 | return x 175 | 176 | class RNfour(nn.Module): 177 | 178 | def __init__(self, block, layers, num_classes=1000): 179 | self.inplanes = 64 180 | super(RNfour, self).__init__() 181 | self.conv1 = nn.Conv2d(4, 64, kernel_size=7, stride=2, padding=3, 182 | bias=False) 183 | self.bn1 = nn.BatchNorm2d(64) 184 | self.relu = nn.ReLU(inplace=True) 185 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 186 | self.layer1 = self._make_layer(block, 64, layers[0]) 187 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 188 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 189 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 190 | self.avgpool = nn.AvgPool2d(7, stride=1) 191 | self.fc = nn.Linear(512 * block.expansion, num_classes) 192 | 193 | for m in self.modules(): 194 | if isinstance(m, nn.Conv2d): 195 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 196 | m.weight.data.normal_(0, math.sqrt(2. / n)) 197 | elif isinstance(m, nn.BatchNorm2d): 198 | m.weight.data.fill_(1) 199 | m.bias.data.zero_() 200 | 201 | def _make_layer(self, block, planes, blocks, stride=1): 202 | downsample = None 203 | if stride != 1 or self.inplanes != planes * block.expansion: 204 | downsample = nn.Sequential( 205 | nn.Conv2d(self.inplanes, planes * block.expansion, 206 | kernel_size=1, stride=stride, bias=False), 207 | nn.BatchNorm2d(planes * block.expansion), 208 | ) 209 | 210 | layers = [] 211 | layers.append(block(self.inplanes, planes, stride, downsample)) 212 | self.inplanes = planes * block.expansion 213 | for i in range(1, blocks): 214 | layers.append(block(self.inplanes, planes)) 215 | 216 | return nn.Sequential(*layers) 217 | 218 | def forward(self, x): 219 | x = self.conv1(x) 220 | x = self.bn1(x) 221 | x = self.relu(x) 222 | x = self.maxpool(x) 223 | 224 | x = self.layer1(x) 225 | x = self.layer2(x) 226 | x = self.layer3(x) 227 | x = self.layer4(x) 228 | 229 | x = self.avgpool(x) 230 | x = x.view(x.size(0), -1) 231 | x = self.fc(x) 232 | 233 | return x 234 | 235 | def forward_extract(self, x): 236 | x = self.conv1(x) 237 | x = self.bn1(x) 238 | x = self.relu(x) 239 | x = self.maxpool(x) 240 | 241 | x = self.layer1(x) 242 | x = self.layer2(x) 243 | x = self.layer3(x) 244 | x = self.layer4(x) 245 | 246 | x = self.avgpool(x) 247 | x = x.view(x.size(0), -1) 248 | 249 | return x 250 | 251 | class RNone(nn.Module): 252 | 253 | def __init__(self, block, layers, num_classes=1000): 254 | self.inplanes = 64 255 | super(RNone, self).__init__() 256 | self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, 257 | bias=False) 258 | self.bn1 = nn.BatchNorm2d(64) 259 | self.relu = nn.ReLU(inplace=True) 260 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 261 | self.layer1 = self._make_layer(block, 64, layers[0]) 262 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 263 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 264 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 265 | self.avgpool = nn.AvgPool2d(7, stride=1) 266 | self.fc = nn.Linear(512 * block.expansion, num_classes) 267 | 268 | for m in self.modules(): 269 | if isinstance(m, nn.Conv2d): 270 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 271 | m.weight.data.normal_(0, math.sqrt(2. / n)) 272 | elif isinstance(m, nn.BatchNorm2d): 273 | m.weight.data.fill_(1) 274 | m.bias.data.zero_() 275 | 276 | def _make_layer(self, block, planes, blocks, stride=1): 277 | downsample = None 278 | if stride != 1 or self.inplanes != planes * block.expansion: 279 | downsample = nn.Sequential( 280 | nn.Conv2d(self.inplanes, planes * block.expansion, 281 | kernel_size=1, stride=stride, bias=False), 282 | nn.BatchNorm2d(planes * block.expansion), 283 | ) 284 | 285 | layers = [] 286 | layers.append(block(self.inplanes, planes, stride, downsample)) 287 | self.inplanes = planes * block.expansion 288 | for i in range(1, blocks): 289 | layers.append(block(self.inplanes, planes)) 290 | 291 | return nn.Sequential(*layers) 292 | 293 | def forward(self, x): 294 | x = self.conv1(x) 295 | x = self.bn1(x) 296 | x = self.relu(x) 297 | x = self.maxpool(x) 298 | 299 | x = self.layer1(x) 300 | x = self.layer2(x) 301 | x = self.layer3(x) 302 | x = self.layer4(x) 303 | 304 | x = self.avgpool(x) 305 | x = x.view(x.size(0), -1) 306 | x = self.fc(x) 307 | 308 | return x 309 | 310 | def forward_extract(self, x): 311 | x = self.conv1(x) 312 | x = self.bn1(x) 313 | x = self.relu(x) 314 | x = self.maxpool(x) 315 | 316 | x = self.layer1(x) 317 | x = self.layer2(x) 318 | x = self.layer3(x) 319 | x = self.layer4(x) 320 | 321 | x = self.avgpool(x) 322 | x = x.view(x.size(0), -1) 323 | 324 | return x 325 | 326 | 327 | class ResNetProject(nn.Module): 328 | 329 | def __init__(self, resnet, hdim=200, input_dim=2048, dropout=.3): 330 | super(ResNetProject, self).__init__() 331 | self.resnet = resnet 332 | self.hdim = hdim 333 | self.dropout = nn.Dropout(p=dropout) 334 | self.project = nn.Linear(input_dim, hdim) 335 | self.fc = nn.Linear(hdim, 1) 336 | 337 | def forward_extract(self, x): 338 | x = self.resnet.forward_extract(x) 339 | x = self.project(x) 340 | x = F.tanh(x) 341 | x = self.dropout(x) 342 | return x 343 | 344 | def forward(self, x): 345 | x = self.forward_extract(x) 346 | x = self.fc(x) 347 | return x 348 | 349 | 350 | def resnet18(pretrained=False, **kwargs): 351 | """Constructs a ResNet-18 model. 352 | Args: 353 | pretrained (bool): If True, returns a model pre-trained on ImageNet 354 | """ 355 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 356 | if pretrained: 357 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 358 | return model 359 | 360 | 361 | def resnet34(pretrained=False, **kwargs): 362 | """Constructs a ResNet-34 model. 363 | Args: 364 | pretrained (bool): If True, returns a model pre-trained on ImageNet 365 | """ 366 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 367 | if pretrained: 368 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 369 | return model 370 | 371 | 372 | def resnet50(pretrained=False, **kwargs): 373 | """Constructs a ResNet-50 model. 374 | Args: 375 | pretrained (bool): If True, returns a model pre-trained on ImageNet 376 | """ 377 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 378 | if pretrained: 379 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 380 | return model 381 | 382 | def resnet50_4channel(pretrained=False, **kwargs): 383 | """Constructs a ResNet-50 model. 384 | Args: 385 | pretrained (bool): If True, returns a model pre-trained on ImageNet 386 | """ 387 | new_model = RNfour(Bottleneck, [3, 4, 6, 3], **kwargs) 388 | 389 | if pretrained: 390 | 391 | pretrained_dict = model_zoo.load_url(model_urls['resnet50']) 392 | new_model_dict = new_model.state_dict() 393 | 394 | # 1. filter out unnecessary keys 395 | filtered_pretrained_dict = {k: v for k, v in pretrained_dict.items() if k!='conv1.weight'} 396 | # 2. overwrite entries in the existing state dict 397 | new_model_dict.update(filtered_pretrained_dict) 398 | # 3. load the new state dict 399 | new_model.load_state_dict(new_model_dict) 400 | 401 | new_model.conv1.weight.data.normal_(0, 0.001) 402 | new_model.conv1.weight.data[:, :3, :, :] = pretrained_dict['conv1.weight'] 403 | 404 | 405 | return new_model 406 | 407 | def resnet50_1channel(pretrained=False, **kwargs): 408 | """Constructs a ResNet-50 model. 409 | Args: 410 | pretrained (bool): If True, returns a model pre-trained on ImageNet 411 | """ 412 | new_model = RNone(Bottleneck, [3, 4, 6, 3], **kwargs) 413 | 414 | if pretrained: 415 | 416 | pretrained_dict = model_zoo.load_url(model_urls['resnet50']) 417 | new_model_dict = new_model.state_dict() 418 | 419 | # 1. filter out unnecessary keys 420 | filtered_pretrained_dict = {k: v for k, v in pretrained_dict.items() if k!='conv1.weight'} 421 | # 2. overwrite entries in the existing state dict 422 | new_model_dict.update(filtered_pretrained_dict) 423 | # 3. load the new state dict 424 | new_model.load_state_dict(new_model_dict) 425 | 426 | con1w=pretrained_dict['conv1.weight'] 427 | con1w_mean=torch.mean(con1w, dim=1, keepdim=True) 428 | 429 | new_model.conv1.weight.data=con1w_mean 430 | 431 | 432 | 433 | 434 | return new_model 435 | 436 | def resnet101(pretrained=False, **kwargs): 437 | """Constructs a ResNet-101 model. 438 | Args: 439 | pretrained (bool): If True, returns a model pre-trained on ImageNet 440 | """ 441 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 442 | if pretrained: 443 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 444 | return model 445 | 446 | 447 | def resnet152(pretrained=False, **kwargs): 448 | """Constructs a ResNet-152 model. 449 | Args: 450 | pretrained (bool): If True, returns a model pre-trained on ImageNet 451 | """ 452 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 453 | if pretrained: 454 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 455 | return -------------------------------------------------------------------------------- /src/run_predict.py: -------------------------------------------------------------------------------- 1 | """ 2 | The run prediction page 3 | """ 4 | 5 | import numpy as np 6 | from PIL import Image 7 | import streamlit as st 8 | from openslide import OpenSlide 9 | from utils import * 10 | from spa_mapping import generate_heatmap_cell_type, generate_heatpmap_survival 11 | from spatial_stat import gen_graph, compute_percent 12 | from get_patch_img import read_patches 13 | import seaborn as sns 14 | import matplotlib.pyplot as plt 15 | from io import BytesIO 16 | import json 17 | import pyvips 18 | 19 | def app(): 20 | 21 | with open("config/config.json", 'r') as f: 22 | config = json.load(f) 23 | 24 | clear_dir('temp') 25 | 26 | # Specify canvas parameters in application 27 | bg_image = st.file_uploader("Image:", type=["tiff", 'tif', "svs"]) 28 | 29 | # Control panel 30 | example_button = st.button("Use an example slide") 31 | test_mode = st.selectbox("Run Mode:", ("Test mode (only 1,000 patches will be predicted)", "Complete")) 32 | st.markdown("**Note**: We are currently working on obtaining GPU support for this software. To expedite the process, the default mode " 33 | "is now set to `Test mode`, which will only predict 1,000 patches of the image. " 34 | "To predict the entire image, please switch to the `Complete` mode.") 35 | 36 | cell_type_button = st.button("Get cell type visualization") 37 | prognosis_button = st.button("Get prognosis visualization") 38 | clear_button = st.button("Clear the session") 39 | 40 | # Check available device 41 | device = check_device(config['use_cuda']) 42 | st.write('Device available:', device) 43 | 44 | # Initialization 45 | if 'slide' not in st.session_state: 46 | st.session_state.slide = None 47 | if 'image' not in st.session_state: 48 | st.session_state.image = None 49 | if 'image_type' not in st.session_state: 50 | st.session_state.image_type = None 51 | if 'dataloader' not in st.session_state: 52 | st.session_state.dataloader = None 53 | 54 | if bg_image: 55 | path = save_uploaded_file(bg_image) 56 | st.session_state.image_type = "svs" 57 | 58 | if path.endswith("tiff") or path.endswith("tif"): 59 | image = pyvips.Image.new_from_file(path) 60 | image.write_to_file("temp/test.tiff", pyramid=True, tile=True) 61 | path = "temp/test.tiff" 62 | st.session_state.image_type = "tif" 63 | 64 | st.session_state.slide = OpenSlide(path) 65 | st.session_state.image = st.session_state.slide.get_thumbnail(size=(512,512)) 66 | st.image(st.session_state.image) 67 | bg_image = None 68 | 69 | if example_button: 70 | path = os.path.join("example", "C3L-00365-21.svs") 71 | st.session_state.slide = OpenSlide(path) 72 | st.session_state.image = st.session_state.slide.get_thumbnail(size=(512,512)) 73 | st.image(st.session_state.image) 74 | 75 | max_patches_per_slide = np.inf 76 | if test_mode == "Test mode (only 1,000 patches will be predicted)": 77 | max_patches_per_slide = 1000 78 | 79 | if cell_type_button and st.session_state.slide: 80 | slide = st.session_state.slide 81 | with st.spinner('Reading patches...'): 82 | dataloader = read_patches(slide, max_patches_per_slide, image_type = st.session_state.image_type) 83 | 84 | with st.spinner('Loading model...'): 85 | model = load_model(checkpoint='model_weights/train_2023-04-28_prob_multi_label_weighted/model_cell.pt', config = config) 86 | 87 | with st.spinner('Predicting transcriptional subtypes...'): 88 | results = predict_cell(model, dataloader, device=device) 89 | 90 | with st.spinner('Generating visualization...'): 91 | heatmap = generate_heatmap_cell_type(slide, patch_size= (112,112), labels=results, config=config) 92 | im = Image.fromarray(heatmap) 93 | legend = Image.open('pictures/cell-type-hor.png') 94 | st.image(legend) 95 | st.image(im, caption='Subtype distribution across the tissue') 96 | 97 | with st.spinner('Calculating spatial statistics...'): 98 | df_percent = compute_percent(results) # cell type composition 99 | dgr_centr, im_mtx_slide, im_mtx_row, df_cluster = gen_graph(slide, results = results) # graph statistics 100 | 101 | # Display statistic tables for cell proportions 102 | color_ids, cluster_colors = get_color_ids() 103 | st.markdown('Cell fraction (%)
', unsafe_allow_html=True) 104 | data_container = st.container() 105 | with data_container: 106 | table, plot, _ , _ = st.columns(4) 107 | with table: 108 | st.table(data=style_table(df_percent)) 109 | with plot: 110 | buf = BytesIO() 111 | fig, ax = plt.subplots() 112 | sns.barplot(data = df_percent, y = 'Subtype', x = "Percentage", palette = cluster_colors, ax = ax) 113 | ax.tick_params(labelsize=14) 114 | ax.set_ylabel('', fontdict= {'fontsize': 16, 'fontweight':'bold'}) 115 | ax.set_xlabel('Percentage (%)',fontdict= { 'fontsize': 16, 'fontweight':'bold'}) 116 | fig.savefig(buf, format="png", bbox_inches = "tight") 117 | st.image(buf) 118 | 119 | # Display row-normalized interaction matrix 120 | st.markdown('Interaction matrix (row-wise normalized)
', unsafe_allow_html=True) 121 | data_container = st.container() 122 | with data_container: 123 | table, plot, _ , _ = st.columns(4) 124 | with table: 125 | st.table(data=style_table(im_mtx_row)) 126 | with plot: 127 | buf = BytesIO() 128 | fig, ax = plt.subplots() 129 | sns.heatmap(im_mtx_row, ax = ax) 130 | #ax.tick_params(labelsize=12) 131 | fig.savefig(buf, format="png", bbox_inches = "tight") 132 | st.image(buf) 133 | 134 | # Display slide-normalized interaction matrix 135 | st.markdown('Interaction matrix (slide-wise normalized)
', unsafe_allow_html=True) 136 | data_container = st.container() 137 | with data_container: 138 | table, plot, _ , _ = st.columns(4) 139 | with table: 140 | st.table(data=style_table(im_mtx_slide)) 141 | with plot: 142 | buf = BytesIO() 143 | fig, ax = plt.subplots() 144 | sns.heatmap(im_mtx_slide, ax = ax) 145 | #ax.tick_params(labelsize=12) 146 | fig.savefig(buf, format="png", bbox_inches = "tight") 147 | st.image(buf) 148 | 149 | # Display statistic tables for clustering coefficient 150 | st.markdown('Clustering coefficient
', unsafe_allow_html=True) 151 | data_container = st.container() 152 | with data_container: 153 | table, plot, _ , _ = st.columns(4) 154 | with table: 155 | st.table(data=style_table(df_cluster)) 156 | with plot: 157 | buf = BytesIO() 158 | fig, ax = plt.subplots() 159 | sns.barplot(data = df_cluster, y = 'Subtype', x = 'cluster_coeff' , palette = cluster_colors, ax = ax) 160 | ax.tick_params(labelsize=14) 161 | ax.set_ylabel('', fontdict= {'fontsize': 16, 'fontweight':'bold'}) 162 | ax.set_xlabel('Clustering coefficient',fontdict= { 'fontsize': 16, 'fontweight':'bold'}) 163 | fig.savefig(buf, format="png", bbox_inches = "tight") 164 | st.image(buf) 165 | 166 | if prognosis_button and st.session_state.slide: 167 | 168 | slide = st.session_state.slide 169 | 170 | with st.spinner('Reading patches...'): 171 | dataloader = read_patches(slide, max_patches_per_slide) 172 | 173 | config['num_classes'] = 1 174 | with st.spinner('Loading model...'): 175 | model = load_model(checkpoint='model_weights/model_survival.pt', config = config) 176 | 177 | with st.spinner('Predicting aggressive scores...'): 178 | results = predict_survival(model, dataloader, device=device) 179 | config['label_column'] = 'risk_score' 180 | 181 | 182 | with st.spinner('Generating visualization...'): 183 | heatmap = generate_heatpmap_survival(slide, patch_size= (112,112), 184 | results=results, 185 | config = config) 186 | 187 | legend = Image.open('pictures/risk_score_legend.png') 188 | st.image(legend) 189 | im = Image.fromarray(heatmap) 190 | st.image(im, caption='Aggressive score prediction') 191 | 192 | if clear_button: 193 | clear(path) 194 | 195 | 196 | 197 | # # Display statistic tables for degree centrality 198 | # st.markdown('Degree centrality
', unsafe_allow_html=True) 199 | # data_container = st.container() 200 | # with data_container: 201 | # table, plot, _ , _ = st.columns(4) 202 | # with table: 203 | # st.table(data=style_table(dgr_centr)) 204 | # with plot: 205 | # buf = BytesIO() 206 | # fig, ax = plt.subplots() 207 | # sns.barplot(data = dgr_centr, y = 'Subtype', x = 'centrality' , palette = cluster_colors, ax = ax) 208 | # ax.tick_params(labelsize=14) 209 | # ax.set_ylabel('', fontdict= {'fontsize': 16, 'fontweight':'bold'}) 210 | # ax.set_xlabel('Centrality score',fontdict= { 'fontsize': 16, 'fontweight':'bold'}) 211 | # fig.savefig(buf, format="png", bbox_inches = "tight") 212 | # st.image(buf) -------------------------------------------------------------------------------- /src/spa_mapping.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | Script to generate the visualization of cell types and prognostic scores in whole slide images 4 | 5 | """ 6 | 7 | from typing import Tuple 8 | import pandas as pd 9 | import numpy as np 10 | from openslide import OpenSlide 11 | import seaborn as sns 12 | from stqdm import stqdm 13 | import matplotlib.pyplot as plt 14 | from PIL import Image 15 | import math 16 | import seaborn as sns 17 | from utils import get_class, get_color_ids 18 | import pdb 19 | 20 | def assig_to_heatmap(heatmap, patch, x, y, ratio_patch_x, ratio_patch_y): 21 | 22 | new_x = int(x / ratio_patch_x) 23 | new_y = int(y / ratio_patch_y) 24 | 25 | try: 26 | if new_x+patch.shape[0] > heatmap.shape[0] and new_y+patch.shape[1] < heatmap.shape[1]: 27 | dif = heatmap.shape[0] - new_x 28 | heatmap[new_x:heatmap.shape[0], new_y:new_y+patch.shape[1], :] = patch[:dif, :, :] 29 | elif new_x+patch.shape[0] < heatmap.shape[0] and new_y+patch.shape[1] > heatmap.shape[1]: 30 | dif = heatmap.shape[1] - new_y 31 | heatmap[new_x:new_x+patch.shape[0], new_y:, :] = patch[:, :dif, :] 32 | elif new_x+patch.shape[0] > heatmap.shape[0] and new_y+patch.shape[1] > heatmap.shape[1]: 33 | return heatmap 34 | else: 35 | heatmap[new_x:new_x+patch.shape[0], new_y:new_y+patch.shape[1], :] = patch 36 | return heatmap 37 | except: 38 | return heatmap 39 | 40 | def get_indices(slide : OpenSlide, patch_size: Tuple, PATCH_LEVEL = 0, dezoom_factor = 1, use_h5 = False): 41 | 42 | xmax, ymax = slide.level_dimensions[PATCH_LEVEL] 43 | 44 | # handle slides with 40 magnification at base level 45 | if use_h5: 46 | resize_factor = 0.5 / float(slide.properties.get('openslide.mpp-x', 0.5)) 47 | else: 48 | resize_factor = float(slide.properties.get('aperio.AppMag', 20)) / 20.0 49 | 50 | resize_factor = resize_factor * dezoom_factor 51 | patch_size_resized = (int(resize_factor * patch_size[0]), int(resize_factor * patch_size[1])) 52 | 53 | indices = [(x, y) for x in range(0, xmax, patch_size_resized[0]) 54 | for y in range(0, ymax, patch_size_resized[0])] 55 | 56 | return(indices, xmax, ymax, patch_size_resized, resize_factor) 57 | 58 | def get_color_linear(minimum, maximum, value): 59 | # give the minimun and maxium value, generate a color mapped to blue-red heatmap 60 | minimum, maximum = float(minimum), float(maximum) 61 | ratio = 2 * (value-minimum) / (maximum - minimum) 62 | b = int(max(0, 255*(1 - ratio))) 63 | r = int(max(0, 255*(ratio - 1))) 64 | g = 255 - b - r 65 | return r, g, b 66 | 67 | def make_dict_cell_type(labels, config): 68 | keys = labels['coordinates'] 69 | labels = labels[config['label_column']] 70 | keys = np.concatenate((keys), axis=0) 71 | labels = np.concatenate((labels), axis=0) 72 | # convert predicted labels to actual cell types 73 | class2idx, id2class = get_class() 74 | cell_types = [id2class[k] for k in labels] 75 | # Match cell types to colors 76 | color_ids, cluster_colors = get_color_ids() 77 | colors = [color_ids[k] for k in cell_types] 78 | color_labels = dict() 79 | for key, value in zip(keys, colors): 80 | color_labels[tuple(key)] = value 81 | return color_labels 82 | 83 | def make_dict_survival(labels): 84 | keys = labels['coordinates'] 85 | values = labels['risk_score'] 86 | survival_labels = dict() 87 | for key, value in zip(keys, values): 88 | for k, v in zip(key, value): 89 | survival_labels[k] = v 90 | return survival_labels 91 | 92 | def generate_heatmap_cell_type(slide, patch_size: Tuple, labels, config): 93 | PATCH_LEVEL = 0 94 | indices, xmax, ymax, patch_size_resized, resize_factor = get_indices(slide, patch_size, PATCH_LEVEL, use_h5 = config['use_h5']) 95 | 96 | compress_factor = config['compress_factor'] * round(resize_factor) 97 | 98 | heatmap = np.zeros((xmax // compress_factor, ymax // compress_factor, 3)) 99 | labels_dict = make_dict_cell_type(labels, config) 100 | 101 | print(f'Overlap patches: {len(set(labels_dict.keys()) & set(indices))}') 102 | 103 | for x, y in stqdm(indices): 104 | try: 105 | patch = np.transpose(np.array(slide.read_region((x, y), PATCH_LEVEL, patch_size_resized).convert('RGB')), axes=[1, 0, 2]) 106 | patch = Image.fromarray(patch) 107 | patch = patch.resize((math.ceil(patch_size_resized[0] / compress_factor), math.ceil(patch_size_resized[1] / compress_factor))) 108 | patch = np.asarray(patch) 109 | 110 | if (x, y) in labels_dict: 111 | score = labels_dict[(x,y)] 112 | color = sns.color_palette()[score] 113 | visualization = np.empty((math.ceil(patch_size_resized[0] / compress_factor), math.ceil(patch_size_resized[1] / compress_factor), 3), np.uint8) 114 | visualization[:] = color[0] * 255, color[1] * 255, color[2] * 255 115 | heatmap = assig_to_heatmap(heatmap, visualization, x, y, compress_factor, compress_factor) 116 | else: 117 | heatmap = assig_to_heatmap(heatmap, patch, x, y, compress_factor, compress_factor) 118 | except Exception as e: 119 | print(e) 120 | 121 | # since the x and y coordiante is flipped after converting the patch to RGB, we flipped the image again to match the original image 122 | heatmap = np.transpose(heatmap, axes=[1, 0, 2]).astype(np.uint8) 123 | return heatmap 124 | 125 | def generate_heatpmap_survival(slide, patch_size: Tuple, results: dict, min_val=-2, max_val=2.34, config = None): 126 | 127 | PATCH_LEVEL = 0 128 | indices, xmax, ymax, patch_size_resized, resize_factor = get_indices(slide, patch_size, PATCH_LEVEL, use_h5 = config['use_h5']) 129 | 130 | compress_factor = config['compress_factor'] * round(resize_factor) 131 | heatmap = np.zeros((xmax // compress_factor, ymax // compress_factor, 3)) 132 | labels_dict = make_dict_survival(results) 133 | 134 | risk_score = [s for sublist in results['risk_score'] for s in sublist] 135 | min_val = np.min(risk_score) 136 | max_val = np.max(risk_score) 137 | 138 | for x, y in stqdm(indices): 139 | try: 140 | patch = np.transpose(np.array(slide.read_region((x, y), PATCH_LEVEL, patch_size_resized).convert('RGB')), axes=[1, 0, 2]) 141 | patch = Image.fromarray(patch) 142 | patch = patch.resize((math.ceil(patch_size_resized[0] / compress_factor), math.ceil(patch_size_resized[1] / compress_factor))) 143 | patch = np.asarray(patch) 144 | 145 | if (x, y) in labels_dict: 146 | score = labels_dict[(x,y)] 147 | color = get_color_linear(min_val, max_val, score) 148 | visualization = np.empty((math.ceil(patch_size_resized[0] / compress_factor), math.ceil(patch_size_resized[1] / compress_factor), 3), np.uint8) 149 | visualization[:] = color[0] * 255, color[1] * 255, color[2] * 255 150 | heatmap = assig_to_heatmap(heatmap, visualization, x, y, compress_factor, compress_factor) 151 | else: 152 | heatmap = assig_to_heatmap(heatmap, patch, x, y, compress_factor, compress_factor) 153 | except Exception as e: 154 | print(e) 155 | 156 | # since the x and y coordiante is flipped after converting the patch to RGB, we flipped the image again to match the original image 157 | heatmap = np.transpose(heatmap, axes=[1, 0, 2]).astype(np.uint8) 158 | 159 | return heatmap 160 | -------------------------------------------------------------------------------- /src/spatial_stat.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility functions to perform spatial statistical analysis 3 | """ 4 | 5 | from openslide import OpenSlide 6 | import math 7 | from anndata import AnnData 8 | import squidpy as sq 9 | import numpy as np 10 | import pandas as pd 11 | import scanpy as sc 12 | import os 13 | import matplotlib.pyplot as plt 14 | from numpy.random import default_rng 15 | import seaborn as sns 16 | import warnings 17 | from utils import get_class, get_color_ids 18 | import pdb 19 | 20 | def gen_output(results): 21 | coord = [c for sublist in results['coordinates'] for c in sublist] 22 | x = [c[0].detach().cpu().numpy() for c in coord] 23 | y = [c[1].detach().cpu().numpy() for c in coord] 24 | labels = [l for sublist in results['label'] for l in sublist] 25 | df_res = pd.DataFrame({'x': x, 'y': y, 'label': labels}) 26 | return df_res 27 | 28 | def get_matracies(slide, cluster_df, patch_size = (112, 112)): 29 | 30 | if not slide.properties.get('openslide.mpp-x'): print(f"resolution is not found, using default 0.5um/px") 31 | resize_factor = float(slide.properties.get('aperio.AppMag', 20)) / 20.0 32 | patch_size_resized = (int(resize_factor * patch_size[0]), int(resize_factor * patch_size[1])) 33 | 34 | with warnings.catch_warnings(): 35 | warnings.simplefilter("ignore") 36 | cluster_df['new_x'] = np.ceil(cluster_df['x'].values / patch_size_resized[0]) 37 | cluster_df['new_y'] = np.ceil(cluster_df['y'].values / patch_size_resized[1]) 38 | 39 | cluster_df['new_x'] = cluster_df['new_x'].astype(int) 40 | cluster_df['new_y'] = cluster_df['new_y'].astype(int) 41 | 42 | matrix_trait = pd.DataFrame({'label': cluster_df['label'], 'x': cluster_df['new_x'], 'y': cluster_df['new_y']}) 43 | 44 | return(matrix_trait) 45 | 46 | 47 | def get_interactions(cells): 48 | """ 49 | Generate a list of interactions between cell types, except for self-interactions. 50 | """ 51 | interactions = [] 52 | i, j = 0, 0 53 | while i < len(cells): 54 | j = i + 1 55 | while j < len(cells): 56 | concat = [cells[i], cells[j]] 57 | interactions.append(concat) 58 | j = j + 1 59 | i = i + 1 60 | return interactions 61 | 62 | def normalize_interactions(im_mtx, cell_types): 63 | """ 64 | Given a cell interaction matrix, normalize each type of interaction by dividing the total number of interactions 65 | """ 66 | sum_links = im_mtx.sum().sum() 67 | interactions = get_interactions(cell_types) 68 | duplicated_links = 0 69 | for inter in interactions: 70 | duplicated_links = duplicated_links + im_mtx.loc[inter[0], inter[1]] 71 | new_links = sum_links - duplicated_links 72 | im_mtx_norm = im_mtx.div(new_links) 73 | return im_mtx_norm 74 | 75 | def gen_graph(slide, results): 76 | 77 | cluster_df = gen_output(results) 78 | 79 | trait = get_matracies(slide, cluster_df = cluster_df, patch_size = (112, 112)) 80 | trait['label'] = trait['label'].astype('category') 81 | class2idx, id2class = get_class() 82 | labels = sorted(np.unique(trait['label'])) 83 | cell_types = [id2class[k] for k in labels] 84 | 85 | cell_number = trait.shape[0] 86 | rng = default_rng(0) 87 | counts = rng.integers(0, 15, size=(cell_number, 50)) # feature matrix 88 | 89 | with warnings.catch_warnings(): 90 | warnings.simplefilter("ignore") 91 | adata = AnnData(counts, obs = trait, obsm={"spatial": np.asarray(trait[['x', 'y']])}, dtype = counts.dtype) 92 | 93 | sq.gr.spatial_neighbors(adata, n_neighs=8, n_rings=2, coord_type="grid") 94 | sq.gr.centrality_scores(adata, cluster_key='label', show_progress_bar=False) 95 | sq.gr.interaction_matrix(adata, cluster_key='label') 96 | 97 | # Generate dataframes 98 | dgr_centr = pd.DataFrame({'Subtype': cell_types, 'centrality':adata.uns['label_centrality_scores']['degree_centrality']}) 99 | 100 | im_mtx = pd.DataFrame(adata.uns['label_interactions'], columns=cell_types, index=cell_types) 101 | im_mtx_slide = normalize_interactions(im_mtx, cell_types) 102 | im_mtx_row = im_mtx.div(im_mtx.sum(axis=1), axis=0) 103 | 104 | cluster_res = [] 105 | for cell in cell_types: 106 | cluster_res.append(im_mtx_row.loc[cell][cell]) 107 | df_cluster = pd.DataFrame({'Subtype': cell_types, 'cluster_coeff': cluster_res}) 108 | 109 | dgr_centr = dgr_centr.sort_values(["centrality"], ascending=False) 110 | df_cluster = df_cluster.sort_values(['cluster_coeff'], ascending=False) 111 | 112 | dgr_centr = dgr_centr.reset_index(drop = True) 113 | df_cluster = df_cluster.reset_index(drop = True) 114 | 115 | return dgr_centr, im_mtx_slide, im_mtx_row, df_cluster 116 | 117 | def compute_percent(labels): 118 | """ 119 | Compute cell type compositions 120 | """ 121 | 122 | labels = labels['label'] 123 | labels = np.concatenate((labels), axis=0) 124 | # convert predicted labels to actual cell types 125 | class2idx, id2class = get_class() 126 | pred_labels = [id2class[k] for k in labels] 127 | total = len(pred_labels) 128 | cell_types = class2idx.keys() 129 | frac = [] 130 | for cell in cell_types: 131 | count = pred_labels.count(cell) 132 | percent = float(count/total) * 100 133 | frac.append(percent) 134 | df = pd.DataFrame({'Subtype': cell_types, 'Percentage': frac}) 135 | df = df.sort_values(['Percentage'], ascending=False) 136 | df = df.reset_index(drop=True) 137 | return df 138 | 139 | 140 | 141 | 142 | 143 | -------------------------------------------------------------------------------- /src/tutorial.md: -------------------------------------------------------------------------------- 1 | 1. Click the `Run` tab located at the top of the page. 2 | 2. To start the analysis, user can either upload a new histology image or simply click `Use an example slide`.