├── model_rna_atac ├── processed_data │ └── transcription_factor │ │ ├── shared_cells.pkl │ │ ├── diff_atac_shared_cells.npz │ │ └── diff_expr_shared_cells.npz ├── README.md ├── configs │ ├── cross_modal_supervision_100.yaml │ ├── cross_modal_supervision_50.yaml │ └── cross_modal_supervision_0.yaml ├── load_model.py ├── train.py ├── networks.py ├── trainer.py └── utils.py ├── LICENSE.txt ├── README.md ├── environment2.yml ├── environment.yml ├── dataloader.py ├── train_ae.py ├── model.py └── train_rna_image.py /model_rna_atac/processed_data/transcription_factor/shared_cells.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uhlerlab/cross-modal-autoencoders/HEAD/model_rna_atac/processed_data/transcription_factor/shared_cells.pkl -------------------------------------------------------------------------------- /model_rna_atac/processed_data/transcription_factor/diff_atac_shared_cells.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uhlerlab/cross-modal-autoencoders/HEAD/model_rna_atac/processed_data/transcription_factor/diff_atac_shared_cells.npz -------------------------------------------------------------------------------- /model_rna_atac/processed_data/transcription_factor/diff_expr_shared_cells.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uhlerlab/cross-modal-autoencoders/HEAD/model_rna_atac/processed_data/transcription_factor/diff_expr_shared_cells.npz -------------------------------------------------------------------------------- /model_rna_atac/README.md: -------------------------------------------------------------------------------- 1 | # Multi-Domain Translation between Single-Cell Imaging and Sequencing Data using Autoencoders 2 | 3 | Code for training RNA-seq - ATAC-seq models. 4 | 5 | ## Usage instructions 6 | 7 | ```bash 8 | python train.py --config configs/cross_modal_supervision_100.yaml --output_path outputs 9 | ``` 10 | 11 | ## Expected output 12 | 13 | PyTorch model checkpoint files in outputs directory that can be loaded as well as logs. 14 | 15 | An example for loading the model & data is provided in `load_model.py` 16 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Karren Yang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Multi-Domain Translation between Single-Cell Imaging and Sequencing Data using Autoencoders 2 | 3 | This is the accompanying code for the paper, "Multi-Domain Translation between Single-Cell Imaging and Sequencing Data using Autoencoders" ([bioRxiv](https://www.biorxiv.org/content/10.1101/2019.12.13.875922v1.full)) 4 | 5 | The preprocessing scripts for the raw data files can be found at this repo ([link](https://github.com/SaradhaVenkatachalapathy/Radial_chromatin_packing_immune_cells)). 6 | 7 | The preprocessed data files can be downloaded from Dropbox ([link](https://www.dropbox.com/sh/6bt78xzzdju6wew/AABqINyjX7O5uRa0pILXfLDna?dl=0)). 8 | 9 | 10 | ## 1. Installation instructions 11 | 12 | Packages are listed in environment.yml file and can be installed using Anaconda/Miniconda: 13 | 14 | ```bash 15 | conda env create -f environment.yml 16 | conda activate pytorch 17 | ``` 18 | This code was tested on NVIDIA GTX 1080TI GPU. 19 | 20 | ## 2. Usage instructions 21 | 22 | Training the image autoencoder with classifier in latent space: 23 | 24 | ``` 25 | python train_ae.py --save-dir --conditional 26 | ``` 27 | 28 | Integrating the RNA autoencoder with conditional discriminator in latent space: 29 | 30 | ``` 31 | python train_rna_image.py --save-dir --pretrained-file --conditional-adv 32 | ``` 33 | 34 | ## 3. Expected output 35 | 36 | Output is log file and PyTorch checkpoint files when code is run on gene expression and imaging data. 37 | -------------------------------------------------------------------------------- /environment2.yml: -------------------------------------------------------------------------------- 1 | name: pytorch 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1 7 | - blas=1.0 8 | - ca-certificates=2020.7.22 9 | - certifi=2020.6.20 10 | - cudatoolkit=10.1.243 11 | - freetype=2.10.2 12 | - intel-openmp=2020.2 13 | - joblib=0.16.0 14 | - jpeg=9b 15 | - lcms2=2.11 16 | - ld_impl_linux-64=2.33.1 17 | - libedit=3.1.20191231 18 | - libffi=3.3 19 | - libgcc-ng=9.1.0 20 | - libgfortran-ng=7.3.0 21 | - libpng=1.6.37 22 | - libstdcxx-ng=9.1.0 23 | - libtiff=4.1.0 24 | - lz4-c=1.9.2 25 | - mkl=2020.2 26 | - mkl-service=2.3.0 27 | - mkl_fft=1.1.0 28 | - mkl_random=1.1.1 29 | - ncurses=6.2 30 | - ninja=1.10.1 31 | - numpy=1.19.1 32 | - numpy-base=1.19.1 33 | - olefile=0.46 34 | - openssl=1.1.1g 35 | - pandas=1.1.1 36 | - pillow=7.2.0 37 | - pip=20.2.2 38 | - python=3.8.5 39 | - python-dateutil=2.8.1 40 | - pytorch=1.6.0 41 | - pytz=2020.1 42 | - readline=8.0 43 | - scikit-learn=0.23.2 44 | - scipy=1.5.2 45 | - setuptools=49.6.0 46 | - six=1.15.0 47 | - sqlite=3.33.0 48 | - threadpoolctl=2.1.0 49 | - tk=8.6.10 50 | - torchvision=0.7.0 51 | - wheel=0.35.1 52 | - xz=5.2.5 53 | - zlib=1.2.11 54 | - zstd=1.4.5 55 | - pip: 56 | - cycler==0.10.0 57 | - decorator==4.4.2 58 | - imageio==2.9.0 59 | - kiwisolver==1.2.0 60 | - matplotlib==3.3.2 61 | - networkx==2.5 62 | - opencv-python==4.4.0.42 63 | - pyparsing==2.4.7 64 | - pywavelets==1.1.1 65 | - scikit-image==0.17.2 66 | - tifffile==2020.9.3 67 | -------------------------------------------------------------------------------- /model_rna_atac/configs/cross_modal_supervision_100.yaml: -------------------------------------------------------------------------------- 1 | # Model config. 2 | # This code is based on https://github.com/NVlabs/MUNIT. 3 | # 4 | # logger options 5 | snapshot_save_iter: 250 # How often do you want to save trained models (now in epochs) 6 | log_iter: 100 # How often do you want to log the training stats (now in epochs) 7 | 8 | 9 | # optimization options 10 | max_iter: 4000 # maximum number of training iterations 11 | batch_size: 32 # batch size 12 | log_data : True # take a log1p of the data 13 | normalize_data: True # normalize the data (after the log, if applicable) 14 | weight_decay: 0.0001 # weight decay 15 | beta1: 0.5 # Adam parameter 16 | beta2: 0.999 # Adam parameter 17 | init: kaiming # initialization [gaussian/kaiming/xavier/orthogonal] 18 | lr: 0.0001 # initial learning rate 19 | lr_policy: step # learning rate scheduler 20 | step_size: 100000 # how often to decay learning rate 21 | gamma: 0.5 # how much to decay learning rate 22 | gan_w: 10 # weight of adversarial loss 23 | recon_x_w: 10 # weight of image reconstruction loss 24 | recon_h_w: 0 # weight of hidden reconstruction loss 25 | recon_kl_w: 0 # weight of KL loss for reconstruction 26 | 27 | 28 | supervise: 1 # fraction to supervise 29 | super_w: 0.1 # weight of supervision loss 30 | 31 | # model options 32 | 33 | shared_layer: True 34 | gen: 35 | dim: 100 # hidden layer 36 | latent: 50 # latent layer size 37 | activ: relu # activation function [relu/lrelu/prelu/selu/tanh] 38 | dis: 39 | dim: 100 # number of filters in the bottommost layer 40 | norm: none # normalization layer [none/bn/in/ln] 41 | activ: lrelu # activation function [relu/lrelu/prelu/selu/tanh] 42 | gan_type: lsgan # GAN loss [lsgan/nsgan] 43 | 44 | # data options 45 | input_dim_a: 815 # input dim of a 46 | input_dim_b: 2613 # input dim of b 47 | -------------------------------------------------------------------------------- /model_rna_atac/configs/cross_modal_supervision_50.yaml: -------------------------------------------------------------------------------- 1 | # Model config. 2 | # This code is based on https://github.com/NVlabs/MUNIT. 3 | # 4 | # logger options 5 | snapshot_save_iter: 250 # How often do you want to save trained models (now in epochs) 6 | log_iter: 100 # How often do you want to log the training stats (now in epochs) 7 | 8 | 9 | # optimization options 10 | max_iter: 4000 # maximum number of training iterations 11 | batch_size: 32 # batch size 12 | log_data : True # take a log1p of the data 13 | normalize_data: True # normalize the data (after the log, if applicable) 14 | weight_decay: 0.0001 # weight decay 15 | beta1: 0.5 # Adam parameter 16 | beta2: 0.999 # Adam parameter 17 | init: kaiming # initialization [gaussian/kaiming/xavier/orthogonal] 18 | lr: 0.0001 # initial learning rate 19 | lr_policy: step # learning rate scheduler 20 | step_size: 100000 # how often to decay learning rate 21 | gamma: 0.5 # how much to decay learning rate 22 | gan_w: 10 # weight of adversarial loss 23 | recon_x_w: 10 # weight of image reconstruction loss 24 | recon_h_w: 0 # weight of hidden reconstruction loss 25 | recon_kl_w: 0 # weight of KL loss for reconstruction 26 | 27 | 28 | supervise: 0.5 #fraction to supervise 29 | super_w: 0.1 # weight of supervision loss 30 | 31 | # model options 32 | 33 | shared_layer: True 34 | gen: 35 | dim: 100 # hidden layer 36 | latent: 50 # latent layer size 37 | activ: relu # activation function [relu/lrelu/prelu/selu/tanh] 38 | dis: 39 | dim: 100 # number of filters in the bottommost layer 40 | norm: none # normalization layer [none/bn/in/ln] 41 | activ: lrelu # activation function [relu/lrelu/prelu/selu/tanh] 42 | gan_type: lsgan # GAN loss [lsgan/nsgan] 43 | 44 | # data options 45 | input_dim_a: 815 # input dim of a 46 | input_dim_b: 2613 # input dim of b 47 | -------------------------------------------------------------------------------- /model_rna_atac/configs/cross_modal_supervision_0.yaml: -------------------------------------------------------------------------------- 1 | # Model config. 2 | # This code is based on https://github.com/NVlabs/MUNIT. 3 | # 4 | # logger options 5 | snapshot_save_iter: 250 # How often do you want to save trained models (now in epochs) 6 | log_iter: 100 # How often do you want to log the training stats (now in epochs) 7 | 8 | 9 | # optimization options 10 | max_iter: 4000 # maximum number of training iterations 11 | batch_size: 32 # batch size 12 | log_data : True # take a log1p of the data 13 | normalize_data: True # normalize the data (after the log, if applicable) 14 | weight_decay: 0.0001 # weight decay 15 | beta1: 0.5 # Adam parameter 16 | beta2: 0.999 # Adam parameter 17 | init: kaiming # initialization [gaussian/kaiming/xavier/orthogonal] 18 | lr: 0.0001 # initial learning rate 19 | lr_policy: step # learning rate scheduler 20 | step_size: 100000 # how often to decay learning rate 21 | gamma: 0.5 # how much to decay learning rate 22 | gan_w: 10 # weight of adversarial loss 23 | recon_x_w: 10 # weight of image reconstruction loss 24 | recon_h_w: 0 # weight of hidden reconstruction loss 25 | recon_kl_w: 0 # weight of KL loss for reconstruction 26 | 27 | 28 | supervise: .01 # fraction to supervise (note: for 0% supervision still specifying != 0 is OK since the weight of supervision loss = 0, this is to avoid errors) 29 | super_w: 0 # weight of supervision loss 30 | 31 | # model options 32 | 33 | shared_layer: True 34 | gen: 35 | dim: 100 # hidden layer 36 | latent: 50 # latent layer size 37 | activ: relu # activation function [relu/lrelu/prelu/selu/tanh] 38 | dis: 39 | dim: 100 # number of filters in the bottommost layer 40 | norm: none # normalization layer [none/bn/in/ln] 41 | activ: lrelu # activation function [relu/lrelu/prelu/selu/tanh] 42 | gan_type: lsgan # GAN loss [lsgan/nsgan] 43 | 44 | # data options 45 | input_dim_a: 815 # input dim of a 46 | input_dim_b: 2613 # input dim of b 47 | -------------------------------------------------------------------------------- /model_rna_atac/load_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.neighbors import NearestNeighbors 3 | from collections import defaultdict 4 | from sklearn import metrics 5 | import torch 6 | from scipy import sparse 7 | import torch.utils.data as utils 8 | from torch.autograd import Variable 9 | torch.cuda.set_device(1) 10 | import pickle 11 | import pandas as pd 12 | import os 13 | import sys 14 | from utils import load_data_for_latent_space_plot, get_all_data_loaders, get_config, get_model_list 15 | from trainer import Trainer 16 | from sklearn.preprocessing import StandardScaler 17 | 18 | def load_data(isatac=False, data_size=1874, for_training=True, supervise=[], drop=False): 19 | DATA_DIRECTORY = "processed_data/transcription_factor/" 20 | log_data = True 21 | normalize_data = True 22 | 23 | if isatac: 24 | f = DATA_DIRECTORY + "diff_atac_shared_cells.npz" 25 | else: 26 | f = DATA_DIRECTORY + "diff_expr_shared_cells.npz" 27 | 28 | data = sparse.load_npz(f).T.todense() 29 | 30 | if drop: 31 | print("drop") 32 | #threshold = 0.01 if isatac else 0.1 33 | threshold = 0 if isatac else 0.1 34 | acceptable = np.count_nonzero(data, axis=0) > threshold * len(data) 35 | data = data[:, acceptable.flatten().tolist()[0]] 36 | 37 | if log_data: 38 | data = np.log1p(data) 39 | if for_training: 40 | print("Taking log of data..") 41 | 42 | if normalize_data: 43 | scaler = StandardScaler() 44 | training_data = data 45 | scaler.fit(training_data) 46 | data = scaler.transform(data) 47 | 48 | return Variable(torch.from_numpy(data).float()).cuda() 49 | 50 | 51 | def get_unit_model_output(data_a, data_b, name): 52 | 53 | config = get_config("configs/%s.yaml"%name) 54 | trainer = Trainer(config) 55 | last_model_name = get_model_list("outputs/outputs/%s/checkpoints/"%name, "gen") 56 | state_dict = torch.load(last_model_name) 57 | trainer.gen_a.load_state_dict(state_dict['a']) 58 | trainer.gen_b.load_state_dict(state_dict['b']) 59 | trainer.cuda() 60 | 61 | latent_a = trainer.gen_a.enc(data_a).data.cpu().numpy() 62 | latent_b = trainer.gen_b.enc(data_b).data.cpu().numpy() 63 | 64 | return latent_a, latent_b 65 | 66 | 67 | def main(): 68 | data_a = load_data(isatac=True) 69 | data_b = load_data(isatac=False) 70 | 71 | names = ["cross_modal_supervision_100"] 72 | labels = ['Cross-modal autoencoders 100%'] 73 | 74 | atac_seq_proj = {} 75 | expr_seq_proj = {} 76 | for name, label in zip(names, labels): 77 | latent_a, latent_b = get_unit_model_output(data_a, data_b, name) 78 | atac_seq_proj[label] = latent_a 79 | expr_seq_proj[label] = latent_b 80 | 81 | print(atac_seq_proj) 82 | print(expr_seq_proj) 83 | 84 | if __name__ == "__main__": 85 | main() 86 | -------------------------------------------------------------------------------- /model_rna_atac/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Main functionality for starting training. 3 | This code is based on https://github.com/NVlabs/MUNIT. 4 | """ 5 | import torch 6 | 7 | torch.cuda.set_device(0) 8 | from utils import get_all_data_loaders, prepare_sub_folder, write_loss, get_config, save_plots, load_supervision, \ 9 | write_knn 10 | import argparse 11 | from torch.autograd import Variable 12 | from trainer import Trainer 13 | import torch.backends.cudnn as cudnn 14 | 15 | try: 16 | from itertools import izip as zip 17 | except ImportError: 18 | pass 19 | import os 20 | import sys 21 | import tensorboardX 22 | import shutil 23 | 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument('--config', type=str, help='Path to the config file.') 26 | parser.add_argument('--output_path', type=str, default='.', help="outputs path") 27 | parser.add_argument("--resume", action="store_true") 28 | opts = parser.parse_args() 29 | 30 | cudnn.benchmark = True 31 | 32 | # Load experiment setting 33 | config = get_config(opts.config) 34 | max_iter = config['max_iter'] 35 | 36 | trainer = Trainer(config) 37 | 38 | trainer.cuda() 39 | 40 | train_loader_a, train_loader_b, test_loader_a, test_loader_b = get_all_data_loaders(config) 41 | super_a, super_b = load_supervision(config, supervise=config["supervise"]) 42 | 43 | 44 | # Setup logger and output folders 45 | model_name = os.path.splitext(os.path.basename(opts.config))[0] 46 | train_writer = tensorboardX.SummaryWriter(os.path.join(opts.output_path + "/logs", model_name)) 47 | output_directory = os.path.join(opts.output_path + "/outputs", model_name) 48 | checkpoint_directory, image_directory = prepare_sub_folder(output_directory) 49 | shutil.copy(opts.config, os.path.join(output_directory, 'config.yaml')) # copy config file to output folder 50 | 51 | # Start training 52 | iterations = trainer.resume(checkpoint_directory, hyperparameters=config) if opts.resume else 0 53 | num_disc = 1 if "num_disc" not in config else config["num_disc"] 54 | num_gen = 1 if "num_gen" not in config else config["num_gen"] 55 | 56 | while True: 57 | for it, (images_a, images_b) in enumerate(zip(train_loader_a, train_loader_b)): 58 | trainer.update_learning_rate() 59 | labels_a, labels_b = Variable(images_a[1]).cuda(), Variable(images_b[1]).cuda() 60 | images_a, images_b = Variable(images_a[0]).cuda(), Variable(images_b[0]).cuda() 61 | # Main training code 62 | 63 | for _ in range(num_disc): 64 | trainer.dis_update(images_a, images_b, config) 65 | for _ in range(num_gen): 66 | trainer.gen_update(images_a, images_b, labels_a, labels_b, super_a, super_b, config, variational=False) 67 | torch.cuda.synchronize() 68 | 69 | # Dump training stats in log file 70 | if (iterations + 1) % config['log_iter'] == 0: 71 | # print("Iteration: %08d/%08d" % (iterations + 1, max_iter)) 72 | write_loss(iterations, trainer, train_writer) 73 | write_knn(trainer, image_directory, str(iterations)) 74 | 75 | # Save network weights 76 | if (iterations + 1) % config['snapshot_save_iter'] == 0: 77 | trainer.save(checkpoint_directory, iterations) 78 | save_plots(trainer, image_directory, str(iterations)) 79 | 80 | iterations += 1 81 | if iterations >= max_iter: 82 | sys.exit('Finish training') 83 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: pytorch 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=main 8 | - attrs=19.1.0=py37_1 9 | - backcall=0.1.0=py37_0 10 | - blas=1.0=mkl 11 | - bleach=3.1.0=py37_0 12 | - bzip2=1.0.8=h7b6447c_0 13 | - ca-certificates=2019.5.15=1 14 | - cairo=1.14.12=h8948797_3 15 | - certifi=2019.6.16=py37_1 16 | - cffi=1.11.5=py37he75722e_1 17 | - cloudpickle=0.5.5=py_0 18 | - cuda92=1.0=0 19 | - cudatoolkit=10.0.130=0 20 | - cycler=0.10.0=py_1 21 | - dask-core=0.19.1=py_0 22 | - dbus=1.13.2=h714fa37_1 23 | - decorator=4.3.0=py_0 24 | - defusedxml=0.5.0=py37_1 25 | - entrypoints=0.3=py37_0 26 | - expat=2.2.5=hfc679d8_2 27 | - ffmpeg=4.0=hcdf2ecd_0 28 | - fontconfig=2.13.0=h65d0f4c_6 29 | - freeglut=3.0.0=hf484d3e_5 30 | - freetype=2.9.1=h8a8886c_1 31 | - glib=2.56.2=hd408876_0 32 | - gmp=6.1.2=h6c8ec71_1 33 | - graphite2=1.3.13=h23475e2_0 34 | - gst-plugins-base=1.14.0=hbbd80ab_1 35 | - gstreamer=1.14.0=hb453b48_1 36 | - harfbuzz=1.8.8=hffaf4a1_0 37 | - hdf5=1.10.2=hba1933b_1 38 | - icu=58.2=hfc679d8_0 39 | - imageio=2.3.0=py_1 40 | - intel-openmp=2018.0.3=0 41 | - ipykernel=5.1.0=py37h39e3cac_0 42 | - ipython=7.4.0=py37h39e3cac_0 43 | - ipython_genutils=0.2.0=py37_0 44 | - jasper=2.0.14=h07fcdf6_1 45 | - jedi=0.13.3=py37_0 46 | - jinja2=2.10.1=py37_0 47 | - jpeg=9b=h024ee3a_2 48 | - jsonschema=3.0.1=py37_0 49 | - jupyter_client=5.2.4=py37_0 50 | - jupyter_core=4.4.0=py37_0 51 | - kiwisolver=1.0.1=py37h2d50403_2 52 | - libedit=3.1.20170329=h6b74fdf_2 53 | - libffi=3.2.1=hd88cf55_4 54 | - libgcc-ng=8.2.0=hdf63c60_1 55 | - libgfortran-ng=7.3.0=hdf63c60_0 56 | - libglu=9.0.0=hf484d3e_1 57 | - libiconv=1.15=h470a237_3 58 | - libopencv=3.4.2=hb342d67_1 59 | - libopus=1.3=h7b6447c_0 60 | - libpng=1.6.36=hbc83047_0 61 | - libsodium=1.0.16=h1bed415_0 62 | - libstdcxx-ng=8.2.0=hdf63c60_1 63 | - libtiff=4.0.9=he85c1e1_2 64 | - libuuid=2.32.1=h470a237_2 65 | - libvpx=1.7.0=h439df22_0 66 | - libxcb=1.13=h470a237_2 67 | - libxml2=2.9.8=h422b904_4 68 | - markupsafe=1.1.1=py37h7b6447c_0 69 | - matplotlib=2.2.3=py37h8e2386c_0 70 | - mistune=0.8.4=py37h7b6447c_0 71 | - mkl=2018.0.3=1 72 | - mkl_fft=1.0.4=py37h4414c95_1 73 | - mkl_random=1.0.1=py37h4414c95_1 74 | - nb_conda=2.2.1=py37_0 75 | - nb_conda_kernels=2.2.0=py37_1 76 | - nbconvert=5.4.1=py37_3 77 | - nbformat=4.4.0=py37_0 78 | - ncurses=6.1=hf484d3e_0 79 | - networkx=2.1=py_1 80 | - ninja=1.8.2=py37h6bb024c_1 81 | - notebook=5.7.8=py37_0 82 | - numpy=1.15.1=py37h3b04361_0 83 | - numpy-base=1.15.1=py37h81de0dd_0 84 | - olefile=0.45.1=py37_0 85 | - opencv=3.4.2=py37h6fd60c2_1 86 | - openssl=1.1.1d=h7b6447c_1 87 | - pandas=0.25.1=py37he6710b0_0 88 | - pandoc=2.2.3.2=0 89 | - pandocfilters=1.4.2=py37_1 90 | - parso=0.4.0=py_0 91 | - pcre=8.42=h439df22_0 92 | - pexpect=4.7.0=py37_0 93 | - pickleshare=0.7.5=py37_0 94 | - pillow=5.2.0=py37heded4f4_0 95 | - pip=10.0.1=py37_0 96 | - pixman=0.38.0=h7b6447c_0 97 | - prometheus_client=0.6.0=py37_0 98 | - prompt_toolkit=2.0.9=py37_0 99 | - pthread-stubs=0.4=h470a237_1 100 | - ptyprocess=0.6.0=py37_0 101 | - py-opencv=3.4.2=py37hb342d67_1 102 | - pycparser=2.18=py37_1 103 | - pygments=2.3.1=py37_0 104 | - pyparsing=2.2.0=py_1 105 | - pyqt=5.6.0=py37h8210e8a_7 106 | - pyrsistent=0.14.11=py37h7b6447c_0 107 | - python=3.7.2=h0371630_0 108 | - python-dateutil=2.7.3=py_0 109 | - pytorch=1.2.0=py3.7_cuda10.0.130_cudnn7.6.2_0 110 | - pytz=2018.5=py_0 111 | - pywavelets=1.0.0=py37hdd07704_0 112 | - qt=5.6.3=h8bf5577_3 113 | - readline=7.0=h7b6447c_5 114 | - scikit-image=0.14.0=py37hf484d3e_1 115 | - scikit-learn=0.19.1=py37hedc7406_0 116 | - send2trash=1.5.0=py37_0 117 | - setuptools=40.2.0=py37_0 118 | - sip=4.18.1=py37hfc679d8_0 119 | - six=1.11.0=py37_1 120 | - sqlite=3.26.0=h7b6447c_0 121 | - terminado=0.8.2=py37_0 122 | - testpath=0.4.2=py37_0 123 | - tk=8.6.8=hbc83047_0 124 | - toolz=0.9.0=py_0 125 | - torchvision=0.4.0=py37_cu100 126 | - traitlets=4.3.2=py37_0 127 | - wcwidth=0.1.7=py37_0 128 | - webencodings=0.5.1=py37_1 129 | - wheel=0.31.1=py37_0 130 | - xorg-libxau=1.0.8=h470a237_6 131 | - xorg-libxdmcp=1.1.2=h470a237_7 132 | - xz=5.2.4=h14c3975_4 133 | - zeromq=4.3.1=he6710b0_3 134 | - zlib=1.2.11=ha838bed_2 135 | - pip: 136 | - chardet==3.0.4 137 | - idna==2.7 138 | - pyzmq==17.1.2 139 | - requests==2.19.1 140 | - scipy==1.1.0 141 | - torchfile==0.1.0 142 | - tornado==5.1 143 | - urllib3==1.23 144 | - visdom==0.1.8.5 145 | - websocket-client==0.53.0 146 | 147 | -------------------------------------------------------------------------------- /model_rna_atac/networks.py: -------------------------------------------------------------------------------- 1 | """ 2 | Defines models. 3 | This code is based on https://github.com/NVlabs/MUNIT. 4 | """ 5 | from torch import nn 6 | from torch.autograd import Variable 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | try: 11 | from itertools import izip as zip 12 | except ImportError: # will be 3.x series 13 | pass 14 | 15 | 16 | ################################################################################## 17 | # Discriminator 18 | ################################################################################## 19 | 20 | class Discriminator(nn.Module): 21 | 22 | def __init__(self, input_dim, params): 23 | super(Discriminator, self).__init__() 24 | self.gan_type = params['gan_type'] 25 | self.dim = params['dim'] 26 | self.norm = params['norm'] 27 | self.input_dim = input_dim 28 | self.net = self._make_net() 29 | 30 | def _make_net(self): 31 | return nn.Sequential( 32 | nn.Linear(self.input_dim, self.input_dim), 33 | nn.LeakyReLU(0.2, inplace=True), 34 | nn.Linear(self.input_dim, self.dim), 35 | nn.LeakyReLU(0.2, inplace=True), 36 | nn.Linear(self.dim, 1), 37 | nn.LeakyReLU(0.2, inplace=True), 38 | ) 39 | 40 | def forward(self, x): 41 | return self.net(x) 42 | 43 | def calc_dis_loss(self, input_fake, input_real): 44 | # calculate the loss to train D 45 | outs0 = [self.forward(input_fake)] 46 | outs1 = [self.forward(input_real)] 47 | loss = 0 48 | 49 | for it, (out0, out1) in enumerate(zip(outs0, outs1)): 50 | loss += torch.mean((out0 - 0) ** 2) + torch.mean((out1 - 1) ** 2) 51 | return loss 52 | 53 | def calc_gen_loss(self, input_fake): 54 | # calculate the loss to train G 55 | outs0 = [self.forward(input_fake)] 56 | loss = 0 57 | for it, (out0) in enumerate(outs0): 58 | # 1 = real data 59 | loss += torch.mean((out0 - 1) ** 2) 60 | return loss 61 | 62 | def calc_gen_loss_reverse(self, input_real): 63 | # calculate the loss to train G 64 | outs0 = [self.forward(input_real)] 65 | loss = 0 66 | for it, (out0) in enumerate(outs0): 67 | # 0 = fake data 68 | loss += torch.mean((out0 - 0) ** 2) 69 | return loss 70 | 71 | def calc_gen_loss_half(self, input_fake): 72 | # calculate the loss to train G 73 | outs0 = [self.forward(input_fake)] 74 | loss = 0 75 | for it, (out0) in enumerate(outs0): 76 | if self.gan_type == 'lsgan': 77 | loss += torch.mean((out0 - 0.5) ** 2) 78 | elif self.gan_type == 'nsgan': 79 | all1 = Variable(torch.ones_like(out0.data).cuda(), requires_grad=False) 80 | loss += torch.mean(F.binary_cross_entropy(F.sigmoid(out0), all1)) 81 | else: 82 | assert 0, "Unsupported GAN type: {}".format(self.gan_type) 83 | return loss 84 | 85 | 86 | ################################################################################## 87 | # Generator 88 | ################################################################################## 89 | 90 | class VAEGen_MORE_LAYERS(nn.Module): 91 | # VAE architecture 92 | def __init__(self, input_dim, params, shared_layer=False): 93 | super(VAEGen_MORE_LAYERS, self).__init__() 94 | self.dim = params['dim'] 95 | self.latent = params['latent'] 96 | self.input_dim = input_dim 97 | 98 | encoder_layers = [nn.Linear(self.input_dim, self.input_dim), 99 | nn.LeakyReLU(0.2, inplace=True), 100 | nn.Linear(self.input_dim, self.input_dim), 101 | nn.LeakyReLU(0.2, inplace=True), 102 | nn.Linear(self.input_dim, self.input_dim), 103 | nn.LeakyReLU(0.2, inplace=True), 104 | nn.Linear(self.input_dim, self.dim), 105 | nn.LeakyReLU(0.2, inplace=True)] 106 | 107 | decoder_layers = [nn.LeakyReLU(0.2, inplace=True), 108 | nn.Linear(self.dim, self.input_dim), 109 | nn.LeakyReLU(0.2, inplace=True), 110 | nn.Linear(self.input_dim, self.input_dim), 111 | nn.LeakyReLU(0.2, inplace=True), 112 | nn.Linear(self.input_dim, self.input_dim), 113 | nn.LeakyReLU(0.2, inplace=True), 114 | nn.Linear(self.input_dim, self.input_dim), 115 | nn.LeakyReLU(0.2, inplace=True)] 116 | 117 | if shared_layer: 118 | encoder_layers += [shared_layer["enc"], nn.LeakyReLU(0.2, inplace=True)] 119 | decoder_layers = [shared_layer["dec"]] + decoder_layers 120 | else: 121 | encoder_layers += [nn.Linear(self.dim, self.latent), nn.LeakyReLU(0.2, inplace=True)] 122 | decoder_layers = [nn.Linear(self.latent, self.dim)] + decoder_layers 123 | self.enc = nn.Sequential(*encoder_layers) 124 | self.dec = nn.Sequential(*decoder_layers) 125 | 126 | def forward(self, images): 127 | # This is a reduced VAE implementation where we assume the outputs are multivariate Gaussian distribution with mean = hiddens and std_dev = all ones. 128 | hiddens = self.encode(images) 129 | if self.training == True: 130 | noise = Variable(torch.randn(hiddens.size()).cuda(hiddens.data.get_device())) 131 | images_recon = self.decode(hiddens + noise) 132 | else: 133 | images_recon = self.decode(hiddens) 134 | return images_recon, hiddens 135 | 136 | def encode(self, images): 137 | hiddens = self.enc(images) 138 | noise = Variable(torch.randn(hiddens.size()).cuda(hiddens.data.get_device())) 139 | return hiddens, noise 140 | 141 | def decode(self, hiddens): 142 | images = self.dec(hiddens) 143 | return images 144 | 145 | ################################################################################## 146 | # Classifier 147 | ################################################################################## 148 | 149 | class Classifier(nn.Module): 150 | def __init__(self, input_dim): 151 | super(Classifier, self).__init__() 152 | self.input_dim = input_dim 153 | self.net = self._make_net() 154 | 155 | self.cel = nn.CrossEntropyLoss() 156 | 157 | def _make_net(self): 158 | return nn.Sequential( 159 | nn.Linear(self.input_dim, 3) 160 | ) 161 | 162 | def forward(self, x): 163 | return self.net(x) 164 | 165 | def class_loss(self, input, target): 166 | return self.cel(input, target) 167 | 168 | 169 | -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | 4 | import numpy as np 5 | import pandas as pd 6 | from sklearn.model_selection import train_test_split 7 | import cv2 8 | from PIL import Image 9 | from skimage import io 10 | 11 | import os 12 | 13 | class ToTensorNormalize(object): 14 | """Convert ndarrays in sample to Tensors.""" 15 | 16 | def __call__(self, sample): 17 | image_tensor = sample['image_tensor'] 18 | 19 | # rescale by maximum and minimum of the image tensor 20 | minX = image_tensor.min() 21 | maxX = image_tensor.max() 22 | image_tensor=(image_tensor-minX)/(maxX-minX) 23 | 24 | # resize the inputs 25 | # torch image tensor expected for 3D operations is (N, C, D, H, W) 26 | image_tensor = image_tensor.max(axis=0) 27 | image_tensor = cv2.resize(image_tensor, dsize=(64, 64), interpolation=cv2.INTER_CUBIC) 28 | image_tensor = np.clip(image_tensor, 0, 1) 29 | return torch.from_numpy(image_tensor).view(1, 64, 64) 30 | 31 | class NucleiDatasetNew(Dataset): 32 | def __init__(self, datadir, mode='train', transform=ToTensorNormalize()): 33 | self.datadir = datadir 34 | self.mode = mode 35 | self.images = self.load_images() 36 | self.transform = transform 37 | self.threshold = 0.74 38 | 39 | # Utility function to load images from a HDF5 file 40 | def load_images(self): 41 | # load labels 42 | label_data = pd.read_csv(os.path.join(self.datadir, "ratio.csv")) 43 | label_data_2 = pd.read_csv(os.path.join(self.datadir, "protein_ratios_full.csv")) 44 | label_data = label_data.merge(label_data_2, how='inner', on='Label') 45 | label_dict = {name: (float(ratio), np.abs(int(cl)-2)) for (name, ratio, cl) in zip(list(label_data['Label']), list(label_data['Cor/RPL']), list(label_data['mycl']))} 46 | label_dict_2 = {name: np.abs(int(cl)-2) for (name, cl) in zip(list(label_data_2['Label']), list(label_data_2['mycl']))} 47 | del label_data 48 | del label_data_2 49 | 50 | # load images 51 | images_train = [] 52 | images_test = [] 53 | 54 | for f in os.listdir(os.path.join(self.datadir, "images")): 55 | basename = os.path.splitext(f)[0] 56 | fname = os.path.join(os.path.join(self.datadir, "images"), f) 57 | if basename in label_dict.keys(): 58 | images_test.append({'name': basename, 'label': label_dict[basename][0], 'image_tensor': np.float32(io.imread(fname)), 'binary_label': label_dict[basename][1]}) 59 | else: 60 | try: 61 | images_train.append({'name': basename, 'label': -1, 'image_tensor': np.float32(io.imread(fname)), 'binary_label': label_dict_2[basename]}) 62 | except Exception as e: 63 | pass 64 | 65 | if self.mode == 'train': 66 | return images_train 67 | elif self.mode == 'test': 68 | return images_test 69 | else: 70 | raise KeyError("Mode %s is invalid, must be 'train' or 'test'" % self.mode) 71 | 72 | def __len__(self): 73 | return len(self.images) 74 | 75 | def __getitem__(self, idx): 76 | sample = self.images[idx] 77 | 78 | if self.transform: 79 | # transform the tensor and the particular z-slice 80 | image_tensor = self.transform(sample) 81 | return {'image_tensor': image_tensor, 'name': sample['name'], 'label': sample['label'], 'binary_label': sample['binary_label']} 82 | return sample 83 | 84 | class ATAC_Dataset(Dataset): 85 | def __init__(self, datadir): 86 | self.datadir = datadir 87 | self.atac_data, self.labels = self._load_atac_data() 88 | 89 | def __len__(self): 90 | return len(self.atac_data) 91 | 92 | def __getitem__(self, idx): 93 | atac_sample = self.atac_data[idx] 94 | cluster = self.labels[idx] 95 | return {'tensor': torch.from_numpy(atac_sample).float(), 'binary_label': int(cluster)} 96 | 97 | def _load_atac_data(self): 98 | data = pd.read_csv(os.path.join(self.datadir, "df_peak_counts_names_nCD4_seuratnorm.csv"), index_col=0) 99 | data = data.transpose() 100 | labels = pd.read_csv(os.path.join(self.datadir, "clustlabels_peak_counts_names_nCD4_seurat_n_2.csv"), index_col=0) 101 | 102 | data = labels.merge(data, left_index=True, right_index=True) 103 | data = data.values 104 | 105 | return data[:,1:], data[:,0] 106 | 107 | 108 | class RNA_Dataset(Dataset): 109 | def __init__(self, datadir): 110 | self.datadir = datadir 111 | self.rna_data, self.labels = self._load_rna_data() 112 | 113 | def __len__(self): 114 | return len(self.rna_data) 115 | 116 | def __getitem__(self, idx): 117 | rna_sample = self.rna_data[idx] 118 | cluster = self.labels[idx] 119 | coro1a = rna_sample[5849] 120 | rpl10a = rna_sample[2555] 121 | return {'tensor': torch.from_numpy(rna_sample).float(), 'coro1a': coro1a, 'rpl10a': rpl10a, 'label': coro1a/rpl10a, 'binary_label': int(cluster)} 122 | 123 | def _load_rna_data(self): 124 | data = pd.read_csv(os.path.join(self.datadir, "filtered_lognuminorm_pc_rp_7633genes_1396cellsnCD4.csv"), index_col=0) 125 | data = data.transpose() 126 | labels = pd.read_csv(os.path.join(self.datadir, "labels_nCD4_corrected.csv"), index_col=0) 127 | 128 | data = labels.merge(data, left_index=True, right_index=True) 129 | data = data.values 130 | 131 | return data[:,1:], np.abs(data[:,0]-1) 132 | 133 | 134 | def print_nuclei_names(): 135 | dataset = NucleiDatasetNew(datadir="data/nuclear_crops_all_experiments", mode='test') 136 | for sample in dataset: 137 | print(sample['name']) 138 | 139 | def test_nuclei_dataset(): 140 | dataset = NucleiDatasetNew(datadir="data/nuclear_crops_all_experiments", mode='train') 141 | print(len(dataset)) 142 | sample = dataset[0] 143 | print(sample['image_tensor'].shape) 144 | print(sample['binary_label']) 145 | 146 | labels = 0 147 | for sample in dataset: 148 | labels += sample['binary_label'] 149 | print(labels) 150 | 151 | def test_atac_loader(): 152 | dataset = ATAC_Dataset(datadir="data/atac_seq_data") 153 | print(len(dataset)) 154 | sample = dataset[0] 155 | print(torch.max(sample['tensor'])) 156 | print(sample['tensor'].shape) 157 | for k in sample.keys(): 158 | print(k) 159 | print(sample[k]) 160 | 161 | def test_rna_loader(): 162 | dataset = RNA_Dataset(datadir="data/nCD4_gene_exp_matrices") 163 | print(len(dataset)) 164 | sample = dataset[0] 165 | print(torch.max(sample['tensor'])) 166 | print(sample['tensor'].shape) 167 | for k in sample.keys(): 168 | print(k) 169 | print(sample[k]) 170 | 171 | if __name__ == "__main__": 172 | test_nuclei_dataset() 173 | -------------------------------------------------------------------------------- /train_ae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, optim 3 | from torch.autograd import Variable 4 | from torch.utils.data import DataLoader 5 | 6 | from dataloader import NucleiDatasetNew as NucleiDataset 7 | import model as AENet 8 | 9 | import argparse 10 | import numpy as np 11 | import sys 12 | import os 13 | import imageio 14 | 15 | # adapted from pytorch/examples/vae and ethanluoyc/pytorch-vae 16 | 17 | # parse arguments 18 | def setup_args(): 19 | 20 | options = argparse.ArgumentParser() 21 | 22 | options.add_argument('--save-dir', action="store", dest="save_dir") 23 | options.add_argument('-pt', action="store", dest="pretrained_file", default=None) 24 | options.add_argument('-bs', action="store", dest="batch_size", default = 128, type = int) 25 | options.add_argument('-ds', action="store", dest="datadir", default = "data/nuclear_crops_all_experiments/") 26 | 27 | options.add_argument('-iter', action="store", dest="max_iter", default = 800, type = int) 28 | options.add_argument('-lr', action="store", dest="lr", default=1e-3, type = float) 29 | options.add_argument('-nz', action="store", dest="nz", default=128, type = int) 30 | options.add_argument('-lamb', action="store", dest="lamb", default=0.0000001, type = float) 31 | options.add_argument('-lamb2', action="store", dest="lamb2", default=0.001, type = float) 32 | options.add_argument('--conditional', action="store_true") 33 | 34 | return options.parse_args() 35 | 36 | args = setup_args() 37 | os.makedirs(args.save_dir, exist_ok=True) 38 | with open(os.path.join(args.save_dir, "log.txt"), 'w') as f: 39 | print(args, file=f) 40 | 41 | # retrieve dataloader 42 | trainset = NucleiDataset(datadir=args.datadir, mode='train') 43 | testset = NucleiDataset(datadir=args.datadir, mode='test') 44 | 45 | train_loader = DataLoader(trainset, batch_size=args.batch_size, drop_last=False, shuffle=True) 46 | test_loader = DataLoader(testset, batch_size=args.batch_size, drop_last=False, shuffle=False) 47 | 48 | print('Data loaded') 49 | 50 | model = AENet.VAE(latent_variable_size=args.nz, batchnorm=True) 51 | if args.conditional: 52 | netCondClf = AENet.Simple_Classifier(nz=args.nz) 53 | 54 | if args.pretrained_file is not None: 55 | model.load_state_dict(torch.load(args.pretrained_file)) 56 | print("Pre-trained model loaded") 57 | sys.stdout.flush() 58 | 59 | CE_weights = torch.FloatTensor([4.5, 0.5]) 60 | 61 | if torch.cuda.is_available(): 62 | print('Using GPU') 63 | model.cuda() 64 | CE_weights = CE_weights.cuda() 65 | if args.conditional: 66 | netCondClf.cuda() 67 | 68 | CE = nn.CrossEntropyLoss(CE_weights) 69 | 70 | if args.conditional: 71 | optimizer = optim.Adam(list(model.parameters())+list(netCondClf.parameters()), lr = args.lr) 72 | else: 73 | optimizer = optim.Adam([{'params': model.parameters()}], lr = args.lr) 74 | 75 | def loss_function(recon_x, x, mu, logvar, latents): 76 | MSE = nn.MSELoss() 77 | lloss = MSE(recon_x,x) 78 | 79 | if args.lamb>0: 80 | KL_loss = -0.5*torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) 81 | lloss = lloss + args.lamb*KL_loss 82 | 83 | return lloss 84 | 85 | def train(epoch): 86 | model.train() 87 | if args.conditional: 88 | netCondClf.train() 89 | 90 | train_loss = 0 91 | total_clf_loss = 0 92 | 93 | for batch_idx, samples in enumerate(train_loader): 94 | 95 | inputs = Variable(samples['image_tensor']) 96 | if torch.cuda.is_available(): 97 | inputs = inputs.cuda() 98 | 99 | optimizer.zero_grad() 100 | recon_inputs, latents, mu, logvar = model(inputs) 101 | loss = loss_function(recon_inputs, inputs, mu, logvar, latents) 102 | train_loss += loss.data.item() * inputs.size(0) 103 | 104 | if args.conditional: 105 | targets = Variable(samples['binary_label']) 106 | if torch.cuda.is_available(): 107 | targets = targets.cuda() 108 | clf_outputs = netCondClf(latents) 109 | class_clf_loss = CE(clf_outputs, targets.view(-1).long()) 110 | loss += args.lamb2 * class_clf_loss 111 | total_clf_loss += class_clf_loss.data.item() * inputs.size(0) 112 | 113 | loss.backward() 114 | optimizer.step() 115 | 116 | with open(os.path.join(args.save_dir, "log.txt"), 'a') as f: 117 | print('Epoch: {} Average loss: {:.15f} Clf loss: {:.15f} '.format(epoch, train_loss / len(train_loader.dataset), total_clf_loss / len(train_loader.dataset)), file=f) 118 | 119 | def test(epoch): 120 | model.eval() 121 | if args.conditional: 122 | netCondClf.eval() 123 | 124 | test_loss = 0 125 | total_clf_loss = 0 126 | 127 | for i, samples in enumerate(test_loader): 128 | 129 | inputs = Variable(samples['image_tensor']) 130 | if torch.cuda.is_available(): 131 | inputs = inputs.cuda() 132 | 133 | recon_inputs, latents, mu, logvar = model(inputs) 134 | 135 | loss = loss_function(recon_inputs, inputs, mu, logvar, latents) 136 | test_loss += loss.data.item() * inputs.size(0) 137 | 138 | if args.conditional: 139 | targets = Variable(samples['binary_label']) 140 | if torch.cuda.is_available(): 141 | targets = targets.cuda() 142 | clf_outputs = netCondClf(latents) 143 | class_clf_loss = CE(clf_outputs, targets.view(-1).long()) 144 | total_clf_loss += class_clf_loss.data.item() * inputs.size(0) 145 | 146 | test_loss /= len(test_loader.dataset) 147 | total_clf_loss /= len(test_loader.dataset) 148 | 149 | with open(os.path.join(args.save_dir, "log.txt"), 'a') as f: 150 | print('Test set loss: {:.15f} Test clf loss: {:.15f}'.format(test_loss, total_clf_loss), file=f) 151 | 152 | return test_loss 153 | 154 | 155 | def save(epoch): 156 | model_dir = os.path.join(args.save_dir, "models") 157 | os.makedirs(model_dir, exist_ok=True) 158 | torch.save(model.cpu().state_dict(), os.path.join(model_dir, str(epoch)+".pth")) 159 | if torch.cuda.is_available(): 160 | model.cuda() 161 | 162 | def generate_image(epoch): 163 | img_dir = os.path.join(args.save_dir, "images") 164 | os.makedirs(img_dir, exist_ok=True) 165 | model.eval() 166 | 167 | for i in range(5): 168 | samples = train_loader.dataset[np.random.randint(30)] 169 | inputs = samples['image_tensor'] 170 | inputs = Variable(inputs.view(1,1,64,64)) 171 | 172 | if torch.cuda.is_available(): 173 | inputs = inputs.cuda() 174 | 175 | recon_inputs, _, _, _ = model(inputs) 176 | 177 | imageio.imwrite(os.path.join(img_dir, "Train_epoch_%s_inputs_%s.jpg" % (epoch, i)), np.uint8(inputs.cpu().data.view(64,64).numpy()*255)) 178 | imageio.imwrite(os.path.join(img_dir, "Train_epoch_%s_recon_%s.jpg" % (epoch, i)), np.uint8(recon_inputs.cpu().data.view(64,64).numpy()*255)) 179 | 180 | samples = test_loader.dataset[np.random.randint(30)] 181 | inputs = samples['image_tensor'] 182 | inputs = Variable(inputs.view(1,1,64,64)) 183 | 184 | if torch.cuda.is_available(): 185 | inputs = inputs.cuda() 186 | 187 | recon_inputs, _, _, _ = model(inputs) 188 | 189 | imageio.imwrite(os.path.join(img_dir, "Test_epoch_%s_inputs_%s.jpg" % (epoch, i)), np.uint8(inputs.cpu().data.view(64,64).numpy()*255)) 190 | imageio.imwrite(os.path.join(img_dir, "Test_epoch_%s_recon_%s.jpg" % (epoch, i)), np.uint8(recon_inputs.cpu().data.view(64,64).numpy()*255)) 191 | 192 | # main training loop 193 | generate_image(0) 194 | save(0) 195 | 196 | _ = test(0) 197 | 198 | for epoch in range(args.max_iter): 199 | print(epoch) 200 | train(epoch) 201 | _ = test(epoch) 202 | 203 | if epoch % 10 == 1: 204 | generate_image(epoch) 205 | save(epoch) 206 | -------------------------------------------------------------------------------- /model_rna_atac/trainer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Trainer class for model training. 3 | This code is based on https://github.com/NVlabs/MUNIT. 4 | """ 5 | from networks import Discriminator, Classifier 6 | from networks import VAEGen_MORE_LAYERS as VAEGen 7 | from utils import weights_init, get_model_list, get_scheduler 8 | from torch.autograd import Variable 9 | import torch 10 | import torch.nn as nn 11 | import os 12 | 13 | 14 | class Trainer(nn.Module): 15 | def __init__(self, hyperparameters): 16 | super(Trainer, self).__init__() 17 | lr = hyperparameters['lr'] 18 | # Initiate the networks 19 | shared_layer = False 20 | if "shared_layer" in hyperparameters and hyperparameters["shared_layer"]: 21 | shared_layer = {} 22 | shared_layer["dec"] = nn.Linear(hyperparameters['gen']['latent'], hyperparameters['gen']['dim']) 23 | shared_layer["enc"] = nn.Linear(hyperparameters['gen']['dim'], hyperparameters['gen']['latent']) 24 | 25 | self.gen_a = VAEGen(hyperparameters['input_dim_a'], hyperparameters['gen'], 26 | shared_layer) # auto-encoder for domain a 27 | self.gen_b = VAEGen(hyperparameters['input_dim_b'], hyperparameters['gen'], 28 | shared_layer) # auto-encoder for domain b 29 | self.dis_latent = Discriminator(hyperparameters['gen']['latent'], 30 | hyperparameters['dis']) # discriminator for latent space 31 | 32 | self.classifier = Classifier(hyperparameters['gen']['latent']) # classifier on the latent space 33 | 34 | # Setup the optimizers 35 | beta1 = hyperparameters['beta1'] 36 | beta2 = hyperparameters['beta2'] 37 | dis_params = list(self.dis_latent.parameters()) 38 | gen_params = list(self.gen_a.parameters()) + list(self.gen_b.parameters()) + list(self.classifier.parameters()) 39 | self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad], 40 | lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) 41 | self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad], 42 | lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) 43 | self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) 44 | self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) 45 | 46 | # Network weight initialization 47 | self.apply(weights_init(hyperparameters['init'])) 48 | self.dis_latent.apply(weights_init('gaussian')) 49 | 50 | def recon_criterion(self, input, target): 51 | return torch.mean(torch.abs(input - target)) 52 | 53 | def super_criterion(self, input, target): 54 | return torch.mean(torch.abs(input - target)) 55 | 56 | def forward(self, x_a, x_b): 57 | self.eval() 58 | h_a, _ = self.gen_a.encode(x_a) 59 | h_b, _ = self.gen_b.encode(x_b) 60 | x_ba = self.gen_a.decode(h_b) 61 | x_ab = self.gen_b.decode(h_a) 62 | self.train() 63 | return x_ab, x_ba 64 | 65 | def gen_update(self, x_a, x_b, a_labels, b_labels, super_a, super_b, hyperparameters, variational=True): 66 | true_samples = Variable( 67 | torch.randn(200, hyperparameters['gen']['latent']), 68 | requires_grad=False 69 | ).cuda() 70 | 71 | self.gen_opt.zero_grad() 72 | # encode 73 | h_a, n_a = self.gen_a.encode(x_a) 74 | h_b, n_b = self.gen_b.encode(x_b) 75 | # decode (within domain) 76 | if variational: 77 | h_a = h_a + n_a 78 | h_b = h_b + n_b 79 | 80 | x_a_recon = self.gen_a.decode(h_a) 81 | x_b_recon = self.gen_b.decode(h_b) 82 | 83 | # decode (cross domain) 84 | x_ba = self.gen_a.decode(h_b) 85 | x_ab = self.gen_b.decode(h_a) 86 | # encode again 87 | h_b_recon, n_b_recon = self.gen_a.encode(x_ba) 88 | h_a_recon, n_a_recon = self.gen_b.encode(x_ab) 89 | # decode again (if needed) 90 | if variational: 91 | h_a_recon = h_a_recon + n_a_recon 92 | h_b_recon = h_b_recon + n_b_recon 93 | 94 | classes_a = self.classifier.forward(h_a) 95 | classes_b = self.classifier.forward(h_b) 96 | 97 | # reconstruction loss 98 | self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a) 99 | self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) 100 | 101 | # GAN loss 102 | self.loss_latent_a = self.dis_latent.calc_gen_loss(h_a) 103 | self.loss_latent_b = self.dis_latent.calc_gen_loss_reverse(h_b) 104 | 105 | # Classification Loss 106 | self.loss_class_a = self.classifier.class_loss(classes_a, a_labels) 107 | self.loss_class_b = self.classifier.class_loss(classes_b, b_labels) 108 | 109 | # supervision 110 | s_a, n_a = self.gen_a.encode(super_a) 111 | s_b, n_b = self.gen_b.encode(super_b) 112 | 113 | self.loss_supervision = self.super_criterion(s_a, s_b) 114 | 115 | class_weight = hyperparameters['gan_w'] if "class_w" not in hyperparameters else hyperparameters["class_w"] 116 | 117 | # total loss 118 | self.loss_gen_total = hyperparameters['gan_w'] * self.loss_latent_a + \ 119 | hyperparameters['gan_w'] * self.loss_latent_b + \ 120 | class_weight * self.loss_class_a + \ 121 | class_weight * self.loss_class_b + \ 122 | hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ 123 | hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ 124 | hyperparameters['super_w'] * self.loss_supervision 125 | 126 | if variational: 127 | self.loss_gen_total += hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_a + \ 128 | hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_b 129 | 130 | self.loss_gen_total.backward() 131 | self.gen_opt.step() 132 | 133 | def sample(self, x_a, x_b): 134 | self.eval() 135 | x_a_recon, x_b_recon, x_ba, x_ab = [], [], [], [] 136 | for i in range(x_a.size(0)): 137 | h_a, _ = self.gen_a.encode(x_a[i].unsqueeze(0)) 138 | h_b, _ = self.gen_b.encode(x_b[i].unsqueeze(0)) 139 | x_a_recon.append(self.gen_a.decode(h_a)) 140 | x_b_recon.append(self.gen_b.decode(h_b)) 141 | x_ba.append(self.gen_a.decode(h_b)) 142 | x_ab.append(self.gen_b.decode(h_a)) 143 | x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) 144 | x_ba = torch.cat(x_ba) 145 | x_ab = torch.cat(x_ab) 146 | self.train() 147 | return x_a, x_a_recon, x_ab, x_b, x_b_recon, x_ba 148 | 149 | def dis_update(self, x_a, x_b, hyperparameters): 150 | self.dis_opt.zero_grad() 151 | # encode 152 | h_a, n_a = self.gen_a.encode(x_a) 153 | h_b, n_b = self.gen_b.encode(x_b) 154 | # D loss 155 | self.loss_dis_latent = self.dis_latent.calc_dis_loss(h_a, h_b) 156 | self.loss_dis_total = hyperparameters['gan_w'] * (self.loss_dis_latent) 157 | self.loss_dis_total.backward() 158 | self.dis_opt.step() 159 | 160 | def update_learning_rate(self): 161 | if self.dis_scheduler is not None: 162 | self.dis_scheduler.step() 163 | if self.gen_scheduler is not None: 164 | self.gen_scheduler.step() 165 | 166 | def resume(self, checkpoint_dir, hyperparameters): 167 | # Load generators 168 | last_model_name = get_model_list(checkpoint_dir, "gen") 169 | state_dict = torch.load(last_model_name) 170 | self.gen_a.load_state_dict(state_dict['a']) 171 | self.gen_b.load_state_dict(state_dict['b']) 172 | iterations = int(last_model_name[-11:-3]) 173 | # Load discriminators 174 | last_model_name = get_model_list(checkpoint_dir, "dis") 175 | state_dict = torch.load(last_model_name) 176 | self.dis_latent.load_state_dict(state_dict['latent']) 177 | # Load optimizers 178 | state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) 179 | self.dis_opt.load_state_dict(state_dict['dis']) 180 | self.gen_opt.load_state_dict(state_dict['gen']) 181 | # Reinitilize schedulers 182 | self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations) 183 | self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations) 184 | print('Resume from iteration %d' % iterations) 185 | return iterations 186 | 187 | def save(self, snapshot_dir, iterations): 188 | # Save generators, discriminators, and optimizers 189 | gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1)) 190 | dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1)) 191 | opt_name = os.path.join(snapshot_dir, 'optimizer.pt') 192 | torch.save( 193 | {'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict(), "classifier": self.classifier.state_dict()}, 194 | gen_name) 195 | torch.save({'latent': self.dis_latent.state_dict()}, dis_name) 196 | torch.save({'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict()}, opt_name) 197 | 198 | 199 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | 5 | # adapted from pytorch/examples/vae and ethanluoyc/pytorch-vae 6 | 7 | class ImageClassifier(nn.Module): 8 | def __init__(self, latent_variable_size, pretrained, nout=2): 9 | super(ImageClassifier, self).__init__() 10 | self.latent_variable_size = latent_variable_size 11 | self.feature_extractor = pretrained 12 | self.classifier = nn.Linear(latent_variable_size, nout) 13 | 14 | def forward(self, x): 15 | x, _ = self.feature_extractor.encode(x) 16 | x = self.classifier(x.view(-1, self.latent_variable_size)) 17 | return x 18 | 19 | class VAE(nn.Module): 20 | def __init__(self, nc=1, ngf=128, ndf=128, latent_variable_size=128, imsize=64, batchnorm=False): 21 | super(VAE, self).__init__() 22 | 23 | self.nc = nc 24 | self.ngf = ngf 25 | self.ndf = ndf 26 | self.imsize = imsize 27 | self.latent_variable_size = latent_variable_size 28 | self.batchnorm = batchnorm 29 | 30 | self.encoder = nn.Sequential( 31 | # input is 3 x 64 x 64 32 | nn.Conv2d(nc, ndf, 4, 2, 1, bias=False), 33 | nn.LeakyReLU(0.2, inplace=True), 34 | # state size. (ndf) x 32 x 32 35 | nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False), 36 | nn.BatchNorm2d(ndf * 2), 37 | nn.LeakyReLU(0.2, inplace=True), 38 | # state size. (ndf*2) x 16 x 16 39 | nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False), 40 | nn.BatchNorm2d(ndf * 4), 41 | nn.LeakyReLU(0.2, inplace=True), 42 | # state size. (ndf*4) x 8 x 8 43 | nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False), 44 | nn.BatchNorm2d(ndf * 8), 45 | nn.LeakyReLU(0.2, inplace=True), 46 | # state size. (ndf*8) x 4 x 4 47 | nn.Conv2d(ndf * 8, ndf * 8, 4, 2, 1, bias=False), 48 | nn.BatchNorm2d(ndf * 8), 49 | nn.LeakyReLU(0.2, inplace=True), 50 | # state size. (ndf*8) x 2 x 2 51 | ) 52 | 53 | self.fc1 = nn.Linear(ndf*8*2*2, latent_variable_size) 54 | self.fc2 = nn.Linear(ndf*8*2*2, latent_variable_size) 55 | 56 | # decoder 57 | 58 | self.decoder = nn.Sequential( 59 | # input is Z, going into a convolution 60 | # state size. (ngf*8) x 2 x 2 61 | nn.ConvTranspose2d(ngf * 8, ngf * 8, 4, 2, 1, bias=False), 62 | nn.BatchNorm2d(ngf * 8), 63 | nn.LeakyReLU(0.2, inplace=True), 64 | # state size. (ngf*8) x 4 x 4 65 | nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False), 66 | nn.BatchNorm2d(ngf * 4), 67 | nn.LeakyReLU(0.2, inplace=True), 68 | # state size. (ngf*4) x 8 x 8 69 | nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False), 70 | nn.BatchNorm2d(ngf * 2), 71 | nn.LeakyReLU(0.2, inplace=True), 72 | # state size. (ngf*2) x 16 x 16 73 | nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False), 74 | nn.BatchNorm2d(ngf), 75 | nn.LeakyReLU(0.2, inplace=True), 76 | # state size. (ngf) x 32 x 32 77 | nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False), 78 | nn.Sigmoid(), 79 | # state size. (nc) x 64 x 64 80 | ) 81 | 82 | self.d1 = nn.Sequential( 83 | nn.Linear(latent_variable_size, ngf*8*2*2), 84 | nn.ReLU(inplace=True), 85 | ) 86 | self.bn_mean = nn.BatchNorm1d(latent_variable_size) 87 | 88 | def encode(self, x): 89 | h = self.encoder(x) 90 | h = h.view(-1, self.ndf*8*2*2) 91 | if self.batchnorm: 92 | return self.bn_mean(self.fc1(h)), self.fc2(h) 93 | else: 94 | return self.fc1(h), self.fc2(h) 95 | 96 | def reparametrize(self, mu, logvar): 97 | std = logvar.mul(0.5).exp_() 98 | if torch.cuda.is_available(): 99 | eps = torch.cuda.FloatTensor(std.size()).normal_() 100 | else: 101 | eps = torch.FloatTensor(std.size()).normal_() 102 | eps = Variable(eps) 103 | return eps.mul(std).add_(mu) 104 | 105 | def decode(self, z): 106 | h = self.d1(z) 107 | h = h.view(-1, self.ngf*8, 2, 2) 108 | return self.decoder(h) 109 | 110 | def get_latent_var(self, x): 111 | mu, logvar = self.encode(x.view(-1, self.nc, self.imsize, self.imsize)) 112 | z = self.reparametrize(mu, logvar) 113 | return z 114 | 115 | def generate(self, z): 116 | res = self.decode(z) 117 | return res 118 | 119 | def forward(self, x): 120 | mu, logvar = self.encode(x.view(-1, self.nc, self.imsize, self.imsize)) 121 | z = self.reparametrize(mu, logvar) 122 | res = self.decode(z) 123 | return res, z, mu, logvar 124 | 125 | class FC_VAE(nn.Module): 126 | """Fully connected variational Autoencoder""" 127 | def __init__(self, n_input, nz, n_hidden=1024): 128 | super(FC_VAE, self).__init__() 129 | self.nz = nz 130 | self.n_input = n_input 131 | self.n_hidden = n_hidden 132 | 133 | self.encoder = nn.Sequential(nn.Linear(n_input, n_hidden), 134 | nn.ReLU(inplace=True), 135 | nn.BatchNorm1d(n_hidden), 136 | nn.Linear(n_hidden, n_hidden), 137 | nn.BatchNorm1d(n_hidden), 138 | nn.ReLU(inplace=True), 139 | nn.Linear(n_hidden, n_hidden), 140 | nn.BatchNorm1d(n_hidden), 141 | nn.ReLU(inplace=True), 142 | nn.Linear(n_hidden, n_hidden), 143 | nn.BatchNorm1d(n_hidden), 144 | nn.ReLU(inplace=True), 145 | nn.Linear(n_hidden, n_hidden), 146 | ) 147 | 148 | self.fc1 = nn.Linear(n_hidden, nz) 149 | self.fc2 = nn.Linear(n_hidden, nz) 150 | 151 | self.decoder = nn.Sequential(nn.Linear(nz, n_hidden), 152 | nn.ReLU(inplace=True), 153 | nn.BatchNorm1d(n_hidden), 154 | nn.Linear(n_hidden, n_hidden), 155 | nn.BatchNorm1d(n_hidden), 156 | nn.ReLU(inplace=True), 157 | nn.Linear(n_hidden, n_hidden), 158 | nn.BatchNorm1d(n_hidden), 159 | nn.ReLU(inplace=True), 160 | nn.Linear(n_hidden, n_hidden), 161 | nn.BatchNorm1d(n_hidden), 162 | nn.ReLU(inplace=True), 163 | nn.Linear(n_hidden, n_input), 164 | ) 165 | def forward(self, x): 166 | mu, logvar = self.encode(x) 167 | z = self.reparametrize(mu, logvar) 168 | res = self.decode(z) 169 | return res, z, mu, logvar 170 | 171 | def encode(self, x): 172 | h = self.encoder(x) 173 | return self.fc1(h), self.fc2(h) 174 | 175 | def reparametrize(self, mu, logvar): 176 | std = logvar.mul(0.5).exp_() 177 | if torch.cuda.is_available(): 178 | eps = torch.cuda.FloatTensor(std.size()).normal_() 179 | else: 180 | eps = torch.FloatTensor(std.size()).normal_() 181 | eps = Variable(eps) 182 | return eps.mul(std).add_(mu) 183 | 184 | def decode(self, z): 185 | return self.decoder(z) 186 | 187 | def get_latent_var(self, x): 188 | mu, logvar = self.encode(x) 189 | z = self.reparametrize(mu, logvar) 190 | return z 191 | 192 | def generate(self, z): 193 | res = self.decode(z) 194 | return res 195 | 196 | class FC_Autoencoder(nn.Module): 197 | """Autoencoder""" 198 | def __init__(self, n_input, nz, n_hidden=512): 199 | super(FC_Autoencoder, self).__init__() 200 | self.nz = nz 201 | self.n_input = n_input 202 | self.n_hidden = n_hidden 203 | 204 | self.encoder = nn.Sequential(nn.Linear(n_input, n_hidden), 205 | nn.ReLU(inplace=True), 206 | nn.BatchNorm1d(n_hidden), 207 | nn.Linear(n_hidden, n_hidden), 208 | nn.BatchNorm1d(n_hidden), 209 | nn.ReLU(inplace=True), 210 | nn.Linear(n_hidden, n_hidden), 211 | nn.BatchNorm1d(n_hidden), 212 | nn.ReLU(inplace=True), 213 | nn.Linear(n_hidden, n_hidden), 214 | nn.BatchNorm1d(n_hidden), 215 | nn.ReLU(inplace=True), 216 | nn.Linear(n_hidden, nz), 217 | ) 218 | 219 | self.decoder = nn.Sequential(nn.Linear(nz, n_hidden), 220 | nn.ReLU(inplace=True), 221 | nn.BatchNorm1d(n_hidden), 222 | nn.Linear(n_hidden, n_hidden), 223 | nn.BatchNorm1d(n_hidden), 224 | nn.ReLU(inplace=True), 225 | nn.Linear(n_hidden, n_hidden), 226 | nn.BatchNorm1d(n_hidden), 227 | nn.ReLU(inplace=True), 228 | nn.Linear(n_hidden, n_hidden), 229 | nn.BatchNorm1d(n_hidden), 230 | nn.ReLU(inplace=True), 231 | nn.Linear(n_hidden, n_input), 232 | ) 233 | 234 | def forward(self, x): 235 | encoding = self.encoder(x) 236 | decoding = self.decoder(encoding) 237 | return encoding, decoding 238 | 239 | class FC_Classifier(nn.Module): 240 | """Latent space discriminator""" 241 | def __init__(self, nz, n_hidden=1024, n_out=2): 242 | super(FC_Classifier, self).__init__() 243 | self.nz = nz 244 | self.n_hidden = n_hidden 245 | self.n_out = n_out 246 | 247 | self.net = nn.Sequential( 248 | nn.Linear(nz, n_hidden), 249 | nn.ReLU(inplace=True), 250 | nn.Linear(n_hidden, n_hidden), 251 | nn.ReLU(inplace=True), 252 | # nn.Linear(n_hidden, n_hidden), 253 | # nn.ReLU(inplace=True), 254 | # nn.Linear(n_hidden, n_hidden), 255 | # nn.ReLU(inplace=True), 256 | nn.Linear(n_hidden, n_hidden), 257 | nn.ReLU(inplace=True), 258 | nn.Linear(n_hidden,n_out) 259 | ) 260 | 261 | def forward(self, x): 262 | return self.net(x) 263 | 264 | class Simple_Classifier(nn.Module): 265 | """Latent space discriminator""" 266 | def __init__(self, nz, n_out=2): 267 | super(Simple_Classifier, self).__init__() 268 | self.nz = nz 269 | self.n_out = n_out 270 | 271 | self.net = nn.Sequential( 272 | nn.Linear(nz, n_out), 273 | ) 274 | 275 | def forward(self, x): 276 | return self.net(x) 277 | 278 | -------------------------------------------------------------------------------- /train_rna_image.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data 3 | from torch import nn, optim 4 | from torch.autograd import Variable 5 | 6 | from dataloader import RNA_Dataset 7 | from dataloader import NucleiDatasetNew as NucleiDataset 8 | from model import FC_Autoencoder, FC_Classifier, VAE, FC_VAE, Simple_Classifier 9 | 10 | import os 11 | import argparse 12 | import numpy as np 13 | import imageio 14 | 15 | torch.manual_seed(1) 16 | 17 | #============ PARSE ARGUMENTS ============= 18 | 19 | def setup_args(): 20 | 21 | options = argparse.ArgumentParser() 22 | 23 | # save and directory options 24 | options.add_argument('-sd', '--save-dir', action="store", dest="save_dir") 25 | options.add_argument('--save-freq', action="store", dest="save_freq", default=20, type=int) 26 | options.add_argument('--pretrained-file', action="store") 27 | 28 | # training parameters 29 | options.add_argument('-bs', '--batch-size', action="store", dest="batch_size", default=32, type=int) 30 | options.add_argument('-w', '--num-workers', action="store", dest="num_workers", default=10, type=int) 31 | options.add_argument('-lrAE', '--learning-rate-AE', action="store", dest="learning_rate_AE", default=1e-4, type=float) 32 | options.add_argument('-lrD', '--learning-rate-D', action="store", dest="learning_rate_D", default=1e-4, type=float) 33 | options.add_argument('-e', '--max-epochs', action="store", dest="max_epochs", default=1000, type=int) 34 | options.add_argument('-wd', '--weight-decay', action="store", dest="weight_decay", default=0, type=float) 35 | options.add_argument('--train-imagenet', action="store_true") 36 | options.add_argument('--conditional', action="store_true") 37 | options.add_argument('--conditional-adv', action="store_true") 38 | 39 | # hyperparameters 40 | options.add_argument('--alpha', action="store", default=0.1, type=float) 41 | options.add_argument('--beta', action="store", default=1., type=float) 42 | options.add_argument('--lamb', action="store", default=0.00000001, type=float) 43 | options.add_argument('--latent-dims', action="store", default=128, type=int) 44 | 45 | # gpu options 46 | options.add_argument('-gpu', '--use-gpu', action="store_false", dest="use_gpu") 47 | 48 | return options.parse_args() 49 | 50 | 51 | args = setup_args() 52 | if not torch.cuda.is_available(): 53 | args.use_gpu = False 54 | 55 | os.makedirs(args.save_dir, exist_ok=True) 56 | 57 | #============= TRAINING INITIALIZATION ============== 58 | 59 | # initialize autoencoder 60 | netRNA = FC_VAE(n_input=7633, nz=args.latent_dims) 61 | 62 | netImage = VAE(latent_variable_size=args.latent_dims, batchnorm=True) 63 | netImage.load_state_dict(torch.load(args.pretrained_file)) 64 | print("Pre-trained model loaded from %s" % args.pretrained_file) 65 | 66 | if args.conditional_adv: 67 | netClf = FC_Classifier(nz=args.latent_dims+10) 68 | assert(not args.conditional) 69 | else: 70 | netClf = FC_Classifier(nz=args.latent_dims) 71 | 72 | if args.conditional: 73 | netCondClf = Simple_Classifier(nz=args.latent_dims) 74 | 75 | if args.use_gpu: 76 | netRNA.cuda() 77 | netImage.cuda() 78 | netClf.cuda() 79 | if args.conditional: 80 | netCondClf.cuda() 81 | 82 | # load data 83 | genomics_dataset = RNA_Dataset(datadir="data/nCD4_gene_exp_matrices/") 84 | image_dataset = NucleiDataset(datadir="data/nuclear_crops_all_experiments", mode='test') 85 | 86 | image_loader = torch.utils.data.DataLoader(image_dataset, batch_size=args.batch_size, drop_last=True, shuffle=True) 87 | genomics_loader = torch.utils.data.DataLoader(genomics_dataset, batch_size=args.batch_size, drop_last=True, shuffle=True) 88 | 89 | # setup optimizer 90 | opt_netRNA = optim.Adam(list(netRNA.parameters()), lr=args.learning_rate_AE) 91 | opt_netClf = optim.Adam(list(netClf.parameters()), lr=args.learning_rate_D, weight_decay=args.weight_decay) 92 | opt_netImage = optim.Adam(list(netImage.parameters()), lr=args.learning_rate_AE) 93 | 94 | if args.conditional: 95 | opt_netCondClf = optim.Adam(list(netCondClf.parameters()), lr=args.learning_rate_AE) 96 | 97 | # loss criteria 98 | criterion_reconstruct = nn.MSELoss() 99 | criterion_classify = nn.CrossEntropyLoss() 100 | 101 | # setup logger 102 | with open(os.path.join(args.save_dir, 'log.txt'), 'w') as f: 103 | print(args, file=f) 104 | print(netRNA, file=f) 105 | print(netImage, file=f) 106 | print(netClf, file=f) 107 | if args.conditional: 108 | print(netCondClf, file=f) 109 | 110 | # define helper train functions 111 | 112 | def compute_KL_loss(mu, logvar): 113 | if args.lamb>0: 114 | KLloss = -0.5*torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) 115 | return args.lamb * KLloss 116 | return 0 117 | 118 | def train_autoencoders(rna_inputs, image_inputs, rna_class_labels=None, image_class_labels=None): 119 | 120 | netRNA.train() 121 | if args.train_imagenet: 122 | netImage.train() 123 | else: 124 | netImage.eval() 125 | netClf.eval() 126 | if args.conditional: 127 | netCondClf.train() 128 | 129 | # process input data 130 | rna_inputs, image_inputs = Variable(rna_inputs), Variable(image_inputs) 131 | 132 | if args.use_gpu: 133 | rna_inputs, image_inputs = rna_inputs.cuda(), image_inputs.cuda() 134 | 135 | # reset parameter gradients 136 | netRNA.zero_grad() 137 | 138 | # forward pass 139 | rna_recon, rna_latents, rna_mu, rna_logvar = netRNA(rna_inputs) 140 | image_recon, image_latents, image_mu, image_logvar = netImage(image_inputs) 141 | 142 | if args.conditional_adv: 143 | rna_class_labels, image_class_labels = rna_class_labels.cuda(), image_class_labels.cuda() 144 | rna_scores = netClf(torch.cat((rna_latents, rna_class_labels.float().view(-1,1).expand(-1,10)), dim=1)) 145 | image_scores = netClf(torch.cat((image_latents, image_class_labels.float().view(-1,1).expand(-1,10)), dim=1)) 146 | else: 147 | rna_scores = netClf(rna_latents) 148 | image_scores = netClf(image_latents) 149 | 150 | rna_labels = torch.zeros(rna_scores.size(0),).long() 151 | image_labels = torch.ones(image_scores.size(0),).long() 152 | 153 | if args.conditional: 154 | rna_class_scores = netCondClf(rna_latents) 155 | image_class_scores = netCondClf(image_latents) 156 | 157 | if args.use_gpu: 158 | rna_labels, image_labels = rna_labels.cuda(), image_labels.cuda() 159 | if args.conditional: 160 | rna_class_labels, image_class_labels = rna_class_labels.cuda(), image_class_labels.cuda() 161 | 162 | # compute losses 163 | rna_recon_loss = criterion_reconstruct(rna_inputs, rna_recon) 164 | image_recon_loss = criterion_reconstruct(image_inputs, image_recon) 165 | kl_loss = compute_KL_loss(rna_mu, rna_logvar) + compute_KL_loss(image_mu, image_logvar) 166 | clf_loss = 0.5*criterion_classify(rna_scores, image_labels) + 0.5*criterion_classify(image_scores, rna_labels) 167 | loss = args.alpha*(rna_recon_loss + image_recon_loss) + clf_loss + kl_loss 168 | 169 | if args.conditional: 170 | clf_class_loss = 0.5*criterion_classify(rna_class_scores, rna_class_labels) + 0.5*criterion_classify(image_class_scores, image_class_labels) 171 | loss += args.beta*clf_class_loss 172 | 173 | # backpropagate and update model 174 | loss.backward() 175 | opt_netRNA.step() 176 | if args.conditional: 177 | opt_netCondClf.step() 178 | 179 | if args.train_imagenet: 180 | opt_netImage.step() 181 | 182 | summary_stats = {'rna_recon_loss': rna_recon_loss.item()*rna_scores.size(0), 'image_recon_loss': image_recon_loss.item()*image_scores.size(0), 183 | 'clf_loss': clf_loss.item()*(rna_scores.size(0)+image_scores.size(0))} 184 | 185 | if args.conditional: 186 | summary_stats['clf_class_loss'] = clf_class_loss.item()*(rna_scores.size(0)+image_scores.size(0)) 187 | 188 | return summary_stats 189 | 190 | def train_classifier(rna_inputs, image_inputs, rna_class_labels=None, image_class_labels=None): 191 | 192 | netRNA.eval() 193 | netImage.eval() 194 | netClf.train() 195 | 196 | # process input data 197 | rna_inputs, image_inputs = Variable(rna_inputs), Variable(image_inputs) 198 | 199 | if args.use_gpu: 200 | rna_inputs, image_inputs = rna_inputs.cuda(), image_inputs.cuda() 201 | 202 | # reset parameter gradients 203 | netClf.zero_grad() 204 | 205 | # forward pass 206 | _, rna_latents, _, _ = netRNA(rna_inputs) 207 | _, image_latents, _, _ = netImage(image_inputs) 208 | 209 | if args.conditional_adv: 210 | rna_class_labels, image_class_labels = rna_class_labels.cuda(), image_class_labels.cuda() 211 | rna_scores = netClf(torch.cat((rna_latents, rna_class_labels.float().view(-1,1).expand(-1,10)), dim=1)) 212 | image_scores = netClf(torch.cat((image_latents, image_class_labels.float().view(-1,1).expand(-1,10)), dim=1)) 213 | else: 214 | rna_scores = netClf(rna_latents) 215 | image_scores = netClf(image_latents) 216 | 217 | rna_labels = torch.zeros(rna_scores.size(0),).long() 218 | image_labels = torch.ones(image_scores.size(0),).long() 219 | 220 | if args.use_gpu: 221 | rna_labels, image_labels = rna_labels.cuda(), image_labels.cuda() 222 | 223 | # compute losses 224 | clf_loss = 0.5*criterion_classify(rna_scores, rna_labels) + 0.5*criterion_classify(image_scores, image_labels) 225 | 226 | loss = clf_loss 227 | 228 | # backpropagate and update model 229 | loss.backward() 230 | opt_netClf.step() 231 | 232 | summary_stats = {'clf_loss': clf_loss*(rna_scores.size(0)+image_scores.size(0)), 'rna_accuracy': accuracy(rna_scores, rna_labels), 'rna_n_samples': rna_scores.size(0), 233 | 'image_accuracy': accuracy(image_scores, image_labels), 'image_n_samples': image_scores.size(0)} 234 | 235 | return summary_stats 236 | 237 | def accuracy(output, target): 238 | pred = output.argmax(dim=1).view(-1) 239 | correct = pred.eq(target.view(-1)).float().sum().item() 240 | return correct 241 | 242 | def generate_image(epoch): 243 | img_dir = os.path.join(args.save_dir, "images") 244 | os.makedirs(img_dir, exist_ok=True) 245 | netRNA.eval() 246 | netImage.eval() 247 | 248 | for i in range(5): 249 | samples = genomics_loader.dataset[np.random.randint(30)] 250 | rna_inputs = samples['tensor'] 251 | rna_inputs = Variable(rna_inputs.unsqueeze(0)) 252 | samples = image_loader.dataset[np.random.randint(30)] 253 | image_inputs = samples['image_tensor'] 254 | image_inputs = Variable(image_inputs.unsqueeze(0)) 255 | 256 | if torch.cuda.is_available(): 257 | rna_inputs = rna_inputs.cuda() 258 | image_inputs = image_inputs.cuda() 259 | 260 | _, rna_latents, _, _ = netRNA(rna_inputs) 261 | recon_inputs = netImage.decode(rna_latents) 262 | imageio.imwrite(os.path.join(img_dir, "epoch_%s_trans_%s.jpg" % (epoch, i)), np.uint8(recon_inputs.cpu().data.view(64,64).numpy()*255)) 263 | recon_images, _, _, _ = netImage(image_inputs) 264 | imageio.imwrite(os.path.join(img_dir, "epoch_%s_recon_%s.jpg" % (epoch, i)), np.uint8(recon_images.cpu().data.view(64,64).numpy()*255)) 265 | 266 | ### main training loop 267 | for epoch in range(args.max_epochs): 268 | print(epoch) 269 | 270 | if epoch % args.save_freq == 0: 271 | generate_image(epoch) 272 | 273 | recon_rna_loss = 0 274 | recon_image_loss = 0 275 | clf_loss = 0 276 | clf_class_loss = 0 277 | AE_clf_loss = 0 278 | 279 | n_rna_correct = 0 280 | n_rna_total = 0 281 | n_atac_correct = 0 282 | n_atac_total = 0 283 | 284 | for idx, (rna_samples, image_samples) in enumerate(zip(genomics_loader, image_loader)): 285 | rna_inputs = rna_samples['tensor'] 286 | image_inputs = image_samples['image_tensor'] 287 | 288 | if args.conditional or args.conditional_adv: 289 | rna_labels = rna_samples['binary_label'] 290 | image_labels = image_samples['binary_label'] 291 | out = train_autoencoders(rna_inputs, image_inputs, rna_labels, image_labels) 292 | else: 293 | out = train_autoencoders(rna_inputs, image_inputs) 294 | 295 | recon_rna_loss += out['rna_recon_loss'] 296 | recon_image_loss += out['image_recon_loss'] 297 | AE_clf_loss += out['clf_loss'] 298 | 299 | if args.conditional: 300 | clf_class_loss += out['clf_class_loss'] 301 | 302 | if args.conditional_adv: 303 | out = train_classifier(rna_inputs, image_inputs, rna_labels, image_labels) 304 | else: 305 | out = train_classifier(rna_inputs, image_inputs) 306 | 307 | clf_loss += out['clf_loss'] 308 | n_rna_correct += out['rna_accuracy'] 309 | n_rna_total += out['rna_n_samples'] 310 | n_atac_correct += out['image_accuracy'] 311 | n_atac_total += out['image_n_samples'] 312 | 313 | recon_rna_loss /= n_rna_total 314 | clf_loss /= n_rna_total+n_atac_total 315 | AE_clf_loss /= n_rna_total+n_atac_total 316 | 317 | if args.conditional: 318 | clf_class_loss /= n_rna_total + n_atac_total 319 | 320 | with open(os.path.join(args.save_dir, 'log.txt'), 'a') as f: 321 | print('Epoch: ', epoch, ', rna recon loss: %.8f' % float(recon_rna_loss), ', image recon loss: %.8f' % float(recon_image_loss), 322 | ', AE clf loss: %.8f' % float(AE_clf_loss), ', clf loss: %.8f' % float(clf_loss), ', clf class loss: %.8f' % float(clf_class_loss), 323 | ', clf accuracy RNA: %.4f' % float(n_rna_correct / n_rna_total), ', clf accuracy ATAC: %.4f' % float(n_atac_correct / n_atac_total), file=f) 324 | 325 | # save model 326 | if epoch % args.save_freq == 0: 327 | torch.save(netRNA.cpu().state_dict(), os.path.join(args.save_dir,"netRNA_%s.pth" % epoch)) 328 | torch.save(netImage.cpu().state_dict(), os.path.join(args.save_dir,"netImage_%s.pth" % epoch)) 329 | torch.save(netClf.cpu().state_dict(), os.path.join(args.save_dir,"netClf_%s.pth" % epoch)) 330 | if args.conditional: 331 | torch.save(netCondClf.cpu().state_dict(), os.path.join(args.save_dir,"netCondClf_%s.pth" % epoch)) 332 | 333 | if args.use_gpu: 334 | netRNA.cuda() 335 | netClf.cuda() 336 | netImage.cuda() 337 | if args.conditional: 338 | netCondClf.cuda() 339 | 340 | -------------------------------------------------------------------------------- /model_rna_atac/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utils for data loading and model training. 3 | This code is based on https://github.com/NVlabs/MUNIT. 4 | """ 5 | 6 | from torch.utils.data import DataLoader 7 | from torch.autograd import Variable 8 | from torch.optim import lr_scheduler 9 | import torch 10 | 11 | torch.cuda.set_device(0) 12 | import os 13 | import math 14 | import yaml 15 | import numpy as np 16 | import torch.nn.init as init 17 | 18 | from scipy import sparse 19 | from scipy.stats import percentileofscore 20 | import torch.utils.data as utils 21 | import pickle 22 | import pandas as pd 23 | from sklearn.decomposition import PCA 24 | from sklearn.preprocessing import StandardScaler 25 | from sklearn.metrics import mean_squared_error 26 | import matplotlib 27 | matplotlib.use('Agg') 28 | import matplotlib.pyplot as plt 29 | from sklearn.neighbors import NearestNeighbors 30 | from random import sample 31 | from sklearn import metrics 32 | 33 | CONF = {} 34 | data_a = None 35 | data_b = None 36 | 37 | 38 | directory_prefix = "processed_data/" 39 | DATA_DIRECTORY = directory_prefix + "transcription_factor/" 40 | TEST_SET = [1652, 1563, 1823, 1252, 1453, 189, 1063, 331, 998, 161, 1103, 41 | 1158, 1595, 1459, 892, 1694, 671, 469, 486, 1537, 1308, 960, 42 | 138, 966, 987, 1448, 1466, 1478, 232, 704, 84, 737, 252, 43 | 256, 62, 1439, 336, 1170, 1786, 1277, 1819, 1096, 508, 462, 44 | 1456, 1129, 240, 352, 1716, 629, 593, 951, 1840, 212, 512, 45 | 1172, 980, 1090, 750, 728, 783, 788, 1000, 498, 5, 569, 46 | 572, 50, 1662, 375, 661, 1778, 235, 1607, 110, 1632, 816, 47 | 209, 1798, 1174, 193, 1362, 310, 342, 98, 1538, 405, 1161, 48 | 1310, 1240, 143, 586, 970, 100, 1679, 604, 700, 549, 1464, 49 | 712, 654, 763, 562, 1323, 1445, 150, 507, 956, 1444, 795, 50 | 394, 1530, 895, 582, 274, 350, 459, 57, 384, 446, 828, 51 | 270, 370, 1510, 300, 1101, 1428, 1561, 1857, 1035, 982, 1276, 52 | 63, 780, 1111, 952, 1347, 268, 421, 1574, 1309, 1168, 1060, 53 | 1566, 804, 769, 1528, 743, 494, 847, 1071, 523, 1011, 914, 54 | 1645, 558, 889, 653, 425, 1863, 844, 812, 1859, 1225, 0, 55 | 1582, 170, 1015, 1242, 1826, 1067, 147, 1651, 884, 1628, 1433, 56 | 165, 976, 45, 1838, 602, 28, 1029, 989, 1725, 1724, 936, 57 | 1082, 1442, 307, 1669, 1791, 1553, 1720, 211, 61, 709, 890, 58 | 86, 148, 1159, 675, 1241, 311, 1254, 1175, 990, 306, 1497, 59 | 385, 1514, 499, 168, 374, 747, 1083, 243, 627, 1869, 1619, 60 | 321, 1012, 1868, 864, 393, 1437, 1806, 25, 320, 111, 1598, 61 | 1526, 873, 972, 59, 217, 434, 341, 557, 1135, 60, 1361, 62 | 117, 1543, 407, 665, 1118, 219, 1251, 713, 688, 1304, 482, 63 | 802, 1380, 349, 1221, 785, 1495, 285, 1208, 1192, 770, 679, 64 | 6, 1021, 518, 305, 1492, 805, 135, 1586, 214, 870, 1421, 65 | 1081, 777, 242, 1274, 1712, 1860, 447, 391, 1661, 766, 840, 66 | 1450, 1220, 1766, 1629, 1653, 191, 1590, 162, 862, 1365, 344, 67 | 87, 1673, 1209, 382, 345, 1069, 1400, 1585, 631, 456, 1267, 68 | 1138, 390, 827, 908, 639, 1649, 1845, 142, 1684, 781, 15, 69 | 301, 1288, 1719, 943, 68, 1298, 626, 1621, 617, 746, 146, 70 | 1482, 745, 403, 556, 899, 220, 471, 651, 1018, 1717] 71 | TEST_SET = list(set(TEST_SET)) 72 | TRAINING_SET = [x for x in range(1874) if x not in TEST_SET] 73 | 74 | 75 | with open(DATA_DIRECTORY + "shared_cells.pkl", 'rb') as f: 76 | shared = pickle.load(f) 77 | gene_exp_cells = pd.read_csv(DATA_DIRECTORY + 'GSM3271040_RNA_sciCAR_A549_cell.txt', index_col=0) 78 | TREATMENT = gene_exp_cells.loc[shared]["treatment_time"] 79 | 80 | 81 | def load_data(conf, isatac=False, data_size=1874, for_training=True, supervise=[]): 82 | print(DATA_DIRECTORY) 83 | log_data = False if "log_data" not in conf else conf["log_data"] 84 | normalize_data = False if "normalize_data" not in conf else conf["normalize_data"] 85 | drop = False if "drop" not in conf else conf["drop"] 86 | # supervise is indices in dataset to drop 87 | 88 | if isatac: 89 | f = DATA_DIRECTORY + "diff_atac_shared_cells.npz" 90 | else: 91 | f = DATA_DIRECTORY + "diff_expr_shared_cells.npz" 92 | 93 | data = sparse.load_npz(f).T.todense() 94 | # assert (len(data) == data_size) 95 | 96 | if drop: 97 | threshold = 0.01 if isatac else 0.1 98 | # threshold = 0 if isatac else 0.1 99 | acceptable = np.count_nonzero(data, axis=0) > threshold * len(data) 100 | data = data[:, acceptable.flatten().tolist()[0]] 101 | 102 | if log_data: 103 | data = np.log1p(data) 104 | if for_training: 105 | print("Taking log of data..") 106 | 107 | if normalize_data: 108 | scaler = StandardScaler() 109 | training_data = data[TRAINING_SET, :] 110 | scaler.fit(training_data) 111 | if for_training: 112 | print("Normalizing the data..") 113 | data = scaler.transform(data) 114 | 115 | if not for_training: 116 | return Variable(torch.from_numpy(data).float()).cuda() 117 | 118 | elif supervise: 119 | supervised_data = data[supervise, :] 120 | assert (len(supervised_data) == len(supervise)) 121 | return Variable(torch.from_numpy(supervised_data).float()).cuda() 122 | 123 | else: 124 | training_data = data[TRAINING_SET, :] 125 | test_data = data[TEST_SET, :] 126 | return torch.from_numpy(training_data).float(), torch.from_numpy(test_data).float() 127 | 128 | 129 | def load_supervision(conf, supervise=0): 130 | # supervise is fraction of data to supervise 131 | s = sample(TRAINING_SET, k=int(supervise * len(TRAINING_SET))) 132 | supervise_a = load_data(conf, isatac=True, supervise=s) 133 | supervise_b = load_data(conf, isatac=False, supervise=s) 134 | return supervise_a, supervise_b 135 | 136 | 137 | def get_all_data_loaders(conf): 138 | global CONF 139 | global data_a 140 | global data_b 141 | CONF = conf 142 | 143 | data_a = load_data_for_latent_space_plot(isatac=True) 144 | # # a is atac 145 | data_b = load_data_for_latent_space_plot(isatac=False) 146 | # b is expression 147 | 148 | labels = [i if i != 3 else 2 for i in TREATMENT] 149 | training_labels = torch.from_numpy(np.array(labels)[TRAINING_SET]).long() 150 | test_labels = torch.from_numpy(np.array(labels)[TEST_SET]).long() 151 | 152 | assert 1 in training_labels and 2 in training_labels and 0 in training_labels 153 | 154 | train, test = load_data(conf, isatac=True) 155 | batch_size = conf['batch_size'] 156 | 157 | train_dataset = utils.TensorDataset(train, training_labels) 158 | train_loader_a = utils.DataLoader(train_dataset, batch_size=batch_size) 159 | 160 | test_dataset = utils.TensorDataset(test, test_labels) 161 | test_loader_a = utils.DataLoader(test_dataset, batch_size=batch_size) 162 | 163 | train, test = load_data(conf, isatac=False) 164 | 165 | train_dataset = utils.TensorDataset(train, training_labels) 166 | train_loader_b = utils.DataLoader(train_dataset, batch_size=batch_size) 167 | 168 | test_dataset = utils.TensorDataset(test, test_labels) 169 | test_loader_b = utils.DataLoader(test_dataset, batch_size=batch_size) 170 | 171 | return train_loader_a, train_loader_b, test_loader_a, test_loader_b 172 | 173 | 174 | def get_config(config): 175 | # Note need to have pip install pyyaml==5.4.1 176 | with open(config, 'r') as stream: 177 | return yaml.load(stream) 178 | 179 | 180 | def prepare_sub_folder(output_directory): 181 | image_directory = os.path.join(output_directory, 'images') 182 | if not os.path.exists(image_directory): 183 | print("Creating directory: {}".format(image_directory)) 184 | os.makedirs(image_directory) 185 | checkpoint_directory = os.path.join(output_directory, 'checkpoints') 186 | if not os.path.exists(checkpoint_directory): 187 | print("Creating directory: {}".format(checkpoint_directory)) 188 | os.makedirs(checkpoint_directory) 189 | return checkpoint_directory, image_directory 190 | 191 | 192 | def write_loss(iterations, trainer, train_writer): 193 | members = [attr for attr in dir(trainer) \ 194 | if not callable(getattr(trainer, attr)) and not attr.startswith("__") and ( 195 | 'loss' in attr or 'grad' in attr or 'nwd' in attr)] 196 | for m in members: 197 | train_writer.add_scalar(m, getattr(trainer, m), iterations + 1) 198 | 199 | 200 | # Get model list for resume 201 | def get_model_list(dirname, key): 202 | if os.path.exists(dirname) is False: 203 | return None 204 | gen_models = [os.path.join(dirname, f) for f in os.listdir(dirname) if 205 | os.path.isfile(os.path.join(dirname, f)) and key in f and ".pt" in f] 206 | if gen_models is None: 207 | return None 208 | gen_models.sort() 209 | last_model_name = gen_models[-1] 210 | return last_model_name 211 | 212 | 213 | def get_scheduler(optimizer, hyperparameters, iterations=-1): 214 | if 'lr_policy' not in hyperparameters or hyperparameters['lr_policy'] == 'constant': 215 | scheduler = None # constant scheduler 216 | elif hyperparameters['lr_policy'] == 'step': 217 | scheduler = lr_scheduler.StepLR(optimizer, step_size=hyperparameters['step_size'], 218 | gamma=hyperparameters['gamma'], last_epoch=iterations) 219 | else: 220 | return NotImplementedError('learning rate policy [%s] is not implemented', hyperparameters['lr_policy']) 221 | return scheduler 222 | 223 | 224 | def weights_init(init_type='gaussian'): 225 | def init_fun(m): 226 | classname = m.__class__.__name__ 227 | if (classname.find('Conv') == 0 or classname.find('Linear') == 0) and hasattr(m, 'weight'): 228 | # print m.__class__.__name__ 229 | if init_type == 'gaussian': 230 | init.normal_(m.weight.data, 0.0, 0.02) 231 | elif init_type == 'xavier': 232 | init.xavier_normal_(m.weight.data, gain=math.sqrt(2)) 233 | elif init_type == 'kaiming': 234 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 235 | elif init_type == 'orthogonal': 236 | init.orthogonal_(m.weight.data, gain=math.sqrt(2)) 237 | elif init_type == 'default': 238 | pass 239 | else: 240 | assert 0, "Unsupported initialization: {}".format(init_type) 241 | if hasattr(m, 'bias') and m.bias is not None: 242 | init.constant_(m.bias.data, 0.0) 243 | 244 | return init_fun 245 | 246 | 247 | # Code for plotting latent space. 248 | 249 | def load_data_for_latent_space_plot(isatac=False): 250 | conf = CONF 251 | return load_data(conf, isatac=isatac, for_training=False) 252 | 253 | 254 | def plot_pca(a,b, outname1=None, outname2=None, outname=None, scale=True): 255 | matrix = np.vstack((b, a)) 256 | pca = PCA(n_components=2) 257 | scaled = matrix.copy() 258 | if scale: 259 | scaler = StandardScaler() 260 | scaled = scaler.fit_transform(matrix) 261 | 262 | comp = pca.fit_transform(scaled) 263 | 264 | half = len(a) 265 | fig, ax = matplotlib.pyplot.subplots() 266 | sc = ax.scatter(comp[:, 0][0:half], comp[:, 1][0:half], c=TREATMENT.values, s=1) 267 | ax.set_xlabel("PC 1") 268 | ax.set_ylabel("PC 2") 269 | cbar = fig.colorbar(sc) 270 | cbar.ax.set_ylabel('Treatment time') 271 | plt.savefig(outname1) 272 | plt.close("all") 273 | 274 | fig, ax = matplotlib.pyplot.subplots() 275 | sc = ax.scatter(comp[:, 0][half:], comp[:, 1][half:], c=TREATMENT.values, s=1) 276 | ax.set_xlabel("PC 1") 277 | ax.set_ylabel("PC 2") 278 | cbar = fig.colorbar(sc) 279 | cbar.ax.set_ylabel('Treatment time') 280 | plt.savefig(outname2) 281 | plt.close("all") 282 | 283 | fig, ax = matplotlib.pyplot.subplots() 284 | # make atac yellow 285 | colors = ['purple'] * len(TREATMENT.values) + ["yellow"] * len(TREATMENT.values) 286 | sc = ax.scatter(comp[:, 0], comp[:, 1], c=colors, s=1) 287 | ax.set_xlabel("PC 1") 288 | ax.set_ylabel("PC 2") 289 | #plt.legend(('RNA-seq', 'ATAC-seq')) 290 | plt.savefig(outname) 291 | plt.close("all") 292 | 293 | def plot_pca_both_spaces(a, b, outname, scale=True): 294 | matrix = np.vstack((b, a)) 295 | pca = PCA(n_components=2) 296 | scaled = matrix.copy() 297 | if scale: 298 | scaler = StandardScaler() 299 | scaled = scaler.fit_transform(matrix) 300 | 301 | comp = pca.fit_transform(scaled) 302 | 303 | fig, ax = matplotlib.pyplot.subplots() 304 | # make atac yellow and 305 | colors = ['purple'] * len(TREATMENT.values) + ["yellow"] * len(TREATMENT.values) 306 | sc = ax.scatter(comp[:, 0], comp[:, 1], c=colors, s=1) 307 | ax.set_xlabel("PC 1") 308 | ax.set_ylabel("PC 2") 309 | ax.legend() 310 | plt.savefig(outname) 311 | plt.close("all") 312 | 313 | 314 | def save_plots(trainer, directory, suffix): 315 | latent_a = trainer.gen_a.enc(data_a).data.cpu().numpy() 316 | latent_b = trainer.gen_b.enc(data_b).data.cpu().numpy() 317 | 318 | #plot_pca(latent_a, os.path.join(directory, "_a_" + suffix + ".png")) 319 | #plot_pca(latent_b, os.path.join(directory, "_b_" + suffix + ".png")) 320 | #plot_pca_both_spaces(latent_a, latent_b, os.path.join(directory, "both_" + suffix + ".png")) 321 | 322 | plot_pca(latent_a, latent_b, os.path.join(directory, "_a_" + suffix + ".png"), os.path.join(directory, "_b_" + suffix + ".png"), os.path.join(directory, "both_" + suffix + ".png")) 323 | 324 | def write_knn(trainer, directory, suffix): 325 | latent_a = trainer.gen_a.enc(data_a).data.cpu().numpy() 326 | latent_b = trainer.gen_b.enc(data_b).data.cpu().numpy() 327 | 328 | for k in [5, 50]: 329 | accuracy_a_train, accuracy_a_test = knn_accuracy(latent_a, latent_b, k) 330 | accuracy_b_train, accuracy_b_test = knn_accuracy(latent_b, latent_a, k) 331 | output = "Iteration: {}\n {}NN accuracy A: train: {} test: {}\n {}NN accuracy B: train: {} test: {}\n".format( 332 | suffix, str(k), accuracy_a_train, accuracy_a_test, str(k), accuracy_b_train, accuracy_b_test) 333 | print(output) 334 | with open(os.path.join(directory, "knn_accuracy.txt"), "a") as myfile: 335 | myfile.write(output) 336 | 337 | def knn_accuracy(A, B, k): 338 | nn = NearestNeighbors(k, metric="l1") 339 | nn.fit(A, k) 340 | transp_nearest_neighbor = nn.kneighbors(B, 1, return_distance=False) 341 | actual_nn = nn.kneighbors(A, k, return_distance=False) 342 | train_correct = 0 343 | test_correct = 0 344 | 345 | for i in range(len(transp_nearest_neighbor)): 346 | if transp_nearest_neighbor[i] not in actual_nn[i]: 347 | continue 348 | elif i in TEST_SET: 349 | test_correct += 1 350 | else: 351 | train_correct += 1 352 | 353 | return train_correct / len(TRAINING_SET), test_correct / len(TEST_SET) --------------------------------------------------------------------------------