├── assets └── Model.png ├── requirements.txt ├── sh ├── upmc_hm.sh ├── charville_hm.sh ├── dfci_general.sh ├── charville_bc.sh └── upmc_bc.sh ├── src ├── __init__.py ├── inference.py ├── precomputing.py ├── models.py ├── train.py ├── utils.py ├── transform.py ├── features.py ├── data.py └── graph_build.py ├── dataset └── README.md ├── README.md └── main.py /assets/Model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UNITES-Lab/Mew/HEAD/assets/Model.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.26.4 2 | pandas==2.2.2 3 | torch==2.3.1 4 | torch-geometric==2.5.3 5 | scikit-learn==1.4.2 6 | scipy==1.13.1 7 | tqdm==4.66.4 -------------------------------------------------------------------------------- /sh/upmc_hm.sh: -------------------------------------------------------------------------------- 1 | device=4 2 | 3 | data=upmc 4 | task=cox 5 | shared=True 6 | attn_weight=True 7 | lr=0.001 8 | num_layers=1 9 | emb_dim=256 10 | batch_size=16 11 | drop_ratio=0.0 12 | 13 | python main.py \ 14 | --data $data \ 15 | --task $task \ 16 | --shared $shared \ 17 | --attn_weight $attn_weight \ 18 | --lr $lr \ 19 | --num_layers $num_layers \ 20 | --emb_dim $emb_dim \ 21 | --batch_size $batch_size \ 22 | --drop_ratio $drop_ratio \ 23 | --device $device 24 | -------------------------------------------------------------------------------- /sh/charville_hm.sh: -------------------------------------------------------------------------------- 1 | device=1 2 | 3 | data=charville 4 | task=cox 5 | shared=False 6 | attn_weight=False 7 | lr=0.0001 8 | num_layers=4 9 | emb_dim=128 10 | batch_size=32 11 | drop_ratio=0.5 12 | 13 | python main.py \ 14 | --data $data \ 15 | --task $task \ 16 | --shared $shared \ 17 | --attn_weight $attn_weight \ 18 | --lr $lr \ 19 | --num_layers $num_layers \ 20 | --emb_dim $emb_dim \ 21 | --batch_size $batch_size \ 22 | --drop_ratio $drop_ratio \ 23 | --device $device 24 | -------------------------------------------------------------------------------- /sh/dfci_general.sh: -------------------------------------------------------------------------------- 1 | device=2 2 | 3 | data=dfci 4 | task=classification 5 | shared=False 6 | attn_weight=True 7 | lr=0.001 8 | num_layers=2 9 | emb_dim=128 10 | batch_size=16 11 | drop_ratio=0.25 12 | 13 | python main.py \ 14 | --data $data \ 15 | --task $task \ 16 | --shared $shared \ 17 | --attn_weight $attn_weight \ 18 | --lr $lr \ 19 | --num_layers $num_layers \ 20 | --emb_dim $emb_dim \ 21 | --batch_size $batch_size \ 22 | --drop_ratio $drop_ratio \ 23 | --device $device 24 | -------------------------------------------------------------------------------- /sh/charville_bc.sh: -------------------------------------------------------------------------------- 1 | device=0 2 | 3 | data=charville 4 | task=classification 5 | shared=False 6 | attn_weight=False 7 | lr=0.0001 8 | num_layers=4 9 | emb_dim=512 10 | batch_size=16 11 | drop_ratio=0.5 12 | 13 | python main.py \ 14 | --data $data \ 15 | --task $task \ 16 | --shared $shared \ 17 | --attn_weight $attn_weight \ 18 | --lr $lr \ 19 | --num_layers $num_layers \ 20 | --emb_dim $emb_dim \ 21 | --batch_size $batch_size \ 22 | --drop_ratio $drop_ratio \ 23 | --device $device 24 | -------------------------------------------------------------------------------- /sh/upmc_bc.sh: -------------------------------------------------------------------------------- 1 | device=3 2 | 3 | data=upmc 4 | task=classification 5 | shared=True 6 | attn_weight=True 7 | pool=mean 8 | lr=0.0001 9 | num_layers=1 10 | emb_dim=256 11 | batch_size=16 12 | drop_ratio=0.0 13 | 14 | python main.py \ 15 | --data $data \ 16 | --task $task \ 17 | --shared $shared \ 18 | --attn_weight $attn_weight \ 19 | --pool $pool \ 20 | --lr $lr \ 21 | --num_layers $num_layers \ 22 | --emb_dim $emb_dim \ 23 | --batch_size $batch_size \ 24 | --drop_ratio $drop_ratio \ 25 | --device $device 26 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | from src.graph_build import plot_graph, plot_voronoi_polygons, construct_graph_for_region 2 | from src.data import CellularGraphDataset 3 | from src.models import SIGN_pred 4 | from src.transform import ( 5 | FeatureMask, 6 | AddCenterCellBiomarkerExpression, 7 | AddCenterCellType, 8 | AddCenterCellIdentifier, 9 | AddGraphLabel, 10 | AddTwoGraphLabel 11 | ) 12 | from src.inference import collect_predict_for_all_nodes 13 | from src.train import train_full_graph 14 | from src.precomputing import PrecomputingBase 15 | -------------------------------------------------------------------------------- /dataset/README.md: -------------------------------------------------------------------------------- 1 | ## **Dataset** 2 | 3 | 1. Please download the following raw data here: [Enable Medicine](https://app.enablemedicine.com/portal/atlas-library/studies/92394a9f-6b48-4897-87de-999614952d94?sid=1168) 4 | - UPMC-HNC: `upmc_raw_data.zip` and `upmc_labels.csv` 5 | - Stanford-CRC: `charville_raw_data.zip` and `charville_labels.csv` 6 | - DFCI-HNC: `dfci_raw_data.zip` and `dfci_labels.csv` 7 | 8 | 2. After the download is complete, please locate the above files as follows: 9 | ``` 10 | dataset/ 11 | ├── charville_data 12 | └── charville_raw_data.zip 13 | └── charville_labels.csv 14 | ├── upmc_data 15 | └── upmc_raw_data.zip 16 | └── upmc_labels.csv 17 | ├── dfci_data 18 | └── dfci_labels.csv 19 | ├── general_data 20 | └── upmc_raw_data.zip 21 | └── dfci_raw_data.zip 22 | ``` 23 | 24 | 3. By running each dataset with a certain task, the preprocessing (which would take a few hours to generate graphs) will automatically happen. Finally, the preprocessed structure for charville data will be as follows: 25 | ``` 26 | dataset/ 27 | ├── charville_data 28 | └── charville_raw_data.zip 29 | └── charville_labels.csv 30 | └── dataset_mew 31 | └── fig 32 | └── graph 33 | └── model 34 | └── tg_graph 35 | └── raw_data 36 | ``` 37 | -------------------------------------------------------------------------------- /src/inference.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | from tqdm import trange 4 | from lifelines.utils import concordance_index 5 | from sklearn.metrics import roc_auc_score 6 | 7 | def collect_predict_for_all_nodes(model, 8 | xs, 9 | device, 10 | inds=None, 11 | graph=False, 12 | **kwargs): 13 | 14 | model = model.to(device) 15 | model.eval() 16 | 17 | node_results = {} 18 | graph_results = {} 19 | voronoi_attn_results = {} 20 | cell_type_attn_results = {} 21 | 22 | start_time = time.time() 23 | for i in inds: 24 | input_1 = [_x.to(device) for _x in xs[i][0]] 25 | input_2 = [_x.to(device) for _x in xs[i][1]] 26 | res, attn_values = model([input_1, input_2]) 27 | 28 | voronoi_attn = attn_values[:,0].mean().item() 29 | cell_type_attn = attn_values[:,1].mean().item() 30 | 31 | graph_results[i] = res[0].cpu().data.numpy() 32 | voronoi_attn_results[i] = voronoi_attn 33 | cell_type_attn_results[i] = cell_type_attn 34 | 35 | # print(f'Prediction for {len(inds)} graphs:', time.time() - start_time) 36 | 37 | return node_results, graph_results, voronoi_attn_results, cell_type_attn_results 38 | 39 | 40 | def graph_classification_evaluate_fn(graph_preds, 41 | graph_ys, 42 | graph_ws=None, 43 | task='classification', 44 | print_res=True): 45 | """ Evaluate graph classification accuracy 46 | 47 | Args: 48 | graph_preds (array-like): binary classification logits for graph-level tasks, (num_subgraphs, num_tasks) 49 | graph_ys (array-like): binary labels for graph-level tasks, (num_subgraphs, num_tasks) 50 | graph_ws (array-like): weights for graph-level tasks, (num_subgraphs, num_tasks) 51 | print_res (bool): if to print the accuracy results 52 | 53 | Returns: 54 | list: list of metrics on all graph-level tasks 55 | """ 56 | if graph_ws is None: 57 | graph_ws = np.ones_like(graph_ys) 58 | scores = [] 59 | if task != 'classification': 60 | for task_i in trange(graph_preds.shape[1]): 61 | _pred = graph_preds[:, task_i] 62 | _times = graph_ys[:, task_i] 63 | _observed = graph_ws[:, task_i] 64 | idx = _times == _times 65 | s = concordance_index(_observed[idx], _pred[idx], _times[idx]) 66 | scores.append(s) 67 | else: 68 | for task_i in trange(graph_ys.shape[1]): 69 | _label = graph_ys[:, task_i] 70 | _pred = graph_preds[:, task_i] 71 | _w = graph_ws[:, task_i] 72 | s = roc_auc_score(_label[np.where(_w > 0)], _pred[np.where(_w > 0)]) 73 | scores.append(s) 74 | if print_res: 75 | print("GRAPH %s" % str(scores)) 76 | return scores 77 | 78 | 79 | def full_graph_graph_classification_evaluate_fn( 80 | dataset_yw, 81 | graph_results, 82 | task, 83 | aggr='mean', 84 | print_res=True): 85 | 86 | n_tasks = list(graph_results.values())[0].shape[1] 87 | graph_preds = [] 88 | graph_ys = [] 89 | graph_ws = [] 90 | for i in graph_results: 91 | graph_pred = [p for p in graph_results[i] if ((p is not None) and np.all(p == p))] 92 | graph_pred = np.stack(graph_pred, 0) 93 | 94 | if aggr == 'mean': 95 | graph_pred = np.nanmean(graph_pred, 0) 96 | else: 97 | raise NotImplementedError("Only mean-aggregation is supported now") 98 | 99 | graph_y = dataset_yw[i][0].numpy().flatten() 100 | graph_w = dataset_yw[i][1].numpy().flatten() 101 | graph_preds.append(graph_pred) 102 | graph_ys.append(graph_y) 103 | graph_ws.append(graph_w) 104 | 105 | graph_preds = np.concatenate(graph_preds, 0).reshape((-1, n_tasks)) 106 | graph_ys = np.concatenate(graph_ys, 0).reshape((-1, n_tasks)) 107 | graph_ws = np.concatenate(graph_ws, 0).reshape((-1, n_tasks)) 108 | 109 | return graph_classification_evaluate_fn(graph_preds, graph_ys, graph_ws, task, print_res=print_res) 110 | -------------------------------------------------------------------------------- /src/precomputing.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | from torch_geometric.transforms import SIGN 4 | from tqdm import trange 5 | from torch_geometric.data import Data 6 | from torch_geometric.transforms import BaseTransform 7 | import torch.distributions as dist 8 | from torch_geometric.utils import add_remaining_self_loops 9 | 10 | class PrecomputingBase(torch.nn.Module): 11 | def __init__(self, num_layer, device): 12 | super(PrecomputingBase, self).__init__() 13 | self.device = device 14 | self.num_layers = num_layer 15 | 16 | def precompute(self, data): 17 | print("precomputing features, may take a while.") 18 | t1 = time.time() 19 | print('Start Precomputing...!') 20 | 21 | self.xs = [] 22 | for i in trange(len(data)): 23 | _data = data[i] 24 | # Geom data 25 | geom_data = Data() 26 | geom_data.x = _data['cell'].x 27 | geom_data.edge_index = _data['cell', 'geom', 'cell'].edge_index 28 | geom_data.edge_attr = _data['cell', 'geom', 'cell'].edge_attr 29 | geom_data.edge_index , geom_data.edge_attr = add_remaining_self_loops(geom_data.edge_index, geom_data.edge_attr) 30 | geom_data = geom_data.to(self.device) 31 | _geom_x = SIGN(self.num_layers)(geom_data, stochastic=False) 32 | geom_x = [_geom_x.x.cpu()] + [_geom_x[f"x{i}"] for i in range(1, self.num_layers + 1)] 33 | 34 | # Cell-type data 35 | cell_type_data = Data() 36 | cell_type_data.x = _data['cell'].x 37 | cell_type_data.edge_index = _data['cell', 'type', 'cell'].edge_index 38 | cell_type_data.edge_attr = _data['cell', 'type', 'cell'].edge_attr 39 | cell_type_data.edge_index , cell_type_data.edge_attr = add_remaining_self_loops(cell_type_data.edge_index, cell_type_data.edge_attr) 40 | cell_type_data = cell_type_data.to(self.device) 41 | _cell_type_x = SIGN(self.num_layers)(cell_type_data, stochastic=True) 42 | cell_type_x = [_cell_type_x.x.cpu()] + [_cell_type_x[f"x{i}"] for i in range(1, self.num_layers + 1)] 43 | 44 | self.xs.append([geom_x, cell_type_x]) 45 | 46 | t2 = time.time() 47 | print("Precomputing finished by %.4f s." % (t2 - t1)) 48 | 49 | return self.xs 50 | 51 | def forward(self, xs): 52 | raise NotImplementedError 53 | 54 | 55 | class SIGN(BaseTransform): 56 | def __init__(self, K): 57 | self.K = K 58 | 59 | def __call__(self, data, stochastic=False): 60 | assert data.edge_index is not None 61 | assert data.x is not None 62 | if not stochastic: 63 | weight = torch.ones_like(data.edge_attr[:,1]) 64 | _edge_weight = data.edge_attr.sum(dim=1) 65 | 66 | else: 67 | weight = data.edge_attr[:,1] 68 | weight = 1 - weight 69 | _edge_weight = torch.ones_like(data.edge_attr[:,1]) 70 | 71 | xs = [data.x] 72 | for i in range(1, self.K + 1): 73 | adj_t = self.stochastic_adj(data, weight, _edge_weight=_edge_weight) 74 | xs += [adj_t @ xs[-1]] 75 | data[f'x{i}'] = xs[-1].cpu() 76 | 77 | return data 78 | 79 | def __repr__(self) -> str: 80 | return f'{self.__class__.__name__}(K={self.K})' 81 | 82 | def stochastic_adj(self, data, weight, _edge_weight=None): 83 | bernoulli_dist = dist.Bernoulli(weight) 84 | mask = bernoulli_dist.sample().bool() 85 | 86 | N = data.num_nodes 87 | N_edges = data.edge_index.shape[1] 88 | 89 | if _edge_weight == None: 90 | edge_attr = torch.ones((N_edges,), device=data.device) 91 | else: 92 | edge_attr = _edge_weight[mask] 93 | 94 | edge_index = data.edge_index[:,mask] 95 | row, col = edge_index 96 | 97 | deg = torch.zeros(N, device=row.device, dtype=edge_attr.dtype) 98 | deg.scatter_add_(dim=0, index=col, src=edge_attr) 99 | 100 | deg_inv_sqrt = deg.pow_(-0.5) 101 | deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0) 102 | edge_weight = deg_inv_sqrt[row] * edge_attr * deg_inv_sqrt[col] 103 | adj = torch.sparse_coo_tensor(edge_index, edge_weight, size=torch.Size([N, N])) 104 | adj_t = adj.t() 105 | 106 | return adj_t 107 | 108 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Mew: Multiplexed Immunofluorescence Image Analysis through an Efficient Multiplex Network 2 | 3 | [![License: MIT](https://img.shields.io/badge/License-MIT-green.svg)](https://opensource.org/licenses/MIT) [![ECCV 2024](https://img.shields.io/badge/ECCV'24-red)](https://eccv.ecva.net/) 4 | 5 | Official implementation for "Mew: Multiplexed Immunofluorescence Image Analysis through an Efficient Multiplex Network" accepted by ECCV 2024. Our implementation is based on [SPACE-GM](https://gitlab.com/enable-medicine-public/space-gm). 6 | 7 | - Authors: [Sukwon Yun](https://sukwonyun.github.io/), [Jie Peng](https://openreview.net/profile?id=~Jie_Peng4), [Alexandro E. Trevino](https://scholar.google.com/citations?user=z7HDsuAAAAAJ&hl=en), [Chanyoung Park](https://dsail.kaist.ac.kr/professor/) and [Tianlong Chen](https://tianlong-chen.github.io/) 8 | - Paper: [arXiv](https://arxiv.org/abs/2407.17857) 9 | 10 | ## Overview 11 | 12 | Recent advancements in graph-based approaches for multiplexed immunofluorescence (mIF) images have significantly propelled the field forward, offering deeper insights into patient-level phenotyping. However, current graph-based methodologies encounter two primary challenges: (1) Cellular Heterogeneity, where existing approaches fail to adequately address the inductive biases inherent in graphs, particularly the homophily characteristic observed in cellular connectivity and; (2) Scalability, where handling cellular graphs from high-dimensional images faces difficulties in managing a high number of cells. To overcome these limitations, we introduce Mew, a novel framework designed to efficiently process mIF images through the lens of multiplex network. Mew innovatively constructs a multiplex network comprising two distinct layers: a Voronoi network for geometric information and a Cell-type network for capturing cell-wise homogeneity. This framework equips a scalable and efficient Graph Neural Network (GNN), capable of processing the entire graph during training. Furthermore, Mew integrates an interpretable attention module that autonomously identifies relevant layers for image classification. Extensive experiments on a real-world patient dataset from various institutions highlight Mew's remarkable efficacy and efficiency, marking a significant advancement in mIF image analysis. 13 | 14 | 15 | 16 | 17 | ## **Setup** 18 | 19 | ``` 20 | conda create -n mew python=3.10 -y && conda activate mew 21 | pip install -r requirements.txt 22 | ``` 23 | 24 | ## **Dataset** 25 | 26 | 1. Please download the following raw data here: [Enable Medicine](https://app.enablemedicine.com/portal/atlas-library/studies/92394a9f-6b48-4897-87de-999614952d94?sid=1168) 27 | - UPMC-HNC: `upmc_raw_data.zip` and `upmc_labels.csv` 28 | - Stanford-CRC: `charville_raw_data.zip` and `charville_labels.csv` 29 | - DFCI-HNC: `dfci_raw_data.zip` and `dfci_labels.csv` 30 | 31 | 2. After the download is complete, please locate the above files as follows: 32 | ``` 33 | dataset/ 34 | ├── charville_data 35 | └── charville_raw_data.zip 36 | └── charville_labels.csv 37 | ├── upmc_data 38 | └── upmc_raw_data.zip 39 | └── upmc_labels.csv 40 | ├── dfci_data 41 | └── dfci_labels.csv 42 | ├── general_data 43 | └── upmc_raw_data.zip 44 | └── dfci_raw_data.zip 45 | ``` 46 | 47 | 3. By running each dataset with a certain task, the preprocessing (which would take a few hours to generate graphs) will automatically happen. Finally, the preprocessed structure for charville data will be as follows: 48 | ``` 49 | dataset/ 50 | ├── charville_data 51 | └── charville_raw_data.zip 52 | └── charville_labels.csv 53 | └── dataset_mew 54 | └── fig 55 | └── graph 56 | └── model 57 | └── tg_graph 58 | └── raw_data 59 | ``` 60 | 61 | ## **Usage and Example** 62 | 1. Choose the data (upmc, charville, dfci) and task (classification, cox) and pass the hyperparameters. For the UPMC-HM (Hazard Modeling) task, it can be as follows: 63 | ``` 64 | python main.py \ 65 | --data upmc \ 66 | --task cox \ 67 | --shared True \ 68 | --attn_weight True \ 69 | --lr 0.001 \ 70 | --num_layers 1 \ 71 | --emb_dim 256 \ 72 | --batch_size 16 \ 73 | --drop_ratio 0.0 \ 74 | --device 0 75 | ``` 76 | 77 | 2. For others, please use: 78 | - **UPMC-BC**: `sh ./sh/upmc_bc.sh` 79 | - **UPMC-HM**: `sh ./sh/upmc_hm.sh` 80 | - **Charville-BC**: `sh ./sh/charville_bc.sh` 81 | - **Charville-HM**: `sh ./sh/charville_hm.sh` 82 | - **DFCI-Generalization**: `sh ./sh/dfci_general.sh` 83 | 84 | 85 | ## Citation 86 | 87 | ```bibtex 88 | @misc{yun2024mew, 89 | title={Mew: Multiplexed Immunofluorescence Image Analysis through an Efficient Multiplex Network}, 90 | author={Sukwon Yun and Jie Peng and Alexandro E. Trevino and Chanyoung Park and Tianlong Chen}, 91 | year={2024}, 92 | eprint={2407.17857}, 93 | archivePrefix={arXiv}, 94 | primaryClass={cs.CV}, 95 | url={https://arxiv.org/abs/2407.17857}, 96 | } 97 | ``` -------------------------------------------------------------------------------- /src/models.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool, GlobalAttention, Set2Set 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch_scatter import scatter_add 7 | from torch_geometric.nn.inits import glorot, zeros 8 | 9 | class FeedForwardNet(nn.Module): 10 | def __init__(self, in_feats, hidden, out_feats, n_layers, dropout): 11 | super(FeedForwardNet, self).__init__() 12 | self.layers = nn.ModuleList() 13 | self.n_layers = n_layers 14 | if n_layers == 1: 15 | self.layers.append(nn.Linear(in_feats, out_feats)) 16 | else: 17 | self.layers.append(nn.Linear(in_feats, hidden)) 18 | for i in range(n_layers - 2): 19 | self.layers.append(nn.Linear(hidden, hidden)) 20 | self.layers.append(nn.Linear(hidden, out_feats)) 21 | if self.n_layers > 1: 22 | self.prelu = nn.PReLU() 23 | self.dropout = nn.Dropout(dropout) 24 | self.reset_parameters() 25 | 26 | def reset_parameters(self): 27 | gain = nn.init.calculate_gain("relu") 28 | for layer in self.layers: 29 | nn.init.xavier_uniform_(layer.weight, gain=gain) 30 | nn.init.zeros_(layer.bias) 31 | 32 | def forward(self, x): 33 | for layer_id, layer in enumerate(self.layers): 34 | x = layer(x) 35 | if layer_id < self.n_layers - 1: 36 | x = self.dropout(self.prelu(x)) 37 | return x 38 | 39 | class SIGN_v2(nn.Module): 40 | def __init__(self, num_layer, num_feat, emb_dim, ffn_layers=2, dropout=0.25): 41 | super(SIGN_v2, self).__init__() 42 | 43 | in_feats = num_feat 44 | emb_dim = emb_dim 45 | out_feats = emb_dim 46 | num_hops = num_layer + 1 47 | 48 | self.dropout = nn.Dropout(dropout) 49 | self.prelu = nn.PReLU() 50 | self.inception_ffs = nn.ModuleList() 51 | 52 | self.batch_norms = torch.nn.ModuleList() 53 | for hop in range(num_hops): 54 | self.inception_ffs.append( 55 | FeedForwardNet(in_feats, emb_dim, emb_dim, ffn_layers, dropout)) 56 | self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim)) 57 | self.project = FeedForwardNet(num_hops * emb_dim, emb_dim, out_feats, 58 | ffn_layers, dropout) 59 | 60 | def forward(self, feats): 61 | feats = [feat.float() for feat in feats] 62 | hidden = [] 63 | for i, (feat, ff) in enumerate(zip(feats, self.inception_ffs)): 64 | emb = ff(feat) 65 | hidden.append(self.batch_norms[i](emb)) 66 | out = self.project(self.dropout(self.prelu(torch.cat(hidden, dim=-1)))) 67 | return out 68 | 69 | def reset_parameters(self): 70 | for ff in self.inception_ffs: 71 | ff.reset_parameters() 72 | self.project.reset_parameters() 73 | 74 | 75 | class SIGN_pred(torch.nn.Module): 76 | def __init__(self, 77 | num_layer=2, 78 | num_feat=38, 79 | emb_dim=256, 80 | num_additional_feat=0, 81 | num_node_tasks=15, 82 | num_graph_tasks=2, 83 | node_embedding_output="last", 84 | drop_ratio=0, 85 | graph_pooling="mean", 86 | attn_weight=False, 87 | shared=False, 88 | ): 89 | super(SIGN_pred, self).__init__() 90 | self.drop_ratio = drop_ratio 91 | self.emb_dim = emb_dim 92 | self.num_node_tasks = num_node_tasks 93 | self.num_graph_tasks = num_graph_tasks 94 | self.attn_weight = attn_weight 95 | 96 | self.sign = SIGN_v2(num_layer, num_feat, emb_dim, dropout=drop_ratio) 97 | self.sign2 = self.sign if shared else SIGN_v2(num_layer, num_feat, emb_dim, dropout=drop_ratio) 98 | self.leakyrelu = nn.LeakyReLU(0.3) 99 | self.attention = nn.Parameter(torch.empty(size=(emb_dim, 1))) 100 | glorot(self.attention) 101 | if self.attn_weight: 102 | self.w1 = torch.nn.Linear(emb_dim, emb_dim) 103 | self.w2 = torch.nn.Linear(emb_dim, emb_dim) 104 | torch.nn.init.xavier_uniform_(self.w1.weight.data) 105 | torch.nn.init.xavier_uniform_(self.w1.weight.data) 106 | 107 | # Different kind of graph pooling 108 | if graph_pooling == "sum": 109 | self.pool = global_add_pool 110 | elif graph_pooling == "mean": 111 | self.pool = global_mean_pool 112 | elif graph_pooling == "max": 113 | self.pool = global_max_pool 114 | elif graph_pooling == "attention": 115 | if node_embedding_output == "concat": 116 | self.pool = GlobalAttention(gate_nn=torch.nn.Linear((self.num_layer + 1) * emb_dim, 1)) 117 | else: 118 | self.pool = GlobalAttention(gate_nn=torch.nn.Linear(emb_dim, 1)) 119 | elif graph_pooling[:-1] == "set2set": 120 | set2set_iter = int(graph_pooling[-1]) 121 | if node_embedding_output == "concat": 122 | self.pool = Set2Set((self.num_layer + 1) * emb_dim, set2set_iter) 123 | else: 124 | self.pool = Set2Set(emb_dim, set2set_iter) 125 | else: 126 | raise ValueError("Invalid graph pooling type.") 127 | 128 | # For node and graph predictions 129 | self.mult = 1 130 | if graph_pooling[:-1] == "set2set": 131 | self.mult *= 2 132 | if node_embedding_output == "concat": 133 | self.mult *= (self.num_layer + 1) 134 | 135 | node_embedding_dim = self.mult * self.emb_dim 136 | if self.num_graph_tasks > 0: 137 | self.graph_pred_module = torch.nn.Sequential( 138 | torch.nn.Linear(node_embedding_dim + num_additional_feat, node_embedding_dim), 139 | torch.nn.LeakyReLU(), 140 | torch.nn.Linear(node_embedding_dim, node_embedding_dim), 141 | torch.nn.LeakyReLU(), 142 | torch.nn.Linear(node_embedding_dim, self.num_graph_tasks)) 143 | 144 | 145 | def from_pretrained(self, model_file): 146 | original_dict = torch.load(model_file) 147 | new_dict = {key.replace('gnn.', '') if 'gnn.' in key else key: value for key, value in original_dict.items()} 148 | if self.num_graph_tasks > 0: 149 | graph_tasks_idx = [] 150 | for key in new_dict.keys(): 151 | if key.startswith('graph_pred_module'): 152 | idx_part = key.split('.')[1] 153 | graph_tasks_idx.append(int(idx_part)) 154 | 155 | for i in range(len(self.graph_pred_module)): 156 | if i in graph_tasks_idx: 157 | self.graph_pred_module[i].weight = torch.nn.Parameter(new_dict[f'graph_pred_module.{i}.weight']) 158 | self.graph_pred_module[i].bias = torch.nn.Parameter(new_dict[f'graph_pred_module.{i}.bias']) 159 | del new_dict[f'graph_pred_module.{i}.weight'] 160 | del new_dict[f'graph_pred_module.{i}.bias'] 161 | 162 | self.gnn.load_state_dict(new_dict) 163 | 164 | def forward(self, data, batch=None, embed=False): 165 | batch = batch if batch != None else torch.zeros(len(data[0][0]),).long() 166 | 167 | node_representation_1 = self.sign(data[0]) # geom 168 | node_representation_2 = self.sign2(data[1]) # cell_type 169 | 170 | if self.attn_weight: 171 | geom_ = self.leakyrelu(torch.mm(self.w1(node_representation_1), self.attention)) 172 | cell_type_ = self.leakyrelu(torch.mm(self.w2(node_representation_2), self.attention)) 173 | else: 174 | geom_ = self.leakyrelu(torch.mm(node_representation_1, self.attention)) 175 | cell_type_ = self.leakyrelu(torch.mm(node_representation_2, self.attention)) 176 | 177 | values = torch.softmax(torch.cat((geom_, cell_type_), dim=1), dim=1) 178 | node_representation = (values[:,0].unsqueeze(1) * node_representation_1) + (values[:,1].unsqueeze(1) * node_representation_2) 179 | 180 | return_vals = [] 181 | if self.num_graph_tasks > 0: 182 | input = self.pool(node_representation, batch.to(node_representation.device)) 183 | graph_pred = self.graph_pred_module(input) 184 | return_vals.append(graph_pred) 185 | 186 | if embed: 187 | return return_vals, values, node_representation_1, node_representation_2 188 | else: 189 | return return_vals, values 190 | 191 | # Loss functions 192 | class BinaryCrossEntropy(torch.nn.Module): 193 | """Weighted binary cross entropy loss function""" 194 | def __init__(self, **kwargs): 195 | super(BinaryCrossEntropy, self).__init__(**kwargs) 196 | self.loss_fn = torch.nn.BCEWithLogitsLoss(reduction='none') 197 | 198 | def forward(self, y_pred, y, w): 199 | return (self.loss_fn(y_pred, y) * w).mean() 200 | 201 | class CrossEntropy(torch.nn.Module): 202 | """Cross entropy loss function for multi-class classification""" 203 | def __init__(self, **kwargs): 204 | super(CrossEntropy, self).__init__(**kwargs) 205 | self.loss_fn = torch.nn.CrossEntropyLoss() 206 | 207 | def forward(self, y_pred, y, w): 208 | return self.loss_fn(y_pred, y) 209 | 210 | 211 | class CoxSGDLossFn(torch.nn.Module): 212 | """Cox SGD loss function""" 213 | def __init__(self, top_n=2, regularizer_weight=0.05, **kwargs): 214 | self.top_n = top_n 215 | self.regularizer_weight = regularizer_weight 216 | super(CoxSGDLossFn, self).__init__(**kwargs) 217 | 218 | def forward(self, y_pred, length, event): 219 | assert y_pred.shape[0] == length.shape[0] == event.shape[0] 220 | n_samples = y_pred.shape[0] 221 | num_tasks = y_pred.shape[1] 222 | loss = 0 223 | 224 | for task in range(num_tasks): 225 | _length = length[:, task] 226 | _event = event[:, task] 227 | _pred = y_pred[:, task] 228 | 229 | pair_mat = (_length.reshape((1, -1)) - _length.reshape((-1, 1)) > 0) * _event.reshape((-1, 1)) 230 | if self.top_n > 0: 231 | p_with_rand = pair_mat * (1 + torch.rand_like(pair_mat)) 232 | rand_thr_ind = torch.argsort(p_with_rand, axis=1)[:, -(self.top_n + 1)] 233 | rand_thr = p_with_rand[(torch.arange(n_samples), rand_thr_ind)].reshape((-1, 1)) 234 | pair_mat = pair_mat * (p_with_rand > rand_thr) 235 | 236 | valid_sample_is = torch.nonzero(pair_mat.sum(1)).flatten() 237 | pair_mat[(valid_sample_is, valid_sample_is)] = 1 238 | 239 | score_diff = (_pred.reshape((1, -1)) - _pred.reshape((-1, 1))) 240 | score_diff_row_max = torch.max(score_diff, axis=1, keepdims=True).values 241 | loss_tmp = (torch.exp(score_diff - score_diff_row_max) * pair_mat).sum(1) 242 | loss += (score_diff_row_max[:, 0][valid_sample_is] + torch.log(loss_tmp[valid_sample_is])).sum() 243 | 244 | regularizer = torch.abs(pair_mat.sum(0) * _pred.flatten()).sum() 245 | loss += self.regularizer_weight * regularizer 246 | 247 | return loss -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import torch.optim 5 | from torch.utils.data import Dataset 6 | from torch.utils.data import DataLoader 7 | 8 | from src.inference import collect_predict_for_all_nodes 9 | from tqdm import trange 10 | from copy import deepcopy 11 | 12 | class CombinedDataset(Dataset): 13 | def __init__(self, graph_dataset, x_dataset): 14 | self.graph_dataset = graph_dataset 15 | self.x_dataset = x_dataset 16 | 17 | def __getitem__(self, idx): 18 | graph_y, graph_w = self.graph_dataset[idx] 19 | geom_x, cell_type_x = self.x_dataset[idx] 20 | geom_x_tensor = torch.stack(geom_x) 21 | cell_type_x_tensor = torch.stack(cell_type_x) 22 | batch_length = geom_x[0].shape[0] 23 | 24 | return graph_y, graph_w, geom_x_tensor, cell_type_x_tensor, batch_length 25 | 26 | def __len__(self): 27 | return len(self.graph_dataset) 28 | 29 | def custom_collate(batch): 30 | graph_y_list, graph_w_list, geom_x_list, cell_type_x_list, batch_length_list = zip(*batch) 31 | batch_idx_list = torch.tensor(sum([[i]*batch_length_list[i] for i in range(len(batch))], [])) 32 | graph_y = torch.stack(graph_y_list).squeeze(2) 33 | graph_w = torch.stack(graph_w_list).squeeze(2) 34 | geom_x_layers = torch.cat(geom_x_list, dim=1) 35 | cell_type_x_layers = torch.cat(cell_type_x_list, dim=1) 36 | 37 | return batch_idx_list, graph_y, graph_w, geom_x_layers, cell_type_x_layers 38 | 39 | def train_full_graph(model, 40 | dataset_yw, 41 | xs, 42 | device, 43 | filename, 44 | fold, 45 | logger, 46 | graph_task_loss_fn=None, 47 | train_inds=None, 48 | valid_inds=None, 49 | test_inds=None, 50 | early_stop=100, 51 | num_iterations=1e5, 52 | evaluate_freq=1e4, 53 | evaluate_fn=[], 54 | evaluate_on_train=True, 55 | batch_size=64, 56 | lr=0.001, 57 | graph_loss_weight=1., 58 | task='classification', 59 | name=None, 60 | save_model=False, 61 | **kwargs): 62 | 63 | if train_inds is None: 64 | train_inds = np.arange(len(xs)) 65 | 66 | optimizer = torch.optim.Adam(model.parameters(), lr=lr) 67 | node_losses = [] 68 | graph_losses = [] 69 | 70 | train_dataset = [dataset_yw[i] for i in train_inds] 71 | train_xs = [xs[i] for i in train_inds] 72 | 73 | train_dataset_combined = CombinedDataset(train_dataset, train_xs) 74 | train_loader = DataLoader(train_dataset_combined, batch_size=batch_size, shuffle=True, num_workers=0, collate_fn=custom_collate) 75 | 76 | best_score = 0 77 | cnt = 0 78 | 79 | for i_iter in trange(int(num_iterations)): 80 | model.train() 81 | model.zero_grad() 82 | 83 | batch, graph_y, graph_w, geom_x, cell_type_x = next(iter(train_loader)) 84 | loss = 0. 85 | geom_x = [_x.to(device) for _x in geom_x] 86 | cell_type_x = [_x.to(device) for _x in cell_type_x] 87 | 88 | res, _, z1, z2 = model([geom_x, cell_type_x], batch, embed=True) 89 | if res[0].sum().isnan(): 90 | return [-1] * model.num_graph_tasks 91 | 92 | if model.num_graph_tasks > 0: 93 | assert graph_task_loss_fn is not None, \ 94 | "Please specify `graph_task_loss_fn` in the training kwargs" 95 | 96 | graph_pred = res[-1] 97 | graph_y, graph_w = graph_y.float().to(device), graph_w.float().to(device) 98 | 99 | if model.num_graph_tasks > 1: 100 | graph_y, graph_w = graph_y.squeeze(1), graph_w.squeeze(1) 101 | idx = graph_y[:,0] == graph_y[:,0] 102 | graph_loss = graph_task_loss_fn(graph_pred[idx], graph_y[idx], graph_w[idx]) 103 | loss += graph_loss * graph_loss_weight 104 | 105 | graph_losses.append(graph_loss.to('cpu').data.item()) 106 | 107 | loss.backward() 108 | optimizer.step() 109 | 110 | if i_iter > 0 and (i_iter+1) % evaluate_freq == 0: 111 | model.eval() 112 | summary_str = "Finished iterations %d" % (i_iter+1) 113 | if len(node_losses) > evaluate_freq: 114 | summary_str += ", node loss %.2f" % np.mean(node_losses[-evaluate_freq:]) 115 | if len(graph_losses) > evaluate_freq: 116 | summary_str += ", graph loss %.2f" % np.mean(graph_losses[-evaluate_freq:]) 117 | print(summary_str) 118 | 119 | fn = evaluate_fn[0] 120 | result = fn(model, 121 | dataset_yw, 122 | xs, 123 | device, 124 | filename, 125 | fold, 126 | logger, 127 | train_inds=train_inds if evaluate_on_train else None, 128 | valid_inds=valid_inds, 129 | batch_size=batch_size, 130 | task=task, 131 | **kwargs) 132 | 133 | _txt = ",".join([s if isinstance(s, str) else ("%.3f" % s) for s in result]) 134 | logger.info(f'[{name}-{task}][Fold {fold}][Epoch {i_iter+1}/{num_iterations}]-{_txt} \t Filename: {filename}') 135 | if cnt > 0: 136 | print(best_txt_test) 137 | valid_score = np.mean(result[-model.num_graph_tasks:]) 138 | 139 | if best_score < valid_score: 140 | cnt +=1 141 | print(f'Iter: {i_iter+1}, Reached best valid score! Start testing ...') 142 | best_score = valid_score 143 | 144 | # Test on test_inds 145 | if test_inds != None: 146 | result_test = fn(model, 147 | dataset_yw, 148 | xs, 149 | device, 150 | filename, 151 | fold, 152 | logger, 153 | train_inds=None, 154 | valid_inds=test_inds, 155 | batch_size=batch_size, 156 | task=task, 157 | save=True, 158 | **kwargs) 159 | 160 | _txt_test = ",".join([s if isinstance(s, str) else ("%.3f" % s) for s in result_test]) 161 | best_txt_test = f'[{name}-{task}][Fold {fold}][(Best Epoch) {i_iter+1}/{num_iterations}][(Test)]-{_txt_test} \t Filename: {filename}' 162 | logger.info(best_txt_test) 163 | 164 | best_epoch = i_iter 165 | 166 | if save_model: 167 | fn_save = evaluate_fn[1] 168 | fn_save(model, 169 | name, 170 | task, 171 | fold) 172 | 173 | else: 174 | _txt_test = _txt 175 | 176 | elif i_iter - best_epoch >= early_stop: 177 | print('Early Stop!') 178 | break 179 | 180 | else: 181 | pass 182 | 183 | return model, _txt_test 184 | 185 | 186 | def evaluate_by_full_graph(model, 187 | dataset_yw, 188 | xs, 189 | device, 190 | filename, 191 | fold, 192 | logger, 193 | train_inds=None, 194 | valid_inds=None, 195 | batch_size=64, 196 | shuffle=True, 197 | full_graph_graph_task_evaluate_fn=None, 198 | score_file=None, 199 | save=False, 200 | task='classification', 201 | **kwargs): 202 | 203 | score_row = ["Eval-Full-Graph"] 204 | if train_inds is not None: 205 | score_row.append("Train") 206 | node_preds, graph_preds, voronoi_attn, cell_type_attn = collect_predict_for_all_nodes( 207 | model, 208 | xs, 209 | device, 210 | inds=train_inds, 211 | batch_size=batch_size, 212 | shuffle=shuffle, 213 | **kwargs) 214 | 215 | if model.num_graph_tasks > 0: 216 | # Evaluate graph-level predictions 217 | assert full_graph_graph_task_evaluate_fn is not None, \ 218 | "Please specify `full_graph_graph_task_evaluate_fn` in the training kwargs" 219 | score_row.append("attn-score") 220 | score_row.extend([np.nanmean(np.array(list(voronoi_attn.values()))), np.nanmean(np.array(list(cell_type_attn.values())))]) 221 | score_row.append("graph-score") 222 | score_row.extend(full_graph_graph_task_evaluate_fn(dataset_yw, graph_preds, task, print_res=False)) 223 | 224 | if valid_inds is not None: 225 | score_row.append("Valid") 226 | node_preds, graph_preds, voronoi_attn, cell_type_attn = collect_predict_for_all_nodes( 227 | model, 228 | xs, 229 | device, 230 | inds=valid_inds, 231 | batch_size=batch_size, 232 | shuffle=shuffle, 233 | **kwargs) 234 | 235 | region_ids = [] 236 | graph_pred_list = [] 237 | graph_ys = [] 238 | graph_ws = [] 239 | 240 | for _i, (i, pred) in enumerate(graph_preds.items()): 241 | region_id = valid_inds[_i] 242 | graph_y = dataset_yw[i][0] 243 | graph_w = dataset_yw[i][1] 244 | graph_pred_ = np.nanmean(pred, 0) 245 | 246 | region_ids.append(region_id) 247 | graph_pred_list.append(graph_pred_) 248 | graph_ys.append(graph_y) 249 | graph_ws.append(graph_w) 250 | 251 | graph_ys = np.stack(graph_ys).squeeze(1) 252 | graph_ws = np.stack(graph_ws).squeeze(1) 253 | 254 | if model.num_graph_tasks > 0: 255 | assert full_graph_graph_task_evaluate_fn is not None, \ 256 | "Please specify `full_graph_graph_task_evaluate_fn` in the training kwargs" 257 | score_row.append("attn-score") 258 | score_row.extend([np.nanmean(np.array(list(voronoi_attn.values()))), np.nanmean(np.array(list(cell_type_attn.values())))]) 259 | score_row.append("graph-score") 260 | score_row.extend(full_graph_graph_task_evaluate_fn(dataset_yw, graph_preds, task, print_res=False)) 261 | 262 | if score_file is not None: 263 | with open(score_file, 'a') as f: 264 | f.write(",".join([s if isinstance(s, str) else ("%.3f" % s) for s in score_row]) + '\n') 265 | 266 | return score_row 267 | 268 | 269 | def save_model_weight(model, 270 | name, 271 | task, 272 | fold): 273 | os.makedirs('./ckpts/', exist_ok=True) 274 | model_tmp = deepcopy(model).cpu() 275 | torch.save(model_tmp.state_dict(), f'./ckpts/Mew_{name}_{task}_{fold}.pt') -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pickle 3 | import torch 4 | import os 5 | import sys 6 | import random 7 | import logging 8 | import numpy as np 9 | import pandas as pd 10 | 11 | MB = 1024 ** 2 12 | GB = 1024 ** 3 13 | 14 | EDGE_TYPES = { 15 | "neighbor": 0, 16 | "distant": 1, 17 | "self": 2, 18 | } 19 | 20 | # Metadata for the example dataset 21 | BIOMARKERS_UPMC = [ 22 | "CD11b", "CD14", "CD15", "CD163", "CD20", "CD21", "CD31", "CD34", "CD3e", 23 | "CD4", "CD45", "CD45RA", "CD45RO", "CD68", "CD8", "CollagenIV", "HLA-DR", 24 | "Ki67", "PanCK", "Podoplanin", "Vimentin", "aSMA", 25 | ] 26 | 27 | CELL_TYPE_MAPPING_UPMC = { 28 | 'APC': 0, 29 | 'B cell': 1, 30 | 'CD4 T cell': 2, 31 | 'CD8 T cell': 3, 32 | 'Granulocyte': 4, 33 | 'Lymph vessel': 5, 34 | 'Macrophage': 6, 35 | 'Naive immune cell': 7, 36 | 'Stromal / Fibroblast': 8, 37 | 'Tumor': 9, 38 | 'Tumor (CD15+)': 10, 39 | 'Tumor (CD20+)': 11, 40 | 'Tumor (CD21+)': 12, 41 | 'Tumor (Ki67+)': 13, 42 | 'Tumor (Podo+)': 14, 43 | 'Vessel': 15, 44 | 'Unassigned': 16, 45 | } 46 | 47 | CELL_TYPE_FREQ_UPMC = { 48 | 'APC': 0.038220815854819415, 49 | 'B cell': 0.06635091324932002, 50 | 'CD4 T cell': 0.09489001514723677, 51 | 'CD8 T cell': 0.07824503590797544, 52 | 'Granulocyte': 0.026886102677111563, 53 | 'Lymph vessel': 0.006429085023448621, 54 | 'Macrophage': 0.10251942892685563, 55 | 'Naive immune cell': 0.033537398925429215, 56 | 'Stromal / Fibroblast': 0.07692583870182068, 57 | 'Tumor': 0.10921293560435145, 58 | 'Tumor (CD15+)': 0.06106975782857908, 59 | 'Tumor (CD20+)': 0.02098925720318548, 60 | 'Tumor (CD21+)': 0.053892044158901406, 61 | 'Tumor (Ki67+)': 0.13373768013421947, 62 | 'Tumor (Podo+)': 0.06276108605978743, 63 | 'Vessel': 0.034332604596958326, 64 | 'Unassigned': 0.001, 65 | } 66 | 67 | 68 | def setup_logger(save_dir, text, filename = 'log.txt'): 69 | os.makedirs(save_dir, exist_ok=True) 70 | logger = logging.getLogger(text) 71 | logger.setLevel(4) 72 | ch = logging.StreamHandler(stream=sys.stdout) 73 | ch.setLevel(logging.DEBUG) 74 | formatter = logging.Formatter("%(message)s") 75 | ch.setFormatter(formatter) 76 | logger.addHandler(ch) 77 | if save_dir: 78 | fh = logging.FileHandler(os.path.join(save_dir, filename)) 79 | fh.setLevel(logging.DEBUG) 80 | fh.setFormatter(formatter) 81 | logger.addHandler(fh) 82 | logger.info("======================================================================================") 83 | 84 | return logger 85 | 86 | def seed_everything(seed=0): 87 | random.seed(seed) 88 | os.environ['PYTHONHASHSEED'] = str(seed) 89 | np.random.seed(seed) 90 | torch.manual_seed(seed) 91 | torch.cuda.manual_seed(seed) 92 | torch.backends.cudnn.deterministic = True 93 | 94 | def get_initials(string_list): 95 | initials = "".join(word[0] for word in string_list if word) # Ensure the word is not empty 96 | return initials.lower() 97 | 98 | def generate_charville_label(df): 99 | ## recurrence_interval & recurrence_event 100 | # Convert 'surgery_date' to datetime & Convert 'first_recurrence_date' to datetime, handling 'NONE' cases 101 | df['surgery_date'] = pd.to_datetime(df['surgery_date']) 102 | df['first_recurrence_date'] = pd.to_datetime(df['first_recurrence_date'], errors='coerce') 103 | df['last_contact_date'] = pd.to_datetime(df['last_contact_date'], errors='coerce') 104 | 105 | # Calculate the recurrence interval as the number of days from 'surgery_date' to 'first_recurrence_date' 106 | df['recurrence_interval'] = (df['first_recurrence_date'] - df['surgery_date']).dt.days 107 | 108 | # Recurrence event: 1 if there is a recurrence, 0 if 'NONE/DISEASE FREE' or 'NEVER DISEASE FREE' 109 | df['recurrence_event'] = df['type_of_first_recurrence'].apply(lambda x: 0 if 'FREE' in x else 1) 110 | df.loc[df['first_recurrence_date'].isna(), 'recurrence_interval'] = (df['last_contact_date'] - df['surgery_date']).dt.days 111 | 112 | df.loc[df['first_recurrence_date'].isna(), 'recurrence_interval'] = (df['last_contact_date'] - df['surgery_date']).dt.days 113 | 114 | ## survival_legnth & survival_status 115 | def calculate_length(row): 116 | # If the 'first_recurrence_date' is not missing, use it to calculate the length 117 | if pd.notna(row['first_recurrence_date']): 118 | return (row['first_recurrence_date'] - row['surgery_date']).days / 30.4375 # Average days per month 119 | # If 'first_recurrence_date' is missing, but 'last_contact_date' is not, use 'last_contact_date' 120 | elif pd.notna(row['last_contact_date']): 121 | return (row['last_contact_date'] - row['surgery_date']).days / 30.4375 122 | # If both are missing, return NaN 123 | else: 124 | return pd.NA 125 | 126 | # Apply the function to the rows where 'length_of_disease_free_survival' is missing 127 | df['survival_legnth'] = df.apply(calculate_length, axis=1) 128 | df['survival_legnth'] = np.round(pd.to_numeric(df['survival_legnth'], errors='coerce')) 129 | 130 | # Survival event: 1 if patient survives, 0 if dead 131 | df['survival_status'] = df['alive_or_deceased'].apply(lambda x: 1 if 'Dead' in x else 0) 132 | 133 | return df 134 | 135 | def preprocess_generalization(root='./dataset'): 136 | upmc_biomarker_cols = set(pd.read_csv(f'{root}/upmc_data/raw_data/UPMC_c001_v001_r001_reg001.expression.csv').columns) 137 | dfci_biomarker_cols = set(pd.read_csv(f'{root}/dfci_data/raw_data/s271_c001_v001_r001_reg001.expression.csv').columns) 138 | 139 | common_biomarker_cols_list = list(upmc_biomarker_cols & dfci_biomarker_cols) 140 | common_cell_type_dict = { 141 | 'APC': 'APC', 142 | 'Dendritic cell': 'APC', 143 | 'APC/macrophage': 'APC', 144 | 'B cell': 'B cell', 145 | 'CD4 T cell': 'CD4 T cell', 146 | 'T cell (CD45RO+/FoxP3+/ICOS+)': 'CD4 T cell', 147 | 'CD4 T cell (ICOS+/FoxP3+)': 'CD4 T cell', 148 | 'CD8 T cell': 'CD8 T cell', 149 | 'T cell (GranzymeB+/LAG3+)': 'CD8 T cell', 150 | 'Granulocyte': 'Granulocyte', 151 | 'Lymph vessel': 'Vessel cell', 152 | 'Macrophage': 'Macrophage', 153 | 'Naive immune cell': 'Naive immune cell', 154 | 'Naive B cell': 'Naive immune cell', 155 | 'Naive lymphocyte (CD45RA+/CD38+)': 'Naive immune cell', 156 | 'Stromal / Fibroblast': 'Stromal cell', 157 | 'Stroma': 'Stromal cell', 158 | 'Tumor': 'Tumor', 159 | 'Tumor (CD15+)': 'Tumor', 160 | 'Tumor (CD20+)': 'Tumor', 161 | 'Tumor (CD21+)': 'Tumor', 162 | 'Tumor (Ki67+)': 'Tumor (Ki67+)', 163 | 'Tumor (Podo+)': 'Tumor', 164 | 'Tumor (PanCK hi)': 'Tumor', 165 | 'Tumor (PanCK low)': 'Tumor', 166 | 'Unassigned': 'Other cell', 167 | 'NK cell': 'Other cell', 168 | 'Mast cell': 'Other cell', 169 | 'Unknown (TCF1+)': 'Other cell', 170 | 'Unclassified': 'Other cell', 171 | 'Vessel': 'Vessel cell', 172 | 'Vessel endothelium': 'Vessel cell', 173 | } 174 | 175 | return common_cell_type_dict, common_biomarker_cols_list 176 | 177 | def get_cell_type_metadata(nx_graph_files): 178 | """Find all unique cell types from a list of cellular graphs 179 | 180 | Args: 181 | nx_graph_files (list/str): path/list of paths to cellular graph files (gpickle) 182 | 183 | Returns: 184 | cell_type_mapping (dict): mapping of unique cell types to integer indices 185 | cell_type_freq (dict): mapping of unique cell types to their frequency 186 | """ 187 | if isinstance(nx_graph_files, str): 188 | nx_graph_files = [nx_graph_files] 189 | cell_type_mapping = {} 190 | for g_f in nx_graph_files: 191 | try: 192 | G = pickle.load(open(g_f, 'rb')) 193 | except: 194 | print('Error detected! in file:', g_f) 195 | assert 'cell_type' in G['layer_1'].nodes[0] 196 | for n in G['layer_1'].nodes: 197 | ct = G['layer_1'].nodes[n]['cell_type'] 198 | if ct not in cell_type_mapping: 199 | cell_type_mapping[ct] = 0 200 | cell_type_mapping[ct] += 1 201 | unique_cell_types = sorted(cell_type_mapping.keys()) 202 | unique_cell_types_ct = [cell_type_mapping[ct] for ct in unique_cell_types] 203 | unique_cell_type_freq = [count / sum(unique_cell_types_ct) for count in unique_cell_types_ct] 204 | cell_type_mapping = {ct: i for i, ct in enumerate(unique_cell_types)} 205 | cell_type_freq = dict(zip(unique_cell_types, unique_cell_type_freq)) 206 | return cell_type_mapping, cell_type_freq 207 | 208 | 209 | def get_biomarker_metadata(nx_graph_files): 210 | """Load all biomarkers from a list of cellular graphs 211 | 212 | Args: 213 | nx_graph_files (list/str): path/list of paths to cellular graph files (gpickle) 214 | 215 | Returns: 216 | shared_bms (list): list of biomarkers shared by all cells (intersect) 217 | all_bms (list): list of all biomarkers (union) 218 | """ 219 | if isinstance(nx_graph_files, str): 220 | nx_graph_files = [nx_graph_files] 221 | all_bms = set() 222 | shared_bms = None 223 | for g_f in nx_graph_files: 224 | G = pickle.load(open(g_f, 'rb')) 225 | for n in G['layer_1'].nodes: 226 | bms = sorted(G['layer_1'].nodes[n]["biomarker_expression"].keys()) 227 | for bm in bms: 228 | all_bms.add(bm) 229 | valid_bms = [ 230 | bm for bm in bms if G['layer_1'].nodes[n]["biomarker_expression"][bm] == G['layer_1'].nodes[n]["biomarker_expression"][bm]] 231 | shared_bms = set(valid_bms) if shared_bms is None else shared_bms & set(valid_bms) 232 | shared_bms = sorted(shared_bms) 233 | all_bms = sorted(all_bms) 234 | return shared_bms, all_bms 235 | 236 | 237 | def get_graph_splits(dataset, 238 | split='random', 239 | cv_k=5, 240 | seed=None, 241 | fold_mapping=None): 242 | """ Define train/valid split 243 | 244 | Args: 245 | dataset (CellularGraphDataset): dataset to split 246 | split (str): split method, one of 'random', 'fold' 247 | cv_k (int): number of splits for random split 248 | seed (int): random seed 249 | fold_mapping (dict): mapping of region ids to folds, 250 | fold could be coverslip, patient, etc. 251 | 252 | Returns: 253 | split_inds (list): fold indices for each region in the dataset 254 | """ 255 | splits = {} 256 | region_ids = set([dataset.get_full(i).region_id for i in range(dataset.N)]) 257 | _region_ids = sorted(region_ids) 258 | if split == 'random': 259 | if seed is not None: 260 | np.random.seed(seed) 261 | if fold_mapping is None: 262 | fold_mapping = {region_id: region_id for region_id in _region_ids} 263 | # `_ids` could be sample ids / patient ids / certain properties 264 | _folds = sorted(set(list(fold_mapping.values()))) 265 | np.random.shuffle(_folds) 266 | cv_shard_size = len(_folds) / cv_k 267 | for i, region_id in enumerate(_region_ids): 268 | splits[region_id] = _folds.index(fold_mapping[region_id]) // cv_shard_size 269 | elif split == 'fold': 270 | # Split into folds, one fold per group 271 | assert fold_mapping is not None 272 | _folds = sorted(set(list(fold_mapping.values()))) 273 | for i, region_id in enumerate(_region_ids): 274 | splits[region_id] = _folds.index(fold_mapping[region_id]) 275 | else: 276 | raise ValueError("split mode not recognized") 277 | 278 | split_inds = [] 279 | for i in range(dataset.N): 280 | split_inds.append(splits[dataset.get_full(i).region_id]) 281 | return split_inds 282 | 283 | def get_memory_usage(gpu, print_info=False): 284 | """Get accurate gpu memory usage by querying torch runtime""" 285 | allocated = torch.cuda.memory_allocated(gpu) 286 | reserved = torch.cuda.memory_reserved(gpu) 287 | if print_info: 288 | print("allocated: %.2f MB" % (allocated / 1024 / 1024), flush=True) 289 | print("reserved: %.2f MB" % (reserved / 1024 / 1024), flush=True) 290 | return allocated 291 | 292 | def compute_tensor_bytes(tensors): 293 | """Compute the bytes used by a list of tensors""" 294 | if not isinstance(tensors, (list, tuple)): 295 | tensors = [tensors] 296 | 297 | ret = 0 298 | for x in tensors: 299 | if x.dtype in [torch.int64, torch.long]: 300 | ret += np.prod(x.size()) * 8 301 | if x.dtype in [torch.float32, torch.int, torch.int32]: 302 | ret += np.prod(x.size()) * 4 303 | elif x.dtype in [torch.bfloat16, torch.float16, torch.int16]: 304 | ret += np.prod(x.size()) * 2 305 | elif x.dtype in [torch.int8]: 306 | ret += np.prod(x.size()) 307 | else: 308 | print(x.dtype) 309 | raise ValueError() 310 | return ret 311 | 312 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import src 4 | from src.utils import setup_logger, seed_everything, preprocess_generalization 5 | import numpy as np 6 | import datetime 7 | import argparse 8 | import time 9 | import re 10 | from src.precomputing import PrecomputingBase 11 | import gc 12 | 13 | torch.set_num_threads(4) 14 | 15 | def str2bool(s): 16 | if s not in {'False', 'True', 'false', 'true'}: 17 | raise ValueError('Not a valid boolean string') 18 | return (s == 'True') or (s == 'true') 19 | 20 | def parse_args(): 21 | parser = argparse.ArgumentParser(description='Mew') 22 | parser.add_argument('--device', type=int, default=0) 23 | parser.add_argument('--data', type=str, default='charville') # upmc, dfci, charville 24 | parser.add_argument('--task', type=str, default='classification') # classification, cox 25 | parser.add_argument('--use_node_features', metavar='S', type=str, nargs='+', default=['biomarker_expression', 'SIZE']) 26 | parser.add_argument('--shared', type=str2bool, default=False) # True, False 27 | parser.add_argument('--attn_weight', type=str2bool, default=False) # True, False 28 | parser.add_argument('--lr', type=float, default=0.0001) # 0.1, 0.01, 0.001, 0.0001 29 | parser.add_argument('--num_layers', type=int, default=4) # 1, 2, 3, 4 30 | parser.add_argument('--emb_dim', type=int, default=512) # 64, 128, 256, 512 31 | parser.add_argument('--batch_size', type=int, default=16) # 16, 32 32 | parser.add_argument('--drop_ratio', type=float, default=0.5) # 0.0, 0.25, 0.5 33 | parser.add_argument('--pool', type=str, default='mean') 34 | parser.add_argument('--num_epochs', type=int, default=1000) 35 | parser.add_argument('--eval_epoch', type=int, default=100) 36 | parser.add_argument('--early_stop', type=int, default=300) 37 | parser.add_argument('--seed', type=int, default=0) 38 | parser.add_argument('--save_model', type=str2bool, default=True) 39 | return parser.parse_known_args() 40 | 41 | def main(): 42 | args, _ = parse_args() 43 | filename = f'Mew.txt' 44 | if args.data == 'dfci': 45 | args.task = 'classification' 46 | 47 | os.makedirs(f'./log/{args.data}/{args.task}', exist_ok=True) 48 | logger = setup_logger('./', '-', f'./log/{args.data}/{args.task}/{filename}') 49 | logger.info(datetime.datetime.now()) 50 | seed_everything(args.seed) 51 | 52 | print(f'Data: {args.data}, Task: {args.task}') 53 | 54 | # Settings 55 | data = args.data # 'upmc', 'charville' 56 | print(args.use_node_features) 57 | use_node_features = args.use_node_features 58 | device = f'cuda:{args.device}' 59 | 60 | data_root = './dataset/'+data + "_data" if data in ['upmc', 'charville'] else './dataset/general_data' 61 | print(data_root) 62 | raw_data_root = f"{data_root}/raw_data" 63 | dataset_root = f"{data_root}/dataset_mew" 64 | graph_label_file = f"{data_root}/{data}_labels.csv" 65 | 66 | if data == "upmc": 67 | _common_cell_type_dict, _common_biomarker_list = None, None 68 | if args.task == 'classification': 69 | graph_tasks = ['primary_outcome', 'recurred', 'hpvstatus'] 70 | 71 | elif args.task == 'cox': 72 | graph_tasks = [['survival_day', 'survival_status']] 73 | 74 | elif data == "charville": 75 | _common_cell_type_dict, _common_biomarker_list = None, None 76 | if args.task == 'classification': 77 | graph_tasks = ['primary_outcome', 'recurrence'] 78 | 79 | elif args.task == 'cox': 80 | graph_tasks = [['survival_legnth', 'survival_status'], ['recurrence_interval', 'recurrence_event']] 81 | 82 | else: # 'dfci' 83 | _common_cell_type_dict, _common_biomarker_list = preprocess_generalization('./dataset') 84 | graph_label_file1 = f"./dataset/upmc_data/upmc_labels.csv" 85 | graph_label_file2 = f"./dataset/dfci_data/dfci_labels.csv" 86 | graph_tasks1 = ['primary_outcome'] # from UPMC 87 | graph_tasks2 = ['pTR_label'] # from DFCI 88 | graph_tasks = graph_tasks2 89 | 90 | # Generate cellular graphs from raw inputs 91 | nx_graph_root = os.path.join(dataset_root, "graph") 92 | fig_save_root = os.path.join(dataset_root, "fig") 93 | model_save_root = os.path.join(dataset_root, "model") 94 | 95 | region_ids = set([f.split('.')[0] for f in os.listdir(raw_data_root)]) 96 | os.makedirs(nx_graph_root, exist_ok=True) 97 | os.makedirs(fig_save_root, exist_ok=True) 98 | os.makedirs(model_save_root, exist_ok=True) 99 | 100 | # Save graph generated from each region 101 | for region_id in region_ids: 102 | graph_output = os.path.join(nx_graph_root, "%s.gpkl" % region_id) 103 | if not os.path.exists(graph_output): 104 | print("Processing %s" % region_id) 105 | _voronoi_file=os.path.join(raw_data_root, "%s.json" % region_id) if data == 'upmc' else None 106 | G = src.construct_graph_for_region( 107 | region_id, 108 | cell_coords_file=os.path.join(raw_data_root, "%s.cell_data.csv" % region_id), 109 | cell_types_file=os.path.join(raw_data_root, "%s.cell_types.csv" % region_id), 110 | cell_biomarker_expression_file=os.path.join(raw_data_root, "%s.expression.csv" % region_id), 111 | cell_features_file=os.path.join(raw_data_root, "%s.cell_features.csv" % region_id), 112 | voronoi_file=_voronoi_file, 113 | graph_output=graph_output, 114 | voronoi_polygon_img_output=None, 115 | graph_img_output=None, 116 | common_cell_type_dict=_common_cell_type_dict, 117 | common_biomarker_list=_common_biomarker_list, 118 | figsize=10) 119 | 120 | # Define Cellular Graph Dataset 121 | dataset_kwargs = { 122 | 'raw_folder_name': 'graph', 123 | 'processed_folder_name': 'tg_graph', 124 | 'node_features': ["cell_type", "SIZE", "biomarker_expression", "neighborhood_composition", "center_coord"], 125 | 'edge_features': ["edge_type", "distance"], 126 | 'cell_type_mapping': None, 127 | 'cell_type_freq': None, 128 | 'biomarkers': None 129 | } 130 | feature_kwargs = { 131 | "biomarker_expression_process_method": "linear", 132 | "biomarker_expression_lower_bound": -2, 133 | "biomarker_expression_upper_bound": 3, 134 | "neighborhood_size": 10, 135 | } 136 | dataset_kwargs.update(feature_kwargs) 137 | dataset = src.CellularGraphDataset(dataset_root, **dataset_kwargs) 138 | 139 | # Define Transformers 140 | transformers = [ 141 | src.FeatureMask(dataset, 142 | use_center_node_features=use_node_features, 143 | use_neighbor_node_features=use_node_features) 144 | ] 145 | if data in ['upmc', 'charville']: 146 | transformers.append(src.AddGraphLabel(graph_label_file, tasks=graph_tasks)) 147 | else: 148 | transformers.append(src.AddTwoGraphLabel([graph_label_file1, graph_label_file2], tasks=[graph_tasks1, graph_tasks2])) 149 | 150 | dataset.set_transforms(transformers) 151 | 152 | # Precomputing & Label Extracting 153 | region_ids = [dataset.get_full(i).region_id for i in range(dataset.N)] 154 | coverslip_ids = [r_id.split('_')[1] for r_id in region_ids] 155 | 156 | num_feat = dataset[0]['cell'].x.shape[1] 157 | precomputer = PrecomputingBase(args.num_layers, device) 158 | sign_xs = precomputer.precompute(dataset) 159 | dataset_yw = [[dataset[i].graph_y, dataset[i].graph_w] for i in range(len(dataset))] 160 | del dataset 161 | gc.collect() 162 | 163 | # Define train/valid split 164 | if data == "upmc": 165 | fold0_coverslips = {'train': ['c001', 'c002', 'c004', 'c006'], 'val': ['c007'], 'test': ['c003', 'c005']} 166 | fold1_coverslips = {'train': ['c002', 'c003', 'c005', 'c007'], 'val': ['c004'], 'test': ['c001', 'c006']} 167 | fold2_coverslips = {'train': ['c001', 'c003', 'c004', 'c006'], 'val': ['c005'], 'test': ['c002', 'c007']} 168 | elif data == "charville": 169 | fold0_coverslips = {'train': ['c001', 'c002'], 'val': ['c003'], 'test': ['c004']} 170 | fold1_coverslips = {'train': ['c003', 'c004'], 'val': ['c002'], 'test': ['c001']} 171 | fold2_coverslips = {'train': ['c002', 'c004'], 'val': ['c001'], 'test': ['c003']} 172 | 173 | if data in ['upmc', 'charville']: 174 | split_indices = {} 175 | split_indices[0] = fold0_coverslips 176 | split_indices[1] = fold1_coverslips 177 | split_indices[2] = fold2_coverslips 178 | fold_list = [0,1,2] 179 | else: 180 | fold_list = [0] 181 | 182 | fold_results = [] 183 | for fold in fold_list: 184 | # Define Model kwargs 185 | model_kwargs = { 186 | 'num_layer': args.num_layers, 187 | 'num_feat': num_feat, 188 | 'emb_dim': args.emb_dim, 189 | 'num_node_tasks': 0, 190 | 'num_graph_tasks': len(graph_tasks), 191 | 'node_embedding_output': 'last', 192 | 'drop_ratio': args.drop_ratio, 193 | 'graph_pooling': args.pool, 194 | 'attn_weight': args.attn_weight, 195 | 'shared': args.shared, 196 | } 197 | # Define Train and Evaluate kwargs 198 | train_kwargs = { 199 | 'batch_size': args.batch_size, 200 | 'lr': args.lr, 201 | 'graph_loss_weight': 1.0, 202 | 'num_iterations': args.num_epochs, 203 | 'early_stop': args.early_stop, 204 | 'node_task_loss_fn': None, 205 | 'graph_task_loss_fn': src.models.BinaryCrossEntropy() if args.task == 'classification' else src.models.CoxSGDLossFn(), 206 | 'evaluate_fn': [src.train.evaluate_by_full_graph, src.train.save_model_weight], 207 | 'evaluate_freq': args.eval_epoch, 208 | } 209 | 210 | evaluate_kwargs = { 211 | 'shuffle': True, 212 | 'full_graph_graph_task_evaluate_fn': src.inference.full_graph_graph_classification_evaluate_fn, 213 | 'score_file': os.path.join(model_save_root, 'Mew-%s-%s-%d.txt' % (graph_tasks[0], '_'.join(use_node_features), fold)), 214 | 'model_folder': os.path.join(model_save_root, 'Mew'), 215 | } 216 | 217 | train_kwargs.update(evaluate_kwargs) 218 | os.makedirs(evaluate_kwargs['model_folder'], exist_ok=True) 219 | 220 | train_inds = [] 221 | valid_inds = [] 222 | test_inds = [] 223 | 224 | if data in ['upmc', 'charville']: 225 | for i, cs_id in enumerate(coverslip_ids): 226 | if cs_id in split_indices[fold]['train']: 227 | train_inds.append(i) 228 | elif cs_id in split_indices[fold]['val']: 229 | valid_inds.append(i) 230 | else: 231 | test_inds.append(i) 232 | else: 233 | for i, region_id in enumerate(region_ids): 234 | if 'UPMC' in region_id: 235 | if 'c007' in region_id: 236 | valid_inds.append(i) 237 | else: 238 | train_inds.append(i) 239 | else: 240 | test_inds.append(i) 241 | 242 | start_time = time.time() 243 | model = src.models.SIGN_pred(**model_kwargs) 244 | model = model.to(device) 245 | 246 | model, fold_result = src.train.train_full_graph( 247 | model, dataset_yw, sign_xs, device, filename, fold, logger, 248 | train_inds=train_inds, valid_inds=valid_inds, test_inds=test_inds, task=args.task, name=data, save_model=args.save_model, **train_kwargs) 249 | 250 | if fold_result == -1: 251 | print(f'The learning rate: {args.lr} is too high, causing numerical instability. Please try using a lower learning rate') 252 | break 253 | 254 | performance_list = re.findall(r"[-+]?\d*\.\d+|\d+", fold_result) 255 | fold_results.append([float(performance) for performance in performance_list[-len(graph_tasks):]]) 256 | 257 | if fold == fold_list[-1]: 258 | averages = [round(sum(pair)/len(pair), 3) for pair in zip(*fold_results)] 259 | std_devs = [round(np.std(pair), 3) for pair in zip(*fold_results)] 260 | 261 | averages_std_combined = [f"{avg} ± {std}" for avg, std in zip(averages, std_devs)] 262 | combined_str = ', '.join(averages_std_combined) 263 | logger.info(f'[*Fold Average*] Total Epoch: {args.num_epochs}, Valid, graph-score, [{combined_str}]') 264 | 265 | print(f"Fold: {fold}") 266 | print("Total time elapsed: {:.4f}s".format(time.time() - start_time)) 267 | logger.info("Filename: {}, Total time elapsed: {:.4f}s, Args: {}".format(filename, time.time() - start_time, args)) 268 | logger.info("") 269 | 270 | if __name__ == '__main__': 271 | main() -------------------------------------------------------------------------------- /src/transform.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from copy import deepcopy 4 | import torch 5 | from src.utils import generate_charville_label 6 | 7 | class FeatureMask(object): 8 | """ Transformer object for masking features """ 9 | def __init__(self, 10 | dataset, 11 | use_neighbor_node_features=None, 12 | use_center_node_features=None, 13 | use_edge_features=None, 14 | **kwargs): 15 | """ Construct the transformer 16 | 17 | Args: 18 | dataset (CellularGraphDataset): dataset object 19 | use_neighbor_node_features (list): list of node feature items to use 20 | for non-center nodes, all other features will be masked out 21 | use_center_node_features (list): list of node feature items to use 22 | for the center node, all other features will be masked out 23 | use_edge_features (list): list of edge feature items to use, 24 | all other features will be masked out 25 | """ 26 | 27 | self.node_feature_names = dataset.node_feature_names 28 | self.edge_feature_names = dataset.edge_feature_names 29 | 30 | self.use_neighbor_node_features = use_neighbor_node_features if \ 31 | use_neighbor_node_features is not None else dataset.node_features 32 | self.use_center_node_features = use_center_node_features if \ 33 | use_center_node_features is not None else dataset.node_features 34 | self.use_edge_features = use_edge_features if \ 35 | use_edge_features is not None else dataset.edge_features 36 | 37 | self.center_node_feature_masks = [ 38 | 1 if any(name.startswith(feat) for feat in self.use_center_node_features) 39 | else 0 for name in self.node_feature_names] 40 | self.neighbor_node_feature_masks = [ 41 | 1 if any(name.startswith(feat) for feat in self.use_neighbor_node_features) 42 | else 0 for name in self.node_feature_names] 43 | 44 | self.center_node_feature_masks = \ 45 | torch.from_numpy(np.array(self.center_node_feature_masks).reshape((-1,))).float() 46 | self.neighbor_node_feature_masks = \ 47 | torch.from_numpy(np.array(self.neighbor_node_feature_masks).reshape((1, -1))).float() 48 | 49 | def __call__(self, data): 50 | data = deepcopy(data) 51 | if "center_node_index" in data: 52 | center_node_feat = data.x[data.center_node_index].detach().data.clone() 53 | else: 54 | center_node_feat = None 55 | data = self.transform_neighbor_node(data) 56 | data = self.transform_center_node(data, center_node_feat) 57 | return data 58 | 59 | def transform_neighbor_node(self, data): 60 | """Apply neighbor node feature masking""" 61 | data['cell'].x = data['cell'].x * self.neighbor_node_feature_masks 62 | return data 63 | 64 | def transform_center_node(self, data, center_node_feat=None): 65 | """Apply center node feature masking""" 66 | if center_node_feat is None: 67 | return data 68 | assert "center_node_index" in data 69 | center_node_feat = center_node_feat * self.center_node_feature_masks 70 | data['cell'].x[data['cell'].center_node_index] = center_node_feat 71 | return data 72 | 73 | 74 | class AddCenterCellType(object): 75 | """Transformer for center cell type prediction""" 76 | def __init__(self, dataset, **kwargs): 77 | self.node_feature_names = dataset.node_feature_names 78 | self.cell_type_feat = self.node_feature_names.index('cell_type') 79 | # Assign a placeholder cell type for the center node 80 | self.placeholder_cell_type = max(dataset.cell_type_mapping.values()) + 1 81 | 82 | def __call__(self, data): 83 | data = deepcopy(data) 84 | assert "center_node_index" in data, \ 85 | "Only subgraphs with center nodes are supported, cannot find `center_node_index`" 86 | center_node_feat = data.x[data.center_node_index].detach().clone() 87 | center_cell_type = center_node_feat[self.cell_type_feat] 88 | data.node_y = center_cell_type.long().view((1,)) 89 | data.x[data.center_node_index, self.cell_type_feat] = self.placeholder_cell_type 90 | return data 91 | 92 | 93 | class AddCenterCellBiomarkerExpression(object): 94 | """Transformer for center cell biomarker expression prediction""" 95 | def __init__(self, dataset, **kwargs): 96 | self.node_feature_names = dataset.node_feature_names 97 | self.bm_exp_feat = np.array([ 98 | i for i, feat in enumerate(self.node_feature_names) 99 | if feat.startswith('biomarker_expression')]) 100 | 101 | def __call__(self, data): 102 | assert "center_node_index" in data, \ 103 | "Only subgraphs with center nodes are supported, cannot find `center_node_index`" 104 | center_node_feat = data.x[data.center_node_index].detach().clone() 105 | center_cell_exp = center_node_feat[self.bm_exp_feat].float() 106 | data.node_y = center_cell_exp.view(1, -1) 107 | return data 108 | 109 | 110 | class AddCenterCellIdentifier(object): 111 | """Transformer for adding another feature column for identifying center cell 112 | Helpful when predicting node-level tasks. 113 | """ 114 | def __init__(self, *args, **kwargs): 115 | pass 116 | 117 | def __call__(self, data): 118 | assert "center_node_index" in data, \ 119 | "Only subgraphs with center nodes are supported, cannot find `center_node_index`" 120 | center_cell_identifier_column = torch.zeros((data.x.shape[0], 1), dtype=data.x.dtype) 121 | center_cell_identifier_column[data.center_node_index, 0] = 1. 122 | data.x = torch.cat([data.x, center_cell_identifier_column], dim=1) 123 | return data 124 | 125 | 126 | class AddGraphLabel(object): 127 | """Transformer for adding graph-level task labels""" 128 | def __init__(self, graph_label_file, tasks=[], **kwargs): 129 | """ Construct the transformer 130 | 131 | Args: 132 | graph_label_file (str): path to the csv file containing graph-level 133 | task labels. This file should always have the first column as region id. 134 | tasks (list): list of tasks to use, corresponding to column names 135 | of the csv file. If empty, use all tasks in the file 136 | """ 137 | self.label_df = pd.read_csv(graph_label_file) 138 | graph_tasks = list(self.label_df.columns) if len(tasks) == 0 else tasks 139 | if ('charville' in graph_label_file) & ('recurrence_event' not in self.label_df.columns): 140 | self.label_df = generate_charville_label(self.label_df) 141 | self.label_df.to_csv(graph_label_file, index=False) # update csv file 142 | self.label_df.index = self.label_df.index.map(str) # Convert index to str 143 | self.graph_tasks = graph_tasks 144 | self.tasks, self.class_label_weights = self.build_class_weights(graph_tasks) 145 | self.c_index = True if ('survival_status' in np.stack(graph_tasks)) or ('recurrence_event' in np.stack(graph_tasks)) else False 146 | 147 | def build_class_weights(self, graph_tasks): 148 | valid_tasks = [] 149 | class_label_weights = {} 150 | for task in graph_tasks: 151 | ar = list(self.label_df[task]) 152 | valid_vals = [_y for _y in ar if _y == _y] 153 | unique_vals = set(valid_vals) 154 | if not all(v.__class__ in [int, float] for v in unique_vals): 155 | # Skip tasks with non-numeric labels 156 | continue 157 | valid_tasks.append(task) 158 | if len(unique_vals) > 5: 159 | # More than 5 unique values in labels, likely a regression task 160 | class_label_weights[task] = {_y: 1 for _y in unique_vals} 161 | else: 162 | # Classification task, compute class weights 163 | val_counts = {_y: valid_vals.count(_y) for _y in unique_vals} 164 | max_count = max(val_counts.values()) 165 | class_label_weights[task] = {_y: max_count / val_counts[_y] for _y in unique_vals} 166 | return valid_tasks, class_label_weights 167 | 168 | def fetch_label(self, region_id, task_name): 169 | # Updated from SPACE-GM 170 | new_int = int(region_id.split('_')[1][1:]) 171 | if 'UPMC' in region_id: 172 | new_int += 4 173 | if len(str(new_int)) == 1: 174 | new_int = f'00{new_int}' 175 | elif len(str(new_int)) == 2: 176 | new_int = f'0{new_int}' 177 | new_region_id = f'SpaceGMP-65_c{new_int}_' + region_id.split('_')[2] + '_' + region_id.split('_')[3] + '_' + region_id.split('_')[4] 178 | if 'acquisition_id_visualizer' in self.label_df: 179 | y = self.label_df[self.label_df["acquisition_id_visualizer"] == new_region_id][task_name].item() 180 | else: 181 | y = self.label_df[self.label_df["region_id"] == new_region_id][task_name].item() 182 | 183 | if y != y: # np.nan 184 | y = 0 185 | w = 0 186 | else: 187 | w = self.class_label_weights[task_name][y] 188 | return y, w 189 | 190 | def fetch_length_event(self, region_id, task_name): 191 | # Updated from SPACE-GM 192 | new_int = int(region_id.split('_')[1][1:]) 193 | if 'UPMC' in region_id: 194 | new_int += 4 195 | if len(str(new_int)) == 1: 196 | new_int = f'00{new_int}' 197 | elif len(str(new_int)) == 2: 198 | new_int = f'0{new_int}' 199 | new_region_id = f'SpaceGMP-65_c{new_int}_' + region_id.split('_')[2] + '_' + region_id.split('_')[3] + '_' + region_id.split('_')[4] 200 | 201 | if 'acquisition_id_visualizer' in self.label_df: 202 | length = self.label_df[self.label_df["acquisition_id_visualizer"] == new_region_id][task_name[0]].item() 203 | event = self.label_df[self.label_df["acquisition_id_visualizer"] == new_region_id][task_name[1]].item() 204 | else: 205 | length = self.label_df[self.label_df["region_id"] == new_region_id][task_name[0]].item() 206 | event = self.label_df[self.label_df["region_id"] == new_region_id][task_name[1]].item() 207 | 208 | return length, event 209 | 210 | def __call__(self, data): 211 | graph_y = [] 212 | graph_w = [] 213 | 214 | for task in self.graph_tasks: 215 | if self.c_index: 216 | y, w = self.fetch_length_event(data.region_id, task) 217 | else: 218 | y, w = self.fetch_label(data.region_id, task) 219 | 220 | graph_y.append(y) 221 | graph_w.append(w) 222 | data.graph_y = torch.from_numpy(np.array(graph_y).reshape((1, -1))) 223 | data.graph_w = torch.from_numpy(np.array(graph_w).reshape((1, -1))) 224 | 225 | return data 226 | 227 | class AddTwoGraphLabel(object): 228 | """Transformer for adding graph-level task labels""" 229 | def __init__(self, graph_label_files=[], tasks=[], **kwargs): 230 | """ Construct the transformer 231 | 232 | Args: 233 | graph_label_file (str): path to the csv file containing graph-level 234 | task labels. This file should always have the first column as region id. 235 | tasks (list): list of tasks to use, corresponding to column names 236 | of the csv file. If empty, use all tasks in the file 237 | """ 238 | self.label_df_upmc = pd.read_csv(graph_label_files[0]) 239 | self.label_df_dfci = pd.read_csv(graph_label_files[1]) 240 | 241 | self.label_df_upmc.index = self.label_df_upmc.index.map(str) # Convert index to str 242 | self.label_df_dfci.index = self.label_df_dfci.index.map(str) # Convert index to str 243 | 244 | graph_tasks_upmc = tasks[0] 245 | graph_tasks_dfci = tasks[1] 246 | self.graph_tasks_upmc = graph_tasks_upmc 247 | self.graph_tasks_dfci = graph_tasks_dfci 248 | 249 | self.tasks_upmc, self.class_label_weights_upmc = self.build_class_weights(graph_tasks_upmc, label_df=self.label_df_upmc) 250 | self.tasks_dfci, self.class_label_weights_dfci = self.build_class_weights(graph_tasks_dfci, label_df=self.label_df_dfci) 251 | self.c_index = False 252 | 253 | def build_class_weights(self, graph_tasks, label_df=None): 254 | valid_tasks = [] 255 | class_label_weights = {} 256 | for task in graph_tasks: 257 | if type(task) == list: 258 | task = task[0] # length 259 | ar = list(label_df[task]) 260 | valid_vals = [_y for _y in ar if _y == _y] 261 | unique_vals = set(valid_vals) 262 | if not all(v.__class__ in [int, float] for v in unique_vals): 263 | # Skip tasks with non-numeric labels 264 | continue 265 | valid_tasks.append(task) 266 | 267 | if len(unique_vals) > 5: 268 | # More than 5 unique values in labels, likely a regression task 269 | class_label_weights[task] = {_y: 1 for _y in unique_vals} 270 | else: 271 | # Classification task, compute class weights 272 | val_counts = {_y: valid_vals.count(_y) for _y in unique_vals} 273 | max_count = max(val_counts.values()) 274 | class_label_weights[task] = {_y: max_count / val_counts[_y] for _y in unique_vals} 275 | 276 | return valid_tasks, class_label_weights 277 | 278 | def fetch_label(self, region_id, task_name): 279 | # Updated from SPACE-GM 280 | self.label_df = self.label_df_upmc if 'UPMC' in region_id else self.label_df_dfci 281 | new_int = int(region_id.split('_')[1][1:]) 282 | if 'UPMC' in region_id: 283 | new_int += 4 284 | if len(str(new_int)) == 1: 285 | new_int = f'00{new_int}' 286 | elif len(str(new_int)) == 2: 287 | new_int = f'0{new_int}' 288 | 289 | if 'UPMC' in region_id: 290 | new_region_id = f'SpaceGMP-65_c{new_int}_' + region_id.split('_')[2] + '_' + region_id.split('_')[3] + '_' + region_id.split('_')[4] 291 | else: 292 | new_region_id = region_id 293 | if "acquisition_id_visualizer" in self.label_df: 294 | y = self.label_df[self.label_df["acquisition_id_visualizer"] == new_region_id][task_name].item() 295 | 296 | else: 297 | y = self.label_df[self.label_df["region_id"] == new_region_id][task_name].item() 298 | 299 | if y != y: # np.nan 300 | y = 0 301 | w = 0 302 | else: 303 | self.class_label_weights = self.class_label_weights_upmc if 'UPMC' in region_id else self.class_label_weights_dfci 304 | w = self.class_label_weights[task_name][y] 305 | return y, w 306 | 307 | def __call__(self, data): 308 | graph_y = [] 309 | graph_w = [] 310 | 311 | self.tasks = self.tasks_upmc if 'UPMC' in data.region_id else self.tasks_dfci 312 | for task in self.tasks: 313 | y, w = self.fetch_label(data.region_id, task) 314 | graph_y.append(y) 315 | graph_w.append(w) 316 | data.graph_y = torch.from_numpy(np.array(graph_y).reshape((1, -1))) 317 | data.graph_w = torch.from_numpy(np.array(graph_w).reshape((1, -1))) 318 | 319 | return data -------------------------------------------------------------------------------- /src/features.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import networkx as nx 3 | from scipy.stats import rankdata 4 | import warnings 5 | import torch 6 | import torch_geometric as tg 7 | from torch_geometric.data import HeteroData 8 | 9 | from src.utils import EDGE_TYPES 10 | 11 | 12 | def process_biomarker_expression(G, 13 | node_ind, 14 | biomarkers=None, 15 | biomarker_expression_process_method='raw', 16 | biomarker_expression_lower_bound=-3, 17 | biomarker_expression_upper_bound=3, 18 | **kwargs): 19 | """ Process biomarker expression 20 | 21 | Args: 22 | G (nx.Graph): full cellular graph of the region 23 | node_ind (int): target node index 24 | biomarkers (list): list of biomarkers 25 | biomarker_expression_process_method (str): process method, one of 'raw', 'linear', 'log', 'rank' 26 | biomarker_expression_lower_bound (float): lower bound for min-max normalization, used for 'linear' and 'log' 27 | biomarker_expression_upper_bound (float): upper bound for min-max normalization, used for 'linear' and 'log' 28 | 29 | Returns: 30 | list: processed biomarker expression values 31 | """ 32 | 33 | bm_exp_dict = G.nodes[node_ind]["biomarker_expression"] 34 | bm_exp_vec = [] 35 | for bm in biomarkers: 36 | if bm_exp_dict is None or bm not in bm_exp_dict: 37 | bm_exp_vec.append(0.) 38 | else: 39 | bm_exp_vec.append(float(bm_exp_dict[bm])) 40 | 41 | bm_exp_vec = np.array(bm_exp_vec) 42 | lb = biomarker_expression_lower_bound 43 | ub = biomarker_expression_upper_bound 44 | 45 | if biomarker_expression_process_method == 'raw': 46 | return list(bm_exp_vec) 47 | elif biomarker_expression_process_method == 'linear': 48 | bm_exp_vec = np.clip(bm_exp_vec, lb, ub) 49 | bm_exp_vec = (bm_exp_vec - lb) / (ub - lb) 50 | return list(bm_exp_vec) 51 | elif biomarker_expression_process_method == 'log': 52 | bm_exp_vec = np.clip(np.log(bm_exp_vec + 1e-9), lb, ub) 53 | bm_exp_vec = (bm_exp_vec - lb) / (ub - lb) 54 | return list(bm_exp_vec) 55 | elif biomarker_expression_process_method == 'rank': 56 | bm_exp_vec = rankdata(bm_exp_vec, method='min') 57 | num_exp = len(bm_exp_vec) 58 | bm_exp_vec = (bm_exp_vec - 1) / (num_exp - 1) 59 | return list(bm_exp_vec) 60 | else: 61 | raise ValueError("expression process method %s not recognized" % biomarker_expression_process_method) 62 | 63 | 64 | def process_neighbor_composition(G, 65 | node_ind, 66 | cell_type_mapping=None, 67 | neighborhood_size=10, 68 | **kwargs): 69 | """ Calculate the composition vector of k-nearest neighboring cells 70 | 71 | Args: 72 | G (nx.Graph): full cellular graph of the region 73 | node_ind (int): target node index 74 | cell_type_mapping (dict): mapping of unique cell types to integer indices 75 | neighborhood_size (int): number of nearest neighbors to consider 76 | 77 | Returns: 78 | comp_vec (list): composition vector of k-nearest neighboring cells 79 | """ 80 | center_coord = G.nodes[node_ind]["center_coord"] 81 | 82 | def node_dist(c1, c2): 83 | return np.linalg.norm(np.array(c1) - np.array(c2), ord=2) 84 | 85 | radius = 1 86 | neighbors = {} 87 | while len(neighbors) < 2 * neighborhood_size and radius < 5: 88 | radius += 1 89 | ego_g = nx.ego_graph(G, node_ind, radius=radius) 90 | neighbors = {n: feat_dict["center_coord"] for n, feat_dict in ego_g.nodes.data()} 91 | 92 | closest_neighbors = sorted(neighbors.keys(), key=lambda x: node_dist(center_coord, neighbors[x])) 93 | closest_neighbors = closest_neighbors[1:(neighborhood_size + 1)] 94 | 95 | comp_vec = np.zeros((len(cell_type_mapping),)) 96 | for n in closest_neighbors: 97 | cell_type = cell_type_mapping[G.nodes[n]["cell_type"]] 98 | comp_vec[cell_type] += 1 99 | comp_vec = list(comp_vec / comp_vec.sum()) 100 | return comp_vec 101 | 102 | 103 | def process_edge_distance(G, 104 | edge_ind, 105 | log_distance_lower_bound=2., 106 | log_distance_upper_bound=5., 107 | **kwargs): 108 | """ Process edge distance, distance will be log-transformed and min-max normalized 109 | 110 | Default parameters assume distances are usually within the range: 10-100 pixels / 3.7-37 um 111 | 112 | Args: 113 | G (nx.Graph): full cellular graph of the region 114 | edge_ind (int): target edge index 115 | log_distance_lower_bound (float): lower bound for log-transformed distance 116 | log_distance_upper_bound (float): upper bound for log-transformed distance 117 | 118 | Returns: 119 | list: list of normalized log-transformed distance 120 | """ 121 | dist = G.edges[edge_ind]["distance"] 122 | log_dist = np.log(dist + 1e-5) 123 | _d = np.clip((log_dist - log_distance_lower_bound) / 124 | (log_distance_upper_bound - log_distance_lower_bound), 0, 1) 125 | return [_d] 126 | 127 | 128 | def process_feature(G, feature_item, node_ind=None, edge_ind=None, **feature_kwargs): 129 | """ Process a single node/edge feature item 130 | 131 | The following feature items are supported, note that some of them require 132 | keyword arguments in `feature_kwargs`: 133 | 134 | Node features: 135 | - feature_item: "cell_type" 136 | (required) "cell_type_mapping" 137 | - feature_item: "center_coord" 138 | - feature_item: "biomarker_expression" 139 | (required) "biomarkers", 140 | (optional) "biomarker_expression_process_method", 141 | (optional, if method is "linear" or "log") "biomarker_expression_lower_bound", 142 | (optional, if method is "linear" or "log") "biomarker_expression_upper_bound" 143 | - feature_item: "neighborhood_composition" 144 | (required) "cell_type_mapping", 145 | (optional) "neighborhood_size" 146 | - other additional feature items stored in the node attributes 147 | (see `graph_build.construct_graph_for_region`, argument `cell_features_file`) 148 | 149 | Edge features: 150 | - feature_item: "edge_type" 151 | - feature_item: "distance" 152 | (optional) "log_distance_lower_bound", 153 | (optional) "log_distance_upper_bound" 154 | 155 | Args: 156 | G (nx.Graph): full cellular graph of the region 157 | feature_item (str): feature item 158 | node_ind (int): target node index (if feature item is node feature) 159 | edge_ind (tuple): target edge index (if feature item is edge feature) 160 | feature_kwargs (dict): arguments for processing features 161 | 162 | Returns: 163 | v (list): feature vector 164 | """ 165 | # Node features 166 | if node_ind is not None and edge_ind is None: 167 | if feature_item == "cell_type": 168 | # Integer index of the cell type 169 | assert "cell_type_mapping" in feature_kwargs, \ 170 | "'cell_type_mapping' is required in the kwargs for feature item 'cell_type'" 171 | v = [feature_kwargs["cell_type_mapping"][G.nodes[node_ind]["cell_type"]]] 172 | return v 173 | elif feature_item == "center_coord": 174 | # Coordinates of the cell centroid 175 | v = list(G.nodes[node_ind]["center_coord"]) 176 | return v 177 | elif feature_item == "biomarker_expression": 178 | # Biomarker expression of the cell 179 | assert "biomarkers" in feature_kwargs, \ 180 | "'biomarkers' is required in the kwargs for feature item 'biomarker_expression'" 181 | v = process_biomarker_expression(G, node_ind, **feature_kwargs) 182 | return v 183 | elif feature_item == "neighborhood_composition": 184 | # Composition vector of the k-nearest neighboring cells 185 | assert "cell_type_mapping" in feature_kwargs, \ 186 | "'cell_type_mapping' is required in the kwargs for feature item 'neighborhood_composition'" 187 | v = process_neighbor_composition(G, node_ind, **feature_kwargs) 188 | return v 189 | elif feature_item in G.nodes[node_ind]: 190 | # Additional features specified by users, e.g. "SIZE" in the example 191 | v = [G.nodes[node_ind][feature_item]] 192 | return v 193 | else: 194 | raise ValueError("Feature %s not found in the node attributes of graph %s, node %s" % 195 | (feature_item, G.region_id, str(node_ind))) 196 | 197 | # Edge features 198 | elif edge_ind is not None and node_ind is None: 199 | if feature_item == "edge_type": 200 | v = [EDGE_TYPES[G.edges[edge_ind]["edge_type"]]] 201 | return v 202 | elif feature_item == "distance": 203 | v = process_edge_distance(G, edge_ind, **feature_kwargs) 204 | return v 205 | elif feature_item in G.edges[edge_ind]: 206 | v = [G.edges[edge_ind][feature_item]] 207 | return v 208 | else: 209 | raise ValueError("Feature %s not found in the edge attributes of graph %s, edge %s" % 210 | (feature_item, G.region_id, str(edge_ind))) 211 | 212 | else: 213 | raise ValueError("One of node_ind or edge_ind should be specified") 214 | 215 | 216 | def nx_to_tg_graph(G, 217 | node_features=["cell_type", 218 | "biomarker_expression", 219 | "neighborhood_composition", 220 | "center_coord"], 221 | edge_features=["edge_type", 222 | "distance"], 223 | **feature_kwargs): 224 | """ Build pyg data objects from nx graphs 225 | 226 | Args: 227 | G (nx.Graph): full cellular graph of the region 228 | node_features (list, optional): list of node feature items 229 | edge_features (list, optional): list of edge feature items 230 | feature_kwargs (dict): arguments for processing features 231 | 232 | Returns: 233 | data_list (list): list of pyg data objects 234 | """ 235 | data_list = [] 236 | 237 | # Each connected component of the cellular graph will be a separate pyg data object 238 | # Usually there should only be one connected component for each cellular graph 239 | for inds in nx.connected_components(G): 240 | # Skip small connected components 241 | if len(inds) < len(G) * 0.1: 242 | continue 243 | sub_G = G.subgraph(inds) 244 | 245 | # Relabel nodes to be consecutive integers, note that node indices are 246 | # not meaningful here, cells are identified by the key "cell_id" in each node 247 | mapping = {n: i for i, n in enumerate(sorted(sub_G.nodes))} 248 | sub_G = nx.relabel.relabel_nodes(sub_G, mapping) 249 | assert np.all(sub_G.nodes == np.arange(len(sub_G))) 250 | 251 | # Append node and edge features to the pyg data object 252 | data = {"x": [], "edge_attr": [], "edge_index": []} 253 | for node_ind in sub_G.nodes: 254 | feat_val = [] 255 | for key in node_features: 256 | feat_val.extend(process_feature(sub_G, key, node_ind=node_ind, **feature_kwargs)) 257 | data["x"].append(feat_val) 258 | 259 | for edge_ind in sub_G.edges: 260 | feat_val = [] 261 | for key in edge_features: 262 | feat_val.extend(process_feature(sub_G, key, edge_ind=edge_ind, **feature_kwargs)) 263 | data["edge_attr"].append(feat_val) 264 | data["edge_index"].append(edge_ind) 265 | data["edge_attr"].append(feat_val) 266 | data["edge_index"].append(tuple(reversed(edge_ind))) 267 | 268 | for key, item in data.items(): 269 | data[key] = torch.tensor(item) 270 | data['edge_index'] = data['edge_index'].t().long() 271 | data = tg.data.Data.from_dict(data) 272 | data.num_nodes = sub_G.number_of_nodes() 273 | data.region_id = G.region_id 274 | data_list.append(data) 275 | return data_list 276 | 277 | def nx_to_tg_hetero_graph(G, 278 | node_features=["cell_type", 279 | "biomarker_expression", 280 | "neighborhood_composition", 281 | "center_coord"], 282 | edge_features=["edge_type", 283 | "distance"], 284 | drop_edge=0.0, 285 | **feature_kwargs): 286 | """ Build pyg data objects from nx graphs 287 | 288 | Args: 289 | G (nx.Graph): full cellular graph of the region 290 | node_features (list, optional): list of node feature items 291 | edge_features (list, optional): list of edge feature items 292 | feature_kwargs (dict): arguments for processing features 293 | 294 | Returns: 295 | data_list (list): list of pyg data objects 296 | """ 297 | data_list = [] 298 | 299 | # Each connected component of the cellular graph will be a separate pyg data object 300 | # Usually there should only be one connected component for each cellular graph 301 | for inds in nx.connected_components(G['layer_1']): 302 | # Skip small connected components 303 | if len(inds) < len(G['layer_1']) * 0.1: 304 | continue 305 | sub_G_1 = G['layer_1'].subgraph(inds) 306 | sub_G_2 = G['layer_2'].subgraph(inds) # same inds for layer_2 307 | 308 | # Relabel nodes to be consecutive integers, note that node indices are 309 | # not meaningful here, cells are identified by the key "cell_id" in each node 310 | mapping = {n: i for i, n in enumerate(sorted(sub_G_1.nodes))} 311 | sub_G_1 = nx.relabel.relabel_nodes(sub_G_1, mapping) 312 | sub_G_2 = nx.relabel.relabel_nodes(sub_G_2, mapping) 313 | 314 | assert np.all(sub_G_1.nodes == np.arange(len(sub_G_1))) 315 | 316 | # Common node feature 317 | node_feature_list = [] 318 | for node_ind in sub_G_1.nodes: 319 | feat_val = [] 320 | for key in node_features: 321 | feat_val.extend(process_feature(sub_G_1, key, node_ind=node_ind, **feature_kwargs)) 322 | node_feature_list.append(feat_val) 323 | 324 | # Edge for layer 1 325 | edge_index_1_list = [] 326 | edge_attr_1_list = [] 327 | for edge_ind in sub_G_1.edges: 328 | feat_val = [] 329 | for key in edge_features: 330 | feat_val.extend(process_feature(sub_G_1, key, edge_ind=edge_ind, **feature_kwargs)) 331 | edge_attr_1_list.append(feat_val) 332 | edge_index_1_list.append(edge_ind) 333 | edge_attr_1_list.append(feat_val) 334 | edge_index_1_list.append(tuple(reversed(edge_ind))) 335 | 336 | # Edge for layer 2 337 | edge_index_2_list = [] 338 | edge_attr_2_list = [] 339 | for edge_ind in sub_G_2.edges: 340 | feat_val = [] 341 | for key in edge_features: 342 | feat_val.extend(process_feature(sub_G_2, key, edge_ind=edge_ind, **feature_kwargs)) 343 | edge_attr_2_list.append(feat_val) 344 | edge_index_2_list.append(edge_ind) 345 | edge_attr_2_list.append(feat_val) 346 | edge_index_2_list.append(tuple(reversed(edge_ind))) 347 | 348 | node_features_tensor = torch.tensor(node_feature_list) 349 | edge_index_1_tensor = torch.tensor(edge_index_1_list, dtype=torch.long).t().contiguous() 350 | edge_attr_1_tensor = torch.tensor(edge_attr_1_list) 351 | edge_index_2_tensor = torch.tensor(edge_index_2_list, dtype=torch.long).t().contiguous() 352 | edge_attr_2_tensor = torch.tensor(edge_attr_2_list) 353 | 354 | # Append node and edge features to the pyg data object 355 | data = HeteroData() 356 | data['cell'].x = node_features_tensor 357 | data['cell', 'geom', 'cell'].edge_index = edge_index_1_tensor.long() 358 | data['cell', 'geom', 'cell'].edge_attr = edge_attr_1_tensor 359 | data['cell', 'type', 'cell'].edge_index = edge_index_2_tensor.long() 360 | data['cell', 'type', 'cell'].edge_attr = edge_attr_2_tensor 361 | 362 | data.num_nodes = sub_G_1.number_of_nodes() 363 | data.region_id = G['layer_1'].region_id 364 | data_list.append(data) 365 | 366 | return data_list 367 | 368 | 369 | def get_feature_names(features, cell_type_mapping=None, biomarkers=None): 370 | """ Helper fn for getting a list of feature names from a list of feature items 371 | 372 | Args: 373 | features (list): list of feature items 374 | cell_type_mapping (dict): mapping of unique cell types to integer indices 375 | biomarkers (list): list of biomarkers 376 | 377 | Returns: 378 | feat_names(list): list of feature names 379 | """ 380 | feat_names = [] 381 | for feat in features: 382 | if feat in ["distance", "cell_type", "edge_type"]: 383 | # feature "cell_type", "edge_type" will be a single integer indice 384 | # feature "distance" will be a single float value 385 | feat_names.append(feat) 386 | elif feat == "center_coord": 387 | # feature "center_coord" will be a tuple of two float values 388 | feat_names.extend(["center_coord-x", "center_coord-y"]) 389 | elif feat == "biomarker_expression": 390 | # feature "biomarker_expression" will contain a list of biomarker expression values 391 | feat_names.extend(["biomarker_expression-%s" % bm for bm in biomarkers]) 392 | elif feat == "neighborhood_composition": 393 | # feature "neighborhood_composition" will contain a composition vector of the immediate neighbors 394 | # The vector will have the same length as the number of unique cell types 395 | feat_names.extend(["neighborhood_composition-%s" % ct 396 | for ct in sorted(cell_type_mapping.keys(), key=lambda x: cell_type_mapping[x])]) 397 | else: 398 | warnings.warn("Using additional feature: %s" % feat) 399 | feat_names.append(feat) 400 | return feat_names 401 | -------------------------------------------------------------------------------- /src/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import numpy as np 4 | import multiprocessing 5 | import matplotlib 6 | import matplotlib.pyplot as plt 7 | import warnings 8 | 9 | import torch 10 | from torch.utils.data import RandomSampler 11 | 12 | import torch_geometric as tg 13 | from torch_geometric.data import Dataset 14 | from torch_geometric.loader import DataLoader 15 | from torch_geometric.utils import subgraph 16 | 17 | from src.graph_build import plot_graph 18 | from src.features import get_feature_names, nx_to_tg_graph, nx_to_tg_hetero_graph 19 | from src.utils import ( 20 | EDGE_TYPES, 21 | get_cell_type_metadata, 22 | get_biomarker_metadata, 23 | ) 24 | 25 | 26 | class CellularGraphDataset(Dataset): 27 | """ Main dataset structure for cellular graphs 28 | Inherited from https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.data.Dataset.html 29 | """ 30 | def __init__(self, 31 | root, 32 | transform=[], 33 | pre_transform=None, 34 | raw_folder_name='graph', 35 | processed_folder_name='tg_graph', 36 | node_features=["cell_type", "expression", "neighborhood_composition", "center_coord"], 37 | edge_features=["edge_type", "distance"], 38 | cell_type_mapping=None, 39 | cell_type_freq=None, 40 | biomarkers=None, 41 | subgraph_size=0, 42 | subgraph_source='on-the-fly', 43 | subgraph_allow_distant_edge=True, 44 | subgraph_radius_limit=-1, 45 | sampling_avoid_unassigned=True, 46 | unassigned_cell_type='Unassigned', 47 | **feature_kwargs): 48 | """ Initialize the dataset 49 | 50 | Args: 51 | root (str): path to the dataset directory 52 | transform (list): list of transformations (see `transform.py`), 53 | applied to each output graph on-the-fly 54 | pre_transform (list): list of transformations, applied to each graph before saving 55 | raw_folder_name (str): name of the sub-folder containing raw graphs (gpickle) 56 | processed_folder_name (str): name of the sub-folder containing processed graphs (pyg data object) 57 | node_features (list): list of all available node feature items, 58 | see `features.process_features` for details 59 | edge_features (list): list of all available edge feature items 60 | cell_type_mapping (dict): mapping of unique cell types to integer indices, 61 | see `utils.get_cell_type_metadata` 62 | cell_type_freq (dict): mapping of unique cell types to their frequencies, 63 | see `utils.get_cell_type_metadata` 64 | biomarkers (list): list of biomarkers, 65 | see `utils.get_expression_biomarker_metadata` 66 | subgraph_size (int): number of hops for subgraph, 0 means using the full cellular graph 67 | subgraph_source (str): source of subgraphs, one of 'on-the-fly', 'chunk_save' 68 | subgraph_allow_distant_edge (bool): whether to consider distant edges 69 | subgraph_radius_limit (float): radius (distance to center cell in pixel) limit for subgraphs, 70 | -1 means no limit 71 | sampling_avoid_unassigned (bool): whether to avoid sampling cells with unassigned cell type 72 | unassigned_cell_type (str): name of the unassigned cell type 73 | feature_kwargs (dict): optional arguments for processing features 74 | see `features.process_features` for details 75 | """ 76 | self.root = root 77 | self.raw_folder_name = raw_folder_name 78 | self.processed_folder_name = processed_folder_name 79 | os.makedirs(self.raw_dir, exist_ok=True) 80 | os.makedirs(self.processed_dir, exist_ok=True) 81 | 82 | # Find all unique cell types in the dataset 83 | if cell_type_mapping is None or cell_type_freq is None: 84 | try: 85 | with open(self.raw_dir+'/cell_type_mapping_freq.pkl', 'rb') as f: 86 | self.cell_type_mapping, self.cell_type_freq = pickle.load(f) 87 | 88 | except: 89 | nx_graph_files = [os.path.join(self.raw_dir, f) for f in self.raw_file_names] 90 | self.cell_type_mapping, self.cell_type_freq = get_cell_type_metadata(nx_graph_files) 91 | 92 | with open(self.raw_dir+'/cell_type_mapping_freq.pkl', 'wb') as f: 93 | pickle.dump([self.cell_type_mapping, self.cell_type_freq], f) 94 | 95 | else: 96 | self.cell_type_mapping = cell_type_mapping 97 | self.cell_type_freq = cell_type_freq 98 | 99 | # Find all available biomarkers for cells in the dataset 100 | if biomarkers is None: 101 | try: 102 | with open(self.raw_dir+'/biomarkers.pkl', 'rb') as f: 103 | self.biomarkers = pickle.load(f) 104 | except: 105 | nx_graph_files = [os.path.join(self.raw_dir, f) for f in self.raw_file_names] 106 | self.biomarkers, _ = get_biomarker_metadata(nx_graph_files) 107 | 108 | with open(self.raw_dir+'/biomarkers.pkl', 'wb') as f: 109 | pickle.dump(self.biomarkers, f) 110 | 111 | else: 112 | self.biomarkers = biomarkers 113 | 114 | # Node features & edge features 115 | self.node_features = node_features 116 | self.edge_features = edge_features 117 | if "cell_type" in self.node_features: 118 | assert self.node_features.index("cell_type") == 0, "cell_type must be the first node feature" 119 | if "edge_type" in self.edge_features: # distant, neighboring based on 20 micrometer 120 | assert self.edge_features.index("edge_type") == 0, "edge_type must be the first edge feature" 121 | 122 | self.node_feature_names = get_feature_names(node_features, 123 | cell_type_mapping=self.cell_type_mapping, 124 | biomarkers=self.biomarkers) 125 | self.edge_feature_names = get_feature_names(edge_features, 126 | cell_type_mapping=self.cell_type_mapping, 127 | biomarkers=self.biomarkers) 128 | 129 | # Prepare kwargs for node and edge featurization 130 | self.feature_kwargs = feature_kwargs 131 | self.feature_kwargs['cell_type_mapping'] = self.cell_type_mapping 132 | self.feature_kwargs['biomarkers'] = self.biomarkers 133 | 134 | # Note this command below calls the `process` function 135 | super(CellularGraphDataset, self).__init__(root, None, pre_transform) 136 | 137 | # Transformations, e.g. masking features, adding graph-level labels 138 | self.transform = transform 139 | 140 | # SPACE-GM uses n-hop ego graphs (subgraphs) to perform prediction 141 | self.subgraph_size = subgraph_size # number of hops, 0 = use full graph 142 | self.subgraph_source = subgraph_source 143 | self.subgraph_allow_distant_edge = subgraph_allow_distant_edge 144 | self.subgraph_radius_limit = subgraph_radius_limit 145 | 146 | self.N = len(self.processed_paths) 147 | # Sampling frequency for each cell type 148 | self.sampling_freq = {self.cell_type_mapping[ct]: 1. / (self.cell_type_freq[ct] + 1e-5) 149 | for ct in self.cell_type_mapping} 150 | self.sampling_freq = torch.from_numpy(np.array([self.sampling_freq[i] for i in range(len(self.sampling_freq))])) 151 | # Avoid sampling unassigned cell 152 | self.unassigned_cell_type = unassigned_cell_type 153 | if sampling_avoid_unassigned and unassigned_cell_type in self.cell_type_mapping: 154 | self.sampling_freq[self.cell_type_mapping[unassigned_cell_type]] = 0. 155 | 156 | # Cache for subgraphs 157 | self.cached_data = {} 158 | 159 | def set_indices(self, inds=None): 160 | """Limit subgraph sampling to a subset of region indices, 161 | helpful when splitting dataset into training/validation/test regions 162 | """ 163 | self._indices = inds 164 | return 165 | 166 | def set_subgraph_source(self, subgraph_source): 167 | """Set subgraph source""" 168 | assert subgraph_source in ['chunk_save', 'on-the-fly'] 169 | self.subgraph_source = subgraph_source 170 | 171 | def set_transforms(self, transform=[]): 172 | """Set transformation functions""" 173 | self.transform = transform 174 | 175 | @property 176 | def raw_dir(self) -> str: 177 | return os.path.join(self.root, self.raw_folder_name) 178 | 179 | @property 180 | def processed_dir(self) -> str: 181 | return os.path.join(self.root, self.processed_folder_name) 182 | 183 | @property 184 | def raw_file_names(self): 185 | return sorted([f for f in os.listdir(self.raw_dir) if f.endswith('.gpkl')]) 186 | 187 | @property 188 | def processed_file_names(self): 189 | # Only files for full graphs 190 | return sorted([f for f in os.listdir(self.processed_dir) if f.endswith('.gpt') and 'hop' not in f]) 191 | 192 | def len(self): 193 | return len(self.processed_paths) 194 | 195 | def process(self): 196 | """Featurize all cellular graphs""" 197 | for raw_path in self.raw_paths: 198 | G = pickle.load(open(raw_path, 'rb')) 199 | region_id = G['layer_1'].region_id 200 | if os.path.exists(os.path.join(self.processed_dir, '%s.0.gpt' % region_id)): 201 | continue 202 | 203 | # Transform networkx graphs to pyg data objects, and add features for nodes and edges 204 | data_list = nx_to_tg_hetero_graph(G, 205 | node_features=self.node_features, 206 | edge_features=self.edge_features, 207 | **self.feature_kwargs) 208 | 209 | for i, d in enumerate(data_list): 210 | try: 211 | assert d.region_id == region_id # graph identifier 212 | assert d['cell'].x.shape[1] == len(self.node_feature_names) # make sure feature dimension matches 213 | assert d['cell', 'geom', 'cell'].edge_attr.shape[1] == len(self.edge_feature_names) 214 | assert d['cell', 'type', 'cell'].edge_attr.shape[1] == len(self.edge_feature_names) 215 | except: 216 | continue 217 | d.component_id = i 218 | if self.pre_transform is not None: 219 | for transform_fn in self.pre_transform: 220 | d = transform_fn(d) 221 | torch.save(d, os.path.join(self.processed_dir, '%s.%d.gpt' % (d.region_id, d.component_id))) 222 | return 223 | 224 | def __getitem__(self, idx): 225 | """Sample a graph/subgraph from the dataset and apply transformations""" 226 | # data = self.get(self.indices()[idx]) 227 | data = self.cached_data[idx] 228 | # Apply transformations 229 | for transform_fn in self.transform: 230 | data = transform_fn(data) 231 | return data 232 | 233 | def get(self, idx): 234 | data = self.get_full(idx) 235 | return data 236 | 237 | def get_subgraph(self, idx, center_ind): 238 | """Get a subgraph from the dataset""" 239 | # Check cache 240 | if (idx, center_ind) in self.cached_data: 241 | return self.cached_data[(idx, center_ind)] 242 | if self.subgraph_source == 'on-the-fly': 243 | # Construct subgraph on-the-fly 244 | return self.calculate_subgraph(idx, center_ind) 245 | elif self.subgraph_source == 'chunk_save': 246 | # Load subgraph from pre-saved chunk file 247 | return self.get_saved_subgraph_from_chunk(idx, center_ind) 248 | 249 | def get_full(self, idx): 250 | """Read the full cellular graph of region `idx`""" 251 | if idx in self.cached_data: 252 | return self.cached_data[idx] 253 | else: 254 | data = torch.load(self.processed_paths[idx]) 255 | self.cached_data[idx] = data 256 | return data 257 | 258 | def get_full_nx(self, idx): 259 | """Read the full cellular graph (nx.Graph) of region `idx`""" 260 | return pickle.load(open(self.raw_paths[idx], 'rb')) 261 | 262 | def calculate_subgraph(self, idx, center_ind): 263 | """Generate the n-hop subgraph around cell `center_ind` from region `idx`""" 264 | data = self.get_full(idx) 265 | if not self.subgraph_allow_distant_edge: 266 | edge_type_mask = (data.edge_attr[:, 0] == EDGE_TYPES["neighbor"]) 267 | else: 268 | edge_type_mask = None 269 | sub_node_inds = k_hop_subgraph(int(center_ind), 270 | self.subgraph_size, 271 | data.edge_index, 272 | edge_type_mask=edge_type_mask, 273 | relabel_nodes=False, 274 | num_nodes=data.x.shape[0])[0] 275 | 276 | if self.subgraph_radius_limit > 0: 277 | # Restrict to neighboring cells that are within the radius (distance to center cell) limit 278 | assert "center_coord" in self.node_features 279 | coord_feature_inds = [i for i, n in enumerate(self.node_feature_names) if n.startswith('center_coord')] 280 | assert len(coord_feature_inds) == 2 281 | center_cell_coord = data.x[[center_ind]][:, coord_feature_inds] 282 | neighbor_cells_coord = data.x[sub_node_inds][:, coord_feature_inds] 283 | dists = ((neighbor_cells_coord - center_cell_coord)**2).sum(1).sqrt() 284 | sub_node_inds = sub_node_inds[(dists < self.subgraph_radius_limit)] 285 | 286 | # Construct subgraphs as separate pyg data objects 287 | sub_x = data.x[sub_node_inds] 288 | sub_edge_index, sub_edge_attr = subgraph(sub_node_inds, 289 | data.edge_index, 290 | edge_attr=data.edge_attr, 291 | relabel_nodes=True) 292 | 293 | relabeled_node_ind = list(sub_node_inds.numpy()).index(center_ind) 294 | 295 | sub_data = {'center_node_index': relabeled_node_ind, # center node index in the subgraph 296 | 'original_center_node': center_ind, # center node index in the original full cellular graph 297 | 'x': sub_x, 298 | 'edge_index': sub_edge_index, 299 | 'edge_attr': sub_edge_attr, 300 | 'num_nodes': len(sub_node_inds)} 301 | 302 | # Assign graph-level attributes 303 | for k in data: 304 | if not k[0] in sub_data: 305 | sub_data[k[0]] = k[1] 306 | 307 | sub_data = tg.data.Data.from_dict(sub_data) 308 | self.cached_data[(idx, center_ind)] = sub_data 309 | return sub_data 310 | 311 | def get_saved_subgraph_from_chunk(self, idx, center_ind): 312 | """Read the n-hop subgraph around cell `center_ind` from region `idx` 313 | Subgraph will be extracted from a pre-saved chunk file, which is generated by calling 314 | `save_all_subgraphs_to_chunk` 315 | """ 316 | full_graph_path = self.processed_paths[idx] 317 | subgraphs_path = full_graph_path.replace('.gpt', '.%d-hop.gpt' % self.subgraph_size) 318 | if not os.path.exists(subgraphs_path): 319 | warnings.warn("Subgraph save %s not found" % subgraphs_path) 320 | return self.calculate_subgraph(idx, center_ind) 321 | 322 | subgraphs = torch.load(subgraphs_path) 323 | # Store to cache first 324 | for j, g in enumerate(subgraphs): 325 | self.cached_data[(idx, j)] = g 326 | return self.cached_data[(idx, center_ind)] 327 | 328 | def pick_center(self, data): 329 | """Randomly pick a center cell from a full cellular graph, cell type balanced""" 330 | cell_types = data["x"][:, 0].long() 331 | freq = self.sampling_freq.gather(0, cell_types) 332 | freq = freq / freq.sum() 333 | center_node_ind = np.random.choice(np.arange(len(freq)), p=freq.cpu().data.numpy()) 334 | return center_node_ind 335 | 336 | def load_to_cache(self, idx, subgraphs=True): 337 | """Pre-load full cellular graph of region `idx` and all its n-hop subgraphs to cache""" 338 | data = torch.load(self.processed_paths[idx]) 339 | self.cached_data[idx] = data 340 | if subgraphs or self.subgraph_source == 'chunk_save': 341 | subgraphs_path = self.processed_paths[idx].replace('.gpt', '.%d-hop.gpt' % self.subgraph_size) 342 | if not os.path.exists(subgraphs_path): 343 | raise FileNotFoundError("Subgraph save %s not found, please run `save_all_subgraphs_to_chunk`." 344 | % subgraphs_path) 345 | neighbor_graphs = torch.load(subgraphs_path) 346 | for j, ng in enumerate(neighbor_graphs): 347 | self.cached_data[(idx, j)] = ng 348 | 349 | def save_all_subgraphs_to_chunk(self): 350 | """Save all n-hop subgraphs for all regions to chunk files (one file per region)""" 351 | for idx, p in enumerate(self.processed_paths): 352 | data = self.get_full(idx) 353 | n_nodes = data.x.shape[0] 354 | neighbor_graph_path = p.replace('.gpt', '.%d-hop.gpt' % self.subgraph_size) 355 | if os.path.exists(neighbor_graph_path): 356 | continue 357 | subgraphs = [] 358 | for node_i in range(n_nodes): 359 | subgraphs.append(self.calculate_subgraph(idx, node_i)) 360 | torch.save(subgraphs, neighbor_graph_path) 361 | return 362 | 363 | def clear_cache(self): 364 | del self.cached_data 365 | self.cached_data = {} 366 | return 367 | 368 | def plot_subgraph(self, idx, center_ind): 369 | """Plot the n-hop subgraph around cell `center_ind` from region `idx`""" 370 | xcoord_ind = self.node_feature_names.index('center_coord-x') 371 | ycoord_ind = self.node_feature_names.index('center_coord-y') 372 | 373 | _subg = self.calculate_subgraph(idx, center_ind) 374 | coords = _subg.x.data.numpy()[:, [xcoord_ind, ycoord_ind]].astype(float) 375 | x_c, y_c = coords[_subg.center_node_index] 376 | 377 | G = self.get_full_nx(idx) 378 | sub_node_inds = [] 379 | for n in G.nodes: 380 | c = np.array(G.nodes[n]['center_coord']).astype(float).reshape((1, -1)) 381 | if np.linalg.norm(coords - c, ord=2, axis=1).min() < 1e-2: 382 | sub_node_inds.append(n) 383 | assert len(sub_node_inds) == len(coords) 384 | _G = G.subgraph(sub_node_inds) 385 | 386 | node_colors = [self.cell_type_mapping[_G.nodes[n]['cell_type']] for n in _G.nodes] 387 | node_colors = [matplotlib.cm.tab20(ct) for ct in node_colors] 388 | plot_graph(_G, node_colors=node_colors) 389 | xmin, xmax = plt.gca().xaxis.get_data_interval() 390 | ymin, ymax = plt.gca().yaxis.get_data_interval() 391 | 392 | scale = max(x_c - xmin, xmax - x_c, y_c - ymin, ymax - y_c) * 1.05 393 | plt.xlim(x_c - scale, x_c + scale) 394 | plt.ylim(y_c - scale, y_c + scale) 395 | plt.plot([x_c], [y_c], 'x', markersize=5, color='k') 396 | 397 | def plot_graph_legend(self): 398 | """Legend for cell type colors""" 399 | plt.clf() 400 | plt.figure(figsize=(2, 2)) 401 | for ct, i in self.cell_type_mapping.items(): 402 | plt.plot([0], [0], '.', label=ct, color=matplotlib.cm.tab20(int(i) % 20)) 403 | plt.legend() 404 | plt.plot([0], [0], color='w', markersize=10) 405 | plt.axis('off') 406 | 407 | 408 | def k_hop_subgraph(node_ind, 409 | subgraph_size, 410 | edge_index, 411 | edge_type_mask=None, 412 | relabel_nodes=False, 413 | num_nodes=None): 414 | """A customized k-hop subgraph fn that filter for edge_type 415 | 416 | Args: 417 | node_ind (int): center node index 418 | subgraph_size (int): number of hops for the neighborhood subgraph 419 | edge_index (torch.Tensor): edge index tensor for the full graph 420 | edge_type_mask (torch.Tensor): edge type mask 421 | relabel_nodes (bool): if to relabel node indices to consecutive integers 422 | num_nodes (int): number of nodes in the full graph 423 | 424 | Returns: 425 | subset (torch.LongTensor): indices of nodes in the subgraph 426 | edge_index (torch.LongTensor): edges in the subgraph 427 | inv (toch.LongTensor): location of the center node in the subgraph 428 | edge_mask (torch.BoolTensor): edge mask indicating which edges were preserved 429 | """ 430 | 431 | num_nodes = edge_index.max().item() + 1 if num_nodes is None else num_nodes 432 | col, row = edge_index 433 | 434 | node_mask = row.new_empty(num_nodes, dtype=torch.bool) 435 | edge_mask = row.new_empty(row.size(0), dtype=torch.bool) 436 | edge_type_mask = torch.ones_like(edge_mask) if edge_type_mask is None else edge_type_mask 437 | 438 | if isinstance(node_ind, (int, list, tuple)): 439 | node_ind = torch.tensor([node_ind], device=row.device).flatten() 440 | else: 441 | node_ind = node_ind.to(row.device) 442 | 443 | subsets = [node_ind] 444 | next_root = node_ind 445 | 446 | for _ in range(subgraph_size): 447 | node_mask.fill_(False) 448 | node_mask[next_root] = True 449 | torch.index_select(node_mask, 0, row, out=edge_mask) 450 | subsets.append(col[edge_mask]) 451 | next_root = col[edge_mask * edge_type_mask] # use nodes connected with mask=True to span 452 | 453 | subset, inv = torch.cat(subsets).unique(return_inverse=True) 454 | inv = inv[:node_ind.numel()] 455 | 456 | node_mask.fill_(False) 457 | node_mask[subset] = True 458 | edge_mask = node_mask[row] & node_mask[col] 459 | 460 | edge_index = edge_index[:, edge_mask] 461 | 462 | if relabel_nodes: 463 | node_ind = row.new_full((num_nodes, ), -1) 464 | node_ind[subset] = torch.arange(subset.size(0), device=row.device) 465 | edge_index = node_ind[edge_index] 466 | 467 | return subset, edge_index, inv, edge_mask -------------------------------------------------------------------------------- /src/graph_build.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pandas as pd 4 | import json 5 | import pickle 6 | import matplotlib 7 | import matplotlib.path as mplPath 8 | import matplotlib.pyplot as plt 9 | import networkx as nx 10 | import warnings 11 | from scipy.spatial import Delaunay 12 | from itertools import combinations 13 | 14 | RADIUS_RELAXATION = 0.1 15 | NEIGHBOR_EDGE_CUTOFF = 55 # distance cutoff for neighbor edges, 55 pixels~20 um 16 | 17 | 18 | def plot_voronoi_polygons(voronoi_polygons, voronoi_polygon_colors=None): 19 | """Plot voronoi polygons for the cellular graph 20 | 21 | Args: 22 | voronoi_polygons (nx.Graph/list): cellular graph or list of voronoi polygons 23 | voronoi_polygon_colors (list): list of colors for voronoi polygons 24 | """ 25 | if isinstance(voronoi_polygons, nx.Graph): 26 | voronoi_polygons = [voronoi_polygons.nodes[n]['voronoi_polygon'] for n in voronoi_polygons.nodes] 27 | 28 | if voronoi_polygon_colors is None: 29 | voronoi_polygon_colors = ['w'] * len(voronoi_polygons) 30 | assert len(voronoi_polygon_colors) == len(voronoi_polygons) 31 | 32 | xmax = 0 33 | ymax = 0 34 | for polygon, polygon_color in zip(voronoi_polygons, voronoi_polygon_colors): 35 | x, y = polygon[:, 0], polygon[:, 1] 36 | plt.fill(x, y, facecolor=polygon_color, edgecolor='k', linewidth=0.5) 37 | xmax = max(xmax, x.max()) 38 | ymax = max(ymax, y.max()) 39 | 40 | plt.xlim(0, xmax) 41 | plt.ylim(0, ymax) 42 | return 43 | 44 | 45 | def plot_graph(G, node_colors=None, cell_type=False): 46 | """Plot dot-line graph for the cellular graph 47 | 48 | Args: 49 | G (nx.Graph): full cellular graph of the region 50 | node_colors (list): list of node colors. Defaults to None. 51 | """ 52 | # Extract basic node attributes 53 | node_coords = [G.nodes[n]['center_coord'] for n in G.nodes] 54 | node_coords = np.stack(node_coords, 0) 55 | 56 | if node_colors is None: 57 | unique_cell_types = sorted(set([G.nodes[n]['cell_type'] for n in G.nodes])) 58 | cell_type_to_color = {ct: matplotlib.cm.get_cmap("tab20")(i % 20) for i, ct in enumerate(unique_cell_types)} 59 | node_colors = [cell_type_to_color[G.nodes[n]['cell_type']] for n in G.nodes] 60 | assert len(node_colors) == node_coords.shape[0] 61 | 62 | for (i, j, edge_type) in G.edges.data(): 63 | xi, yi = G.nodes[i]['center_coord'] 64 | xj, yj = G.nodes[j]['center_coord'] 65 | if cell_type: 66 | plotting_kwargs = {"c": "k", 67 | "linewidth": 1, 68 | "linestyle": '-'} 69 | else: 70 | if edge_type['edge_type'] == 'neighbor': 71 | plotting_kwargs = {"c": "k", 72 | "linewidth": 1, 73 | "linestyle": '-'} 74 | else: 75 | plotting_kwargs = {"c": (0.4, 0.4, 0.4, 1.0), 76 | "linewidth": 0.3, 77 | "linestyle": '--'} 78 | plt.plot([xi, xj], [yi, yj], zorder=1, **plotting_kwargs) 79 | 80 | plt.scatter(node_coords[:, 0], 81 | node_coords[:, 1], 82 | s=10, 83 | c=node_colors, 84 | linewidths=0.3, 85 | zorder=2) 86 | plt.xlim(0, node_coords[:, 0].max() * 1.01) 87 | plt.ylim(0, node_coords[:, 1].max() * 1.01) 88 | return 89 | 90 | 91 | def load_cell_coords(cell_coords_file): 92 | """Load cell coordinates from file 93 | 94 | Args: 95 | cell_coords_file (str): path to csv file containing cell coordinates 96 | 97 | Returns: 98 | pd.DataFrame: dataframe containing cell coordinates, columns ['CELL_ID', 'X', 'Y'] 99 | """ 100 | df = pd.read_csv(cell_coords_file) 101 | df.columns = [c.upper() for c in df.columns] 102 | assert 'X' in df.columns, "Cannot find column for X coordinates" 103 | assert 'Y' in df.columns, "Cannot find column for Y coordinates" 104 | if 'CELL_ID' not in df.columns: 105 | warnings.warn("Cannot find column for cell id, using index as cell id") 106 | df['CELL_ID'] = df.index 107 | return df[['CELL_ID', 'X', 'Y']] 108 | 109 | 110 | def load_cell_types(cell_types_file): 111 | """Load cell types from file 112 | 113 | Args: 114 | cell_types_file (str): path to csv file containing cell types 115 | 116 | Returns: 117 | pd.DataFrame: dataframe containing cell types, columns ['CELL_ID', 'CELL_TYPE'] 118 | """ 119 | df = pd.read_csv(cell_types_file) 120 | df.columns = [c.upper() for c in df.columns] 121 | 122 | cell_type_column = [c for c in df.columns if c != 'CELL_ID'] 123 | if len(cell_type_column) == 1: 124 | cell_type_column = cell_type_column[0] 125 | elif 'CELL_TYPE' in cell_type_column: 126 | cell_type_column = 'CELL_TYPE' 127 | elif 'CELL_TYPES' in cell_type_column: 128 | cell_type_column = 'CELL_TYPES' 129 | else: 130 | raise ValueError("Please rename the column for cell type as 'CELL_TYPE'") 131 | 132 | if 'CELL_ID' not in df.columns: 133 | warnings.warn("Cannot find column for cell id, using index as cell id") 134 | df['CELL_ID'] = df.index 135 | _df = df[['CELL_ID', cell_type_column]] 136 | _df.columns = ['CELL_ID', 'CELL_TYPE'] # rename columns for clarity 137 | return _df 138 | 139 | 140 | def load_cell_biomarker_expression(cell_biomarker_expression_file): 141 | """Load cell biomarker expression from file 142 | 143 | Args: 144 | cell_biomarker_expression_file (str): path to csv file containing cell biomarker expression 145 | 146 | Returns: 147 | pd.DataFrame: dataframe containing cell biomarker expression, 148 | columns ['CELL_ID', 'BM-', 'BM-', ...] 149 | """ 150 | df = pd.read_csv(cell_biomarker_expression_file) 151 | df.columns = [c.upper() for c in df.columns] 152 | biomarkers = sorted([c for c in df.columns if c != 'CELL_ID']) 153 | for bm in biomarkers: 154 | if df[bm].dtype not in [np.dtype(int), np.dtype(float), np.dtype('float64')]: 155 | warnings.warn("Skipping column %s as it is not numeric" % bm) 156 | biomarkers.remove(bm) 157 | 158 | if 'CELL_ID' not in df.columns: 159 | warnings.warn("Cannot find column for cell id, using index as cell id") 160 | df['CELL_ID'] = df.index 161 | _df = df[['CELL_ID'] + biomarkers] 162 | _df.columns = ['CELL_ID'] + ['BM-%s' % bm for bm in biomarkers] 163 | return _df 164 | 165 | 166 | def load_cell_features(cell_features_file): 167 | """Load additional cell features from file 168 | 169 | Args: 170 | cell_features_file (str): path to csv file containing additional cell features 171 | 172 | Returns: 173 | pd.DataFrame: dataframe containing cell features 174 | columns ['CELL_ID', '', '', ...] 175 | """ 176 | df = pd.read_csv(cell_features_file) 177 | df.columns = [c.upper() for c in df.columns] 178 | 179 | feature_columns = sorted([c for c in df.columns if c != 'CELL_ID']) 180 | for feat in feature_columns: 181 | if df[feat].dtype not in [np.dtype(int), np.dtype(float), np.dtype('float64')]: 182 | warnings.warn("Skipping column %s as it is not numeric" % feat) 183 | feature_columns.remove(feat) 184 | 185 | if 'CELL_ID' not in df.columns: 186 | warnings.warn("Cannot find column for cell id, using index as cell id") 187 | df['CELL_ID'] = df.index 188 | 189 | return df[['CELL_ID'] + feature_columns] 190 | 191 | 192 | def read_raw_voronoi(voronoi_file): 193 | """Read raw coordinates of voronoi polygons from file 194 | 195 | Args: 196 | voronoi_file (str): path to the voronoi polygon file 197 | 198 | Returns: 199 | voronoi_polygons (list): list of voronoi polygons, 200 | represented by the coordinates of their exterior vertices 201 | """ 202 | if voronoi_file.endswith('json'): 203 | with open(voronoi_file) as f: 204 | raw_voronoi_polygons = json.load(f) 205 | elif voronoi_file.endswith('.pkl'): 206 | with open(voronoi_file, 'rb') as f: 207 | raw_voronoi_polygons = pickle.load(f) 208 | 209 | voronoi_polygons = [] 210 | for i, polygon in enumerate(raw_voronoi_polygons): 211 | if isinstance(polygon, list): 212 | polygon = np.array(polygon).reshape((-1, 2)) 213 | elif isinstance(polygon, dict): 214 | assert len(polygon) == 1 215 | polygon = list(polygon.values())[0] 216 | polygon = np.array(polygon).reshape((-1, 2)) 217 | voronoi_polygons.append(polygon) 218 | return voronoi_polygons 219 | 220 | 221 | def calcualte_voronoi_from_coords(x, y, xmax=None, ymax=None): 222 | """Calculate voronoi polygons from a set of points 223 | 224 | Points are assumed to have coordinates in ([0, xmax], [0, ymax]) 225 | 226 | Args: 227 | x (array-like): x coordinates of points 228 | y (array-like): y coordinates of points 229 | xmax (float): maximum x coordinate 230 | ymax (float): maximum y coordinate 231 | 232 | Returns: 233 | voronoi_polygons (list): list of voronoi polygons, 234 | represented by the coordinates of their exterior vertices 235 | """ 236 | from geovoronoi import voronoi_regions_from_coords 237 | from shapely import geometry 238 | xmax = 1.01 * max(x) if xmax is None else xmax 239 | ymax = 1.01 * max(y) if ymax is None else ymax 240 | boundary = geometry.Polygon([[0, 0], [xmax, 0], [xmax, ymax], [0, ymax]]) 241 | coords = np.stack([ 242 | np.array(x).reshape((-1,)), 243 | np.array(y).reshape((-1,))], 1) 244 | region_polys, _ = voronoi_regions_from_coords(coords, boundary) 245 | voronoi_polygons = [np.array(list(region_polys[k].exterior.coords)) for k in region_polys] 246 | return voronoi_polygons 247 | 248 | 249 | def build_graph_from_cell_coords(cell_data, voronoi_polygons): 250 | """Construct a networkx graph based on cell coordinates 251 | 252 | Args: 253 | cell_data (pd.DataFrame): dataframe containing cell data, 254 | columns ['CELL_ID', 'X', 'Y', ...] 255 | voronoi_polygons (list): list of voronoi polygons, 256 | represented by the coordinates of their exterior vertices 257 | 258 | Returns: 259 | G (nx.Graph): full cellular graph of the region 260 | """ 261 | save_polygon = True 262 | if not len(cell_data) == len(voronoi_polygons): 263 | warnings.warn("Number of cells does not match number of voronoi polygons") 264 | save_polygon = False 265 | 266 | coord_ar = np.array(cell_data[['CELL_ID', 'X', 'Y']]) 267 | G = nx.Graph() 268 | node_to_cell_mapping = {} 269 | for i, row in enumerate(coord_ar): 270 | vp = voronoi_polygons[i] if save_polygon else None 271 | G.add_node(i, voronoi_polygon=vp) 272 | node_to_cell_mapping[i] = row[0] 273 | 274 | dln = Delaunay(coord_ar[:, 1:3]) 275 | neighbors = [set() for _ in range(len(coord_ar))] 276 | for t in dln.simplices: 277 | for v in t: 278 | neighbors[v].update(t) 279 | 280 | for i, ns in enumerate(neighbors): 281 | for n in ns: 282 | G.add_edge(int(i), int(n)) 283 | 284 | return G, node_to_cell_mapping 285 | 286 | 287 | def build_graph_from_voronoi_polygons(voronoi_polygons, radius_relaxation=RADIUS_RELAXATION): 288 | """Construct a networkx graph based on voronoi polygons 289 | 290 | Args: 291 | voronoi_polygons (list): list of voronoi polygons, 292 | represented by the coordinates of their exterior vertices 293 | 294 | Returns: 295 | G (nx.Graph): full cellular graph of the region 296 | """ 297 | G = nx.Graph() 298 | 299 | polygon_vertices = [] 300 | vertice_identities = [] 301 | for i, polygon in enumerate(voronoi_polygons): 302 | G.add_node(i, voronoi_polygon=polygon) 303 | polygon_vertices.append(polygon) 304 | vertice_identities.append(np.ones((polygon.shape[0],)) * i) 305 | 306 | polygon_vertices = np.concatenate(polygon_vertices, 0) 307 | vertice_identities = np.concatenate(vertice_identities, 0).astype(int) 308 | for i, polygon in enumerate(voronoi_polygons): 309 | path = mplPath.Path(polygon) 310 | points_inside = np.where(path.contains_points(polygon_vertices, radius=radius_relaxation) + 311 | path.contains_points(polygon_vertices, radius=-radius_relaxation))[0] 312 | id_inside = set(vertice_identities[points_inside]) 313 | for j in id_inside: 314 | if j > i: 315 | G.add_edge(int(i), int(j)) 316 | return G 317 | 318 | 319 | def build_voronoi_polygon_to_cell_mapping(G, voronoi_polygons, cell_data): 320 | """Construct 1-to-1 mapping between voronoi polygons and cells 321 | 322 | Args: 323 | G (nx.Graph): full cellular graph of the region 324 | voronoi_polygons (list): list of voronoi polygons, 325 | represented by the coordinates of their exterior vertices 326 | cell_data (pd.DataFrame): dataframe containing cellular data 327 | 328 | Returns: 329 | voronoi_polygon_to_cell_mapping (dict): 1-to-1 mapping between 330 | polygon index (also node index in `G`) and cell id 331 | """ 332 | cell_coords = np.array(list(zip(cell_data['X'], cell_data['Y']))).reshape((-1, 2)) 333 | # Fetch all cells within each polygon 334 | cells_in_polygon = {} 335 | for i, polygon in enumerate(voronoi_polygons): 336 | path = mplPath.Path(polygon) 337 | _cell_ids = cell_data.iloc[np.where(path.contains_points(cell_coords))[0]] 338 | _cells = list(_cell_ids[['CELL_ID', 'X', 'Y']].values) 339 | cells_in_polygon[i] = _cells 340 | 341 | def get_point_reflection(c1, c2, c3): 342 | # Reflection of point c1 across line defined by c2 & c3 343 | x1, y1 = c1 344 | x2, y2 = c2 345 | x3, y3 = c3 346 | if x2 == x3: 347 | return (2 * x2 - x1, y1) 348 | m = (y3 - y2) / (x3 - x2) 349 | c = (x3 * y2 - x2 * y3) / (x3 - x2) 350 | d = (float(x1) + (float(y1) - c) * m) / (1 + m**2) 351 | x4 = 2 * d - x1 352 | y4 = 2 * d * m - y1 + 2 * c 353 | return (x4, y4) 354 | 355 | # Establish 1-to-1 mapping between polygons and cell ids 356 | voronoi_polygon_to_cell_mapping = {} 357 | for i, polygon in enumerate(voronoi_polygons): 358 | path = mplPath.Path(polygon) 359 | if len(cells_in_polygon[i]) == 1: 360 | # A single polygon contains a single cell centroid, assign cell id 361 | voronoi_polygon_to_cell_mapping[i] = cells_in_polygon[i][0][0] 362 | 363 | elif len(cells_in_polygon[i]) == 0: 364 | # Skipping polygons that do not contain any cell centroids 365 | continue 366 | 367 | else: 368 | # A single polygon contains multiple cell centroids 369 | polygon_edges = [(polygon[_i], polygon[_i + 1]) for _i in range(-1, len(polygon) - 1)] 370 | # Use the reflection of neighbor polygon's center cell 371 | neighbor_cells = sum([cells_in_polygon[j] for j in G.neighbors(i)], []) 372 | reflection_points = np.concatenate( 373 | [[get_point_reflection(cell[1:], edge[0], edge[1]) for edge in polygon_edges] 374 | for cell in neighbor_cells], 0) 375 | reflection_points = reflection_points[np.where(path.contains_points(reflection_points))] 376 | # Reflection should be very close to the center cell 377 | dists = [((reflection_points - c[1:])**2).sum(1).min(0) for c in cells_in_polygon[i]] 378 | if not np.min(dists) < 0.01: 379 | warnings.warn("Cannot find the exact center cell for polygon %d" % i) 380 | voronoi_polygon_to_cell_mapping[i] = cells_in_polygon[i][np.argmin(dists)][0] 381 | 382 | return voronoi_polygon_to_cell_mapping 383 | 384 | 385 | def assign_attributes(G, cell_data, node_to_cell_mapping, _assert=True): 386 | """Assign node and edge attributes to the cellular graph 387 | 388 | Args: 389 | G (nx.Graph): full cellular graph of the region 390 | cell_data (pd.DataFrame): dataframe containing cellular data 391 | node_to_cell_mapping (dict): 1-to-1 mapping between 392 | node index in `G` and cell id 393 | 394 | Returns: 395 | nx.Graph: populated cellular graph 396 | """ 397 | if _assert: 398 | assert set(G.nodes) == set(node_to_cell_mapping.keys()) 399 | biomarkers = sorted([c for c in cell_data.columns if c.startswith('BM-')]) 400 | 401 | additional_features = sorted([ 402 | c for c in cell_data.columns if c not in biomarkers + ['CELL_ID', 'X', 'Y', 'CELL_TYPE']]) 403 | 404 | cell_to_node_mapping = {v: k for k, v in node_to_cell_mapping.items()} 405 | node_properties = {} 406 | for _, cell_row in cell_data.iterrows(): 407 | cell_id = cell_row['CELL_ID'] 408 | if cell_id not in cell_to_node_mapping: 409 | continue 410 | node_index = cell_to_node_mapping[cell_id] 411 | p = {"cell_id": cell_id} 412 | p["center_coord"] = (cell_row['X'], cell_row['Y']) 413 | if "CELL_TYPE" in cell_row: 414 | p["cell_type"] = cell_row["CELL_TYPE"] 415 | else: 416 | p["cell_type"] = "Unassigned" 417 | biomarker_expression_dict = {bm.split('BM-')[1]: cell_row[bm] for bm in biomarkers} 418 | p["biomarker_expression"] = biomarker_expression_dict 419 | for feat_name in additional_features: 420 | p[feat_name] = cell_row[feat_name] 421 | node_properties[node_index] = p 422 | 423 | G = G.subgraph(node_properties.keys()) 424 | nx.set_node_attributes(G, node_properties) 425 | 426 | # Add distance, edge type (by thresholding) to edge feature 427 | edge_properties = get_edge_type(G) 428 | nx.set_edge_attributes(G, edge_properties) 429 | 430 | return G 431 | 432 | 433 | def get_edge_type(G, neighbor_edge_cutoff=NEIGHBOR_EDGE_CUTOFF): 434 | """Define neighbor vs distant edges based on distance 435 | 436 | Args: 437 | G (nx.Graph): full cellular graph of the region 438 | neighbor_edge_cutoff (float): distance cutoff for neighbor edges. 439 | By default we use 55 pixels (~20 um) 440 | 441 | Returns: 442 | dict: edge properties 443 | """ 444 | edge_properties = {} 445 | for (i, j) in G.edges: 446 | ci = G.nodes[i]['center_coord'] 447 | cj = G.nodes[j]['center_coord'] 448 | dist = np.linalg.norm(np.array(ci) - np.array(cj), ord=2) 449 | edge_properties[(i, j)] = { 450 | "distance": dist, 451 | "edge_type": "neighbor" if dist < neighbor_edge_cutoff else "distant" 452 | } 453 | ''' 454 | if ('center_coord' in G.nodes[i].keys()) & ('center_coord' in G.nodes[j].keys()): 455 | ci = G.nodes[i]['center_coord'] 456 | cj = G.nodes[j]['center_coord'] 457 | dist = np.linalg.norm(np.array(ci) - np.array(cj), ord=2) 458 | edge_properties[(i, j)] = { 459 | "distance": dist, 460 | "edge_type": "neighbor" if dist < neighbor_edge_cutoff else "distant" 461 | } 462 | else: 463 | G.remove_edge(i, j) 464 | # if 'center_coord' not in G.nodes[i].keys(): 465 | # G.remove_node(i) 466 | # else: 467 | # G.remove_node(j) 468 | 469 | # G.remove_edge(i, j) 470 | ''' 471 | return edge_properties 472 | 473 | 474 | def merge_cell_dataframes(df1, df2): 475 | """Merge two cell dataframes on shared rows (cells)""" 476 | if set(df2['CELL_ID']) != set(df1['CELL_ID']): 477 | warnings.warn("Cell ids in the two dataframes do not match") 478 | shared_cell_ids = set(df2['CELL_ID']).intersection(set(df1['CELL_ID'])) 479 | df1 = df1[df1['CELL_ID'].isin(shared_cell_ids)] 480 | df1 = df1.merge(df2, on='CELL_ID') 481 | return df1 482 | 483 | 484 | def construct_graph_for_region(region_id, 485 | cell_coords_file=None, 486 | cell_types_file=None, 487 | cell_biomarker_expression_file=None, 488 | cell_features_file=None, 489 | voronoi_file=None, 490 | graph_source='polygon', 491 | graph_output=None, 492 | voronoi_polygon_img_output=None, 493 | graph_img_output=None, 494 | common_cell_type_dict=None, 495 | common_biomarker_list=None, 496 | figsize=10): 497 | """Construct cellular graph for a region 498 | 499 | Args: 500 | region_id (str): region id 501 | cell_coords_file (str): path to csv file containing cell coordinates 502 | cell_types_file (str): path to csv file containing cell types/annotations 503 | cell_biomarker_expression_file (str): path to csv file containing cell biomarker expression 504 | cell_features_file (str): path to csv file containing additional cell features 505 | Note that features stored in this file can only be numeric and 506 | will be saved and used as is. 507 | voronoi_file (str): path to the voronoi coordinates file 508 | graph_source (str): source of edges in the graph, either "polygon" or "cell" 509 | graph_output (str): path for saving cellular graph as gpickle 510 | voronoi_polygon_img_output (str): path for saving voronoi image 511 | graph_img_output (str): path for saving dot-line graph image 512 | figsize (int): figure size for plotting 513 | 514 | Returns: 515 | G (nx.Graph): full cellular graph of the region 516 | """ 517 | assert cell_coords_file is not None, "cell coordinates must be provided" 518 | cell_data = load_cell_coords(cell_coords_file) 519 | 520 | if voronoi_file is None: 521 | # Calculate voronoi polygons based on cell coordinates 522 | voronoi_polygons = calcualte_voronoi_from_coords(cell_data['X'], cell_data['Y']) 523 | else: 524 | # Load voronoi polygons from file 525 | voronoi_polygons = read_raw_voronoi(voronoi_file) 526 | 527 | if cell_types_file is not None: 528 | # Load cell types 529 | cell_types = load_cell_types(cell_types_file) 530 | if common_cell_type_dict != None: 531 | cell_types['CELL_TYPE'] = cell_types['CELL_TYPE'].map(common_cell_type_dict) 532 | cell_data = merge_cell_dataframes(cell_data, cell_types) 533 | 534 | if cell_biomarker_expression_file is not None: 535 | # Load cell biomarker expression 536 | cell_expression = load_cell_biomarker_expression(cell_biomarker_expression_file) 537 | if common_biomarker_list != None: 538 | columns_to_keep = ['CELL_ID'] + ['BM-' + suffix.upper() for suffix in common_biomarker_list if suffix != 'CELL_ID'] 539 | cell_expression = cell_expression[columns_to_keep] 540 | cell_data = merge_cell_dataframes(cell_data, cell_expression) 541 | 542 | if cell_features_file is not None: 543 | # Load additional cell features 544 | additional_cell_features = load_cell_features(cell_features_file) 545 | cell_data = merge_cell_dataframes(cell_data, additional_cell_features) 546 | 547 | if graph_source == 'polygon': 548 | # Build initial cellular graph 549 | G = build_graph_from_voronoi_polygons(voronoi_polygons) 550 | # Construct matching between voronoi polygons and cells 551 | node_to_cell_mapping = build_voronoi_polygon_to_cell_mapping(G, voronoi_polygons, cell_data) 552 | # Prune graph to contain only voronoi polygons that have corresponding cells 553 | G = G.subgraph(node_to_cell_mapping.keys()) 554 | elif graph_source == 'cell': 555 | G, node_to_cell_mapping = build_graph_from_cell_coords(cell_data, voronoi_polygons) 556 | else: 557 | raise ValueError("graph_source must be either 'polygon' or 'cell'") 558 | 559 | # Assign attributes to cellular graph 560 | G = assign_attributes(G, cell_data, node_to_cell_mapping) 561 | G.region_id = region_id 562 | 563 | # Build Multiplex Network 564 | cell_to_node_mapping = {value: key for key, value in node_to_cell_mapping.items()} 565 | cell_type_graph = nx.Graph() 566 | cell_type_graph.add_nodes_from(G.nodes(data=True)) # This copies nodes with their features 567 | 568 | grouped = cell_data.groupby('CELL_TYPE') 569 | 570 | for cell_type, group in grouped: 571 | nodes = group['CELL_ID'] 572 | nodes = [cell_to_node_mapping[cell_id] for cell_id in nodes if cell_id in cell_to_node_mapping] # Convert cell IDs to node indices 573 | edges = combinations(nodes, 2) # Create combinations of nodes 574 | cell_type_graph.add_edges_from(edges) 575 | 576 | # Assign attributes to cell type graph 577 | cell_type_graph = assign_attributes(cell_type_graph, cell_data, node_to_cell_mapping, _assert=False) 578 | cell_type_graph.region_id = region_id 579 | 580 | multiplex_network = { 581 | 'layer_1': G, 582 | 'layer_2': cell_type_graph 583 | } 584 | 585 | # Visualization of cellular graph 586 | if voronoi_polygon_img_output is not None: 587 | plt.clf() 588 | plt.figure(figsize=(figsize, figsize)) 589 | plot_voronoi_polygons(G) 590 | plt.axis('scaled') 591 | plt.savefig(voronoi_polygon_img_output, dpi=300, bbox_inches='tight') 592 | 593 | if graph_img_output is not None: 594 | plt.clf() 595 | plt.figure(figsize=(figsize, figsize)) 596 | plot_graph(G) 597 | plt.axis('scaled') 598 | plt.savefig(graph_img_output[0], dpi=300, bbox_inches='tight') 599 | plt.clf() 600 | plt.figure(figsize=(figsize, figsize)) 601 | plot_graph(cell_type_graph, cell_type=True) 602 | plt.axis('scaled') 603 | plt.savefig(graph_img_output[1], dpi=300, bbox_inches='tight') 604 | 605 | # Save graph to file 606 | if graph_output is not None: 607 | with open(graph_output, 'wb') as f: 608 | pickle.dump(multiplex_network, f) 609 | 610 | return multiplex_network 611 | 612 | 613 | if __name__ == "__main__": 614 | raw_data_root = "data/voronoi/" 615 | nx_graph_root = "data/example_dataset/graph" 616 | fig_save_root = "data/example_dataset/fig" 617 | os.makedirs(nx_graph_root, exist_ok=True) 618 | os.makedirs(fig_save_root, exist_ok=True) 619 | 620 | region_ids = sorted(set(f.split('.')[0] for f in os.listdir(raw_data_root))) 621 | 622 | for region_id in region_ids: 623 | print("Processing %s" % region_id) 624 | cell_coords_file = os.path.join(raw_data_root, "%s.cell_data.csv" % region_id) 625 | cell_types_file = os.path.join(raw_data_root, "%s.cell_types.csv" % region_id) 626 | cell_biomarker_expression_file = os.path.join(raw_data_root, "%s.expression.csv" % region_id) 627 | cell_features_file = os.path.join(raw_data_root, "%s.cell_features.csv" % region_id) 628 | voronoi_file = os.path.join(raw_data_root, "%s.json" % region_id) 629 | 630 | voronoi_img_output = os.path.join(fig_save_root, "%s_voronoi.png" % region_id) 631 | graph_img_output = os.path.join(fig_save_root, "%s_graph.png" % region_id) 632 | graph_output = os.path.join(nx_graph_root, "%s.gpkl" % region_id) 633 | 634 | if not os.path.exists(graph_output): 635 | G = construct_graph_for_region( 636 | region_id, 637 | cell_coords_file=cell_coords_file, 638 | cell_types_file=cell_types_file, 639 | cell_biomarker_expression_file=cell_biomarker_expression_file, 640 | cell_features_file=cell_features_file, 641 | voronoi_file=voronoi_file, 642 | graph_output=graph_output, 643 | voronoi_polygon_img_output=voronoi_img_output, 644 | graph_img_output=graph_img_output, 645 | figsize=10) 646 | --------------------------------------------------------------------------------