├── .gitignore
├── README.md
├── assets
└── infoalign.png
├── ckpt
└── pretrain.yaml
├── configures
├── arguments.py
└── finetune.yaml
├── convert_to_dataset.py
├── dataset
├── __init__.py
├── context_graph.py
├── create_datasets.py
├── data_utils.py
├── prediction_molecule.py
├── pretrain_context.py
├── pretrain_molecule.py
└── retrieval.py
├── main.py
├── models
├── conv.py
└── gnn.py
├── raw_data
├── biogenadme
│ └── raw
│ │ ├── ADME_public_set_3521.csv
│ │ └── assays.csv.gz
├── broad6k
│ └── raw
│ │ ├── CP-Bray.csv.gz
│ │ ├── GE.csv.gz
│ │ ├── assays.csv.gz
│ │ └── structure.csv.gz
├── chembl2k
│ └── raw
│ │ ├── CP-JUMP.csv.gz
│ │ ├── GE.csv.gz
│ │ ├── assays.csv.gz
│ │ └── structure.csv.gz
└── moltoxcast
│ └── raw
│ └── assays.csv.gz
├── requirements.txt
└── utils
├── __init__.py
├── misc.py
└── train_funcs.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # General
2 | .DS_Store
3 | *.log
4 | npm-debug.log*
5 | yarn-debug.log*
6 | yarn-error.log*
7 |
8 | # Virtualenv
9 | *.pyc
10 | *.pyo
11 | *.pyd
12 | .Python
13 | env/
14 | venv/
15 | .python-version
16 | pip-log.txt
17 | pip-delete-this-directory.txt
18 | htmlcov/
19 | .tox/
20 | .coverage
21 | .coverage.*
22 | .cache
23 | nosetests.xml
24 | coverage.xml
25 | *.cover
26 | .hypothesis/
27 | .pytest_cache/
28 |
29 | # Jupyter Notebook
30 | .ipynb_checkpoints
31 |
32 | # macOS specific
33 | .DS_Store
34 |
35 | # IDE specific
36 | .idea/
37 | .vscode/
38 | *.swp
39 | *.swo
40 | *~
41 |
42 |
43 | ## just in case
44 | **/node_modules
45 | **/dist
46 | **/build
47 | **/coverage
48 | **/.cache
49 | **/.env
50 | **/.env.local
51 | **/.env.development.local
52 | **/.env.test.local
53 | **/.env.production.local
54 | **/instance
55 | **/.venv
56 | **/env
57 | **/venv
58 | **/.flaskenv
59 | **/*.pyc
60 | **/*.pyo
61 | **/*.pyd
62 | **/.Python
63 | **/env/
64 | **/venv/
65 | **/.python-version
66 | **/pip-log.txt
67 | **/pip-delete-this-directory.txt
68 | **/htmlcov/
69 | **/.tox/
70 | **/.coverage
71 | **/.coverage.*
72 | **/.cache
73 | **/nosetests.xml
74 | **/coverage.xml
75 | **/*.cover
76 | **/.hypothesis/
77 | **/.pytest_cache/
78 | **/.ipynb_checkpoints
79 | **/.DS_Store
80 | **/.idea/
81 | **/.vscode/
82 | **/*.swp
83 | **/*.swo
84 | **/__pycache__/
85 |
86 | # customized
87 | **/*.pt
88 | **/*.npz
89 | *.pt
90 | *.npz
91 | **/pretrain*
92 | **/*.zip
93 | requirements copy.txt
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Learning Molecular Representation in a Cell
2 |
3 | **InfoAlign** learns molecular representations from bottleneck information derived from molecular structures, cell morphology, and gene expressions. For more details, please refer to our [paper](https://arxiv.org/abs/2406.12056v3).
4 |
5 |
6 |
7 |
8 |
9 |
10 | ---
11 | ## Update on March 6, 2025:
12 | - Added `convert_to_dataset.py`, which can be used to convert CSV files into a dataset for fine-tuning. For example:
13 |
14 | ```bash
15 | python convert_to_dataset.py --task_name chembl2k_sub --csv_path raw_data/chembl2k/raw/assays.csv.gz --smiles_column smiles --property_columns ABCB1 ABL1
16 | ## Update on Oct 14, 2024:
17 | - All packages can now be installed with `pip install -r requirements.txt`!
18 | - We have **automated** the model and data download process for ML developers. The InfoAlign model can now be trained with a single command!
19 | - We have created the `infoalign` package, which can be installed via `pip install infoalign`. For more details, refer to: [https://github.com/liugangcode/infoalign-package](https://github.com/liugangcode/infoalign-package).
20 |
21 | ---
22 |
23 |
24 | ## Requirements
25 |
26 | This project was developed and tested with the following versions:
27 |
28 | - **Python**: 3.11.7
29 | - **PyTorch**: 2.2.0+cu118
30 | - **Torch Geometric**: 2.6.1
31 |
32 | All dependencies are listed in the `requirements.txt` file.
33 |
34 | ### Setup Instructions
35 |
36 | 1. **Create a Conda Environment**:
37 | ```bash
38 | conda create --name infoalign python=3.11.7
39 | ```
40 |
41 | 2. **Activate the Environment**:
42 | ```bash
43 | conda activate infoalign
44 | ```
45 |
46 | 3. **Install Dependencies**:
47 | ```bash
48 | pip install -r requirements.txt
49 | ```
50 |
51 | ## Usage
52 |
53 | ### Fine-tuning
54 |
55 | We provide a pretrained checkpoint available for download from [Hugging Face](https://huggingface.co/liuganghuggingface/InfoAlign-Pretrained). For fine-tuning and inference, use the following commands. The pretrained model will be automatically downloaded to the `ckpt/pretrain.pt` file by default.
56 |
57 | ```bash
58 | python main.py --model-path ckpt/pretrain.pt --dataset finetune-chembl2k
59 | python main.py --model-path ckpt/pretrain.pt --dataset finetune-broad6k
60 | python main.py --model-path ckpt/pretrain.pt --dataset finetune-biogenadme
61 | python main.py --model-path ckpt/pretrain.pt --dataset finetune-moltoxcast
62 | ```
63 |
64 | Alternatively, you can manually download the model weights and place the `pretrain.pt` file under the `ckpt` folder along with its corresponding YAML configuration file.
65 |
66 | **Note**: If you wish to access the cell morphology and gene expression features in the ChEMBL2k and Broad6K datasets for baseline evaluation, visit our [Hugging Face repository](https://huggingface.co/liuganghuggingface/InfoAlign-Pretrained) to download these features.
67 |
68 | ### Pretraining
69 |
70 | To pretrain the model from scratch, execute the following command:
71 |
72 | ```bash
73 | python main.py --model-path "ckpt/pretrain.pt" --lr 1e-4 --wdecay 1e-8 --batch-size 3072
74 | ```
75 |
76 | This will automatically download the pretraining dataset from [Hugging Face](https://huggingface.co/datasets/liuganghuggingface/InfoAlign-Data). If you prefer to download the dataset manually, place all pretraining data files in the `raw_data/pretrain/raw` folder.
77 |
78 | The pretrained model will be saved in the `ckpt` folder as `pretrain.pt`.
79 |
80 | ---
81 |
82 | ## Data source
83 |
84 | For readers interested in data collection, here are the sources:
85 |
86 | 1. **Cell Morphology Data**
87 | - JUMP dataset: The data are from "JUMP Cell Painting dataset: morphological impact of 136,000 chemical and genetic perturbations" and can be downloaded [here](https://github.com/jump-cellpainting/datasets/blob/1c245002cbcaea9156eea56e61baa52ad8307db3/profile_index.csv). The dataset includes chemical and genetic perturbations for cell morphology features.
88 | - Bray's dataset: "A dataset of images and morphological profiles of 30,000 small-molecule treatments using the Cell Painting assay". Download from [GigaDB](http://gigadb.org/dataset/100351). Processed version available on [Zenodo](https://zenodo.org/records/7589312).
89 |
90 | 2. **Gene Expression Data**
91 | - LINCS L1000 gene expression data from the paper "Drug-induced adverse events prediction with the LINCS L1000 data": [Data](https://maayanlab.net/SEP-L1000/#download).
92 |
93 | 3. **Relationships**
94 | - Gene-gene, gene-compound relationships from Hetionet: [Data](https://github.com/hetio/hetionet).
95 |
96 | # Citation
97 |
98 | If you find this repository useful, please cite our paper:
99 |
100 | ```
101 | @article{liu2024learning,
102 | title={Learning Molecular Representation in a Cell},
103 | author={Liu, Gang and Seal, Srijit and Arevalo, John and Liang, Zhenwen and Carpenter, Anne E and Jiang, Meng and Singh, Shantanu},
104 | journal={arXiv preprint arXiv:2406.12056},
105 | year={2024}
106 | }
107 | ```
108 |
109 |
--------------------------------------------------------------------------------
/assets/infoalign.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liugangcode/InfoAlign/eed96f36503a8a427fac4f382ee2001f957dd42a/assets/infoalign.png
--------------------------------------------------------------------------------
/ckpt/pretrain.yaml:
--------------------------------------------------------------------------------
1 | emb_dim: 300
2 | model: gin-virtual
3 | norm_layer: batch_norm
4 | num_layer: 5
5 | prior: 1.0e-09
6 | readout: sum
7 | threshold: 0.8
8 | walk_length: 4
9 |
--------------------------------------------------------------------------------
/configures/arguments.py:
--------------------------------------------------------------------------------
1 | import yaml
2 | import argparse
3 |
4 | model_hyperparams = [
5 | "emb_dim",
6 | "model",
7 | "num_layer",
8 | "readout",
9 | "norm_layer",
10 | "threshold",
11 | "walk_length",
12 | "prior",
13 | ]
14 |
15 | def load_arguments_from_yaml(filename, model_only=False):
16 | with open(filename, "r") as file:
17 | config = yaml.safe_load(file)
18 | if model_only:
19 | config = {k: v for k, v in config.items() if k in model_hyperparams}
20 | else:
21 | config = yaml.safe_load(file)
22 | return config
23 |
24 |
25 | def save_arguments_to_yaml(args, filename, model_only=False):
26 | if model_only:
27 | args = {k: v for k, v in vars(args).items() if k in model_hyperparams}
28 | else:
29 | args = vars(args)
30 |
31 | with open(filename, "w") as f:
32 | yaml.dump(args, f)
33 |
34 |
35 | def get_args():
36 | parser = argparse.ArgumentParser(
37 | description="Learning molecular representation in a cell"
38 | )
39 | parser.add_argument(
40 | "--gpu-id", type=int, default=0, help="which gpu to use if any (default: 0)"
41 | )
42 | parser.add_argument(
43 | "--num-workers", type=int, default=0, help="number of workers for data loader"
44 | )
45 | parser.add_argument(
46 | "--no-print", action="store_true", default=False, help="don't use progress bar"
47 | )
48 |
49 | parser.add_argument("--dataset", default="pretrain", type=str, help="dataset name")
50 |
51 | # model
52 | parser.add_argument(
53 | "--model",
54 | type=str,
55 | default="gin-virtual",
56 | help="GNN gin, gin-virtual, or gcn, or gcn-virtual (default: gin-virtual)",
57 | )
58 | parser.add_argument(
59 | "--readout", type=str, default="sum", help="graph readout (default: sum)"
60 | )
61 | parser.add_argument(
62 | "--norm-layer",
63 | type=str,
64 | default="batch_norm",
65 | help="GNN gin, gin-virtual, or gcn, or gcn-virtual (default: gin-virtual)",
66 | )
67 | parser.add_argument(
68 | "--drop-ratio", type=float, default=0.5, help="dropout ratio (default: 0.5)"
69 | )
70 | parser.add_argument(
71 | "--num-layer",
72 | type=int,
73 | default=5,
74 | help="number of GNN message passing layers (default: 5)",
75 | )
76 | parser.add_argument(
77 | "--emb-dim",
78 | type=int,
79 | default=300,
80 | help="dimensionality of hidden units in GNNs (default: 300)",
81 | )
82 | # training
83 | ## pretraining
84 | parser.add_argument(
85 | "--walk-length",
86 | type=int,
87 | default=4,
88 | help="pretraining context length",
89 | )
90 | parser.add_argument(
91 | "--threshold",
92 | type=float,
93 | default=0.8,
94 | help="minimum similarity threshold for context graph",
95 | )
96 | # prior
97 | parser.add_argument(
98 | "--prior",
99 | type=float,
100 | default=1e-9,
101 | help="loss weight to prior",
102 | )
103 |
104 | ## other
105 | parser.add_argument(
106 | "--batch-size",
107 | type=int,
108 | default=5120,
109 | help="input batch size for training (default: 256)",
110 | )
111 | parser.add_argument(
112 | "--lr",
113 | "--learning-rate",
114 | type=float,
115 | default=1e-3,
116 | help="Learning rate (default: 1e-3)",
117 | )
118 | parser.add_argument("--wdecay", default=1e-5, type=float, help="weight decay")
119 | parser.add_argument(
120 | "--epochs", type=int, default=300, help="number of epochs to train"
121 | )
122 | parser.add_argument(
123 | "--initw-name",
124 | type=str,
125 | default="default",
126 | help="method to initialize the model paramter",
127 | )
128 | parser.add_argument(
129 | "--model-path",
130 | type=str,
131 | default="ckpt/pretrain.pt",
132 | help="path to the pretrained model",
133 | )
134 | parser.add_argument(
135 | "--patience", type=int, default=50, help="patience for early stop"
136 | )
137 |
138 | args = parser.parse_args()
139 | print("no print", args.no_print)
140 |
141 | ## n_steps for solver
142 | args.n_steps = 1
143 | return args
144 |
--------------------------------------------------------------------------------
/configures/finetune.yaml:
--------------------------------------------------------------------------------
1 | batch_size: 5120
2 | dataset: pretrain
3 | drop_ratio: 0.5
4 | emb_dim: 300
5 | epochs: 300
6 | gpu_id: 0
7 | initw_name: default
8 | lr: 0.01
9 | model: gin-virtual
10 | model_path: ckpt/pretrain.pt
11 | n_steps: 1
12 | no_print: false
13 | norm_layer: batch_norm
14 | num_layer: 5
15 | num_workers: 0
16 | patience: 50
17 | readout: sum
18 | trails: 2
19 | walk_length: 3
20 | wdecay: 1.0e-05
21 |
--------------------------------------------------------------------------------
/convert_to_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import pandas as pd
4 | import numpy as np
5 | from pathlib import Path
6 |
7 |
8 | class CSVToDataset:
9 | """
10 | A class to convert CSV files containing SMILES and property data to a standardized dataset format.
11 |
12 | The class extracts the SMILES column and specified property columns from a CSV file,
13 | renames the SMILES column, and saves the results to a specified location with metadata.
14 | """
15 |
16 | def __init__(self, task_name):
17 | """
18 | Initialize the converter with a task name.
19 |
20 | Args:
21 | task_name (str): Name of the task/dataset
22 | """
23 | self.task_name = task_name
24 | self.output_dir = Path(f"raw_data/{task_name}/raw")
25 | self.output_dir.mkdir(parents=True, exist_ok=True)
26 |
27 | def convert(self, csv_path, smiles_column, property_columns, eval_metric):
28 | """
29 | Convert a CSV file to the standardized dataset format.
30 |
31 | Args:
32 | csv_path (str): Path to the input CSV file
33 | smiles_column (str): Name of the column containing SMILES strings
34 | property_columns (list): List of column names for properties to extract
35 |
36 | Returns:
37 | bool: True if conversion was successful, False otherwise
38 | """
39 | if not os.path.exists(csv_path):
40 | print(f"Error: File {csv_path} does not exist.")
41 | return False
42 | df = pd.read_csv(csv_path)
43 |
44 | # Validate columns exist in the dataframe
45 | missing_columns = []
46 | if smiles_column not in df.columns:
47 | missing_columns.append(smiles_column)
48 |
49 | for col in property_columns:
50 | if col not in df.columns:
51 | missing_columns.append(col)
52 |
53 | if missing_columns:
54 | print(f"Error: The following columns are missing from the CSV: {', '.join(missing_columns)}")
55 | return False
56 |
57 | # Extract and rename columns
58 | selected_columns = [smiles_column] + property_columns
59 | result_df = df[selected_columns].copy()
60 | result_df.rename(columns={smiles_column: 'smiles'}, inplace=True)
61 |
62 | # Save the processed data
63 | output_path = self.output_dir / 'assays.csv.gz'
64 | result_df.to_csv(output_path, index=False, compression='gzip')
65 |
66 | # Create and save metadata
67 | self.num_tasks = len(property_columns)
68 | metadata = {
69 | 'num_tasks': self.num_tasks,
70 | 'start_column': 1,
71 | 'eval_metric': eval_metric
72 | }
73 |
74 | with open(self.output_dir / 'meta.json', 'w') as f:
75 | json.dump(metadata, f, indent=2)
76 |
77 | print(f"Successfully processed {csv_path}")
78 | print(f"Saved data to {output_path}")
79 | print(f"Number of compounds: {len(result_df)}")
80 | print(f"Number of tasks: {self.num_tasks}")
81 |
82 | return True
83 |
84 |
85 | if __name__ == "__main__":
86 | import argparse
87 |
88 | parser = argparse.ArgumentParser(description='Convert CSV files to standardized dataset format')
89 | parser.add_argument('--task_name', type=str, required=True, help='Name of the task/dataset')
90 | parser.add_argument('--csv_path', type=str, required=True, help='Path to the input CSV file')
91 | parser.add_argument('--smiles_column', type=str, required=True, help='Name of the column containing SMILES strings')
92 | parser.add_argument('--property_columns', type=str, nargs='+', required=True,
93 | help='Names of columns containing property values to extract')
94 | parser.add_argument('--eval_metric', type=str, default='roc_auc', help='Evaluation metric')
95 |
96 | args = parser.parse_args()
97 | assert args.eval_metric in ["roc_auc", "avg_mae"]
98 |
99 | converter = CSVToDataset(args.task_name)
100 | success = converter.convert(
101 | csv_path=args.csv_path,
102 | smiles_column=args.smiles_column,
103 | property_columns=args.property_columns,
104 | eval_metric=args.eval_metric
105 | )
106 |
107 | exit(0 if success else 1)
108 |
--------------------------------------------------------------------------------
/dataset/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liugangcode/InfoAlign/eed96f36503a8a427fac4f382ee2001f957dd42a/dataset/__init__.py
--------------------------------------------------------------------------------
/dataset/context_graph.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import numpy as np
3 |
4 | logger = logging.getLogger(__name__)
5 |
6 | import networkx as nx
7 |
8 | from collections import defaultdict
9 | from typing import Any, Dict, List, Literal, Optional, Union
10 |
11 | import torch
12 | from torch import Tensor
13 |
14 | import torch_geometric
15 | from torch_geometric.data import Data
16 |
17 | def from_networkx(
18 | G: Any,
19 | group_node_attrs: Optional[Union[List[str], Literal['all']]] = None,
20 | group_edge_attrs: Optional[Union[List[str], Literal['all']]] = None,
21 | ) -> 'torch_geometric.data.Data':
22 | r"""Converts a :obj:`networkx.Graph` or :obj:`networkx.DiGraph` to a
23 | :class:`torch_geometric.data.Data` instance.
24 |
25 | Args:
26 | G (networkx.Graph or networkx.DiGraph): A networkx graph.
27 | group_node_attrs (List[str] or "all", optional): The node attributes to
28 | be concatenated and added to :obj:`data.x`. (default: :obj:`None`)
29 | group_edge_attrs (List[str] or "all", optional): The edge attributes to
30 | be concatenated and added to :obj:`data.edge_attr`.
31 | (default: :obj:`None`)
32 |
33 | .. note::
34 |
35 | All :attr:`group_node_attrs` and :attr:`group_edge_attrs` values must
36 | be numeric.
37 |
38 | Examples:
39 | >>> edge_index = torch.tensor([
40 | ... [0, 1, 1, 2, 2, 3],
41 | ... [1, 0, 2, 1, 3, 2],
42 | ... ])
43 | >>> data = Data(edge_index=edge_index, num_nodes=4)
44 | >>> g = to_networkx(data)
45 | >>> # A `Data` object is returned
46 | >>> from_networkx(g)
47 | Data(edge_index=[2, 6], num_nodes=4)
48 | """
49 |
50 | G = G.to_directed() if not nx.is_directed(G) else G
51 |
52 | mapping = dict(zip(G.nodes(), range(G.number_of_nodes())))
53 | edge_index = torch.empty((2, G.number_of_edges()), dtype=torch.long)
54 | for i, (src, dst) in enumerate(G.edges()):
55 | edge_index[0, i] = mapping[src]
56 | edge_index[1, i] = mapping[dst]
57 |
58 | data_dict: Dict[str, Any] = defaultdict(list)
59 | data_dict['edge_index'] = edge_index
60 |
61 | node_attrs: List[str] = []
62 | if G.number_of_nodes() > 0:
63 | node_attrs = list(next(iter(G.nodes(data=True)))[-1].keys())
64 |
65 | edge_attrs: List[str] = []
66 | if G.number_of_edges() > 0:
67 | edge_attrs = list(next(iter(G.edges(data=True)))[-1].keys())
68 |
69 | if group_node_attrs is not None and not isinstance(group_node_attrs, list):
70 | group_node_attrs = node_attrs
71 |
72 | if group_edge_attrs is not None and not isinstance(group_edge_attrs, list):
73 | group_edge_attrs = edge_attrs
74 |
75 | for i, (_, feat_dict) in enumerate(G.nodes(data=True)):
76 | if set(feat_dict.keys()) != set(node_attrs):
77 | raise ValueError('Not all nodes contain the same attributes')
78 | for key, value in feat_dict.items():
79 | data_dict[str(key)].append(value)
80 |
81 | for i, (_, _, feat_dict) in enumerate(G.edges(data=True)):
82 | if set(feat_dict.keys()) != set(edge_attrs):
83 | raise ValueError('Not all edges contain the same attributes')
84 | for key, value in feat_dict.items():
85 | key = f'edge_{key}' if key in node_attrs else key
86 | data_dict[str(key)].append(value)
87 |
88 | for key, value in G.graph.items():
89 | if key == 'node_default' or key == 'edge_default':
90 | continue # Do not load default attributes.
91 | key = f'graph_{key}' if key in node_attrs else key
92 | data_dict[str(key)] = value
93 |
94 | for key, value in data_dict.items():
95 | if isinstance(value, (tuple, list)) and isinstance(value[0], Tensor):
96 | data_dict[key] = torch.stack(value, dim=0, dtype=torch.float32)
97 | elif isinstance(value, (tuple, list)) and isinstance(value[0], np.ndarray):
98 | data_dict[key] = torch.tensor(np.stack(value), dtype=torch.float32)
99 | else:
100 | try:
101 | data_dict[key] = torch.as_tensor(np.array(value))
102 | except Exception:
103 | pass
104 |
105 | data = Data.from_dict(data_dict)
106 |
107 | if group_node_attrs is not None:
108 | xs = []
109 | for key in group_node_attrs:
110 | x = data[key]
111 | x = x.view(-1, 1) if x.dim() <= 1 else x
112 | xs.append(x)
113 | del data[key]
114 | data.x = torch.cat(xs, dim=-1)
115 |
116 | if group_edge_attrs is not None:
117 | xs = []
118 | for key in group_edge_attrs:
119 | key = f'edge_{key}' if key in node_attrs else key
120 | x = data[key]
121 | x = x.view(-1, 1) if x.dim() <= 1 else x
122 | xs.append(x)
123 | del data[key]
124 | data.edge_attr = torch.cat(xs, dim=-1)
125 |
126 | if data.x is None and data.pos is None:
127 | data.num_nodes = G.number_of_nodes()
128 |
129 | return data
130 |
131 | if __name__ == "__main__":
132 | pass
--------------------------------------------------------------------------------
/dataset/create_datasets.py:
--------------------------------------------------------------------------------
1 | def get_data(args, load_path, transform="pyg"):
2 | assert transform in [
3 | None,
4 | "fingerprint",
5 | "smiles",
6 | "pyg",
7 | "morphology",
8 | "expression",
9 | ]
10 | pretrained = args.dataset == "pretrain"
11 |
12 | if pretrained:
13 | assert transform == "pyg"
14 | from .pretrain_molecule import PretrainMoleculeDataset
15 | from .pretrain_context import PretrainContextDataset
16 |
17 | molecule = PretrainMoleculeDataset(root=load_path)
18 | context = PretrainContextDataset(root=load_path, pre_transform=args.threshold)
19 | return molecule, context
20 |
21 | if args.dataset.startswith("finetune"):
22 | data_name = args.dataset.split("-")[1]
23 | else:
24 | data_name = args.dataset
25 |
26 | # if data_name in ["broad6k", "chembl2k", "biogenadme", "moltoxcast"]:
27 | assert transform in ["fingerprint", "smiles", "morphology", "expression", "pyg"]
28 | if transform == "pyg":
29 | from .prediction_molecule import PygPredictionMoleculeDataset
30 |
31 | return PygPredictionMoleculeDataset(name=data_name, root=load_path)
32 | else:
33 | from .prediction_molecule import PredictionMoleculeDataset
34 |
35 | return PredictionMoleculeDataset(
36 | name=data_name, root=load_path, transform=transform
37 | )
38 |
39 |
--------------------------------------------------------------------------------
/dataset/data_utils.py:
--------------------------------------------------------------------------------
1 | from ogb.utils.features import atom_to_feature_vector, bond_to_feature_vector
2 |
3 | import torch
4 | import numpy as np
5 |
6 | import copy
7 | import pathlib
8 | import pandas as pd
9 | from tqdm import tqdm
10 | from torch_geometric.data import Data
11 |
12 | from rdkit import Chem
13 | from rdkit.Chem import AllChem
14 |
15 | def smiles2graph(smiles_string):
16 | """
17 | Converts SMILES string to graph Data object
18 | :input: SMILES string (str)
19 | :return: graph object
20 | """
21 | try:
22 | mol = Chem.MolFromSmiles(smiles_string)
23 | # atoms
24 | atom_features_list = []
25 | # atom_label = []
26 | for atom in mol.GetAtoms():
27 | atom_features_list.append(atom_to_feature_vector(atom))
28 | # atom_label.append(atom.GetSymbol())
29 |
30 | x = np.array(atom_features_list, dtype=np.int64)
31 | # atom_label = np.array(atom_label, dtype=np.str)
32 |
33 | # bonds
34 | num_bond_features = 3 # bond type, bond stereo, is_conjugated
35 | if len(mol.GetBonds()) > 0: # mol has bonds
36 | edges_list = []
37 | edge_features_list = []
38 | for bond in mol.GetBonds():
39 | i = bond.GetBeginAtomIdx()
40 | j = bond.GetEndAtomIdx()
41 |
42 | edge_feature = bond_to_feature_vector(bond)
43 | # add edges in both directions
44 | edges_list.append((i, j))
45 | edge_features_list.append(edge_feature)
46 | edges_list.append((j, i))
47 | edge_features_list.append(edge_feature)
48 |
49 | # data.edge_index: Graph connectivity in COO format with shape [2, num_edges]
50 | edge_index = np.array(edges_list, dtype=np.int64).T
51 |
52 | # data.edge_attr: Edge feature matrix with shape [num_edges, num_edge_features]
53 | edge_attr = np.array(edge_features_list, dtype=np.int64)
54 |
55 | else: # mol has no bonds
56 | edge_index = np.empty((2, 0), dtype=np.int64)
57 | edge_attr = np.empty((0, num_bond_features), dtype=np.int64)
58 |
59 | graph = dict()
60 | graph["edge_index"] = edge_index
61 | graph["edge_feat"] = edge_attr
62 | graph["node_feat"] = x
63 | graph["num_nodes"] = len(x)
64 |
65 | return graph
66 |
67 | except:
68 | return None
69 |
70 |
71 | def read_graph_list(mol_df, keep_id=False):
72 |
73 | mol_list = mol_df["smiles"].tolist()
74 | ids_list = mol_df["mol_id"].tolist()
75 |
76 | graph_list = []
77 | total_length = len(mol_list)
78 | with tqdm(total=total_length, desc="Processing molecules") as pbar:
79 | for index, smiles_str in enumerate(mol_list):
80 | graph_dict = smiles2graph(smiles_str)
81 | if keep_id:
82 | graph_dict["type"] = ids_list[index]
83 | graph_list.append(graph_dict)
84 | pbar.update(1)
85 |
86 | pyg_graph_list = []
87 | print("Converting graphs into PyG objects...")
88 | for graph in graph_list:
89 | g = Data()
90 | g.__num_nodes__ = graph["num_nodes"]
91 | g.edge_index = torch.from_numpy(graph["edge_index"])
92 | del graph["num_nodes"]
93 | del graph["edge_index"]
94 |
95 | if graph["edge_feat"] is not None:
96 | g.edge_attr = torch.from_numpy(graph["edge_feat"])
97 | del graph["edge_feat"]
98 |
99 | if graph["node_feat"] is not None:
100 | g.x = torch.from_numpy(graph["node_feat"])
101 | del graph["node_feat"]
102 |
103 | if graph["type"] is not None:
104 | g.type = graph["type"]
105 | del graph["type"]
106 |
107 | addition_prop = copy.deepcopy(graph)
108 | for key in addition_prop.keys():
109 | g[key] = torch.tensor(graph[key])
110 | del graph[key]
111 |
112 | pyg_graph_list.append(g)
113 |
114 | return pyg_graph_list
115 |
116 |
117 | ## utils to create prediction context (modality) graph ##
118 |
119 | from rdkit.Chem.Scaffolds import MurckoScaffold
120 | from rdkit.Chem import DataStructs
121 | from collections import defaultdict
122 | from joblib import Parallel, delayed
123 |
124 | from sklearn.decomposition import PCA
125 | from sklearn.cluster import KMeans
126 | from scipy.spatial import distance
127 | from scipy.sparse import csr_matrix
128 | from sklearn.metrics.pairwise import cosine_similarity
129 |
130 | import networkx as nx
131 | from copy import deepcopy
132 |
133 | def get_scaffold(mol):
134 | """Extracts the Murcko Scaffold from a molecule."""
135 | scaffold = MurckoScaffold.GetScaffoldForMol(mol)
136 | return Chem.MolToSmiles(scaffold)
137 |
138 |
139 | def parallel_scaffold_computation(molecule, molecule_id):
140 | """Computes scaffold for a single molecule in parallel."""
141 | scaffold = get_scaffold(molecule)
142 | return scaffold, molecule, molecule_id
143 |
144 |
145 | def cluster_molecules_by_scaffold(
146 | molecules, all_data_id, n_jobs=-1, remove_single=True, flatten_id=True
147 | ):
148 | """Clusters molecules based on their scaffolds using parallel processing."""
149 | # Ensure molecules and IDs are paired correctly
150 | paired_results = Parallel(n_jobs=n_jobs)(
151 | delayed(parallel_scaffold_computation)(mol, molecule_id)
152 | for mol, molecule_id in zip(molecules, all_data_id)
153 | )
154 |
155 | # Initialize dictionaries for batches and IDs
156 | batch = defaultdict(list)
157 | batched_data_id = defaultdict(list)
158 |
159 | # Process results to fill batch and batched_data_id dictionaries
160 | for scaffold, mol, molecule_id in paired_results:
161 | batch[scaffold].append(mol)
162 | batched_data_id[scaffold].append(molecule_id)
163 |
164 | # Optionally remove clusters with only one molecule
165 | if remove_single:
166 | batch = {scaffold: mols for scaffold, mols in batch.items() if len(mols) > 1}
167 | batched_data_id = {
168 | scaffold: ids for scaffold, ids in batched_data_id.items() if len(ids) > 1
169 | }
170 |
171 | # Convert dictionaries to lists for output
172 | scaffolds = list(batch.keys())
173 | batch = list(batch.values())
174 | batched_data_id = list(batched_data_id.values())
175 | if flatten_id:
176 | batched_data_id = [idd for batch in batched_data_id for idd in batch]
177 | batched_data_id = np.array(batched_data_id)
178 |
179 | return scaffolds, batch, batched_data_id
180 |
181 |
182 | def calculate_mol_similarity(fingerprint1, fingerprint2):
183 | """Wrapper function to calculate Tanimoto similarity between two fingerprints."""
184 | return DataStructs.TanimotoSimilarity(fingerprint1, fingerprint2)
185 |
186 |
187 | def pairwise_mol_similarity(mol_list, n_jobs=1):
188 | """Calculates the internal similarity within a cluster of molecules using multiple CPUs."""
189 | fingerprints = [
190 | AllChem.GetMorganFingerprintAsBitVect(m, 2, nBits=1024) for m in mol_list
191 | ]
192 | n = len(fingerprints)
193 | similarity_matrix = np.zeros((n, n))
194 |
195 | # Define a task for each pair of molecules to calculate similarity
196 | def compute_similarity(i, j):
197 | if i < j: # Avoid redundant calculations and diagonal
198 | return i, j, calculate_mol_similarity(fingerprints[i], fingerprints[j])
199 | else:
200 | return i, j, 0 # No calculation needed, fill with zeros
201 |
202 | # Use Parallel and delayed to compute similarities in parallel
203 | results = Parallel(n_jobs=n_jobs)(
204 | delayed(compute_similarity)(i, j) for i in range(n) for j in range(i, n)
205 | )
206 |
207 | # Fill the similarity matrix with results
208 | for i, j, sim in results:
209 | similarity_matrix[i, j] = similarity_matrix[j, i] = sim
210 |
211 | return similarity_matrix
212 |
213 |
214 | def perform_pca_and_kmeans(features, all_data_id, n_components=2, n_clusters=100, return_pca_feature=False):
215 | features = np.nan_to_num(features)
216 | pca = PCA(n_components=n_components)
217 | pca_features = pca.fit_transform(features)
218 | # print('explained_variance_ratio_ in perform_pca_and_kmeans', np.cumsum(pca.explained_variance_ratio_))
219 |
220 | kmeans = KMeans(n_clusters=n_clusters, random_state=0).fit(pca_features)
221 |
222 | batch = []
223 | batched_data_id = []
224 | for i in range(n_clusters):
225 | if return_pca_feature:
226 | batch.append(pca_features[kmeans.labels_ == i])
227 | else:
228 | batch.append(features[kmeans.labels_ == i])
229 | batched_data_id.append(all_data_id[kmeans.labels_ == i])
230 |
231 | batched_data_id = [idd for batch in batched_data_id for idd in batch]
232 | batched_data_id = np.array(batched_data_id)
233 |
234 | return batch, batched_data_id
235 |
236 |
237 | # calculate the similarity between the data points in the same batch and finally merge the similarity matrix
238 | def l1_similarity(matrix):
239 | matrix = np.nan_to_num(matrix)
240 | matrix = (matrix - matrix.mean(axis=1)[:, None]) / (matrix.std(axis=1)[:, None] + 1e-10)
241 | dist = distance.pdist(matrix, "cityblock") # L1 distance
242 | dist = distance.squareform(dist) # Convert to square form
243 | sim = 1 / (1 + dist) # Convert distance to similarity
244 | np.fill_diagonal(sim, 0) # Set diagonal elements to 0
245 | return sim
246 |
247 |
248 | def l2_similarity(matrix):
249 | matrix = np.nan_to_num(matrix)
250 | matrix = (matrix - matrix.mean(axis=1)[:, None]) / (matrix.std(axis=1)[:, None] + 1e-10)
251 | dist = distance.pdist(matrix, "euclidean") # L2 distance
252 | dist = distance.squareform(dist) # Convert to square form
253 | sim = 1 / (1 + dist) # Convert distance to similarity
254 | np.fill_diagonal(sim, 0) # Set diagonal elements to 0
255 | return sim
256 |
257 |
258 | def pairwise_cosine_similarity(matrix):
259 | matrix = np.nan_to_num(matrix)
260 | matrix = (matrix - matrix.mean(axis=1)[:, None]) / (matrix.std(axis=1)[:, None] + 1e-10)
261 | sim = cosine_similarity(matrix)
262 | np.fill_diagonal(sim, 0) # Set diagonal elements to 0
263 | return sim
264 |
265 |
266 | def batch_similarity(batches, similarity_func):
267 | # Calculate the similarity for each batch
268 | batch_sims = [
269 | similarity_func(batch)
270 | for batch in tqdm(batches, desc="Calculating similarities")
271 | ]
272 |
273 | # Find the total number of nonzero elements
274 | total_nonzero = sum(np.count_nonzero(arr) for arr in batch_sims)
275 |
276 | # Initialize arrays to store row, col, and data for csr_matrix construction
277 | rows = np.zeros(total_nonzero, dtype=np.int32)
278 | cols = np.zeros(total_nonzero, dtype=np.int32)
279 | data = np.zeros(total_nonzero, dtype=batch_sims[0].dtype)
280 |
281 | current_idx = 0
282 | current_row = 0
283 | current_col = 0
284 | # Loop through each batch similarity array
285 | for idx, arr in enumerate(batch_sims):
286 | rows_batch, cols_batch = np.nonzero(arr)
287 | num_nonzero = len(rows_batch)
288 | rows[current_idx : current_idx + num_nonzero] = rows_batch + current_row
289 | cols[current_idx : current_idx + num_nonzero] = cols_batch + current_col
290 | data[current_idx : current_idx + num_nonzero] = arr[rows_batch, cols_batch]
291 | current_idx += num_nonzero
292 | current_row += arr.shape[0]
293 | current_col += arr.shape[1]
294 |
295 | # Construct the csr_matrix
296 | merged_sim = csr_matrix(
297 | (data, (rows, cols)),
298 | shape=(
299 | sum(arr.shape[0] for arr in batch_sims),
300 | sum(arr.shape[1] for arr in batch_sims),
301 | ),
302 | )
303 |
304 | return merged_sim
305 |
306 |
307 | def direct_similarity(matrix, similarity_func):
308 | sim = similarity_func(matrix)
309 | np.fill_diagonal(sim, 0) # Set diagonal elements to 0
310 | # convert to csr_matrix
311 | sim = csr_matrix(sim)
312 | return sim
313 |
314 | def determine_threshold(similarity, min_threshold, high_threshold=0.99, target_sparsity=0.995):
315 | low_threshold = min_threshold
316 | threshold = (high_threshold + low_threshold) / 2.0 # Start in the middle
317 | current_sparsity = 1 - np.count_nonzero(similarity > threshold) / similarity.size
318 |
319 | while low_threshold < high_threshold - 0.001: # Continue until the interval is small
320 | if current_sparsity > target_sparsity:
321 | high_threshold = threshold # Move the upper limit down
322 | else:
323 | low_threshold = threshold # Move the lower limit up
324 |
325 | threshold = (high_threshold + low_threshold) / 2.0 # Recalculate the middle
326 | current_sparsity = 1 - np.count_nonzero(similarity > threshold) / similarity.size
327 |
328 | return threshold
329 |
330 | def filter_similarity_and_get_ids(similarity_matrix, threshold, data_id):
331 | filtered_similarity = similarity_matrix.copy()
332 | filtered_similarity.data[filtered_similarity.data < threshold] = 0
333 | filtered_similarity.eliminate_zeros()
334 |
335 | row, col = filtered_similarity.nonzero()
336 |
337 | s_node_id = data_id[row]
338 | t_node_id = data_id[col]
339 | edge_weight = filtered_similarity.data
340 | edge_weight = np.round(edge_weight, 2)
341 |
342 | return s_node_id, t_node_id, edge_weight
343 |
344 | def merge_features_and_dataframes(df1, df2, features1, features2, col_id1, col_id2, connect_col=None):
345 | if connect_col is None:
346 | df1['connect_id'] = df1[col_id1].astype(str) + '_' + df1[col_id2].astype(str)
347 | df2['connect_id'] = df2[col_id1].astype(str) + '_' + df2[col_id2].astype(str)
348 | else:
349 | df1['connect_id'] = df1[connect_col].astype(str)
350 | df2['connect_id'] = df2[connect_col].astype(str)
351 | index1 = f'index_{col_id1}'
352 | index2 = f'index_{col_id2}'
353 | df1[index1] = df1.index
354 | df2[index2] = df2.index
355 |
356 | merged_df = pd.merge(df1[['connect_id', index1, col_id1]], df2[['connect_id', index2, col_id2]], on='connect_id', how='outer')
357 | if connect_col is None:
358 | merged_df[[col_id1, col_id2]] = merged_df['connect_id'].str.split('_', expand=True)
359 | merged_features = []
360 | # Iterate over each row in the merged dataframe to concatenate features
361 | for _, row in merged_df.iterrows():
362 | feature1_row = np.nan * np.ones(features1.shape[1]) if np.isnan(row[index1]) else features1[int(row[index1])]
363 | feature2_row = np.nan * np.ones(features2.shape[1]) if np.isnan(row[index2]) else features2[int(row[index2])]
364 | merged_features.append(np.concatenate([feature1_row, feature2_row]))
365 |
366 | merged_features = np.vstack(merged_features)
367 | return merged_df, merged_features
368 |
369 | def minmax_normalize(data, min_val=None, max_val=None):
370 | """
371 | Normalizes the data using min-max normalization, handling NaN values in min and max calculations.
372 |
373 | Parameters:
374 | - data: The data to be normalized (numpy array).
375 | - min_val: Optional. Precomputed minimum values for each feature. If not provided, computed from data.
376 | - max_val: Optional. Precomputed maximum values for each feature. If not provided, computed from data.
377 |
378 | Returns:
379 | - Normalized data, used min_val, used max_val.
380 | """
381 | # Calculate min and max values if not provided, ignoring NaN values
382 | if min_val is None or max_val is None:
383 | min_val = np.nanmin(data, axis=0)
384 | max_val = np.nanmax(data, axis=0)
385 | else:
386 | assert min_val.shape[0] == data.shape[1], "min_val must match the number of features in data"
387 | assert max_val.shape[0] == data.shape[1], "max_val must match the number of features in data"
388 |
389 | # Handle cases where min and max are equal
390 | equal_indices = min_val == max_val
391 | min_val[equal_indices] = min_val[equal_indices] - 1e-6
392 | max_val[equal_indices] = max_val[equal_indices] + 1e-6
393 |
394 | # Normalize, handling divisions by zero or where max_val equals min_val
395 | normalized_data = np.where(
396 | max_val - min_val == 0,
397 | 0,
398 | (data - min_val) / (max_val - min_val)
399 | )
400 |
401 | return normalized_data, min_val, max_val
402 |
403 | def create_nx_graph(folder, min_thres=0.6, min_sparsity=0.995, top_compound_gene_express=0.05):
404 |
405 | structure_df = pd.read_csv(f"{folder}/structure.csv.gz")
406 | mol_df = structure_df.drop_duplicates(subset="mol_id")
407 | smiles_list = mol_df["smiles"].tolist()
408 | all_mol_id = mol_df["mol_id"].tolist()
409 | molecules = [Chem.MolFromSmiles(smiles) for smiles in smiles_list]
410 | structure_feature = np.array(
411 | [AllChem.GetMorganFingerprintAsBitVect(m, 4, nBits=1024) for m in molecules]
412 | )
413 | # load cell nodes and features
414 | cp_bray_df = pd.read_csv(f"{folder}/CP-Bray.csv.gz")
415 | cp_jump_df = pd.read_csv(f"{folder}/CP-JUMP.csv.gz")
416 | cp_bray_feature = np.load(f"{folder}/CP-Bray_feature.npz")["data"]
417 | cp_jump_feature = np.load(f"{folder}/CP-JUMP_feature.npz")["data"]
418 | bray_dim, jump_dim = cp_bray_feature.shape[1], cp_jump_feature.shape[1]
419 |
420 | cell_df, cell_feature = merge_features_and_dataframes(
421 | df1=cp_bray_df, df2=cp_jump_df,
422 | features1=cp_bray_feature, features2=cp_jump_feature,
423 | col_id1='cell_bid', col_id2='cell_jid',
424 | connect_col='mol_id'
425 | )
426 | cell_df = cell_df.rename(columns={'connect_id': 'mol_id'})
427 | cell_df['cell_id'] = ['c' + str(i) for i in range(1, len(cell_df) + 1)]
428 | cell_df = cell_df[['cell_id', 'mol_id', 'index_cell_bid', 'index_cell_jid']]
429 |
430 | # load gene nodes and features
431 | gc_df = pd.read_csv(f"{folder}/G-CRISPR.csv.gz")
432 | gc_feature = np.load(f"{folder}/G-CRISPR_feature.npz")["data"]
433 | go_df = pd.read_csv(f"{folder}/G-ORF.csv.gz")
434 | go_feature = np.load(f"{folder}/G-ORF_feature.npz")["data"]
435 |
436 | gene_df, gene_feature = merge_features_and_dataframes(
437 | df1=gc_df, df2=go_df,
438 | features1=gc_feature, features2=go_feature,
439 | col_id1='ncbi_gene_id', col_id2='mol_id',
440 | connect_col=None
441 | )
442 | gene_df['gene_id'] = gene_df['ncbi_gene_id'].apply(lambda x: 'g' + str(x))
443 | gene_df = gene_df[['gene_id', 'mol_id', 'index_ncbi_gene_id', 'index_mol_id']]
444 |
445 | # load gene expression
446 | express_df = pd.read_csv(f"{folder}/GE.csv.gz")
447 | express_feature = np.load(f"{folder}/GE_feature.npz")["data"]
448 |
449 | # load gene-gene interaction
450 | gg_df = pd.read_csv(f"{folder}/G-G.csv.gz")
451 |
452 | nid_to_feature_id = {}
453 | # Combine the loop iterations for different dataframes into a single loop
454 | for df, col_id in [
455 | (mol_df, "mol_id"),
456 | (express_df, "express_id"),
457 | (gene_df, "gene_id"),
458 | (cell_df, "cell_id"),
459 | ]:
460 | for i, nid in enumerate(df[col_id]):
461 | nid_to_feature_id[nid] = i
462 |
463 | target_mol = structure_feature
464 | target_gene, min_gene, max_gene = minmax_normalize(gene_feature)
465 | target_cell, min_cell, max_cell = minmax_normalize(cell_feature)
466 | target_express, min_express, max_express = minmax_normalize(express_feature)
467 |
468 | mol_dim, gene_dim, cell_dim, express_dim = (
469 | structure_feature.shape[1],
470 | gene_feature.shape[1],
471 | cell_feature.shape[1],
472 | express_feature.shape[1],
473 | )
474 |
475 | #### molecular similarity
476 | scaffold_names, batch_mol_feature, batched_mol_id = cluster_molecules_by_scaffold(
477 | molecules, all_mol_id
478 | )
479 | sim_mol = batch_similarity(batch_mol_feature, pairwise_mol_similarity)
480 |
481 | #### gene similarity
482 | batched_gene_feature, batched_gene_id = perform_pca_and_kmeans(
483 | gene_feature, gene_df["gene_id"].values
484 | )
485 | sim_gene = batch_similarity(batched_gene_feature, pairwise_cosine_similarity)
486 |
487 | #### cell similarity
488 | batched_cell_feature, batched_cell_id = perform_pca_and_kmeans(
489 | cell_feature[:, bray_dim:], cell_df["cell_id"].values
490 | )
491 | sim_cell = batch_similarity(batched_cell_feature, pairwise_cosine_similarity)
492 |
493 | #### gene expression similarity
494 | sim_express = direct_similarity(express_feature, pairwise_cosine_similarity)
495 | direct_express_id = express_df["express_id"].values
496 |
497 | ########## create graph ########
498 | G = nx.Graph()
499 |
500 | mol_ids = structure_df["mol_id"].values
501 | gene_ids = gene_df["gene_id"].values
502 | cell_ids = cell_df["cell_id"].values
503 | express_ids = express_df["express_id"].values
504 |
505 | # Add molecule nodes
506 | mol_nodes = [
507 | (
508 | mol_ids[idx],
509 | dict(
510 | type=mol_ids[idx],
511 | mol_target=target_mol[nid_to_feature_id[mol_id]],
512 | gene_target=np.full(gene_dim, np.nan),
513 | cell_target=np.full(cell_dim, np.nan),
514 | express_target=np.full(express_dim, np.nan)
515 | ),
516 | )
517 | for idx, mol_id in enumerate(mol_ids)
518 | ]
519 | G.add_nodes_from(mol_nodes)
520 | print("Count nodes after adding molecule nodes:", G.number_of_nodes())
521 |
522 | # Add gene crispr nodes
523 | gene_nodes = [
524 | (
525 | gene_ids[idx],
526 | dict(
527 | type=gene_ids[idx],
528 | mol_target=np.full(mol_dim, np.nan),
529 | gene_target=target_gene[nid_to_feature_id[gene_id]],
530 | cell_target=np.full(cell_dim, np.nan),
531 | express_target=np.full(express_dim, np.nan),
532 | ),
533 | )
534 | for idx, gene_id in enumerate(gene_ids)
535 | ]
536 | G.add_nodes_from(gene_nodes)
537 | print("Count nodes after adding gene nodes:", G.number_of_nodes())
538 |
539 | cell_nodes = [
540 | (
541 | cell_ids[idx],
542 | dict(
543 | type=cell_ids[idx],
544 | mol_target=np.full(mol_dim, np.nan),
545 | gene_target=np.full(gene_dim, np.nan),
546 | cell_target=target_cell[nid_to_feature_id[cell_id]],
547 | express_target=np.full(express_dim, np.nan),
548 | ),
549 | )
550 | for idx, cell_id in enumerate(cell_ids)
551 | ]
552 | G.add_nodes_from(cell_nodes)
553 | print("Count nodes after adding cell nodes:", G.number_of_nodes())
554 |
555 | # Add gene expression nodes
556 | express_nodes = [
557 | (
558 | express_ids[idx],
559 | dict(
560 | type=express_ids[idx],
561 | mol_target=np.full(mol_dim, np.nan),
562 | gene_target=np.full(gene_dim, np.nan),
563 | cell_target=np.full(cell_dim, np.nan),
564 | express_target=target_express[nid_to_feature_id[express_id]],
565 |
566 | ),
567 | )
568 | for idx, express_id in enumerate(express_ids)
569 | ]
570 | G.add_nodes_from(express_nodes)
571 | print("Count nodes after adding gene expression nodes:", G.number_of_nodes())
572 |
573 | G.add_edges_from(zip(gene_df["gene_id"], gene_df["mol_id"]), weight=1)
574 | print("Count of edges after adding mol-gene:", G.number_of_edges())
575 | G.add_edges_from(zip(cell_df["mol_id"], cell_df["cell_id"]), weight=1)
576 | print("Count of edges after adding mol-cell:", G.number_of_edges())
577 | G.add_edges_from(zip(express_df["express_id"], express_df["mol_id"]), weight=1)
578 | print("Count of edges after adding mol-express:", G.number_of_edges())
579 |
580 | # add gene-gene edges
581 | gene_source_prefixed = "g" + gg_df["source_id"].astype(str)
582 | gene_target_prefixed = "g" + gg_df["target_id"].astype(str)
583 |
584 | ## Filter the 'go' prefixed edges where both nodes exist in go_df["orf_id"]
585 | valid_gene_sources = gene_source_prefixed.isin(gene_df["gene_id"])
586 | valid_gene_targets = gene_target_prefixed.isin(gene_df["gene_id"])
587 | gene_edges_to_add = zip(
588 | gene_source_prefixed[valid_gene_sources & valid_gene_targets],
589 | gene_target_prefixed[valid_gene_sources & valid_gene_targets],
590 | )
591 | G.add_edges_from(gene_edges_to_add, weight=1)
592 | print("Count of edges after adding gene-gene in the graph:", G.number_of_edges())
593 |
594 | # Add edges between cell according to the similarity matrix
595 | def add_edges_with_weight(G, s_node, t_node, edge_weight):
596 | # G.add_edges_from(zip(s_node, t_node), weight=edge_weight)
597 | for i, (s, t) in enumerate(zip(s_node, t_node)):
598 | weight = edge_weight[i]
599 | G.add_edge(s, t, weight=weight)
600 |
601 | L1k_idmaps = pd.read_csv(f'{folder}/L1k_idmaps.csv')
602 | sorted_indices = np.argsort(np.abs(express_feature).flatten())
603 | start_index = int(len(sorted_indices) * (1-top_compound_gene_express))
604 | express_thre = np.abs(express_feature).flatten()[sorted_indices[start_index]]
605 | top_percent_indices = sorted_indices[start_index:]
606 | row_indices, col_indices = np.unravel_index(top_percent_indices, express_feature.shape)
607 | if len(row_indices) < 1000:
608 | top_percent_indices = sorted_indices[-1000:]
609 | row_indices, col_indices = np.unravel_index(top_percent_indices, express_feature.shape)
610 |
611 | expressed_mol = express_df['mol_id'].iloc[row_indices].values
612 | expressed_gene = L1k_idmaps['ncbi_gene_id'].iloc[col_indices].values
613 | expressed_gene = np.array([f'g{int(gid)}' for gid in expressed_gene])
614 | expressed_weight = np.abs(express_feature[row_indices, col_indices])
615 | expressed_gene_series = pd.Series(expressed_gene)
616 | valid_express_gene = expressed_gene_series.isin(gene_df["gene_id"]).values
617 | add_edges_with_weight(G, expressed_mol[valid_express_gene], expressed_gene[valid_express_gene], expressed_weight[valid_express_gene])
618 | print(f"Count of edges after adding top {round(top_compound_gene_express * 100, 2)} % gene-gene from expression with threshold {round(express_thre, 4)}:", G.number_of_edges())
619 |
620 | thre_cell = determine_threshold(sim_cell.data, min_thres, target_sparsity=min_sparsity)
621 | thre_gene = determine_threshold(sim_gene.data, min_thres, target_sparsity=min_sparsity)
622 | thre_express = determine_threshold(sim_express.data, min_thres, target_sparsity=min_sparsity)
623 | thre_mol = determine_threshold(sim_mol.data, min_thres, target_sparsity=min_sparsity)
624 | thres_dict = {'cell': thre_cell, 'gene': thre_gene, 'express': thre_express, 'mol': thre_mol}
625 |
626 | ## filter similarity and get ids
627 | mol_s_node, mol_t_node, mol_edge_weight = filter_similarity_and_get_ids(
628 | sim_mol, thres_dict['mol'], batched_mol_id
629 | )
630 | gene_s_node, gene_t_node, gene_edge_weight = filter_similarity_and_get_ids(
631 | sim_gene, thres_dict['gene'], batched_gene_id
632 | )
633 | cell_s_node, cell_t_node, cell_edge_weight = filter_similarity_and_get_ids(
634 | sim_cell, thres_dict['cell'], batched_cell_id
635 | )
636 | express_s_node, express_t_node, express_edge_weight = filter_similarity_and_get_ids(
637 | sim_express, thres_dict['express'], direct_express_id
638 | )
639 | add_edges_with_weight(G, mol_s_node, mol_t_node, mol_edge_weight)
640 | print(f"Count of edges after adding mol-mol sim with threshold {round(thres_dict['mol'], 2)}:", G.number_of_edges())
641 |
642 | add_edges_with_weight(G, gene_s_node, gene_t_node, gene_edge_weight)
643 | print(f"Count of edges after adding gene-gene sim with threshold {round(thres_dict['gene'], 2)}:", G.number_of_edges())
644 |
645 | add_edges_with_weight(G, cell_s_node, cell_t_node, cell_edge_weight)
646 | print(f"Count of edges after adding cell-cell sim with threshold {round(thres_dict['cell'], 2)}:", G.number_of_edges())
647 |
648 | add_edges_with_weight(G, express_s_node, express_t_node, express_edge_weight)
649 | print(f"Count of edges after adding express-express sim with threshold {round(thres_dict['express'], 2)}:", G.number_of_edges())
650 |
651 | nan_nodes = [n for n in G.nodes() if pd.isnull(n)]
652 | G.remove_nodes_from(nan_nodes)
653 |
654 | gene_bound = {"min": min_gene, "max": max_gene}
655 | cell_bound = {"min": min_cell, "max": max_cell}
656 | express_bound = {"min": min_express, "max": max_express}
657 |
658 | # add global information for ge_bound
659 | G.graph["gene_bound"] = gene_bound
660 | G.graph["cell_bound"] = cell_bound
661 | G.graph["express_bound"] = express_bound
662 |
663 | return G
664 |
665 |
666 | import networkx as nx
667 | from collections import defaultdict
668 | from typing import Any, Dict, List, Literal, Optional, Union
669 |
670 | import torch
671 | from torch import Tensor
672 |
673 | import torch_geometric
674 | from torch_geometric.data import Data
675 |
676 |
677 | def from_networkx(
678 | G: Any,
679 | group_node_attrs: Optional[Union[List[str], Literal["all"]]] = None,
680 | group_edge_attrs: Optional[Union[List[str], Literal["all"]]] = None,
681 | ) -> "torch_geometric.data.Data":
682 | r"""Converts a :obj:`networkx.Graph` or :obj:`networkx.DiGraph` to a
683 | :class:`torch_geometric.data.Data` instance.
684 |
685 | Args:
686 | G (networkx.Graph or networkx.DiGraph): A networkx graph.
687 | group_node_attrs (List[str] or "all", optional): The node attributes to
688 | be concatenated and added to :obj:`data.x`. (default: :obj:`None`)
689 | group_edge_attrs (List[str] or "all", optional): The edge attributes to
690 | be concatenated and added to :obj:`data.edge_attr`.
691 | (default: :obj:`None`)
692 |
693 | .. note::
694 |
695 | All :attr:`group_node_attrs` and :attr:`group_edge_attrs` values must
696 | be numeric.
697 |
698 | Examples:
699 | >>> edge_index = torch.tensor([
700 | ... [0, 1, 1, 2, 2, 3],
701 | ... [1, 0, 2, 1, 3, 2],
702 | ... ])
703 | >>> data = Data(edge_index=edge_index, num_nodes=4)
704 | >>> g = to_networkx(data)
705 | >>> # A `Data` object is returned
706 | >>> from_networkx(g)
707 | Data(edge_index=[2, 6], num_nodes=4)
708 | """
709 |
710 | G = G.to_directed() if not nx.is_directed(G) else G
711 |
712 | mapping = dict(zip(G.nodes(), range(G.number_of_nodes())))
713 | edge_index = torch.empty((2, G.number_of_edges()), dtype=torch.long)
714 | for i, (src, dst) in enumerate(G.edges()):
715 | edge_index[0, i] = mapping[src]
716 | edge_index[1, i] = mapping[dst]
717 |
718 | data_dict: Dict[str, Any] = defaultdict(list)
719 | data_dict["edge_index"] = edge_index
720 |
721 | node_attrs: List[str] = []
722 | if G.number_of_nodes() > 0:
723 | node_attrs = list(next(iter(G.nodes(data=True)))[-1].keys())
724 |
725 | edge_attrs: List[str] = []
726 | if G.number_of_edges() > 0:
727 | edge_attrs = list(next(iter(G.edges(data=True)))[-1].keys())
728 |
729 | if group_node_attrs is not None and not isinstance(group_node_attrs, list):
730 | group_node_attrs = node_attrs
731 |
732 | if group_edge_attrs is not None and not isinstance(group_edge_attrs, list):
733 | group_edge_attrs = edge_attrs
734 |
735 | for i, (_, feat_dict) in enumerate(G.nodes(data=True)):
736 | if set(feat_dict.keys()) != set(node_attrs):
737 | print('i', i)
738 | print('feat_dict', feat_dict)
739 | print('node_attrs', node_attrs)
740 | raise ValueError("Not all nodes contain the same attributes")
741 | for key, value in feat_dict.items():
742 | data_dict[str(key)].append(value)
743 |
744 | for i, (_, _, feat_dict) in enumerate(G.edges(data=True)):
745 | if set(feat_dict.keys()) != set(edge_attrs):
746 | raise ValueError("Not all edges contain the same attributes")
747 | for key, value in feat_dict.items():
748 | key = f"edge_{key}" if key in node_attrs else key
749 | data_dict[str(key)].append(value)
750 |
751 | for key, value in G.graph.items():
752 | if key == "node_default" or key == "edge_default":
753 | continue # Do not load default attributes.
754 | key = f"graph_{key}" if key in node_attrs else key
755 | data_dict[str(key)] = value
756 |
757 | for key, value in data_dict.items():
758 | if isinstance(value, (tuple, list)) and isinstance(value[0], Tensor):
759 | data_dict[key] = torch.stack(value, dim=0, dtype=torch.float32)
760 | elif isinstance(value, (tuple, list)) and isinstance(value[0], np.ndarray):
761 | data_dict[key] = torch.tensor(np.stack(value), dtype=torch.float32)
762 | else:
763 | try:
764 | data_dict[key] = torch.as_tensor(np.array(value))
765 | except Exception:
766 | pass
767 |
768 | data = Data.from_dict(data_dict)
769 |
770 | if group_node_attrs is not None:
771 | xs = []
772 | for key in group_node_attrs:
773 | x = data[key]
774 | x = x.view(-1, 1) if x.dim() <= 1 else x
775 | xs.append(x)
776 | del data[key]
777 | data.x = torch.cat(xs, dim=-1)
778 |
779 | if group_edge_attrs is not None:
780 | xs = []
781 | for key in group_edge_attrs:
782 | key = f"edge_{key}" if key in node_attrs else key
783 | x = data[key]
784 | x = x.view(-1, 1) if x.dim() <= 1 else x
785 | xs.append(x)
786 | del data[key]
787 | data.edge_attr = torch.cat(xs, dim=-1)
788 |
789 | if data.x is None and data.pos is None:
790 | data.num_nodes = G.number_of_nodes()
791 |
792 | return data
793 |
794 |
795 |
796 | ##### scaffold splitting
797 |
798 | def scaffold_split(train_df, train_ratio=0.6, valid_ratio=0.15, test_ratio=0.25):
799 | """Splits a dataframe of molecules into scaffold-based clusters."""
800 | # Get smiles from the dataframe
801 | train_smiles_list = train_df["smiles"]
802 |
803 | indinces = list(range(len(train_smiles_list)))
804 | train_mol_list = [Chem.MolFromSmiles(smiles) for smiles in train_smiles_list]
805 | scaffold_names, _, batched_id = cluster_molecules_by_scaffold(train_mol_list, indinces, remove_single=False, flatten_id=False)
806 |
807 | train_cutoff = int(train_ratio * len(train_df))
808 | valid_cutoff = int(valid_ratio * len(train_df)) + train_cutoff
809 | train_inds, valid_inds, test_inds = [], [], []
810 | inds_all = deepcopy(batched_id)
811 | np.random.seed(3)
812 | np.random.shuffle(inds_all)
813 | idx_count = 0
814 | for inds_list in inds_all:
815 | for ind in inds_list:
816 | if idx_count < train_cutoff:
817 | train_inds.append(ind)
818 | elif idx_count < valid_cutoff:
819 | valid_inds.append(ind)
820 | else:
821 | test_inds.append(ind)
822 | idx_count += 1
823 |
824 | return train_inds, valid_inds, test_inds
--------------------------------------------------------------------------------
/dataset/prediction_molecule.py:
--------------------------------------------------------------------------------
1 | import ast
2 | import os
3 | import os.path as osp
4 | import json
5 | import pandas as pd
6 | import numpy as np
7 | import torch
8 |
9 | from torch_geometric.data import InMemoryDataset, Data
10 | from .data_utils import smiles2graph, scaffold_split
11 |
12 | class PygPredictionMoleculeDataset(InMemoryDataset):
13 | def __init__(
14 | self, name="chembl2k", root="raw_data", transform=None, pre_transform=None
15 | ):
16 | self.name = name
17 | self.root = osp.join(root, name)
18 | self.task_type = 'finetune'
19 |
20 | self.eval_metric = "roc_auc"
21 | if name == "chembl2k":
22 | self.num_tasks = 41
23 | self.start_column = 4
24 | elif name == "broad6k":
25 | self.num_tasks = 32
26 | self.start_column = 2
27 | elif name == "biogenadme":
28 | self.num_tasks = 6
29 | self.start_column = 4
30 | self.eval_metric = "avg_mae"
31 | elif name == "moltoxcast":
32 | self.num_tasks = 617
33 | self.start_column = 2
34 | else:
35 | meta_path = osp.join(self.root, "raw", "meta.json")
36 | if os.path.exists(meta_path):
37 | with open(meta_path, "r") as f:
38 | meta = json.load(f)
39 | self.num_tasks = meta["num_tasks"]
40 | self.start_column = meta["start_column"]
41 | self.eval_metric = meta["eval_metric"]
42 | else:
43 | raise ValueError("Invalid dataset name")
44 |
45 | super(PygPredictionMoleculeDataset, self).__init__(
46 | self.root, transform, pre_transform
47 | )
48 | self.data, self.slices = torch.load(self.processed_paths[0])
49 |
50 | def get_idx_split(self):
51 | path = osp.join(self.root, "split", "scaffold")
52 |
53 | if os.path.isfile(os.path.join(path, "split_dict.pt")):
54 | return torch.load(os.path.join(path, "split_dict.pt"))
55 | else:
56 | print("Initializing split...")
57 | data_df = pd.read_csv(osp.join(self.raw_dir, "assays.csv.gz"))
58 | train_idx, valid_idx, test_idx = scaffold_split(data_df)
59 | train_idx = torch.tensor(train_idx, dtype=torch.long)
60 | valid_idx = torch.tensor(valid_idx, dtype=torch.long)
61 | test_idx = torch.tensor(test_idx, dtype=torch.long)
62 | os.makedirs(path, exist_ok=True)
63 | torch.save(
64 | {"train": train_idx, "valid": valid_idx, "test": test_idx},
65 | os.path.join(path, "split_dict.pt"),
66 | )
67 | return {"train": train_idx, "valid": valid_idx, "test": test_idx}
68 |
69 | @property
70 | def raw_file_names(self):
71 | return ["assays.csv.gz"]
72 |
73 | @property
74 | def processed_file_names(self):
75 | return ["geometric_data_processed.pt"]
76 |
77 | def download(self):
78 | assert os.path.exists(
79 | os.path.join(self.raw_dir, "assays.csv.gz")
80 | ), f"assays.csv.gz does not exist in {self.raw_dir}"
81 |
82 | def process(self):
83 | data_df = pd.read_csv(osp.join(self.raw_dir, "assays.csv.gz"))
84 |
85 | pyg_graph_list = []
86 | for idx, row in data_df.iterrows():
87 | smiles = row["smiles"]
88 | graph = smiles2graph(smiles)
89 |
90 | g = Data()
91 | g.num_nodes = graph["num_nodes"]
92 | g.edge_index = torch.from_numpy(graph["edge_index"])
93 |
94 | del graph["num_nodes"]
95 | del graph["edge_index"]
96 |
97 | if graph["edge_feat"] is not None:
98 | g.edge_attr = torch.from_numpy(graph["edge_feat"])
99 | del graph["edge_feat"]
100 |
101 | if graph["node_feat"] is not None:
102 | g.x = torch.from_numpy(graph["node_feat"])
103 | del graph["node_feat"]
104 |
105 | try:
106 | g.fp = torch.tensor(graph["fp"], dtype=torch.int8).view(1, -1)
107 | del graph["fp"]
108 | except:
109 | pass
110 |
111 | y = []
112 | for col in range(self.start_column, len(row)):
113 | y.append(float(row.iloc[col]))
114 |
115 | g.y = torch.tensor(y, dtype=torch.float32).view(1, -1)
116 | pyg_graph_list.append(g)
117 |
118 | pyg_graph_list = (
119 | pyg_graph_list
120 | if self.pre_transform is None
121 | else self.pre_transform(pyg_graph_list)
122 | )
123 | print("Saving...")
124 | torch.save(self.collate(pyg_graph_list), self.processed_paths[0])
125 |
126 | def __repr__(self):
127 | return "{}()".format(self.__class__.__name__)
128 |
129 |
130 | class PredictionMoleculeDataset(object):
131 | def __init__(self, name="chembl2k", root="raw_data", transform="smiles"):
132 |
133 | assert transform in [
134 | "smiles",
135 | "fingerprint",
136 | "morphology",
137 | "expression",
138 | ], "Invalid transform type"
139 |
140 | self.name = name
141 | self.folder = osp.join(root, name)
142 | self.transform = transform
143 | self.raw_data = os.path.join(self.folder, "raw", "assays.csv.gz")
144 |
145 | self.eval_metric = "roc_auc"
146 | if name == "chembl2k":
147 | self.num_tasks = 41
148 | self.start_column = 4
149 | elif name == "broad6k":
150 | self.num_tasks = 32
151 | self.start_column = 2
152 | elif "moltoxcast" in self.name:
153 | self.num_tasks = 617
154 | self.start_column = 2
155 | elif name == "biogenadme":
156 | self.num_tasks = 6
157 | self.start_column = 4
158 | self.eval_metric = "avg_mae"
159 | else:
160 | meta_path = osp.join(self.folder, "raw", "meta.json")
161 | if os.path.exists(meta_path):
162 | with open(meta_path, "r") as f:
163 | meta = json.load(f)
164 | self.num_tasks = meta["num_tasks"]
165 | self.start_column = meta["start_column"]
166 | else:
167 | raise ValueError("Invalid dataset name")
168 |
169 | super(PredictionMoleculeDataset, self).__init__()
170 | if transform == "smiles":
171 | self.prepare_smiles()
172 | elif transform == "fingerprint":
173 | self.prepare_fingerprints()
174 | elif transform in ["morphology", "expression"]:
175 | self.prepare_other_modality()
176 |
177 | def get_idx_split(self, to_list=False):
178 | path = osp.join(self.folder, "split", "scaffold")
179 | if os.path.isfile(os.path.join(path, "split_dict.pt")):
180 | split_dict = torch.load(os.path.join(path, "split_dict.pt"))
181 | else:
182 | data_df = pd.read_csv(self.raw_data)
183 | train_idx, valid_idx, test_idx = scaffold_split(data_df)
184 | train_idx = torch.tensor(train_idx, dtype=torch.long)
185 | valid_idx = torch.tensor(valid_idx, dtype=torch.long)
186 | test_idx = torch.tensor(test_idx, dtype=torch.long)
187 |
188 | os.makedirs(path, exist_ok=True)
189 | torch.save(
190 | {"train": train_idx, "valid": valid_idx, "test": test_idx},
191 | os.path.join(path, "split_dict.pt"),
192 | )
193 | split_dict = {"train": train_idx, "valid": valid_idx, "test": test_idx}
194 |
195 | if to_list:
196 | split_dict = {k: v.tolist() for k, v in split_dict.items()}
197 | return split_dict
198 |
199 | def prepare_other_modality(self):
200 | assert os.path.exists(
201 | self.raw_data
202 | ), f" {self.raw_data} assays.csv.gz does not exist"
203 | data_df = pd.read_csv(self.raw_data)
204 |
205 | processed_dir = osp.join(self.folder, "processed")
206 | os.makedirs(processed_dir, exist_ok=True)
207 |
208 | if self.transform == "morphology":
209 | if self.name == "chembl2k":
210 | feature_df = pd.read_csv(
211 | os.path.join(self.folder, "raw", "CP-JUMP.csv.gz"),
212 | compression="gzip",
213 | )
214 | feature_arr = np.load(
215 | os.path.join(self.folder, "raw", "CP-JUMP_feature.npz")
216 | )["data"]
217 | else:
218 | feature_df = pd.read_csv(
219 | os.path.join(self.folder, "raw", "CP-Bray.csv.gz"),
220 | compression="gzip",
221 | )
222 | feature_arr = np.load(
223 | os.path.join(self.folder, "raw", "CP-Bray_feature.npz")
224 | )["data"]
225 | else:
226 | feature_df = pd.read_csv(
227 | os.path.join(self.folder, "raw", "GE.csv.gz"), compression="gzip"
228 | )
229 | feature_arr = np.load(os.path.join(self.folder, "raw", "GE_feature.npz"))[
230 | "data"
231 | ]
232 |
233 | if not osp.exists(osp.join(processed_dir, f"processed_{self.transform}.pt")):
234 | x_list = []
235 | y_list = []
236 | feature_dim = feature_arr.shape[1]
237 | for idx, row in data_df.iterrows():
238 | if len(feature_df[feature_df["inchikey"] == row["inchikey"]]) == 0:
239 | x_list.append(torch.tensor([float("nan")] * feature_dim))
240 | else:
241 | x_tensor = torch.tensor(
242 | feature_arr[
243 | feature_df[
244 | feature_df["inchikey"] == row["inchikey"]
245 | ].index.tolist()[0]
246 | ],
247 | dtype=torch.float32,
248 | )
249 | x_list.append(x_tensor)
250 |
251 | y = []
252 | for col in range(self.start_column, len(row)):
253 | y.append(float(row.iloc[col]))
254 | y = torch.tensor(y, dtype=torch.float32)
255 | y_list.append(y)
256 |
257 | x_list = torch.stack(x_list, dim=0)
258 | y_list = torch.stack(y_list, dim=0)
259 | torch.save(
260 | (x_list, y_list),
261 | osp.join(processed_dir, f"processed_{self.transform}.pt"),
262 | )
263 | else:
264 | x_list, y_list = torch.load(
265 | osp.join(processed_dir, f"processed_{self.transform}.pt")
266 | )
267 |
268 | self.data = x_list
269 | self.labels = y_list
270 |
271 | def prepare_smiles(self):
272 | assert os.path.exists(
273 | self.raw_data
274 | ), f" {self.raw_data} assays.csv.gz does not exist"
275 | data_df = pd.read_csv(self.raw_data)
276 |
277 | processed_dir = osp.join(self.folder, "processed")
278 | os.makedirs(processed_dir, exist_ok=True)
279 | x_list = []
280 | y_list = []
281 | for idx, row in data_df.iterrows():
282 | smiles = row["smiles"]
283 | x_list.append(smiles)
284 | y = []
285 | for col in range(self.start_column, len(row)):
286 | y.append(float(row.iloc[col]))
287 | y = torch.tensor(y, dtype=torch.float32)
288 | y_list.append(y)
289 |
290 | self.data = x_list
291 | self.labels = y_list
292 |
293 | def prepare_fingerprints(self):
294 | assert os.path.exists(
295 | self.raw_data
296 | ), f" {self.raw_data} assays.csv.gz does not exist"
297 | data_df = pd.read_csv(self.raw_data)
298 |
299 | processed_dir = osp.join(self.folder, "processed")
300 | os.makedirs(processed_dir, exist_ok=True)
301 |
302 | if not osp.exists(osp.join(processed_dir, "processed_fp.pt")):
303 | print("Processing fingerprints...")
304 | from rdkit import Chem
305 | from rdkit.Chem import AllChem
306 |
307 | x_list = []
308 | y_list = []
309 | for idx, row in data_df.iterrows():
310 | smiles = row["smiles"]
311 | mol = Chem.MolFromSmiles(smiles)
312 | x = torch.tensor(
313 | list(AllChem.GetMorganFingerprintAsBitVect(mol, 2)),
314 | dtype=torch.float32,
315 | )
316 | x_list.append(x)
317 | y = []
318 | for col in range(self.start_column, len(row)):
319 | y.append(float(row.iloc[col]))
320 | y = torch.tensor(y, dtype=torch.float32)
321 | y_list.append(y)
322 |
323 | x_list = torch.stack(x_list, dim=0)
324 | y_list = torch.stack(y_list, dim=0)
325 | torch.save((x_list, y_list), osp.join(processed_dir, "processed_fp.pt"))
326 | else:
327 | x_list, y_list = torch.load(osp.join(processed_dir, "processed_fp.pt"))
328 |
329 | self.data = x_list
330 | self.labels = y_list
331 |
332 | def __getitem__(self, idx):
333 | """Get datapoint(s) with index(indices)"""
334 |
335 | if isinstance(idx, (int, np.integer)):
336 | return self.data[idx], self.labels[idx]
337 | elif isinstance(idx, (list, np.ndarray)):
338 | return [self.data[i] for i in idx], [self.labels[i] for i in idx]
339 | elif isinstance(idx, torch.LongTensor):
340 | return self.data[idx], self.labels[idx]
341 |
342 | raise IndexError("Not supported index {}.".format(type(idx).__name__))
343 |
344 | def __len__(self):
345 | return len(self.data)
346 |
347 | def __repr__(self):
348 | return "{}({})".format(self.__class__.__name__, len(self))
349 |
350 |
351 | if __name__ == "__main__":
352 | pass
353 |
--------------------------------------------------------------------------------
/dataset/pretrain_context.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import pickle
4 | import logging
5 | import pandas as pd
6 | import os.path as osp
7 | from torch_geometric.data import InMemoryDataset
8 |
9 | from .data_utils import create_nx_graph, from_networkx
10 |
11 | # from data_utils import create_nx_graph, from_networkx
12 |
13 | logger = logging.getLogger(__name__)
14 |
15 | class PretrainContextDataset(InMemoryDataset):
16 | def __init__(
17 | self, name="pretrain", root="raw_data", transform=None, pre_transform=None
18 | ):
19 | """
20 | - name (str): name of the pretraining dataset: pretrain
21 | - root (str): root directory to store the dataset folder
22 | - transform, pre_transform (optional): transform/pre-transform graph objects
23 | """
24 | self.name = name
25 | self.dir_name = "_".join(name.split("-"))
26 | self.original_root = root
27 | self.root = osp.join(root, self.dir_name)
28 | self.processed_root = osp.join(osp.abspath(self.root))
29 | self.threshold = pre_transform
30 | self.nxg_name = f"nxg_mint{pre_transform}"
31 |
32 | super(PretrainContextDataset, self).__init__(self.processed_root, None, None)
33 | self.data, self.slices = torch.load(self.processed_paths[0])
34 |
35 | def processed_file_names(self):
36 | if self.threshold is not None:
37 | return [f"context_{self.nxg_name}.pt"]
38 | else:
39 | return ["context_data_processed.pt"]
40 |
41 | def process(self):
42 | threshold = self.threshold if self.threshold is not None else 0.6
43 | folder = osp.join(self.root, "raw")
44 | nxg_name = f"{folder}/{self.nxg_name}.pickle"
45 | if not os.path.exists(nxg_name):
46 | G = create_nx_graph(folder, min_thres=threshold, top_compound_gene_express=0.01)
47 | with open(nxg_name, "wb") as f:
48 | pickle.dump(G, f)
49 | else:
50 | G = pd.read_pickle(nxg_name)
51 |
52 | pyg_graph = from_networkx(G)
53 | torch.save(self.collate([pyg_graph]), self.processed_paths[0])
54 |
55 |
56 | # main file
57 | if __name__ == "__main__":
58 | pass
--------------------------------------------------------------------------------
/dataset/pretrain_molecule.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import logging
4 | import pandas as pd
5 | import os.path as osp
6 | from torch_geometric.data import InMemoryDataset
7 | from sklearn.model_selection import train_test_split
8 |
9 | from .data_utils import read_graph_list
10 |
11 | logger = logging.getLogger(__name__)
12 |
13 | class PretrainMoleculeDataset(InMemoryDataset):
14 | def __init__(self, name='pretrain', root ='raw_data', transform=None, pre_transform = None):
15 | '''
16 | - name (str): name of the pretraining dataset: pretrain_all
17 | - root (str): root directory to store the dataset folder
18 | - transform, pre_transform (optional): transform/pre-transform graph objects
19 | '''
20 | self.name = name
21 | self.dir_name = '_'.join(name.split('-'))
22 | self.original_root = root
23 | self.root = osp.join(root, self.dir_name)
24 | self.processed_root = osp.join(osp.abspath(self.root))
25 |
26 | self.num_tasks = 1
27 | self.eval_metric = 'customize'
28 | self.task_type = 'pretrain'
29 | self.__num_classes__ = '-1'
30 | self.binary = 'False'
31 |
32 | super(PretrainMoleculeDataset, self).__init__(self.processed_root, transform, pre_transform)
33 | self.data, self.slices = torch.load(self.processed_paths[0])
34 | self.total_data_len = self.__len__()
35 |
36 | def get_idx_split(self):
37 | full_idx = list(range(self.total_data_len))
38 | train_idx, valid_idx, test_idx = full_idx, [], []
39 | return {'train': torch.tensor(train_idx, dtype = torch.long), 'valid': torch.tensor(valid_idx, dtype = torch.long), 'test': torch.tensor(test_idx, dtype = torch.long)}
40 |
41 | @property
42 | def processed_file_names(self):
43 | return ['mol_data_processed.pt']
44 |
45 | def process(self):
46 |
47 | mol_data_path = osp.join(self.root, 'raw', 'structure.csv.gz')
48 | print('Processing molecule data at folder: ' , mol_data_path)
49 |
50 | mol_df = pd.read_csv(mol_data_path, compression='gzip')
51 | mol_df = mol_df.drop_duplicates(subset="mol_id")
52 | data_list = read_graph_list(mol_df, keep_id=True)
53 |
54 | self.total_data_len = len(data_list)
55 |
56 | print('Pretrain molecule data loading finished with length ', self.total_data_len)
57 |
58 | if self.pre_transform is not None:
59 | data_list = [self.pre_transform(data) for data in data_list]
60 | data, slices = self.collate(data_list)
61 | torch.save((data, slices), self.processed_paths[0])
62 |
63 |
64 |
65 |
66 | # main file
67 | if __name__ == "__main__":
68 | pass
--------------------------------------------------------------------------------
/dataset/retrieval.py:
--------------------------------------------------------------------------------
1 | import ast
2 | import os
3 | import os.path as osp
4 |
5 | import pandas as pd
6 | import numpy as np
7 | import torch
8 |
9 | from torch_geometric.data import InMemoryDataset, Data
10 | from .data_utils import smiles2graph, scaffold_split
11 | # from data_utils import smiles2graph, scaffold_split
12 |
13 | class PygRetrievalMoleculeDataset(InMemoryDataset):
14 | def __init__(
15 | self, name="chembl2k", root="raw_data", transform=None, pre_transform=None
16 | ):
17 | self.name = name
18 | self.root = osp.join(root, name)
19 | self.task_type = 'ranking'
20 |
21 | self.eval_metric = "rank"
22 | if name == "chembl2k":
23 | self.num_tasks = 41
24 | self.start_column = 4
25 | self.target_file = 'CP-JUMP_feature.npz'
26 | self.mol_file = 'CP-JUMP.csv.gz'
27 | elif name == "broad6k":
28 | self.num_tasks = 32
29 | self.start_column = 2
30 | self.target_file = 'CP-Bray_feature.npz'
31 | self.mol_file = 'CP-Bray.csv.gz'
32 | else:
33 | raise ValueError("Invalid dataset name")
34 |
35 | super(PygRetrievalMoleculeDataset, self).__init__(
36 | self.root, transform, pre_transform
37 | )
38 | self.data, self.slices = torch.load(self.processed_paths[0])
39 |
40 | @property
41 | def raw_file_names(self):
42 | return [self.target_file, self.mol_file]
43 |
44 | @property
45 | def processed_file_names(self):
46 | return [f"ret_processed_{self.name}_pyg.pt"]
47 |
48 | def download(self):
49 | assert os.path.exists(
50 | osp.join(self.raw_dir, self.target_file)
51 | ), f" {osp.join(self.raw_dir, self.target_file)} does not exist"
52 | assert os.path.exists(
53 | osp.join(self.raw_dir, self.mol_file)
54 | ), f" {osp.join(self.raw_dir, self.mol_file)} does not exist"
55 |
56 | def process(self):
57 | target_reps = np.load(osp.join(self.raw_dir, self.target_file))['data']
58 | mol_df = pd.read_csv(osp.join(self.raw_dir, self.mol_file))
59 |
60 | if self.name == 'broad6k':
61 | for i in range(target_reps.shape[1]):
62 | mol_df['feature_'+str(i)] = target_reps[:, i]
63 | smiles_dict = mol_df.drop_duplicates('inchikey').set_index('inchikey')['smiles'].to_dict()
64 | feature_cols = [col for col in mol_df.columns if 'feature_' in col]
65 | df_subset = mol_df[['inchikey'] + feature_cols]
66 | df_subset = df_subset.groupby('inchikey').median().reset_index()
67 | df_subset['smiles'] = df_subset['inchikey'].map(smiles_dict)
68 |
69 | target_reps = df_subset[feature_cols].values
70 | mol_df = df_subset.drop(columns=feature_cols)
71 |
72 | mol_smiles_list = mol_df['smiles'].tolist()
73 | target_reps = torch.tensor(target_reps, dtype=torch.float32)
74 |
75 | pyg_graph_list = []
76 | for idx, smiles in enumerate(mol_smiles_list):
77 | graph = smiles2graph(smiles)
78 |
79 | g = Data()
80 | g.num_nodes = graph["num_nodes"]
81 | g.edge_index = torch.from_numpy(graph["edge_index"])
82 |
83 | del graph["num_nodes"]
84 | del graph["edge_index"]
85 |
86 | if graph["edge_feat"] is not None:
87 | g.edge_attr = torch.from_numpy(graph["edge_feat"])
88 | del graph["edge_feat"]
89 |
90 | if graph["node_feat"] is not None:
91 | g.x = torch.from_numpy(graph["node_feat"])
92 | del graph["node_feat"]
93 |
94 | try:
95 | g.fp = torch.tensor(graph["fp"], dtype=torch.int8).view(1, -1)
96 | del graph["fp"]
97 | except:
98 | pass
99 |
100 | g.target = target_reps[idx].view(1, -1)
101 |
102 | pyg_graph_list.append(g)
103 |
104 | print("Saving...")
105 | torch.save(self.collate(pyg_graph_list), self.processed_paths[0])
106 |
107 | def __repr__(self):
108 | return "{}()".format(self.__class__.__name__)
109 |
110 |
111 | class RetrievalMoleculeDataset(object):
112 | def __init__(self, name="chembl2k", root="raw_data", transform="smiles"):
113 |
114 | assert transform in [
115 | "smiles",
116 | "fingerprint",
117 | "pyg",
118 | ], "Invalid transform type"
119 |
120 | self.name = name
121 | self.folder = osp.join(root, name)
122 | self.transform = transform
123 |
124 | self.eval_metric = "rank"
125 | if name == "chembl2k":
126 | self.num_tasks = 41
127 | self.start_column = 4
128 | target_file = 'CP-JUMP_feature.npz'
129 | mol_file = 'CP-JUMP.csv.gz'
130 | elif name == "broad6k":
131 | self.num_tasks = 32
132 | self.start_column = 2
133 | target_file = 'CP-Bray_feature.npz'
134 | mol_file = 'CP-Bray.csv.gz'
135 | else:
136 | raise ValueError("Invalid dataset name")
137 |
138 | self.target_file = os.path.join(self.folder, "raw", target_file)
139 | self.mol_file = os.path.join(self.folder, "raw", mol_file)
140 |
141 | super(RetrievalMoleculeDataset, self).__init__()
142 | self.prepare_data()
143 |
144 | def prepare_data(self):
145 | processed_dir = osp.join(self.folder, "processed")
146 | processed_file = osp.join(processed_dir, f"ret_processed_{self.name}_{self.transform}.pt")
147 | if not osp.exists(processed_file):
148 | assert os.path.exists(
149 | self.target_file
150 | ), f" {self.target_file} does not exist"
151 | target_reps = np.load(self.target_file)['data']
152 | assert os.path.exists(
153 | self.mol_file
154 | ), f" {self.mol_file} does not exist"
155 | mol_df = pd.read_csv(self.mol_file)
156 |
157 | if self.name == 'broad6k':
158 | for i in range(target_reps.shape[1]):
159 | mol_df['feature_'+str(i)] = target_reps[:, i]
160 | smiles_dict = mol_df.drop_duplicates('inchikey').set_index('inchikey')['smiles'].to_dict()
161 | feature_cols = [col for col in mol_df.columns if 'feature_' in col]
162 | df_subset = mol_df[['inchikey'] + feature_cols]
163 | df_subset = df_subset.groupby('inchikey').median().reset_index()
164 | df_subset['smiles'] = df_subset['inchikey'].map(smiles_dict)
165 |
166 | target_reps = df_subset[feature_cols].values
167 | mol_df = df_subset.drop(columns=feature_cols)
168 |
169 | mol_smiles_list = mol_df['smiles'].tolist()
170 | inchikey_list = mol_df['inchikey'].tolist()
171 | target_reps = torch.tensor(target_reps, dtype=torch.float32)
172 |
173 | if self.transform == "smiles":
174 | self.data = mol_smiles_list
175 | self.target = target_reps
176 | self.inchikey_list = inchikey_list
177 | torch.save((mol_smiles_list, target_reps, inchikey_list), processed_file)
178 | elif self.transform == "fingerprint":
179 | from rdkit import Chem
180 | from rdkit.Chem import AllChem
181 | x_list = []
182 | for smiles in mol_smiles_list:
183 | mol = Chem.MolFromSmiles(smiles)
184 | x = torch.tensor(
185 | list(AllChem.GetMorganFingerprintAsBitVect(mol, 2)),
186 | dtype=torch.float32,
187 | )
188 | x_list.append(x)
189 | x_list = torch.stack(x_list, dim=0)
190 | self.data = x_list
191 | self.target = target_reps
192 | self.inchikey_list = inchikey_list
193 | torch.save((x_list, target_reps, inchikey_list), processed_file)
194 | else:
195 | raise ValueError("Invalid transform type")
196 | else:
197 | self.data, self.target, self.inchikey_list = torch.load(processed_file)
198 |
199 | def __getitem__(self, idx):
200 | """Get datapoint(s) with index(indices)"""
201 |
202 | if isinstance(idx, slice):
203 | return self.data[idx], self.target[idx], self.inchikey_list[idx]
204 |
205 | raise IndexError("Not supported index {}.".format(type(idx).__name__))
206 |
207 | def __len__(self):
208 | return len(self.data)
209 |
210 | def __repr__(self):
211 | return "{}({})".format(self.__class__.__name__, len(self))
212 |
213 |
214 | if __name__ == "__main__":
215 | pass
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | warnings.filterwarnings("ignore", category=UserWarning)
4 |
5 | import math
6 | import logging
7 |
8 | import numpy as np
9 | import torch
10 | import torch.optim as optim
11 | from torch.optim.lr_scheduler import LambdaLR
12 | from torch_geometric.loader import DataLoader
13 | from torch.nn.utils import parameters_to_vector, vector_to_parameters
14 |
15 | from configures.arguments import (
16 | load_arguments_from_yaml,
17 | save_arguments_to_yaml,
18 | get_args,
19 | )
20 | from dataset.create_datasets import get_data
21 | from utils import validate, init_weights, save_prediction
22 | from utils.train_funcs import pretrain_func, finetune_func
23 |
24 |
25 | def get_logger(name, logfile=None):
26 | """create a nice logger"""
27 | logger = logging.getLogger(name)
28 | # clear handlers if they were created in other runs
29 | if logger.hasHandlers():
30 | logger.handlers.clear()
31 | logger.setLevel(logging.DEBUG)
32 | # create formatter
33 | formatter = logging.Formatter("%(asctime)s - %(message)s")
34 | # create console handler add add to logger
35 | ch = logging.StreamHandler()
36 | ch.setLevel(logging.DEBUG)
37 | ch.setFormatter(formatter)
38 | logger.addHandler(ch)
39 | # create file handler add add to logger when name is not None
40 | if logfile is not None:
41 | fh = logging.FileHandler(logfile)
42 | fh.setFormatter(formatter)
43 | fh.setLevel(logging.DEBUG)
44 | logger.addHandler(fh)
45 | logger.propagate = False
46 | return logger
47 |
48 |
49 | def seed_torch(seed=0):
50 | print("Seed", seed)
51 | torch.manual_seed(seed)
52 | torch.cuda.manual_seed(seed)
53 | torch.backends.cudnn.benchmark = False
54 | torch.backends.cudnn.deterministic = True
55 |
56 |
57 | def get_cosine_schedule_with_warmup(optimizer,
58 | num_warmup_steps,
59 | num_training_steps,
60 | num_cycles=7./16.,
61 | last_epoch=-1):
62 | def _lr_lambda(current_step):
63 | if current_step < num_warmup_steps:
64 | return float(current_step) / float(max(1, num_warmup_steps))
65 | no_progress = float(current_step - num_warmup_steps) / \
66 | float(max(1, num_training_steps - num_warmup_steps))
67 | return max(0, math.cos(math.pi * num_cycles * no_progress))
68 |
69 | return LambdaLR(optimizer, _lr_lambda, last_epoch)
70 |
71 |
72 | def main(args, seed):
73 | device = torch.device("cuda", args.gpu_id)
74 | args.n_gpu = torch.cuda.device_count()
75 | args.device = device
76 |
77 | if args.dataset == "pretrain":
78 | dataset, context_graph = get_data(args, "./raw_data", transform="pyg")
79 | context_graph = context_graph[0]
80 | else:
81 | dataset = get_data(args, "./raw_data", transform="pyg")
82 | context_graph = None
83 |
84 | split_idx = dataset.get_idx_split()
85 | args.num_trained = len(split_idx["train"])
86 | args.task_type = dataset.task_type
87 | args.steps = args.num_trained // args.batch_size + 1
88 |
89 | train_loader = DataLoader(
90 | dataset[split_idx["train"]],
91 | batch_size=args.batch_size,
92 | shuffle=True,
93 | num_workers=args.num_workers,
94 | )
95 |
96 | if args.dataset == "pretrain":
97 | from models.gnn import GNN
98 | from torch.distributions import Normal, Independent
99 |
100 | test_loader = None
101 | model = GNN(
102 | gnn_type=args.model,
103 | num_tasks=dataset.num_tasks,
104 | num_layer=args.num_layer,
105 | emb_dim=args.emb_dim,
106 | drop_ratio=args.drop_ratio,
107 | graph_pooling=args.readout,
108 | norm_layer=args.norm_layer,
109 | ).to(device)
110 | init_weights(model, args.initw_name, init_gain=0.02)
111 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wdecay)
112 |
113 | prior_mu = torch.zeros(args.emb_dim).to(device)
114 | prior_sigma = torch.ones(args.emb_dim).to(device)
115 | args.prior_dist = Independent(Normal(loc=prior_mu, scale=prior_sigma), 1)
116 |
117 | elif args.dataset.startswith("finetune"):
118 | from models.gnn import FineTuneGNN
119 |
120 | valid_loader = DataLoader(
121 | dataset[split_idx["valid"]],
122 | batch_size=args.batch_size,
123 | shuffle=False,
124 | num_workers=args.num_workers,
125 | )
126 | test_loader = DataLoader(
127 | dataset[split_idx["test"]],
128 | batch_size=args.batch_size,
129 | shuffle=False,
130 | num_workers=args.num_workers,
131 | )
132 | model = FineTuneGNN(
133 | gnn_type=args.model,
134 | num_tasks=dataset.num_tasks,
135 | num_layer=args.num_layer,
136 | emb_dim=args.emb_dim,
137 | drop_ratio=args.drop_ratio,
138 | graph_pooling=args.readout,
139 | norm_layer=args.norm_layer,
140 | ).to(device)
141 | model.load_pretrained_graph_encoder(args.model_path)
142 | model.freeze_graph_encoder()
143 | optimizer = optim.Adam(
144 | model.task_decoder.parameters(), lr=args.lr, weight_decay=args.wdecay
145 | )
146 | else:
147 | raise ValueError("Invalid dataset name")
148 |
149 | # scheduler = None
150 | scheduler = get_cosine_schedule_with_warmup(optimizer, 0, args.epochs * args.steps)
151 |
152 | logging.warning(f"device: {args.device}, " f"n_gpu: {args.n_gpu}, ")
153 | logger.info(dict(args._get_kwargs()))
154 | logger.info(model)
155 | logger.info("***** Running training *****")
156 | logger.info(
157 | f" Task = {args.dataset}@{args.num_trained}/{len(split_idx['valid'])}/{len(split_idx['test'])}"
158 | )
159 | logger.info(f" Num Epochs = {args.epochs}")
160 | logger.info(f" Total train batch size = {args.batch_size}")
161 | logger.info(f" Total optimization steps = {args.epochs * args.steps}")
162 |
163 | train_loaders = {"train_iter": iter(train_loader), "train_loader": train_loader}
164 |
165 | best_train, best_valid, best_test, best_count = None, None, None, None
166 | best_epoch = 0
167 | loss_tots = []
168 | if args.dataset == "pretrain": # later args.finetune
169 | for epoch in range(0, args.epochs):
170 | loss, train_loaders = pretrain_func(
171 | args, model, train_loaders, context_graph, optimizer, scheduler, epoch
172 | )
173 | loss_tots.append(loss)
174 | if epoch == args.epochs - 1:
175 | torch.save(model.state_dict(), args.model_path)
176 | yaml_path = args.model_path.replace(".pt", ".yaml")
177 | save_arguments_to_yaml(args, yaml_path, model_only=True)
178 | logger.info(
179 | f"Finished Training \n Model saved at {args.model_path} and Arguments saved at {yaml_path} with loss {loss_tots}"
180 | )
181 |
182 | elif args.dataset.startswith("finetune"):
183 | args.task_type = (
184 | "regression" if "mae" in dataset.eval_metric else "classification"
185 | )
186 | best_params = None
187 | for epoch in range(0, args.epochs):
188 | train_loaders = finetune_func(
189 | args, model, train_loaders, optimizer, scheduler, epoch
190 | )
191 | valid_perf = validate(args, model, valid_loader)
192 |
193 | if epoch > 0:
194 | is_improved = (
195 | valid_perf[dataset.eval_metric] < best_valid
196 | if args.task_type == "regression"
197 | else valid_perf[dataset.eval_metric] > best_valid
198 | )
199 | if epoch == 0 or is_improved:
200 | train_perf = validate(args, model, train_loader)
201 | test_perf = validate(args, model, test_loader)
202 | best_params = parameters_to_vector(model.parameters())
203 | best_valid = valid_perf[dataset.eval_metric]
204 | best_test = test_perf[dataset.eval_metric]
205 | best_train = train_perf[dataset.eval_metric]
206 | best_epoch = epoch
207 | best_count = test_perf.get("count", None)
208 | if best_count is None:
209 | best_count = test_perf.get("mae_list", None)
210 | if not args.no_print:
211 | logger.info(
212 | "Update Epoch {}: best_train: {:.4f} best_valid: {:.4f}, best_test: {:.4f}".format(
213 | epoch, best_train, best_valid, best_test
214 | )
215 | )
216 | if best_count is not None and args.task_type == "classification":
217 | outstr = "Best Count: "
218 | for key, value in best_count.items():
219 | sum_num = int(np.nansum(value))
220 | nan_num = sum(np.isnan(value))
221 | outstr += f"{key}: {sum_num/len(value):.4f} (nan {sum(np.isnan(value))} / {len(value)}), "
222 | logger.info(outstr)
223 | else:
224 | if not args.no_print:
225 | logger.info(
226 | "Epoch {}: best_valid: {:.4f}, current_valid: {:.4f}, patience: {}/{}".format(
227 | epoch,
228 | best_valid,
229 | valid_perf[dataset.eval_metric],
230 | epoch - best_epoch,
231 | args.patience,
232 | )
233 | )
234 | if epoch - best_epoch > args.patience:
235 | break
236 |
237 | logger.info(
238 | "Finished. \n {}-{} Best validation epoch {} with metric {}, train {:.4f}, valid {:.4f}, test {:.4f}".format(
239 | args.dataset, args.pretrain_name, best_epoch, dataset.eval_metric, best_train, best_valid, best_test
240 | )
241 | )
242 | vector_to_parameters(best_params, model.parameters())
243 | save_prediction(model, device, test_loader, dataset, args.output_dir, seed)
244 |
245 | return (
246 | args.pretrain_name,
247 | args.dataset,
248 | dataset.eval_metric,
249 | best_train,
250 | best_valid,
251 | best_test,
252 | best_epoch,
253 | best_count,
254 | )
255 |
256 |
257 | if __name__ == "__main__":
258 | import os
259 | import pandas as pd
260 |
261 | args = get_args()
262 | log_path = args.model_path.replace(".pt", ".log")
263 |
264 | pretrain_name = args.model_path.split("/")[-1]
265 | pretrain_name = pretrain_name.split(".")[0]
266 | args.pretrain_name = pretrain_name
267 |
268 | if args.dataset.startswith("finetune"):
269 | args.output_dir = f"output/{args.dataset}/{pretrain_name}"
270 | yaml_path = args.model_path.replace(".pt", ".yaml")
271 |
272 | # Check if args.model_path exists
273 | if not os.path.exists(args.model_path):
274 | from huggingface_hub import hf_hub_download
275 | os.makedirs(os.path.dirname(args.model_path), exist_ok=True)
276 | hf_hub_download(repo_id="liuganghuggingface/InfoAlign-Pretrained",
277 | filename="pretrain.pt",
278 | local_dir=os.path.dirname(args.model_path),
279 | local_dir_use_symlinks=False)
280 | config_path = hf_hub_download(repo_id="liuganghuggingface/InfoAlign-Pretrained",
281 | filename="config.yaml",
282 | local_dir=os.path.dirname(yaml_path),
283 | local_dir_use_symlinks=False)
284 |
285 | # Rename the downloaded config file to pretrain.yaml
286 | new_yaml_path = os.path.join(os.path.dirname(yaml_path), "pretrain.yaml")
287 | os.rename(config_path, new_yaml_path)
288 |
289 | print('args.model_path', args.model_path)
290 | print('args.yaml_path', yaml_path)
291 |
292 | config = load_arguments_from_yaml(yaml_path, model_only=True)
293 | for arg, value in config.items():
294 | setattr(args, arg, value)
295 | log_path = log_path + ".finetune"
296 | else:
297 | log_path = log_path + ".pretrain"
298 |
299 | # Define the repository ID and local directory
300 | repo_id = "liuganghuggingface/InfoAlign-Data"
301 | local_dir = "raw_data/pretrain/raw"
302 |
303 | # Check if the local directory exists
304 | if not os.path.exists(local_dir):
305 | from huggingface_hub import hf_hub_download, HfApi
306 | import os
307 |
308 | os.makedirs(local_dir, exist_ok=True)
309 |
310 | try:
311 | # Use HfApi to list files
312 | api = HfApi()
313 | all_files = api.list_repo_files(repo_id, repo_type="dataset")
314 |
315 | # Filter for files in the pretrain_raw folder
316 | pretrain_raw_files = [f for f in all_files if f.startswith("pretrain_raw/")]
317 |
318 | # Download each file
319 | for file in pretrain_raw_files:
320 | # Extract the filename from the path
321 | filename = os.path.basename(file)
322 |
323 | hf_hub_download(repo_id=repo_id,
324 | filename=file,
325 | repo_type="dataset",
326 | local_dir=local_dir,
327 | local_dir_use_symlinks=False)
328 |
329 | # Rename the file to remove the 'pretrain_raw/' prefix
330 | old_path = os.path.join(local_dir, file)
331 | new_path = os.path.join(local_dir, filename)
332 | os.rename(old_path, new_path)
333 |
334 | print(f"Downloaded {len(pretrain_raw_files)} files to {local_dir}")
335 | except Exception as e:
336 | print(f"Error downloading dataset: {str(e)}")
337 | print("Please check your internet connection and ensure you have the necessary permissions.")
338 | print("If the issue persists, you may need to log in using `huggingface-cli login`")
339 | else:
340 | print(f"Directory {local_dir} already exists. Skipping download.")
341 |
342 | # logger = get_logger(__name__, logfile=log_path)
343 | logger = get_logger(__name__)
344 | args.logger = logger
345 | print(vars(args))
346 |
347 | if args.dataset.startswith("pretrain"):
348 | main(args, 0)
349 | else:
350 | df = pd.DataFrame()
351 | for i in range(5):
352 | model, dataset, metric, train, valid, test, epoch, count = main(args, i)
353 | if "auc" in metric:
354 | new_results = {
355 | "model": model,
356 | "dataset": dataset,
357 | "seed": i,
358 | "metric": metric,
359 | "train": train,
360 | "valid": valid,
361 | "test": test,
362 | "epoch": epoch,
363 | "suc_80": round(np.nansum(count[80]) / len(count[80]), 4),
364 | "suc_85": round(np.nansum(count[85]) / len(count[85]), 4),
365 | "suc_90": round(np.nansum(count[90]) / len(count[90]), 4),
366 | "suc_95": round(np.nansum(count[95]) / len(count[95]), 4),
367 | "thr_80": count[80],
368 | "thr_85": count[85],
369 | "thr_90": count[90],
370 | "thr_95": count[95],
371 | }
372 | else:
373 | mae_list = count
374 | new_results = {
375 | "model": model,
376 | "dataset": dataset,
377 | "seed": i,
378 | "metric": metric,
379 | "train": train,
380 | "valid": valid,
381 | "test": test,
382 | "epoch": epoch,
383 | "mae_1": mae_list[0],
384 | "mae_2": mae_list[1],
385 | "mae_3": mae_list[2],
386 | "mae_4": mae_list[3],
387 | "mae_5": mae_list[4],
388 | "mae_6": mae_list[5],
389 | }
390 | df = pd.concat([df, pd.DataFrame([new_results])], ignore_index=True)
391 |
392 | summary_each = f"output/{args.dataset}/summary_each.csv"
393 | if os.path.exists(summary_each):
394 | df.to_csv(summary_each, mode="a", header=False, index=False)
395 | else:
396 | df.to_csv(summary_each, index=False)
397 | print(df)
398 |
399 | # Calculate mean and std
400 | if "auc" in metric:
401 | cols = [
402 | "model",
403 | "dataset",
404 | "metric",
405 | "train",
406 | "valid",
407 | "test",
408 | "suc_80",
409 | "suc_85",
410 | "suc_90",
411 | "suc_95",
412 | ]
413 | else:
414 | cols = [
415 | "model",
416 | "dataset",
417 | "metric",
418 | "train",
419 | "valid",
420 | "test",
421 | "mae_1",
422 | "mae_2",
423 | "mae_3",
424 | "mae_4",
425 | "mae_5",
426 | "mae_6",
427 | ]
428 | df_mean = df[cols].groupby(["model", "dataset", "metric"]).mean().round(4)
429 | df_std = df[cols].groupby(["model", "dataset", "metric"]).std().round(4)
430 |
431 | df_mean = df_mean.reset_index()
432 | df_std = df_std.reset_index()
433 | df_summary = df_mean[["model", "dataset", "metric"]].copy()
434 | if "auc" in metric:
435 | for col in [
436 | "train",
437 | "valid",
438 | "test",
439 | "suc_80",
440 | "suc_85",
441 | "suc_90",
442 | "suc_95",
443 | ]:
444 | df_summary[col] = (
445 | df_mean[col].astype(str) + "±" + df_std[col].astype(str)
446 | )
447 | else:
448 | for col in [
449 | "train",
450 | "valid",
451 | "test",
452 | "mae_1",
453 | "mae_2",
454 | "mae_3",
455 | "mae_4",
456 | "mae_5",
457 | "mae_6",
458 | ]:
459 | df_summary[col] = (
460 | df_mean[col].astype(str) + "±" + df_std[col].astype(str)
461 | )
462 |
463 | summary_all = f"output/{args.dataset}/summary_all.csv"
464 | if os.path.exists(summary_all):
465 | df_summary.to_csv(summary_all, mode="a", header=False, index=False)
466 | else:
467 | df_summary.to_csv(summary_all, index=False)
468 | print(df_summary)
469 |
--------------------------------------------------------------------------------
/models/conv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch_geometric.utils import degree
4 | from torch_geometric.nn.norm import GraphNorm, PairNorm, MessageNorm, DiffGroupNorm, InstanceNorm, LayerNorm, GraphSizeNorm, MessageNorm
5 | from torch_geometric.nn import MessagePassing
6 | from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool
7 | from ogb.utils.features import get_atom_feature_dims, get_bond_feature_dims
8 |
9 | full_atom_feature_dims = get_atom_feature_dims()
10 | full_bond_feature_dims = get_bond_feature_dims()
11 |
12 | class GINConv(MessagePassing):
13 | def __init__(self, emb_dim):
14 | '''
15 | emb_dim (int): node embedding dimensionality
16 | '''
17 |
18 | super(GINConv, self).__init__(aggr = "add")
19 |
20 | self.mlp = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2*emb_dim), torch.nn.BatchNorm1d(2*emb_dim), torch.nn.ReLU(), torch.nn.Linear(2*emb_dim, emb_dim))
21 | self.eps = torch.nn.Parameter(torch.Tensor([0]))
22 |
23 | self.bond_encoder = BondEncoder(emb_dim = emb_dim)
24 |
25 | def forward(self, x, edge_index, edge_attr):
26 | edge_embedding = self.bond_encoder(edge_attr)
27 | out = self.mlp((1 + self.eps) *x + self.propagate(edge_index, x=x, edge_attr=edge_embedding))
28 |
29 | return out
30 |
31 | def message(self, x_j, edge_attr):
32 | return F.relu(x_j + edge_attr)
33 |
34 | def update(self, aggr_out):
35 | return aggr_out
36 |
37 | class GCNConv(MessagePassing):
38 | def __init__(self, emb_dim):
39 | super(GCNConv, self).__init__(aggr='add')
40 |
41 | self.linear = torch.nn.Linear(emb_dim, emb_dim)
42 | self.root_emb = torch.nn.Embedding(1, emb_dim)
43 | self.bond_encoder = BondEncoder(emb_dim = emb_dim)
44 |
45 | def forward(self, x, edge_index, edge_attr):
46 | x = self.linear(x)
47 | edge_embedding = self.bond_encoder(edge_attr)
48 |
49 | row, col = edge_index
50 |
51 | #edge_weight = torch.ones((edge_index.size(1), ), device=edge_index.device)
52 | deg = degree(row, x.size(0), dtype = x.dtype) + 1
53 | deg_inv_sqrt = deg.pow(-0.5)
54 | deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
55 |
56 | norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
57 |
58 | return self.propagate(edge_index, x=x, edge_attr = edge_embedding, norm=norm) + F.relu(x + self.root_emb.weight) * 1./deg.view(-1,1)
59 |
60 | def message(self, x_j, edge_attr, norm):
61 | return norm.view(-1, 1) * F.relu(x_j + edge_attr)
62 |
63 | def update(self, aggr_out):
64 | return aggr_out
65 |
66 |
67 | ### GNN to generate node embedding
68 | class GNN_node(torch.nn.Module):
69 | """
70 | Output:
71 | node representations
72 | """
73 | def __init__(self, num_layer, emb_dim, drop_ratio = 0.5, JK = "last", residual = False, gnn_name = 'gin', norm_layer = 'batch_norm'):
74 | '''
75 | emb_dim (int): node embedding dimensionality
76 | num_layer (int): number of GNN message passing layers
77 |
78 | '''
79 |
80 | super(GNN_node, self).__init__()
81 | self.num_layer = num_layer
82 | self.drop_ratio = drop_ratio
83 | self.JK = JK
84 | ### add residual connection or not
85 | self.residual = residual
86 | self.norm_layer = norm_layer
87 |
88 | if self.num_layer < 2:
89 | raise ValueError("Number of GNN layers must be greater than 1.")
90 |
91 | self.atom_encoder = AtomEncoder(emb_dim)
92 | self.bond_encoder = BondEncoder(emb_dim)
93 |
94 | ###List of GNNs
95 | self.convs = torch.nn.ModuleList()
96 | self.batch_norms = torch.nn.ModuleList()
97 |
98 | for layer in range(num_layer):
99 | if gnn_name == 'gin':
100 | self.convs.append(GINConv(emb_dim))
101 | elif gnn_name == 'gcn':
102 | self.convs.append(GCNConv(emb_dim))
103 | else:
104 | raise ValueError('Undefined GNN type called {}'.format(gnn_name))
105 |
106 | if norm_layer.split('_')[0] == 'batch':
107 | if norm_layer.split('_')[-1] == 'notrack':
108 | self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim, track_running_stats=False, affine=False))
109 | else:
110 | self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim))
111 | elif norm_layer.split('_')[0] == 'instance':
112 | self.batch_norms.append(InstanceNorm(emb_dim))
113 | elif norm_layer.split('_')[0] == 'layer':
114 | self.batch_norms.append(LayerNorm(emb_dim))
115 | elif norm_layer.split('_')[0] == 'graph':
116 | self.batch_norms.append(GraphNorm(emb_dim))
117 | elif norm_layer.split('_')[0] == 'size':
118 | self.batch_norms.append(GraphSizeNorm())
119 | elif norm_layer.split('_')[0] == 'pair':
120 | self.batch_norms.append(PairNorm(emb_dim))
121 | elif norm_layer.split('_')[0] == 'group':
122 | self.batch_norms.append(DiffGroupNorm(emb_dim, groups=4))
123 | else:
124 | raise ValueError('Undefined normalization layer called {}'.format(norm_layer))
125 | if norm_layer.split('_')[1] == 'size':
126 | self.graph_size_norm = GraphSizeNorm()
127 |
128 | def forward(self, batched_data):
129 | x, edge_index, edge_attr, batch = batched_data.x, batched_data.edge_index, batched_data.edge_attr, batched_data.batch
130 |
131 | h_list = [self.atom_encoder(x)]
132 | for layer in range(self.num_layer):
133 |
134 | h = self.convs[layer](h_list[layer], edge_index, edge_attr)
135 | if self.norm_layer.split('_')[0] == 'batch':
136 | h = self.batch_norms[layer](h)
137 | else:
138 | h = self.batch_norms[layer](h, batch)
139 | if self.norm_layer.split('_')[1] == 'size':
140 | h = self.graph_size_norm(h, batch)
141 |
142 | if layer == self.num_layer - 1:
143 | #remove relu for the last layer
144 | h = F.dropout(h, self.drop_ratio, training = self.training)
145 | else:
146 | h = F.relu(h)
147 | h = F.dropout(h, self.drop_ratio, training = self.training)
148 | if self.residual:
149 | h = h + h_list[layer]
150 |
151 | h_list.append(h)
152 |
153 | ### Different implementations of Jk-concat
154 | if self.JK == "last":
155 | node_representation = h_list[-1]
156 | elif self.JK == "sum":
157 | node_representation = 0
158 | for layer in range(self.num_layer + 1):
159 | node_representation += h_list[layer]
160 |
161 | return node_representation, h_list
162 |
163 |
164 | ### Virtual GNN to generate node embedding
165 | class GNN_node_Virtualnode(torch.nn.Module):
166 | """
167 | Output:
168 | node representations
169 | """
170 | def __init__(self, num_layer, emb_dim, drop_ratio = 0.5, JK = "last", residual = False, gnn_name = 'gin', norm_layer = 'batch_norm'):
171 | '''
172 | emb_dim (int): node embedding dimensionality
173 | '''
174 |
175 | super(GNN_node_Virtualnode, self).__init__()
176 | self.num_layer = num_layer
177 | self.drop_ratio = drop_ratio
178 | self.JK = JK
179 | ### add residual connection or not
180 | self.residual = residual
181 | self.norm_layer = norm_layer
182 |
183 | if self.num_layer < 2:
184 | raise ValueError("Number of GNN layers must be greater than 1.")
185 |
186 | self.atom_encoder = AtomEncoder(emb_dim)
187 | self.bond_encoder = BondEncoder(emb_dim)
188 |
189 | ### set the initial virtual node embedding to 0.
190 | self.virtualnode_embedding = torch.nn.Embedding(1, emb_dim)
191 | torch.nn.init.constant_(self.virtualnode_embedding.weight.data, 0)
192 |
193 | ### List of GNNs
194 | self.convs = torch.nn.ModuleList()
195 | ### batch norms applied to node embeddings
196 | self.batch_norms = torch.nn.ModuleList()
197 |
198 | ### List of MLPs to transform virtual node at every layer
199 | self.mlp_virtualnode_list = torch.nn.ModuleList()
200 |
201 | for layer in range(num_layer):
202 | if gnn_name == 'gin':
203 | self.convs.append(GINConv(emb_dim))
204 | elif gnn_name == 'gcn':
205 | self.convs.append(GCNConv(emb_dim))
206 | else:
207 | raise ValueError('Undefined GNN type called {}'.format(gnn_name))
208 |
209 | if norm_layer.split('_')[0] == 'batch':
210 | if norm_layer.split('_')[-1] == 'notrack':
211 | self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim, track_running_stats=False, affine=False))
212 | else:
213 | self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim))
214 | elif norm_layer.split('_')[0] == 'instance':
215 | self.batch_norms.append(InstanceNorm(emb_dim))
216 | elif norm_layer.split('_')[0] == 'layer':
217 | self.batch_norms.append(LayerNorm(emb_dim))
218 | elif norm_layer.split('_')[0] == 'graph':
219 | self.batch_norms.append(GraphNorm(emb_dim))
220 | elif norm_layer.split('_')[0] == 'size':
221 | self.batch_norms.append(GraphSizeNorm())
222 | elif norm_layer.split('_')[0] == 'pair':
223 | self.batch_norms.append(PairNorm(emb_dim))
224 | elif norm_layer.split('_')[0] == 'group':
225 | self.batch_norms.append(DiffGroupNorm(emb_dim, groups=4))
226 | else:
227 | raise ValueError('Undefined normalization layer called {}'.format(norm_layer))
228 | if norm_layer.split('_')[1] == 'size':
229 | self.graph_size_norm = GraphSizeNorm()
230 | for layer in range(num_layer - 1):
231 | self.mlp_virtualnode_list.append(torch.nn.Sequential(torch.nn.Linear(emb_dim, 2*emb_dim), torch.nn.BatchNorm1d(2*emb_dim), torch.nn.ReLU(), \
232 | torch.nn.Linear(2*emb_dim, emb_dim), torch.nn.BatchNorm1d(emb_dim), torch.nn.ReLU()))
233 |
234 |
235 | def forward(self, batched_data):
236 | x, edge_index, edge_attr, batch = batched_data.x, batched_data.edge_index, batched_data.edge_attr, batched_data.batch
237 | ### virtual node embeddings for graphs
238 | virtualnode_embedding = self.virtualnode_embedding(torch.zeros(batch[-1].item() + 1).to(edge_index.dtype).to(edge_index.device))
239 |
240 | h_list = [self.atom_encoder(x)]
241 | for layer in range(self.num_layer):
242 | ### add message from virtual nodes to graph nodes
243 | h_list[layer] = h_list[layer] + virtualnode_embedding[batch]
244 | ### Message passing among graph nodes
245 | h = self.convs[layer](h_list[layer], edge_index, edge_attr)
246 | if self.norm_layer.split('_')[0] == 'batch':
247 | h = self.batch_norms[layer](h)
248 | else:
249 | h = self.batch_norms[layer](h, batch)
250 | if self.norm_layer.split('_')[1] == 'size':
251 | h = self.graph_size_norm(h, batch)
252 |
253 | if layer == self.num_layer - 1:
254 | h = F.dropout(h, self.drop_ratio, training = self.training)
255 | else:
256 | h = F.relu(h)
257 | h = F.dropout(h, self.drop_ratio, training = self.training)
258 |
259 | if self.residual:
260 | h = h + h_list[layer]
261 |
262 | h_list.append(h)
263 |
264 | ### update the virtual nodes
265 | if layer < self.num_layer - 1:
266 | ### add message from graph nodes to virtual nodes
267 | virtualnode_embedding_temp = global_add_pool(h_list[layer], batch) + virtualnode_embedding
268 | if self.residual:
269 | virtualnode_embedding = virtualnode_embedding + F.dropout(self.mlp_virtualnode_list[layer](virtualnode_embedding_temp), self.drop_ratio, training = self.training)
270 | else:
271 | virtualnode_embedding = F.dropout(self.mlp_virtualnode_list[layer](virtualnode_embedding_temp), self.drop_ratio, training = self.training)
272 |
273 | ### Different implementations of Jk-concat
274 | if self.JK == "last":
275 | node_representation = h_list[-1]
276 | elif self.JK == "sum":
277 | node_representation = 0
278 | for layer in range(self.num_layer + 1):
279 | node_representation += h_list[layer]
280 |
281 | return node_representation, h_list
282 |
283 |
284 | class AtomEncoder(torch.nn.Module):
285 |
286 | def __init__(self, emb_dim):
287 | super(AtomEncoder, self).__init__()
288 |
289 | self.atom_embedding_list = torch.nn.ModuleList()
290 |
291 | for i, dim in enumerate(full_atom_feature_dims):
292 | emb = torch.nn.Embedding(dim, emb_dim, max_norm=1)
293 | torch.nn.init.xavier_uniform_(emb.weight.data)
294 | self.atom_embedding_list.append(emb)
295 |
296 | def forward(self, x):
297 | x_embedding = 0
298 | for i in range(x.shape[1]):
299 | x_embedding += self.atom_embedding_list[i](x[:,i])
300 |
301 | return x_embedding
302 |
303 |
304 | class BondEncoder(torch.nn.Module):
305 |
306 | def __init__(self, emb_dim):
307 | super(BondEncoder, self).__init__()
308 |
309 | self.bond_embedding_list = torch.nn.ModuleList()
310 |
311 | for i, dim in enumerate(full_bond_feature_dims):
312 | emb = torch.nn.Embedding(dim, emb_dim, max_norm=1)
313 | torch.nn.init.xavier_uniform_(emb.weight.data)
314 | self.bond_embedding_list.append(emb)
315 |
316 | def forward(self, edge_attr):
317 | bond_embedding = 0
318 | for i in range(edge_attr.shape[1]):
319 | bond_embedding += self.bond_embedding_list[i](edge_attr[:,i])
320 |
321 | return bond_embedding
--------------------------------------------------------------------------------
/models/gnn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool
5 | from .conv import GNN_node, GNN_node_Virtualnode
6 |
7 | from torch.distributions import Normal, Independent
8 | from torch.nn.functional import softplus
9 |
10 | class GNN(torch.nn.Module):
11 | def __init__(
12 | self,
13 | num_tasks=None, # to remove
14 | num_layer=5,
15 | emb_dim=300,
16 | gnn_type="gin",
17 | drop_ratio=0.5,
18 | graph_pooling="max",
19 | norm_layer="batch_norm",
20 | decoder_dims=[1024, 1111, 862, 1783, 966, 978],
21 | # mol, gene (gc, go), cell (bray, jump), express
22 | ):
23 | super(GNN, self).__init__()
24 | self.num_layer = num_layer
25 | self.drop_ratio = drop_ratio
26 | self.graph_pooling = graph_pooling
27 | if self.num_layer < 2:
28 | raise ValueError("Number of GNN layers must be greater than 1.")
29 | ### GNN to generate node embeddings
30 | gnn_name = gnn_type.split("-")[0]
31 | if "virtual" in gnn_type:
32 | self.graph_encoder = GNN_node_Virtualnode(
33 | num_layer,
34 | emb_dim ,
35 | JK="last",
36 | drop_ratio=drop_ratio,
37 | residual=True,
38 | gnn_name=gnn_name,
39 | norm_layer=norm_layer,
40 | )
41 | else:
42 | self.graph_encoder = GNN_node(
43 | num_layer,
44 | emb_dim,
45 | JK="last",
46 | drop_ratio=drop_ratio,
47 | residual=True,
48 | gnn_name=gnn_name,
49 | norm_layer=norm_layer,
50 | )
51 |
52 | ### Poolinwg function to generate whole-graph embeddings
53 | if graph_pooling == "sum":
54 | self.pool = global_add_pool
55 | elif graph_pooling == "mean":
56 | self.pool = global_mean_pool
57 | elif graph_pooling == "max":
58 | self.pool = global_max_pool
59 | else:
60 | raise ValueError("Invalid graph pooling type.")
61 |
62 | self.dist_net = nn.Sequential(
63 | nn.SiLU(),
64 | nn.Linear(emb_dim, 2 * emb_dim, bias=True)
65 | )
66 |
67 | self.decoder_list = nn.ModuleList()
68 | for out_dim in decoder_dims:
69 | self.decoder_list.append(MLP(emb_dim, hidden_features=emb_dim * 4, out_features=out_dim))
70 |
71 |
72 | def forward(self, batched_data):
73 | h_node, _ = self.graph_encoder(batched_data)
74 | h_graph = self.pool(h_node, batched_data.batch)
75 |
76 | mu, sigma = self.dist_net(h_graph).chunk(2, dim=1)
77 | sigma = softplus(sigma) + 1e-7
78 | p_z_given_x = Independent(Normal(loc=mu, scale=sigma), 1)
79 |
80 | out = []
81 | out.append(p_z_given_x)
82 | # p, mol, gene (gc, go), cell (bray, jump), express
83 | for decoder in self.decoder_list:
84 | out.append(decoder(mu))
85 | out_gene = torch.cat((out[2], out[3]), dim=1)
86 | out_cell = torch.cat((out[4], out[5]), dim=1)
87 | return [out[0], out[1], out_gene, out_cell, out[6]]
88 |
89 | # define a new finetune model with the same architecture of GNN with a new MLP
90 |
91 | class FineTuneGNN(nn.Module):
92 | def __init__(
93 | self,
94 | num_tasks=None,
95 | num_layer=5,
96 | emb_dim=300,
97 | gnn_type="gin",
98 | drop_ratio=0.5,
99 | graph_pooling="max",
100 | norm_layer="batch_norm",
101 | ):
102 | super(FineTuneGNN, self).__init__()
103 |
104 | ### GNN to generate node embeddings
105 | gnn_name = gnn_type.split("-")[0]
106 | if "virtual" in gnn_type:
107 | self.graph_encoder = GNN_node_Virtualnode(
108 | num_layer,
109 | emb_dim,
110 | JK="last",
111 | drop_ratio=drop_ratio,
112 | residual=True,
113 | gnn_name=gnn_name,
114 | norm_layer=norm_layer,
115 | )
116 | else:
117 | self.graph_encoder = GNN_node(
118 | num_layer,
119 | emb_dim,
120 | JK="last",
121 | drop_ratio=drop_ratio,
122 | residual=True,
123 | gnn_name=gnn_name,
124 | norm_layer=norm_layer,
125 | )
126 | ### Poolinwg function to generate whole-graph embeddings
127 | if graph_pooling == "sum":
128 | self.pool = global_add_pool
129 | elif graph_pooling == "mean":
130 | self.pool = global_mean_pool
131 | elif graph_pooling == "max":
132 | self.pool = global_max_pool
133 | else:
134 | raise ValueError("Invalid graph pooling type.")
135 |
136 | self.dist_net = nn.Sequential(
137 | nn.SiLU(),
138 | nn.Linear(emb_dim, 2 * emb_dim, bias=True)
139 | )
140 |
141 | # self.task_decoder = nn.Linear(emb_dim, num_tasks)
142 | self.task_decoder = MLP(emb_dim, hidden_features=4 * emb_dim, out_features=num_tasks)
143 |
144 | def forward(self, batched_data):
145 | h_node, _ = self.graph_encoder(batched_data)
146 | h_graph = self.pool(h_node, batched_data.batch)
147 |
148 | mu, _ = self.dist_net(h_graph).chunk(2, dim=1)
149 | task_out = self.task_decoder(mu)
150 | return task_out
151 |
152 | def load_pretrained_graph_encoder(self, model_path):
153 | saved_state_dict = torch.load(model_path, map_location=torch.device('cpu'))
154 | graph_encoder_state_dict = {key: value for key, value in saved_state_dict.items() if key.startswith('graph_encoder.')}
155 | graph_encoder_state_dict = {key.replace('graph_encoder.', ''): value for key, value in graph_encoder_state_dict.items()}
156 | self.graph_encoder.load_state_dict(graph_encoder_state_dict)
157 | # Load dist_net state dictionary
158 | dist_net_state_dict = {key: value for key, value in saved_state_dict.items() if key.startswith('dist_net.')}
159 | dist_net_state_dict = {key.replace('dist_net.', ''): value for key, value in dist_net_state_dict.items()}
160 | self.dist_net.load_state_dict(dist_net_state_dict)
161 | self.freeze_graph_encoder()
162 |
163 | def freeze_graph_encoder(self):
164 | for param in self.graph_encoder.parameters():
165 | param.requires_grad = False
166 | for param in self.dist_net.parameters():
167 | param.requires_grad = False
168 |
169 | class MLP(nn.Module):
170 | def __init__(
171 | self,
172 | in_features,
173 | hidden_features=None,
174 | out_features=None,
175 | act_layer=nn.GELU,
176 | bias=True,
177 | drop=0.5,
178 | ):
179 | super().__init__()
180 | out_features = out_features or in_features
181 | hidden_features = hidden_features or in_features
182 | linear_layer = nn.Linear
183 |
184 | self.fc1 = linear_layer(in_features, hidden_features, bias=bias)
185 | self.bn1 = nn.BatchNorm1d(hidden_features)
186 | self.act = act_layer()
187 | self.drop1 = nn.Dropout(drop)
188 | self.fc2 = linear_layer(hidden_features, out_features, bias=bias)
189 | self.drop2 = nn.Dropout(drop)
190 |
191 | def forward(self, x):
192 | x = self.fc1(x)
193 | x = self.bn1(x)
194 | x = self.act(x)
195 | x = self.drop1(x)
196 | x = self.fc2(x)
197 | # x = self.drop2(x)
198 | return x
--------------------------------------------------------------------------------
/raw_data/biogenadme/raw/assays.csv.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liugangcode/InfoAlign/eed96f36503a8a427fac4f382ee2001f957dd42a/raw_data/biogenadme/raw/assays.csv.gz
--------------------------------------------------------------------------------
/raw_data/broad6k/raw/CP-Bray.csv.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liugangcode/InfoAlign/eed96f36503a8a427fac4f382ee2001f957dd42a/raw_data/broad6k/raw/CP-Bray.csv.gz
--------------------------------------------------------------------------------
/raw_data/broad6k/raw/GE.csv.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liugangcode/InfoAlign/eed96f36503a8a427fac4f382ee2001f957dd42a/raw_data/broad6k/raw/GE.csv.gz
--------------------------------------------------------------------------------
/raw_data/broad6k/raw/assays.csv.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liugangcode/InfoAlign/eed96f36503a8a427fac4f382ee2001f957dd42a/raw_data/broad6k/raw/assays.csv.gz
--------------------------------------------------------------------------------
/raw_data/broad6k/raw/structure.csv.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liugangcode/InfoAlign/eed96f36503a8a427fac4f382ee2001f957dd42a/raw_data/broad6k/raw/structure.csv.gz
--------------------------------------------------------------------------------
/raw_data/chembl2k/raw/CP-JUMP.csv.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liugangcode/InfoAlign/eed96f36503a8a427fac4f382ee2001f957dd42a/raw_data/chembl2k/raw/CP-JUMP.csv.gz
--------------------------------------------------------------------------------
/raw_data/chembl2k/raw/GE.csv.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liugangcode/InfoAlign/eed96f36503a8a427fac4f382ee2001f957dd42a/raw_data/chembl2k/raw/GE.csv.gz
--------------------------------------------------------------------------------
/raw_data/chembl2k/raw/assays.csv.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liugangcode/InfoAlign/eed96f36503a8a427fac4f382ee2001f957dd42a/raw_data/chembl2k/raw/assays.csv.gz
--------------------------------------------------------------------------------
/raw_data/chembl2k/raw/structure.csv.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liugangcode/InfoAlign/eed96f36503a8a427fac4f382ee2001f957dd42a/raw_data/chembl2k/raw/structure.csv.gz
--------------------------------------------------------------------------------
/raw_data/moltoxcast/raw/assays.csv.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liugangcode/InfoAlign/eed96f36503a8a427fac4f382ee2001f957dd42a/raw_data/moltoxcast/raw/assays.csv.gz
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # Install PyTorch with CUDA 11.8
2 | -f https://download.pytorch.org/whl/cu118/torch_stable.html
3 | torch==2.2.0+cu118
4 |
5 | # Install PyTorch Geometric and related packages
6 | -f https://data.pyg.org/whl/torch-2.2.0+cu118.html
7 | torch_geometric==2.6.1
8 | torch_cluster==1.6.3
9 |
10 | # Other dependencies
11 | huggingface_hub==0.22.2
12 | joblib==1.3.2
13 | networkx==3.2.1
14 | ogb==1.3.6
15 | pandas==2.2.3
16 | PyYAML==6.0.2
17 | rdkit==2023.9.5
18 | scikit_learn==1.4.1.post1
19 | scipy==1.14.1
20 | tqdm==4.66.2
21 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .misc import *
2 |
--------------------------------------------------------------------------------
/utils/misc.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pandas as pd
3 |
4 | import torch
5 | import numpy as np
6 | from sklearn.metrics import (
7 | mean_absolute_error,
8 | roc_auc_score,
9 | )
10 | __all__ = ["validate", "init_weights", "AverageMeter", "save_prediction"]
11 |
12 | def eval_func(pred, true, reduction=True):
13 | unique_values = np.unique(true[~np.isnan(true)])
14 | task_type = 'classification' if set(unique_values).issubset({0, 1}) else 'regression'
15 |
16 | if task_type == 'classification':
17 | rocauc_list = []
18 | count_dict = {80: [], 85: [], 90: [], 95: []}
19 | for i in range(true.shape[1]):
20 | if np.sum(true[:, i] == 1) > 0 and np.sum(true[:, i] == 0) > 0:
21 | is_labeled = true[:, i] == true[:, i]
22 | score = roc_auc_score(true[is_labeled, i], pred[is_labeled, i])
23 | for threshold in count_dict.keys():
24 | count_dict[threshold].append(int(score >= threshold / 100))
25 | rocauc_list.append(score)
26 | else:
27 | for threshold in count_dict.keys():
28 | count_dict[threshold].append(np.nan)
29 |
30 | if len(rocauc_list) == 0:
31 | raise RuntimeError(
32 | "No positively labeled data available. Cannot compute ROC-AUC."
33 | )
34 | return {"roc_auc": sum(rocauc_list) / len(rocauc_list), "count": count_dict} if reduction else {"roc_auc": rocauc_list}
35 |
36 | elif task_type == 'regression':
37 | mae_list = []
38 | for i in range(true.shape[1]):
39 | is_labeled = ~np.isnan(true[:, i])
40 | mae_score = mean_absolute_error(true[is_labeled, i], pred[is_labeled, i])
41 | mae_list.append(mae_score)
42 | return {"avg_mae": np.mean(mae_list), "mae_list": mae_list}
43 |
44 | def save_prediction(model, device, loader, dataset, output_dir, seed):
45 | y_true = []
46 | y_pred = []
47 | model.eval()
48 | for step, batch in enumerate(loader):
49 | batch = batch.to(device)
50 | if batch.x.shape[0] == 1:
51 | pass
52 | else:
53 | with torch.no_grad():
54 | pred = model(batch)
55 | y_true.append(batch.y.view(pred.shape).detach().cpu())
56 | y_pred.append(pred.detach().cpu())
57 | y_true = torch.cat(y_true, dim=0).numpy()
58 | y_pred = torch.cat(y_pred, dim=0).numpy()
59 |
60 | assay_path = f"raw_data/{dataset.name}/raw/assays.csv.gz"
61 | assay_df = pd.read_csv(assay_path, compression="gzip")
62 | assay_names = assay_df.columns[dataset.start_column :]
63 |
64 | y_pred[np.isnan(y_true)] = np.nan
65 | os.makedirs(output_dir, exist_ok=True)
66 |
67 | df_pred = pd.DataFrame(y_pred, columns=[name + "_pred" for name in assay_names])
68 | df_true = pd.DataFrame(y_true, columns=[name + "_true" for name in assay_names])
69 | df = pd.concat([df_pred, df_true], axis=1)
70 | df.to_csv(os.path.join(output_dir, f"preds-{seed}.csv"), index=False)
71 |
72 | returned_dict = eval_func(y_pred, y_true, reduction=False)
73 | results = returned_dict.get('roc_auc', None)
74 | if results is None:
75 | results = returned_dict.get('mae_list', None)
76 | if results is None:
77 | raise ValueError("Invalid task type")
78 | results_dict = dict(zip(assay_names, results))
79 | sorted_dict = dict(
80 | sorted(results_dict.items(), key=lambda item: item[1], reverse=True)
81 | )
82 | df_sorted = pd.DataFrame(
83 | list(sorted_dict.items()), columns=["Assay Name", "Result"]
84 | )
85 | df_sorted.to_csv(os.path.join(output_dir, f"result-{seed}.csv"), index=False)
86 |
87 | def validate(args, model, loader):
88 | y_true = []
89 | y_pred = []
90 | device = args.device
91 | model.eval()
92 | for step, batch in enumerate(loader):
93 | batch = batch.to(device)
94 | if batch.x.shape[0] == 1:
95 | pass
96 | else:
97 | with torch.no_grad():
98 | pred = model(batch)
99 | y_true.append(batch.y.view(pred.shape).detach().cpu())
100 | y_pred.append(pred.detach().cpu())
101 | y_true = torch.cat(y_true, dim=0).numpy()
102 | y_pred = torch.cat(y_pred, dim=0).numpy()
103 |
104 | return eval_func(y_pred, y_true)
105 |
106 |
107 | def init_weights(net, init_type="normal", init_gain=0.02):
108 | """Initialize network weights.
109 | Parameters:
110 | net (network) -- network to be initialized
111 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
112 | init_gain (float) -- scaling factor for normal, xavier and orthogonal.
113 | """
114 |
115 | def init_func(m): # define the initialization function
116 | classname = m.__class__.__name__
117 | if hasattr(m, "weight") and (
118 | classname.find("Conv") != -1 or classname.find("Linear") != -1
119 | ):
120 | if init_type == "normal":
121 | torch.nn.init.normal_(m.weight.data, 0.0, init_gain)
122 | elif init_type == "xavier":
123 | torch.nn.init.xavier_normal_(m.weight.data, gain=init_gain)
124 | elif init_type == "kaiming":
125 | torch.nn.init.kaiming_normal_(m.weight.data, a=0, mode="fan_in")
126 | elif init_type == "orthogonal":
127 | torch.nn.init.orthogonal_(m.weight.data, gain=init_gain)
128 | elif init_type == "default":
129 | pass
130 | else:
131 | raise NotImplementedError(
132 | "initialization method [%s] is not implemented" % init_type
133 | )
134 | if hasattr(m, "bias") and m.bias is not None:
135 | torch.nn.init.constant_(m.bias.data, 0.0)
136 | elif (
137 | classname.find("BatchNorm2d") != -1
138 | ): # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
139 | torch.nn.init.normal_(m.weight.data, 1.0, init_gain)
140 | torch.nn.init.constant_(m.bias.data, 0.0)
141 |
142 | print("initialize network with %s" % init_type)
143 | net.apply(init_func) # apply the initialization function
144 |
145 | class AverageMeter(object):
146 | """Computes and stores the average and current value
147 | Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
148 | """
149 |
150 | def __init__(self):
151 | self.reset()
152 |
153 | def reset(self):
154 | self.val = 0
155 | self.avg = 0
156 | self.sum = 0
157 | self.count = 0
158 |
159 | def update(self, val, n=1):
160 | self.val = val
161 | self.sum += val * n
162 | self.count += n
163 | self.avg = self.sum / self.count
164 |
165 |
166 | def log_base(base, x):
167 | return np.log(x) / np.log(base)
168 |
169 |
170 | def _eval_rocauc(y_true, y_pred):
171 | """
172 | compute ROC-AUC averaged across tasks
173 | """
174 | rocauc_list = []
175 | for i in range(y_true.shape[1]):
176 | # AUC is only defined when there is at least one positive data.
177 | if np.sum(y_true[:, i] == 1) > 0 and np.sum(y_true[:, i] == 0) > 0:
178 | # ignore nan values
179 | is_labeled = y_true[:, i] == y_true[:, i]
180 | rocauc_list.append(
181 | roc_auc_score(y_true[is_labeled, i], y_pred[is_labeled, i])
182 | )
183 |
184 | if len(rocauc_list) == 0:
185 | raise RuntimeError(
186 | "No positively labeled data available. Cannot compute ROC-AUC."
187 | )
188 | return {"rocauc": sum(rocauc_list) / len(rocauc_list)}
189 |
190 |
191 | if __name__ == "__main__":
192 | pass
193 |
--------------------------------------------------------------------------------
/utils/train_funcs.py:
--------------------------------------------------------------------------------
1 | import time
2 | from tqdm import tqdm
3 |
4 | import torch
5 | import torch.nn.functional as F
6 | from torch_cluster import random_walk
7 |
8 | from .misc import AverageMeter
9 |
10 | cls_criterion = torch.nn.BCEWithLogitsLoss(reduction="none")
11 | reg_criterion = torch.nn.L1Loss(reduction="none")
12 |
13 | def prior_criterion(args, pz_given_x):
14 | z = pz_given_x.rsample()
15 | loss = pz_given_x.log_prob(z) - args.prior_dist.log_prob(z)
16 | return loss.mean()
17 |
18 | def pretrain_loss_ce(pred, target, weight, nan_mask):
19 | expanded_weight = weight.unsqueeze(2).expand_as(target)
20 | loss = F.binary_cross_entropy_with_logits(
21 | pred[~nan_mask],
22 | target[~nan_mask],
23 | weight=expanded_weight[~nan_mask],
24 | reduction="none",
25 | )
26 | return loss.mean()
27 |
28 | def pretrain_func(
29 | args, model, train_loaders, context_graph, optimizer, scheduler, epoch
30 | ):
31 |
32 | criterion = pretrain_loss_ce
33 | if not args.no_print:
34 | p_bar = tqdm(range(args.steps))
35 | batch_time = AverageMeter()
36 | (losses_tot, losses_prior, losses_mol, losses_gene, losses_cell, losses_exp) = (
37 | AverageMeter(),
38 | AverageMeter(),
39 | AverageMeter(),
40 | AverageMeter(),
41 | AverageMeter(),
42 | AverageMeter(),
43 | )
44 | device = args.device
45 | model.train()
46 |
47 | context_node_type = context_graph.type
48 | mol_target = context_graph.mol_target
49 | gene_target = context_graph.gene_target
50 | cell_target = context_graph.cell_target
51 | express_target = context_graph.express_target
52 | edge_weight = context_graph.weight
53 | edge_weight = torch.cat(
54 | [
55 | edge_weight,
56 | torch.zeros(1, dtype=edge_weight.dtype, device=edge_weight.device),
57 | ]
58 | )
59 | index_map = {value: index for index, value in enumerate(context_node_type)}
60 |
61 | for batch_idx in range(args.steps):
62 | end = time.time()
63 | model.zero_grad()
64 | try:
65 | batched_data = next(train_loaders["train_iter"])
66 | except:
67 | train_loaders["train_iter"] = iter(train_loaders["train_loader"])
68 | batched_data = next(train_loaders["train_iter"])
69 | batched_data = batched_data.to(device)
70 |
71 | start_indices = [index_map[value] for value in batched_data.type]
72 | start_indices = torch.tensor(start_indices).long()
73 | start_indices = start_indices.view(-1, 1).repeat(1, 1).view(-1)
74 | batched_walk, batched_edge_seq = random_walk(
75 | context_graph.edge_index[0],
76 | context_graph.edge_index[1],
77 | start_indices,
78 | args.walk_length,
79 | num_nodes=context_graph.num_nodes,
80 | return_edge_indices=True,
81 | )
82 |
83 | # batched_walk = batched_walk[:, 1:]
84 | batched_path_weight = (
85 | edge_weight[batched_edge_seq]
86 | .view(-1, args.walk_length)
87 | .cumprod(dim=-1)
88 | .to(device)
89 | )
90 | ## if count starting
91 | batched_path_weight = torch.cat(
92 | [
93 | torch.ones(
94 | batched_path_weight.size(0), 1, device=batched_path_weight.device
95 | ),
96 | batched_path_weight,
97 | ],
98 | dim=1,
99 | )
100 |
101 | batched_mol_target = mol_target[batched_walk].to(device)
102 | batched_gene_target = gene_target[batched_walk].to(device)
103 | batched_cell_target = cell_target[batched_walk].to(device)
104 | batched_express_target = express_target[batched_walk].to(device)
105 |
106 | nan_mask_mol = torch.isnan(batched_mol_target)
107 | nan_mask_gene = torch.isnan(batched_gene_target)
108 | nan_mask_cell = torch.isnan(batched_cell_target)
109 | nan_mask_express = torch.isnan(batched_express_target)
110 |
111 | if batched_data.x.shape[0] == 1 or batched_data.batch[-1] == 0:
112 | continue
113 | else:
114 | pred_list = model(batched_data)
115 | pz_given_x, pred_mol, pred_gene, pred_cell, pred_express = pred_list
116 |
117 | pred_mol = pred_mol.unsqueeze(1).expand(-1, batched_mol_target.size(1), -1)
118 | pred_gene = pred_gene.unsqueeze(1).expand(
119 | -1, batched_gene_target.size(1), -1
120 | )
121 | pred_cell = pred_cell.unsqueeze(1).expand(
122 | -1, batched_cell_target.size(1), -1
123 | )
124 | pred_express = pred_express.unsqueeze(1).expand(
125 | -1, batched_express_target.size(1), -1
126 | )
127 |
128 | loss_prior = prior_criterion(args, pz_given_x)
129 |
130 | loss_mol = criterion(
131 | pred_mol, batched_mol_target, batched_path_weight, nan_mask_mol
132 | )
133 | loss_gene = criterion(
134 | pred_gene, batched_gene_target, batched_path_weight, nan_mask_gene
135 | )
136 | loss_cell = criterion(
137 | pred_cell, batched_cell_target, batched_path_weight, nan_mask_cell
138 | )
139 | loss_exp = criterion(
140 | pred_express,
141 | batched_express_target,
142 | batched_path_weight,
143 | nan_mask_express,
144 | )
145 | loss = args.prior * loss_prior + loss_mol + loss_gene + loss_cell + loss_exp
146 |
147 | loss.backward()
148 | optimizer.step()
149 | scheduler.step()
150 | losses_tot.update(loss.item())
151 | losses_prior.update(args.prior * loss_prior.item())
152 | losses_mol.update(loss_mol.item())
153 | losses_gene.update(loss_gene.item())
154 | losses_cell.update(loss_cell.item())
155 | losses_exp.update(loss_exp.item())
156 |
157 | batch_time.update(time.time() - end)
158 | end = time.time()
159 | if not args.no_print:
160 | log_message = "Train Epoch: {epoch}/{epochs:3}. Iter: {batch:2}/{iter:2}. LR: {lr:1}e-4. Batch: {bt:.1f}s. Loss (Total): {loss_tot:.2f}. Loss (prior): {loss_prior:.4f}. Loss (mol): {loss_mol:.4f}. Loss (gene): {loss_gene:.4f}. Loss (cell): {loss_cell:.4f}. Loss (express): {loss_exp:.4f}.".format(
161 | epoch=epoch + 1,
162 | epochs=args.epochs,
163 | batch=batch_idx + 1,
164 | iter=args.steps,
165 | lr=args.lr * 10000,
166 | bt=batch_time.avg,
167 | loss_tot=losses_tot.avg,
168 | loss_prior=losses_prior.avg,
169 | loss_mol=losses_mol.avg,
170 | loss_gene=losses_gene.avg,
171 | loss_cell=losses_cell.avg,
172 | loss_exp=losses_exp.avg,
173 | )
174 | p_bar.set_description(log_message)
175 | p_bar.update()
176 | if not args.no_print:
177 | p_bar.close()
178 | args.logger.info(log_message)
179 |
180 | return losses_tot.avg, train_loaders
181 |
182 |
183 | def finetune_func(args, model, train_loaders, optimizer, scheduler, epoch):
184 | if args.task_type == "regression":
185 | criterion = reg_criterion
186 | else:
187 | criterion = cls_criterion
188 | if not args.no_print:
189 | p_bar = tqdm(range(args.steps))
190 | batch_time = AverageMeter()
191 | data_time = AverageMeter()
192 | losses = AverageMeter()
193 | device = args.device
194 | model.train()
195 | for batch_idx in range(args.steps):
196 | end = time.time()
197 | model.zero_grad()
198 | try:
199 | batched_data = next(train_loaders["train_iter"])
200 | except:
201 | train_loaders["train_iter"] = iter(train_loaders["train_loader"])
202 | batched_data = next(train_loaders["train_iter"])
203 | batched_data = batched_data.to(device)
204 | targets = batched_data.y.to(torch.float32)
205 | is_labeled = targets == targets
206 | if batched_data.x.shape[0] == 1 or batched_data.batch[-1] == 0:
207 | continue
208 | else:
209 | preds = model(batched_data)
210 | loss = criterion(
211 | preds.view(targets.size()).to(torch.float32)[is_labeled],
212 | targets[is_labeled],
213 | ).mean()
214 | loss.backward()
215 | optimizer.step()
216 | # scheduler.step()
217 | losses.update(loss.item())
218 | batch_time.update(time.time() - end)
219 | end = time.time()
220 | if not args.no_print:
221 | p_bar.set_description(
222 | "Train Epoch: {epoch}/{epochs:4}. Iter: {batch:4}/{iter:4}. LR: {lr:.8f}. Batch: {bt:.3f}s. Loss: {loss:.4f}. ".format(
223 | epoch=epoch + 1,
224 | epochs=args.epochs,
225 | batch=batch_idx + 1,
226 | iter=args.steps,
227 | # lr=scheduler.get_last_lr()[0],
228 | lr=args.lr,
229 | bt=batch_time.avg,
230 | loss=losses.avg,
231 | )
232 | )
233 | p_bar.update()
234 | if not args.no_print:
235 | p_bar.close()
236 |
237 | return train_loaders
238 |
--------------------------------------------------------------------------------