├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── anno └── B_anno.csv ├── data ├── A.tsv └── labels.tsv ├── datasets ├── __init__.py ├── a_dataset.py ├── ab_dataset.py ├── abc_dataset.py ├── b_dataset.py ├── basic_dataset.py ├── c_dataset.py └── dataloader_prefetch.py ├── documents ├── preprint.pdf └── supplementary.pdf ├── environment.yml ├── models ├── __init__.py ├── basic_model.py ├── losses.py ├── networks.py ├── vae_alltask_gn_model.py ├── vae_alltask_model.py ├── vae_basic_model.py ├── vae_classifier_model.py ├── vae_multitask_gn_model.py ├── vae_multitask_model.py ├── vae_regression_model.py └── vae_survival_model.py ├── params ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ └── basic_params.cpython-36.pyc ├── basic_params.py ├── test_params.py ├── train_params.py └── train_test_params.py ├── requirements.txt ├── test.py ├── train.py ├── train_test.py └── util ├── __init__.py ├── metrics.py ├── preprocess.py ├── util.py └── visualizer.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints 2 | .idea/ 3 | .DS_Store 4 | checkpoints/ 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Xiaoyu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # OmiEmbed 2 | ***Please also have a look at our brand new omics-to-omics DL freamwork 👀:*** 3 | [OmiTrans](https://github.com/zhangxiaoyu11/OmiTrans) 4 | 5 | [![DOI](https://zenodo.org/badge/334077812.svg)](https://zenodo.org/badge/latestdoi/334077812) 6 | [![Codacy Badge](https://app.codacy.com/project/badge/Grade/ce304bf91b534e26b310b3c50072e8ae)](https://www.codacy.com/gh/zhangxiaoyu11/OmiEmbed/dashboard?utm_source=github.com&utm_medium=referral&utm_content=zhangxiaoyu11/OmiEmbed&utm_campaign=Badge_Grade) 7 | [![GitHub license](https://img.shields.io/github/license/Naereen/StrapDown.js.svg)](https://github.com/zhangxiaoyu11/OmiEmbed/blob/main/LICENSE) 8 | ![Safe](https://img.shields.io/badge/Stay-Safe-red?logo=data:image/svg%2bxml;base64,PHN2ZyBpZD0iTGF5ZXJfMSIgZW5hYmxlLWJhY2tncm91bmQ9Im5ldyAwIDAgNTEwIDUxMCIgaGVpZ2h0PSI1MTIiIHZpZXdCb3g9IjAgMCA1MTAgNTEwIiB3aWR0aD0iNTEyIiB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciPjxnPjxnPjxwYXRoIGQ9Im0xNzQuNjEgMzAwYy0yMC41OCAwLTQwLjU2IDYuOTUtNTYuNjkgMTkuNzJsLTExMC4wOSA4NS43OTd2MTA0LjQ4M2g1My41MjlsNzYuNDcxLTY1aDEyNi44MnYtMTQ1eiIgZmlsbD0iI2ZmZGRjZSIvPjwvZz48cGF0aCBkPSJtNTAyLjE3IDI4NC43MmMwIDguOTUtMy42IDE3Ljg5LTEwLjc4IDI0LjQ2bC0xNDguNTYgMTM1LjgyaC03OC4xOHYtODVoNjguMThsMTE0LjM0LTEwMC4yMWMxMi44Mi0xMS4yMyAzMi4wNi0xMC45MiA0NC41LjczIDcgNi41NSAxMC41IDE1LjM4IDEwLjUgMjQuMnoiIGZpbGw9IiNmZmNjYmQiLz48cGF0aCBkPSJtMzMyLjgzIDM0OS42M3YxMC4zN2gtNjguMTh2LTYwaDE4LjU1YzI3LjQxIDAgNDkuNjMgMjIuMjIgNDkuNjMgNDkuNjN6IiBmaWxsPSIjZmZjY2JkIi8+PHBhdGggZD0ibTM5OS44IDc3LjN2OC4wMWMwIDIwLjY1LTguMDQgNDAuMDctMjIuNjQgNTQuNjdsLTExMi41MSAxMTIuNTF2LTIyNi42NmwzLjE4LTMuMTljMTQuNi0xNC42IDM0LjAyLTIyLjY0IDU0LjY3LTIyLjY0IDQyLjYyIDAgNzcuMyAzNC42OCA3Ny4zIDc3LjN6IiBmaWxsPSIjZDAwMDUwIi8+PHBhdGggZD0ibTI2NC42NSAyNS44M3YyMjYuNjZsLTExMi41MS0xMTIuNTFjLTE0LjYtMTQuNi0yMi42NC0zNC4wMi0yMi42NC01NC42N3YtOC4wMWMwLTQyLjYyIDM0LjY4LTc3LjMgNzcuMy03Ny4zIDIwLjY1IDAgNDAuMDYgOC4wNCA1NC42NiAyMi42NHoiIGZpbGw9IiNmZjRhNGEiLz48cGF0aCBkPSJtMjEyLjgzIDM2MC4xMnYzMGg1MS44MnYtMzB6IiBmaWxsPSIjZmZjY2JkIi8+PHBhdGggZD0ibTI2NC42NSAzNjAuMTJ2MzBoMzYuMTRsMzIuMDQtMzB6IiBmaWxsPSIjZmZiZGE5Ii8+PC9nPjwvc3ZnPg==) 9 | [![GitHub Repo stars](https://img.shields.io/github/stars/zhangxiaoyu11/OmiEmbed?style=social)](https://github.com/zhangxiaoyu11/OmiEmbed/stargazers) 10 | [![GitHub forks](https://img.shields.io/github/forks/zhangxiaoyu11/OmiEmbed?style=social)](https://github.com/zhangxiaoyu11/OmiEmbed/network/members) 11 | 12 | **OmiEmbed: A Unified Multi-task Deep Learning Framework for Multi-omics Data** 13 | 14 | **Xiaoyu Zhang** (x.zhang18@imperial.ac.uk) 15 | 16 | Data Science Institute, Imperial College London 17 | 18 | ## Introduction 19 | 20 | OmiEmbed is a unified framework for deep learning-based omics data analysis, which supports: 21 | 22 | 1. Multi-omics integration 23 | 2. Dimensionality reduction 24 | 3. Omics embedding learning 25 | 4. Tumour type classification 26 | 5. Phenotypic feature reconstruction 27 | 6. Survival prediction 28 | 7. Multi-task learning for aforementioned tasks 29 | 30 | Paper Link: [https://doi.org/10.3390/cancers13123047](https://doi.org/10.3390/cancers13123047) 31 | 32 | ## Getting Started 33 | 34 | ### Prerequisites 35 | - CPU or NVIDIA GPU + CUDA CuDNN 36 | - [Python](https://www.python.org/downloads) 3.6+ 37 | - Python Package Manager 38 | - [Anaconda](https://docs.anaconda.com/anaconda/install) 3 (recommended) 39 | - or [pip](https://pip.pypa.io/en/stable/installing/) 21.0+ 40 | - Python Packages 41 | - [PyTorch](https://pytorch.org/get-started/locally) 1.2+ 42 | - TensorBoard 1.10+ 43 | - Tables 3.6+ 44 | - scikit-survival 0.6+ 45 | - prefetch-generator 1.0+ 46 | - [Git](https://git-scm.com/book/en/v2/Getting-Started-Installing-Git) 2.7+ 47 | 48 | ### Installation 49 | - Clone the repo 50 | ```bash 51 | git clone https://github.com/zhangxiaoyu11/OmiEmbed.git 52 | cd OmiEmbed 53 | ``` 54 | - Install the dependencies 55 | - For conda users 56 | ```bash 57 | conda env create -f environment.yml 58 | conda activate omiembed 59 | ``` 60 | - For pip users 61 | ```bash 62 | pip install -r requirements.txt 63 | ``` 64 | 65 | ### Try it out 66 | - Train and test using the built-in sample dataset with the default settings 67 | ```bash 68 | python train_test.py 69 | ``` 70 | - Check the output files 71 | ```bash 72 | cd checkpoints/test/ 73 | ``` 74 | - Visualise the metrics and losses 75 | ```bash 76 | tensorboard --logdir=tb_log --bind_all 77 | ``` 78 | 79 | ## Citation 80 | If you use this code in your research, please cite our paper. 81 | ```bibtex 82 | @Article{OmiEmbed2021, 83 | AUTHOR = {Zhang, Xiaoyu and Xing, Yuting and Sun, Kai and Guo, Yike}, 84 | TITLE = {OmiEmbed: A Unified Multi-Task Deep Learning Framework for Multi-Omics Data}, 85 | JOURNAL = {Cancers}, 86 | VOLUME = {13}, 87 | YEAR = {2021}, 88 | NUMBER = {12}, 89 | ARTICLE-NUMBER = {3047}, 90 | ISSN = {2072-6694}, 91 | DOI = {10.3390/cancers13123047} 92 | } 93 | ``` 94 | 95 | ## OmiTrans 96 | ***Please also have a look at our brand new omics-to-omics DL freamwork 👀:*** 97 | [OmiTrans](https://github.com/zhangxiaoyu11/OmiTrans) 98 | 99 | ## License 100 | This source code is licensed under the [MIT](https://github.com/zhangxiaoyu11/OmiEmbed/blob/main/LICENSE) license. 101 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | This package about data loading and data preprocessing 3 | """ 4 | import os 5 | import torch 6 | import importlib 7 | import numpy as np 8 | import pandas as pd 9 | from util import util 10 | from datasets.basic_dataset import BasicDataset 11 | from datasets.dataloader_prefetch import DataLoaderPrefetch 12 | from torch.utils.data import Subset 13 | from sklearn.model_selection import train_test_split 14 | 15 | 16 | def find_dataset_using_name(dataset_mode): 17 | """ 18 | Get the dataset of certain mode 19 | """ 20 | dataset_filename = "datasets." + dataset_mode + "_dataset" 21 | datasetlib = importlib.import_module(dataset_filename) 22 | 23 | # Instantiate the dataset class 24 | dataset = None 25 | # Change the name format to corresponding class name 26 | target_dataset_name = dataset_mode.replace('_', '') + 'dataset' 27 | for name, cls in datasetlib.__dict__.items(): 28 | if name.lower() == target_dataset_name.lower() \ 29 | and issubclass(cls, BasicDataset): 30 | dataset = cls 31 | 32 | if dataset is None: 33 | raise NotImplementedError("In %s.py, there should be a subclass of BasicDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name)) 34 | 35 | return dataset 36 | 37 | 38 | def create_dataset(param): 39 | """ 40 | Create a dataset given the parameters. 41 | """ 42 | dataset_class = find_dataset_using_name(param.omics_mode) 43 | # Get an instance of this dataset class 44 | dataset = dataset_class(param) 45 | print("Dataset [%s] was created" % type(dataset).__name__) 46 | 47 | return dataset 48 | 49 | 50 | class CustomDataLoader: 51 | """ 52 | Create a dataloader for certain dataset. 53 | """ 54 | def __init__(self, dataset, param, shuffle=True, enable_drop_last=False): 55 | self.dataset = dataset 56 | self.param = param 57 | 58 | drop_last = False 59 | if enable_drop_last: 60 | if len(dataset) % param.batch_size < 3*len(param.gpu_ids): 61 | drop_last = True 62 | 63 | # Create dataloader for this dataset 64 | self.dataloader = DataLoaderPrefetch( 65 | dataset, 66 | batch_size=param.batch_size, 67 | shuffle=shuffle, 68 | num_workers=int(param.num_threads), 69 | drop_last=drop_last, 70 | pin_memory=param.set_pin_memory 71 | ) 72 | 73 | def __len__(self): 74 | """Return the number of data in the dataset""" 75 | return len(self.dataset) 76 | 77 | def __iter__(self): 78 | """Return a batch of data""" 79 | for i, data in enumerate(self.dataloader): 80 | yield data 81 | 82 | def get_A_dim(self): 83 | """Return the dimension of first input omics data type""" 84 | return self.dataset.A_dim 85 | 86 | def get_B_dim(self): 87 | """Return the dimension of second input omics data type""" 88 | return self.dataset.B_dim 89 | 90 | def get_omics_dims(self): 91 | """Return a list of omics dimensions""" 92 | return self.dataset.omics_dims 93 | 94 | def get_class_num(self): 95 | """Return the number of classes for the downstream classification task""" 96 | return self.dataset.class_num 97 | 98 | def get_values_max(self): 99 | """Return the maximum target value of the dataset""" 100 | return self.dataset.values_max 101 | 102 | def get_values_min(self): 103 | """Return the minimum target value of the dataset""" 104 | return self.dataset.values_min 105 | 106 | def get_survival_T_max(self): 107 | """Return the maximum T of the dataset""" 108 | return self.dataset.survival_T_max 109 | 110 | def get_survival_T_min(self): 111 | """Return the minimum T of the dataset""" 112 | return self.dataset.survival_T_min 113 | 114 | def get_sample_list(self): 115 | """Return the sample list of the dataset""" 116 | return self.dataset.sample_list 117 | 118 | 119 | def create_single_dataloader(param, shuffle=True, enable_drop_last=False): 120 | """ 121 | Create a single dataloader 122 | """ 123 | dataset = create_dataset(param) 124 | dataloader = CustomDataLoader(dataset, param, shuffle=shuffle, enable_drop_last=enable_drop_last) 125 | sample_list = dataset.sample_list 126 | 127 | return dataloader, sample_list 128 | 129 | 130 | def create_separate_dataloader(param): 131 | """ 132 | Create set of dataloader (train, val, test). 133 | """ 134 | full_dataset = create_dataset(param) 135 | full_size = len(full_dataset) 136 | full_idx = np.arange(full_size) 137 | 138 | if param.not_stratified: 139 | train_idx, test_idx = train_test_split(full_idx, 140 | test_size=param.test_ratio, 141 | train_size=param.train_ratio, 142 | shuffle=True) 143 | else: 144 | if param.downstream_task == 'classification': 145 | targets = full_dataset.labels_array 146 | elif param.downstream_task == 'survival': 147 | targets = full_dataset.survival_E_array 148 | if param.stratify_label: 149 | targets = full_dataset.labels_array 150 | elif param.downstream_task == 'multitask': 151 | targets = full_dataset.labels_array 152 | elif param.downstream_task == 'alltask': 153 | targets = full_dataset.labels_array[0] 154 | train_idx, test_idx = train_test_split(full_idx, 155 | test_size=param.test_ratio, 156 | train_size=param.train_ratio, 157 | shuffle=True, 158 | stratify=targets) 159 | 160 | val_idx = list(set(full_idx) - set(train_idx) - set(test_idx)) 161 | 162 | train_dataset = Subset(full_dataset, train_idx) 163 | val_dataset = Subset(full_dataset, val_idx) 164 | test_dataset = Subset(full_dataset, test_idx) 165 | 166 | full_dataloader = CustomDataLoader(full_dataset, param) 167 | train_dataloader = CustomDataLoader(train_dataset, param, enable_drop_last=True) 168 | val_dataloader = CustomDataLoader(val_dataset, param, shuffle=False) 169 | test_dataloader = CustomDataLoader(test_dataset, param, shuffle=False) 170 | 171 | return full_dataloader, train_dataloader, val_dataloader, test_dataloader 172 | 173 | 174 | def load_file(param, file_name): 175 | """ 176 | Load data according to the format. 177 | """ 178 | if param.file_format == 'tsv': 179 | file_path = os.path.join(param.data_root, file_name + '.tsv') 180 | print('Loading data from ' + file_path) 181 | df = pd.read_csv(file_path, sep='\t', header=0, index_col=0, na_filter=param.detect_na) 182 | elif param.file_format == 'csv': 183 | file_path = os.path.join(param.data_root, file_name + '.csv') 184 | print('Loading data from ' + file_path) 185 | df = pd.read_csv(file_path, header=0, index_col=0, na_filter=param.detect_na) 186 | elif param.file_format == 'hdf': 187 | file_path = os.path.join(param.data_root, file_name + '.h5') 188 | print('Loading data from ' + file_path) 189 | df = pd.read_hdf(file_path, header=0, index_col=0) 190 | else: 191 | raise NotImplementedError('File format %s is supported' % param.file_format) 192 | return df 193 | 194 | 195 | def get_survival_y_true(param, T, E): 196 | """ 197 | Get y_true for survival prediction based on T and E 198 | """ 199 | # Get T_max 200 | if param.survival_T_max == -1: 201 | T_max = T.max() 202 | else: 203 | T_max = param.survival_T_max 204 | 205 | # Get time points 206 | time_points = util.get_time_points(T_max, param.time_num) 207 | 208 | # Get the y_true 209 | y_true = [] 210 | for i, (t, e) in enumerate(zip(T, E)): 211 | y_true_i = np.zeros(param.time_num + 1) 212 | dist_to_time_points = [abs(t - point) for point in time_points[:-1]] 213 | time_index = np.argmin(dist_to_time_points) 214 | # if this is a uncensored data point 215 | if e == 1: 216 | y_true_i[time_index] = 1 217 | y_true.append(y_true_i) 218 | # if this is a censored data point 219 | else: 220 | y_true_i[time_index:] = 1 221 | y_true.append(y_true_i) 222 | y_true = torch.Tensor(y_true) 223 | 224 | return y_true 225 | -------------------------------------------------------------------------------- /datasets/a_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | from datasets import load_file 3 | from datasets import get_survival_y_true 4 | from datasets.basic_dataset import BasicDataset 5 | import numpy as np 6 | import pandas as pd 7 | import torch 8 | 9 | 10 | class ADataset(BasicDataset): 11 | """ 12 | A dataset class for gene expression dataset. 13 | File should be prepared as '/path/to/data/A.tsv'. 14 | For each omics file, each columns should be each sample and each row should be each molecular feature. 15 | """ 16 | 17 | def __init__(self, param): 18 | """ 19 | Initialize this dataset class. 20 | """ 21 | BasicDataset.__init__(self, param) 22 | self.omics_dims = [] 23 | 24 | # Load data for A 25 | A_df = load_file(param, 'A') 26 | # Get the sample list 27 | if param.use_sample_list: 28 | sample_list_path = os.path.join(param.data_root, 'sample_list.tsv') # get the path of sample list 29 | self.sample_list = np.loadtxt(sample_list_path, delimiter='\t', dtype=': initialize the class, first call BasicDataset.__init__(self, param). 13 | -- <__len__>: return the size of dataset. 14 | -- <__getitem__>: get a data point. 15 | """ 16 | 17 | def __init__(self, param): 18 | """ 19 | Initialize the class, save the parameters in the class 20 | """ 21 | self.param = param 22 | self.sample_list = None 23 | 24 | @abstractmethod 25 | def __len__(self): 26 | """Return the total number of samples in the dataset.""" 27 | return 0 28 | 29 | @abstractmethod 30 | def __getitem__(self, index): 31 | """ 32 | Return a data point and its metadata information. 33 | Parameters: 34 | index - - a integer for data indexing 35 | Returns: 36 | a dictionary of data with their names. It usually contains the data itself and its metadata information. 37 | """ 38 | pass 39 | -------------------------------------------------------------------------------- /datasets/c_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | from datasets import load_file 3 | from datasets import get_survival_y_true 4 | from datasets.basic_dataset import BasicDataset 5 | import numpy as np 6 | import pandas as pd 7 | import torch 8 | 9 | 10 | class CDataset(BasicDataset): 11 | """ 12 | A dataset class for miRNA expression dataset. 13 | File should be prepared as '/path/to/data/C.tsv'. 14 | For each omics file, each columns should be each sample and each row should be each molecular feature. 15 | """ 16 | 17 | def __init__(self, param): 18 | """ 19 | Initialize this dataset class. 20 | """ 21 | BasicDataset.__init__(self, param) 22 | self.omics_dims = [] 23 | self.omics_dims.append(None) # First dimension is for gene expression (A) 24 | self.omics_dims.append(None) # Second dimension is for DNA methylation (B) 25 | 26 | # Load data for C 27 | C_df = load_file(param, 'C') 28 | # Get the sample list 29 | if param.use_sample_list: 30 | sample_list_path = os.path.join(param.data_root, 'sample_list.tsv') # get the path of sample list 31 | self.sample_list = np.loadtxt(sample_list_path, delimiter='\t', dtype=' of the model class.""" 33 | model_class = find_model_using_name(model_name) 34 | return model_class.modify_commandline_parameters 35 | 36 | 37 | def create_model(param): 38 | """ 39 | Create a model given the parameters 40 | """ 41 | model = find_model_using_name(param.model) 42 | # Initialize the model 43 | instance = model(param) 44 | print('Model [%s] was created' % type(instance).__name__) 45 | return instance 46 | -------------------------------------------------------------------------------- /models/basic_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from abc import ABC, abstractmethod 5 | from . import networks 6 | from collections import OrderedDict 7 | 8 | 9 | class BasicModel(ABC): 10 | """ 11 | This class is an abstract base class for models. 12 | To create a subclass, you need to implement the following five functions: 13 | -- <__init__>: Initialize the class, first call BasicModel.__init__(self, param) 14 | -- : Add model-specific parameters, and rewrite default values for existing parameters 15 | -- : Unpack input data from the output dictionary of the dataloader 16 | -- : Get the reconstructed omics data and results for the downstream task 17 | -- : Calculate losses, gradients and update network parameters 18 | """ 19 | 20 | def __init__(self, param): 21 | """ 22 | Initialize the BaseModel class 23 | """ 24 | self.param = param 25 | self.gpu_ids = param.gpu_ids 26 | self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU 27 | self.save_dir = os.path.join(param.checkpoints_dir, param.experiment_name) # save all the checkpoints to save_dir, and this is where to load the models 28 | self.load_net_dir = os.path.join(param.checkpoints_dir, param.experiment_to_load) # load pretrained networks from certain experiment folder 29 | self.isTrain = param.isTrain 30 | self.phase = 'p1' 31 | self.epoch = 1 32 | self.iter = 0 33 | 34 | # Improve the performance if the dimensionality and shape of the input data keep the same 35 | torch.backends.cudnn.benchmark = True 36 | 37 | self.plateau_metric = 0 # used for learning rate policy 'plateau' 38 | 39 | self.loss_names = [] 40 | self.model_names = [] 41 | self.metric_names = [] 42 | self.optimizers = [] 43 | self.schedulers = [] 44 | 45 | self.latent = None 46 | self.loss_embed = None 47 | self.loss_down = None 48 | self.loss_All = None 49 | 50 | @staticmethod 51 | def modify_commandline_parameters(parser, is_train): 52 | """ 53 | Add model-specific parameters, and rewrite default values for existing parameters. 54 | 55 | Parameters: 56 | parser -- original parameter parser 57 | is_train (bool) -- whether it is currently training phase or test phase. Use this flag to add or change training-specific or test-specific parameters. 58 | 59 | Returns: 60 | The modified parser. 61 | """ 62 | return parser 63 | 64 | @abstractmethod 65 | def set_input(self, input_dict): 66 | """ 67 | Unpack input data from the output dictionary of the dataloader 68 | 69 | Parameters: 70 | input_dict (dict): include the data tensor and its label 71 | """ 72 | pass 73 | 74 | @abstractmethod 75 | def forward(self): 76 | """ 77 | Run forward pass 78 | """ 79 | pass 80 | 81 | @abstractmethod 82 | def cal_losses(self): 83 | """ 84 | Calculate losses 85 | """ 86 | pass 87 | 88 | @abstractmethod 89 | def update(self): 90 | """ 91 | Calculate losses, gradients and update network weights; called in every training iteration 92 | """ 93 | pass 94 | 95 | def setup(self, param): 96 | """ 97 | Load and print networks, create schedulers 98 | """ 99 | if self.isTrain: 100 | self.print_networks(param) 101 | # For every optimizer we have a scheduler 102 | self.schedulers = [networks.get_scheduler(optimizer, param) for optimizer in self.optimizers] 103 | 104 | # Loading the networks 105 | if not self.isTrain or param.continue_train: 106 | self.load_networks(param.epoch_to_load) 107 | 108 | def update_learning_rate(self): 109 | """ 110 | Update learning rates for all the networks 111 | Called at the end of each epoch 112 | """ 113 | lr = self.optimizers[0].param_groups[0]['lr'] 114 | 115 | for scheduler in self.schedulers: 116 | if self.param.lr_policy == 'plateau': 117 | scheduler.step(self.plateau_metric) 118 | else: 119 | scheduler.step() 120 | 121 | return lr 122 | 123 | 124 | def print_networks(self, param): 125 | """ 126 | Print the total number of parameters in the network and network architecture if detail is true 127 | Save the networks information to the disk 128 | """ 129 | message = '\n----------------------Networks Information----------------------' 130 | for model_name in self.model_names: 131 | if isinstance(model_name, str): 132 | net = getattr(self, 'net' + model_name) 133 | num_params = 0 134 | for parameter in net.parameters(): 135 | num_params += parameter.numel() 136 | if param.detail: 137 | message += '\n' + str(net) 138 | message += '\n[Network {:s}] Total number of parameters : {:.3f} M'.format(model_name, num_params / 1e6) 139 | message += '\n----------------------------------------------------------------\n' 140 | 141 | # Save the networks information to the disk 142 | net_info_filename = os.path.join(param.checkpoints_dir, param.experiment_name, 'net_info.txt') 143 | with open(net_info_filename, 'w') as log_file: 144 | log_file.write(message) 145 | 146 | print(message) 147 | 148 | def save_networks(self, epoch): 149 | """ 150 | Save all the networks to the disk. 151 | 152 | Parameters: 153 | epoch (str) -- current epoch 154 | """ 155 | for name in self.model_names: 156 | if isinstance(name, str): 157 | save_filename = '{:s}_net_{:s}.pth'.format(epoch, name) 158 | save_path = os.path.join(self.save_dir, save_filename) 159 | # Use the str to get the attribute aka the network (self.netG / self.netD) 160 | net = getattr(self, 'net' + name) 161 | # If we use multi GPUs and apply the data parallel 162 | if len(self.gpu_ids) > 0 and torch.cuda.is_available(): 163 | torch.save(net.module.cpu().state_dict(), save_path) 164 | net.cuda(self.gpu_ids[0]) 165 | else: 166 | torch.save(net.cpu().state_dict(), save_path) 167 | 168 | def load_networks(self, epoch): 169 | """ 170 | Load networks at specified epoch from the disk. 171 | 172 | Parameters: 173 | epoch (str) -- Which epoch to load 174 | """ 175 | for model_name in self.model_names: 176 | if isinstance(model_name, str): 177 | load_filename = '{:s}_net_{:s}.pth'.format(epoch, model_name) 178 | load_path = os.path.join(self.load_net_dir, load_filename) 179 | # Use the str to get the attribute aka the network (self.netG / self.netD) 180 | net = getattr(self, 'net' + model_name) 181 | # If we use multi GPUs and apply the data parallel 182 | if isinstance(net, torch.nn.DataParallel): 183 | net = net.module 184 | print('Loading the model from %s' % load_path) 185 | state_dict = torch.load(load_path, map_location=self.device) 186 | if hasattr(state_dict, '_metadata'): 187 | del state_dict._metadata 188 | 189 | net.load_state_dict(state_dict) 190 | 191 | def set_train(self): 192 | """ 193 | Set train mode for networks 194 | """ 195 | for model_name in self.model_names: 196 | if isinstance(model_name, str): 197 | # Use the str to get the attribute aka the network (self.netXXX) 198 | net = getattr(self, 'net' + model_name) 199 | net.train() 200 | 201 | def set_eval(self): 202 | """ 203 | Set eval mode for networks 204 | """ 205 | for model_name in self.model_names: 206 | if isinstance(model_name, str): 207 | # Use the str to get the attribute aka the network (self.netG / self.netD) 208 | net = getattr(self, 'net' + model_name) 209 | net.eval() 210 | 211 | def test(self): 212 | """ 213 | Forward in testing to get the output tensors 214 | """ 215 | with torch.no_grad(): 216 | self.forward() 217 | self.cal_losses() 218 | 219 | def init_output_dict(self): 220 | """ 221 | initialize a dictionary for downstream task output 222 | """ 223 | output_dict = OrderedDict() 224 | output_names = [] 225 | if self.param.downstream_task == 'classification': 226 | output_names = ['index', 'y_true', 'y_pred', 'y_prob'] 227 | elif self.param.downstream_task == 'regression': 228 | output_names = ['index', 'y_true', 'y_pred'] 229 | elif self.param.downstream_task == 'survival': 230 | output_names = ['index', 'y_true_E', 'y_true_T', 'survival', 'risk', 'y_out'] 231 | elif self.param.downstream_task == 'multitask' or self.param.downstream_task == 'alltask': 232 | output_names = ['index', 'y_true_E', 'y_true_T', 'survival', 'risk', 'y_out_sur', 'y_true_cla', 'y_pred_cla', 233 | 'y_prob_cla', 'y_true_reg', 'y_pred_reg'] 234 | for name in output_names: 235 | output_dict[name] = None 236 | 237 | return output_dict 238 | 239 | def update_output_dict(self, output_dict): 240 | """ 241 | output_dict (OrderedDict) -- the output dictionary to be updated 242 | """ 243 | down_output = self.get_down_output() 244 | output_names = [] 245 | if self.param.downstream_task == 'classification': 246 | output_names = ['index', 'y_true', 'y_pred', 'y_prob'] 247 | elif self.param.downstream_task == 'regression': 248 | output_names = ['index', 'y_true', 'y_pred'] 249 | elif self.param.downstream_task == 'survival': 250 | output_names = ['index', 'y_true_E', 'y_true_T', 'survival', 'risk', 'y_out'] 251 | elif self.param.downstream_task == 'multitask' or self.param.downstream_task == 'alltask': 252 | output_names = ['index', 'y_true_E', 'y_true_T', 'survival', 'risk', 'y_out_sur', 'y_true_cla', 253 | 'y_pred_cla', 'y_prob_cla', 'y_true_reg', 'y_pred_reg'] 254 | 255 | for name in output_names: 256 | if output_dict[name] is None: 257 | output_dict[name] = down_output[name] 258 | else: 259 | if self.param.downstream_task == 'alltask' and name in ['y_true_cla', 'y_pred_cla', 'y_prob_cla']: 260 | for i in range(self.param.task_num-2): 261 | output_dict[name][i] = torch.cat((output_dict[name][i], down_output[name][i])) 262 | else: 263 | output_dict[name] = torch.cat((output_dict[name], down_output[name])) 264 | 265 | def init_losses_dict(self): 266 | """ 267 | initialize a losses dictionary 268 | """ 269 | losses_dict = OrderedDict() 270 | for name in self.loss_names: 271 | if isinstance(name, str): 272 | losses_dict[name] = [] 273 | return losses_dict 274 | 275 | def update_losses_dict(self, losses_dict, actual_batch_size): 276 | """ 277 | losses_dict (OrderedDict) -- the losses dictionary to be updated 278 | actual_batch_size (int) -- actual batch size for loss normalization 279 | """ 280 | for name in self.loss_names: 281 | if isinstance(name, str): 282 | if self.param.reduction == 'sum': 283 | losses_dict[name].append(float(getattr(self, 'loss_' + name))/actual_batch_size) 284 | elif self.param.reduction == 'mean': 285 | losses_dict[name].append(float(getattr(self, 'loss_' + name))) 286 | 287 | def init_metrics_dict(self): 288 | """ 289 | initialize a metrics dictionary 290 | """ 291 | metrics_dict = OrderedDict() 292 | for name in self.metric_names: 293 | if isinstance(name, str): 294 | metrics_dict[name] = None 295 | return metrics_dict 296 | 297 | def update_metrics_dict(self, metrics_dict): 298 | """ 299 | metrics_dict (OrderedDict) -- the metrics dictionary to be updated 300 | """ 301 | for name in self.metric_names: 302 | if isinstance(name, str): 303 | metrics_dict[name] = getattr(self, 'metric_' + name) 304 | 305 | def init_log_dict(self): 306 | """ 307 | initialize losses and metrics dictionary 308 | """ 309 | output_dict = self.init_output_dict() 310 | losses_dict = self.init_losses_dict() 311 | metrics_dict = self.init_metrics_dict() 312 | return output_dict, losses_dict, metrics_dict 313 | 314 | def update_log_dict(self, output_dict, losses_dict, metrics_dict, actual_batch_size): 315 | """ 316 | output_dict (OrderedDict) -- the output dictionary to be updated 317 | losses_dict (OrderedDict) -- the losses dictionary to be updated 318 | metrics_dict (OrderedDict) -- the metrics dictionary to be updated 319 | actual_batch_size (int) -- actual batch size for loss normalization 320 | """ 321 | self.update_output_dict(output_dict) 322 | self.calculate_current_metrics(output_dict) 323 | self.update_losses_dict(losses_dict, actual_batch_size) 324 | self.update_metrics_dict(metrics_dict) 325 | 326 | def init_latent_dict(self): 327 | """ 328 | initialize and return an empty latent space array and an empty index array 329 | """ 330 | latent_dict = OrderedDict() 331 | latent_dict['index'] = np.zeros(shape=[0]) 332 | latent_dict['latent'] = np.zeros(shape=[0, self.param.latent_space_dim]) 333 | return latent_dict 334 | 335 | def update_latent_dict(self, latent_dict): 336 | """ 337 | update the latent dict 338 | latent_dict (OrderedDict) 339 | """ 340 | with torch.no_grad(): 341 | current_latent_array = self.latent.cpu().numpy() 342 | latent_dict['latent'] = np.concatenate((latent_dict['latent'], current_latent_array)) 343 | current_index_array = self.data_index.cpu().numpy() 344 | latent_dict['index'] = np.concatenate((latent_dict['index'], current_index_array)) 345 | return latent_dict 346 | -------------------------------------------------------------------------------- /models/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def get_loss_func(loss_name, reduction='mean'): 6 | """ 7 | Return the loss function. 8 | Parameters: 9 | loss_name (str) -- the name of the loss function: BCE | MSE | L1 | CE 10 | reduction (str) -- the reduction method applied to the loss function: sum | mean 11 | """ 12 | if loss_name == 'BCE': 13 | return nn.BCEWithLogitsLoss(reduction=reduction) 14 | elif loss_name == 'MSE': 15 | return nn.MSELoss(reduction=reduction) 16 | elif loss_name == 'L1': 17 | return nn.L1Loss(reduction=reduction) 18 | elif loss_name == 'CE': 19 | return nn.CrossEntropyLoss(reduction=reduction) 20 | else: 21 | raise NotImplementedError('Loss function %s is not found' % loss_name) 22 | 23 | 24 | def kl_loss(mean, log_var, reduction='mean'): 25 | part_loss = 1 + log_var - mean.pow(2) - log_var.exp() 26 | if reduction == 'mean': 27 | loss = -0.5 * torch.mean(part_loss) 28 | else: 29 | loss = -0.5 * torch.sum(part_loss) 30 | return loss 31 | 32 | 33 | def MTLR_survival_loss(y_pred, y_true, E, tri_matrix, reduction='mean'): 34 | """ 35 | Compute the MTLR survival loss 36 | """ 37 | # Get censored index and uncensored index 38 | censor_idx = [] 39 | uncensor_idx = [] 40 | for i in range(len(E)): 41 | # If this is a uncensored data point 42 | if E[i] == 1: 43 | # Add to uncensored index list 44 | uncensor_idx.append(i) 45 | else: 46 | # Add to censored index list 47 | censor_idx.append(i) 48 | 49 | # Separate y_true and y_pred 50 | y_pred_censor = y_pred[censor_idx] 51 | y_true_censor = y_true[censor_idx] 52 | y_pred_uncensor = y_pred[uncensor_idx] 53 | y_true_uncensor = y_true[uncensor_idx] 54 | 55 | # Calculate likelihood for censored datapoint 56 | phi_censor = torch.exp(torch.mm(y_pred_censor, tri_matrix)) 57 | reduc_phi_censor = torch.sum(phi_censor * y_true_censor, dim=1) 58 | 59 | # Calculate likelihood for uncensored datapoint 60 | phi_uncensor = torch.exp(torch.mm(y_pred_uncensor, tri_matrix)) 61 | reduc_phi_uncensor = torch.sum(phi_uncensor * y_true_uncensor, dim=1) 62 | 63 | # Likelihood normalisation 64 | z_censor = torch.exp(torch.mm(y_pred_censor, tri_matrix)) 65 | reduc_z_censor = torch.sum(z_censor, dim=1) 66 | z_uncensor = torch.exp(torch.mm(y_pred_uncensor, tri_matrix)) 67 | reduc_z_uncensor = torch.sum(z_uncensor, dim=1) 68 | 69 | # MTLR loss 70 | loss = - (torch.sum(torch.log(reduc_phi_censor)) + torch.sum(torch.log(reduc_phi_uncensor)) - torch.sum(torch.log(reduc_z_censor)) - torch.sum(torch.log(reduc_z_uncensor))) 71 | 72 | if reduction == 'mean': 73 | loss = loss / E.shape[0] 74 | 75 | return loss 76 | -------------------------------------------------------------------------------- /models/vae_alltask_gn_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .basic_model import BasicModel 4 | from . import networks 5 | from . import losses 6 | from torch.nn import functional as F 7 | from sklearn import metrics 8 | 9 | 10 | class VaeAlltaskGNModel(BasicModel): 11 | """ 12 | This class implements the VAE multitasking model with GradNorm (all tasks), using the VAE framework with the multiple downstream tasks. 13 | """ 14 | @staticmethod 15 | def modify_commandline_parameters(parser, is_train=True): 16 | # Downstream task network 17 | parser.set_defaults(net_down='multi_FC_alltask') 18 | # Survival prediction related 19 | parser.add_argument('--survival_loss', type=str, default='MTLR', help='choose the survival loss') 20 | parser.add_argument('--survival_T_max', type=float, default=-1, help='maximum T value for survival prediction task') 21 | parser.add_argument('--time_num', type=int, default=256, help='number of time intervals in the survival model') 22 | # Classification related 23 | parser.add_argument('--class_num', type=int, default=0, help='the number of classes for the classification task') 24 | # Regression related 25 | parser.add_argument('--regression_scale', type=int, default=1, help='normalization scale for y in regression task') 26 | parser.add_argument('--dist_loss', type=str, default='L1', help='choose the distance loss for regression task, options: [MSE | L1]') 27 | # GradNorm ralated 28 | parser.add_argument('--alpha', type=float, default=1.5, help='the additional hyperparameter for GradNorm') 29 | parser.add_argument('--lr_gn', type=float, default=1e-3, help='the learning rate for GradNorm') 30 | parser.add_argument('--k_survival', type=float, default=1.0, help='initial weight for the survival loss') 31 | parser.add_argument('--k_classifier', type=float, default=1.0, help='initial weight for the classifier loss') 32 | parser.add_argument('--k_regression', type=float, default=1.0, help='initial weight for the regression loss') 33 | # Number of tasks 34 | parser.add_argument('--task_num', type=int, default=7, help='the number of downstream tasks') 35 | return parser 36 | 37 | def __init__(self, param): 38 | """ 39 | Initialize the VAE_multitask class. 40 | """ 41 | BasicModel.__init__(self, param) 42 | # specify the training losses you want to print out. 43 | if param.omics_mode == 'abc': 44 | self.loss_names = ['recon_A', 'recon_B', 'recon_C', 'kl'] 45 | if param.omics_mode == 'ab': 46 | self.loss_names = ['recon_A', 'recon_B', 'kl'] 47 | elif param.omics_mode == 'b': 48 | self.loss_names = ['recon_B', 'kl'] 49 | elif param.omics_mode == 'a': 50 | self.loss_names = ['recon_A', 'kl'] 51 | elif param.omics_mode == 'c': 52 | self.loss_names = ['recon_C', 'kl'] 53 | self.loss_names.extend(['survival', 'classifier_1', 'classifier_2', 'classifier_3', 'classifier_4', 'classifier_5', 'regression', 'gradient', 'w_sur', 'w_cla_1', 'w_cla_2', 'w_cla_3', 'w_cla_4', 'w_cla_5', 'w_reg']) 54 | # specify the models you want to save to the disk and load. 55 | self.model_names = ['All'] 56 | 57 | # input tensor 58 | self.input_omics = [] 59 | self.data_index = None # The indexes of input data 60 | self.survival_T = None 61 | self.survival_E = None 62 | self.y_true = None 63 | self.label = None 64 | self.value = None 65 | 66 | # output tensor 67 | self.z = None 68 | self.recon_omics = None 69 | self.mean = None 70 | self.log_var = None 71 | self.y_out_sur = None 72 | self.y_out_cla = None 73 | self.y_out_reg = None 74 | 75 | # specify the metrics you want to print out. 76 | self.metric_names = ['accuracy_1', 'accuracy_2', 'accuracy_3', 'accuracy_4', 'accuracy_5', 'rmse'] 77 | 78 | # define the network 79 | self.netAll = networks.define_net(param.net_VAE, param.net_down, param.omics_dims, param.omics_mode, 80 | param.norm_type, param.filter_num, param.conv_k_size, param.leaky_slope, 81 | param.dropout_p, param.latent_space_dim, param.class_num, param.time_num, param.task_num, 82 | param.init_type, param.init_gain, self.gpu_ids) 83 | 84 | # define the reconstruction loss 85 | self.lossFuncRecon = losses.get_loss_func(param.recon_loss, param.reduction) 86 | # define the classification loss 87 | self.lossFuncClass = losses.get_loss_func('CE', param.reduction) 88 | # define the regression distance loss 89 | self.lossFuncDist = losses.get_loss_func(param.dist_loss, param.reduction) 90 | self.loss_recon_A = None 91 | self.loss_recon_B = None 92 | self.loss_recon_C = None 93 | self.loss_recon = None 94 | self.loss_kl = None 95 | self.loss_survival = None 96 | self.loss_classifier_1 = None 97 | self.loss_classifier_2 = None 98 | self.loss_classifier_3 = None 99 | self.loss_classifier_4 = None 100 | self.loss_classifier_5 = None 101 | self.loss_regression = None 102 | self.loss_gradient = 0 103 | 104 | self.loss_w_sur = None 105 | self.loss_w_cla_1 = None 106 | self.loss_w_cla_2 = None 107 | self.loss_w_cla_3 = None 108 | self.loss_w_cla_4 = None 109 | self.loss_w_cla_5 = None 110 | self.loss_w_reg = None 111 | 112 | self.task_losses = None 113 | self.weighted_losses = None 114 | self.initial_losses = None 115 | 116 | self.metric_accuracy_1 = None 117 | self.metric_accuracy_2 = None 118 | self.metric_accuracy_3 = None 119 | self.metric_accuracy_4 = None 120 | self.metric_accuracy_5 = None 121 | self.metric_rmse = None 122 | 123 | if param.survival_loss == 'MTLR': 124 | self.tri_matrix_1 = self.get_tri_matrix(dimension_type=1) 125 | self.tri_matrix_2 = self.get_tri_matrix(dimension_type=2) 126 | 127 | # Weights of multiple downstream tasks 128 | self.loss_weights = nn.Parameter(torch.ones(param.task_num, requires_grad=True, device=self.device)) 129 | 130 | if self.isTrain: 131 | # Set the optimizer 132 | self.optimizer_All = torch.optim.Adam([{'params': self.netAll.parameters(), 'lr': param.lr, 'betas': (param.beta1, 0.999), 'weight_decay': param.weight_decay}, 133 | {'params': self.loss_weights, 'lr': param.lr_gn}]) 134 | self.optimizers.append(self.optimizer_All) 135 | 136 | def set_input(self, input_dict): 137 | """ 138 | Unpack input data from the output dictionary of the dataloader 139 | 140 | Parameters: 141 | input_dict (dict): include the data tensor and its index. 142 | """ 143 | self.input_omics = [] 144 | for i in range(0, 3): 145 | if i == 1 and self.param.ch_separate: 146 | input_B = [] 147 | for ch in range(0, 23): 148 | input_B.append(input_dict['input_omics'][1][ch].to(self.device)) 149 | self.input_omics.append(input_B) 150 | else: 151 | self.input_omics.append(input_dict['input_omics'][i].to(self.device)) 152 | 153 | self.data_index = input_dict['index'] 154 | self.survival_T = input_dict['survival_T'].to(self.device) 155 | self.survival_E = input_dict['survival_E'].to(self.device) 156 | self.y_true = input_dict['y_true'].to(self.device) 157 | self.label = [] 158 | for i in range(self.param.task_num - 2): 159 | self.label.append(input_dict['label'][i].to(self.device)) 160 | self.value = input_dict['value'].to(self.device) 161 | 162 | def forward(self): 163 | # Get the output tensor 164 | self.z, self.recon_omics, self.mean, self.log_var, self.y_out_sur, self.y_out_cla, self.y_out_reg = self.netAll(self.input_omics) 165 | # define the latent 166 | self.latent = self.mean 167 | 168 | def cal_losses(self): 169 | """Calculate losses""" 170 | # Calculate the reconstruction loss for A 171 | if self.param.omics_mode == 'a' or self.param.omics_mode == 'ab' or self.param.omics_mode == 'abc': 172 | self.loss_recon_A = self.lossFuncRecon(self.recon_omics[0], self.input_omics[0]) 173 | else: 174 | self.loss_recon_A = 0 175 | # Calculate the reconstruction loss for B 176 | if self.param.omics_mode == 'b' or self.param.omics_mode == 'ab' or self.param.omics_mode == 'abc': 177 | if self.param.ch_separate: 178 | recon_omics_B = torch.cat(self.recon_omics[1], -1) 179 | input_omics_B = torch.cat(self.input_omics[1], -1) 180 | self.loss_recon_B = self.lossFuncRecon(recon_omics_B, input_omics_B) 181 | else: 182 | self.loss_recon_B = self.lossFuncRecon(self.recon_omics[1], self.input_omics[1]) 183 | else: 184 | self.loss_recon_B = 0 185 | # Calculate the reconstruction loss for C 186 | if self.param.omics_mode == 'c' or self.param.omics_mode == 'abc': 187 | self.loss_recon_C = self.lossFuncRecon(self.recon_omics[2], self.input_omics[2]) 188 | else: 189 | self.loss_recon_C = 0 190 | # Overall reconstruction loss 191 | if self.param.reduction == 'sum': 192 | self.loss_recon = self.loss_recon_A + self.loss_recon_B + self.loss_recon_C 193 | elif self.param.reduction == 'mean': 194 | self.loss_recon = (self.loss_recon_A + self.loss_recon_B + self.loss_recon_C) / self.param.omics_num 195 | # Calculate the kl loss 196 | self.loss_kl = losses.kl_loss(self.mean, self.log_var, self.param.reduction) 197 | # Calculate the overall vae loss (embedding loss) 198 | # LOSS EMBED 199 | self.loss_embed = self.loss_recon + self.param.k_kl * self.loss_kl 200 | 201 | # Calculate the survival loss 202 | if self.param.survival_loss == 'MTLR': 203 | self.loss_survival = losses.MTLR_survival_loss(self.y_out_sur, self.y_true, self.survival_E, self.tri_matrix_1, self.param.reduction) 204 | # Calculate the classification loss 205 | self.loss_classifier_1 = self.lossFuncClass(self.y_out_cla[0], self.label[0]) 206 | self.loss_classifier_2 = self.lossFuncClass(self.y_out_cla[1], self.label[1]) 207 | self.loss_classifier_3 = self.lossFuncClass(self.y_out_cla[2], self.label[2]) 208 | self.loss_classifier_4 = self.lossFuncClass(self.y_out_cla[3], self.label[3]) 209 | self.loss_classifier_5 = self.lossFuncClass(self.y_out_cla[4], self.label[4]) 210 | # Calculate the regression loss 211 | self.loss_regression = self.lossFuncDist(self.y_out_reg.squeeze().type(torch.float32), (self.value / self.param.regression_scale).type(torch.float32)) 212 | # Calculate the weighted downstream losses 213 | # Add initial weights 214 | self.task_losses = torch.stack([self.param.k_survival * self.loss_survival, self.param.k_classifier * self.loss_classifier_1, self.param.k_classifier * self.loss_classifier_2, self.param.k_classifier * self.loss_classifier_3, self.param.k_classifier * self.loss_classifier_4, self.param.k_classifier * self.loss_classifier_5, self.param.k_regression * self.loss_regression]) 215 | self.weighted_losses = self.loss_weights * self.task_losses 216 | 217 | # LOSS DOWN 218 | self.loss_down = self.weighted_losses.sum() 219 | 220 | self.loss_All = self.param.k_embed * self.loss_embed + self.loss_down 221 | 222 | # Log the loss weights 223 | self.loss_w_sur = self.loss_weights[0] * self.param.k_survival 224 | self.loss_w_cla_1 = self.loss_weights[1] * self.param.k_classifier 225 | self.loss_w_cla_2 = self.loss_weights[2] * self.param.k_classifier 226 | self.loss_w_cla_3 = self.loss_weights[3] * self.param.k_classifier 227 | self.loss_w_cla_4 = self.loss_weights[4] * self.param.k_classifier 228 | self.loss_w_cla_5 = self.loss_weights[5] * self.param.k_classifier 229 | self.loss_w_reg = self.loss_weights[6] * self.param.k_regression 230 | 231 | def update(self): 232 | if self.phase == 'p1': 233 | self.forward() 234 | self.optimizer_All.zero_grad() # Set gradients to zero 235 | self.cal_losses() # Calculate losses 236 | self.loss_embed.backward() # Backpropagation 237 | self.optimizer_All.step() # Update weights 238 | elif self.phase == 'p2': 239 | self.forward() 240 | self.optimizer_All.zero_grad() # Set gradients to zero 241 | self.cal_losses() # Calculate losses 242 | self.loss_down.backward() # Backpropagation 243 | self.optimizer_All.step() # Update weights 244 | elif self.phase == 'p3': 245 | self.forward() 246 | self.cal_losses() # Calculate losses 247 | self.optimizer_All.zero_grad() # Set gradients to zero 248 | 249 | # Calculate the GradNorm gradients 250 | if isinstance(self.netAll, torch.nn.DataParallel): 251 | W = list(self.netAll.module.get_last_encode_layer().parameters()) 252 | else: 253 | W = list(self.netAll.get_last_encode_layer().parameters()) 254 | grad_norms = [] 255 | for weight, loss in zip(self.loss_weights, self.task_losses): 256 | grad = torch.autograd.grad(loss, W, retain_graph=True) 257 | grad_norms.append(torch.norm(weight * grad[0])) 258 | grad_norms = torch.stack(grad_norms) 259 | 260 | if self.iter == 0: 261 | self.initial_losses = self.task_losses.detach() 262 | 263 | # Calculate the constant targets 264 | with torch.no_grad(): 265 | # loss ratios 266 | loss_ratios = self.task_losses / self.initial_losses 267 | # inverse training rate 268 | inverse_train_rates = loss_ratios / loss_ratios.mean() 269 | constant_terms = grad_norms.mean() * (inverse_train_rates ** self.param.alpha) 270 | 271 | # Calculate the gradient loss 272 | self.loss_gradient = (grad_norms - constant_terms).abs().sum() 273 | # Set the gradients of weights 274 | loss_weights_grad = torch.autograd.grad(self.loss_gradient, self.loss_weights)[0] 275 | 276 | self.loss_All.backward() 277 | 278 | self.loss_weights.grad = loss_weights_grad 279 | 280 | self.optimizer_All.step() # Update weights 281 | 282 | # Re-normalize the losses weights 283 | with torch.no_grad(): 284 | normalize_coeff = len(self.loss_weights) / self.loss_weights.sum() 285 | self.loss_weights.data = self.loss_weights.data * normalize_coeff 286 | 287 | def get_down_output(self): 288 | """ 289 | Get output from downstream task 290 | """ 291 | with torch.no_grad(): 292 | index = self.data_index 293 | # Survival 294 | y_true_E = self.survival_E 295 | y_true_T = self.survival_T 296 | y_out_sur = self.y_out_sur 297 | predict = self.predict_risk() 298 | # density = predict['density'] 299 | survival = predict['survival'] 300 | # hazard = predict['hazard'] 301 | risk = predict['risk'] 302 | 303 | # Classification 304 | y_prob_cla = [] 305 | y_pred_cla = [] 306 | y_true_cla = [] 307 | for i in range(self.param.task_num - 2): 308 | y_prob_cla.append(F.softmax(self.y_out_cla[i], dim=1)) 309 | _, y_pred_cla_i = torch.max(y_prob_cla[i], 1) 310 | y_pred_cla.append(y_pred_cla_i) 311 | y_true_cla.append(self.label[i]) 312 | 313 | # Regression 314 | y_true_reg = self.value 315 | y_pred_reg = self.y_out_reg * self.param.regression_scale 316 | 317 | return {'index': index, 'y_true_E': y_true_E, 'y_true_T': y_true_T, 'survival': survival, 'risk': risk, 318 | 'y_out_sur': y_out_sur, 'y_true_cla': y_true_cla, 'y_pred_cla': y_pred_cla, 319 | 'y_prob_cla': y_prob_cla, 'y_true_reg': y_true_reg, 'y_pred_reg': y_pred_reg} 320 | 321 | def calculate_current_metrics(self, output_dict): 322 | """ 323 | Calculate current metrics 324 | """ 325 | self.metric_accuracy_1 = (output_dict['y_true_cla'][0] == output_dict['y_pred_cla'][0]).sum().item() / len( 326 | output_dict['y_true_cla'][0]) 327 | self.metric_accuracy_2 = (output_dict['y_true_cla'][1] == output_dict['y_pred_cla'][1]).sum().item() / len( 328 | output_dict['y_true_cla'][1]) 329 | self.metric_accuracy_3 = (output_dict['y_true_cla'][2] == output_dict['y_pred_cla'][2]).sum().item() / len( 330 | output_dict['y_true_cla'][2]) 331 | self.metric_accuracy_4 = (output_dict['y_true_cla'][3] == output_dict['y_pred_cla'][3]).sum().item() / len( 332 | output_dict['y_true_cla'][3]) 333 | self.metric_accuracy_5 = (output_dict['y_true_cla'][4] == output_dict['y_pred_cla'][4]).sum().item() / len( 334 | output_dict['y_true_cla'][4]) 335 | 336 | y_true_reg = output_dict['y_true_reg'].cpu().numpy() 337 | y_pred_reg = output_dict['y_pred_reg'].cpu().detach().numpy() 338 | self.metric_rmse = metrics.mean_squared_error(y_true_reg, y_pred_reg, squared=False) 339 | 340 | def get_tri_matrix(self, dimension_type=1): 341 | """ 342 | Get tensor of the triangular matrix 343 | """ 344 | if dimension_type == 1: 345 | ones_matrix = torch.ones(self.param.time_num, self.param.time_num + 1, device=self.device) 346 | else: 347 | ones_matrix = torch.ones(self.param.time_num + 1, self.param.time_num + 1, device=self.device) 348 | tri_matrix = torch.tril(ones_matrix) 349 | return tri_matrix 350 | 351 | def predict_risk(self): 352 | """ 353 | Predict the density, survival and hazard function, as well as the risk score 354 | """ 355 | if self.param.survival_loss == 'MTLR': 356 | phi = torch.exp(torch.mm(self.y_out_sur, self.tri_matrix_1)) 357 | div = torch.repeat_interleave(torch.sum(phi, 1).reshape(-1, 1), phi.shape[1], dim=1) 358 | 359 | density = phi / div 360 | survival = torch.mm(density, self.tri_matrix_2) 361 | hazard = density[:, :-1] / survival[:, 1:] 362 | 363 | cumulative_hazard = torch.cumsum(hazard, dim=1) 364 | risk = torch.sum(cumulative_hazard, 1) 365 | 366 | return {'density': density, 'survival': survival, 'hazard': hazard, 'risk': risk} 367 | -------------------------------------------------------------------------------- /models/vae_alltask_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .vae_basic_model import VaeBasicModel 3 | from . import networks 4 | from . import losses 5 | from torch.nn import functional as F 6 | from sklearn import metrics 7 | 8 | 9 | class VaeAlltaskModel(VaeBasicModel): 10 | """ 11 | This class implements the VAE multitasking model with all downstream tasks (5 classifiers + 1 regressor + 1 survival predictor), using the VAE framework with the multiple downstream tasks. 12 | """ 13 | @staticmethod 14 | def modify_commandline_parameters(parser, is_train=True): 15 | # Downstream task network 16 | parser.set_defaults(net_down='multi_FC_alltask') 17 | # Survival prediction related 18 | parser.add_argument('--survival_loss', type=str, default='MTLR', help='choose the survival loss') 19 | parser.add_argument('--survival_T_max', type=float, default=-1, help='maximum T value for survival prediction task') 20 | parser.add_argument('--time_num', type=int, default=256, help='number of time intervals in the survival model') 21 | # Classification related 22 | parser.add_argument('--class_num', type=int, default=0, help='the number of classes for the classification task') 23 | # Regression related 24 | parser.add_argument('--regression_scale', type=int, default=1, help='normalization scale for y in regression task') 25 | parser.add_argument('--dist_loss', type=str, default='L1', help='choose the distance loss for regression task, options: [MSE | L1]') 26 | # Loss combined 27 | parser.add_argument('--k_survival', type=float, default=1, 28 | help='weight for the survival loss') 29 | parser.add_argument('--k_classifier', type=float, default=1, 30 | help='weight for the classifier loss') 31 | parser.add_argument('--k_regression', type=float, default=1, 32 | help='weight for the regression loss') 33 | # Number of tasks 34 | parser.add_argument('--task_num', type=int, default=7, 35 | help='the number of downstream tasks') 36 | return parser 37 | 38 | def __init__(self, param): 39 | """ 40 | Initialize the VAE_multitask class. 41 | """ 42 | VaeBasicModel.__init__(self, param) 43 | # specify the training losses you want to print out. 44 | self.loss_names.extend(['survival', 'classifier_1', 'classifier_2', 'classifier_3', 'classifier_4', 'classifier_5', 'regression']) 45 | # specify the metrics you want to print out. 46 | self.metric_names = ['accuracy_1', 'accuracy_2', 'accuracy_3', 'accuracy_4', 'accuracy_5', 'rmse'] 47 | # input tensor 48 | self.survival_T = None 49 | self.survival_E = None 50 | self.y_true = None 51 | self.label = None 52 | self.value = None 53 | # output tensor 54 | self.y_out_sur = None 55 | self.y_out_cla = None 56 | self.y_out_reg = None 57 | # define the network 58 | self.netDown = networks.define_down(param.net_down, param.norm_type, param.leaky_slope, param.dropout_p, 59 | param.latent_space_dim, param.class_num, param.time_num, param.task_num, param.init_type, 60 | param.init_gain, self.gpu_ids) 61 | # define the classification loss 62 | self.lossFuncClass = losses.get_loss_func('CE', param.reduction) 63 | # define the regression distance loss 64 | self.lossFuncDist = losses.get_loss_func(param.dist_loss, param.reduction) 65 | self.loss_survival = None 66 | self.loss_classifier_1 = None 67 | self.loss_classifier_2 = None 68 | self.loss_classifier_3 = None 69 | self.loss_classifier_4 = None 70 | self.loss_classifier_5 = None 71 | self.loss_regression = None 72 | self.metric_accuracy_1 = None 73 | self.metric_accuracy_2 = None 74 | self.metric_accuracy_3 = None 75 | self.metric_accuracy_4 = None 76 | self.metric_accuracy_5 = None 77 | self.metric_rmse = None 78 | 79 | if param.survival_loss == 'MTLR': 80 | self.tri_matrix_1 = self.get_tri_matrix(dimension_type=1) 81 | self.tri_matrix_2 = self.get_tri_matrix(dimension_type=2) 82 | 83 | if self.isTrain: 84 | # Set the optimizer 85 | self.optimizer_Down = torch.optim.Adam(self.netDown.parameters(), lr=param.lr, betas=(param.beta1, 0.999), weight_decay=param.weight_decay) 86 | # optimizer list was already defined in BaseModel 87 | self.optimizers.append(self.optimizer_Down) 88 | 89 | def set_input(self, input_dict): 90 | """ 91 | Unpack input data from the output dictionary of the dataloader 92 | 93 | Parameters: 94 | input_dict (dict): include the data tensor and its index. 95 | """ 96 | VaeBasicModel.set_input(self, input_dict) 97 | self.survival_T = input_dict['survival_T'].to(self.device) 98 | self.survival_E = input_dict['survival_E'].to(self.device) 99 | self.y_true = input_dict['y_true'].to(self.device) 100 | self.label = [] 101 | for i in range(self.param.task_num-2): 102 | self.label.append(input_dict['label'][i].to(self.device)) 103 | self.value = input_dict['value'].to(self.device) 104 | 105 | def forward(self): 106 | # Get the output tensor 107 | VaeBasicModel.forward(self) 108 | self.y_out_sur, self.y_out_cla, self.y_out_reg = self.netDown(self.latent) 109 | 110 | def cal_losses(self): 111 | """Calculate losses""" 112 | VaeBasicModel.cal_losses(self) 113 | # Calculate the survival loss 114 | if self.param.survival_loss == 'MTLR': 115 | self.loss_survival = losses.MTLR_survival_loss(self.y_out_sur, self.y_true, self.survival_E, self.tri_matrix_1, self.param.reduction) 116 | # Calculate the classification loss 117 | self.loss_classifier_1 = self.lossFuncClass(self.y_out_cla[0], self.label[0]) 118 | self.loss_classifier_2 = self.lossFuncClass(self.y_out_cla[1], self.label[1]) 119 | self.loss_classifier_3 = self.lossFuncClass(self.y_out_cla[2], self.label[2]) 120 | self.loss_classifier_4 = self.lossFuncClass(self.y_out_cla[3], self.label[3]) 121 | self.loss_classifier_5 = self.lossFuncClass(self.y_out_cla[4], self.label[4]) 122 | # Calculate the regression loss 123 | self.loss_regression = self.lossFuncDist(self.y_out_reg.squeeze().type(torch.float32), (self.value / self.param.regression_scale).type(torch.float32)) 124 | # LOSS DOWN 125 | self.loss_down = self.param.k_survival * self.loss_survival + self.param.k_classifier * self.loss_classifier_1 + self.param.k_classifier * self.loss_classifier_2 + self.param.k_classifier * self.loss_classifier_3 + self.param.k_classifier * self.loss_classifier_4 + self.param.k_classifier * self.loss_classifier_5 + self.param.k_regression * self.loss_regression 126 | 127 | self.loss_All = self.param.k_embed * self.loss_embed + self.loss_down 128 | 129 | def update(self): 130 | VaeBasicModel.update(self) 131 | 132 | def get_down_output(self): 133 | """ 134 | Get output from downstream task 135 | """ 136 | with torch.no_grad(): 137 | index = self.data_index 138 | # Survival 139 | y_true_E = self.survival_E 140 | y_true_T = self.survival_T 141 | y_out_sur = self.y_out_sur 142 | predict = self.predict_risk() 143 | # density = predict['density'] 144 | survival = predict['survival'] 145 | # hazard = predict['hazard'] 146 | risk = predict['risk'] 147 | 148 | # Classification 149 | y_prob_cla = [] 150 | y_pred_cla = [] 151 | y_true_cla = [] 152 | for i in range(self.param.task_num-2): 153 | y_prob_cla.append(F.softmax(self.y_out_cla[i], dim=1)) 154 | _, y_pred_cla_i = torch.max(y_prob_cla[i], 1) 155 | y_pred_cla.append(y_pred_cla_i) 156 | y_true_cla.append(self.label[i]) 157 | 158 | # Regression 159 | y_true_reg = self.value 160 | y_pred_reg = self.y_out_reg * self.param.regression_scale 161 | 162 | return {'index': index, 'y_true_E': y_true_E, 'y_true_T': y_true_T, 'survival': survival, 'risk': risk, 'y_out_sur': y_out_sur, 'y_true_cla': y_true_cla, 'y_pred_cla': y_pred_cla, 'y_prob_cla': y_prob_cla, 'y_true_reg': y_true_reg, 'y_pred_reg': y_pred_reg} 163 | 164 | def calculate_current_metrics(self, output_dict): 165 | """ 166 | Calculate current metrics 167 | """ 168 | self.metric_accuracy_1 = (output_dict['y_true_cla'][0] == output_dict['y_pred_cla'][0]).sum().item() / len(output_dict['y_true_cla'][0]) 169 | self.metric_accuracy_2 = (output_dict['y_true_cla'][1] == output_dict['y_pred_cla'][1]).sum().item() / len(output_dict['y_true_cla'][1]) 170 | self.metric_accuracy_3 = (output_dict['y_true_cla'][2] == output_dict['y_pred_cla'][2]).sum().item() / len(output_dict['y_true_cla'][2]) 171 | self.metric_accuracy_4 = (output_dict['y_true_cla'][3] == output_dict['y_pred_cla'][3]).sum().item() / len(output_dict['y_true_cla'][3]) 172 | self.metric_accuracy_5 = (output_dict['y_true_cla'][4] == output_dict['y_pred_cla'][4]).sum().item() / len(output_dict['y_true_cla'][4]) 173 | 174 | y_true_reg = output_dict['y_true_reg'].cpu().numpy() 175 | y_pred_reg = output_dict['y_pred_reg'].cpu().detach().numpy() 176 | self.metric_rmse = metrics.mean_squared_error(y_true_reg, y_pred_reg, squared=False) 177 | 178 | def get_tri_matrix(self, dimension_type=1): 179 | """ 180 | Get tensor of the triangular matrix 181 | """ 182 | if dimension_type == 1: 183 | ones_matrix = torch.ones(self.param.time_num, self.param.time_num + 1, device=self.device) 184 | else: 185 | ones_matrix = torch.ones(self.param.time_num + 1, self.param.time_num + 1, device=self.device) 186 | tri_matrix = torch.tril(ones_matrix) 187 | return tri_matrix 188 | 189 | def predict_risk(self): 190 | """ 191 | Predict the density, survival and hazard function, as well as the risk score 192 | """ 193 | if self.param.survival_loss == 'MTLR': 194 | phi = torch.exp(torch.mm(self.y_out_sur, self.tri_matrix_1)) 195 | div = torch.repeat_interleave(torch.sum(phi, 1).reshape(-1, 1), phi.shape[1], dim=1) 196 | 197 | density = phi / div 198 | survival = torch.mm(density, self.tri_matrix_2) 199 | hazard = density[:, :-1] / survival[:, 1:] 200 | 201 | cumulative_hazard = torch.cumsum(hazard, dim=1) 202 | risk = torch.sum(cumulative_hazard, 1) 203 | 204 | return {'density': density, 'survival': survival, 'hazard': hazard, 'risk': risk} 205 | -------------------------------------------------------------------------------- /models/vae_basic_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .basic_model import BasicModel 3 | from . import networks 4 | from . import losses 5 | 6 | 7 | class VaeBasicModel(BasicModel): 8 | """ 9 | This is the basic VAE model class, called by all other VAE son classes. 10 | """ 11 | 12 | def __init__(self, param): 13 | """ 14 | Initialize the VAE basic class. 15 | """ 16 | BasicModel.__init__(self, param) 17 | # specify the training losses you want to print out. 18 | if param.omics_mode == 'abc': 19 | self.loss_names = ['recon_A', 'recon_B', 'recon_C', 'kl'] 20 | if param.omics_mode == 'ab': 21 | self.loss_names = ['recon_A', 'recon_B', 'kl'] 22 | elif param.omics_mode == 'b': 23 | self.loss_names = ['recon_B', 'kl'] 24 | elif param.omics_mode == 'a': 25 | self.loss_names = ['recon_A', 'kl'] 26 | elif param.omics_mode == 'c': 27 | self.loss_names = ['recon_C', 'kl'] 28 | # specify the models you want to save to the disk and load. 29 | self.model_names = ['Embed', 'Down'] 30 | 31 | # input tensor 32 | self.input_omics = [] 33 | self.data_index = None # The indexes of input data 34 | 35 | # output tensor 36 | self.z = None 37 | self.recon_omics = None 38 | self.mean = None 39 | self.log_var = None 40 | 41 | # define the network 42 | self.netEmbed = networks.define_VAE(param.net_VAE, param.omics_dims, param.omics_mode, 43 | param.norm_type, param.filter_num, param.conv_k_size, param.leaky_slope, 44 | param.dropout_p, param.latent_space_dim, param.init_type, param.init_gain, 45 | self.gpu_ids) 46 | 47 | # define the reconstruction loss 48 | self.lossFuncRecon = losses.get_loss_func(param.recon_loss, param.reduction) 49 | 50 | self.loss_recon_A = None 51 | self.loss_recon_B = None 52 | self.loss_recon_C = None 53 | self.loss_recon = None 54 | self.loss_kl = None 55 | 56 | if self.isTrain: 57 | # Set the optimizer 58 | # netEmbed and netDown can set to different initial learning rate 59 | self.optimizer_Embed = torch.optim.Adam(self.netEmbed.parameters(), lr=param.lr, betas=(param.beta1, 0.999), weight_decay=param.weight_decay) 60 | # optimizer list was already defined in BaseModel 61 | self.optimizers.append(self.optimizer_Embed) 62 | 63 | self.optimizer_Down = None 64 | 65 | def set_input(self, input_dict): 66 | """ 67 | Unpack input data from the output dictionary of the dataloader 68 | 69 | Parameters: 70 | input_dict (dict): include the data tensor and its index. 71 | """ 72 | self.input_omics = [] 73 | for i in range(0, 3): 74 | if i == 1 and self.param.ch_separate: 75 | input_B = [] 76 | for ch in range(0, 23): 77 | input_B.append(input_dict['input_omics'][1][ch].to(self.device)) 78 | self.input_omics.append(input_B) 79 | else: 80 | self.input_omics.append(input_dict['input_omics'][i].to(self.device)) 81 | 82 | self.data_index = input_dict['index'] 83 | 84 | def forward(self): 85 | # Get the output tensor 86 | self.z, self.recon_omics, self.mean, self.log_var = self.netEmbed(self.input_omics) 87 | # define the latent 88 | if self.phase == 'p1' or self.phase == 'p3': 89 | self.latent = self.mean 90 | elif self.phase == 'p2': 91 | self.latent = self.mean.detach() 92 | 93 | def cal_losses(self): 94 | """Calculate losses""" 95 | # Calculate the reconstruction loss for A 96 | if self.param.omics_mode == 'a' or self.param.omics_mode == 'ab' or self.param.omics_mode == 'abc': 97 | self.loss_recon_A = self.lossFuncRecon(self.recon_omics[0], self.input_omics[0]) 98 | else: 99 | self.loss_recon_A = 0 100 | # Calculate the reconstruction loss for B 101 | if self.param.omics_mode == 'b' or self.param.omics_mode == 'ab' or self.param.omics_mode == 'abc': 102 | if self.param.ch_separate: 103 | recon_omics_B = torch.cat(self.recon_omics[1], -1) 104 | input_omics_B = torch.cat(self.input_omics[1], -1) 105 | self.loss_recon_B = self.lossFuncRecon(recon_omics_B, input_omics_B) 106 | else: 107 | self.loss_recon_B = self.lossFuncRecon(self.recon_omics[1], self.input_omics[1]) 108 | else: 109 | self.loss_recon_B = 0 110 | # Calculate the reconstruction loss for C 111 | if self.param.omics_mode == 'c' or self.param.omics_mode == 'abc': 112 | self.loss_recon_C = self.lossFuncRecon(self.recon_omics[2], self.input_omics[2]) 113 | else: 114 | self.loss_recon_C = 0 115 | # Overall reconstruction loss 116 | if self.param.reduction == 'sum': 117 | self.loss_recon = self.loss_recon_A + self.loss_recon_B + self.loss_recon_C 118 | elif self.param.reduction == 'mean': 119 | self.loss_recon = (self.loss_recon_A + self.loss_recon_B + self.loss_recon_C) / self.param.omics_num 120 | # Calculate the kl loss 121 | self.loss_kl = losses.kl_loss(self.mean, self.log_var, self.param.reduction) 122 | # Calculate the overall vae loss (embedding loss) 123 | # LOSS EMBED 124 | self.loss_embed = self.loss_recon + self.param.k_kl * self.loss_kl 125 | 126 | def update(self): 127 | if self.phase == 'p1': 128 | self.forward() 129 | self.optimizer_Embed.zero_grad() # Set gradients to zero 130 | self.cal_losses() # Calculate losses 131 | self.loss_embed.backward() # Backpropagation 132 | self.optimizer_Embed.step() # Update weights 133 | elif self.phase == 'p2': 134 | self.forward() 135 | self.optimizer_Down.zero_grad() # Set gradients to zero 136 | self.cal_losses() # Calculate losses 137 | self.loss_down.backward() # Backpropagation 138 | self.optimizer_Down.step() # Update weights 139 | elif self.phase == 'p3': 140 | self.forward() 141 | self.optimizer_Embed.zero_grad() # Set gradients to zero 142 | self.optimizer_Down.zero_grad() 143 | self.cal_losses() # Calculate losses 144 | self.loss_All.backward() # Backpropagation 145 | self.optimizer_Embed.step() # Update weights 146 | self.optimizer_Down.step() 147 | -------------------------------------------------------------------------------- /models/vae_classifier_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .vae_basic_model import VaeBasicModel 3 | from . import networks 4 | from . import losses 5 | from torch.nn import functional as F 6 | 7 | 8 | class VaeClassifierModel(VaeBasicModel): 9 | """ 10 | This class implements the VAE classifier model, using the VAE framework with the classification downstream task. 11 | """ 12 | 13 | @staticmethod 14 | def modify_commandline_parameters(parser, is_train=True): 15 | # changing the default values of parameters to match the vae regression model 16 | parser.add_argument('--class_num', type=int, default=0, 17 | help='the number of classes for the classification task') 18 | return parser 19 | 20 | def __init__(self, param): 21 | """ 22 | Initialize the VAE_classifier class. 23 | """ 24 | VaeBasicModel.__init__(self, param) 25 | # specify the training losses you want to print out. 26 | self.loss_names.append('classifier') 27 | # specify the metrics you want to print out. 28 | self.metric_names = ['accuracy'] 29 | # input tensor 30 | self.label = None 31 | # output tensor 32 | self.y_out = None 33 | # define the network 34 | self.netDown = networks.define_down(param.net_down, param.norm_type, param.leaky_slope, param.dropout_p, 35 | param.latent_space_dim, param.class_num, None, None, param.init_type, 36 | param.init_gain, self.gpu_ids) 37 | # define the classification loss 38 | self.lossFuncClass = losses.get_loss_func('CE', param.reduction) 39 | self.loss_classifier = None 40 | self.metric_accuracy = None 41 | 42 | if self.isTrain: 43 | # Set the optimizer 44 | self.optimizer_Down = torch.optim.Adam(self.netDown.parameters(), lr=param.lr, betas=(param.beta1, 0.999), weight_decay=param.weight_decay) 45 | # optimizer list was already defined in BaseModel 46 | self.optimizers.append(self.optimizer_Down) 47 | 48 | def set_input(self, input_dict): 49 | """ 50 | Unpack input data from the output dictionary of the dataloader 51 | 52 | Parameters: 53 | input_dict (dict): include the data tensor and its index. 54 | """ 55 | VaeBasicModel.set_input(self, input_dict) 56 | self.label = input_dict['label'].to(self.device) 57 | 58 | def forward(self): 59 | VaeBasicModel.forward(self) 60 | # Get the output tensor 61 | self.y_out = self.netDown(self.latent) 62 | 63 | def cal_losses(self): 64 | """Calculate losses""" 65 | VaeBasicModel.cal_losses(self) 66 | # Calculate the classification loss (downstream loss) 67 | self.loss_classifier = self.lossFuncClass(self.y_out, self.label) 68 | # LOSS DOWN 69 | self.loss_down = self.loss_classifier 70 | 71 | self.loss_All = self.param.k_embed * self.loss_embed + self.loss_down 72 | 73 | def update(self): 74 | VaeBasicModel.update(self) 75 | 76 | def get_down_output(self): 77 | """ 78 | Get output from downstream task 79 | """ 80 | with torch.no_grad(): 81 | y_prob = F.softmax(self.y_out, dim=1) 82 | _, y_pred = torch.max(y_prob, 1) 83 | 84 | index = self.data_index 85 | y_true = self.label 86 | 87 | return {'index': index, 'y_true': y_true, 'y_pred': y_pred, 'y_prob': y_prob} 88 | 89 | def calculate_current_metrics(self, output_dict): 90 | """ 91 | Calculate current metrics 92 | """ 93 | self.metric_accuracy = (output_dict['y_true'] == output_dict['y_pred']).sum().item() / len(output_dict['y_true']) 94 | -------------------------------------------------------------------------------- /models/vae_multitask_gn_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .basic_model import BasicModel 4 | from . import networks 5 | from . import losses 6 | from torch.nn import functional as F 7 | from sklearn import metrics 8 | 9 | 10 | class VaeMultitaskGNModel(BasicModel): 11 | """ 12 | This class implements the VAE multitasking model with GradNorm, using the VAE framework with the multiple downstream tasks. 13 | """ 14 | @staticmethod 15 | def modify_commandline_parameters(parser, is_train=True): 16 | # Downstream task network 17 | parser.set_defaults(net_down='multi_FC_multitask') 18 | # Survival prediction related 19 | parser.add_argument('--survival_loss', type=str, default='MTLR', help='choose the survival loss') 20 | parser.add_argument('--survival_T_max', type=float, default=-1, help='maximum T value for survival prediction task') 21 | parser.add_argument('--time_num', type=int, default=256, help='number of time intervals in the survival model') 22 | # Classification related 23 | parser.add_argument('--class_num', type=int, default=0, help='the number of classes for the classification task') 24 | # Regression related 25 | parser.add_argument('--regression_scale', type=int, default=1, help='normalization scale for y in regression task') 26 | parser.add_argument('--dist_loss', type=str, default='L1', help='choose the distance loss for regression task, options: [MSE | L1]') 27 | # GradNorm ralated 28 | parser.add_argument('--alpha', type=float, default=1.5, help='the additional hyperparameter for GradNorm') 29 | parser.add_argument('--lr_gn', type=float, default=1e-3, help='the learning rate for GradNorm') 30 | parser.add_argument('--k_survival', type=float, default=1.0, help='initial weight for the survival loss') 31 | parser.add_argument('--k_classifier', type=float, default=1.0, help='initial weight for the classifier loss') 32 | parser.add_argument('--k_regression', type=float, default=1.0, help='initial weight for the regression loss') 33 | return parser 34 | 35 | def __init__(self, param): 36 | """ 37 | Initialize the VAE_multitask class. 38 | """ 39 | BasicModel.__init__(self, param) 40 | # specify the training losses you want to print out. 41 | if param.omics_mode == 'abc': 42 | self.loss_names = ['recon_A', 'recon_B', 'recon_C', 'kl'] 43 | if param.omics_mode == 'ab': 44 | self.loss_names = ['recon_A', 'recon_B', 'kl'] 45 | elif param.omics_mode == 'b': 46 | self.loss_names = ['recon_B', 'kl'] 47 | elif param.omics_mode == 'a': 48 | self.loss_names = ['recon_A', 'kl'] 49 | elif param.omics_mode == 'c': 50 | self.loss_names = ['recon_C', 'kl'] 51 | self.loss_names.extend(['survival', 'classifier', 'regression', 'gradient', 'w_sur', 'w_cla', 'w_reg']) 52 | # specify the models you want to save to the disk and load. 53 | self.model_names = ['All'] 54 | 55 | # input tensor 56 | self.input_omics = [] 57 | self.data_index = None # The indexes of input data 58 | self.survival_T = None 59 | self.survival_E = None 60 | self.y_true = None 61 | self.label = None 62 | self.value = None 63 | 64 | # output tensor 65 | self.z = None 66 | self.recon_omics = None 67 | self.mean = None 68 | self.log_var = None 69 | self.y_out_sur = None 70 | self.y_out_cla = None 71 | self.y_out_reg = None 72 | 73 | # specify the metrics you want to print out. 74 | self.metric_names = ['accuracy', 'rmse'] 75 | 76 | # define the network 77 | self.netAll = networks.define_net(param.net_VAE, param.net_down, param.omics_dims, param.omics_mode, 78 | param.norm_type, param.filter_num, param.conv_k_size, param.leaky_slope, 79 | param.dropout_p, param.latent_space_dim, param.class_num, param.time_num, None, 80 | param.init_type, param.init_gain, self.gpu_ids) 81 | 82 | # define the reconstruction loss 83 | self.lossFuncRecon = losses.get_loss_func(param.recon_loss, param.reduction) 84 | # define the classification loss 85 | self.lossFuncClass = losses.get_loss_func('CE', param.reduction) 86 | # define the regression distance loss 87 | self.lossFuncDist = losses.get_loss_func(param.dist_loss, param.reduction) 88 | self.loss_recon_A = None 89 | self.loss_recon_B = None 90 | self.loss_recon_C = None 91 | self.loss_recon = None 92 | self.loss_kl = None 93 | self.loss_survival = None 94 | self.loss_classifier = None 95 | self.loss_regression = None 96 | self.loss_gradient = 0 97 | 98 | self.loss_w_sur = None 99 | self.loss_w_cla = None 100 | self.loss_w_reg = None 101 | 102 | self.task_losses = None 103 | self.weighted_losses = None 104 | self.initial_losses = None 105 | 106 | self.metric_accuracy = None 107 | self.metric_rmse = None 108 | 109 | if param.survival_loss == 'MTLR': 110 | self.tri_matrix_1 = self.get_tri_matrix(dimension_type=1) 111 | self.tri_matrix_2 = self.get_tri_matrix(dimension_type=2) 112 | 113 | # Weights of multiple downstream tasks 114 | self.loss_weights = nn.Parameter(torch.ones(3, requires_grad=True, device=self.device)) 115 | 116 | if self.isTrain: 117 | # Set the optimizer 118 | self.optimizer_All = torch.optim.Adam([{'params': self.netAll.parameters(), 'lr': param.lr, 'betas': (param.beta1, 0.999), 'weight_decay': param.weight_decay}, 119 | {'params': self.loss_weights, 'lr': param.lr_gn}]) 120 | self.optimizers.append(self.optimizer_All) 121 | 122 | def set_input(self, input_dict): 123 | """ 124 | Unpack input data from the output dictionary of the dataloader 125 | 126 | Parameters: 127 | input_dict (dict): include the data tensor and its index. 128 | """ 129 | self.input_omics = [] 130 | for i in range(0, 3): 131 | if i == 1 and self.param.ch_separate: 132 | input_B = [] 133 | for ch in range(0, 23): 134 | input_B.append(input_dict['input_omics'][1][ch].to(self.device)) 135 | self.input_omics.append(input_B) 136 | else: 137 | self.input_omics.append(input_dict['input_omics'][i].to(self.device)) 138 | 139 | self.data_index = input_dict['index'] 140 | self.survival_T = input_dict['survival_T'].to(self.device) 141 | self.survival_E = input_dict['survival_E'].to(self.device) 142 | self.y_true = input_dict['y_true'].to(self.device) 143 | self.label = input_dict['label'].to(self.device) 144 | self.value = input_dict['value'].to(self.device) 145 | 146 | def forward(self): 147 | # Get the output tensor 148 | self.z, self.recon_omics, self.mean, self.log_var, self.y_out_sur, self.y_out_cla, self.y_out_reg = self.netAll(self.input_omics) 149 | # define the latent 150 | self.latent = self.mean 151 | 152 | def cal_losses(self): 153 | """Calculate losses""" 154 | # Calculate the reconstruction loss for A 155 | if self.param.omics_mode == 'a' or self.param.omics_mode == 'ab' or self.param.omics_mode == 'abc': 156 | self.loss_recon_A = self.lossFuncRecon(self.recon_omics[0], self.input_omics[0]) 157 | else: 158 | self.loss_recon_A = 0 159 | # Calculate the reconstruction loss for B 160 | if self.param.omics_mode == 'b' or self.param.omics_mode == 'ab' or self.param.omics_mode == 'abc': 161 | if self.param.ch_separate: 162 | recon_omics_B = torch.cat(self.recon_omics[1], -1) 163 | input_omics_B = torch.cat(self.input_omics[1], -1) 164 | self.loss_recon_B = self.lossFuncRecon(recon_omics_B, input_omics_B) 165 | else: 166 | self.loss_recon_B = self.lossFuncRecon(self.recon_omics[1], self.input_omics[1]) 167 | else: 168 | self.loss_recon_B = 0 169 | # Calculate the reconstruction loss for C 170 | if self.param.omics_mode == 'c' or self.param.omics_mode == 'abc': 171 | self.loss_recon_C = self.lossFuncRecon(self.recon_omics[2], self.input_omics[2]) 172 | else: 173 | self.loss_recon_C = 0 174 | # Overall reconstruction loss 175 | if self.param.reduction == 'sum': 176 | self.loss_recon = self.loss_recon_A + self.loss_recon_B + self.loss_recon_C 177 | elif self.param.reduction == 'mean': 178 | self.loss_recon = (self.loss_recon_A + self.loss_recon_B + self.loss_recon_C) / self.param.omics_num 179 | # Calculate the kl loss 180 | self.loss_kl = losses.kl_loss(self.mean, self.log_var, self.param.reduction) 181 | # Calculate the overall vae loss (embedding loss) 182 | # LOSS EMBED 183 | self.loss_embed = self.loss_recon + self.param.k_kl * self.loss_kl 184 | 185 | # Calculate the survival loss 186 | if self.param.survival_loss == 'MTLR': 187 | self.loss_survival = losses.MTLR_survival_loss(self.y_out_sur, self.y_true, self.survival_E, self.tri_matrix_1, self.param.reduction) 188 | # Calculate the classification loss 189 | self.loss_classifier = self.lossFuncClass(self.y_out_cla, self.label) 190 | # Calculate the regression loss 191 | self.loss_regression = self.lossFuncDist(self.y_out_reg.squeeze().type(torch.float32), (self.value / self.param.regression_scale).type(torch.float32)) 192 | # Calculate the weighted downstream losses 193 | # Add initial weights 194 | self.task_losses = torch.stack([self.param.k_survival * self.loss_survival, self.param.k_classifier * self.loss_classifier, self.param.k_regression * self.loss_regression]) 195 | self.weighted_losses = self.loss_weights * self.task_losses 196 | 197 | # LOSS DOWN 198 | self.loss_down = self.weighted_losses.sum() 199 | 200 | self.loss_All = self.param.k_embed * self.loss_embed + self.loss_down 201 | 202 | # Log the loss weights 203 | self.loss_w_sur = self.loss_weights[0] * self.param.k_survival 204 | self.loss_w_cla = self.loss_weights[1] * self.param.k_classifier 205 | self.loss_w_reg = self.loss_weights[2] * self.param.k_regression 206 | 207 | def update(self): 208 | if self.phase == 'p1': 209 | self.forward() 210 | self.optimizer_All.zero_grad() # Set gradients to zero 211 | self.cal_losses() # Calculate losses 212 | self.loss_embed.backward() # Backpropagation 213 | self.optimizer_All.step() # Update weights 214 | elif self.phase == 'p2': 215 | self.forward() 216 | self.optimizer_All.zero_grad() # Set gradients to zero 217 | self.cal_losses() # Calculate losses 218 | self.loss_down.backward() # Backpropagation 219 | self.optimizer_All.step() # Update weights 220 | elif self.phase == 'p3': 221 | self.forward() 222 | self.cal_losses() # Calculate losses 223 | self.optimizer_All.zero_grad() # Set gradients to zero 224 | 225 | # Calculate the GradNorm gradients 226 | if isinstance(self.netAll, torch.nn.DataParallel): 227 | W = list(self.netAll.module.get_last_encode_layer().parameters()) 228 | else: 229 | W = list(self.netAll.get_last_encode_layer().parameters()) 230 | grad_norms = [] 231 | for weight, loss in zip(self.loss_weights, self.task_losses): 232 | grad = torch.autograd.grad(loss, W, retain_graph=True) 233 | grad_norms.append(torch.norm(weight * grad[0])) 234 | grad_norms = torch.stack(grad_norms) 235 | 236 | if self.iter == 0: 237 | self.initial_losses = self.task_losses.detach() 238 | 239 | # Calculate the constant targets 240 | with torch.no_grad(): 241 | # loss ratios 242 | loss_ratios = self.task_losses / self.initial_losses 243 | # inverse training rate 244 | inverse_train_rates = loss_ratios / loss_ratios.mean() 245 | constant_terms = grad_norms.mean() * (inverse_train_rates ** self.param.alpha) 246 | 247 | # Calculate the gradient loss 248 | self.loss_gradient = (grad_norms - constant_terms).abs().sum() 249 | # Set the gradients of weights 250 | loss_weights_grad = torch.autograd.grad(self.loss_gradient, self.loss_weights)[0] 251 | 252 | self.loss_All.backward() 253 | 254 | self.loss_weights.grad = loss_weights_grad 255 | 256 | self.optimizer_All.step() # Update weights 257 | 258 | # Re-normalize the losses weights 259 | with torch.no_grad(): 260 | normalize_coeff = len(self.loss_weights) / self.loss_weights.sum() 261 | self.loss_weights.data = self.loss_weights.data * normalize_coeff 262 | 263 | def get_down_output(self): 264 | """ 265 | Get output from downstream task 266 | """ 267 | with torch.no_grad(): 268 | index = self.data_index 269 | # Survival 270 | y_true_E = self.survival_E 271 | y_true_T = self.survival_T 272 | y_out_sur = self.y_out_sur 273 | predict = self.predict_risk() 274 | # density = predict['density'] 275 | survival = predict['survival'] 276 | # hazard = predict['hazard'] 277 | risk = predict['risk'] 278 | 279 | # Classification 280 | y_prob_cla = F.softmax(self.y_out_cla, dim=1) 281 | _, y_pred_cla = torch.max(y_prob_cla, 1) 282 | y_true_cla = self.label 283 | 284 | # Regression 285 | y_true_reg = self.value 286 | y_pred_reg = self.y_out_reg * self.param.regression_scale 287 | 288 | return {'index': index, 'y_true_E': y_true_E, 'y_true_T': y_true_T, 'survival': survival, 'risk': risk, 'y_out_sur': y_out_sur, 'y_true_cla': y_true_cla, 'y_pred_cla': y_pred_cla, 'y_prob_cla': y_prob_cla, 'y_true_reg': y_true_reg, 'y_pred_reg': y_pred_reg} 289 | 290 | def calculate_current_metrics(self, output_dict): 291 | """ 292 | Calculate current metrics 293 | """ 294 | self.metric_accuracy = (output_dict['y_true_cla'] == output_dict['y_pred_cla']).sum().item() / len(output_dict['y_true_cla']) 295 | 296 | y_true_reg = output_dict['y_true_reg'].cpu().numpy() 297 | y_pred_reg = output_dict['y_pred_reg'].cpu().detach().numpy() 298 | self.metric_rmse = metrics.mean_squared_error(y_true_reg, y_pred_reg, squared=False) 299 | 300 | def get_tri_matrix(self, dimension_type=1): 301 | """ 302 | Get tensor of the triangular matrix 303 | """ 304 | if dimension_type == 1: 305 | ones_matrix = torch.ones(self.param.time_num, self.param.time_num + 1, device=self.device) 306 | else: 307 | ones_matrix = torch.ones(self.param.time_num + 1, self.param.time_num + 1, device=self.device) 308 | tri_matrix = torch.tril(ones_matrix) 309 | return tri_matrix 310 | 311 | def predict_risk(self): 312 | """ 313 | Predict the density, survival and hazard function, as well as the risk score 314 | """ 315 | if self.param.survival_loss == 'MTLR': 316 | phi = torch.exp(torch.mm(self.y_out_sur, self.tri_matrix_1)) 317 | div = torch.repeat_interleave(torch.sum(phi, 1).reshape(-1, 1), phi.shape[1], dim=1) 318 | 319 | density = phi / div 320 | survival = torch.mm(density, self.tri_matrix_2) 321 | hazard = density[:, :-1] / survival[:, 1:] 322 | 323 | cumulative_hazard = torch.cumsum(hazard, dim=1) 324 | risk = torch.sum(cumulative_hazard, 1) 325 | 326 | return {'density': density, 'survival': survival, 'hazard': hazard, 'risk': risk} 327 | -------------------------------------------------------------------------------- /models/vae_multitask_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .vae_basic_model import VaeBasicModel 3 | from . import networks 4 | from . import losses 5 | from torch.nn import functional as F 6 | from sklearn import metrics 7 | 8 | 9 | class VaeMultitaskModel(VaeBasicModel): 10 | """ 11 | This class implements the VAE multitasking model, using the VAE framework with the multiple downstream tasks. 12 | """ 13 | 14 | @staticmethod 15 | def modify_commandline_parameters(parser, is_train=True): 16 | # Downstream task network 17 | parser.set_defaults(net_down='multi_FC_multitask') 18 | # Survival prediction related 19 | parser.add_argument('--survival_loss', type=str, default='MTLR', help='choose the survival loss') 20 | parser.add_argument('--survival_T_max', type=float, default=-1, help='maximum T value for survival prediction task') 21 | parser.add_argument('--time_num', type=int, default=256, help='number of time intervals in the survival model') 22 | # Classification related 23 | parser.add_argument('--class_num', type=int, default=0, help='the number of classes for the classification task') 24 | # Regression related 25 | parser.add_argument('--regression_scale', type=int, default=1, help='normalization scale for y in regression task') 26 | parser.add_argument('--dist_loss', type=str, default='L1', help='choose the distance loss for regression task, options: [MSE | L1]') 27 | # Loss combined 28 | parser.add_argument('--k_survival', type=float, default=1, 29 | help='weight for the survival loss') 30 | parser.add_argument('--k_classifier', type=float, default=1, 31 | help='weight for the classifier loss') 32 | parser.add_argument('--k_regression', type=float, default=1, 33 | help='weight for the regression loss') 34 | return parser 35 | 36 | def __init__(self, param): 37 | """ 38 | Initialize the VAE_multitask class. 39 | """ 40 | VaeBasicModel.__init__(self, param) 41 | # specify the training losses you want to print out. 42 | self.loss_names.extend(['survival', 'classifier', 'regression']) 43 | # specify the metrics you want to print out. 44 | self.metric_names = ['accuracy', 'rmse'] 45 | # input tensor 46 | self.survival_T = None 47 | self.survival_E = None 48 | self.y_true = None 49 | self.label = None 50 | self.value = None 51 | # output tensor 52 | self.y_out_sur = None 53 | self.y_out_cla = None 54 | self.y_out_reg = None 55 | # define the network 56 | self.netDown = networks.define_down(param.net_down, param.norm_type, param.leaky_slope, param.dropout_p, 57 | param.latent_space_dim, param.class_num, param.time_num, None, param.init_type, 58 | param.init_gain, self.gpu_ids) 59 | # define the classification loss 60 | self.lossFuncClass = losses.get_loss_func('CE', param.reduction) 61 | # define the regression distance loss 62 | self.lossFuncDist = losses.get_loss_func(param.dist_loss, param.reduction) 63 | self.loss_survival = None 64 | self.loss_classifier = None 65 | self.loss_regression = None 66 | self.metric_accuracy = None 67 | self.metric_rmse = None 68 | 69 | if param.survival_loss == 'MTLR': 70 | self.tri_matrix_1 = self.get_tri_matrix(dimension_type=1) 71 | self.tri_matrix_2 = self.get_tri_matrix(dimension_type=2) 72 | 73 | if self.isTrain: 74 | # Set the optimizer 75 | self.optimizer_Down = torch.optim.Adam(self.netDown.parameters(), lr=param.lr, betas=(param.beta1, 0.999), weight_decay=param.weight_decay) 76 | # optimizer list was already defined in BaseModel 77 | self.optimizers.append(self.optimizer_Down) 78 | 79 | def set_input(self, input_dict): 80 | """ 81 | Unpack input data from the output dictionary of the dataloader 82 | 83 | Parameters: 84 | input_dict (dict): include the data tensor and its index. 85 | """ 86 | VaeBasicModel.set_input(self, input_dict) 87 | self.survival_T = input_dict['survival_T'].to(self.device) 88 | self.survival_E = input_dict['survival_E'].to(self.device) 89 | self.y_true = input_dict['y_true'].to(self.device) 90 | self.label = input_dict['label'].to(self.device) 91 | self.value = input_dict['value'].to(self.device) 92 | 93 | def forward(self): 94 | # Get the output tensor 95 | VaeBasicModel.forward(self) 96 | self.y_out_sur, self.y_out_cla, self.y_out_reg = self.netDown(self.latent) 97 | 98 | def cal_losses(self): 99 | """Calculate losses""" 100 | VaeBasicModel.cal_losses(self) 101 | # Calculate the survival loss 102 | if self.param.survival_loss == 'MTLR': 103 | self.loss_survival = losses.MTLR_survival_loss(self.y_out_sur, self.y_true, self.survival_E, self.tri_matrix_1, self.param.reduction) 104 | # Calculate the classification loss 105 | self.loss_classifier = self.lossFuncClass(self.y_out_cla, self.label) 106 | # Calculate the regression loss 107 | self.loss_regression = self.lossFuncDist(self.y_out_reg.squeeze().type(torch.float32), (self.value / self.param.regression_scale).type(torch.float32)) 108 | # LOSS DOWN 109 | self.loss_down = self.param.k_survival * self.loss_survival + self.param.k_classifier * self.loss_classifier + self.param.k_regression * self.loss_regression 110 | 111 | self.loss_All = self.param.k_embed * self.loss_embed + self.loss_down 112 | 113 | def update(self): 114 | VaeBasicModel.update(self) 115 | 116 | def get_down_output(self): 117 | """ 118 | Get output from downstream task 119 | """ 120 | with torch.no_grad(): 121 | index = self.data_index 122 | # Survival 123 | y_true_E = self.survival_E 124 | y_true_T = self.survival_T 125 | y_out_sur = self.y_out_sur 126 | predict = self.predict_risk() 127 | # density = predict['density'] 128 | survival = predict['survival'] 129 | # hazard = predict['hazard'] 130 | risk = predict['risk'] 131 | 132 | # Classification 133 | y_prob_cla = F.softmax(self.y_out_cla, dim=1) 134 | _, y_pred_cla = torch.max(y_prob_cla, 1) 135 | y_true_cla = self.label 136 | 137 | # Regression 138 | y_true_reg = self.value 139 | y_pred_reg = self.y_out_reg * self.param.regression_scale 140 | 141 | return {'index': index, 'y_true_E': y_true_E, 'y_true_T': y_true_T, 'survival': survival, 'risk': risk, 'y_out_sur': y_out_sur, 'y_true_cla': y_true_cla, 'y_pred_cla': y_pred_cla, 'y_prob_cla': y_prob_cla, 'y_true_reg': y_true_reg, 'y_pred_reg': y_pred_reg} 142 | 143 | def calculate_current_metrics(self, output_dict): 144 | """ 145 | Calculate current metrics 146 | """ 147 | self.metric_accuracy = (output_dict['y_true_cla'] == output_dict['y_pred_cla']).sum().item() / len(output_dict['y_true_cla']) 148 | 149 | y_true_reg = output_dict['y_true_reg'].cpu().numpy() 150 | y_pred_reg = output_dict['y_pred_reg'].cpu().detach().numpy() 151 | self.metric_rmse = metrics.mean_squared_error(y_true_reg, y_pred_reg, squared=False) 152 | 153 | def get_tri_matrix(self, dimension_type=1): 154 | """ 155 | Get tensor of the triangular matrix 156 | """ 157 | if dimension_type == 1: 158 | ones_matrix = torch.ones(self.param.time_num, self.param.time_num + 1, device=self.device) 159 | else: 160 | ones_matrix = torch.ones(self.param.time_num + 1, self.param.time_num + 1, device=self.device) 161 | tri_matrix = torch.tril(ones_matrix) 162 | return tri_matrix 163 | 164 | def predict_risk(self): 165 | """ 166 | Predict the density, survival and hazard function, as well as the risk score 167 | """ 168 | if self.param.survival_loss == 'MTLR': 169 | phi = torch.exp(torch.mm(self.y_out_sur, self.tri_matrix_1)) 170 | div = torch.repeat_interleave(torch.sum(phi, 1).reshape(-1, 1), phi.shape[1], dim=1) 171 | 172 | density = phi / div 173 | survival = torch.mm(density, self.tri_matrix_2) 174 | hazard = density[:, :-1] / survival[:, 1:] 175 | 176 | cumulative_hazard = torch.cumsum(hazard, dim=1) 177 | risk = torch.sum(cumulative_hazard, 1) 178 | 179 | return {'density': density, 'survival': survival, 'hazard': hazard, 'risk': risk} 180 | -------------------------------------------------------------------------------- /models/vae_regression_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from sklearn import metrics 3 | from .vae_basic_model import VaeBasicModel 4 | from . import networks 5 | from . import losses 6 | 7 | 8 | class VaeRegressionModel(VaeBasicModel): 9 | """ 10 | This class implements the VAE regression model, using the VAE framework with the regression downstream task. 11 | """ 12 | 13 | @staticmethod 14 | def modify_commandline_parameters(parser, is_train=True): 15 | # changing the default values of parameters to match the vae regression model 16 | parser.set_defaults(net_down='multi_FC_regression', not_stratified=True) 17 | parser.add_argument('--regression_scale', type=int, default=1, 18 | help='normalization scale for y in regression task') 19 | parser.add_argument('--dist_loss', type=str, default='L1', 20 | help='choose the distance loss for regression task, options: [MSE | L1]') 21 | return parser 22 | 23 | def __init__(self, param): 24 | """ 25 | Initialize the VAE_regression class. 26 | """ 27 | VaeBasicModel.__init__(self, param) 28 | # specify the training losses you want to print out. 29 | self.loss_names.append('distance') 30 | # specify the metrics you want to print out. 31 | self.metric_names = ['rmse'] 32 | # input tensor 33 | self.value = None 34 | # output tensor 35 | self.y_out = None 36 | # define the network 37 | self.netDown = networks.define_down(param.net_down, param.norm_type, param.leaky_slope, param.dropout_p, 38 | param.latent_space_dim, None, None, None, param.init_type, 39 | param.init_gain, self.gpu_ids) 40 | # define the distance loss 41 | self.lossFuncDist = losses.get_loss_func(param.dist_loss, param.reduction) 42 | self.loss_distance = None 43 | self.metric_rmse = None 44 | 45 | if self.isTrain: 46 | # Set the optimizer 47 | self.optimizer_Down = torch.optim.Adam(self.netDown.parameters(), lr=param.lr, betas=(param.beta1, 0.999), weight_decay=param.weight_decay) 48 | # optimizer list was already defined in BaseModel 49 | self.optimizers.append(self.optimizer_Down) 50 | 51 | def set_input(self, input_dict): 52 | """ 53 | Unpack input data from the output dictionary of the dataloader 54 | 55 | Parameters: 56 | input_dict (dict): include the data tensor and its index. 57 | """ 58 | VaeBasicModel.set_input(self, input_dict) 59 | self.value = input_dict['value'].to(self.device) 60 | 61 | def forward(self): 62 | VaeBasicModel.forward(self) 63 | # Get the output tensor 64 | self.y_out = self.netDown(self.latent) 65 | 66 | def cal_losses(self): 67 | """Calculate losses""" 68 | VaeBasicModel.cal_losses(self) 69 | # Calculate the regression distance loss (downstream loss) 70 | self.loss_distance = self.lossFuncDist(self.y_out.squeeze().type(torch.float32), (self.value / self.param.regression_scale).type(torch.float32)) 71 | # LOSS DOWN 72 | self.loss_down = self.loss_distance 73 | 74 | self.loss_All = self.param.k_embed * self.loss_embed + self.loss_down 75 | 76 | def update(self): 77 | VaeBasicModel.update(self) 78 | 79 | def get_down_output(self): 80 | """ 81 | Get output from downstream task 82 | """ 83 | with torch.no_grad(): 84 | index = self.data_index 85 | y_true = self.value 86 | y_pred = self.y_out * self.param.regression_scale 87 | 88 | return {'index': index, 'y_true': y_true, 'y_pred': y_pred} 89 | 90 | def calculate_current_metrics(self, output_dict): 91 | """ 92 | Calculate current metrics 93 | """ 94 | y_true = output_dict['y_true'].cpu().numpy() 95 | y_pred = output_dict['y_pred'].cpu().detach().numpy() 96 | 97 | self.metric_rmse = metrics.mean_squared_error(y_true, y_pred, squared=False) 98 | 99 | -------------------------------------------------------------------------------- /models/vae_survival_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .vae_basic_model import VaeBasicModel 3 | from . import networks 4 | from . import losses 5 | 6 | 7 | class VaeSurvivalModel(VaeBasicModel): 8 | """ 9 | This class implements the VAE survival model, using the VAE framework with the survival prediction downstream task. 10 | """ 11 | 12 | @staticmethod 13 | def modify_commandline_parameters(parser, is_train=True): 14 | # changing the default values of parameters to match the vae survival prediction model 15 | parser.set_defaults(net_down='multi_FC_survival') 16 | parser.add_argument('--survival_loss', type=str, default='MTLR', help='choose the survival loss') 17 | parser.add_argument('--survival_T_max', type=float, default=-1, help='maximum T value for survival prediction task') 18 | parser.add_argument('--time_num', type=int, default=256, help='number of time intervals in the survival model') 19 | parser.add_argument('--stratify_label', action='store_true', help='load extra label for stratified dataset separation') 20 | return parser 21 | 22 | def __init__(self, param): 23 | """ 24 | Initialize the VAE_survival class. 25 | """ 26 | VaeBasicModel.__init__(self, param) 27 | # specify the training losses you want to print out. 28 | self.loss_names.append('survival') 29 | # specify the metrics you want to print out. 30 | self.metric_names = [] 31 | # input tensor 32 | self.survival_T = None 33 | self.survival_E = None 34 | self.y_true = None 35 | # output tensor 36 | self.y_out = None 37 | # define the network 38 | self.netDown = networks.define_down(param.net_down, param.norm_type, param.leaky_slope, param.dropout_p, 39 | param.latent_space_dim, None, param.time_num, None, param.init_type, 40 | param.init_gain, self.gpu_ids) 41 | self.loss_survival = None 42 | 43 | if param.survival_loss == 'MTLR': 44 | self.tri_matrix_1 = self.get_tri_matrix(dimension_type=1) 45 | self.tri_matrix_2 = self.get_tri_matrix(dimension_type=2) 46 | 47 | if self.isTrain: 48 | # Set the optimizer 49 | self.optimizer_Down = torch.optim.Adam(self.netDown.parameters(), lr=param.lr, betas=(param.beta1, 0.999), weight_decay=param.weight_decay) 50 | # optimizer list was already defined in BaseModel 51 | self.optimizers.append(self.optimizer_Down) 52 | 53 | def set_input(self, input_dict): 54 | """ 55 | Unpack input data from the output dictionary of the dataloader 56 | 57 | Parameters: 58 | input_dict (dict): include the data tensor and its index. 59 | """ 60 | VaeBasicModel.set_input(self, input_dict) 61 | self.survival_T = input_dict['survival_T'].to(self.device) 62 | self.survival_E = input_dict['survival_E'].to(self.device) 63 | self.y_true = input_dict['y_true'].to(self.device) 64 | 65 | def forward(self): 66 | VaeBasicModel.forward(self) 67 | # Get the output tensor 68 | self.y_out = self.netDown(self.latent) 69 | 70 | def cal_losses(self): 71 | """Calculate losses""" 72 | VaeBasicModel.cal_losses(self) 73 | # Calculate the survival loss (downstream loss) 74 | if self.param.survival_loss == 'MTLR': 75 | self.loss_survival = losses.MTLR_survival_loss(self.y_out, self.y_true, self.survival_E, self.tri_matrix_1, self.param.reduction) 76 | # LOSS DOWN 77 | self.loss_down = self.loss_survival 78 | 79 | self.loss_All = self.param.k_embed * self.loss_embed + self.loss_down 80 | 81 | def update(self): 82 | VaeBasicModel.update(self) 83 | 84 | def get_down_output(self): 85 | """ 86 | Get output from downstream task 87 | """ 88 | with torch.no_grad(): 89 | index = self.data_index 90 | y_true_E = self.survival_E 91 | y_true_T = self.survival_T 92 | y_out = self.y_out 93 | 94 | predict = self.predict_risk() 95 | # density = predict['density'] 96 | survival = predict['survival'] 97 | # hazard = predict['hazard'] 98 | risk = predict['risk'] 99 | 100 | return {'index': index, 'y_true_E': y_true_E, 'y_true_T': y_true_T, 'survival': survival, 'risk': risk, 'y_out': y_out} 101 | 102 | def calculate_current_metrics(self, output_dict): 103 | """ 104 | Calculate current metrics 105 | """ 106 | pass 107 | 108 | def get_tri_matrix(self, dimension_type=1): 109 | """ 110 | Get tensor of the triangular matrix 111 | """ 112 | if dimension_type == 1: 113 | ones_matrix = torch.ones(self.param.time_num, self.param.time_num + 1, device=self.device) 114 | else: 115 | ones_matrix = torch.ones(self.param.time_num + 1, self.param.time_num + 1, device=self.device) 116 | tri_matrix = torch.tril(ones_matrix) 117 | return tri_matrix 118 | 119 | def predict_risk(self): 120 | """ 121 | Predict the density, survival and hazard function, as well as the risk score 122 | """ 123 | if self.param.survival_loss == 'MTLR': 124 | phi = torch.exp(torch.mm(self.y_out, self.tri_matrix_1)) 125 | div = torch.repeat_interleave(torch.sum(phi, 1).reshape(-1, 1), phi.shape[1], dim=1) 126 | 127 | density = phi / div 128 | survival = torch.mm(density, self.tri_matrix_2) 129 | hazard = density[:, :-1] / survival[:, 1:] 130 | 131 | cumulative_hazard = torch.cumsum(hazard, dim=1) 132 | risk = torch.sum(cumulative_hazard, 1) 133 | 134 | return {'density': density, 'survival': survival, 'hazard': hazard, 'risk': risk} 135 | -------------------------------------------------------------------------------- /params/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangxiaoyu11/OmiEmbed/0abf210325ea0d23948823bbf673bf5695e3b9b2/params/__init__.py -------------------------------------------------------------------------------- /params/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangxiaoyu11/OmiEmbed/0abf210325ea0d23948823bbf673bf5695e3b9b2/params/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /params/__pycache__/basic_params.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangxiaoyu11/OmiEmbed/0abf210325ea0d23948823bbf673bf5695e3b9b2/params/__pycache__/basic_params.cpython-36.pyc -------------------------------------------------------------------------------- /params/basic_params.py: -------------------------------------------------------------------------------- 1 | import time 2 | import argparse 3 | import torch 4 | import os 5 | import models 6 | from util import util 7 | 8 | 9 | class BasicParams: 10 | """ 11 | This class define the console parameters 12 | """ 13 | 14 | def __init__(self): 15 | """ 16 | Reset the class. Indicates the class hasn't been initialized 17 | """ 18 | self.initialized = False 19 | self.isTrain = True 20 | self.isTest = True 21 | 22 | def initialize(self, parser): 23 | """ 24 | Define the common console parameters 25 | """ 26 | parser.add_argument('--gpu_ids', type=str, default='0', 27 | help='which GPU would like to use: e.g. 0 or 0,1, -1 for CPU') 28 | parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', 29 | help='models, settings and intermediate results are saved in folder in this directory') 30 | parser.add_argument('--experiment_name', type=str, default='test', 31 | help='name of the folder in the checkpoint directory') 32 | 33 | # Dataset parameters 34 | parser.add_argument('--omics_mode', type=str, default='a', 35 | help='omics types would like to use in the model, options: [abc | ab | a | b | c]') 36 | parser.add_argument('--data_root', type=str, default='./data', 37 | help='path to input data') 38 | parser.add_argument('--batch_size', type=int, default=32, 39 | help='input data batch size') 40 | parser.add_argument('--num_threads', default=0, type=int, 41 | help='number of threads for loading data') 42 | parser.add_argument('--set_pin_memory', action='store_true', 43 | help='set pin_memory in the dataloader to increase data loading performance') 44 | parser.add_argument('--not_stratified', action='store_true', 45 | help='do not apply the stratified mode in train/test split if set true') 46 | parser.add_argument('--use_sample_list', action='store_true', 47 | help='provide a subset sample list of the dataset, store in the path data_root/sample_list.tsv, if False use all the samples') 48 | parser.add_argument('--use_feature_lists', action='store_true', 49 | help='provide feature lists of the input omics data, e.g. data_root/feature_list_A.tsv, if False use all the features') 50 | parser.add_argument('--detect_na', action='store_true', 51 | help='detect missing value markers during data loading, stay False can improve the loading performance') 52 | parser.add_argument('--file_format', type=str, default='tsv', 53 | help='file format of the omics data, options: [tsv | csv | hdf]') 54 | 55 | # Model parameters 56 | parser.add_argument('--model', type=str, default='vae_classifier', 57 | help='chooses which model want to use, options: [vae_classifier | vae_regression | vae_survival | vae_multitask]') 58 | parser.add_argument('--net_VAE', type=str, default='fc_sep', 59 | help='specify the backbone of the VAE, default is the one dimensional CNN, options: [conv_1d | fc_sep | fc]') 60 | parser.add_argument('--net_down', type=str, default='multi_FC_classifier', 61 | help='specify the backbone of the downstream task network, default is the multi-layer FC classifier, options: [multi_FC_classifier | multi_FC_regression | multi_FC_survival | multi_FC_multitask]') 62 | parser.add_argument('--norm_type', type=str, default='batch', 63 | help='the type of normalization applied to the model, default to use batch normalization, options: [batch | instance | none ]') 64 | parser.add_argument('--filter_num', type=int, default=8, 65 | help='number of filters in the last convolution layer in the generator') 66 | parser.add_argument('--conv_k_size', type=int, default=9, 67 | help='the kernel size of convolution layer, default kernel size is 9, the kernel is one dimensional.') 68 | parser.add_argument('--dropout_p', type=float, default=0.2, 69 | help='probability of an element to be zeroed in a dropout layer, default is 0 which means no dropout.') 70 | parser.add_argument('--leaky_slope', type=float, default=0.2, 71 | help='the negative slope of the Leaky ReLU activation function') 72 | parser.add_argument('--latent_space_dim', type=int, default=128, 73 | help='the dimensionality of the latent space') 74 | parser.add_argument('--seed', type=int, default=42, 75 | help='random seed') 76 | parser.add_argument('--init_type', type=str, default='normal', 77 | help='choose the method of network initialization, options: [normal | xavier_normal | xavier_uniform | kaiming_normal | kaiming_uniform | orthogonal]') 78 | parser.add_argument('--init_gain', type=float, default=0.02, 79 | help='scaling factor for normal, xavier and orthogonal initialization methods') 80 | 81 | # Loss parameters 82 | parser.add_argument('--recon_loss', type=str, default='BCE', 83 | help='chooses the reconstruction loss function, options: [BCE | MSE | L1]') 84 | parser.add_argument('--reduction', type=str, default='mean', 85 | help='chooses the reduction to apply to the loss function, options: [sum | mean]') 86 | parser.add_argument('--k_kl', type=float, default=0.01, 87 | help='weight for the kl loss') 88 | parser.add_argument('--k_embed', type=float, default=0.001, 89 | help='weight for the embedding loss') 90 | 91 | # Other parameters 92 | parser.add_argument('--deterministic', action='store_true', 93 | help='make the model deterministic for reproduction if set true') 94 | parser.add_argument('--detail', action='store_true', 95 | help='print more detailed information if set true') 96 | parser.add_argument('--epoch_to_load', type=str, default='latest', 97 | help='the epoch number to load, set latest to load latest cached model') 98 | parser.add_argument('--experiment_to_load', type=str, default='test', 99 | help='the experiment to load') 100 | 101 | self.initialized = True # set the initialized to True after we define the parameters of the project 102 | return parser 103 | 104 | def get_params(self): 105 | """ 106 | Initialize our parser with basic parameters once. 107 | Add additional model-specific parameters. 108 | """ 109 | if not self.initialized: # check if this object has been initialized 110 | # if not create a new parser object 111 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 112 | # use our method to initialize the parser with the predefined arguments 113 | parser = self.initialize(parser) 114 | 115 | # get the basic parameters 116 | param, _ = parser.parse_known_args() 117 | 118 | # modify model-related parser options 119 | model_name = param.model 120 | model_param_setter = models.get_param_setter(model_name) 121 | parser = model_param_setter(parser, self.isTrain) 122 | 123 | # save and return the parser 124 | self.parser = parser 125 | return parser.parse_args() 126 | 127 | def print_params(self, param): 128 | """ 129 | Print welcome words and command line parameters. 130 | Save the command line parameters in a txt file to the disk 131 | """ 132 | message = '' 133 | message += '\nWelcome to OmiEmbed\nby Xiaoyu Zhang x.zhang18@imperial.ac.uk\n\n' 134 | message += '-----------------------Running Parameters-----------------------\n' 135 | for key, value in sorted(vars(param).items()): 136 | comment = '' 137 | default = self.parser.get_default(key) 138 | if value != default: 139 | comment = '\t[default: %s]' % str(default) 140 | message += '{:>18}: {:<15}{}\n'.format(str(key), str(value), comment) 141 | message += '----------------------------------------------------------------\n' 142 | print(message) 143 | 144 | # Save the running parameters setting in the disk 145 | experiment_dir = os.path.join(param.checkpoints_dir, param.experiment_name) 146 | util.mkdir(experiment_dir) 147 | file_name = os.path.join(experiment_dir, 'cmd_parameters.txt') 148 | with open(file_name, 'w') as param_file: 149 | now = time.strftime('%c') 150 | param_file.write('{:s}\n'.format(now)) 151 | param_file.write(message) 152 | param_file.write('\n') 153 | 154 | def parse(self): 155 | """ 156 | Parse the parameters of our project. Set up GPU device. Print the welcome words and list parameters in the console. 157 | """ 158 | param = self.get_params() # get the parameters to the object param 159 | param.isTrain = self.isTrain 160 | param.isTest = self.isTest 161 | 162 | # Print welcome words and command line parameters 163 | self.print_params(param) 164 | 165 | # Set the internal parameters 166 | # epoch_num: the total epoch number 167 | if self.isTrain: 168 | param.epoch_num = param.epoch_num_p1 + param.epoch_num_p2 + param.epoch_num_p3 169 | # downstream_task: for the classification task a labels.tsv file is needed, for the regression task a values.tsv file is needed 170 | if param.model == 'vae_classifier': 171 | param.downstream_task = 'classification' 172 | elif param.model == 'vae_regression': 173 | param.downstream_task = 'regression' 174 | elif param.model == 'vae_survival': 175 | param.downstream_task = 'survival' 176 | elif param.model == 'vae_multitask' or param.model == 'vae_multitask_gn': 177 | param.downstream_task = 'multitask' 178 | elif param.model == 'vae_alltask' or param.model == 'vae_alltask_gn': 179 | param.downstream_task = 'alltask' 180 | else: 181 | raise NotImplementedError('Model name [%s] is not recognized' % param.model) 182 | # add_channel: add one extra dimension of channel for the input data, used for convolution layer 183 | # ch_separate: separate the DNA methylation matrix base on the chromosome 184 | if param.net_VAE == 'conv_1d': 185 | param.add_channel = True 186 | param.ch_separate = False 187 | elif param.net_VAE == 'fc_sep': 188 | param.add_channel = False 189 | param.ch_separate = True 190 | elif param.net_VAE == 'fc': 191 | param.add_channel = False 192 | param.ch_separate = False 193 | else: 194 | raise NotImplementedError('VAE model name [%s] is not recognized' % param.net_VAE) 195 | # omics_num: the number of omics types 196 | param.omics_num = len(param.omics_mode) 197 | 198 | # Set up GPU 199 | str_gpu_ids = param.gpu_ids.split(',') 200 | param.gpu_ids = [] 201 | for str_gpu_id in str_gpu_ids: 202 | int_gpu_id = int(str_gpu_id) 203 | if int_gpu_id >= 0: 204 | param.gpu_ids.append(int_gpu_id) 205 | if len(param.gpu_ids) > 0: 206 | torch.cuda.set_device(param.gpu_ids[0]) 207 | 208 | self.param = param 209 | return self.param 210 | -------------------------------------------------------------------------------- /params/test_params.py: -------------------------------------------------------------------------------- 1 | from .basic_params import BasicParams 2 | 3 | 4 | class TestParams(BasicParams): 5 | """ 6 | This class is a son class of BasicParams. 7 | This class includes parameters for testing and parameters inherited from the father class. 8 | """ 9 | def initialize(self, parser): 10 | parser = BasicParams.initialize(self, parser) 11 | 12 | # Testing parameters 13 | parser.add_argument('--save_latent_space', action='store_true', help='save the latent space of input data to disc') 14 | 15 | # Logging and visualization 16 | parser.add_argument('--print_freq', type=int, default=1, 17 | help='frequency of showing results on console') 18 | 19 | self.isTrain = False 20 | self.isTest = True 21 | return parser 22 | -------------------------------------------------------------------------------- /params/train_params.py: -------------------------------------------------------------------------------- 1 | from .basic_params import BasicParams 2 | 3 | 4 | class TrainParams(BasicParams): 5 | """ 6 | This class is a son class of BasicParams. 7 | This class includes parameters for training and parameters inherited from the father class. 8 | """ 9 | def initialize(self, parser): 10 | parser = BasicParams.initialize(self, parser) 11 | 12 | # Training parameters 13 | parser.add_argument('--epoch_num_p1', type=int, default=50, 14 | help='epoch number for phase 1') 15 | parser.add_argument('--epoch_num_p2', type=int, default=50, 16 | help='epoch number for phase 2') 17 | parser.add_argument('--epoch_num_p3', type=int, default=100, 18 | help='epoch number for phase 3') 19 | parser.add_argument('--lr', type=float, default=1e-4, 20 | help='initial learning rate') 21 | parser.add_argument('--beta1', type=float, default=0.5, 22 | help='momentum term of adam') 23 | parser.add_argument('--lr_policy', type=str, default='linear', 24 | help='The learning rate policy for the scheduler. [linear | step | plateau | cosine]') 25 | parser.add_argument('--epoch_count', type=int, default=1, 26 | help='the starting epoch count, default start from 1') 27 | parser.add_argument('--epoch_num_decay', type=int, default=50, 28 | help='Number of epoch to linearly decay learning rate to zero (lr_policy == linear)') 29 | parser.add_argument('--decay_step_size', type=int, default=50, 30 | help='The original learning rate multiply by a gamma every decay_step_size epoch (lr_policy == step)') 31 | parser.add_argument('--weight_decay', type=float, default=1e-4, 32 | help='weight decay (L2 penalty)') 33 | 34 | # Network saving and loading parameters 35 | parser.add_argument('--continue_train', action='store_true', 36 | help='load the latest model and continue training') 37 | parser.add_argument('--save_model', action='store_true', 38 | help='save the model during training') 39 | parser.add_argument('--save_epoch_freq', type=int, default=-1, 40 | help='frequency of saving checkpoints at the end of epochs, -1 means only save the last epoch') 41 | 42 | # Logging and visualization 43 | parser.add_argument('--print_freq', type=int, default=1, 44 | help='frequency of showing results on console') 45 | parser.add_argument('--save_latent_space', action='store_true', 46 | help='save the latent space of input data to disc') 47 | 48 | self.isTrain = True 49 | self.isTest = False 50 | return parser 51 | -------------------------------------------------------------------------------- /params/train_test_params.py: -------------------------------------------------------------------------------- 1 | from .basic_params import BasicParams 2 | 3 | 4 | class TrainTestParams(BasicParams): 5 | """ 6 | This class is a son class of BasicParams. 7 | This class includes parameters for training & testing and parameters inherited from the father class. 8 | """ 9 | def initialize(self, parser): 10 | parser = BasicParams.initialize(self, parser) 11 | 12 | # Training parameters 13 | parser.add_argument('--epoch_num_p1', type=int, default=50, 14 | help='epoch number for phase 1') 15 | parser.add_argument('--epoch_num_p2', type=int, default=50, 16 | help='epoch number for phase 2') 17 | parser.add_argument('--epoch_num_p3', type=int, default=100, 18 | help='epoch number for phase 3') 19 | parser.add_argument('--lr', type=float, default=1e-4, 20 | help='initial learning rate') 21 | parser.add_argument('--beta1', type=float, default=0.5, 22 | help='momentum term of adam') 23 | parser.add_argument('--lr_policy', type=str, default='linear', 24 | help='The learning rate policy for the scheduler. [linear | step | plateau | cosine]') 25 | parser.add_argument('--epoch_count', type=int, default=1, 26 | help='the starting epoch count, default start from 1') 27 | parser.add_argument('--epoch_num_decay', type=int, default=50, 28 | help='Number of epoch to linearly decay learning rate to zero (lr_policy == linear)') 29 | parser.add_argument('--decay_step_size', type=int, default=50, 30 | help='The original learning rate multiply by a gamma every decay_step_size epoch (lr_policy == step)') 31 | parser.add_argument('--weight_decay', type=float, default=1e-4, 32 | help='weight decay (L2 penalty)') 33 | 34 | # Network saving and loading parameters 35 | parser.add_argument('--continue_train', action='store_true', 36 | help='load the latest model and continue training') 37 | parser.add_argument('--save_model', action='store_true', 38 | help='save the model during training') 39 | parser.add_argument('--save_epoch_freq', type=int, default=-1, 40 | help='frequency of saving checkpoints at the end of epochs, -1 means only save the last epoch') 41 | 42 | # Logging and visualization 43 | parser.add_argument('--print_freq', type=int, default=1, 44 | help='frequency of showing results on console') 45 | 46 | # Dataset parameters 47 | parser.add_argument('--train_ratio', type=float, default=0.8, 48 | help='ratio of training set in the full dataset') 49 | parser.add_argument('--test_ratio', type=float, default=0.2, 50 | help='ratio of testing set in the full dataset') 51 | 52 | self.isTrain = True 53 | self.isTest = True 54 | return parser 55 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.2.0 2 | tensorboard>=1.10.0 3 | prefetch_generator>=1.0.0 4 | tables>=3.6.0 5 | scikit-survival>=0.6.0 6 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | """ 2 | Separated testing for OmiEmbed 3 | """ 4 | import time 5 | from util import util 6 | from params.test_params import TestParams 7 | from datasets import create_single_dataloader 8 | from models import create_model 9 | from util.visualizer import Visualizer 10 | 11 | if __name__ == '__main__': 12 | # Get testing parameter 13 | param = TestParams().parse() 14 | if param.deterministic: 15 | util.setup_seed(param.seed) 16 | 17 | # Dataset related 18 | dataloader, sample_list = create_single_dataloader(param, shuffle=False) # No shuffle for testing 19 | print('The size of testing set is {}'.format(len(dataloader))) 20 | # Get sample list for the dataset 21 | param.sample_list = dataloader.get_sample_list() 22 | # Get the dimension of input omics data 23 | param.omics_dims = dataloader.get_omics_dims() 24 | if param.downstream_task == 'classification' or param.downstream_task == 'multitask': 25 | # Get the number of classes for the classification task 26 | if param.class_num == 0: 27 | param.class_num = dataloader.get_class_num() 28 | print('The number of classes: {}'.format(param.class_num)) 29 | if param.downstream_task == 'regression' or param.downstream_task == 'multitask': 30 | # Get the range of the target values 31 | values_min = dataloader.get_values_min() 32 | values_max = dataloader.get_values_max() 33 | if param.regression_scale == 1: 34 | param.regression_scale = values_max 35 | print('The range of the target values is [{}, {}]'.format(values_min, values_max)) 36 | if param.downstream_task == 'survival' or param.downstream_task == 'multitask': 37 | # Get the range of T 38 | survival_T_min = dataloader.get_survival_T_min() 39 | survival_T_max = dataloader.get_survival_T_max() 40 | if param.survival_T_max == -1: 41 | param.survival_T_max = survival_T_max 42 | print('The range of survival T is [{}, {}]'.format(survival_T_min, survival_T_max)) 43 | 44 | # Model related 45 | model = create_model(param) # Create a model given param.model and other parameters 46 | model.setup(param) # Regular setup for the model: load and print networks, create schedulers 47 | visualizer = Visualizer(param) # Create a visualizer to print results 48 | 49 | # TESTING 50 | model.set_eval() 51 | test_start_time = time.time() # Start time of testing 52 | output_dict, losses_dict, metrics_dict = model.init_log_dict() # Initialize the log dictionaries 53 | if param.save_latent_space: 54 | latent_dict = model.init_latent_dict() 55 | 56 | # Start testing loop 57 | for i, data in enumerate(dataloader): 58 | dataset_size = len(dataloader) 59 | actual_batch_size = len(data['index']) 60 | model.set_input(data) # Unpack input data from the output dictionary of the dataloader 61 | model.test() # Run forward to get the output tensors 62 | model.update_log_dict(output_dict, losses_dict, metrics_dict, actual_batch_size) # Update the log dictionaries 63 | if param.save_latent_space: 64 | latent_dict = model.update_latent_dict(latent_dict) # Update the latent space array 65 | if i % param.print_freq == 0: # Print testing log 66 | visualizer.print_test_log(param.epoch_to_load, i, losses_dict, metrics_dict, param.batch_size, dataset_size) 67 | 68 | test_time = time.time() - test_start_time 69 | visualizer.print_test_summary(param.epoch_to_load, losses_dict, output_dict, test_time) 70 | visualizer.save_output_dict(output_dict) 71 | if param.save_latent_space: 72 | visualizer.save_latent_space(latent_dict, sample_list) 73 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Separated training for OmiEmbed 3 | """ 4 | import time 5 | import warnings 6 | from util import util 7 | from params.train_params import TrainParams 8 | from datasets import create_single_dataloader 9 | from models import create_model 10 | from util.visualizer import Visualizer 11 | 12 | 13 | if __name__ == "__main__": 14 | warnings.filterwarnings('ignore') 15 | # Get parameters 16 | param = TrainParams().parse() 17 | if param.deterministic: 18 | util.setup_seed(param.seed) 19 | 20 | # Dataset related 21 | dataloader, sample_list = create_single_dataloader(param, enable_drop_last=True) 22 | print('The size of training set is {}'.format(len(dataloader))) 23 | # Get the dimension of input omics data 24 | param.omics_dims = dataloader.get_omics_dims() 25 | if param.downstream_task in ['classification', 'multitask', 'alltask']: 26 | # Get the number of classes for the classification task 27 | if param.class_num == 0: 28 | param.class_num = dataloader.get_class_num() 29 | if param.downstream_task != 'alltask': 30 | print('The number of classes: {}'.format(param.class_num)) 31 | if param.downstream_task in ['regression', 'multitask', 'alltask']: 32 | # Get the range of the target values 33 | values_min = dataloader.get_values_min() 34 | values_max = dataloader.get_values_max() 35 | if param.regression_scale == 1: 36 | param.regression_scale = values_max 37 | print('The range of the target values is [{}, {}]'.format(values_min, values_max)) 38 | if param.downstream_task in ['survival', 'multitask', 'alltask']: 39 | # Get the range of T 40 | survival_T_min = dataloader.get_survival_T_min() 41 | survival_T_max = dataloader.get_survival_T_max() 42 | if param.survival_T_max == -1: 43 | param.survival_T_max = survival_T_max 44 | print('The range of survival T is [{}, {}]'.format(survival_T_min, survival_T_max)) 45 | 46 | # Model related 47 | model = create_model(param) # Create a model given param.model and other parameters 48 | model.setup(param) # Regular setup for the model: load and print networks, create schedulers 49 | visualizer = Visualizer(param) # Create a visualizer to print results 50 | 51 | # Start the epoch loop 52 | visualizer.print_phase(model.phase) 53 | for epoch in range(param.epoch_count, param.epoch_num + 1): # outer loop for different epochs 54 | epoch_start_time = time.time() # Start time of this epoch 55 | model.epoch = epoch 56 | # TRAINING 57 | model.set_train() # Set train mode for training 58 | iter_load_start_time = time.time() # Start time of data loading for this iteration 59 | output_dict, losses_dict, metrics_dict = model.init_log_dict() # Initialize the log dictionaries 60 | if epoch == param.epoch_num_p1 + 1: 61 | model.phase = 'p2' # Change to supervised phase 62 | visualizer.print_phase(model.phase) 63 | if epoch == param.epoch_num_p1 + param.epoch_num_p2 + 1: 64 | model.phase = 'p3' # Change to supervised phase 65 | visualizer.print_phase(model.phase) 66 | if param.save_latent_space and epoch == param.epoch_num: 67 | latent_dict = model.init_latent_dict() 68 | 69 | # Start training loop 70 | for i, data in enumerate(dataloader): # Inner loop for different iteration within one epoch 71 | model.iter = i 72 | dataset_size = len(dataloader) 73 | actual_batch_size = len(data['index']) 74 | iter_start_time = time.time() # Timer for computation per iteration 75 | if i % param.print_freq == 0: 76 | load_time = iter_start_time - iter_load_start_time # Data loading time for this iteration 77 | model.set_input(data) # Unpack input data from the output dictionary of the dataloader 78 | model.update() # Calculate losses, gradients and update network parameters 79 | model.update_log_dict(output_dict, losses_dict, metrics_dict, actual_batch_size) # Update the log dictionaries 80 | if param.save_latent_space and epoch == param.epoch_num: 81 | latent_dict = model.update_latent_dict(latent_dict) # Update the latent space array 82 | if i % param.print_freq == 0: # Print training losses and save logging information to the disk 83 | comp_time = time.time() - iter_start_time # Computational time for this iteration 84 | visualizer.print_train_log(epoch, i, losses_dict, metrics_dict, load_time, comp_time, param.batch_size, dataset_size) 85 | iter_load_start_time = time.time() 86 | 87 | # Model saving 88 | if param.save_model: 89 | if param.save_epoch_freq == -1: # Only save networks during last epoch 90 | if epoch == param.epoch_num: 91 | print('Saving the model at the end of epoch {:d}'.format(epoch)) 92 | model.save_networks(str(epoch)) 93 | elif epoch % param.save_epoch_freq == 0: # Save both the generator and the discriminator every epochs 94 | print('Saving the model at the end of epoch {:d}'.format(epoch)) 95 | # model.save_networks('latest') 96 | model.save_networks(str(epoch)) 97 | 98 | train_time = time.time() - epoch_start_time 99 | current_lr = model.update_learning_rate() # update learning rates at the end of each epoch 100 | visualizer.print_train_summary(epoch, losses_dict, output_dict, train_time, current_lr) 101 | 102 | if param.save_latent_space and epoch == param.epoch_num: 103 | visualizer.save_latent_space(latent_dict, sample_list) 104 | -------------------------------------------------------------------------------- /train_test.py: -------------------------------------------------------------------------------- 1 | """ 2 | Training and testing for OmiEmbed 3 | """ 4 | import time 5 | import warnings 6 | from util import util 7 | from params.train_test_params import TrainTestParams 8 | from datasets import create_separate_dataloader 9 | from models import create_model 10 | from util.visualizer import Visualizer 11 | 12 | 13 | if __name__ == "__main__": 14 | warnings.filterwarnings('ignore') 15 | full_start_time = time.time() 16 | # Get parameters 17 | param = TrainTestParams().parse() 18 | if param.deterministic: 19 | util.setup_seed(param.seed) 20 | 21 | # Dataset related 22 | full_dataloader, train_dataloader, val_dataloader, test_dataloader = create_separate_dataloader(param) 23 | print('The size of training set is {}'.format(len(train_dataloader))) 24 | # Get sample list for the dataset 25 | param.sample_list = full_dataloader.get_sample_list() 26 | # Get the dimension of input omics data 27 | param.omics_dims = full_dataloader.get_omics_dims() 28 | if param.downstream_task in ['classification', 'multitask', 'alltask']: 29 | # Get the number of classes for the classification task 30 | if param.class_num == 0: 31 | param.class_num = full_dataloader.get_class_num() 32 | if param.downstream_task != 'alltask': 33 | print('The number of classes: {}'.format(param.class_num)) 34 | if param.downstream_task in ['regression', 'multitask', 'alltask']: 35 | # Get the range of the target values 36 | values_min = full_dataloader.get_values_min() 37 | values_max = full_dataloader.get_values_max() 38 | if param.regression_scale == 1: 39 | param.regression_scale = values_max 40 | print('The range of the target values is [{}, {}]'.format(values_min, values_max)) 41 | if param.downstream_task in ['survival', 'multitask', 'alltask']: 42 | # Get the range of T 43 | survival_T_min = full_dataloader.get_survival_T_min() 44 | survival_T_max = full_dataloader.get_survival_T_max() 45 | if param.survival_T_max == -1: 46 | param.survival_T_max = survival_T_max 47 | print('The range of survival T is [{}, {}]'.format(survival_T_min, survival_T_max)) 48 | 49 | # Model related 50 | model = create_model(param) # Create a model given param.model and other parameters 51 | model.setup(param) # Regular setup for the model: load and print networks, create schedulers 52 | visualizer = Visualizer(param) # Create a visualizer to print results 53 | 54 | # Start the epoch loop 55 | visualizer.print_phase(model.phase) 56 | for epoch in range(param.epoch_count, param.epoch_num + 1): # outer loop for different epochs 57 | epoch_start_time = time.time() # Start time of this epoch 58 | model.epoch = epoch 59 | # TRAINING 60 | model.set_train() # Set train mode for training 61 | iter_load_start_time = time.time() # Start time of data loading for this iteration 62 | output_dict, losses_dict, metrics_dict = model.init_log_dict() # Initialize the log dictionaries 63 | if epoch == param.epoch_num_p1 + 1: 64 | model.phase = 'p2' # Change to supervised phase 65 | visualizer.print_phase(model.phase) 66 | if epoch == param.epoch_num_p1 + param.epoch_num_p2 + 1: 67 | model.phase = 'p3' # Change to supervised phase 68 | visualizer.print_phase(model.phase) 69 | 70 | # Start training loop 71 | for i, data in enumerate(train_dataloader): # Inner loop for different iteration within one epoch 72 | model.iter = i 73 | dataset_size = len(train_dataloader) 74 | actual_batch_size = len(data['index']) 75 | iter_start_time = time.time() # Timer for computation per iteration 76 | if i % param.print_freq == 0: 77 | load_time = iter_start_time - iter_load_start_time # Data loading time for this iteration 78 | model.set_input(data) # Unpack input data from the output dictionary of the dataloader 79 | model.update() # Calculate losses, gradients and update network parameters 80 | model.update_log_dict(output_dict, losses_dict, metrics_dict, actual_batch_size) # Update the log dictionaries 81 | if i % param.print_freq == 0: # Print training losses and save logging information to the disk 82 | comp_time = time.time() - iter_start_time # Computational time for this iteration 83 | visualizer.print_train_log(epoch, i, losses_dict, metrics_dict, load_time, comp_time, param.batch_size, dataset_size) 84 | iter_load_start_time = time.time() 85 | 86 | # Model saving 87 | if param.save_model: 88 | if param.save_epoch_freq == -1: # Only save networks during last epoch 89 | if epoch == param.epoch_num: 90 | print('Saving the model at the end of epoch {:d}'.format(epoch)) 91 | model.save_networks(str(epoch)) 92 | elif epoch % param.save_epoch_freq == 0: # Save both the generator and the discriminator every epochs 93 | print('Saving the model at the end of epoch {:d}'.format(epoch)) 94 | # model.save_networks('latest') 95 | model.save_networks(str(epoch)) 96 | 97 | train_time = time.time() - epoch_start_time 98 | current_lr = model.update_learning_rate() # update learning rates at the end of each epoch 99 | visualizer.print_train_summary(epoch, losses_dict, output_dict, train_time, current_lr) 100 | 101 | # TESTING 102 | model.set_eval() # Set eval mode for testing 103 | test_start_time = time.time() # Start time of testing 104 | output_dict, losses_dict, metrics_dict = model.init_log_dict() # Initialize the log dictionaries 105 | 106 | # Start testing loop 107 | for i, data in enumerate(test_dataloader): 108 | dataset_size = len(test_dataloader) 109 | actual_batch_size = len(data['index']) 110 | model.set_input(data) # Unpack input data from the output dictionary of the dataloader 111 | model.test() # Run forward to get the output tensors 112 | model.update_log_dict(output_dict, losses_dict, metrics_dict, actual_batch_size) # Update the log dictionaries 113 | if i % param.print_freq == 0: # Print testing log 114 | visualizer.print_test_log(epoch, i, losses_dict, metrics_dict, param.batch_size, dataset_size) 115 | 116 | test_time = time.time() - test_start_time 117 | visualizer.print_test_summary(epoch, losses_dict, output_dict, test_time) 118 | if epoch == param.epoch_num: 119 | visualizer.save_output_dict(output_dict) 120 | 121 | full_time = time.time() - full_start_time 122 | print('Full running time: {:.3f}s'.format(full_time)) 123 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangxiaoyu11/OmiEmbed/0abf210325ea0d23948823bbf673bf5695e3b9b2/util/__init__.py -------------------------------------------------------------------------------- /util/metrics.py: -------------------------------------------------------------------------------- 1 | """ 2 | Contain some metrics 3 | """ 4 | import numpy as np 5 | # from lifelines.utils import concordance_index 6 | # from pysurvival.utils._metrics import _concordance_index 7 | from sksurv.metrics import concordance_index_censored 8 | from sksurv.metrics import integrated_brier_score 9 | 10 | 11 | def c_index(true_T, true_E, pred_risk, include_ties=True): 12 | """ 13 | Calculate c-index for survival prediction downstream task 14 | """ 15 | # Ordering true_T, true_E and pred_score in descending order according to true_T 16 | order = np.argsort(-true_T) 17 | 18 | true_T = true_T[order] 19 | true_E = true_E[order] 20 | pred_risk = pred_risk[order] 21 | 22 | # Calculating the c-index 23 | # result = concordance_index(true_T, -pred_risk, true_E) 24 | # result = _concordance_index(pred_risk, true_T, true_E, include_ties)[0] 25 | result = concordance_index_censored(true_E.astype(bool), true_T, pred_risk)[0] 26 | 27 | return result 28 | 29 | 30 | def ibs(true_T, true_E, pred_survival, time_points): 31 | """ 32 | Calculate integrated brier score for survival prediction downstream task 33 | """ 34 | true_E_bool = true_E.astype(bool) 35 | true = np.array([(true_E_bool[i], true_T[i]) for i in range(len(true_E))], dtype=[('event', np.bool_), ('time', np.float32)]) 36 | 37 | # time points must be within the range of T 38 | min_T = true_T.min() 39 | max_T = true_T.max() 40 | valid_index = [] 41 | for i in range(len(time_points)): 42 | if min_T <= time_points[i] <= max_T: 43 | valid_index.append(i) 44 | time_points = time_points[valid_index] 45 | pred_survival = pred_survival[:, valid_index] 46 | 47 | result = integrated_brier_score(true, true, pred_survival, time_points) 48 | 49 | return result 50 | -------------------------------------------------------------------------------- /util/preprocess.py: -------------------------------------------------------------------------------- 1 | """ 2 | Contain some omics data preprocess functions 3 | """ 4 | import pandas as pd 5 | 6 | 7 | def separate_B(B_df_single): 8 | """ 9 | Separate the DNA methylation dataframe into subsets according to their targeting chromosomes 10 | 11 | Parameters: 12 | B_df_single(DataFrame) -- a dataframe that contains the single DNA methylation matrix 13 | 14 | Return: 15 | B_df_list(list) -- a list with 23 subset dataframe 16 | B_dim(list) -- the dims of each chromosome 17 | """ 18 | anno = pd.read_csv('./anno/B_anno.csv', dtype={'CHR': str}, index_col=0) 19 | anno_contain = anno.loc[B_df_single.index, :] 20 | print('Separating B.tsv according the targeting chromosome...') 21 | B_df_list, B_dim_list = [], [] 22 | ch_id = list(range(1, 23)) 23 | ch_id.append('X') 24 | for ch in ch_id: 25 | ch_index = anno_contain[anno_contain.CHR == str(ch)].index 26 | ch_df = B_df_single.loc[ch_index, :] 27 | B_df_list.append(ch_df) 28 | B_dim_list.append(len(ch_df)) 29 | 30 | return B_df_list, B_dim_list 31 | -------------------------------------------------------------------------------- /util/util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Contain some simple helper functions 3 | """ 4 | import os 5 | import shutil 6 | import torch 7 | import random 8 | import numpy as np 9 | 10 | 11 | def mkdir(path): 12 | """ 13 | Create a empty directory in the disk if it didn't exist 14 | 15 | Parameters: 16 | path(str) -- a directory path we would like to create 17 | """ 18 | if not os.path.exists(path): 19 | os.makedirs(path) 20 | 21 | 22 | def clear_dir(path): 23 | """ 24 | delete all files in a path 25 | 26 | Parameters: 27 | path(str) -- a directory path that we would like to delete all files in it 28 | """ 29 | if os.path.exists(path): 30 | shutil.rmtree(path, ignore_errors=True) 31 | os.makedirs(path, exist_ok=True) 32 | 33 | 34 | def setup_seed(seed): 35 | """ 36 | setup seed to make the experiments deterministic 37 | 38 | Parameters: 39 | seed(int) -- the random seed 40 | """ 41 | torch.manual_seed(seed) 42 | torch.cuda.manual_seed_all(seed) 43 | np.random.seed(seed) 44 | random.seed(seed) 45 | torch.backends.cudnn.deterministic = True 46 | 47 | 48 | def get_time_points(T_max, time_num, extra_time_percent=0.1): 49 | """ 50 | Get time points for the MTLR model 51 | """ 52 | # Get time points in the time axis 53 | time_points = np.linspace(0, T_max * (1 + extra_time_percent), time_num + 1) 54 | 55 | return time_points 56 | --------------------------------------------------------------------------------