├── proteinbert ├── shared_utils │ ├── __init__.py │ ├── .gitignore │ ├── reference_genome.py │ └── util.py ├── existing_model_loading.py ├── conv_and_global_attention_model.py ├── model_generation.py ├── uniref_dataset.py └── pretraining.py ├── .gitignore ├── __init__.py ├── Sample_data └── input.csv ├── Dockerfile ├── tokenization.py ├── README.md ├── finetuning.py ├── Analysis_and_Figures ├── Attention_analysis │ ├── Fig5.ipynb │ ├── Fig3.ipynb │ └── Fig4.ipynb └── Fig2_and_Suppl.ipynb ├── Practice.ipynb └── test.py /proteinbert/shared_utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | inputs 2 | outputs 3 | __pycache__ 4 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from proteinbert.shared_utils.util import log 2 | from .tokenization import ADDED_TOKENS_PER_SEQ 3 | from proteinbert.model_generation import ModelGenerator, PretrainingModelGenerator, FinetuningModelGenerator, InputEncoder, load_pretrained_model_from_dump, tokenize_seqs 4 | from proteinbert.existing_model_loading import load_pretrained_model 5 | from .finetuning import finetune 6 | from .test import evaluate_by_len 7 | -------------------------------------------------------------------------------- /Sample_data/input.csv: -------------------------------------------------------------------------------- 1 | ,ID,FEATURES,database,target,subclass,mechanism,transferable,sequence 2 | 0,AAS90606.1,FEATURES,farme,aminoglycoside,AAC(6')-Iai,antibiotic inactivation,1,MNIFSLTRDNEHLIHQAAQLLMDAFHEHWPDAWPTFEEGWKEVHEMLEAERICRAAVDDNGNLLGLIGGIAGYDGNVWELHPLAVQPHLQGQGIGRALVEDFEEQVRLRGGLTITLGSDDEDDMTSLSDVDLHENPWEKIQKIRNFKGHPFTFYQKMGYVITGVVPDANGRGKPDILMSKRVG 3 | 1,ACH59003.1,FEATURES,farme,beta_lactam,mrdA,antibiotic inactivation,0,MTTHSPVPDVDEMTIQTRPGNAVQSSSGQVAHADRQRSDPSQVHQYSIQRTRIAERLETLIPWLSVAWFFGVLGLTARVTGGMIYTQRLIRCHTQLFGSYWTERLKHVSKRLRLSRSVRLLESSVVKVPTTIGWWRPVILVPGSVLSGLTPQQLELILAHELAHIRRHDYLINLFQVLVETLLFYHPAVWWISKQVRNERELVCDDMAVSVGGDPITYARALAKIERLRRETPVWALAADGGRLSKRIVRLIDSTQDSPRLPSMVVGFIMIGALFISIAVVQNVLSITKRSAQVAVAGATSVHQQLTPQSVQELIALDDTTGEDSEVRRISLAALGKREGAVIVMDPRTGRVYTIVNQEWAVRQSWQPASIIKLVTAAAALGEKVIQPSQPLRVSAKSRPLDLTEALALSSNPYFAFVGNGVGPDQIIKYAREFGLGERTGINYSQEGAGIIPGFSENLDVRFGATGEGVEATPIQLATLVSAVANGGQLVTPYVPHSSAESSETQPPVRRRIAIPPANLGLLMAGMIAAVDHGSGTGASDTSQIVAGKTGTFRDKTTNVGLFASYAPANDPRFVVVVVTRGQNESGPEAANVAGTIFRGLRNRS 4 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:11.6.2-cudnn8-devel-ubuntu20.04 2 | 3 | ARG port 4 | ARG password 5 | 6 | ENV PORT=$port 7 | ENV PASSWORD=$password 8 | 9 | RUN apt-get update && apt-get upgrade -y 10 | RUN apt-get install -y libcairo2 build-essential libbz2-dev libdb-dev libreadline-dev libffi-dev libgdbm-dev liblzma-dev libncursesw5-dev libsqlite3-dev libssl-dev zlib1g-dev uuid-dev python3 python3-dev curl python3-pip 11 | RUN curl -kL https://bootstrap.pypa.io/get-pip.py | python3 12 | 13 | RUN pip install tensorflow-gpu==2.10.0 14 | RUN pip install tensorflow==2.10.0 15 | RUN pip install matplotlib 16 | RUN pip install jupyter 17 | RUN pip install numpy==1.21.0 18 | RUN pip install pandas 19 | RUN pip install scikit-learn 20 | RUN pip install biopython 21 | 22 | WORKDIR /home 23 | 24 | # RUN mkdir -p /root/.jupyter 25 | # COPY jupyter_notebook_config.py /root/.jupyter/jupyter_notebook_config.py 26 | 27 | EXPOSE $port 28 | ENTRYPOINT ["sh", "-c","jupyter notebook --port=$PORT --no-browser --allow-root --ip=0.0.0.0 --NotebookApp.token=$PASSWORD"] 29 | -------------------------------------------------------------------------------- /proteinbert/shared_utils/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /proteinbert/existing_model_loading.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from urllib.parse import urlparse 4 | from urllib.request import urlopen 5 | 6 | from tensorflow import keras 7 | 8 | from . import conv_and_global_attention_model 9 | from .model_generation import load_pretrained_model_from_dump 10 | 11 | DEFAULT_LOCAL_MODEL_DUMP_DIR = 'proteinbert_models' 12 | DEFAULT_LOCAL_MODEL_DUMP_FILE_NAME = 'default.pkl' 13 | DEFAULT_REMOTE_MODEL_DUMP_URL = 'ftp://ftp.cs.huji.ac.il/users/nadavb/protein_bert/epoch_92400_sample_23500000.pkl' 14 | 15 | def load_pretrained_model(local_model_dump_dir = DEFAULT_LOCAL_MODEL_DUMP_DIR, local_model_dump_file_name = DEFAULT_LOCAL_MODEL_DUMP_FILE_NAME, \ 16 | remote_model_dump_url = DEFAULT_REMOTE_MODEL_DUMP_URL, download_model_dump_if_not_exists = True, validate_downloading = True, \ 17 | create_model_function = conv_and_global_attention_model.create_model, create_model_kwargs = {}, optimizer_class = keras.optimizers.Adam, lr = 2e-04, \ 18 | other_optimizer_kwargs = {}, annots_loss_weight = 1, load_optimizer_weights = False): 19 | 20 | #local_model_dump_dir = os.path.expanduser(local_model_dump_dir) 21 | dump_file_path = os.path.join(local_model_dump_dir, local_model_dump_file_name) 22 | print(dump_file_path) 23 | 24 | if not os.path.exists(dump_file_path) and download_model_dump_if_not_exists: 25 | 26 | if validate_downloading: 27 | print('aaa') 28 | print(f'Local model dump file {dump_file_path} doesn\'t exist. Will download {remote_model_dump_url} into {local_model_dump_dir}. Please approve or reject this ' + \ 29 | '(to exit and potentially call the function again with different parameters).') 30 | 31 | while True: 32 | 33 | user_input = input('Do you approve downloadig the file into the specified directory? Please specify "Yes" or "No":') 34 | 35 | if user_input.lower() in {'yes', 'y'}: 36 | break 37 | elif user_input.lower() in {'no', 'n'}: 38 | raise ValueError('User wished to cancel.') 39 | 40 | downloaded_file_name = os.path.basename(urlparse(remote_model_dump_url).path) 41 | downloaded_file_path = os.path.join(local_model_dump_dir, downloaded_file_name) 42 | assert not os.path.exists(downloaded_file_path), 'Cannot download into an already existing file: %s' % downloaded_file_path 43 | 44 | with urlopen(remote_model_dump_url) as remote_file, open(downloaded_file_path, 'wb') as local_file: 45 | shutil.copyfileobj(remote_file, local_file) 46 | 47 | print('Downloaded file: %s' % downloaded_file_path) 48 | 49 | if downloaded_file_name != local_model_dump_file_name: 50 | os.symlink(downloaded_file_path, dump_file_path) 51 | print('Created: %s' % dump_file_path) 52 | 53 | return load_pretrained_model_from_dump(dump_file_path, create_model_function, create_model_kwargs = create_model_kwargs, optimizer_class = optimizer_class, lr = lr, \ 54 | other_optimizer_kwargs = other_optimizer_kwargs, annots_loss_weight = annots_loss_weight, load_optimizer_weights = load_optimizer_weights) -------------------------------------------------------------------------------- /tokenization.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | 4 | from proteinbert.shared_utils.util import log 5 | 6 | ALL_AAS = 'ACDEFGHIKLMNPQRSTUVWXY' 7 | ADDITIONAL_TOKENS = ['', '', '', ''] 8 | mechanism_labels = { 9 | 10 | 'antibiotic target alteration':0, 11 | 'antibiotic target replacement':1, 12 | 'antibiotic target protection':2, 13 | 'antibiotic inactivation':3, 14 | 'antibiotic efflux':4, 15 | 'others':5 16 | 17 | } 18 | 19 | # Each sequence is added and tokens 20 | ADDED_TOKENS_PER_SEQ = 2 21 | 22 | n_aas = len(ALL_AAS) 23 | aa_to_token_index = {aa: i for i, aa in enumerate(ALL_AAS)} 24 | additional_token_to_index = {token: i + n_aas for i, token in enumerate(ADDITIONAL_TOKENS)} 25 | token_to_index = {**aa_to_token_index, **additional_token_to_index} 26 | index_to_token = {index: token for token, index in token_to_index.items()} 27 | n_tokens = len(token_to_index) 28 | 29 | def tokenize_seq(seq): 30 | other_token_index = additional_token_to_index[''] 31 | return [additional_token_to_index['']] + [aa_to_token_index.get(aa, other_token_index) for aa in parse_seq(seq)] + \ 32 | [additional_token_to_index['']] 33 | 34 | def parse_seq(seq): 35 | if isinstance(seq, str): 36 | return seq 37 | elif isinstance(seq, bytes): 38 | return seq.decode('utf8') 39 | else: 40 | raise TypeError('Unexpected sequence type: %s' % type(seq)) 41 | 42 | def encode_dataset(dataset, input_encoder, mechanism_labels, seq_len = 512, needs_filtering = True, dataset_name = 'Dataset', verbose = True):#seq_len = 512 43 | 44 | seqs = dataset['sequence'] 45 | raw_Y = dataset['mechanism'] 46 | 47 | if needs_filtering: 48 | dataset = filter_dataset_by_len(dataset, seq_len = seq_len, dataset_name = dataset_name, verbose = verbose) 49 | seqs = dataset['sequence'] 50 | raw_Y = dataset['mechanism'] 51 | 52 | X = input_encoder.encode_X(seqs, seq_len) 53 | Y, sample_weigths = encode_Y(raw_Y, seq_len = seq_len, mechanism_labels = mechanism_labels) 54 | return X, Y, sample_weigths 55 | 56 | def encode_Y(raw_Y, seq_len = 512, mechanism_labels = mechanism_labels): 57 | return encode_categorical_Y(raw_Y, mechanism_labels), np.ones(len(raw_Y)) 58 | 59 | def encode_seq_Y(seqs, seq_len, is_binary, mechanism_labels): 60 | 61 | label_to_index = {str(label): i for i, label in enumerate(mechanism_labels)} 62 | 63 | Y = np.zeros((len(seqs), seq_len), dtype = int) 64 | sample_weigths = np.zeros((len(seqs), seq_len)) 65 | 66 | for i, seq in enumerate(seqs): 67 | 68 | for j, label in enumerate(seq): 69 | # +1 to account for the token at the beginning. 70 | Y[i, j + 1] = label_to_index[label] 71 | 72 | sample_weigths[i, 1:(len(seq) + 1)] = 1 73 | 74 | if is_binary: 75 | Y = np.expand_dims(Y, axis = -1) 76 | sample_weigths = np.expand_dims(sample_weigths, axis = -1) 77 | 78 | return Y, sample_weigths 79 | 80 | def encode_categorical_Y(labels, mechanism_labels): 81 | 82 | label_to_index = {label: i for i, label in enumerate(mechanism_labels)} 83 | Y = np.zeros(len(labels), dtype = int) 84 | 85 | for i, label in enumerate(labels): 86 | Y[i] = label_to_index[label] 87 | 88 | return Y 89 | 90 | def filter_dataset_by_len(dataset, seq_len = 512, seq_col_name = 'sequence', dataset_name = 'Dataset', verbose = True): 91 | 92 | max_allowed_input_seq_len = seq_len - ADDED_TOKENS_PER_SEQ 93 | filtered_dataset = dataset[dataset[seq_col_name].str.len() <= max_allowed_input_seq_len] 94 | n_removed_records = len(dataset) - len(filtered_dataset) 95 | 96 | if verbose: 97 | log('%s: Filtered out %d of %d (%.1f%%) records of lengths exceeding %d.' % (dataset_name, n_removed_records, len(dataset), 100 * n_removed_records / len(dataset), \ 98 | max_allowed_input_seq_len)) 99 | 100 | return filtered_dataset 101 | 102 | def split_dataset_by_len(dataset, seq_col_name = 'sequence', start_seq_len = 512, start_batch_size = 32, increase_factor = 2):#start_seq_len = 512 103 | 104 | seq_len = start_seq_len 105 | batch_size = start_batch_size 106 | 107 | while len(dataset) > 0: 108 | max_allowed_input_seq_len = seq_len - ADDED_TOKENS_PER_SEQ 109 | len_mask = (dataset[seq_col_name].str.len() <= max_allowed_input_seq_len) 110 | len_matching_dataset = dataset[len_mask] 111 | yield len_matching_dataset, seq_len, batch_size, len_matching_dataset.index.tolist() 112 | dataset = dataset[~len_mask] 113 | seq_len *= increase_factor 114 | batch_size = max(batch_size // increase_factor, 1) -------------------------------------------------------------------------------- /proteinbert/shared_utils/reference_genome.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .util import as_biopython_seq 4 | 5 | class GenomeReader: 6 | 7 | ''' 8 | An API to read sequences from the reference genome, assuming it's organized into per-chromosome file. 9 | A GenomeReader object is initialized from a directory containing the relevant FAST files. It will automatically detect all .fa and .fasta files 10 | within that directory. 11 | The reference genome sequences of all human chromosomes (chrXXX.fa.gz files) can be downloaded from UCSC's FTP site at: 12 | ftp://hgdownload.cse.ucsc.edu/goldenPath/hg19/chromosomes/ (for version hg19) 13 | ftp://hgdownload.cse.ucsc.edu/goldenPath/hg38/chromosomes/ (for version hg38/GRCh38) 14 | The chrXXX.fa.gz files need to be uncompressed to obtain chrXXX.fa files. 15 | IMPORTANT: In version hg19 there's an inconsistency in the reference genome of the M chromosome between UCSC and RegSeq/GENCODE, 16 | so the file chrM.fa is better to be taken from RefSeq (NC_012920.1) instead of UCSC, from: 17 | https://www.ncbi.nlm.nih.gov/sviewer/viewer.cgi?tool=portal&save=file&log$=seqview&db=nuccore&report=fasta&sort=&id=251831106&from=begin&to=end&maxplex=1 18 | In GRCh38 all UCSC files are fine. 19 | ''' 20 | 21 | def __init__(self, ref_genome_dir): 22 | self.ref_genome_dir = ref_genome_dir 23 | self.chromosome_readers = dict(self._create_chromosome_reader(file_name) for file_name in os.listdir(self.ref_genome_dir) \ 24 | if _is_ref_genome_file(file_name)) 25 | 26 | def read_seq(self, chromosome, start, end): 27 | try: 28 | return self.chromosome_readers[_find_chrom(chromosome, self.chromosome_readers.keys())].read_seq(start, end) 29 | except KeyError: 30 | raise ValueError('Chromosome "%s" was not found in the reference genome.' % chromosome) 31 | 32 | def close(self): 33 | for chromosome_reader in self.chromosome_readers.values(): 34 | chromosome_reader.file_handler.close() 35 | 36 | def __contains__(self, chromosome): 37 | return _find_chrom(chromosome, self.chromosome_readers.keys()) is not None 38 | 39 | def _create_chromosome_reader(self, file_name): 40 | chr_name = file_name.split('.')[0].replace('chr', '') 41 | f = open(os.path.join(self.ref_genome_dir, file_name), 'r') 42 | return chr_name, ChromosomeReader(f) 43 | 44 | class ChromosomeReader: 45 | 46 | def __init__(self, file_handler): 47 | self.file_handler = file_handler 48 | self.header_len = len(file_handler.readline()) 49 | self.line_len = len(file_handler.readline()) - 1 50 | 51 | def read_seq(self, start, end): 52 | 53 | absolute_start = self.convert_to_absolute_coordinate(start) 54 | absolute_length = self.convert_to_absolute_coordinate(end) - absolute_start + 1 55 | 56 | self.file_handler.seek(absolute_start) 57 | seq = self.file_handler.read(absolute_length).replace('\n', '').upper() 58 | return as_biopython_seq(seq) 59 | 60 | def convert_to_absolute_coordinate(self, position): 61 | position_zero_index = position - 1 62 | return self.header_len + position_zero_index + (position_zero_index // self.line_len) 63 | 64 | def _find_chrom(query_chr_name, available_chr_names): 65 | 66 | assert isinstance(query_chr_name, str), 'Unexpected chromosome type: %s' % type(query_chr_name) 67 | 68 | if query_chr_name.lower().startswith('chr'): 69 | query_chr_name = query_chr_name[3:] 70 | 71 | query_chr_name = query_chr_name.upper() 72 | 73 | for possible_chr_name in _find_synonymous_chr_names(query_chr_name): 74 | 75 | if possible_chr_name in available_chr_names: 76 | return possible_chr_name 77 | 78 | prefixed_possible_chr_name = 'chr%s' % possible_chr_name 79 | 80 | if prefixed_possible_chr_name in available_chr_names: 81 | return prefixed_possible_chr_name 82 | 83 | return None 84 | 85 | def _is_ref_genome_file(file_name): 86 | 87 | file_name = file_name.lower() 88 | 89 | for extension in _SUPPORTED_REF_GENOME_EXTENSIONS: 90 | if file_name.endswith(file_name): 91 | return True 92 | 93 | return False 94 | 95 | def _find_synonymous_chr_names(chr_name): 96 | 97 | for synonymous_chr_name_group in _SYNONYMOUS_CHR_NAME_GROUPS: 98 | if chr_name in synonymous_chr_name_group: 99 | return synonymous_chr_name_group 100 | 101 | # Single-digit numbers can either appear with or without a trailing 0. 102 | if chr_name.isdigit() and len(str(int(chr_name))) == 1: 103 | chr_number = str(int(chr_name)) 104 | return {chr_number, '0' + chr_number} 105 | 106 | return {chr_name} 107 | 108 | _SUPPORTED_REF_GENOME_EXTENSIONS = [ 109 | '.fa', 110 | '.fasta', 111 | ] 112 | 113 | _SYNONYMOUS_CHR_NAME_GROUPS = [ 114 | {'X', '23'}, 115 | {'Y', '24'}, 116 | {'XY', '25'}, 117 | {'M', 'MT', '26'}, 118 | ] 119 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ARG-BERT 2 | The repository contains an implementation of ARG-BERT, a BERT model that predicts the resistance mechanism of antibiotic resistance genes. If you use this model, please cite our paper. If you have any problems or comments about this repository, please contact us. 3 | ## 1. Dependencies 4 | We have tested the model in the following environments. 5 | ``` 6 | Linux: x86_64 7 | OS: Ubuntu 20.04.6 8 | GPU: NVIDIA A100 80G 9 | CUDA Version: 11.6 10 | Nvidia Driver Version: 510.47.03 11 | ``` 12 | We have all the necessary packages (see Dockerfile) in the Docker environment. Install them with the following command: 13 | ``` 14 | docker build -t USERNAME/CONTAINERNAME --build-args port=PORT --build-args password=PASSWORD . 15 | docker login && docker push USERNAME/CONTAINERNAME 16 | docker pull USERNAME/CONTAINERNAME 17 | docker run -p PORT:PORT -e -it --gpus all --rm -v $PWD:/home USERNAME/CONTAINERNAME 18 | ``` 19 | ## 2. Dataset and Fine-tuning 20 | ### 2.1 Dataset 21 | Sorry we cannot publish the HMD-ARG DB and Low Homology Dataset, but the format of the data is shown in the sample data in `Sample_data`. 22 | 23 | We saved the larger files, such as the output results and the Attention values for all sequences in the HMD-ARG DB, at [https://waseda.box.com/v/ARG-BERT-suppl](https://waseda.box.com/v/ARG-BERT-suppl). 24 | You could run the script if you stored the `Prediction results` in `Analysis_and_Figures` in this repository and the contents of the `Attention_analysis` in a directory of the same name in `Analysis_and_Figures`. 25 | 26 | ### 2.2 Fine-tuning 27 | Run `finetuning.py` to train the ProteinBERT on ARGs by running the follwing commands. 28 | 29 | If you would like to train with the HMD-ARG DB, run: 30 | ``` 31 | python3 finetuning.py \ 32 | --fold FOLD \ 33 | --gpu GPU \ 34 | --seed SEED 35 | ``` 36 | FOLD, GPU and SEED are integers of type int, indicating the number of iterations in 5-fold CV, the GPU device you will use and the random seed respectively. 37 | 38 | Alternatively, if you would like to train with the Low Homology Dataset, run: 39 | ``` 40 | python3 finetuning.py \ 41 | --fold FOLD \ 42 | --use_LHD \ 43 | --threshold THRESHOLD 44 | --gpu GPU \ 45 | --seed SEED 46 | ``` 47 | THRESHOLD is the floating point numbers of type float, indicating the sequence similarity thresholds set when creating the LHD. 48 | 49 | ## 3. Test 50 | Run `test.py` by running the follwing commands. 51 | 52 | If you would like to test with the HMD-ARG DB, run: 53 | ``` 54 | python3 test.py \ 55 | --fold FOLD \ 56 | --seed SEED 57 | ``` 58 | 59 | Alternatively, if you would like to test with the Low Homology Dataset, run: 60 | ``` 61 | python3 finetuning.py \ 62 | --fold FOLD \ 63 | --use_LHD \ 64 | --threshold THRESHOLD 65 | --gpu GPU \ 66 | --seed SEED 67 | ``` 68 | 69 | To get the input sequences' attention, use the command `--get_all_attention` or `-attn` in both cases. 70 | 71 | ## 4. Attention Analysis 72 | The code for the analysis is available in a Jupyter notebook in `Analysis_and_Figures/Attention_analysis`. 73 | 74 | In `Fig3.ipynb`, the experiments described in sections 2.6.1 and 3.2.1 of our paper are carried out, and Figure 3 and Figure S5 can be output. 75 | 76 | `Fig4.ipynb` and `Fig5.ipynb` also carry out the experiments described in sections 2.6.2 and 3.2.2, and can output Figure 4ac, S6a and S7a for `Fig4.ipynb` and Figure 5 for `Fig5.ipynb`. 77 | 78 | ## 5. License 79 | We used ProteinBERT licensed under the MIT license; the copyright notice and permission notice for ProteinBERT are given here. 80 | 81 | ``` 82 | Copyright (C) 2022 Nadav Brandes 83 | 84 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 85 | 86 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 87 | 88 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE X CONSORTIUM BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 89 | 90 | Except as contained in this notice, the name of shall not be used in advertising or otherwise to promote the sale, use or other dealings in this Software without prior written authorization from . 91 | ``` 92 | 93 | ## 6. Citation 94 | If you would like to use ARG-BERT, please cite our paper 95 | ``` 96 | @article{yagimoto2024prediction, 97 | title={Prediction of antibiotic resistance mechanisms using a protein language model}, 98 | author={Yagimoto, Kanami and Hosoda, Shion and Sato, Miwa and Hamada, Michiaki}, 99 | journal={bioRxiv}, 100 | pages={2024--05}, 101 | year={2024}, 102 | publisher={Cold Spring Harbor Laboratory} 103 | } 104 | ``` 105 | -------------------------------------------------------------------------------- /finetuning.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import os 4 | import argparse 5 | import pickle 6 | from sklearn.model_selection import train_test_split 7 | 8 | import tensorflow as tf 9 | from tensorflow import keras 10 | from tensorflow.keras import backend as K 11 | 12 | from proteinbert.existing_model_loading import load_pretrained_model 13 | from proteinbert.conv_and_global_attention_model import get_model_with_hidden_layers_as_outputs 14 | from proteinbert.shared_utils.util import log 15 | from proteinbert.model_generation import FinetuningModelGenerator 16 | from tokenization import index_to_token, ADDED_TOKENS_PER_SEQ, mechanism_labels, encode_dataset 17 | 18 | class Config_Finetuning: 19 | 20 | def __init__(self, args): 21 | 22 | 23 | self.mechanism_labels = mechanism_labels 24 | self.fold = args.fold 25 | self.use_LHD = args.use_LHD 26 | self.threshold = args.threshold 27 | self.gpu = args.gpu 28 | self.seed = args.seed 29 | 30 | def create_dataset_or_model_path(self, create_dataset_path): 31 | 32 | if self.use_LHD: 33 | sub_path = 'LHD/c%1f/fold_%d_%1f'%(self.fold, self.threshold) 34 | else: 35 | sub_path = 'HMDARG-DB/fold_%d'%self.fold 36 | 37 | if create_dataset_path: 38 | root_dir = 'inputs' 39 | return os.path.join(root_dir, '%s.train.csv' % sub_path) 40 | else: 41 | root_dir = 'outputs/finetuned_model' 42 | return os.path.join(root_dir,'%s.finetuned_model.h5' % sub_path) 43 | 44 | def set_gpu(self): 45 | 46 | if self.gpu != None: 47 | config = tf.compat.v1.ConfigProto( 48 | gpu_options=tf.compat.v1.GPUOptions( 49 | visible_device_list= str(self.gpu), # specify GPU number 50 | allow_growth=True, 51 | per_process_gpu_memory_fraction=0.2 52 | ) 53 | ) 54 | sess = tf.compat.v1.Session(config=config) 55 | 56 | 57 | def set_seed(self): 58 | if self.seed != None: 59 | return tf.random.set_seed(self.seed) 60 | 61 | 62 | def finetune(model_generator, input_encoder, config, train, valid = None, seq_len = 512, batch_size = 32, \ 63 | max_epochs_per_stage = 40, lr = None, begin_with_frozen_pretrained_layers = True, lr_with_frozen_pretrained_layers = None, n_final_epochs = 1, \ 64 | final_seq_len = 1024, final_lr = None, callbacks = []): 65 | print(seq_len) 66 | 67 | encoded_train_set, encoded_valid_set = encode_train_and_valid_sets(train, valid, input_encoder, config, seq_len) 68 | 69 | if begin_with_frozen_pretrained_layers: 70 | log('Training with frozen pretrained layers...') 71 | model_generator.train(encoded_train_set, encoded_valid_set, seq_len, batch_size, max_epochs_per_stage, lr = lr_with_frozen_pretrained_layers, \ 72 | callbacks = callbacks, freeze_pretrained_layers = True) 73 | 74 | log('Training the entire fine-tuned model...') 75 | model_generator.train(encoded_train_set, encoded_valid_set, seq_len, batch_size, max_epochs_per_stage, lr = lr, callbacks = callbacks, \ 76 | freeze_pretrained_layers = False) 77 | 78 | if n_final_epochs > 0: 79 | log('Training on final epochs of sequence length %d...' % final_seq_len) 80 | final_batch_size = max(int(batch_size / (final_seq_len / seq_len)), 1) 81 | encoded_train_set, encoded_valid_set = encode_train_and_valid_sets(train, valid, input_encoder, config, final_seq_len) 82 | model_generator.train(encoded_train_set, encoded_valid_set, final_seq_len, final_batch_size, n_final_epochs, lr = final_lr, callbacks = callbacks, \ 83 | freeze_pretrained_layers = False) 84 | 85 | model_generator.optimizer_weights = None 86 | 87 | def make_train_and_valid_sets(path): 88 | 89 | train_set = pd.read_csv(path, index_col = 0)#.dropna()#.drop_duplicates() 90 | train_set, valid_set = train_test_split(train_set, test_size = 0.1, random_state = 1) 91 | 92 | print(f'{len(train_set)} training set records, {len(valid_set)} validation set records') 93 | 94 | return train_set, valid_set 95 | 96 | 97 | def encode_train_and_valid_sets(train, valid, input_encoder, config, seq_len): 98 | 99 | encoded_train_set = encode_dataset(train, input_encoder, config.mechanism_labels, seq_len = seq_len, needs_filtering = True, \ 100 | dataset_name = 'Training set') 101 | 102 | if valid is None: 103 | encoded_valid_set = None 104 | else: 105 | encoded_valid_set = encode_dataset(valid, input_encoder, mechanism_labels, seq_len = seq_len, needs_filtering = True, \ 106 | dataset_name = 'Validation set') 107 | 108 | return encoded_train_set, encoded_valid_set 109 | 110 | 111 | def main(config): 112 | 113 | config.set_gpu() 114 | config.set_seed() 115 | #output_spec = None 116 | 117 | dataset_path = config.create_dataset_or_model_path(create_dataset_path = True) 118 | train_set, valid_set = make_train_and_valid_sets(dataset_path) 119 | 120 | # Loading the pre-trained model and fine-tuning it on the loaded dataset 121 | 122 | pretrained_model_generator, input_encoder = load_pretrained_model(local_model_dump_dir = 'proteinbert/proteinbert_models', local_model_dump_file_name = 'default.pkl') 123 | 124 | # get_model_with_hidden_layers_as_outputs gives the model output access to the hidden layers (on top of the output) 125 | model_generator = FinetuningModelGenerator(pretrained_model_generator, pretraining_model_manipulation_function = \ 126 | get_model_with_hidden_layers_as_outputs, dropout_rate = 0.5) 127 | 128 | training_callbacks = [ 129 | keras.callbacks.ReduceLROnPlateau(patience = 1, factor = 0.25, min_lr = 1e-05, verbose = 1), 130 | keras.callbacks.EarlyStopping(patience = 2, restore_best_weights = True), 131 | ] 132 | 133 | finetune(model_generator, input_encoder, config, train_set, valid_set, \ 134 | seq_len = 512, batch_size = 32, max_epochs_per_stage = 40, lr = 1e-04, begin_with_frozen_pretrained_layers = True, \ 135 | lr_with_frozen_pretrained_layers = 1e-02, n_final_epochs = 1, final_seq_len = 1024, final_lr = 1e-05, callbacks = training_callbacks) 136 | 137 | finetuned_model = model_generator.create_model(seq_len = 1578) 138 | 139 | finetuned_model_path = config.create_dataset_or_model_path(create_dataset_path = False) 140 | with open(finetuned_model_path, 'wb') as f: 141 | pickle.dump((finetuned_model.get_weights(), finetuned_model.optimizer.get_weights()), f) 142 | 143 | if __name__ == "__main__": 144 | 145 | parser = argparse.ArgumentParser() 146 | parser.add_argument('-f', '--fold', type=int, help='The number of iterations in 5-fold CV.') 147 | parser.add_argument('-LHD', '--use_LHD', action='store_true', help='Whether you use Low Homology Dataset or not.', default=False) 148 | parser.add_argument('-t', '--threshold', type=float, help='Sequence similarity thresholds set when creating LHD.', default=0) 149 | parser.add_argument('-g', '--gpu', type=int, help='Assign the GPU devices you will use.', default=None) 150 | parser.add_argument('-s', '--seed', type=int, help='Set random seed.', default=None) 151 | config = Config_Finetuning(parser.parse_args()) 152 | 153 | main(config) -------------------------------------------------------------------------------- /Analysis_and_Figures/Attention_analysis/Fig5.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# GO enrichment analysis" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import os\n", 17 | "import pandas as pd\n", 18 | "import numpy as np\n", 19 | "import matplotlib.pyplot as plt\n", 20 | "import requests\n", 21 | "from scipy.stats import fisher_exact" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": null, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "mechanism_list = [\n", 31 | "'antibiotic inactivation', \n", 32 | "'antibiotic target alteration',\n", 33 | "'antibiotic efflux', \n", 34 | "'antibiotic target replacement',\n", 35 | "'antibiotic target protection'\n", 36 | "]\n", 37 | "\n", 38 | "path = 'Results/'\n", 39 | "if not os.path.exists(path):\n", 40 | " print(\"make\" + path)\n", 41 | " os.makedirs(path)" 42 | ] 43 | }, 44 | { 45 | "cell_type": "markdown", 46 | "metadata": {}, 47 | "source": [ 48 | "## Run static tests." 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": null, 54 | "metadata": {}, 55 | "outputs": [], 56 | "source": [ 57 | "def get_lists_of_GO(interpro_with_GO, mechanism):\n", 58 | " interpro_with_GO_per_mechanism = interpro_with_GO[interpro_with_GO['mechanism'] == mechanism].fillna('-')\n", 59 | " GO_list = []\n", 60 | " for GO in interpro_with_GO_per_mechanism['GO']:\n", 61 | " try:\n", 62 | " if GO != '-':\n", 63 | " GO_list += GO.split('|')\n", 64 | " except AttributeError :\n", 65 | " pass\n", 66 | " return_tuple = (GO_list,interpro_with_GO_per_mechanism)\n", 67 | " return return_tuple" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": null, 73 | "metadata": {}, 74 | "outputs": [], 75 | "source": [ 76 | "interpro = pd.read_csv(path + 'result_attention-intensive_regions.csv')\n", 77 | "interpro_unique = interpro.drop('Start', axis=1).drop('End', axis=1).groupby(['ID','Accession'], as_index=False).first()\n", 78 | "interpro_with_GO_all = interpro_unique[(interpro_unique['GO'] != '-') & (interpro['GO'].notna())]\n", 79 | "interpro_with_GO_significant = interpro_with_GO_all[interpro_with_GO_all['Significance']]" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": null, 85 | "metadata": {}, 86 | "outputs": [], 87 | "source": [ 88 | "result_all = pd.DataFrame()\n", 89 | "num_of_test = 0\n", 90 | "for mechanism in mechanism_list:\n", 91 | " significant_GO_list,significant_interpro_per_mechanism = get_lists_of_GO(interpro_with_GO_significant, mechanism)\n", 92 | " all_GO_list,all_interpro_per_mechanism = get_lists_of_GO(interpro_with_GO_all,mechanism)\n", 93 | "\n", 94 | " GO_list = list(set(all_GO_list))\n", 95 | " result_dict = {}\n", 96 | " num_significant_regions = len(significant_interpro_per_mechanism)\n", 97 | " num_all_regions = len(interpro_per_mechanism)\n", 98 | " for GO in GO_list:\n", 99 | " \n", 100 | " num_withGO_significant = significant_interpro_per_mechanism['GO'].str.contains(GO).sum()\n", 101 | " num_withGO_ns = all_interpro_per_mechanism['GO'].str.contains(GO).sum() - num_withGO_significant\n", 102 | " num_withoutGO_significant = num_significant_regions - num_withGO_significant\n", 103 | " num_withoutGO_ns = num_all_regions - num_withGO_significant - num_withGO_ns - num_withoutGO_significant\n", 104 | "\n", 105 | " data = np.array([[num_withGO_significant, num_withGO_ns],[num_withoutGO_significant, num_withoutGO_ns]])\n", 106 | " result_dict[GO] = [fisher_exact(data,alternative='greater')[1]] + data.flatten().tolist()\n", 107 | "\n", 108 | " result_df = pd.DataFrame(result_dict, index=['p-value', 'w/ GO and significant','w/ GO and NOT significant','w/o GO and significant','w/o GO and NOT significant']).T\n", 109 | " result_df['mechanism']=mechanism\n", 110 | " result_all = pd.concat([result_all,result_df])\n", 111 | " num_of_test += len(GO_list)\n", 112 | " " 113 | ] 114 | }, 115 | { 116 | "cell_type": "markdown", 117 | "metadata": {}, 118 | "source": [ 119 | "## Get GO terms from QuickGO" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": null, 125 | "metadata": {}, 126 | "outputs": [], 127 | "source": [ 128 | "def get_go_term(go_id):\n", 129 | " api_url = f\"https://www.ebi.ac.uk/QuickGO/services/ontology/go/terms/{go_id}\"\n", 130 | " response = requests.get(api_url)\n", 131 | " if response.status_code == 200:\n", 132 | " data = response.json()\n", 133 | " go_term = data['results'][0]['name']\n", 134 | " return go_term\n", 135 | " else:\n", 136 | " return None\n", 137 | "\n", 138 | "GO_list = list(map(get_go_term,result_all.index))\n", 139 | "result_all['GO term'] = GO_list" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": null, 145 | "metadata": {}, 146 | "outputs": [], 147 | "source": [ 148 | "result_true = result_all[result_all['p-value']<0.05/num_of_test]\n", 149 | "result_false = result_all[result_all['p-value']>=0.05/num_of_test]\n", 150 | "result_all = pd.concat([result_true,result_false])\n", 151 | "result_all.to_csv(path + 'result_GO_analysis.csv')" 152 | ] 153 | }, 154 | { 155 | "cell_type": "markdown", 156 | "metadata": {}, 157 | "source": [ 158 | "## Visualization:Fig5" 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": null, 164 | "metadata": {}, 165 | "outputs": [], 166 | "source": [ 167 | "result_viz = result_true[result_true['w/ GO and significant']>150]\n", 168 | "result_viz['-log10(p-value)'] = -result_viz['p-value'].apply(np.log10)\n", 169 | "result_viz = result_viz.rename(columns ={'mechanism':'Resistance mechanism'}).loc[:,['GO term','Resistance mechanism','-log(p-value)']]" 170 | ] 171 | }, 172 | { 173 | "cell_type": "code", 174 | "execution_count": null, 175 | "metadata": {}, 176 | "outputs": [], 177 | "source": [ 178 | "sns.set(style=\"whitegrid\") \n", 179 | "plt.figure(figsize=(15, 15))\n", 180 | "mechanism_list_new = mechanism_list[:3] + [mechanism_list[4]]\n", 181 | "sns.barplot(GO_viz['-log(p-value)'],GO_viz['GO term'], hue = GO_viz['Resistance mechanism'], hue_order= mechanism_list_new, palette = 'Set2', dodge = False)\n", 182 | "plt.legend(title=\"Resistance mechanism\", title_fontsize=\"xx-large\",fontsize=\"xx-large\")\n", 183 | "plt.yticks(fontsize=\"xx-large\")\n", 184 | "plt.xticks(fontsize=\"xx-large\")\n", 185 | "plt.xlabel('-log10(p-value)',fontsize=\"xx-large\")\n", 186 | "plt.ylabel('')" 187 | ] 188 | } 189 | ], 190 | "metadata": { 191 | "interpreter": { 192 | "hash": "4e4e8b0f9678671e7c1801361c888ac62514a7043b530e64d08c86727710599e" 193 | }, 194 | "kernelspec": { 195 | "display_name": "Python 3 (ipykernel)", 196 | "language": "python", 197 | "name": "python3" 198 | }, 199 | "language_info": { 200 | "codemirror_mode": { 201 | "name": "ipython", 202 | "version": 3 203 | }, 204 | "file_extension": ".py", 205 | "mimetype": "text/x-python", 206 | "name": "python", 207 | "nbconvert_exporter": "python", 208 | "pygments_lexer": "ipython3", 209 | "version": "3.8.10" 210 | } 211 | }, 212 | "nbformat": 4, 213 | "nbformat_minor": 4 214 | } 215 | -------------------------------------------------------------------------------- /Analysis_and_Figures/Attention_analysis/Fig3.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Calculate amino acid conservation score." 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "from Bio import AlignIO\n", 17 | "from Bio import SeqIO\n", 18 | "from Bio.Align.AlignInfo import PSSM\n", 19 | "from Bio.Align.AlignInfo import SummaryInfo\n", 20 | "import pandas as pd\n", 21 | "import matplotlib.pyplot as plt\n", 22 | "import seaborn as sns" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": null, 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "target_dict = {\n", 32 | " 'blaOXA-114s': 'U3N8W9',#antibiotic inactivation ABW87257\n", 33 | " 'rpoB': 'NP_273190.1',# antibiotic target alteration AF:B4RQW2\n", 34 | " 'macB': 'A0A011P660', # antibiotic efflux\n", 35 | " 'tetW': 'ABN80187'# antibiotic target protection\n", 36 | "}\n", 37 | "gene_name = 'tetW'\n", 38 | "gene_id = target_dict[gene_name]\n", 39 | "\n", 40 | "if not os.path.exists(gene_name):\n", 41 | " os.makedirs(path + gene_name)" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "fasta_file = gene_name + \"/clustalw.fasta\"\n", 51 | "\n", 52 | "align = AlignIO.read(fasta_file, \"fasta\")\n", 53 | "summary_align = SummaryInfo(align)\n", 54 | "\n", 55 | "for record in align:\n", 56 | " if record.id == gene_name:\n", 57 | " sequence = record.seq\n", 58 | " break\n", 59 | "\n", 60 | "freq = {}\n", 61 | "for aa in set(list(sequence)):\n", 62 | " if aa != '-':\n", 63 | " freq[aa] = 1/20\n", 64 | "\n", 65 | " " 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": null, 71 | "metadata": {}, 72 | "outputs": [], 73 | "source": [ 74 | "pssm = summary_align.pos_specific_score_matrix(axis_seq = sequence, chars_to_ignore = ['-'])\n", 75 | "df_pssm = pd.DataFrame(index=pssm[0].keys(),columns=list(range(len(summary_align.dumb_consensus()))))\n", 76 | "n = 0\n", 77 | "for p in pssm:\n", 78 | " df_pssm[n] = p.values()\n", 79 | " n += 1\n", 80 | "\n", 81 | "from scipy.stats import entropy\n", 82 | "qk = [1/20] * 20\n", 83 | "info = []\n", 84 | "for n in df_pssm.columns:\n", 85 | " info.append(entropy(df_pssm[n], qk=qk, base = 2))\n", 86 | "\n", 87 | "amino_index = []\n", 88 | "for i,a in enumerate(list(sequence)):\n", 89 | " if a != '-':\n", 90 | " amino_index.append(i)\n", 91 | "info_content = []\n", 92 | "for a_i in amino_index:\n", 93 | " info_content.append(info[a_i])" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": null, 99 | "metadata": {}, 100 | "outputs": [], 101 | "source": [ 102 | "path = '/'\n", 103 | "fold_table = pd.read_csv('../Sample_data/fold_table.csv', index_col=0)\n", 104 | "fold = fold_table[fold_table['ID']==gene_id]['fold'].tolist()[0]\n", 105 | "attention = pd.read_csv('attention/fold_'+str(fold)+'_attention.csv', index_col = 0)[gene_id].dropna().iloc[1:-1]" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": null, 111 | "metadata": {}, 112 | "outputs": [], 113 | "source": [ 114 | "conserv_score = pd.DataFrame({'Attention': attention,'Conservation score':info_content})\n", 115 | "threshold = conserv_score['Attention'].quantile(q=[0.33,0.66]).tolist()\n", 116 | "threshold" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": null, 122 | "metadata": {}, 123 | "outputs": [], 124 | "source": [ 125 | "def annotation(c):\n", 126 | " if c <= threshold[0]:\n", 127 | " return 'Low'\n", 128 | " elif threshold[0] <= c < threshold[1]:\n", 129 | " return 'Medium'\n", 130 | " elif threshold[1] <= c:\n", 131 | " return 'High'\n", 132 | "conserv_score['Groups'] = list(map(annotation,conserv_score['Attention']))" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": null, 138 | "metadata": {}, 139 | "outputs": [], 140 | "source": [ 141 | "conserv_score['Groups'].unique()" 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": null, 147 | "metadata": {}, 148 | "outputs": [], 149 | "source": [ 150 | "plt.figure(figsize=(10, 7.5))\n", 151 | "\n", 152 | "\n", 153 | "categories = ['Low', 'Medium', 'High',]\n", 154 | "palette = sns.color_palette('Set2', n_colors=len(categories))\n", 155 | "line_label = ['33 percentile', '66 percentile']\n", 156 | "line_style = ['dashed','dashdot']\n", 157 | "\n", 158 | "sns.scatterplot(x='Attention', y='Conservation score', hue='Groups', data=conserv_score, palette = palette)\n", 159 | "for i,t in enumerate(threshold):\n", 160 | " plt.axvline(x = t, color='red', label = line_label[i], linestyle = line_style[i])\n", 161 | "plt.xlabel('Attention', fontsize = 30)\n", 162 | "plt.ylabel('Conservation score', fontsize = 30)\n", 163 | "plt.xticks(fontsize=18)\n", 164 | "plt.yticks(fontsize=18)\n", 165 | "plt.legend(fontsize = 20, markerscale = 3)" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": null, 171 | "metadata": {}, 172 | "outputs": [], 173 | "source": [ 174 | "import seaborn as sns\n", 175 | "import matplotlib.pyplot as plt\n", 176 | "import pandas as pd\n", 177 | "import seaborn as sns\n", 178 | "\n", 179 | "# A few helper functions:\n", 180 | "from statannotations.Annotator import Annotator\n", 181 | "from statannotations.stats.utils import check_alpha\n", 182 | "\n", 183 | "\n", 184 | "import numpy as np\n", 185 | "from scipy.stats import mannwhitneyu\n", 186 | "\n", 187 | "new_order = ['Low','High','Medium']\n", 188 | "new_palette = [palette[categories.index(category)] for category in new_order]\n", 189 | "\n", 190 | "# Putting the parameters in a dictionary avoids code duplication\n", 191 | "# since we use the same for `sns.boxplot` and `Annotator` calls\n", 192 | "plotting_parameters = {\n", 193 | " 'data': conserv_score,\n", 194 | " 'x': 'Groups',\n", 195 | " 'y': 'Conservation score',\n", 196 | " 'order': categories,\n", 197 | " #'color': color\n", 198 | " 'palette': new_palette,\n", 199 | "}\n", 200 | "\n", 201 | "pairs = [('Low', 'Medium'),\n", 202 | " ('Medium', 'High'),\n", 203 | " ('Low', 'High')]\n", 204 | "\n", 205 | "\n", 206 | "plt.rcParams[\"font.size\"] = 20\n", 207 | "fig = plt.figure(figsize=(10, 7.5))\n", 208 | "ax = fig.add_subplot(1, 1, 1)\n", 209 | "\n", 210 | "with sns.plotting_context('notebook', font_scale=1.4):\n", 211 | "\n", 212 | " # Plot with seaborn\n", 213 | " sns.violinplot(ax = ax, **plotting_parameters)\n", 214 | "\n", 215 | " # Add annotations\n", 216 | " annotator = Annotator(ax, pairs, **plotting_parameters)\n", 217 | " annotator.configure(test='Mann-Whitney', comparisons_correction=\"bonferroni\")\n", 218 | " _, corrected_results = annotator.apply_and_annotate()\n", 219 | " \n", 220 | "ax.set_xlabel(\"\")" 221 | ] 222 | } 223 | ], 224 | "metadata": { 225 | "interpreter": { 226 | "hash": "4e4e8b0f9678671e7c1801361c888ac62514a7043b530e64d08c86727710599e" 227 | }, 228 | "kernelspec": { 229 | "display_name": "Python 3 (ipykernel)", 230 | "language": "python", 231 | "name": "python3" 232 | }, 233 | "language_info": { 234 | "codemirror_mode": { 235 | "name": "ipython", 236 | "version": 3 237 | }, 238 | "file_extension": ".py", 239 | "mimetype": "text/x-python", 240 | "name": "python", 241 | "nbconvert_exporter": "python", 242 | "pygments_lexer": "ipython3", 243 | "version": "3.8.10" 244 | } 245 | }, 246 | "nbformat": 4, 247 | "nbformat_minor": 4 248 | } 249 | -------------------------------------------------------------------------------- /proteinbert/conv_and_global_attention_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import tensorflow as tf 4 | from tensorflow import keras 5 | import tensorflow.keras.backend as K 6 | 7 | class GlobalAttention(keras.layers.Layer): 8 | 9 | ''' 10 | Recevies two inputs: 11 | 1. A global representation (of some fixed dimension) 12 | 2. A sequence (of any length, and some fixed dimension) 13 | The global representation is used to construct a global query that attends to all the positions in the sequence (independently 14 | for any of the heads). 15 | ''' 16 | 17 | def __init__(self, n_heads, d_key, d_value, **kwargs): 18 | self.n_heads = n_heads 19 | self.d_key = d_key 20 | self.sqrt_d_key = np.sqrt(self.d_key) 21 | self.d_value = d_value 22 | self.d_output = n_heads * d_value 23 | super(GlobalAttention, self).__init__(**kwargs) 24 | 25 | def compute_output_shape(self, input_shapes): 26 | # input_shapes: (batch_size, d_global_input), (batch_size, length, d_seq_input) 27 | (batch_size, _), _ = input_shapes 28 | return (batch_size, self.d_output) 29 | 30 | def build(self, input_shapes): 31 | # input_shapes: (batch_size, d_global_input), (batch_size, length, d_seq_input) 32 | (_, self.d_global_input), (_, _, self.d_seq_input) = input_shapes 33 | # Wq: (n_heads, d_global_input, d_key) 34 | self.Wq = self.add_weight(name = 'Wq', shape = (self.n_heads, self.d_global_input, self.d_key), \ 35 | initializer = 'glorot_uniform', trainable = True) 36 | # Wk: (n_heads, d_seq_input, d_key) 37 | self.Wk = self.add_weight(name = 'Wk', shape = (self.n_heads, self.d_seq_input, self.d_key), \ 38 | initializer = 'glorot_uniform', trainable = True) 39 | # Wv: (n_heads, d_seq_input, d_value) 40 | self.Wv = self.add_weight(name = 'Wv', shape = (self.n_heads, self.d_seq_input, self.d_value), \ 41 | initializer = 'glorot_uniform', trainable = True) 42 | super(GlobalAttention, self).build(input_shapes) 43 | 44 | def call(self, inputs): 45 | 46 | # X: (batch_size, d_global_input) 47 | # S: (batch_size, length, d_seq_input) 48 | X, S = inputs 49 | _, length, _ = K.int_shape(S) 50 | 51 | # (batch_size, n_heads, length, d_value) 52 | VS = K.permute_dimensions(keras.activations.gelu(K.dot(S, self.Wv)), (0, 2, 1, 3)) 53 | # (batch_size * n_heads, length, d_value) 54 | VS_batched_heads = K.reshape(VS, (-1, length, self.d_value)) 55 | 56 | Z_batched_heads = self.calculate_attention(inputs) 57 | # (batch_size * n_heads, d_value) 58 | Y_batched_heads = K.batch_dot(Z_batched_heads, VS_batched_heads) 59 | # (batch_size, n_heads * d_value) 60 | Y = K.reshape(Y_batched_heads, (-1, self.d_output)) 61 | 62 | return Y 63 | 64 | def calculate_attention(self, inputs): 65 | 66 | # X: (batch_size, d_global_input) 67 | # S: (batch_size, length, d_seq_input) 68 | X, S = inputs 69 | _, length, _ = K.int_shape(S) 70 | 71 | # (batch_size, n_heads, d_key) 72 | QX = K.tanh(K.dot(X, self.Wq)) 73 | # (batch_size * n_heads, d_key) 74 | QX_batched_heads = K.reshape(QX, (-1, self.d_key)) 75 | 76 | # (batch_size, n_heads, d_key, length) 77 | KS = K.permute_dimensions(K.tanh(K.dot(S, self.Wk)), (0, 2, 3, 1)) 78 | # (batch_size * n_heads, d_key, length) 79 | KS_batched_heads = K.reshape(KS, (-1, self.d_key, length)) 80 | 81 | # (batch_size * n_heads, length) 82 | Z_batched_heads = K.softmax(K.batch_dot(QX_batched_heads, KS_batched_heads) / self.sqrt_d_key) 83 | return Z_batched_heads 84 | 85 | def create_model(seq_len, vocab_size, n_annotations, d_hidden_seq = 128, d_hidden_global = 512, n_blocks = 6, n_heads = 4, \ 86 | d_key = 64, conv_kernel_size = 9, wide_conv_dilation_rate = 5, activation = 'gelu'): 87 | 88 | ''' 89 | seq_len is required to create the model, but all the weights are independent of the length and can be re-used with 90 | different lengths. 91 | ''' 92 | 93 | assert d_hidden_global % n_heads == 0 94 | d_value = d_hidden_global // n_heads 95 | 96 | input_seq = keras.layers.Input(shape = (seq_len,), dtype = np.int32, name = 'input-seq') 97 | input_annoatations = keras.layers.Input(shape = (n_annotations,), dtype = np.float32, name = 'input-annotations') 98 | 99 | hidden_seq = keras.layers.Embedding(vocab_size, d_hidden_seq, name = 'embedding-seq-input')(input_seq) 100 | hidden_global = keras.layers.Dense(d_hidden_global, activation = activation, name = 'dense-global-input')(input_annoatations) 101 | 102 | for block_index in range(1, n_blocks + 1): 103 | 104 | seqed_global = keras.layers.Dense(d_hidden_seq, activation = activation, name = 'global-to-seq-dense-block%d' % block_index)(hidden_global) 105 | seqed_global = keras.layers.Reshape((1, d_hidden_seq), name = 'global-to-seq-reshape-block%d' % block_index)(seqed_global) 106 | 107 | narrow_conv_seq = keras.layers.Conv1D(filters = d_hidden_seq, kernel_size = conv_kernel_size, strides = 1, \ 108 | padding = 'same', dilation_rate = 1, activation = activation, name = 'narrow-conv-block%d' % block_index)(hidden_seq) 109 | wide_conv_seq = keras.layers.Conv1D(filters = d_hidden_seq, kernel_size = conv_kernel_size, strides = 1, \ 110 | padding = 'same', dilation_rate = wide_conv_dilation_rate, activation = activation, name = 'wide-conv-block%d' % \ 111 | block_index)(hidden_seq) 112 | 113 | hidden_seq = keras.layers.Add(name = 'seq-merge1-block%d' % block_index)([hidden_seq, seqed_global, narrow_conv_seq, wide_conv_seq]) 114 | hidden_seq = keras.layers.LayerNormalization(name = 'seq-merge1-norm-block%d' % block_index)(hidden_seq) 115 | 116 | dense_seq = keras.layers.Dense(d_hidden_seq, activation = activation, name = 'seq-dense-block%d' % block_index)(hidden_seq) 117 | hidden_seq = keras.layers.Add(name = 'seq-merge2-block%d' % block_index)([hidden_seq, dense_seq]) 118 | hidden_seq = keras.layers.LayerNormalization(name = 'seq-merge2-norm-block%d' % block_index)(hidden_seq) 119 | 120 | dense_global = keras.layers.Dense(d_hidden_global, activation = activation, name = 'global-dense1-block%d' % block_index)(hidden_global) 121 | attention = GlobalAttention(n_heads, d_key, d_value, name = 'global-attention-block%d' % block_index)([hidden_global, hidden_seq]) 122 | hidden_global = keras.layers.Add(name = 'global-merge1-block%d' % block_index)([hidden_global, dense_global, attention]) 123 | hidden_global = keras.layers.LayerNormalization(name = 'global-merge1-norm-block%d' % block_index)(hidden_global) 124 | 125 | dense_global = keras.layers.Dense(d_hidden_global, activation = activation, name = 'global-dense2-block%d' % block_index)(hidden_global) 126 | hidden_global = keras.layers.Add(name = 'global-merge2-block%d' % block_index)([hidden_global, dense_global]) 127 | hidden_global = keras.layers.LayerNormalization(name = 'global-merge2-norm-block%d' % block_index)(hidden_global) 128 | 129 | output_seq = keras.layers.Dense(vocab_size, activation = 'softmax', name = 'output-seq')(hidden_seq) 130 | output_annotations = keras.layers.Dense(n_annotations, activation = 'sigmoid', name = 'output-annotations')(hidden_global) 131 | 132 | return keras.models.Model(inputs = [input_seq, input_annoatations], outputs = [output_seq, output_annotations]) 133 | 134 | def get_model_with_hidden_layers_as_outputs(model): 135 | 136 | _, seq_len, _ = model.outputs[0].shape 137 | 138 | seq_layers = [layer.output for layer in model.layers if len(layer.output.shape) == 3 and \ 139 | tuple(layer.output.shape)[:2] == (None, seq_len) and (layer.name in ['input-seq-encoding', 'dense-seq-input', 'output-seq'] or \ 140 | isinstance(layer, keras.layers.LayerNormalization))] 141 | global_layers = [layer.output for layer in model.layers if len(layer.output.shape) == 2 and (layer.name in ['input_annoatations', \ 142 | 'dense-global-input', 'output-annotations'] or isinstance(layer, keras.layers.LayerNormalization))] 143 | 144 | concatenated_seq_output = keras.layers.Concatenate(name = 'all-seq-layers')(seq_layers) 145 | concatenated_global_output = keras.layers.Concatenate(name = 'all-global-layers')(global_layers) 146 | 147 | return keras.models.Model(inputs = model.inputs, outputs = [concatenated_seq_output, concatenated_global_output]) 148 | -------------------------------------------------------------------------------- /proteinbert/model_generation.py: -------------------------------------------------------------------------------- 1 | from numbers import Number 2 | import pickle 3 | 4 | import numpy as np 5 | 6 | from tensorflow import keras 7 | 8 | from proteinbert.shared_utils.util import log 9 | from tokenization import additional_token_to_index, n_tokens, tokenize_seq 10 | #from _new_finetuning import mechanism_labels 11 | 12 | 13 | class ModelGenerator: 14 | 15 | def __init__(self, optimizer_class = keras.optimizers.legacy.Adam, lr = 2e-04, other_optimizer_kwargs = {}, model_weights = None, optimizer_weights = None): 16 | self.optimizer_class = optimizer_class 17 | self.lr = lr 18 | self.other_optimizer_kwargs = other_optimizer_kwargs 19 | self.model_weights = model_weights 20 | self.optimizer_weights = optimizer_weights 21 | 22 | def train(self, encoded_train_set, encoded_valid_set, seq_len, batch_size, n_epochs, lr = None, callbacks = [], **create_model_kwargs): 23 | 24 | train_X, train_Y, train_sample_weigths = encoded_train_set 25 | self.dummy_epoch = (_slice_arrays(train_X, slice(0, 1)), _slice_arrays(train_Y, slice(0, 1))) 26 | model = self.create_model(seq_len, **create_model_kwargs) 27 | 28 | if lr is not None: 29 | model.optimizer.lr = lr 30 | 31 | model.fit(train_X, train_Y, sample_weight = train_sample_weigths, batch_size = batch_size, epochs = n_epochs, validation_data = encoded_valid_set, \ 32 | callbacks = callbacks) 33 | self.update_state(model) 34 | 35 | def update_state(self, model): 36 | self.model_weights = copy_weights(model.get_weights()) 37 | self.optimizer_weights = copy_weights(model.optimizer.get_weights()) 38 | 39 | def _init_weights(self, model): 40 | 41 | #if self.optimizer_weights is not None: 42 | # For some reason keras requires this strange little hack in order to properly initialize a new model's optimizer, so that 43 | # the optimizer's weights can be reloaded from an existing state. 44 | #self._train_for_a_dummy_epoch(model) 45 | 46 | if self.model_weights is not None: 47 | 48 | model.set_weights(copy_weights(self.model_weights)) 49 | 50 | if self.optimizer_weights is not None: 51 | if len(self.optimizer_weights) == len(model.optimizer.get_weights()): 52 | model.optimizer.set_weights(copy_weights(self.optimizer_weights)) 53 | else: 54 | log('Incompatible number of optimizer weights - will not initialize them.') 55 | 56 | def _train_for_a_dummy_epoch(self, model): 57 | X, Y = self.dummy_epoch 58 | model.fit(X, Y, batch_size = 1, verbose = 0) 59 | 60 | class PretrainingModelGenerator(ModelGenerator): 61 | 62 | def __init__(self, create_model_function, n_annotations, create_model_kwargs = {}, optimizer_class = keras.optimizers.Adam, lr = 2e-04, other_optimizer_kwargs = {}, \ 63 | annots_loss_weight = 1, model_weights = None, optimizer_weights = None): 64 | 65 | ModelGenerator.__init__(self, optimizer_class = optimizer_class, lr = lr, other_optimizer_kwargs = other_optimizer_kwargs, model_weights = model_weights, \ 66 | optimizer_weights = optimizer_weights) 67 | 68 | self.create_model_function = create_model_function 69 | self.n_annotations = n_annotations 70 | self.create_model_kwargs = create_model_kwargs 71 | self.annots_loss_weight = annots_loss_weight 72 | 73 | def create_model(self, seq_len, compile = True, init_weights = True): 74 | 75 | clear_session() 76 | model = self.create_model_function(seq_len, n_tokens, self.n_annotations, **self.create_model_kwargs) 77 | 78 | if compile: 79 | model.compile(optimizer =self.optimizer_class(learning_rate = self.lr, **self.other_optimizer_kwargs), loss = ['sparse_categorical_crossentropy', 'binary_crossentropy'], \ 80 | loss_weights = [1, self.annots_loss_weight]) 81 | 82 | if init_weights: 83 | self._init_weights(model) 84 | 85 | return model 86 | 87 | class FinetuningModelGenerator(ModelGenerator): 88 | 89 | def __init__(self, pretraining_model_generator, pretraining_model_manipulation_function = None, dropout_rate = 0.5, optimizer_class = None, \ 90 | lr = None, other_optimizer_kwargs = None, model_weights = None, optimizer_weights = None): 91 | 92 | if other_optimizer_kwargs is None: 93 | if optimizer_class is None: 94 | other_optimizer_kwargs = pretraining_model_generator.other_optimizer_kwargs 95 | else: 96 | other_optimizer_kwargs = {} 97 | 98 | if optimizer_class is None: 99 | optimizer_class = pretraining_model_generator.optimizer_class 100 | 101 | if lr is None: 102 | lr = pretraining_model_generator.lr 103 | 104 | ModelGenerator.__init__(self, optimizer_class = optimizer_class, lr = lr, other_optimizer_kwargs = other_optimizer_kwargs, model_weights = model_weights, \ 105 | optimizer_weights = optimizer_weights) 106 | 107 | self.pretraining_model_generator = pretraining_model_generator 108 | self.pretraining_model_manipulation_function = pretraining_model_manipulation_function 109 | self.dropout_rate = dropout_rate 110 | 111 | def create_model(self, seq_len, freeze_pretrained_layers = False): 112 | 113 | model = self.pretraining_model_generator.create_model(seq_len, compile = False, init_weights = (self.model_weights is None)) 114 | 115 | if self.pretraining_model_manipulation_function is not None: 116 | model = self.pretraining_model_manipulation_function(model) 117 | 118 | if freeze_pretrained_layers: 119 | for layer in model.layers: 120 | layer.trainable = False 121 | 122 | model_inputs = model.input 123 | pretraining_output_seq_layer, pretraining_output_annoatations_layer = model.output 124 | last_hidden_layer = pretraining_output_annoatations_layer 125 | last_hidden_layer = keras.layers.Dropout(self.dropout_rate)(last_hidden_layer) 126 | 127 | output_layer = keras.layers.Dense(6, activation = 'softmax')(last_hidden_layer) 128 | loss = 'sparse_categorical_crossentropy' 129 | model = keras.models.Model(inputs = model_inputs, outputs = output_layer) 130 | model.compile(loss = loss, optimizer =self.optimizer_class(learning_rate = self.lr, **self.other_optimizer_kwargs)) 131 | 132 | self._init_weights(model) 133 | 134 | return model 135 | 136 | class InputEncoder: 137 | 138 | def __init__(self, n_annotations): 139 | self.n_annotations = n_annotations 140 | 141 | def encode_X(self, seqs, seq_len): 142 | return [ 143 | tokenize_seqs(seqs, seq_len), 144 | np.zeros((len(seqs), self.n_annotations), dtype = np.int8) 145 | ] 146 | 147 | def load_pretrained_model_from_dump(dump_file_path, create_model_function, create_model_kwargs = {}, optimizer_class = keras.optimizers.Adam, lr = 2e-04, other_optimizer_kwargs = {}, annots_loss_weight = 1, load_optimizer_weights = False): 148 | 149 | with open(dump_file_path, 'rb') as f: 150 | n_annotations, model_weights, optimizer_weights = pickle.load(f) 151 | 152 | if not load_optimizer_weights: 153 | optimizer_weights = None 154 | 155 | model_generator = PretrainingModelGenerator(create_model_function, n_annotations, create_model_kwargs = create_model_kwargs, optimizer_class = optimizer_class, lr = lr, \ 156 | other_optimizer_kwargs = other_optimizer_kwargs, annots_loss_weight = annots_loss_weight, model_weights = model_weights, optimizer_weights = optimizer_weights) 157 | input_encoder = InputEncoder(n_annotations) 158 | 159 | return model_generator, input_encoder 160 | 161 | def tokenize_seqs(seqs, seq_len): 162 | # Note that tokenize_seq already adds and tokens. 163 | return np.array([seq_tokens + (seq_len - len(seq_tokens)) * [additional_token_to_index['']] for seq_tokens in map(tokenize_seq, seqs)], dtype = np.int32) 164 | 165 | def clear_session(): 166 | import tensorflow.keras.backend as K 167 | K.clear_session() 168 | 169 | def copy_weights(weights): 170 | return [_copy_number_or_array(w) for w in weights] 171 | 172 | def _copy_number_or_array(variable): 173 | if isinstance(variable, np.ndarray): 174 | return variable.copy() 175 | elif isinstance(variable, Number): 176 | return variable 177 | else: 178 | raise TypeError('Unexpected type %s' % type(variable)) 179 | 180 | def _slice_arrays(arrays, slicing): 181 | if isinstance(arrays, list) or isinstance(arrays, tuple): 182 | return [array[slicing] for array in arrays] 183 | else: 184 | return arrays[slicing] 185 | -------------------------------------------------------------------------------- /Practice.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "696a404a-500e-490a-b3ec-5659169ba2ba", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "name": "stdout", 11 | "output_type": "stream", 12 | "text": [ 13 | "2024-05-03 09:49:46.623822: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 AVX512F AVX512_VNNI FMA\n", 14 | "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", 15 | "2024-05-03 09:49:46.735325: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n", 16 | "2024-05-03 09:49:46.768236: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", 17 | "2024-05-03 09:49:47.267324: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64\n", 18 | "2024-05-03 09:49:47.267383: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64\n", 19 | "2024-05-03 09:49:47.267392: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n", 20 | "17338 training set records, 1927 validation set records\n", 21 | "proteinbert/proteinbert_models/default.pkl\n", 22 | "512\n", 23 | "[2024_05_03-09:49:47] Training set: Filtered out 1649 of 17338 (9.5%) records of lengths exceeding 510.\n", 24 | "^C\n", 25 | "Traceback (most recent call last):\n", 26 | " File \"finetuning.py\", line 155, in \n", 27 | " main(config)\n", 28 | " File \"finetuning.py\", line 135, in main\n", 29 | " finetune(model_generator, input_encoder, config, train_set, valid_set, \\\n", 30 | " File \"finetuning.py\", line 69, in finetune\n", 31 | " encoded_train_set, encoded_valid_set = encode_train_and_valid_sets(train, valid, input_encoder, config, seq_len)\n", 32 | " File \"finetuning.py\", line 101, in encode_train_and_valid_sets\n", 33 | " encoded_train_set = encode_dataset(train, input_encoder, config.mechanism_labels, seq_len = seq_len, needs_filtering = True, \\\n", 34 | " File \"/home/ARG-BERT/tokenization.py\", line 52, in encode_dataset\n", 35 | " X = input_encoder.encode_X(seqs, seq_len)\n", 36 | " File \"/home/ARG-BERT/proteinbert/model_generation.py\", line 143, in encode_X\n", 37 | " tokenize_seqs(seqs, seq_len),\n", 38 | " File \"/home/ARG-BERT/proteinbert/model_generation.py\", line 163, in tokenize_seqs\n", 39 | " return np.array([seq_tokens + (seq_len - len(seq_tokens)) * [additional_token_to_index['']] for seq_tokens in map(tokenize_seq, seqs)], dtype = np.int32)\n", 40 | " File \"/home/ARG-BERT/proteinbert/model_generation.py\", line 163, in \n", 41 | " return np.array([seq_tokens + (seq_len - len(seq_tokens)) * [additional_token_to_index['']] for seq_tokens in map(tokenize_seq, seqs)], dtype = np.int32)\n", 42 | " File \"/home/ARG-BERT/tokenization.py\", line 31, in tokenize_seq\n", 43 | " return [additional_token_to_index['']] + [aa_to_token_index.get(aa, other_token_index) for aa in parse_seq(seq)] + \\\n", 44 | " File \"/home/ARG-BERT/tokenization.py\", line 31, in \n", 45 | " return [additional_token_to_index['']] + [aa_to_token_index.get(aa, other_token_index) for aa in parse_seq(seq)] + \\\n", 46 | "KeyboardInterrupt\n" 47 | ] 48 | } 49 | ], 50 | "source": [ 51 | "! python3 finetuning.py \\\n", 52 | "--fold 1 \\\n", 53 | "--seed 4" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 2, 59 | "id": "82c57535-18f4-45bb-9291-ca088dced246", 60 | "metadata": {}, 61 | "outputs": [ 62 | { 63 | "name": "stdout", 64 | "output_type": "stream", 65 | "text": [ 66 | "2024-05-03 10:04:23.425031: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 AVX512F AVX512_VNNI FMA\n", 67 | "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", 68 | "2024-05-03 10:04:23.535188: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n", 69 | "2024-05-03 10:04:23.569050: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", 70 | "2024-05-03 10:04:24.101725: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64\n", 71 | "2024-05-03 10:04:24.101788: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64\n", 72 | "2024-05-03 10:04:24.101795: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n", 73 | "proteinbert/proteinbert_models/default.pkl\n", 74 | "2024-05-03 10:04:25.107630: I tensorflow/core/common_runtime/gpu/gpu_device.cc:2006] Ignoring visible gpu device (device: 4, name: NVIDIA T400 4GB, pci bus id: 0000:e3:00.0, compute capability: 7.5) with core count: 6. The minimum required count is 8. You can adjust this requirement with the env var TF_MIN_GPU_MULTIPROCESSOR_COUNT.\n", 75 | "2024-05-03 10:04:25.107951: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 AVX512F AVX512_VNNI FMA\n", 76 | "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", 77 | "2024-05-03 10:04:26.747667: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1616] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 46694 MB memory: -> device: 0, name: NVIDIA RTX A6000, pci bus id: 0000:31:00.0, compute capability: 8.6\n", 78 | "2024-05-03 10:04:26.748401: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1616] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 46694 MB memory: -> device: 1, name: NVIDIA RTX A6000, pci bus id: 0000:4b:00.0, compute capability: 8.6\n", 79 | "2024-05-03 10:04:26.749021: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1616] Created device /job:localhost/replica:0/task:0/device:GPU:2 with 46694 MB memory: -> device: 2, name: NVIDIA RTX A6000, pci bus id: 0000:b1:00.0, compute capability: 8.6\n", 80 | "2024-05-03 10:04:26.749659: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1616] Created device /job:localhost/replica:0/task:0/device:GPU:3 with 46694 MB memory: -> device: 3, name: NVIDIA RTX A6000, pci bus id: 0000:ca:00.0, compute capability: 8.6\n", 81 | "2024-05-03 10:04:30.950281: I tensorflow/stream_executor/cuda/cuda_blas.cc:1614] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once.\n", 82 | "2024-05-03 10:04:31.159559: I tensorflow/stream_executor/cuda/cuda_dnn.cc:384] Loaded cuDNN version 8400\n", 83 | "136/136 [==============================] - 7s 26ms/step\n", 84 | "26/26 [==============================] - 2s 25ms/step\n", 85 | "9/9 [==============================] - 2s 25ms/step\n" 86 | ] 87 | } 88 | ], 89 | "source": [ 90 | "! python3 test.py \\\n", 91 | "--fold 1 \\\n", 92 | "--seed 4" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": null, 98 | "id": "7132f4fc-71cc-450d-96c7-17f4cf8b2b7f", 99 | "metadata": {}, 100 | "outputs": [], 101 | "source": [] 102 | } 103 | ], 104 | "metadata": { 105 | "kernelspec": { 106 | "display_name": "Python 3 (ipykernel)", 107 | "language": "python", 108 | "name": "python3" 109 | }, 110 | "language_info": { 111 | "codemirror_mode": { 112 | "name": "ipython", 113 | "version": 3 114 | }, 115 | "file_extension": ".py", 116 | "mimetype": "text/x-python", 117 | "name": "python", 118 | "nbconvert_exporter": "python", 119 | "pygments_lexer": "ipython3", 120 | "version": "3.8.10" 121 | } 122 | }, 123 | "nbformat": 4, 124 | "nbformat_minor": 5 125 | } 126 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import argparse 4 | import os 5 | import pickle 6 | 7 | import tensorflow as tf 8 | from tensorflow import keras 9 | 10 | from proteinbert.shared_utils.util import log 11 | from proteinbert.existing_model_loading import load_pretrained_model 12 | from proteinbert.conv_and_global_attention_model import get_model_with_hidden_layers_as_outputs 13 | from proteinbert.model_generation import InputEncoder,FinetuningModelGenerator 14 | from tokenization import ADDED_TOKENS_PER_SEQ, encode_dataset, split_dataset_by_len 15 | 16 | class Config_Test: 17 | 18 | def __init__(self, args): 19 | 20 | self.mechanism_labels = { 21 | 22 | 'antibiotic target alteration':0, 23 | 'antibiotic target replacement':1, 24 | 'antibiotic target protection':2, 25 | 'antibiotic inactivation':3, 26 | 'antibiotic efflux':4, 27 | 'others':5 28 | 29 | } 30 | self.fold = args.fold 31 | self.use_LHD = args.use_LHD 32 | self.threshold = args.threshold 33 | self.seed = args.seed 34 | self.get_all_attention = args.get_all_attention 35 | 36 | def create_input_path(self, create_dataset_path): 37 | 38 | if self.use_LHD: 39 | sub_path = 'LHD/c%1f/fold_%d_%1f'%(self.fold, self.threshold) 40 | else: 41 | sub_path = 'HMDARG-DB/fold_%d'%self.fold 42 | 43 | if create_dataset_path: 44 | root_dir = 'inputs' 45 | return os.path.join(root_dir, '%s.test.csv' % sub_path) 46 | else: 47 | root_dir = 'outputs/finetuned_model' 48 | return os.path.join(root_dir,'%s.finetuned_model.h5' % sub_path) 49 | 50 | def create_output_path(self, create_dataset_path): 51 | 52 | if self.use_LHD: 53 | sub_path = 'LHD/c%1f/fold_%d_%1f'%(self.fold, self.threshold) 54 | else: 55 | sub_path = 'HMDARG-DB/fold_%d'%self.fold 56 | 57 | if create_dataset_path: 58 | root_dir = 'outputs/Prediction results' 59 | if not os.path.exists(root_dir): 60 | os.makedirs('outputs/Prediction results/HMDARG-DB') 61 | return os.path.join(root_dir, '%s.test.csv' % sub_path) 62 | else: 63 | root_dir = 'outputs/attention' 64 | if not os.path.exists(root_dir): 65 | os.makedirs(root_dir) 66 | return os.path.join(root_dir, '%s.attn.csv' % sub_path) 67 | 68 | def n_mechanism_labels(self): 69 | return len(self.mechanism_labels) 70 | 71 | def set_seed(self): 72 | if self.seed != None: 73 | return tf.random.set_seed(self.seed) 74 | 75 | def evaluate_by_len(model_generator, input_encoder, config, df, start_seq_len = 512, start_batch_size = 32, increase_factor = 2):#output_spec, 76 | 77 | # assert model_generator.optimizer_weights is None 78 | 79 | results = [] 80 | results_names = [] 81 | y_trues = [] 82 | y_preds = [] 83 | index_list = [] 84 | #inverse_UNIQUE_LABELS = {v:k for k,v in output_spec.unique_labels.items()} 85 | index_to_label = {} 86 | 87 | for len_matching_dataset, seq_len, batch_size, index in split_dataset_by_len(df, start_seq_len = start_seq_len, start_batch_size = start_batch_size, \ 88 | increase_factor = increase_factor): 89 | 90 | X, y_true, sample_weights = encode_dataset(len_matching_dataset, input_encoder, config.mechanism_labels, \ 91 | seq_len = seq_len, needs_filtering = False)#output_spec, 92 | 93 | assert set(np.unique(sample_weights)) <= {0.0, 1.0} 94 | y_mask = (sample_weights == 1) 95 | 96 | model = model_generator.create_model(seq_len) 97 | y_pred = model.predict(X, batch_size = batch_size) 98 | 99 | y_true = y_true[y_mask].flatten() 100 | y_pred = y_pred[y_mask].reshape((-1, y_pred.shape[-1])) 101 | #y_pred = y_pred.reshape((-1, y_pred.shape[-1])) 102 | y_trues.append(y_true) 103 | y_preds.append(y_pred) 104 | index_list += index 105 | 106 | y_true = np.concatenate(y_trues, axis = 0) 107 | y_pred = np.concatenate(y_preds, axis = 0) 108 | prediction, confusion_matrix = get_evaluation_results(y_true, y_pred, config, return_confusion_matrix = True) 109 | 110 | for i in list(range(len(df))): 111 | index_to_label[index_list[i]] = prediction[i] 112 | df = pd.concat([df,pd.DataFrame.from_dict(index_to_label, orient='index').sort_index()],axis = 1) 113 | 114 | return df, confusion_matrix 115 | 116 | def get_evaluation_results(y_true, y_pred, config, return_confusion_matrix = False):#output_spec, 117 | 118 | from scipy.stats import spearmanr 119 | from sklearn.metrics import roc_auc_score, accuracy_score, confusion_matrix 120 | 121 | str_unique_labels = list(map(str, config.mechanism_labels)) 122 | y_pred_classes = y_pred.argmax(axis = -1) 123 | confusion_matrix = pd.DataFrame(confusion_matrix(y_true, y_pred_classes, labels = np.arange(config.n_mechanism_labels())), index = str_unique_labels, \ 124 | columns = str_unique_labels) 125 | 126 | if return_confusion_matrix: 127 | return y_pred_classes, confusion_matrix 128 | else: 129 | return results 130 | 131 | def calculate_attentions(model, input_encoder, seq, seq_len = None): 132 | 133 | from tensorflow.keras import backend as K 134 | from tokenization import index_to_token 135 | 136 | if seq_len is None: 137 | seq_len = len(seq) + 2 138 | 139 | X = input_encoder.encode_X([seq], seq_len) 140 | (X_seq,), _ = X 141 | seq_tokens = list(map(index_to_token.get, X_seq)) 142 | 143 | model_inputs = [layer.input for layer in model.layers if 'InputLayer' in str(type(layer))][::-1] 144 | model_attentions = [layer.calculate_attention(layer.input) for layer in model.layers if 'GlobalAttention' in str(type(layer))] 145 | invoke_model_attentions = K.function(model_inputs, model_attentions) 146 | attention_values = invoke_model_attentions(X) 147 | 148 | attention_labels = [] 149 | merged_attention_values = [] 150 | 151 | for attention_layer_index, attention_layer_values in enumerate(attention_values): 152 | for head_index, head_values in enumerate(attention_layer_values): 153 | attention_labels.append('Attention %d - head %d' % (attention_layer_index + 1, head_index + 1)) 154 | merged_attention_values.append(head_values) 155 | 156 | attention_values = np.array(merged_attention_values) 157 | 158 | return attention_values, seq_tokens, attention_labels 159 | 160 | def main(config): 161 | 162 | # Load an input dataset 163 | test_set_file_path = config.create_input_path(create_dataset_path = True) 164 | test_set = pd.read_csv(test_set_file_path, index_col = 0)#.dropna().drop_duplicates() 165 | 166 | # Load a fine-tuned model 167 | finetuned_model_path = config.create_input_path(create_dataset_path = False) 168 | with open(finetuned_model_path, 'rb') as f: 169 | model_weights, optimizer_weights = pickle.load(f) 170 | 171 | # Generate the fine-tuned model 172 | pretrained_model_generator, input_encoder = load_pretrained_model(local_model_dump_dir = 'proteinbert/proteinbert_models', local_model_dump_file_name = 'default.pkl') 173 | finetuned_model = FinetuningModelGenerator(pretrained_model_generator, pretraining_model_manipulation_function = \ 174 | get_model_with_hidden_layers_as_outputs, model_weights = model_weights, optimizer_weights = optimizer_weights, dropout_rate = 0.5) 175 | 176 | df, confusion_matrix = evaluate_by_len(finetuned_model, input_encoder, config, test_set, start_seq_len = 512, start_batch_size = 32) 177 | df = df.replace({v: k for k, v in config.mechanism_labels.items()}) 178 | df.to_csv(config.create_output_path(create_dataset_path = True)) 179 | 180 | if config.get_all_attention: 181 | attention_fold = {} 182 | for i in test_set.index: 183 | ID = test_set.iloc[i,0] 184 | seq = test_set.iloc[i,-1] 185 | seq_len = len(seq) + 2 186 | 187 | pretrained_model_generator, input_encoder = load_pretrained_model(local_model_dump_dir = 'proteinbert/proteinbert_models', local_model_dump_file_name = 'default.pkl') 188 | pretrained_model = pretrained_model_generator.create_model(seq_len) 189 | pretrained_attention_values, pretrained_seq_tokens, pretrained_attention_labels = calculate_attentions(pretrained_model, input_encoder, seq, \ 190 | seq_len = seq_len) 191 | 192 | #finetuned_model = keras.models.load_model(config.create_input_path(create_dataset_path = False)) 193 | finetuned_created_model = finetuned_model.create_model(seq_len) 194 | finetuned_attention_values, finetuned_seq_tokens, finetuned_attention_labels = calculate_attentions(finetuned_created_model, input_encoder, seq,\ 195 | seq_len = seq_len) 196 | 197 | attention = finetuned_attention_values - pretrained_attention_values 198 | attention_fold[ID] = attention[-4:,:].mean(axis = 0) 199 | 200 | attention_fold = pd.DataFrame.from_dict(attention_fold, orient='index').T 201 | attention_fold.to_csv(config.create_output_path(create_dataset_path = False)) 202 | 203 | if __name__ == "__main__": 204 | 205 | parser = argparse.ArgumentParser() 206 | parser.add_argument('-f', '--fold', type=int, help='The number of iterations in 5-fold CV.') 207 | parser.add_argument('-LHD', '--use_LHD', action='store_true', help='Whether you use Low Homology Dataset or not.') 208 | parser.add_argument('-t', '--threshold', type=float, help='Sequence similarity thresholds set when creating LHD.', default=0) 209 | parser.add_argument('-s', '--seed', type=int, help='Set random seed.', default=None) 210 | parser.add_argument('-attn', '--get_all_attention', action='store_true', help='Whether you need attention or not.', default=False) 211 | config = Config_Test(parser.parse_args()) 212 | 213 | main(config) 214 | -------------------------------------------------------------------------------- /Analysis_and_Figures/Attention_analysis/Fig4.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import pandas as pd\n", 10 | "import numpy as np\n", 11 | "from scipy import stats\n", 12 | "import os\n", 13 | "\n", 14 | "import matplotlib.pyplot as plt\n", 15 | "import seaborn as sns" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": null, 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [ 24 | "path = 'Results/'\n", 25 | "if not os.path.exists(path):\n", 26 | " print(\"make\" + path)\n", 27 | " os.makedirs(path)\n", 28 | "fold_table = pd.read_csv('../Sample_data/fold_table.csv')" 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "metadata": {}, 34 | "source": [ 35 | "## Run statical tests." 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": null, 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "\"\"\"\n", 45 | "def make_dataset(fold,fold_table):\n", 46 | " header = ['ID', 'Length', 'Database', 'Accession', 'Signature description', 'Start', 'End', 'Date', 'InterPro accession', 'InterPro description', 'GO']\n", 47 | " col_names = list(range(14))\n", 48 | " interpro = pd.read_csv('../OUTPUT_DIR/fold_'+str(fold)+'_interpro.tsv', index_col = None, sep = '\\t', header= None, names=col_names)\n", 49 | " interpro = interpro.iloc[:,[0,2,3,4,5,6,7,10,11,12,13]]\n", 50 | " interpro.columns = header\n", 51 | " interpro['ID'] = interpro['ID'].apply(lambda x: x.split('|')[0])\n", 52 | " interpro = pd.merge(interpro, fold_table[['ID', 'mechanism']], on='ID', how='left')\n", 53 | " attention = pd.read_csv('../attention/fold_'+str(fold)+'_attention.csv', index_col=0)\n", 54 | " return interpro,attention\n", 55 | "\"\"\"\n", 56 | "\n", 57 | "def run_U_test_for_a_sequence(accession,attention_of_target_sequence,interpro_of_target_sequence):\n", 58 | " attention_of_target_region = []\n", 59 | " interpro_of_target_region = interpro_of_target_sequence[interpro_of_target_sequence['Accession'] == accession]\n", 60 | " for i in interpro_of_target_region.index:\n", 61 | " start = interpro_of_target_region.loc[i,'Start']\n", 62 | " end = interpro_of_target_region.loc[i,'End']\n", 63 | " attention_of_target_region += attention_of_target_sequence[start-1:end]\n", 64 | " stats_U, p = stats.mannwhitneyu(attention_of_target_region, attention_of_target_sequence, True, 'greater')\n", 65 | " return p\n", 66 | "\n", 67 | "def run_U_test_for_a_dataset(interpro,attention):\n", 68 | " result = pd.DataFrame()\n", 69 | " for ID in interpro['ID'].unique():\n", 70 | " #try:\n", 71 | " interpro_of_target_sequence = interpro[interpro['ID'] == ID]\n", 72 | " accessions_of_target_region = interpro_of_target_sequence['Accession'].tolist()\n", 73 | "\n", 74 | " attention_of_target_sequence = attention[ID].dropna().tolist()[:-1]\n", 75 | "\n", 76 | " p_value = pd.DataFrame({accession: run_U_test_for_a_sequence(accession, attention_of_target_sequence,interpro_of_target_sequence) for accession in accessions_of_target_region}, index =['U p-value']).T.reset_index().rename(columns = {'index':'Accession'})\n", 77 | " p_value['Significance'] = (p_value['U p-value']<(0.05/len(interpro_of_target_sequence))).tolist()\n", 78 | " interpro_of_target_sequence = pd.merge(interpro_of_target_sequence, p_value, on='Accession', how='left')\n", 79 | " result = pd.concat([result, interpro_of_target_sequence], axis=0)\n", 80 | " #except:\n", 81 | " # print(ID)\n", 82 | " return result" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": null, 88 | "metadata": {}, 89 | "outputs": [], 90 | "source": [ 91 | "result_all = pd.DataFrame()\n", 92 | "\n", 93 | "for fold in range(5):\n", 94 | " interpro_per_fold = pd.read_csv(path + 'Interpro/fold_'+str(fold)+'.interpro.csv', index_col = 0)\n", 95 | " attention = pd.read_csv(path + 'attention/fold_'+str(fold)+'.attention.csv', index_col =0)\n", 96 | " result_per_fold = run_U_test_for_a_dataset(interpro_per_fold,attention)\n", 97 | " result_all = pd.concat([result_all, result_per_fold])\n", 98 | "result_all.to_csv(path + 'result_attention-intensive_regions.csv')" 99 | ] 100 | }, 101 | { 102 | "cell_type": "markdown", 103 | "metadata": {}, 104 | "source": [ 105 | "# Visualization" 106 | ] 107 | }, 108 | { 109 | "cell_type": "markdown", 110 | "metadata": {}, 111 | "source": [ 112 | "### Decide the target sequence." 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": null, 118 | "metadata": {}, 119 | "outputs": [], 120 | "source": [ 121 | "## Decide from its AMR Family\n", 122 | "\n", 123 | "familyname_to_id = {\n", 124 | " 'blaOXA-114s': 'U3N8W9',\n", 125 | " 'rpoB': 'NP_273190.1',\n", 126 | " 'macB': 'A0A011P660',\n", 127 | " 'tetW': 'ABN80187',\n", 128 | "}\n", 129 | "\n", 130 | "familyname = 'tetW'\n", 131 | "target_id = familyname_to_id[familyname]\n", 132 | "os.mkdir(path + familyname)" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": null, 138 | "metadata": {}, 139 | "outputs": [], 140 | "source": [ 141 | "## Decide from its resistance mechanism\n", 142 | "\n", 143 | "mechanism_to_id = {\n", 144 | " 'antibiotic inactivation': 'U3N8W9',\n", 145 | " 'antibiotic target alteration': 'NP_273190.1',\n", 146 | " 'antibiotic efflux': 'A0A011P660',\n", 147 | " 'antibiotic target protection': 'ABN80187',\n", 148 | "}\n", 149 | "\n", 150 | "mechanism = 'antibiotic target protection'\n", 151 | "target_id = mechanism_to_id[mechanism]\n" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": null, 157 | "metadata": {}, 158 | "outputs": [], 159 | "source": [ 160 | "fold = fold_table[fold_table['ID']==target_id]['fold'].tolist()[0]\n", 161 | "print(fold, mechanism)\n", 162 | "\n", 163 | "interpro_1domain = pd.read_csv(path + 'result_attention-intensive_regions.csv', index_col=0)\n", 164 | "interpro_1domain = interpro_1domain[interpro_1domain['ID'] == target_id]\n", 165 | "interpro_1domain[interpro_1domain['Significance']].to_csv(path + familyname + '/' + familyname +'.csv')" 166 | ] 167 | }, 168 | { 169 | "cell_type": "markdown", 170 | "metadata": {}, 171 | "source": [ 172 | "### Corresponding Attention-intensive areas and their positions." 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": null, 178 | "metadata": {}, 179 | "outputs": [], 180 | "source": [ 181 | "length = interpro_1domain['Length'].unique()[0]\n", 182 | "accession_list = interpro_1domain['Accession'].unique()\n", 183 | "position = pd.DataFrame(columns = accession_list, index = list(range(1,length+1)))\n", 184 | "\n", 185 | "for accession in accession_list:\n", 186 | " # print(accession)\n", 187 | " interpro_accession = interpro_1domain[interpro_1domain['Accession'] == accession]\n", 188 | " for l in range(len(interpro_accession)):\n", 189 | " start = interpro_accession.iloc[l,5]\n", 190 | " end = interpro_accession.iloc[l,6]\n", 191 | " # print(start,end)\n", 192 | " if interpro_accession.iloc[0,-1]:\n", 193 | " position.loc[start:end,accession] = -1\n", 194 | " else:\n", 195 | " position.loc[start:end,accession] = 1\n", 196 | "\n", 197 | "position.to_csv(path + familyname + '/position.csv')\n", 198 | "position = pd.read_csv(path + familyname + '/position.csv',index_col=0).T" 199 | ] 200 | }, 201 | { 202 | "cell_type": "markdown", 203 | "metadata": {}, 204 | "source": [ 205 | "### Visualization" 206 | ] 207 | }, 208 | { 209 | "cell_type": "code", 210 | "execution_count": null, 211 | "metadata": {}, 212 | "outputs": [], 213 | "source": [ 214 | "attention = pd.read_csv('attention/fold_'+str(fold)+'.attention.csv', index_col = 0)\n", 215 | "attention_of_target_sequence = attention[target_id].dropna().iloc[1:-1]\n", 216 | "\n", 217 | "## Focus only on the Attention-intensive regions.\n", 218 | "position = position[position.sum(axis = 1)<0]" 219 | ] 220 | }, 221 | { 222 | "cell_type": "code", 223 | "execution_count": null, 224 | "metadata": {}, 225 | "outputs": [], 226 | "source": [ 227 | "sns.set()\n", 228 | "fig, ax = plt.subplots(2, 1, figsize=(15, 10), sharex = True)\n", 229 | "\n", 230 | "sns.heatmap(position,cmap='bwr',ax = ax[0], cbar= False, yticklabels=False, xticklabels=True)\n", 231 | "ax[0].set_title('Attention-intensive regions', fontsize = 30)\n", 232 | "ax[0].tick_params(labelsize=30)\n", 233 | "ax[0].set_yticks([y + 0.5 for y in list(range(len(position.index)))])\n", 234 | "ax[0].set_yticklabels(position.index.tolist())\n", 235 | "\n", 236 | "\n", 237 | "attention_of_target_sequence.rolling(5, center=True).apply(lambda x: x.mean()).plot(fontsize = 10, legend=False)\n", 238 | "ax[1].set_title('Attention', fontsize = 30)\n", 239 | "ax[1].tick_params(labelsize=30)\n", 240 | "ax[1].set_xticks(range(0,len(position.T),100))\n", 241 | "ax[1].set_xticklabels(list(range(0,len(position.T),100)),rotation = 45)\n", 242 | "fig.savefig(path + familyname + '/' + familyname +'.png')" 243 | ] 244 | }, 245 | { 246 | "cell_type": "code", 247 | "execution_count": null, 248 | "metadata": {}, 249 | "outputs": [], 250 | "source": [] 251 | } 252 | ], 253 | "metadata": { 254 | "interpreter": { 255 | "hash": "4e4e8b0f9678671e7c1801361c888ac62514a7043b530e64d08c86727710599e" 256 | }, 257 | "kernelspec": { 258 | "display_name": "Python 3 (ipykernel)", 259 | "language": "python", 260 | "name": "python3" 261 | }, 262 | "language_info": { 263 | "codemirror_mode": { 264 | "name": "ipython", 265 | "version": 3 266 | }, 267 | "file_extension": ".py", 268 | "mimetype": "text/x-python", 269 | "name": "python", 270 | "nbconvert_exporter": "python", 271 | "pygments_lexer": "ipython3", 272 | "version": "3.8.10" 273 | } 274 | }, 275 | "nbformat": 4, 276 | "nbformat_minor": 4 277 | } 278 | -------------------------------------------------------------------------------- /Analysis_and_Figures/Fig2_and_Suppl.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import pandas as pd\n", 10 | "import seaborn as sns\n", 11 | "import matplotlib.pyplot as plt\n", 12 | "\n", 13 | "mechanism_list = ['antibiotic inactivation',\n", 14 | " 'antibiotic target alteration',\n", 15 | " 'antibiotic efflux',\n", 16 | " 'antibiotic target replacement',\n", 17 | " 'antibiotic target protection',\n", 18 | " 'others']" 19 | ] 20 | }, 21 | { 22 | "cell_type": "markdown", 23 | "metadata": {}, 24 | "source": [ 25 | "# Fig.2" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": null, 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "result_acc = pd.DataFrame()\n", 35 | "result_pre = pd.DataFrame()\n", 36 | "result_rec = pd.DataFrame()\n", 37 | "result_f1 = pd.DataFrame()\n", 38 | "\n", 39 | "for method in ['Proposed', 'LM-ARG', 'BLAST']:\n", 40 | " result_method = pd.read_csv('Prediction results/Raw_data/LHD_all_'+method+'.csv', index_col=0)\n", 41 | " result_acc[method] = result_method['Accuracy']\n", 42 | " result_pre[method] = result_method['Precision']\n", 43 | " result_rec[method] = result_method['Recall']\n", 44 | " result_f1[method] = result_method['F1 Score']" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": null, 50 | "metadata": {}, 51 | "outputs": [], 52 | "source": [ 53 | "result_acc['threshold'] = result_method['threshold']\n", 54 | "result_acc['fold'] = result_method['fold']\n", 55 | "result_pre['threshold'] = result_method['threshold']\n", 56 | "result_pre['fold'] = result_method['fold']\n", 57 | "result_rec['threshold'] = result_method['threshold']\n", 58 | "result_rec['fold'] = result_method['fold']\n", 59 | "result_f1['threshold'] = result_method['threshold']\n", 60 | "result_f1['fold'] = result_method['fold']" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": null, 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "result_for_heatmap_acc = result_acc[result_acc['fold'] == 'Average'].set_index('threshold').loc[:,['Proposed', 'LM-ARG', 'BLAST']]\n", 70 | "result_for_heatmap_pre = result_pre[result_pre['fold'] == 'Average'].set_index('threshold').loc[:,['Proposed', 'LM-ARG', 'BLAST']]\n", 71 | "result_for_heatmap_rec = result_rec[result_rec['fold'] == 'Average'].set_index('threshold').loc[:,['Proposed', 'LM-ARG', 'BLAST']]\n", 72 | "result_for_heatmap_f1 = result_f1[result_f1['fold'] == 'Average'].set_index('threshold').loc[:,['Proposed', 'LM-ARG', 'BLAST']]" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": null, 78 | "metadata": {}, 79 | "outputs": [], 80 | "source": [ 81 | "category_order = ['F1 Score', 'Recall', 'Precision', 'Accuracy'][::-1]\n", 82 | "ylabel_heatmap = ['threshold','','','']\n", 83 | "\n", 84 | "fig_heat,ax_heat = plt.subplots(1, len(category_order), figsize=(18, 10), sharey=True)\n", 85 | "plt.rcParams[\"font.size\"] = 15\n", 86 | "cmap = plt.get_cmap(\"Set2\")\n", 87 | "sns.set(style=\"whitegrid\") \n", 88 | "sns.heatmap(result_for_heatmap_acc, vmax=1, vmin=0.70, annot=True, fmt='.3f', ax = ax_heat[0],annot_kws={\"size\": 20},cbar=False, cmap='pink_r')\n", 89 | "sns.heatmap(result_for_heatmap_pre, vmax=1, vmin=0.70, annot=True, fmt='.3f', ax = ax_heat[1],annot_kws={\"size\": 20},cbar=False, cmap='pink_r')\n", 90 | "sns.heatmap(result_for_heatmap_rec, vmax=1, vmin=0.70, annot=True, fmt='.3f', ax = ax_heat[2],annot_kws={\"size\": 20},cbar=False, cmap='pink_r')\n", 91 | "sns.heatmap(result_for_heatmap_f1, vmax=1, vmin=0.70, annot=True, fmt='.3f', ax = ax_heat[3],annot_kws={\"size\": 20}, cbar_kws={ \"location\":\"right\"}, cmap='pink_r')\n", 92 | "\n", 93 | "for i in range(len(category_order)):\n", 94 | " indicate = category_order[i]\n", 95 | " \n", 96 | " ax_heat[i].set_xlabel(indicate, fontsize=20)\n", 97 | " ax_heat[i].set_ylabel(ylabel_heatmap[i], fontsize = 20)\n", 98 | " ax_heat[i].tick_params(axis='x', labelrotation=45, labelsize=20)\n", 99 | " ax_heat[i].tick_params(axis='y', labelsize=20)\n", 100 | " \n", 101 | "fig_heat.tight_layout()" 102 | ] 103 | }, 104 | { 105 | "cell_type": "markdown", 106 | "metadata": {}, 107 | "source": [ 108 | "# Fig.S1" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": null, 114 | "metadata": {}, 115 | "outputs": [], 116 | "source": [ 117 | "fold_table = pd.read_csv('../Sample_data/fold_table.csv',index_col=0)\n", 118 | "mechanism_count = pd.DataFrame(fold_table['mechanism'].value_counts()).rename(columns = {'mechanism':'Resistance mechanism'})\n", 119 | "\n", 120 | "hmdargdb = pd.read_csv('../Sample_data/input.csv',index_col=0) ## Change the path to your data.\n", 121 | "length_list = []\n", 122 | "for m in mechanism_list:\n", 123 | " length_list.append([hmdargdb['mechanism']==m]['Length'].tolist())" 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": null, 129 | "metadata": {}, 130 | "outputs": [], 131 | "source": [ 132 | "fig, ax = plt.subplots(2, 1, figsize=(15, 15))\n", 133 | "\n", 134 | "sns.barplot(mechanism_count['Resistance mechanism'],mechanism_count.index, palette = 'Set2', ax = ax[0])\n", 135 | "for i in range(len(mechanism_count.index)):\n", 136 | " ax[0].text(mechanism_count.iloc[i,0], i, mechanism_count.iloc[i,0],fontsize = 20)\n", 137 | "\n", 138 | "\n", 139 | "colors = [cmap(i) for i in range(len(set(mechanism_list)))]\n", 140 | "ax[1].hist(length_list, histtype='barstacked', label=mechanism_list, color = colors)\n", 141 | "ax[1].legend(title=\"Resistance mechanism\")\n", 142 | "plt.xlabel('Sequence length')\n", 143 | "plt.ylabel('# of sequences')" 144 | ] 145 | }, 146 | { 147 | "cell_type": "markdown", 148 | "metadata": {}, 149 | "source": [ 150 | "# Fig S2-4" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": null, 156 | "metadata": {}, 157 | "outputs": [], 158 | "source": [ 159 | "category = 'mechanism' # mechanism or threshold\n", 160 | "dataset = 'hmdargdb' # LHD_0.4 or hmdargdb" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": null, 166 | "metadata": {}, 167 | "outputs": [], 168 | "source": [ 169 | "fold_table = pd.read_csv('../Sample_data/fold_table.csv', index_col = 0)\n", 170 | "\n", 171 | "mechanism_count = fold_table['mechanism'].value_counts()\n", 172 | "\n", 173 | "c_list = [0.4,0.6,0.7,0.8,0.9]\n", 174 | "\n", 175 | "mechanism_count_dict = {}\n", 176 | "for m in mechanism_list[:-1]:\n", 177 | " mechanism_count_dict[m] = m+'\\n('+str(mechanism_count.loc[m])+')'" 178 | ] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "execution_count": null, 183 | "metadata": {}, 184 | "outputs": [], 185 | "source": [ 186 | "def make_dataset_for_heatmap(method,dataset,category):\n", 187 | " result_per_mechanism = pd.read_csv('Prediction results/'+category + '/'+dataset+'_' + method +'.csv',index_col=0)\n", 188 | " if category == 'mechanism':\n", 189 | " result_for_heatmap = result_per_mechanism[result_per_mechanism['fold'] == 'Average'].iloc[:-1,1:].set_index(category).reindex(index=mechanism_list)[:-1].rename(index = mechanism_count_dict)\n", 190 | " else:\n", 191 | " result_for_heatmap = result_per_mechanism[result_per_mechanism['fold'] == 'Average'].iloc[:,1:].set_index(category)\n", 192 | " return result_for_heatmap" 193 | ] 194 | }, 195 | { 196 | "cell_type": "code", 197 | "execution_count": null, 198 | "metadata": {}, 199 | "outputs": [], 200 | "source": [ 201 | "result = pd.read_csv('Prediction results/Raw_data/'+dataset+'_' + method +'.csv', index_col = 0).reset_index(drop = 'True')\n", 202 | "df = pd.DataFrame()\n", 203 | "\n", 204 | "category_order = ['F1 Score', 'Recall', 'Precision', 'Accuracy'][::-1]\n", 205 | "\n", 206 | "if category == 'mechanism':\n", 207 | " ylabel_heatmap = ['Resistance mechanism(# of sequence)','','','','']\n", 208 | "else:\n", 209 | " ylabel_heatmap = [category,'','','']\n", 210 | "\n", 211 | "if dataset == 'hmdargdb':\n", 212 | " method_list = ['Proposed', 'LM-ARG', 'HMD-ARG', 'BLAST', 'CARD-RGI']\n", 213 | "else:\n", 214 | " method_list = ['Proposed', 'LM-ARG', 'BLAST']\n", 215 | "\n", 216 | "fig, axis = plt.subplots(1, len(method_list), figsize=(15, 10), sharey=True)\n", 217 | "fig_heat,ax_heat = plt.subplots(1, len(method_list), figsize=(18, 10), sharey=True)\n", 218 | "plt.rcParams[\"font.size\"] = 15\n", 219 | "cmap = plt.get_cmap(\"Set2\")\n", 220 | "sns.set(style=\"whitegrid\") \n", 221 | "\n", 222 | "for i in range(len(method_list)):\n", 223 | " methods = method_list[i]\n", 224 | " \n", 225 | " sns.stripplot(x=\"metrics\", y=\"Value\", data=result[result['method'] == methods], jitter=True, alpha=0.7, ax=axis[i], order=category_order, size=10,color='black')\n", 226 | " \n", 227 | " category_means_proposed = result[result['method'] == methods].groupby('metrics')[['Value']].mean().reindex(category_order)\n", 228 | " category_means_proposed.plot(kind='bar', alpha=1, ax=axis[i], color=cmap(i))#color_list[i])\n", 229 | " df[methods] = category_means_proposed['Value']\n", 230 | " \n", 231 | " axis[i].set_xlabel(methods, fontsize=20)\n", 232 | " axis[i].set_ylabel('',fontsize=20)\n", 233 | " axis[i].set_ylim(0.6, 1)\n", 234 | " axis[i].tick_params(axis='x', labelrotation=45, labelsize=20)\n", 235 | " axis[i].tick_params(axis='y', labelsize=20)\n", 236 | " axis[i].legend('')\n", 237 | " \n", 238 | " sns.heatmap(make_dataset_for_heatmap(methods,dataset,category), vmax=1, vmin=0.70, annot=True, fmt='.3f', ax = ax_heat[i],annot_kws={\"size\": 13}, cbar_kws={ \"location\":\"top\"}, cmap='pink_r')\n", 239 | " ax_heat[i].set_xlabel(methods, fontsize=20)\n", 240 | " ax_heat[i].set_ylabel(ylabel_heatmap[i], fontsize = 20)\n", 241 | " ax_heat[i].tick_params(axis='x', labelrotation=45, labelsize=20)\n", 242 | " ax_heat[i].tick_params(axis='y', labelsize=13)\n", 243 | " \n", 244 | "df\n", 245 | "fig.tight_layout() \n", 246 | "fig_heat.tight_layout()\n", 247 | "plt.show()\n" 248 | ] 249 | } 250 | ], 251 | "metadata": { 252 | "interpreter": { 253 | "hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49" 254 | }, 255 | "kernelspec": { 256 | "display_name": "Python 3 (ipykernel)", 257 | "language": "python", 258 | "name": "python3" 259 | }, 260 | "language_info": { 261 | "codemirror_mode": { 262 | "name": "ipython", 263 | "version": 3 264 | }, 265 | "file_extension": ".py", 266 | "mimetype": "text/x-python", 267 | "name": "python", 268 | "nbconvert_exporter": "python", 269 | "pygments_lexer": "ipython3", 270 | "version": "3.8.10" 271 | } 272 | }, 273 | "nbformat": 4, 274 | "nbformat_minor": 4 275 | } 276 | -------------------------------------------------------------------------------- /proteinbert/uniref_dataset.py: -------------------------------------------------------------------------------- 1 | import re 2 | import gzip 3 | import json 4 | from collections import Counter 5 | import sqlite3 6 | 7 | import numpy as np 8 | import pandas as pd 9 | import h5py 10 | from lxml import etree 11 | from pyfaidx import Faidx 12 | 13 | from .shared_utils.util import log, to_chunks 14 | 15 | class UnirefToSqliteParser: 16 | 17 | def __init__(self, uniref_xml_gz_file_path, go_annotations_meta, sqlite_file_path, verbose = True, log_progress_every = 1000, \ 18 | chunk_size = 100000): 19 | 20 | self.uniref_xml_gz_file_path = uniref_xml_gz_file_path 21 | self.go_annotations_meta = go_annotations_meta 22 | self.sqlite_conn = sqlite3.connect(sqlite_file_path) 23 | self.verbose = verbose 24 | self.log_progress_every = log_progress_every 25 | self.chunk_size = chunk_size 26 | 27 | self._process_go_annotations_meta() 28 | 29 | self.go_index_record_counter = Counter() 30 | self.unrecognized_go_annotations = Counter() 31 | self.n_records_with_any_go_annotation = 0 32 | 33 | self._chunk_indices = [] 34 | self._chunk_records = [] 35 | 36 | def parse(self): 37 | 38 | with gzip.open(self.uniref_xml_gz_file_path, 'rb') as f: 39 | context = etree.iterparse(f, tag = UnirefToSqliteParser._NAMESPACE_PREFIX + 'entry', events = ('end',)) 40 | _etree_fast_iter(context, self._process_entry) 41 | 42 | if len(self._chunk_records) > 0: 43 | self._save_current_chunk() 44 | 45 | if self.verbose: 46 | log('Ignored the following unrecognized GO annotations: %s' % self.unrecognized_go_annotations) 47 | log('Parsed %d records with any GO annotation.' % self.n_records_with_any_go_annotation) 48 | 49 | go_id_record_counter = pd.Series(self.go_index_record_counter) 50 | go_id_record_counter.index = [self.go_index_to_id[index] for index in go_id_record_counter.index] 51 | 52 | self.go_annotations_meta['count'] = go_id_record_counter.reindex(self.go_annotations_meta.index).fillna(0) 53 | self.go_annotations_meta['freq'] = self.go_annotations_meta['count'] / self.n_records_with_any_go_annotation 54 | 55 | if self.verbose: 56 | log('Done.') 57 | 58 | def _process_go_annotations_meta(self): 59 | self.go_annotation_to_all_ancestors = self.go_annotations_meta['all_ancestors'].to_dict() 60 | self.go_id_to_index = self.go_annotations_meta['index'].to_dict() 61 | self.go_index_to_id = self.go_annotations_meta.reset_index().set_index('index')['id'].to_dict() 62 | 63 | def _process_entry(self, i, event, entry): 64 | 65 | if self.verbose and i % self.log_progress_every == 0: 66 | log(i, end = '\r') 67 | 68 | repr_member, = entry.xpath(r'uniprot:representativeMember', namespaces = UnirefToSqliteParser._NAMESPACES) 69 | db_ref, = repr_member.xpath(r'uniprot:dbReference', namespaces = UnirefToSqliteParser._NAMESPACES) 70 | protein_name = db_ref.attrib['id'] 71 | 72 | try: 73 | taxonomy_element, = db_ref.xpath(r'uniprot:property[@type="NCBI taxonomy"]', namespaces = UnirefToSqliteParser._NAMESPACES) 74 | tax_id = int(taxonomy_element.attrib['value']) 75 | except: 76 | tax_id = np.nan 77 | 78 | extracted_go_annotations = {category: UnirefToSqliteParser._extract_go_category(entry, category) for category in UnirefToSqliteParser._GO_ANNOTATION_CATEGORIES} 79 | 80 | self._chunk_indices.append(i) 81 | self._chunk_records.append((tax_id, protein_name, extracted_go_annotations)) 82 | 83 | if len(self._chunk_records) >= self.chunk_size: 84 | self._save_current_chunk() 85 | 86 | def _save_current_chunk(self): 87 | 88 | chunk_records_df = pd.DataFrame(self._chunk_records, columns = ['tax_id', 'uniprot_name', 'go_annotations'], index = self._chunk_indices) 89 | 90 | chunk_records_df['flat_go_annotations'] = chunk_records_df['go_annotations'].apply(\ 91 | lambda go_annotations: list(sorted(set.union(*map(set, go_annotations.values()))))) 92 | chunk_records_df['n_go_annotations'] = chunk_records_df['flat_go_annotations'].apply(len) 93 | chunk_records_df['complete_go_annotation_indices'] = chunk_records_df['flat_go_annotations'].apply(self._get_complete_go_annotation_indices) 94 | chunk_records_df['n_complete_go_annotations'] = chunk_records_df['complete_go_annotation_indices'].apply(len) 95 | self.n_records_with_any_go_annotation += (chunk_records_df['n_complete_go_annotations'] > 0).sum() 96 | 97 | for complete_go_annotation_indices in chunk_records_df['complete_go_annotation_indices']: 98 | self.go_index_record_counter.update(complete_go_annotation_indices) 99 | 100 | chunk_records_df['go_annotations'] = chunk_records_df['go_annotations'].apply(json.dumps) 101 | chunk_records_df['flat_go_annotations'] = chunk_records_df['flat_go_annotations'].apply(json.dumps) 102 | chunk_records_df['complete_go_annotation_indices'] = chunk_records_df['complete_go_annotation_indices'].apply(json.dumps) 103 | chunk_records_df.to_sql('protein_annotations', self.sqlite_conn, if_exists = 'append') 104 | 105 | self._chunk_indices = [] 106 | self._chunk_records = [] 107 | 108 | def _get_complete_go_annotation_indices(self, go_annotations): 109 | complete_go_annotations = self._get_complete_go_annotations(go_annotations) 110 | return list(sorted(filter(None, map(self.go_id_to_index.get, go_annotations)))) 111 | 112 | def _get_complete_go_annotations(self, go_annotations): 113 | return set.union(set(), *[self._get_go_annotation_all_ancestors(annotation) for annotation in go_annotations]) 114 | 115 | def _get_go_annotation_all_ancestors(self, annotation): 116 | if annotation in self.go_annotation_to_all_ancestors: 117 | return self.go_annotation_to_all_ancestors[annotation] 118 | else: 119 | self.unrecognized_go_annotations[annotation] += 1 120 | return set() 121 | 122 | @staticmethod 123 | def _extract_go_category(entry, category): 124 | return list({property_element.attrib['value'] for property_element in entry.xpath(r'uniprot:property[@type="%s"]' % \ 125 | category, namespaces = UnirefToSqliteParser._NAMESPACES)}) 126 | 127 | _NAMESPACE_PREFIX = r'{http://uniprot.org/uniref}' 128 | _NAMESPACES = {'uniprot': r'http://uniprot.org/uniref'} 129 | 130 | _GO_ANNOTATION_CATEGORIES = [ 131 | 'GO Molecular Function', 132 | 'GO Biological Process', 133 | 'GO Cellular Component', 134 | ] 135 | 136 | def parse_go_annotations_meta(meta_file_path): 137 | 138 | ALL_FIELDS = ['id', 'name', 'namespace', 'def', 'is_a', 'synonym', 'alt_id', 'subset', 'is_obsolete', 'xref', \ 139 | 'relationship', 'intersection_of', 'disjoint_from', 'consider', 'comment', 'replaced_by', 'created_by', \ 140 | 'creation_date', 'property_value'] 141 | LIST_FIELDS = {'synonym', 'alt_id', 'subset', 'is_a', 'xref', 'relationship', 'disjoint_from', 'intersection_of', \ 142 | 'consider', 'property_value'} 143 | 144 | GO_ANNOTATION_PATTERN = re.compile(r'\[Term\]\n((?:\w+\: .*\n?)+)') 145 | FIELD_LINE_PATTERN = re.compile(r'(\w+)\: (.*)') 146 | 147 | with open(meta_file_path, 'r') as f: 148 | raw_go_meta = f.read() 149 | 150 | go_annotations_meta = [] 151 | 152 | for match in GO_ANNOTATION_PATTERN.finditer(raw_go_meta): 153 | 154 | raw_go_annotation = match.group(1) 155 | go_annotation = {field: [] for field in LIST_FIELDS} 156 | 157 | for line in raw_go_annotation.splitlines(): 158 | 159 | (field, value), = FIELD_LINE_PATTERN.findall(line) 160 | assert field in ALL_FIELDS 161 | 162 | if field in LIST_FIELDS: 163 | go_annotation[field].append(value) 164 | else: 165 | assert field not in go_annotation 166 | go_annotation[field] = value 167 | 168 | go_annotations_meta.append(go_annotation) 169 | 170 | go_annotations_meta = pd.DataFrame(go_annotations_meta, columns = ALL_FIELDS) 171 | go_annotations_meta['is_obsolete'] = go_annotations_meta['is_obsolete'].fillna(False) 172 | assert go_annotations_meta['id'].is_unique 173 | go_annotations_meta.set_index('id', drop = True, inplace = True) 174 | go_annotations_meta.insert(0, 'index', np.arange(len(go_annotations_meta))) 175 | _add_children_and_parents_to_go_annotations_meta(go_annotations_meta) 176 | 177 | return go_annotations_meta 178 | 179 | def create_h5_dataset(protein_annotations_sqlite_db_file_path, fasta_file_path, go_annotations_meta_csv_file_path, output_h5_file_path, shuffle = True, \ 180 | min_records_to_keep_annotation = 100, records_limit = None, save_chunk_size = 10000, verbose = True, log_progress_every = 10000): 181 | 182 | go_annotations_meta = pd.read_csv(go_annotations_meta_csv_file_path, usecols = ['id', 'index', 'count'], index_col = 0) 183 | annotation_counts = go_annotations_meta['count'] 184 | common_annotation_ids = np.array(sorted(annotation_counts[annotation_counts >= min_records_to_keep_annotation].index)) 185 | original_annotation_index_to_common_annotation_index = {go_annotations_meta.loc[annotation_id, 'index']: i for i, annotation_id in enumerate(common_annotation_ids)} 186 | 187 | if verbose: 188 | log('Will encode the %d most common annotations.' % len(common_annotation_ids)) 189 | 190 | n_seqs = sum(1 for _ in load_seqs_and_annotations(protein_annotations_sqlite_db_file_path, fasta_file_path, shuffle = False, records_limit = records_limit, \ 191 | verbose = verbose, log_progress_every = log_progress_every)) 192 | 193 | if verbose: 194 | log('Will create an h5 dataset of %d final sequences.' % n_seqs) 195 | 196 | with h5py.File(output_h5_file_path, 'w') as h5f: 197 | 198 | h5f.create_dataset('included_annotations', data = [annotation.encode('ascii') for annotation in common_annotation_ids], dtype = h5py.string_dtype()) 199 | uniprot_ids = h5f.create_dataset('uniprot_ids', shape = (n_seqs,), dtype = h5py.string_dtype()) 200 | seqs = h5f.create_dataset('seqs', shape = (n_seqs,), dtype = h5py.string_dtype()) 201 | seq_lengths = h5f.create_dataset('seq_lengths', shape = (n_seqs,), dtype = np.int32) 202 | annotation_masks = h5f.create_dataset('annotation_masks', shape = (n_seqs, len(common_annotation_ids)), dtype = bool) 203 | 204 | start_index = 0 205 | 206 | for seqs_and_annotations_chunk in to_chunks(load_seqs_and_annotations(protein_annotations_sqlite_db_file_path, fasta_file_path, shuffle = shuffle, \ 207 | records_limit = records_limit, verbose = verbose, log_progress_every = log_progress_every), save_chunk_size): 208 | 209 | end_index = start_index + len(seqs_and_annotations_chunk) 210 | uniprot_id_chunk, seq_chunk, annotation_indices_chunk = map(list, zip(*seqs_and_annotations_chunk)) 211 | 212 | uniprot_ids[start_index:end_index] = uniprot_id_chunk 213 | seqs[start_index:end_index] = seq_chunk 214 | seq_lengths[start_index:end_index] = list(map(len, seq_chunk)) 215 | annotation_masks[start_index:end_index, :] = _encode_annotations_as_a_binary_matrix(annotation_indices_chunk, original_annotation_index_to_common_annotation_index) 216 | 217 | start_index = end_index 218 | 219 | if verbose: 220 | log('Done.') 221 | 222 | def load_seqs_and_annotations(protein_annotations_sqlite_db_file_path, fasta_file_path, shuffle = True, records_limit = None, verbose = True, \ 223 | log_progress_every = 10000): 224 | 225 | if verbose: 226 | log('Loading %s records...' % ('all' if records_limit is None else records_limit)) 227 | 228 | conn = sqlite3.connect(protein_annotations_sqlite_db_file_path) 229 | raw_proteins_and_annotations = pd.read_sql_query('SELECT uniprot_name, complete_go_annotation_indices FROM protein_annotations' + ('' if records_limit is None else \ 230 | (' LIMIT %d' % records_limit)), conn) 231 | 232 | if verbose: 233 | log('Loaded %d proteins and their GO annotations (%d columns: %s)' % (raw_proteins_and_annotations.shape + (', '.join(raw_proteins_and_annotations.columns),))) 234 | 235 | if shuffle: 236 | raw_proteins_and_annotations = raw_proteins_and_annotations.sample(frac = 1, random_state = 0) 237 | 238 | if verbose: 239 | log('Loading Faidx (%s)...' % fasta_file_path) 240 | 241 | seqs_faidx = Faidx(fasta_file_path) 242 | 243 | if verbose: 244 | log('Finished loading Faidx.') 245 | 246 | n_failed = 0 247 | 248 | for i, (_, (uniprot_id, raw_go_annotation_indices)) in enumerate(raw_proteins_and_annotations.iterrows()): 249 | 250 | if verbose and i % log_progress_every == 0: 251 | log('%d/%d' % (i, len(raw_proteins_and_annotations)), end = '\r') 252 | 253 | seq_fasta_id = 'UniRef90_%s' % uniprot_id.split('_')[0] 254 | 255 | try: 256 | seq = str(seqs_faidx.fetch(seq_fasta_id, 1, seqs_faidx.index[seq_fasta_id].rlen)) 257 | yield uniprot_id, seq, json.loads(raw_go_annotation_indices) 258 | except KeyError: 259 | n_failed += 1 260 | 261 | if verbose: 262 | log('Finished. Failed finding the sequence for %d of %d records.' % (n_failed, len(raw_proteins_and_annotations))) 263 | 264 | def _add_children_and_parents_to_go_annotations_meta(go_annotations_meta): 265 | 266 | go_annotations_meta['direct_children'] = [set() for _ in range(len(go_annotations_meta))] 267 | go_annotations_meta['direct_parents'] = [set() for _ in range(len(go_annotations_meta))] 268 | 269 | for go_id, go_annotation in go_annotations_meta.iterrows(): 270 | for raw_is_a in go_annotation['is_a']: 271 | parent_id, parent_name = raw_is_a.split(' ! ') 272 | parent_go_annotation = go_annotations_meta.loc[parent_id] 273 | assert parent_go_annotation['name'] == parent_name 274 | go_annotation['direct_parents'].add(parent_id) 275 | parent_go_annotation['direct_children'].add(go_id) 276 | 277 | go_annotations_meta['all_ancestors'] = pd.Series(_get_index_to_all_ancestors(\ 278 | go_annotations_meta['direct_children'].to_dict(), \ 279 | go_annotations_meta[~go_annotations_meta['direct_parents'].apply(bool)].index)) 280 | go_annotations_meta['all_offspring'] = pd.Series(_get_index_to_all_ancestors(\ 281 | go_annotations_meta['direct_parents'].to_dict(), \ 282 | go_annotations_meta[~go_annotations_meta['direct_children'].apply(bool)].index)) 283 | 284 | def _get_index_to_all_ancestors(index_to_direct_children, root_indices): 285 | 286 | index_to_all_ancestors = {index: {index} for index in index_to_direct_children.keys()} 287 | indices_to_scan = set(root_indices) 288 | 289 | while indices_to_scan: 290 | 291 | scanned_child_indices = set() 292 | 293 | for index in indices_to_scan: 294 | for child_index in index_to_direct_children[index]: 295 | index_to_all_ancestors[child_index].update(index_to_all_ancestors[index]) 296 | scanned_child_indices.add(child_index) 297 | 298 | indices_to_scan = scanned_child_indices 299 | 300 | return index_to_all_ancestors 301 | 302 | def _encode_annotations_as_a_binary_matrix(records_annotations, annotation_to_index): 303 | 304 | annotation_masks = np.zeros((len(records_annotations), len(annotation_to_index)), dtype = bool) 305 | 306 | for i, record_annotations in enumerate(records_annotations): 307 | for annotation in record_annotations: 308 | if annotation in annotation_to_index: 309 | annotation_masks[i, annotation_to_index[annotation]] = True 310 | 311 | return annotation_masks 312 | 313 | def _etree_fast_iter(context, func, func_args = [], func_kwargs = {}, max_elements = None): 314 | ''' 315 | Based on: https://stackoverflow.com/questions/12160418/why-is-lxml-etree-iterparse-eating-up-all-my-memory 316 | http://lxml.de/parsing.html#modifying-the-tree 317 | Based on Liza Daly's fast_iter 318 | http://www.ibm.com/developerworks/xml/library/x-hiperfparse/ 319 | See also http://effbot.org/zone/element-iterparse.htm 320 | ''' 321 | for i, (event, elem) in enumerate(context): 322 | func(i, event, elem, *func_args, **func_kwargs) 323 | # It's safe to call clear() here because no descendants will be 324 | # accessed 325 | elem.clear() 326 | # Also eliminate now-empty references from the root node to elem 327 | for ancestor in elem.xpath('ancestor-or-self::*'): 328 | while ancestor.getprevious() is not None: 329 | del ancestor.getparent()[0] 330 | if max_elements is not None and i >= max_elements - 1: 331 | break 332 | del context 333 | -------------------------------------------------------------------------------- /proteinbert/pretraining.py: -------------------------------------------------------------------------------- 1 | import os 2 | import itertools 3 | from datetime import datetime, timedelta 4 | import pickle 5 | 6 | import numpy as np 7 | import pandas as pd 8 | import h5py 9 | 10 | from tensorflow import keras 11 | 12 | from .shared_utils.util import log 13 | from .tokenization import ADDED_TOKENS_PER_SEQ, additional_token_to_index, n_tokens, tokenize_seq, parse_seq 14 | from .model_generation import PretrainingModelGenerator 15 | 16 | DEFAULT_EPISODE_SETTINGS = [ 17 | # seq_len, batch_size 18 | (128, 128), 19 | (512, 64), 20 | (1024, 32), 21 | ] 22 | 23 | def run_pretraining(create_model_function, epoch_generator, h5_dataset_file_path, create_model_kwargs = {}, optimizer_class = keras.optimizers.Adam, lr = 2e-04, \ 24 | other_optimizer_kwargs = {}, annots_loss_weight = 1, autosave_manager = None, weights_dir = None, resume_from = None, n_epochs = None, fit_callbacks = []): 25 | 26 | np.random.seed(0) 27 | 28 | with h5py.File(h5_dataset_file_path, 'r') as h5f: 29 | n_annotations = len(h5f['included_annotations']) 30 | 31 | model_generator = PretrainingModelGenerator(create_model_function, n_annotations, create_model_kwargs = create_model_kwargs, optimizer_class = optimizer_class, lr = lr, \ 32 | other_optimizer_kwargs = other_optimizer_kwargs, annots_loss_weight = annots_loss_weight) 33 | model_trainer = ModelTrainer(model_generator, epoch_generator, autosave_manager = autosave_manager, weights_dir = weights_dir, fit_callbacks = fit_callbacks) 34 | 35 | with h5py.File(h5_dataset_file_path, 'r') as h5f: 36 | model_trainer.setup(DatasetHandler(h5f), resume_from = resume_from) 37 | model_trainer.train(n_epochs = n_epochs) 38 | 39 | return model_trainer 40 | 41 | class ModelTrainer: 42 | 43 | def __init__(self, model_generator, epoch_generator, autosave_manager = None, weights_dir = None, fit_callbacks = []): 44 | 45 | self.model_generator = model_generator 46 | self.epoch_generator = epoch_generator 47 | self.autosave_manager = autosave_manager 48 | self.weights_dir = weights_dir 49 | self.fit_callbacks = fit_callbacks 50 | 51 | if self.autosave_manager is not None: 52 | self.autosave_manager.n_annotations = self.model_generator.n_annotations 53 | 54 | def setup(self, dataset_handler, resume_from = None): 55 | 56 | if resume_from is None: 57 | self.current_epoch_index = 0 58 | start_sample_index = 0 59 | resumed_weights_file_path = None 60 | else: 61 | self.current_epoch_index, start_sample_index = resume_from 62 | self.current_epoch_index += 1 63 | resumed_weights_file_path = os.path.join(self.weights_dir, 'epoch_%d_sample_%d.pkl' % resume_from) 64 | 65 | starting_episode = self.epoch_generator.setup(dataset_handler, start_sample_index) 66 | self.model_generator.dummy_epoch = self.epoch_generator.create_dummpy_epoch()[:2] 67 | log('Starting with episode with seq_len = %d.' % starting_episode.seq_len) 68 | 69 | if resumed_weights_file_path is not None: 70 | with open(resumed_weights_file_path, 'rb') as f: 71 | n_annotations, self.model_generator.model_weights, self.model_generator.optimizer_weights = pickle.load(f) 72 | assert n_annotations == self.model_generator.n_annotations 73 | log('Loaded weights from %s.' % resumed_weights_file_path) 74 | 75 | self.model = self.model_generator.create_model(starting_episode.seq_len) 76 | self.model.summary() 77 | 78 | def train(self, n_epochs = None, autosave = True): 79 | for _ in (itertools.count() if n_epochs is None else range(n_epochs)): 80 | self.train_next_epoch(autosave = autosave) 81 | 82 | def train_next_epoch(self, autosave = True): 83 | 84 | changed_episode, episode = self.epoch_generator.determine_episode_and_ready_next_epoch() 85 | 86 | if changed_episode: 87 | log('Starting a new episode with seq_len = %d.' % episode.seq_len) 88 | self.model_generator.dummy_epoch = self.epoch_generator.create_dummpy_epoch()[:2] 89 | self.model_generator.update_state(self.model) 90 | self.model = self.model_generator.create_model(episode.seq_len) 91 | 92 | X, Y, sample_weigths = self.epoch_generator.create_next_epoch() 93 | log('Epoch %d (current sample %d):' % (self.current_epoch_index, self.epoch_generator.current_sample_index)) 94 | self.model.fit(X, Y, sample_weight = sample_weigths, batch_size = episode.batch_size, callbacks = self.fit_callbacks) 95 | 96 | if autosave and self.autosave_manager is not None: 97 | self.autosave_manager.on_epoch_end(self.model, self.current_epoch_index, self.epoch_generator.current_sample_index) 98 | 99 | self.current_epoch_index += 1 100 | 101 | class EpochGenerator: 102 | 103 | def __init__(self, n_batches_per_epoch = 100, p_seq_noise = 0.05, p_no_input_annot = 0.5, p_annot_noise_positive = 0.25, \ 104 | p_annot_noise_negative = 1e-04, load_chunk_size = 100000, min_time_per_episode = timedelta(minutes = 15), \ 105 | episode_settings = DEFAULT_EPISODE_SETTINGS): 106 | 107 | self.n_batches_per_epoch = n_batches_per_epoch 108 | self.p_seq_noise = p_seq_noise 109 | self.p_no_input_annot = p_no_input_annot 110 | self.p_annot_noise_positive = p_annot_noise_positive 111 | self.p_annot_noise_negative = p_annot_noise_negative 112 | self.load_chunk_size = load_chunk_size 113 | self.min_time_per_episode = min_time_per_episode 114 | 115 | self.episode_managers = [EpisodeDataManager(seq_len, batch_size, self.n_batches_per_epoch) for seq_len, batch_size in \ 116 | episode_settings] 117 | self.episode_seq_lens = np.array([episode_manager.seq_len for episode_manager in self.episode_managers]) 118 | 119 | def setup(self, dataset_handler, start_sample_index = 0): 120 | self.dataset_handler = dataset_handler 121 | self.current_sample_index = start_sample_index % self.dataset_handler.total_size 122 | self._load_chunk() 123 | self._select_new_episode() 124 | return self._current_episode 125 | 126 | def determine_episode_and_ready_next_epoch(self): 127 | 128 | if self._episode_selection_time + self.min_time_per_episode <= datetime.now(): 129 | old_episode = self._current_episode 130 | self._select_new_episode() 131 | changed_episode = (self._current_episode is not old_episode) 132 | else: 133 | changed_episode = False 134 | 135 | while not self._current_episode.is_epoch_ready(): 136 | self._load_chunk() 137 | 138 | return changed_episode, self._current_episode 139 | 140 | def create_next_epoch(self): 141 | return self._encode_epoch(*self.create_next_epoch_Y()) 142 | 143 | def create_dummpy_epoch(self, size = 1): 144 | return self._encode_epoch(*self.create_next_dummy_epoch_Y(size)) 145 | 146 | def create_next_epoch_Y(self): 147 | assert self._current_episode.is_epoch_ready() 148 | return self._current_episode.encode_next_epoch() 149 | 150 | def create_next_dummy_epoch_Y(self, size = 1): 151 | 152 | while not self._current_episode.is_epoch_ready(size): 153 | self._load_chunk() 154 | 155 | return self._current_episode.encode_dummy_epoch(size) 156 | 157 | def _select_new_episode(self): 158 | self._current_episode = max(self.episode_managers, key = lambda episode_manager: len(episode_manager.sample_cache)) 159 | self._episode_selection_time = datetime.now() 160 | 161 | def _load_chunk(self): 162 | 163 | chunk_sample_cache = self.dataset_handler[self.current_sample_index:(self.current_sample_index + self.load_chunk_size)] 164 | self.current_sample_index += self.load_chunk_size 165 | 166 | if self.current_sample_index >= self.dataset_handler.total_size: 167 | self.current_sample_index = 0 168 | 169 | self._assign_samples(chunk_sample_cache) 170 | 171 | def _assign_samples(self, sample_cache): 172 | 173 | seq_lens = np.array(list(map(len, sample_cache.seqs))) + ADDED_TOKENS_PER_SEQ 174 | assigned_episode_indices = self._select_episodes_to_assign(seq_lens) 175 | 176 | for episode_manager_index, episode_manager in enumerate(self.episode_managers): 177 | sample_indices_for_episode, = np.where(assigned_episode_indices == episode_manager_index) 178 | episode_manager.sample_cache.extend(sample_cache.slice_indices(sample_indices_for_episode)) 179 | 180 | def _select_episodes_to_assign(self, seq_lens, gamma = 1): 181 | # The smaller the distance between a sample's sequence length to an episode's maximum sequence length, the higher the chance 182 | # that it will be assigned to that episode. 183 | samples_by_episodes_seq_len_ratio = seq_lens.reshape(-1, 1) / self.episode_seq_lens.reshape(1, -1) 184 | samples_by_episodes_seq_len_symmetric_ratio = np.maximum(samples_by_episodes_seq_len_ratio, 1 / samples_by_episodes_seq_len_ratio) 185 | raw_samples_by_episodes_probs = np.exp(-gamma * samples_by_episodes_seq_len_symmetric_ratio) 186 | samples_by_episodes_probs = raw_samples_by_episodes_probs / raw_samples_by_episodes_probs.sum(axis = -1).reshape(-1, 1) 187 | samples_by_episodes_cum_probs = samples_by_episodes_probs.cumsum(axis = -1) 188 | assigned_episode_indices = (np.random.rand(len(seq_lens), 1) <= samples_by_episodes_cum_probs).argmax(axis = 1) 189 | return assigned_episode_indices 190 | 191 | def _encode_epoch(self, encoded_seqs, encoded_annotation_masks): 192 | 193 | seqs_noise_mask = np.random.choice([True, False], encoded_seqs.shape, p = [1 - self.p_seq_noise, self.p_seq_noise]) 194 | random_seq_tokens = np.random.randint(0, n_tokens, encoded_seqs.shape) 195 | noisy_encoded_seqs = np.where(seqs_noise_mask, encoded_seqs, random_seq_tokens) 196 | 197 | noisy_annotations_when_positive = np.random.choice([True, False], encoded_annotation_masks.shape, \ 198 | p = [1 - self.p_annot_noise_positive, self.p_annot_noise_positive]) 199 | noisy_annotations_when_negative = np.random.choice([True, False], encoded_annotation_masks.shape, \ 200 | p = [self.p_annot_noise_negative, 1 - self.p_annot_noise_negative]) 201 | noisy_annotation_masks = np.where(encoded_annotation_masks, noisy_annotations_when_positive, \ 202 | noisy_annotations_when_negative) 203 | noisy_annotation_masks[np.random.choice([True, False], len(noisy_annotation_masks), p = [self.p_no_input_annot, \ 204 | 1 - self.p_no_input_annot]), :] = False 205 | 206 | seq_weights = (encoded_seqs != additional_token_to_index['']).astype(float) 207 | # When a protein has no annotations at all, we don't know whether it's because such annotations don't exist or just not found, 208 | # so it's safer to set the loss weight of those annotations to zero. 209 | annotation_weights = encoded_annotation_masks.any(axis = -1).astype(float) 210 | 211 | X = [noisy_encoded_seqs, noisy_annotation_masks.astype(np.int8)] 212 | Y = [np.expand_dims(encoded_seqs, axis = -1), encoded_annotation_masks.astype(np.int8)] 213 | sample_weigths = [seq_weights, annotation_weights] 214 | 215 | return X, Y, sample_weigths 216 | 217 | class EpisodeDataManager: 218 | 219 | def __init__(self, seq_len, batch_size, n_batches_per_epoch): 220 | self.seq_len = seq_len 221 | self.batch_size = batch_size 222 | self.n_batches_per_epoch = n_batches_per_epoch 223 | self.epoch_size = self.n_batches_per_epoch * self.batch_size 224 | self.sample_cache = SampleCache() 225 | 226 | def is_epoch_ready(self, n_required_samples = None): 227 | return len(self.sample_cache) >= self._resolve_epoch_size(n_required_samples) 228 | 229 | def get_next_raw_epoch(self, size = None): 230 | return self.sample_cache.pop(self._resolve_epoch_size(size)) 231 | 232 | def peek_raw_epoch(self, size = None): 233 | return self.sample_cache.slice_first(self._resolve_epoch_size(size)) 234 | 235 | def encode_next_epoch(self, log_length_dist = True): 236 | 237 | seq_lengths, encoded_seqs, encoded_annotation_masks = self._encode_epoch(self.get_next_raw_epoch()) 238 | 239 | if log_length_dist: 240 | log('Epoch sequence length distribution (for seq_len = %d): %s' % (self.seq_len, \ 241 | ', '.join('%s: %s' % item for item in pd.Series(seq_lengths).describe().iteritems()))) 242 | 243 | return encoded_seqs, encoded_annotation_masks 244 | 245 | def encode_dummy_epoch(self, size = 1): 246 | seq_lengths, encoded_seqs, encoded_annotation_masks = self._encode_epoch(self.peek_raw_epoch(size)) 247 | return encoded_seqs, encoded_annotation_masks 248 | 249 | def _encode_epoch(self, epoch_sample_cache): 250 | 251 | pad_token_index = additional_token_to_index[''] 252 | tokenized_seqs = list(map(tokenize_seq, epoch_sample_cache.seqs)) 253 | seq_lengths = np.array(list(map(len, tokenized_seqs))) 254 | max_offsets = np.maximum(seq_lengths - self.seq_len, 0) 255 | chosen_offsets = (np.random.rand(self.epoch_size) * (max_offsets + 1)).astype(int) 256 | trimmed_tokenized_seqs = [seq_tokens[chosen_offset:(chosen_offset + self.seq_len)] for seq_tokens, chosen_offset in \ 257 | zip(tokenized_seqs, chosen_offsets)] 258 | encoded_seqs = np.array([seq_tokens + max(self.seq_len - len(seq_tokens), 0) * [pad_token_index] for seq_tokens in \ 259 | trimmed_tokenized_seqs]).astype(np.int8) 260 | 261 | encoded_annotation_masks = np.concatenate([annotation_mask.reshape(1, -1) for annotation_mask in \ 262 | epoch_sample_cache.annotation_masks], axis = 0).astype(bool) 263 | 264 | # We hide the annotations of test-set samples to avoid "cheating" on downstream fine-tuning tests. Note that by removing all of the annotations, 265 | # EpochGenerator._encode_epoch will then set the annotation_weights for these records to 0, meaning they will not be part of the loss function. 266 | encoded_annotation_masks[epoch_sample_cache.test_set_mask, :] = False 267 | 268 | return seq_lengths, encoded_seqs, encoded_annotation_masks 269 | 270 | def _resolve_epoch_size(self, size): 271 | if size is None: 272 | return self.epoch_size 273 | else: 274 | return size 275 | 276 | class DatasetHandler: 277 | 278 | def __init__(self, dataset_h5f): 279 | self.dataset_h5f = dataset_h5f 280 | self.total_size = len(dataset_h5f['seq_lengths']) 281 | 282 | def __getitem__(self, slicing): 283 | return SampleCache(list(map(parse_seq, self.dataset_h5f['seqs'][slicing])), self.dataset_h5f['annotation_masks'][slicing], \ 284 | self.dataset_h5f['test_set_mask'][slicing]) 285 | 286 | class SampleCache: 287 | 288 | def __init__(self, seqs = [], annotation_masks = [], test_set_mask = []): 289 | self.seqs = list(seqs) 290 | self.annotation_masks = list(annotation_masks) 291 | self.test_set_mask = list(test_set_mask) 292 | 293 | def extend(self, other_cache): 294 | self.seqs.extend(other_cache.seqs) 295 | self.annotation_masks.extend(other_cache.annotation_masks) 296 | self.test_set_mask.extend(other_cache.test_set_mask) 297 | 298 | def pop(self, n): 299 | popped_sample_cache = self.slice_first(n) 300 | self.seqs = self.seqs[n:] 301 | self.annotation_masks = self.annotation_masks[n:] 302 | self.test_set_mask = self.test_set_mask[n:] 303 | return popped_sample_cache 304 | 305 | def slice_first(self, n): 306 | return SampleCache(self.seqs[:n], self.annotation_masks[:n], self.test_set_mask[:n]) 307 | 308 | def slice_indices(self, indices): 309 | return SampleCache([self.seqs[i] for i in indices], [self.annotation_masks[i] for i in indices], \ 310 | [self.test_set_mask[i] for i in indices]) 311 | 312 | def __len__(self): 313 | assert len(self.seqs) == len(self.annotation_masks) == len(self.test_set_mask) 314 | return len(self.seqs) 315 | 316 | class AutoSaveManager: 317 | 318 | def __init__(self, directory, every_epochs_to_save = 10, every_saves_to_keep = 25): 319 | self.directory = directory 320 | self.every_epochs_to_save = every_epochs_to_save 321 | self.every_saves_to_keep = every_saves_to_keep 322 | self.last_saved_path_to_delete = None 323 | self.n_saves = 0 324 | 325 | def on_epoch_end(self, model, epoch_index, sample_index): 326 | 327 | if epoch_index % self.every_epochs_to_save != 0: 328 | return 329 | 330 | save_path = os.path.join(self.directory, 'epoch_%d_sample_%d.pkl' % (epoch_index, sample_index)) 331 | _save_model_state(model, self.n_annotations, save_path) 332 | self.n_saves += 1 333 | 334 | if self.last_saved_path_to_delete is not None: 335 | os.remove(self.last_saved_path_to_delete) 336 | 337 | if self.n_saves % self.every_saves_to_keep == 0: 338 | self.last_saved_path_to_delete = None 339 | else: 340 | self.last_saved_path_to_delete = save_path 341 | 342 | def _save_model_state(model, n_annotations, path): 343 | with open(path, 'wb') as f: 344 | pickle.dump((n_annotations, model.get_weights(), model.optimizer.get_weights()), f) 345 | -------------------------------------------------------------------------------- /proteinbert/shared_utils/util.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import re 4 | import gc 5 | import importlib 6 | from collections import defaultdict 7 | from functools import reduce 8 | from datetime import datetime, timedelta 9 | import json 10 | 11 | import numpy as np 12 | import pandas as pd 13 | 14 | 15 | ### Logging ### 16 | 17 | def log(*message, **kwargs): 18 | 19 | global _log_file 20 | 21 | end = kwargs.get('end', '\n') 22 | 23 | if len(message) == 1: 24 | message, = message 25 | 26 | full_message = '[%s] %s' % (format_now(), message) 27 | 28 | print(full_message, end = end) 29 | sys.stdout.flush() 30 | 31 | if log_file_open(): 32 | _log_file.write(full_message + end) 33 | _log_file.flush() 34 | 35 | def start_log(log_dir, log_file_base_name): 36 | 37 | global _log_file 38 | 39 | if not os.path.exists(log_dir): 40 | os.makedirs(log_dir) 41 | 42 | log_file_name = '%s__%d__%s.txt' % (log_file_base_name, os.getpid(), format_now()) 43 | 44 | if not log_file_open(): 45 | print('Creating log file: %s' % log_file_name) 46 | _log_file = open(os.path.join(log_dir, log_file_name), 'w') 47 | 48 | def close_log(): 49 | 50 | global _log_file 51 | 52 | if log_file_open(): 53 | _log_file.close() 54 | del _log_file 55 | 56 | def restart_log(): 57 | close_log() 58 | start_log() 59 | 60 | def log_file_open(): 61 | global _log_file 62 | return '_log_file' in globals() 63 | 64 | def create_time_measure_if_verbose(opening_statement, verbose): 65 | if verbose: 66 | return TimeMeasure(opening_statement) 67 | else: 68 | return DummyContext() 69 | 70 | 71 | ### General ### 72 | 73 | def get_nullable(value, default_value): 74 | if pd.isnull(value): 75 | return default_value 76 | else: 77 | return value 78 | 79 | 80 | ### Reflection ### 81 | 82 | def load_object(full_object_name): 83 | name_parts = full_object_name.split('.') 84 | object_name = name_parts[-1] 85 | module_name = '.'.join(name_parts[:-1]) 86 | module = importlib.import_module(module_name) 87 | return getattr(module, object_name) 88 | 89 | 90 | ### Strings ### 91 | 92 | def trim(string, max_length, trim_suffix = '...'): 93 | if len(string) <= max_length: 94 | return string 95 | else: 96 | return string[:(max_length - len(trim_suffix))] + trim_suffix 97 | 98 | def break_to_lines(text, max_line_len): 99 | 100 | lines = [''] 101 | 102 | for word in text.split(): 103 | 104 | if len(lines[-1]) + len(word) > max_line_len: 105 | lines.append('') 106 | 107 | if lines[-1] != '': 108 | lines[-1] += ' ' 109 | 110 | lines[-1] += word 111 | 112 | return '\n'.join(lines) 113 | 114 | 115 | ### IO ### 116 | 117 | def safe_symlink(src, dst, post_creation_hook = lambda created_symlink: None): 118 | if os.path.exists(dst): 119 | log('%s: already exists.' % dst) 120 | else: 121 | try: 122 | os.symlink(src, dst) 123 | post_creation_hook(dst) 124 | log('Created link: %s -> %s' % (src, dst)) 125 | except OSError as e: 126 | if e.errno == 17: 127 | log('%s: already exists after all.' % dst) 128 | else: 129 | raise e 130 | 131 | def safe_mkdir(path): 132 | try: 133 | os.mkdir(path) 134 | except OSError as e: 135 | assert 'File exists' in str(e), str(e) 136 | 137 | def format_size_in_bytes(size): 138 | 139 | UNIT_RATIO = 1024 140 | UNITS = ['B', 'KB', 'MB', 'GB', 'TB'] 141 | 142 | for unit_index in range(len(UNITS)): 143 | if size < UNIT_RATIO: 144 | break 145 | else: 146 | size /= UNIT_RATIO 147 | 148 | return '%.1f%s' % (size, UNITS[unit_index]) 149 | 150 | def get_recognized_files_in_dir(dir_path, file_parser, log_unrecognized_files = True): 151 | 152 | recognized_files = [] 153 | unrecognized_files = [] 154 | 155 | for file_name in os.listdir(dir_path): 156 | try: 157 | recognized_files.append((file_parser(file_name), file_name)) 158 | except: 159 | if log_unrecognized_files: 160 | unrecognized_files.append(file_name) 161 | 162 | if log_unrecognized_files and len(unrecognized_files) > 0: 163 | log('%s: %d unrecognized files: %s' % (dir_path, len(unrecognized_files), ', '.join(unrecognized_files))) 164 | 165 | return list(sorted(recognized_files)) 166 | 167 | def monitor_memory(min_bytes_to_log = 1e08, max_elements_to_check = 100, collect_gc = True, del_output_variables = True, \ 168 | list_like_types = [list, tuple, np.ndarray, pd.Series], dict_like_types = [dict, defaultdict]): 169 | 170 | already_monitored_object_ids = set() 171 | 172 | def _is_of_any_type(obj, types): 173 | 174 | for t in types: 175 | if isinstance(obj, t): 176 | return True 177 | 178 | return False 179 | 180 | def _check_len_limit(obj): 181 | try: 182 | return len(obj) <= max_elements_to_check 183 | except: 184 | return False 185 | 186 | def _log_object_if_needed(name, obj): 187 | 188 | size = sys.getsizeof(obj) 189 | 190 | if size >= min_bytes_to_log: 191 | log('%s: %s' % (name, format_size_in_bytes(size))) 192 | 193 | def _monitor_object(name, obj): 194 | if id(obj) not in already_monitored_object_ids: 195 | 196 | already_monitored_object_ids.add(id(obj)) 197 | _log_object_if_needed(name, obj) 198 | 199 | if _is_of_any_type(obj, list_like_types) and _check_len_limit(obj): 200 | for i, element in enumerate(obj): 201 | _monitor_object('%s[%d]' % (name, i), element) 202 | 203 | if _is_of_any_type(obj, dict_like_types) and _check_len_limit(obj): 204 | for key, value in obj.items(): 205 | _monitor_object('%s[%s]' % (name, repr(key)), value) 206 | 207 | 208 | for module_name, module in sys.modules.items(): 209 | for variable_name in dir(module): 210 | 211 | full_variable_name = variable_name if module_name == '__main__' else '%s.%s' % (module_name, variable_name) 212 | _monitor_object(full_variable_name, getattr(module, variable_name)) 213 | 214 | if del_output_variables and module_name == '__main__' and re.match(r'^_[\d_]+$', variable_name): 215 | delattr(module, variable_name) 216 | 217 | if del_output_variables: 218 | sys.modules['__main__'].Out = dict() 219 | sys.modules['__main__']._oh = dict() 220 | 221 | if collect_gc: 222 | gc.collect() 223 | 224 | 225 | ### Date & time ### 226 | 227 | def format_now(): 228 | return datetime.now().strftime('%Y_%m_%d-%H:%M:%S') 229 | 230 | 231 | ### Iterators & collections ### 232 | 233 | def compare_list_against_collection(input_list, collection): 234 | collection_set = set(collection) 235 | return [element for element in input_list if element in collection_set], [element for element in input_list if element not in collection_set] 236 | 237 | def get_chunk_slice(size, n_chunks, chunk_index): 238 | assert size >= n_chunks 239 | chunk_size = size / n_chunks 240 | start_index = int(chunk_index * chunk_size) 241 | end_index = int((chunk_index + 1) * chunk_size) 242 | return start_index, end_index 243 | 244 | def get_chunk_intervals(size, chunk_size): 245 | for start_index in range(0, size, chunk_size): 246 | end_index = min(start_index + chunk_size, size) 247 | yield start_index, end_index 248 | 249 | def to_chunks(iterable, chunk_size): 250 | 251 | chunk = [] 252 | 253 | for element in iterable: 254 | 255 | chunk.append(element) 256 | 257 | if len(chunk) >= chunk_size: 258 | yield chunk 259 | chunk = [] 260 | 261 | if len(chunk) > 0: 262 | yield chunk 263 | 264 | def get_job_and_subjob_indices(n_jobs, n_tasks, task_index): 265 | 266 | ''' 267 | For example, if there are 170 tasks for working on 50 jobs, than each job will be divided to 3-4 tasks. 268 | Since 170 % 50 = 20, the 20 first jobs will receive 4 tasks and the last 30 jobs will receive only 3 tasks. 269 | In total, the first 80 tasks will be dedicated to jobs with 4 tasks each, and the 90 last tasks will be 270 | dedicated to jobs with 3 tasks each. Hence, tasks 0-3 will go to job 0, tasks 4-7 will go to job 1, and so on; 271 | tasks 80-82 will go to job 21, tasks 83-85 will job to job 22, and so on. 272 | ''' 273 | 274 | assert n_tasks >= n_jobs 275 | n_tasks_in_unprivileged_jobs = n_tasks // n_jobs 276 | n_tasks_in_privileged_jobs = n_tasks_in_unprivileged_jobs + 1 277 | n_privileged_jobs = n_tasks % n_jobs 278 | n_tasks_of_privileged_jobs = n_tasks_in_privileged_jobs * n_privileged_jobs 279 | 280 | if task_index < n_tasks_of_privileged_jobs: 281 | job_index = task_index // n_tasks_in_privileged_jobs 282 | index_within_job = task_index % n_tasks_in_privileged_jobs 283 | n_tasks_in_job = n_tasks_in_privileged_jobs 284 | else: 285 | task_index_in_unprivileged_group = task_index - n_tasks_of_privileged_jobs 286 | job_index = n_privileged_jobs + task_index_in_unprivileged_group // n_tasks_in_unprivileged_jobs 287 | index_within_job = task_index_in_unprivileged_group % n_tasks_in_unprivileged_jobs 288 | n_tasks_in_job = n_tasks_in_unprivileged_jobs 289 | 290 | return job_index, index_within_job, n_tasks_in_job 291 | 292 | def choose_from_cartesian_product(list_of_values, i, total = None): 293 | 294 | n = int(np.prod(list(map(len, list_of_values)))) 295 | 296 | if total is not None: 297 | assert n == total 298 | 299 | chosen_elements = [] 300 | 301 | for values in list_of_values: 302 | n //= len(values) 303 | chosen_elements.append(values[i // n]) 304 | i %= n 305 | 306 | return chosen_elements 307 | 308 | def calc_overlap_between_segments(ordered_segments1, ordered_segments2): 309 | 310 | ''' 311 | Calculates the total overlap size between a pair of ordered and disjoint groups of segments. 312 | Each group of segment is given by: [(start1, end1), (start2, end2), ...]. 313 | ''' 314 | 315 | from interval_tree import IntervalTree 316 | 317 | if len(ordered_segments1) == 0 or len(ordered_segments2) == 0: 318 | return 0 319 | 320 | if len(ordered_segments1) > len(ordered_segments2): 321 | ordered_segments1, ordered_segments2 = ordered_segments2, ordered_segments1 322 | 323 | min_value = min(ordered_segments1[0][0], ordered_segments2[0][0]) 324 | max_value = max(ordered_segments1[-1][1], ordered_segments2[-1][1]) 325 | interval_tree1 = IntervalTree([segment + (segment,) for segment in ordered_segments1], min_value, max_value) 326 | total_overlap = 0 327 | 328 | for segment in ordered_segments2: 329 | for overlapping_segment in interval_tree1.find_range(segment): 330 | overlapping_start = max(segment[0], overlapping_segment[0]) 331 | overlapping_end = min(segment[1], overlapping_segment[1]) 332 | assert overlapping_start <= overlapping_end, 'Reported overlap between %d..%d to %d..%d.' % (segment + \ 333 | overlapping_segment) 334 | total_overlap += (overlapping_end - overlapping_start + 1) 335 | 336 | return total_overlap 337 | 338 | def merge_lists_with_compatible_relative_order(lists): 339 | 340 | ''' 341 | Given a list of lists with compatible relative ordering (i.e. for every two sublists, the subset of elements that exist in the two 342 | sublists will have the same relative order), returns a merging of these sublists into a single grand list that contains all the 343 | elements (each element only once), and preserves the same ordering. 344 | ''' 345 | 346 | def merge_two_sublists(list1, list2): 347 | 348 | value_to_index = {value: float(i) for i, value in enumerate(list1)} 349 | unique_list2_index = {} 350 | last_identified_index = len(list1) 351 | 352 | for i, value in list(enumerate(list2))[::-1]: 353 | if value in value_to_index: 354 | last_identified_index = value_to_index[value] 355 | else: 356 | unique_list2_index[value] = last_identified_index - 1 + i / len(list2) 357 | 358 | value_to_index.update(unique_list2_index) 359 | return sorted(value_to_index.keys(), key = value_to_index.get) 360 | 361 | return reduce(merge_two_sublists, lists, []) 362 | 363 | 364 | ### argparse ### 365 | 366 | def get_parser_bool_type(parser): 367 | 368 | def _bool_type(value): 369 | if isinstance(value, bool): 370 | return value 371 | if value.lower() in ['yes', 'true', 't', 'y', '1']: 372 | return True 373 | elif value.lower() in ['no', 'false', 'f', 'n', '0']: 374 | return False 375 | else: 376 | raise parser.error('"%s": unrecognized boolean value.' % value) 377 | 378 | return _bool_type 379 | 380 | def get_parser_file_type(parser, must_exist = False): 381 | 382 | def _file_type(path): 383 | 384 | path = os.path.expanduser(path) 385 | 386 | if must_exist: 387 | if not os.path.exists(path): 388 | parser.error('File doesn\'t exist: %s' % path) 389 | elif not os.path.isfile(path): 390 | parser.error('Not a file: %s' % path) 391 | else: 392 | return path 393 | else: 394 | 395 | dir_path = os.path.dirname(path) 396 | 397 | if dir_path and not os.path.exists(dir_path): 398 | parser.error('Parent directory doesn\'t exist: %s' % dir_path) 399 | else: 400 | return path 401 | 402 | return _file_type 403 | 404 | def get_parser_directory_type(parser, create_if_not_exists = False): 405 | 406 | def _directory_type(path): 407 | 408 | path = os.path.expanduser(path) 409 | 410 | if not os.path.exists(path): 411 | if create_if_not_exists: 412 | 413 | parent_path = os.path.dirname(path) 414 | 415 | if parent_path and not os.path.exists(parent_path): 416 | parser.error('Cannot create empty directory (parent directory doesn\'t exist): %s' % path) 417 | else: 418 | os.mkdir(path) 419 | return path 420 | else: 421 | parser.error('Path doesn\'t exist: %s' % path) 422 | elif not os.path.isdir(path): 423 | parser.error('Not a directory: %s' % path) 424 | else: 425 | return path 426 | 427 | return _directory_type 428 | 429 | def add_parser_task_arguments(parser): 430 | parser.add_argument('--task-index', dest = 'task_index', metavar = '<0,...,N_TASKS-1>', type = int, default = None, help = 'If you want to ' + \ 431 | ' distribute this process across multiple computation resources (e.g. on a cluster) you can specify the total number of tasks ' + \ 432 | '(--total-tasks) to split it into, and the index of the current task to run (--task-index).') 433 | parser.add_argument('--total-tasks', dest = 'total_tasks', metavar = '', type = int, default = None, help = 'See --task-index.') 434 | parser.add_argument('--task-index-env-variable', dest = 'task_index_env_variable', metavar = '', type = str, default = None, \ 435 | help = 'Instead of specifying a hardcoded --task-index, you can specify an environtment variable to take it from (e.g. SLURM_ARRAY_TASK_ID ' + \ 436 | 'if you use SLURM to distribute the jobs).') 437 | parser.add_argument('--total-tasks-env-variable', dest = 'total_tasks_env_variable', metavar = '', type = str, \ 438 | default = None, help = 'Instead of specifying a hardcoded --total-tasks, you can specify an environtment variable to take it from (e.g. ' + \ 439 | 'SLURM_ARRAY_TASK_COUNT if you use SLURM to distribute the jobs).') 440 | 441 | def determine_parser_task_details(args): 442 | 443 | if args.task_index is not None and args.task_index_env_variable is not None: 444 | parser.error('You must choose between --task-index and --task-index-env-variable.') 445 | 446 | if args.task_index is not None: 447 | task_index = args.task_index 448 | elif args.task_index_env_variable is not None: 449 | task_index = int(os.getenv(args.task_index_env_variable)) 450 | else: 451 | task_index = None 452 | 453 | if args.total_tasks is not None and args.total_tasks_env_variable is not None: 454 | parser.error('You must choose between --total-tasks and --total-tasks-env-variable.') 455 | 456 | if args.total_tasks is not None: 457 | total_tasks = args.total_tasks 458 | elif args.total_tasks_env_variable is not None: 459 | total_tasks = int(os.getenv(args.total_tasks_env_variable)) 460 | else: 461 | total_tasks = None 462 | 463 | if task_index is None and total_tasks is None: 464 | task_index = 0 465 | total_tasks = 1 466 | elif task_index is None or total_tasks is None: 467 | parser.error('Task index and total tasks must either be specified or unspecified together.') 468 | 469 | if task_index < 0 or task_index >= total_tasks: 470 | parser.error('Task index must be in the range 0,...,(total tasks)-1.') 471 | 472 | return task_index, total_tasks 473 | 474 | 475 | ### Numpy ### 476 | 477 | def normalize(x): 478 | 479 | if isinstance(x, list): 480 | x = np.array(x) 481 | 482 | u = np.mean(x) 483 | sigma = np.std(x) 484 | 485 | if sigma == 0: 486 | return np.ones_like(x) 487 | else: 488 | return (x - u) / sigma 489 | 490 | def random_mask(size, n_trues): 491 | assert n_trues <= size 492 | mask = np.full(size, False) 493 | mask[:n_trues] = True 494 | np.random.shuffle(mask) 495 | return mask 496 | 497 | def indices_to_masks(n, indices): 498 | positive_mask = np.zeros(n, dtype = bool) 499 | positive_mask[indices] = True 500 | negative_mask = np.ones(n, dtype = bool) 501 | negative_mask[indices] = False 502 | return positive_mask, negative_mask 503 | 504 | def as_hot_encoding(values, value_to_index, n_values = None): 505 | 506 | if n_values is None: 507 | n_values = len(value_to_index) 508 | 509 | result = np.zeros(n_values) 510 | 511 | try: 512 | values = iter(values) 513 | except TypeError: 514 | values = iter([values]) 515 | 516 | for value in values: 517 | result[value_to_index[value]] += 1 518 | 519 | def is_full_rank(matrix): 520 | return np.linalg.matrix_rank(matrix) == min(matrix.shape) 521 | 522 | def find_linearly_independent_columns(matrix): 523 | 524 | ''' 525 | The calculation is fasciliated by the Gram Schmidt process, everytime taking the next column and removing its projections 526 | from all next columns, getting rid of columns which end up zero. 527 | ''' 528 | 529 | n_rows, n_cols = matrix.shape 530 | 531 | if np.linalg.matrix_rank(matrix) == n_cols: 532 | return np.arange(n_cols) 533 | 534 | orthogonalized_matrix = matrix.copy().astype(float) 535 | independent_columns = [] 536 | 537 | for i in range(n_cols): 538 | if not np.isclose(orthogonalized_matrix[:, i], 0).all(): 539 | 540 | independent_columns.append(i) 541 | 542 | if len(independent_columns) >= n_rows: 543 | break 544 | 545 | orthogonalized_matrix[:, i] = orthogonalized_matrix[:, i] / np.linalg.norm(orthogonalized_matrix[:, i]) 546 | 547 | if i < n_cols - 1: 548 | # Remove the projection of the ith column from all next columns 549 | orthogonalized_matrix[:, (i + 1):] -= np.dot(orthogonalized_matrix[:, i], \ 550 | orthogonalized_matrix[:, (i + 1):]).reshape(1, -1) * orthogonalized_matrix[:, i].reshape(-1, 1) 551 | 552 | return np.array(independent_columns) 553 | 554 | def transpose_dataset(src, dst, max_memory_bytes, flush_func = None): 555 | 556 | n_rows, n_cols = src.shape[:2] 557 | entry_nbytes = src[:1, :1].nbytes 558 | ideal_entries_per_chunk = max_memory_bytes / entry_nbytes 559 | ideal_chunk_size = np.sqrt(ideal_entries_per_chunk) 560 | 561 | if n_rows <= n_cols: 562 | row_chunk_size = min(int(ideal_chunk_size), n_rows) 563 | col_chunk_size = min(int(ideal_entries_per_chunk / row_chunk_size), n_cols) 564 | else: 565 | col_chunk_size = min(int(ideal_chunk_size), n_cols) 566 | row_chunk_size = min(int(ideal_entries_per_chunk / col_chunk_size), n_rows) 567 | 568 | log('Will use chunks of size %dx%d to transpose a %dx%d matrix.' % (row_chunk_size, col_chunk_size, n_rows, n_cols)) 569 | 570 | for row_start, row_end in get_chunk_intervals(n_rows, row_chunk_size): 571 | for col_start, col_end in get_chunk_intervals(n_cols, col_chunk_size): 572 | 573 | log('Transposing chunk (%d..%d)x(%d..%d)...' % (row_start, row_end - 1, col_start, col_end - 1)) 574 | dst[col_start:col_end, row_start:row_end] = src[row_start:row_end, col_start:col_end].transpose() 575 | 576 | if flush_func is not None: 577 | flush_func() 578 | 579 | log('Finished transposing.') 580 | 581 | 582 | ### Pandas ### 583 | 584 | def summarize(df, n = 5, sample = False): 585 | 586 | from IPython.display import display 587 | 588 | if sample: 589 | display(df.sample(n)) 590 | else: 591 | display(df.head(n)) 592 | 593 | print('%d records' % len(df)) 594 | 595 | def nullable_idxmin(series): 596 | 597 | result = series.idxmin() 598 | 599 | if pd.isnull(result): 600 | if len(series) == 0: 601 | return np.nan 602 | else: 603 | return series.index[0] 604 | else: 605 | return result 606 | 607 | def get_first_value(df): 608 | ''' 609 | Will return a Series with the same index. For each row the value will be that of the first column which is not null. 610 | ''' 611 | col_idxs = np.argmax(pd.notnull(df).values, axis = 1) 612 | return pd.Series(df.values[np.arange(len(df)), col_idxs], index = df.index) 613 | 614 | def slice_not_in_index(df_or_series, index_to_exclude): 615 | mask = pd.Series(True, index = df_or_series.index) 616 | mask.loc[index_to_exclude] = False 617 | return df_or_series.loc[mask] 618 | 619 | def swap_series_index_and_value(series): 620 | return pd.Series(series.index, index = series.values) 621 | 622 | def concat_dfs_with_partial_columns(dfs): 623 | columns = max([df.columns for df in dfs], key = len) 624 | assert all([set(df.columns) <= set(columns) for df in dfs]) 625 | return pd.concat(dfs, sort = False)[columns] 626 | 627 | def concat_dfs_with_compatible_columns(dfs): 628 | columns = merge_lists_with_compatible_relative_order([df.columns for df in dfs]) 629 | return pd.concat(dfs, sort = False)[columns] 630 | 631 | def safe_get_df_group(df_groupby, group_name): 632 | if group_name in df_groupby.groups: 633 | return df_groupby.get_group(group_name) 634 | else: 635 | _, some_group_df = next(iter(df_groupby)) 636 | return pd.DataFrame(columns = some_group_df.columns) 637 | 638 | def bin_groupby(df, series_or_col_name, n_bins): 639 | 640 | if len(df) == 0: 641 | return df 642 | 643 | if isinstance(series_or_col_name, str): 644 | series = df[series_or_col_name] 645 | else: 646 | series = series_or_col_name 647 | 648 | min_value, max_value = series.min(), series.max() 649 | bin_size = (max_value - min_value) / n_bins 650 | 651 | bind_ids = ((series - min_value) / bin_size).astype(int) 652 | bind_ids[bind_ids >= n_bins] = n_bins - 1 653 | 654 | return df.groupby(bind_ids) 655 | 656 | def value_df_to_hot_encoding_df(value_df, value_headers = {}): 657 | 658 | flat_values = value_df.values.flatten() 659 | all_values = sorted(np.unique(flat_values[pd.notnull(flat_values)])) 660 | value_to_index = {value: i for i, value in enumerate(all_values)} 661 | hot_encoding_matrix = np.zeros((len(value_df), len(all_values))) 662 | 663 | for _, column_values in value_df.iteritems(): 664 | row_position_to_value_index = column_values.reset_index(drop = True).dropna().map(value_to_index) 665 | hot_encoding_matrix[row_position_to_value_index.index.values, row_position_to_value_index.values] = 1 666 | 667 | headers = [value_headers.get(value, value) for value in all_values] 668 | return pd.DataFrame(hot_encoding_matrix, index = value_df.index, columns = headers) 669 | 670 | def set_series_to_hot_encoding_df(set_series, value_headers = {}): 671 | 672 | all_values = sorted(set.union(*set_series)) 673 | value_to_index = {value: i for i, value in enumerate(all_values)} 674 | hot_encoding_matrix = np.zeros((len(set_series), len(all_values))) 675 | 676 | for i, record_values in enumerate(set_series): 677 | hot_encoding_matrix[i, [value_to_index[value] for value in record_values]] = 1 678 | 679 | headers = [value_headers.get(value, value) for value in all_values] 680 | return pd.DataFrame(hot_encoding_matrix, index = set_series.index, columns = headers) 681 | 682 | def resolve_dummy_variable_trap(hot_encoding_df, validate_completeness = True, inplace = False, verbose = True): 683 | 684 | ''' 685 | When using one-hot-encoding in regression, there is a problem of encoding all possible variables if also using an intercept/const variable, 686 | because then the variables end up linearly dependent (a singular matrix is problematic with many implementations of regression). See for 687 | example: https://www.algosome.com/articles/dummy-variable-trap-regression.html 688 | To resolve this issue, this function will remove the most frequent column (to minimize the chance of any subset of the rows resulting a 689 | matrix which is not fully ranked). 690 | ''' 691 | 692 | # Validate we are indeed dealing with one-hot-encoding. 693 | assert set(np.unique(hot_encoding_df.values).astype(float)) <= {0.0, 1.0} 694 | 695 | if validate_completeness: 696 | assert (hot_encoding_df.sum(axis = 1) == 1).all() 697 | else: 698 | assert (hot_encoding_df.sum(axis = 1) <= 1).all() 699 | 700 | most_frequent_variable = hot_encoding_df.sum().idxmax() 701 | 702 | if verbose: 703 | log('To avoid the "dummy variable trap", removing the %s column (%d matching records).' % (most_frequent_variable, \ 704 | hot_encoding_df[most_frequent_variable].sum())) 705 | 706 | if inplace: 707 | del hot_encoding_df[most_frequent_variable] 708 | else: 709 | return hot_encoding_df[[column_name for column_name in hot_encoding_df.columns if column_name != most_frequent_variable]] 710 | 711 | def set_constant_row(df, row_mask, row_values): 712 | df[row_mask] = np.tile(row_values, (row_mask.sum(), 1)) 713 | 714 | def construct_df_from_rows(row_repertoire, row_indexer): 715 | 716 | result = pd.DataFrame(index = row_indexer.index, columns = row_repertoire.columns) 717 | 718 | for row_index, row_values in row_repertoire.iterrows(): 719 | set_constant_row(result, row_indexer == row_index, row_values) 720 | 721 | return result 722 | 723 | def get_row_last_values(df): 724 | 725 | result = pd.Series(np.nan, index = df.index) 726 | 727 | for column in df.columns[::-1]: 728 | result = result.where(pd.notnull(result), df[column]) 729 | 730 | return result 731 | 732 | def are_close_dfs(df1, df2, rtol = 1e-05, atol = 1e-08): 733 | 734 | assert (df1.dtypes == df2.dtypes).all() 735 | 736 | for column, dtype in df1.dtypes.iteritems(): 737 | 738 | if np.issubdtype(dtype, np.float): 739 | cmp_series = np.isclose(df1[column], df2[column], rtol = rtol, atol = atol) | (pd.isnull(df1[column]) & \ 740 | pd.isnull(df2[column])) 741 | else: 742 | cmp_series = (df1[column] == df2[column]) 743 | 744 | if not cmp_series.all(): 745 | return False 746 | 747 | return True 748 | 749 | def append_df_to_excel(excel_writer, df, sheet_name, index = True): 750 | 751 | header_format = excel_writer.book.add_format({'bold': True}) 752 | 753 | df.to_excel(excel_writer, sheet_name, index = index) 754 | worksheet = excel_writer.sheets[sheet_name] 755 | 756 | for column_index, column_name in enumerate(df.columns): 757 | worksheet.write(0, column_index + int(index), column_name, header_format) 758 | 759 | if index: 760 | for row_index_number, row_index_value in enumerate(df.index): 761 | worksheet.write(row_index_number + 1, 0, row_index_value) 762 | 763 | def is_binary_series(series): 764 | 765 | # First validating that the type of the series is convertable to float. 766 | try: 767 | float(series.iloc[0]) 768 | except TypeError: 769 | return False 770 | 771 | return set(series.unique().astype(float)) <= {0.0, 1.0} 772 | 773 | def resolve_quasi_complete_separation_by_removing_binary_columns(X, y): 774 | 775 | ''' 776 | When performing logistic regression of y against X, the matrix X must be of full rank; otherwise (i.e. if the columns of X are 777 | linearly dependent), then statsmodel's Logit model gives a singular-matrix error. It also appears that quasi-complete separation 778 | causes troubles, namely if the columns of X are linearly dependent conditioned on y. In other words, assuming that y is binary, 779 | we need that X[y, :] would still be of full rank (we assume that the vast majority of records have a negative y value, and only 780 | a small fraction have a positive value, so given that X is of full rank we need not worry about X[~y, :]). To resolve this problem, 781 | this function will remove binary columns of X until X[y, :] is of full rank. Whenever a column of X is removed, we also remove the 782 | corresponding records (rows of X and y) that have those values (so if a removed column represent some covariate, e.g. a certain 783 | batch, we also remove all the samples from this batch in order for not having any covariates not accounted for). 784 | @param X (pd.DataFrame): The exogenous variables (rows are records, columns are variables). 785 | @pram y (pd.Series): The endogenous variable (must have the same index as X). 786 | ''' 787 | 788 | row_mask = pd.Series(True, index = X.index) 789 | 790 | if not is_binary_series(y): 791 | return X, y, X.columns, set(), row_mask 792 | 793 | boolean_y = y.astype(bool) 794 | all_kept_binary_columns = np.array([column_name for column_name in X.columns if is_binary_series(X[column_name])]) 795 | # We sort the binary columns by how common they are, so when we start removing them, we will give priority to the more common ones 796 | # (i.e. remove the least frequent first). 797 | all_kept_binary_columns = X[all_kept_binary_columns].sum().sort_values(ascending = False).index 798 | all_removed_binary_columns = set() 799 | 800 | while len(all_kept_binary_columns) > 0: 801 | 802 | positive_X = X.loc[row_mask & boolean_y, all_kept_binary_columns] 803 | old_all_kept_binary_columns = all_kept_binary_columns 804 | all_kept_binary_columns = all_kept_binary_columns[find_linearly_independent_columns(positive_X.values)] 805 | columns_to_remove = set(old_all_kept_binary_columns) - set(all_kept_binary_columns) 806 | 807 | for column_name in columns_to_remove: 808 | log('Removing the columns %s (%d occurances) to avoid quasi-complete separation.' % (column_name, X[column_name].sum())) 809 | all_removed_binary_columns.add(column_name) 810 | row_mask &= (~X[column_name].astype(bool)) 811 | 812 | if len(columns_to_remove) == 0: 813 | break 814 | 815 | if not row_mask.all(): 816 | log('Overall removed %d columns occuring in %d records to avoid quasi-complete separation.' % (len(all_removed_binary_columns), \ 817 | (~row_mask).sum())) 818 | 819 | retained_columns = [column_name for column_name in X.columns if column_name not in all_removed_binary_columns] 820 | X = X.loc[row_mask, retained_columns] 821 | y = y.loc[row_mask] 822 | 823 | return X, y, retained_columns, all_removed_binary_columns, row_mask 824 | 825 | 826 | ### Statistics ### 827 | 828 | def to_normal_z_values(raw_values): 829 | 830 | from scipy.stats import rankdata, norm 831 | 832 | pvals = (rankdata(raw_values) - 0.5) / len(raw_values) 833 | normal_z_values = norm.ppf(pvals) 834 | 835 | if isinstance(raw_values, pd.Series): 836 | return pd.Series(normal_z_values, index = raw_values.index) 837 | else: 838 | return normal_z_values 839 | 840 | def multipletests_with_nulls(values, method = 'fdr_bh'): 841 | 842 | from statsmodels.stats.multitest import multipletests 843 | 844 | significance = np.zeros(len(values), dtype = bool) 845 | qvals = np.nan * np.empty(len(values)) 846 | mask = pd.notnull(values) 847 | 848 | if mask.any(): 849 | significance[np.array(mask)], qvals[np.array(mask)], _, _ = multipletests(values[mask], method = method) 850 | 851 | return significance, qvals 852 | 853 | def test_enrichment(mask1, mask2): 854 | 855 | from scipy.stats import fisher_exact 856 | 857 | assert len(mask1) == len(mask2) 858 | 859 | n1 = mask1.sum() 860 | n2 = mask2.sum() 861 | n_both = (mask1 & mask2).sum() 862 | n_total = len(mask1) 863 | n_expected = n1 * n2 / n_total 864 | enrichment_factor = n_both / n_expected 865 | 866 | contingency_table = np.array([ 867 | [(mask1 & mask2).sum(), (mask1 & (~mask2)).sum()], 868 | [((~mask1) & mask2).sum(), ((~mask1) & (~mask2)).sum()], 869 | ]) 870 | _, pval = fisher_exact(contingency_table) 871 | 872 | return n1, n2, n_both, n_total, n_expected, enrichment_factor, contingency_table, pval 873 | 874 | def test_enrichment_sets(set1, set2, n_total): 875 | 876 | from scipy.stats import fisher_exact 877 | 878 | n1 = len(set1) 879 | n2 = len(set2) 880 | n_both = len(set1 & set2) 881 | n_expected = n1 * n2 / n_total 882 | enrichment_factor = n_both / n_expected 883 | 884 | contingency_table = np.array([ 885 | [n_both, n1 - n_both], 886 | [n2 - n_both, n_total - n1 - n2 + n_both], 887 | ]) 888 | _, pval = fisher_exact(contingency_table) 889 | 890 | return n1, n2, n_both, n_total, n_expected, enrichment_factor, contingency_table, pval 891 | 892 | 893 | ### h5f ### 894 | 895 | def flush_h5_file(h5f): 896 | h5f.flush() 897 | os.fsync(h5f.id.get_vfd_handle()) 898 | 899 | def transpose_h5f_dataset(h5f, src_name, dst_name, max_memory_bytes): 900 | flush_func = lambda: flush_h5_file(h5f) 901 | src = h5f[src_name] 902 | nrows, ncols = src.shape[:2] 903 | dst = h5f.create_dataset(dst_name, shape = (ncols, nrows), dtype = src.dtype) 904 | transpose_dataset(src, dst, max_memory_bytes, flush_func) 905 | 906 | 907 | ### Matplotlib ### 908 | 909 | def draw_rectangle(ax, start_x, end_x, start_y, end_y, **kwargs): 910 | from matplotlib import patches 911 | ax.add_patch(patches.Rectangle((start_x, start_y), end_x - start_x, end_y - start_y, **kwargs)) 912 | 913 | def set_ax_border_color(ax, color): 914 | 915 | import matplotlib.pyplot as plt 916 | 917 | for child in ax.get_children(): 918 | if isinstance(child, plt.matplotlib.spines.Spine): 919 | child.set_color(color) 920 | 921 | def plot_prediction_scatter(y_pred, y_true, value = 'value'): 922 | 923 | import matplotlib.pyplot as plt 924 | 925 | log(pearsonr(y_pred, y_true)) 926 | log(spearmanr(y_pred, y_true)) 927 | 928 | fig, ax = plt.subplots(figsize = (10, 6)) 929 | ax.scatter(y_pred, y_true) 930 | ax.set_xlabel('Predicted %s' % value) 931 | ax.set_ylabel('Actual %s' % value) 932 | 933 | def draw_pvals_qq_plot(pvals, max_density = 100, min_pval = None, ax = None, figsize = (7, 7), scatter_options = {}, \ 934 | xlabel = 'Expected p-values (-log10)', ylabel = 'Observed p-values (-log10)'): 935 | 936 | import matplotlib.pyplot as plt 937 | 938 | if 'color' not in scatter_options: 939 | scatter_options['color'] = '#2e75b6' 940 | 941 | pvals = np.array(pvals) 942 | 943 | if min_pval is not None: 944 | pvals = np.maximum(pvals, min_pval) 945 | 946 | n_total_pvals = len(pvals) 947 | sorted_mlog_pvals = np.sort(-np.log10(pvals)) 948 | max_mlog_pval = sorted_mlog_pvals.max() 949 | 950 | if ax is None: 951 | _, ax = plt.subplots(figsize = figsize) 952 | 953 | ax.plot([0, max_mlog_pval], [0, max_mlog_pval], color = 'red', linestyle = '--', alpha = 0.5) 954 | ax.set_xlim((0, max_mlog_pval)) 955 | ax.set_ylim((0, max_mlog_pval)) 956 | ax.set_xlabel(xlabel) 957 | ax.set_ylabel(ylabel) 958 | 959 | for upper_limit in range(1, int(max_mlog_pval + 3)): 960 | 961 | n_remained_pvals = len(sorted_mlog_pvals) 962 | i = np.searchsorted(sorted_mlog_pvals, upper_limit) 963 | range_pvals = sorted_mlog_pvals[:i] 964 | sorted_mlog_pvals = sorted_mlog_pvals[i:] 965 | 966 | if len(range_pvals) > 0: 967 | 968 | if len(range_pvals) <= max_density: 969 | range_chosen_indices = np.arange(len(range_pvals)) 970 | else: 971 | # We want to choose the p-values uniformly in the space of their expected frequencies (i.e. sampling more towards the higher end of the 972 | # spectrum). 973 | range_min_mlog_freq = -np.log10(n_remained_pvals / n_total_pvals) 974 | range_max_mlog_freq = -np.log10((n_remained_pvals - len(range_pvals) + 1) / n_total_pvals) 975 | range_chosen_mlog_freqs = np.linspace(range_min_mlog_freq, range_max_mlog_freq, max_density) 976 | range_chosen_freqs = np.power(10, -range_chosen_mlog_freqs) 977 | # Once having the desired freqs, reverse the function to get the indices that provide them 978 | range_chosen_indices = np.unique((n_remained_pvals - n_total_pvals * range_chosen_freqs).astype(int)) 979 | 980 | range_pvals = range_pvals[range_chosen_indices] 981 | range_freqs = (n_remained_pvals - range_chosen_indices) / n_total_pvals 982 | range_mlog_freqs = -np.log10(range_freqs) 983 | ax.scatter(range_mlog_freqs, range_pvals, **scatter_options) 984 | 985 | def draw_manhattan_plot(gwas_results, significance_treshold = 5e-08, max_results_to_plot = 1e06, \ 986 | pval_threshold_to_force_inclusion = 1e-03, min_pval = 1e-300, ax = None, figsize = (12, 6), \ 987 | s = 1.5, chrom_to_color = None): 988 | 989 | ''' 990 | gwas_results (pd.DataFrame): Should have the following columns: 991 | - chromosome (str) 992 | - position (int) 993 | - pval (float) 994 | ''' 995 | 996 | import matplotlib.pyplot as plt 997 | 998 | CHROMS = list(map(str, range(1, 23))) + ['X', 'Y'] 999 | CHROM_TO_COLOR = {'1': '#0100fb', '2': '#ffff00', '3': '#00ff03', '4': '#bfbfbf', '5': '#acdae9', '6': '#a020f1', 1000 | '7': '#ffa502', '8': '#ff00fe', '9': '#fe0000', '10': '#90ee90', '11': '#a52929', '12': '#000000', 1001 | '13': '#ffbfcf', '14': '#4484b2', '15': '#b63063', '16': '#f8816f', '17': '#ed84f3', '18': '#006401', 1002 | '19': '#020184', '20': '#ced000', '21': '#cd0001', '22': '#050098', 'X': '#505050', 'Y': '#ff8000'} 1003 | 1004 | if chrom_to_color is None: 1005 | chrom_to_color = CHROM_TO_COLOR 1006 | 1007 | if len(gwas_results) > max_results_to_plot: 1008 | mask = pd.Series(random_mask(len(gwas_results), int(max_results_to_plot)), index = gwas_results.index) 1009 | mask[gwas_results['pval'] <= pval_threshold_to_force_inclusion] = True 1010 | gwas_results = gwas_results[mask] 1011 | 1012 | max_pos_per_chrom = gwas_results.groupby('chromosome')['position'].max() 1013 | accumulating_pos = 0 1014 | chrom_accumulating_positions = [] 1015 | 1016 | for chrom in CHROMS: 1017 | if chrom in max_pos_per_chrom.index: 1018 | chrom_accumulating_positions.append((chrom, accumulating_pos + 1, accumulating_pos + max_pos_per_chrom[chrom])) 1019 | accumulating_pos += max_pos_per_chrom[chrom] 1020 | 1021 | chrom_accumulating_positions = pd.DataFrame(chrom_accumulating_positions, columns = ['chromosome', \ 1022 | 'accumulating_start_position', 'accumulating_end_position']).set_index('chromosome', drop = True) 1023 | chrom_middle_accumulating_positions = (chrom_accumulating_positions['accumulating_start_position'] + \ 1024 | chrom_accumulating_positions['accumulating_end_position']) / 2 1025 | 1026 | if ax is None: 1027 | _, ax = plt.subplots(figsize = figsize) 1028 | 1029 | ax.set_facecolor('white') 1030 | plt.setp(ax.spines.values(), color = '#444444') 1031 | ax.grid(False) 1032 | 1033 | if significance_treshold is not None: 1034 | ax.axhline(y = -np.log10(significance_treshold), linestyle = '--', linewidth = 1, color = 'red') 1035 | 1036 | gwas_results_per_chrom = gwas_results.groupby('chromosome') 1037 | max_y = 0 1038 | 1039 | for chrom in chrom_accumulating_positions.index: 1040 | chrom_gwas_results = gwas_results_per_chrom.get_group(chrom) 1041 | chrom_gwas_accumulating_positions = chrom_accumulating_positions.loc[chrom, 'accumulating_start_position'] + \ 1042 | chrom_gwas_results['position'] 1043 | chrom_gwas_minus_log_pval = -np.log10(np.maximum(chrom_gwas_results['pval'], min_pval)) 1044 | max_y = max(max_y, chrom_gwas_minus_log_pval.max()) 1045 | ax.scatter(chrom_gwas_accumulating_positions, chrom_gwas_minus_log_pval, color = chrom_to_color[chrom], s = s) 1046 | 1047 | ax.set_xlabel('Chromosome') 1048 | ax.set_ylabel('-log10(p-value)') 1049 | ax.set_xticks(chrom_middle_accumulating_positions) 1050 | ax.set_xticklabels(chrom_middle_accumulating_positions.index) 1051 | ax.set_xlim(1, accumulating_pos) 1052 | ax.set_ylim(0, max_y + 1) 1053 | 1054 | return ax 1055 | 1056 | 1057 | ### Biopython Helper Functions ### 1058 | 1059 | def as_biopython_seq(seq): 1060 | 1061 | from Bio.Seq import Seq 1062 | 1063 | if isinstance(seq, Seq): 1064 | return seq 1065 | elif isinstance(seq, str): 1066 | return Seq(seq) 1067 | else: 1068 | raise Exception('Cannot resolve type %s as Biopython Seq' % type(seq)) 1069 | 1070 | 1071 | ### Slurm ### 1072 | 1073 | def get_slurm_job_array_ids(parse_total_tasks_by_max_variable = True, log_ids = True, verbose = True, task_index_remapping_json_file_path = None): 1074 | 1075 | job_id = int(os.getenv('SLURM_ARRAY_JOB_ID')) 1076 | task_index = int(os.getenv('SLURM_ARRAY_TASK_ID')) 1077 | 1078 | if 'TASK_ID_OFFSET' in os.environ: 1079 | 1080 | task_offset = int(os.getenv('TASK_ID_OFFSET')) 1081 | 1082 | if verbose: 1083 | log('Raw task index %d with offset %d.' % (task_index, task_offset)) 1084 | 1085 | task_index += task_offset 1086 | 1087 | if task_index_remapping_json_file_path is not None: 1088 | 1089 | with open(task_index_remapping_json_file_path, 'r') as f: 1090 | task_index_remapping = json.load(f) 1091 | 1092 | remapped_task_index = task_index_remapping[task_index] 1093 | 1094 | if verbose: 1095 | log('Remapped task index %d into %d.' % (task_index, remapped_task_index)) 1096 | 1097 | task_index = remapped_task_index 1098 | 1099 | if 'TOTAL_TASKS' in os.environ: 1100 | total_tasks = int(os.getenv('TOTAL_TASKS')) 1101 | elif parse_total_tasks_by_max_variable: 1102 | total_tasks = int(os.getenv('SLURM_ARRAY_TASK_MAX')) + 1 1103 | else: 1104 | total_tasks = int(os.getenv('SLURM_ARRAY_TASK_COUNT')) 1105 | 1106 | if log_ids: 1107 | log('Running job %s, task %d of %d.' % (job_id, task_index, total_tasks)) 1108 | 1109 | return job_id, total_tasks, task_index 1110 | 1111 | 1112 | ### Liftover ### 1113 | 1114 | def liftover_locus(liftover, chrom, pos): 1115 | try: 1116 | 1117 | pos = int(pos) 1118 | 1119 | if not isinstance(chrom, str) or not chrom.startswith('chr'): 1120 | chrom = 'chr%s' % chrom 1121 | 1122 | (new_chrom, new_pos, _, _), = liftover.convert_coordinate(chrom, pos) 1123 | 1124 | if new_chrom.startswith('chr'): 1125 | new_chrom = new_chrom[3:] 1126 | 1127 | return new_chrom, new_pos 1128 | except: 1129 | return np.nan, np.nan 1130 | 1131 | def liftover_loci_in_df(df, chrom_column = 'chromosome', pos_column = 'position', source_ref_genome = 'hg38', \ 1132 | target_ref_genome = 'hg19'): 1133 | 1134 | from pyliftover import LiftOver 1135 | 1136 | liftover = LiftOver(source_ref_genome, target_ref_genome) 1137 | new_loci = [] 1138 | 1139 | for _, (chrom, pos) in df[[chrom_column, pos_column]].iterrows(): 1140 | new_loci.append(liftover_locus(liftover, chrom, pos)) 1141 | 1142 | new_chroms, new_positions = (pd.Series(list(values), index = df.index) for values in zip(*new_loci)) 1143 | return pd.concat([new_chroms.rename(chrom_column) if column == chrom_column else (new_positions.rename(pos_column) if \ 1144 | column == pos_column else df[column]) for column in df.columns], axis = 1) 1145 | 1146 | 1147 | ### Helper classes ### 1148 | 1149 | class DummyContext(object): 1150 | 1151 | def __enter__(self): 1152 | pass 1153 | 1154 | def __exit__(self, exc_type, exc_value, exc_traceback): 1155 | pass 1156 | 1157 | class TimeMeasure(object): 1158 | 1159 | def __init__(self, opening_statement): 1160 | self.opening_statement = opening_statement 1161 | 1162 | def __enter__(self): 1163 | self.start_time = datetime.now() 1164 | log(self.opening_statement) 1165 | 1166 | def __exit__(self, exc_type, exc_value, exc_traceback): 1167 | self.finish_time = datetime.now() 1168 | self.elapsed_time = self.finish_time - self.start_time 1169 | log('Finished after %s.' % self.elapsed_time) 1170 | 1171 | class Profiler(object): 1172 | 1173 | def __init__(self): 1174 | self.creation_time = datetime.now() 1175 | self.profiles = defaultdict(Profiler.Profile) 1176 | 1177 | def measure(self, profile_name): 1178 | return self.profiles[profile_name].measure() 1179 | 1180 | def format(self, delimiter = '\n'): 1181 | all_profiles = list(self.profiles.items()) + [('Total', Profiler.Profile(total_invokes = 1, total_time = datetime.now() - self.creation_time))] 1182 | sorted_profiles = sorted(all_profiles, key = lambda profile_tuple: profile_tuple[1].total_time, reverse = True) 1183 | return delimiter.join(['%s: %s' % (profile_name, profile) for profile_name, profile in sorted_profiles]) 1184 | 1185 | def __repr__(self): 1186 | return self.format() 1187 | 1188 | class Profile(object): 1189 | 1190 | def __init__(self, total_invokes = 0, total_time = timedelta(0)): 1191 | self.total_invokes = total_invokes 1192 | self.total_time = total_time 1193 | 1194 | def measure(self): 1195 | return Profiler._Measurement(self) 1196 | 1197 | def __repr__(self): 1198 | return '%s (%d times)' % (self.total_time, self.total_invokes) 1199 | 1200 | class _Measurement(object): 1201 | 1202 | def __init__(self, profile): 1203 | self.profile = profile 1204 | 1205 | def __enter__(self): 1206 | self.start_time = datetime.now() 1207 | 1208 | def __exit__(self, exc_type, exc_value, exc_traceback): 1209 | self.profile.total_time += (datetime.now() - self.start_time) 1210 | self.profile.total_invokes += 1 1211 | --------------------------------------------------------------------------------