├── .gitignore ├── README.md ├── assets └── infoalign.png ├── ckpt └── pretrain.yaml ├── configures ├── arguments.py └── finetune.yaml ├── convert_to_dataset.py ├── dataset ├── __init__.py ├── context_graph.py ├── create_datasets.py ├── data_utils.py ├── prediction_molecule.py ├── pretrain_context.py ├── pretrain_molecule.py └── retrieval.py ├── main.py ├── models ├── conv.py └── gnn.py ├── raw_data ├── biogenadme │ └── raw │ │ ├── ADME_public_set_3521.csv │ │ └── assays.csv.gz ├── broad6k │ └── raw │ │ ├── CP-Bray.csv.gz │ │ ├── GE.csv.gz │ │ ├── assays.csv.gz │ │ └── structure.csv.gz ├── chembl2k │ └── raw │ │ ├── CP-JUMP.csv.gz │ │ ├── GE.csv.gz │ │ ├── assays.csv.gz │ │ └── structure.csv.gz └── moltoxcast │ └── raw │ └── assays.csv.gz ├── requirements.txt └── utils ├── __init__.py ├── misc.py └── train_funcs.py /.gitignore: -------------------------------------------------------------------------------- 1 | # General 2 | .DS_Store 3 | *.log 4 | npm-debug.log* 5 | yarn-debug.log* 6 | yarn-error.log* 7 | 8 | # Virtualenv 9 | *.pyc 10 | *.pyo 11 | *.pyd 12 | .Python 13 | env/ 14 | venv/ 15 | .python-version 16 | pip-log.txt 17 | pip-delete-this-directory.txt 18 | htmlcov/ 19 | .tox/ 20 | .coverage 21 | .coverage.* 22 | .cache 23 | nosetests.xml 24 | coverage.xml 25 | *.cover 26 | .hypothesis/ 27 | .pytest_cache/ 28 | 29 | # Jupyter Notebook 30 | .ipynb_checkpoints 31 | 32 | # macOS specific 33 | .DS_Store 34 | 35 | # IDE specific 36 | .idea/ 37 | .vscode/ 38 | *.swp 39 | *.swo 40 | *~ 41 | 42 | 43 | ## just in case 44 | **/node_modules 45 | **/dist 46 | **/build 47 | **/coverage 48 | **/.cache 49 | **/.env 50 | **/.env.local 51 | **/.env.development.local 52 | **/.env.test.local 53 | **/.env.production.local 54 | **/instance 55 | **/.venv 56 | **/env 57 | **/venv 58 | **/.flaskenv 59 | **/*.pyc 60 | **/*.pyo 61 | **/*.pyd 62 | **/.Python 63 | **/env/ 64 | **/venv/ 65 | **/.python-version 66 | **/pip-log.txt 67 | **/pip-delete-this-directory.txt 68 | **/htmlcov/ 69 | **/.tox/ 70 | **/.coverage 71 | **/.coverage.* 72 | **/.cache 73 | **/nosetests.xml 74 | **/coverage.xml 75 | **/*.cover 76 | **/.hypothesis/ 77 | **/.pytest_cache/ 78 | **/.ipynb_checkpoints 79 | **/.DS_Store 80 | **/.idea/ 81 | **/.vscode/ 82 | **/*.swp 83 | **/*.swo 84 | **/__pycache__/ 85 | 86 | # customized 87 | **/*.pt 88 | **/*.npz 89 | *.pt 90 | *.npz 91 | **/pretrain* 92 | **/*.zip 93 | requirements copy.txt -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Learning Molecular Representation in a Cell 2 | 3 | **InfoAlign** learns molecular representations from bottleneck information derived from molecular structures, cell morphology, and gene expressions. For more details, please refer to our [paper](https://arxiv.org/abs/2406.12056v3). 4 | 5 |

6 | InfoAlign 7 |

8 | 9 | 10 | --- 11 | ## Update on March 6, 2025: 12 | - Added `convert_to_dataset.py`, which can be used to convert CSV files into a dataset for fine-tuning. For example: 13 | 14 | ```bash 15 | python convert_to_dataset.py --task_name chembl2k_sub --csv_path raw_data/chembl2k/raw/assays.csv.gz --smiles_column smiles --property_columns ABCB1 ABL1 16 | ## Update on Oct 14, 2024: 17 | - All packages can now be installed with `pip install -r requirements.txt`! 18 | - We have **automated** the model and data download process for ML developers. The InfoAlign model can now be trained with a single command! 19 | - We have created the `infoalign` package, which can be installed via `pip install infoalign`. For more details, refer to: [https://github.com/liugangcode/infoalign-package](https://github.com/liugangcode/infoalign-package). 20 | 21 | --- 22 | 23 | 24 | ## Requirements 25 | 26 | This project was developed and tested with the following versions: 27 | 28 | - **Python**: 3.11.7 29 | - **PyTorch**: 2.2.0+cu118 30 | - **Torch Geometric**: 2.6.1 31 | 32 | All dependencies are listed in the `requirements.txt` file. 33 | 34 | ### Setup Instructions 35 | 36 | 1. **Create a Conda Environment**: 37 | ```bash 38 | conda create --name infoalign python=3.11.7 39 | ``` 40 | 41 | 2. **Activate the Environment**: 42 | ```bash 43 | conda activate infoalign 44 | ``` 45 | 46 | 3. **Install Dependencies**: 47 | ```bash 48 | pip install -r requirements.txt 49 | ``` 50 | 51 | ## Usage 52 | 53 | ### Fine-tuning 54 | 55 | We provide a pretrained checkpoint available for download from [Hugging Face](https://huggingface.co/liuganghuggingface/InfoAlign-Pretrained). For fine-tuning and inference, use the following commands. The pretrained model will be automatically downloaded to the `ckpt/pretrain.pt` file by default. 56 | 57 | ```bash 58 | python main.py --model-path ckpt/pretrain.pt --dataset finetune-chembl2k 59 | python main.py --model-path ckpt/pretrain.pt --dataset finetune-broad6k 60 | python main.py --model-path ckpt/pretrain.pt --dataset finetune-biogenadme 61 | python main.py --model-path ckpt/pretrain.pt --dataset finetune-moltoxcast 62 | ``` 63 | 64 | Alternatively, you can manually download the model weights and place the `pretrain.pt` file under the `ckpt` folder along with its corresponding YAML configuration file. 65 | 66 | **Note**: If you wish to access the cell morphology and gene expression features in the ChEMBL2k and Broad6K datasets for baseline evaluation, visit our [Hugging Face repository](https://huggingface.co/liuganghuggingface/InfoAlign-Pretrained) to download these features. 67 | 68 | ### Pretraining 69 | 70 | To pretrain the model from scratch, execute the following command: 71 | 72 | ```bash 73 | python main.py --model-path "ckpt/pretrain.pt" --lr 1e-4 --wdecay 1e-8 --batch-size 3072 74 | ``` 75 | 76 | This will automatically download the pretraining dataset from [Hugging Face](https://huggingface.co/datasets/liuganghuggingface/InfoAlign-Data). If you prefer to download the dataset manually, place all pretraining data files in the `raw_data/pretrain/raw` folder. 77 | 78 | The pretrained model will be saved in the `ckpt` folder as `pretrain.pt`. 79 | 80 | --- 81 | 82 | ## Data source 83 | 84 | For readers interested in data collection, here are the sources: 85 | 86 | 1. **Cell Morphology Data** 87 | - JUMP dataset: The data are from "JUMP Cell Painting dataset: morphological impact of 136,000 chemical and genetic perturbations" and can be downloaded [here](https://github.com/jump-cellpainting/datasets/blob/1c245002cbcaea9156eea56e61baa52ad8307db3/profile_index.csv). The dataset includes chemical and genetic perturbations for cell morphology features. 88 | - Bray's dataset: "A dataset of images and morphological profiles of 30,000 small-molecule treatments using the Cell Painting assay". Download from [GigaDB](http://gigadb.org/dataset/100351). Processed version available on [Zenodo](https://zenodo.org/records/7589312). 89 | 90 | 2. **Gene Expression Data** 91 | - LINCS L1000 gene expression data from the paper "Drug-induced adverse events prediction with the LINCS L1000 data": [Data](https://maayanlab.net/SEP-L1000/#download). 92 | 93 | 3. **Relationships** 94 | - Gene-gene, gene-compound relationships from Hetionet: [Data](https://github.com/hetio/hetionet). 95 | 96 | # Citation 97 | 98 | If you find this repository useful, please cite our paper: 99 | 100 | ``` 101 | @article{liu2024learning, 102 | title={Learning Molecular Representation in a Cell}, 103 | author={Liu, Gang and Seal, Srijit and Arevalo, John and Liang, Zhenwen and Carpenter, Anne E and Jiang, Meng and Singh, Shantanu}, 104 | journal={arXiv preprint arXiv:2406.12056}, 105 | year={2024} 106 | } 107 | ``` 108 | 109 | -------------------------------------------------------------------------------- /assets/infoalign.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liugangcode/InfoAlign/eed96f36503a8a427fac4f382ee2001f957dd42a/assets/infoalign.png -------------------------------------------------------------------------------- /ckpt/pretrain.yaml: -------------------------------------------------------------------------------- 1 | emb_dim: 300 2 | model: gin-virtual 3 | norm_layer: batch_norm 4 | num_layer: 5 5 | prior: 1.0e-09 6 | readout: sum 7 | threshold: 0.8 8 | walk_length: 4 9 | -------------------------------------------------------------------------------- /configures/arguments.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import argparse 3 | 4 | model_hyperparams = [ 5 | "emb_dim", 6 | "model", 7 | "num_layer", 8 | "readout", 9 | "norm_layer", 10 | "threshold", 11 | "walk_length", 12 | "prior", 13 | ] 14 | 15 | def load_arguments_from_yaml(filename, model_only=False): 16 | with open(filename, "r") as file: 17 | config = yaml.safe_load(file) 18 | if model_only: 19 | config = {k: v for k, v in config.items() if k in model_hyperparams} 20 | else: 21 | config = yaml.safe_load(file) 22 | return config 23 | 24 | 25 | def save_arguments_to_yaml(args, filename, model_only=False): 26 | if model_only: 27 | args = {k: v for k, v in vars(args).items() if k in model_hyperparams} 28 | else: 29 | args = vars(args) 30 | 31 | with open(filename, "w") as f: 32 | yaml.dump(args, f) 33 | 34 | 35 | def get_args(): 36 | parser = argparse.ArgumentParser( 37 | description="Learning molecular representation in a cell" 38 | ) 39 | parser.add_argument( 40 | "--gpu-id", type=int, default=0, help="which gpu to use if any (default: 0)" 41 | ) 42 | parser.add_argument( 43 | "--num-workers", type=int, default=0, help="number of workers for data loader" 44 | ) 45 | parser.add_argument( 46 | "--no-print", action="store_true", default=False, help="don't use progress bar" 47 | ) 48 | 49 | parser.add_argument("--dataset", default="pretrain", type=str, help="dataset name") 50 | 51 | # model 52 | parser.add_argument( 53 | "--model", 54 | type=str, 55 | default="gin-virtual", 56 | help="GNN gin, gin-virtual, or gcn, or gcn-virtual (default: gin-virtual)", 57 | ) 58 | parser.add_argument( 59 | "--readout", type=str, default="sum", help="graph readout (default: sum)" 60 | ) 61 | parser.add_argument( 62 | "--norm-layer", 63 | type=str, 64 | default="batch_norm", 65 | help="GNN gin, gin-virtual, or gcn, or gcn-virtual (default: gin-virtual)", 66 | ) 67 | parser.add_argument( 68 | "--drop-ratio", type=float, default=0.5, help="dropout ratio (default: 0.5)" 69 | ) 70 | parser.add_argument( 71 | "--num-layer", 72 | type=int, 73 | default=5, 74 | help="number of GNN message passing layers (default: 5)", 75 | ) 76 | parser.add_argument( 77 | "--emb-dim", 78 | type=int, 79 | default=300, 80 | help="dimensionality of hidden units in GNNs (default: 300)", 81 | ) 82 | # training 83 | ## pretraining 84 | parser.add_argument( 85 | "--walk-length", 86 | type=int, 87 | default=4, 88 | help="pretraining context length", 89 | ) 90 | parser.add_argument( 91 | "--threshold", 92 | type=float, 93 | default=0.8, 94 | help="minimum similarity threshold for context graph", 95 | ) 96 | # prior 97 | parser.add_argument( 98 | "--prior", 99 | type=float, 100 | default=1e-9, 101 | help="loss weight to prior", 102 | ) 103 | 104 | ## other 105 | parser.add_argument( 106 | "--batch-size", 107 | type=int, 108 | default=5120, 109 | help="input batch size for training (default: 256)", 110 | ) 111 | parser.add_argument( 112 | "--lr", 113 | "--learning-rate", 114 | type=float, 115 | default=1e-3, 116 | help="Learning rate (default: 1e-3)", 117 | ) 118 | parser.add_argument("--wdecay", default=1e-5, type=float, help="weight decay") 119 | parser.add_argument( 120 | "--epochs", type=int, default=300, help="number of epochs to train" 121 | ) 122 | parser.add_argument( 123 | "--initw-name", 124 | type=str, 125 | default="default", 126 | help="method to initialize the model paramter", 127 | ) 128 | parser.add_argument( 129 | "--model-path", 130 | type=str, 131 | default="ckpt/pretrain.pt", 132 | help="path to the pretrained model", 133 | ) 134 | parser.add_argument( 135 | "--patience", type=int, default=50, help="patience for early stop" 136 | ) 137 | 138 | args = parser.parse_args() 139 | print("no print", args.no_print) 140 | 141 | ## n_steps for solver 142 | args.n_steps = 1 143 | return args 144 | -------------------------------------------------------------------------------- /configures/finetune.yaml: -------------------------------------------------------------------------------- 1 | batch_size: 5120 2 | dataset: pretrain 3 | drop_ratio: 0.5 4 | emb_dim: 300 5 | epochs: 300 6 | gpu_id: 0 7 | initw_name: default 8 | lr: 0.01 9 | model: gin-virtual 10 | model_path: ckpt/pretrain.pt 11 | n_steps: 1 12 | no_print: false 13 | norm_layer: batch_norm 14 | num_layer: 5 15 | num_workers: 0 16 | patience: 50 17 | readout: sum 18 | trails: 2 19 | walk_length: 3 20 | wdecay: 1.0e-05 21 | -------------------------------------------------------------------------------- /convert_to_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import pandas as pd 4 | import numpy as np 5 | from pathlib import Path 6 | 7 | 8 | class CSVToDataset: 9 | """ 10 | A class to convert CSV files containing SMILES and property data to a standardized dataset format. 11 | 12 | The class extracts the SMILES column and specified property columns from a CSV file, 13 | renames the SMILES column, and saves the results to a specified location with metadata. 14 | """ 15 | 16 | def __init__(self, task_name): 17 | """ 18 | Initialize the converter with a task name. 19 | 20 | Args: 21 | task_name (str): Name of the task/dataset 22 | """ 23 | self.task_name = task_name 24 | self.output_dir = Path(f"raw_data/{task_name}/raw") 25 | self.output_dir.mkdir(parents=True, exist_ok=True) 26 | 27 | def convert(self, csv_path, smiles_column, property_columns, eval_metric): 28 | """ 29 | Convert a CSV file to the standardized dataset format. 30 | 31 | Args: 32 | csv_path (str): Path to the input CSV file 33 | smiles_column (str): Name of the column containing SMILES strings 34 | property_columns (list): List of column names for properties to extract 35 | 36 | Returns: 37 | bool: True if conversion was successful, False otherwise 38 | """ 39 | if not os.path.exists(csv_path): 40 | print(f"Error: File {csv_path} does not exist.") 41 | return False 42 | df = pd.read_csv(csv_path) 43 | 44 | # Validate columns exist in the dataframe 45 | missing_columns = [] 46 | if smiles_column not in df.columns: 47 | missing_columns.append(smiles_column) 48 | 49 | for col in property_columns: 50 | if col not in df.columns: 51 | missing_columns.append(col) 52 | 53 | if missing_columns: 54 | print(f"Error: The following columns are missing from the CSV: {', '.join(missing_columns)}") 55 | return False 56 | 57 | # Extract and rename columns 58 | selected_columns = [smiles_column] + property_columns 59 | result_df = df[selected_columns].copy() 60 | result_df.rename(columns={smiles_column: 'smiles'}, inplace=True) 61 | 62 | # Save the processed data 63 | output_path = self.output_dir / 'assays.csv.gz' 64 | result_df.to_csv(output_path, index=False, compression='gzip') 65 | 66 | # Create and save metadata 67 | self.num_tasks = len(property_columns) 68 | metadata = { 69 | 'num_tasks': self.num_tasks, 70 | 'start_column': 1, 71 | 'eval_metric': eval_metric 72 | } 73 | 74 | with open(self.output_dir / 'meta.json', 'w') as f: 75 | json.dump(metadata, f, indent=2) 76 | 77 | print(f"Successfully processed {csv_path}") 78 | print(f"Saved data to {output_path}") 79 | print(f"Number of compounds: {len(result_df)}") 80 | print(f"Number of tasks: {self.num_tasks}") 81 | 82 | return True 83 | 84 | 85 | if __name__ == "__main__": 86 | import argparse 87 | 88 | parser = argparse.ArgumentParser(description='Convert CSV files to standardized dataset format') 89 | parser.add_argument('--task_name', type=str, required=True, help='Name of the task/dataset') 90 | parser.add_argument('--csv_path', type=str, required=True, help='Path to the input CSV file') 91 | parser.add_argument('--smiles_column', type=str, required=True, help='Name of the column containing SMILES strings') 92 | parser.add_argument('--property_columns', type=str, nargs='+', required=True, 93 | help='Names of columns containing property values to extract') 94 | parser.add_argument('--eval_metric', type=str, default='roc_auc', help='Evaluation metric') 95 | 96 | args = parser.parse_args() 97 | assert args.eval_metric in ["roc_auc", "avg_mae"] 98 | 99 | converter = CSVToDataset(args.task_name) 100 | success = converter.convert( 101 | csv_path=args.csv_path, 102 | smiles_column=args.smiles_column, 103 | property_columns=args.property_columns, 104 | eval_metric=args.eval_metric 105 | ) 106 | 107 | exit(0 if success else 1) 108 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liugangcode/InfoAlign/eed96f36503a8a427fac4f382ee2001f957dd42a/dataset/__init__.py -------------------------------------------------------------------------------- /dataset/context_graph.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import numpy as np 3 | 4 | logger = logging.getLogger(__name__) 5 | 6 | import networkx as nx 7 | 8 | from collections import defaultdict 9 | from typing import Any, Dict, List, Literal, Optional, Union 10 | 11 | import torch 12 | from torch import Tensor 13 | 14 | import torch_geometric 15 | from torch_geometric.data import Data 16 | 17 | def from_networkx( 18 | G: Any, 19 | group_node_attrs: Optional[Union[List[str], Literal['all']]] = None, 20 | group_edge_attrs: Optional[Union[List[str], Literal['all']]] = None, 21 | ) -> 'torch_geometric.data.Data': 22 | r"""Converts a :obj:`networkx.Graph` or :obj:`networkx.DiGraph` to a 23 | :class:`torch_geometric.data.Data` instance. 24 | 25 | Args: 26 | G (networkx.Graph or networkx.DiGraph): A networkx graph. 27 | group_node_attrs (List[str] or "all", optional): The node attributes to 28 | be concatenated and added to :obj:`data.x`. (default: :obj:`None`) 29 | group_edge_attrs (List[str] or "all", optional): The edge attributes to 30 | be concatenated and added to :obj:`data.edge_attr`. 31 | (default: :obj:`None`) 32 | 33 | .. note:: 34 | 35 | All :attr:`group_node_attrs` and :attr:`group_edge_attrs` values must 36 | be numeric. 37 | 38 | Examples: 39 | >>> edge_index = torch.tensor([ 40 | ... [0, 1, 1, 2, 2, 3], 41 | ... [1, 0, 2, 1, 3, 2], 42 | ... ]) 43 | >>> data = Data(edge_index=edge_index, num_nodes=4) 44 | >>> g = to_networkx(data) 45 | >>> # A `Data` object is returned 46 | >>> from_networkx(g) 47 | Data(edge_index=[2, 6], num_nodes=4) 48 | """ 49 | 50 | G = G.to_directed() if not nx.is_directed(G) else G 51 | 52 | mapping = dict(zip(G.nodes(), range(G.number_of_nodes()))) 53 | edge_index = torch.empty((2, G.number_of_edges()), dtype=torch.long) 54 | for i, (src, dst) in enumerate(G.edges()): 55 | edge_index[0, i] = mapping[src] 56 | edge_index[1, i] = mapping[dst] 57 | 58 | data_dict: Dict[str, Any] = defaultdict(list) 59 | data_dict['edge_index'] = edge_index 60 | 61 | node_attrs: List[str] = [] 62 | if G.number_of_nodes() > 0: 63 | node_attrs = list(next(iter(G.nodes(data=True)))[-1].keys()) 64 | 65 | edge_attrs: List[str] = [] 66 | if G.number_of_edges() > 0: 67 | edge_attrs = list(next(iter(G.edges(data=True)))[-1].keys()) 68 | 69 | if group_node_attrs is not None and not isinstance(group_node_attrs, list): 70 | group_node_attrs = node_attrs 71 | 72 | if group_edge_attrs is not None and not isinstance(group_edge_attrs, list): 73 | group_edge_attrs = edge_attrs 74 | 75 | for i, (_, feat_dict) in enumerate(G.nodes(data=True)): 76 | if set(feat_dict.keys()) != set(node_attrs): 77 | raise ValueError('Not all nodes contain the same attributes') 78 | for key, value in feat_dict.items(): 79 | data_dict[str(key)].append(value) 80 | 81 | for i, (_, _, feat_dict) in enumerate(G.edges(data=True)): 82 | if set(feat_dict.keys()) != set(edge_attrs): 83 | raise ValueError('Not all edges contain the same attributes') 84 | for key, value in feat_dict.items(): 85 | key = f'edge_{key}' if key in node_attrs else key 86 | data_dict[str(key)].append(value) 87 | 88 | for key, value in G.graph.items(): 89 | if key == 'node_default' or key == 'edge_default': 90 | continue # Do not load default attributes. 91 | key = f'graph_{key}' if key in node_attrs else key 92 | data_dict[str(key)] = value 93 | 94 | for key, value in data_dict.items(): 95 | if isinstance(value, (tuple, list)) and isinstance(value[0], Tensor): 96 | data_dict[key] = torch.stack(value, dim=0, dtype=torch.float32) 97 | elif isinstance(value, (tuple, list)) and isinstance(value[0], np.ndarray): 98 | data_dict[key] = torch.tensor(np.stack(value), dtype=torch.float32) 99 | else: 100 | try: 101 | data_dict[key] = torch.as_tensor(np.array(value)) 102 | except Exception: 103 | pass 104 | 105 | data = Data.from_dict(data_dict) 106 | 107 | if group_node_attrs is not None: 108 | xs = [] 109 | for key in group_node_attrs: 110 | x = data[key] 111 | x = x.view(-1, 1) if x.dim() <= 1 else x 112 | xs.append(x) 113 | del data[key] 114 | data.x = torch.cat(xs, dim=-1) 115 | 116 | if group_edge_attrs is not None: 117 | xs = [] 118 | for key in group_edge_attrs: 119 | key = f'edge_{key}' if key in node_attrs else key 120 | x = data[key] 121 | x = x.view(-1, 1) if x.dim() <= 1 else x 122 | xs.append(x) 123 | del data[key] 124 | data.edge_attr = torch.cat(xs, dim=-1) 125 | 126 | if data.x is None and data.pos is None: 127 | data.num_nodes = G.number_of_nodes() 128 | 129 | return data 130 | 131 | if __name__ == "__main__": 132 | pass -------------------------------------------------------------------------------- /dataset/create_datasets.py: -------------------------------------------------------------------------------- 1 | def get_data(args, load_path, transform="pyg"): 2 | assert transform in [ 3 | None, 4 | "fingerprint", 5 | "smiles", 6 | "pyg", 7 | "morphology", 8 | "expression", 9 | ] 10 | pretrained = args.dataset == "pretrain" 11 | 12 | if pretrained: 13 | assert transform == "pyg" 14 | from .pretrain_molecule import PretrainMoleculeDataset 15 | from .pretrain_context import PretrainContextDataset 16 | 17 | molecule = PretrainMoleculeDataset(root=load_path) 18 | context = PretrainContextDataset(root=load_path, pre_transform=args.threshold) 19 | return molecule, context 20 | 21 | if args.dataset.startswith("finetune"): 22 | data_name = args.dataset.split("-")[1] 23 | else: 24 | data_name = args.dataset 25 | 26 | # if data_name in ["broad6k", "chembl2k", "biogenadme", "moltoxcast"]: 27 | assert transform in ["fingerprint", "smiles", "morphology", "expression", "pyg"] 28 | if transform == "pyg": 29 | from .prediction_molecule import PygPredictionMoleculeDataset 30 | 31 | return PygPredictionMoleculeDataset(name=data_name, root=load_path) 32 | else: 33 | from .prediction_molecule import PredictionMoleculeDataset 34 | 35 | return PredictionMoleculeDataset( 36 | name=data_name, root=load_path, transform=transform 37 | ) 38 | 39 | -------------------------------------------------------------------------------- /dataset/data_utils.py: -------------------------------------------------------------------------------- 1 | from ogb.utils.features import atom_to_feature_vector, bond_to_feature_vector 2 | 3 | import torch 4 | import numpy as np 5 | 6 | import copy 7 | import pathlib 8 | import pandas as pd 9 | from tqdm import tqdm 10 | from torch_geometric.data import Data 11 | 12 | from rdkit import Chem 13 | from rdkit.Chem import AllChem 14 | 15 | def smiles2graph(smiles_string): 16 | """ 17 | Converts SMILES string to graph Data object 18 | :input: SMILES string (str) 19 | :return: graph object 20 | """ 21 | try: 22 | mol = Chem.MolFromSmiles(smiles_string) 23 | # atoms 24 | atom_features_list = [] 25 | # atom_label = [] 26 | for atom in mol.GetAtoms(): 27 | atom_features_list.append(atom_to_feature_vector(atom)) 28 | # atom_label.append(atom.GetSymbol()) 29 | 30 | x = np.array(atom_features_list, dtype=np.int64) 31 | # atom_label = np.array(atom_label, dtype=np.str) 32 | 33 | # bonds 34 | num_bond_features = 3 # bond type, bond stereo, is_conjugated 35 | if len(mol.GetBonds()) > 0: # mol has bonds 36 | edges_list = [] 37 | edge_features_list = [] 38 | for bond in mol.GetBonds(): 39 | i = bond.GetBeginAtomIdx() 40 | j = bond.GetEndAtomIdx() 41 | 42 | edge_feature = bond_to_feature_vector(bond) 43 | # add edges in both directions 44 | edges_list.append((i, j)) 45 | edge_features_list.append(edge_feature) 46 | edges_list.append((j, i)) 47 | edge_features_list.append(edge_feature) 48 | 49 | # data.edge_index: Graph connectivity in COO format with shape [2, num_edges] 50 | edge_index = np.array(edges_list, dtype=np.int64).T 51 | 52 | # data.edge_attr: Edge feature matrix with shape [num_edges, num_edge_features] 53 | edge_attr = np.array(edge_features_list, dtype=np.int64) 54 | 55 | else: # mol has no bonds 56 | edge_index = np.empty((2, 0), dtype=np.int64) 57 | edge_attr = np.empty((0, num_bond_features), dtype=np.int64) 58 | 59 | graph = dict() 60 | graph["edge_index"] = edge_index 61 | graph["edge_feat"] = edge_attr 62 | graph["node_feat"] = x 63 | graph["num_nodes"] = len(x) 64 | 65 | return graph 66 | 67 | except: 68 | return None 69 | 70 | 71 | def read_graph_list(mol_df, keep_id=False): 72 | 73 | mol_list = mol_df["smiles"].tolist() 74 | ids_list = mol_df["mol_id"].tolist() 75 | 76 | graph_list = [] 77 | total_length = len(mol_list) 78 | with tqdm(total=total_length, desc="Processing molecules") as pbar: 79 | for index, smiles_str in enumerate(mol_list): 80 | graph_dict = smiles2graph(smiles_str) 81 | if keep_id: 82 | graph_dict["type"] = ids_list[index] 83 | graph_list.append(graph_dict) 84 | pbar.update(1) 85 | 86 | pyg_graph_list = [] 87 | print("Converting graphs into PyG objects...") 88 | for graph in graph_list: 89 | g = Data() 90 | g.__num_nodes__ = graph["num_nodes"] 91 | g.edge_index = torch.from_numpy(graph["edge_index"]) 92 | del graph["num_nodes"] 93 | del graph["edge_index"] 94 | 95 | if graph["edge_feat"] is not None: 96 | g.edge_attr = torch.from_numpy(graph["edge_feat"]) 97 | del graph["edge_feat"] 98 | 99 | if graph["node_feat"] is not None: 100 | g.x = torch.from_numpy(graph["node_feat"]) 101 | del graph["node_feat"] 102 | 103 | if graph["type"] is not None: 104 | g.type = graph["type"] 105 | del graph["type"] 106 | 107 | addition_prop = copy.deepcopy(graph) 108 | for key in addition_prop.keys(): 109 | g[key] = torch.tensor(graph[key]) 110 | del graph[key] 111 | 112 | pyg_graph_list.append(g) 113 | 114 | return pyg_graph_list 115 | 116 | 117 | ## utils to create prediction context (modality) graph ## 118 | 119 | from rdkit.Chem.Scaffolds import MurckoScaffold 120 | from rdkit.Chem import DataStructs 121 | from collections import defaultdict 122 | from joblib import Parallel, delayed 123 | 124 | from sklearn.decomposition import PCA 125 | from sklearn.cluster import KMeans 126 | from scipy.spatial import distance 127 | from scipy.sparse import csr_matrix 128 | from sklearn.metrics.pairwise import cosine_similarity 129 | 130 | import networkx as nx 131 | from copy import deepcopy 132 | 133 | def get_scaffold(mol): 134 | """Extracts the Murcko Scaffold from a molecule.""" 135 | scaffold = MurckoScaffold.GetScaffoldForMol(mol) 136 | return Chem.MolToSmiles(scaffold) 137 | 138 | 139 | def parallel_scaffold_computation(molecule, molecule_id): 140 | """Computes scaffold for a single molecule in parallel.""" 141 | scaffold = get_scaffold(molecule) 142 | return scaffold, molecule, molecule_id 143 | 144 | 145 | def cluster_molecules_by_scaffold( 146 | molecules, all_data_id, n_jobs=-1, remove_single=True, flatten_id=True 147 | ): 148 | """Clusters molecules based on their scaffolds using parallel processing.""" 149 | # Ensure molecules and IDs are paired correctly 150 | paired_results = Parallel(n_jobs=n_jobs)( 151 | delayed(parallel_scaffold_computation)(mol, molecule_id) 152 | for mol, molecule_id in zip(molecules, all_data_id) 153 | ) 154 | 155 | # Initialize dictionaries for batches and IDs 156 | batch = defaultdict(list) 157 | batched_data_id = defaultdict(list) 158 | 159 | # Process results to fill batch and batched_data_id dictionaries 160 | for scaffold, mol, molecule_id in paired_results: 161 | batch[scaffold].append(mol) 162 | batched_data_id[scaffold].append(molecule_id) 163 | 164 | # Optionally remove clusters with only one molecule 165 | if remove_single: 166 | batch = {scaffold: mols for scaffold, mols in batch.items() if len(mols) > 1} 167 | batched_data_id = { 168 | scaffold: ids for scaffold, ids in batched_data_id.items() if len(ids) > 1 169 | } 170 | 171 | # Convert dictionaries to lists for output 172 | scaffolds = list(batch.keys()) 173 | batch = list(batch.values()) 174 | batched_data_id = list(batched_data_id.values()) 175 | if flatten_id: 176 | batched_data_id = [idd for batch in batched_data_id for idd in batch] 177 | batched_data_id = np.array(batched_data_id) 178 | 179 | return scaffolds, batch, batched_data_id 180 | 181 | 182 | def calculate_mol_similarity(fingerprint1, fingerprint2): 183 | """Wrapper function to calculate Tanimoto similarity between two fingerprints.""" 184 | return DataStructs.TanimotoSimilarity(fingerprint1, fingerprint2) 185 | 186 | 187 | def pairwise_mol_similarity(mol_list, n_jobs=1): 188 | """Calculates the internal similarity within a cluster of molecules using multiple CPUs.""" 189 | fingerprints = [ 190 | AllChem.GetMorganFingerprintAsBitVect(m, 2, nBits=1024) for m in mol_list 191 | ] 192 | n = len(fingerprints) 193 | similarity_matrix = np.zeros((n, n)) 194 | 195 | # Define a task for each pair of molecules to calculate similarity 196 | def compute_similarity(i, j): 197 | if i < j: # Avoid redundant calculations and diagonal 198 | return i, j, calculate_mol_similarity(fingerprints[i], fingerprints[j]) 199 | else: 200 | return i, j, 0 # No calculation needed, fill with zeros 201 | 202 | # Use Parallel and delayed to compute similarities in parallel 203 | results = Parallel(n_jobs=n_jobs)( 204 | delayed(compute_similarity)(i, j) for i in range(n) for j in range(i, n) 205 | ) 206 | 207 | # Fill the similarity matrix with results 208 | for i, j, sim in results: 209 | similarity_matrix[i, j] = similarity_matrix[j, i] = sim 210 | 211 | return similarity_matrix 212 | 213 | 214 | def perform_pca_and_kmeans(features, all_data_id, n_components=2, n_clusters=100, return_pca_feature=False): 215 | features = np.nan_to_num(features) 216 | pca = PCA(n_components=n_components) 217 | pca_features = pca.fit_transform(features) 218 | # print('explained_variance_ratio_ in perform_pca_and_kmeans', np.cumsum(pca.explained_variance_ratio_)) 219 | 220 | kmeans = KMeans(n_clusters=n_clusters, random_state=0).fit(pca_features) 221 | 222 | batch = [] 223 | batched_data_id = [] 224 | for i in range(n_clusters): 225 | if return_pca_feature: 226 | batch.append(pca_features[kmeans.labels_ == i]) 227 | else: 228 | batch.append(features[kmeans.labels_ == i]) 229 | batched_data_id.append(all_data_id[kmeans.labels_ == i]) 230 | 231 | batched_data_id = [idd for batch in batched_data_id for idd in batch] 232 | batched_data_id = np.array(batched_data_id) 233 | 234 | return batch, batched_data_id 235 | 236 | 237 | # calculate the similarity between the data points in the same batch and finally merge the similarity matrix 238 | def l1_similarity(matrix): 239 | matrix = np.nan_to_num(matrix) 240 | matrix = (matrix - matrix.mean(axis=1)[:, None]) / (matrix.std(axis=1)[:, None] + 1e-10) 241 | dist = distance.pdist(matrix, "cityblock") # L1 distance 242 | dist = distance.squareform(dist) # Convert to square form 243 | sim = 1 / (1 + dist) # Convert distance to similarity 244 | np.fill_diagonal(sim, 0) # Set diagonal elements to 0 245 | return sim 246 | 247 | 248 | def l2_similarity(matrix): 249 | matrix = np.nan_to_num(matrix) 250 | matrix = (matrix - matrix.mean(axis=1)[:, None]) / (matrix.std(axis=1)[:, None] + 1e-10) 251 | dist = distance.pdist(matrix, "euclidean") # L2 distance 252 | dist = distance.squareform(dist) # Convert to square form 253 | sim = 1 / (1 + dist) # Convert distance to similarity 254 | np.fill_diagonal(sim, 0) # Set diagonal elements to 0 255 | return sim 256 | 257 | 258 | def pairwise_cosine_similarity(matrix): 259 | matrix = np.nan_to_num(matrix) 260 | matrix = (matrix - matrix.mean(axis=1)[:, None]) / (matrix.std(axis=1)[:, None] + 1e-10) 261 | sim = cosine_similarity(matrix) 262 | np.fill_diagonal(sim, 0) # Set diagonal elements to 0 263 | return sim 264 | 265 | 266 | def batch_similarity(batches, similarity_func): 267 | # Calculate the similarity for each batch 268 | batch_sims = [ 269 | similarity_func(batch) 270 | for batch in tqdm(batches, desc="Calculating similarities") 271 | ] 272 | 273 | # Find the total number of nonzero elements 274 | total_nonzero = sum(np.count_nonzero(arr) for arr in batch_sims) 275 | 276 | # Initialize arrays to store row, col, and data for csr_matrix construction 277 | rows = np.zeros(total_nonzero, dtype=np.int32) 278 | cols = np.zeros(total_nonzero, dtype=np.int32) 279 | data = np.zeros(total_nonzero, dtype=batch_sims[0].dtype) 280 | 281 | current_idx = 0 282 | current_row = 0 283 | current_col = 0 284 | # Loop through each batch similarity array 285 | for idx, arr in enumerate(batch_sims): 286 | rows_batch, cols_batch = np.nonzero(arr) 287 | num_nonzero = len(rows_batch) 288 | rows[current_idx : current_idx + num_nonzero] = rows_batch + current_row 289 | cols[current_idx : current_idx + num_nonzero] = cols_batch + current_col 290 | data[current_idx : current_idx + num_nonzero] = arr[rows_batch, cols_batch] 291 | current_idx += num_nonzero 292 | current_row += arr.shape[0] 293 | current_col += arr.shape[1] 294 | 295 | # Construct the csr_matrix 296 | merged_sim = csr_matrix( 297 | (data, (rows, cols)), 298 | shape=( 299 | sum(arr.shape[0] for arr in batch_sims), 300 | sum(arr.shape[1] for arr in batch_sims), 301 | ), 302 | ) 303 | 304 | return merged_sim 305 | 306 | 307 | def direct_similarity(matrix, similarity_func): 308 | sim = similarity_func(matrix) 309 | np.fill_diagonal(sim, 0) # Set diagonal elements to 0 310 | # convert to csr_matrix 311 | sim = csr_matrix(sim) 312 | return sim 313 | 314 | def determine_threshold(similarity, min_threshold, high_threshold=0.99, target_sparsity=0.995): 315 | low_threshold = min_threshold 316 | threshold = (high_threshold + low_threshold) / 2.0 # Start in the middle 317 | current_sparsity = 1 - np.count_nonzero(similarity > threshold) / similarity.size 318 | 319 | while low_threshold < high_threshold - 0.001: # Continue until the interval is small 320 | if current_sparsity > target_sparsity: 321 | high_threshold = threshold # Move the upper limit down 322 | else: 323 | low_threshold = threshold # Move the lower limit up 324 | 325 | threshold = (high_threshold + low_threshold) / 2.0 # Recalculate the middle 326 | current_sparsity = 1 - np.count_nonzero(similarity > threshold) / similarity.size 327 | 328 | return threshold 329 | 330 | def filter_similarity_and_get_ids(similarity_matrix, threshold, data_id): 331 | filtered_similarity = similarity_matrix.copy() 332 | filtered_similarity.data[filtered_similarity.data < threshold] = 0 333 | filtered_similarity.eliminate_zeros() 334 | 335 | row, col = filtered_similarity.nonzero() 336 | 337 | s_node_id = data_id[row] 338 | t_node_id = data_id[col] 339 | edge_weight = filtered_similarity.data 340 | edge_weight = np.round(edge_weight, 2) 341 | 342 | return s_node_id, t_node_id, edge_weight 343 | 344 | def merge_features_and_dataframes(df1, df2, features1, features2, col_id1, col_id2, connect_col=None): 345 | if connect_col is None: 346 | df1['connect_id'] = df1[col_id1].astype(str) + '_' + df1[col_id2].astype(str) 347 | df2['connect_id'] = df2[col_id1].astype(str) + '_' + df2[col_id2].astype(str) 348 | else: 349 | df1['connect_id'] = df1[connect_col].astype(str) 350 | df2['connect_id'] = df2[connect_col].astype(str) 351 | index1 = f'index_{col_id1}' 352 | index2 = f'index_{col_id2}' 353 | df1[index1] = df1.index 354 | df2[index2] = df2.index 355 | 356 | merged_df = pd.merge(df1[['connect_id', index1, col_id1]], df2[['connect_id', index2, col_id2]], on='connect_id', how='outer') 357 | if connect_col is None: 358 | merged_df[[col_id1, col_id2]] = merged_df['connect_id'].str.split('_', expand=True) 359 | merged_features = [] 360 | # Iterate over each row in the merged dataframe to concatenate features 361 | for _, row in merged_df.iterrows(): 362 | feature1_row = np.nan * np.ones(features1.shape[1]) if np.isnan(row[index1]) else features1[int(row[index1])] 363 | feature2_row = np.nan * np.ones(features2.shape[1]) if np.isnan(row[index2]) else features2[int(row[index2])] 364 | merged_features.append(np.concatenate([feature1_row, feature2_row])) 365 | 366 | merged_features = np.vstack(merged_features) 367 | return merged_df, merged_features 368 | 369 | def minmax_normalize(data, min_val=None, max_val=None): 370 | """ 371 | Normalizes the data using min-max normalization, handling NaN values in min and max calculations. 372 | 373 | Parameters: 374 | - data: The data to be normalized (numpy array). 375 | - min_val: Optional. Precomputed minimum values for each feature. If not provided, computed from data. 376 | - max_val: Optional. Precomputed maximum values for each feature. If not provided, computed from data. 377 | 378 | Returns: 379 | - Normalized data, used min_val, used max_val. 380 | """ 381 | # Calculate min and max values if not provided, ignoring NaN values 382 | if min_val is None or max_val is None: 383 | min_val = np.nanmin(data, axis=0) 384 | max_val = np.nanmax(data, axis=0) 385 | else: 386 | assert min_val.shape[0] == data.shape[1], "min_val must match the number of features in data" 387 | assert max_val.shape[0] == data.shape[1], "max_val must match the number of features in data" 388 | 389 | # Handle cases where min and max are equal 390 | equal_indices = min_val == max_val 391 | min_val[equal_indices] = min_val[equal_indices] - 1e-6 392 | max_val[equal_indices] = max_val[equal_indices] + 1e-6 393 | 394 | # Normalize, handling divisions by zero or where max_val equals min_val 395 | normalized_data = np.where( 396 | max_val - min_val == 0, 397 | 0, 398 | (data - min_val) / (max_val - min_val) 399 | ) 400 | 401 | return normalized_data, min_val, max_val 402 | 403 | def create_nx_graph(folder, min_thres=0.6, min_sparsity=0.995, top_compound_gene_express=0.05): 404 | 405 | structure_df = pd.read_csv(f"{folder}/structure.csv.gz") 406 | mol_df = structure_df.drop_duplicates(subset="mol_id") 407 | smiles_list = mol_df["smiles"].tolist() 408 | all_mol_id = mol_df["mol_id"].tolist() 409 | molecules = [Chem.MolFromSmiles(smiles) for smiles in smiles_list] 410 | structure_feature = np.array( 411 | [AllChem.GetMorganFingerprintAsBitVect(m, 4, nBits=1024) for m in molecules] 412 | ) 413 | # load cell nodes and features 414 | cp_bray_df = pd.read_csv(f"{folder}/CP-Bray.csv.gz") 415 | cp_jump_df = pd.read_csv(f"{folder}/CP-JUMP.csv.gz") 416 | cp_bray_feature = np.load(f"{folder}/CP-Bray_feature.npz")["data"] 417 | cp_jump_feature = np.load(f"{folder}/CP-JUMP_feature.npz")["data"] 418 | bray_dim, jump_dim = cp_bray_feature.shape[1], cp_jump_feature.shape[1] 419 | 420 | cell_df, cell_feature = merge_features_and_dataframes( 421 | df1=cp_bray_df, df2=cp_jump_df, 422 | features1=cp_bray_feature, features2=cp_jump_feature, 423 | col_id1='cell_bid', col_id2='cell_jid', 424 | connect_col='mol_id' 425 | ) 426 | cell_df = cell_df.rename(columns={'connect_id': 'mol_id'}) 427 | cell_df['cell_id'] = ['c' + str(i) for i in range(1, len(cell_df) + 1)] 428 | cell_df = cell_df[['cell_id', 'mol_id', 'index_cell_bid', 'index_cell_jid']] 429 | 430 | # load gene nodes and features 431 | gc_df = pd.read_csv(f"{folder}/G-CRISPR.csv.gz") 432 | gc_feature = np.load(f"{folder}/G-CRISPR_feature.npz")["data"] 433 | go_df = pd.read_csv(f"{folder}/G-ORF.csv.gz") 434 | go_feature = np.load(f"{folder}/G-ORF_feature.npz")["data"] 435 | 436 | gene_df, gene_feature = merge_features_and_dataframes( 437 | df1=gc_df, df2=go_df, 438 | features1=gc_feature, features2=go_feature, 439 | col_id1='ncbi_gene_id', col_id2='mol_id', 440 | connect_col=None 441 | ) 442 | gene_df['gene_id'] = gene_df['ncbi_gene_id'].apply(lambda x: 'g' + str(x)) 443 | gene_df = gene_df[['gene_id', 'mol_id', 'index_ncbi_gene_id', 'index_mol_id']] 444 | 445 | # load gene expression 446 | express_df = pd.read_csv(f"{folder}/GE.csv.gz") 447 | express_feature = np.load(f"{folder}/GE_feature.npz")["data"] 448 | 449 | # load gene-gene interaction 450 | gg_df = pd.read_csv(f"{folder}/G-G.csv.gz") 451 | 452 | nid_to_feature_id = {} 453 | # Combine the loop iterations for different dataframes into a single loop 454 | for df, col_id in [ 455 | (mol_df, "mol_id"), 456 | (express_df, "express_id"), 457 | (gene_df, "gene_id"), 458 | (cell_df, "cell_id"), 459 | ]: 460 | for i, nid in enumerate(df[col_id]): 461 | nid_to_feature_id[nid] = i 462 | 463 | target_mol = structure_feature 464 | target_gene, min_gene, max_gene = minmax_normalize(gene_feature) 465 | target_cell, min_cell, max_cell = minmax_normalize(cell_feature) 466 | target_express, min_express, max_express = minmax_normalize(express_feature) 467 | 468 | mol_dim, gene_dim, cell_dim, express_dim = ( 469 | structure_feature.shape[1], 470 | gene_feature.shape[1], 471 | cell_feature.shape[1], 472 | express_feature.shape[1], 473 | ) 474 | 475 | #### molecular similarity 476 | scaffold_names, batch_mol_feature, batched_mol_id = cluster_molecules_by_scaffold( 477 | molecules, all_mol_id 478 | ) 479 | sim_mol = batch_similarity(batch_mol_feature, pairwise_mol_similarity) 480 | 481 | #### gene similarity 482 | batched_gene_feature, batched_gene_id = perform_pca_and_kmeans( 483 | gene_feature, gene_df["gene_id"].values 484 | ) 485 | sim_gene = batch_similarity(batched_gene_feature, pairwise_cosine_similarity) 486 | 487 | #### cell similarity 488 | batched_cell_feature, batched_cell_id = perform_pca_and_kmeans( 489 | cell_feature[:, bray_dim:], cell_df["cell_id"].values 490 | ) 491 | sim_cell = batch_similarity(batched_cell_feature, pairwise_cosine_similarity) 492 | 493 | #### gene expression similarity 494 | sim_express = direct_similarity(express_feature, pairwise_cosine_similarity) 495 | direct_express_id = express_df["express_id"].values 496 | 497 | ########## create graph ######## 498 | G = nx.Graph() 499 | 500 | mol_ids = structure_df["mol_id"].values 501 | gene_ids = gene_df["gene_id"].values 502 | cell_ids = cell_df["cell_id"].values 503 | express_ids = express_df["express_id"].values 504 | 505 | # Add molecule nodes 506 | mol_nodes = [ 507 | ( 508 | mol_ids[idx], 509 | dict( 510 | type=mol_ids[idx], 511 | mol_target=target_mol[nid_to_feature_id[mol_id]], 512 | gene_target=np.full(gene_dim, np.nan), 513 | cell_target=np.full(cell_dim, np.nan), 514 | express_target=np.full(express_dim, np.nan) 515 | ), 516 | ) 517 | for idx, mol_id in enumerate(mol_ids) 518 | ] 519 | G.add_nodes_from(mol_nodes) 520 | print("Count nodes after adding molecule nodes:", G.number_of_nodes()) 521 | 522 | # Add gene crispr nodes 523 | gene_nodes = [ 524 | ( 525 | gene_ids[idx], 526 | dict( 527 | type=gene_ids[idx], 528 | mol_target=np.full(mol_dim, np.nan), 529 | gene_target=target_gene[nid_to_feature_id[gene_id]], 530 | cell_target=np.full(cell_dim, np.nan), 531 | express_target=np.full(express_dim, np.nan), 532 | ), 533 | ) 534 | for idx, gene_id in enumerate(gene_ids) 535 | ] 536 | G.add_nodes_from(gene_nodes) 537 | print("Count nodes after adding gene nodes:", G.number_of_nodes()) 538 | 539 | cell_nodes = [ 540 | ( 541 | cell_ids[idx], 542 | dict( 543 | type=cell_ids[idx], 544 | mol_target=np.full(mol_dim, np.nan), 545 | gene_target=np.full(gene_dim, np.nan), 546 | cell_target=target_cell[nid_to_feature_id[cell_id]], 547 | express_target=np.full(express_dim, np.nan), 548 | ), 549 | ) 550 | for idx, cell_id in enumerate(cell_ids) 551 | ] 552 | G.add_nodes_from(cell_nodes) 553 | print("Count nodes after adding cell nodes:", G.number_of_nodes()) 554 | 555 | # Add gene expression nodes 556 | express_nodes = [ 557 | ( 558 | express_ids[idx], 559 | dict( 560 | type=express_ids[idx], 561 | mol_target=np.full(mol_dim, np.nan), 562 | gene_target=np.full(gene_dim, np.nan), 563 | cell_target=np.full(cell_dim, np.nan), 564 | express_target=target_express[nid_to_feature_id[express_id]], 565 | 566 | ), 567 | ) 568 | for idx, express_id in enumerate(express_ids) 569 | ] 570 | G.add_nodes_from(express_nodes) 571 | print("Count nodes after adding gene expression nodes:", G.number_of_nodes()) 572 | 573 | G.add_edges_from(zip(gene_df["gene_id"], gene_df["mol_id"]), weight=1) 574 | print("Count of edges after adding mol-gene:", G.number_of_edges()) 575 | G.add_edges_from(zip(cell_df["mol_id"], cell_df["cell_id"]), weight=1) 576 | print("Count of edges after adding mol-cell:", G.number_of_edges()) 577 | G.add_edges_from(zip(express_df["express_id"], express_df["mol_id"]), weight=1) 578 | print("Count of edges after adding mol-express:", G.number_of_edges()) 579 | 580 | # add gene-gene edges 581 | gene_source_prefixed = "g" + gg_df["source_id"].astype(str) 582 | gene_target_prefixed = "g" + gg_df["target_id"].astype(str) 583 | 584 | ## Filter the 'go' prefixed edges where both nodes exist in go_df["orf_id"] 585 | valid_gene_sources = gene_source_prefixed.isin(gene_df["gene_id"]) 586 | valid_gene_targets = gene_target_prefixed.isin(gene_df["gene_id"]) 587 | gene_edges_to_add = zip( 588 | gene_source_prefixed[valid_gene_sources & valid_gene_targets], 589 | gene_target_prefixed[valid_gene_sources & valid_gene_targets], 590 | ) 591 | G.add_edges_from(gene_edges_to_add, weight=1) 592 | print("Count of edges after adding gene-gene in the graph:", G.number_of_edges()) 593 | 594 | # Add edges between cell according to the similarity matrix 595 | def add_edges_with_weight(G, s_node, t_node, edge_weight): 596 | # G.add_edges_from(zip(s_node, t_node), weight=edge_weight) 597 | for i, (s, t) in enumerate(zip(s_node, t_node)): 598 | weight = edge_weight[i] 599 | G.add_edge(s, t, weight=weight) 600 | 601 | L1k_idmaps = pd.read_csv(f'{folder}/L1k_idmaps.csv') 602 | sorted_indices = np.argsort(np.abs(express_feature).flatten()) 603 | start_index = int(len(sorted_indices) * (1-top_compound_gene_express)) 604 | express_thre = np.abs(express_feature).flatten()[sorted_indices[start_index]] 605 | top_percent_indices = sorted_indices[start_index:] 606 | row_indices, col_indices = np.unravel_index(top_percent_indices, express_feature.shape) 607 | if len(row_indices) < 1000: 608 | top_percent_indices = sorted_indices[-1000:] 609 | row_indices, col_indices = np.unravel_index(top_percent_indices, express_feature.shape) 610 | 611 | expressed_mol = express_df['mol_id'].iloc[row_indices].values 612 | expressed_gene = L1k_idmaps['ncbi_gene_id'].iloc[col_indices].values 613 | expressed_gene = np.array([f'g{int(gid)}' for gid in expressed_gene]) 614 | expressed_weight = np.abs(express_feature[row_indices, col_indices]) 615 | expressed_gene_series = pd.Series(expressed_gene) 616 | valid_express_gene = expressed_gene_series.isin(gene_df["gene_id"]).values 617 | add_edges_with_weight(G, expressed_mol[valid_express_gene], expressed_gene[valid_express_gene], expressed_weight[valid_express_gene]) 618 | print(f"Count of edges after adding top {round(top_compound_gene_express * 100, 2)} % gene-gene from expression with threshold {round(express_thre, 4)}:", G.number_of_edges()) 619 | 620 | thre_cell = determine_threshold(sim_cell.data, min_thres, target_sparsity=min_sparsity) 621 | thre_gene = determine_threshold(sim_gene.data, min_thres, target_sparsity=min_sparsity) 622 | thre_express = determine_threshold(sim_express.data, min_thres, target_sparsity=min_sparsity) 623 | thre_mol = determine_threshold(sim_mol.data, min_thres, target_sparsity=min_sparsity) 624 | thres_dict = {'cell': thre_cell, 'gene': thre_gene, 'express': thre_express, 'mol': thre_mol} 625 | 626 | ## filter similarity and get ids 627 | mol_s_node, mol_t_node, mol_edge_weight = filter_similarity_and_get_ids( 628 | sim_mol, thres_dict['mol'], batched_mol_id 629 | ) 630 | gene_s_node, gene_t_node, gene_edge_weight = filter_similarity_and_get_ids( 631 | sim_gene, thres_dict['gene'], batched_gene_id 632 | ) 633 | cell_s_node, cell_t_node, cell_edge_weight = filter_similarity_and_get_ids( 634 | sim_cell, thres_dict['cell'], batched_cell_id 635 | ) 636 | express_s_node, express_t_node, express_edge_weight = filter_similarity_and_get_ids( 637 | sim_express, thres_dict['express'], direct_express_id 638 | ) 639 | add_edges_with_weight(G, mol_s_node, mol_t_node, mol_edge_weight) 640 | print(f"Count of edges after adding mol-mol sim with threshold {round(thres_dict['mol'], 2)}:", G.number_of_edges()) 641 | 642 | add_edges_with_weight(G, gene_s_node, gene_t_node, gene_edge_weight) 643 | print(f"Count of edges after adding gene-gene sim with threshold {round(thres_dict['gene'], 2)}:", G.number_of_edges()) 644 | 645 | add_edges_with_weight(G, cell_s_node, cell_t_node, cell_edge_weight) 646 | print(f"Count of edges after adding cell-cell sim with threshold {round(thres_dict['cell'], 2)}:", G.number_of_edges()) 647 | 648 | add_edges_with_weight(G, express_s_node, express_t_node, express_edge_weight) 649 | print(f"Count of edges after adding express-express sim with threshold {round(thres_dict['express'], 2)}:", G.number_of_edges()) 650 | 651 | nan_nodes = [n for n in G.nodes() if pd.isnull(n)] 652 | G.remove_nodes_from(nan_nodes) 653 | 654 | gene_bound = {"min": min_gene, "max": max_gene} 655 | cell_bound = {"min": min_cell, "max": max_cell} 656 | express_bound = {"min": min_express, "max": max_express} 657 | 658 | # add global information for ge_bound 659 | G.graph["gene_bound"] = gene_bound 660 | G.graph["cell_bound"] = cell_bound 661 | G.graph["express_bound"] = express_bound 662 | 663 | return G 664 | 665 | 666 | import networkx as nx 667 | from collections import defaultdict 668 | from typing import Any, Dict, List, Literal, Optional, Union 669 | 670 | import torch 671 | from torch import Tensor 672 | 673 | import torch_geometric 674 | from torch_geometric.data import Data 675 | 676 | 677 | def from_networkx( 678 | G: Any, 679 | group_node_attrs: Optional[Union[List[str], Literal["all"]]] = None, 680 | group_edge_attrs: Optional[Union[List[str], Literal["all"]]] = None, 681 | ) -> "torch_geometric.data.Data": 682 | r"""Converts a :obj:`networkx.Graph` or :obj:`networkx.DiGraph` to a 683 | :class:`torch_geometric.data.Data` instance. 684 | 685 | Args: 686 | G (networkx.Graph or networkx.DiGraph): A networkx graph. 687 | group_node_attrs (List[str] or "all", optional): The node attributes to 688 | be concatenated and added to :obj:`data.x`. (default: :obj:`None`) 689 | group_edge_attrs (List[str] or "all", optional): The edge attributes to 690 | be concatenated and added to :obj:`data.edge_attr`. 691 | (default: :obj:`None`) 692 | 693 | .. note:: 694 | 695 | All :attr:`group_node_attrs` and :attr:`group_edge_attrs` values must 696 | be numeric. 697 | 698 | Examples: 699 | >>> edge_index = torch.tensor([ 700 | ... [0, 1, 1, 2, 2, 3], 701 | ... [1, 0, 2, 1, 3, 2], 702 | ... ]) 703 | >>> data = Data(edge_index=edge_index, num_nodes=4) 704 | >>> g = to_networkx(data) 705 | >>> # A `Data` object is returned 706 | >>> from_networkx(g) 707 | Data(edge_index=[2, 6], num_nodes=4) 708 | """ 709 | 710 | G = G.to_directed() if not nx.is_directed(G) else G 711 | 712 | mapping = dict(zip(G.nodes(), range(G.number_of_nodes()))) 713 | edge_index = torch.empty((2, G.number_of_edges()), dtype=torch.long) 714 | for i, (src, dst) in enumerate(G.edges()): 715 | edge_index[0, i] = mapping[src] 716 | edge_index[1, i] = mapping[dst] 717 | 718 | data_dict: Dict[str, Any] = defaultdict(list) 719 | data_dict["edge_index"] = edge_index 720 | 721 | node_attrs: List[str] = [] 722 | if G.number_of_nodes() > 0: 723 | node_attrs = list(next(iter(G.nodes(data=True)))[-1].keys()) 724 | 725 | edge_attrs: List[str] = [] 726 | if G.number_of_edges() > 0: 727 | edge_attrs = list(next(iter(G.edges(data=True)))[-1].keys()) 728 | 729 | if group_node_attrs is not None and not isinstance(group_node_attrs, list): 730 | group_node_attrs = node_attrs 731 | 732 | if group_edge_attrs is not None and not isinstance(group_edge_attrs, list): 733 | group_edge_attrs = edge_attrs 734 | 735 | for i, (_, feat_dict) in enumerate(G.nodes(data=True)): 736 | if set(feat_dict.keys()) != set(node_attrs): 737 | print('i', i) 738 | print('feat_dict', feat_dict) 739 | print('node_attrs', node_attrs) 740 | raise ValueError("Not all nodes contain the same attributes") 741 | for key, value in feat_dict.items(): 742 | data_dict[str(key)].append(value) 743 | 744 | for i, (_, _, feat_dict) in enumerate(G.edges(data=True)): 745 | if set(feat_dict.keys()) != set(edge_attrs): 746 | raise ValueError("Not all edges contain the same attributes") 747 | for key, value in feat_dict.items(): 748 | key = f"edge_{key}" if key in node_attrs else key 749 | data_dict[str(key)].append(value) 750 | 751 | for key, value in G.graph.items(): 752 | if key == "node_default" or key == "edge_default": 753 | continue # Do not load default attributes. 754 | key = f"graph_{key}" if key in node_attrs else key 755 | data_dict[str(key)] = value 756 | 757 | for key, value in data_dict.items(): 758 | if isinstance(value, (tuple, list)) and isinstance(value[0], Tensor): 759 | data_dict[key] = torch.stack(value, dim=0, dtype=torch.float32) 760 | elif isinstance(value, (tuple, list)) and isinstance(value[0], np.ndarray): 761 | data_dict[key] = torch.tensor(np.stack(value), dtype=torch.float32) 762 | else: 763 | try: 764 | data_dict[key] = torch.as_tensor(np.array(value)) 765 | except Exception: 766 | pass 767 | 768 | data = Data.from_dict(data_dict) 769 | 770 | if group_node_attrs is not None: 771 | xs = [] 772 | for key in group_node_attrs: 773 | x = data[key] 774 | x = x.view(-1, 1) if x.dim() <= 1 else x 775 | xs.append(x) 776 | del data[key] 777 | data.x = torch.cat(xs, dim=-1) 778 | 779 | if group_edge_attrs is not None: 780 | xs = [] 781 | for key in group_edge_attrs: 782 | key = f"edge_{key}" if key in node_attrs else key 783 | x = data[key] 784 | x = x.view(-1, 1) if x.dim() <= 1 else x 785 | xs.append(x) 786 | del data[key] 787 | data.edge_attr = torch.cat(xs, dim=-1) 788 | 789 | if data.x is None and data.pos is None: 790 | data.num_nodes = G.number_of_nodes() 791 | 792 | return data 793 | 794 | 795 | 796 | ##### scaffold splitting 797 | 798 | def scaffold_split(train_df, train_ratio=0.6, valid_ratio=0.15, test_ratio=0.25): 799 | """Splits a dataframe of molecules into scaffold-based clusters.""" 800 | # Get smiles from the dataframe 801 | train_smiles_list = train_df["smiles"] 802 | 803 | indinces = list(range(len(train_smiles_list))) 804 | train_mol_list = [Chem.MolFromSmiles(smiles) for smiles in train_smiles_list] 805 | scaffold_names, _, batched_id = cluster_molecules_by_scaffold(train_mol_list, indinces, remove_single=False, flatten_id=False) 806 | 807 | train_cutoff = int(train_ratio * len(train_df)) 808 | valid_cutoff = int(valid_ratio * len(train_df)) + train_cutoff 809 | train_inds, valid_inds, test_inds = [], [], [] 810 | inds_all = deepcopy(batched_id) 811 | np.random.seed(3) 812 | np.random.shuffle(inds_all) 813 | idx_count = 0 814 | for inds_list in inds_all: 815 | for ind in inds_list: 816 | if idx_count < train_cutoff: 817 | train_inds.append(ind) 818 | elif idx_count < valid_cutoff: 819 | valid_inds.append(ind) 820 | else: 821 | test_inds.append(ind) 822 | idx_count += 1 823 | 824 | return train_inds, valid_inds, test_inds -------------------------------------------------------------------------------- /dataset/prediction_molecule.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import os 3 | import os.path as osp 4 | import json 5 | import pandas as pd 6 | import numpy as np 7 | import torch 8 | 9 | from torch_geometric.data import InMemoryDataset, Data 10 | from .data_utils import smiles2graph, scaffold_split 11 | 12 | class PygPredictionMoleculeDataset(InMemoryDataset): 13 | def __init__( 14 | self, name="chembl2k", root="raw_data", transform=None, pre_transform=None 15 | ): 16 | self.name = name 17 | self.root = osp.join(root, name) 18 | self.task_type = 'finetune' 19 | 20 | self.eval_metric = "roc_auc" 21 | if name == "chembl2k": 22 | self.num_tasks = 41 23 | self.start_column = 4 24 | elif name == "broad6k": 25 | self.num_tasks = 32 26 | self.start_column = 2 27 | elif name == "biogenadme": 28 | self.num_tasks = 6 29 | self.start_column = 4 30 | self.eval_metric = "avg_mae" 31 | elif name == "moltoxcast": 32 | self.num_tasks = 617 33 | self.start_column = 2 34 | else: 35 | meta_path = osp.join(self.root, "raw", "meta.json") 36 | if os.path.exists(meta_path): 37 | with open(meta_path, "r") as f: 38 | meta = json.load(f) 39 | self.num_tasks = meta["num_tasks"] 40 | self.start_column = meta["start_column"] 41 | self.eval_metric = meta["eval_metric"] 42 | else: 43 | raise ValueError("Invalid dataset name") 44 | 45 | super(PygPredictionMoleculeDataset, self).__init__( 46 | self.root, transform, pre_transform 47 | ) 48 | self.data, self.slices = torch.load(self.processed_paths[0]) 49 | 50 | def get_idx_split(self): 51 | path = osp.join(self.root, "split", "scaffold") 52 | 53 | if os.path.isfile(os.path.join(path, "split_dict.pt")): 54 | return torch.load(os.path.join(path, "split_dict.pt")) 55 | else: 56 | print("Initializing split...") 57 | data_df = pd.read_csv(osp.join(self.raw_dir, "assays.csv.gz")) 58 | train_idx, valid_idx, test_idx = scaffold_split(data_df) 59 | train_idx = torch.tensor(train_idx, dtype=torch.long) 60 | valid_idx = torch.tensor(valid_idx, dtype=torch.long) 61 | test_idx = torch.tensor(test_idx, dtype=torch.long) 62 | os.makedirs(path, exist_ok=True) 63 | torch.save( 64 | {"train": train_idx, "valid": valid_idx, "test": test_idx}, 65 | os.path.join(path, "split_dict.pt"), 66 | ) 67 | return {"train": train_idx, "valid": valid_idx, "test": test_idx} 68 | 69 | @property 70 | def raw_file_names(self): 71 | return ["assays.csv.gz"] 72 | 73 | @property 74 | def processed_file_names(self): 75 | return ["geometric_data_processed.pt"] 76 | 77 | def download(self): 78 | assert os.path.exists( 79 | os.path.join(self.raw_dir, "assays.csv.gz") 80 | ), f"assays.csv.gz does not exist in {self.raw_dir}" 81 | 82 | def process(self): 83 | data_df = pd.read_csv(osp.join(self.raw_dir, "assays.csv.gz")) 84 | 85 | pyg_graph_list = [] 86 | for idx, row in data_df.iterrows(): 87 | smiles = row["smiles"] 88 | graph = smiles2graph(smiles) 89 | 90 | g = Data() 91 | g.num_nodes = graph["num_nodes"] 92 | g.edge_index = torch.from_numpy(graph["edge_index"]) 93 | 94 | del graph["num_nodes"] 95 | del graph["edge_index"] 96 | 97 | if graph["edge_feat"] is not None: 98 | g.edge_attr = torch.from_numpy(graph["edge_feat"]) 99 | del graph["edge_feat"] 100 | 101 | if graph["node_feat"] is not None: 102 | g.x = torch.from_numpy(graph["node_feat"]) 103 | del graph["node_feat"] 104 | 105 | try: 106 | g.fp = torch.tensor(graph["fp"], dtype=torch.int8).view(1, -1) 107 | del graph["fp"] 108 | except: 109 | pass 110 | 111 | y = [] 112 | for col in range(self.start_column, len(row)): 113 | y.append(float(row.iloc[col])) 114 | 115 | g.y = torch.tensor(y, dtype=torch.float32).view(1, -1) 116 | pyg_graph_list.append(g) 117 | 118 | pyg_graph_list = ( 119 | pyg_graph_list 120 | if self.pre_transform is None 121 | else self.pre_transform(pyg_graph_list) 122 | ) 123 | print("Saving...") 124 | torch.save(self.collate(pyg_graph_list), self.processed_paths[0]) 125 | 126 | def __repr__(self): 127 | return "{}()".format(self.__class__.__name__) 128 | 129 | 130 | class PredictionMoleculeDataset(object): 131 | def __init__(self, name="chembl2k", root="raw_data", transform="smiles"): 132 | 133 | assert transform in [ 134 | "smiles", 135 | "fingerprint", 136 | "morphology", 137 | "expression", 138 | ], "Invalid transform type" 139 | 140 | self.name = name 141 | self.folder = osp.join(root, name) 142 | self.transform = transform 143 | self.raw_data = os.path.join(self.folder, "raw", "assays.csv.gz") 144 | 145 | self.eval_metric = "roc_auc" 146 | if name == "chembl2k": 147 | self.num_tasks = 41 148 | self.start_column = 4 149 | elif name == "broad6k": 150 | self.num_tasks = 32 151 | self.start_column = 2 152 | elif "moltoxcast" in self.name: 153 | self.num_tasks = 617 154 | self.start_column = 2 155 | elif name == "biogenadme": 156 | self.num_tasks = 6 157 | self.start_column = 4 158 | self.eval_metric = "avg_mae" 159 | else: 160 | meta_path = osp.join(self.folder, "raw", "meta.json") 161 | if os.path.exists(meta_path): 162 | with open(meta_path, "r") as f: 163 | meta = json.load(f) 164 | self.num_tasks = meta["num_tasks"] 165 | self.start_column = meta["start_column"] 166 | else: 167 | raise ValueError("Invalid dataset name") 168 | 169 | super(PredictionMoleculeDataset, self).__init__() 170 | if transform == "smiles": 171 | self.prepare_smiles() 172 | elif transform == "fingerprint": 173 | self.prepare_fingerprints() 174 | elif transform in ["morphology", "expression"]: 175 | self.prepare_other_modality() 176 | 177 | def get_idx_split(self, to_list=False): 178 | path = osp.join(self.folder, "split", "scaffold") 179 | if os.path.isfile(os.path.join(path, "split_dict.pt")): 180 | split_dict = torch.load(os.path.join(path, "split_dict.pt")) 181 | else: 182 | data_df = pd.read_csv(self.raw_data) 183 | train_idx, valid_idx, test_idx = scaffold_split(data_df) 184 | train_idx = torch.tensor(train_idx, dtype=torch.long) 185 | valid_idx = torch.tensor(valid_idx, dtype=torch.long) 186 | test_idx = torch.tensor(test_idx, dtype=torch.long) 187 | 188 | os.makedirs(path, exist_ok=True) 189 | torch.save( 190 | {"train": train_idx, "valid": valid_idx, "test": test_idx}, 191 | os.path.join(path, "split_dict.pt"), 192 | ) 193 | split_dict = {"train": train_idx, "valid": valid_idx, "test": test_idx} 194 | 195 | if to_list: 196 | split_dict = {k: v.tolist() for k, v in split_dict.items()} 197 | return split_dict 198 | 199 | def prepare_other_modality(self): 200 | assert os.path.exists( 201 | self.raw_data 202 | ), f" {self.raw_data} assays.csv.gz does not exist" 203 | data_df = pd.read_csv(self.raw_data) 204 | 205 | processed_dir = osp.join(self.folder, "processed") 206 | os.makedirs(processed_dir, exist_ok=True) 207 | 208 | if self.transform == "morphology": 209 | if self.name == "chembl2k": 210 | feature_df = pd.read_csv( 211 | os.path.join(self.folder, "raw", "CP-JUMP.csv.gz"), 212 | compression="gzip", 213 | ) 214 | feature_arr = np.load( 215 | os.path.join(self.folder, "raw", "CP-JUMP_feature.npz") 216 | )["data"] 217 | else: 218 | feature_df = pd.read_csv( 219 | os.path.join(self.folder, "raw", "CP-Bray.csv.gz"), 220 | compression="gzip", 221 | ) 222 | feature_arr = np.load( 223 | os.path.join(self.folder, "raw", "CP-Bray_feature.npz") 224 | )["data"] 225 | else: 226 | feature_df = pd.read_csv( 227 | os.path.join(self.folder, "raw", "GE.csv.gz"), compression="gzip" 228 | ) 229 | feature_arr = np.load(os.path.join(self.folder, "raw", "GE_feature.npz"))[ 230 | "data" 231 | ] 232 | 233 | if not osp.exists(osp.join(processed_dir, f"processed_{self.transform}.pt")): 234 | x_list = [] 235 | y_list = [] 236 | feature_dim = feature_arr.shape[1] 237 | for idx, row in data_df.iterrows(): 238 | if len(feature_df[feature_df["inchikey"] == row["inchikey"]]) == 0: 239 | x_list.append(torch.tensor([float("nan")] * feature_dim)) 240 | else: 241 | x_tensor = torch.tensor( 242 | feature_arr[ 243 | feature_df[ 244 | feature_df["inchikey"] == row["inchikey"] 245 | ].index.tolist()[0] 246 | ], 247 | dtype=torch.float32, 248 | ) 249 | x_list.append(x_tensor) 250 | 251 | y = [] 252 | for col in range(self.start_column, len(row)): 253 | y.append(float(row.iloc[col])) 254 | y = torch.tensor(y, dtype=torch.float32) 255 | y_list.append(y) 256 | 257 | x_list = torch.stack(x_list, dim=0) 258 | y_list = torch.stack(y_list, dim=0) 259 | torch.save( 260 | (x_list, y_list), 261 | osp.join(processed_dir, f"processed_{self.transform}.pt"), 262 | ) 263 | else: 264 | x_list, y_list = torch.load( 265 | osp.join(processed_dir, f"processed_{self.transform}.pt") 266 | ) 267 | 268 | self.data = x_list 269 | self.labels = y_list 270 | 271 | def prepare_smiles(self): 272 | assert os.path.exists( 273 | self.raw_data 274 | ), f" {self.raw_data} assays.csv.gz does not exist" 275 | data_df = pd.read_csv(self.raw_data) 276 | 277 | processed_dir = osp.join(self.folder, "processed") 278 | os.makedirs(processed_dir, exist_ok=True) 279 | x_list = [] 280 | y_list = [] 281 | for idx, row in data_df.iterrows(): 282 | smiles = row["smiles"] 283 | x_list.append(smiles) 284 | y = [] 285 | for col in range(self.start_column, len(row)): 286 | y.append(float(row.iloc[col])) 287 | y = torch.tensor(y, dtype=torch.float32) 288 | y_list.append(y) 289 | 290 | self.data = x_list 291 | self.labels = y_list 292 | 293 | def prepare_fingerprints(self): 294 | assert os.path.exists( 295 | self.raw_data 296 | ), f" {self.raw_data} assays.csv.gz does not exist" 297 | data_df = pd.read_csv(self.raw_data) 298 | 299 | processed_dir = osp.join(self.folder, "processed") 300 | os.makedirs(processed_dir, exist_ok=True) 301 | 302 | if not osp.exists(osp.join(processed_dir, "processed_fp.pt")): 303 | print("Processing fingerprints...") 304 | from rdkit import Chem 305 | from rdkit.Chem import AllChem 306 | 307 | x_list = [] 308 | y_list = [] 309 | for idx, row in data_df.iterrows(): 310 | smiles = row["smiles"] 311 | mol = Chem.MolFromSmiles(smiles) 312 | x = torch.tensor( 313 | list(AllChem.GetMorganFingerprintAsBitVect(mol, 2)), 314 | dtype=torch.float32, 315 | ) 316 | x_list.append(x) 317 | y = [] 318 | for col in range(self.start_column, len(row)): 319 | y.append(float(row.iloc[col])) 320 | y = torch.tensor(y, dtype=torch.float32) 321 | y_list.append(y) 322 | 323 | x_list = torch.stack(x_list, dim=0) 324 | y_list = torch.stack(y_list, dim=0) 325 | torch.save((x_list, y_list), osp.join(processed_dir, "processed_fp.pt")) 326 | else: 327 | x_list, y_list = torch.load(osp.join(processed_dir, "processed_fp.pt")) 328 | 329 | self.data = x_list 330 | self.labels = y_list 331 | 332 | def __getitem__(self, idx): 333 | """Get datapoint(s) with index(indices)""" 334 | 335 | if isinstance(idx, (int, np.integer)): 336 | return self.data[idx], self.labels[idx] 337 | elif isinstance(idx, (list, np.ndarray)): 338 | return [self.data[i] for i in idx], [self.labels[i] for i in idx] 339 | elif isinstance(idx, torch.LongTensor): 340 | return self.data[idx], self.labels[idx] 341 | 342 | raise IndexError("Not supported index {}.".format(type(idx).__name__)) 343 | 344 | def __len__(self): 345 | return len(self.data) 346 | 347 | def __repr__(self): 348 | return "{}({})".format(self.__class__.__name__, len(self)) 349 | 350 | 351 | if __name__ == "__main__": 352 | pass 353 | -------------------------------------------------------------------------------- /dataset/pretrain_context.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import pickle 4 | import logging 5 | import pandas as pd 6 | import os.path as osp 7 | from torch_geometric.data import InMemoryDataset 8 | 9 | from .data_utils import create_nx_graph, from_networkx 10 | 11 | # from data_utils import create_nx_graph, from_networkx 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | class PretrainContextDataset(InMemoryDataset): 16 | def __init__( 17 | self, name="pretrain", root="raw_data", transform=None, pre_transform=None 18 | ): 19 | """ 20 | - name (str): name of the pretraining dataset: pretrain 21 | - root (str): root directory to store the dataset folder 22 | - transform, pre_transform (optional): transform/pre-transform graph objects 23 | """ 24 | self.name = name 25 | self.dir_name = "_".join(name.split("-")) 26 | self.original_root = root 27 | self.root = osp.join(root, self.dir_name) 28 | self.processed_root = osp.join(osp.abspath(self.root)) 29 | self.threshold = pre_transform 30 | self.nxg_name = f"nxg_mint{pre_transform}" 31 | 32 | super(PretrainContextDataset, self).__init__(self.processed_root, None, None) 33 | self.data, self.slices = torch.load(self.processed_paths[0]) 34 | 35 | def processed_file_names(self): 36 | if self.threshold is not None: 37 | return [f"context_{self.nxg_name}.pt"] 38 | else: 39 | return ["context_data_processed.pt"] 40 | 41 | def process(self): 42 | threshold = self.threshold if self.threshold is not None else 0.6 43 | folder = osp.join(self.root, "raw") 44 | nxg_name = f"{folder}/{self.nxg_name}.pickle" 45 | if not os.path.exists(nxg_name): 46 | G = create_nx_graph(folder, min_thres=threshold, top_compound_gene_express=0.01) 47 | with open(nxg_name, "wb") as f: 48 | pickle.dump(G, f) 49 | else: 50 | G = pd.read_pickle(nxg_name) 51 | 52 | pyg_graph = from_networkx(G) 53 | torch.save(self.collate([pyg_graph]), self.processed_paths[0]) 54 | 55 | 56 | # main file 57 | if __name__ == "__main__": 58 | pass -------------------------------------------------------------------------------- /dataset/pretrain_molecule.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import logging 4 | import pandas as pd 5 | import os.path as osp 6 | from torch_geometric.data import InMemoryDataset 7 | from sklearn.model_selection import train_test_split 8 | 9 | from .data_utils import read_graph_list 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | class PretrainMoleculeDataset(InMemoryDataset): 14 | def __init__(self, name='pretrain', root ='raw_data', transform=None, pre_transform = None): 15 | ''' 16 | - name (str): name of the pretraining dataset: pretrain_all 17 | - root (str): root directory to store the dataset folder 18 | - transform, pre_transform (optional): transform/pre-transform graph objects 19 | ''' 20 | self.name = name 21 | self.dir_name = '_'.join(name.split('-')) 22 | self.original_root = root 23 | self.root = osp.join(root, self.dir_name) 24 | self.processed_root = osp.join(osp.abspath(self.root)) 25 | 26 | self.num_tasks = 1 27 | self.eval_metric = 'customize' 28 | self.task_type = 'pretrain' 29 | self.__num_classes__ = '-1' 30 | self.binary = 'False' 31 | 32 | super(PretrainMoleculeDataset, self).__init__(self.processed_root, transform, pre_transform) 33 | self.data, self.slices = torch.load(self.processed_paths[0]) 34 | self.total_data_len = self.__len__() 35 | 36 | def get_idx_split(self): 37 | full_idx = list(range(self.total_data_len)) 38 | train_idx, valid_idx, test_idx = full_idx, [], [] 39 | return {'train': torch.tensor(train_idx, dtype = torch.long), 'valid': torch.tensor(valid_idx, dtype = torch.long), 'test': torch.tensor(test_idx, dtype = torch.long)} 40 | 41 | @property 42 | def processed_file_names(self): 43 | return ['mol_data_processed.pt'] 44 | 45 | def process(self): 46 | 47 | mol_data_path = osp.join(self.root, 'raw', 'structure.csv.gz') 48 | print('Processing molecule data at folder: ' , mol_data_path) 49 | 50 | mol_df = pd.read_csv(mol_data_path, compression='gzip') 51 | mol_df = mol_df.drop_duplicates(subset="mol_id") 52 | data_list = read_graph_list(mol_df, keep_id=True) 53 | 54 | self.total_data_len = len(data_list) 55 | 56 | print('Pretrain molecule data loading finished with length ', self.total_data_len) 57 | 58 | if self.pre_transform is not None: 59 | data_list = [self.pre_transform(data) for data in data_list] 60 | data, slices = self.collate(data_list) 61 | torch.save((data, slices), self.processed_paths[0]) 62 | 63 | 64 | 65 | 66 | # main file 67 | if __name__ == "__main__": 68 | pass -------------------------------------------------------------------------------- /dataset/retrieval.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import os 3 | import os.path as osp 4 | 5 | import pandas as pd 6 | import numpy as np 7 | import torch 8 | 9 | from torch_geometric.data import InMemoryDataset, Data 10 | from .data_utils import smiles2graph, scaffold_split 11 | # from data_utils import smiles2graph, scaffold_split 12 | 13 | class PygRetrievalMoleculeDataset(InMemoryDataset): 14 | def __init__( 15 | self, name="chembl2k", root="raw_data", transform=None, pre_transform=None 16 | ): 17 | self.name = name 18 | self.root = osp.join(root, name) 19 | self.task_type = 'ranking' 20 | 21 | self.eval_metric = "rank" 22 | if name == "chembl2k": 23 | self.num_tasks = 41 24 | self.start_column = 4 25 | self.target_file = 'CP-JUMP_feature.npz' 26 | self.mol_file = 'CP-JUMP.csv.gz' 27 | elif name == "broad6k": 28 | self.num_tasks = 32 29 | self.start_column = 2 30 | self.target_file = 'CP-Bray_feature.npz' 31 | self.mol_file = 'CP-Bray.csv.gz' 32 | else: 33 | raise ValueError("Invalid dataset name") 34 | 35 | super(PygRetrievalMoleculeDataset, self).__init__( 36 | self.root, transform, pre_transform 37 | ) 38 | self.data, self.slices = torch.load(self.processed_paths[0]) 39 | 40 | @property 41 | def raw_file_names(self): 42 | return [self.target_file, self.mol_file] 43 | 44 | @property 45 | def processed_file_names(self): 46 | return [f"ret_processed_{self.name}_pyg.pt"] 47 | 48 | def download(self): 49 | assert os.path.exists( 50 | osp.join(self.raw_dir, self.target_file) 51 | ), f" {osp.join(self.raw_dir, self.target_file)} does not exist" 52 | assert os.path.exists( 53 | osp.join(self.raw_dir, self.mol_file) 54 | ), f" {osp.join(self.raw_dir, self.mol_file)} does not exist" 55 | 56 | def process(self): 57 | target_reps = np.load(osp.join(self.raw_dir, self.target_file))['data'] 58 | mol_df = pd.read_csv(osp.join(self.raw_dir, self.mol_file)) 59 | 60 | if self.name == 'broad6k': 61 | for i in range(target_reps.shape[1]): 62 | mol_df['feature_'+str(i)] = target_reps[:, i] 63 | smiles_dict = mol_df.drop_duplicates('inchikey').set_index('inchikey')['smiles'].to_dict() 64 | feature_cols = [col for col in mol_df.columns if 'feature_' in col] 65 | df_subset = mol_df[['inchikey'] + feature_cols] 66 | df_subset = df_subset.groupby('inchikey').median().reset_index() 67 | df_subset['smiles'] = df_subset['inchikey'].map(smiles_dict) 68 | 69 | target_reps = df_subset[feature_cols].values 70 | mol_df = df_subset.drop(columns=feature_cols) 71 | 72 | mol_smiles_list = mol_df['smiles'].tolist() 73 | target_reps = torch.tensor(target_reps, dtype=torch.float32) 74 | 75 | pyg_graph_list = [] 76 | for idx, smiles in enumerate(mol_smiles_list): 77 | graph = smiles2graph(smiles) 78 | 79 | g = Data() 80 | g.num_nodes = graph["num_nodes"] 81 | g.edge_index = torch.from_numpy(graph["edge_index"]) 82 | 83 | del graph["num_nodes"] 84 | del graph["edge_index"] 85 | 86 | if graph["edge_feat"] is not None: 87 | g.edge_attr = torch.from_numpy(graph["edge_feat"]) 88 | del graph["edge_feat"] 89 | 90 | if graph["node_feat"] is not None: 91 | g.x = torch.from_numpy(graph["node_feat"]) 92 | del graph["node_feat"] 93 | 94 | try: 95 | g.fp = torch.tensor(graph["fp"], dtype=torch.int8).view(1, -1) 96 | del graph["fp"] 97 | except: 98 | pass 99 | 100 | g.target = target_reps[idx].view(1, -1) 101 | 102 | pyg_graph_list.append(g) 103 | 104 | print("Saving...") 105 | torch.save(self.collate(pyg_graph_list), self.processed_paths[0]) 106 | 107 | def __repr__(self): 108 | return "{}()".format(self.__class__.__name__) 109 | 110 | 111 | class RetrievalMoleculeDataset(object): 112 | def __init__(self, name="chembl2k", root="raw_data", transform="smiles"): 113 | 114 | assert transform in [ 115 | "smiles", 116 | "fingerprint", 117 | "pyg", 118 | ], "Invalid transform type" 119 | 120 | self.name = name 121 | self.folder = osp.join(root, name) 122 | self.transform = transform 123 | 124 | self.eval_metric = "rank" 125 | if name == "chembl2k": 126 | self.num_tasks = 41 127 | self.start_column = 4 128 | target_file = 'CP-JUMP_feature.npz' 129 | mol_file = 'CP-JUMP.csv.gz' 130 | elif name == "broad6k": 131 | self.num_tasks = 32 132 | self.start_column = 2 133 | target_file = 'CP-Bray_feature.npz' 134 | mol_file = 'CP-Bray.csv.gz' 135 | else: 136 | raise ValueError("Invalid dataset name") 137 | 138 | self.target_file = os.path.join(self.folder, "raw", target_file) 139 | self.mol_file = os.path.join(self.folder, "raw", mol_file) 140 | 141 | super(RetrievalMoleculeDataset, self).__init__() 142 | self.prepare_data() 143 | 144 | def prepare_data(self): 145 | processed_dir = osp.join(self.folder, "processed") 146 | processed_file = osp.join(processed_dir, f"ret_processed_{self.name}_{self.transform}.pt") 147 | if not osp.exists(processed_file): 148 | assert os.path.exists( 149 | self.target_file 150 | ), f" {self.target_file} does not exist" 151 | target_reps = np.load(self.target_file)['data'] 152 | assert os.path.exists( 153 | self.mol_file 154 | ), f" {self.mol_file} does not exist" 155 | mol_df = pd.read_csv(self.mol_file) 156 | 157 | if self.name == 'broad6k': 158 | for i in range(target_reps.shape[1]): 159 | mol_df['feature_'+str(i)] = target_reps[:, i] 160 | smiles_dict = mol_df.drop_duplicates('inchikey').set_index('inchikey')['smiles'].to_dict() 161 | feature_cols = [col for col in mol_df.columns if 'feature_' in col] 162 | df_subset = mol_df[['inchikey'] + feature_cols] 163 | df_subset = df_subset.groupby('inchikey').median().reset_index() 164 | df_subset['smiles'] = df_subset['inchikey'].map(smiles_dict) 165 | 166 | target_reps = df_subset[feature_cols].values 167 | mol_df = df_subset.drop(columns=feature_cols) 168 | 169 | mol_smiles_list = mol_df['smiles'].tolist() 170 | inchikey_list = mol_df['inchikey'].tolist() 171 | target_reps = torch.tensor(target_reps, dtype=torch.float32) 172 | 173 | if self.transform == "smiles": 174 | self.data = mol_smiles_list 175 | self.target = target_reps 176 | self.inchikey_list = inchikey_list 177 | torch.save((mol_smiles_list, target_reps, inchikey_list), processed_file) 178 | elif self.transform == "fingerprint": 179 | from rdkit import Chem 180 | from rdkit.Chem import AllChem 181 | x_list = [] 182 | for smiles in mol_smiles_list: 183 | mol = Chem.MolFromSmiles(smiles) 184 | x = torch.tensor( 185 | list(AllChem.GetMorganFingerprintAsBitVect(mol, 2)), 186 | dtype=torch.float32, 187 | ) 188 | x_list.append(x) 189 | x_list = torch.stack(x_list, dim=0) 190 | self.data = x_list 191 | self.target = target_reps 192 | self.inchikey_list = inchikey_list 193 | torch.save((x_list, target_reps, inchikey_list), processed_file) 194 | else: 195 | raise ValueError("Invalid transform type") 196 | else: 197 | self.data, self.target, self.inchikey_list = torch.load(processed_file) 198 | 199 | def __getitem__(self, idx): 200 | """Get datapoint(s) with index(indices)""" 201 | 202 | if isinstance(idx, slice): 203 | return self.data[idx], self.target[idx], self.inchikey_list[idx] 204 | 205 | raise IndexError("Not supported index {}.".format(type(idx).__name__)) 206 | 207 | def __len__(self): 208 | return len(self.data) 209 | 210 | def __repr__(self): 211 | return "{}({})".format(self.__class__.__name__, len(self)) 212 | 213 | 214 | if __name__ == "__main__": 215 | pass -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | warnings.filterwarnings("ignore", category=UserWarning) 4 | 5 | import math 6 | import logging 7 | 8 | import numpy as np 9 | import torch 10 | import torch.optim as optim 11 | from torch.optim.lr_scheduler import LambdaLR 12 | from torch_geometric.loader import DataLoader 13 | from torch.nn.utils import parameters_to_vector, vector_to_parameters 14 | 15 | from configures.arguments import ( 16 | load_arguments_from_yaml, 17 | save_arguments_to_yaml, 18 | get_args, 19 | ) 20 | from dataset.create_datasets import get_data 21 | from utils import validate, init_weights, save_prediction 22 | from utils.train_funcs import pretrain_func, finetune_func 23 | 24 | 25 | def get_logger(name, logfile=None): 26 | """create a nice logger""" 27 | logger = logging.getLogger(name) 28 | # clear handlers if they were created in other runs 29 | if logger.hasHandlers(): 30 | logger.handlers.clear() 31 | logger.setLevel(logging.DEBUG) 32 | # create formatter 33 | formatter = logging.Formatter("%(asctime)s - %(message)s") 34 | # create console handler add add to logger 35 | ch = logging.StreamHandler() 36 | ch.setLevel(logging.DEBUG) 37 | ch.setFormatter(formatter) 38 | logger.addHandler(ch) 39 | # create file handler add add to logger when name is not None 40 | if logfile is not None: 41 | fh = logging.FileHandler(logfile) 42 | fh.setFormatter(formatter) 43 | fh.setLevel(logging.DEBUG) 44 | logger.addHandler(fh) 45 | logger.propagate = False 46 | return logger 47 | 48 | 49 | def seed_torch(seed=0): 50 | print("Seed", seed) 51 | torch.manual_seed(seed) 52 | torch.cuda.manual_seed(seed) 53 | torch.backends.cudnn.benchmark = False 54 | torch.backends.cudnn.deterministic = True 55 | 56 | 57 | def get_cosine_schedule_with_warmup(optimizer, 58 | num_warmup_steps, 59 | num_training_steps, 60 | num_cycles=7./16., 61 | last_epoch=-1): 62 | def _lr_lambda(current_step): 63 | if current_step < num_warmup_steps: 64 | return float(current_step) / float(max(1, num_warmup_steps)) 65 | no_progress = float(current_step - num_warmup_steps) / \ 66 | float(max(1, num_training_steps - num_warmup_steps)) 67 | return max(0, math.cos(math.pi * num_cycles * no_progress)) 68 | 69 | return LambdaLR(optimizer, _lr_lambda, last_epoch) 70 | 71 | 72 | def main(args, seed): 73 | device = torch.device("cuda", args.gpu_id) 74 | args.n_gpu = torch.cuda.device_count() 75 | args.device = device 76 | 77 | if args.dataset == "pretrain": 78 | dataset, context_graph = get_data(args, "./raw_data", transform="pyg") 79 | context_graph = context_graph[0] 80 | else: 81 | dataset = get_data(args, "./raw_data", transform="pyg") 82 | context_graph = None 83 | 84 | split_idx = dataset.get_idx_split() 85 | args.num_trained = len(split_idx["train"]) 86 | args.task_type = dataset.task_type 87 | args.steps = args.num_trained // args.batch_size + 1 88 | 89 | train_loader = DataLoader( 90 | dataset[split_idx["train"]], 91 | batch_size=args.batch_size, 92 | shuffle=True, 93 | num_workers=args.num_workers, 94 | ) 95 | 96 | if args.dataset == "pretrain": 97 | from models.gnn import GNN 98 | from torch.distributions import Normal, Independent 99 | 100 | test_loader = None 101 | model = GNN( 102 | gnn_type=args.model, 103 | num_tasks=dataset.num_tasks, 104 | num_layer=args.num_layer, 105 | emb_dim=args.emb_dim, 106 | drop_ratio=args.drop_ratio, 107 | graph_pooling=args.readout, 108 | norm_layer=args.norm_layer, 109 | ).to(device) 110 | init_weights(model, args.initw_name, init_gain=0.02) 111 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wdecay) 112 | 113 | prior_mu = torch.zeros(args.emb_dim).to(device) 114 | prior_sigma = torch.ones(args.emb_dim).to(device) 115 | args.prior_dist = Independent(Normal(loc=prior_mu, scale=prior_sigma), 1) 116 | 117 | elif args.dataset.startswith("finetune"): 118 | from models.gnn import FineTuneGNN 119 | 120 | valid_loader = DataLoader( 121 | dataset[split_idx["valid"]], 122 | batch_size=args.batch_size, 123 | shuffle=False, 124 | num_workers=args.num_workers, 125 | ) 126 | test_loader = DataLoader( 127 | dataset[split_idx["test"]], 128 | batch_size=args.batch_size, 129 | shuffle=False, 130 | num_workers=args.num_workers, 131 | ) 132 | model = FineTuneGNN( 133 | gnn_type=args.model, 134 | num_tasks=dataset.num_tasks, 135 | num_layer=args.num_layer, 136 | emb_dim=args.emb_dim, 137 | drop_ratio=args.drop_ratio, 138 | graph_pooling=args.readout, 139 | norm_layer=args.norm_layer, 140 | ).to(device) 141 | model.load_pretrained_graph_encoder(args.model_path) 142 | model.freeze_graph_encoder() 143 | optimizer = optim.Adam( 144 | model.task_decoder.parameters(), lr=args.lr, weight_decay=args.wdecay 145 | ) 146 | else: 147 | raise ValueError("Invalid dataset name") 148 | 149 | # scheduler = None 150 | scheduler = get_cosine_schedule_with_warmup(optimizer, 0, args.epochs * args.steps) 151 | 152 | logging.warning(f"device: {args.device}, " f"n_gpu: {args.n_gpu}, ") 153 | logger.info(dict(args._get_kwargs())) 154 | logger.info(model) 155 | logger.info("***** Running training *****") 156 | logger.info( 157 | f" Task = {args.dataset}@{args.num_trained}/{len(split_idx['valid'])}/{len(split_idx['test'])}" 158 | ) 159 | logger.info(f" Num Epochs = {args.epochs}") 160 | logger.info(f" Total train batch size = {args.batch_size}") 161 | logger.info(f" Total optimization steps = {args.epochs * args.steps}") 162 | 163 | train_loaders = {"train_iter": iter(train_loader), "train_loader": train_loader} 164 | 165 | best_train, best_valid, best_test, best_count = None, None, None, None 166 | best_epoch = 0 167 | loss_tots = [] 168 | if args.dataset == "pretrain": # later args.finetune 169 | for epoch in range(0, args.epochs): 170 | loss, train_loaders = pretrain_func( 171 | args, model, train_loaders, context_graph, optimizer, scheduler, epoch 172 | ) 173 | loss_tots.append(loss) 174 | if epoch == args.epochs - 1: 175 | torch.save(model.state_dict(), args.model_path) 176 | yaml_path = args.model_path.replace(".pt", ".yaml") 177 | save_arguments_to_yaml(args, yaml_path, model_only=True) 178 | logger.info( 179 | f"Finished Training \n Model saved at {args.model_path} and Arguments saved at {yaml_path} with loss {loss_tots}" 180 | ) 181 | 182 | elif args.dataset.startswith("finetune"): 183 | args.task_type = ( 184 | "regression" if "mae" in dataset.eval_metric else "classification" 185 | ) 186 | best_params = None 187 | for epoch in range(0, args.epochs): 188 | train_loaders = finetune_func( 189 | args, model, train_loaders, optimizer, scheduler, epoch 190 | ) 191 | valid_perf = validate(args, model, valid_loader) 192 | 193 | if epoch > 0: 194 | is_improved = ( 195 | valid_perf[dataset.eval_metric] < best_valid 196 | if args.task_type == "regression" 197 | else valid_perf[dataset.eval_metric] > best_valid 198 | ) 199 | if epoch == 0 or is_improved: 200 | train_perf = validate(args, model, train_loader) 201 | test_perf = validate(args, model, test_loader) 202 | best_params = parameters_to_vector(model.parameters()) 203 | best_valid = valid_perf[dataset.eval_metric] 204 | best_test = test_perf[dataset.eval_metric] 205 | best_train = train_perf[dataset.eval_metric] 206 | best_epoch = epoch 207 | best_count = test_perf.get("count", None) 208 | if best_count is None: 209 | best_count = test_perf.get("mae_list", None) 210 | if not args.no_print: 211 | logger.info( 212 | "Update Epoch {}: best_train: {:.4f} best_valid: {:.4f}, best_test: {:.4f}".format( 213 | epoch, best_train, best_valid, best_test 214 | ) 215 | ) 216 | if best_count is not None and args.task_type == "classification": 217 | outstr = "Best Count: " 218 | for key, value in best_count.items(): 219 | sum_num = int(np.nansum(value)) 220 | nan_num = sum(np.isnan(value)) 221 | outstr += f"{key}: {sum_num/len(value):.4f} (nan {sum(np.isnan(value))} / {len(value)}), " 222 | logger.info(outstr) 223 | else: 224 | if not args.no_print: 225 | logger.info( 226 | "Epoch {}: best_valid: {:.4f}, current_valid: {:.4f}, patience: {}/{}".format( 227 | epoch, 228 | best_valid, 229 | valid_perf[dataset.eval_metric], 230 | epoch - best_epoch, 231 | args.patience, 232 | ) 233 | ) 234 | if epoch - best_epoch > args.patience: 235 | break 236 | 237 | logger.info( 238 | "Finished. \n {}-{} Best validation epoch {} with metric {}, train {:.4f}, valid {:.4f}, test {:.4f}".format( 239 | args.dataset, args.pretrain_name, best_epoch, dataset.eval_metric, best_train, best_valid, best_test 240 | ) 241 | ) 242 | vector_to_parameters(best_params, model.parameters()) 243 | save_prediction(model, device, test_loader, dataset, args.output_dir, seed) 244 | 245 | return ( 246 | args.pretrain_name, 247 | args.dataset, 248 | dataset.eval_metric, 249 | best_train, 250 | best_valid, 251 | best_test, 252 | best_epoch, 253 | best_count, 254 | ) 255 | 256 | 257 | if __name__ == "__main__": 258 | import os 259 | import pandas as pd 260 | 261 | args = get_args() 262 | log_path = args.model_path.replace(".pt", ".log") 263 | 264 | pretrain_name = args.model_path.split("/")[-1] 265 | pretrain_name = pretrain_name.split(".")[0] 266 | args.pretrain_name = pretrain_name 267 | 268 | if args.dataset.startswith("finetune"): 269 | args.output_dir = f"output/{args.dataset}/{pretrain_name}" 270 | yaml_path = args.model_path.replace(".pt", ".yaml") 271 | 272 | # Check if args.model_path exists 273 | if not os.path.exists(args.model_path): 274 | from huggingface_hub import hf_hub_download 275 | os.makedirs(os.path.dirname(args.model_path), exist_ok=True) 276 | hf_hub_download(repo_id="liuganghuggingface/InfoAlign-Pretrained", 277 | filename="pretrain.pt", 278 | local_dir=os.path.dirname(args.model_path), 279 | local_dir_use_symlinks=False) 280 | config_path = hf_hub_download(repo_id="liuganghuggingface/InfoAlign-Pretrained", 281 | filename="config.yaml", 282 | local_dir=os.path.dirname(yaml_path), 283 | local_dir_use_symlinks=False) 284 | 285 | # Rename the downloaded config file to pretrain.yaml 286 | new_yaml_path = os.path.join(os.path.dirname(yaml_path), "pretrain.yaml") 287 | os.rename(config_path, new_yaml_path) 288 | 289 | print('args.model_path', args.model_path) 290 | print('args.yaml_path', yaml_path) 291 | 292 | config = load_arguments_from_yaml(yaml_path, model_only=True) 293 | for arg, value in config.items(): 294 | setattr(args, arg, value) 295 | log_path = log_path + ".finetune" 296 | else: 297 | log_path = log_path + ".pretrain" 298 | 299 | # Define the repository ID and local directory 300 | repo_id = "liuganghuggingface/InfoAlign-Data" 301 | local_dir = "raw_data/pretrain/raw" 302 | 303 | # Check if the local directory exists 304 | if not os.path.exists(local_dir): 305 | from huggingface_hub import hf_hub_download, HfApi 306 | import os 307 | 308 | os.makedirs(local_dir, exist_ok=True) 309 | 310 | try: 311 | # Use HfApi to list files 312 | api = HfApi() 313 | all_files = api.list_repo_files(repo_id, repo_type="dataset") 314 | 315 | # Filter for files in the pretrain_raw folder 316 | pretrain_raw_files = [f for f in all_files if f.startswith("pretrain_raw/")] 317 | 318 | # Download each file 319 | for file in pretrain_raw_files: 320 | # Extract the filename from the path 321 | filename = os.path.basename(file) 322 | 323 | hf_hub_download(repo_id=repo_id, 324 | filename=file, 325 | repo_type="dataset", 326 | local_dir=local_dir, 327 | local_dir_use_symlinks=False) 328 | 329 | # Rename the file to remove the 'pretrain_raw/' prefix 330 | old_path = os.path.join(local_dir, file) 331 | new_path = os.path.join(local_dir, filename) 332 | os.rename(old_path, new_path) 333 | 334 | print(f"Downloaded {len(pretrain_raw_files)} files to {local_dir}") 335 | except Exception as e: 336 | print(f"Error downloading dataset: {str(e)}") 337 | print("Please check your internet connection and ensure you have the necessary permissions.") 338 | print("If the issue persists, you may need to log in using `huggingface-cli login`") 339 | else: 340 | print(f"Directory {local_dir} already exists. Skipping download.") 341 | 342 | # logger = get_logger(__name__, logfile=log_path) 343 | logger = get_logger(__name__) 344 | args.logger = logger 345 | print(vars(args)) 346 | 347 | if args.dataset.startswith("pretrain"): 348 | main(args, 0) 349 | else: 350 | df = pd.DataFrame() 351 | for i in range(5): 352 | model, dataset, metric, train, valid, test, epoch, count = main(args, i) 353 | if "auc" in metric: 354 | new_results = { 355 | "model": model, 356 | "dataset": dataset, 357 | "seed": i, 358 | "metric": metric, 359 | "train": train, 360 | "valid": valid, 361 | "test": test, 362 | "epoch": epoch, 363 | "suc_80": round(np.nansum(count[80]) / len(count[80]), 4), 364 | "suc_85": round(np.nansum(count[85]) / len(count[85]), 4), 365 | "suc_90": round(np.nansum(count[90]) / len(count[90]), 4), 366 | "suc_95": round(np.nansum(count[95]) / len(count[95]), 4), 367 | "thr_80": count[80], 368 | "thr_85": count[85], 369 | "thr_90": count[90], 370 | "thr_95": count[95], 371 | } 372 | else: 373 | mae_list = count 374 | new_results = { 375 | "model": model, 376 | "dataset": dataset, 377 | "seed": i, 378 | "metric": metric, 379 | "train": train, 380 | "valid": valid, 381 | "test": test, 382 | "epoch": epoch, 383 | "mae_1": mae_list[0], 384 | "mae_2": mae_list[1], 385 | "mae_3": mae_list[2], 386 | "mae_4": mae_list[3], 387 | "mae_5": mae_list[4], 388 | "mae_6": mae_list[5], 389 | } 390 | df = pd.concat([df, pd.DataFrame([new_results])], ignore_index=True) 391 | 392 | summary_each = f"output/{args.dataset}/summary_each.csv" 393 | if os.path.exists(summary_each): 394 | df.to_csv(summary_each, mode="a", header=False, index=False) 395 | else: 396 | df.to_csv(summary_each, index=False) 397 | print(df) 398 | 399 | # Calculate mean and std 400 | if "auc" in metric: 401 | cols = [ 402 | "model", 403 | "dataset", 404 | "metric", 405 | "train", 406 | "valid", 407 | "test", 408 | "suc_80", 409 | "suc_85", 410 | "suc_90", 411 | "suc_95", 412 | ] 413 | else: 414 | cols = [ 415 | "model", 416 | "dataset", 417 | "metric", 418 | "train", 419 | "valid", 420 | "test", 421 | "mae_1", 422 | "mae_2", 423 | "mae_3", 424 | "mae_4", 425 | "mae_5", 426 | "mae_6", 427 | ] 428 | df_mean = df[cols].groupby(["model", "dataset", "metric"]).mean().round(4) 429 | df_std = df[cols].groupby(["model", "dataset", "metric"]).std().round(4) 430 | 431 | df_mean = df_mean.reset_index() 432 | df_std = df_std.reset_index() 433 | df_summary = df_mean[["model", "dataset", "metric"]].copy() 434 | if "auc" in metric: 435 | for col in [ 436 | "train", 437 | "valid", 438 | "test", 439 | "suc_80", 440 | "suc_85", 441 | "suc_90", 442 | "suc_95", 443 | ]: 444 | df_summary[col] = ( 445 | df_mean[col].astype(str) + "±" + df_std[col].astype(str) 446 | ) 447 | else: 448 | for col in [ 449 | "train", 450 | "valid", 451 | "test", 452 | "mae_1", 453 | "mae_2", 454 | "mae_3", 455 | "mae_4", 456 | "mae_5", 457 | "mae_6", 458 | ]: 459 | df_summary[col] = ( 460 | df_mean[col].astype(str) + "±" + df_std[col].astype(str) 461 | ) 462 | 463 | summary_all = f"output/{args.dataset}/summary_all.csv" 464 | if os.path.exists(summary_all): 465 | df_summary.to_csv(summary_all, mode="a", header=False, index=False) 466 | else: 467 | df_summary.to_csv(summary_all, index=False) 468 | print(df_summary) 469 | -------------------------------------------------------------------------------- /models/conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch_geometric.utils import degree 4 | from torch_geometric.nn.norm import GraphNorm, PairNorm, MessageNorm, DiffGroupNorm, InstanceNorm, LayerNorm, GraphSizeNorm, MessageNorm 5 | from torch_geometric.nn import MessagePassing 6 | from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool 7 | from ogb.utils.features import get_atom_feature_dims, get_bond_feature_dims 8 | 9 | full_atom_feature_dims = get_atom_feature_dims() 10 | full_bond_feature_dims = get_bond_feature_dims() 11 | 12 | class GINConv(MessagePassing): 13 | def __init__(self, emb_dim): 14 | ''' 15 | emb_dim (int): node embedding dimensionality 16 | ''' 17 | 18 | super(GINConv, self).__init__(aggr = "add") 19 | 20 | self.mlp = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2*emb_dim), torch.nn.BatchNorm1d(2*emb_dim), torch.nn.ReLU(), torch.nn.Linear(2*emb_dim, emb_dim)) 21 | self.eps = torch.nn.Parameter(torch.Tensor([0])) 22 | 23 | self.bond_encoder = BondEncoder(emb_dim = emb_dim) 24 | 25 | def forward(self, x, edge_index, edge_attr): 26 | edge_embedding = self.bond_encoder(edge_attr) 27 | out = self.mlp((1 + self.eps) *x + self.propagate(edge_index, x=x, edge_attr=edge_embedding)) 28 | 29 | return out 30 | 31 | def message(self, x_j, edge_attr): 32 | return F.relu(x_j + edge_attr) 33 | 34 | def update(self, aggr_out): 35 | return aggr_out 36 | 37 | class GCNConv(MessagePassing): 38 | def __init__(self, emb_dim): 39 | super(GCNConv, self).__init__(aggr='add') 40 | 41 | self.linear = torch.nn.Linear(emb_dim, emb_dim) 42 | self.root_emb = torch.nn.Embedding(1, emb_dim) 43 | self.bond_encoder = BondEncoder(emb_dim = emb_dim) 44 | 45 | def forward(self, x, edge_index, edge_attr): 46 | x = self.linear(x) 47 | edge_embedding = self.bond_encoder(edge_attr) 48 | 49 | row, col = edge_index 50 | 51 | #edge_weight = torch.ones((edge_index.size(1), ), device=edge_index.device) 52 | deg = degree(row, x.size(0), dtype = x.dtype) + 1 53 | deg_inv_sqrt = deg.pow(-0.5) 54 | deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 55 | 56 | norm = deg_inv_sqrt[row] * deg_inv_sqrt[col] 57 | 58 | return self.propagate(edge_index, x=x, edge_attr = edge_embedding, norm=norm) + F.relu(x + self.root_emb.weight) * 1./deg.view(-1,1) 59 | 60 | def message(self, x_j, edge_attr, norm): 61 | return norm.view(-1, 1) * F.relu(x_j + edge_attr) 62 | 63 | def update(self, aggr_out): 64 | return aggr_out 65 | 66 | 67 | ### GNN to generate node embedding 68 | class GNN_node(torch.nn.Module): 69 | """ 70 | Output: 71 | node representations 72 | """ 73 | def __init__(self, num_layer, emb_dim, drop_ratio = 0.5, JK = "last", residual = False, gnn_name = 'gin', norm_layer = 'batch_norm'): 74 | ''' 75 | emb_dim (int): node embedding dimensionality 76 | num_layer (int): number of GNN message passing layers 77 | 78 | ''' 79 | 80 | super(GNN_node, self).__init__() 81 | self.num_layer = num_layer 82 | self.drop_ratio = drop_ratio 83 | self.JK = JK 84 | ### add residual connection or not 85 | self.residual = residual 86 | self.norm_layer = norm_layer 87 | 88 | if self.num_layer < 2: 89 | raise ValueError("Number of GNN layers must be greater than 1.") 90 | 91 | self.atom_encoder = AtomEncoder(emb_dim) 92 | self.bond_encoder = BondEncoder(emb_dim) 93 | 94 | ###List of GNNs 95 | self.convs = torch.nn.ModuleList() 96 | self.batch_norms = torch.nn.ModuleList() 97 | 98 | for layer in range(num_layer): 99 | if gnn_name == 'gin': 100 | self.convs.append(GINConv(emb_dim)) 101 | elif gnn_name == 'gcn': 102 | self.convs.append(GCNConv(emb_dim)) 103 | else: 104 | raise ValueError('Undefined GNN type called {}'.format(gnn_name)) 105 | 106 | if norm_layer.split('_')[0] == 'batch': 107 | if norm_layer.split('_')[-1] == 'notrack': 108 | self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim, track_running_stats=False, affine=False)) 109 | else: 110 | self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim)) 111 | elif norm_layer.split('_')[0] == 'instance': 112 | self.batch_norms.append(InstanceNorm(emb_dim)) 113 | elif norm_layer.split('_')[0] == 'layer': 114 | self.batch_norms.append(LayerNorm(emb_dim)) 115 | elif norm_layer.split('_')[0] == 'graph': 116 | self.batch_norms.append(GraphNorm(emb_dim)) 117 | elif norm_layer.split('_')[0] == 'size': 118 | self.batch_norms.append(GraphSizeNorm()) 119 | elif norm_layer.split('_')[0] == 'pair': 120 | self.batch_norms.append(PairNorm(emb_dim)) 121 | elif norm_layer.split('_')[0] == 'group': 122 | self.batch_norms.append(DiffGroupNorm(emb_dim, groups=4)) 123 | else: 124 | raise ValueError('Undefined normalization layer called {}'.format(norm_layer)) 125 | if norm_layer.split('_')[1] == 'size': 126 | self.graph_size_norm = GraphSizeNorm() 127 | 128 | def forward(self, batched_data): 129 | x, edge_index, edge_attr, batch = batched_data.x, batched_data.edge_index, batched_data.edge_attr, batched_data.batch 130 | 131 | h_list = [self.atom_encoder(x)] 132 | for layer in range(self.num_layer): 133 | 134 | h = self.convs[layer](h_list[layer], edge_index, edge_attr) 135 | if self.norm_layer.split('_')[0] == 'batch': 136 | h = self.batch_norms[layer](h) 137 | else: 138 | h = self.batch_norms[layer](h, batch) 139 | if self.norm_layer.split('_')[1] == 'size': 140 | h = self.graph_size_norm(h, batch) 141 | 142 | if layer == self.num_layer - 1: 143 | #remove relu for the last layer 144 | h = F.dropout(h, self.drop_ratio, training = self.training) 145 | else: 146 | h = F.relu(h) 147 | h = F.dropout(h, self.drop_ratio, training = self.training) 148 | if self.residual: 149 | h = h + h_list[layer] 150 | 151 | h_list.append(h) 152 | 153 | ### Different implementations of Jk-concat 154 | if self.JK == "last": 155 | node_representation = h_list[-1] 156 | elif self.JK == "sum": 157 | node_representation = 0 158 | for layer in range(self.num_layer + 1): 159 | node_representation += h_list[layer] 160 | 161 | return node_representation, h_list 162 | 163 | 164 | ### Virtual GNN to generate node embedding 165 | class GNN_node_Virtualnode(torch.nn.Module): 166 | """ 167 | Output: 168 | node representations 169 | """ 170 | def __init__(self, num_layer, emb_dim, drop_ratio = 0.5, JK = "last", residual = False, gnn_name = 'gin', norm_layer = 'batch_norm'): 171 | ''' 172 | emb_dim (int): node embedding dimensionality 173 | ''' 174 | 175 | super(GNN_node_Virtualnode, self).__init__() 176 | self.num_layer = num_layer 177 | self.drop_ratio = drop_ratio 178 | self.JK = JK 179 | ### add residual connection or not 180 | self.residual = residual 181 | self.norm_layer = norm_layer 182 | 183 | if self.num_layer < 2: 184 | raise ValueError("Number of GNN layers must be greater than 1.") 185 | 186 | self.atom_encoder = AtomEncoder(emb_dim) 187 | self.bond_encoder = BondEncoder(emb_dim) 188 | 189 | ### set the initial virtual node embedding to 0. 190 | self.virtualnode_embedding = torch.nn.Embedding(1, emb_dim) 191 | torch.nn.init.constant_(self.virtualnode_embedding.weight.data, 0) 192 | 193 | ### List of GNNs 194 | self.convs = torch.nn.ModuleList() 195 | ### batch norms applied to node embeddings 196 | self.batch_norms = torch.nn.ModuleList() 197 | 198 | ### List of MLPs to transform virtual node at every layer 199 | self.mlp_virtualnode_list = torch.nn.ModuleList() 200 | 201 | for layer in range(num_layer): 202 | if gnn_name == 'gin': 203 | self.convs.append(GINConv(emb_dim)) 204 | elif gnn_name == 'gcn': 205 | self.convs.append(GCNConv(emb_dim)) 206 | else: 207 | raise ValueError('Undefined GNN type called {}'.format(gnn_name)) 208 | 209 | if norm_layer.split('_')[0] == 'batch': 210 | if norm_layer.split('_')[-1] == 'notrack': 211 | self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim, track_running_stats=False, affine=False)) 212 | else: 213 | self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim)) 214 | elif norm_layer.split('_')[0] == 'instance': 215 | self.batch_norms.append(InstanceNorm(emb_dim)) 216 | elif norm_layer.split('_')[0] == 'layer': 217 | self.batch_norms.append(LayerNorm(emb_dim)) 218 | elif norm_layer.split('_')[0] == 'graph': 219 | self.batch_norms.append(GraphNorm(emb_dim)) 220 | elif norm_layer.split('_')[0] == 'size': 221 | self.batch_norms.append(GraphSizeNorm()) 222 | elif norm_layer.split('_')[0] == 'pair': 223 | self.batch_norms.append(PairNorm(emb_dim)) 224 | elif norm_layer.split('_')[0] == 'group': 225 | self.batch_norms.append(DiffGroupNorm(emb_dim, groups=4)) 226 | else: 227 | raise ValueError('Undefined normalization layer called {}'.format(norm_layer)) 228 | if norm_layer.split('_')[1] == 'size': 229 | self.graph_size_norm = GraphSizeNorm() 230 | for layer in range(num_layer - 1): 231 | self.mlp_virtualnode_list.append(torch.nn.Sequential(torch.nn.Linear(emb_dim, 2*emb_dim), torch.nn.BatchNorm1d(2*emb_dim), torch.nn.ReLU(), \ 232 | torch.nn.Linear(2*emb_dim, emb_dim), torch.nn.BatchNorm1d(emb_dim), torch.nn.ReLU())) 233 | 234 | 235 | def forward(self, batched_data): 236 | x, edge_index, edge_attr, batch = batched_data.x, batched_data.edge_index, batched_data.edge_attr, batched_data.batch 237 | ### virtual node embeddings for graphs 238 | virtualnode_embedding = self.virtualnode_embedding(torch.zeros(batch[-1].item() + 1).to(edge_index.dtype).to(edge_index.device)) 239 | 240 | h_list = [self.atom_encoder(x)] 241 | for layer in range(self.num_layer): 242 | ### add message from virtual nodes to graph nodes 243 | h_list[layer] = h_list[layer] + virtualnode_embedding[batch] 244 | ### Message passing among graph nodes 245 | h = self.convs[layer](h_list[layer], edge_index, edge_attr) 246 | if self.norm_layer.split('_')[0] == 'batch': 247 | h = self.batch_norms[layer](h) 248 | else: 249 | h = self.batch_norms[layer](h, batch) 250 | if self.norm_layer.split('_')[1] == 'size': 251 | h = self.graph_size_norm(h, batch) 252 | 253 | if layer == self.num_layer - 1: 254 | h = F.dropout(h, self.drop_ratio, training = self.training) 255 | else: 256 | h = F.relu(h) 257 | h = F.dropout(h, self.drop_ratio, training = self.training) 258 | 259 | if self.residual: 260 | h = h + h_list[layer] 261 | 262 | h_list.append(h) 263 | 264 | ### update the virtual nodes 265 | if layer < self.num_layer - 1: 266 | ### add message from graph nodes to virtual nodes 267 | virtualnode_embedding_temp = global_add_pool(h_list[layer], batch) + virtualnode_embedding 268 | if self.residual: 269 | virtualnode_embedding = virtualnode_embedding + F.dropout(self.mlp_virtualnode_list[layer](virtualnode_embedding_temp), self.drop_ratio, training = self.training) 270 | else: 271 | virtualnode_embedding = F.dropout(self.mlp_virtualnode_list[layer](virtualnode_embedding_temp), self.drop_ratio, training = self.training) 272 | 273 | ### Different implementations of Jk-concat 274 | if self.JK == "last": 275 | node_representation = h_list[-1] 276 | elif self.JK == "sum": 277 | node_representation = 0 278 | for layer in range(self.num_layer + 1): 279 | node_representation += h_list[layer] 280 | 281 | return node_representation, h_list 282 | 283 | 284 | class AtomEncoder(torch.nn.Module): 285 | 286 | def __init__(self, emb_dim): 287 | super(AtomEncoder, self).__init__() 288 | 289 | self.atom_embedding_list = torch.nn.ModuleList() 290 | 291 | for i, dim in enumerate(full_atom_feature_dims): 292 | emb = torch.nn.Embedding(dim, emb_dim, max_norm=1) 293 | torch.nn.init.xavier_uniform_(emb.weight.data) 294 | self.atom_embedding_list.append(emb) 295 | 296 | def forward(self, x): 297 | x_embedding = 0 298 | for i in range(x.shape[1]): 299 | x_embedding += self.atom_embedding_list[i](x[:,i]) 300 | 301 | return x_embedding 302 | 303 | 304 | class BondEncoder(torch.nn.Module): 305 | 306 | def __init__(self, emb_dim): 307 | super(BondEncoder, self).__init__() 308 | 309 | self.bond_embedding_list = torch.nn.ModuleList() 310 | 311 | for i, dim in enumerate(full_bond_feature_dims): 312 | emb = torch.nn.Embedding(dim, emb_dim, max_norm=1) 313 | torch.nn.init.xavier_uniform_(emb.weight.data) 314 | self.bond_embedding_list.append(emb) 315 | 316 | def forward(self, edge_attr): 317 | bond_embedding = 0 318 | for i in range(edge_attr.shape[1]): 319 | bond_embedding += self.bond_embedding_list[i](edge_attr[:,i]) 320 | 321 | return bond_embedding -------------------------------------------------------------------------------- /models/gnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool 5 | from .conv import GNN_node, GNN_node_Virtualnode 6 | 7 | from torch.distributions import Normal, Independent 8 | from torch.nn.functional import softplus 9 | 10 | class GNN(torch.nn.Module): 11 | def __init__( 12 | self, 13 | num_tasks=None, # to remove 14 | num_layer=5, 15 | emb_dim=300, 16 | gnn_type="gin", 17 | drop_ratio=0.5, 18 | graph_pooling="max", 19 | norm_layer="batch_norm", 20 | decoder_dims=[1024, 1111, 862, 1783, 966, 978], 21 | # mol, gene (gc, go), cell (bray, jump), express 22 | ): 23 | super(GNN, self).__init__() 24 | self.num_layer = num_layer 25 | self.drop_ratio = drop_ratio 26 | self.graph_pooling = graph_pooling 27 | if self.num_layer < 2: 28 | raise ValueError("Number of GNN layers must be greater than 1.") 29 | ### GNN to generate node embeddings 30 | gnn_name = gnn_type.split("-")[0] 31 | if "virtual" in gnn_type: 32 | self.graph_encoder = GNN_node_Virtualnode( 33 | num_layer, 34 | emb_dim , 35 | JK="last", 36 | drop_ratio=drop_ratio, 37 | residual=True, 38 | gnn_name=gnn_name, 39 | norm_layer=norm_layer, 40 | ) 41 | else: 42 | self.graph_encoder = GNN_node( 43 | num_layer, 44 | emb_dim, 45 | JK="last", 46 | drop_ratio=drop_ratio, 47 | residual=True, 48 | gnn_name=gnn_name, 49 | norm_layer=norm_layer, 50 | ) 51 | 52 | ### Poolinwg function to generate whole-graph embeddings 53 | if graph_pooling == "sum": 54 | self.pool = global_add_pool 55 | elif graph_pooling == "mean": 56 | self.pool = global_mean_pool 57 | elif graph_pooling == "max": 58 | self.pool = global_max_pool 59 | else: 60 | raise ValueError("Invalid graph pooling type.") 61 | 62 | self.dist_net = nn.Sequential( 63 | nn.SiLU(), 64 | nn.Linear(emb_dim, 2 * emb_dim, bias=True) 65 | ) 66 | 67 | self.decoder_list = nn.ModuleList() 68 | for out_dim in decoder_dims: 69 | self.decoder_list.append(MLP(emb_dim, hidden_features=emb_dim * 4, out_features=out_dim)) 70 | 71 | 72 | def forward(self, batched_data): 73 | h_node, _ = self.graph_encoder(batched_data) 74 | h_graph = self.pool(h_node, batched_data.batch) 75 | 76 | mu, sigma = self.dist_net(h_graph).chunk(2, dim=1) 77 | sigma = softplus(sigma) + 1e-7 78 | p_z_given_x = Independent(Normal(loc=mu, scale=sigma), 1) 79 | 80 | out = [] 81 | out.append(p_z_given_x) 82 | # p, mol, gene (gc, go), cell (bray, jump), express 83 | for decoder in self.decoder_list: 84 | out.append(decoder(mu)) 85 | out_gene = torch.cat((out[2], out[3]), dim=1) 86 | out_cell = torch.cat((out[4], out[5]), dim=1) 87 | return [out[0], out[1], out_gene, out_cell, out[6]] 88 | 89 | # define a new finetune model with the same architecture of GNN with a new MLP 90 | 91 | class FineTuneGNN(nn.Module): 92 | def __init__( 93 | self, 94 | num_tasks=None, 95 | num_layer=5, 96 | emb_dim=300, 97 | gnn_type="gin", 98 | drop_ratio=0.5, 99 | graph_pooling="max", 100 | norm_layer="batch_norm", 101 | ): 102 | super(FineTuneGNN, self).__init__() 103 | 104 | ### GNN to generate node embeddings 105 | gnn_name = gnn_type.split("-")[0] 106 | if "virtual" in gnn_type: 107 | self.graph_encoder = GNN_node_Virtualnode( 108 | num_layer, 109 | emb_dim, 110 | JK="last", 111 | drop_ratio=drop_ratio, 112 | residual=True, 113 | gnn_name=gnn_name, 114 | norm_layer=norm_layer, 115 | ) 116 | else: 117 | self.graph_encoder = GNN_node( 118 | num_layer, 119 | emb_dim, 120 | JK="last", 121 | drop_ratio=drop_ratio, 122 | residual=True, 123 | gnn_name=gnn_name, 124 | norm_layer=norm_layer, 125 | ) 126 | ### Poolinwg function to generate whole-graph embeddings 127 | if graph_pooling == "sum": 128 | self.pool = global_add_pool 129 | elif graph_pooling == "mean": 130 | self.pool = global_mean_pool 131 | elif graph_pooling == "max": 132 | self.pool = global_max_pool 133 | else: 134 | raise ValueError("Invalid graph pooling type.") 135 | 136 | self.dist_net = nn.Sequential( 137 | nn.SiLU(), 138 | nn.Linear(emb_dim, 2 * emb_dim, bias=True) 139 | ) 140 | 141 | # self.task_decoder = nn.Linear(emb_dim, num_tasks) 142 | self.task_decoder = MLP(emb_dim, hidden_features=4 * emb_dim, out_features=num_tasks) 143 | 144 | def forward(self, batched_data): 145 | h_node, _ = self.graph_encoder(batched_data) 146 | h_graph = self.pool(h_node, batched_data.batch) 147 | 148 | mu, _ = self.dist_net(h_graph).chunk(2, dim=1) 149 | task_out = self.task_decoder(mu) 150 | return task_out 151 | 152 | def load_pretrained_graph_encoder(self, model_path): 153 | saved_state_dict = torch.load(model_path, map_location=torch.device('cpu')) 154 | graph_encoder_state_dict = {key: value for key, value in saved_state_dict.items() if key.startswith('graph_encoder.')} 155 | graph_encoder_state_dict = {key.replace('graph_encoder.', ''): value for key, value in graph_encoder_state_dict.items()} 156 | self.graph_encoder.load_state_dict(graph_encoder_state_dict) 157 | # Load dist_net state dictionary 158 | dist_net_state_dict = {key: value for key, value in saved_state_dict.items() if key.startswith('dist_net.')} 159 | dist_net_state_dict = {key.replace('dist_net.', ''): value for key, value in dist_net_state_dict.items()} 160 | self.dist_net.load_state_dict(dist_net_state_dict) 161 | self.freeze_graph_encoder() 162 | 163 | def freeze_graph_encoder(self): 164 | for param in self.graph_encoder.parameters(): 165 | param.requires_grad = False 166 | for param in self.dist_net.parameters(): 167 | param.requires_grad = False 168 | 169 | class MLP(nn.Module): 170 | def __init__( 171 | self, 172 | in_features, 173 | hidden_features=None, 174 | out_features=None, 175 | act_layer=nn.GELU, 176 | bias=True, 177 | drop=0.5, 178 | ): 179 | super().__init__() 180 | out_features = out_features or in_features 181 | hidden_features = hidden_features or in_features 182 | linear_layer = nn.Linear 183 | 184 | self.fc1 = linear_layer(in_features, hidden_features, bias=bias) 185 | self.bn1 = nn.BatchNorm1d(hidden_features) 186 | self.act = act_layer() 187 | self.drop1 = nn.Dropout(drop) 188 | self.fc2 = linear_layer(hidden_features, out_features, bias=bias) 189 | self.drop2 = nn.Dropout(drop) 190 | 191 | def forward(self, x): 192 | x = self.fc1(x) 193 | x = self.bn1(x) 194 | x = self.act(x) 195 | x = self.drop1(x) 196 | x = self.fc2(x) 197 | # x = self.drop2(x) 198 | return x -------------------------------------------------------------------------------- /raw_data/biogenadme/raw/assays.csv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liugangcode/InfoAlign/eed96f36503a8a427fac4f382ee2001f957dd42a/raw_data/biogenadme/raw/assays.csv.gz -------------------------------------------------------------------------------- /raw_data/broad6k/raw/CP-Bray.csv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liugangcode/InfoAlign/eed96f36503a8a427fac4f382ee2001f957dd42a/raw_data/broad6k/raw/CP-Bray.csv.gz -------------------------------------------------------------------------------- /raw_data/broad6k/raw/GE.csv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liugangcode/InfoAlign/eed96f36503a8a427fac4f382ee2001f957dd42a/raw_data/broad6k/raw/GE.csv.gz -------------------------------------------------------------------------------- /raw_data/broad6k/raw/assays.csv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liugangcode/InfoAlign/eed96f36503a8a427fac4f382ee2001f957dd42a/raw_data/broad6k/raw/assays.csv.gz -------------------------------------------------------------------------------- /raw_data/broad6k/raw/structure.csv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liugangcode/InfoAlign/eed96f36503a8a427fac4f382ee2001f957dd42a/raw_data/broad6k/raw/structure.csv.gz -------------------------------------------------------------------------------- /raw_data/chembl2k/raw/CP-JUMP.csv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liugangcode/InfoAlign/eed96f36503a8a427fac4f382ee2001f957dd42a/raw_data/chembl2k/raw/CP-JUMP.csv.gz -------------------------------------------------------------------------------- /raw_data/chembl2k/raw/GE.csv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liugangcode/InfoAlign/eed96f36503a8a427fac4f382ee2001f957dd42a/raw_data/chembl2k/raw/GE.csv.gz -------------------------------------------------------------------------------- /raw_data/chembl2k/raw/assays.csv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liugangcode/InfoAlign/eed96f36503a8a427fac4f382ee2001f957dd42a/raw_data/chembl2k/raw/assays.csv.gz -------------------------------------------------------------------------------- /raw_data/chembl2k/raw/structure.csv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liugangcode/InfoAlign/eed96f36503a8a427fac4f382ee2001f957dd42a/raw_data/chembl2k/raw/structure.csv.gz -------------------------------------------------------------------------------- /raw_data/moltoxcast/raw/assays.csv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liugangcode/InfoAlign/eed96f36503a8a427fac4f382ee2001f957dd42a/raw_data/moltoxcast/raw/assays.csv.gz -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # Install PyTorch with CUDA 11.8 2 | -f https://download.pytorch.org/whl/cu118/torch_stable.html 3 | torch==2.2.0+cu118 4 | 5 | # Install PyTorch Geometric and related packages 6 | -f https://data.pyg.org/whl/torch-2.2.0+cu118.html 7 | torch_geometric==2.6.1 8 | torch_cluster==1.6.3 9 | 10 | # Other dependencies 11 | huggingface_hub==0.22.2 12 | joblib==1.3.2 13 | networkx==3.2.1 14 | ogb==1.3.6 15 | pandas==2.2.3 16 | PyYAML==6.0.2 17 | rdkit==2023.9.5 18 | scikit_learn==1.4.1.post1 19 | scipy==1.14.1 20 | tqdm==4.66.2 21 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .misc import * 2 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | 4 | import torch 5 | import numpy as np 6 | from sklearn.metrics import ( 7 | mean_absolute_error, 8 | roc_auc_score, 9 | ) 10 | __all__ = ["validate", "init_weights", "AverageMeter", "save_prediction"] 11 | 12 | def eval_func(pred, true, reduction=True): 13 | unique_values = np.unique(true[~np.isnan(true)]) 14 | task_type = 'classification' if set(unique_values).issubset({0, 1}) else 'regression' 15 | 16 | if task_type == 'classification': 17 | rocauc_list = [] 18 | count_dict = {80: [], 85: [], 90: [], 95: []} 19 | for i in range(true.shape[1]): 20 | if np.sum(true[:, i] == 1) > 0 and np.sum(true[:, i] == 0) > 0: 21 | is_labeled = true[:, i] == true[:, i] 22 | score = roc_auc_score(true[is_labeled, i], pred[is_labeled, i]) 23 | for threshold in count_dict.keys(): 24 | count_dict[threshold].append(int(score >= threshold / 100)) 25 | rocauc_list.append(score) 26 | else: 27 | for threshold in count_dict.keys(): 28 | count_dict[threshold].append(np.nan) 29 | 30 | if len(rocauc_list) == 0: 31 | raise RuntimeError( 32 | "No positively labeled data available. Cannot compute ROC-AUC." 33 | ) 34 | return {"roc_auc": sum(rocauc_list) / len(rocauc_list), "count": count_dict} if reduction else {"roc_auc": rocauc_list} 35 | 36 | elif task_type == 'regression': 37 | mae_list = [] 38 | for i in range(true.shape[1]): 39 | is_labeled = ~np.isnan(true[:, i]) 40 | mae_score = mean_absolute_error(true[is_labeled, i], pred[is_labeled, i]) 41 | mae_list.append(mae_score) 42 | return {"avg_mae": np.mean(mae_list), "mae_list": mae_list} 43 | 44 | def save_prediction(model, device, loader, dataset, output_dir, seed): 45 | y_true = [] 46 | y_pred = [] 47 | model.eval() 48 | for step, batch in enumerate(loader): 49 | batch = batch.to(device) 50 | if batch.x.shape[0] == 1: 51 | pass 52 | else: 53 | with torch.no_grad(): 54 | pred = model(batch) 55 | y_true.append(batch.y.view(pred.shape).detach().cpu()) 56 | y_pred.append(pred.detach().cpu()) 57 | y_true = torch.cat(y_true, dim=0).numpy() 58 | y_pred = torch.cat(y_pred, dim=0).numpy() 59 | 60 | assay_path = f"raw_data/{dataset.name}/raw/assays.csv.gz" 61 | assay_df = pd.read_csv(assay_path, compression="gzip") 62 | assay_names = assay_df.columns[dataset.start_column :] 63 | 64 | y_pred[np.isnan(y_true)] = np.nan 65 | os.makedirs(output_dir, exist_ok=True) 66 | 67 | df_pred = pd.DataFrame(y_pred, columns=[name + "_pred" for name in assay_names]) 68 | df_true = pd.DataFrame(y_true, columns=[name + "_true" for name in assay_names]) 69 | df = pd.concat([df_pred, df_true], axis=1) 70 | df.to_csv(os.path.join(output_dir, f"preds-{seed}.csv"), index=False) 71 | 72 | returned_dict = eval_func(y_pred, y_true, reduction=False) 73 | results = returned_dict.get('roc_auc', None) 74 | if results is None: 75 | results = returned_dict.get('mae_list', None) 76 | if results is None: 77 | raise ValueError("Invalid task type") 78 | results_dict = dict(zip(assay_names, results)) 79 | sorted_dict = dict( 80 | sorted(results_dict.items(), key=lambda item: item[1], reverse=True) 81 | ) 82 | df_sorted = pd.DataFrame( 83 | list(sorted_dict.items()), columns=["Assay Name", "Result"] 84 | ) 85 | df_sorted.to_csv(os.path.join(output_dir, f"result-{seed}.csv"), index=False) 86 | 87 | def validate(args, model, loader): 88 | y_true = [] 89 | y_pred = [] 90 | device = args.device 91 | model.eval() 92 | for step, batch in enumerate(loader): 93 | batch = batch.to(device) 94 | if batch.x.shape[0] == 1: 95 | pass 96 | else: 97 | with torch.no_grad(): 98 | pred = model(batch) 99 | y_true.append(batch.y.view(pred.shape).detach().cpu()) 100 | y_pred.append(pred.detach().cpu()) 101 | y_true = torch.cat(y_true, dim=0).numpy() 102 | y_pred = torch.cat(y_pred, dim=0).numpy() 103 | 104 | return eval_func(y_pred, y_true) 105 | 106 | 107 | def init_weights(net, init_type="normal", init_gain=0.02): 108 | """Initialize network weights. 109 | Parameters: 110 | net (network) -- network to be initialized 111 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal 112 | init_gain (float) -- scaling factor for normal, xavier and orthogonal. 113 | """ 114 | 115 | def init_func(m): # define the initialization function 116 | classname = m.__class__.__name__ 117 | if hasattr(m, "weight") and ( 118 | classname.find("Conv") != -1 or classname.find("Linear") != -1 119 | ): 120 | if init_type == "normal": 121 | torch.nn.init.normal_(m.weight.data, 0.0, init_gain) 122 | elif init_type == "xavier": 123 | torch.nn.init.xavier_normal_(m.weight.data, gain=init_gain) 124 | elif init_type == "kaiming": 125 | torch.nn.init.kaiming_normal_(m.weight.data, a=0, mode="fan_in") 126 | elif init_type == "orthogonal": 127 | torch.nn.init.orthogonal_(m.weight.data, gain=init_gain) 128 | elif init_type == "default": 129 | pass 130 | else: 131 | raise NotImplementedError( 132 | "initialization method [%s] is not implemented" % init_type 133 | ) 134 | if hasattr(m, "bias") and m.bias is not None: 135 | torch.nn.init.constant_(m.bias.data, 0.0) 136 | elif ( 137 | classname.find("BatchNorm2d") != -1 138 | ): # BatchNorm Layer's weight is not a matrix; only normal distribution applies. 139 | torch.nn.init.normal_(m.weight.data, 1.0, init_gain) 140 | torch.nn.init.constant_(m.bias.data, 0.0) 141 | 142 | print("initialize network with %s" % init_type) 143 | net.apply(init_func) # apply the initialization function 144 | 145 | class AverageMeter(object): 146 | """Computes and stores the average and current value 147 | Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262 148 | """ 149 | 150 | def __init__(self): 151 | self.reset() 152 | 153 | def reset(self): 154 | self.val = 0 155 | self.avg = 0 156 | self.sum = 0 157 | self.count = 0 158 | 159 | def update(self, val, n=1): 160 | self.val = val 161 | self.sum += val * n 162 | self.count += n 163 | self.avg = self.sum / self.count 164 | 165 | 166 | def log_base(base, x): 167 | return np.log(x) / np.log(base) 168 | 169 | 170 | def _eval_rocauc(y_true, y_pred): 171 | """ 172 | compute ROC-AUC averaged across tasks 173 | """ 174 | rocauc_list = [] 175 | for i in range(y_true.shape[1]): 176 | # AUC is only defined when there is at least one positive data. 177 | if np.sum(y_true[:, i] == 1) > 0 and np.sum(y_true[:, i] == 0) > 0: 178 | # ignore nan values 179 | is_labeled = y_true[:, i] == y_true[:, i] 180 | rocauc_list.append( 181 | roc_auc_score(y_true[is_labeled, i], y_pred[is_labeled, i]) 182 | ) 183 | 184 | if len(rocauc_list) == 0: 185 | raise RuntimeError( 186 | "No positively labeled data available. Cannot compute ROC-AUC." 187 | ) 188 | return {"rocauc": sum(rocauc_list) / len(rocauc_list)} 189 | 190 | 191 | if __name__ == "__main__": 192 | pass 193 | -------------------------------------------------------------------------------- /utils/train_funcs.py: -------------------------------------------------------------------------------- 1 | import time 2 | from tqdm import tqdm 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torch_cluster import random_walk 7 | 8 | from .misc import AverageMeter 9 | 10 | cls_criterion = torch.nn.BCEWithLogitsLoss(reduction="none") 11 | reg_criterion = torch.nn.L1Loss(reduction="none") 12 | 13 | def prior_criterion(args, pz_given_x): 14 | z = pz_given_x.rsample() 15 | loss = pz_given_x.log_prob(z) - args.prior_dist.log_prob(z) 16 | return loss.mean() 17 | 18 | def pretrain_loss_ce(pred, target, weight, nan_mask): 19 | expanded_weight = weight.unsqueeze(2).expand_as(target) 20 | loss = F.binary_cross_entropy_with_logits( 21 | pred[~nan_mask], 22 | target[~nan_mask], 23 | weight=expanded_weight[~nan_mask], 24 | reduction="none", 25 | ) 26 | return loss.mean() 27 | 28 | def pretrain_func( 29 | args, model, train_loaders, context_graph, optimizer, scheduler, epoch 30 | ): 31 | 32 | criterion = pretrain_loss_ce 33 | if not args.no_print: 34 | p_bar = tqdm(range(args.steps)) 35 | batch_time = AverageMeter() 36 | (losses_tot, losses_prior, losses_mol, losses_gene, losses_cell, losses_exp) = ( 37 | AverageMeter(), 38 | AverageMeter(), 39 | AverageMeter(), 40 | AverageMeter(), 41 | AverageMeter(), 42 | AverageMeter(), 43 | ) 44 | device = args.device 45 | model.train() 46 | 47 | context_node_type = context_graph.type 48 | mol_target = context_graph.mol_target 49 | gene_target = context_graph.gene_target 50 | cell_target = context_graph.cell_target 51 | express_target = context_graph.express_target 52 | edge_weight = context_graph.weight 53 | edge_weight = torch.cat( 54 | [ 55 | edge_weight, 56 | torch.zeros(1, dtype=edge_weight.dtype, device=edge_weight.device), 57 | ] 58 | ) 59 | index_map = {value: index for index, value in enumerate(context_node_type)} 60 | 61 | for batch_idx in range(args.steps): 62 | end = time.time() 63 | model.zero_grad() 64 | try: 65 | batched_data = next(train_loaders["train_iter"]) 66 | except: 67 | train_loaders["train_iter"] = iter(train_loaders["train_loader"]) 68 | batched_data = next(train_loaders["train_iter"]) 69 | batched_data = batched_data.to(device) 70 | 71 | start_indices = [index_map[value] for value in batched_data.type] 72 | start_indices = torch.tensor(start_indices).long() 73 | start_indices = start_indices.view(-1, 1).repeat(1, 1).view(-1) 74 | batched_walk, batched_edge_seq = random_walk( 75 | context_graph.edge_index[0], 76 | context_graph.edge_index[1], 77 | start_indices, 78 | args.walk_length, 79 | num_nodes=context_graph.num_nodes, 80 | return_edge_indices=True, 81 | ) 82 | 83 | # batched_walk = batched_walk[:, 1:] 84 | batched_path_weight = ( 85 | edge_weight[batched_edge_seq] 86 | .view(-1, args.walk_length) 87 | .cumprod(dim=-1) 88 | .to(device) 89 | ) 90 | ## if count starting 91 | batched_path_weight = torch.cat( 92 | [ 93 | torch.ones( 94 | batched_path_weight.size(0), 1, device=batched_path_weight.device 95 | ), 96 | batched_path_weight, 97 | ], 98 | dim=1, 99 | ) 100 | 101 | batched_mol_target = mol_target[batched_walk].to(device) 102 | batched_gene_target = gene_target[batched_walk].to(device) 103 | batched_cell_target = cell_target[batched_walk].to(device) 104 | batched_express_target = express_target[batched_walk].to(device) 105 | 106 | nan_mask_mol = torch.isnan(batched_mol_target) 107 | nan_mask_gene = torch.isnan(batched_gene_target) 108 | nan_mask_cell = torch.isnan(batched_cell_target) 109 | nan_mask_express = torch.isnan(batched_express_target) 110 | 111 | if batched_data.x.shape[0] == 1 or batched_data.batch[-1] == 0: 112 | continue 113 | else: 114 | pred_list = model(batched_data) 115 | pz_given_x, pred_mol, pred_gene, pred_cell, pred_express = pred_list 116 | 117 | pred_mol = pred_mol.unsqueeze(1).expand(-1, batched_mol_target.size(1), -1) 118 | pred_gene = pred_gene.unsqueeze(1).expand( 119 | -1, batched_gene_target.size(1), -1 120 | ) 121 | pred_cell = pred_cell.unsqueeze(1).expand( 122 | -1, batched_cell_target.size(1), -1 123 | ) 124 | pred_express = pred_express.unsqueeze(1).expand( 125 | -1, batched_express_target.size(1), -1 126 | ) 127 | 128 | loss_prior = prior_criterion(args, pz_given_x) 129 | 130 | loss_mol = criterion( 131 | pred_mol, batched_mol_target, batched_path_weight, nan_mask_mol 132 | ) 133 | loss_gene = criterion( 134 | pred_gene, batched_gene_target, batched_path_weight, nan_mask_gene 135 | ) 136 | loss_cell = criterion( 137 | pred_cell, batched_cell_target, batched_path_weight, nan_mask_cell 138 | ) 139 | loss_exp = criterion( 140 | pred_express, 141 | batched_express_target, 142 | batched_path_weight, 143 | nan_mask_express, 144 | ) 145 | loss = args.prior * loss_prior + loss_mol + loss_gene + loss_cell + loss_exp 146 | 147 | loss.backward() 148 | optimizer.step() 149 | scheduler.step() 150 | losses_tot.update(loss.item()) 151 | losses_prior.update(args.prior * loss_prior.item()) 152 | losses_mol.update(loss_mol.item()) 153 | losses_gene.update(loss_gene.item()) 154 | losses_cell.update(loss_cell.item()) 155 | losses_exp.update(loss_exp.item()) 156 | 157 | batch_time.update(time.time() - end) 158 | end = time.time() 159 | if not args.no_print: 160 | log_message = "Train Epoch: {epoch}/{epochs:3}. Iter: {batch:2}/{iter:2}. LR: {lr:1}e-4. Batch: {bt:.1f}s. Loss (Total): {loss_tot:.2f}. Loss (prior): {loss_prior:.4f}. Loss (mol): {loss_mol:.4f}. Loss (gene): {loss_gene:.4f}. Loss (cell): {loss_cell:.4f}. Loss (express): {loss_exp:.4f}.".format( 161 | epoch=epoch + 1, 162 | epochs=args.epochs, 163 | batch=batch_idx + 1, 164 | iter=args.steps, 165 | lr=args.lr * 10000, 166 | bt=batch_time.avg, 167 | loss_tot=losses_tot.avg, 168 | loss_prior=losses_prior.avg, 169 | loss_mol=losses_mol.avg, 170 | loss_gene=losses_gene.avg, 171 | loss_cell=losses_cell.avg, 172 | loss_exp=losses_exp.avg, 173 | ) 174 | p_bar.set_description(log_message) 175 | p_bar.update() 176 | if not args.no_print: 177 | p_bar.close() 178 | args.logger.info(log_message) 179 | 180 | return losses_tot.avg, train_loaders 181 | 182 | 183 | def finetune_func(args, model, train_loaders, optimizer, scheduler, epoch): 184 | if args.task_type == "regression": 185 | criterion = reg_criterion 186 | else: 187 | criterion = cls_criterion 188 | if not args.no_print: 189 | p_bar = tqdm(range(args.steps)) 190 | batch_time = AverageMeter() 191 | data_time = AverageMeter() 192 | losses = AverageMeter() 193 | device = args.device 194 | model.train() 195 | for batch_idx in range(args.steps): 196 | end = time.time() 197 | model.zero_grad() 198 | try: 199 | batched_data = next(train_loaders["train_iter"]) 200 | except: 201 | train_loaders["train_iter"] = iter(train_loaders["train_loader"]) 202 | batched_data = next(train_loaders["train_iter"]) 203 | batched_data = batched_data.to(device) 204 | targets = batched_data.y.to(torch.float32) 205 | is_labeled = targets == targets 206 | if batched_data.x.shape[0] == 1 or batched_data.batch[-1] == 0: 207 | continue 208 | else: 209 | preds = model(batched_data) 210 | loss = criterion( 211 | preds.view(targets.size()).to(torch.float32)[is_labeled], 212 | targets[is_labeled], 213 | ).mean() 214 | loss.backward() 215 | optimizer.step() 216 | # scheduler.step() 217 | losses.update(loss.item()) 218 | batch_time.update(time.time() - end) 219 | end = time.time() 220 | if not args.no_print: 221 | p_bar.set_description( 222 | "Train Epoch: {epoch}/{epochs:4}. Iter: {batch:4}/{iter:4}. LR: {lr:.8f}. Batch: {bt:.3f}s. Loss: {loss:.4f}. ".format( 223 | epoch=epoch + 1, 224 | epochs=args.epochs, 225 | batch=batch_idx + 1, 226 | iter=args.steps, 227 | # lr=scheduler.get_last_lr()[0], 228 | lr=args.lr, 229 | bt=batch_time.avg, 230 | loss=losses.avg, 231 | ) 232 | ) 233 | p_bar.update() 234 | if not args.no_print: 235 | p_bar.close() 236 | 237 | return train_loaders 238 | --------------------------------------------------------------------------------