├── assets
└── Model.png
├── requirements.txt
├── sh
├── upmc_hm.sh
├── charville_hm.sh
├── dfci_general.sh
├── charville_bc.sh
└── upmc_bc.sh
├── src
├── __init__.py
├── inference.py
├── precomputing.py
├── models.py
├── train.py
├── utils.py
├── transform.py
├── features.py
├── data.py
└── graph_build.py
├── dataset
└── README.md
├── README.md
└── main.py
/assets/Model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UNITES-Lab/Mew/HEAD/assets/Model.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy==1.26.4
2 | pandas==2.2.2
3 | torch==2.3.1
4 | torch-geometric==2.5.3
5 | scikit-learn==1.4.2
6 | scipy==1.13.1
7 | tqdm==4.66.4
--------------------------------------------------------------------------------
/sh/upmc_hm.sh:
--------------------------------------------------------------------------------
1 | device=4
2 |
3 | data=upmc
4 | task=cox
5 | shared=True
6 | attn_weight=True
7 | lr=0.001
8 | num_layers=1
9 | emb_dim=256
10 | batch_size=16
11 | drop_ratio=0.0
12 |
13 | python main.py \
14 | --data $data \
15 | --task $task \
16 | --shared $shared \
17 | --attn_weight $attn_weight \
18 | --lr $lr \
19 | --num_layers $num_layers \
20 | --emb_dim $emb_dim \
21 | --batch_size $batch_size \
22 | --drop_ratio $drop_ratio \
23 | --device $device
24 |
--------------------------------------------------------------------------------
/sh/charville_hm.sh:
--------------------------------------------------------------------------------
1 | device=1
2 |
3 | data=charville
4 | task=cox
5 | shared=False
6 | attn_weight=False
7 | lr=0.0001
8 | num_layers=4
9 | emb_dim=128
10 | batch_size=32
11 | drop_ratio=0.5
12 |
13 | python main.py \
14 | --data $data \
15 | --task $task \
16 | --shared $shared \
17 | --attn_weight $attn_weight \
18 | --lr $lr \
19 | --num_layers $num_layers \
20 | --emb_dim $emb_dim \
21 | --batch_size $batch_size \
22 | --drop_ratio $drop_ratio \
23 | --device $device
24 |
--------------------------------------------------------------------------------
/sh/dfci_general.sh:
--------------------------------------------------------------------------------
1 | device=2
2 |
3 | data=dfci
4 | task=classification
5 | shared=False
6 | attn_weight=True
7 | lr=0.001
8 | num_layers=2
9 | emb_dim=128
10 | batch_size=16
11 | drop_ratio=0.25
12 |
13 | python main.py \
14 | --data $data \
15 | --task $task \
16 | --shared $shared \
17 | --attn_weight $attn_weight \
18 | --lr $lr \
19 | --num_layers $num_layers \
20 | --emb_dim $emb_dim \
21 | --batch_size $batch_size \
22 | --drop_ratio $drop_ratio \
23 | --device $device
24 |
--------------------------------------------------------------------------------
/sh/charville_bc.sh:
--------------------------------------------------------------------------------
1 | device=0
2 |
3 | data=charville
4 | task=classification
5 | shared=False
6 | attn_weight=False
7 | lr=0.0001
8 | num_layers=4
9 | emb_dim=512
10 | batch_size=16
11 | drop_ratio=0.5
12 |
13 | python main.py \
14 | --data $data \
15 | --task $task \
16 | --shared $shared \
17 | --attn_weight $attn_weight \
18 | --lr $lr \
19 | --num_layers $num_layers \
20 | --emb_dim $emb_dim \
21 | --batch_size $batch_size \
22 | --drop_ratio $drop_ratio \
23 | --device $device
24 |
--------------------------------------------------------------------------------
/sh/upmc_bc.sh:
--------------------------------------------------------------------------------
1 | device=3
2 |
3 | data=upmc
4 | task=classification
5 | shared=True
6 | attn_weight=True
7 | pool=mean
8 | lr=0.0001
9 | num_layers=1
10 | emb_dim=256
11 | batch_size=16
12 | drop_ratio=0.0
13 |
14 | python main.py \
15 | --data $data \
16 | --task $task \
17 | --shared $shared \
18 | --attn_weight $attn_weight \
19 | --pool $pool \
20 | --lr $lr \
21 | --num_layers $num_layers \
22 | --emb_dim $emb_dim \
23 | --batch_size $batch_size \
24 | --drop_ratio $drop_ratio \
25 | --device $device
26 |
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
1 | from src.graph_build import plot_graph, plot_voronoi_polygons, construct_graph_for_region
2 | from src.data import CellularGraphDataset
3 | from src.models import SIGN_pred
4 | from src.transform import (
5 | FeatureMask,
6 | AddCenterCellBiomarkerExpression,
7 | AddCenterCellType,
8 | AddCenterCellIdentifier,
9 | AddGraphLabel,
10 | AddTwoGraphLabel
11 | )
12 | from src.inference import collect_predict_for_all_nodes
13 | from src.train import train_full_graph
14 | from src.precomputing import PrecomputingBase
15 |
--------------------------------------------------------------------------------
/dataset/README.md:
--------------------------------------------------------------------------------
1 | ## **Dataset**
2 |
3 | 1. Please download the following raw data here: [Enable Medicine](https://app.enablemedicine.com/portal/atlas-library/studies/92394a9f-6b48-4897-87de-999614952d94?sid=1168)
4 | - UPMC-HNC: `upmc_raw_data.zip` and `upmc_labels.csv`
5 | - Stanford-CRC: `charville_raw_data.zip` and `charville_labels.csv`
6 | - DFCI-HNC: `dfci_raw_data.zip` and `dfci_labels.csv`
7 |
8 | 2. After the download is complete, please locate the above files as follows:
9 | ```
10 | dataset/
11 | ├── charville_data
12 | └── charville_raw_data.zip
13 | └── charville_labels.csv
14 | ├── upmc_data
15 | └── upmc_raw_data.zip
16 | └── upmc_labels.csv
17 | ├── dfci_data
18 | └── dfci_labels.csv
19 | ├── general_data
20 | └── upmc_raw_data.zip
21 | └── dfci_raw_data.zip
22 | ```
23 |
24 | 3. By running each dataset with a certain task, the preprocessing (which would take a few hours to generate graphs) will automatically happen. Finally, the preprocessed structure for charville data will be as follows:
25 | ```
26 | dataset/
27 | ├── charville_data
28 | └── charville_raw_data.zip
29 | └── charville_labels.csv
30 | └── dataset_mew
31 | └── fig
32 | └── graph
33 | └── model
34 | └── tg_graph
35 | └── raw_data
36 | ```
37 |
--------------------------------------------------------------------------------
/src/inference.py:
--------------------------------------------------------------------------------
1 | import time
2 | import numpy as np
3 | from tqdm import trange
4 | from lifelines.utils import concordance_index
5 | from sklearn.metrics import roc_auc_score
6 |
7 | def collect_predict_for_all_nodes(model,
8 | xs,
9 | device,
10 | inds=None,
11 | graph=False,
12 | **kwargs):
13 |
14 | model = model.to(device)
15 | model.eval()
16 |
17 | node_results = {}
18 | graph_results = {}
19 | voronoi_attn_results = {}
20 | cell_type_attn_results = {}
21 |
22 | start_time = time.time()
23 | for i in inds:
24 | input_1 = [_x.to(device) for _x in xs[i][0]]
25 | input_2 = [_x.to(device) for _x in xs[i][1]]
26 | res, attn_values = model([input_1, input_2])
27 |
28 | voronoi_attn = attn_values[:,0].mean().item()
29 | cell_type_attn = attn_values[:,1].mean().item()
30 |
31 | graph_results[i] = res[0].cpu().data.numpy()
32 | voronoi_attn_results[i] = voronoi_attn
33 | cell_type_attn_results[i] = cell_type_attn
34 |
35 | # print(f'Prediction for {len(inds)} graphs:', time.time() - start_time)
36 |
37 | return node_results, graph_results, voronoi_attn_results, cell_type_attn_results
38 |
39 |
40 | def graph_classification_evaluate_fn(graph_preds,
41 | graph_ys,
42 | graph_ws=None,
43 | task='classification',
44 | print_res=True):
45 | """ Evaluate graph classification accuracy
46 |
47 | Args:
48 | graph_preds (array-like): binary classification logits for graph-level tasks, (num_subgraphs, num_tasks)
49 | graph_ys (array-like): binary labels for graph-level tasks, (num_subgraphs, num_tasks)
50 | graph_ws (array-like): weights for graph-level tasks, (num_subgraphs, num_tasks)
51 | print_res (bool): if to print the accuracy results
52 |
53 | Returns:
54 | list: list of metrics on all graph-level tasks
55 | """
56 | if graph_ws is None:
57 | graph_ws = np.ones_like(graph_ys)
58 | scores = []
59 | if task != 'classification':
60 | for task_i in trange(graph_preds.shape[1]):
61 | _pred = graph_preds[:, task_i]
62 | _times = graph_ys[:, task_i]
63 | _observed = graph_ws[:, task_i]
64 | idx = _times == _times
65 | s = concordance_index(_observed[idx], _pred[idx], _times[idx])
66 | scores.append(s)
67 | else:
68 | for task_i in trange(graph_ys.shape[1]):
69 | _label = graph_ys[:, task_i]
70 | _pred = graph_preds[:, task_i]
71 | _w = graph_ws[:, task_i]
72 | s = roc_auc_score(_label[np.where(_w > 0)], _pred[np.where(_w > 0)])
73 | scores.append(s)
74 | if print_res:
75 | print("GRAPH %s" % str(scores))
76 | return scores
77 |
78 |
79 | def full_graph_graph_classification_evaluate_fn(
80 | dataset_yw,
81 | graph_results,
82 | task,
83 | aggr='mean',
84 | print_res=True):
85 |
86 | n_tasks = list(graph_results.values())[0].shape[1]
87 | graph_preds = []
88 | graph_ys = []
89 | graph_ws = []
90 | for i in graph_results:
91 | graph_pred = [p for p in graph_results[i] if ((p is not None) and np.all(p == p))]
92 | graph_pred = np.stack(graph_pred, 0)
93 |
94 | if aggr == 'mean':
95 | graph_pred = np.nanmean(graph_pred, 0)
96 | else:
97 | raise NotImplementedError("Only mean-aggregation is supported now")
98 |
99 | graph_y = dataset_yw[i][0].numpy().flatten()
100 | graph_w = dataset_yw[i][1].numpy().flatten()
101 | graph_preds.append(graph_pred)
102 | graph_ys.append(graph_y)
103 | graph_ws.append(graph_w)
104 |
105 | graph_preds = np.concatenate(graph_preds, 0).reshape((-1, n_tasks))
106 | graph_ys = np.concatenate(graph_ys, 0).reshape((-1, n_tasks))
107 | graph_ws = np.concatenate(graph_ws, 0).reshape((-1, n_tasks))
108 |
109 | return graph_classification_evaluate_fn(graph_preds, graph_ys, graph_ws, task, print_res=print_res)
110 |
--------------------------------------------------------------------------------
/src/precomputing.py:
--------------------------------------------------------------------------------
1 | import time
2 | import torch
3 | from torch_geometric.transforms import SIGN
4 | from tqdm import trange
5 | from torch_geometric.data import Data
6 | from torch_geometric.transforms import BaseTransform
7 | import torch.distributions as dist
8 | from torch_geometric.utils import add_remaining_self_loops
9 |
10 | class PrecomputingBase(torch.nn.Module):
11 | def __init__(self, num_layer, device):
12 | super(PrecomputingBase, self).__init__()
13 | self.device = device
14 | self.num_layers = num_layer
15 |
16 | def precompute(self, data):
17 | print("precomputing features, may take a while.")
18 | t1 = time.time()
19 | print('Start Precomputing...!')
20 |
21 | self.xs = []
22 | for i in trange(len(data)):
23 | _data = data[i]
24 | # Geom data
25 | geom_data = Data()
26 | geom_data.x = _data['cell'].x
27 | geom_data.edge_index = _data['cell', 'geom', 'cell'].edge_index
28 | geom_data.edge_attr = _data['cell', 'geom', 'cell'].edge_attr
29 | geom_data.edge_index , geom_data.edge_attr = add_remaining_self_loops(geom_data.edge_index, geom_data.edge_attr)
30 | geom_data = geom_data.to(self.device)
31 | _geom_x = SIGN(self.num_layers)(geom_data, stochastic=False)
32 | geom_x = [_geom_x.x.cpu()] + [_geom_x[f"x{i}"] for i in range(1, self.num_layers + 1)]
33 |
34 | # Cell-type data
35 | cell_type_data = Data()
36 | cell_type_data.x = _data['cell'].x
37 | cell_type_data.edge_index = _data['cell', 'type', 'cell'].edge_index
38 | cell_type_data.edge_attr = _data['cell', 'type', 'cell'].edge_attr
39 | cell_type_data.edge_index , cell_type_data.edge_attr = add_remaining_self_loops(cell_type_data.edge_index, cell_type_data.edge_attr)
40 | cell_type_data = cell_type_data.to(self.device)
41 | _cell_type_x = SIGN(self.num_layers)(cell_type_data, stochastic=True)
42 | cell_type_x = [_cell_type_x.x.cpu()] + [_cell_type_x[f"x{i}"] for i in range(1, self.num_layers + 1)]
43 |
44 | self.xs.append([geom_x, cell_type_x])
45 |
46 | t2 = time.time()
47 | print("Precomputing finished by %.4f s." % (t2 - t1))
48 |
49 | return self.xs
50 |
51 | def forward(self, xs):
52 | raise NotImplementedError
53 |
54 |
55 | class SIGN(BaseTransform):
56 | def __init__(self, K):
57 | self.K = K
58 |
59 | def __call__(self, data, stochastic=False):
60 | assert data.edge_index is not None
61 | assert data.x is not None
62 | if not stochastic:
63 | weight = torch.ones_like(data.edge_attr[:,1])
64 | _edge_weight = data.edge_attr.sum(dim=1)
65 |
66 | else:
67 | weight = data.edge_attr[:,1]
68 | weight = 1 - weight
69 | _edge_weight = torch.ones_like(data.edge_attr[:,1])
70 |
71 | xs = [data.x]
72 | for i in range(1, self.K + 1):
73 | adj_t = self.stochastic_adj(data, weight, _edge_weight=_edge_weight)
74 | xs += [adj_t @ xs[-1]]
75 | data[f'x{i}'] = xs[-1].cpu()
76 |
77 | return data
78 |
79 | def __repr__(self) -> str:
80 | return f'{self.__class__.__name__}(K={self.K})'
81 |
82 | def stochastic_adj(self, data, weight, _edge_weight=None):
83 | bernoulli_dist = dist.Bernoulli(weight)
84 | mask = bernoulli_dist.sample().bool()
85 |
86 | N = data.num_nodes
87 | N_edges = data.edge_index.shape[1]
88 |
89 | if _edge_weight == None:
90 | edge_attr = torch.ones((N_edges,), device=data.device)
91 | else:
92 | edge_attr = _edge_weight[mask]
93 |
94 | edge_index = data.edge_index[:,mask]
95 | row, col = edge_index
96 |
97 | deg = torch.zeros(N, device=row.device, dtype=edge_attr.dtype)
98 | deg.scatter_add_(dim=0, index=col, src=edge_attr)
99 |
100 | deg_inv_sqrt = deg.pow_(-0.5)
101 | deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0)
102 | edge_weight = deg_inv_sqrt[row] * edge_attr * deg_inv_sqrt[col]
103 | adj = torch.sparse_coo_tensor(edge_index, edge_weight, size=torch.Size([N, N]))
104 | adj_t = adj.t()
105 |
106 | return adj_t
107 |
108 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Mew: Multiplexed Immunofluorescence Image Analysis through an Efficient Multiplex Network
2 |
3 | [](https://opensource.org/licenses/MIT) [](https://eccv.ecva.net/)
4 |
5 | Official implementation for "Mew: Multiplexed Immunofluorescence Image Analysis through an Efficient Multiplex Network" accepted by ECCV 2024. Our implementation is based on [SPACE-GM](https://gitlab.com/enable-medicine-public/space-gm).
6 |
7 | - Authors: [Sukwon Yun](https://sukwonyun.github.io/), [Jie Peng](https://openreview.net/profile?id=~Jie_Peng4), [Alexandro E. Trevino](https://scholar.google.com/citations?user=z7HDsuAAAAAJ&hl=en), [Chanyoung Park](https://dsail.kaist.ac.kr/professor/) and [Tianlong Chen](https://tianlong-chen.github.io/)
8 | - Paper: [arXiv](https://arxiv.org/abs/2407.17857)
9 |
10 | ## Overview
11 |
12 | Recent advancements in graph-based approaches for multiplexed immunofluorescence (mIF) images have significantly propelled the field forward, offering deeper insights into patient-level phenotyping. However, current graph-based methodologies encounter two primary challenges: (1) Cellular Heterogeneity, where existing approaches fail to adequately address the inductive biases inherent in graphs, particularly the homophily characteristic observed in cellular connectivity and; (2) Scalability, where handling cellular graphs from high-dimensional images faces difficulties in managing a high number of cells. To overcome these limitations, we introduce Mew, a novel framework designed to efficiently process mIF images through the lens of multiplex network. Mew innovatively constructs a multiplex network comprising two distinct layers: a Voronoi network for geometric information and a Cell-type network for capturing cell-wise homogeneity. This framework equips a scalable and efficient Graph Neural Network (GNN), capable of processing the entire graph during training. Furthermore, Mew integrates an interpretable attention module that autonomously identifies relevant layers for image classification. Extensive experiments on a real-world patient dataset from various institutions highlight Mew's remarkable efficacy and efficiency, marking a significant advancement in mIF image analysis.
13 |
14 |
15 |
16 |
17 | ## **Setup**
18 |
19 | ```
20 | conda create -n mew python=3.10 -y && conda activate mew
21 | pip install -r requirements.txt
22 | ```
23 |
24 | ## **Dataset**
25 |
26 | 1. Please download the following raw data here: [Enable Medicine](https://app.enablemedicine.com/portal/atlas-library/studies/92394a9f-6b48-4897-87de-999614952d94?sid=1168)
27 | - UPMC-HNC: `upmc_raw_data.zip` and `upmc_labels.csv`
28 | - Stanford-CRC: `charville_raw_data.zip` and `charville_labels.csv`
29 | - DFCI-HNC: `dfci_raw_data.zip` and `dfci_labels.csv`
30 |
31 | 2. After the download is complete, please locate the above files as follows:
32 | ```
33 | dataset/
34 | ├── charville_data
35 | └── charville_raw_data.zip
36 | └── charville_labels.csv
37 | ├── upmc_data
38 | └── upmc_raw_data.zip
39 | └── upmc_labels.csv
40 | ├── dfci_data
41 | └── dfci_labels.csv
42 | ├── general_data
43 | └── upmc_raw_data.zip
44 | └── dfci_raw_data.zip
45 | ```
46 |
47 | 3. By running each dataset with a certain task, the preprocessing (which would take a few hours to generate graphs) will automatically happen. Finally, the preprocessed structure for charville data will be as follows:
48 | ```
49 | dataset/
50 | ├── charville_data
51 | └── charville_raw_data.zip
52 | └── charville_labels.csv
53 | └── dataset_mew
54 | └── fig
55 | └── graph
56 | └── model
57 | └── tg_graph
58 | └── raw_data
59 | ```
60 |
61 | ## **Usage and Example**
62 | 1. Choose the data (upmc, charville, dfci) and task (classification, cox) and pass the hyperparameters. For the UPMC-HM (Hazard Modeling) task, it can be as follows:
63 | ```
64 | python main.py \
65 | --data upmc \
66 | --task cox \
67 | --shared True \
68 | --attn_weight True \
69 | --lr 0.001 \
70 | --num_layers 1 \
71 | --emb_dim 256 \
72 | --batch_size 16 \
73 | --drop_ratio 0.0 \
74 | --device 0
75 | ```
76 |
77 | 2. For others, please use:
78 | - **UPMC-BC**: `sh ./sh/upmc_bc.sh`
79 | - **UPMC-HM**: `sh ./sh/upmc_hm.sh`
80 | - **Charville-BC**: `sh ./sh/charville_bc.sh`
81 | - **Charville-HM**: `sh ./sh/charville_hm.sh`
82 | - **DFCI-Generalization**: `sh ./sh/dfci_general.sh`
83 |
84 |
85 | ## Citation
86 |
87 | ```bibtex
88 | @misc{yun2024mew,
89 | title={Mew: Multiplexed Immunofluorescence Image Analysis through an Efficient Multiplex Network},
90 | author={Sukwon Yun and Jie Peng and Alexandro E. Trevino and Chanyoung Park and Tianlong Chen},
91 | year={2024},
92 | eprint={2407.17857},
93 | archivePrefix={arXiv},
94 | primaryClass={cs.CV},
95 | url={https://arxiv.org/abs/2407.17857},
96 | }
97 | ```
--------------------------------------------------------------------------------
/src/models.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool, GlobalAttention, Set2Set
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from torch_scatter import scatter_add
7 | from torch_geometric.nn.inits import glorot, zeros
8 |
9 | class FeedForwardNet(nn.Module):
10 | def __init__(self, in_feats, hidden, out_feats, n_layers, dropout):
11 | super(FeedForwardNet, self).__init__()
12 | self.layers = nn.ModuleList()
13 | self.n_layers = n_layers
14 | if n_layers == 1:
15 | self.layers.append(nn.Linear(in_feats, out_feats))
16 | else:
17 | self.layers.append(nn.Linear(in_feats, hidden))
18 | for i in range(n_layers - 2):
19 | self.layers.append(nn.Linear(hidden, hidden))
20 | self.layers.append(nn.Linear(hidden, out_feats))
21 | if self.n_layers > 1:
22 | self.prelu = nn.PReLU()
23 | self.dropout = nn.Dropout(dropout)
24 | self.reset_parameters()
25 |
26 | def reset_parameters(self):
27 | gain = nn.init.calculate_gain("relu")
28 | for layer in self.layers:
29 | nn.init.xavier_uniform_(layer.weight, gain=gain)
30 | nn.init.zeros_(layer.bias)
31 |
32 | def forward(self, x):
33 | for layer_id, layer in enumerate(self.layers):
34 | x = layer(x)
35 | if layer_id < self.n_layers - 1:
36 | x = self.dropout(self.prelu(x))
37 | return x
38 |
39 | class SIGN_v2(nn.Module):
40 | def __init__(self, num_layer, num_feat, emb_dim, ffn_layers=2, dropout=0.25):
41 | super(SIGN_v2, self).__init__()
42 |
43 | in_feats = num_feat
44 | emb_dim = emb_dim
45 | out_feats = emb_dim
46 | num_hops = num_layer + 1
47 |
48 | self.dropout = nn.Dropout(dropout)
49 | self.prelu = nn.PReLU()
50 | self.inception_ffs = nn.ModuleList()
51 |
52 | self.batch_norms = torch.nn.ModuleList()
53 | for hop in range(num_hops):
54 | self.inception_ffs.append(
55 | FeedForwardNet(in_feats, emb_dim, emb_dim, ffn_layers, dropout))
56 | self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim))
57 | self.project = FeedForwardNet(num_hops * emb_dim, emb_dim, out_feats,
58 | ffn_layers, dropout)
59 |
60 | def forward(self, feats):
61 | feats = [feat.float() for feat in feats]
62 | hidden = []
63 | for i, (feat, ff) in enumerate(zip(feats, self.inception_ffs)):
64 | emb = ff(feat)
65 | hidden.append(self.batch_norms[i](emb))
66 | out = self.project(self.dropout(self.prelu(torch.cat(hidden, dim=-1))))
67 | return out
68 |
69 | def reset_parameters(self):
70 | for ff in self.inception_ffs:
71 | ff.reset_parameters()
72 | self.project.reset_parameters()
73 |
74 |
75 | class SIGN_pred(torch.nn.Module):
76 | def __init__(self,
77 | num_layer=2,
78 | num_feat=38,
79 | emb_dim=256,
80 | num_additional_feat=0,
81 | num_node_tasks=15,
82 | num_graph_tasks=2,
83 | node_embedding_output="last",
84 | drop_ratio=0,
85 | graph_pooling="mean",
86 | attn_weight=False,
87 | shared=False,
88 | ):
89 | super(SIGN_pred, self).__init__()
90 | self.drop_ratio = drop_ratio
91 | self.emb_dim = emb_dim
92 | self.num_node_tasks = num_node_tasks
93 | self.num_graph_tasks = num_graph_tasks
94 | self.attn_weight = attn_weight
95 |
96 | self.sign = SIGN_v2(num_layer, num_feat, emb_dim, dropout=drop_ratio)
97 | self.sign2 = self.sign if shared else SIGN_v2(num_layer, num_feat, emb_dim, dropout=drop_ratio)
98 | self.leakyrelu = nn.LeakyReLU(0.3)
99 | self.attention = nn.Parameter(torch.empty(size=(emb_dim, 1)))
100 | glorot(self.attention)
101 | if self.attn_weight:
102 | self.w1 = torch.nn.Linear(emb_dim, emb_dim)
103 | self.w2 = torch.nn.Linear(emb_dim, emb_dim)
104 | torch.nn.init.xavier_uniform_(self.w1.weight.data)
105 | torch.nn.init.xavier_uniform_(self.w1.weight.data)
106 |
107 | # Different kind of graph pooling
108 | if graph_pooling == "sum":
109 | self.pool = global_add_pool
110 | elif graph_pooling == "mean":
111 | self.pool = global_mean_pool
112 | elif graph_pooling == "max":
113 | self.pool = global_max_pool
114 | elif graph_pooling == "attention":
115 | if node_embedding_output == "concat":
116 | self.pool = GlobalAttention(gate_nn=torch.nn.Linear((self.num_layer + 1) * emb_dim, 1))
117 | else:
118 | self.pool = GlobalAttention(gate_nn=torch.nn.Linear(emb_dim, 1))
119 | elif graph_pooling[:-1] == "set2set":
120 | set2set_iter = int(graph_pooling[-1])
121 | if node_embedding_output == "concat":
122 | self.pool = Set2Set((self.num_layer + 1) * emb_dim, set2set_iter)
123 | else:
124 | self.pool = Set2Set(emb_dim, set2set_iter)
125 | else:
126 | raise ValueError("Invalid graph pooling type.")
127 |
128 | # For node and graph predictions
129 | self.mult = 1
130 | if graph_pooling[:-1] == "set2set":
131 | self.mult *= 2
132 | if node_embedding_output == "concat":
133 | self.mult *= (self.num_layer + 1)
134 |
135 | node_embedding_dim = self.mult * self.emb_dim
136 | if self.num_graph_tasks > 0:
137 | self.graph_pred_module = torch.nn.Sequential(
138 | torch.nn.Linear(node_embedding_dim + num_additional_feat, node_embedding_dim),
139 | torch.nn.LeakyReLU(),
140 | torch.nn.Linear(node_embedding_dim, node_embedding_dim),
141 | torch.nn.LeakyReLU(),
142 | torch.nn.Linear(node_embedding_dim, self.num_graph_tasks))
143 |
144 |
145 | def from_pretrained(self, model_file):
146 | original_dict = torch.load(model_file)
147 | new_dict = {key.replace('gnn.', '') if 'gnn.' in key else key: value for key, value in original_dict.items()}
148 | if self.num_graph_tasks > 0:
149 | graph_tasks_idx = []
150 | for key in new_dict.keys():
151 | if key.startswith('graph_pred_module'):
152 | idx_part = key.split('.')[1]
153 | graph_tasks_idx.append(int(idx_part))
154 |
155 | for i in range(len(self.graph_pred_module)):
156 | if i in graph_tasks_idx:
157 | self.graph_pred_module[i].weight = torch.nn.Parameter(new_dict[f'graph_pred_module.{i}.weight'])
158 | self.graph_pred_module[i].bias = torch.nn.Parameter(new_dict[f'graph_pred_module.{i}.bias'])
159 | del new_dict[f'graph_pred_module.{i}.weight']
160 | del new_dict[f'graph_pred_module.{i}.bias']
161 |
162 | self.gnn.load_state_dict(new_dict)
163 |
164 | def forward(self, data, batch=None, embed=False):
165 | batch = batch if batch != None else torch.zeros(len(data[0][0]),).long()
166 |
167 | node_representation_1 = self.sign(data[0]) # geom
168 | node_representation_2 = self.sign2(data[1]) # cell_type
169 |
170 | if self.attn_weight:
171 | geom_ = self.leakyrelu(torch.mm(self.w1(node_representation_1), self.attention))
172 | cell_type_ = self.leakyrelu(torch.mm(self.w2(node_representation_2), self.attention))
173 | else:
174 | geom_ = self.leakyrelu(torch.mm(node_representation_1, self.attention))
175 | cell_type_ = self.leakyrelu(torch.mm(node_representation_2, self.attention))
176 |
177 | values = torch.softmax(torch.cat((geom_, cell_type_), dim=1), dim=1)
178 | node_representation = (values[:,0].unsqueeze(1) * node_representation_1) + (values[:,1].unsqueeze(1) * node_representation_2)
179 |
180 | return_vals = []
181 | if self.num_graph_tasks > 0:
182 | input = self.pool(node_representation, batch.to(node_representation.device))
183 | graph_pred = self.graph_pred_module(input)
184 | return_vals.append(graph_pred)
185 |
186 | if embed:
187 | return return_vals, values, node_representation_1, node_representation_2
188 | else:
189 | return return_vals, values
190 |
191 | # Loss functions
192 | class BinaryCrossEntropy(torch.nn.Module):
193 | """Weighted binary cross entropy loss function"""
194 | def __init__(self, **kwargs):
195 | super(BinaryCrossEntropy, self).__init__(**kwargs)
196 | self.loss_fn = torch.nn.BCEWithLogitsLoss(reduction='none')
197 |
198 | def forward(self, y_pred, y, w):
199 | return (self.loss_fn(y_pred, y) * w).mean()
200 |
201 | class CrossEntropy(torch.nn.Module):
202 | """Cross entropy loss function for multi-class classification"""
203 | def __init__(self, **kwargs):
204 | super(CrossEntropy, self).__init__(**kwargs)
205 | self.loss_fn = torch.nn.CrossEntropyLoss()
206 |
207 | def forward(self, y_pred, y, w):
208 | return self.loss_fn(y_pred, y)
209 |
210 |
211 | class CoxSGDLossFn(torch.nn.Module):
212 | """Cox SGD loss function"""
213 | def __init__(self, top_n=2, regularizer_weight=0.05, **kwargs):
214 | self.top_n = top_n
215 | self.regularizer_weight = regularizer_weight
216 | super(CoxSGDLossFn, self).__init__(**kwargs)
217 |
218 | def forward(self, y_pred, length, event):
219 | assert y_pred.shape[0] == length.shape[0] == event.shape[0]
220 | n_samples = y_pred.shape[0]
221 | num_tasks = y_pred.shape[1]
222 | loss = 0
223 |
224 | for task in range(num_tasks):
225 | _length = length[:, task]
226 | _event = event[:, task]
227 | _pred = y_pred[:, task]
228 |
229 | pair_mat = (_length.reshape((1, -1)) - _length.reshape((-1, 1)) > 0) * _event.reshape((-1, 1))
230 | if self.top_n > 0:
231 | p_with_rand = pair_mat * (1 + torch.rand_like(pair_mat))
232 | rand_thr_ind = torch.argsort(p_with_rand, axis=1)[:, -(self.top_n + 1)]
233 | rand_thr = p_with_rand[(torch.arange(n_samples), rand_thr_ind)].reshape((-1, 1))
234 | pair_mat = pair_mat * (p_with_rand > rand_thr)
235 |
236 | valid_sample_is = torch.nonzero(pair_mat.sum(1)).flatten()
237 | pair_mat[(valid_sample_is, valid_sample_is)] = 1
238 |
239 | score_diff = (_pred.reshape((1, -1)) - _pred.reshape((-1, 1)))
240 | score_diff_row_max = torch.max(score_diff, axis=1, keepdims=True).values
241 | loss_tmp = (torch.exp(score_diff - score_diff_row_max) * pair_mat).sum(1)
242 | loss += (score_diff_row_max[:, 0][valid_sample_is] + torch.log(loss_tmp[valid_sample_is])).sum()
243 |
244 | regularizer = torch.abs(pair_mat.sum(0) * _pred.flatten()).sum()
245 | loss += self.regularizer_weight * regularizer
246 |
247 | return loss
--------------------------------------------------------------------------------
/src/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch
4 | import torch.optim
5 | from torch.utils.data import Dataset
6 | from torch.utils.data import DataLoader
7 |
8 | from src.inference import collect_predict_for_all_nodes
9 | from tqdm import trange
10 | from copy import deepcopy
11 |
12 | class CombinedDataset(Dataset):
13 | def __init__(self, graph_dataset, x_dataset):
14 | self.graph_dataset = graph_dataset
15 | self.x_dataset = x_dataset
16 |
17 | def __getitem__(self, idx):
18 | graph_y, graph_w = self.graph_dataset[idx]
19 | geom_x, cell_type_x = self.x_dataset[idx]
20 | geom_x_tensor = torch.stack(geom_x)
21 | cell_type_x_tensor = torch.stack(cell_type_x)
22 | batch_length = geom_x[0].shape[0]
23 |
24 | return graph_y, graph_w, geom_x_tensor, cell_type_x_tensor, batch_length
25 |
26 | def __len__(self):
27 | return len(self.graph_dataset)
28 |
29 | def custom_collate(batch):
30 | graph_y_list, graph_w_list, geom_x_list, cell_type_x_list, batch_length_list = zip(*batch)
31 | batch_idx_list = torch.tensor(sum([[i]*batch_length_list[i] for i in range(len(batch))], []))
32 | graph_y = torch.stack(graph_y_list).squeeze(2)
33 | graph_w = torch.stack(graph_w_list).squeeze(2)
34 | geom_x_layers = torch.cat(geom_x_list, dim=1)
35 | cell_type_x_layers = torch.cat(cell_type_x_list, dim=1)
36 |
37 | return batch_idx_list, graph_y, graph_w, geom_x_layers, cell_type_x_layers
38 |
39 | def train_full_graph(model,
40 | dataset_yw,
41 | xs,
42 | device,
43 | filename,
44 | fold,
45 | logger,
46 | graph_task_loss_fn=None,
47 | train_inds=None,
48 | valid_inds=None,
49 | test_inds=None,
50 | early_stop=100,
51 | num_iterations=1e5,
52 | evaluate_freq=1e4,
53 | evaluate_fn=[],
54 | evaluate_on_train=True,
55 | batch_size=64,
56 | lr=0.001,
57 | graph_loss_weight=1.,
58 | task='classification',
59 | name=None,
60 | save_model=False,
61 | **kwargs):
62 |
63 | if train_inds is None:
64 | train_inds = np.arange(len(xs))
65 |
66 | optimizer = torch.optim.Adam(model.parameters(), lr=lr)
67 | node_losses = []
68 | graph_losses = []
69 |
70 | train_dataset = [dataset_yw[i] for i in train_inds]
71 | train_xs = [xs[i] for i in train_inds]
72 |
73 | train_dataset_combined = CombinedDataset(train_dataset, train_xs)
74 | train_loader = DataLoader(train_dataset_combined, batch_size=batch_size, shuffle=True, num_workers=0, collate_fn=custom_collate)
75 |
76 | best_score = 0
77 | cnt = 0
78 |
79 | for i_iter in trange(int(num_iterations)):
80 | model.train()
81 | model.zero_grad()
82 |
83 | batch, graph_y, graph_w, geom_x, cell_type_x = next(iter(train_loader))
84 | loss = 0.
85 | geom_x = [_x.to(device) for _x in geom_x]
86 | cell_type_x = [_x.to(device) for _x in cell_type_x]
87 |
88 | res, _, z1, z2 = model([geom_x, cell_type_x], batch, embed=True)
89 | if res[0].sum().isnan():
90 | return [-1] * model.num_graph_tasks
91 |
92 | if model.num_graph_tasks > 0:
93 | assert graph_task_loss_fn is not None, \
94 | "Please specify `graph_task_loss_fn` in the training kwargs"
95 |
96 | graph_pred = res[-1]
97 | graph_y, graph_w = graph_y.float().to(device), graph_w.float().to(device)
98 |
99 | if model.num_graph_tasks > 1:
100 | graph_y, graph_w = graph_y.squeeze(1), graph_w.squeeze(1)
101 | idx = graph_y[:,0] == graph_y[:,0]
102 | graph_loss = graph_task_loss_fn(graph_pred[idx], graph_y[idx], graph_w[idx])
103 | loss += graph_loss * graph_loss_weight
104 |
105 | graph_losses.append(graph_loss.to('cpu').data.item())
106 |
107 | loss.backward()
108 | optimizer.step()
109 |
110 | if i_iter > 0 and (i_iter+1) % evaluate_freq == 0:
111 | model.eval()
112 | summary_str = "Finished iterations %d" % (i_iter+1)
113 | if len(node_losses) > evaluate_freq:
114 | summary_str += ", node loss %.2f" % np.mean(node_losses[-evaluate_freq:])
115 | if len(graph_losses) > evaluate_freq:
116 | summary_str += ", graph loss %.2f" % np.mean(graph_losses[-evaluate_freq:])
117 | print(summary_str)
118 |
119 | fn = evaluate_fn[0]
120 | result = fn(model,
121 | dataset_yw,
122 | xs,
123 | device,
124 | filename,
125 | fold,
126 | logger,
127 | train_inds=train_inds if evaluate_on_train else None,
128 | valid_inds=valid_inds,
129 | batch_size=batch_size,
130 | task=task,
131 | **kwargs)
132 |
133 | _txt = ",".join([s if isinstance(s, str) else ("%.3f" % s) for s in result])
134 | logger.info(f'[{name}-{task}][Fold {fold}][Epoch {i_iter+1}/{num_iterations}]-{_txt} \t Filename: {filename}')
135 | if cnt > 0:
136 | print(best_txt_test)
137 | valid_score = np.mean(result[-model.num_graph_tasks:])
138 |
139 | if best_score < valid_score:
140 | cnt +=1
141 | print(f'Iter: {i_iter+1}, Reached best valid score! Start testing ...')
142 | best_score = valid_score
143 |
144 | # Test on test_inds
145 | if test_inds != None:
146 | result_test = fn(model,
147 | dataset_yw,
148 | xs,
149 | device,
150 | filename,
151 | fold,
152 | logger,
153 | train_inds=None,
154 | valid_inds=test_inds,
155 | batch_size=batch_size,
156 | task=task,
157 | save=True,
158 | **kwargs)
159 |
160 | _txt_test = ",".join([s if isinstance(s, str) else ("%.3f" % s) for s in result_test])
161 | best_txt_test = f'[{name}-{task}][Fold {fold}][(Best Epoch) {i_iter+1}/{num_iterations}][(Test)]-{_txt_test} \t Filename: {filename}'
162 | logger.info(best_txt_test)
163 |
164 | best_epoch = i_iter
165 |
166 | if save_model:
167 | fn_save = evaluate_fn[1]
168 | fn_save(model,
169 | name,
170 | task,
171 | fold)
172 |
173 | else:
174 | _txt_test = _txt
175 |
176 | elif i_iter - best_epoch >= early_stop:
177 | print('Early Stop!')
178 | break
179 |
180 | else:
181 | pass
182 |
183 | return model, _txt_test
184 |
185 |
186 | def evaluate_by_full_graph(model,
187 | dataset_yw,
188 | xs,
189 | device,
190 | filename,
191 | fold,
192 | logger,
193 | train_inds=None,
194 | valid_inds=None,
195 | batch_size=64,
196 | shuffle=True,
197 | full_graph_graph_task_evaluate_fn=None,
198 | score_file=None,
199 | save=False,
200 | task='classification',
201 | **kwargs):
202 |
203 | score_row = ["Eval-Full-Graph"]
204 | if train_inds is not None:
205 | score_row.append("Train")
206 | node_preds, graph_preds, voronoi_attn, cell_type_attn = collect_predict_for_all_nodes(
207 | model,
208 | xs,
209 | device,
210 | inds=train_inds,
211 | batch_size=batch_size,
212 | shuffle=shuffle,
213 | **kwargs)
214 |
215 | if model.num_graph_tasks > 0:
216 | # Evaluate graph-level predictions
217 | assert full_graph_graph_task_evaluate_fn is not None, \
218 | "Please specify `full_graph_graph_task_evaluate_fn` in the training kwargs"
219 | score_row.append("attn-score")
220 | score_row.extend([np.nanmean(np.array(list(voronoi_attn.values()))), np.nanmean(np.array(list(cell_type_attn.values())))])
221 | score_row.append("graph-score")
222 | score_row.extend(full_graph_graph_task_evaluate_fn(dataset_yw, graph_preds, task, print_res=False))
223 |
224 | if valid_inds is not None:
225 | score_row.append("Valid")
226 | node_preds, graph_preds, voronoi_attn, cell_type_attn = collect_predict_for_all_nodes(
227 | model,
228 | xs,
229 | device,
230 | inds=valid_inds,
231 | batch_size=batch_size,
232 | shuffle=shuffle,
233 | **kwargs)
234 |
235 | region_ids = []
236 | graph_pred_list = []
237 | graph_ys = []
238 | graph_ws = []
239 |
240 | for _i, (i, pred) in enumerate(graph_preds.items()):
241 | region_id = valid_inds[_i]
242 | graph_y = dataset_yw[i][0]
243 | graph_w = dataset_yw[i][1]
244 | graph_pred_ = np.nanmean(pred, 0)
245 |
246 | region_ids.append(region_id)
247 | graph_pred_list.append(graph_pred_)
248 | graph_ys.append(graph_y)
249 | graph_ws.append(graph_w)
250 |
251 | graph_ys = np.stack(graph_ys).squeeze(1)
252 | graph_ws = np.stack(graph_ws).squeeze(1)
253 |
254 | if model.num_graph_tasks > 0:
255 | assert full_graph_graph_task_evaluate_fn is not None, \
256 | "Please specify `full_graph_graph_task_evaluate_fn` in the training kwargs"
257 | score_row.append("attn-score")
258 | score_row.extend([np.nanmean(np.array(list(voronoi_attn.values()))), np.nanmean(np.array(list(cell_type_attn.values())))])
259 | score_row.append("graph-score")
260 | score_row.extend(full_graph_graph_task_evaluate_fn(dataset_yw, graph_preds, task, print_res=False))
261 |
262 | if score_file is not None:
263 | with open(score_file, 'a') as f:
264 | f.write(",".join([s if isinstance(s, str) else ("%.3f" % s) for s in score_row]) + '\n')
265 |
266 | return score_row
267 |
268 |
269 | def save_model_weight(model,
270 | name,
271 | task,
272 | fold):
273 | os.makedirs('./ckpts/', exist_ok=True)
274 | model_tmp = deepcopy(model).cpu()
275 | torch.save(model_tmp.state_dict(), f'./ckpts/Mew_{name}_{task}_{fold}.pt')
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pickle
3 | import torch
4 | import os
5 | import sys
6 | import random
7 | import logging
8 | import numpy as np
9 | import pandas as pd
10 |
11 | MB = 1024 ** 2
12 | GB = 1024 ** 3
13 |
14 | EDGE_TYPES = {
15 | "neighbor": 0,
16 | "distant": 1,
17 | "self": 2,
18 | }
19 |
20 | # Metadata for the example dataset
21 | BIOMARKERS_UPMC = [
22 | "CD11b", "CD14", "CD15", "CD163", "CD20", "CD21", "CD31", "CD34", "CD3e",
23 | "CD4", "CD45", "CD45RA", "CD45RO", "CD68", "CD8", "CollagenIV", "HLA-DR",
24 | "Ki67", "PanCK", "Podoplanin", "Vimentin", "aSMA",
25 | ]
26 |
27 | CELL_TYPE_MAPPING_UPMC = {
28 | 'APC': 0,
29 | 'B cell': 1,
30 | 'CD4 T cell': 2,
31 | 'CD8 T cell': 3,
32 | 'Granulocyte': 4,
33 | 'Lymph vessel': 5,
34 | 'Macrophage': 6,
35 | 'Naive immune cell': 7,
36 | 'Stromal / Fibroblast': 8,
37 | 'Tumor': 9,
38 | 'Tumor (CD15+)': 10,
39 | 'Tumor (CD20+)': 11,
40 | 'Tumor (CD21+)': 12,
41 | 'Tumor (Ki67+)': 13,
42 | 'Tumor (Podo+)': 14,
43 | 'Vessel': 15,
44 | 'Unassigned': 16,
45 | }
46 |
47 | CELL_TYPE_FREQ_UPMC = {
48 | 'APC': 0.038220815854819415,
49 | 'B cell': 0.06635091324932002,
50 | 'CD4 T cell': 0.09489001514723677,
51 | 'CD8 T cell': 0.07824503590797544,
52 | 'Granulocyte': 0.026886102677111563,
53 | 'Lymph vessel': 0.006429085023448621,
54 | 'Macrophage': 0.10251942892685563,
55 | 'Naive immune cell': 0.033537398925429215,
56 | 'Stromal / Fibroblast': 0.07692583870182068,
57 | 'Tumor': 0.10921293560435145,
58 | 'Tumor (CD15+)': 0.06106975782857908,
59 | 'Tumor (CD20+)': 0.02098925720318548,
60 | 'Tumor (CD21+)': 0.053892044158901406,
61 | 'Tumor (Ki67+)': 0.13373768013421947,
62 | 'Tumor (Podo+)': 0.06276108605978743,
63 | 'Vessel': 0.034332604596958326,
64 | 'Unassigned': 0.001,
65 | }
66 |
67 |
68 | def setup_logger(save_dir, text, filename = 'log.txt'):
69 | os.makedirs(save_dir, exist_ok=True)
70 | logger = logging.getLogger(text)
71 | logger.setLevel(4)
72 | ch = logging.StreamHandler(stream=sys.stdout)
73 | ch.setLevel(logging.DEBUG)
74 | formatter = logging.Formatter("%(message)s")
75 | ch.setFormatter(formatter)
76 | logger.addHandler(ch)
77 | if save_dir:
78 | fh = logging.FileHandler(os.path.join(save_dir, filename))
79 | fh.setLevel(logging.DEBUG)
80 | fh.setFormatter(formatter)
81 | logger.addHandler(fh)
82 | logger.info("======================================================================================")
83 |
84 | return logger
85 |
86 | def seed_everything(seed=0):
87 | random.seed(seed)
88 | os.environ['PYTHONHASHSEED'] = str(seed)
89 | np.random.seed(seed)
90 | torch.manual_seed(seed)
91 | torch.cuda.manual_seed(seed)
92 | torch.backends.cudnn.deterministic = True
93 |
94 | def get_initials(string_list):
95 | initials = "".join(word[0] for word in string_list if word) # Ensure the word is not empty
96 | return initials.lower()
97 |
98 | def generate_charville_label(df):
99 | ## recurrence_interval & recurrence_event
100 | # Convert 'surgery_date' to datetime & Convert 'first_recurrence_date' to datetime, handling 'NONE' cases
101 | df['surgery_date'] = pd.to_datetime(df['surgery_date'])
102 | df['first_recurrence_date'] = pd.to_datetime(df['first_recurrence_date'], errors='coerce')
103 | df['last_contact_date'] = pd.to_datetime(df['last_contact_date'], errors='coerce')
104 |
105 | # Calculate the recurrence interval as the number of days from 'surgery_date' to 'first_recurrence_date'
106 | df['recurrence_interval'] = (df['first_recurrence_date'] - df['surgery_date']).dt.days
107 |
108 | # Recurrence event: 1 if there is a recurrence, 0 if 'NONE/DISEASE FREE' or 'NEVER DISEASE FREE'
109 | df['recurrence_event'] = df['type_of_first_recurrence'].apply(lambda x: 0 if 'FREE' in x else 1)
110 | df.loc[df['first_recurrence_date'].isna(), 'recurrence_interval'] = (df['last_contact_date'] - df['surgery_date']).dt.days
111 |
112 | df.loc[df['first_recurrence_date'].isna(), 'recurrence_interval'] = (df['last_contact_date'] - df['surgery_date']).dt.days
113 |
114 | ## survival_legnth & survival_status
115 | def calculate_length(row):
116 | # If the 'first_recurrence_date' is not missing, use it to calculate the length
117 | if pd.notna(row['first_recurrence_date']):
118 | return (row['first_recurrence_date'] - row['surgery_date']).days / 30.4375 # Average days per month
119 | # If 'first_recurrence_date' is missing, but 'last_contact_date' is not, use 'last_contact_date'
120 | elif pd.notna(row['last_contact_date']):
121 | return (row['last_contact_date'] - row['surgery_date']).days / 30.4375
122 | # If both are missing, return NaN
123 | else:
124 | return pd.NA
125 |
126 | # Apply the function to the rows where 'length_of_disease_free_survival' is missing
127 | df['survival_legnth'] = df.apply(calculate_length, axis=1)
128 | df['survival_legnth'] = np.round(pd.to_numeric(df['survival_legnth'], errors='coerce'))
129 |
130 | # Survival event: 1 if patient survives, 0 if dead
131 | df['survival_status'] = df['alive_or_deceased'].apply(lambda x: 1 if 'Dead' in x else 0)
132 |
133 | return df
134 |
135 | def preprocess_generalization(root='./dataset'):
136 | upmc_biomarker_cols = set(pd.read_csv(f'{root}/upmc_data/raw_data/UPMC_c001_v001_r001_reg001.expression.csv').columns)
137 | dfci_biomarker_cols = set(pd.read_csv(f'{root}/dfci_data/raw_data/s271_c001_v001_r001_reg001.expression.csv').columns)
138 |
139 | common_biomarker_cols_list = list(upmc_biomarker_cols & dfci_biomarker_cols)
140 | common_cell_type_dict = {
141 | 'APC': 'APC',
142 | 'Dendritic cell': 'APC',
143 | 'APC/macrophage': 'APC',
144 | 'B cell': 'B cell',
145 | 'CD4 T cell': 'CD4 T cell',
146 | 'T cell (CD45RO+/FoxP3+/ICOS+)': 'CD4 T cell',
147 | 'CD4 T cell (ICOS+/FoxP3+)': 'CD4 T cell',
148 | 'CD8 T cell': 'CD8 T cell',
149 | 'T cell (GranzymeB+/LAG3+)': 'CD8 T cell',
150 | 'Granulocyte': 'Granulocyte',
151 | 'Lymph vessel': 'Vessel cell',
152 | 'Macrophage': 'Macrophage',
153 | 'Naive immune cell': 'Naive immune cell',
154 | 'Naive B cell': 'Naive immune cell',
155 | 'Naive lymphocyte (CD45RA+/CD38+)': 'Naive immune cell',
156 | 'Stromal / Fibroblast': 'Stromal cell',
157 | 'Stroma': 'Stromal cell',
158 | 'Tumor': 'Tumor',
159 | 'Tumor (CD15+)': 'Tumor',
160 | 'Tumor (CD20+)': 'Tumor',
161 | 'Tumor (CD21+)': 'Tumor',
162 | 'Tumor (Ki67+)': 'Tumor (Ki67+)',
163 | 'Tumor (Podo+)': 'Tumor',
164 | 'Tumor (PanCK hi)': 'Tumor',
165 | 'Tumor (PanCK low)': 'Tumor',
166 | 'Unassigned': 'Other cell',
167 | 'NK cell': 'Other cell',
168 | 'Mast cell': 'Other cell',
169 | 'Unknown (TCF1+)': 'Other cell',
170 | 'Unclassified': 'Other cell',
171 | 'Vessel': 'Vessel cell',
172 | 'Vessel endothelium': 'Vessel cell',
173 | }
174 |
175 | return common_cell_type_dict, common_biomarker_cols_list
176 |
177 | def get_cell_type_metadata(nx_graph_files):
178 | """Find all unique cell types from a list of cellular graphs
179 |
180 | Args:
181 | nx_graph_files (list/str): path/list of paths to cellular graph files (gpickle)
182 |
183 | Returns:
184 | cell_type_mapping (dict): mapping of unique cell types to integer indices
185 | cell_type_freq (dict): mapping of unique cell types to their frequency
186 | """
187 | if isinstance(nx_graph_files, str):
188 | nx_graph_files = [nx_graph_files]
189 | cell_type_mapping = {}
190 | for g_f in nx_graph_files:
191 | try:
192 | G = pickle.load(open(g_f, 'rb'))
193 | except:
194 | print('Error detected! in file:', g_f)
195 | assert 'cell_type' in G['layer_1'].nodes[0]
196 | for n in G['layer_1'].nodes:
197 | ct = G['layer_1'].nodes[n]['cell_type']
198 | if ct not in cell_type_mapping:
199 | cell_type_mapping[ct] = 0
200 | cell_type_mapping[ct] += 1
201 | unique_cell_types = sorted(cell_type_mapping.keys())
202 | unique_cell_types_ct = [cell_type_mapping[ct] for ct in unique_cell_types]
203 | unique_cell_type_freq = [count / sum(unique_cell_types_ct) for count in unique_cell_types_ct]
204 | cell_type_mapping = {ct: i for i, ct in enumerate(unique_cell_types)}
205 | cell_type_freq = dict(zip(unique_cell_types, unique_cell_type_freq))
206 | return cell_type_mapping, cell_type_freq
207 |
208 |
209 | def get_biomarker_metadata(nx_graph_files):
210 | """Load all biomarkers from a list of cellular graphs
211 |
212 | Args:
213 | nx_graph_files (list/str): path/list of paths to cellular graph files (gpickle)
214 |
215 | Returns:
216 | shared_bms (list): list of biomarkers shared by all cells (intersect)
217 | all_bms (list): list of all biomarkers (union)
218 | """
219 | if isinstance(nx_graph_files, str):
220 | nx_graph_files = [nx_graph_files]
221 | all_bms = set()
222 | shared_bms = None
223 | for g_f in nx_graph_files:
224 | G = pickle.load(open(g_f, 'rb'))
225 | for n in G['layer_1'].nodes:
226 | bms = sorted(G['layer_1'].nodes[n]["biomarker_expression"].keys())
227 | for bm in bms:
228 | all_bms.add(bm)
229 | valid_bms = [
230 | bm for bm in bms if G['layer_1'].nodes[n]["biomarker_expression"][bm] == G['layer_1'].nodes[n]["biomarker_expression"][bm]]
231 | shared_bms = set(valid_bms) if shared_bms is None else shared_bms & set(valid_bms)
232 | shared_bms = sorted(shared_bms)
233 | all_bms = sorted(all_bms)
234 | return shared_bms, all_bms
235 |
236 |
237 | def get_graph_splits(dataset,
238 | split='random',
239 | cv_k=5,
240 | seed=None,
241 | fold_mapping=None):
242 | """ Define train/valid split
243 |
244 | Args:
245 | dataset (CellularGraphDataset): dataset to split
246 | split (str): split method, one of 'random', 'fold'
247 | cv_k (int): number of splits for random split
248 | seed (int): random seed
249 | fold_mapping (dict): mapping of region ids to folds,
250 | fold could be coverslip, patient, etc.
251 |
252 | Returns:
253 | split_inds (list): fold indices for each region in the dataset
254 | """
255 | splits = {}
256 | region_ids = set([dataset.get_full(i).region_id for i in range(dataset.N)])
257 | _region_ids = sorted(region_ids)
258 | if split == 'random':
259 | if seed is not None:
260 | np.random.seed(seed)
261 | if fold_mapping is None:
262 | fold_mapping = {region_id: region_id for region_id in _region_ids}
263 | # `_ids` could be sample ids / patient ids / certain properties
264 | _folds = sorted(set(list(fold_mapping.values())))
265 | np.random.shuffle(_folds)
266 | cv_shard_size = len(_folds) / cv_k
267 | for i, region_id in enumerate(_region_ids):
268 | splits[region_id] = _folds.index(fold_mapping[region_id]) // cv_shard_size
269 | elif split == 'fold':
270 | # Split into folds, one fold per group
271 | assert fold_mapping is not None
272 | _folds = sorted(set(list(fold_mapping.values())))
273 | for i, region_id in enumerate(_region_ids):
274 | splits[region_id] = _folds.index(fold_mapping[region_id])
275 | else:
276 | raise ValueError("split mode not recognized")
277 |
278 | split_inds = []
279 | for i in range(dataset.N):
280 | split_inds.append(splits[dataset.get_full(i).region_id])
281 | return split_inds
282 |
283 | def get_memory_usage(gpu, print_info=False):
284 | """Get accurate gpu memory usage by querying torch runtime"""
285 | allocated = torch.cuda.memory_allocated(gpu)
286 | reserved = torch.cuda.memory_reserved(gpu)
287 | if print_info:
288 | print("allocated: %.2f MB" % (allocated / 1024 / 1024), flush=True)
289 | print("reserved: %.2f MB" % (reserved / 1024 / 1024), flush=True)
290 | return allocated
291 |
292 | def compute_tensor_bytes(tensors):
293 | """Compute the bytes used by a list of tensors"""
294 | if not isinstance(tensors, (list, tuple)):
295 | tensors = [tensors]
296 |
297 | ret = 0
298 | for x in tensors:
299 | if x.dtype in [torch.int64, torch.long]:
300 | ret += np.prod(x.size()) * 8
301 | if x.dtype in [torch.float32, torch.int, torch.int32]:
302 | ret += np.prod(x.size()) * 4
303 | elif x.dtype in [torch.bfloat16, torch.float16, torch.int16]:
304 | ret += np.prod(x.size()) * 2
305 | elif x.dtype in [torch.int8]:
306 | ret += np.prod(x.size())
307 | else:
308 | print(x.dtype)
309 | raise ValueError()
310 | return ret
311 |
312 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import src
4 | from src.utils import setup_logger, seed_everything, preprocess_generalization
5 | import numpy as np
6 | import datetime
7 | import argparse
8 | import time
9 | import re
10 | from src.precomputing import PrecomputingBase
11 | import gc
12 |
13 | torch.set_num_threads(4)
14 |
15 | def str2bool(s):
16 | if s not in {'False', 'True', 'false', 'true'}:
17 | raise ValueError('Not a valid boolean string')
18 | return (s == 'True') or (s == 'true')
19 |
20 | def parse_args():
21 | parser = argparse.ArgumentParser(description='Mew')
22 | parser.add_argument('--device', type=int, default=0)
23 | parser.add_argument('--data', type=str, default='charville') # upmc, dfci, charville
24 | parser.add_argument('--task', type=str, default='classification') # classification, cox
25 | parser.add_argument('--use_node_features', metavar='S', type=str, nargs='+', default=['biomarker_expression', 'SIZE'])
26 | parser.add_argument('--shared', type=str2bool, default=False) # True, False
27 | parser.add_argument('--attn_weight', type=str2bool, default=False) # True, False
28 | parser.add_argument('--lr', type=float, default=0.0001) # 0.1, 0.01, 0.001, 0.0001
29 | parser.add_argument('--num_layers', type=int, default=4) # 1, 2, 3, 4
30 | parser.add_argument('--emb_dim', type=int, default=512) # 64, 128, 256, 512
31 | parser.add_argument('--batch_size', type=int, default=16) # 16, 32
32 | parser.add_argument('--drop_ratio', type=float, default=0.5) # 0.0, 0.25, 0.5
33 | parser.add_argument('--pool', type=str, default='mean')
34 | parser.add_argument('--num_epochs', type=int, default=1000)
35 | parser.add_argument('--eval_epoch', type=int, default=100)
36 | parser.add_argument('--early_stop', type=int, default=300)
37 | parser.add_argument('--seed', type=int, default=0)
38 | parser.add_argument('--save_model', type=str2bool, default=True)
39 | return parser.parse_known_args()
40 |
41 | def main():
42 | args, _ = parse_args()
43 | filename = f'Mew.txt'
44 | if args.data == 'dfci':
45 | args.task = 'classification'
46 |
47 | os.makedirs(f'./log/{args.data}/{args.task}', exist_ok=True)
48 | logger = setup_logger('./', '-', f'./log/{args.data}/{args.task}/{filename}')
49 | logger.info(datetime.datetime.now())
50 | seed_everything(args.seed)
51 |
52 | print(f'Data: {args.data}, Task: {args.task}')
53 |
54 | # Settings
55 | data = args.data # 'upmc', 'charville'
56 | print(args.use_node_features)
57 | use_node_features = args.use_node_features
58 | device = f'cuda:{args.device}'
59 |
60 | data_root = './dataset/'+data + "_data" if data in ['upmc', 'charville'] else './dataset/general_data'
61 | print(data_root)
62 | raw_data_root = f"{data_root}/raw_data"
63 | dataset_root = f"{data_root}/dataset_mew"
64 | graph_label_file = f"{data_root}/{data}_labels.csv"
65 |
66 | if data == "upmc":
67 | _common_cell_type_dict, _common_biomarker_list = None, None
68 | if args.task == 'classification':
69 | graph_tasks = ['primary_outcome', 'recurred', 'hpvstatus']
70 |
71 | elif args.task == 'cox':
72 | graph_tasks = [['survival_day', 'survival_status']]
73 |
74 | elif data == "charville":
75 | _common_cell_type_dict, _common_biomarker_list = None, None
76 | if args.task == 'classification':
77 | graph_tasks = ['primary_outcome', 'recurrence']
78 |
79 | elif args.task == 'cox':
80 | graph_tasks = [['survival_legnth', 'survival_status'], ['recurrence_interval', 'recurrence_event']]
81 |
82 | else: # 'dfci'
83 | _common_cell_type_dict, _common_biomarker_list = preprocess_generalization('./dataset')
84 | graph_label_file1 = f"./dataset/upmc_data/upmc_labels.csv"
85 | graph_label_file2 = f"./dataset/dfci_data/dfci_labels.csv"
86 | graph_tasks1 = ['primary_outcome'] # from UPMC
87 | graph_tasks2 = ['pTR_label'] # from DFCI
88 | graph_tasks = graph_tasks2
89 |
90 | # Generate cellular graphs from raw inputs
91 | nx_graph_root = os.path.join(dataset_root, "graph")
92 | fig_save_root = os.path.join(dataset_root, "fig")
93 | model_save_root = os.path.join(dataset_root, "model")
94 |
95 | region_ids = set([f.split('.')[0] for f in os.listdir(raw_data_root)])
96 | os.makedirs(nx_graph_root, exist_ok=True)
97 | os.makedirs(fig_save_root, exist_ok=True)
98 | os.makedirs(model_save_root, exist_ok=True)
99 |
100 | # Save graph generated from each region
101 | for region_id in region_ids:
102 | graph_output = os.path.join(nx_graph_root, "%s.gpkl" % region_id)
103 | if not os.path.exists(graph_output):
104 | print("Processing %s" % region_id)
105 | _voronoi_file=os.path.join(raw_data_root, "%s.json" % region_id) if data == 'upmc' else None
106 | G = src.construct_graph_for_region(
107 | region_id,
108 | cell_coords_file=os.path.join(raw_data_root, "%s.cell_data.csv" % region_id),
109 | cell_types_file=os.path.join(raw_data_root, "%s.cell_types.csv" % region_id),
110 | cell_biomarker_expression_file=os.path.join(raw_data_root, "%s.expression.csv" % region_id),
111 | cell_features_file=os.path.join(raw_data_root, "%s.cell_features.csv" % region_id),
112 | voronoi_file=_voronoi_file,
113 | graph_output=graph_output,
114 | voronoi_polygon_img_output=None,
115 | graph_img_output=None,
116 | common_cell_type_dict=_common_cell_type_dict,
117 | common_biomarker_list=_common_biomarker_list,
118 | figsize=10)
119 |
120 | # Define Cellular Graph Dataset
121 | dataset_kwargs = {
122 | 'raw_folder_name': 'graph',
123 | 'processed_folder_name': 'tg_graph',
124 | 'node_features': ["cell_type", "SIZE", "biomarker_expression", "neighborhood_composition", "center_coord"],
125 | 'edge_features': ["edge_type", "distance"],
126 | 'cell_type_mapping': None,
127 | 'cell_type_freq': None,
128 | 'biomarkers': None
129 | }
130 | feature_kwargs = {
131 | "biomarker_expression_process_method": "linear",
132 | "biomarker_expression_lower_bound": -2,
133 | "biomarker_expression_upper_bound": 3,
134 | "neighborhood_size": 10,
135 | }
136 | dataset_kwargs.update(feature_kwargs)
137 | dataset = src.CellularGraphDataset(dataset_root, **dataset_kwargs)
138 |
139 | # Define Transformers
140 | transformers = [
141 | src.FeatureMask(dataset,
142 | use_center_node_features=use_node_features,
143 | use_neighbor_node_features=use_node_features)
144 | ]
145 | if data in ['upmc', 'charville']:
146 | transformers.append(src.AddGraphLabel(graph_label_file, tasks=graph_tasks))
147 | else:
148 | transformers.append(src.AddTwoGraphLabel([graph_label_file1, graph_label_file2], tasks=[graph_tasks1, graph_tasks2]))
149 |
150 | dataset.set_transforms(transformers)
151 |
152 | # Precomputing & Label Extracting
153 | region_ids = [dataset.get_full(i).region_id for i in range(dataset.N)]
154 | coverslip_ids = [r_id.split('_')[1] for r_id in region_ids]
155 |
156 | num_feat = dataset[0]['cell'].x.shape[1]
157 | precomputer = PrecomputingBase(args.num_layers, device)
158 | sign_xs = precomputer.precompute(dataset)
159 | dataset_yw = [[dataset[i].graph_y, dataset[i].graph_w] for i in range(len(dataset))]
160 | del dataset
161 | gc.collect()
162 |
163 | # Define train/valid split
164 | if data == "upmc":
165 | fold0_coverslips = {'train': ['c001', 'c002', 'c004', 'c006'], 'val': ['c007'], 'test': ['c003', 'c005']}
166 | fold1_coverslips = {'train': ['c002', 'c003', 'c005', 'c007'], 'val': ['c004'], 'test': ['c001', 'c006']}
167 | fold2_coverslips = {'train': ['c001', 'c003', 'c004', 'c006'], 'val': ['c005'], 'test': ['c002', 'c007']}
168 | elif data == "charville":
169 | fold0_coverslips = {'train': ['c001', 'c002'], 'val': ['c003'], 'test': ['c004']}
170 | fold1_coverslips = {'train': ['c003', 'c004'], 'val': ['c002'], 'test': ['c001']}
171 | fold2_coverslips = {'train': ['c002', 'c004'], 'val': ['c001'], 'test': ['c003']}
172 |
173 | if data in ['upmc', 'charville']:
174 | split_indices = {}
175 | split_indices[0] = fold0_coverslips
176 | split_indices[1] = fold1_coverslips
177 | split_indices[2] = fold2_coverslips
178 | fold_list = [0,1,2]
179 | else:
180 | fold_list = [0]
181 |
182 | fold_results = []
183 | for fold in fold_list:
184 | # Define Model kwargs
185 | model_kwargs = {
186 | 'num_layer': args.num_layers,
187 | 'num_feat': num_feat,
188 | 'emb_dim': args.emb_dim,
189 | 'num_node_tasks': 0,
190 | 'num_graph_tasks': len(graph_tasks),
191 | 'node_embedding_output': 'last',
192 | 'drop_ratio': args.drop_ratio,
193 | 'graph_pooling': args.pool,
194 | 'attn_weight': args.attn_weight,
195 | 'shared': args.shared,
196 | }
197 | # Define Train and Evaluate kwargs
198 | train_kwargs = {
199 | 'batch_size': args.batch_size,
200 | 'lr': args.lr,
201 | 'graph_loss_weight': 1.0,
202 | 'num_iterations': args.num_epochs,
203 | 'early_stop': args.early_stop,
204 | 'node_task_loss_fn': None,
205 | 'graph_task_loss_fn': src.models.BinaryCrossEntropy() if args.task == 'classification' else src.models.CoxSGDLossFn(),
206 | 'evaluate_fn': [src.train.evaluate_by_full_graph, src.train.save_model_weight],
207 | 'evaluate_freq': args.eval_epoch,
208 | }
209 |
210 | evaluate_kwargs = {
211 | 'shuffle': True,
212 | 'full_graph_graph_task_evaluate_fn': src.inference.full_graph_graph_classification_evaluate_fn,
213 | 'score_file': os.path.join(model_save_root, 'Mew-%s-%s-%d.txt' % (graph_tasks[0], '_'.join(use_node_features), fold)),
214 | 'model_folder': os.path.join(model_save_root, 'Mew'),
215 | }
216 |
217 | train_kwargs.update(evaluate_kwargs)
218 | os.makedirs(evaluate_kwargs['model_folder'], exist_ok=True)
219 |
220 | train_inds = []
221 | valid_inds = []
222 | test_inds = []
223 |
224 | if data in ['upmc', 'charville']:
225 | for i, cs_id in enumerate(coverslip_ids):
226 | if cs_id in split_indices[fold]['train']:
227 | train_inds.append(i)
228 | elif cs_id in split_indices[fold]['val']:
229 | valid_inds.append(i)
230 | else:
231 | test_inds.append(i)
232 | else:
233 | for i, region_id in enumerate(region_ids):
234 | if 'UPMC' in region_id:
235 | if 'c007' in region_id:
236 | valid_inds.append(i)
237 | else:
238 | train_inds.append(i)
239 | else:
240 | test_inds.append(i)
241 |
242 | start_time = time.time()
243 | model = src.models.SIGN_pred(**model_kwargs)
244 | model = model.to(device)
245 |
246 | model, fold_result = src.train.train_full_graph(
247 | model, dataset_yw, sign_xs, device, filename, fold, logger,
248 | train_inds=train_inds, valid_inds=valid_inds, test_inds=test_inds, task=args.task, name=data, save_model=args.save_model, **train_kwargs)
249 |
250 | if fold_result == -1:
251 | print(f'The learning rate: {args.lr} is too high, causing numerical instability. Please try using a lower learning rate')
252 | break
253 |
254 | performance_list = re.findall(r"[-+]?\d*\.\d+|\d+", fold_result)
255 | fold_results.append([float(performance) for performance in performance_list[-len(graph_tasks):]])
256 |
257 | if fold == fold_list[-1]:
258 | averages = [round(sum(pair)/len(pair), 3) for pair in zip(*fold_results)]
259 | std_devs = [round(np.std(pair), 3) for pair in zip(*fold_results)]
260 |
261 | averages_std_combined = [f"{avg} ± {std}" for avg, std in zip(averages, std_devs)]
262 | combined_str = ', '.join(averages_std_combined)
263 | logger.info(f'[*Fold Average*] Total Epoch: {args.num_epochs}, Valid, graph-score, [{combined_str}]')
264 |
265 | print(f"Fold: {fold}")
266 | print("Total time elapsed: {:.4f}s".format(time.time() - start_time))
267 | logger.info("Filename: {}, Total time elapsed: {:.4f}s, Args: {}".format(filename, time.time() - start_time, args))
268 | logger.info("")
269 |
270 | if __name__ == '__main__':
271 | main()
--------------------------------------------------------------------------------
/src/transform.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | from copy import deepcopy
4 | import torch
5 | from src.utils import generate_charville_label
6 |
7 | class FeatureMask(object):
8 | """ Transformer object for masking features """
9 | def __init__(self,
10 | dataset,
11 | use_neighbor_node_features=None,
12 | use_center_node_features=None,
13 | use_edge_features=None,
14 | **kwargs):
15 | """ Construct the transformer
16 |
17 | Args:
18 | dataset (CellularGraphDataset): dataset object
19 | use_neighbor_node_features (list): list of node feature items to use
20 | for non-center nodes, all other features will be masked out
21 | use_center_node_features (list): list of node feature items to use
22 | for the center node, all other features will be masked out
23 | use_edge_features (list): list of edge feature items to use,
24 | all other features will be masked out
25 | """
26 |
27 | self.node_feature_names = dataset.node_feature_names
28 | self.edge_feature_names = dataset.edge_feature_names
29 |
30 | self.use_neighbor_node_features = use_neighbor_node_features if \
31 | use_neighbor_node_features is not None else dataset.node_features
32 | self.use_center_node_features = use_center_node_features if \
33 | use_center_node_features is not None else dataset.node_features
34 | self.use_edge_features = use_edge_features if \
35 | use_edge_features is not None else dataset.edge_features
36 |
37 | self.center_node_feature_masks = [
38 | 1 if any(name.startswith(feat) for feat in self.use_center_node_features)
39 | else 0 for name in self.node_feature_names]
40 | self.neighbor_node_feature_masks = [
41 | 1 if any(name.startswith(feat) for feat in self.use_neighbor_node_features)
42 | else 0 for name in self.node_feature_names]
43 |
44 | self.center_node_feature_masks = \
45 | torch.from_numpy(np.array(self.center_node_feature_masks).reshape((-1,))).float()
46 | self.neighbor_node_feature_masks = \
47 | torch.from_numpy(np.array(self.neighbor_node_feature_masks).reshape((1, -1))).float()
48 |
49 | def __call__(self, data):
50 | data = deepcopy(data)
51 | if "center_node_index" in data:
52 | center_node_feat = data.x[data.center_node_index].detach().data.clone()
53 | else:
54 | center_node_feat = None
55 | data = self.transform_neighbor_node(data)
56 | data = self.transform_center_node(data, center_node_feat)
57 | return data
58 |
59 | def transform_neighbor_node(self, data):
60 | """Apply neighbor node feature masking"""
61 | data['cell'].x = data['cell'].x * self.neighbor_node_feature_masks
62 | return data
63 |
64 | def transform_center_node(self, data, center_node_feat=None):
65 | """Apply center node feature masking"""
66 | if center_node_feat is None:
67 | return data
68 | assert "center_node_index" in data
69 | center_node_feat = center_node_feat * self.center_node_feature_masks
70 | data['cell'].x[data['cell'].center_node_index] = center_node_feat
71 | return data
72 |
73 |
74 | class AddCenterCellType(object):
75 | """Transformer for center cell type prediction"""
76 | def __init__(self, dataset, **kwargs):
77 | self.node_feature_names = dataset.node_feature_names
78 | self.cell_type_feat = self.node_feature_names.index('cell_type')
79 | # Assign a placeholder cell type for the center node
80 | self.placeholder_cell_type = max(dataset.cell_type_mapping.values()) + 1
81 |
82 | def __call__(self, data):
83 | data = deepcopy(data)
84 | assert "center_node_index" in data, \
85 | "Only subgraphs with center nodes are supported, cannot find `center_node_index`"
86 | center_node_feat = data.x[data.center_node_index].detach().clone()
87 | center_cell_type = center_node_feat[self.cell_type_feat]
88 | data.node_y = center_cell_type.long().view((1,))
89 | data.x[data.center_node_index, self.cell_type_feat] = self.placeholder_cell_type
90 | return data
91 |
92 |
93 | class AddCenterCellBiomarkerExpression(object):
94 | """Transformer for center cell biomarker expression prediction"""
95 | def __init__(self, dataset, **kwargs):
96 | self.node_feature_names = dataset.node_feature_names
97 | self.bm_exp_feat = np.array([
98 | i for i, feat in enumerate(self.node_feature_names)
99 | if feat.startswith('biomarker_expression')])
100 |
101 | def __call__(self, data):
102 | assert "center_node_index" in data, \
103 | "Only subgraphs with center nodes are supported, cannot find `center_node_index`"
104 | center_node_feat = data.x[data.center_node_index].detach().clone()
105 | center_cell_exp = center_node_feat[self.bm_exp_feat].float()
106 | data.node_y = center_cell_exp.view(1, -1)
107 | return data
108 |
109 |
110 | class AddCenterCellIdentifier(object):
111 | """Transformer for adding another feature column for identifying center cell
112 | Helpful when predicting node-level tasks.
113 | """
114 | def __init__(self, *args, **kwargs):
115 | pass
116 |
117 | def __call__(self, data):
118 | assert "center_node_index" in data, \
119 | "Only subgraphs with center nodes are supported, cannot find `center_node_index`"
120 | center_cell_identifier_column = torch.zeros((data.x.shape[0], 1), dtype=data.x.dtype)
121 | center_cell_identifier_column[data.center_node_index, 0] = 1.
122 | data.x = torch.cat([data.x, center_cell_identifier_column], dim=1)
123 | return data
124 |
125 |
126 | class AddGraphLabel(object):
127 | """Transformer for adding graph-level task labels"""
128 | def __init__(self, graph_label_file, tasks=[], **kwargs):
129 | """ Construct the transformer
130 |
131 | Args:
132 | graph_label_file (str): path to the csv file containing graph-level
133 | task labels. This file should always have the first column as region id.
134 | tasks (list): list of tasks to use, corresponding to column names
135 | of the csv file. If empty, use all tasks in the file
136 | """
137 | self.label_df = pd.read_csv(graph_label_file)
138 | graph_tasks = list(self.label_df.columns) if len(tasks) == 0 else tasks
139 | if ('charville' in graph_label_file) & ('recurrence_event' not in self.label_df.columns):
140 | self.label_df = generate_charville_label(self.label_df)
141 | self.label_df.to_csv(graph_label_file, index=False) # update csv file
142 | self.label_df.index = self.label_df.index.map(str) # Convert index to str
143 | self.graph_tasks = graph_tasks
144 | self.tasks, self.class_label_weights = self.build_class_weights(graph_tasks)
145 | self.c_index = True if ('survival_status' in np.stack(graph_tasks)) or ('recurrence_event' in np.stack(graph_tasks)) else False
146 |
147 | def build_class_weights(self, graph_tasks):
148 | valid_tasks = []
149 | class_label_weights = {}
150 | for task in graph_tasks:
151 | ar = list(self.label_df[task])
152 | valid_vals = [_y for _y in ar if _y == _y]
153 | unique_vals = set(valid_vals)
154 | if not all(v.__class__ in [int, float] for v in unique_vals):
155 | # Skip tasks with non-numeric labels
156 | continue
157 | valid_tasks.append(task)
158 | if len(unique_vals) > 5:
159 | # More than 5 unique values in labels, likely a regression task
160 | class_label_weights[task] = {_y: 1 for _y in unique_vals}
161 | else:
162 | # Classification task, compute class weights
163 | val_counts = {_y: valid_vals.count(_y) for _y in unique_vals}
164 | max_count = max(val_counts.values())
165 | class_label_weights[task] = {_y: max_count / val_counts[_y] for _y in unique_vals}
166 | return valid_tasks, class_label_weights
167 |
168 | def fetch_label(self, region_id, task_name):
169 | # Updated from SPACE-GM
170 | new_int = int(region_id.split('_')[1][1:])
171 | if 'UPMC' in region_id:
172 | new_int += 4
173 | if len(str(new_int)) == 1:
174 | new_int = f'00{new_int}'
175 | elif len(str(new_int)) == 2:
176 | new_int = f'0{new_int}'
177 | new_region_id = f'SpaceGMP-65_c{new_int}_' + region_id.split('_')[2] + '_' + region_id.split('_')[3] + '_' + region_id.split('_')[4]
178 | if 'acquisition_id_visualizer' in self.label_df:
179 | y = self.label_df[self.label_df["acquisition_id_visualizer"] == new_region_id][task_name].item()
180 | else:
181 | y = self.label_df[self.label_df["region_id"] == new_region_id][task_name].item()
182 |
183 | if y != y: # np.nan
184 | y = 0
185 | w = 0
186 | else:
187 | w = self.class_label_weights[task_name][y]
188 | return y, w
189 |
190 | def fetch_length_event(self, region_id, task_name):
191 | # Updated from SPACE-GM
192 | new_int = int(region_id.split('_')[1][1:])
193 | if 'UPMC' in region_id:
194 | new_int += 4
195 | if len(str(new_int)) == 1:
196 | new_int = f'00{new_int}'
197 | elif len(str(new_int)) == 2:
198 | new_int = f'0{new_int}'
199 | new_region_id = f'SpaceGMP-65_c{new_int}_' + region_id.split('_')[2] + '_' + region_id.split('_')[3] + '_' + region_id.split('_')[4]
200 |
201 | if 'acquisition_id_visualizer' in self.label_df:
202 | length = self.label_df[self.label_df["acquisition_id_visualizer"] == new_region_id][task_name[0]].item()
203 | event = self.label_df[self.label_df["acquisition_id_visualizer"] == new_region_id][task_name[1]].item()
204 | else:
205 | length = self.label_df[self.label_df["region_id"] == new_region_id][task_name[0]].item()
206 | event = self.label_df[self.label_df["region_id"] == new_region_id][task_name[1]].item()
207 |
208 | return length, event
209 |
210 | def __call__(self, data):
211 | graph_y = []
212 | graph_w = []
213 |
214 | for task in self.graph_tasks:
215 | if self.c_index:
216 | y, w = self.fetch_length_event(data.region_id, task)
217 | else:
218 | y, w = self.fetch_label(data.region_id, task)
219 |
220 | graph_y.append(y)
221 | graph_w.append(w)
222 | data.graph_y = torch.from_numpy(np.array(graph_y).reshape((1, -1)))
223 | data.graph_w = torch.from_numpy(np.array(graph_w).reshape((1, -1)))
224 |
225 | return data
226 |
227 | class AddTwoGraphLabel(object):
228 | """Transformer for adding graph-level task labels"""
229 | def __init__(self, graph_label_files=[], tasks=[], **kwargs):
230 | """ Construct the transformer
231 |
232 | Args:
233 | graph_label_file (str): path to the csv file containing graph-level
234 | task labels. This file should always have the first column as region id.
235 | tasks (list): list of tasks to use, corresponding to column names
236 | of the csv file. If empty, use all tasks in the file
237 | """
238 | self.label_df_upmc = pd.read_csv(graph_label_files[0])
239 | self.label_df_dfci = pd.read_csv(graph_label_files[1])
240 |
241 | self.label_df_upmc.index = self.label_df_upmc.index.map(str) # Convert index to str
242 | self.label_df_dfci.index = self.label_df_dfci.index.map(str) # Convert index to str
243 |
244 | graph_tasks_upmc = tasks[0]
245 | graph_tasks_dfci = tasks[1]
246 | self.graph_tasks_upmc = graph_tasks_upmc
247 | self.graph_tasks_dfci = graph_tasks_dfci
248 |
249 | self.tasks_upmc, self.class_label_weights_upmc = self.build_class_weights(graph_tasks_upmc, label_df=self.label_df_upmc)
250 | self.tasks_dfci, self.class_label_weights_dfci = self.build_class_weights(graph_tasks_dfci, label_df=self.label_df_dfci)
251 | self.c_index = False
252 |
253 | def build_class_weights(self, graph_tasks, label_df=None):
254 | valid_tasks = []
255 | class_label_weights = {}
256 | for task in graph_tasks:
257 | if type(task) == list:
258 | task = task[0] # length
259 | ar = list(label_df[task])
260 | valid_vals = [_y for _y in ar if _y == _y]
261 | unique_vals = set(valid_vals)
262 | if not all(v.__class__ in [int, float] for v in unique_vals):
263 | # Skip tasks with non-numeric labels
264 | continue
265 | valid_tasks.append(task)
266 |
267 | if len(unique_vals) > 5:
268 | # More than 5 unique values in labels, likely a regression task
269 | class_label_weights[task] = {_y: 1 for _y in unique_vals}
270 | else:
271 | # Classification task, compute class weights
272 | val_counts = {_y: valid_vals.count(_y) for _y in unique_vals}
273 | max_count = max(val_counts.values())
274 | class_label_weights[task] = {_y: max_count / val_counts[_y] for _y in unique_vals}
275 |
276 | return valid_tasks, class_label_weights
277 |
278 | def fetch_label(self, region_id, task_name):
279 | # Updated from SPACE-GM
280 | self.label_df = self.label_df_upmc if 'UPMC' in region_id else self.label_df_dfci
281 | new_int = int(region_id.split('_')[1][1:])
282 | if 'UPMC' in region_id:
283 | new_int += 4
284 | if len(str(new_int)) == 1:
285 | new_int = f'00{new_int}'
286 | elif len(str(new_int)) == 2:
287 | new_int = f'0{new_int}'
288 |
289 | if 'UPMC' in region_id:
290 | new_region_id = f'SpaceGMP-65_c{new_int}_' + region_id.split('_')[2] + '_' + region_id.split('_')[3] + '_' + region_id.split('_')[4]
291 | else:
292 | new_region_id = region_id
293 | if "acquisition_id_visualizer" in self.label_df:
294 | y = self.label_df[self.label_df["acquisition_id_visualizer"] == new_region_id][task_name].item()
295 |
296 | else:
297 | y = self.label_df[self.label_df["region_id"] == new_region_id][task_name].item()
298 |
299 | if y != y: # np.nan
300 | y = 0
301 | w = 0
302 | else:
303 | self.class_label_weights = self.class_label_weights_upmc if 'UPMC' in region_id else self.class_label_weights_dfci
304 | w = self.class_label_weights[task_name][y]
305 | return y, w
306 |
307 | def __call__(self, data):
308 | graph_y = []
309 | graph_w = []
310 |
311 | self.tasks = self.tasks_upmc if 'UPMC' in data.region_id else self.tasks_dfci
312 | for task in self.tasks:
313 | y, w = self.fetch_label(data.region_id, task)
314 | graph_y.append(y)
315 | graph_w.append(w)
316 | data.graph_y = torch.from_numpy(np.array(graph_y).reshape((1, -1)))
317 | data.graph_w = torch.from_numpy(np.array(graph_w).reshape((1, -1)))
318 |
319 | return data
--------------------------------------------------------------------------------
/src/features.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import networkx as nx
3 | from scipy.stats import rankdata
4 | import warnings
5 | import torch
6 | import torch_geometric as tg
7 | from torch_geometric.data import HeteroData
8 |
9 | from src.utils import EDGE_TYPES
10 |
11 |
12 | def process_biomarker_expression(G,
13 | node_ind,
14 | biomarkers=None,
15 | biomarker_expression_process_method='raw',
16 | biomarker_expression_lower_bound=-3,
17 | biomarker_expression_upper_bound=3,
18 | **kwargs):
19 | """ Process biomarker expression
20 |
21 | Args:
22 | G (nx.Graph): full cellular graph of the region
23 | node_ind (int): target node index
24 | biomarkers (list): list of biomarkers
25 | biomarker_expression_process_method (str): process method, one of 'raw', 'linear', 'log', 'rank'
26 | biomarker_expression_lower_bound (float): lower bound for min-max normalization, used for 'linear' and 'log'
27 | biomarker_expression_upper_bound (float): upper bound for min-max normalization, used for 'linear' and 'log'
28 |
29 | Returns:
30 | list: processed biomarker expression values
31 | """
32 |
33 | bm_exp_dict = G.nodes[node_ind]["biomarker_expression"]
34 | bm_exp_vec = []
35 | for bm in biomarkers:
36 | if bm_exp_dict is None or bm not in bm_exp_dict:
37 | bm_exp_vec.append(0.)
38 | else:
39 | bm_exp_vec.append(float(bm_exp_dict[bm]))
40 |
41 | bm_exp_vec = np.array(bm_exp_vec)
42 | lb = biomarker_expression_lower_bound
43 | ub = biomarker_expression_upper_bound
44 |
45 | if biomarker_expression_process_method == 'raw':
46 | return list(bm_exp_vec)
47 | elif biomarker_expression_process_method == 'linear':
48 | bm_exp_vec = np.clip(bm_exp_vec, lb, ub)
49 | bm_exp_vec = (bm_exp_vec - lb) / (ub - lb)
50 | return list(bm_exp_vec)
51 | elif biomarker_expression_process_method == 'log':
52 | bm_exp_vec = np.clip(np.log(bm_exp_vec + 1e-9), lb, ub)
53 | bm_exp_vec = (bm_exp_vec - lb) / (ub - lb)
54 | return list(bm_exp_vec)
55 | elif biomarker_expression_process_method == 'rank':
56 | bm_exp_vec = rankdata(bm_exp_vec, method='min')
57 | num_exp = len(bm_exp_vec)
58 | bm_exp_vec = (bm_exp_vec - 1) / (num_exp - 1)
59 | return list(bm_exp_vec)
60 | else:
61 | raise ValueError("expression process method %s not recognized" % biomarker_expression_process_method)
62 |
63 |
64 | def process_neighbor_composition(G,
65 | node_ind,
66 | cell_type_mapping=None,
67 | neighborhood_size=10,
68 | **kwargs):
69 | """ Calculate the composition vector of k-nearest neighboring cells
70 |
71 | Args:
72 | G (nx.Graph): full cellular graph of the region
73 | node_ind (int): target node index
74 | cell_type_mapping (dict): mapping of unique cell types to integer indices
75 | neighborhood_size (int): number of nearest neighbors to consider
76 |
77 | Returns:
78 | comp_vec (list): composition vector of k-nearest neighboring cells
79 | """
80 | center_coord = G.nodes[node_ind]["center_coord"]
81 |
82 | def node_dist(c1, c2):
83 | return np.linalg.norm(np.array(c1) - np.array(c2), ord=2)
84 |
85 | radius = 1
86 | neighbors = {}
87 | while len(neighbors) < 2 * neighborhood_size and radius < 5:
88 | radius += 1
89 | ego_g = nx.ego_graph(G, node_ind, radius=radius)
90 | neighbors = {n: feat_dict["center_coord"] for n, feat_dict in ego_g.nodes.data()}
91 |
92 | closest_neighbors = sorted(neighbors.keys(), key=lambda x: node_dist(center_coord, neighbors[x]))
93 | closest_neighbors = closest_neighbors[1:(neighborhood_size + 1)]
94 |
95 | comp_vec = np.zeros((len(cell_type_mapping),))
96 | for n in closest_neighbors:
97 | cell_type = cell_type_mapping[G.nodes[n]["cell_type"]]
98 | comp_vec[cell_type] += 1
99 | comp_vec = list(comp_vec / comp_vec.sum())
100 | return comp_vec
101 |
102 |
103 | def process_edge_distance(G,
104 | edge_ind,
105 | log_distance_lower_bound=2.,
106 | log_distance_upper_bound=5.,
107 | **kwargs):
108 | """ Process edge distance, distance will be log-transformed and min-max normalized
109 |
110 | Default parameters assume distances are usually within the range: 10-100 pixels / 3.7-37 um
111 |
112 | Args:
113 | G (nx.Graph): full cellular graph of the region
114 | edge_ind (int): target edge index
115 | log_distance_lower_bound (float): lower bound for log-transformed distance
116 | log_distance_upper_bound (float): upper bound for log-transformed distance
117 |
118 | Returns:
119 | list: list of normalized log-transformed distance
120 | """
121 | dist = G.edges[edge_ind]["distance"]
122 | log_dist = np.log(dist + 1e-5)
123 | _d = np.clip((log_dist - log_distance_lower_bound) /
124 | (log_distance_upper_bound - log_distance_lower_bound), 0, 1)
125 | return [_d]
126 |
127 |
128 | def process_feature(G, feature_item, node_ind=None, edge_ind=None, **feature_kwargs):
129 | """ Process a single node/edge feature item
130 |
131 | The following feature items are supported, note that some of them require
132 | keyword arguments in `feature_kwargs`:
133 |
134 | Node features:
135 | - feature_item: "cell_type"
136 | (required) "cell_type_mapping"
137 | - feature_item: "center_coord"
138 | - feature_item: "biomarker_expression"
139 | (required) "biomarkers",
140 | (optional) "biomarker_expression_process_method",
141 | (optional, if method is "linear" or "log") "biomarker_expression_lower_bound",
142 | (optional, if method is "linear" or "log") "biomarker_expression_upper_bound"
143 | - feature_item: "neighborhood_composition"
144 | (required) "cell_type_mapping",
145 | (optional) "neighborhood_size"
146 | - other additional feature items stored in the node attributes
147 | (see `graph_build.construct_graph_for_region`, argument `cell_features_file`)
148 |
149 | Edge features:
150 | - feature_item: "edge_type"
151 | - feature_item: "distance"
152 | (optional) "log_distance_lower_bound",
153 | (optional) "log_distance_upper_bound"
154 |
155 | Args:
156 | G (nx.Graph): full cellular graph of the region
157 | feature_item (str): feature item
158 | node_ind (int): target node index (if feature item is node feature)
159 | edge_ind (tuple): target edge index (if feature item is edge feature)
160 | feature_kwargs (dict): arguments for processing features
161 |
162 | Returns:
163 | v (list): feature vector
164 | """
165 | # Node features
166 | if node_ind is not None and edge_ind is None:
167 | if feature_item == "cell_type":
168 | # Integer index of the cell type
169 | assert "cell_type_mapping" in feature_kwargs, \
170 | "'cell_type_mapping' is required in the kwargs for feature item 'cell_type'"
171 | v = [feature_kwargs["cell_type_mapping"][G.nodes[node_ind]["cell_type"]]]
172 | return v
173 | elif feature_item == "center_coord":
174 | # Coordinates of the cell centroid
175 | v = list(G.nodes[node_ind]["center_coord"])
176 | return v
177 | elif feature_item == "biomarker_expression":
178 | # Biomarker expression of the cell
179 | assert "biomarkers" in feature_kwargs, \
180 | "'biomarkers' is required in the kwargs for feature item 'biomarker_expression'"
181 | v = process_biomarker_expression(G, node_ind, **feature_kwargs)
182 | return v
183 | elif feature_item == "neighborhood_composition":
184 | # Composition vector of the k-nearest neighboring cells
185 | assert "cell_type_mapping" in feature_kwargs, \
186 | "'cell_type_mapping' is required in the kwargs for feature item 'neighborhood_composition'"
187 | v = process_neighbor_composition(G, node_ind, **feature_kwargs)
188 | return v
189 | elif feature_item in G.nodes[node_ind]:
190 | # Additional features specified by users, e.g. "SIZE" in the example
191 | v = [G.nodes[node_ind][feature_item]]
192 | return v
193 | else:
194 | raise ValueError("Feature %s not found in the node attributes of graph %s, node %s" %
195 | (feature_item, G.region_id, str(node_ind)))
196 |
197 | # Edge features
198 | elif edge_ind is not None and node_ind is None:
199 | if feature_item == "edge_type":
200 | v = [EDGE_TYPES[G.edges[edge_ind]["edge_type"]]]
201 | return v
202 | elif feature_item == "distance":
203 | v = process_edge_distance(G, edge_ind, **feature_kwargs)
204 | return v
205 | elif feature_item in G.edges[edge_ind]:
206 | v = [G.edges[edge_ind][feature_item]]
207 | return v
208 | else:
209 | raise ValueError("Feature %s not found in the edge attributes of graph %s, edge %s" %
210 | (feature_item, G.region_id, str(edge_ind)))
211 |
212 | else:
213 | raise ValueError("One of node_ind or edge_ind should be specified")
214 |
215 |
216 | def nx_to_tg_graph(G,
217 | node_features=["cell_type",
218 | "biomarker_expression",
219 | "neighborhood_composition",
220 | "center_coord"],
221 | edge_features=["edge_type",
222 | "distance"],
223 | **feature_kwargs):
224 | """ Build pyg data objects from nx graphs
225 |
226 | Args:
227 | G (nx.Graph): full cellular graph of the region
228 | node_features (list, optional): list of node feature items
229 | edge_features (list, optional): list of edge feature items
230 | feature_kwargs (dict): arguments for processing features
231 |
232 | Returns:
233 | data_list (list): list of pyg data objects
234 | """
235 | data_list = []
236 |
237 | # Each connected component of the cellular graph will be a separate pyg data object
238 | # Usually there should only be one connected component for each cellular graph
239 | for inds in nx.connected_components(G):
240 | # Skip small connected components
241 | if len(inds) < len(G) * 0.1:
242 | continue
243 | sub_G = G.subgraph(inds)
244 |
245 | # Relabel nodes to be consecutive integers, note that node indices are
246 | # not meaningful here, cells are identified by the key "cell_id" in each node
247 | mapping = {n: i for i, n in enumerate(sorted(sub_G.nodes))}
248 | sub_G = nx.relabel.relabel_nodes(sub_G, mapping)
249 | assert np.all(sub_G.nodes == np.arange(len(sub_G)))
250 |
251 | # Append node and edge features to the pyg data object
252 | data = {"x": [], "edge_attr": [], "edge_index": []}
253 | for node_ind in sub_G.nodes:
254 | feat_val = []
255 | for key in node_features:
256 | feat_val.extend(process_feature(sub_G, key, node_ind=node_ind, **feature_kwargs))
257 | data["x"].append(feat_val)
258 |
259 | for edge_ind in sub_G.edges:
260 | feat_val = []
261 | for key in edge_features:
262 | feat_val.extend(process_feature(sub_G, key, edge_ind=edge_ind, **feature_kwargs))
263 | data["edge_attr"].append(feat_val)
264 | data["edge_index"].append(edge_ind)
265 | data["edge_attr"].append(feat_val)
266 | data["edge_index"].append(tuple(reversed(edge_ind)))
267 |
268 | for key, item in data.items():
269 | data[key] = torch.tensor(item)
270 | data['edge_index'] = data['edge_index'].t().long()
271 | data = tg.data.Data.from_dict(data)
272 | data.num_nodes = sub_G.number_of_nodes()
273 | data.region_id = G.region_id
274 | data_list.append(data)
275 | return data_list
276 |
277 | def nx_to_tg_hetero_graph(G,
278 | node_features=["cell_type",
279 | "biomarker_expression",
280 | "neighborhood_composition",
281 | "center_coord"],
282 | edge_features=["edge_type",
283 | "distance"],
284 | drop_edge=0.0,
285 | **feature_kwargs):
286 | """ Build pyg data objects from nx graphs
287 |
288 | Args:
289 | G (nx.Graph): full cellular graph of the region
290 | node_features (list, optional): list of node feature items
291 | edge_features (list, optional): list of edge feature items
292 | feature_kwargs (dict): arguments for processing features
293 |
294 | Returns:
295 | data_list (list): list of pyg data objects
296 | """
297 | data_list = []
298 |
299 | # Each connected component of the cellular graph will be a separate pyg data object
300 | # Usually there should only be one connected component for each cellular graph
301 | for inds in nx.connected_components(G['layer_1']):
302 | # Skip small connected components
303 | if len(inds) < len(G['layer_1']) * 0.1:
304 | continue
305 | sub_G_1 = G['layer_1'].subgraph(inds)
306 | sub_G_2 = G['layer_2'].subgraph(inds) # same inds for layer_2
307 |
308 | # Relabel nodes to be consecutive integers, note that node indices are
309 | # not meaningful here, cells are identified by the key "cell_id" in each node
310 | mapping = {n: i for i, n in enumerate(sorted(sub_G_1.nodes))}
311 | sub_G_1 = nx.relabel.relabel_nodes(sub_G_1, mapping)
312 | sub_G_2 = nx.relabel.relabel_nodes(sub_G_2, mapping)
313 |
314 | assert np.all(sub_G_1.nodes == np.arange(len(sub_G_1)))
315 |
316 | # Common node feature
317 | node_feature_list = []
318 | for node_ind in sub_G_1.nodes:
319 | feat_val = []
320 | for key in node_features:
321 | feat_val.extend(process_feature(sub_G_1, key, node_ind=node_ind, **feature_kwargs))
322 | node_feature_list.append(feat_val)
323 |
324 | # Edge for layer 1
325 | edge_index_1_list = []
326 | edge_attr_1_list = []
327 | for edge_ind in sub_G_1.edges:
328 | feat_val = []
329 | for key in edge_features:
330 | feat_val.extend(process_feature(sub_G_1, key, edge_ind=edge_ind, **feature_kwargs))
331 | edge_attr_1_list.append(feat_val)
332 | edge_index_1_list.append(edge_ind)
333 | edge_attr_1_list.append(feat_val)
334 | edge_index_1_list.append(tuple(reversed(edge_ind)))
335 |
336 | # Edge for layer 2
337 | edge_index_2_list = []
338 | edge_attr_2_list = []
339 | for edge_ind in sub_G_2.edges:
340 | feat_val = []
341 | for key in edge_features:
342 | feat_val.extend(process_feature(sub_G_2, key, edge_ind=edge_ind, **feature_kwargs))
343 | edge_attr_2_list.append(feat_val)
344 | edge_index_2_list.append(edge_ind)
345 | edge_attr_2_list.append(feat_val)
346 | edge_index_2_list.append(tuple(reversed(edge_ind)))
347 |
348 | node_features_tensor = torch.tensor(node_feature_list)
349 | edge_index_1_tensor = torch.tensor(edge_index_1_list, dtype=torch.long).t().contiguous()
350 | edge_attr_1_tensor = torch.tensor(edge_attr_1_list)
351 | edge_index_2_tensor = torch.tensor(edge_index_2_list, dtype=torch.long).t().contiguous()
352 | edge_attr_2_tensor = torch.tensor(edge_attr_2_list)
353 |
354 | # Append node and edge features to the pyg data object
355 | data = HeteroData()
356 | data['cell'].x = node_features_tensor
357 | data['cell', 'geom', 'cell'].edge_index = edge_index_1_tensor.long()
358 | data['cell', 'geom', 'cell'].edge_attr = edge_attr_1_tensor
359 | data['cell', 'type', 'cell'].edge_index = edge_index_2_tensor.long()
360 | data['cell', 'type', 'cell'].edge_attr = edge_attr_2_tensor
361 |
362 | data.num_nodes = sub_G_1.number_of_nodes()
363 | data.region_id = G['layer_1'].region_id
364 | data_list.append(data)
365 |
366 | return data_list
367 |
368 |
369 | def get_feature_names(features, cell_type_mapping=None, biomarkers=None):
370 | """ Helper fn for getting a list of feature names from a list of feature items
371 |
372 | Args:
373 | features (list): list of feature items
374 | cell_type_mapping (dict): mapping of unique cell types to integer indices
375 | biomarkers (list): list of biomarkers
376 |
377 | Returns:
378 | feat_names(list): list of feature names
379 | """
380 | feat_names = []
381 | for feat in features:
382 | if feat in ["distance", "cell_type", "edge_type"]:
383 | # feature "cell_type", "edge_type" will be a single integer indice
384 | # feature "distance" will be a single float value
385 | feat_names.append(feat)
386 | elif feat == "center_coord":
387 | # feature "center_coord" will be a tuple of two float values
388 | feat_names.extend(["center_coord-x", "center_coord-y"])
389 | elif feat == "biomarker_expression":
390 | # feature "biomarker_expression" will contain a list of biomarker expression values
391 | feat_names.extend(["biomarker_expression-%s" % bm for bm in biomarkers])
392 | elif feat == "neighborhood_composition":
393 | # feature "neighborhood_composition" will contain a composition vector of the immediate neighbors
394 | # The vector will have the same length as the number of unique cell types
395 | feat_names.extend(["neighborhood_composition-%s" % ct
396 | for ct in sorted(cell_type_mapping.keys(), key=lambda x: cell_type_mapping[x])])
397 | else:
398 | warnings.warn("Using additional feature: %s" % feat)
399 | feat_names.append(feat)
400 | return feat_names
401 |
--------------------------------------------------------------------------------
/src/data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | import numpy as np
4 | import multiprocessing
5 | import matplotlib
6 | import matplotlib.pyplot as plt
7 | import warnings
8 |
9 | import torch
10 | from torch.utils.data import RandomSampler
11 |
12 | import torch_geometric as tg
13 | from torch_geometric.data import Dataset
14 | from torch_geometric.loader import DataLoader
15 | from torch_geometric.utils import subgraph
16 |
17 | from src.graph_build import plot_graph
18 | from src.features import get_feature_names, nx_to_tg_graph, nx_to_tg_hetero_graph
19 | from src.utils import (
20 | EDGE_TYPES,
21 | get_cell_type_metadata,
22 | get_biomarker_metadata,
23 | )
24 |
25 |
26 | class CellularGraphDataset(Dataset):
27 | """ Main dataset structure for cellular graphs
28 | Inherited from https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.data.Dataset.html
29 | """
30 | def __init__(self,
31 | root,
32 | transform=[],
33 | pre_transform=None,
34 | raw_folder_name='graph',
35 | processed_folder_name='tg_graph',
36 | node_features=["cell_type", "expression", "neighborhood_composition", "center_coord"],
37 | edge_features=["edge_type", "distance"],
38 | cell_type_mapping=None,
39 | cell_type_freq=None,
40 | biomarkers=None,
41 | subgraph_size=0,
42 | subgraph_source='on-the-fly',
43 | subgraph_allow_distant_edge=True,
44 | subgraph_radius_limit=-1,
45 | sampling_avoid_unassigned=True,
46 | unassigned_cell_type='Unassigned',
47 | **feature_kwargs):
48 | """ Initialize the dataset
49 |
50 | Args:
51 | root (str): path to the dataset directory
52 | transform (list): list of transformations (see `transform.py`),
53 | applied to each output graph on-the-fly
54 | pre_transform (list): list of transformations, applied to each graph before saving
55 | raw_folder_name (str): name of the sub-folder containing raw graphs (gpickle)
56 | processed_folder_name (str): name of the sub-folder containing processed graphs (pyg data object)
57 | node_features (list): list of all available node feature items,
58 | see `features.process_features` for details
59 | edge_features (list): list of all available edge feature items
60 | cell_type_mapping (dict): mapping of unique cell types to integer indices,
61 | see `utils.get_cell_type_metadata`
62 | cell_type_freq (dict): mapping of unique cell types to their frequencies,
63 | see `utils.get_cell_type_metadata`
64 | biomarkers (list): list of biomarkers,
65 | see `utils.get_expression_biomarker_metadata`
66 | subgraph_size (int): number of hops for subgraph, 0 means using the full cellular graph
67 | subgraph_source (str): source of subgraphs, one of 'on-the-fly', 'chunk_save'
68 | subgraph_allow_distant_edge (bool): whether to consider distant edges
69 | subgraph_radius_limit (float): radius (distance to center cell in pixel) limit for subgraphs,
70 | -1 means no limit
71 | sampling_avoid_unassigned (bool): whether to avoid sampling cells with unassigned cell type
72 | unassigned_cell_type (str): name of the unassigned cell type
73 | feature_kwargs (dict): optional arguments for processing features
74 | see `features.process_features` for details
75 | """
76 | self.root = root
77 | self.raw_folder_name = raw_folder_name
78 | self.processed_folder_name = processed_folder_name
79 | os.makedirs(self.raw_dir, exist_ok=True)
80 | os.makedirs(self.processed_dir, exist_ok=True)
81 |
82 | # Find all unique cell types in the dataset
83 | if cell_type_mapping is None or cell_type_freq is None:
84 | try:
85 | with open(self.raw_dir+'/cell_type_mapping_freq.pkl', 'rb') as f:
86 | self.cell_type_mapping, self.cell_type_freq = pickle.load(f)
87 |
88 | except:
89 | nx_graph_files = [os.path.join(self.raw_dir, f) for f in self.raw_file_names]
90 | self.cell_type_mapping, self.cell_type_freq = get_cell_type_metadata(nx_graph_files)
91 |
92 | with open(self.raw_dir+'/cell_type_mapping_freq.pkl', 'wb') as f:
93 | pickle.dump([self.cell_type_mapping, self.cell_type_freq], f)
94 |
95 | else:
96 | self.cell_type_mapping = cell_type_mapping
97 | self.cell_type_freq = cell_type_freq
98 |
99 | # Find all available biomarkers for cells in the dataset
100 | if biomarkers is None:
101 | try:
102 | with open(self.raw_dir+'/biomarkers.pkl', 'rb') as f:
103 | self.biomarkers = pickle.load(f)
104 | except:
105 | nx_graph_files = [os.path.join(self.raw_dir, f) for f in self.raw_file_names]
106 | self.biomarkers, _ = get_biomarker_metadata(nx_graph_files)
107 |
108 | with open(self.raw_dir+'/biomarkers.pkl', 'wb') as f:
109 | pickle.dump(self.biomarkers, f)
110 |
111 | else:
112 | self.biomarkers = biomarkers
113 |
114 | # Node features & edge features
115 | self.node_features = node_features
116 | self.edge_features = edge_features
117 | if "cell_type" in self.node_features:
118 | assert self.node_features.index("cell_type") == 0, "cell_type must be the first node feature"
119 | if "edge_type" in self.edge_features: # distant, neighboring based on 20 micrometer
120 | assert self.edge_features.index("edge_type") == 0, "edge_type must be the first edge feature"
121 |
122 | self.node_feature_names = get_feature_names(node_features,
123 | cell_type_mapping=self.cell_type_mapping,
124 | biomarkers=self.biomarkers)
125 | self.edge_feature_names = get_feature_names(edge_features,
126 | cell_type_mapping=self.cell_type_mapping,
127 | biomarkers=self.biomarkers)
128 |
129 | # Prepare kwargs for node and edge featurization
130 | self.feature_kwargs = feature_kwargs
131 | self.feature_kwargs['cell_type_mapping'] = self.cell_type_mapping
132 | self.feature_kwargs['biomarkers'] = self.biomarkers
133 |
134 | # Note this command below calls the `process` function
135 | super(CellularGraphDataset, self).__init__(root, None, pre_transform)
136 |
137 | # Transformations, e.g. masking features, adding graph-level labels
138 | self.transform = transform
139 |
140 | # SPACE-GM uses n-hop ego graphs (subgraphs) to perform prediction
141 | self.subgraph_size = subgraph_size # number of hops, 0 = use full graph
142 | self.subgraph_source = subgraph_source
143 | self.subgraph_allow_distant_edge = subgraph_allow_distant_edge
144 | self.subgraph_radius_limit = subgraph_radius_limit
145 |
146 | self.N = len(self.processed_paths)
147 | # Sampling frequency for each cell type
148 | self.sampling_freq = {self.cell_type_mapping[ct]: 1. / (self.cell_type_freq[ct] + 1e-5)
149 | for ct in self.cell_type_mapping}
150 | self.sampling_freq = torch.from_numpy(np.array([self.sampling_freq[i] for i in range(len(self.sampling_freq))]))
151 | # Avoid sampling unassigned cell
152 | self.unassigned_cell_type = unassigned_cell_type
153 | if sampling_avoid_unassigned and unassigned_cell_type in self.cell_type_mapping:
154 | self.sampling_freq[self.cell_type_mapping[unassigned_cell_type]] = 0.
155 |
156 | # Cache for subgraphs
157 | self.cached_data = {}
158 |
159 | def set_indices(self, inds=None):
160 | """Limit subgraph sampling to a subset of region indices,
161 | helpful when splitting dataset into training/validation/test regions
162 | """
163 | self._indices = inds
164 | return
165 |
166 | def set_subgraph_source(self, subgraph_source):
167 | """Set subgraph source"""
168 | assert subgraph_source in ['chunk_save', 'on-the-fly']
169 | self.subgraph_source = subgraph_source
170 |
171 | def set_transforms(self, transform=[]):
172 | """Set transformation functions"""
173 | self.transform = transform
174 |
175 | @property
176 | def raw_dir(self) -> str:
177 | return os.path.join(self.root, self.raw_folder_name)
178 |
179 | @property
180 | def processed_dir(self) -> str:
181 | return os.path.join(self.root, self.processed_folder_name)
182 |
183 | @property
184 | def raw_file_names(self):
185 | return sorted([f for f in os.listdir(self.raw_dir) if f.endswith('.gpkl')])
186 |
187 | @property
188 | def processed_file_names(self):
189 | # Only files for full graphs
190 | return sorted([f for f in os.listdir(self.processed_dir) if f.endswith('.gpt') and 'hop' not in f])
191 |
192 | def len(self):
193 | return len(self.processed_paths)
194 |
195 | def process(self):
196 | """Featurize all cellular graphs"""
197 | for raw_path in self.raw_paths:
198 | G = pickle.load(open(raw_path, 'rb'))
199 | region_id = G['layer_1'].region_id
200 | if os.path.exists(os.path.join(self.processed_dir, '%s.0.gpt' % region_id)):
201 | continue
202 |
203 | # Transform networkx graphs to pyg data objects, and add features for nodes and edges
204 | data_list = nx_to_tg_hetero_graph(G,
205 | node_features=self.node_features,
206 | edge_features=self.edge_features,
207 | **self.feature_kwargs)
208 |
209 | for i, d in enumerate(data_list):
210 | try:
211 | assert d.region_id == region_id # graph identifier
212 | assert d['cell'].x.shape[1] == len(self.node_feature_names) # make sure feature dimension matches
213 | assert d['cell', 'geom', 'cell'].edge_attr.shape[1] == len(self.edge_feature_names)
214 | assert d['cell', 'type', 'cell'].edge_attr.shape[1] == len(self.edge_feature_names)
215 | except:
216 | continue
217 | d.component_id = i
218 | if self.pre_transform is not None:
219 | for transform_fn in self.pre_transform:
220 | d = transform_fn(d)
221 | torch.save(d, os.path.join(self.processed_dir, '%s.%d.gpt' % (d.region_id, d.component_id)))
222 | return
223 |
224 | def __getitem__(self, idx):
225 | """Sample a graph/subgraph from the dataset and apply transformations"""
226 | # data = self.get(self.indices()[idx])
227 | data = self.cached_data[idx]
228 | # Apply transformations
229 | for transform_fn in self.transform:
230 | data = transform_fn(data)
231 | return data
232 |
233 | def get(self, idx):
234 | data = self.get_full(idx)
235 | return data
236 |
237 | def get_subgraph(self, idx, center_ind):
238 | """Get a subgraph from the dataset"""
239 | # Check cache
240 | if (idx, center_ind) in self.cached_data:
241 | return self.cached_data[(idx, center_ind)]
242 | if self.subgraph_source == 'on-the-fly':
243 | # Construct subgraph on-the-fly
244 | return self.calculate_subgraph(idx, center_ind)
245 | elif self.subgraph_source == 'chunk_save':
246 | # Load subgraph from pre-saved chunk file
247 | return self.get_saved_subgraph_from_chunk(idx, center_ind)
248 |
249 | def get_full(self, idx):
250 | """Read the full cellular graph of region `idx`"""
251 | if idx in self.cached_data:
252 | return self.cached_data[idx]
253 | else:
254 | data = torch.load(self.processed_paths[idx])
255 | self.cached_data[idx] = data
256 | return data
257 |
258 | def get_full_nx(self, idx):
259 | """Read the full cellular graph (nx.Graph) of region `idx`"""
260 | return pickle.load(open(self.raw_paths[idx], 'rb'))
261 |
262 | def calculate_subgraph(self, idx, center_ind):
263 | """Generate the n-hop subgraph around cell `center_ind` from region `idx`"""
264 | data = self.get_full(idx)
265 | if not self.subgraph_allow_distant_edge:
266 | edge_type_mask = (data.edge_attr[:, 0] == EDGE_TYPES["neighbor"])
267 | else:
268 | edge_type_mask = None
269 | sub_node_inds = k_hop_subgraph(int(center_ind),
270 | self.subgraph_size,
271 | data.edge_index,
272 | edge_type_mask=edge_type_mask,
273 | relabel_nodes=False,
274 | num_nodes=data.x.shape[0])[0]
275 |
276 | if self.subgraph_radius_limit > 0:
277 | # Restrict to neighboring cells that are within the radius (distance to center cell) limit
278 | assert "center_coord" in self.node_features
279 | coord_feature_inds = [i for i, n in enumerate(self.node_feature_names) if n.startswith('center_coord')]
280 | assert len(coord_feature_inds) == 2
281 | center_cell_coord = data.x[[center_ind]][:, coord_feature_inds]
282 | neighbor_cells_coord = data.x[sub_node_inds][:, coord_feature_inds]
283 | dists = ((neighbor_cells_coord - center_cell_coord)**2).sum(1).sqrt()
284 | sub_node_inds = sub_node_inds[(dists < self.subgraph_radius_limit)]
285 |
286 | # Construct subgraphs as separate pyg data objects
287 | sub_x = data.x[sub_node_inds]
288 | sub_edge_index, sub_edge_attr = subgraph(sub_node_inds,
289 | data.edge_index,
290 | edge_attr=data.edge_attr,
291 | relabel_nodes=True)
292 |
293 | relabeled_node_ind = list(sub_node_inds.numpy()).index(center_ind)
294 |
295 | sub_data = {'center_node_index': relabeled_node_ind, # center node index in the subgraph
296 | 'original_center_node': center_ind, # center node index in the original full cellular graph
297 | 'x': sub_x,
298 | 'edge_index': sub_edge_index,
299 | 'edge_attr': sub_edge_attr,
300 | 'num_nodes': len(sub_node_inds)}
301 |
302 | # Assign graph-level attributes
303 | for k in data:
304 | if not k[0] in sub_data:
305 | sub_data[k[0]] = k[1]
306 |
307 | sub_data = tg.data.Data.from_dict(sub_data)
308 | self.cached_data[(idx, center_ind)] = sub_data
309 | return sub_data
310 |
311 | def get_saved_subgraph_from_chunk(self, idx, center_ind):
312 | """Read the n-hop subgraph around cell `center_ind` from region `idx`
313 | Subgraph will be extracted from a pre-saved chunk file, which is generated by calling
314 | `save_all_subgraphs_to_chunk`
315 | """
316 | full_graph_path = self.processed_paths[idx]
317 | subgraphs_path = full_graph_path.replace('.gpt', '.%d-hop.gpt' % self.subgraph_size)
318 | if not os.path.exists(subgraphs_path):
319 | warnings.warn("Subgraph save %s not found" % subgraphs_path)
320 | return self.calculate_subgraph(idx, center_ind)
321 |
322 | subgraphs = torch.load(subgraphs_path)
323 | # Store to cache first
324 | for j, g in enumerate(subgraphs):
325 | self.cached_data[(idx, j)] = g
326 | return self.cached_data[(idx, center_ind)]
327 |
328 | def pick_center(self, data):
329 | """Randomly pick a center cell from a full cellular graph, cell type balanced"""
330 | cell_types = data["x"][:, 0].long()
331 | freq = self.sampling_freq.gather(0, cell_types)
332 | freq = freq / freq.sum()
333 | center_node_ind = np.random.choice(np.arange(len(freq)), p=freq.cpu().data.numpy())
334 | return center_node_ind
335 |
336 | def load_to_cache(self, idx, subgraphs=True):
337 | """Pre-load full cellular graph of region `idx` and all its n-hop subgraphs to cache"""
338 | data = torch.load(self.processed_paths[idx])
339 | self.cached_data[idx] = data
340 | if subgraphs or self.subgraph_source == 'chunk_save':
341 | subgraphs_path = self.processed_paths[idx].replace('.gpt', '.%d-hop.gpt' % self.subgraph_size)
342 | if not os.path.exists(subgraphs_path):
343 | raise FileNotFoundError("Subgraph save %s not found, please run `save_all_subgraphs_to_chunk`."
344 | % subgraphs_path)
345 | neighbor_graphs = torch.load(subgraphs_path)
346 | for j, ng in enumerate(neighbor_graphs):
347 | self.cached_data[(idx, j)] = ng
348 |
349 | def save_all_subgraphs_to_chunk(self):
350 | """Save all n-hop subgraphs for all regions to chunk files (one file per region)"""
351 | for idx, p in enumerate(self.processed_paths):
352 | data = self.get_full(idx)
353 | n_nodes = data.x.shape[0]
354 | neighbor_graph_path = p.replace('.gpt', '.%d-hop.gpt' % self.subgraph_size)
355 | if os.path.exists(neighbor_graph_path):
356 | continue
357 | subgraphs = []
358 | for node_i in range(n_nodes):
359 | subgraphs.append(self.calculate_subgraph(idx, node_i))
360 | torch.save(subgraphs, neighbor_graph_path)
361 | return
362 |
363 | def clear_cache(self):
364 | del self.cached_data
365 | self.cached_data = {}
366 | return
367 |
368 | def plot_subgraph(self, idx, center_ind):
369 | """Plot the n-hop subgraph around cell `center_ind` from region `idx`"""
370 | xcoord_ind = self.node_feature_names.index('center_coord-x')
371 | ycoord_ind = self.node_feature_names.index('center_coord-y')
372 |
373 | _subg = self.calculate_subgraph(idx, center_ind)
374 | coords = _subg.x.data.numpy()[:, [xcoord_ind, ycoord_ind]].astype(float)
375 | x_c, y_c = coords[_subg.center_node_index]
376 |
377 | G = self.get_full_nx(idx)
378 | sub_node_inds = []
379 | for n in G.nodes:
380 | c = np.array(G.nodes[n]['center_coord']).astype(float).reshape((1, -1))
381 | if np.linalg.norm(coords - c, ord=2, axis=1).min() < 1e-2:
382 | sub_node_inds.append(n)
383 | assert len(sub_node_inds) == len(coords)
384 | _G = G.subgraph(sub_node_inds)
385 |
386 | node_colors = [self.cell_type_mapping[_G.nodes[n]['cell_type']] for n in _G.nodes]
387 | node_colors = [matplotlib.cm.tab20(ct) for ct in node_colors]
388 | plot_graph(_G, node_colors=node_colors)
389 | xmin, xmax = plt.gca().xaxis.get_data_interval()
390 | ymin, ymax = plt.gca().yaxis.get_data_interval()
391 |
392 | scale = max(x_c - xmin, xmax - x_c, y_c - ymin, ymax - y_c) * 1.05
393 | plt.xlim(x_c - scale, x_c + scale)
394 | plt.ylim(y_c - scale, y_c + scale)
395 | plt.plot([x_c], [y_c], 'x', markersize=5, color='k')
396 |
397 | def plot_graph_legend(self):
398 | """Legend for cell type colors"""
399 | plt.clf()
400 | plt.figure(figsize=(2, 2))
401 | for ct, i in self.cell_type_mapping.items():
402 | plt.plot([0], [0], '.', label=ct, color=matplotlib.cm.tab20(int(i) % 20))
403 | plt.legend()
404 | plt.plot([0], [0], color='w', markersize=10)
405 | plt.axis('off')
406 |
407 |
408 | def k_hop_subgraph(node_ind,
409 | subgraph_size,
410 | edge_index,
411 | edge_type_mask=None,
412 | relabel_nodes=False,
413 | num_nodes=None):
414 | """A customized k-hop subgraph fn that filter for edge_type
415 |
416 | Args:
417 | node_ind (int): center node index
418 | subgraph_size (int): number of hops for the neighborhood subgraph
419 | edge_index (torch.Tensor): edge index tensor for the full graph
420 | edge_type_mask (torch.Tensor): edge type mask
421 | relabel_nodes (bool): if to relabel node indices to consecutive integers
422 | num_nodes (int): number of nodes in the full graph
423 |
424 | Returns:
425 | subset (torch.LongTensor): indices of nodes in the subgraph
426 | edge_index (torch.LongTensor): edges in the subgraph
427 | inv (toch.LongTensor): location of the center node in the subgraph
428 | edge_mask (torch.BoolTensor): edge mask indicating which edges were preserved
429 | """
430 |
431 | num_nodes = edge_index.max().item() + 1 if num_nodes is None else num_nodes
432 | col, row = edge_index
433 |
434 | node_mask = row.new_empty(num_nodes, dtype=torch.bool)
435 | edge_mask = row.new_empty(row.size(0), dtype=torch.bool)
436 | edge_type_mask = torch.ones_like(edge_mask) if edge_type_mask is None else edge_type_mask
437 |
438 | if isinstance(node_ind, (int, list, tuple)):
439 | node_ind = torch.tensor([node_ind], device=row.device).flatten()
440 | else:
441 | node_ind = node_ind.to(row.device)
442 |
443 | subsets = [node_ind]
444 | next_root = node_ind
445 |
446 | for _ in range(subgraph_size):
447 | node_mask.fill_(False)
448 | node_mask[next_root] = True
449 | torch.index_select(node_mask, 0, row, out=edge_mask)
450 | subsets.append(col[edge_mask])
451 | next_root = col[edge_mask * edge_type_mask] # use nodes connected with mask=True to span
452 |
453 | subset, inv = torch.cat(subsets).unique(return_inverse=True)
454 | inv = inv[:node_ind.numel()]
455 |
456 | node_mask.fill_(False)
457 | node_mask[subset] = True
458 | edge_mask = node_mask[row] & node_mask[col]
459 |
460 | edge_index = edge_index[:, edge_mask]
461 |
462 | if relabel_nodes:
463 | node_ind = row.new_full((num_nodes, ), -1)
464 | node_ind[subset] = torch.arange(subset.size(0), device=row.device)
465 | edge_index = node_ind[edge_index]
466 |
467 | return subset, edge_index, inv, edge_mask
--------------------------------------------------------------------------------
/src/graph_build.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import pandas as pd
4 | import json
5 | import pickle
6 | import matplotlib
7 | import matplotlib.path as mplPath
8 | import matplotlib.pyplot as plt
9 | import networkx as nx
10 | import warnings
11 | from scipy.spatial import Delaunay
12 | from itertools import combinations
13 |
14 | RADIUS_RELAXATION = 0.1
15 | NEIGHBOR_EDGE_CUTOFF = 55 # distance cutoff for neighbor edges, 55 pixels~20 um
16 |
17 |
18 | def plot_voronoi_polygons(voronoi_polygons, voronoi_polygon_colors=None):
19 | """Plot voronoi polygons for the cellular graph
20 |
21 | Args:
22 | voronoi_polygons (nx.Graph/list): cellular graph or list of voronoi polygons
23 | voronoi_polygon_colors (list): list of colors for voronoi polygons
24 | """
25 | if isinstance(voronoi_polygons, nx.Graph):
26 | voronoi_polygons = [voronoi_polygons.nodes[n]['voronoi_polygon'] for n in voronoi_polygons.nodes]
27 |
28 | if voronoi_polygon_colors is None:
29 | voronoi_polygon_colors = ['w'] * len(voronoi_polygons)
30 | assert len(voronoi_polygon_colors) == len(voronoi_polygons)
31 |
32 | xmax = 0
33 | ymax = 0
34 | for polygon, polygon_color in zip(voronoi_polygons, voronoi_polygon_colors):
35 | x, y = polygon[:, 0], polygon[:, 1]
36 | plt.fill(x, y, facecolor=polygon_color, edgecolor='k', linewidth=0.5)
37 | xmax = max(xmax, x.max())
38 | ymax = max(ymax, y.max())
39 |
40 | plt.xlim(0, xmax)
41 | plt.ylim(0, ymax)
42 | return
43 |
44 |
45 | def plot_graph(G, node_colors=None, cell_type=False):
46 | """Plot dot-line graph for the cellular graph
47 |
48 | Args:
49 | G (nx.Graph): full cellular graph of the region
50 | node_colors (list): list of node colors. Defaults to None.
51 | """
52 | # Extract basic node attributes
53 | node_coords = [G.nodes[n]['center_coord'] for n in G.nodes]
54 | node_coords = np.stack(node_coords, 0)
55 |
56 | if node_colors is None:
57 | unique_cell_types = sorted(set([G.nodes[n]['cell_type'] for n in G.nodes]))
58 | cell_type_to_color = {ct: matplotlib.cm.get_cmap("tab20")(i % 20) for i, ct in enumerate(unique_cell_types)}
59 | node_colors = [cell_type_to_color[G.nodes[n]['cell_type']] for n in G.nodes]
60 | assert len(node_colors) == node_coords.shape[0]
61 |
62 | for (i, j, edge_type) in G.edges.data():
63 | xi, yi = G.nodes[i]['center_coord']
64 | xj, yj = G.nodes[j]['center_coord']
65 | if cell_type:
66 | plotting_kwargs = {"c": "k",
67 | "linewidth": 1,
68 | "linestyle": '-'}
69 | else:
70 | if edge_type['edge_type'] == 'neighbor':
71 | plotting_kwargs = {"c": "k",
72 | "linewidth": 1,
73 | "linestyle": '-'}
74 | else:
75 | plotting_kwargs = {"c": (0.4, 0.4, 0.4, 1.0),
76 | "linewidth": 0.3,
77 | "linestyle": '--'}
78 | plt.plot([xi, xj], [yi, yj], zorder=1, **plotting_kwargs)
79 |
80 | plt.scatter(node_coords[:, 0],
81 | node_coords[:, 1],
82 | s=10,
83 | c=node_colors,
84 | linewidths=0.3,
85 | zorder=2)
86 | plt.xlim(0, node_coords[:, 0].max() * 1.01)
87 | plt.ylim(0, node_coords[:, 1].max() * 1.01)
88 | return
89 |
90 |
91 | def load_cell_coords(cell_coords_file):
92 | """Load cell coordinates from file
93 |
94 | Args:
95 | cell_coords_file (str): path to csv file containing cell coordinates
96 |
97 | Returns:
98 | pd.DataFrame: dataframe containing cell coordinates, columns ['CELL_ID', 'X', 'Y']
99 | """
100 | df = pd.read_csv(cell_coords_file)
101 | df.columns = [c.upper() for c in df.columns]
102 | assert 'X' in df.columns, "Cannot find column for X coordinates"
103 | assert 'Y' in df.columns, "Cannot find column for Y coordinates"
104 | if 'CELL_ID' not in df.columns:
105 | warnings.warn("Cannot find column for cell id, using index as cell id")
106 | df['CELL_ID'] = df.index
107 | return df[['CELL_ID', 'X', 'Y']]
108 |
109 |
110 | def load_cell_types(cell_types_file):
111 | """Load cell types from file
112 |
113 | Args:
114 | cell_types_file (str): path to csv file containing cell types
115 |
116 | Returns:
117 | pd.DataFrame: dataframe containing cell types, columns ['CELL_ID', 'CELL_TYPE']
118 | """
119 | df = pd.read_csv(cell_types_file)
120 | df.columns = [c.upper() for c in df.columns]
121 |
122 | cell_type_column = [c for c in df.columns if c != 'CELL_ID']
123 | if len(cell_type_column) == 1:
124 | cell_type_column = cell_type_column[0]
125 | elif 'CELL_TYPE' in cell_type_column:
126 | cell_type_column = 'CELL_TYPE'
127 | elif 'CELL_TYPES' in cell_type_column:
128 | cell_type_column = 'CELL_TYPES'
129 | else:
130 | raise ValueError("Please rename the column for cell type as 'CELL_TYPE'")
131 |
132 | if 'CELL_ID' not in df.columns:
133 | warnings.warn("Cannot find column for cell id, using index as cell id")
134 | df['CELL_ID'] = df.index
135 | _df = df[['CELL_ID', cell_type_column]]
136 | _df.columns = ['CELL_ID', 'CELL_TYPE'] # rename columns for clarity
137 | return _df
138 |
139 |
140 | def load_cell_biomarker_expression(cell_biomarker_expression_file):
141 | """Load cell biomarker expression from file
142 |
143 | Args:
144 | cell_biomarker_expression_file (str): path to csv file containing cell biomarker expression
145 |
146 | Returns:
147 | pd.DataFrame: dataframe containing cell biomarker expression,
148 | columns ['CELL_ID', 'BM-', 'BM-', ...]
149 | """
150 | df = pd.read_csv(cell_biomarker_expression_file)
151 | df.columns = [c.upper() for c in df.columns]
152 | biomarkers = sorted([c for c in df.columns if c != 'CELL_ID'])
153 | for bm in biomarkers:
154 | if df[bm].dtype not in [np.dtype(int), np.dtype(float), np.dtype('float64')]:
155 | warnings.warn("Skipping column %s as it is not numeric" % bm)
156 | biomarkers.remove(bm)
157 |
158 | if 'CELL_ID' not in df.columns:
159 | warnings.warn("Cannot find column for cell id, using index as cell id")
160 | df['CELL_ID'] = df.index
161 | _df = df[['CELL_ID'] + biomarkers]
162 | _df.columns = ['CELL_ID'] + ['BM-%s' % bm for bm in biomarkers]
163 | return _df
164 |
165 |
166 | def load_cell_features(cell_features_file):
167 | """Load additional cell features from file
168 |
169 | Args:
170 | cell_features_file (str): path to csv file containing additional cell features
171 |
172 | Returns:
173 | pd.DataFrame: dataframe containing cell features
174 | columns ['CELL_ID', '', '', ...]
175 | """
176 | df = pd.read_csv(cell_features_file)
177 | df.columns = [c.upper() for c in df.columns]
178 |
179 | feature_columns = sorted([c for c in df.columns if c != 'CELL_ID'])
180 | for feat in feature_columns:
181 | if df[feat].dtype not in [np.dtype(int), np.dtype(float), np.dtype('float64')]:
182 | warnings.warn("Skipping column %s as it is not numeric" % feat)
183 | feature_columns.remove(feat)
184 |
185 | if 'CELL_ID' not in df.columns:
186 | warnings.warn("Cannot find column for cell id, using index as cell id")
187 | df['CELL_ID'] = df.index
188 |
189 | return df[['CELL_ID'] + feature_columns]
190 |
191 |
192 | def read_raw_voronoi(voronoi_file):
193 | """Read raw coordinates of voronoi polygons from file
194 |
195 | Args:
196 | voronoi_file (str): path to the voronoi polygon file
197 |
198 | Returns:
199 | voronoi_polygons (list): list of voronoi polygons,
200 | represented by the coordinates of their exterior vertices
201 | """
202 | if voronoi_file.endswith('json'):
203 | with open(voronoi_file) as f:
204 | raw_voronoi_polygons = json.load(f)
205 | elif voronoi_file.endswith('.pkl'):
206 | with open(voronoi_file, 'rb') as f:
207 | raw_voronoi_polygons = pickle.load(f)
208 |
209 | voronoi_polygons = []
210 | for i, polygon in enumerate(raw_voronoi_polygons):
211 | if isinstance(polygon, list):
212 | polygon = np.array(polygon).reshape((-1, 2))
213 | elif isinstance(polygon, dict):
214 | assert len(polygon) == 1
215 | polygon = list(polygon.values())[0]
216 | polygon = np.array(polygon).reshape((-1, 2))
217 | voronoi_polygons.append(polygon)
218 | return voronoi_polygons
219 |
220 |
221 | def calcualte_voronoi_from_coords(x, y, xmax=None, ymax=None):
222 | """Calculate voronoi polygons from a set of points
223 |
224 | Points are assumed to have coordinates in ([0, xmax], [0, ymax])
225 |
226 | Args:
227 | x (array-like): x coordinates of points
228 | y (array-like): y coordinates of points
229 | xmax (float): maximum x coordinate
230 | ymax (float): maximum y coordinate
231 |
232 | Returns:
233 | voronoi_polygons (list): list of voronoi polygons,
234 | represented by the coordinates of their exterior vertices
235 | """
236 | from geovoronoi import voronoi_regions_from_coords
237 | from shapely import geometry
238 | xmax = 1.01 * max(x) if xmax is None else xmax
239 | ymax = 1.01 * max(y) if ymax is None else ymax
240 | boundary = geometry.Polygon([[0, 0], [xmax, 0], [xmax, ymax], [0, ymax]])
241 | coords = np.stack([
242 | np.array(x).reshape((-1,)),
243 | np.array(y).reshape((-1,))], 1)
244 | region_polys, _ = voronoi_regions_from_coords(coords, boundary)
245 | voronoi_polygons = [np.array(list(region_polys[k].exterior.coords)) for k in region_polys]
246 | return voronoi_polygons
247 |
248 |
249 | def build_graph_from_cell_coords(cell_data, voronoi_polygons):
250 | """Construct a networkx graph based on cell coordinates
251 |
252 | Args:
253 | cell_data (pd.DataFrame): dataframe containing cell data,
254 | columns ['CELL_ID', 'X', 'Y', ...]
255 | voronoi_polygons (list): list of voronoi polygons,
256 | represented by the coordinates of their exterior vertices
257 |
258 | Returns:
259 | G (nx.Graph): full cellular graph of the region
260 | """
261 | save_polygon = True
262 | if not len(cell_data) == len(voronoi_polygons):
263 | warnings.warn("Number of cells does not match number of voronoi polygons")
264 | save_polygon = False
265 |
266 | coord_ar = np.array(cell_data[['CELL_ID', 'X', 'Y']])
267 | G = nx.Graph()
268 | node_to_cell_mapping = {}
269 | for i, row in enumerate(coord_ar):
270 | vp = voronoi_polygons[i] if save_polygon else None
271 | G.add_node(i, voronoi_polygon=vp)
272 | node_to_cell_mapping[i] = row[0]
273 |
274 | dln = Delaunay(coord_ar[:, 1:3])
275 | neighbors = [set() for _ in range(len(coord_ar))]
276 | for t in dln.simplices:
277 | for v in t:
278 | neighbors[v].update(t)
279 |
280 | for i, ns in enumerate(neighbors):
281 | for n in ns:
282 | G.add_edge(int(i), int(n))
283 |
284 | return G, node_to_cell_mapping
285 |
286 |
287 | def build_graph_from_voronoi_polygons(voronoi_polygons, radius_relaxation=RADIUS_RELAXATION):
288 | """Construct a networkx graph based on voronoi polygons
289 |
290 | Args:
291 | voronoi_polygons (list): list of voronoi polygons,
292 | represented by the coordinates of their exterior vertices
293 |
294 | Returns:
295 | G (nx.Graph): full cellular graph of the region
296 | """
297 | G = nx.Graph()
298 |
299 | polygon_vertices = []
300 | vertice_identities = []
301 | for i, polygon in enumerate(voronoi_polygons):
302 | G.add_node(i, voronoi_polygon=polygon)
303 | polygon_vertices.append(polygon)
304 | vertice_identities.append(np.ones((polygon.shape[0],)) * i)
305 |
306 | polygon_vertices = np.concatenate(polygon_vertices, 0)
307 | vertice_identities = np.concatenate(vertice_identities, 0).astype(int)
308 | for i, polygon in enumerate(voronoi_polygons):
309 | path = mplPath.Path(polygon)
310 | points_inside = np.where(path.contains_points(polygon_vertices, radius=radius_relaxation) +
311 | path.contains_points(polygon_vertices, radius=-radius_relaxation))[0]
312 | id_inside = set(vertice_identities[points_inside])
313 | for j in id_inside:
314 | if j > i:
315 | G.add_edge(int(i), int(j))
316 | return G
317 |
318 |
319 | def build_voronoi_polygon_to_cell_mapping(G, voronoi_polygons, cell_data):
320 | """Construct 1-to-1 mapping between voronoi polygons and cells
321 |
322 | Args:
323 | G (nx.Graph): full cellular graph of the region
324 | voronoi_polygons (list): list of voronoi polygons,
325 | represented by the coordinates of their exterior vertices
326 | cell_data (pd.DataFrame): dataframe containing cellular data
327 |
328 | Returns:
329 | voronoi_polygon_to_cell_mapping (dict): 1-to-1 mapping between
330 | polygon index (also node index in `G`) and cell id
331 | """
332 | cell_coords = np.array(list(zip(cell_data['X'], cell_data['Y']))).reshape((-1, 2))
333 | # Fetch all cells within each polygon
334 | cells_in_polygon = {}
335 | for i, polygon in enumerate(voronoi_polygons):
336 | path = mplPath.Path(polygon)
337 | _cell_ids = cell_data.iloc[np.where(path.contains_points(cell_coords))[0]]
338 | _cells = list(_cell_ids[['CELL_ID', 'X', 'Y']].values)
339 | cells_in_polygon[i] = _cells
340 |
341 | def get_point_reflection(c1, c2, c3):
342 | # Reflection of point c1 across line defined by c2 & c3
343 | x1, y1 = c1
344 | x2, y2 = c2
345 | x3, y3 = c3
346 | if x2 == x3:
347 | return (2 * x2 - x1, y1)
348 | m = (y3 - y2) / (x3 - x2)
349 | c = (x3 * y2 - x2 * y3) / (x3 - x2)
350 | d = (float(x1) + (float(y1) - c) * m) / (1 + m**2)
351 | x4 = 2 * d - x1
352 | y4 = 2 * d * m - y1 + 2 * c
353 | return (x4, y4)
354 |
355 | # Establish 1-to-1 mapping between polygons and cell ids
356 | voronoi_polygon_to_cell_mapping = {}
357 | for i, polygon in enumerate(voronoi_polygons):
358 | path = mplPath.Path(polygon)
359 | if len(cells_in_polygon[i]) == 1:
360 | # A single polygon contains a single cell centroid, assign cell id
361 | voronoi_polygon_to_cell_mapping[i] = cells_in_polygon[i][0][0]
362 |
363 | elif len(cells_in_polygon[i]) == 0:
364 | # Skipping polygons that do not contain any cell centroids
365 | continue
366 |
367 | else:
368 | # A single polygon contains multiple cell centroids
369 | polygon_edges = [(polygon[_i], polygon[_i + 1]) for _i in range(-1, len(polygon) - 1)]
370 | # Use the reflection of neighbor polygon's center cell
371 | neighbor_cells = sum([cells_in_polygon[j] for j in G.neighbors(i)], [])
372 | reflection_points = np.concatenate(
373 | [[get_point_reflection(cell[1:], edge[0], edge[1]) for edge in polygon_edges]
374 | for cell in neighbor_cells], 0)
375 | reflection_points = reflection_points[np.where(path.contains_points(reflection_points))]
376 | # Reflection should be very close to the center cell
377 | dists = [((reflection_points - c[1:])**2).sum(1).min(0) for c in cells_in_polygon[i]]
378 | if not np.min(dists) < 0.01:
379 | warnings.warn("Cannot find the exact center cell for polygon %d" % i)
380 | voronoi_polygon_to_cell_mapping[i] = cells_in_polygon[i][np.argmin(dists)][0]
381 |
382 | return voronoi_polygon_to_cell_mapping
383 |
384 |
385 | def assign_attributes(G, cell_data, node_to_cell_mapping, _assert=True):
386 | """Assign node and edge attributes to the cellular graph
387 |
388 | Args:
389 | G (nx.Graph): full cellular graph of the region
390 | cell_data (pd.DataFrame): dataframe containing cellular data
391 | node_to_cell_mapping (dict): 1-to-1 mapping between
392 | node index in `G` and cell id
393 |
394 | Returns:
395 | nx.Graph: populated cellular graph
396 | """
397 | if _assert:
398 | assert set(G.nodes) == set(node_to_cell_mapping.keys())
399 | biomarkers = sorted([c for c in cell_data.columns if c.startswith('BM-')])
400 |
401 | additional_features = sorted([
402 | c for c in cell_data.columns if c not in biomarkers + ['CELL_ID', 'X', 'Y', 'CELL_TYPE']])
403 |
404 | cell_to_node_mapping = {v: k for k, v in node_to_cell_mapping.items()}
405 | node_properties = {}
406 | for _, cell_row in cell_data.iterrows():
407 | cell_id = cell_row['CELL_ID']
408 | if cell_id not in cell_to_node_mapping:
409 | continue
410 | node_index = cell_to_node_mapping[cell_id]
411 | p = {"cell_id": cell_id}
412 | p["center_coord"] = (cell_row['X'], cell_row['Y'])
413 | if "CELL_TYPE" in cell_row:
414 | p["cell_type"] = cell_row["CELL_TYPE"]
415 | else:
416 | p["cell_type"] = "Unassigned"
417 | biomarker_expression_dict = {bm.split('BM-')[1]: cell_row[bm] for bm in biomarkers}
418 | p["biomarker_expression"] = biomarker_expression_dict
419 | for feat_name in additional_features:
420 | p[feat_name] = cell_row[feat_name]
421 | node_properties[node_index] = p
422 |
423 | G = G.subgraph(node_properties.keys())
424 | nx.set_node_attributes(G, node_properties)
425 |
426 | # Add distance, edge type (by thresholding) to edge feature
427 | edge_properties = get_edge_type(G)
428 | nx.set_edge_attributes(G, edge_properties)
429 |
430 | return G
431 |
432 |
433 | def get_edge_type(G, neighbor_edge_cutoff=NEIGHBOR_EDGE_CUTOFF):
434 | """Define neighbor vs distant edges based on distance
435 |
436 | Args:
437 | G (nx.Graph): full cellular graph of the region
438 | neighbor_edge_cutoff (float): distance cutoff for neighbor edges.
439 | By default we use 55 pixels (~20 um)
440 |
441 | Returns:
442 | dict: edge properties
443 | """
444 | edge_properties = {}
445 | for (i, j) in G.edges:
446 | ci = G.nodes[i]['center_coord']
447 | cj = G.nodes[j]['center_coord']
448 | dist = np.linalg.norm(np.array(ci) - np.array(cj), ord=2)
449 | edge_properties[(i, j)] = {
450 | "distance": dist,
451 | "edge_type": "neighbor" if dist < neighbor_edge_cutoff else "distant"
452 | }
453 | '''
454 | if ('center_coord' in G.nodes[i].keys()) & ('center_coord' in G.nodes[j].keys()):
455 | ci = G.nodes[i]['center_coord']
456 | cj = G.nodes[j]['center_coord']
457 | dist = np.linalg.norm(np.array(ci) - np.array(cj), ord=2)
458 | edge_properties[(i, j)] = {
459 | "distance": dist,
460 | "edge_type": "neighbor" if dist < neighbor_edge_cutoff else "distant"
461 | }
462 | else:
463 | G.remove_edge(i, j)
464 | # if 'center_coord' not in G.nodes[i].keys():
465 | # G.remove_node(i)
466 | # else:
467 | # G.remove_node(j)
468 |
469 | # G.remove_edge(i, j)
470 | '''
471 | return edge_properties
472 |
473 |
474 | def merge_cell_dataframes(df1, df2):
475 | """Merge two cell dataframes on shared rows (cells)"""
476 | if set(df2['CELL_ID']) != set(df1['CELL_ID']):
477 | warnings.warn("Cell ids in the two dataframes do not match")
478 | shared_cell_ids = set(df2['CELL_ID']).intersection(set(df1['CELL_ID']))
479 | df1 = df1[df1['CELL_ID'].isin(shared_cell_ids)]
480 | df1 = df1.merge(df2, on='CELL_ID')
481 | return df1
482 |
483 |
484 | def construct_graph_for_region(region_id,
485 | cell_coords_file=None,
486 | cell_types_file=None,
487 | cell_biomarker_expression_file=None,
488 | cell_features_file=None,
489 | voronoi_file=None,
490 | graph_source='polygon',
491 | graph_output=None,
492 | voronoi_polygon_img_output=None,
493 | graph_img_output=None,
494 | common_cell_type_dict=None,
495 | common_biomarker_list=None,
496 | figsize=10):
497 | """Construct cellular graph for a region
498 |
499 | Args:
500 | region_id (str): region id
501 | cell_coords_file (str): path to csv file containing cell coordinates
502 | cell_types_file (str): path to csv file containing cell types/annotations
503 | cell_biomarker_expression_file (str): path to csv file containing cell biomarker expression
504 | cell_features_file (str): path to csv file containing additional cell features
505 | Note that features stored in this file can only be numeric and
506 | will be saved and used as is.
507 | voronoi_file (str): path to the voronoi coordinates file
508 | graph_source (str): source of edges in the graph, either "polygon" or "cell"
509 | graph_output (str): path for saving cellular graph as gpickle
510 | voronoi_polygon_img_output (str): path for saving voronoi image
511 | graph_img_output (str): path for saving dot-line graph image
512 | figsize (int): figure size for plotting
513 |
514 | Returns:
515 | G (nx.Graph): full cellular graph of the region
516 | """
517 | assert cell_coords_file is not None, "cell coordinates must be provided"
518 | cell_data = load_cell_coords(cell_coords_file)
519 |
520 | if voronoi_file is None:
521 | # Calculate voronoi polygons based on cell coordinates
522 | voronoi_polygons = calcualte_voronoi_from_coords(cell_data['X'], cell_data['Y'])
523 | else:
524 | # Load voronoi polygons from file
525 | voronoi_polygons = read_raw_voronoi(voronoi_file)
526 |
527 | if cell_types_file is not None:
528 | # Load cell types
529 | cell_types = load_cell_types(cell_types_file)
530 | if common_cell_type_dict != None:
531 | cell_types['CELL_TYPE'] = cell_types['CELL_TYPE'].map(common_cell_type_dict)
532 | cell_data = merge_cell_dataframes(cell_data, cell_types)
533 |
534 | if cell_biomarker_expression_file is not None:
535 | # Load cell biomarker expression
536 | cell_expression = load_cell_biomarker_expression(cell_biomarker_expression_file)
537 | if common_biomarker_list != None:
538 | columns_to_keep = ['CELL_ID'] + ['BM-' + suffix.upper() for suffix in common_biomarker_list if suffix != 'CELL_ID']
539 | cell_expression = cell_expression[columns_to_keep]
540 | cell_data = merge_cell_dataframes(cell_data, cell_expression)
541 |
542 | if cell_features_file is not None:
543 | # Load additional cell features
544 | additional_cell_features = load_cell_features(cell_features_file)
545 | cell_data = merge_cell_dataframes(cell_data, additional_cell_features)
546 |
547 | if graph_source == 'polygon':
548 | # Build initial cellular graph
549 | G = build_graph_from_voronoi_polygons(voronoi_polygons)
550 | # Construct matching between voronoi polygons and cells
551 | node_to_cell_mapping = build_voronoi_polygon_to_cell_mapping(G, voronoi_polygons, cell_data)
552 | # Prune graph to contain only voronoi polygons that have corresponding cells
553 | G = G.subgraph(node_to_cell_mapping.keys())
554 | elif graph_source == 'cell':
555 | G, node_to_cell_mapping = build_graph_from_cell_coords(cell_data, voronoi_polygons)
556 | else:
557 | raise ValueError("graph_source must be either 'polygon' or 'cell'")
558 |
559 | # Assign attributes to cellular graph
560 | G = assign_attributes(G, cell_data, node_to_cell_mapping)
561 | G.region_id = region_id
562 |
563 | # Build Multiplex Network
564 | cell_to_node_mapping = {value: key for key, value in node_to_cell_mapping.items()}
565 | cell_type_graph = nx.Graph()
566 | cell_type_graph.add_nodes_from(G.nodes(data=True)) # This copies nodes with their features
567 |
568 | grouped = cell_data.groupby('CELL_TYPE')
569 |
570 | for cell_type, group in grouped:
571 | nodes = group['CELL_ID']
572 | nodes = [cell_to_node_mapping[cell_id] for cell_id in nodes if cell_id in cell_to_node_mapping] # Convert cell IDs to node indices
573 | edges = combinations(nodes, 2) # Create combinations of nodes
574 | cell_type_graph.add_edges_from(edges)
575 |
576 | # Assign attributes to cell type graph
577 | cell_type_graph = assign_attributes(cell_type_graph, cell_data, node_to_cell_mapping, _assert=False)
578 | cell_type_graph.region_id = region_id
579 |
580 | multiplex_network = {
581 | 'layer_1': G,
582 | 'layer_2': cell_type_graph
583 | }
584 |
585 | # Visualization of cellular graph
586 | if voronoi_polygon_img_output is not None:
587 | plt.clf()
588 | plt.figure(figsize=(figsize, figsize))
589 | plot_voronoi_polygons(G)
590 | plt.axis('scaled')
591 | plt.savefig(voronoi_polygon_img_output, dpi=300, bbox_inches='tight')
592 |
593 | if graph_img_output is not None:
594 | plt.clf()
595 | plt.figure(figsize=(figsize, figsize))
596 | plot_graph(G)
597 | plt.axis('scaled')
598 | plt.savefig(graph_img_output[0], dpi=300, bbox_inches='tight')
599 | plt.clf()
600 | plt.figure(figsize=(figsize, figsize))
601 | plot_graph(cell_type_graph, cell_type=True)
602 | plt.axis('scaled')
603 | plt.savefig(graph_img_output[1], dpi=300, bbox_inches='tight')
604 |
605 | # Save graph to file
606 | if graph_output is not None:
607 | with open(graph_output, 'wb') as f:
608 | pickle.dump(multiplex_network, f)
609 |
610 | return multiplex_network
611 |
612 |
613 | if __name__ == "__main__":
614 | raw_data_root = "data/voronoi/"
615 | nx_graph_root = "data/example_dataset/graph"
616 | fig_save_root = "data/example_dataset/fig"
617 | os.makedirs(nx_graph_root, exist_ok=True)
618 | os.makedirs(fig_save_root, exist_ok=True)
619 |
620 | region_ids = sorted(set(f.split('.')[0] for f in os.listdir(raw_data_root)))
621 |
622 | for region_id in region_ids:
623 | print("Processing %s" % region_id)
624 | cell_coords_file = os.path.join(raw_data_root, "%s.cell_data.csv" % region_id)
625 | cell_types_file = os.path.join(raw_data_root, "%s.cell_types.csv" % region_id)
626 | cell_biomarker_expression_file = os.path.join(raw_data_root, "%s.expression.csv" % region_id)
627 | cell_features_file = os.path.join(raw_data_root, "%s.cell_features.csv" % region_id)
628 | voronoi_file = os.path.join(raw_data_root, "%s.json" % region_id)
629 |
630 | voronoi_img_output = os.path.join(fig_save_root, "%s_voronoi.png" % region_id)
631 | graph_img_output = os.path.join(fig_save_root, "%s_graph.png" % region_id)
632 | graph_output = os.path.join(nx_graph_root, "%s.gpkl" % region_id)
633 |
634 | if not os.path.exists(graph_output):
635 | G = construct_graph_for_region(
636 | region_id,
637 | cell_coords_file=cell_coords_file,
638 | cell_types_file=cell_types_file,
639 | cell_biomarker_expression_file=cell_biomarker_expression_file,
640 | cell_features_file=cell_features_file,
641 | voronoi_file=voronoi_file,
642 | graph_output=graph_output,
643 | voronoi_polygon_img_output=voronoi_img_output,
644 | graph_img_output=graph_img_output,
645 | figsize=10)
646 |
--------------------------------------------------------------------------------