├── LICENSE ├── README.md ├── assets └── overview_trinityLLM.png ├── data └── README.md ├── environment.md ├── finetune ├── args.py ├── bert_vocab.txt ├── finetune_pubchem_light.py ├── finetune_pubchem_light_classification.py ├── finetune_pubchem_light_classification_multitask.py ├── fire_finetune_pubchem_light_relL2Loss_LRSchedule.py ├── fire_finetune_pubchem_light_relL2Loss_LRSchedule_classify.py ├── fire_molformer_predict.py ├── fire_predict_conductivity.sh ├── fire_predict_density.sh ├── fire_predict_qpua_pk50.sh ├── fire_predict_sea.sh ├── fire_predict_specific_heat.sh ├── fire_predict_tig50.sh ├── fire_run_finetune_conductivity.sh ├── fire_run_finetune_density.sh ├── fire_run_finetune_qpua_pk50.sh ├── fire_run_finetune_sea.sh ├── fire_run_finetune_specific_heat.sh ├── fire_run_finetune_tig50.sh ├── rotate_attention │ ├── attention_layer.py │ ├── rotary.py │ └── rotate_builder.py ├── tokenizer │ └── tokenizer.py └── utils.py ├── notebooks ├── full_attention_rotary │ ├── args.py │ ├── attention_analysis_rotary_full.ipynb │ ├── bert_vocab.txt │ ├── get_attention_map_full.py │ ├── rotate_attention │ │ ├── attention_layer.py │ │ ├── full_attention.py │ │ ├── linear_attention.py │ │ ├── rotary.py │ │ ├── rotate_builder.py │ │ └── transformers.py │ ├── tokenizer │ │ └── tokenizer.py │ └── utils.py ├── linear_attention_rotary │ ├── args.py │ ├── attention_analysis_rotary_linear.ipynb │ ├── bert_vocab.txt │ ├── get_attention_map_full.py │ ├── rotate_attention │ │ ├── attention_layer.py │ │ ├── full_attention.py │ │ ├── linear_attention.py │ │ ├── rotary.py │ │ ├── rotate_builder.py │ │ └── transformers.py │ ├── tokenizer │ │ └── tokenizer.py │ └── utils.py └── pretrained_molformer │ ├── args.py │ ├── bert_vocab.txt │ ├── frozen_embeddings_classification.ipynb │ ├── pubchem_canon_zinc_final_vocab_sorted.pth │ ├── pubchem_encoder.py │ ├── rotate_attention │ ├── attention_layer.py │ ├── rotary.py │ └── rotate_builder.py │ ├── tokenizer │ └── tokenizer.py │ └── train_pubchem_light.py └── training ├── Dockerfile ├── args.py ├── bert_vocab.txt ├── continue_pubchem_light_distributed_multinode_bert_both.sh ├── dataset_pubchem.py ├── environ.sh ├── pubchem_canon_script.py ├── pubchem_canon_zinc_final_vocab_sorted.pth ├── pubchem_encoder.py ├── pubchem_script.py ├── rotate_attention ├── attention_layer.py ├── rotary.py └── rotate_builder.py ├── run_pubchem_light.sh ├── tokenizer └── tokenizer.py ├── train_pubchem_light.py ├── utils.py ├── val_script.py └── zinc_script.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Ning Liu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.14278416.svg)](https://doi.org/10.5281/zenodo.14278416) 2 | 3 | # Large language models, physics-based modeling, experimental measurements: the trinity of data-scarce learning of polymer properties 4 | 5 | This is the official implementation of TrinityLLM: 6 | 7 | Ning Liu, Siavash Jafarzadeh, Brian Y Lattimer, Shuna Ni, Jim Lua, Yue Yu. "[Large language models, physics-based modeling, experimental measurements: the trinity of data-scarce learning of polymer properties](https://arxiv.org/abs/2407.02770)". 8 | 9 | 10 | ![](assets/overview_trinityLLM.png) 11 | 12 | Large language models (LLMs) bear promise as a fast and accurate material modeling paradigm for evaluation, analysis, and design. Their vast number of trainable parameters necessitates a wealth of data to achieve accuracy and mitigate overfitting. However, experimental measurements are often limited and costly to obtain in sufficient quantities for finetuning. To this end, we present a physics-based training pipeline that tackles the pathology of data scarcity. The core enabler is a physics-based modeling framework that generates a multitude of synthetic data to align the LLM to a physically consistent initial state before finetuning. Our framework features a two-phase training strategy: (1) utilizing the large-in-amount while less accurate synthetic data for supervised pretraining, and (2) finetuning the phase-1 model with limited experimental data. We empirically demonstrate that supervised pretraining is vital to obtaining accurate finetuned LLMs, via the lens of learning polymer flammability metrics where cone calorimeter data is sparse. 13 | 14 | ## Requirements 15 | 16 | We use the apex.optimizers in our code, and thus Apex needs to be compiled from source. A step-by-step instruction is provided in [environment.md](environment.md) 17 | 18 | 19 | ## Pretraining Datasets 20 | 21 | The pretraining datasets are available via the [MoLFormer repo](https://github.com/IBM/molformer) and can be downloaded [here](https://ibm.box.com/v/MoLFormer-data). We briefly reproduce the instructions below. 22 | 23 | Due to the large nature of the combination of the PubChem and Zinc (over 1.1 billion molecules in total) datasets, the code expects the data to be in a certain location and format. The details of this processing is documented below for each individaul dataset. 24 | 25 | The code expects both the zinc15(ZINC) and pubchem datasets to be located in he ```./data/``` directory of the training directory. 26 | * Zinc15 itself should be located in ```data/ZINC/``` and is expected to be processed in multiple smi files which contains one smiles string per line. 27 | * PubChem should be located in ```data/pubchem/``` and is expected to be processed as a single “CID-SMILES” text file with 2 columns (index and smiles string). We take the raw Pubchem dataset and convert every smiles molecule into the canonical form, utilizing rdkit, as well as trimmed down the file itself. Our dataloader expects Pubchem to be in our converted form and will not run on the raw pubchem file. 28 | 29 | ``` 30 | data/ 31 | ├── pubchem 32 | │   └── CID-SMILES-CANONICAL.smi 33 | └── ZINC 34 | ├── AAAA.smi 35 | ├── AAAB.smi 36 | ├── AAAC.smi 37 | ├── AAAD.smi 38 | ├── AABA.smi 39 | ├── AABB.smi 40 | ├── AABD.smi 41 | ├── AACA.smi 42 | ├── AACB.smi 43 | ├── AAEA.smi 44 | ├── AAEB.smi 45 | ├── AAED.smi 46 | ├── ABAA.smi 47 | ├── ABAB.smi 48 | ├── ABAC.smi 49 | ├── ABAD.smi 50 | ├── ABBA.smi 51 | ├── ABBB.smi 52 | ├── ABBD.smi 53 | ├── ABCA.smi 54 | ├── ABCB.smi 55 | ├── ABCD.smi 56 | ├── ABEA.smi 57 | ├── ABEB.smi 58 | ├── ABEC.smi 59 | ├── ABED.smi 60 | ├── ACAA.smi 61 | ├── ACAB.smi 62 | ``` 63 | 64 | #### Pretrained Models and training logs 65 | The checkpoints of MoLFormer encoders pre-trained on a dataset of ~100M molecules are available for [download](https://ibm.box.com/v/MoLFormer-data). This dataset combines 10% of Zinc and 10% of PubChem molecules used for MoLFormer-XL training. 66 | 67 | Extract `Pretrained MoLFormer.zip` containing the pretrained models and the associated training logs to the `./data/` directory. 68 | The hierarchy looks like the following: 69 | 70 | ``` 71 | data/ 72 | ├── Pretrained MoLFormer 73 | │ ├── checkpoints 74 | │ │ ├── N-Step-Checkpoint_0_0.ckpt 75 | │ │ ├── N-Step-Checkpoint_0_5000.ckpt 76 | │ │ ├── N-Step-Checkpoint_1_10000.ckpt 77 | │ │ ├── N-Step-Checkpoint_1_15000.ckpt 78 | │ │ ├── N-Step-Checkpoint_2_20000.ckpt 79 | │ │ ├── N-Step-Checkpoint_3_25000.ckpt 80 | │ │ └── N-Step-Checkpoint_3_30000.ckpt 81 | │ ├── events.out.tfevents.1643396916.cccxc543.3427421.0 82 | │ └── hparams.yaml 83 | ``` 84 | 85 | ## Pretraining 86 | To train a model from scratch, the masked language model approach is utilized. 87 | 88 | Note that the trinityLLM framework can be applied to any polymer LLM. We use MoLFormer as an example in the following demonstration. MoLFormer is pre-trained on canonicalized SMILES of >1 B molecules from ZINC and PubChem with the following constraints: 89 | 90 | During pre-processing, the compounds are filtered to keep a maximum length of 211 characters. A 100/0/0 split was used for training, validation, and test, i.e. we used all the data for training the model. As a confidence test, we would evaluate the model at the end of each epoch with additional independent data. Data canonicalization is performed using RDKit. 91 | 92 | The pre-training code provides an example of data processing and training of a model trained on a smaller pre-training dataset size, which requires 16 v100 GPUs. 93 | 94 | To train a model, run: 95 | 96 | > bash run_pubchem_light.sh 97 | 98 | ## Finetuning 99 | 100 | The finetuning related datasets need to be put in the corresponding folder as specified in the bash script below. Once you have the environment set up, you can run a fine-tune task by running 101 | 102 | > bash fire_run_finetune_density.sh 103 | 104 | Finetuning training/checkpointing resources will be available in the directory named ```checkpoint_```. The path to the results csv will be in the form of ```./checkpoint_//results/results_.csv```. The ```results_.csv``` file contains 4 columns of data. Column one contains the validation score for each epoch while column 2 contains the test score for each epoch. Column 3 contains the best validation score observed up to that point of fine tuning while column 4 is the test score of the epoch which had the best validation score. 105 | 106 | To perform the two-phase finetuning as discussed in the paper, simply pretrain a LLM model using physics-based synthetic data in the first phase, and load the trained model in the second phase and use limited experimental data for further finetuning. 107 | 108 | ## Running trained TrinityLLM models 109 | 110 | The trained models in the paper are available for download [here](https://drive.google.com/drive/folders/1l_GlVGW49iqu1Q8QpcJxEveaomjy2Z8a?usp=sharing). 111 | 112 | To load a trained model and directly make predictions (note that the trained checkpoints need to be put in the corresponding folder as specified in the bash file), run: 113 | 114 | > bash fire_predict_density.sh 115 | 116 | ## Citation 117 | 118 | If you find this repository useful, please consider citing our paper: 119 | ```latex 120 | @article{liu2025harnessing, 121 | title={Harnessing large language models for data-scarce learning of polymer properties}, 122 | author={Liu, Ning and Jafarzadeh, Siavash and Lattimer, Brian Y and Ni, Shuna and Lua, Jim and Yu, Yue}, 123 | journal={Nature Computational Science}, 124 | pages={1--10}, 125 | year={2025}, 126 | publisher={Nature Publishing Group US New York} 127 | } 128 | ``` 129 | 130 | ## Acknowledgments 131 | This repository is constructed using the codebase from [MoLFormer](https://github.com/IBM/molformer). 132 | -------------------------------------------------------------------------------- /assets/overview_trinityLLM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ningliu-iga/TrinityLLM/5659e6e7d7dfb748a338a88f1da1e3f85a2676f7/assets/overview_trinityLLM.png -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | All Datasets and Models can be downloaded from [https://ibm.box.com/v/MoLFormer-data](https://ibm.box.com/v/MoLFormer-data) and should be placed in this directory according to the instruction in the main [README.md](../README.md) 2 | -------------------------------------------------------------------------------- /environment.md: -------------------------------------------------------------------------------- 1 | # Python Environment 2 | 3 | ## Conda Create and Activate Environment 4 | 5 | ``` 6 | conda create --name MolTran_CUDA11 python=3.8.10 7 | conda activate MolTran_CUDA11 8 | ``` 9 | 10 | ## Conda Install Packages 11 | ``` 12 | conda install pytorch==1.7.1 cudatoolkit=11.0 -c pytorch 13 | conda install numpy=1.22.3 pandas=1.2.4 scikit-learn=0.24.2 scipy=1.6.2 14 | conda install rdkit==2022.03.2 -c conda-forge 15 | ``` 16 | 17 | ## Pip install Packages 18 | ``` 19 | pip install transformers==4.6.0 pytorch-lightning==1.1.5 pytorch-fast-transformers==0.4.0 datasets==1.6.2 jupyterlab==3.4.0 ipywidgets==7.7.0 bertviz==1.4.0 20 | ``` 21 | 22 | ## Compile Apex from source 23 | 24 | Due to the use of [Apex Optimizers](https://nvidia.github.io/apex/optimizers.html), Apex must be compiled with CUDA and C++ extensions via 25 | 26 | 27 | ``` 28 | git clone https://github.com/NVIDIA/apex 29 | cd apex 30 | git checkout tags/22.03 -b v22.03 31 | export CUDA_HOME='Cuda 11 install' 32 | pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./ 33 | ``` 34 | -------------------------------------------------------------------------------- /finetune/args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def get_parser(parser=None): 5 | if parser is None: 6 | parser = argparse.ArgumentParser() 7 | 8 | # Model 9 | #model_arg = parser.add_argument_group('Model') 10 | parser.add_argument('--n_head', 11 | type=int, default=8, 12 | help='GPT number of heads') 13 | parser.add_argument('--fold', 14 | type=int, default=0, 15 | help='number of folds for fine tuning') 16 | parser.add_argument('--n_layer', 17 | type=int, default=12, 18 | help='GPT number of layers') 19 | parser.add_argument('--d_dropout', 20 | type=float, default=0.1, 21 | help='Decoder layers dropout') 22 | parser.add_argument('--n_embd', 23 | type=int, default=768, 24 | help='Latent vector dimensionality') 25 | parser.add_argument('--fc_h', 26 | type=int, default=512, 27 | help='Fully connected hidden dimensionality') 28 | 29 | 30 | # Train 31 | #train_arg = parser.add_argument_group('Train') 32 | parser.add_argument('--n_batch', 33 | type=int, default=512, 34 | help='Batch size') 35 | parser.add_argument('--from_scratch', 36 | action='store_true', default=False, 37 | help='train on qm9 from scratch') 38 | parser.add_argument('--checkpoint_every', 39 | type=int, default=1000, 40 | help='save checkpoint every x iterations') 41 | parser.add_argument('--lr_start', 42 | type=float, default=3 * 1e-4, 43 | help='Initial lr value') 44 | parser.add_argument('--lr_multiplier', 45 | type=int, default=1, 46 | help='lr weight multiplier') 47 | parser.add_argument('--n_jobs', 48 | type=int, default=1, 49 | help='Number of threads') 50 | parser.add_argument('--device', 51 | type=str, default='cuda', 52 | help='Device to run: "cpu" or "cuda:"') 53 | parser.add_argument('--seed', 54 | type=int, default=12345, 55 | help='Seed') 56 | 57 | parser.add_argument('--seed_path', 58 | type=str, default="", 59 | help='path to trainer file to continue training') 60 | 61 | parser.add_argument('--num_feats', 62 | type=int, required=False, default=32, 63 | help='number of random reatures for FAVOR+') 64 | parser.add_argument('--max_epochs', 65 | type=int, required=False, default=1, 66 | help='max number of epochs') 67 | 68 | # debug() FINE TUNEING 69 | # parser.add_argument('--save_dir', type=str, required=True) 70 | parser.add_argument('--mode', 71 | type=str, default='avg', 72 | help='type of pooling to use') 73 | parser.add_argument("--train_dataset_length", type=int, default=None, required=False) 74 | parser.add_argument("--eval_dataset_length", type=int, default=None, required=False) 75 | parser.add_argument("--desc_skip_connection", type=bool, default=False, required=False) 76 | parser.add_argument("--num_workers", type=int, default=8, required=False) 77 | parser.add_argument("--dropout", type=float, default=0.1, required=False) 78 | parser.add_argument("--dims", type=int, nargs="*", default="[]", required=False) 79 | parser.add_argument( 80 | "--smiles_embedding", 81 | type=str, 82 | default="/dccstor/medscan7/smallmolecule/runs/ba-predictor/small-data/embeddings/protein/ba_embeddings_tanh_512_2986138_2.pt", 83 | ) 84 | # parser.add_argument("--train_pct", type=str, required=False, default="95") 85 | parser.add_argument("--aug", type=int, required=False) 86 | parser.add_argument("--num_classes", type=int, required=False) 87 | parser.add_argument("--dataset_name", type=str, required=False, default="sol") 88 | parser.add_argument("--measure_name", type=str, required=False, default="measure") 89 | parser.add_argument("--checkpoints_folder", type=str, required=True) 90 | parser.add_argument("--checkpoint_root", type=str, required=False) 91 | 92 | parser.add_argument( 93 | "--data_root", 94 | type=str, 95 | required=False, 96 | default="/dccstor/medscan7/smallmolecule/runs/ba-predictor/small-data/affinity", 97 | ) 98 | parser.add_argument("--batch_size", type=int, default=64) 99 | 100 | return parser 101 | def parse_args(): 102 | parser = get_parser() 103 | args = parser.parse_args() 104 | return args 105 | 106 | -------------------------------------------------------------------------------- /finetune/fire_predict_conductivity.sh: -------------------------------------------------------------------------------- 1 | python -W ignore fire_molformer_predict.py \ 2 | --device cuda \ 3 | --batch_size 1 \ 4 | --n_head 12 \ 5 | --n_layer 12 \ 6 | --n_embd 768 \ 7 | --d_dropout 0.1 \ 8 | --dropout 0.1 \ 9 | --lr_start 3e-5 \ 10 | --num_workers 8\ 11 | --num_feats 32 \ 12 | --dataset_name conductivity \ 13 | --data_root ../data/conductivity \ 14 | --measure_name conductivity \ 15 | --dims 768 768 768 1 \ 16 | --checkpoints_folder './checkpoints_conductivity'\ 17 | -------------------------------------------------------------------------------- /finetune/fire_predict_density.sh: -------------------------------------------------------------------------------- 1 | python -W ignore fire_molformer_predict.py \ 2 | --device cuda \ 3 | --batch_size 1 \ 4 | --n_head 12 \ 5 | --n_layer 12 \ 6 | --n_embd 768 \ 7 | --d_dropout 0.1 \ 8 | --dropout 0.1 \ 9 | --lr_start 3e-5 \ 10 | --num_workers 8\ 11 | --num_feats 32 \ 12 | --dataset_name density \ 13 | --data_root ../data/density \ 14 | --measure_name density \ 15 | --dims 768 768 768 1 \ 16 | --checkpoints_folder './checkpoints_density'\ 17 | -------------------------------------------------------------------------------- /finetune/fire_predict_qpua_pk50.sh: -------------------------------------------------------------------------------- 1 | python -W ignore fire_molformer_predict.py \ 2 | --device cuda \ 3 | --batch_size 1 \ 4 | --n_head 12 \ 5 | --n_layer 12 \ 6 | --n_embd 768 \ 7 | --d_dropout 0.1 \ 8 | --dropout 0.1 \ 9 | --lr_start 3e-5 \ 10 | --num_workers 8\ 11 | --num_feats 32 \ 12 | --dataset_name qpua_pk50 \ 13 | --data_root ../data/qpua_pk50 \ 14 | --measure_name qpua_pk50 \ 15 | --dims 768 768 768 1 \ 16 | --checkpoints_folder './checkpoints_qpua_pk50'\ 17 | -------------------------------------------------------------------------------- /finetune/fire_predict_sea.sh: -------------------------------------------------------------------------------- 1 | python -W ignore fire_molformer_predict.py \ 2 | --device cuda \ 3 | --batch_size 1 \ 4 | --n_head 12 \ 5 | --n_layer 12 \ 6 | --n_embd 768 \ 7 | --d_dropout 0.1 \ 8 | --dropout 0.1 \ 9 | --lr_start 3e-5 \ 10 | --num_workers 8\ 11 | --num_feats 32 \ 12 | --dataset_name sea \ 13 | --data_root ../data/sea \ 14 | --measure_name sea \ 15 | --dims 768 768 768 1 \ 16 | --checkpoints_folder './checkpoints_sea'\ 17 | -------------------------------------------------------------------------------- /finetune/fire_predict_specific_heat.sh: -------------------------------------------------------------------------------- 1 | python -W ignore fire_molformer_predict.py \ 2 | --device cuda \ 3 | --batch_size 1 \ 4 | --n_head 12 \ 5 | --n_layer 12 \ 6 | --n_embd 768 \ 7 | --d_dropout 0.1 \ 8 | --dropout 0.1 \ 9 | --lr_start 3e-5 \ 10 | --num_workers 8\ 11 | --num_feats 32 \ 12 | --dataset_name specific_heat \ 13 | --data_root ../data/specific_heat \ 14 | --measure_name specific_heat \ 15 | --dims 768 768 768 1 \ 16 | --checkpoints_folder './checkpoints_specific_heat'\ 17 | -------------------------------------------------------------------------------- /finetune/fire_predict_tig50.sh: -------------------------------------------------------------------------------- 1 | python -W ignore fire_molformer_predict.py \ 2 | --device cuda \ 3 | --batch_size 1 \ 4 | --n_head 12 \ 5 | --n_layer 12 \ 6 | --n_embd 768 \ 7 | --d_dropout 0.1 \ 8 | --dropout 0.1 \ 9 | --lr_start 3e-5 \ 10 | --num_workers 8\ 11 | --num_feats 32 \ 12 | --dataset_name tig50 \ 13 | --data_root ../data/tig50 \ 14 | --measure_name tig50 \ 15 | --dims 768 768 768 1 \ 16 | --checkpoints_folder './checkpoints_tig50'\ 17 | -------------------------------------------------------------------------------- /finetune/fire_run_finetune_conductivity.sh: -------------------------------------------------------------------------------- 1 | python fire_finetune_pubchem_light_relL2Loss_LRSchedule.py \ 2 | --device cuda \ 3 | --batch_size 10 \ 4 | --n_head 12 \ 5 | --n_layer 12 \ 6 | --n_embd 768 \ 7 | --d_dropout 0.1 \ 8 | --dropout 0.1 \ 9 | --lr_start 3e-5 \ 10 | --num_workers 8\ 11 | --max_epochs 500 \ 12 | --num_feats 32 \ 13 | --checkpoint_every 100 \ 14 | --seed_path '../data/Pretrained MoLFormer/checkpoints/N-Step-Checkpoint_3_30000.ckpt' \ 15 | --dataset_name conductivity \ 16 | --data_root ../data/conductivity \ 17 | --measure_name conductivity \ 18 | --dims 768 768 768 1 \ 19 | --checkpoints_folder './checkpoints_conductivity'\ 20 | -------------------------------------------------------------------------------- /finetune/fire_run_finetune_density.sh: -------------------------------------------------------------------------------- 1 | python fire_finetune_pubchem_light_relL2Loss_LRSchedule.py \ 2 | --device cuda \ 3 | --batch_size 50 \ 4 | --n_head 12 \ 5 | --n_layer 12 \ 6 | --n_embd 768 \ 7 | --d_dropout 0.1 \ 8 | --dropout 0.1 \ 9 | --lr_start 3e-5 \ 10 | --num_workers 8\ 11 | --max_epochs 500 \ 12 | --num_feats 32 \ 13 | --checkpoint_every 100 \ 14 | --seed_path '../data/Pretrained MoLFormer/checkpoints/N-Step-Checkpoint_3_30000.ckpt' \ 15 | --dataset_name density \ 16 | --data_root ../data/density \ 17 | --measure_name density \ 18 | --dims 768 768 768 1 \ 19 | --checkpoints_folder './checkpoints_density'\ 20 | -------------------------------------------------------------------------------- /finetune/fire_run_finetune_qpua_pk50.sh: -------------------------------------------------------------------------------- 1 | python fire_finetune_pubchem_light_relL2Loss_LRSchedule.py \ 2 | --device cuda \ 3 | --batch_size 3 \ 4 | --n_head 12 \ 5 | --n_layer 12 \ 6 | --n_embd 768 \ 7 | --d_dropout 0.1 \ 8 | --dropout 0.1 \ 9 | --lr_start 3e-5 \ 10 | --num_workers 8\ 11 | --max_epochs 500 \ 12 | --num_feats 32 \ 13 | --checkpoint_every 100 \ 14 | --seed_path '../data/Pretrained MoLFormer/checkpoints/N-Step-Checkpoint_3_30000.ckpt' \ 15 | --dataset_name qpua_pk50 \ 16 | --data_root ../data/qpua_pk50 \ 17 | --measure_name qpua_pk50 \ 18 | --dims 768 768 768 1 \ 19 | --checkpoints_folder './checkpoints_qpua_pk50'\ 20 | -------------------------------------------------------------------------------- /finetune/fire_run_finetune_sea.sh: -------------------------------------------------------------------------------- 1 | python fire_finetune_pubchem_light_relL2Loss_LRSchedule.py \ 2 | --device cuda \ 3 | --batch_size 2 \ 4 | --n_head 12 \ 5 | --n_layer 12 \ 6 | --n_embd 768 \ 7 | --d_dropout 0.1 \ 8 | --dropout 0.1 \ 9 | --lr_start 3e-5 \ 10 | --num_workers 8\ 11 | --max_epochs 500 \ 12 | --num_feats 32 \ 13 | --checkpoint_every 100 \ 14 | --seed_path '../data/Pretrained MoLFormer/checkpoints/N-Step-Checkpoint_3_30000.ckpt' \ 15 | --dataset_name sea \ 16 | --data_root ../data/sea \ 17 | --measure_name sea \ 18 | --dims 768 768 768 1 \ 19 | --checkpoints_folder './checkpoints_sea'\ 20 | -------------------------------------------------------------------------------- /finetune/fire_run_finetune_specific_heat.sh: -------------------------------------------------------------------------------- 1 | python fire_finetune_pubchem_light_relL2Loss_LRSchedule.py \ 2 | --device cuda \ 3 | --batch_size 20 \ 4 | --n_head 12 \ 5 | --n_layer 12 \ 6 | --n_embd 768 \ 7 | --d_dropout 0.1 \ 8 | --dropout 0.1 \ 9 | --lr_start 3e-5 \ 10 | --num_workers 8\ 11 | --max_epochs 500 \ 12 | --num_feats 32 \ 13 | --checkpoint_every 100 \ 14 | --seed_path '../data/Pretrained MoLFormer/checkpoints/N-Step-Checkpoint_3_30000.ckpt' \ 15 | --dataset_name specific_heat \ 16 | --data_root ../data/specific_heat \ 17 | --measure_name specific_heat \ 18 | --dims 768 768 768 1 \ 19 | --checkpoints_folder './checkpoints_specific_heat'\ 20 | -------------------------------------------------------------------------------- /finetune/fire_run_finetune_tig50.sh: -------------------------------------------------------------------------------- 1 | python fire_finetune_pubchem_light_relL2Loss_LRSchedule.py \ 2 | --device cuda \ 3 | --batch_size 3 \ 4 | --n_head 12 \ 5 | --n_layer 12 \ 6 | --n_embd 768 \ 7 | --d_dropout 0.1 \ 8 | --dropout 0.1 \ 9 | --lr_start 3e-5 \ 10 | --num_workers 8\ 11 | --max_epochs 500 \ 12 | --num_feats 32 \ 13 | --checkpoint_every 100 \ 14 | --seed_path '../data/Pretrained MoLFormer/checkpoints/N-Step-Checkpoint_3_30000.ckpt' \ 15 | --dataset_name tig50 \ 16 | --data_root ../data/tig50 \ 17 | --measure_name tig50 \ 18 | --dims 768 768 768 1 \ 19 | --checkpoints_folder './checkpoints_tig50'\ 20 | -------------------------------------------------------------------------------- /finetune/rotate_attention/attention_layer.py: -------------------------------------------------------------------------------- 1 | """The rotate attention layer performs all the query key value projections and 2 | output projections leaving the implementation of the attention to the inner 3 | attention module. 4 | """ 5 | 6 | from torch.nn import Linear, Module 7 | 8 | from fast_transformers.attention import AttentionLayer 9 | from fast_transformers.events import EventDispatcher, QKVEvent 10 | from .rotary import RotaryEmbedding, apply_rotary_pos_emb 11 | 12 | class RotateAttentionLayer(AttentionLayer): 13 | """Rotate attention layer inherits from fast_transformer attention layer. 14 | The only thing added is an Embedding encoding, for more information 15 | on the attention layer see the fast_transformers code 16 | """ 17 | def __init__(self, attention, d_model, n_heads, d_keys=None, 18 | d_values=None, event_dispatcher=""): 19 | super(RotateAttentionLayer, self).__init__(attention,d_model, n_heads, d_keys=d_keys, 20 | d_values=d_values, event_dispatcher=event_dispatcher) 21 | 22 | self.rotaryemb = RotaryEmbedding(d_keys) 23 | print('Using Rotation Embedding') 24 | 25 | def forward(self, queries, keys, values, attn_mask, query_lengths, 26 | key_lengths): 27 | """ 28 | Using the same frame work as the fast_Transformers attention layer 29 | but injecting rotary information to the queries and the keys 30 | after the keys and queries are projected. 31 | In the argument description we make use of the following sizes 32 | 33 | - N: the batch size 34 | - L: The maximum length of the queries 35 | - S: The maximum length of the keys (the actual length per sequence 36 | is given by the length mask) 37 | - D: The input feature dimensionality passed in the constructor as 38 | 'd_model' 39 | 40 | Arguments 41 | --------- 42 | queries: (N, L, D) The tensor containing the queries 43 | keys: (N, S, D) The tensor containing the keys 44 | values: (N, S, D) The tensor containing the values 45 | attn_mask: An implementation of BaseMask that encodes where each 46 | query can attend to 47 | query_lengths: An implementation of BaseMask that encodes how 48 | many queries each sequence in the batch consists of 49 | key_lengths: An implementation of BaseMask that encodes how 50 | many queries each sequence in the batch consists of 51 | 52 | Returns 53 | ------- 54 | The new value for each query as a tensor of shape (N, L, D). 55 | """ 56 | # Extract the dimensions into local variables 57 | N, L, _ = queries.shape 58 | _, S, _ = keys.shape 59 | H = self.n_heads 60 | 61 | # Project the queries/keys/values 62 | queries = self.query_projection(queries).view(N, L, H, -1) 63 | keys = self.key_projection(keys).view(N, S, H, -1) 64 | cos, sin = self.rotaryemb(queries) 65 | queries, keys = apply_rotary_pos_emb(queries, keys, cos, sin) 66 | values = self.value_projection(values).view(N, S, H, -1) 67 | # Let the world know of the qkv 68 | self.event_dispatcher.dispatch(QKVEvent(self, queries, keys, values)) 69 | 70 | 71 | # Compute the attention 72 | new_values = self.inner_attention( 73 | queries, 74 | keys, 75 | values, 76 | attn_mask, 77 | query_lengths, 78 | key_lengths 79 | ).view(N, L, -1) 80 | 81 | # Project the output and return 82 | return self.out_projection(new_values) 83 | -------------------------------------------------------------------------------- /finetune/rotate_attention/rotary.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class RotaryEmbedding(torch.nn.Module): 4 | 5 | def __init__(self, dim, base=10000): 6 | super().__init__() 7 | inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim)) 8 | self.register_buffer('inv_freq', inv_freq) 9 | self.seq_len_cached = 0 10 | self.cos_cached = None 11 | self.sin_cached = None 12 | 13 | def forward(self, x, seq_dim=1): 14 | seq_len = x.shape[seq_dim] 15 | if seq_len != self.seq_len_cached: 16 | #if seq_len > self.seq_len_cached: 17 | self.seq_len_cached = seq_len 18 | t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) 19 | freqs = torch.einsum('i,j->ij', t, self.inv_freq) 20 | emb = torch.cat((freqs, freqs), dim=-1).to(x.device) 21 | self.cos_cached = emb.cos()[None,:, None, :] 22 | self.sin_cached = emb.sin()[None,:, None, :] 23 | #else: 24 | # cos_return = self.cos_cached[..., :seq_len] 25 | # sin_return = self.sin_cached[..., :seq_len] 26 | # return cos_return, sin_return 27 | 28 | return self.cos_cached, self.sin_cached 29 | 30 | 31 | # rotary pos emb helpers: 32 | 33 | def rotate_half(x): 34 | x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:] 35 | return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions 36 | 37 | @torch.jit.script 38 | def apply_rotary_pos_emb(q, k, cos, sin): 39 | return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) 40 | -------------------------------------------------------------------------------- /finetune/rotate_attention/rotate_builder.py: -------------------------------------------------------------------------------- 1 | from torch.nn import LayerNorm 2 | 3 | from .attention_layer import RotateAttentionLayer 4 | from fast_transformers.transformers import TransformerEncoder, TransformerEncoderLayer 5 | from fast_transformers.builders.base import BaseBuilder 6 | from fast_transformers.builders.transformer_builders import BaseTransformerEncoderBuilder 7 | from fast_transformers.builders.attention_builders import AttentionBuilder 8 | 9 | 10 | class RotateEncoderBuilder(BaseTransformerEncoderBuilder): 11 | """Build a batch transformer encoder with Relative Rotary embeddings 12 | for training or processing of sequences all elements at a time. 13 | 14 | Example usage: 15 | 16 | builder = RotateEncoderBuilder() 17 | builder.n_layers = 12 18 | builder.n_heads = 8 19 | builder.feed_forward_dimensions = 1024 20 | builder.query_dimensions = 64 21 | builder.value_dimensions = 64 22 | builder.dropout = 0.1 23 | builder.attention_dropout = 0.1 24 | builder.attention_type = "linear" 25 | transformer = builder.get() 26 | """ 27 | def _get_attention_builder(self): 28 | """Return an instance of the appropriate attention builder.""" 29 | return AttentionBuilder() 30 | 31 | def _get_attention_layer_class(self): 32 | """Return the class for the layer that projects queries keys and 33 | values.""" 34 | return RotateAttentionLayer 35 | 36 | def _get_encoder_class(self): 37 | """Return the class for the transformer encoder.""" 38 | return TransformerEncoder 39 | 40 | def _get_encoder_layer_class(self): 41 | """Return the class for the transformer encoder layer.""" 42 | return TransformerEncoderLayer 43 | -------------------------------------------------------------------------------- /finetune/tokenizer/tokenizer.py: -------------------------------------------------------------------------------- 1 | from transformers import BertTokenizer 2 | import regex as re 3 | 4 | PATTERN = "(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])" 5 | 6 | 7 | class MolTranBertTokenizer(BertTokenizer): 8 | def __init__(self, vocab_file: str = '', 9 | do_lower_case=False, 10 | unk_token='', 11 | sep_token='', 12 | pad_token='', 13 | cls_token='', 14 | mask_token='', 15 | **kwargs): 16 | super().__init__(vocab_file, 17 | unk_token=unk_token, 18 | sep_token=sep_token, 19 | pad_token=pad_token, 20 | cls_token=cls_token, 21 | mask_token=mask_token, 22 | **kwargs) 23 | 24 | self.regex_tokenizer = re.compile(PATTERN) 25 | self.wordpiece_tokenizer = None 26 | self.basic_tokenizer = None 27 | 28 | def _tokenize(self, text): 29 | split_tokens = self.regex_tokenizer.findall(text) 30 | return split_tokens 31 | 32 | def convert_tokens_to_string(self, tokens): 33 | out_string = "".join(tokens).strip() 34 | return out_string 35 | -------------------------------------------------------------------------------- /finetune/utils.py: -------------------------------------------------------------------------------- 1 | import debugpy 2 | import socket 3 | import glob 4 | import pandas as pd 5 | from typing import List 6 | from rdkit import Chem 7 | import torch 8 | 9 | 10 | def getipaddress(): 11 | return socket.gethostbyname(socket.getfqdn()) 12 | 13 | 14 | def debug(): 15 | print("Waiting for debugger to connect") 16 | if ( 17 | socket.getfqdn().startswith("dcc") 18 | or socket.getfqdn().startswith("mol") 19 | or socket.getfqdn().startswith("ccc") 20 | ): 21 | debugpy.listen(address=(getipaddress(), 3000)) 22 | debugpy.wait_for_client() 23 | debugpy.breakpoint() 24 | 25 | 26 | class ListDataset: 27 | def __init__(self, seqs): 28 | self.seqs = seqs 29 | 30 | def __getitem__(self, index): 31 | return self.seqs[index] 32 | 33 | def __len__(self): 34 | return len(self.seqs) 35 | 36 | 37 | def transform_single_embedding_to_multiple(smiles_z_map): 38 | """Transforms an embedding map of the format smi->embedding to 39 | smi-> {"canonical_embeddings":embedding}. This function exists 40 | as a compatibility layer 41 | 42 | Args: 43 | smiles_z_map ([type]): [description] 44 | """ 45 | retval = dict() 46 | for key in smiles_z_map: 47 | retval[key] = {"canonical_embeddings": smiles_z_map[key]} 48 | return retval 49 | 50 | 51 | def normalize_smiles(smi, canonical, isomeric): 52 | normalized = Chem.MolToSmiles( 53 | Chem.MolFromSmiles(smi), canonical=canonical, isomericSmiles=isomeric 54 | ) 55 | return normalized 56 | 57 | 58 | def get_all_proteins(affinity_dir: str): 59 | files = glob.glob(affinity_dir + "/*.csv") 60 | all_proteins = [] 61 | print(files) 62 | for file in files: 63 | df = pd.read_csv(file) 64 | all_proteins.extend(df["protein"].tolist()) 65 | return set(all_proteins) 66 | 67 | 68 | def append_to_file(filename, line): 69 | with open(filename, "a") as f: 70 | f.write(line + "\n") 71 | 72 | 73 | def write_to_file(filename, line): 74 | with open(filename, "w") as f: 75 | f.write(line + "\n") 76 | 77 | 78 | class LpLoss(object): 79 | def __init__(self, d=2, p=2, size_average=True, reduction=True): 80 | super(LpLoss, self).__init__() 81 | 82 | # Dimension and Lp-norm type are postive 83 | assert d > 0 and p > 0 84 | 85 | self.d = d 86 | self.p = p 87 | self.reduction = reduction 88 | self.size_average = size_average 89 | 90 | def abs(self, x, y): 91 | num_examples = x.size()[0] 92 | 93 | # Assume uniform mesh 94 | h = 1.0 / (x.size()[1] - 1.0) 95 | 96 | all_norms = (h ** (self.d / self.p)) * torch.norm(x.view(num_examples, -1) - y.view(num_examples, -1), self.p, 97 | 1) 98 | 99 | if self.reduction: 100 | if self.size_average: 101 | return torch.mean(all_norms) 102 | else: 103 | return torch.sum(all_norms) 104 | 105 | return all_norms 106 | 107 | def rel(self, x, y): 108 | num_examples = x.size()[0] 109 | 110 | diff_norms = torch.norm(x.reshape(num_examples, -1) - y.reshape(num_examples, -1), self.p, 1) 111 | y_norms = torch.norm(y.reshape(num_examples, -1), self.p, 1) 112 | 113 | if self.reduction: 114 | if self.size_average: 115 | return torch.mean(diff_norms / y_norms) 116 | else: 117 | return torch.sum(diff_norms / y_norms) 118 | 119 | return diff_norms / y_norms 120 | 121 | def __call__(self, x, y): 122 | return self.rel(x, y) -------------------------------------------------------------------------------- /notebooks/full_attention_rotary/args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def get_parser(parser=None): 5 | if parser is None: 6 | parser = argparse.ArgumentParser() 7 | 8 | # Model 9 | # model_arg = parser.add_argument_group('Model') 10 | parser.add_argument("--n_head", type=int, default=12, help="GPT number of heads") 11 | parser.add_argument( 12 | "--fold", type=int, default=0, help="number of folds for fine tuning" 13 | ) 14 | parser.add_argument("--n_layer", type=int, default=12, help="GPT number of layers") 15 | parser.add_argument( 16 | "--q_dropout", type=float, default=0.5, help="Encoder layers dropout" 17 | ) 18 | parser.add_argument( 19 | "--d_dropout", type=float, default=0.1, help="Decoder layers dropout" 20 | ) 21 | parser.add_argument( 22 | "--n_embd", type=int, default=768, help="Latent vector dimensionality" 23 | ) 24 | parser.add_argument( 25 | "--fc_h", type=int, default=512, help="Fully connected hidden dimensionality" 26 | ) 27 | 28 | # Train 29 | # train_arg = parser.add_argument_group('Train') 30 | parser.add_argument("--n_batch", type=int, default=512, help="Batch size") 31 | parser.add_argument( 32 | "--unlike_alpha", type=float, default=1.0, help="unlikelihood loss alpha weight" 33 | ) 34 | parser.add_argument( 35 | "--from_scratch", 36 | action="store_true", 37 | default=False, 38 | help="train on qm9 from scratch", 39 | ) 40 | parser.add_argument( 41 | "--unlikelihood", 42 | action="store_true", 43 | default=False, 44 | help="use unlikelihood loss with gpt pretrain", 45 | ) 46 | parser.add_argument( 47 | "--grad_acc", 48 | type=int, 49 | default=1, 50 | help="number of batches to accumulate gradients", 51 | ) 52 | parser.add_argument( 53 | "--checkpoint_every", 54 | type=int, 55 | default=1000, 56 | help="save checkpoint every x iterations", 57 | ) 58 | parser.add_argument( 59 | "--clip_grad", type=int, default=50, help="Clip gradients to this value" 60 | ) 61 | parser.add_argument( 62 | "--lr_start", type=float, default=3 * 1e-4, help="Initial lr value" 63 | ) 64 | parser.add_argument( 65 | "--lr_end", type=float, default=3 * 1e-4, help="Maximum lr weight value" 66 | ) 67 | parser.add_argument( 68 | "--lr_multiplier", type=int, default=1, help="lr weight multiplier" 69 | ) 70 | parser.add_argument( 71 | "--n_last", type=int, default=1000, help="Number of iters to smooth loss calc" 72 | ) 73 | parser.add_argument("--n_jobs", type=int, default=1, help="Number of threads") 74 | """parser.add_argument('--molecule', 75 | type=str, default='C=O=C', 76 | help='input molecule string from command line')""" 77 | parser.add_argument( 78 | "--accelerator", 79 | type=str, 80 | default="ddp", 81 | help="The accelerator backend to use (previously known as distributed_backend)", 82 | ) 83 | parser.add_argument( 84 | "--num_nodes", 85 | type=int, 86 | default=1, 87 | help="number of GPU nodes for distributed training", 88 | ) 89 | parser.add_argument( 90 | "--device", 91 | type=str, 92 | default="cuda", 93 | help='Device to run: "cpu" or "cuda:"', 94 | ) 95 | parser.add_argument("--seed", type=int, default=12345, help="Seed") 96 | 97 | # common_arg = parser.add_argument_group('Common') 98 | parser.add_argument( 99 | "--vocab_load", type=str, required=False, help="Where to load the vocab" 100 | ) 101 | parser.add_argument( 102 | "--n_samples", type=int, required=False, help="Number of samples to sample" 103 | ) 104 | parser.add_argument( 105 | "--gen_save", type=str, required=False, help="Where to save the gen molecules" 106 | ) 107 | parser.add_argument( 108 | "--max_len", type=int, default=202, help="Max of length of SMILES" 109 | ) 110 | parser.add_argument( 111 | "--train_load", type=str, required=False, help="Where to load the model" 112 | ) 113 | parser.add_argument( 114 | "--val_load", type=str, required=False, help="Where to load the model" 115 | ) 116 | parser.add_argument( 117 | "--n_workers", 118 | type=int, 119 | required=False, 120 | default=8, 121 | help="Where to load the model", 122 | ) 123 | # beam search hyper parameters 124 | parser.add_argument( 125 | "--beam_size", type=int, default=0, help="Number of beams to generate" 126 | ) 127 | parser.add_argument( 128 | "--num_seq_returned", 129 | type=int, 130 | default=0, 131 | help="number of beams to be returned (must be <= beam_size", 132 | ) 133 | parser.add_argument( 134 | "--min_len", type=int, default=1, help="minimum length to be generated" 135 | ) 136 | parser.add_argument( 137 | "--nucleus_thresh", type=float, default=0.9, help="nucleus sampling threshold" 138 | ) 139 | parser.add_argument( 140 | "--seed_path", 141 | type=str, 142 | default="../../data/checkpoints/full_rotary.ckpt", 143 | help="path to trainer file to continue training", 144 | ) 145 | parser.add_argument( 146 | "--data_path", type=str, default="", help="path to pubchem file" 147 | ) 148 | parser.add_argument( 149 | "--pretext_size", type=int, default=0, help="number of k-mers to pretext" 150 | ) 151 | parser.add_argument( 152 | "--model_save_dir", 153 | type=str, 154 | required=False, 155 | default="./models_dump/", 156 | help="Where to save the models/log/config/vocab", 157 | ) 158 | parser.add_argument( 159 | "--model_save", 160 | type=str, 161 | required=False, 162 | default="model.pt", 163 | help="Where to save the model", 164 | ) 165 | # parser.add_argument('--save_frequency', 166 | # type=int, default=20, 167 | # help='How often to save the model') 168 | parser.add_argument( 169 | "--num_epoch", type=int, default=1, help="number of epochs to train" 170 | ) 171 | # parser.add_argument('--num_iter', 172 | # type=int, default=-1, 173 | # help='how many itersations per epoch (for unlikelihood tuning)') 174 | parser.add_argument( 175 | "--log_file", type=str, required=False, help="Where to save the log" 176 | ) 177 | parser.add_argument( 178 | "--tb_loc", 179 | type=str, 180 | required=False, 181 | help="Where to save the tensorflow location", 182 | ) 183 | parser.add_argument( 184 | "--config_save", type=str, required=False, help="Where to save the config" 185 | ) 186 | parser.add_argument("--vocab_save", type=str, help="Where to save the vocab") 187 | 188 | # resume_arg = parser.add_argument_group('Resume') 189 | parser.add_argument("--mask", default=False, action="store_true", help="mask") 190 | parser.add_argument( 191 | "--canonical", default=False, action="store_true", help="convert to canonical" 192 | ) 193 | parser.add_argument( 194 | "--debug", 195 | default=False, 196 | action="store_true", 197 | help="do not erase cache at end of program", 198 | ) 199 | parser.add_argument( 200 | "--fast_dev_run", 201 | default=False, 202 | help="This flag runs a “unit test” by running n if set to n (int) else 1 if set to True training and validation batch(es).", 203 | ) 204 | parser.add_argument( 205 | "--freeze_model", 206 | default=False, 207 | action="store_true", 208 | help="freeze weights of bert model during fine tuning", 209 | ) 210 | parser.add_argument( 211 | "--resume", default=False, action="store_true", help="Resume from a saved model" 212 | ) 213 | parser.add_argument( 214 | "--rotate", 215 | default=True, 216 | action="store_true", 217 | help="use rotational relative embedding", 218 | ) 219 | parser.add_argument( 220 | "--model_load", type=str, required=False, help="Where to load the model" 221 | ) 222 | parser.add_argument( 223 | "--root_dir", type=str, required=False, default=".", help="location of root dir" 224 | ) 225 | parser.add_argument( 226 | "--config_load", type=str, required=False, help="Where to load the config" 227 | ) 228 | parser.add_argument( 229 | "--gpus", type=int, required=False, default=1, help="number of gpus to use" 230 | ) 231 | # parser.add_argument('--start_epoch', 232 | # type=int, required=False, default=0, 233 | # help='Where to load the config') 234 | 235 | parser.add_argument( 236 | "--model_arch", 237 | type=str, 238 | required=False, 239 | help="used to teack model arch in params", 240 | ) 241 | parser.add_argument( 242 | "--eval_every", 243 | type=int, 244 | default=50000, 245 | help="run evaluation every x iterations", 246 | ) 247 | parser.add_argument( 248 | "--num_feats", 249 | type=int, 250 | required=False, 251 | default=32, 252 | help="number of random reatures for FAVOR+", 253 | ) 254 | parser.add_argument( 255 | "--max_epochs", type=int, required=False, default=3, help="max number of epochs" 256 | ) 257 | 258 | # debug() FINE TUNEING 259 | # parser.add_argument('--save_dir', type=str, required=True) 260 | parser.add_argument( 261 | "--mode", type=str, default="avg", help="type of pooling to use" 262 | ) 263 | parser.add_argument( 264 | "--train_dataset_length", type=int, default=None, required=False 265 | ) 266 | parser.add_argument("--eval_dataset_length", type=int, default=None, required=False) 267 | parser.add_argument( 268 | "--desc_skip_connection", type=bool, default=False, required=False 269 | ) 270 | parser.add_argument("--num_workers", type=int, default=2, required=False) 271 | parser.add_argument("--dropout", type=float, default=0.1, required=False) 272 | parser.add_argument( 273 | "--smiles_embedding", 274 | type=str, 275 | default="../../data/etc/ba_embeddings.pt", 276 | ) 277 | # parser.add_argument("--train_pct", type=str, required=False, default="95") 278 | parser.add_argument("--aug", type=int, required=False) 279 | parser.add_argument("--num_classes", type=int, required=False) 280 | parser.add_argument("--dataset_name", type=str, required=False, default="sol") 281 | parser.add_argument("--measure_name", type=str, required=False, default="measure") 282 | parser.add_argument("--checkpoints_folder", type=str, required=False) 283 | parser.add_argument("--emb_type", type=str, required=False) 284 | parser.add_argument("--checkpoint_root", type=str, required=False) 285 | 286 | parser.add_argument( 287 | "--data_root", 288 | type=str, 289 | required=False, 290 | default="../../data/etc/ba-predictor/small-data/affinity", 291 | ) 292 | parser.add_argument("--use_linear", type=int, default=0) 293 | 294 | parser.add_argument("--lr", type=float, default=0.001) 295 | parser.add_argument("--batch_size", type=int, default=64) 296 | 297 | return parser 298 | 299 | 300 | def parse_args(): 301 | parser = get_parser() 302 | args, _ = parser.parse_known_args() 303 | return args 304 | -------------------------------------------------------------------------------- /notebooks/full_attention_rotary/get_attention_map_full.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | import torch.nn.functional as F 3 | 4 | from tokenizer.tokenizer import MolTranBertTokenizer 5 | from utils import normalize_smiles 6 | import torch 7 | import shutil 8 | from torch import nn 9 | import args 10 | import os 11 | import getpass 12 | from datasets import load_dataset, concatenate_datasets, load_from_disk 13 | 14 | from fast_transformers.builders import TransformerEncoderBuilder 15 | from fast_transformers.masking import FullMask, LengthMask as LM 16 | from rotate_attention.rotate_builder import RotateEncoderBuilder as rotate_builder 17 | from fast_transformers.feature_maps import Favor, GeneralizedRandomFeatures 18 | from functools import partial 19 | import rotate_attention.full_attention 20 | 21 | from torch.utils.data import DataLoader 22 | 23 | 24 | class TestBert(nn.Module): 25 | def __init__( 26 | self, vocab, model_path=None, extend_pos=False, rotate=False, device="cpu" 27 | ): 28 | if model_path == None: 29 | assert False 30 | super().__init__() 31 | model = torch.load(model_path, map_location=torch.device(device)) 32 | config = Namespace(**model["hyper_parameters"]) 33 | config.rotate = rotate 34 | ( 35 | self.tok_emb, 36 | self.pos_emb, 37 | self.blocks, 38 | self.drop, 39 | self.lang_model, 40 | ) = self.create_model(config, vocab) 41 | 42 | self.drop = nn.Dropout(config.d_dropout) 43 | 44 | def create_model(self, config, vocab): 45 | 46 | n_vocab, d_emb = len(vocab.keys()), config.n_embd 47 | block_size = 250 48 | if config.rotate: 49 | builder = rotate_builder.from_kwargs( 50 | n_layers=config.n_layer, 51 | n_heads=config.n_head, 52 | query_dimensions=config.n_embd // config.n_head, 53 | value_dimensions=config.n_embd // config.n_head, 54 | feed_forward_dimensions=config.n_embd, 55 | attention_type="fullwweights", 56 | #feature_map=partial( 57 | # GeneralizedRandomFeatures, 58 | # n_dims=config.num_feats, 59 | # deterministic_eval=True, 60 | #), 61 | activation="gelu", 62 | ) 63 | pos_emb = None 64 | else: 65 | builder = TransformerEncoderBuilder.from_kwargs( 66 | n_layers=config.n_layer, 67 | n_heads=config.n_head, 68 | query_dimensions=config.n_embd // config.n_head, 69 | value_dimensions=config.n_embd // config.n_head, 70 | feed_forward_dimensions=config.n_embd, 71 | attention_type="fullwweights", 72 | activation="gelu", 73 | ) 74 | pos_emb = nn.Parameter(torch.zeros(1, block_size, config.n_embd)) 75 | tok_emb = nn.Embedding(n_vocab, config.n_embd) 76 | drop = nn.Dropout(config.d_dropout) 77 | 78 | blocks = builder.get() 79 | lang_model = lm_layer(config.n_embd, n_vocab) 80 | train_config = config 81 | block_size = block_size 82 | 83 | return tok_emb, pos_emb, blocks, drop, lang_model 84 | 85 | def forward(self, batch, mask=None, mode="cls"): 86 | b, t = batch.size() 87 | 88 | # forward the GPT model 89 | token_embeddings = self.tok_emb( 90 | batch 91 | ) # each index maps to a (learnable) vector 92 | if self.pos_emb != None: 93 | position_embeddings = self.pos_emb[ 94 | :, :t, : 95 | ] # each position maps to a (learnable) vector 96 | x = self.drop(token_embeddings + position_embeddings) 97 | else: 98 | x = self.drop(token_embeddings) 99 | 100 | if mask != None: 101 | x, attention_mask = self.blocks(x, length_mask=LM(mask._mask.sum(-1))) 102 | 103 | else: 104 | x, attention_mask = self.blocks(x) 105 | 106 | if mode == "cls": 107 | return x[:, 0, :], attention_mask 108 | elif mode == "max": 109 | token_embeddings = x 110 | input_mask_expanded = ( 111 | mask._mask.unsqueeze(-1).expand(token_embeddings.size()).float() 112 | ) 113 | token_embeddings[ 114 | input_mask_expanded == 0 115 | ] = -1e9 # Set padding tokens to large negative value 116 | max_over_time = torch.max(token_embeddings, 1)[0] 117 | return max_over_time, attention_mask 118 | elif mode == "avg": 119 | 120 | token_embeddings = x 121 | input_mask_expanded = ( 122 | mask._mask.unsqueeze(-1).expand(token_embeddings.size()).float() 123 | ) 124 | sum_embeddings = torch.sum( 125 | token_embeddings * input_mask_expanded, 1 126 | ) # sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9) 127 | sum_mask = input_mask_expanded.sum(1) 128 | return sum_embeddings / sum_mask, attention_mask 129 | 130 | 131 | class lm_layer(nn.Module): 132 | def __init__(self, n_embd, n_vocab): 133 | super().__init__() 134 | self.embed = nn.Linear(n_embd, n_embd) 135 | self.ln_f = nn.LayerNorm(n_embd) 136 | self.head = nn.Linear(n_embd, n_vocab, bias=False) 137 | 138 | def forward(self, tensor): 139 | tensor = self.embed(tensor) 140 | tensor = F.gelu(tensor) 141 | tensor = self.ln_f(tensor) 142 | tensor = self.head(tensor) 143 | return tensor 144 | 145 | 146 | def get_database(config): 147 | pubchem_path = { 148 | "train": "/dccstor/trustedgen/data/pubchem/CID-SMILES-CANONICAL.smi" 149 | } 150 | if "CANONICAL" in pubchem_path: 151 | pubchem_script = "./pubchem_canon_script.py" 152 | else: 153 | pubchem_script = "./pubchem_script.py" 154 | dataset_dict = load_dataset( 155 | pubchem_script, 156 | data_files=pubchem_path, 157 | cache_dir=os.path.join( 158 | "/tmp", getpass.getuser(), "pubchem_{}".format(config.chunk_num) 159 | ), 160 | split="train", 161 | ) 162 | train_config = { 163 | "batch_size": config.n_batch, 164 | "shuffle": False, 165 | "num_workers": config.n_workers, 166 | "pin_memory": True, 167 | } 168 | # loader = DataLoader(dataset_dict, **train_config) 169 | print(dataset_dict.cache_files) 170 | cache_files = [] 171 | for cache in dataset_dict.cache_files: 172 | tmp = "/".join(cache["filename"].split("/")[:4]) 173 | print(tmp) 174 | cache_files.append(tmp) 175 | 176 | print("dataset length {}".format(len(dataset_dict))) 177 | if 50000 * config.chunk_num > len(dataset_dict): 178 | index_end = 0 179 | loader = None 180 | elif 50000 + 50000 * config.chunk_num > len(dataset_dict): 181 | index_end = 50000 + 50000 * config.chunk_num - len(dataset_dict) 182 | index = [i + (50000 * config.chunk_num) for i in range(index_end)] 183 | loader = torch.utils.data.Subset(dataset_dict, index) 184 | loader = DataLoader(loader, **train_config) 185 | else: 186 | index_end = 50000 187 | index = [i + (50000 * config.chunk_num) for i in range(index_end)] 188 | loader = torch.utils.data.Subset(dataset_dict, index) 189 | loader = DataLoader(loader, **train_config) 190 | # index= [i+(50000*config.chunk_num) for i in range(index_end)]) 191 | return loader, cache_files 192 | 193 | 194 | def get_bert(config, tokenizer): 195 | bert_model = ( 196 | TestBert( 197 | tokenizer.vocab, 198 | config.seed_path, 199 | rotate=config.rotate, 200 | device=config.device, 201 | ) 202 | .to(config.device) 203 | .eval() 204 | ) 205 | tmp_model = torch.load(config.seed_path)["state_dict"] 206 | bert_model.load_state_dict(tmp_model, strict=True) 207 | return bert_model 208 | 209 | 210 | def remove_tree(cachefiles): 211 | if type(cachefiles) == type([]): 212 | cachefiles = list(set(cachefiles)) 213 | for cache in cachefiles: 214 | shutil.rmtree(cache) 215 | else: 216 | shutil.rmtree(cachefiles) 217 | 218 | 219 | def get_tokens_from_ids(input_ids, tokenizer): 220 | tokens = [] 221 | 222 | for idx_lst in input_ids: 223 | seq = [] 224 | for idx in idx_lst: 225 | seq.append(tokenizer.ids_to_tokens[idx]) 226 | tokens.append(seq) 227 | return tokens 228 | 229 | 230 | def get_full_attention(molecule): 231 | config = args.parse_args() 232 | 233 | model_path = config.seed_path 234 | device = config.device 235 | batch_size = config.batch_size 236 | canonical = config.canonical 237 | mode = config.mode 238 | mask = config.mask 239 | 240 | loader = None 241 | tokenizer = MolTranBertTokenizer("bert_vocab.txt") 242 | bert_model = get_bert(config, tokenizer) 243 | 244 | batch_total = 0 245 | if loader is not None: 246 | 247 | for batch_number, mols in enumerate(loader): 248 | batch_to_save = [] 249 | with torch.no_grad(): 250 | 251 | # print(batch_number) 252 | if config.canonical is True: 253 | output = [ 254 | normalize_smiles(smiles, canonical=True, isomeric=False) 255 | for smiles in mols["text"] 256 | if smiles is not None 257 | ] 258 | else: 259 | output = mols["text"] 260 | batch_ids = tokenizer.batch_encode_plus( 261 | output, 262 | padding=True, 263 | add_special_tokens=True, 264 | return_attention_mask=True, 265 | return_length=True, 266 | ) 267 | 268 | if config.mask is True: 269 | att_mask = FullMask( 270 | torch.tensor(batch_ids["attention_mask"], dtype=bool).to( 271 | device 272 | ), 273 | device=device, 274 | ) 275 | else: 276 | att_mask = FullMask( 277 | torch.ones( 278 | torch.tensor(batch_ids["input_ids"]).size(), dtype=bool 279 | ).to(device), 280 | device=device, 281 | ) 282 | 283 | embeddings, attention_mask = bert_model( 284 | torch.tensor(batch_ids["input_ids"]).to(device), 285 | att_mask, 286 | mode=config.mode, 287 | ) 288 | 289 | for number, mol in enumerate(output): 290 | batch_to_save.append((embeddings[number].cpu().numpy(), mol)) 291 | 292 | # if len(batch_to_save) >= 500: 293 | batch_name = "batch_num_{}.pth".format( 294 | batch_number + (50000 * config.chunk_num) 295 | ) 296 | chunk_name = "chunk_num_{}".format(config.chunk_num) 297 | if batch_number % 250 == 0: 298 | print(batch_name) 299 | torch.save( 300 | batch_to_save[0], 301 | os.path.join("./embedding_dump_deterministic", chunk_name, batch_name), 302 | ) 303 | 304 | else: 305 | with torch.no_grad(): 306 | 307 | if config.canonical is True: 308 | output = [normalize_smiles(molecule, canonical=True, isomeric=False)] 309 | else: 310 | output = molecule 311 | 312 | batch_ids = tokenizer.batch_encode_plus( 313 | [output], 314 | padding=True, 315 | add_special_tokens=True, 316 | return_attention_mask=True, 317 | return_length=True, 318 | ) 319 | 320 | raw_tokens = get_tokens_from_ids(batch_ids["input_ids"], tokenizer)[0] 321 | 322 | if config.mask is True: 323 | att_mask = FullMask( 324 | torch.tensor(batch_ids["attention_mask"], dtype=bool).to(device), 325 | device=device, 326 | ) 327 | else: 328 | att_mask = FullMask( 329 | torch.ones( 330 | torch.tensor(batch_ids["input_ids"]).size(), dtype=bool 331 | ).to(device), 332 | device=device, 333 | ) 334 | 335 | embeddings, attention_mask = bert_model( 336 | torch.tensor(batch_ids["input_ids"]).to(device), 337 | att_mask, 338 | mode=config.mode, 339 | ) 340 | return attention_mask, raw_tokens 341 | 342 | if loader != None: 343 | remove_tree(cache_files) 344 | 345 | 346 | if __name__ == "__main__": 347 | attentions = get_full_attention() 348 | -------------------------------------------------------------------------------- /notebooks/full_attention_rotary/rotate_attention/attention_layer.py: -------------------------------------------------------------------------------- 1 | """The rotate attention layer performs all the query key value projections and 2 | output projections leaving the implementation of the attention to the inner 3 | attention module. 4 | """ 5 | 6 | from torch.nn import Linear, Module 7 | 8 | from fast_transformers.attention import AttentionLayer 9 | from fast_transformers.events import EventDispatcher, QKVEvent 10 | from .rotary import RotaryEmbedding, apply_rotary_pos_emb 11 | 12 | class RotateAttentionLayer(AttentionLayer): 13 | """Rotate attention layer inherits from fast_transformer attention layer. 14 | The only thing added is an Embedding encoding, for more information 15 | on the attention layer see the fast_transformers code 16 | """ 17 | def __init__(self, attention, d_model, n_heads, d_keys=None, 18 | d_values=None, event_dispatcher=""): 19 | super(RotateAttentionLayer, self).__init__(attention,d_model, n_heads, d_keys=d_keys, 20 | d_values=d_values, event_dispatcher=event_dispatcher) 21 | 22 | self.rotaryemb = RotaryEmbedding(d_keys) 23 | print('Using Rotation Embedding') 24 | 25 | def forward(self, queries, keys, values, attn_mask, query_lengths, 26 | key_lengths): 27 | """ 28 | Using the same frame work as the fast_Transformers attention layer 29 | but injecting rotary information to the queries and the keys 30 | after the keys and queries are projected. 31 | In the argument description we make use of the following sizes 32 | 33 | - N: the batch size 34 | - L: The maximum length of the queries 35 | - S: The maximum length of the keys (the actual length per sequence 36 | is given by the length mask) 37 | - D: The input feature dimensionality passed in the constructor as 38 | 'd_model' 39 | 40 | Arguments 41 | --------- 42 | queries: (N, L, D) The tensor containing the queries 43 | keys: (N, S, D) The tensor containing the keys 44 | values: (N, S, D) The tensor containing the values 45 | attn_mask: An implementation of BaseMask that encodes where each 46 | query can attend to 47 | query_lengths: An implementation of BaseMask that encodes how 48 | many queries each sequence in the batch consists of 49 | key_lengths: An implementation of BaseMask that encodes how 50 | many queries each sequence in the batch consists of 51 | 52 | Returns 53 | ------- 54 | The new value for each query as a tensor of shape (N, L, D). 55 | """ 56 | # Extract the dimensions into local variables 57 | N, L, _ = queries.shape 58 | _, S, _ = keys.shape 59 | H = self.n_heads 60 | 61 | # Project the queries/keys/values 62 | queries = self.query_projection(queries).view(N, L, H, -1) 63 | keys = self.key_projection(keys).view(N, S, H, -1) 64 | cos, sin = self.rotaryemb(queries) 65 | queries, keys = apply_rotary_pos_emb(queries, keys, cos, sin) 66 | values = self.value_projection(values).view(N, S, H, -1) 67 | # Let the world know of the qkv 68 | self.event_dispatcher.dispatch(QKVEvent(self, queries, keys, values)) 69 | 70 | 71 | # Compute the attention 72 | new_values, attention_weights = self.inner_attention( 73 | queries, 74 | keys, 75 | values, 76 | attn_mask, 77 | query_lengths, 78 | key_lengths 79 | ) 80 | 81 | new_values = new_values.view(N, L, -1) 82 | 83 | # Project the output and return 84 | return self.out_projection(new_values), attention_weights 85 | -------------------------------------------------------------------------------- /notebooks/full_attention_rotary/rotate_attention/full_attention.py: -------------------------------------------------------------------------------- 1 | from math import sqrt 2 | 3 | import torch 4 | from torch.nn import Dropout, Module 5 | 6 | from fast_transformers.attention_registry import AttentionRegistry, Optional, Float, \ 7 | EventDispatcherInstance 8 | from fast_transformers.events import EventDispatcher, AttentionEvent 9 | from fast_transformers.attention.full_attention import FullAttention 10 | 11 | class FullWWeight(Module): 12 | """ 13 | Slightly modify the fast transformers Full attention to return 14 | the 'attention weights' for visual analysis 15 | """ 16 | def __init__(self, softmax_temp=None, attention_dropout=0.1, event_dispatcher=""): 17 | super(FullWWeight, self).__init__() 18 | 19 | self.softmax_temp = softmax_temp 20 | self.dropout = Dropout(attention_dropout) 21 | self.event_dispatcher = EventDispatcher.get(event_dispatcher) 22 | 23 | def forward(self, queries, keys, values, attn_mask, query_lengths, 24 | key_lengths): 25 | # Extract some shapes and compute the temperature 26 | N, L, H, E = queries.shape 27 | _, S, _, D = values.shape 28 | softmax_temp = self.softmax_temp or 1./sqrt(E) 29 | 30 | # Scale the queries instead of applying the softmax temperature to the 31 | # dot products 32 | queries = queries * softmax_temp 33 | 34 | # Compute the unnormalized attention and apply the masks 35 | QK = torch.einsum("nlhe,nshe->nhls", queries, keys) 36 | if not attn_mask.all_ones: 37 | QK = QK + attn_mask.additive_matrix 38 | if not key_lengths.all_ones: 39 | QK = QK + key_lengths.additive_matrix[:, None, None] 40 | 41 | # Compute the attention and the weighted average 42 | attention_weights = torch.softmax(QK, dim=-1) 43 | #A = self.dropout(torch.softmax(QK, dim=-1)) 44 | A = self.dropout(attention_weights) 45 | V = torch.einsum("nhls,nshd->nlhd", A, values) 46 | 47 | # Let the world know of the attention matrix 48 | self.event_dispatcher.dispatch(AttentionEvent(self, A)) 49 | 50 | # Make sure that what we return is contiguous 51 | return V.contiguous(), attention_weights.detach() 52 | 53 | 54 | # Register the attention implementation so that it becomes available in our 55 | # builders 56 | AttentionRegistry.register( 57 | "fullwweights", FullWWeight, 58 | [ 59 | ("softmax_temp", Optional(Float)), 60 | ("attention_dropout", Optional(Float, 0.1)), 61 | ("event_dispatcher", Optional(EventDispatcherInstance, "")) 62 | ] 63 | ) 64 | -------------------------------------------------------------------------------- /notebooks/full_attention_rotary/rotate_attention/linear_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Module 3 | 4 | from fast_transformers.attention_registry import AttentionRegistry, Optional, Callable, Int, \ 5 | EventDispatcherInstance 6 | from fast_transformers.attention.linear_attention import LinearAttention 7 | from fast_transformers.events import EventDispatcher 8 | 9 | 10 | class LinearWWeight(Module): 11 | """ 12 | Slightly modify the fast transformers linear attention to return 13 | the 'attention weights' for visual analysis 14 | """ 15 | def __init__(self, query_dimensions, feature_map=None, eps=1e-6, 16 | event_dispatcher=""): 17 | super(LinearWWeight, self).__init__() 18 | 19 | self.feature_map = (feature_map(query_dimensions)) 20 | self.eps = eps 21 | self.event_dispatcher=EventDispatcher.get(event_dispatcher) 22 | def forward(self, queries, keys, values, attn_mask, query_lengths, 23 | key_lengths): 24 | # Apply the feature map to the queries and keys 25 | self.feature_map.new_feature_map(queries.device) 26 | Q = self.feature_map.forward_queries(queries) 27 | K = self.feature_map.forward_keys(keys) 28 | 29 | # Apply the key padding mask and make sure that the attn_mask is 30 | # all_ones 31 | if not attn_mask.all_ones: 32 | raise RuntimeError(("LinearAttention does not support arbitrary " 33 | "attention masks")) 34 | K = K * key_lengths.float_matrix[:, :, None, None] 35 | 36 | # Compute the KV matrix, namely the dot product of keys and values so 37 | # that we never explicitly compute the attention matrix and thus 38 | # decrease the complexity 39 | KV = torch.einsum("nshd,nshm->nhmd", K, values) 40 | 41 | # Compute the normalizer 42 | Z = 1/(torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1))+self.eps) 43 | 44 | attention = torch.einsum('nlhd, nshd->nlsh', Q, K) 45 | #try both postive and negative values 46 | 47 | 48 | attention_norm = 1/(torch.einsum('nlsh->nlh', attention+self.eps)) 49 | attention_out = torch.einsum('nlsh, nlh->nlsh', attention, attention_norm) 50 | attention_out = torch.einsum('nlsh->nhls', attention_out) 51 | #remove negative numbers with relu 52 | 53 | # Finally compute and return the new values 54 | V = torch.einsum("nlhd,nhmd,nlh->nlhm", Q, KV, Z) 55 | 56 | return V.contiguous(), attention_out.detach() 57 | 58 | 59 | # Register the attention implementation so that it becomes available in our 60 | # builders 61 | AttentionRegistry.register( 62 | "linearwweights", LinearWWeight, 63 | [ 64 | ("query_dimensions", Int), 65 | ("feature_map", Optional(Callable)), 66 | ("event_dispatcher", Optional(EventDispatcherInstance, "")) 67 | ] 68 | ) 69 | -------------------------------------------------------------------------------- /notebooks/full_attention_rotary/rotate_attention/rotary.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class RotaryEmbedding(torch.nn.Module): 4 | 5 | def __init__(self, dim, base=10000): 6 | super().__init__() 7 | inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim)) 8 | self.register_buffer('inv_freq', inv_freq) 9 | self.seq_len_cached = 0 10 | self.cos_cached = None 11 | self.sin_cached = None 12 | 13 | def forward(self, x, seq_dim=1): 14 | seq_len = x.shape[seq_dim] 15 | if seq_len != self.seq_len_cached: 16 | #if seq_len > self.seq_len_cached: 17 | self.seq_len_cached = seq_len 18 | t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) 19 | freqs = torch.einsum('i,j->ij', t, self.inv_freq) 20 | emb = torch.cat((freqs, freqs), dim=-1).to(x.device) 21 | self.cos_cached = emb.cos()[None,:, None, :] 22 | self.sin_cached = emb.sin()[None,:, None, :] 23 | #else: 24 | # cos_return = self.cos_cached[..., :seq_len] 25 | # sin_return = self.sin_cached[..., :seq_len] 26 | # return cos_return, sin_return 27 | 28 | return self.cos_cached, self.sin_cached 29 | 30 | 31 | # rotary pos emb helpers: 32 | 33 | def rotate_half(x): 34 | x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:] 35 | return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions 36 | 37 | @torch.jit.script 38 | def apply_rotary_pos_emb(q, k, cos, sin): 39 | return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) 40 | -------------------------------------------------------------------------------- /notebooks/full_attention_rotary/rotate_attention/rotate_builder.py: -------------------------------------------------------------------------------- 1 | from torch.nn import LayerNorm 2 | 3 | from .transformers import VizEncoderLayer, VizEncoder 4 | from .attention_layer import RotateAttentionLayer 5 | from.linear_attention import LinearWWeight 6 | from .full_attention import FullWWeight 7 | 8 | from fast_transformers.transformers import TransformerEncoder, TransformerEncoderLayer 9 | from fast_transformers.builders.base import BaseBuilder 10 | from fast_transformers.builders.transformer_builders import BaseTransformerEncoderBuilder 11 | from fast_transformers.builders.attention_builders import AttentionBuilder 12 | 13 | 14 | class RotateEncoderBuilder(BaseTransformerEncoderBuilder): 15 | """Build a batch transformer encoder with Relative Rotary embeddings 16 | for training or processing of sequences all elements at a time. 17 | 18 | Example usage: 19 | 20 | builder = RotateEncoderBuilder() 21 | builder.n_layers = 12 22 | builder.n_heads = 8 23 | builder.feed_forward_dimensions = 1024 24 | builder.query_dimensions = 64 25 | builder.value_dimensions = 64 26 | builder.dropout = 0.1 27 | builder.attention_dropout = 0.1 28 | builder.attention_type = "linear" 29 | transformer = builder.get() 30 | """ 31 | def _get_attention_builder(self): 32 | """Return an instance of the appropriate attention builder.""" 33 | return AttentionBuilder() 34 | 35 | def _get_attention_layer_class(self): 36 | """Return the class for the layer that projects queries keys and 37 | values.""" 38 | return RotateAttentionLayer 39 | 40 | def _get_encoder_class(self): 41 | """Return the class for the transformer encoder.""" 42 | return VizEncoder 43 | 44 | def _get_encoder_layer_class(self): 45 | """Return the class for the transformer encoder layer.""" 46 | return VizEncoderLayer 47 | -------------------------------------------------------------------------------- /notebooks/full_attention_rotary/rotate_attention/transformers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Dropout, LayerNorm, Linear, Module, ModuleList 3 | import torch.nn.functional as F 4 | 5 | from fast_transformers.events import EventDispatcher 6 | from fast_transformers.masking import FullMask, LengthMask 7 | from fast_transformers.transformers import TransformerEncoderLayer, TransformerEncoder 8 | 9 | 10 | class VizEncoderLayer(TransformerEncoderLayer): 11 | """Self attention and feed forward network with skip connections. 12 | 13 | This transformer encoder layer implements a modification of the Fast 14 | Transformer layer code where the attention weights of the model are 15 | returned for vizualization. 16 | 17 | Arguments 18 | --------- 19 | These are unchanged from the Fast Transformer encoder layer. For 20 | further information look there. 21 | """ 22 | def __init__(self, attention, d_model, d_ff=None, dropout=0.1, 23 | activation="relu", event_dispatcher=""): 24 | super(VizEncoderLayer, self).__init__(attention, d_model, d_ff=d_ff, dropout=dropout, 25 | activation=activation, event_dispatcher=event_dispatcher) 26 | 27 | def forward(self, x, attn_mask=None, length_mask=None): 28 | """Apply the transformer encoder to the input x. 29 | 30 | Arguments 31 | --------- 32 | x: The input features of shape (N, L, E) where N is the batch size, 33 | L is the sequence length (padded) and E is d_model passed in the 34 | constructor. 35 | attn_mask: An implementation of fast_transformers.masking.BaseMask 36 | that encodes where each element of x can attend to. 37 | length_mask: An implementation of 38 | fast_transformers.masking.BaseMask that encodes how 39 | many elements each sequence in the batch consists of. 40 | """ 41 | # Normalize the masks 42 | N = x.shape[0] 43 | L = x.shape[1] 44 | attn_mask = attn_mask or FullMask(L, device=x.device) 45 | length_mask = length_mask or \ 46 | LengthMask(x.new_full((N,), L, dtype=torch.int64)) 47 | 48 | out, attention_mask = self.attention( 49 | x, x, x, 50 | attn_mask=attn_mask, 51 | query_lengths=length_mask, 52 | key_lengths=length_mask 53 | ) 54 | x = x + self.dropout(out) 55 | # Run the fully connected part of the layer 56 | y = x = self.norm1(x) 57 | y = self.dropout(self.activation(self.linear1(y))) 58 | y = self.dropout(self.linear2(y)) 59 | 60 | return self.norm2(x+y), attention_mask.detach() 61 | 62 | 63 | class VizEncoder(TransformerEncoder): 64 | 65 | """A Modifiation of TransformerEncoder where a list of attention weights 66 | are returned for vizualization purposes. 67 | 68 | Arguments 69 | --------- 70 | These are unchanged from the Fast Transformer encoder. For 71 | further information look there. 72 | """ 73 | def __init__(self, layers, norm_layer=None, event_dispatcher=""): 74 | super(VizEncoder, self).__init__(layers, norm_layer=norm_layer, event_dispatcher=event_dispatcher) 75 | 76 | def forward(self, x, attn_mask=None, length_mask=None): 77 | """Apply all transformer encoder layers to the input x. 78 | 79 | Arguments 80 | --------- 81 | x: The input features of shape (N, L, E) where N is the batch size, 82 | L is the sequence length (padded) and E is d_model passed in the 83 | constructor of each transformer encoder layer. 84 | attn_mask: An implementation of fast_transformers.masking.BaseMask 85 | that encodes where each element of x can attend to. 86 | length_mask: An implementation of 87 | fast_transformers.masking.BaseMask that encodes how 88 | many elements each sequence in the batch consists of. 89 | """ 90 | # Normalize the masks 91 | attention_mask_list = [] 92 | N = x.shape[0] 93 | L = x.shape[1] 94 | attn_mask = attn_mask or FullMask(L, device=x.device) 95 | length_mask = length_mask or \ 96 | LengthMask(x.new_full((N,), L, dtype=torch.int64)) 97 | 98 | # Apply all the transformers 99 | for layer in self.layers: 100 | x, attention_mask = layer(x, attn_mask=attn_mask, length_mask=length_mask) 101 | attention_mask_list.append(attention_mask) 102 | # Apply the normalization if needed 103 | if self.norm is not None: 104 | x = self.norm(x) 105 | 106 | return x, attention_mask_list 107 | 108 | 109 | -------------------------------------------------------------------------------- /notebooks/full_attention_rotary/tokenizer/tokenizer.py: -------------------------------------------------------------------------------- 1 | from transformers import BertTokenizer 2 | import regex as re 3 | 4 | PATTERN = "(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])" 5 | 6 | 7 | class MolTranBertTokenizer(BertTokenizer): 8 | def __init__( 9 | self, 10 | vocab_file: str = "", 11 | do_lower_case=False, 12 | unk_token="", 13 | sep_token="", 14 | pad_token="", 15 | cls_token="", 16 | mask_token="", 17 | **kwargs 18 | ): 19 | super().__init__( 20 | vocab_file, 21 | unk_token=unk_token, 22 | sep_token=sep_token, 23 | pad_token=pad_token, 24 | cls_token=cls_token, 25 | mask_token=mask_token, 26 | **kwargs 27 | ) 28 | 29 | self.regex_tokenizer = re.compile(PATTERN) 30 | self.wordpiece_tokenizer = None 31 | self.basic_tokenizer = None 32 | 33 | def _tokenize(self, text): 34 | split_tokens = self.regex_tokenizer.findall(text) 35 | return split_tokens 36 | 37 | def convert_tokens_to_string(self, tokens): 38 | out_string = "".join(tokens).strip() 39 | return out_string 40 | -------------------------------------------------------------------------------- /notebooks/full_attention_rotary/utils.py: -------------------------------------------------------------------------------- 1 | import socket 2 | import glob 3 | import pandas as pd 4 | from rdkit import Chem 5 | 6 | 7 | def getipaddress(): 8 | return socket.gethostbyname(socket.getfqdn()) 9 | 10 | 11 | class ListDataset: 12 | def __init__(self, seqs): 13 | self.seqs = seqs 14 | 15 | def __getitem__(self, index): 16 | return self.seqs[index] 17 | 18 | def __len__(self): 19 | return len(self.seqs) 20 | 21 | 22 | def transform_single_embedding_to_multiple(smiles_z_map): 23 | """Transforms an embedding map of the format smi->embedding to 24 | smi-> {"canonical_embeddings":embedding}. This function exists 25 | as a compatibility layer 26 | 27 | Args: 28 | smiles_z_map ([type]): [description] 29 | """ 30 | retval = dict() 31 | for key in smiles_z_map: 32 | retval[key] = {"canonical_embeddings": smiles_z_map[key]} 33 | return retval 34 | 35 | 36 | def normalize_smiles(smi, canonical, isomeric): 37 | normalized = Chem.MolToSmiles( 38 | Chem.MolFromSmiles(smi), canonical=canonical, isomericSmiles=isomeric 39 | ) 40 | return normalized 41 | 42 | 43 | def get_all_proteins(affinity_dir: str): 44 | files = glob.glob(affinity_dir + "/*.csv") 45 | all_proteins = [] 46 | print(files) 47 | for file in files: 48 | df = pd.read_csv(file) 49 | all_proteins.extend(df["protein"].tolist()) 50 | return set(all_proteins) 51 | 52 | 53 | def append_to_file(filename, line): 54 | with open(filename, "a") as f: 55 | f.write(line + "\n") 56 | 57 | 58 | def write_to_file(filename, line): 59 | with open(filename, "w") as f: 60 | f.write(line + "\n") 61 | -------------------------------------------------------------------------------- /notebooks/linear_attention_rotary/args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def get_parser(parser=None): 5 | if parser is None: 6 | parser = argparse.ArgumentParser() 7 | 8 | # Model 9 | # model_arg = parser.add_argument_group('Model') 10 | parser.add_argument("--n_head", type=int, default=12, help="GPT number of heads") 11 | parser.add_argument( 12 | "--fold", type=int, default=0, help="number of folds for fine tuning" 13 | ) 14 | parser.add_argument("--n_layer", type=int, default=12, help="GPT number of layers") 15 | parser.add_argument( 16 | "--q_dropout", type=float, default=0.5, help="Encoder layers dropout" 17 | ) 18 | parser.add_argument( 19 | "--d_dropout", type=float, default=0.1, help="Decoder layers dropout" 20 | ) 21 | parser.add_argument( 22 | "--n_embd", type=int, default=768, help="Latent vector dimensionality" 23 | ) 24 | parser.add_argument( 25 | "--fc_h", type=int, default=512, help="Fully connected hidden dimensionality" 26 | ) 27 | 28 | # Train 29 | # train_arg = parser.add_argument_group('Train') 30 | parser.add_argument("--n_batch", type=int, default=512, help="Batch size") 31 | parser.add_argument( 32 | "--unlike_alpha", type=float, default=1.0, help="unlikelihood loss alpha weight" 33 | ) 34 | parser.add_argument( 35 | "--from_scratch", 36 | action="store_true", 37 | default=False, 38 | help="train on qm9 from scratch", 39 | ) 40 | parser.add_argument( 41 | "--unlikelihood", 42 | action="store_true", 43 | default=False, 44 | help="use unlikelihood loss with gpt pretrain", 45 | ) 46 | parser.add_argument( 47 | "--grad_acc", 48 | type=int, 49 | default=1, 50 | help="number of batches to accumulate gradients", 51 | ) 52 | parser.add_argument( 53 | "--checkpoint_every", 54 | type=int, 55 | default=1000, 56 | help="save checkpoint every x iterations", 57 | ) 58 | parser.add_argument( 59 | "--clip_grad", type=int, default=50, help="Clip gradients to this value" 60 | ) 61 | parser.add_argument( 62 | "--lr_start", type=float, default=3 * 1e-4, help="Initial lr value" 63 | ) 64 | parser.add_argument( 65 | "--lr_end", type=float, default=3 * 1e-4, help="Maximum lr weight value" 66 | ) 67 | parser.add_argument( 68 | "--lr_multiplier", type=int, default=1, help="lr weight multiplier" 69 | ) 70 | parser.add_argument( 71 | "--n_last", type=int, default=1000, help="Number of iters to smooth loss calc" 72 | ) 73 | parser.add_argument("--n_jobs", type=int, default=1, help="Number of threads") 74 | """parser.add_argument('--molecule', 75 | type=str, default='C=O=C', 76 | help='input molecule string from command line')""" 77 | parser.add_argument( 78 | "--accelerator", 79 | type=str, 80 | default="ddp", 81 | help="The accelerator backend to use (previously known as distributed_backend)", 82 | ) 83 | parser.add_argument( 84 | "--num_nodes", 85 | type=int, 86 | default=1, 87 | help="number of GPU nodes for distributed training", 88 | ) 89 | parser.add_argument( 90 | "--device", 91 | type=str, 92 | default="cpu", 93 | help='Device to run: "cpu" or "cuda:"', 94 | ) 95 | parser.add_argument("--seed", type=int, default=12345, help="Seed") 96 | 97 | # common_arg = parser.add_argument_group('Common') 98 | parser.add_argument( 99 | "--vocab_load", type=str, required=False, help="Where to load the vocab" 100 | ) 101 | parser.add_argument( 102 | "--n_samples", type=int, required=False, help="Number of samples to sample" 103 | ) 104 | parser.add_argument( 105 | "--gen_save", type=str, required=False, help="Where to save the gen molecules" 106 | ) 107 | parser.add_argument( 108 | "--max_len", type=int, default=202, help="Max of length of SMILES" 109 | ) 110 | parser.add_argument( 111 | "--train_load", type=str, required=False, help="Where to load the model" 112 | ) 113 | parser.add_argument( 114 | "--val_load", type=str, required=False, help="Where to load the model" 115 | ) 116 | parser.add_argument( 117 | "--n_workers", 118 | type=int, 119 | required=False, 120 | default=8, 121 | help="Where to load the model", 122 | ) 123 | # beam search hyper parameters 124 | parser.add_argument( 125 | "--beam_size", type=int, default=0, help="Number of beams to generate" 126 | ) 127 | parser.add_argument( 128 | "--num_seq_returned", 129 | type=int, 130 | default=0, 131 | help="number of beams to be returned (must be <= beam_size", 132 | ) 133 | parser.add_argument( 134 | "--min_len", type=int, default=1, help="minimum length to be generated" 135 | ) 136 | parser.add_argument( 137 | "--nucleus_thresh", type=float, default=0.9, help="nucleus sampling threshold" 138 | ) 139 | parser.add_argument( 140 | "--seed_path", 141 | type=str, 142 | default="../../data/Pretrained MoLFormer/checkpoints/N-Step-Checkpoint_3_30000.ckpt", 143 | help="path to trainer file to continue training", 144 | ) 145 | parser.add_argument( 146 | "--data_path", type=str, default="", help="path to pubchem file" 147 | ) 148 | parser.add_argument( 149 | "--pretext_size", type=int, default=0, help="number of k-mers to pretext" 150 | ) 151 | parser.add_argument( 152 | "--model_save_dir", 153 | type=str, 154 | required=False, 155 | default="./models_dump/", 156 | help="Where to save the models/log/config/vocab", 157 | ) 158 | parser.add_argument( 159 | "--model_save", 160 | type=str, 161 | required=False, 162 | default="model.pt", 163 | help="Where to save the model", 164 | ) 165 | # parser.add_argument('--save_frequency', 166 | # type=int, default=20, 167 | # help='How often to save the model') 168 | parser.add_argument( 169 | "--num_epoch", type=int, default=1, help="number of epochs to train" 170 | ) 171 | # parser.add_argument('--num_iter', 172 | # type=int, default=-1, 173 | # help='how many itersations per epoch (for unlikelihood tuning)') 174 | parser.add_argument( 175 | "--log_file", type=str, required=False, help="Where to save the log" 176 | ) 177 | parser.add_argument( 178 | "--tb_loc", 179 | type=str, 180 | required=False, 181 | help="Where to save the tensorflow location", 182 | ) 183 | parser.add_argument( 184 | "--config_save", type=str, required=False, help="Where to save the config" 185 | ) 186 | parser.add_argument("--vocab_save", type=str, help="Where to save the vocab") 187 | 188 | # resume_arg = parser.add_argument_group('Resume') 189 | parser.add_argument("--mask", default=False, action="store_true", help="mask") 190 | parser.add_argument( 191 | "--canonical", default=False, action="store_true", help="convert to canonical" 192 | ) 193 | parser.add_argument( 194 | "--debug", 195 | default=False, 196 | action="store_true", 197 | help="do not erase cache at end of program", 198 | ) 199 | parser.add_argument( 200 | "--fast_dev_run", 201 | default=False, 202 | help="This flag runs a “unit test” by running n if set to n (int) else 1 if set to True training and validation batch(es).", 203 | ) 204 | parser.add_argument( 205 | "--freeze_model", 206 | default=False, 207 | action="store_true", 208 | help="freeze weights of bert model during fine tuning", 209 | ) 210 | parser.add_argument( 211 | "--resume", default=False, action="store_true", help="Resume from a saved model" 212 | ) 213 | parser.add_argument( 214 | "--rotate", 215 | default=True, 216 | action="store_true", 217 | help="use rotational relative embedding", 218 | ) 219 | parser.add_argument( 220 | "--model_load", type=str, required=False, help="Where to load the model" 221 | ) 222 | parser.add_argument( 223 | "--root_dir", type=str, required=False, default=".", help="location of root dir" 224 | ) 225 | parser.add_argument( 226 | "--config_load", type=str, required=False, help="Where to load the config" 227 | ) 228 | parser.add_argument( 229 | "--gpus", type=int, required=False, default=1, help="number of gpus to use" 230 | ) 231 | # parser.add_argument('--start_epoch', 232 | # type=int, required=False, default=0, 233 | # help='Where to load the config') 234 | 235 | parser.add_argument( 236 | "--model_arch", 237 | type=str, 238 | required=False, 239 | help="used to teack model arch in params", 240 | ) 241 | parser.add_argument( 242 | "--eval_every", 243 | type=int, 244 | default=50000, 245 | help="run evaluation every x iterations", 246 | ) 247 | parser.add_argument( 248 | "--num_feats", 249 | type=int, 250 | required=False, 251 | default=32, 252 | help="number of random reatures for FAVOR+", 253 | ) 254 | parser.add_argument( 255 | "--max_epochs", type=int, required=False, default=3, help="max number of epochs" 256 | ) 257 | 258 | # debug() FINE TUNEING 259 | # parser.add_argument('--save_dir', type=str, required=True) 260 | parser.add_argument( 261 | "--mode", type=str, default="avg", help="type of pooling to use" 262 | ) 263 | parser.add_argument( 264 | "--train_dataset_length", type=int, default=None, required=False 265 | ) 266 | parser.add_argument("--eval_dataset_length", type=int, default=None, required=False) 267 | parser.add_argument( 268 | "--desc_skip_connection", type=bool, default=False, required=False 269 | ) 270 | parser.add_argument("--num_workers", type=int, default=2, required=False) 271 | parser.add_argument("--dropout", type=float, default=0.1, required=False) 272 | parser.add_argument( 273 | "--smiles_embedding", 274 | type=str, 275 | default="../../data/etc/ba_embeddings_tanh_512_2986138_2.pt", 276 | ) 277 | # parser.add_argument("--train_pct", type=str, required=False, default="95") 278 | parser.add_argument("--aug", type=int, required=False) 279 | parser.add_argument("--num_classes", type=int, required=False) 280 | parser.add_argument("--dataset_name", type=str, required=False, default="sol") 281 | parser.add_argument("--measure_name", type=str, required=False, default="measure") 282 | parser.add_argument("--checkpoints_folder", type=str, required=False) 283 | parser.add_argument("--emb_type", type=str, required=False) 284 | parser.add_argument("--checkpoint_root", type=str, required=False) 285 | 286 | parser.add_argument( 287 | "--data_root", 288 | type=str, 289 | required=False, 290 | default="../../data/etc/affinity", 291 | ) 292 | parser.add_argument("--use_linear", type=int, default=0) 293 | 294 | parser.add_argument("--lr", type=float, default=0.001) 295 | parser.add_argument("--batch_size", type=int, default=64) 296 | 297 | return parser 298 | 299 | 300 | def parse_args(): 301 | parser = get_parser() 302 | args, _ = parser.parse_known_args() 303 | return args 304 | -------------------------------------------------------------------------------- /notebooks/linear_attention_rotary/get_attention_map_full.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | import torch.nn.functional as F 3 | 4 | from tokenizer.tokenizer import MolTranBertTokenizer 5 | from utils import normalize_smiles 6 | import torch 7 | import shutil 8 | from torch import nn 9 | import args 10 | import os 11 | import getpass 12 | from datasets import load_dataset, concatenate_datasets, load_from_disk 13 | 14 | from fast_transformers.builders import TransformerEncoderBuilder 15 | from fast_transformers.masking import FullMask, LengthMask as LM 16 | from rotate_attention.rotate_builder import RotateEncoderBuilder as rotate_builder 17 | import fast_transformers.attention.linear_attention 18 | import rotate_attention.linear_attention 19 | 20 | from fast_transformers.feature_maps import Favor, GeneralizedRandomFeatures 21 | from functools import partial 22 | 23 | from torch.utils.data import DataLoader 24 | 25 | 26 | class TestBert(nn.Module): 27 | def __init__( 28 | self, vocab, model_path=None, extend_pos=False, rotate=False, device="cpu" 29 | ): 30 | if model_path == None: 31 | assert False 32 | super().__init__() 33 | model = torch.load(model_path, map_location=torch.device(device)) 34 | config = Namespace(**model["hyper_parameters"]) 35 | config.rotate = rotate 36 | ( 37 | self.tok_emb, 38 | self.pos_emb, 39 | self.blocks, 40 | self.drop, 41 | self.lang_model, 42 | ) = self.create_model(config, vocab) 43 | """if self.pos_emb != None: 44 | pos_emb = nn.Parameter(model['state_dict']['pos_emb']) 45 | 46 | if extend_pos is True and self.pos_emb != None: 47 | pos_extend = nn.Parameter(torch.zeros(1, 1000, config.n_embd)).to(pos_emb.device) 48 | self.pos_emb = nn.Parameter(torch.cat([pos_emb.data, pos_extend.data], dim=1))""" 49 | 50 | self.drop = nn.Dropout(config.d_dropout) 51 | 52 | def create_model(self, config, vocab): 53 | 54 | n_vocab, d_emb = len(vocab.keys()), config.n_embd 55 | block_size = 250 56 | if config.rotate: 57 | builder = rotate_builder.from_kwargs( 58 | n_layers=config.n_layer, 59 | n_heads=config.n_head, 60 | query_dimensions=config.n_embd // config.n_head, 61 | value_dimensions=config.n_embd // config.n_head, 62 | feed_forward_dimensions=config.n_embd, 63 | attention_type="linearwweights", 64 | feature_map=partial( 65 | GeneralizedRandomFeatures, 66 | n_dims=config.num_feats, 67 | deterministic_eval=True, 68 | ), 69 | activation="gelu", 70 | ) 71 | pos_emb = None 72 | else: 73 | builder = TransformerEncoderBuilder.from_kwargs( 74 | n_layers=config.n_layer, 75 | n_heads=config.n_head, 76 | query_dimensions=config.n_embd // config.n_head, 77 | value_dimensions=config.n_embd // config.n_head, 78 | feed_forward_dimensions=config.n_embd, 79 | attention_type="linearwweights", 80 | feature_map=partial(GeneralizedRandomFeatures, n_dims=config.num_feats, deterministic_eval=True), 81 | activation="gelu", 82 | ) 83 | pos_emb = nn.Parameter(torch.zeros(1, block_size, config.n_embd)) 84 | tok_emb = nn.Embedding(n_vocab, config.n_embd) 85 | drop = nn.Dropout(config.d_dropout) 86 | 87 | blocks = builder.get() 88 | lang_model = lm_layer(config.n_embd, n_vocab) 89 | train_config = config 90 | block_size = block_size 91 | 92 | return tok_emb, pos_emb, blocks, drop, lang_model 93 | 94 | def forward(self, batch, mask=None, mode="cls"): 95 | b, t = batch.size() 96 | 97 | # forward the GPT model 98 | token_embeddings = self.tok_emb( 99 | batch 100 | ) # each index maps to a (learnable) vector 101 | if self.pos_emb != None: 102 | position_embeddings = self.pos_emb[ 103 | :, :t, : 104 | ] # each position maps to a (learnable) vector 105 | x = self.drop(token_embeddings + position_embeddings) 106 | else: 107 | x = self.drop(token_embeddings) 108 | 109 | if mask != None: 110 | x, attention_mask = self.blocks(x, length_mask=LM(mask._mask.sum(-1))) 111 | 112 | else: 113 | x, attention_mask = self.blocks(x) 114 | 115 | if mode == "cls": 116 | return x[:, 0, :], attention_mask 117 | elif mode == "max": 118 | token_embeddings = x 119 | input_mask_expanded = ( 120 | mask._mask.unsqueeze(-1).expand(token_embeddings.size()).float() 121 | ) 122 | token_embeddings[ 123 | input_mask_expanded == 0 124 | ] = -1e9 # Set padding tokens to large negative value 125 | max_over_time = torch.max(token_embeddings, 1)[0] 126 | return max_over_time, attention_mask 127 | elif mode == "avg": 128 | 129 | token_embeddings = x 130 | input_mask_expanded = ( 131 | mask._mask.unsqueeze(-1).expand(token_embeddings.size()).float() 132 | ) 133 | sum_embeddings = torch.sum( 134 | token_embeddings * input_mask_expanded, 1 135 | ) # sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9) 136 | sum_mask = input_mask_expanded.sum(1) 137 | return sum_embeddings / sum_mask, attention_mask 138 | 139 | 140 | class lm_layer(nn.Module): 141 | def __init__(self, n_embd, n_vocab): 142 | super().__init__() 143 | self.embed = nn.Linear(n_embd, n_embd) 144 | self.ln_f = nn.LayerNorm(n_embd) 145 | self.head = nn.Linear(n_embd, n_vocab, bias=False) 146 | 147 | def forward(self, tensor): 148 | tensor = self.embed(tensor) 149 | tensor = F.gelu(tensor) 150 | tensor = self.ln_f(tensor) 151 | tensor = self.head(tensor) 152 | return tensor 153 | 154 | 155 | def get_database(config): 156 | pubchem_path = { 157 | "train": "/dccstor/trustedgen/data/pubchem/CID-SMILES-CANONICAL.smi" 158 | } 159 | if "CANONICAL" in pubchem_path: 160 | pubchem_script = "./pubchem_canon_script.py" 161 | else: 162 | pubchem_script = "./pubchem_script.py" 163 | dataset_dict = load_dataset( 164 | pubchem_script, 165 | data_files=pubchem_path, 166 | cache_dir=os.path.join( 167 | "/tmp", getpass.getuser(), "pubchem_{}".format(config.chunk_num) 168 | ), 169 | split="train", 170 | ) 171 | train_config = { 172 | "batch_size": config.n_batch, 173 | "shuffle": False, 174 | "num_workers": config.n_workers, 175 | "pin_memory": True, 176 | } 177 | # loader = DataLoader(dataset_dict, **train_config) 178 | print(dataset_dict.cache_files) 179 | cache_files = [] 180 | for cache in dataset_dict.cache_files: 181 | tmp = "/".join(cache["filename"].split("/")[:4]) 182 | print(tmp) 183 | cache_files.append(tmp) 184 | 185 | print("dataset length {}".format(len(dataset_dict))) 186 | if 50000 * config.chunk_num > len(dataset_dict): 187 | index_end = 0 188 | loader = None 189 | elif 50000 + 50000 * config.chunk_num > len(dataset_dict): 190 | index_end = 50000 + 50000 * config.chunk_num - len(dataset_dict) 191 | index = [i + (50000 * config.chunk_num) for i in range(index_end)] 192 | loader = torch.utils.data.Subset(dataset_dict, index) 193 | loader = DataLoader(loader, **train_config) 194 | else: 195 | index_end = 50000 196 | index = [i + (50000 * config.chunk_num) for i in range(index_end)] 197 | loader = torch.utils.data.Subset(dataset_dict, index) 198 | loader = DataLoader(loader, **train_config) 199 | # index= [i+(50000*config.chunk_num) for i in range(index_end)]) 200 | return loader, cache_files 201 | 202 | 203 | def get_bert(config, tokenizer): 204 | bert_model = TestBert( 205 | tokenizer.vocab, config.seed_path, rotate=config.rotate, device=config.device 206 | ).to(config.device) 207 | tmp_model = torch.load(config.seed_path, map_location=torch.device(config.device))["state_dict"] 208 | bert_model.load_state_dict(tmp_model, strict=True) 209 | bert_model = bert_model.eval() 210 | return bert_model 211 | 212 | 213 | def remove_tree(cachefiles): 214 | if type(cachefiles) == type([]): 215 | cachefiles = list(set(cachefiles)) 216 | for cache in cachefiles: 217 | shutil.rmtree(cache) 218 | else: 219 | shutil.rmtree(cachefiles) 220 | 221 | 222 | def get_tokens_from_ids(input_ids, tokenizer): 223 | tokens = [] 224 | 225 | for idx_lst in input_ids: 226 | seq = [] 227 | for idx in idx_lst: 228 | seq.append(tokenizer.ids_to_tokens[idx]) 229 | tokens.append(seq) 230 | return tokens 231 | 232 | 233 | def get_full_attention(molecule): 234 | config = args.parse_args() 235 | model_path = config.seed_path 236 | device = config.device 237 | batch_size = config.batch_size 238 | canonical = config.canonical 239 | mode = config.mode 240 | mask = config.mask 241 | 242 | loader = None 243 | tokenizer = MolTranBertTokenizer("bert_vocab.txt") 244 | bert_model = get_bert(config, tokenizer) 245 | 246 | batch_total = 0 247 | if loader is not None: 248 | 249 | for batch_number, mols in enumerate(loader): 250 | batch_to_save = [] 251 | with torch.no_grad(): 252 | 253 | # print(batch_number) 254 | if config.canonical is True: 255 | output = [ 256 | normalize_smiles(smiles, canonical=True, isomeric=False) 257 | for smiles in mols["text"] 258 | if smiles is not None 259 | ] 260 | else: 261 | output = mols["text"] 262 | batch_ids = tokenizer.batch_encode_plus( 263 | output, 264 | padding=True, 265 | add_special_tokens=True, 266 | return_attention_mask=True, 267 | return_length=True, 268 | ) 269 | 270 | if config.mask is True: 271 | att_mask = FullMask( 272 | torch.tensor(batch_ids["attention_mask"], dtype=bool).to( 273 | device 274 | ), 275 | device=device, 276 | ) 277 | else: 278 | att_mask = FullMask( 279 | torch.ones( 280 | torch.tensor(batch_ids["input_ids"]).size(), dtype=bool 281 | ).to(device), 282 | device=device, 283 | ) 284 | 285 | embeddings, attention_mask = bert_model( 286 | torch.tensor(batch_ids["input_ids"]).to(device), 287 | att_mask, 288 | mode=config.mode, 289 | ) 290 | 291 | for number, mol in enumerate(output): 292 | batch_to_save.append((embeddings[number].cpu().numpy(), mol)) 293 | 294 | # if len(batch_to_save) >= 500: 295 | batch_name = "batch_num_{}.pth".format( 296 | batch_number + (50000 * config.chunk_num) 297 | ) 298 | chunk_name = "chunk_num_{}".format(config.chunk_num) 299 | if batch_number % 250 == 0: 300 | print(batch_name) 301 | torch.save( 302 | batch_to_save[0], 303 | os.path.join("./embedding_dump_deterministic", chunk_name, batch_name), 304 | ) 305 | 306 | else: 307 | with torch.no_grad(): 308 | 309 | if config.canonical is True: 310 | output = [normalize_smiles(molecule, canonical=True, isomeric=False)] 311 | else: 312 | output = molecule 313 | 314 | batch_ids = tokenizer.batch_encode_plus( 315 | [output], 316 | padding=True, 317 | add_special_tokens=True, 318 | return_attention_mask=True, 319 | return_length=True, 320 | ) 321 | 322 | raw_tokens = get_tokens_from_ids(batch_ids["input_ids"], tokenizer)[0] 323 | 324 | if config.mask is True: 325 | att_mask = FullMask( 326 | torch.tensor(batch_ids["attention_mask"], dtype=bool).to(device), 327 | device=device, 328 | ) 329 | else: 330 | att_mask = FullMask( 331 | torch.ones( 332 | torch.tensor(batch_ids["input_ids"]).size(), dtype=bool 333 | ).to(device), 334 | device=device, 335 | ) 336 | 337 | embeddings, attention_mask = bert_model( 338 | torch.tensor(batch_ids["input_ids"]).to(device), 339 | att_mask, 340 | mode=config.mode, 341 | ) 342 | return attention_mask, raw_tokens 343 | 344 | if loader != None: 345 | remove_tree(cache_files) 346 | 347 | 348 | if __name__ == "__main__": 349 | attentions = get_full_attention() 350 | -------------------------------------------------------------------------------- /notebooks/linear_attention_rotary/rotate_attention/attention_layer.py: -------------------------------------------------------------------------------- 1 | """The rotate attention layer performs all the query key value projections and 2 | output projections leaving the implementation of the attention to the inner 3 | attention module. 4 | """ 5 | 6 | from torch.nn import Linear, Module 7 | 8 | from fast_transformers.attention import AttentionLayer 9 | from fast_transformers.events import EventDispatcher, QKVEvent 10 | from .rotary import RotaryEmbedding, apply_rotary_pos_emb 11 | 12 | class RotateAttentionLayer(AttentionLayer): 13 | """Rotate attention layer inherits from fast_transformer attention layer. 14 | The only thing added is an Embedding encoding, for more information 15 | on the attention layer see the fast_transformers code 16 | """ 17 | def __init__(self, attention, d_model, n_heads, d_keys=None, 18 | d_values=None, event_dispatcher=""): 19 | super(RotateAttentionLayer, self).__init__(attention,d_model, n_heads, d_keys=d_keys, 20 | d_values=d_values, event_dispatcher=event_dispatcher) 21 | 22 | self.rotaryemb = RotaryEmbedding(d_keys) 23 | print('Using Rotation Embedding') 24 | 25 | def forward(self, queries, keys, values, attn_mask, query_lengths, 26 | key_lengths): 27 | """ 28 | Using the same frame work as the fast_Transformers attention layer 29 | but injecting rotary information to the queries and the keys 30 | after the keys and queries are projected. 31 | In the argument description we make use of the following sizes 32 | 33 | - N: the batch size 34 | - L: The maximum length of the queries 35 | - S: The maximum length of the keys (the actual length per sequence 36 | is given by the length mask) 37 | - D: The input feature dimensionality passed in the constructor as 38 | 'd_model' 39 | 40 | Arguments 41 | --------- 42 | queries: (N, L, D) The tensor containing the queries 43 | keys: (N, S, D) The tensor containing the keys 44 | values: (N, S, D) The tensor containing the values 45 | attn_mask: An implementation of BaseMask that encodes where each 46 | query can attend to 47 | query_lengths: An implementation of BaseMask that encodes how 48 | many queries each sequence in the batch consists of 49 | key_lengths: An implementation of BaseMask that encodes how 50 | many queries each sequence in the batch consists of 51 | 52 | Returns 53 | ------- 54 | The new value for each query as a tensor of shape (N, L, D). 55 | """ 56 | # Extract the dimensions into local variables 57 | N, L, _ = queries.shape 58 | _, S, _ = keys.shape 59 | H = self.n_heads 60 | 61 | # Project the queries/keys/values 62 | queries = self.query_projection(queries).view(N, L, H, -1) 63 | keys = self.key_projection(keys).view(N, S, H, -1) 64 | cos, sin = self.rotaryemb(queries) 65 | queries, keys = apply_rotary_pos_emb(queries, keys, cos, sin) 66 | values = self.value_projection(values).view(N, S, H, -1) 67 | # Let the world know of the qkv 68 | self.event_dispatcher.dispatch(QKVEvent(self, queries, keys, values)) 69 | 70 | 71 | # Compute the attention 72 | new_values, attention_weights = self.inner_attention( 73 | queries, 74 | keys, 75 | values, 76 | attn_mask, 77 | query_lengths, 78 | key_lengths 79 | ) 80 | 81 | new_values = new_values.view(N,L,-1) 82 | # Project the output and return 83 | return self.out_projection(new_values), attention_weights 84 | -------------------------------------------------------------------------------- /notebooks/linear_attention_rotary/rotate_attention/full_attention.py: -------------------------------------------------------------------------------- 1 | from math import sqrt 2 | 3 | import torch 4 | from torch.nn import Dropout, Module 5 | 6 | from fast_transformers.attention_registry import AttentionRegistry, Optional, Float, \ 7 | EventDispatcherInstance 8 | from fast_transformers.events import EventDispatcher, AttentionEvent 9 | from fast_transformers.attention.full_attention import FullAttention 10 | 11 | class FullWWeight(Module): 12 | """ 13 | Slightly modify the fast transformers Full attention to return 14 | the 'attention weights' for visual analysis 15 | """ 16 | def __init__(self, softmax_temp=None, attention_dropout=0.1, event_dispatcher=""): 17 | super(FullWWeight, self).__init__() 18 | 19 | self.softmax_temp = softmax_temp 20 | self.dropout = Dropout(attention_dropout) 21 | self.event_dispatcher = EventDispatcher.get(event_dispatcher) 22 | 23 | def forward(self, queries, keys, values, attn_mask, query_lengths, 24 | key_lengths): 25 | # Extract some shapes and compute the temperature 26 | N, L, H, E = queries.shape 27 | _, S, _, D = values.shape 28 | softmax_temp = self.softmax_temp or 1./sqrt(E) 29 | 30 | # Scale the queries instead of applying the softmax temperature to the 31 | # dot products 32 | queries = queries * softmax_temp 33 | 34 | # Compute the unnormalized attention and apply the masks 35 | QK = torch.einsum("nlhe,nshe->nhls", queries, keys) 36 | if not attn_mask.all_ones: 37 | QK = QK + attn_mask.additive_matrix 38 | if not key_lengths.all_ones: 39 | QK = QK + key_lengths.additive_matrix[:, None, None] 40 | 41 | # Compute the attention and the weighted average 42 | attention_weights = torch.softmax(QK, dim=-1) 43 | #A = self.dropout(torch.softmax(QK, dim=-1)) 44 | A = self.dropout(attention_weights) 45 | V = torch.einsum("nhls,nshd->nlhd", A, values) 46 | 47 | # Let the world know of the attention matrix 48 | self.event_dispatcher.dispatch(AttentionEvent(self, A)) 49 | 50 | # Make sure that what we return is contiguous 51 | return V.contiguous(), attention_weights.detach() 52 | 53 | 54 | # Register the attention implementation so that it becomes available in our 55 | # builders 56 | AttentionRegistry.register( 57 | "fullwweights", FullWWeight, 58 | [ 59 | ("softmax_temp", Optional(Float)), 60 | ("attention_dropout", Optional(Float, 0.1)), 61 | ("event_dispatcher", Optional(EventDispatcherInstance, "")) 62 | ] 63 | ) 64 | -------------------------------------------------------------------------------- /notebooks/linear_attention_rotary/rotate_attention/linear_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Module 3 | 4 | from fast_transformers.attention_registry import AttentionRegistry, Optional, Callable, Int, \ 5 | EventDispatcherInstance 6 | from fast_transformers.attention.linear_attention import LinearAttention 7 | from fast_transformers.events import EventDispatcher 8 | 9 | 10 | class LinearWWeight(Module): 11 | """ 12 | Slightly modify the fast transformers linear attention to return 13 | the 'attention weights' for visual analysis 14 | """ 15 | def __init__(self, query_dimensions, feature_map=None, eps=1e-6, 16 | event_dispatcher=""): 17 | super(LinearWWeight, self).__init__() 18 | 19 | self.feature_map = (feature_map(query_dimensions)) 20 | self.eps = eps 21 | self.event_dispatcher=EventDispatcher.get(event_dispatcher) 22 | def forward(self, queries, keys, values, attn_mask, query_lengths, 23 | key_lengths): 24 | # Apply the feature map to the queries and keys 25 | self.feature_map.new_feature_map(queries.device) 26 | Q = self.feature_map.forward_queries(queries) 27 | K = self.feature_map.forward_keys(keys) 28 | 29 | # Apply the key padding mask and make sure that the attn_mask is 30 | # all_ones 31 | if not attn_mask.all_ones: 32 | raise RuntimeError(("LinearAttention does not support arbitrary " 33 | "attention masks")) 34 | K = K * key_lengths.float_matrix[:, :, None, None] 35 | 36 | # Compute the KV matrix, namely the dot product of keys and values so 37 | # that we never explicitly compute the attention matrix and thus 38 | # decrease the complexity 39 | KV = torch.einsum("nshd,nshm->nhmd", K, values) 40 | 41 | # Compute the normalizer 42 | Z = 1/(torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1))+self.eps) 43 | 44 | attention = torch.einsum('nlhd, nshd->nlsh', Q, K) 45 | #try both postive and negative values 46 | 47 | 48 | attention_norm = 1/(torch.einsum('nlsh->nlh', attention+self.eps)) 49 | attention_out = torch.einsum('nlsh, nlh->nlsh', attention, attention_norm) 50 | attention_out = torch.einsum('nlsh->nhls', attention_out) 51 | #remove negative numbers with relu 52 | 53 | # Finally compute and return the new values 54 | V = torch.einsum("nlhd,nhmd,nlh->nlhm", Q, KV, Z) 55 | 56 | return V.contiguous(), attention_out.detach() 57 | 58 | 59 | # Register the attention implementation so that it becomes available in our 60 | # builders 61 | AttentionRegistry.register( 62 | "linearwweights", LinearWWeight, 63 | [ 64 | ("query_dimensions", Int), 65 | ("feature_map", Optional(Callable)), 66 | ("event_dispatcher", Optional(EventDispatcherInstance, "")) 67 | ] 68 | ) 69 | -------------------------------------------------------------------------------- /notebooks/linear_attention_rotary/rotate_attention/rotary.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class RotaryEmbedding(torch.nn.Module): 4 | 5 | def __init__(self, dim, base=10000): 6 | super().__init__() 7 | inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim)) 8 | self.register_buffer('inv_freq', inv_freq) 9 | self.seq_len_cached = 0 10 | self.cos_cached = None 11 | self.sin_cached = None 12 | 13 | def forward(self, x, seq_dim=1): 14 | seq_len = x.shape[seq_dim] 15 | if seq_len != self.seq_len_cached: 16 | #if seq_len > self.seq_len_cached: 17 | self.seq_len_cached = seq_len 18 | t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) 19 | freqs = torch.einsum('i,j->ij', t, self.inv_freq) 20 | emb = torch.cat((freqs, freqs), dim=-1).to(x.device) 21 | self.cos_cached = emb.cos()[None,:, None, :] 22 | self.sin_cached = emb.sin()[None,:, None, :] 23 | #else: 24 | # cos_return = self.cos_cached[..., :seq_len] 25 | # sin_return = self.sin_cached[..., :seq_len] 26 | # return cos_return, sin_return 27 | 28 | return self.cos_cached, self.sin_cached 29 | 30 | 31 | # rotary pos emb helpers: 32 | 33 | def rotate_half(x): 34 | x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:] 35 | return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions 36 | 37 | @torch.jit.script 38 | def apply_rotary_pos_emb(q, k, cos, sin): 39 | return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) 40 | -------------------------------------------------------------------------------- /notebooks/linear_attention_rotary/rotate_attention/rotate_builder.py: -------------------------------------------------------------------------------- 1 | from torch.nn import LayerNorm 2 | 3 | from .transformers import VizEncoderLayer, VizEncoder 4 | from .attention_layer import RotateAttentionLayer 5 | from .linear_attention import LinearWWeight 6 | from .full_attention import FullWWeight 7 | 8 | from fast_transformers.transformers import TransformerEncoder, TransformerEncoderLayer 9 | from fast_transformers.builders.base import BaseBuilder 10 | from fast_transformers.builders.transformer_builders import BaseTransformerEncoderBuilder 11 | from fast_transformers.builders.attention_builders import AttentionBuilder 12 | 13 | 14 | class RotateEncoderBuilder(BaseTransformerEncoderBuilder): 15 | """Build a batch transformer encoder with Relative Rotary embeddings 16 | for training or processing of sequences all elements at a time. 17 | 18 | Example usage: 19 | 20 | builder = RotateEncoderBuilder() 21 | builder.n_layers = 12 22 | builder.n_heads = 8 23 | builder.feed_forward_dimensions = 1024 24 | builder.query_dimensions = 64 25 | builder.value_dimensions = 64 26 | builder.dropout = 0.1 27 | builder.attention_dropout = 0.1 28 | builder.attention_type = "linear" 29 | transformer = builder.get() 30 | """ 31 | def _get_attention_builder(self): 32 | """Return an instance of the appropriate attention builder.""" 33 | return AttentionBuilder() 34 | 35 | def _get_attention_layer_class(self): 36 | """Return the class for the layer that projects queries keys and 37 | values.""" 38 | return RotateAttentionLayer 39 | 40 | def _get_encoder_class(self): 41 | """Return the class for the transformer encoder.""" 42 | return VizEncoder 43 | 44 | def _get_encoder_layer_class(self): 45 | """Return the class for the transformer encoder layer.""" 46 | return VizEncoderLayer 47 | -------------------------------------------------------------------------------- /notebooks/linear_attention_rotary/rotate_attention/transformers.py: -------------------------------------------------------------------------------- 1 | """Implement transformer encoders and decoders that are going to be used with 2 | different attention mechanisms. 3 | 4 | In all cases the batch dimension is first and the sequence dimension is second. 5 | """ 6 | 7 | import torch 8 | from torch.nn import Dropout, LayerNorm, Linear, Module, ModuleList 9 | import torch.nn.functional as F 10 | 11 | from fast_transformers.events import EventDispatcher 12 | from fast_transformers.masking import FullMask, LengthMask 13 | from fast_transformers.transformers import TransformerEncoderLayer, TransformerEncoder 14 | 15 | 16 | class VizEncoderLayer(TransformerEncoderLayer): 17 | """Self attention and feed forward network with skip connections. 18 | 19 | This transformer encoder layer implements a modification of the Fast 20 | Transformer layer code where the attention weights of the model are 21 | returned for vizualization. 22 | 23 | Arguments 24 | --------- 25 | These are unchanged from the Fast Transformer encoder layer. For 26 | further information look there. 27 | """ 28 | def __init__(self, attention, d_model, d_ff=None, dropout=0.1, 29 | activation="relu", event_dispatcher=""): 30 | super(VizEncoderLayer, self).__init__(attention, d_model, d_ff=d_ff, dropout=dropout, 31 | activation=activation, event_dispatcher=event_dispatcher) 32 | 33 | def forward(self, x, attn_mask=None, length_mask=None): 34 | """Apply the transformer encoder to the input x. 35 | 36 | Arguments 37 | --------- 38 | x: The input features of shape (N, L, E) where N is the batch size, 39 | L is the sequence length (padded) and E is d_model passed in the 40 | constructor. 41 | attn_mask: An implementation of fast_transformers.masking.BaseMask 42 | that encodes where each element of x can attend to. 43 | length_mask: An implementation of 44 | fast_transformers.masking.BaseMask that encodes how 45 | many elements each sequence in the batch consists of. 46 | """ 47 | # Normalize the masks 48 | N = x.shape[0] 49 | L = x.shape[1] 50 | attn_mask = attn_mask or FullMask(L, device=x.device) 51 | length_mask = length_mask or \ 52 | LengthMask(x.new_full((N,), L, dtype=torch.int64)) 53 | 54 | out, attention_mask = self.attention( 55 | x, x, x, 56 | attn_mask=attn_mask, 57 | query_lengths=length_mask, 58 | key_lengths=length_mask 59 | ) 60 | x = x + self.dropout(out) 61 | # Run the fully connected part of the layer 62 | y = x = self.norm1(x) 63 | y = self.dropout(self.activation(self.linear1(y))) 64 | y = self.dropout(self.linear2(y)) 65 | 66 | return self.norm2(x+y), attention_mask.detach() 67 | 68 | 69 | class VizEncoder(TransformerEncoder): 70 | 71 | """A Modifiation of TransformerEncoder where a list of attention weights 72 | are returned for vizualization purposes. 73 | 74 | Arguments 75 | --------- 76 | These are unchanged from the Fast Transformer encoder. For 77 | further information look there. 78 | """ 79 | def __init__(self, layers, norm_layer=None, event_dispatcher=""): 80 | super(VizEncoder, self).__init__(layers, norm_layer=norm_layer, event_dispatcher=event_dispatcher) 81 | 82 | def forward(self, x, attn_mask=None, length_mask=None): 83 | """Apply all transformer encoder layers to the input x. 84 | 85 | Arguments 86 | --------- 87 | x: The input features of shape (N, L, E) where N is the batch size, 88 | L is the sequence length (padded) and E is d_model passed in the 89 | constructor of each transformer encoder layer. 90 | attn_mask: An implementation of fast_transformers.masking.BaseMask 91 | that encodes where each element of x can attend to. 92 | length_mask: An implementation of 93 | fast_transformers.masking.BaseMask that encodes how 94 | many elements each sequence in the batch consists of. 95 | """ 96 | # Normalize the masks 97 | attention_mask_list = [] 98 | N = x.shape[0] 99 | L = x.shape[1] 100 | attn_mask = attn_mask or FullMask(L, device=x.device) 101 | length_mask = length_mask or \ 102 | LengthMask(x.new_full((N,), L, dtype=torch.int64)) 103 | 104 | # Apply all the transformers 105 | for layer in self.layers: 106 | x, attention_mask = layer(x, attn_mask=attn_mask, length_mask=length_mask) 107 | attention_mask_list.append(attention_mask) 108 | # Apply the normalization if needed 109 | if self.norm is not None: 110 | x = self.norm(x) 111 | 112 | return x, attention_mask_list 113 | 114 | 115 | -------------------------------------------------------------------------------- /notebooks/linear_attention_rotary/tokenizer/tokenizer.py: -------------------------------------------------------------------------------- 1 | from transformers import BertTokenizer 2 | import regex as re 3 | 4 | PATTERN = "(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])" 5 | 6 | 7 | class MolTranBertTokenizer(BertTokenizer): 8 | def __init__( 9 | self, 10 | vocab_file: str = "", 11 | do_lower_case=False, 12 | unk_token="", 13 | sep_token="", 14 | pad_token="", 15 | cls_token="", 16 | mask_token="", 17 | **kwargs 18 | ): 19 | super().__init__( 20 | vocab_file, 21 | unk_token=unk_token, 22 | sep_token=sep_token, 23 | pad_token=pad_token, 24 | cls_token=cls_token, 25 | mask_token=mask_token, 26 | **kwargs 27 | ) 28 | 29 | self.regex_tokenizer = re.compile(PATTERN) 30 | self.wordpiece_tokenizer = None 31 | self.basic_tokenizer = None 32 | 33 | def _tokenize(self, text): 34 | split_tokens = self.regex_tokenizer.findall(text) 35 | return split_tokens 36 | 37 | def convert_tokens_to_string(self, tokens): 38 | out_string = "".join(tokens).strip() 39 | return out_string 40 | -------------------------------------------------------------------------------- /notebooks/linear_attention_rotary/utils.py: -------------------------------------------------------------------------------- 1 | import socket 2 | import glob 3 | import pandas as pd 4 | from rdkit import Chem 5 | 6 | 7 | def getipaddress(): 8 | return socket.gethostbyname(socket.getfqdn()) 9 | 10 | class ListDataset: 11 | def __init__(self, seqs): 12 | self.seqs = seqs 13 | 14 | def __getitem__(self, index): 15 | return self.seqs[index] 16 | 17 | def __len__(self): 18 | return len(self.seqs) 19 | 20 | 21 | def transform_single_embedding_to_multiple(smiles_z_map): 22 | """Transforms an embedding map of the format smi->embedding to 23 | smi-> {"canonical_embeddings":embedding}. This function exists 24 | as a compatibility layer 25 | 26 | Args: 27 | smiles_z_map ([type]): [description] 28 | """ 29 | retval = dict() 30 | for key in smiles_z_map: 31 | retval[key] = {"canonical_embeddings": smiles_z_map[key]} 32 | return retval 33 | 34 | 35 | def normalize_smiles(smi, canonical, isomeric): 36 | normalized = Chem.MolToSmiles( 37 | Chem.MolFromSmiles(smi), canonical=canonical, isomericSmiles=isomeric 38 | ) 39 | return normalized 40 | 41 | 42 | def get_all_proteins(affinity_dir: str): 43 | files = glob.glob(affinity_dir + "/*.csv") 44 | all_proteins = [] 45 | print(files) 46 | for file in files: 47 | df = pd.read_csv(file) 48 | all_proteins.extend(df["protein"].tolist()) 49 | return set(all_proteins) 50 | 51 | 52 | def append_to_file(filename, line): 53 | with open(filename, "a") as f: 54 | f.write(line + "\n") 55 | 56 | 57 | def write_to_file(filename, line): 58 | with open(filename, "w") as f: 59 | f.write(line + "\n") 60 | -------------------------------------------------------------------------------- /notebooks/pretrained_molformer/args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def get_parser(parser=None): 5 | if parser is None: 6 | parser = argparse.ArgumentParser() 7 | 8 | # Model 9 | #model_arg = parser.add_argument_group('Model') 10 | parser.add_argument('--n_head', 11 | type=int, default=8, 12 | help='GPT number of heads') 13 | parser.add_argument('--n_layer', 14 | type=int, default=12, 15 | help='GPT number of layers') 16 | parser.add_argument('--q_dropout', 17 | type=float, default=0.5, 18 | help='Encoder layers dropout') 19 | parser.add_argument('--d_dropout', 20 | type=float, default=0.1, 21 | help='Decoder layers dropout') 22 | parser.add_argument('--n_embd', 23 | type=int, default=768, 24 | help='Latent vector dimensionality') 25 | parser.add_argument('--fc_h', 26 | type=int, default=512, 27 | help='Fully connected hidden dimensionality') 28 | 29 | 30 | # Train 31 | #train_arg = parser.add_argument_group('Train') 32 | parser.add_argument('--n_batch', 33 | type=int, default=512, 34 | help='Batch size') 35 | parser.add_argument('--unlike_alpha', 36 | type=float, default=1.0, 37 | help='unlikelihood loss alpha weight') 38 | parser.add_argument('--from_scratch', 39 | action='store_true', default=False, 40 | help='train on qm9 from scratch') 41 | parser.add_argument('--unlikelihood', 42 | action='store_true', default=False, 43 | help='use unlikelihood loss with gpt pretrain') 44 | parser.add_argument('--grad_acc', 45 | type=int, default=1, 46 | help='number of batches to accumulate gradients') 47 | parser.add_argument('--checkpoint_every', 48 | type=int, default=1000, 49 | help='save checkpoint every x iterations') 50 | parser.add_argument('--clip_grad', 51 | type=int, default=50, 52 | help='Clip gradients to this value') 53 | parser.add_argument('--lr_start', 54 | type=float, default=3 * 1e-4, 55 | help='Initial lr value') 56 | parser.add_argument('--lr_end', 57 | type=float, default=3 * 1e-4, 58 | help='Maximum lr weight value') 59 | parser.add_argument('--lr_multiplier', 60 | type=int, default=1, 61 | help='lr weight multiplier') 62 | parser.add_argument('--n_last', 63 | type=int, default=1000, 64 | help='Number of iters to smooth loss calc') 65 | parser.add_argument('--n_jobs', 66 | type=int, default=1, 67 | help='Number of threads') 68 | parser.add_argument('--accelerator', 69 | type=str, default='ddp', 70 | help='The accelerator backend to use (previously known as distributed_backend)') 71 | parser.add_argument('--num_nodes', 72 | type=int, default=1, 73 | help='number of GPU nodes for distributed training') 74 | parser.add_argument('--device', 75 | type=str, default='cuda', 76 | help='Device to run: "cpu" or "cuda:"') 77 | parser.add_argument('--seed', 78 | type=int, default=12345, 79 | help='Seed') 80 | 81 | #common_arg = parser.add_argument_group('Common') 82 | parser.add_argument('--vocab_load', 83 | type=str, required=False, 84 | help='Where to load the vocab') 85 | parser.add_argument('--n_samples', 86 | type=int, required=False, 87 | help='Number of samples to sample') 88 | parser.add_argument('--gen_save', 89 | type=str, required=False, 90 | help='Where to save the gen molecules') 91 | parser.add_argument("--max_len", 92 | type=int, default=100, 93 | help="Max of length of SMILES") 94 | parser.add_argument('--train_load', 95 | type=str, required=False, 96 | help='Where to load the model') 97 | parser.add_argument('--val_load', 98 | type=str, required=False, 99 | help='Where to load the model') 100 | parser.add_argument('--n_workers', 101 | type=int, required=False, default=1, 102 | help='Where to load the model') 103 | #beam search hyper parameters 104 | parser.add_argument('--beam_size', type=int, default=0, 105 | help="Number of beams to generate") 106 | parser.add_argument('--num_seq_returned', type=int, default=0, 107 | help="number of beams to be returned (must be <= beam_size") 108 | parser.add_argument('--min_len', type=int, default=1, 109 | help="minimum length to be generated") 110 | parser.add_argument('--nucleus_thresh', type=float, default=.9, 111 | help="nucleus sampling threshold") 112 | parser.add_argument('--finetune_path', 113 | type=str, default="", 114 | help='path to trainer file to continue training') 115 | parser.add_argument('--restart_path', 116 | type=str, default="", 117 | help='path to trainer file to continue training') 118 | parser.add_argument('--data_path', 119 | type=str, default="", 120 | help='path to pubchem file') 121 | parser.add_argument('--pretext_size', 122 | type=int, default=0, 123 | help='number of k-mers to pretext') 124 | parser.add_argument('--model_save_dir', 125 | type=str, required=False, default='./models_dump/', 126 | help='Where to save the models/log/config/vocab') 127 | parser.add_argument('--model_save', 128 | type=str, required=False, default='model.pt', 129 | help='Where to save the model') 130 | #parser.add_argument('--save_frequency', 131 | # type=int, default=20, 132 | # help='How often to save the model') 133 | parser.add_argument('--num_epoch', 134 | type=int, default=1, 135 | help='number of epochs to train') 136 | #parser.add_argument('--num_iter', 137 | # type=int, default=-1, 138 | # help='how many itersations per epoch (for unlikelihood tuning)') 139 | parser.add_argument('--log_file', 140 | type=str, required=False, 141 | help='Where to save the log') 142 | parser.add_argument('--tb_loc', 143 | type=str, required=False, 144 | help='Where to save the tensorflow location') 145 | parser.add_argument('--config_save', 146 | type=str, required=False, 147 | help='Where to save the config') 148 | parser.add_argument('--vocab_save', 149 | type=str, 150 | help='Where to save the vocab') 151 | 152 | # resume_arg = parser.add_argument_group('Resume') 153 | parser.add_argument('--debug', 154 | default=False, action='store_true', 155 | help='do not erase cache at end of program') 156 | parser.add_argument('--fast_dev_run', 157 | default=False, 158 | help='This flag runs a “unit test” by running n if set to n (int) else 1 if set to True training and validation batch(es).') 159 | parser.add_argument('--freeze_model', 160 | default=False, action='store_true', 161 | help='freeze weights of bert model during fine tuning') 162 | parser.add_argument('--resume', 163 | default=False, action='store_true', 164 | help='Resume from a saved model') 165 | parser.add_argument('--rotate', 166 | default=False, action='store_true', 167 | help='use rotational relative embedding') 168 | parser.add_argument('--model_load', 169 | type=str, required=False, 170 | help='Where to load the model') 171 | parser.add_argument('--root_dir', 172 | type=str, required=False, default='.', 173 | help='location of root dir') 174 | parser.add_argument('--config_load', 175 | type=str, required=False, 176 | help='Where to load the config') 177 | parser.add_argument('--gpus', 178 | type=int, required=False, default=1, 179 | help='number of gpus to use') 180 | #parser.add_argument('--start_epoch', 181 | # type=int, required=False, default=0, 182 | # help='Where to load the config') 183 | 184 | parser.add_argument('--model_arch', 185 | type=str, required=False, 186 | help='used to teack model arch in params') 187 | parser.add_argument('--eval_every', 188 | type=int, default=50000, 189 | help='run evaluation every x iterations') 190 | parser.add_argument('--num_feats', 191 | type=int, required=False, default=32, 192 | help='number of random reatures for FAVOR+') 193 | parser.add_argument('--max_epochs', 194 | type=int, required=False, default=1, 195 | help='max number of epochs') 196 | 197 | # debug() FINE TUNEING 198 | # parser.add_argument('--save_dir', type=str, required=True) 199 | parser.add_argument('--mode', 200 | type=str, default='cls', 201 | help='type of pooling to use') 202 | parser.add_argument("--dataset_length", type=int, default=None, required=False) 203 | parser.add_argument("--num_workers", type=int, default=0, required=False) 204 | parser.add_argument("--dropout", type=float, default=0.1, required=False) 205 | #parser.add_argument("--dims", type=int, nargs="*", default="", required=False) 206 | parser.add_argument( 207 | "--smiles_embedding", 208 | type=str, 209 | default="/dccstor/medscan7/smallmolecule/runs/ba-predictor/small-data/embeddings/protein/ba_embeddings_tanh_512_2986138_2.pt", 210 | ) 211 | # parser.add_argument("--train_pct", type=str, required=False, default="95") 212 | #parser.add_argument("--aug", type=int, required=True) 213 | parser.add_argument("--dataset_name", type=str, required=False, default="sol") 214 | parser.add_argument("--measure_name", type=str, required=False, default="measure") 215 | #parser.add_argument("--emb_type", type=str, required=True) 216 | #parser.add_argument("--checkpoints_folder", type=str, required=True) 217 | #parser.add_argument("--results_dir", type=str, required=True) 218 | #parser.add_argument("--patience_epochs", type=int, required=True) 219 | 220 | parser.add_argument( 221 | "--data_root", 222 | type=str, 223 | required=False, 224 | default="/dccstor/medscan7/smallmolecule/runs/ba-predictor/small-data/affinity", 225 | ) 226 | # parser.add_argument("--use_bn", type=int, default=0) 227 | parser.add_argument("--use_linear", type=int, default=0) 228 | 229 | parser.add_argument("--lr", type=float, default=0.001) 230 | # parser.add_argument("--weight_decay", type=float, default=5e-4) 231 | # parser.add_argument("--val_check_interval", type=float, default=1.0) 232 | parser.add_argument("--batch_size", type=int, default=64) 233 | 234 | return parser 235 | def parse_args(): 236 | parser = get_parser() 237 | args = parser.parse_args() 238 | return args 239 | 240 | -------------------------------------------------------------------------------- /notebooks/pretrained_molformer/pubchem_canon_zinc_final_vocab_sorted.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ningliu-iga/TrinityLLM/5659e6e7d7dfb748a338a88f1da1e3f85a2676f7/notebooks/pretrained_molformer/pubchem_canon_zinc_final_vocab_sorted.pth -------------------------------------------------------------------------------- /notebooks/pretrained_molformer/pubchem_encoder.py: -------------------------------------------------------------------------------- 1 | import regex as re 2 | import torch 3 | import numpy as np 4 | import random 5 | import collections 6 | 7 | class Encoder(): 8 | 9 | def __init__(self, max_length=500, add_bos=True, add_eos=True, feature_size=32): 10 | self.vocab_encoder = torch.load('pubchem_canon_zinc_final_vocab_sorted.pth') 11 | 12 | self.max_length = max_length 13 | self.min_length = 1 14 | self.mod_length = 42 15 | self.mlm_probability = .15 16 | self.avg_length = 66 17 | self.tail = 122 18 | self.b0_cache=collections.deque() 19 | self.b1_cache=collections.deque() 20 | self.b2_cache=collections.deque() 21 | self.b3_cache=collections.deque() 22 | self.bucket0=collections.deque() 23 | self.bucket1=collections.deque() 24 | self.bucket2=collections.deque() 25 | self.bucket3=collections.deque() 26 | if feature_size == 32: 27 | self.b0_max=1100 28 | self.b1_max=700 29 | self.b2_max=150 30 | self.b3_max=50 31 | else: 32 | self.b0_max=1382 33 | self.b1_max=871 34 | self.b2_max=516 35 | self.b3_max=311 36 | values = list(self.vocab_encoder.values()) 37 | num_top = 0 38 | middle_top = 0 39 | bottom = 0 40 | for count in values: 41 | if count > 100000: 42 | num_top += 1 43 | if count > 50: 44 | middle_top += 1 45 | middle_top = middle_top - num_top 46 | self.cutoffs = [num_top+4, middle_top] 47 | self.char2id = {"":0, "":1, "":2, "":3} 48 | self.id2char = {0:"", 1:"", 2:"", 3:""} 49 | self.pad = self.char2id[''] 50 | self.mask = self.char2id[''] 51 | self.eos = self.char2id[''] 52 | self.bos = self.char2id[''] 53 | pos = 0 54 | for key, value in self.vocab_encoder.items(): 55 | #for pos, key in enumerate(self.vocab_encoder.keys()): 56 | self.char2id[key] = pos+4 57 | self.id2char[pos+4] = key 58 | pos += 1 59 | self.char2id[""] = pos + 4 60 | self.id2char[pos+4] = "" 61 | self.pattern = "(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])" 62 | self.regex = re.compile(self.pattern) 63 | self.add_bos = add_bos 64 | self.add_eos = add_eos 65 | 66 | def encode(self, char): 67 | #if len(char) > self.max_length: 68 | # char = char[:self.max_length] 69 | if self.add_bos == True: 70 | char = [''] + char 71 | if self.add_eos == True: 72 | char = char + [''] 73 | 74 | return torch.tensor([self.char2id[word] for word in char]) 75 | 76 | def encoder(self, tokens): 77 | #return *map(lambda x: self.encode(x), tokens) 78 | return [self.encode(mol) for mol in tokens] 79 | 80 | def process_text(self, text): 81 | #print(text) 82 | #random length sequences seems to help training 83 | mod_length = self.mod_length #+ random.randint(-1, 3) 84 | avg_length = self.avg_length #+ random.randint(-3, 5) 85 | for mol in text: 86 | #fill up buckets and caches 87 | if '\n' in mol['text']: 88 | print('carriage return in mol') 89 | raw_regex = self.regex.findall(mol['text'].strip('\n')) 90 | length = len(raw_regex) 91 | if length > self.min_length and length < mod_length: 92 | if len(self.bucket0) < self.b0_max: 93 | self.bucket0.append(raw_regex) 94 | else: 95 | self.b0_cache.append(raw_regex) 96 | elif length >= mod_length and length < avg_length: 97 | if len(self.bucket1) < self.b1_max: 98 | self.bucket1.append(raw_regex) 99 | else: 100 | self.b1_cache.append(raw_regex) 101 | elif length >= avg_length and length < self.tail: 102 | self.b2_cache.append(raw_regex) 103 | #if len(bucket2) < self.b2_max: 104 | # bucket2.append(raw_regex) 105 | #else: 106 | # self.b2_cache.append(raw_regex) 107 | elif length >= self.tail and length < self.max_length: 108 | self.b3_cache.append(raw_regex) 109 | #if len(bucket3) < self.b3_max: 110 | # bucket3.append(raw_regex) 111 | #else: 112 | # self.b3_cache.append(raw_regex) 113 | 114 | #print('before Cache size {} {} {} {}'.format(len(self.b0_cache), len(self.b1_cache), len(self.b2_cache), len(self.b3_cache))) 115 | #pour cache elements into any open bucket 116 | if len(self.bucket0) < self.b0_max and len(self.b0_cache) > 0: 117 | cache_size = len(self.b0_cache) 118 | max_margin = self.b0_max-len(self.bucket0) 119 | range0 = min(cache_size, max_margin) 120 | outbucket0 = [self.bucket0.pop() for item in range(len(self.bucket0))] + [self.b0_cache.pop() for i in range(range0)] 121 | #self.b0_cache = collections.deque(self.b0_cache[:self.b0_max-len(bucket0)]) 122 | #print('0 type {}'.format(type(self.b0_cache))) 123 | else: 124 | outbucket0 = [self.bucket0.pop() for item in range(len(self.bucket0))] 125 | if len(self.bucket1) < self.b1_max and len(self.b1_cache) > 0: 126 | cache_size = len(self.b1_cache) 127 | max_margin = self.b1_max-len(self.bucket1) 128 | range1 = min(cache_size, max_margin) 129 | outbucket1 = [self.bucket1.pop() for item in range(len(self.bucket1))] + [self.b1_cache.pop() for i in range(range1)] 130 | else: 131 | outbucket1 = [self.bucket1.pop() for item in range(len(self.bucket1))] 132 | 133 | if len(self.b2_cache) > self.b2_max: 134 | cache_size = len(self.b2_cache) 135 | max_margin = self.b2_max 136 | range2 = min(cache_size, max_margin) 137 | outbucket2 = [self.b2_cache.pop() for i in range(range2)] 138 | else: 139 | outbucket2=[] 140 | if len(self.b3_cache) > self.b3_max: 141 | cache_size = len(self.b3_cache) 142 | max_margin = self.b3_max 143 | range3 = min(cache_size, max_margin) 144 | outbucket3 = [self.b3_cache.pop() for i in range(range3)] 145 | else: 146 | outbucket3 = [] 147 | return outbucket0, outbucket1, outbucket2, outbucket3 148 | 149 | def mask_tokens( self, inputs, special_tokens_mask= None): 150 | """ 151 | Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. 152 | """ 153 | labels = inputs.clone() 154 | # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`) 155 | probability_matrix = torch.full(labels.size(), self.mlm_probability) 156 | if special_tokens_mask is None: 157 | special_tokens_mask = [ 158 | self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist() 159 | ] 160 | special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool) 161 | else: 162 | special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool) 163 | #special_tokens_mask = special_tokens_mask.bool() 164 | 165 | #print(special_tokens_mask.size()) 166 | probability_matrix.masked_fill_(special_tokens_mask, value=0.0) 167 | masked_indices = torch.bernoulli(probability_matrix).bool() 168 | labels[~masked_indices] = -100 # We only compute loss on masked tokens 169 | 170 | # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) 171 | indices_replaced = torch.bernoulli(torch.full(labels.size(), 0.8)).bool() & masked_indices 172 | inputs[indices_replaced] = self.mask 173 | 174 | # 10% of the time, we replace masked input tokens with random word 175 | indices_random = torch.bernoulli(torch.full(labels.size(), 0.5)).bool() & masked_indices & ~indices_replaced 176 | random_words = torch.randint(len(self.char2id.keys()), labels.size(), dtype=torch.long) 177 | inputs[indices_random] = random_words[indices_random] 178 | 179 | # The rest of the time (10% of the time) we keep the masked input tokens unchanged 180 | return inputs, labels 181 | def pack_tensors(self, tokens): 182 | array = self.encoder(tokens) 183 | array = torch.nn.utils.rnn.pad_sequence(array, batch_first=True, padding_value=self.pad) 184 | #lengths = (array!=self.pad).sum(dim=-1) 185 | #Bert tokenization 186 | special_token_mask = [list(map(lambda x: 1 if x in [self.bos, self.eos, self.pad] else 0, stuff)) for stuff in array.tolist()] 187 | masked_array, masked_labels = self.mask_tokens(array, special_token_mask) 188 | return masked_array, masked_labels#, lengths 189 | def process(self, text): 190 | arrays = [] 191 | #lengths = [] 192 | targets = [] 193 | for tokens in self.process_text(text): 194 | if len(tokens) > 0: 195 | array, target = self.pack_tensors(tokens) 196 | arrays.append(array) 197 | targets.append(target) 198 | return arrays, targets 199 | 200 | if __name__ == '__main__': 201 | 202 | text_encoder = Encoder() 203 | -------------------------------------------------------------------------------- /notebooks/pretrained_molformer/rotate_attention/attention_layer.py: -------------------------------------------------------------------------------- 1 | """The rotate attention layer performs all the query key value projections and 2 | output projections leaving the implementation of the attention to the inner 3 | attention module. 4 | """ 5 | 6 | from torch.nn import Linear, Module 7 | 8 | from fast_transformers.attention import AttentionLayer 9 | from fast_transformers.events import EventDispatcher, QKVEvent 10 | from .rotary import RotaryEmbedding, apply_rotary_pos_emb 11 | 12 | class RotateAttentionLayer(AttentionLayer): 13 | """Rotate attention layer inherits from fast_transformer attention layer. 14 | The only thing added is an Embedding encoding, for more information 15 | on the attention layer see the fast_transformers code 16 | """ 17 | def __init__(self, attention, d_model, n_heads, d_keys=None, 18 | d_values=None, event_dispatcher=""): 19 | super(RotateAttentionLayer, self).__init__(attention,d_model, n_heads, d_keys=d_keys, 20 | d_values=d_values, event_dispatcher=event_dispatcher) 21 | 22 | self.rotaryemb = RotaryEmbedding(d_keys) 23 | print('Using Rotation Embedding') 24 | 25 | def forward(self, queries, keys, values, attn_mask, query_lengths, 26 | key_lengths): 27 | """ 28 | Using the same frame work as the fast_Transformers attention layer 29 | but injecting rotary information to the queries and the keys 30 | after the keys and queries are projected. 31 | In the argument description we make use of the following sizes 32 | 33 | - N: the batch size 34 | - L: The maximum length of the queries 35 | - S: The maximum length of the keys (the actual length per sequence 36 | is given by the length mask) 37 | - D: The input feature dimensionality passed in the constructor as 38 | 'd_model' 39 | 40 | Arguments 41 | --------- 42 | queries: (N, L, D) The tensor containing the queries 43 | keys: (N, S, D) The tensor containing the keys 44 | values: (N, S, D) The tensor containing the values 45 | attn_mask: An implementation of BaseMask that encodes where each 46 | query can attend to 47 | query_lengths: An implementation of BaseMask that encodes how 48 | many queries each sequence in the batch consists of 49 | key_lengths: An implementation of BaseMask that encodes how 50 | many queries each sequence in the batch consists of 51 | 52 | Returns 53 | ------- 54 | The new value for each query as a tensor of shape (N, L, D). 55 | """ 56 | # Extract the dimensions into local variables 57 | N, L, _ = queries.shape 58 | _, S, _ = keys.shape 59 | H = self.n_heads 60 | 61 | # Project the queries/keys/values 62 | queries = self.query_projection(queries).view(N, L, H, -1) 63 | keys = self.key_projection(keys).view(N, S, H, -1) 64 | cos, sin = self.rotaryemb(queries) 65 | queries, keys = apply_rotary_pos_emb(queries, keys, cos, sin) 66 | values = self.value_projection(values).view(N, S, H, -1) 67 | # Let the world know of the qkv 68 | self.event_dispatcher.dispatch(QKVEvent(self, queries, keys, values)) 69 | 70 | 71 | # Compute the attention 72 | new_values = self.inner_attention( 73 | queries, 74 | keys, 75 | values, 76 | attn_mask, 77 | query_lengths, 78 | key_lengths 79 | ).view(N, L, -1) 80 | 81 | # Project the output and return 82 | return self.out_projection(new_values) 83 | -------------------------------------------------------------------------------- /notebooks/pretrained_molformer/rotate_attention/rotary.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class RotaryEmbedding(torch.nn.Module): 4 | 5 | def __init__(self, dim, base=10000): 6 | super().__init__() 7 | inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim)) 8 | self.register_buffer('inv_freq', inv_freq) 9 | self.seq_len_cached = 0 10 | self.cos_cached = None 11 | self.sin_cached = None 12 | 13 | def forward(self, x, seq_dim=1): 14 | seq_len = x.shape[seq_dim] 15 | if seq_len != self.seq_len_cached: 16 | #if seq_len > self.seq_len_cached: 17 | self.seq_len_cached = seq_len 18 | t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) 19 | freqs = torch.einsum('i,j->ij', t, self.inv_freq) 20 | emb = torch.cat((freqs, freqs), dim=-1).to(x.device) 21 | self.cos_cached = emb.cos()[None,:, None, :] 22 | self.sin_cached = emb.sin()[None,:, None, :] 23 | #else: 24 | # cos_return = self.cos_cached[..., :seq_len] 25 | # sin_return = self.sin_cached[..., :seq_len] 26 | # return cos_return, sin_return 27 | 28 | return self.cos_cached, self.sin_cached 29 | 30 | 31 | # rotary pos emb helpers: 32 | 33 | def rotate_half(x): 34 | x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:] 35 | return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions 36 | 37 | @torch.jit.script 38 | def apply_rotary_pos_emb(q, k, cos, sin): 39 | return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) 40 | -------------------------------------------------------------------------------- /notebooks/pretrained_molformer/rotate_attention/rotate_builder.py: -------------------------------------------------------------------------------- 1 | from torch.nn import LayerNorm 2 | 3 | from .attention_layer import RotateAttentionLayer 4 | from fast_transformers.transformers import TransformerEncoder, TransformerEncoderLayer 5 | from fast_transformers.builders.base import BaseBuilder 6 | from fast_transformers.builders.transformer_builders import BaseTransformerEncoderBuilder 7 | from fast_transformers.builders.attention_builders import AttentionBuilder 8 | 9 | 10 | class RotateEncoderBuilder(BaseTransformerEncoderBuilder): 11 | """Build a batch transformer encoder with Relative Rotary embeddings 12 | for training or processing of sequences all elements at a time. 13 | 14 | Example usage: 15 | 16 | builder = RotateEncoderBuilder() 17 | builder.n_layers = 12 18 | builder.n_heads = 8 19 | builder.feed_forward_dimensions = 1024 20 | builder.query_dimensions = 64 21 | builder.value_dimensions = 64 22 | builder.dropout = 0.1 23 | builder.attention_dropout = 0.1 24 | builder.attention_type = "linear" 25 | transformer = builder.get() 26 | """ 27 | def _get_attention_builder(self): 28 | """Return an instance of the appropriate attention builder.""" 29 | return AttentionBuilder() 30 | 31 | def _get_attention_layer_class(self): 32 | """Return the class for the layer that projects queries keys and 33 | values.""" 34 | return RotateAttentionLayer 35 | 36 | def _get_encoder_class(self): 37 | """Return the class for the transformer encoder.""" 38 | return TransformerEncoder 39 | 40 | def _get_encoder_layer_class(self): 41 | """Return the class for the transformer encoder layer.""" 42 | return TransformerEncoderLayer 43 | -------------------------------------------------------------------------------- /notebooks/pretrained_molformer/tokenizer/tokenizer.py: -------------------------------------------------------------------------------- 1 | from transformers import BertTokenizer 2 | import regex as re 3 | 4 | PATTERN = "(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])" 5 | 6 | 7 | class MolTranBertTokenizer(BertTokenizer): 8 | def __init__(self, vocab_file: str = '', 9 | do_lower_case=False, 10 | unk_token='', 11 | sep_token='', 12 | pad_token='', 13 | cls_token='', 14 | mask_token='', 15 | **kwargs): 16 | super().__init__(vocab_file, 17 | unk_token=unk_token, 18 | sep_token=sep_token, 19 | pad_token=pad_token, 20 | cls_token=cls_token, 21 | mask_token=mask_token, 22 | **kwargs) 23 | 24 | self.regex_tokenizer = re.compile(PATTERN) 25 | self.wordpiece_tokenizer = None 26 | self.basic_tokenizer = None 27 | 28 | def _tokenize(self, text): 29 | split_tokens = self.regex_tokenizer.findall(text) 30 | return split_tokens 31 | 32 | def convert_tokens_to_string(self, tokens): 33 | out_string = "".join(tokens).strip() 34 | return out_string 35 | -------------------------------------------------------------------------------- /training/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:9.0-cudnn7-devel-ubuntu16.04 2 | 3 | ENV DEBIAN_FRONTEND noninteractive 4 | ENV TERM linux 5 | 6 | RUN mkdir /moses 7 | COPY . /moses 8 | 9 | RUN set -ex \ 10 | && apt-get update -yqq \ 11 | && apt-get upgrade -yqq \ 12 | && apt-get install -yqq --no-install-recommends \ 13 | git wget curl ssh libxrender1 libxext6 software-properties-common apt-utils \ 14 | && wget --no-check-certificate https://repo.continuum.io/miniconda/Miniconda3-4.6.14-Linux-x86_64.sh \ 15 | && /bin/bash Miniconda3-4.6.14-Linux-x86_64.sh -f -b -p /opt/miniconda \ 16 | && add-apt-repository ppa:git-core/ppa \ 17 | && (curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | bash) \ 18 | && apt-get install git-lfs \ 19 | && git lfs install \ 20 | && apt-get clean \ 21 | && /opt/miniconda/bin/conda install conda=4.6.14=py36_0 \ 22 | && /opt/miniconda/bin/conda clean -yq -a \ 23 | && rm Miniconda3-4.6.14-Linux-x86_64.sh \ 24 | && rm -rf \ 25 | /tmp/* \ 26 | /var/tmp/* \ 27 | /usr/share/man \ 28 | /usr/share/doc \ 29 | /usr/share/doc-base 30 | 31 | ENV PATH /opt/miniconda/bin:$PATH 32 | RUN conda install -yq numpy=1.16.0 scipy=1.2.0 matplotlib=3.0.1 pandas=0.23.3 scikit-learn=0.20.3 tqdm>=4.26.0 \ 33 | && conda install -yq -c rdkit rdkit=2019.03.2 \ 34 | && conda install -yq -c pytorch pytorch=1.1.0 torchvision=0.2.1 \ 35 | && conda clean -yq -a 36 | 37 | WORKDIR /moses 38 | RUN python setup.py install && git lfs pull && conda clean -yq -a && rm -rf .git/lfs 39 | 40 | CMD [ "/bin/bash" ] 41 | -------------------------------------------------------------------------------- /training/args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def get_parser(parser=None): 5 | if parser is None: 6 | parser = argparse.ArgumentParser() 7 | 8 | # Model 9 | #model_arg = parser.add_argument_group('Model') 10 | parser.add_argument('--n_head', 11 | type=int, default=8, 12 | help='GPT number of heads') 13 | parser.add_argument('--n_layer', 14 | type=int, default=12, 15 | help='GPT number of layers') 16 | parser.add_argument('--q_dropout', 17 | type=float, default=0.5, 18 | help='Encoder layers dropout') 19 | parser.add_argument('--d_dropout', 20 | type=float, default=0.1, 21 | help='Decoder layers dropout') 22 | parser.add_argument('--n_embd', 23 | type=int, default=768, 24 | help='Latent vector dimensionality') 25 | parser.add_argument('--fc_h', 26 | type=int, default=512, 27 | help='Fully connected hidden dimensionality') 28 | 29 | 30 | # Train 31 | #train_arg = parser.add_argument_group('Train') 32 | parser.add_argument('--n_batch', 33 | type=int, default=512, 34 | help='Batch size') 35 | parser.add_argument('--unlike_alpha', 36 | type=float, default=1.0, 37 | help='unlikelihood loss alpha weight') 38 | parser.add_argument('--from_scratch', 39 | action='store_true', default=False, 40 | help='train on qm9 from scratch') 41 | parser.add_argument('--unlikelihood', 42 | action='store_true', default=False, 43 | help='use unlikelihood loss with gpt pretrain') 44 | parser.add_argument('--grad_acc', 45 | type=int, default=1, 46 | help='number of batches to accumulate gradients') 47 | parser.add_argument('--checkpoint_every', 48 | type=int, default=1000, 49 | help='save checkpoint every x iterations') 50 | parser.add_argument('--clip_grad', 51 | type=int, default=50, 52 | help='Clip gradients to this value') 53 | parser.add_argument('--lr_start', 54 | type=float, default=3 * 1e-4, 55 | help='Initial lr value') 56 | parser.add_argument('--lr_end', 57 | type=float, default=3 * 1e-4, 58 | help='Maximum lr weight value') 59 | parser.add_argument('--lr_multiplier', 60 | type=int, default=1, 61 | help='lr weight multiplier') 62 | parser.add_argument('--n_last', 63 | type=int, default=1000, 64 | help='Number of iters to smooth loss calc') 65 | parser.add_argument('--n_jobs', 66 | type=int, default=1, 67 | help='Number of threads') 68 | parser.add_argument('--accelerator', 69 | type=str, default='ddp', 70 | help='The accelerator backend to use (previously known as distributed_backend)') 71 | parser.add_argument('--num_nodes', 72 | type=int, default=1, 73 | help='number of GPU nodes for distributed training') 74 | parser.add_argument('--device', 75 | type=str, default='cuda', 76 | help='Device to run: "cpu" or "cuda:"') 77 | parser.add_argument('--seed', 78 | type=int, default=12345, 79 | help='Seed') 80 | 81 | #common_arg = parser.add_argument_group('Common') 82 | parser.add_argument('--vocab_load', 83 | type=str, required=False, 84 | help='Where to load the vocab') 85 | parser.add_argument('--n_samples', 86 | type=int, required=False, 87 | help='Number of samples to sample') 88 | parser.add_argument('--gen_save', 89 | type=str, required=False, 90 | help='Where to save the gen molecules') 91 | parser.add_argument("--max_len", 92 | type=int, default=100, 93 | help="Max of length of SMILES") 94 | parser.add_argument('--train_load', 95 | type=str, required=False, 96 | help='Where to load the model') 97 | parser.add_argument('--val_load', 98 | type=str, required=False, 99 | help='Where to load the model') 100 | parser.add_argument('--n_workers', 101 | type=int, required=False, default=1, 102 | help='Where to load the model') 103 | #beam search hyper parameters 104 | parser.add_argument('--beam_size', type=int, default=0, 105 | help="Number of beams to generate") 106 | parser.add_argument('--num_seq_returned', type=int, default=0, 107 | help="number of beams to be returned (must be <= beam_size") 108 | parser.add_argument('--min_len', type=int, default=1, 109 | help="minimum length to be generated") 110 | parser.add_argument('--nucleus_thresh', type=float, default=.9, 111 | help="nucleus sampling threshold") 112 | parser.add_argument('--finetune_path', 113 | type=str, default="", 114 | help='path to trainer file to continue training') 115 | parser.add_argument('--restart_path', 116 | type=str, default="", 117 | help='path to trainer file to continue training') 118 | parser.add_argument('--data_path', 119 | type=str, default="", 120 | help='path to pubchem file') 121 | parser.add_argument('--pretext_size', 122 | type=int, default=0, 123 | help='number of k-mers to pretext') 124 | parser.add_argument('--model_save_dir', 125 | type=str, required=False, default='./models_dump/', 126 | help='Where to save the models/log/config/vocab') 127 | parser.add_argument('--model_save', 128 | type=str, required=False, default='model.pt', 129 | help='Where to save the model') 130 | #parser.add_argument('--save_frequency', 131 | # type=int, default=20, 132 | # help='How often to save the model') 133 | parser.add_argument('--num_epoch', 134 | type=int, default=1, 135 | help='number of epochs to train') 136 | #parser.add_argument('--num_iter', 137 | # type=int, default=-1, 138 | # help='how many itersations per epoch (for unlikelihood tuning)') 139 | parser.add_argument('--log_file', 140 | type=str, required=False, 141 | help='Where to save the log') 142 | parser.add_argument('--tb_loc', 143 | type=str, required=False, 144 | help='Where to save the tensorflow location') 145 | parser.add_argument('--config_save', 146 | type=str, required=False, 147 | help='Where to save the config') 148 | parser.add_argument('--vocab_save', 149 | type=str, 150 | help='Where to save the vocab') 151 | 152 | # resume_arg = parser.add_argument_group('Resume') 153 | parser.add_argument('--debug', 154 | default=False, action='store_true', 155 | help='do not erase cache at end of program') 156 | parser.add_argument('--fast_dev_run', 157 | default=False, 158 | help='This flag runs a “unit test” by running n if set to n (int) else 1 if set to True training and validation batch(es).') 159 | parser.add_argument('--freeze_model', 160 | default=False, action='store_true', 161 | help='freeze weights of bert model during fine tuning') 162 | parser.add_argument('--resume', 163 | default=False, action='store_true', 164 | help='Resume from a saved model') 165 | parser.add_argument('--rotate', 166 | default=False, action='store_true', 167 | help='use rotational relative embedding') 168 | parser.add_argument('--model_load', 169 | type=str, required=False, 170 | help='Where to load the model') 171 | parser.add_argument('--root_dir', 172 | type=str, required=False, default='.', 173 | help='location of root dir') 174 | parser.add_argument('--config_load', 175 | type=str, required=False, 176 | help='Where to load the config') 177 | parser.add_argument('--gpus', 178 | type=int, required=False, default=1, 179 | help='number of gpus to use') 180 | #parser.add_argument('--start_epoch', 181 | # type=int, required=False, default=0, 182 | # help='Where to load the config') 183 | 184 | parser.add_argument('--model_arch', 185 | type=str, required=False, 186 | help='used to teack model arch in params') 187 | parser.add_argument('--eval_every', 188 | type=int, default=50000, 189 | help='run evaluation every x iterations') 190 | parser.add_argument('--num_feats', 191 | type=int, required=False, default=32, 192 | help='number of random reatures for FAVOR+') 193 | parser.add_argument('--max_epochs', 194 | type=int, required=False, default=1, 195 | help='max number of epochs') 196 | 197 | # debug() FINE TUNEING 198 | # parser.add_argument('--save_dir', type=str, required=True) 199 | parser.add_argument('--mode', 200 | type=str, default='cls', 201 | help='type of pooling to use') 202 | parser.add_argument("--dataset_length", type=int, default=None, required=False) 203 | parser.add_argument("--num_workers", type=int, default=0, required=False) 204 | parser.add_argument("--dropout", type=float, default=0.1, required=False) 205 | #parser.add_argument("--dims", type=int, nargs="*", default="", required=False) 206 | parser.add_argument( 207 | "--smiles_embedding", 208 | type=str, 209 | default="/dccstor/medscan7/smallmolecule/runs/ba-predictor/small-data/embeddings/protein/ba_embeddings_tanh_512_2986138_2.pt", 210 | ) 211 | # parser.add_argument("--train_pct", type=str, required=False, default="95") 212 | #parser.add_argument("--aug", type=int, required=True) 213 | parser.add_argument("--dataset_name", type=str, required=False, default="sol") 214 | parser.add_argument("--measure_name", type=str, required=False, default="measure") 215 | #parser.add_argument("--emb_type", type=str, required=True) 216 | #parser.add_argument("--checkpoints_folder", type=str, required=True) 217 | #parser.add_argument("--results_dir", type=str, required=True) 218 | #parser.add_argument("--patience_epochs", type=int, required=True) 219 | 220 | parser.add_argument( 221 | "--data_root", 222 | type=str, 223 | required=False, 224 | default="/dccstor/medscan7/smallmolecule/runs/ba-predictor/small-data/affinity", 225 | ) 226 | # parser.add_argument("--use_bn", type=int, default=0) 227 | parser.add_argument("--use_linear", type=int, default=0) 228 | 229 | parser.add_argument("--lr", type=float, default=0.001) 230 | # parser.add_argument("--weight_decay", type=float, default=5e-4) 231 | # parser.add_argument("--val_check_interval", type=float, default=1.0) 232 | parser.add_argument("--batch_size", type=int, default=64) 233 | 234 | return parser 235 | def parse_args(): 236 | parser = get_parser() 237 | args = parser.parse_args() 238 | return args 239 | 240 | -------------------------------------------------------------------------------- /training/continue_pubchem_light_distributed_multinode_bert_both.sh: -------------------------------------------------------------------------------- 1 | source /opt/share/anaconda3-2019.03/x86_64/etc/profile.d/conda.sh 2 | conda activate /dccstor/bmbelgod1/environments/MolTran_CUDA11 3 | python train_pubchem_light.py \ 4 | --device cuda \ 5 | --n_batch 800 \ 6 | --n_head 12 \ 7 | --n_layer 12 \ 8 | --n_embd 768 \ 9 | --max_len 202 \ 10 | --d_dropout 0.2 \ 11 | --lr_start 3e-5 \ 12 | --lr_multiplier 8 \ 13 | --n_workers 16 \ 14 | --max_epochs 4 \ 15 | --gpu 8 \ 16 | --num_nodes 2 \ 17 | --accelerator ddp \ 18 | --num_feats 32 \ 19 | --root_dir . \ 20 | --checkpoint_every 1000 \ 21 | --grad_acc 1\ 22 | --train_load 'both' \ 23 | --eval_every 2500 \ 24 | --rotate \ 25 | --debug \ 26 | --model_arch 'BERT_16GPU_Long_Run_with_Rotate_Continued' \ 27 | --restart_path /dccstor/bmbelgod1/projects/MolTran/lightning_logs/version_11/checkpoints/epoch\=2-step\=178277.ckpt \ 28 | | tee $HOSTNAME.$LSF_PM_XPROCID.$(date +%F_%R).log 29 | #--restart_path /dccstor/bmbelgod1/projects/MolTran/lightning_logs/version_11/checkpoints/N-Step-Checkpoint_2_180000.ckpt \ 30 | -------------------------------------------------------------------------------- /training/dataset_pubchem.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from rdkit import Chem 3 | import codecs 4 | import pytorch_lightning as pl 5 | import pandas as pd 6 | import numpy as np 7 | import time 8 | 9 | class SS: 10 | bos = '' 11 | eos = '' 12 | pad = '' 13 | unk = '' 14 | 15 | class DatasetPubchem(torch.utils.data.IterableDataset): 16 | #class DatasetPubchem(torch.utils.data.Dataset): 17 | 18 | def __init__(self, train_load=None, vocab=None, randomize_smiles=False): 19 | """PubChem Dataset 20 | 21 | 22 | Keyword Arguments: 23 | database_file {[type]} -- [description] 24 | randomize_smiles {bool} -- Randomize the smiles each epoch 25 | """ 26 | #regex vocab used by Molecular Transformer 27 | import re 28 | pattern = "(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])" 29 | self.regex = re.compile(pattern) 30 | 31 | if train_load is None: 32 | self.database_file = '/dccstor/trustedgen/data/pubchem/CID-SMILES' 33 | else: 34 | self.database_file = train_load 35 | self.len = len(open(self.database_file, 'rb').readlines()) 36 | #self.smiles = [] 37 | #with codecs.open(self.database_file) as f: 38 | # for line in f: 39 | # self.smiles.append(''.join(self.regex.findall(line.split()[-1]))) 40 | 41 | #self.smiles = pd.read_csv(self.database_file)[smiles_header_name].tolist() 42 | self.randomize_smiles = randomize_smiles 43 | if vocab != None: 44 | self.vocab = vocab 45 | else: 46 | self.vocab = {} 47 | vocab = torch.load('./atom_vocab.pth') 48 | self.ss = SS 49 | self.c2i = {c: i for i, c in vocab.items()} 50 | self.i2c = {i: c for i, c in vocab.items()} 51 | 52 | def process(self,text): 53 | #return np.asarray(self.string2ids(self.smi_tokenizer(smiles), add_bos=True, add_eos=True)) 54 | print(text) 55 | mol = ''.join(self.regex.findall(text.split()[-1])) 56 | return mol 57 | #return np.asarray(self.string2ids(self.smi_tokenizer(mol), add_bos=True, add_eos=True)) 58 | def line_mapper(self, line): 59 | text = self.process(line) 60 | return text 61 | def __len__(self): 62 | #pass 63 | return self.len 64 | # return len(self.smiles) 65 | def __iter__(self): 66 | data_file = open(self.database_file) 67 | data_map = map(self.line_mapper, data_file) 68 | return data_map 69 | 70 | def get_vocab(self): 71 | return self.c2i 72 | #return {'bos':self.ss.bos, 'eos':self.ss.eos, 'pad':self.ss.pad, 73 | #'unk':self.ss.unk, 'c2i':self.c2i, 'i2c':self.i2c} 74 | 75 | def smi_tokenizer(self, smi): 76 | """ 77 | Tokenize a SMILES molecule or reaction 78 | """ 79 | tokens = self.regex.findall(smi) 80 | #tokens = [token for token in self.regex.findall(smi)] 81 | assert smi == ''.join(tokens) 82 | return tokens 83 | 84 | #def __getitem__(self, index): 85 | # smiles = self.smiles[index] 86 | # #print(smiles) 87 | # #print(self.vocab.string2ids(smiles, add_bos=True, add_eos=True)) 88 | # #print(np.asarray(self.vocab.string2ids(smiles, add_bos=True, add_eos=True))) 89 | # 90 | # if self.randomize_smiles: 91 | # smiles = self.randomize_smiles(smiles, self.isomeric_smiles) 92 | # if self.is_measure_available: 93 | # return smiles, self.measure[index] 94 | # else: 95 | # return np.asarray(self.string2ids(self.smi_tokenizer(smiles), add_bos=True, add_eos=True)) 96 | 97 | 98 | def create_collate_fn(self, pad): 99 | def collate(batch): 100 | ''' 101 | Padds batch of variable length 102 | note: it converts things ToTensor manually here since the ToTensor transform 103 | assume it takes in images rather than arbitrary tensors. 104 | ''' 105 | ## get sequence lengths 106 | lengths = torch.tensor([ t.shape[0] for t in batch if t.shape[0] > 4 and t.shape[0] < 42 ]]) 107 | batch_tmp = [ torch.from_numpy(t) for t in batch if t.shape[0] > 4 and t.shape[0] < 42 ] 108 | batch_tmp = torch.nn.utils.rnn.pad_sequence(batch_tmp, batch_first=True, padding_value=pad) 109 | target = [ torch.from_numpy(t[1:]) for t in batch if t.shape[0] > 4 and t.shape[0] < 42 ] 110 | target = torch.nn.utils.rnn.pad_sequence(target, batch_first=True, padding_value=pad) 111 | target = torch.cat((target, torch.full((target.size(0), 1),pad, dtype=int)), dim=-1) 112 | lengths_mid = None 113 | batch_tmp_mid = None 114 | target_mid = None 115 | assert(False) 116 | if len(batch_tmp) < len(batch): 117 | lengths_mid = torch.tensor([ t.shape[0] for t in batch if t.shape[0] >= 42 and t.shape[0] <= 100 ]]) 118 | batch_tmp_mid = [ torch.from_numpy(t) for t in batch if t.shape[0] >= 42 and t.shape[0] <= 100 ] 119 | batch_tmp_mid = torch.nn.utils.rnn.pad_sequence(batch_tmp_mid, batch_first=True, padding_value=pad) 120 | target_mid = [ torch.from_numpy(t[1:]) for t in batch if t.shape[0] >= 42 and t.shape[0] <= 100 ] 121 | target_mid= torch.nn.utils.rnn.pad_sequence(target, batch_first=True, padding_value=pad) 122 | target_mid= torch.cat((target, torch.full((target.size(0), 1),pad, dtype=int)), dim=-1) 123 | ## compute mask 124 | #mask = (batch != pad) 125 | lengths_long = None 126 | batch_tmp_long = None 127 | target_long = None 128 | 129 | print(len(batch_tmp) < len(batch)) 130 | if len(batch_tmp) < len(batch): 131 | lengths_long = torch.tensor([ t.shape[0] for t in batch if t.shape[0] >= 101 and t.shape[0] <= 200 ]]) 132 | batch_tmp_long = [ torch.from_numpy(t) for t in batch if t.shape[0] >= 101 and t.shape[0] <= 200 ] 133 | batch_tmp_long = torch.nn.utils.rnn.pad_sequence(batch_tmp_long, batch_first=True, padding_value=pad) 134 | target_long = [ torch.from_numpy(t[1:]) for t in batch if t.shape[0] >= 101 and t.shape[0] <= 200 ] 135 | target_long= torch.nn.utils.rnn.pad_sequence(target, batch_first=True, padding_value=pad) 136 | target_long= torch.cat((target, torch.full((target.size(0), 1),pad, dtype=int)), dim=-1) 137 | ## compute mask 138 | return batch_tmp, lengths, target, batch_tmp_mid, lengths_mid, target_mid, batch_tmp_long, lengths_long, target_long 139 | return collate 140 | @property 141 | def bos(self): 142 | return self.c2i[self.ss.bos] 143 | 144 | @property 145 | def eos(self): 146 | return self.c2i[self.ss.eos] 147 | 148 | @property 149 | def pad(self): 150 | return self.c2i[self.ss.pad] 151 | 152 | @property 153 | def unk(self): 154 | return self.c2i[self.ss.unk] 155 | 156 | def char2id(self, char): 157 | if char not in self.c2i: 158 | return self.unk 159 | 160 | return self.c2i[char] 161 | 162 | def id2char(self, id): 163 | if id not in self.i2c: 164 | return self.unk 165 | 166 | return self.i2c[id] 167 | 168 | def string2ids(self, string, add_bos=False, add_eos=False): 169 | ids = [self.char2id(c) for c in string] 170 | 171 | if add_bos: 172 | ids = [self.bos] + ids 173 | if add_eos: 174 | ids = ids + [self.eos] 175 | 176 | return ids 177 | 178 | def ids2string(self, ids, rem_bos=True, rem_eos=True): 179 | if len(ids) == 0: 180 | return '' 181 | if rem_bos and ids[0] == self.bos: 182 | ids = ids[1:] 183 | if rem_eos and ids[-1] == self.eos: 184 | ids = ids[:(ids != self.eos).sum()] 185 | #ids = ids[:-1] 186 | 187 | string = ''.join([self.id2char(id) for id in ids]) 188 | 189 | return string 190 | def read_smiles_csv(path, fields=None, normalize=True): 191 | df = pd.read_csv(path) 192 | fields=['SMILES'] 193 | data = [df[field].to_list() for field in fields] 194 | return data 195 | #return list(zip(*data)) 196 | def randomize_smiles(self, smiles, isomeric_smiles): 197 | """Perform a randomization of a SMILES string 198 | must be RDKit sanitizable""" 199 | m = Chem.MolFromSmiles(smiles) 200 | if m is None: 201 | return None # Invalid SMILES 202 | ans = list(range(m.GetNumAtoms())) 203 | np.random.shuffle(ans) 204 | nm = Chem.RenumberAtoms(m,ans) 205 | return Chem.MolToSmiles(nm, canonical=False, isomericSmiles=isomeric_smiles) 206 | if __name__ == '__main__': 207 | 208 | t0 = time.time() 209 | print('start time is {}'.format(t0)) 210 | dataset = DatasetPubchem() 211 | t1 = time.time() 212 | print('total time is {}'.format(t1-t0)) 213 | print('len(dataset)') 214 | print(len(dataset)) 215 | -------------------------------------------------------------------------------- /training/environ.sh: -------------------------------------------------------------------------------- 1 | conda activate /dccstor/bmbelgod1/environments/MolTran_CUDA11 2 | #source /dccstor/medscan7/anaconda3/bin/activate py37exp 3 | -------------------------------------------------------------------------------- /training/pubchem_canon_script.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from dataclasses import dataclass 3 | import pyarrow as pa 4 | 5 | import datasets 6 | 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | FEATURES = datasets.Features( 12 | { 13 | "text": datasets.Value("string"), 14 | } 15 | ) 16 | 17 | 18 | @dataclass 19 | class PubChemConfig(datasets.BuilderConfig): 20 | """BuilderConfig for text files.""" 21 | 22 | encoding: str = "utf-8" 23 | chunksize: int = 10 << 20 # 10MB 24 | 25 | 26 | class PubChem(datasets.ArrowBasedBuilder): 27 | 28 | BUILDER_CONFIG_CLASS = PubChemConfig 29 | 30 | def _info(self): 31 | return datasets.DatasetInfo(features=FEATURES) 32 | 33 | def _split_generators(self, dl_manager): 34 | """The `data_files` kwarg in load_dataset() can be a str, List[str], Dict[str,str], or Dict[str,List[str]]. 35 | 36 | If str or List[str], then the dataset returns only the 'train' split. 37 | If dict, then keys should be from the `datasets.Split` enum. 38 | """ 39 | if not self.config.data_files: 40 | raise ValueError(f"At least one data file must be specified, but got data_files={self.config.data_files}") 41 | data_files = dl_manager.download_and_extract(self.config.data_files) 42 | if isinstance(data_files, (str, list, tuple)): 43 | files = data_files 44 | if isinstance(files, str): 45 | files = [files] 46 | return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"files": files})] 47 | splits = [] 48 | for split_name, files in data_files.items(): 49 | if isinstance(files, str): 50 | files = [files] 51 | splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": files})) 52 | return splits 53 | 54 | def _generate_tables(self, files): 55 | 56 | for file_idx, file in enumerate(files): 57 | batch_idx = 0 58 | with open(file, "r", encoding=self.config.encoding) as f: 59 | while True: 60 | batch = f.read(self.config.chunksize) 61 | if not batch: 62 | break 63 | batch += f.readline() # finish current line 64 | batch = batch.splitlines() 65 | #batch = [word.split()[-1] for word in batch] 66 | pa_table = pa.Table.from_arrays([pa.array(batch)], schema=pa.schema({"text": pa.string()})) 67 | # Uncomment for debugging (will print the Arrow table size and elements) 68 | #logger.warning(f"pa_table: {pa_table} num rows: {pa_table.num_rows}") 69 | #logger.warning('\n'.join(str(pa_table.slice(i, 1).to_pydict()) for i in range(pa_table.num_rows))) 70 | yield (file_idx, batch_idx), pa_table 71 | batch_idx += 1 72 | -------------------------------------------------------------------------------- /training/pubchem_canon_zinc_final_vocab_sorted.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ningliu-iga/TrinityLLM/5659e6e7d7dfb748a338a88f1da1e3f85a2676f7/training/pubchem_canon_zinc_final_vocab_sorted.pth -------------------------------------------------------------------------------- /training/pubchem_encoder.py: -------------------------------------------------------------------------------- 1 | import regex as re 2 | import torch 3 | import numpy as np 4 | import random 5 | import collections 6 | 7 | class Encoder(): 8 | 9 | def __init__(self, max_length=500, add_bos=True, add_eos=True, feature_size=32): 10 | self.vocab_encoder = torch.load('pubchem_canon_zinc_final_vocab_sorted.pth') 11 | 12 | self.max_length = max_length 13 | self.min_length = 1 14 | self.mod_length = 42 15 | self.mlm_probability = .15 16 | self.avg_length = 66 17 | self.tail = 122 18 | self.b0_cache=collections.deque() 19 | self.b1_cache=collections.deque() 20 | self.b2_cache=collections.deque() 21 | self.b3_cache=collections.deque() 22 | self.bucket0=collections.deque() 23 | self.bucket1=collections.deque() 24 | self.bucket2=collections.deque() 25 | self.bucket3=collections.deque() 26 | if feature_size == 32: 27 | self.b0_max=1100 28 | self.b1_max=700 29 | self.b2_max=150 30 | self.b3_max=50 31 | else: 32 | self.b0_max=1382 33 | self.b1_max=871 34 | self.b2_max=516 35 | self.b3_max=311 36 | values = list(self.vocab_encoder.values()) 37 | num_top = 0 38 | middle_top = 0 39 | bottom = 0 40 | for count in values: 41 | if count > 100000: 42 | num_top += 1 43 | if count > 50: 44 | middle_top += 1 45 | middle_top = middle_top - num_top 46 | self.cutoffs = [num_top+4, middle_top] 47 | self.char2id = {"":0, "":1, "":2, "":3} 48 | self.id2char = {0:"", 1:"", 2:"", 3:""} 49 | self.pad = self.char2id[''] 50 | self.mask = self.char2id[''] 51 | self.eos = self.char2id[''] 52 | self.bos = self.char2id[''] 53 | pos = 0 54 | for key, value in self.vocab_encoder.items(): 55 | #for pos, key in enumerate(self.vocab_encoder.keys()): 56 | self.char2id[key] = pos+4 57 | self.id2char[pos+4] = key 58 | pos += 1 59 | self.char2id[""] = pos + 4 60 | self.id2char[pos+4] = "" 61 | self.pattern = "(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])" 62 | self.regex = re.compile(self.pattern) 63 | self.add_bos = add_bos 64 | self.add_eos = add_eos 65 | 66 | def encode(self, char): 67 | #if len(char) > self.max_length: 68 | # char = char[:self.max_length] 69 | if self.add_bos == True: 70 | char = [''] + char 71 | if self.add_eos == True: 72 | char = char + [''] 73 | 74 | return torch.tensor([self.char2id[word] for word in char]) 75 | 76 | def encoder(self, tokens): 77 | #return *map(lambda x: self.encode(x), tokens) 78 | return [self.encode(mol) for mol in tokens] 79 | 80 | def process_text(self, text): 81 | #print(text) 82 | #random length sequences seems to help training 83 | mod_length = self.mod_length #+ random.randint(-1, 3) 84 | avg_length = self.avg_length #+ random.randint(-3, 5) 85 | for mol in text: 86 | #fill up buckets and caches 87 | if '\n' in mol['text']: 88 | print('carriage return in mol') 89 | raw_regex = self.regex.findall(mol['text'].strip('\n')) 90 | length = len(raw_regex) 91 | if length > self.min_length and length < mod_length: 92 | if len(self.bucket0) < self.b0_max: 93 | self.bucket0.append(raw_regex) 94 | else: 95 | self.b0_cache.append(raw_regex) 96 | elif length >= mod_length and length < avg_length: 97 | if len(self.bucket1) < self.b1_max: 98 | self.bucket1.append(raw_regex) 99 | else: 100 | self.b1_cache.append(raw_regex) 101 | elif length >= avg_length and length < self.tail: 102 | self.b2_cache.append(raw_regex) 103 | #if len(bucket2) < self.b2_max: 104 | # bucket2.append(raw_regex) 105 | #else: 106 | # self.b2_cache.append(raw_regex) 107 | elif length >= self.tail and length < self.max_length: 108 | self.b3_cache.append(raw_regex) 109 | #if len(bucket3) < self.b3_max: 110 | # bucket3.append(raw_regex) 111 | #else: 112 | # self.b3_cache.append(raw_regex) 113 | 114 | #print('before Cache size {} {} {} {}'.format(len(self.b0_cache), len(self.b1_cache), len(self.b2_cache), len(self.b3_cache))) 115 | #pour cache elements into any open bucket 116 | if len(self.bucket0) < self.b0_max and len(self.b0_cache) > 0: 117 | cache_size = len(self.b0_cache) 118 | max_margin = self.b0_max-len(self.bucket0) 119 | range0 = min(cache_size, max_margin) 120 | outbucket0 = [self.bucket0.pop() for item in range(len(self.bucket0))] + [self.b0_cache.pop() for i in range(range0)] 121 | #self.b0_cache = collections.deque(self.b0_cache[:self.b0_max-len(bucket0)]) 122 | #print('0 type {}'.format(type(self.b0_cache))) 123 | else: 124 | outbucket0 = [self.bucket0.pop() for item in range(len(self.bucket0))] 125 | if len(self.bucket1) < self.b1_max and len(self.b1_cache) > 0: 126 | cache_size = len(self.b1_cache) 127 | max_margin = self.b1_max-len(self.bucket1) 128 | range1 = min(cache_size, max_margin) 129 | outbucket1 = [self.bucket1.pop() for item in range(len(self.bucket1))] + [self.b1_cache.pop() for i in range(range1)] 130 | else: 131 | outbucket1 = [self.bucket1.pop() for item in range(len(self.bucket1))] 132 | 133 | if len(self.b2_cache) > self.b2_max: 134 | cache_size = len(self.b2_cache) 135 | max_margin = self.b2_max 136 | range2 = min(cache_size, max_margin) 137 | outbucket2 = [self.b2_cache.pop() for i in range(range2)] 138 | else: 139 | outbucket2=[] 140 | if len(self.b3_cache) > self.b3_max: 141 | cache_size = len(self.b3_cache) 142 | max_margin = self.b3_max 143 | range3 = min(cache_size, max_margin) 144 | outbucket3 = [self.b3_cache.pop() for i in range(range3)] 145 | else: 146 | outbucket3 = [] 147 | return outbucket0, outbucket1, outbucket2, outbucket3 148 | 149 | def mask_tokens( self, inputs, special_tokens_mask= None): 150 | """ 151 | Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. 152 | """ 153 | labels = inputs.clone() 154 | # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`) 155 | probability_matrix = torch.full(labels.size(), self.mlm_probability) 156 | if special_tokens_mask is None: 157 | special_tokens_mask = [ 158 | self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist() 159 | ] 160 | special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool) 161 | else: 162 | special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool) 163 | #special_tokens_mask = special_tokens_mask.bool() 164 | 165 | #print(special_tokens_mask.size()) 166 | probability_matrix.masked_fill_(special_tokens_mask, value=0.0) 167 | masked_indices = torch.bernoulli(probability_matrix).bool() 168 | labels[~masked_indices] = -100 # We only compute loss on masked tokens 169 | 170 | # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) 171 | indices_replaced = torch.bernoulli(torch.full(labels.size(), 0.8)).bool() & masked_indices 172 | inputs[indices_replaced] = self.mask 173 | 174 | # 10% of the time, we replace masked input tokens with random word 175 | indices_random = torch.bernoulli(torch.full(labels.size(), 0.5)).bool() & masked_indices & ~indices_replaced 176 | random_words = torch.randint(len(self.char2id.keys()), labels.size(), dtype=torch.long) 177 | inputs[indices_random] = random_words[indices_random] 178 | 179 | # The rest of the time (10% of the time) we keep the masked input tokens unchanged 180 | return inputs, labels 181 | def pack_tensors(self, tokens): 182 | array = self.encoder(tokens) 183 | array = torch.nn.utils.rnn.pad_sequence(array, batch_first=True, padding_value=self.pad) 184 | #lengths = (array!=self.pad).sum(dim=-1) 185 | #Bert tokenization 186 | special_token_mask = [list(map(lambda x: 1 if x in [self.bos, self.eos, self.pad] else 0, stuff)) for stuff in array.tolist()] 187 | masked_array, masked_labels = self.mask_tokens(array, special_token_mask) 188 | return masked_array, masked_labels#, lengths 189 | def process(self, text): 190 | arrays = [] 191 | #lengths = [] 192 | targets = [] 193 | for tokens in self.process_text(text): 194 | if len(tokens) > 0: 195 | array, target = self.pack_tensors(tokens) 196 | arrays.append(array) 197 | targets.append(target) 198 | return arrays, targets 199 | 200 | if __name__ == '__main__': 201 | 202 | text_encoder = Encoder() 203 | -------------------------------------------------------------------------------- /training/pubchem_script.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from dataclasses import dataclass 3 | import pyarrow as pa 4 | 5 | import datasets 6 | 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | FEATURES = datasets.Features( 12 | { 13 | "text": datasets.Value("string"), 14 | } 15 | ) 16 | 17 | 18 | @dataclass 19 | class PubChemConfig(datasets.BuilderConfig): 20 | """BuilderConfig for text files.""" 21 | 22 | encoding: str = "utf-8" 23 | chunksize: int = 10 << 20 # 10MB 24 | 25 | 26 | class PubChem(datasets.ArrowBasedBuilder): 27 | 28 | BUILDER_CONFIG_CLASS = PubChemConfig 29 | 30 | def _info(self): 31 | return datasets.DatasetInfo(features=FEATURES) 32 | 33 | def _split_generators(self, dl_manager): 34 | """The `data_files` kwarg in load_dataset() can be a str, List[str], Dict[str,str], or Dict[str,List[str]]. 35 | 36 | If str or List[str], then the dataset returns only the 'train' split. 37 | If dict, then keys should be from the `datasets.Split` enum. 38 | """ 39 | if not self.config.data_files: 40 | raise ValueError(f"At least one data file must be specified, but got data_files={self.config.data_files}") 41 | data_files = dl_manager.download_and_extract(self.config.data_files) 42 | if isinstance(data_files, (str, list, tuple)): 43 | files = data_files 44 | if isinstance(files, str): 45 | files = [files] 46 | return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"files": files})] 47 | splits = [] 48 | for split_name, files in data_files.items(): 49 | if isinstance(files, str): 50 | files = [files] 51 | splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": files})) 52 | return splits 53 | 54 | def _generate_tables(self, files): 55 | 56 | for file_idx, file in enumerate(files): 57 | batch_idx = 0 58 | with open(file, "r", encoding=self.config.encoding) as f: 59 | while True: 60 | batch = f.read(self.config.chunksize) 61 | if not batch: 62 | break 63 | batch += f.readline() # finish current line 64 | batch = batch.splitlines() 65 | batch = [word.split()[-1] for word in batch] 66 | pa_table = pa.Table.from_arrays([pa.array(batch)], schema=pa.schema({"text": pa.string()})) 67 | # Uncomment for debugging (will print the Arrow table size and elements) 68 | #logger.warning(f"pa_table: {pa_table} num rows: {pa_table.num_rows}") 69 | #logger.warning('\n'.join(str(pa_table.slice(i, 1).to_pydict()) for i in range(pa_table.num_rows))) 70 | yield (file_idx, batch_idx), pa_table 71 | batch_idx += 1 72 | -------------------------------------------------------------------------------- /training/rotate_attention/attention_layer.py: -------------------------------------------------------------------------------- 1 | """The rotate attention layer performs all the query key value projections and 2 | output projections leaving the implementation of the attention to the inner 3 | attention module. 4 | """ 5 | 6 | from torch.nn import Linear, Module 7 | 8 | from fast_transformers.attention import AttentionLayer 9 | from fast_transformers.events import EventDispatcher, QKVEvent 10 | from .rotary import RotaryEmbedding, apply_rotary_pos_emb 11 | 12 | class RotateAttentionLayer(AttentionLayer): 13 | """Rotate attention layer inherits from fast_transformer attention layer. 14 | The only thing added is an Embedding encoding, for more information 15 | on the attention layer see the fast_transformers code 16 | """ 17 | def __init__(self, attention, d_model, n_heads, d_keys=None, 18 | d_values=None, event_dispatcher=""): 19 | super(RotateAttentionLayer, self).__init__(attention,d_model, n_heads, d_keys=d_keys, 20 | d_values=d_values, event_dispatcher=event_dispatcher) 21 | 22 | self.rotaryemb = RotaryEmbedding(d_keys) 23 | print('Using Rotation Embedding') 24 | 25 | def forward(self, queries, keys, values, attn_mask, query_lengths, 26 | key_lengths): 27 | """ 28 | Using the same frame work as the fast_Transformers attention layer 29 | but injecting rotary information to the queries and the keys 30 | after the keys and queries are projected. 31 | In the argument description we make use of the following sizes 32 | 33 | - N: the batch size 34 | - L: The maximum length of the queries 35 | - S: The maximum length of the keys (the actual length per sequence 36 | is given by the length mask) 37 | - D: The input feature dimensionality passed in the constructor as 38 | 'd_model' 39 | 40 | Arguments 41 | --------- 42 | queries: (N, L, D) The tensor containing the queries 43 | keys: (N, S, D) The tensor containing the keys 44 | values: (N, S, D) The tensor containing the values 45 | attn_mask: An implementation of BaseMask that encodes where each 46 | query can attend to 47 | query_lengths: An implementation of BaseMask that encodes how 48 | many queries each sequence in the batch consists of 49 | key_lengths: An implementation of BaseMask that encodes how 50 | many queries each sequence in the batch consists of 51 | 52 | Returns 53 | ------- 54 | The new value for each query as a tensor of shape (N, L, D). 55 | """ 56 | # Extract the dimensions into local variables 57 | N, L, _ = queries.shape 58 | _, S, _ = keys.shape 59 | H = self.n_heads 60 | 61 | # Project the queries/keys/values 62 | queries = self.query_projection(queries).view(N, L, H, -1) 63 | keys = self.key_projection(keys).view(N, S, H, -1) 64 | cos, sin = self.rotaryemb(queries) 65 | queries, keys = apply_rotary_pos_emb(queries, keys, cos, sin) 66 | values = self.value_projection(values).view(N, S, H, -1) 67 | # Let the world know of the qkv 68 | self.event_dispatcher.dispatch(QKVEvent(self, queries, keys, values)) 69 | 70 | 71 | # Compute the attention 72 | new_values = self.inner_attention( 73 | queries, 74 | keys, 75 | values, 76 | attn_mask, 77 | query_lengths, 78 | key_lengths 79 | ).view(N, L, -1) 80 | 81 | # Project the output and return 82 | return self.out_projection(new_values) 83 | -------------------------------------------------------------------------------- /training/rotate_attention/rotary.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class RotaryEmbedding(torch.nn.Module): 4 | 5 | def __init__(self, dim, base=10000): 6 | super().__init__() 7 | inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim)) 8 | self.register_buffer('inv_freq', inv_freq) 9 | self.seq_len_cached = 0 10 | self.cos_cached = None 11 | self.sin_cached = None 12 | 13 | def forward(self, x, seq_dim=1): 14 | seq_len = x.shape[seq_dim] 15 | if seq_len != self.seq_len_cached: 16 | #if seq_len > self.seq_len_cached: 17 | self.seq_len_cached = seq_len 18 | t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) 19 | freqs = torch.einsum('i,j->ij', t, self.inv_freq) 20 | emb = torch.cat((freqs, freqs), dim=-1).to(x.device) 21 | self.cos_cached = emb.cos()[None,:, None, :] 22 | self.sin_cached = emb.sin()[None,:, None, :] 23 | #else: 24 | # cos_return = self.cos_cached[..., :seq_len] 25 | # sin_return = self.sin_cached[..., :seq_len] 26 | # return cos_return, sin_return 27 | 28 | return self.cos_cached, self.sin_cached 29 | 30 | 31 | # rotary pos emb helpers: 32 | 33 | def rotate_half(x): 34 | x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:] 35 | return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions 36 | 37 | @torch.jit.script 38 | def apply_rotary_pos_emb(q, k, cos, sin): 39 | return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) 40 | -------------------------------------------------------------------------------- /training/rotate_attention/rotate_builder.py: -------------------------------------------------------------------------------- 1 | from torch.nn import LayerNorm 2 | 3 | from .attention_layer import RotateAttentionLayer 4 | from fast_transformers.transformers import TransformerEncoder, TransformerEncoderLayer 5 | from fast_transformers.builders.base import BaseBuilder 6 | from fast_transformers.builders.transformer_builders import BaseTransformerEncoderBuilder 7 | from fast_transformers.builders.attention_builders import AttentionBuilder 8 | 9 | 10 | class RotateEncoderBuilder(BaseTransformerEncoderBuilder): 11 | """Build a batch transformer encoder with Relative Rotary embeddings 12 | for training or processing of sequences all elements at a time. 13 | 14 | Example usage: 15 | 16 | builder = RotateEncoderBuilder() 17 | builder.n_layers = 12 18 | builder.n_heads = 8 19 | builder.feed_forward_dimensions = 1024 20 | builder.query_dimensions = 64 21 | builder.value_dimensions = 64 22 | builder.dropout = 0.1 23 | builder.attention_dropout = 0.1 24 | builder.attention_type = "linear" 25 | transformer = builder.get() 26 | """ 27 | def _get_attention_builder(self): 28 | """Return an instance of the appropriate attention builder.""" 29 | return AttentionBuilder() 30 | 31 | def _get_attention_layer_class(self): 32 | """Return the class for the layer that projects queries keys and 33 | values.""" 34 | return RotateAttentionLayer 35 | 36 | def _get_encoder_class(self): 37 | """Return the class for the transformer encoder.""" 38 | return TransformerEncoder 39 | 40 | def _get_encoder_layer_class(self): 41 | """Return the class for the transformer encoder layer.""" 42 | return TransformerEncoderLayer 43 | -------------------------------------------------------------------------------- /training/run_pubchem_light.sh: -------------------------------------------------------------------------------- 1 | python train_pubchem_light.py \ 2 | --device cuda \ 3 | --n_batch 1200 \ 4 | --n_head 12 \ 5 | --n_layer 12 \ 6 | --n_embd 768 \ 7 | --max_len 202 \ 8 | --d_dropout 0.2 \ 9 | --lr_start 3e-5 \ 10 | --lr_multiplier 8 \ 11 | --n_workers 8 \ 12 | --max_epochs 4 \ 13 | --gpu -1 \ 14 | --num_nodes 1 \ 15 | --accelerator ddp \ 16 | --num_feats 32 \ 17 | --root_dir . \ 18 | --checkpoint_every 1000 \ 19 | --grad_acc 1\ 20 | --train_load 'both' \ 21 | --eval_every 1000 \ 22 | --rotate \ 23 | --debug \ 24 | --model_arch 'BERT__both_rotate' \ 25 | | tee $HOSTNAME.$LSF_PM_XPROCID.$(date +%F_%R).8.log 26 | -------------------------------------------------------------------------------- /training/tokenizer/tokenizer.py: -------------------------------------------------------------------------------- 1 | from transformers import BertTokenizer 2 | import regex as re 3 | 4 | PATTERN = "(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])" 5 | 6 | 7 | class MolTranBertTokenizer(BertTokenizer): 8 | def __init__(self, vocab_file: str = '', 9 | do_lower_case=False, 10 | unk_token='', 11 | sep_token='', 12 | pad_token='', 13 | cls_token='', 14 | mask_token='', 15 | **kwargs): 16 | super().__init__(vocab_file, 17 | unk_token=unk_token, 18 | sep_token=sep_token, 19 | pad_token=pad_token, 20 | cls_token=cls_token, 21 | mask_token=mask_token, 22 | **kwargs) 23 | 24 | self.regex_tokenizer = re.compile(PATTERN) 25 | self.wordpiece_tokenizer = None 26 | self.basic_tokenizer = None 27 | 28 | def _tokenize(self, text): 29 | split_tokens = self.regex_tokenizer.findall(text) 30 | return split_tokens 31 | 32 | def convert_tokens_to_string(self, tokens): 33 | out_string = "".join(tokens).strip() 34 | return out_string 35 | -------------------------------------------------------------------------------- /training/utils.py: -------------------------------------------------------------------------------- 1 | #import debugpy 2 | import socket 3 | import glob 4 | import pandas as pd 5 | from typing import List 6 | from rdkit import Chem 7 | 8 | 9 | def getipaddress(): 10 | return socket.gethostbyname(socket.getfqdn()) 11 | 12 | 13 | def debug(): 14 | print("Waiting for debugger to connect") 15 | if ( 16 | socket.getfqdn().startswith("dcc") 17 | or socket.getfqdn().startswith("mol") 18 | or socket.getfqdn().startswith("ccc") 19 | ): 20 | debugpy.listen(address=(getipaddress(), 3000)) 21 | debugpy.wait_for_client() 22 | debugpy.breakpoint() 23 | 24 | 25 | class ListDataset: 26 | def __init__(self, seqs): 27 | self.seqs = seqs 28 | 29 | def __getitem__(self, index): 30 | return self.seqs[index] 31 | 32 | def __len__(self): 33 | return len(self.seqs) 34 | 35 | 36 | def transform_single_embedding_to_multiple(smiles_z_map): 37 | """Transforms an embedding map of the format smi->embedding to 38 | smi-> {"canonical_embeddings":embedding}. This function exists 39 | as a compatibility layer 40 | 41 | Args: 42 | smiles_z_map ([type]): [description] 43 | """ 44 | retval = dict() 45 | for key in smiles_z_map: 46 | retval[key] = {"canonical_embeddings": smiles_z_map[key]} 47 | return retval 48 | 49 | 50 | def normalize_smiles(smi, canonical, isomeric): 51 | normalized = Chem.MolToSmiles( 52 | Chem.MolFromSmiles(smi), canonical=canonical, isomericSmiles=isomeric 53 | ) 54 | return normalized 55 | 56 | 57 | def get_all_proteins(affinity_dir: str): 58 | files = glob.glob(affinity_dir + "/*.csv") 59 | all_proteins = [] 60 | print(files) 61 | for file in files: 62 | df = pd.read_csv(file) 63 | all_proteins.extend(df["protein"].tolist()) 64 | return set(all_proteins) 65 | 66 | 67 | def append_to_file(filename, line): 68 | with open(filename, "a") as f: 69 | f.write(line + "\n") 70 | 71 | 72 | def write_to_file(filename, line): 73 | with open(filename, "w") as f: 74 | f.write(line + "\n") 75 | -------------------------------------------------------------------------------- /training/val_script.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from dataclasses import dataclass 3 | import pyarrow as pa 4 | 5 | import datasets 6 | 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | FEATURES = datasets.Features( 11 | { 12 | "text": datasets.Value("string"), 13 | } 14 | ) 15 | 16 | 17 | @dataclass 18 | class ValConfig(datasets.BuilderConfig): 19 | """BuilderConfig for text files.""" 20 | 21 | encoding: str = "utf-8" 22 | chunksize: int = 10 << 20 # 10MB 23 | 24 | 25 | class Val(datasets.ArrowBasedBuilder): 26 | 27 | BUILDER_CONFIG_CLASS = ValConfig 28 | 29 | def _info(self): 30 | return datasets.DatasetInfo(features=FEATURES) 31 | 32 | def _split_generators(self, dl_manager): 33 | """The `data_files` kwarg in load_dataset() can be a str, List[str], Dict[str,str], or Dict[str,List[str]]. 34 | 35 | If str or List[str], then the dataset returns only the 'train' split. 36 | If dict, then keys should be from the `datasets.Split` enum. 37 | """ 38 | if not self.config.data_files: 39 | raise ValueError(f"At least one data file must be specified, but got data_files={self.config.data_files}") 40 | data_files = dl_manager.download_and_extract(self.config.data_files) 41 | if isinstance(data_files, (str, list, tuple)): 42 | files = data_files 43 | if isinstance(files, str): 44 | files = [files] 45 | return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"files": files})] 46 | splits = [] 47 | for split_name, files in data_files.items(): 48 | if isinstance(files, str): 49 | files = [files] 50 | splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": files})) 51 | return splits 52 | 53 | def _generate_tables(self, files): 54 | 55 | for file_idx, file in enumerate(files): 56 | batch_idx = 0 57 | with open(file, "r", encoding=self.config.encoding) as f: 58 | while True: 59 | batch = f.read(self.config.chunksize) 60 | if not batch: 61 | break 62 | batch += f.readline() # finish current line 63 | batch = batch.splitlines() 64 | batch = [word.split(',')[0] for word in batch if 'smiles' not in word] 65 | pa_table = pa.Table.from_arrays([pa.array(batch)], schema=pa.schema({"text": pa.string()})) 66 | # Uncomment for debugging (will print the Arrow table size and elements) 67 | #logger.warning(f"pa_table: {pa_table} num rows: {pa_table.num_rows}") 68 | #logger.warning('\n'.join(str(pa_table.slice(i, 1).to_pydict()) for i in range(pa_table.num_rows))) 69 | yield (file_idx, batch_idx), pa_table 70 | batch_idx += 1 71 | -------------------------------------------------------------------------------- /training/zinc_script.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from dataclasses import dataclass 3 | import pyarrow as pa 4 | 5 | import datasets 6 | 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | FEATURES = datasets.Features( 11 | { 12 | "text": datasets.Value("string"), 13 | } 14 | ) 15 | 16 | 17 | @dataclass 18 | class ZincConfig(datasets.BuilderConfig): 19 | """BuilderConfig for text files.""" 20 | 21 | encoding: str = "utf-8" 22 | chunksize: int = 10 << 20 # 10MB 23 | 24 | 25 | class Zinc(datasets.ArrowBasedBuilder): 26 | 27 | BUILDER_CONFIG_CLASS = ZincConfig 28 | 29 | def _info(self): 30 | return datasets.DatasetInfo(features=FEATURES) 31 | 32 | def _split_generators(self, dl_manager): 33 | """The `data_files` kwarg in load_dataset() can be a str, List[str], Dict[str,str], or Dict[str,List[str]]. 34 | 35 | If str or List[str], then the dataset returns only the 'train' split. 36 | If dict, then keys should be from the `datasets.Split` enum. 37 | """ 38 | if not self.config.data_files: 39 | raise ValueError(f"At least one data file must be specified, but got data_files={self.config.data_files}") 40 | data_files = dl_manager.download_and_extract(self.config.data_files) 41 | if isinstance(data_files, (str, list, tuple)): 42 | files = data_files 43 | if isinstance(files, str): 44 | files = [files] 45 | return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"files": files})] 46 | splits = [] 47 | for split_name, files in data_files.items(): 48 | if isinstance(files, str): 49 | files = [files] 50 | splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": files})) 51 | return splits 52 | 53 | def _generate_tables(self, files): 54 | 55 | for file_idx, file in enumerate(files): 56 | batch_idx = 0 57 | with open(file, "r", encoding=self.config.encoding) as f: 58 | while True: 59 | batch = f.read(self.config.chunksize) 60 | if not batch: 61 | break 62 | batch += f.readline() # finish current line 63 | batch = batch.splitlines() 64 | batch = [word.split()[0] for word in batch if 'smiles' not in word] 65 | pa_table = pa.Table.from_arrays([pa.array(batch)], schema=pa.schema({"text": pa.string()})) 66 | # Uncomment for debugging (will print the Arrow table size and elements) 67 | #logger.warning(f"pa_table: {pa_table} num rows: {pa_table.num_rows}") 68 | #logger.warning('\n'.join(str(pa_table.slice(i, 1).to_pydict()) for i in range(pa_table.num_rows))) 69 | yield (file_idx, batch_idx), pa_table 70 | batch_idx += 1 71 | --------------------------------------------------------------------------------