├── LICENSE ├── README.md ├── args_utils.py ├── bleu.py ├── bleu_eval.py ├── data ├── README.md ├── iupac_spm.model ├── real_iupac_tokenizer.pt └── train_spm.py ├── dataloader_utils.py ├── dataloader_utils_ori.py ├── example ├── seed.iupac └── seed.smiles ├── figure ├── README └── framework.png ├── inference_main.py ├── iupac_tokenization.py ├── main.py ├── main_retrain.py ├── model_utils.py ├── modeling_bart.py ├── nmt_bleu.py ├── requirements.txt ├── results └── README.md ├── rouge.py ├── sacre_bleu.py ├── smile_tokenization.py ├── src ├── __init__.py ├── controllable │ ├── classifier.py │ ├── controllable_text_sample.py │ └── langevin.py ├── modeling │ ├── __init__.py │ ├── diffusion │ │ ├── __init__.py │ │ ├── gaussian_diffusion.py │ │ ├── losses.py │ │ ├── nn.py │ │ ├── resample.py │ │ ├── respace.py │ │ └── rounding.py │ └── predictor │ │ └── transformer_model.py └── utils │ ├── args_utils.py │ ├── custom_tokenizer.py │ ├── data_utils_sentencepiece.py │ ├── dist_util.py │ ├── eval_ppl.py │ ├── fp16_util.py │ ├── logger.py │ ├── show_sampling_progress.py │ └── test_util.py ├── temp.py ├── tokenizer_utils.py ├── train_scripts ├── gen_opt.sh ├── wjm_iupac_smiles.sh └── wjm_iupac_smiles_retrain.sh ├── train_spm.py └── trainer.py /README.md: -------------------------------------------------------------------------------- 1 | [![License: GNU](https://img.shields.io/badge/License-GNU-yellow)](https://github.com/AspirinCode/DiffIUPAC) 2 | [![J. Pharm. Anal.](https://img.shields.io/badge/10.1016%2Fj.jpha.2024.101137-green)](https://doi.org/10.1016/j.jpha.2024.101137) 3 | 4 | 5 | ## DiffIUPAC 6 | 7 | **Diffusion-based generative drug-like molecular editing with chemical natural language** 8 | 9 | Recently, diffusion models have emerged as a promising paradigm for molecular 10 | design and optimization. However, most diffusion-based molecular generative models 11 | focus on modeling 2D graphs or 3D geometries, with limited research on molecular 12 | sequence diffusion models. The International Union of Pure and Applied Chemistry 13 | (IUPAC) names are more akin to chemical natural language than the Simplified 14 | Molecular Input Line Entry System (SMILES) for organic compounds. In this work, we 15 | apply an IUPAC-guided conditional diffusion model to facilitate molecular editing from 16 | chemical natural language to chemical language (SMILES) and explore whether the 17 | pre-trained generative performance of diffusion models can be transferred to chemical 18 | natural language. We propose DiffIUPAC, a controllable molecular editing diffusion 19 | model that converts IUPAC names to SMILES strings. Evaluation results demonstrate 20 | that our model outperforms existing methods and successfully captures the semantic 21 | rules of both chemical languages. Chemical space and scaffold analysis show that the 22 | model can generate similar compounds with diverse scaffolds within the specified 23 | constraints. Additionally, to illustrate the model's applicability in drug design, we 24 | conducted case studies in functional group editing, analogue design and linker design. 25 | 26 | 27 | ![Model Architecture of DiffIUPAC](https://github.com/AspirinCode/DiffIUPAC/blob/main/figure/framework.png) 28 | 29 | 30 | ## Acknowledgements 31 | We thank the authors of C5T5: Controllable Generation of Organic Molecules with Transformers, IUPAC2Struct: Transformer-based artificial neural networks for the conversion between chemical notations, Deep molecular generative model based on variant transformer for antiviral drug design, and SeqDiffuSeq: Text Diffusion with Encoder-Decoder Transformers for releasing their code. The code in this repository is based on their source code release (https://github.com/dhroth/c5t5, https://github.com/sergsb/IUPAC2Struct, https://github.com/AspirinCode/TransAntivirus, and https://github.com/yuanhy1997/seqdiffuseq). If you find this code useful, please consider citing their work. 32 | 33 | 34 | ## News! 35 | 36 | **[2024/11/02]** Available [online](https://doi.org/10.1016/j.jpha.2024.101137) **Journal of Pharmaceutical Analysis**, 2024. 37 | 38 | **[2024/10/29]** Accepted in **Journal of Pharmaceutical Analysis**, 2024. 39 | 40 | **[2024/05/14]** submission to **Journal of Pharmaceutical Analysis**, 2024. 41 | 42 | 43 | 44 | ## Requirements 45 | ```python 46 | conda create -n diffiupac python=3.8 47 | conda install mpi4py 48 | pip install torch==1.10.0+cu111 torchvision==0.11.0+cu111 torchaudio==0.10.0 49 | pip install -r requirements.txt 50 | 51 | ``` 52 | 53 | https://github.com/rdkit/rdkit 54 | 55 | 56 | 57 | 58 | ## System Requirerments 59 | * requires system memory larger than 228GB. 60 | 61 | * (if GPU is available) requires GPU memory larger than 80GB. 62 | 63 | 64 | 65 | 66 | ## Data 67 | 68 | 69 | **PubChem** 70 | 71 | https://pubchem.ncbi.nlm.nih.gov/ 72 | 73 | IUPAC Name-Canonical SMILES pairs 74 | 75 | ``` 76 | #example:Aspirin 77 | 2-acetyloxybenzoic acid | CC(=O)OC1=CC=CC=C1C(=O)O 78 | ``` 79 | 80 | ## IUPAC name ⇆ SMILES string 81 | 82 | 83 | ### Structure/SMILES2IUPAC 84 | 85 | **IUPAC Naming** 86 | 87 | https://web.chemdoodle.com/demos/iupac-naming 88 | 89 | 90 | **SMILES2IUPAC** 91 | 92 | https://huggingface.co/knowledgator/SMILES2IUPAC-canonical-base 93 | 94 | **Smiles-TO-iUpac-Translator** 95 | 96 | https://github.com/Kohulan/Smiles-TO-iUpac-Translator 97 | 98 | 99 | 100 | ### IUPAC2SMILES 101 | 102 | https://www.antvaset.com/iupac-to-smiles 103 | 104 | https://web.chemdoodle.com/demos/iupac-naming 105 | 106 | 107 | ## Training 108 | 109 | To run the code, we use iwslt14 en-de as an illustrative example: 110 | 111 | **Prepare the data:** 112 | Learning the BPE tokenizer by 113 | ``` 114 | sh ./tokenizer_utils.py train-byte-level iwslt14 10000 115 | ``` 116 | 117 | **To train with the following line:** 118 | ``` 119 | mkdir ckpts 120 | bash ./train_scripts/train.sh 0 iupac smiles 121 | #(for en to de translation) bash ./train_scripts/iwslt_en_de.sh 0 smiles iupac 122 | ``` 123 | 124 | You may modify the scripts in ./train_scripts for your own training settings. 125 | 126 | 127 | **To fine tune with the following line:** 128 | 129 | ``` 130 | bash ./train_scripts/fine_tune.sh 0 iupac smiles 131 | 132 | ``` 133 | 134 | ## Generating 135 | 136 | To run the code, example data is in the example folder: 137 | 138 | ``` 139 | bash ./train_scripts/gen_opt.sh 140 | 141 | ``` 142 | 143 | 144 | ## Model Metrics 145 | 146 | ### MOSES 147 | 148 | Molecular Sets (MOSES), a benchmarking platform to support research on machine learning for drug discovery. MOSES implements several popular molecular generation models and provides a set of metrics to evaluate the quality and diversity of generated molecules. With MOSES, MOSES aim to standardize the research on molecular generation and facilitate the sharing and comparison of new models. 149 | https://github.com/molecularsets/moses 150 | 151 | ### QEPPI 152 | quantitative estimate of protein-protein interaction targeting drug-likeness 153 | 154 | https://github.com/ohuelab/QEPPI 155 | 156 | 157 | ## License 158 | Code is released under GNU GENERAL PUBLIC LICENSE. 159 | 160 | 161 | ## Cite: 162 | 163 | * J. Wang, P. Zhou, Z. Wang, W. Long, Y. Chen, K.T. No, D. Ouyang, J. Mao, X. Zeng, Diffusion-based generative drug-like molecular editing with chemical natural language, Journal of Pharmaceutical Analysis, https://doi.org/10.1016/j.jpha.2024.101137. 164 | 165 | * Jiashun Mao, Jianmin Wang, Amir Zeb, Kwang-Hwi Cho, Haiyan Jin, Jongwan Kim, Onju Lee, Yunyun Wang, and Kyoung Tai No. "Transformer-Based Molecular Generative Model for Antiviral Drug Design" Journal of Chemical Information and Modeling, 2023;, [DOI: 10.1021/acs.jcim.3c00536](https://doi.org/10.1021/acs.jcim.3c00536) 166 | 167 | * Yuan, Hongyi, Zheng Yuan, Chuanqi Tan, Fei Huang, and Songfang Huang. "SeqDiffuSeq: Text Diffusion with Encoder-Decoder Transformers." arXiv preprint arXiv:2212.10325 (2022). 168 | 169 | * Rothchild, Daniel, Alex Tamkin, Julie Yu, Ujval Misra, and Joseph Gonzalez. "C5t5: Controllable generation of organic molecules with transformers." arXiv preprint arXiv:2108.10307 (2021). 170 | -------------------------------------------------------------------------------- /args_utils.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | 4 | def create_argparser(): 5 | defaults = dict( 6 | data_dir="", 7 | src='src', 8 | tgt='tgt', 9 | schedule_sampler="uniform", 10 | lr=1e-4, 11 | weight_decay=0.0, 12 | lr_anneal_steps=30000, 13 | warmup=0, 14 | batch_size=1, 15 | microbatch=-1, # -1 disables microbatches 16 | ema_rate="0.9999", # comma-separated list of EMA values 17 | log_interval=50, 18 | save_interval=25000, 19 | resume_checkpoint="", 20 | use_fp16=False, 21 | fp16_scale_growth=1e-3, 22 | seed=101, 23 | gradient_clipping=-1.0, 24 | eval_interval=2000, 25 | checkpoint_path="diff_models", 26 | train_txt_path="data/quotes_train.txt", 27 | val_txt_path="data/quotes_valid.txt", 28 | dataset="", 29 | notes="", 30 | ) 31 | text_defaults = dict( 32 | modality="text", 33 | emb_scale_factor=1.0, 34 | in_channel=16, 35 | out_channel=16, 36 | noise_level=0.0, 37 | cache_mode="no", 38 | use_bert_tokenizer="no", 39 | padding_mode="block", 40 | preprocessing_num_workers=1, 41 | tok_thresh=150 42 | ) 43 | 44 | guided_generation_defaults = dict( 45 | classifier_num_epochs=15 46 | ) 47 | 48 | defaults.update(model_and_diffusion_defaults()) 49 | defaults.update(text_defaults) 50 | defaults.update(guided_generation_defaults) 51 | defaults.update(decoding_defaults()) 52 | defaults.update(additional_args_for_translation()) 53 | parser = argparse.ArgumentParser() 54 | parser.add_argument("--debug", action="store_true") 55 | 56 | add_dict_to_argparser(parser, defaults) 57 | return parser 58 | 59 | def additional_args_for_translation(): 60 | 61 | return dict( 62 | pretrained_tokenizer=None, 63 | sequence_len_src=64, 64 | use_pretrained_tokenizer=False, 65 | generate_by_q=False, 66 | generate_by_mix=False, 67 | generate_by_mix_prob=0.0, 68 | generate_by_mix_part=1.0, 69 | ) 70 | 71 | 72 | def model_and_diffusion_defaults(): 73 | """ 74 | Defaults for text-diffusion model training. 75 | """ 76 | return dict( 77 | encoder_layers=6, 78 | decoder_layers=6, 79 | sequence_len=64, 80 | num_channels=16, 81 | num_heads=4, 82 | dropout=0.0, 83 | learn_sigma=False, 84 | sigma_small=False, 85 | class_cond=False, 86 | diffusion_steps=10000, 87 | noise_schedule="linear", 88 | timestep_respacing="", 89 | use_kl=False, 90 | predict_xstart=False, 91 | rescale_timesteps=True, 92 | rescale_learned_sigmas=True, 93 | use_checkpoint=False, 94 | model_arch="transformer", 95 | in_channel=16, 96 | out_channel=16, 97 | vocab_size=66, 98 | config_name="bert-base-uncased", 99 | logits_mode=1, 100 | training_mode="diffusion-lm", 101 | init_pretrained=False, 102 | freeze_embeddings=False, 103 | use_pretrained_embeddings=True, 104 | load_ckpt=None, 105 | loss_update_granu=None, 106 | schedule_update_stride=0, 107 | ) 108 | 109 | 110 | def decoding_defaults(): 111 | return dict( 112 | num_samples=50, 113 | top_p=0.9, 114 | out_dir="", 115 | model_name_or_path="", 116 | checkpoint_path="", 117 | use_ddim=False, 118 | clip_denoised=False, 119 | batch_size=64, 120 | mbr_sample=1, 121 | verbose="yes", 122 | clamp="clamp", 123 | preprocessing_num_workers=1, 124 | emb_scale_factor=1.0, 125 | classifier_path="", 126 | time_schedule_path='', 127 | comment='', 128 | ) 129 | 130 | 131 | def add_dict_to_argparser(parser, default_dict): 132 | for k, v in default_dict.items(): 133 | v_type = type(v) 134 | if v is None: 135 | v_type = str 136 | elif isinstance(v, bool): 137 | v_type = str2bool 138 | parser.add_argument(f"--{k}", default=v, type=v_type) 139 | 140 | 141 | def args_to_dict(args, keys): 142 | return {k: getattr(args, k) for k in keys} 143 | 144 | 145 | def str2bool(v): 146 | """ 147 | https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse 148 | """ 149 | if isinstance(v, bool): 150 | return v 151 | if v.lower() in ("yes", "true", "t", "y", "1"): 152 | return True 153 | elif v.lower() in ("no", "false", "f", "n", "0"): 154 | return False 155 | else: 156 | raise argparse.ArgumentTypeError("boolean value expected") 157 | -------------------------------------------------------------------------------- /bleu.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The HuggingFace Datasets Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ BLEU metric. """ 16 | 17 | import datasets 18 | 19 | from nmt_bleu import compute_bleu # From: https://github.com/tensorflow/nmt/blob/master/nmt/scripts/bleu.py 20 | 21 | 22 | _CITATION = """\ 23 | @INPROCEEDINGS{Papineni02bleu:a, 24 | author = {Kishore Papineni and Salim Roukos and Todd Ward and Wei-jing Zhu}, 25 | title = {BLEU: a Method for Automatic Evaluation of Machine Translation}, 26 | booktitle = {}, 27 | year = {2002}, 28 | pages = {311--318} 29 | } 30 | @inproceedings{lin-och-2004-orange, 31 | title = "{ORANGE}: a Method for Evaluating Automatic Evaluation Metrics for Machine Translation", 32 | author = "Lin, Chin-Yew and 33 | Och, Franz Josef", 34 | booktitle = "{COLING} 2004: Proceedings of the 20th International Conference on Computational Linguistics", 35 | month = "aug 23{--}aug 27", 36 | year = "2004", 37 | address = "Geneva, Switzerland", 38 | publisher = "COLING", 39 | url = "https://www.aclweb.org/anthology/C04-1072", 40 | pages = "501--507", 41 | } 42 | """ 43 | 44 | _DESCRIPTION = """\ 45 | BLEU (bilingual evaluation understudy) is an algorithm for evaluating the quality of text which has been machine-translated from one natural language to another. 46 | Quality is considered to be the correspondence between a machine's output and that of a human: "the closer a machine translation is to a professional human translation, 47 | the better it is" – this is the central idea behind BLEU. BLEU was one of the first metrics to claim a high correlation with human judgements of quality, and 48 | remains one of the most popular automated and inexpensive metrics. 49 | 50 | Scores are calculated for individual translated segments—generally sentences—by comparing them with a set of good quality reference translations. 51 | Those scores are then averaged over the whole corpus to reach an estimate of the translation's overall quality. Intelligibility or grammatical correctness 52 | are not taken into account[citation needed]. 53 | 54 | BLEU's output is always a number between 0 and 1. This value indicates how similar the candidate text is to the reference texts, with values closer to 1 55 | representing more similar texts. Few human translations will attain a score of 1, since this would indicate that the candidate is identical to one of the 56 | reference translations. For this reason, it is not necessary to attain a score of 1. Because there are more opportunities to match, adding additional 57 | reference translations will increase the BLEU score. 58 | """ 59 | 60 | _KWARGS_DESCRIPTION = """ 61 | Computes BLEU score of translated segments against one or more references. 62 | Args: 63 | predictions: list of translations to score. 64 | Each translation should be tokenized into a list of tokens. 65 | references: list of lists of references for each translation. 66 | Each reference should be tokenized into a list of tokens. 67 | max_order: Maximum n-gram order to use when computing BLEU score. 68 | smooth: Whether or not to apply Lin et al. 2004 smoothing. 69 | Returns: 70 | 'bleu': bleu score, 71 | 'precisions': geometric mean of n-gram precisions, 72 | 'brevity_penalty': brevity penalty, 73 | 'length_ratio': ratio of lengths, 74 | 'translation_length': translation_length, 75 | 'reference_length': reference_length 76 | Examples: 77 | 78 | >>> predictions = [ 79 | ... ["hello", "there", "general", "kenobi"], # tokenized prediction of the first sample 80 | ... ["foo", "bar", "foobar"] # tokenized prediction of the second sample 81 | ... ] 82 | >>> references = [ 83 | ... [["hello", "there", "general", "kenobi"], ["hello", "there", "!"]], # tokenized references for the first sample (2 references) 84 | ... [["foo", "bar", "foobar"]] # tokenized references for the second sample (1 reference) 85 | ... ] 86 | >>> bleu = datasets.load_metric("bleu") 87 | >>> results = bleu.compute(predictions=predictions, references=references) 88 | >>> print(results["bleu"]) 89 | 1.0 90 | """ 91 | 92 | 93 | @datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) 94 | class Bleu(datasets.Metric): 95 | def _info(self): 96 | return datasets.MetricInfo( 97 | description=_DESCRIPTION, 98 | citation=_CITATION, 99 | inputs_description=_KWARGS_DESCRIPTION, 100 | features=datasets.Features( 101 | { 102 | "predictions": datasets.Sequence(datasets.Value("string", id="token"), id="sequence"), 103 | "references": datasets.Sequence( 104 | datasets.Sequence(datasets.Value("string", id="token"), id="sequence"), id="references" 105 | ), 106 | } 107 | ), 108 | codebase_urls=["https://github.com/tensorflow/nmt/blob/master/nmt/scripts/bleu.py"], 109 | reference_urls=[ 110 | "https://en.wikipedia.org/wiki/BLEU", 111 | "https://towardsdatascience.com/evaluating-text-output-in-nlp-bleu-at-your-own-risk-e8609665a213", 112 | ], 113 | ) 114 | 115 | def _compute(self, predictions, references, max_order=4, smooth=False): 116 | score = compute_bleu( 117 | reference_corpus=references, translation_corpus=predictions, max_order=max_order, smooth=smooth 118 | ) 119 | (bleu, precisions, bp, ratio, translation_length, reference_length) = score 120 | return { 121 | "bleu": bleu, 122 | "precisions": precisions, 123 | "brevity_penalty": bp, 124 | "length_ratio": ratio, 125 | "translation_length": translation_length, 126 | "reference_length": reference_length, 127 | } -------------------------------------------------------------------------------- /bleu_eval.py: -------------------------------------------------------------------------------- 1 | from datasets import load_metric 2 | import numpy as np 3 | import json 4 | import sys 5 | from tokenizer_utils import create_tokenizer 6 | from transformers import AutoTokenizer 7 | from sacremoses import MosesDetokenizer, MosesTokenizer 8 | import os 9 | 10 | mt, md = MosesTokenizer(lang='en'), MosesDetokenizer(lang='en') 11 | metric_bleu = load_metric("./bleu.py") 12 | metric_sacrebleu = load_metric("./sacre_bleu.py") 13 | metric_rouge = load_metric("./rouge.py") 14 | tokenizer_mbert = AutoTokenizer.from_pretrained('bert-base-multilingual-cased') 15 | 16 | def cal_metrics(data): 17 | refs = [[md.detokenize(mt.tokenize(item[-1]))] for item in data] 18 | preds = [md.detokenize(mt.tokenize(item[0])) for item in data] 19 | sacre_results = metric_sacrebleu.compute(predictions=preds, references=refs) 20 | print('***SacreBLEU score', round(sacre_results['score'], 2)) 21 | 22 | refs = [[tokenizer_mbert.tokenize(item[-1])] for item in data] 23 | preds = [tokenizer_mbert.tokenize(item[0]) for item in data] 24 | results = metric_bleu.compute(predictions=preds, references=refs) 25 | print('*** tokenized BLEU score', round(results['bleu']*100, 2)) 26 | 27 | 28 | refs = [item[-1] for item in data] 29 | preds = [item[0] for item in data] 30 | results = metric_rouge.compute(predictions=preds, references=refs) 31 | print('Rouge score', results) 32 | 33 | return sacre_results['score'] 34 | 35 | def selectBest(sentences): 36 | selfBleu = [[] for i in range(len(sentences))] 37 | for i, s1 in enumerate(sentences): 38 | for j, s2 in enumerate(sentences): 39 | score = metric_sacrebleu.compute(predictions=[s1], 40 | references=[[s2]])['score'] 41 | selfBleu[i].append(score) 42 | for i, s1 in enumerate(sentences): 43 | selfBleu[i][i] = 0 44 | idx = np.argmax(np.sum(selfBleu, -1)) 45 | 46 | return sentences[idx] 47 | 48 | input_file = sys.argv[1] 49 | if os.path.exists(input_file): 50 | with open(input_file, 'r') as f: 51 | data = f.readlines() 52 | data = [json.loads(item.strip('\n')) for item in data] 53 | cal_metrics(data) 54 | 55 | else: 56 | path = '/'.join(input_file.split('/')[:-1]) 57 | prefix = input_file.split('/')[-1] 58 | files = [os.path.join(path, f) for f in os.listdir(path) if f.startswith(prefix) and sys.argv[2] in f] 59 | print(files) 60 | refs = [] 61 | preds = [] 62 | for f in files: 63 | print('===='+f.split('/')[-1]) 64 | 65 | with open(f, 'r') as fi: 66 | data = fi.readlines() 67 | data = [json.loads(item.strip('\n')) for item in data] 68 | 69 | if not refs: 70 | refs = [md.detokenize(mt.tokenize(item[-1])) for item in data] 71 | if not preds: 72 | preds = [[md.detokenize(mt.tokenize(item[0]))] for item in data] 73 | else: 74 | for idx, item in enumerate(data): 75 | preds[idx].append(item[0]) 76 | 77 | preds = [selectBest(item) for item in preds] 78 | data_buffer = [] 79 | for p, r in zip(preds, refs): 80 | data_buffer.append([p,r]) 81 | cal_metrics(data_buffer) 82 | 83 | 84 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | ## PubChem 2 | 3 | 4 | https://ftp.ncbi.nlm.nih.gov/pubchem/RDF/compound/ 5 | -------------------------------------------------------------------------------- /data/iupac_spm.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AspirinCode/DiffIUPAC/525ffe6850c21d7cafbca94c5c6f971da2a450d4/data/iupac_spm.model -------------------------------------------------------------------------------- /data/real_iupac_tokenizer.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AspirinCode/DiffIUPAC/525ffe6850c21d7cafbca94c5c6f971da2a450d4/data/real_iupac_tokenizer.pt -------------------------------------------------------------------------------- /data/train_spm.py: -------------------------------------------------------------------------------- 1 | import sentencepiece as spm 2 | import sys 3 | from collections import Counter 4 | 5 | # file with a list of IUPAC names (can be just 1 line if you want) 6 | #iupacs_fn = int(sys.argv[1]) 7 | 8 | 9 | with open("opsin_vocab_reduced.txt", "r") as f: 10 | words = f.read().split("\n") 11 | words = list(map(str, range(100))) + words 12 | 13 | smile_atom =[ 14 | 'Ac', 'Ag', 'Al', 'Am', 'Ar', 'As', 'At', 'Au', 'B', 'Ba', 'Be', 'Bh', 15 | 'Bi', 'Bk', 'Br', 'C', 'Ca', 'Cd', 'Ce', 'Cf', 'Cl', 'Cm', 'Co', 'Cr', 16 | 'Cs', 'Cu', 'Db', 'Dy', 'Er', 'Es', 'Eu', 'F', 'Fe', 'Fm', 'Fr', 'Ga', 17 | 'Gd', 'Ge', 'H', 'He', 'Hf', 'Hg', 'Ho', 'Hs', 'I', 'In', 'Ir', 'K', 18 | 'Kr', 'La', 'Li', 'Lr', 'Lu', 'Md', 'Mg', 'Mn', 'Mo', 'Mt', 'N', 'Na', 19 | 'Nb', 'Nd', 'Ne', 'Ni', 'No', 'Np', 'O', 'Os', 'P', 'Pa', 'Pb', 'Pd', 20 | 'Pm', 'Po', 'Pr', 'Pt', 'Pu', 'Ra', 'Rb', 'Re', 'Rf', 'Rh', 'Rn', 21 | 'Ru', 'S', 'Sb', 'Sc', 'Se', 'Sg', 'Si', 'Sm', 'Sn', 'Sr', 'Ta', 'Tb', 22 | 'Tc', 'Te', 'Th', 'Ti', 'Tl', 'Tm', 'U', 'V', 'W', 'Xe', 'Y', 'Yb', 23 | 'Zn', 'Zr' 24 | ] 25 | smile_non_atom = [ 26 | '-', '=', '#', ':', '(', ')', '.', '[', ']', '+', '-', '\\', '/', '*', 27 | #'1', '2', '3', '4', '5', '6', '7', '8', '9', '0', 28 | '@', 'AL', 'TH', 'SP', 'TB', 'OH', 29 | ] 30 | 31 | #words = smile_atom+smile_non_atom+words 32 | 33 | words = list(set(words)) 34 | 35 | vocab_size = len(words) + 1+100 36 | 37 | user_defined_symbols = words 38 | 39 | print("num user defined:", len(user_defined_symbols)) 40 | 41 | args = {"input": sys.argv[1], 42 | "model_type": "unigram", 43 | "model_prefix": "iupac_spm".format(vocab_size), 44 | "vocab_size": vocab_size, 45 | "input_sentence_size": 50000, 46 | "shuffle_input_sentence": True, 47 | "user_defined_symbols": user_defined_symbols, 48 | "split_by_number": False, 49 | "split_by_whitespace": False, 50 | "hard_vocab_limit": False, 51 | "max_sentencepiece_length": 320, 52 | "character_coverage": 0.99, 53 | "pad_id": 0, 54 | "eos_id": 1, 55 | "unk_id": 2, 56 | "bos_id": -1 57 | } 58 | #"train_extremely_large_corpus": True 59 | 60 | spm.SentencePieceTrainer.train(**args) 61 | -------------------------------------------------------------------------------- /dataloader_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | import pandas as pd 4 | from torch.utils.data import DataLoader, Dataset 5 | import torch 6 | from functools import partial 7 | from mpi4py import MPI 8 | import os 9 | import random 10 | import numpy as np 11 | logging.basicConfig(level=logging.INFO) 12 | 13 | def get_dataloader(iupac_tokenizer,smiles_tokenizer, data_path, batch_size, max_seq_len, max_seq_len_src, args): 14 | 15 | dataset = TextDataset_translation(iupac_tokenizer=iupac_tokenizer,smiles_tokenizer=smiles_tokenizer, data_path=data_path, source=args.src, target=args.tgt, 16 | shard=MPI.COMM_WORLD.Get_rank(), 17 | num_shards=MPI.COMM_WORLD.Get_size()) 18 | 19 | dataloader = DataLoader( 20 | dataset, 21 | batch_size=batch_size, # 20, 22 | drop_last=True, 23 | shuffle='train' in data_path, 24 | num_workers=10, 25 | collate_fn=partial(TextDataset_translation.collate_pad, 26 | args=args, 27 | cutoff=max_seq_len, 28 | cutoff_src=max_seq_len_src, 29 | padding_token=iupac_tokenizer.pad_token_id if hasattr(iupac_tokenizer, 'pad_token_id') else iupac_tokenizer.get_vocab()['']), 30 | ) 31 | 32 | while True: 33 | for batch in dataloader: 34 | yield batch 35 | 36 | class TextDataset(Dataset): 37 | def __init__( 38 | self, 39 | tokenizer, 40 | data_path: str, 41 | has_labels: bool = False 42 | ) -> None: 43 | super().__init__() 44 | self.data_path = data_path 45 | self.tokenizer = tokenizer 46 | self.read_data() 47 | if has_labels: 48 | self.read_labels() 49 | 50 | def read_data(self): 51 | logging.info("Reading data from {}".format(self.data_path)) 52 | data = pd.read_csv(self.data_path, sep="\t", header=None) # read text file 53 | logging.info(f"Tokenizing {len(data)} sentences") 54 | 55 | self.text = data[0].apply(lambda x: x.strip()).tolist() 56 | if hasattr(self.tokenizer, 'encode_batch'): 57 | 58 | encoded_input = self.tokenizer.encode_batch(self.text) 59 | self.input_ids = [x.ids for x in encoded_input] 60 | 61 | else: 62 | encoded_input = self.tokenizer(self.text) 63 | self.input_ids = encoded_input["input_ids"] 64 | 65 | 66 | 67 | def read_labels(self): 68 | self.labels = pd.read_csv(self.data_path, sep="\t", header=None)[1].tolist() 69 | # check if labels are already numerical 70 | self.labels = [str(x) for x in self.labels] 71 | if isinstance(self.labels[0], int): 72 | return 73 | # if not, convert to numerical 74 | all_labels = sorted(list(set(self.labels))) 75 | self.label_to_idx = {label: i for i, label in enumerate(all_labels)} 76 | self.idx_to_label = {i: label for i, label in self.label_to_idx.items()} 77 | self.labels = [self.label_to_idx[label] for label in self.labels] 78 | 79 | 80 | 81 | def __len__(self) -> int: 82 | return len(self.text) 83 | 84 | def __getitem__(self, i): 85 | out_dict = { 86 | "input_ids": self.input_ids[i], 87 | # "attention_mask": [1] * len(self.input_ids[i]), 88 | } 89 | if hasattr(self, "labels"): 90 | out_dict["label"] = self.labels[i] 91 | return out_dict 92 | 93 | @staticmethod 94 | def collate_pad(batch, cutoff: int): 95 | max_token_len = 0 96 | num_elems = len(batch) 97 | # batch[0] -> __getitem__[0] --> returns a tuple (embeddings, out_dict) 98 | 99 | for i in range(num_elems): 100 | max_token_len = max(max_token_len, len(batch[i]["input_ids"])) 101 | 102 | max_token_len = min(cutoff, max_token_len) 103 | 104 | tokens = torch.zeros(num_elems, max_token_len).long() 105 | tokens_mask = torch.zeros(num_elems, max_token_len).long() 106 | 107 | has_labels = False 108 | if "label" in batch[0]: 109 | labels = torch.zeros(num_elems).long() 110 | has_labels = True 111 | 112 | for i in range(num_elems): 113 | toks = batch[i]["input_ids"] 114 | length = len(toks) 115 | tokens[i, :length] = torch.LongTensor(toks) 116 | tokens_mask[i, :length] = 1 117 | if has_labels: 118 | labels[i] = batch[i]["label"] 119 | 120 | # TODO: the first return None is just for backward compatibility -- can be removed 121 | if has_labels: 122 | return None, {"input_ids": tokens, "attention_mask": tokens_mask, "labels": labels} 123 | else: 124 | return None, {"input_ids": tokens, "attention_mask": tokens_mask} 125 | 126 | 127 | class TextDataset_translation(TextDataset): 128 | 129 | def __init__( 130 | self, 131 | iupac_tokenizer, 132 | smiles_tokenizer, 133 | data_path: str, 134 | source, 135 | target, 136 | shard, 137 | num_shards, 138 | ) -> None: 139 | self.data_path = data_path 140 | self.iupac_tokenizer = iupac_tokenizer 141 | self.smiles_tokenizer = smiles_tokenizer 142 | self.shard = shard 143 | self.src = source 144 | self.tgt = target 145 | self.num_shards = num_shards 146 | self.read_data() 147 | 148 | def read_data(self): 149 | print("Reading data from {}".format(self.data_path)) 150 | data = [open(self.data_path+'.'+self.src, 'r').readlines(), 151 | open(self.data_path+'.'+self.tgt, 'r').readlines()] 152 | print(f"Tokenizing {len(data[0])} sentences") 153 | 154 | data = [[src, tgt] for src, tgt in zip(data[0], data[1])] 155 | # random.shuffle(data) 156 | 157 | self.src_text = [item[0].strip('\n') for item in data] 158 | self.tgt_text = [item[1].strip('\n') for item in data] 159 | 160 | bos_idx = (len(self.src_text) // self.num_shards) * self.shard 161 | eos_idx = (len(self.src_text) // self.num_shards) * (self.shard+1) 162 | self.src_text = self.src_text[bos_idx:eos_idx] 163 | self.tgt_text = self.tgt_text[bos_idx:eos_idx] 164 | 165 | print('examples src', self.src_text[0]) 166 | print('examples tgt', self.tgt_text[0]) 167 | 168 | # check if iupac_tokenizer has a method 'encode_batch' 169 | if hasattr(self.iupac_tokenizer, 'encode_batch'): 170 | 171 | encoded_input_src = self.iupac_tokenizer.encode_batch(self.src_text) 172 | self.input_ids_src = [x.ids for x in encoded_input_src] 173 | 174 | encoded_input_tgt = self.smiles_tokenizer.encode_batch(self.tgt_text) 175 | self.input_ids_tgt = [x.ids for x in encoded_input_tgt] 176 | 177 | else: 178 | 179 | encoded_input_src = self.iupac_tokenizer(self.src_text) 180 | self.input_ids_src = encoded_input_src["input_ids"] 181 | 182 | encoded_input_tgt = self.smiles_tokenizer(self.tgt_text) 183 | self.input_ids_tgt = encoded_input_tgt["input_ids"] 184 | 185 | count_length_src = np.mean([len(item) for item in self.input_ids_src]) 186 | count_length_tgt = np.mean([len(item) for item in self.input_ids_tgt]) 187 | 188 | print(f'average number of tokens in source {count_length_src}') 189 | print(f'average number of tokens in target {count_length_tgt}') 190 | 191 | def __len__(self) -> int: 192 | return len(self.src_text) 193 | 194 | def __getitem__(self, i): 195 | out_dict = { 196 | "encoder_input_ids": self.input_ids_src[i], 197 | "decoder_input_ids": self.input_ids_tgt[i], 198 | } 199 | return out_dict 200 | 201 | @staticmethod 202 | def collate_pad(batch, args, cutoff: int, cutoff_src: int, padding_token: int): 203 | max_token_len_src, max_token_len_tgt = cutoff_src, cutoff 204 | num_elems = len(batch) 205 | 206 | tokens_src = torch.ones(num_elems, max_token_len_src).long() * padding_token 207 | tokens_mask_src = torch.zeros(num_elems, max_token_len_src).long() 208 | 209 | tokens_tgt = torch.ones(num_elems, max_token_len_tgt).long() * padding_token 210 | tokens_mask_tgt = torch.zeros(num_elems, max_token_len_tgt).long() 211 | 212 | for i in range(num_elems): 213 | toks_src = batch[i]["encoder_input_ids"][:max_token_len_src] 214 | toks_tgt = batch[i]["decoder_input_ids"][:max_token_len_tgt] 215 | l_s, l_t = len(toks_src), len(toks_tgt) 216 | tokens_src[i, :l_s] = torch.LongTensor(toks_src) 217 | tokens_tgt[i, :l_t] = torch.LongTensor(toks_tgt) 218 | tokens_mask_src[i, :l_s] = 1 219 | tokens_mask_tgt[i, :] = 1 220 | 221 | return {"input_ids": tokens_src, "attention_mask": tokens_mask_src, 222 | 'decoder_input_ids': tokens_tgt, 'decoder_attention_mask': tokens_mask_tgt}, None 223 | -------------------------------------------------------------------------------- /dataloader_utils_ori.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | import pandas as pd 4 | from torch.utils.data import DataLoader, Dataset 5 | import torch 6 | from functools import partial 7 | from mpi4py import MPI 8 | import os 9 | import random 10 | import numpy as np 11 | logging.basicConfig(level=logging.INFO) 12 | 13 | def get_dataloader(tokenizer, data_path, batch_size, max_seq_len, max_seq_len_src, args): 14 | 15 | dataset = TextDataset_translation(tokenizer=tokenizer, data_path=data_path, source=args.src, target=args.tgt, 16 | shard=MPI.COMM_WORLD.Get_rank(), 17 | num_shards=MPI.COMM_WORLD.Get_size()) 18 | 19 | dataloader = DataLoader( 20 | dataset, 21 | batch_size=batch_size, # 20, 22 | drop_last=True, 23 | shuffle='train' in data_path, 24 | num_workers=10, 25 | collate_fn=partial(TextDataset_translation.collate_pad, 26 | args=args, 27 | cutoff=max_seq_len, 28 | cutoff_src=max_seq_len_src, 29 | padding_token=tokenizer.pad_token_id if hasattr(tokenizer, 'pad_token_id') else tokenizer.get_vocab()['']), 30 | ) 31 | 32 | while True: 33 | for batch in dataloader: 34 | yield batch 35 | 36 | class TextDataset(Dataset): 37 | def __init__( 38 | self, 39 | tokenizer, 40 | data_path: str, 41 | has_labels: bool = False 42 | ) -> None: 43 | super().__init__() 44 | self.data_path = data_path 45 | self.tokenizer = tokenizer 46 | self.read_data() 47 | if has_labels: 48 | self.read_labels() 49 | 50 | def read_data(self): 51 | logging.info("Reading data from {}".format(self.data_path)) 52 | data = pd.read_csv(self.data_path, sep="\t", header=None) # read text file 53 | logging.info(f"Tokenizing {len(data)} sentences") 54 | 55 | self.text = data[0].apply(lambda x: x.strip()).tolist() 56 | if hasattr(self.tokenizer, 'encode_batch'): 57 | 58 | encoded_input = self.tokenizer.encode_batch(self.text) 59 | self.input_ids = [x.ids for x in encoded_input] 60 | 61 | else: 62 | encoded_input = self.tokenizer(self.text) 63 | self.input_ids = encoded_input["input_ids"] 64 | 65 | 66 | 67 | def read_labels(self): 68 | self.labels = pd.read_csv(self.data_path, sep="\t", header=None)[1].tolist() 69 | # check if labels are already numerical 70 | self.labels = [str(x) for x in self.labels] 71 | if isinstance(self.labels[0], int): 72 | return 73 | # if not, convert to numerical 74 | all_labels = sorted(list(set(self.labels))) 75 | self.label_to_idx = {label: i for i, label in enumerate(all_labels)} 76 | self.idx_to_label = {i: label for i, label in self.label_to_idx.items()} 77 | self.labels = [self.label_to_idx[label] for label in self.labels] 78 | 79 | 80 | 81 | def __len__(self) -> int: 82 | return len(self.text) 83 | 84 | def __getitem__(self, i): 85 | out_dict = { 86 | "input_ids": self.input_ids[i], 87 | # "attention_mask": [1] * len(self.input_ids[i]), 88 | } 89 | if hasattr(self, "labels"): 90 | out_dict["label"] = self.labels[i] 91 | return out_dict 92 | 93 | @staticmethod 94 | def collate_pad(batch, cutoff: int): 95 | max_token_len = 0 96 | num_elems = len(batch) 97 | # batch[0] -> __getitem__[0] --> returns a tuple (embeddings, out_dict) 98 | 99 | for i in range(num_elems): 100 | max_token_len = max(max_token_len, len(batch[i]["input_ids"])) 101 | 102 | max_token_len = min(cutoff, max_token_len) 103 | 104 | tokens = torch.zeros(num_elems, max_token_len).long() 105 | tokens_mask = torch.zeros(num_elems, max_token_len).long() 106 | 107 | has_labels = False 108 | if "label" in batch[0]: 109 | labels = torch.zeros(num_elems).long() 110 | has_labels = True 111 | 112 | for i in range(num_elems): 113 | toks = batch[i]["input_ids"] 114 | length = len(toks) 115 | tokens[i, :length] = torch.LongTensor(toks) 116 | tokens_mask[i, :length] = 1 117 | if has_labels: 118 | labels[i] = batch[i]["label"] 119 | 120 | # TODO: the first return None is just for backward compatibility -- can be removed 121 | if has_labels: 122 | return None, {"input_ids": tokens, "attention_mask": tokens_mask, "labels": labels} 123 | else: 124 | return None, {"input_ids": tokens, "attention_mask": tokens_mask} 125 | 126 | 127 | class TextDataset_translation(TextDataset): 128 | 129 | def __init__( 130 | self, 131 | tokenizer, 132 | data_path: str, 133 | source, 134 | target, 135 | shard, 136 | num_shards, 137 | ) -> None: 138 | self.data_path = data_path 139 | self.tokenizer = tokenizer 140 | self.shard = shard 141 | self.src = source 142 | self.tgt = target 143 | self.num_shards = num_shards 144 | self.read_data() 145 | 146 | def read_data(self): 147 | print("Reading data from {}".format(self.data_path)) 148 | data = [open(self.data_path+'.'+self.src, 'r').readlines(), 149 | open(self.data_path+'.'+self.tgt, 'r').readlines()] 150 | print(f"Tokenizing {len(data[0])} sentences") 151 | 152 | data = [[src, tgt] for src, tgt in zip(data[0], data[1])] 153 | # random.shuffle(data) 154 | 155 | self.src_text = [item[0].strip('\n') for item in data] 156 | self.tgt_text = [item[1].strip('\n') for item in data] 157 | 158 | bos_idx = (len(self.src_text) // self.num_shards) * self.shard 159 | eos_idx = (len(self.src_text) // self.num_shards) * (self.shard+1) 160 | self.src_text = self.src_text[bos_idx:eos_idx] 161 | self.tgt_text = self.tgt_text[bos_idx:eos_idx] 162 | 163 | print('examples src', self.src_text[0]) 164 | print('examples tgt', self.tgt_text[0]) 165 | 166 | # check if tokenizer has a method 'encode_batch' 167 | if hasattr(self.tokenizer, 'encode_batch'): 168 | 169 | encoded_input_src = self.tokenizer.encode_batch(self.src_text) 170 | self.input_ids_src = [x.ids for x in encoded_input_src] 171 | 172 | encoded_input_tgt = self.tokenizer.encode_batch(self.tgt_text) 173 | self.input_ids_tgt = [x.ids for x in encoded_input_tgt] 174 | 175 | else: 176 | 177 | encoded_input_src = self.tokenizer(self.src_text) 178 | self.input_ids_src = encoded_input_src["input_ids"] 179 | 180 | encoded_input_tgt = self.tokenizer(self.tgt_text) 181 | self.input_ids_tgt = encoded_input_tgt["input_ids"] 182 | 183 | count_length_src = np.mean([len(item) for item in self.input_ids_src]) 184 | count_length_tgt = np.mean([len(item) for item in self.input_ids_tgt]) 185 | 186 | print(f'average number of tokens in source {count_length_src}') 187 | print(f'average number of tokens in target {count_length_tgt}') 188 | 189 | def __len__(self) -> int: 190 | return len(self.src_text) 191 | 192 | def __getitem__(self, i): 193 | out_dict = { 194 | "encoder_input_ids": self.input_ids_src[i], 195 | "decoder_input_ids": self.input_ids_tgt[i], 196 | } 197 | return out_dict 198 | 199 | @staticmethod 200 | def collate_pad(batch, args, cutoff: int, cutoff_src: int, padding_token: int): 201 | max_token_len_src, max_token_len_tgt = cutoff_src, cutoff 202 | num_elems = len(batch) 203 | 204 | tokens_src = torch.ones(num_elems, max_token_len_src).long() * padding_token 205 | tokens_mask_src = torch.zeros(num_elems, max_token_len_src).long() 206 | 207 | tokens_tgt = torch.ones(num_elems, max_token_len_tgt).long() * padding_token 208 | tokens_mask_tgt = torch.zeros(num_elems, max_token_len_tgt).long() 209 | 210 | for i in range(num_elems): 211 | toks_src = batch[i]["encoder_input_ids"][:max_token_len_src] 212 | toks_tgt = batch[i]["decoder_input_ids"][:max_token_len_tgt] 213 | l_s, l_t = len(toks_src), len(toks_tgt) 214 | tokens_src[i, :l_s] = torch.LongTensor(toks_src) 215 | tokens_tgt[i, :l_t] = torch.LongTensor(toks_tgt) 216 | tokens_mask_src[i, :l_s] = 1 217 | tokens_mask_tgt[i, :] = 1 218 | 219 | return {"input_ids": tokens_src, "attention_mask": tokens_mask_src, 220 | 'decoder_input_ids': tokens_tgt, 'decoder_attention_mask': tokens_mask_tgt}, None 221 | -------------------------------------------------------------------------------- /figure/README: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /figure/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AspirinCode/DiffIUPAC/525ffe6850c21d7cafbca94c5c6f971da2a450d4/figure/framework.png -------------------------------------------------------------------------------- /inference_main.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generate a large batch of image samples from a model and save them as a large 3 | numpy array. This can be used to produce samples for FID evaluation. 4 | """ 5 | import os, json 6 | from typing import List 7 | import numpy as np 8 | import torch as th 9 | import torch.distributed as dist 10 | from transformers import set_seed 11 | from src.utils import dist_util, logger 12 | 13 | from args_utils import * 14 | from model_utils import create_model_and_diffusion 15 | from args_utils import create_argparser, args_to_dict, model_and_diffusion_defaults 16 | from tokenizer_utils import create_tokenizer 17 | import dataloader_utils 18 | from mpi4py import MPI 19 | 20 | def main(): 21 | 22 | args = create_argparser().parse_args() 23 | 24 | set_seed(args.seed) 25 | th.manual_seed(args.seed) 26 | print(args.seed) 27 | dist_util.setup_dist() 28 | logger.configure() 29 | 30 | # load configurations. 31 | args.checkpoint_path = os.path.split(args.model_name_or_path)[0] 32 | 33 | config_path = os.path.join(args.checkpoint_path, "training_args.json") 34 | training_args = read_training_args(config_path) 35 | training_args["batch_size"] = args.batch_size 36 | training_args["diffusion_steps"] = args.diffusion_steps 37 | training_args['model_name_or_path'] = args.model_name_or_path 38 | training_args["clamp"] = args.clamp 39 | training_args['out_dir'] = args.out_dir 40 | training_args['num_samples'] = args.num_samples 41 | training_args['val_txt_path'] = args.val_txt_path 42 | training_args['top_p'] = args.top_p 43 | training_args['sequence_len_src'] = args.sequence_len_src 44 | training_args['sequence_len'] = args.sequence_len 45 | training_args['generate_by_q'] = args.generate_by_q 46 | training_args['generate_by_mix'] = args.generate_by_mix 47 | training_args['time_schedule_path'] = args.time_schedule_path 48 | training_args['seed'] = args.seed 49 | 50 | args.__dict__.update(training_args) 51 | args.sigma_small = True 52 | 53 | 54 | logger.info(f"Init pretrained = {args.init_pretrained}") 55 | logger.info(f"Freeze embeddings = {args.freeze_embeddings}") 56 | logger.info(f"Use pretrained embeddings = {args.use_pretrained_embeddings}") 57 | logger.info(f"Use pretrained embeddings = {args.use_pretrained_tokenizer}") 58 | 59 | tokenizer = create_tokenizer(return_pretokenized=args.use_pretrained_tokenizer, 60 | path=f"data/{args.dataset}/", 61 | tokenizer_type='byte-level', 62 | tokenizer_ckpt=args.pretrained_tokenizer) 63 | 64 | model, diffusion = create_model_and_diffusion( 65 | pad_tok_id=tokenizer.pad_token_id if hasattr(tokenizer, 'pad_token_id') else tokenizer.get_vocab()[''], 66 | resume_checkpoint=args.resume_checkpoint, **args_to_dict(args, model_and_diffusion_defaults().keys()) 67 | ) 68 | 69 | diffusion._load_time_schedule(args.time_schedule_path) 70 | model.load_state_dict(dist_util.load_state_dict(args.model_name_or_path, map_location="cpu")) 71 | model.eval() 72 | 73 | print('data path', args.val_txt_path) 74 | val_dataloader = dataloader_utils.get_dataloader( 75 | tokenizer=tokenizer, 76 | args=args, 77 | data_path=args.val_txt_path, 78 | batch_size=args.batch_size, 79 | max_seq_len=args.sequence_len, 80 | max_seq_len_src=args.sequence_len_src, 81 | ) 82 | 83 | if args.num_samples <= 0: 84 | args.num_samples = len(dataloader_utils.TextDataset_translation(tokenizer=tokenizer, data_path=args.val_txt_path, source=args.src, target=args.tgt, 85 | shard=MPI.COMM_WORLD.Get_rank(), 86 | num_shards=MPI.COMM_WORLD.Get_size())) 87 | logger.log(f"sample count is {args.num_samples}") 88 | pytorch_total_params = sum(p.numel() for p in model.parameters()) 89 | logger.log(f"the parameter count is {pytorch_total_params}") 90 | 91 | diffusion.rescale_timesteps = True 92 | 93 | model.to(dist_util.dev()) 94 | model.eval() # DEBUG 95 | 96 | logger.log("sampling...") 97 | logger.log(f"Clamping is set to {args.clamp}") 98 | all_samples = [] 99 | ground_true_samples = [] 100 | while len(all_samples) * args.batch_size < args.num_samples: 101 | batch, _ = next(val_dataloader) 102 | model_kwargs = {key:item.to(dist_util.dev()) for key, item in batch.items() if 'decoder' not in key} 103 | sample_shape = (args.batch_size, args.sequence_len, model.input_transformers.shared.weight.shape[1]) 104 | print('sample_shape', sample_shape) 105 | sample = diffusion.p_sample_loop( 106 | model, 107 | sample_shape, 108 | clip_denoised=args.clip_denoised, 109 | denoised_fn=None, 110 | model_kwargs=model_kwargs, 111 | top_p=args.top_p, 112 | progress=True, 113 | tokenizer=tokenizer, 114 | log_verbose=True, 115 | decoder_inputs=batch['decoder_input_ids'], 116 | generate_by_q=args.generate_by_q, 117 | generate_by_mix=args.generate_by_mix, 118 | generate_by_mix_prob=args.generate_by_mix_prob, 119 | generate_by_mix_part=args.generate_by_mix_part, 120 | ) 121 | 122 | logits = model.get_logits(sample) # bsz, seqlen, vocab 123 | cands = th.topk(logits, k=1, dim=-1).indices.squeeze() 124 | if args.decoder_attention_mask: 125 | cands[model_kwargs['decoder_attention_mask']==0] = 1 126 | 127 | gathered_samples = [th.zeros_like(cands) for _ in range(dist.get_world_size())] 128 | dist.all_gather(gathered_samples, cands) # gather not supported with NCCL 129 | all_samples.extend([sample.cpu().numpy() for sample in gathered_samples]) 130 | print('number of sample', len(all_samples), all_samples[0].shape) 131 | 132 | batch['decoder_input_ids'] = batch['decoder_input_ids'].to(dist_util.dev()) 133 | gathered_ground_true_sample = [th.zeros_like(batch['decoder_input_ids']) for _ in range(dist.get_world_size())] 134 | dist.all_gather(gathered_ground_true_sample, batch['decoder_input_ids']) 135 | ground_true_samples.extend([sample.cpu().numpy() for sample in gathered_ground_true_sample]) 136 | 137 | logger.log(f"created {len(all_samples) * args.batch_size} samples") 138 | 139 | cands = np.concatenate(all_samples, axis=0) 140 | cands = cands[: args.num_samples] 141 | 142 | decoded_sentences = [] 143 | for seq in cands: 144 | seq = seq[seq>2] 145 | decoded_sentence = tokenizer.decode(seq.tolist(), skip_special_tokens=True) 146 | decoded_sentences.append(decoded_sentence) 147 | 148 | ground_true_sentences = [] 149 | ground_true_samples = np.concatenate(ground_true_samples, axis=0)[: args.num_samples] 150 | for seq in ground_true_samples: 151 | seq = seq[seq>2] 152 | ground_true_sentence = tokenizer.decode(seq.squeeze().tolist(), skip_special_tokens=True) 153 | ground_true_sentences.append(ground_true_sentence) 154 | 155 | dist.barrier() 156 | logger.log("sampling complete") 157 | 158 | write_outputs(args=args, 159 | sentences=decoded_sentences, 160 | gt_sentences = ground_true_sentences, 161 | raw_sentences=cands, 162 | raw_gt_sentences=ground_true_samples,) 163 | 164 | 165 | def load_embeddings(checkpoint_path, tokenizer, emb_dim): 166 | embeddings = th.nn.Embedding(tokenizer.vocab_size, emb_dim) 167 | embeddings.load_state_dict(th.load(f'{checkpoint_path}/random_emb.torch')) 168 | return embeddings 169 | 170 | 171 | def read_training_args(config_path): 172 | with open(config_path, "r") as f: 173 | return json.load(f) 174 | 175 | 176 | def write_outputs(args: dict, sentences: List[str], gt_sentences: List[str], raw_sentences, raw_gt_sentences) -> None: 177 | 178 | model_dir = os.path.split(args.model_name_or_path)[0] 179 | model_base_name = os.path.split(args.model_name_or_path)[1] 180 | if args.generate_by_q: 181 | comments = f'predict_by_qsample_{args.seed}' 182 | elif args.generate_by_mix: 183 | comments = f'predict_by_mixsample_{args.generate_by_mix_prob}_{args.generate_by_mix_part}_{args.seed}' 184 | else: 185 | comments = f'normal_{args.seed}' 186 | num_samples = len(sentences) 187 | output_file_basepath = os.path.join( 188 | model_dir, 189 | f"{model_base_name}.samples_{num_samples}.steps-{args.diffusion_steps}.clamp-{args.clamp}-{comments}", 190 | ) + ".txt" 191 | with open(output_file_basepath, "w") as text_fout: 192 | for generated_sentence, ground_true_sentence in zip(sentences, gt_sentences): 193 | text_fout.write(json.dumps([generated_sentence, ground_true_sentence]) + "\n") 194 | 195 | print(f"written the decoded output to {output_file_basepath}") 196 | 197 | output_file_basepath = os.path.join( 198 | model_dir, 199 | f"{model_base_name}.samples_{num_samples}.steps-{args.diffusion_steps}.clamp-{args.clamp}.raw-output-ids-{comments}", 200 | ) + ".txt" 201 | with open(output_file_basepath, "w") as text_fout: 202 | for generated_sentence, ground_true_sentence in zip(raw_sentences, raw_gt_sentences): 203 | text_fout.write(json.dumps([generated_sentence.tolist(), ground_true_sentence.tolist()]) + "\n") 204 | 205 | print(f"written the decoded output to {output_file_basepath}") 206 | 207 | 208 | if __name__ == "__main__": 209 | main() 210 | -------------------------------------------------------------------------------- /iupac_tokenization.py: -------------------------------------------------------------------------------- 1 | from transformers import ( 2 | AdamW, 3 | DataCollatorWithPadding, 4 | HfArgumentParser, 5 | T5Config, 6 | T5ForConditionalGeneration, 7 | T5Tokenizer, 8 | Trainer, 9 | TrainingArguments, 10 | ) 11 | import os 12 | import re 13 | import pandas as pd 14 | import numpy as np 15 | import torch 16 | from torch.nn.utils.rnn import pad_sequence 17 | import os.path as pt 18 | #os.environ["CUDA_VISIBLE_DEVICES"]="0" 19 | 20 | 21 | class T5Collator: 22 | def __init__(self, pad_token_id): 23 | super().__init__() 24 | self.pad_token_id = pad_token_id 25 | def __call__(self, records): 26 | # records is a list of dicts 27 | batch = {} 28 | padvals = {"input_ids": self.pad_token_id,'labels':-100} 29 | for k in records[0]: 30 | if k in padvals: 31 | batch[k] = pad_sequence([torch.tensor(r[k]) for r in records], 32 | batch_first=True, 33 | padding_value=padvals[k]) 34 | else: 35 | batch[k] = torch.FloatTensor([r[k] for r in records]) #torch.Tensor 36 | return batch 37 | 38 | class T5IUPACTokenizer(T5Tokenizer): 39 | def prepare_for_tokenization(self, text, is_split_into_words=False, 40 | **kwargs): 41 | return re.sub(" ", "_", text), kwargs 42 | 43 | def _decode(self, *args, **kwargs): 44 | # replace "_" with " ", except for the _ in extra_id_# 45 | text = super()._decode(*args, **kwargs) 46 | text = re.sub("extra_id_", "extraAidA", text) 47 | text = re.sub("_", " ", text) 48 | text = re.sub("extraAidA", "extra_id_", text) 49 | return text 50 | 51 | def sentinels(self, sentinel_ids): 52 | return self.vocab_size - sentinel_ids - 1 53 | 54 | def sentinel_mask(self, ids): 55 | return ((self.vocab_size - self._extra_ids <= ids) & 56 | (ids < self.vocab_size)) 57 | 58 | def _tokenize(self, text, sample=False): 59 | #pieces = super()._tokenize(text, sample=sample) 60 | pieces = super()._tokenize(text) 61 | # sentencepiece adds a non-printing token at the start. Remove it 62 | return [""]+pieces[1:] 63 | 64 | 65 | def get_iupac_tokenizer(is_train=1,full_path = './data'): 66 | 67 | iupac_tokenizer = T5IUPACTokenizer(vocab_file=pt.join(full_path,'iupac_spm.model')) 68 | iupac_vocab_size = iupac_tokenizer.vocab_size 69 | print('iupac_vocab_size:',iupac_vocab_size) 70 | if is_train: 71 | torch.save(iupac_tokenizer, pt.join(full_path,"real_iupac_tokenizer.pt")) 72 | print("training...",len(iupac_tokenizer)) 73 | else: 74 | iupac_tokenizer = torch.load(pt.join(full_path,"real_iupac_tokenizer.pt"), map_location="cpu") 75 | print('fina_tune...',len(iupac_tokenizer)) 76 | 77 | #collator = T5Collator(iupac_tokenizer.pad_token_id) 78 | 79 | return iupac_tokenizer 80 | 81 | if __name__ == "__main__": 82 | 83 | iupac_tokenizer = get_iupac_tokenizer(is_train=1,full_path = './data') 84 | 85 | print(iupac_tokenizer,iupac_tokenizer.vocab_size) 86 | 87 | iupac_string = "2-(6-aminopurin-9-yl)-5-(methylsulfanylmethyl)oxolane-3,4-diol" 88 | iupac_encoded = iupac_tokenizer(iupac_string) 89 | iupac_merges = iupac_tokenizer.convert_ids_to_tokens(iupac_encoded["input_ids"]) 90 | print(iupac_encoded) 91 | print(iupac_merges) 92 | 93 | line_number = 1 94 | 95 | valid_line=[] 96 | 97 | with open("data/pubchem_iupac.csv",'r') as f: 98 | myline = f.readline() 99 | while myline: 100 | #print("line_number:",line_number) 101 | 102 | iupac_encoded = iupac_tokenizer(myline) 103 | iupac_merges = iupac_tokenizer.convert_ids_to_tokens(iupac_encoded["input_ids"]) 104 | #print(iupac_encoded) 105 | #print(iupac_merges) 106 | 107 | if iupac_encoded["input_ids"].count(2)==1: 108 | valid_line.append(myline) 109 | 110 | if line_number%50000==0: 111 | with open("data/pubchem_iupac_valid.csv",'a') as ff: 112 | for j in valid_line: 113 | ff.write(j) 114 | valid_line=[] 115 | 116 | myline = f.readline() 117 | line_number = 1+line_number 118 | 119 | 120 | 121 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train a diffusion model on images. 3 | """ 4 | 5 | import json, os 6 | import pathlib 7 | import pprint 8 | import sys 9 | from transformers import set_seed 10 | import os 11 | 12 | from src.utils import dist_util, logger 13 | from src.modeling.diffusion.resample import create_named_schedule_sampler 14 | from model_utils import create_model_and_diffusion 15 | from trainer import Trainer 16 | import dataloader_utils 17 | from args_utils import create_argparser, args_to_dict, model_and_diffusion_defaults 18 | from tokenizer_utils import create_iupac_smiles_tokenizer #create_tokenizer 19 | 20 | 21 | def main(): 22 | args = create_argparser().parse_args() 23 | dist_util.setup_dist() 24 | logger.configure(dir=os.path.join(args.checkpoint_path, 'logger/')) 25 | set_seed(args.seed) 26 | print(f'set seed {args.seed + int(os.environ["RANK"])}') 27 | 28 | logger.log("creating data loader") 29 | pathlib.Path(args.checkpoint_path).mkdir(parents=True, exist_ok=True) 30 | 31 | #tokenizer = create_tokenizer(return_pretokenized=args.use_pretrained_tokenizer, 32 | # path=f"data/{args.dataset}/", 33 | # tokenizer_type='byte-level', 34 | # tokenizer_ckpt=args.pretrained_tokenizer) 35 | 36 | iupac_tokenizer,smiles_tokenizer = create_iupac_smiles_tokenizer(return_pretokenized=True, 37 | path=f"data/{args.dataset}/", 38 | tokenizer_ckpt='./data') # 39 | 40 | 41 | train_dataloader = dataloader_utils.get_dataloader( 42 | iupac_tokenizer=iupac_tokenizer, 43 | smiles_tokenizer=smiles_tokenizer, 44 | args=args, 45 | data_path=args.train_txt_path, 46 | batch_size=args.batch_size, 47 | max_seq_len=args.sequence_len, 48 | max_seq_len_src=args.sequence_len_src, 49 | ) 50 | 51 | val_dataloader = dataloader_utils.get_dataloader( 52 | iupac_tokenizer=iupac_tokenizer, 53 | smiles_tokenizer=smiles_tokenizer, 54 | args=args, 55 | data_path=args.val_txt_path, 56 | batch_size=args.batch_size, 57 | max_seq_len=args.sequence_len, 58 | max_seq_len_src=args.sequence_len_src, 59 | ) 60 | args.vocab_size = iupac_tokenizer.vocab_size 61 | 62 | logger.log("creating model and diffusion...", args.checkpoint_path) 63 | model, diffusion = create_model_and_diffusion( 64 | pad_tok_id=iupac_tokenizer.pad_token_id if hasattr(iupac_tokenizer, 'pad_token_id') else iupac_tokenizer.get_vocab()[''], 65 | resume_checkpoint=args.checkpoint_path, 66 | **args_to_dict(args, model_and_diffusion_defaults().keys()) 67 | ) 68 | model.to(dist_util.dev()) 69 | 70 | print(model) 71 | 72 | pytorch_total_params = sum(p.numel() for p in model.parameters()) 73 | 74 | logger.log(f"the parameter count is {pytorch_total_params}") 75 | schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion) 76 | 77 | logger.log(f"saving the hyperparameters to {args.checkpoint_path}/training_args.json") 78 | with open(f"{args.checkpoint_path}/training_args.json", "w") as f: 79 | json.dump(args.__dict__, f, indent=2) 80 | 81 | logger.log("training...") 82 | Trainer( 83 | model=model, 84 | diffusion=diffusion, 85 | data=train_dataloader, 86 | batch_size=args.batch_size, 87 | microbatch=args.microbatch, 88 | lr=args.lr, 89 | ema_rate=args.ema_rate, 90 | log_interval=args.log_interval, 91 | save_interval=args.save_interval, 92 | resume_checkpoint=args.resume_checkpoint, 93 | use_fp16=args.use_fp16, 94 | fp16_scale_growth=args.fp16_scale_growth, 95 | schedule_sampler=schedule_sampler, 96 | weight_decay=args.weight_decay, 97 | lr_anneal_steps=args.lr_anneal_steps, 98 | checkpoint_path=args.checkpoint_path, 99 | gradient_clipping=args.gradient_clipping, 100 | eval_data=val_dataloader, 101 | eval_interval=args.eval_interval, 102 | warmup=args.warmup, 103 | ).run_loop() 104 | 105 | 106 | def make_tensorboard_name_from_args(args): 107 | keys_to_add = ["batch_size", "lr", "num_heads", "lr_anneal_steps", "config_name", "seed", "in_channel"] 108 | name = "" 109 | for key in keys_to_add: 110 | name += f"{key}={getattr(args, key)}_" 111 | return name 112 | 113 | if __name__ == "__main__": 114 | main() 115 | -------------------------------------------------------------------------------- /main_retrain.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train a diffusion model on images. 3 | """ 4 | 5 | import json, os 6 | import pathlib 7 | import pprint 8 | import sys 9 | from transformers import set_seed 10 | import os 11 | 12 | from src.utils import dist_util, logger 13 | from src.modeling.diffusion.resample import create_named_schedule_sampler 14 | from model_utils import create_model_and_diffusion 15 | from trainer import Trainer 16 | import dataloader_utils 17 | from args_utils import create_argparser, args_to_dict, model_and_diffusion_defaults 18 | from tokenizer_utils import create_iupac_smiles_tokenizer #create_tokenizer 19 | 20 | 21 | def main(): 22 | args = create_argparser().parse_args() 23 | dist_util.setup_dist() 24 | logger.configure(dir=os.path.join(args.checkpoint_path, 'logger/')) 25 | set_seed(args.seed) 26 | print(f'set seed {args.seed + int(os.environ["RANK"])}') 27 | 28 | logger.log("creating data loader") 29 | pathlib.Path(args.checkpoint_path).mkdir(parents=True, exist_ok=True) 30 | 31 | #tokenizer = create_tokenizer(return_pretokenized=args.use_pretrained_tokenizer, 32 | # path=f"data/{args.dataset}/", 33 | # tokenizer_type='byte-level', 34 | # tokenizer_ckpt=args.pretrained_tokenizer) 35 | 36 | iupac_tokenizer,smiles_tokenizer = create_iupac_smiles_tokenizer(return_pretokenized=True, 37 | path=f"data/{args.dataset}/", 38 | tokenizer_ckpt='./data') # 39 | 40 | 41 | train_dataloader = dataloader_utils.get_dataloader( 42 | iupac_tokenizer=iupac_tokenizer, 43 | smiles_tokenizer=smiles_tokenizer, 44 | args=args, 45 | data_path=args.train_txt_path, 46 | batch_size=args.batch_size, 47 | max_seq_len=args.sequence_len, 48 | max_seq_len_src=args.sequence_len_src, 49 | ) 50 | 51 | val_dataloader = dataloader_utils.get_dataloader( 52 | iupac_tokenizer=iupac_tokenizer, 53 | smiles_tokenizer=smiles_tokenizer, 54 | args=args, 55 | data_path=args.val_txt_path, 56 | batch_size=args.batch_size, 57 | max_seq_len=args.sequence_len, 58 | max_seq_len_src=args.sequence_len_src, 59 | ) 60 | args.vocab_size = iupac_tokenizer.vocab_size 61 | 62 | args.load_ckpt=args.model_name_or_path 63 | 64 | args.resume_checkpoint=args.checkpoint_path 65 | 66 | logger.log("creating model and diffusion...", args.checkpoint_path) 67 | model, diffusion = create_model_and_diffusion( 68 | pad_tok_id=iupac_tokenizer.pad_token_id if hasattr(iupac_tokenizer, 'pad_token_id') else iupac_tokenizer.get_vocab()[''], 69 | resume_checkpoint=args.checkpoint_path, 70 | **args_to_dict(args, model_and_diffusion_defaults().keys()) 71 | ) 72 | 73 | 74 | diffusion._load_time_schedule(args.time_schedule_path) 75 | #model.load_state_dict(dist_util.load_state_dict(args.model_name_or_path, map_location="cpu")) 76 | 77 | 78 | model.to(dist_util.dev()) 79 | 80 | print(model) 81 | 82 | pytorch_total_params = sum(p.numel() for p in model.parameters()) 83 | 84 | logger.log(f"the parameter count is {pytorch_total_params}") 85 | schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion) 86 | 87 | logger.log(f"saving the hyperparameters to {args.checkpoint_path}/training_args.json") 88 | with open(f"{args.checkpoint_path}/training_args.json", "w") as f: 89 | json.dump(args.__dict__, f, indent=2) 90 | 91 | logger.log("training...") 92 | Trainer( 93 | model=model, 94 | diffusion=diffusion, 95 | data=train_dataloader, 96 | batch_size=args.batch_size, 97 | microbatch=args.microbatch, 98 | lr=args.lr, 99 | ema_rate=args.ema_rate, 100 | log_interval=args.log_interval, 101 | save_interval=args.save_interval, 102 | resume_checkpoint=args.resume_checkpoint, 103 | use_fp16=args.use_fp16, 104 | fp16_scale_growth=args.fp16_scale_growth, 105 | schedule_sampler=schedule_sampler, 106 | weight_decay=args.weight_decay, 107 | lr_anneal_steps=args.lr_anneal_steps, 108 | checkpoint_path=args.checkpoint_path, 109 | gradient_clipping=args.gradient_clipping, 110 | eval_data=val_dataloader, 111 | eval_interval=args.eval_interval, 112 | warmup=args.warmup, 113 | ).run_loop() 114 | 115 | 116 | def make_tensorboard_name_from_args(args): 117 | keys_to_add = ["batch_size", "lr", "num_heads", "lr_anneal_steps", "config_name", "seed", "in_channel"] 118 | name = "" 119 | for key in keys_to_add: 120 | name += f"{key}={getattr(args, key)}_" 121 | return name 122 | 123 | if __name__ == "__main__": 124 | main() 125 | -------------------------------------------------------------------------------- /model_utils.py: -------------------------------------------------------------------------------- 1 | import src.modeling.diffusion.gaussian_diffusion as gd 2 | from src.modeling.diffusion.respace import SpacedDiffusion, space_timesteps 3 | from src.modeling.predictor.transformer_model import TransformerNetModel_encoder_decoder 4 | 5 | 6 | def create_model_and_diffusion( 7 | class_cond, 8 | learn_sigma, 9 | sigma_small, 10 | num_channels, 11 | num_heads, 12 | dropout, 13 | diffusion_steps, 14 | noise_schedule, 15 | timestep_respacing, 16 | use_kl, 17 | predict_xstart, 18 | rescale_timesteps, 19 | rescale_learned_sigmas, 20 | use_checkpoint, 21 | model_arch, 22 | in_channel, 23 | out_channel, 24 | training_mode, 25 | vocab_size, 26 | config_name, 27 | logits_mode, 28 | init_pretrained, 29 | freeze_embeddings, 30 | use_pretrained_embeddings, 31 | load_ckpt, 32 | sequence_len, 33 | resume_checkpoint, 34 | pad_tok_id, 35 | loss_update_granu, 36 | schedule_update_stride, 37 | **kwargs, 38 | ): 39 | model = create_model( 40 | num_channels, 41 | learn_sigma=learn_sigma, 42 | class_cond=class_cond, 43 | use_checkpoint=use_checkpoint, 44 | num_heads=num_heads, 45 | dropout=dropout, 46 | in_channel=in_channel, 47 | out_channel=out_channel, 48 | training_mode=training_mode, 49 | vocab_size=vocab_size, 50 | config_name=config_name, 51 | logits_mode=logits_mode, 52 | init_pretrained=init_pretrained, 53 | freeze_embeddings=freeze_embeddings, 54 | use_pretrained_embeddings=use_pretrained_embeddings, 55 | load_ckpt=load_ckpt, 56 | ) 57 | diffusion = create_gaussian_diffusion( 58 | steps=diffusion_steps, 59 | learn_sigma=learn_sigma, 60 | sigma_small=sigma_small, 61 | noise_schedule=noise_schedule, 62 | use_kl=use_kl, 63 | predict_xstart=predict_xstart, 64 | rescale_timesteps=rescale_timesteps, 65 | rescale_learned_sigmas=rescale_learned_sigmas, 66 | timestep_respacing=timestep_respacing, 67 | model_arch=model_arch, 68 | training_mode=training_mode, 69 | sequence_len=sequence_len, 70 | resume_checkpoint=resume_checkpoint, 71 | pad_tok_id=pad_tok_id, 72 | loss_update_granu=loss_update_granu, 73 | schedule_update_stride=schedule_update_stride, 74 | ) 75 | return model, diffusion 76 | 77 | 78 | def create_model( 79 | num_channels, 80 | learn_sigma, 81 | use_checkpoint, 82 | class_cond, # TODO for the next version 83 | num_heads, 84 | dropout, 85 | init_pretrained, 86 | freeze_embeddings, 87 | use_pretrained_embeddings, 88 | in_channel, 89 | out_channel, 90 | training_mode, 91 | vocab_size, 92 | config_name, 93 | logits_mode, 94 | load_ckpt, 95 | encoder_layers = 6, 96 | decoder_layers = 6, 97 | model_type = 'encoder_decoder', 98 | ): 99 | return TransformerNetModel_encoder_decoder( 100 | in_channels=in_channel, 101 | model_channels=num_channels, 102 | out_channels=(out_channel if not learn_sigma else out_channel * 2), 103 | dropout=dropout, 104 | use_checkpoint=use_checkpoint, 105 | num_heads=num_heads, 106 | config_name=config_name, 107 | vocab_size=vocab_size, 108 | logits_mode=logits_mode, 109 | init_pretrained=init_pretrained, 110 | use_pretrained_embeddings=use_pretrained_embeddings, 111 | freeze_embeddings=freeze_embeddings, 112 | encoder_layers = encoder_layers, 113 | decoder_layers = decoder_layers, 114 | load_ckpt=load_ckpt, 115 | ) 116 | 117 | 118 | 119 | def create_gaussian_diffusion( 120 | *, 121 | steps=1000, 122 | learn_sigma=False, 123 | sigma_small=False, 124 | noise_schedule="linear", 125 | use_kl=False, 126 | predict_xstart=False, 127 | rescale_timesteps=False, 128 | rescale_learned_sigmas=False, 129 | timestep_respacing="", 130 | model_arch="transformer", 131 | training_mode="diffusion-lm", 132 | sequence_len=None, 133 | resume_checkpoint='', 134 | pad_tok_id=None, 135 | loss_update_granu=None, 136 | schedule_update_stride=0, 137 | ): 138 | 139 | betas = gd.get_named_beta_schedule(noise_schedule, steps) 140 | 141 | if use_kl: 142 | loss_type = gd.LossType.E2E_KL 143 | else: 144 | loss_type = gd.LossType.E2E_MSE 145 | 146 | if not timestep_respacing: 147 | timestep_respacing = [steps] 148 | 149 | # Whether variance is learned or fixed 150 | model_var_type = None 151 | if not learn_sigma: 152 | if sigma_small: 153 | model_var_type = gd.ModelVarType.FIXED_SMALL 154 | else: 155 | model_var_type = gd.ModelVarType.FIXED_LARGE 156 | else: 157 | model_var_type = gd.ModelVarType.LEARNED_RANGE 158 | 159 | # what is the interpretation of the output generated by the model? Is it generating the noise or the mean directly? 160 | 161 | model_mean_type = None 162 | if not predict_xstart: 163 | model_mean_type = gd.ModelMeanType.EPSILON # predicts noise 164 | else: # predicts starting x (x0 estimate, possibly used by DDIM?) 165 | model_mean_type = gd.ModelMeanType.START_X 166 | 167 | return SpacedDiffusion( 168 | use_timesteps=space_timesteps(steps, timestep_respacing), 169 | betas=betas, 170 | model_var_type=model_var_type, 171 | model_mean_type=model_mean_type, 172 | loss_type=loss_type, 173 | rescale_timesteps=rescale_timesteps, 174 | model_arch=model_arch, 175 | training_mode=training_mode, 176 | token_max_length=sequence_len, 177 | save_dir=resume_checkpoint, 178 | pad_tok_id=pad_tok_id, 179 | loss_update_granu=loss_update_granu, 180 | schedule_update_stride=schedule_update_stride, 181 | ) 182 | -------------------------------------------------------------------------------- /nmt_bleu.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Python implementation of BLEU and smooth-BLEU. 17 | 18 | This module provides a Python implementation of BLEU and smooth-BLEU. 19 | Smooth BLEU is computed following the method outlined in the paper: 20 | Chin-Yew Lin, Franz Josef Och. ORANGE: a method for evaluating automatic 21 | evaluation metrics for machine translation. COLING 2004. 22 | """ 23 | 24 | import collections 25 | import math 26 | 27 | 28 | def _get_ngrams(segment, max_order): 29 | """Extracts all n-grams upto a given maximum order from an input segment. 30 | 31 | Args: 32 | segment: text segment from which n-grams will be extracted. 33 | max_order: maximum length in tokens of the n-grams returned by this 34 | methods. 35 | 36 | Returns: 37 | The Counter containing all n-grams upto max_order in segment 38 | with a count of how many times each n-gram occurred. 39 | """ 40 | ngram_counts = collections.Counter() 41 | for order in range(1, max_order + 1): 42 | for i in range(0, len(segment) - order + 1): 43 | ngram = tuple(segment[i:i+order]) 44 | ngram_counts[ngram] += 1 45 | return ngram_counts 46 | 47 | 48 | def compute_bleu(reference_corpus, translation_corpus, max_order=4, 49 | smooth=False): 50 | """Computes BLEU score of translated segments against one or more references. 51 | 52 | Args: 53 | reference_corpus: list of lists of references for each translation. Each 54 | reference should be tokenized into a list of tokens. 55 | translation_corpus: list of translations to score. Each translation 56 | should be tokenized into a list of tokens. 57 | max_order: Maximum n-gram order to use when computing BLEU score. 58 | smooth: Whether or not to apply Lin et al. 2004 smoothing. 59 | 60 | Returns: 61 | 3-Tuple with the BLEU score, n-gram precisions, geometric mean of n-gram 62 | precisions and brevity penalty. 63 | """ 64 | matches_by_order = [0] * max_order 65 | possible_matches_by_order = [0] * max_order 66 | reference_length = 0 67 | translation_length = 0 68 | for (references, translation) in zip(reference_corpus, 69 | translation_corpus): 70 | reference_length += min(len(r) for r in references) 71 | translation_length += len(translation) 72 | 73 | merged_ref_ngram_counts = collections.Counter() 74 | for reference in references: 75 | merged_ref_ngram_counts |= _get_ngrams(reference, max_order) 76 | translation_ngram_counts = _get_ngrams(translation, max_order) 77 | overlap = translation_ngram_counts & merged_ref_ngram_counts 78 | for ngram in overlap: 79 | matches_by_order[len(ngram)-1] += overlap[ngram] 80 | for order in range(1, max_order+1): 81 | possible_matches = len(translation) - order + 1 82 | if possible_matches > 0: 83 | possible_matches_by_order[order-1] += possible_matches 84 | 85 | precisions = [0] * max_order 86 | for i in range(0, max_order): 87 | if smooth: 88 | precisions[i] = ((matches_by_order[i] + 1.) / 89 | (possible_matches_by_order[i] + 1.)) 90 | else: 91 | if possible_matches_by_order[i] > 0: 92 | precisions[i] = (float(matches_by_order[i]) / 93 | possible_matches_by_order[i]) 94 | else: 95 | precisions[i] = 0.0 96 | 97 | if min(precisions) > 0: 98 | p_log_sum = sum((1. / max_order) * math.log(p) for p in precisions) 99 | geo_mean = math.exp(p_log_sum) 100 | else: 101 | geo_mean = 0 102 | 103 | ratio = float(translation_length) / reference_length 104 | 105 | if ratio > 1.0: 106 | bp = 1. 107 | else: 108 | bp = math.exp(1 - 1. / ratio) 109 | 110 | bleu = geo_mean * bp 111 | 112 | return (bleu, precisions, bp, ratio, translation_length, reference_length) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | 2 | bert-score 3 | blobfile 4 | datasets 5 | huggingface-hub==0.4.0 6 | mpi4py 7 | nltk 8 | numpy 9 | pandas 10 | protobuf 11 | rouge-score 12 | sacrebleu 13 | sacremoses 14 | scikit-learn 15 | scipy 16 | spacy 17 | tokenizers 18 | torchmetrics 19 | tqdm 20 | transformers==4.18.0 -------------------------------------------------------------------------------- /results/README.md: -------------------------------------------------------------------------------- 1 | 2 | ### Scaffold analysis 3 | 4 | https://github.com/grisoniFr/scaffold_hopping_whales 5 | 6 | 7 | 8 | 9 | 10 | 11 | ### IUPAC name ⇆ SMILES string 12 | 13 | 14 | 15 | 16 | #### Structure/SMILES2IUPAC 17 | 18 | 19 | **IUPAC Naming** 20 | 21 | https://web.chemdoodle.com/demos/iupac-naming 22 | 23 | 24 | **SMILES2IUPAC** 25 | 26 | https://huggingface.co/knowledgator/SMILES2IUPAC-canonical-base 27 | 28 | **Smiles-TO-iUpac-Translator** 29 | 30 | https://github.com/Kohulan/Smiles-TO-iUpac-Translator 31 | 32 | 33 | 34 | #### IUPAC2SMILES 35 | 36 | 37 | https://www.antvaset.com/iupac-to-smiles 38 | 39 | 40 | https://web.chemdoodle.com/demos/iupac-naming 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | -------------------------------------------------------------------------------- /rouge.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The HuggingFace Datasets Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ ROUGE metric from Google Research github repo. """ 16 | 17 | # The dependencies in https://github.com/google-research/google-research/blob/master/rouge/requirements.txt 18 | import absl # Here to have a nice missing dependency error message early on 19 | import nltk # Here to have a nice missing dependency error message early on 20 | import numpy # Here to have a nice missing dependency error message early on 21 | import six # Here to have a nice missing dependency error message early on 22 | from rouge_score import rouge_scorer, scoring 23 | 24 | import datasets 25 | 26 | 27 | _CITATION = """\ 28 | @inproceedings{lin-2004-rouge, 29 | title = "{ROUGE}: A Package for Automatic Evaluation of Summaries", 30 | author = "Lin, Chin-Yew", 31 | booktitle = "Text Summarization Branches Out", 32 | month = jul, 33 | year = "2004", 34 | address = "Barcelona, Spain", 35 | publisher = "Association for Computational Linguistics", 36 | url = "https://www.aclweb.org/anthology/W04-1013", 37 | pages = "74--81", 38 | } 39 | """ 40 | 41 | _DESCRIPTION = """\ 42 | ROUGE, or Recall-Oriented Understudy for Gisting Evaluation, is a set of metrics and a software package used for 43 | evaluating automatic summarization and machine translation software in natural language processing. 44 | The metrics compare an automatically produced summary or translation against a reference or a set of references (human-produced) summary or translation. 45 | 46 | Note that ROUGE is case insensitive, meaning that upper case letters are treated the same way as lower case letters. 47 | 48 | This metrics is a wrapper around Google Research reimplementation of ROUGE: 49 | https://github.com/google-research/google-research/tree/master/rouge 50 | """ 51 | 52 | _KWARGS_DESCRIPTION = """ 53 | Calculates average rouge scores for a list of hypotheses and references 54 | Args: 55 | predictions: list of predictions to score. Each predictions 56 | should be a string with tokens separated by spaces. 57 | references: list of reference for each prediction. Each 58 | reference should be a string with tokens separated by spaces. 59 | rouge_types: A list of rouge types to calculate. 60 | Valid names: 61 | `"rouge{n}"` (e.g. `"rouge1"`, `"rouge2"`) where: {n} is the n-gram based scoring, 62 | `"rougeL"`: Longest common subsequence based scoring. 63 | `"rougeLSum"`: rougeLsum splits text using `"\n"`. 64 | See details in https://github.com/huggingface/datasets/issues/617 65 | use_stemmer: Bool indicating whether Porter stemmer should be used to strip word suffixes. 66 | use_agregator: Return aggregates if this is set to True 67 | Returns: 68 | rouge1: rouge_1 (precision, recall, f1), 69 | rouge2: rouge_2 (precision, recall, f1), 70 | rougeL: rouge_l (precision, recall, f1), 71 | rougeLsum: rouge_lsum (precision, recall, f1) 72 | Examples: 73 | 74 | >>> rouge = datasets.load_metric('rouge') 75 | >>> predictions = ["hello there", "general kenobi"] 76 | >>> references = ["hello there", "general kenobi"] 77 | >>> results = rouge.compute(predictions=predictions, references=references) 78 | >>> print(list(results.keys())) 79 | ['rouge1', 'rouge2', 'rougeL', 'rougeLsum'] 80 | >>> print(results["rouge1"]) 81 | AggregateScore(low=Score(precision=1.0, recall=1.0, fmeasure=1.0), mid=Score(precision=1.0, recall=1.0, fmeasure=1.0), high=Score(precision=1.0, recall=1.0, fmeasure=1.0)) 82 | >>> print(results["rouge1"].mid.fmeasure) 83 | 1.0 84 | """ 85 | 86 | 87 | @datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) 88 | class Rouge(datasets.Metric): 89 | def _info(self): 90 | return datasets.MetricInfo( 91 | description=_DESCRIPTION, 92 | citation=_CITATION, 93 | inputs_description=_KWARGS_DESCRIPTION, 94 | features=datasets.Features( 95 | { 96 | "predictions": datasets.Value("string", id="sequence"), 97 | "references": datasets.Value("string", id="sequence"), 98 | } 99 | ), 100 | codebase_urls=["https://github.com/google-research/google-research/tree/master/rouge"], 101 | reference_urls=[ 102 | "https://en.wikipedia.org/wiki/ROUGE_(metric)", 103 | "https://github.com/google-research/google-research/tree/master/rouge", 104 | ], 105 | ) 106 | 107 | def _compute(self, predictions, references, rouge_types=None, use_agregator=True, use_stemmer=False): 108 | if rouge_types is None: 109 | rouge_types = ["rouge1", "rouge2", "rougeL", "rougeLsum"] 110 | 111 | scorer = rouge_scorer.RougeScorer(rouge_types=rouge_types, use_stemmer=use_stemmer) 112 | if use_agregator: 113 | aggregator = scoring.BootstrapAggregator() 114 | else: 115 | scores = [] 116 | 117 | for ref, pred in zip(references, predictions): 118 | score = scorer.score(ref, pred) 119 | if use_agregator: 120 | aggregator.add_scores(score) 121 | else: 122 | scores.append(score) 123 | 124 | if use_agregator: 125 | result = aggregator.aggregate() 126 | else: 127 | result = {} 128 | for key in scores[0]: 129 | result[key] = list(score[key] for score in scores) 130 | 131 | return result -------------------------------------------------------------------------------- /sacre_bleu.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The HuggingFace Datasets Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ SACREBLEU metric. """ 15 | 16 | import sacrebleu as scb 17 | from packaging import version 18 | 19 | import datasets 20 | 21 | 22 | _CITATION = """\ 23 | @inproceedings{post-2018-call, 24 | title = "A Call for Clarity in Reporting {BLEU} Scores", 25 | author = "Post, Matt", 26 | booktitle = "Proceedings of the Third Conference on Machine Translation: Research Papers", 27 | month = oct, 28 | year = "2018", 29 | address = "Belgium, Brussels", 30 | publisher = "Association for Computational Linguistics", 31 | url = "https://www.aclweb.org/anthology/W18-6319", 32 | pages = "186--191", 33 | } 34 | """ 35 | 36 | _DESCRIPTION = """\ 37 | SacreBLEU provides hassle-free computation of shareable, comparable, and reproducible BLEU scores. 38 | Inspired by Rico Sennrich's `multi-bleu-detok.perl`, it produces the official WMT scores but works with plain text. 39 | It also knows all the standard test sets and handles downloading, processing, and tokenization for you. 40 | 41 | See the [README.md] file at https://github.com/mjpost/sacreBLEU for more information. 42 | """ 43 | 44 | _KWARGS_DESCRIPTION = """ 45 | Produces BLEU scores along with its sufficient statistics 46 | from a source against one or more references. 47 | 48 | Args: 49 | predictions: The system stream (a sequence of segments). 50 | references: A list of one or more reference streams (each a sequence of segments). 51 | smooth_method: The smoothing method to use. (Default: 'exp'). 52 | smooth_value: The smoothing value. Only valid for 'floor' and 'add-k'. (Defaults: floor: 0.1, add-k: 1). 53 | tokenize: Tokenization method to use for BLEU. If not provided, defaults to 'zh' for Chinese, 'ja-mecab' for 54 | Japanese and '13a' (mteval) otherwise. 55 | lowercase: Lowercase the data. If True, enables case-insensitivity. (Default: False). 56 | force: Insist that your tokenized input is actually detokenized. 57 | 58 | Returns: 59 | 'score': BLEU score, 60 | 'counts': Counts, 61 | 'totals': Totals, 62 | 'precisions': Precisions, 63 | 'bp': Brevity penalty, 64 | 'sys_len': predictions length, 65 | 'ref_len': reference length, 66 | 67 | Examples: 68 | 69 | >>> predictions = ["hello there general kenobi", "foo bar foobar"] 70 | >>> references = [["hello there general kenobi", "hello there !"], ["foo bar foobar", "foo bar foobar"]] 71 | >>> sacrebleu = datasets.load_metric("sacrebleu") 72 | >>> results = sacrebleu.compute(predictions=predictions, references=references) 73 | >>> print(list(results.keys())) 74 | ['score', 'counts', 'totals', 'precisions', 'bp', 'sys_len', 'ref_len'] 75 | >>> print(round(results["score"], 1)) 76 | 100.0 77 | """ 78 | 79 | 80 | @datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) 81 | class Sacrebleu(datasets.Metric): 82 | def _info(self): 83 | if version.parse(scb.__version__) < version.parse("1.4.12"): 84 | raise ImportWarning( 85 | "To use `sacrebleu`, the module `sacrebleu>=1.4.12` is required, and the current version of `sacrebleu` doesn't match this condition.\n" 86 | 'You can install it with `pip install "sacrebleu>=1.4.12"`.' 87 | ) 88 | return datasets.MetricInfo( 89 | description=_DESCRIPTION, 90 | citation=_CITATION, 91 | homepage="https://github.com/mjpost/sacreBLEU", 92 | inputs_description=_KWARGS_DESCRIPTION, 93 | features=datasets.Features( 94 | { 95 | "predictions": datasets.Value("string", id="sequence"), 96 | "references": datasets.Sequence(datasets.Value("string", id="sequence"), id="references"), 97 | } 98 | ), 99 | codebase_urls=["https://github.com/mjpost/sacreBLEU"], 100 | reference_urls=[ 101 | "https://github.com/mjpost/sacreBLEU", 102 | "https://en.wikipedia.org/wiki/BLEU", 103 | "https://towardsdatascience.com/evaluating-text-output-in-nlp-bleu-at-your-own-risk-e8609665a213", 104 | ], 105 | ) 106 | 107 | def _compute( 108 | self, 109 | predictions, 110 | references, 111 | smooth_method="exp", 112 | smooth_value=None, 113 | force=False, 114 | lowercase=False, 115 | tokenize=None, 116 | use_effective_order=False, 117 | ): 118 | references_per_prediction = len(references[0]) 119 | if any(len(refs) != references_per_prediction for refs in references): 120 | raise ValueError("Sacrebleu requires the same number of references for each prediction") 121 | transformed_references = [[refs[i] for refs in references] for i in range(references_per_prediction)] 122 | output = scb.corpus_bleu( 123 | predictions, 124 | transformed_references, 125 | smooth_method=smooth_method, 126 | smooth_value=smooth_value, 127 | force=force, 128 | lowercase=lowercase, 129 | use_effective_order=use_effective_order, 130 | **(dict(tokenize=tokenize) if tokenize else {}), 131 | ) 132 | output_dict = { 133 | "score": output.score, 134 | "counts": output.counts, 135 | "totals": output.totals, 136 | "precisions": output.precisions, 137 | "bp": output.bp, 138 | "sys_len": output.sys_len, 139 | "ref_len": output.ref_len, 140 | } 141 | return output_dict -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /src/controllable/classifier.py: -------------------------------------------------------------------------------- 1 | ### Trains a classifier on the latent space of a diffusion model 2 | from functools import partial 3 | import json 4 | import os 5 | import sys 6 | from torch import nn 7 | import torch 8 | import pandas as pd 9 | from torch.utils.data import DataLoader 10 | 11 | 12 | from transformers.models.bert.modeling_bert import BertConfig, BertModel, BertPooler 13 | from transformers import AutoTokenizer 14 | from transformers.modeling_outputs import SequenceClassifierOutput 15 | from modeling.diffusion.gaussian_diffusion import GaussianDiffusion 16 | 17 | 18 | from train_infer.factory_methods import create_model_and_diffusion 19 | from src.utils import dist_util 20 | from src.utils.args_utils import create_argparser, args_to_dict, model_and_diffusion_defaults 21 | from src.utils.data_utils_sentencepiece import TextDataset 22 | from src.utils.custom_tokenizer import create_tokenizer 23 | from utils.logger import log 24 | 25 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 26 | 27 | 28 | class DiffusionBertForSequenceClassification(nn.Module): 29 | """A bert based classifier that uses the latent space of a diffusion model as input""" 30 | 31 | _keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias", r"lm_head.weight", "word_embeddings.weight"] 32 | 33 | def __init__(self, config: BertConfig, diffusion_model: GaussianDiffusion, num_labels: int): 34 | super().__init__() 35 | self.diffusion_model = diffusion_model 36 | self.classifier = BertModel(config) 37 | self.word_embeddings = nn.Embedding( 38 | num_embeddings=config.vocab_size, embedding_dim=config.embedding_dim 39 | ) 40 | 41 | 42 | self.pooler = BertPooler(config) 43 | 44 | self.num_labels = num_labels 45 | 46 | self.up_proj = nn.Sequential( 47 | nn.Linear(config.embedding_dim, config.embedding_dim * 4), 48 | nn.Tanh(), 49 | nn.Linear(config.embedding_dim * 4, config.hidden_size), 50 | ) 51 | 52 | self.train_diffusion_steps = config.train_diffusion_steps 53 | 54 | # self.time_embeddings = nn.Embedding(self.train_diffusion_steps + 1, config.hidden_size) 55 | self.time_embeddings = nn.Embedding(2000 + 1, config.hidden_size) 56 | 57 | # Model parallel 58 | self.model_parallel = False 59 | self.device_map = None 60 | 61 | self.classification_head = nn.Sequential( 62 | nn.Linear(config.hidden_size, config.hidden_size), 63 | nn.Tanh(), 64 | nn.Linear(config.hidden_size, self.num_labels), 65 | ) 66 | 67 | def forward( 68 | self, 69 | input_ids=None, 70 | past_key_values=None, 71 | attention_mask=None, 72 | token_type_ids=None, 73 | position_ids=None, 74 | head_mask=None, 75 | inputs_embeds=None, 76 | encoder_hidden_states=None, 77 | encoder_attention_mask=None, 78 | labels=None, 79 | use_cache=None, 80 | output_attentions=None, 81 | output_hidden_states=None, 82 | return_dict=None, 83 | ): 84 | 85 | if inputs_embeds is None: 86 | # The classifier is supposed to be used with a diffusion model. During training, the embeddings should be provided 87 | # by the backbone model. The word_embeddings are included here for to test the classifier on its own. 88 | inputs_embeds = self.word_embeddings(input_ids) 89 | 90 | t = torch.randint(-1, self.train_diffusion_steps, (inputs_embeds.shape[0],)).to( 91 | inputs_embeds.device 92 | ) 93 | # done this way because torch randint is [inclusive, exclusive), and we don't want to pass samples with t = num_diffusion_steps 94 | # TODO: double-check this 95 | t_mask = t >= 0 96 | 97 | inputs_with_added_noise = self.diffusion_model.q_sample(x_start=inputs_embeds, t=t) 98 | # replace the embeddings with the noisy versions for all samples with t >= 0 99 | inputs_embeds[t_mask] = inputs_with_added_noise[t_mask] 100 | inputs_embeds = self.up_proj(inputs_embeds) 101 | 102 | # essentially, t = -1 is the last step 103 | t[~t_mask] = self.train_diffusion_steps 104 | time_embedded = self.time_embeddings(t).unsqueeze(1) 105 | 106 | 107 | 108 | inputs_embeds = torch.cat([inputs_embeds, time_embedded], dim=1) 109 | 110 | outputs = self.classifier( 111 | inputs_embeds=inputs_embeds, 112 | past_key_values=past_key_values, 113 | attention_mask=attention_mask, 114 | token_type_ids=token_type_ids, 115 | position_ids=position_ids, 116 | head_mask=head_mask, 117 | encoder_hidden_states=encoder_hidden_states, 118 | encoder_attention_mask=encoder_attention_mask, 119 | use_cache=use_cache, 120 | output_attentions=output_attentions, 121 | output_hidden_states=output_hidden_states, 122 | return_dict=return_dict, 123 | ) 124 | 125 | return self.loss_from_outputs(outputs, labels) 126 | 127 | def label_logp(self, inputs_with_added_noise, t, labels): 128 | """ 129 | Returns p(labels | x_t, t) for a batch of samples. Note that inputs_with_added_noise are supposed to be the noisy versions of the inputs. Using DDPM terminology, this is the x_t. 130 | """ 131 | 132 | inputs_with_added_noise = self.up_proj(inputs_with_added_noise) 133 | 134 | time_embedded = self.time_embeddings(t).unsqueeze(1) 135 | inputs_embeds = torch.cat([inputs_with_added_noise, time_embedded], dim=1) 136 | outputs = self.classifier( 137 | inputs_embeds=inputs_embeds, 138 | ) 139 | return self.loss_from_outputs(outputs, labels) 140 | 141 | 142 | 143 | def loss_from_outputs(self, outputs, labels = None): 144 | pooled_output = self.pooler(outputs[0]) 145 | logits = self.classification_head(pooled_output) 146 | if labels is not None: 147 | loss = nn.CrossEntropyLoss()(logits.view(-1, self.num_labels), labels.view(-1)) 148 | else: 149 | loss = None 150 | 151 | return SequenceClassifierOutput( 152 | loss=loss, 153 | logits=logits, 154 | hidden_states=outputs.hidden_states, 155 | attentions=outputs.attentions, 156 | ) 157 | 158 | # TODO: make num_labels a property of the config 159 | @staticmethod 160 | def load_from_checkpoint( 161 | checkpoint_path: str, 162 | config: BertConfig, 163 | diffusion_model: GaussianDiffusion, 164 | num_labels: int = 2, 165 | ): 166 | model = DiffusionBertForSequenceClassification(config, diffusion_model, num_labels) 167 | model.load_state_dict(torch.load(checkpoint_path), strict=False) 168 | return model 169 | 170 | def train_classifier_on_diffusion_latents(): 171 | 172 | # Step 1: load the arguments 173 | args = get_training_args() 174 | 175 | # Step 2: load the model and diffusion 176 | model, diffusion = create_model_and_diffusion( 177 | **args_to_dict(args, model_and_diffusion_defaults().keys()) 178 | ) 179 | model.load_state_dict(dist_util.load_state_dict(args.model_name_or_path, map_location="cpu")) 180 | 181 | tokenizer = create_tokenizer( 182 | return_pretokenized=args.use_pretrained_embeddings, path=f"data/{args.dataset}/" 183 | ) 184 | 185 | # Step 3: load the data 186 | dataloader = get_dataloader( 187 | path=f"data/{args.dataset}/{args.dataset}_labeled.tsv", tokenizer=tokenizer, 188 | max_seq_len=args.sequence_len 189 | 190 | ) 191 | 192 | # Step 4: create the classifier 193 | config = BertConfig.from_pretrained("bert-base-uncased") 194 | config.train_diffusion_steps = args.diffusion_steps 195 | config.embedding_dim = args.in_channel 196 | config.vocab_size = tokenizer.vocab_size 197 | 198 | model = DiffusionBertForSequenceClassification( 199 | config=config, num_labels=2, diffusion_model=diffusion 200 | ).to(device) 201 | 202 | # Step 5: train the classifier 203 | 204 | model = training_loop(model=model, dataloader=dataloader, num_epochs=args.classifier_num_epochs) 205 | 206 | # Step 6: save the model 207 | torch.save(model.state_dict(), f"{args.checkpoint_path}/classifier.pt") 208 | 209 | 210 | def get_dataloader(path, tokenizer, max_seq_len: int, batch_size=32): 211 | dataset = TextDataset(data_path=path, has_labels=True, tokenizer=tokenizer) 212 | return DataLoader( 213 | dataset, batch_size=batch_size, shuffle=True, 214 | collate_fn=partial(TextDataset.collate_pad, cutoff=max_seq_len) 215 | ) 216 | 217 | 218 | def training_loop(model, dataloader, num_epochs: int, lr: float = 1e-5): 219 | from transformers import AdamW 220 | 221 | optimizer = AdamW(model.parameters(), lr=lr) 222 | for epoch_idx in range(num_epochs): 223 | epoch_loss = 0.0 224 | num_batches = 0 225 | for batch_idx, (_, batch) in enumerate(dataloader): 226 | optimizer.zero_grad() 227 | batch_input_ids = batch["input_ids"].to(device) 228 | batch_labels = batch["labels"].to(device) 229 | 230 | outputs = model(input_ids=batch_input_ids.to(device), labels=batch_labels.to(device)) 231 | outputs.loss.backward() 232 | optimizer.step() 233 | 234 | epoch_loss += outputs.loss.item() 235 | num_batches += 1 236 | 237 | print(f"Epoch {epoch_idx}: {epoch_loss / num_batches:.2f}") 238 | return model 239 | 240 | 241 | def get_training_args(): 242 | 243 | args = create_argparser().parse_args() 244 | args.checkpoint_path = os.path.split(args.model_name_or_path)[0] 245 | with open(f"{args.checkpoint_path}/training_args.json", "r") as f: 246 | training_args = json.load(f) 247 | 248 | # we want to retain defaults for some arguments 249 | training_args["batch_size"] = args.batch_size 250 | training_args["model_name_or_path"] = args.model_name_or_path 251 | training_args["clamp"] = args.clamp 252 | training_args["out_dir"] = args.out_dir 253 | training_args["num_samples"] = args.num_samples 254 | 255 | args.__dict__.update(training_args) 256 | return args 257 | 258 | 259 | ### TESTING ### 260 | 261 | 262 | class StubDiffusionModel(nn.Module): 263 | def __init__(self): 264 | super().__init__() 265 | 266 | def q_sample(self, x_start, t): 267 | return x_start 268 | 269 | 270 | def unit_data_for_text_classification(train_epochs=50): 271 | # generate 100 sentences talking about food both good and bad. Also generate the label 1 or 0 for good or bad 272 | dishes = [ 273 | "pancakes", 274 | "pizza", 275 | "pasta", 276 | "salad", 277 | "steak", 278 | "chicken", 279 | "fish", 280 | "soup", 281 | "ice cream", 282 | "cake", 283 | "pie", 284 | "cookies", 285 | "brownies", 286 | "sushi", 287 | "ramen", 288 | "tacos", 289 | "burritos", 290 | "sandwiches", 291 | "waffles", 292 | "french fries", 293 | "chips", 294 | "popcorn", 295 | "chocolate", 296 | "candy", 297 | "ice cream", 298 | "milkshake", 299 | "coffee", 300 | "tea", 301 | "juice", 302 | "water", 303 | "beer", 304 | "wine", 305 | "soda", 306 | "milk", 307 | "eggs", 308 | "bacon", 309 | "sausage", 310 | "cheese", 311 | "bread", 312 | "rice", 313 | "beans", 314 | "potatoes", 315 | "carrots", 316 | "broccoli", 317 | ] 318 | 319 | positive_adjectives = [ 320 | "delicious", 321 | "superb", 322 | "amazing", 323 | "fantastic", 324 | "great", 325 | "good", 326 | "nice", 327 | "yummy", 328 | "tasty", 329 | ] 330 | negative_adjectives = [ 331 | "disgusting", 332 | "awful", 333 | "terrible", 334 | "horrible", 335 | "bad", 336 | "gross", 337 | "nasty", 338 | "yucky", 339 | "icky", 340 | ] 341 | 342 | # set random seed 343 | import random 344 | 345 | random.seed(0) 346 | 347 | # generate 100 positive sentences 348 | sentences, labels = [], [] 349 | for _ in range(100): 350 | sentences.append(f"The {random.choice(dishes)} was {random.choice(positive_adjectives)}.") 351 | labels.append(1) 352 | 353 | for _ in range(100): 354 | sentences.append(f"The {random.choice(dishes)} was {random.choice(negative_adjectives)}.") 355 | labels.append(0) 356 | 357 | data = pd.DataFrame({"sentence": sentences, "label": labels}) 358 | 359 | data.to_csv("ipynb/food_reviews.csv", index=False) 360 | 361 | tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") 362 | 363 | data = data.sample(frac=1).reset_index(drop=True) 364 | sentences = data["sentence"].tolist() 365 | labels = data["label"].tolist() 366 | 367 | input_ids = tokenizer(sentences, padding=True, truncation=True, return_tensors="pt")[ 368 | "input_ids" 369 | ] 370 | 371 | config = BertConfig.from_pretrained("bert-base-uncased") 372 | config.train_diffusion_steps = 200 373 | config.embedding_dim = 128 374 | 375 | model = DiffusionBertForSequenceClassification( 376 | config=config, num_labels=2, diffusion_model=StubDiffusionModel() 377 | ).to(device) 378 | 379 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-5) 380 | 381 | for epoch_idx in range(train_epochs): 382 | batch_size = 32 383 | epoch_loss = 0.0 384 | num_batches = 0 385 | for i in range(0, len(input_ids), batch_size): 386 | optimizer.zero_grad() 387 | batch_input_ids = input_ids[i : i + batch_size] 388 | batch_labels = torch.tensor(labels[i : i + batch_size]) 389 | outputs = model(input_ids=batch_input_ids.to(device), labels=batch_labels.to(device)) 390 | outputs.loss.backward() 391 | optimizer.step() 392 | epoch_loss += outputs.loss.item() 393 | num_batches += 1 394 | 395 | print(f"Epoch {epoch_idx}: {epoch_loss / num_batches:.2f}") 396 | 397 | # sanity test for overfitting 398 | 399 | is_corr = 0 400 | 401 | print("Sanity test for overfitting") 402 | for i, row in data.iterrows(): 403 | inferred_label = get_label_from_sentence( 404 | sentence=row["sentence"], model=model, tokenizer=tokenizer 405 | ) 406 | is_corr += int(inferred_label == row["label"]) 407 | 408 | frac_corr = is_corr / len(data) 409 | 410 | print(f"Accuracy: {frac_corr:.2f}") 411 | 412 | assert frac_corr > 0.99, f"Fraction correct is {frac_corr}, should be > 0.9" 413 | 414 | print("Sanity test passed!") 415 | 416 | # test on new data 417 | test_sentences = [ 418 | "The pancakes were delicious.", 419 | "The pancakes were disgusting.", 420 | "The smoothie was terrible.", 421 | "The smoothie was great.", 422 | ] 423 | 424 | for sentence in test_sentences: 425 | inferred_label = get_label_from_sentence( 426 | sentence=sentence, model=model, tokenizer=tokenizer 427 | ) 428 | print(f"{sentence} -> {inferred_label}") 429 | 430 | 431 | def get_label_from_sentence(model, sentence, tokenizer): 432 | ids = tokenizer(sentence, padding=True, truncation=True, return_tensors="pt")["input_ids"] 433 | return model(ids.to(device)).logits.argmax().item() 434 | 435 | 436 | if __name__ == "__main__": 437 | import sys 438 | 439 | if sys.argv[1] == "run_unit_tests": 440 | unit_data_for_text_classification() 441 | 442 | else: 443 | train_classifier_on_diffusion_latents() 444 | 445 | 446 | def txt_to_jsonl(basedir): 447 | sentences, labels = [], [] 448 | with open(f"{basedir}/train.text", "r") as f: 449 | for line in f: 450 | sentences.append(line.strip()) 451 | with open(f"{basedir}/train.labels", "r") as f: 452 | for line in f: 453 | labels.append(int(line.strip())) 454 | return pd.DataFrame({"sentence": sentences, "label": labels}) 455 | -------------------------------------------------------------------------------- /src/controllable/controllable_text_sample.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generate a large batch of image samples from a model and save them as a large 3 | numpy array. This can be used to produce samples for FID evaluation. 4 | """ 5 | import os, json 6 | import sys 7 | from typing import List 8 | import numpy as np 9 | import torch as th 10 | import torch.distributed as dist 11 | from transformers import set_seed 12 | from functools import partial 13 | from src.utils import dist_util, logger 14 | 15 | 16 | from src.utils.args_utils import * 17 | from train_infer.factory_methods import create_model_and_diffusion 18 | from src.utils.args_utils import create_argparser, args_to_dict, model_and_diffusion_defaults 19 | from src.utils.custom_tokenizer import create_tokenizer 20 | from src.controllable.langevin import langevin_binary_classifier 21 | from src.controllable.classifier import DiffusionBertForSequenceClassification 22 | 23 | 24 | def main(): 25 | 26 | args = create_argparser().parse_args() 27 | 28 | set_seed(args.seed) 29 | dist_util.setup_dist() 30 | logger.configure() 31 | 32 | # load configurations. 33 | args.checkpoint_path = os.path.split(args.model_name_or_path)[0] 34 | 35 | config_path = os.path.join(args.checkpoint_path, "training_args.json") 36 | training_args = read_training_args(config_path) 37 | training_args["batch_size"] = args.batch_size 38 | # overwrite this because we want to allow generation for any diffusion step. 39 | training_args["diffusion_steps"] = args.diffusion_steps 40 | training_args["model_name_or_path"] = args.model_name_or_path 41 | training_args["clamp"] = args.clamp 42 | training_args["out_dir"] = args.out_dir 43 | training_args["num_samples"] = args.num_samples 44 | 45 | args.__dict__.update(training_args) 46 | args.sigma_small = True 47 | 48 | logger.info(f"Init pretrained = {args.init_pretrained}") 49 | logger.info(f"Freeze embeddings = {args.freeze_embeddings}") 50 | logger.info(f"Use pretrained embeddings = {args.use_pretrained_embeddings}") 51 | 52 | model, diffusion = create_model_and_diffusion( 53 | **args_to_dict(args, model_and_diffusion_defaults().keys()) 54 | ) 55 | model.load_state_dict(dist_util.load_state_dict(args.model_name_or_path, map_location="cpu")) 56 | model.eval() 57 | 58 | tokenizer = create_tokenizer( 59 | return_pretokenized=args.use_pretrained_embeddings, path=f"data/{args.dataset}/" 60 | ) 61 | 62 | model.config.update({"embedding_dim": args.in_channel}) 63 | model.config.update({"train_diffusion_steps": args.diffusion_steps}) 64 | model.config.update({"vocab_size": tokenizer.vocab_size}) 65 | 66 | classifier = DiffusionBertForSequenceClassification.load_from_checkpoint( 67 | checkpoint_path=args.checkpoint_path + "/classifier.pt", 68 | config=model.config, 69 | diffusion_model=diffusion, 70 | ).to("cuda") 71 | 72 | # freeze the classifier 73 | for param in classifier.parameters(): 74 | param.requires_grad = False 75 | 76 | langevin_classifier_wrapper = partial(langevin_binary_classifier, classifier=classifier) 77 | 78 | pytorch_total_params = sum(p.numel() for p in model.parameters()) 79 | logger.log(f"the parameter count is {pytorch_total_params}") 80 | 81 | diffusion.rescale_timesteps = True 82 | 83 | model.to(dist_util.dev()) 84 | model.eval() # DEBUG 85 | 86 | logger.log(f"Generating {args.num_samples} samples") 87 | logger.log(f"Clamping is set to {args.clamp}") 88 | all_samples = [] 89 | while len(all_samples) * args.batch_size < args.num_samples: 90 | model_kwargs = {} 91 | sample_shape = (args.batch_size, args.sequence_len, model.word_embedding.weight.shape[1]) 92 | sample = diffusion.p_sample_loop( 93 | model, 94 | sample_shape, 95 | clip_denoised=args.clip_denoised, 96 | denoised_fn=None, 97 | model_kwargs=model_kwargs, 98 | top_p=args.top_p, 99 | progress=True, 100 | tokenizer=tokenizer, 101 | log_verbose=True, 102 | langevin_fn=langevin_classifier_wrapper, 103 | ) 104 | 105 | gathered_samples = [th.zeros_like(sample) for _ in range(dist.get_world_size())] 106 | dist.all_gather(gathered_samples, sample) # gather not supported with NCCL 107 | all_samples.extend([sample.cpu().numpy() for sample in gathered_samples]) 108 | 109 | logger.log(f"created {len(all_samples)} samples") 110 | 111 | arr = np.concatenate(all_samples, axis=0) 112 | arr = arr[: args.num_samples * args.mbr_sample] 113 | 114 | x_t = th.tensor(arr).cuda() 115 | 116 | logits = model.get_logits(x_t) # bsz, seqlen, vocab 117 | cands = th.topk(logits, k=1, dim=-1) 118 | 119 | decoded_sentences = [] 120 | 121 | for seq in cands.indices: 122 | decoded_sentence = tokenizer.decode(seq.squeeze(1).tolist()) 123 | decoded_sentences.append(decoded_sentence) 124 | 125 | dist.barrier() 126 | logger.log("sampling complete") 127 | 128 | write_outputs(args=args, sentences=decoded_sentences) 129 | 130 | 131 | def load_embeddings(checkpoint_path, tokenizer, emb_dim): 132 | embeddings = th.nn.Embedding(tokenizer.vocab_size, emb_dim) 133 | embeddings.load_state_dict(th.load(f"{checkpoint_path}/random_emb.torch")) 134 | return embeddings 135 | 136 | 137 | def read_training_args(config_path): 138 | with open(config_path, "r") as f: 139 | return json.load(f) 140 | 141 | 142 | def write_outputs(args: dict, sentences: List[str]) -> None: 143 | 144 | model_dir = os.path.split(args.model_name_or_path)[0] 145 | model_base_name = os.path.split(args.model_name_or_path)[1] 146 | 147 | num_samples = len(sentences) 148 | output_file_basepath = ( 149 | os.path.join( 150 | model_dir, 151 | f"{model_base_name}.samples_{num_samples}.steps-{args.diffusion_steps}.clamp-{args.clamp}", 152 | ) 153 | + ".txt.ctrl" 154 | ) 155 | 156 | with open(output_file_basepath, "w") as text_fout: 157 | for generated_sentence in sentences: 158 | text_fout.write(generated_sentence + "\n") 159 | 160 | print(f"written the decoded output to {output_file_basepath}") 161 | 162 | 163 | if __name__ == "__main__": 164 | main() 165 | -------------------------------------------------------------------------------- /src/controllable/langevin.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilizes a trained classifier model to guide the diffusion process. 3 | 4 | - Given: 5 | 1. input embeddings 6 | 2. A classifier model 7 | 3. Labels 8 | 9 | The classifier model is used to refine the input embeddings such that the logits of the classifier model are maximized for the labels. 10 | """ 11 | import torch 12 | 13 | 14 | def langevin_binary_classifier(classifier, label_ids, x_t, t, num_langevin_steps: int = 1, step_size: float=1e-2): # current best. 15 | 16 | x_t_as_params = torch.nn.Parameter(x_t) 17 | 18 | with torch.enable_grad(): 19 | for i in range(num_langevin_steps): 20 | optimizer = torch.optim.Adagrad([x_t_as_params], lr=step_size) 21 | 22 | optimizer.zero_grad() 23 | model_out = classifier.label_logp(inputs_with_added_noise=x_t_as_params, 24 | labels=label_ids, 25 | t=t) 26 | loss = -model_out.loss # logp 27 | loss.backward() 28 | # print(f"{i}> grad norm: {x_t_as_params.grad.data.norm(2)} | loss: {loss}") 29 | 30 | optimizer.step() 31 | 32 | 33 | x_t_as_params = torch.nn.Parameter(x_t_as_params.data.detach()) 34 | 35 | return x_t_as_params.data.detach() 36 | -------------------------------------------------------------------------------- /src/modeling/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AspirinCode/DiffIUPAC/525ffe6850c21d7cafbca94c5c6f971da2a450d4/src/modeling/__init__.py -------------------------------------------------------------------------------- /src/modeling/diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AspirinCode/DiffIUPAC/525ffe6850c21d7cafbca94c5c6f971da2a450d4/src/modeling/diffusion/__init__.py -------------------------------------------------------------------------------- /src/modeling/diffusion/losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for various likelihood-based losses. These are ported from the original 3 | Ho et al. diffusion models codebase: 4 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py 5 | """ 6 | 7 | import numpy as np 8 | 9 | import torch as th 10 | 11 | 12 | def normal_kl(mean1, logvar1, mean2, logvar2): 13 | """ 14 | Compute the KL divergence between two gaussians. 15 | 16 | Shapes are automatically broadcasted, so batches can be compared to 17 | scalars, among other use cases. 18 | """ 19 | tensor = None 20 | for obj in (mean1, logvar1, mean2, logvar2): 21 | if isinstance(obj, th.Tensor): 22 | tensor = obj 23 | break 24 | assert tensor is not None, "at least one argument must be a Tensor" 25 | 26 | # Force variances to be Tensors. Broadcasting helps convert scalars to 27 | # Tensors, but it does not work for th.exp(). 28 | logvar1, logvar2 = [ 29 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) 30 | for x in (logvar1, logvar2) 31 | ] 32 | 33 | # print(logvar2.shape) 34 | # temp1 = 0.5 * (-1.0 + logvar2 - logvar1 + th.exp(logvar1 - logvar2)) 35 | # print(f'const = {temp1.mean()}, coef={(th.exp(-logvar2) * 0.5).mean()}, mse={((mean1 - mean2) ** 2).mean().item()}') 36 | 37 | return 0.5 * ( 38 | -1.0 39 | + logvar2 40 | - logvar1 41 | + th.exp(logvar1 - logvar2) 42 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2) 43 | ) 44 | 45 | 46 | def approx_standard_normal_cdf(x): 47 | """ 48 | A fast approximation of the cumulative distribution function of the 49 | standard normal. 50 | """ 51 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) 52 | 53 | 54 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): 55 | """ 56 | Compute the log-likelihood of a Gaussian distribution discretizing to a 57 | given image. 58 | 59 | :param x: the target images. It is assumed that this was uint8 values, 60 | rescaled to the range [-1, 1]. 61 | :param means: the Gaussian mean Tensor. 62 | :param log_scales: the Gaussian log stddev Tensor. 63 | :return: a tensor like x of log probabilities (in nats). 64 | """ 65 | assert x.shape == means.shape == log_scales.shape 66 | centered_x = x - means 67 | inv_stdv = th.exp(-log_scales) 68 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 69 | cdf_plus = approx_standard_normal_cdf(plus_in) 70 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 71 | cdf_min = approx_standard_normal_cdf(min_in) 72 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) 73 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) 74 | cdf_delta = cdf_plus - cdf_min 75 | log_probs = th.where( 76 | x < -0.999, 77 | log_cdf_plus, 78 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), 79 | ) 80 | assert log_probs.shape == x.shape 81 | return log_probs 82 | 83 | def gaussian_density(x, *, means, log_scales): 84 | from torch.distributions import Normal 85 | normal_dist = Normal(means, log_scales.exp()) 86 | logp = normal_dist.log_prob(x) 87 | return logp 88 | -------------------------------------------------------------------------------- /src/modeling/diffusion/nn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various utilities for neural networks. 3 | """ 4 | 5 | import math 6 | 7 | import torch as th 8 | import torch.nn as nn 9 | 10 | 11 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 12 | class SiLU(nn.Module): 13 | def forward(self, x): 14 | return x * th.sigmoid(x) 15 | 16 | 17 | class GroupNorm32(nn.GroupNorm): 18 | def forward(self, x): 19 | return super().forward(x.float()).type(x.dtype) 20 | 21 | 22 | def conv_nd(dims, *args, **kwargs): 23 | """ 24 | Create a 1D, 2D, or 3D convolution module. 25 | """ 26 | if dims == 1: 27 | return nn.Conv1d(*args, **kwargs) 28 | elif dims == 2: 29 | return nn.Conv2d(*args, **kwargs) 30 | elif dims == 3: 31 | return nn.Conv3d(*args, **kwargs) 32 | raise ValueError(f"unsupported dimensions: {dims}") 33 | 34 | 35 | def linear(*args, **kwargs): 36 | """ 37 | Create a linear module. 38 | """ 39 | return nn.Linear(*args, **kwargs) 40 | 41 | 42 | def avg_pool_nd(dims, *args, **kwargs): 43 | """ 44 | Create a 1D, 2D, or 3D average pooling module. 45 | """ 46 | if dims == 1: 47 | return nn.AvgPool1d(*args, **kwargs) 48 | elif dims == 2: 49 | return nn.AvgPool2d(*args, **kwargs) 50 | elif dims == 3: 51 | return nn.AvgPool3d(*args, **kwargs) 52 | raise ValueError(f"unsupported dimensions: {dims}") 53 | 54 | 55 | def update_ema(target_params, source_params, rate=0.99): 56 | """ 57 | Update target parameters to be closer to those of source parameters using 58 | an exponential moving average. 59 | 60 | :param target_params: the target parameter sequence. 61 | :param source_params: the source parameter sequence. 62 | :param rate: the EMA rate (closer to 1 means slower). 63 | """ 64 | for targ, src in zip(target_params, source_params): 65 | targ.detach().mul_(rate).add_(src, alpha=1 - rate) 66 | 67 | 68 | def zero_module(module): 69 | """ 70 | Zero out the parameters of a module and return it. 71 | """ 72 | for p in module.parameters(): 73 | p.detach().zero_() 74 | return module 75 | 76 | 77 | def scale_module(module, scale): 78 | """ 79 | Scale the parameters of a module and return it. 80 | """ 81 | for p in module.parameters(): 82 | p.detach().mul_(scale) 83 | return module 84 | 85 | 86 | def mean_flat(tensor, mask = None): 87 | """ 88 | Take the mean over all non-batch dimensions. 89 | """ 90 | if mask is None: 91 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 92 | else: 93 | for _ in range(len(tensor.shape)-len(mask.shape)): 94 | mask = mask.unsqueeze(-1) 95 | mask = mask.expand(tensor.shape) 96 | tensor = tensor * mask 97 | return tensor.sum(dim=list(range(1, len(tensor.shape)))) / mask.sum(dim=list(range(1, len(mask.shape)))) 98 | 99 | 100 | 101 | def normalization(channels): 102 | """ 103 | Make a standard normalization layer. 104 | 105 | :param channels: number of input channels. 106 | :return: an nn.Module for normalization. 107 | """ 108 | return GroupNorm32(32, channels) 109 | 110 | 111 | def timestep_embedding(timesteps, dim, max_period=10000): 112 | """ 113 | Create sinusoidal timestep embeddings. 114 | 115 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 116 | These may be fractional. 117 | :param dim: the dimension of the output. 118 | :param max_period: controls the minimum frequency of the embeddings. 119 | :return: an [N x dim] Tensor of positional embeddings. 120 | """ 121 | half = dim // 2 122 | freqs = th.exp( 123 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half 124 | ).to(device=timesteps.device) 125 | args = timesteps[:, None].float() * freqs[None] 126 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) 127 | if dim % 2: 128 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) 129 | return embedding 130 | 131 | def checkpoint(func, inputs, params, flag): 132 | """ 133 | Evaluate a function without caching intermediate activations, allowing for 134 | reduced memory at the expense of extra compute in the backward pass. 135 | 136 | :param func: the function to evaluate. 137 | :param inputs: the argument sequence to pass to `func`. 138 | :param params: a sequence of parameters `func` depends on but does not 139 | explicitly take as arguments. 140 | :param flag: if False, disable gradient checkpointing. 141 | """ 142 | if flag: 143 | args = tuple(inputs) + tuple(params) 144 | return CheckpointFunction.apply(func, len(inputs), *args) 145 | else: 146 | return func(*inputs) 147 | 148 | 149 | class CheckpointFunction(th.autograd.Function): 150 | @staticmethod 151 | def forward(ctx, run_function, length, *args): 152 | ctx.run_function = run_function 153 | ctx.input_tensors = list(args[:length]) 154 | ctx.input_params = list(args[length:]) 155 | with th.no_grad(): 156 | output_tensors = ctx.run_function(*ctx.input_tensors) 157 | return output_tensors 158 | 159 | @staticmethod 160 | def backward(ctx, *output_grads): 161 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 162 | with th.enable_grad(): 163 | # Fixes a bug where the first op in run_function modifies the 164 | # Tensor storage in place, which is not allowed for detach()'d 165 | # Tensors. 166 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 167 | output_tensors = ctx.run_function(*shallow_copies) 168 | input_grads = th.autograd.grad( 169 | output_tensors, 170 | ctx.input_tensors + ctx.input_params, 171 | output_grads, 172 | allow_unused=True, 173 | ) 174 | del ctx.input_tensors 175 | del ctx.input_params 176 | del output_tensors 177 | return (None, None) + input_grads 178 | -------------------------------------------------------------------------------- /src/modeling/diffusion/resample.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import numpy as np 4 | import torch as th 5 | import torch.distributed as dist 6 | 7 | 8 | def create_named_schedule_sampler(name, diffusion): 9 | """ 10 | Create a ScheduleSampler from a library of pre-defined samplers. 11 | 12 | :param name: the name of the sampler. 13 | :param diffusion: the diffusion object to sample for. 14 | """ 15 | if name == "uniform": 16 | return UniformSampler(diffusion) 17 | elif name == "loss-second-moment": 18 | print('using loss-second-moment sampler') 19 | return LossSecondMomentResampler(diffusion) 20 | elif name == 'uniform-sample-second-moment': 21 | print('using uniform-sample-second-moment sampler') 22 | return UniformSamplerSecondMomentWeigth(diffusion) 23 | else: 24 | raise NotImplementedError(f"unknown schedule sampler: {name}") 25 | 26 | 27 | class ScheduleSampler(ABC): 28 | """ 29 | A distribution over timesteps in the diffusion process, intended to reduce 30 | variance of the objective. 31 | 32 | By default, samplers perform unbiased importance sampling, in which the 33 | objective's mean is unchanged. 34 | However, subclasses may override sample() to change how the resampled 35 | terms are reweighted, allowing for actual changes in the objective. 36 | """ 37 | 38 | @abstractmethod 39 | def weights(self): 40 | """ 41 | Get a numpy array of weights, one per diffusion step. 42 | 43 | The weights needn't be normalized, but must be positive. 44 | """ 45 | 46 | def sample(self, batch_size, device): 47 | """ 48 | Importance-sample timesteps for a batch. 49 | 50 | :param batch_size: the number of timesteps. 51 | :param device: the torch device to save to. 52 | :return: a tuple (timesteps, weights): 53 | - timesteps: a tensor of timestep indices. 54 | - weights: a tensor of weights to scale the resulting losses. 55 | """ 56 | w = self.weights() 57 | p = w / np.sum(w) 58 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p) 59 | indices = th.from_numpy(indices_np).long().to(device) 60 | weights_np = 1 / (len(p) * p[indices_np]) 61 | weights = th.from_numpy(weights_np).float().to(device) 62 | return indices, weights 63 | 64 | 65 | class UniformSampler(ScheduleSampler): 66 | def __init__(self, diffusion): 67 | self.diffusion = diffusion 68 | self._weights = np.ones([diffusion.num_timesteps]) 69 | 70 | def weights(self): 71 | return self._weights 72 | 73 | 74 | class LossAwareSampler(ScheduleSampler): 75 | def update_with_local_losses(self, local_ts, local_losses): 76 | """ 77 | Update the reweighting using losses from a model. 78 | 79 | Call this method from each rank with a batch of timesteps and the 80 | corresponding losses for each of those timesteps. 81 | This method will perform synchronization to make sure all of the ranks 82 | maintain the exact same reweighting. 83 | 84 | :param local_ts: an integer Tensor of timesteps. 85 | :param local_losses: a 1D Tensor of losses. 86 | """ 87 | batch_sizes = [ 88 | th.tensor([0], dtype=th.int32, device=local_ts.device) 89 | for _ in range(dist.get_world_size()) 90 | ] 91 | dist.all_gather( 92 | batch_sizes, 93 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), 94 | ) 95 | 96 | # Pad all_gather batches to be the maximum batch size. 97 | batch_sizes = [x.item() for x in batch_sizes] 98 | max_bs = max(batch_sizes) 99 | 100 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] 101 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] 102 | dist.all_gather(timestep_batches, local_ts) 103 | dist.all_gather(loss_batches, local_losses) 104 | timesteps = [ 105 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] 106 | ] 107 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] 108 | self.update_with_all_losses(timesteps, losses) 109 | 110 | @abstractmethod 111 | def update_with_all_losses(self, ts, losses): 112 | """ 113 | Update the reweighting using losses from a model. 114 | 115 | Sub-classes should override this method to update the reweighting 116 | using losses from the model. 117 | 118 | This method directly updates the reweighting without synchronizing 119 | between workers. It is called by update_with_local_losses from all 120 | ranks with identical arguments. Thus, it should have deterministic 121 | behavior to maintain state across workers. 122 | 123 | :param ts: a list of int timesteps. 124 | :param losses: a list of float losses, one per timestep. 125 | """ 126 | 127 | 128 | class LossSecondMomentResampler(LossAwareSampler): 129 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): 130 | self.diffusion = diffusion 131 | self.history_per_term = history_per_term 132 | self.uniform_prob = uniform_prob 133 | self._loss_history = np.zeros( 134 | [diffusion.num_timesteps, history_per_term], dtype=np.float64 135 | ) 136 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) 137 | 138 | def weights(self): 139 | if not self._warmed_up(): 140 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64) 141 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) 142 | weights /= np.sum(weights) 143 | weights *= 1 - self.uniform_prob 144 | weights += self.uniform_prob / len(weights) 145 | return weights 146 | 147 | def update_with_all_losses(self, ts, losses): 148 | for t, loss in zip(ts, losses): 149 | if self._loss_counts[t] == self.history_per_term: 150 | # Shift out the oldest loss term. 151 | self._loss_history[t, :-1] = self._loss_history[t, 1:] 152 | self._loss_history[t, -1] = loss 153 | else: 154 | self._loss_history[t, self._loss_counts[t]] = loss 155 | self._loss_counts[t] += 1 156 | 157 | def _warmed_up(self): 158 | return (self._loss_counts == self.history_per_term).all() 159 | 160 | 161 | class UniformSamplerSecondMomentWeigth(LossSecondMomentResampler): 162 | 163 | def sample(self, batch_size, device): 164 | """ 165 | Importance-sample timesteps for a batch. 166 | 167 | :param batch_size: the number of timesteps. 168 | :param device: the torch device to save to. 169 | :return: a tuple (timesteps, weights): 170 | - timesteps: a tensor of timestep indices. 171 | - weights: a tensor of weights to scale the resulting losses. 172 | """ 173 | w = self.weights() 174 | p = w / np.sum(w) 175 | 176 | p_samper = 1 / len(p) * np.ones(len(p)) 177 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p_samper) 178 | indices = th.from_numpy(indices_np).long().to(device) 179 | 180 | weights_np = p[indices_np] * batch_size / np.sum(p[indices_np]) 181 | weights = th.from_numpy(weights_np).float().to(device) 182 | return indices, weights 183 | -------------------------------------------------------------------------------- /src/modeling/diffusion/respace.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch as th 3 | 4 | from src.modeling.diffusion.gaussian_diffusion import GaussianDiffusion 5 | 6 | 7 | def space_timesteps(num_timesteps, section_counts): 8 | """ 9 | Create a list of timesteps to use from an original diffusion process, 10 | given the number of timesteps we want to take from equally-sized portions 11 | of the original process. 12 | 13 | For example, if there's 300 timesteps and the section counts are [10,15,20] 14 | then the first 100 timesteps are strided to be 10 timesteps, the second 100 15 | are strided to be 15 timesteps, and the final 100 are strided to be 20. 16 | 17 | If the stride is a string starting with "ddim", then the fixed striding 18 | from the DDIM paper is used, and only one section is allowed. 19 | 20 | :param num_timesteps: the number of diffusion steps in the original 21 | process to divide up. 22 | :param section_counts: either a list of numbers, or a string containing 23 | comma-separated numbers, indicating the step count 24 | per section. As a special case, use "ddimN" where N 25 | is a number of steps to use the striding from the 26 | DDIM paper. 27 | :return: a set of diffusion steps from the original process to use. 28 | """ 29 | if isinstance(section_counts, str): 30 | if section_counts.startswith("ddim"): 31 | desired_count = int(section_counts[len("ddim") :]) 32 | for i in range(1, num_timesteps): 33 | if len(range(0, num_timesteps, i)) == desired_count: 34 | return set(range(0, num_timesteps, i)) 35 | raise ValueError( 36 | f"cannot create exactly {num_timesteps} steps with an integer stride" 37 | ) 38 | section_counts = [int(x) for x in section_counts.split(",")] 39 | size_per = num_timesteps // len(section_counts) 40 | extra = num_timesteps % len(section_counts) 41 | start_idx = 0 42 | all_steps = [] 43 | for i, section_count in enumerate(section_counts): 44 | size = size_per + (1 if i < extra else 0) 45 | if size < section_count: 46 | raise ValueError( 47 | f"cannot divide section of {size} steps into {section_count}" 48 | ) 49 | if section_count <= 1: 50 | frac_stride = 1 51 | else: 52 | frac_stride = (size - 1) / (section_count - 1) 53 | cur_idx = 0.0 54 | taken_steps = [] 55 | for _ in range(section_count): 56 | taken_steps.append(start_idx + round(cur_idx)) 57 | cur_idx += frac_stride 58 | all_steps += taken_steps 59 | start_idx += size 60 | return set(all_steps) 61 | 62 | 63 | class SpacedDiffusion(GaussianDiffusion): 64 | """ 65 | A diffusion process which can skip steps in a base diffusion process. 66 | 67 | :param use_timesteps: a collection (sequence or set) of timesteps from the 68 | original diffusion process to retain. 69 | :param kwargs: the kwargs to create the base diffusion process. 70 | """ 71 | 72 | def __init__(self, use_timesteps, **kwargs): 73 | self.use_timesteps = set(use_timesteps) 74 | self.timestep_map = [] 75 | self.original_num_steps = len(kwargs["betas"]) 76 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa 77 | last_alpha_cumprod = 1.0 78 | new_betas = [] 79 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): 80 | if i in self.use_timesteps: 81 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) 82 | last_alpha_cumprod = alpha_cumprod 83 | self.timestep_map.append(i) 84 | kwargs["betas"] = np.array(new_betas) 85 | super().__init__(**kwargs) 86 | 87 | def p_mean_variance( 88 | self, model, *args, **kwargs 89 | ): # pylint: disable=signature-differs 90 | # print('called p_mean_var') 91 | # print(kwargs.keys()) 92 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) 93 | 94 | def training_losses( 95 | self, model, *args, **kwargs 96 | ): # pylint: disable=signature-differs 97 | # print('called training_losses') 98 | return super().training_losses(self._wrap_model(model), *args, **kwargs) 99 | 100 | def _wrap_model(self, model): 101 | if isinstance(model, _WrappedModel_encoder_decoder): 102 | return model 103 | return _WrappedModel_encoder_decoder( 104 | model, self.timestep_map, self.rescale_timesteps, self.original_num_steps 105 | ) 106 | 107 | def _scale_timesteps(self, t): 108 | # Scaling is done by the wrapped model. 109 | return t 110 | 111 | class _WrappedModel_encoder_decoder: 112 | def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps): 113 | self.model = model 114 | self.timestep_map = timestep_map 115 | self.rescale_timesteps = rescale_timesteps 116 | self.original_num_steps = original_num_steps 117 | 118 | def __call__(self, x, ts, **kwargs): 119 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) 120 | new_ts = map_tensor[ts] 121 | if self.rescale_timesteps: 122 | new_ts = new_ts.float() * (1000.0 / self.original_num_steps) 123 | return self.model(decoder_inputs_embeds = x, timesteps = new_ts, **kwargs) 124 | 125 | -------------------------------------------------------------------------------- /src/modeling/diffusion/rounding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | # bert results 3 | 4 | # print( os.path.join(sys.path[0], '../../transformers/examples/pytorch/language-modeling')) 5 | # sys.path.insert(0, 'diffusion_lm/transformers/examples/pytorch/language-modeling') 6 | # sys.path.insert(0, os.path.join(sys.path[0], '../../transformers/examples/pytorch/language-modeling')) 7 | # from custom_trainer import GPT2LMHeadModelCompress, BERTModelCompress, AutoEncoderWithNoise 8 | import json 9 | 10 | def load_embeddings_and_tokenizer(modality=None, mode=None, model_name_or_path=None, emb_dim=None, checkpoint_path=None, extra_args=None): 11 | 12 | path_save_tokenizer = '{}/vocab.json'.format(checkpoint_path) 13 | print(f'loading from {path_save_tokenizer}') 14 | with open(path_save_tokenizer, 'r') as f: 15 | vocab = json.load(f) 16 | print(len(vocab)) 17 | tokenizer = {v: k for k, v in vocab.items()} 18 | model = torch.nn.Embedding(tokenizer.vocab_size, emb_dim) 19 | path_save = '{}/random_emb.torch'.format(checkpoint_path) 20 | model.load_state_dict(torch.load(path_save)) 21 | 22 | return model, tokenizer 23 | 24 | 25 | def load_tokenizer(modality, mode, model_name_or_path): 26 | import json 27 | path_save_tokenizer = '{}/vocab.json'.format(model_name_or_path) 28 | with open(path_save_tokenizer, 'r') as f: 29 | vocab = json.load(f) 30 | tokenizer = {v: k for k, v in vocab.items()} 31 | 32 | return tokenizer 33 | 34 | def rounding_func(mode, text_emb_lst, model, tokenizer, emb_scale_factor=1.0): 35 | decoded_out_lst = [] 36 | if mode in ['random', 'random_up_proj', 'glove']: 37 | down_proj_emb = model.weight # input_embs 38 | down_proj_emb2 = None 39 | 40 | 41 | def get_knn(down_proj_emb, text_emb, dist='cos'): 42 | 43 | if dist == 'cos': 44 | adjacency = down_proj_emb @ text_emb.transpose(1, 0).to(down_proj_emb.device) 45 | elif dist == 'l2': 46 | adjacency = down_proj_emb.unsqueeze(1).expand(-1, text_emb.size(0), -1) - text_emb.unsqueeze(0).expand( 47 | down_proj_emb.size(0), -1, -1) 48 | adjacency = -torch.norm(adjacency, dim=-1) 49 | topk_out = torch.topk(adjacency, k=6, dim=0) 50 | return topk_out.values, topk_out.indices 51 | 52 | dist = 'l2' 53 | # print(npzfile['arr_0'].shape) 54 | for text_emb in text_emb_lst: 55 | import torch 56 | text_emb = torch.tensor(text_emb) 57 | # print(text_emb.shape) 58 | if len(text_emb.shape) > 2: 59 | text_emb = text_emb.view(-1, text_emb.size(-1)) 60 | else: 61 | text_emb = text_emb 62 | val, indices = get_knn((down_proj_emb2 if dist == 'cos' else down_proj_emb), 63 | text_emb.to(down_proj_emb.device), dist=dist) 64 | # generated_lst.append(tuple(indices[0].tolist())) 65 | 66 | # print(indices[0].tolist()) 67 | # for i in range(64): 68 | # print([tokenizer[x.item()] for x in indices[:,i]]) 69 | decoded_out = " ".join([tokenizer[i] for i in indices[0].tolist()]) 70 | decoded_out_lst.append(decoded_out) 71 | 72 | return decoded_out_lst 73 | 74 | -------------------------------------------------------------------------------- /src/modeling/predictor/transformer_model.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig 2 | from modeling_bart import BartModel 3 | import torch 4 | import torch as th 5 | import torch.nn as nn 6 | from src.modeling.diffusion.nn import ( 7 | SiLU, 8 | linear, 9 | timestep_embedding, 10 | ) 11 | import math 12 | 13 | 14 | class TransformerNetModel_encoder_decoder(nn.Module): 15 | """ 16 | A transformer model to be used in Diffusion Model Training. 17 | 18 | :param in_channels: channels in the input Tensor. 19 | :param model_channels: base channel count for the model. 20 | :param out_channels: channels in the output Tensor. 21 | :param dropout: the dropout probability. 22 | :param channel_mult: channel multiplier for each level of the UNet. 23 | :param dims: determines if the signal is 1D, 2D, or 3D. 24 | :param num_classes: if specified (as an int), then this model will be 25 | class-conditional with `num_classes` classes. TODO for the next version 26 | :param use_checkpoint: use gradient checkpointing to reduce memory usage. 27 | :param num_heads: the number of attention heads in each attention layer. 28 | """ 29 | 30 | def __init__( 31 | self, 32 | in_channels, 33 | model_channels, 34 | out_channels, 35 | init_pretrained, 36 | freeze_embeddings, 37 | use_pretrained_embeddings, 38 | dropout=0, 39 | use_checkpoint=False, 40 | num_heads=1, 41 | config=None, 42 | config_name="bert-base-uncased", 43 | vocab_size=None, 44 | logits_mode=1, 45 | encoder_layers = 6, 46 | decoder_layers = 6, 47 | load_ckpt=None, 48 | ): 49 | super().__init__() 50 | 51 | if config is None: 52 | config = AutoConfig.from_pretrained(config_name) 53 | config.dropout = dropout 54 | # config.hidden_size = 512 55 | 56 | self.in_channels = in_channels 57 | self.model_channels = model_channels 58 | self.out_channels = out_channels 59 | self.dropout = dropout 60 | self.use_checkpoint = use_checkpoint 61 | self.num_heads = num_heads 62 | self.logits_mode = logits_mode 63 | self.vocab_size = vocab_size 64 | self.init_pretrained = init_pretrained 65 | self.freeze_embeddings = freeze_embeddings 66 | self.use_pretrained_embeddings = use_pretrained_embeddings 67 | self.config = config 68 | self.config_name = config_name 69 | self.load_ckpt = load_ckpt 70 | 71 | if not self.init_pretrained: 72 | self.config.encoder_layers = encoder_layers 73 | self.config.decoder_layers = decoder_layers 74 | self.config.vocab_size = vocab_size 75 | self.config.encoder_attention_heads = num_heads 76 | self.config.decoder_attention_heads = num_heads 77 | self.config.d_model = in_channels 78 | self.config.encoder_ffn_dim = model_channels 79 | self.config.decoder_ffn_dim = model_channels 80 | self.embedding_dim = 128 #self.config.d_model // 4 81 | self.embed_scale = math.sqrt(self.embedding_dim) if self.config.scale_embedding else 1.0 82 | 83 | time_embed_dim = in_channels 84 | self.time_embed = nn.Sequential( 85 | linear(in_channels, time_embed_dim), 86 | SiLU(), 87 | linear(time_embed_dim, config.d_model), 88 | ) 89 | 90 | 91 | self.build_xstart_predictor() 92 | self.build_input_output_projections() 93 | self.build_embeddings() 94 | 95 | self.LayerNorm = nn.LayerNorm(config.d_model) 96 | self.dropout = nn.Dropout(config.dropout) 97 | 98 | if self.load_ckpt is not None: 99 | self.load_weight(self.load_ckpt) 100 | 101 | def get_embeds(self, input_ids): 102 | return self.input_transformers.decoder.embed_tokens(input_ids) * self.embed_scale 103 | 104 | def load_weight(self, path): 105 | 106 | self.load_state_dict(torch.load(path)) 107 | print(f'weigth initialize from {path}') 108 | 109 | def build_xstart_predictor(self): 110 | if self.init_pretrained: 111 | 112 | temp_bart = BartModel.from_pretrained(self.config_name, config=self.config) 113 | self.input_transformers = temp_bart 114 | else: 115 | self.input_transformers = BartModel(self.config, self.embedding_dim) 116 | 117 | def build_input_output_projections(self): 118 | if self.in_channels != self.embedding_dim: 119 | # need to adapt the model to the embedding size 120 | self.input_up_proj_dec = nn.Sequential( 121 | nn.Linear(self.embedding_dim * 2, self.config.d_model), 122 | nn.Tanh(), 123 | nn.Linear(self.config.d_model, self.config.d_model), 124 | ) 125 | 126 | self.input_up_proj_enc = nn.Sequential( 127 | nn.Linear(self.embedding_dim, self.config.d_model), 128 | nn.Tanh(), 129 | nn.Linear(self.config.d_model, self.config.d_model), 130 | ) 131 | 132 | self.output_down_proj = nn.Sequential( 133 | nn.Linear(self.config.d_model, self.config.d_model), 134 | nn.Tanh(), 135 | nn.Linear(self.config.d_model, self.embedding_dim), 136 | ) 137 | else: 138 | self.input_up_proj = nn.Identity() 139 | self.output_down_proj = nn.Identity() 140 | 141 | 142 | def build_embeddings(self): 143 | 144 | self.lm_head = nn.Linear(self.embedding_dim, self.input_transformers.shared.weight.shape[0]) 145 | 146 | with th.no_grad(): 147 | self.lm_head.weight = self.input_transformers.shared.weight 148 | 149 | def get_logits(self, hidden_repr): 150 | return self.lm_head(hidden_repr) 151 | 152 | def forward_encoder(self, 153 | input_ids = None, 154 | timesteps = None, 155 | attention_mask = None, 156 | decoder_inputs_embeds = None, 157 | decoder_attention_mask = None, 158 | self_conditions = None, 159 | ): 160 | """ 161 | Apply the model to an input batch. 162 | 163 | :param x: an [N x C x ...] Tensor of inputs. 164 | :param timesteps: a 1-D batch of timesteps. 165 | :param y: an [N] Tensor of labels, if class-conditional. 166 | :return: an [N x C x ...] Tensor of outputs. 167 | """ 168 | 169 | emb = self.time_embed(timestep_embedding(timesteps, self.in_channels)) 170 | seq_length = decoder_inputs_embeds.size(1) 171 | if len(emb.shape) < 3: 172 | emb = emb.unsqueeze(1).expand(-1, seq_length, -1) 173 | # decoder_inputs_embeds = self.input_transformers.decoder.embed_tokens(decoder_input_ids) * self.embed_scale 174 | if self_conditions is not None: 175 | 176 | decoder_inputs_embeds = th.concat((decoder_inputs_embeds, self_conditions), dim = -1) 177 | 178 | decoder_inputs_embeds = ( 179 | self.input_up_proj_dec(decoder_inputs_embeds) 180 | + emb 181 | ) 182 | emb_inputs = self.dropout(self.LayerNorm(decoder_inputs_embeds)) 183 | 184 | encoder_hidden_states = self.input_transformers( 185 | input_ids = None, 186 | attention_mask=attention_mask, 187 | inputs_embeds = self.input_up_proj_enc(self.input_transformers.encoder.embed_tokens(input_ids) * self.embed_scale), 188 | decoder_input_ids=None, 189 | decoder_inputs_embeds=emb_inputs, 190 | decoder_attention_mask=decoder_attention_mask, 191 | output_attentions=True, 192 | ).encoder_last_hidden_state 193 | 194 | return encoder_hidden_states 195 | 196 | def forward(self, 197 | input_ids = None, 198 | timesteps = None, 199 | attention_mask = None, 200 | decoder_inputs_embeds = None, 201 | decoder_attention_mask = None, 202 | self_conditions = None, 203 | encoder_outputs=None, 204 | ): 205 | """ 206 | Apply the model to an input batch. 207 | 208 | :param x: an [N x C x ...] Tensor of inputs. 209 | :param timesteps: a 1-D batch of timesteps. 210 | :param y: an [N] Tensor of labels, if class-conditional. 211 | :return: an [N x C x ...] Tensor of outputs. 212 | """ 213 | assert encoder_outputs is None or input_ids is None 214 | emb = self.time_embed(timestep_embedding(timesteps, self.in_channels)) 215 | seq_length = decoder_inputs_embeds.size(1) 216 | if len(emb.shape) < 3: 217 | emb = emb.unsqueeze(1).expand(-1, seq_length, -1) 218 | if self_conditions is not None: 219 | 220 | decoder_inputs_embeds = th.concat((decoder_inputs_embeds, self_conditions), dim = -1) 221 | 222 | decoder_inputs_embeds = ( 223 | self.input_up_proj_dec(decoder_inputs_embeds) 224 | + emb 225 | ) 226 | emb_inputs = self.dropout(self.LayerNorm(decoder_inputs_embeds)) 227 | 228 | input_trans_hidden_states = self.input_transformers( 229 | input_ids = None, 230 | attention_mask=attention_mask, 231 | inputs_embeds = self.input_up_proj_enc(self.input_transformers.encoder.embed_tokens(input_ids) * self.embed_scale) if input_ids is not None else None, 232 | decoder_input_ids=None, 233 | decoder_inputs_embeds=emb_inputs, 234 | decoder_attention_mask=decoder_attention_mask, 235 | encoder_outputs=encoder_outputs 236 | ).last_hidden_state 237 | 238 | h = self.output_down_proj(input_trans_hidden_states) 239 | 240 | return h 241 | -------------------------------------------------------------------------------- /src/utils/args_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for command line arguments. 3 | """ 4 | 5 | import argparse 6 | 7 | 8 | 9 | def create_argparser(): 10 | defaults = dict( 11 | data_dir="", 12 | schedule_sampler="uniform", 13 | lr=1e-4, 14 | weight_decay=0.0, 15 | lr_anneal_steps=30000, 16 | batch_size=1, 17 | microbatch=-1, # -1 disables microbatches 18 | ema_rate="0.9999", # comma-separated list of EMA values 19 | log_interval=50, 20 | save_interval=25000, 21 | resume_checkpoint="", 22 | use_fp16=False, 23 | fp16_scale_growth=1e-3, 24 | seed=101, 25 | gradient_clipping=-1.0, 26 | eval_interval=2000, 27 | checkpoint_path="diff_models", 28 | train_txt_path="data/quotes_train.txt", 29 | val_txt_path="data/quotes_valid.txt", 30 | dataset="", 31 | notes="", 32 | ) 33 | text_defaults = dict( 34 | modality="text", 35 | emb_scale_factor=1.0, 36 | in_channel=16, 37 | out_channel=16, 38 | noise_level=0.0, 39 | cache_mode="no", 40 | use_bert_tokenizer="no", 41 | padding_mode="block", 42 | preprocessing_num_workers=1, 43 | tok_thresh=150 44 | ) 45 | 46 | guided_generation_defaults = dict( 47 | classifier_num_epochs=15 48 | ) 49 | 50 | defaults.update(model_and_diffusion_defaults()) 51 | defaults.update(text_defaults) 52 | defaults.update(guided_generation_defaults) 53 | defaults.update(decoding_defaults()) 54 | parser = argparse.ArgumentParser() 55 | parser.add_argument("--debug", action="store_true") 56 | 57 | add_dict_to_argparser(parser, defaults) 58 | return parser 59 | 60 | 61 | def model_and_diffusion_defaults(): 62 | """ 63 | Defaults for text-diffusion model training. 64 | """ 65 | return dict( 66 | sequence_len=64, 67 | num_channels=16, 68 | num_heads=4, 69 | dropout=0.0, 70 | learn_sigma=False, 71 | sigma_small=False, 72 | class_cond=False, 73 | diffusion_steps=10000, 74 | noise_schedule="linear", 75 | timestep_respacing="", 76 | use_kl=False, 77 | predict_xstart=False, 78 | rescale_timesteps=True, 79 | rescale_learned_sigmas=True, 80 | use_checkpoint=False, 81 | model_arch="transformer", 82 | in_channel=16, 83 | out_channel=16, 84 | vocab_size=66, 85 | config_name="bert-base-uncased", 86 | logits_mode=1, 87 | training_mode="diffusion-lm", 88 | init_pretrained=False, 89 | freeze_embeddings=False, 90 | use_pretrained_embeddings=True, 91 | ) 92 | 93 | 94 | def decoding_defaults(): 95 | return dict( 96 | num_samples=50, 97 | top_p=0.9, 98 | out_dir="", 99 | model_name_or_path="", 100 | checkpoint_path="", 101 | use_ddim=False, 102 | clip_denoised=False, 103 | batch_size=64, 104 | mbr_sample=1, 105 | verbose="yes", 106 | clamp="clamp", 107 | preprocessing_num_workers=1, 108 | emb_scale_factor=1.0, 109 | classifier_path="", 110 | ) 111 | 112 | 113 | def add_dict_to_argparser(parser, default_dict): 114 | for k, v in default_dict.items(): 115 | v_type = type(v) 116 | if v is None: 117 | v_type = str 118 | elif isinstance(v, bool): 119 | v_type = str2bool 120 | parser.add_argument(f"--{k}", default=v, type=v_type) 121 | 122 | 123 | def args_to_dict(args, keys): 124 | return {k: getattr(args, k) for k in keys} 125 | 126 | 127 | def str2bool(v): 128 | """ 129 | https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse 130 | """ 131 | if isinstance(v, bool): 132 | return v 133 | if v.lower() in ("yes", "true", "t", "y", "1"): 134 | return True 135 | elif v.lower() in ("no", "false", "f", "n", "0"): 136 | return False 137 | else: 138 | raise argparse.ArgumentTypeError("boolean value expected") 139 | -------------------------------------------------------------------------------- /src/utils/custom_tokenizer.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import pathlib 4 | import torch 5 | from transformers import AutoTokenizer 6 | 7 | from tokenizers.processors import BertProcessing 8 | from tokenizers import ByteLevelBPETokenizer, decoders 9 | 10 | logging.basicConfig(level=logging.INFO) 11 | 12 | def create_tokenizer(return_pretokenized, path, tokenizer_type: str = "word-level"): 13 | if return_pretokenized: 14 | tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") 15 | return tokenizer 16 | 17 | if tokenizer_type == "byte-level": 18 | return read_byte_level(path) 19 | elif tokenizer_type == "word-level": 20 | return read_word_level(path) 21 | else: 22 | raise ValueError(f"Invalid tokenizer type: {tokenizer_type}") 23 | 24 | def train_bytelevel( 25 | path, 26 | vocab_size=10000, 27 | min_frequency=1, 28 | special_tokens=["", "", "", "", ""], 29 | ): 30 | 31 | tokenizer = ByteLevelBPETokenizer() 32 | 33 | # Customize training 34 | tokenizer.train( 35 | files=[path], 36 | vocab_size=vocab_size, 37 | min_frequency=min_frequency, 38 | special_tokens=special_tokens, 39 | ) 40 | 41 | tokenizer.save_model(str(pathlib.Path(path).parent)) 42 | 43 | 44 | 45 | def read_byte_level(path: str): 46 | tokenizer = ByteLevelBPETokenizer( 47 | f"{path}/vocab.json", 48 | f"{path}/merges.txt", 49 | ) 50 | 51 | tokenizer._tokenizer.post_processor = BertProcessing( 52 | ("", tokenizer.token_to_id("")), 53 | ("", tokenizer.token_to_id("")), 54 | ) 55 | 56 | tokenizer.enable_truncation(max_length=512) 57 | 58 | print( 59 | tokenizer.encode( 60 | "Bores can be divided into two classes; those who have their own particular subject, and those who do not need a subject." 61 | ).tokens 62 | ) 63 | 64 | with open(f"{path}/vocab.json", "r") as fin: 65 | vocab = json.load(fin) 66 | 67 | # add length method to tokenizer object 68 | tokenizer.vocab_size = len(vocab) 69 | 70 | # add length property to tokenizer object 71 | tokenizer.__len__ = property(lambda self: self.vocab_size) 72 | 73 | tokenizer.decoder = decoders.ByteLevel() 74 | print(tokenizer.vocab_size) 75 | 76 | print( 77 | tokenizer.encode( 78 | "Bores can be divided into two classes; those who have their own particular subject, and those who do not need a subject." 79 | ).ids 80 | ) 81 | 82 | print( 83 | tokenizer.decode( 84 | tokenizer.encode( 85 | "Bores can be divided into two classes; those who have their own particular subject, and those who do not need a subject." 86 | ).ids, 87 | skip_special_tokens=True, 88 | ) 89 | ) 90 | 91 | ids = tokenizer.encode( 92 | "Bores can be divided into two classes; those who have their own particular subject, and those who do not need a subject." 93 | ).ids 94 | tensor = torch.tensor(ids) 95 | print(tokenizer.decode(tensor.tolist(), skip_special_tokens=True)) 96 | print(f"Vocab size: {tokenizer.vocab_size}") 97 | 98 | return tokenizer 99 | 100 | 101 | def read_word_level(path: str): 102 | 103 | from transformers import PreTrainedTokenizerFast 104 | 105 | logging.info(f"Loading tokenizer from {path}/word-level-vocab.json") 106 | tokenizer = PreTrainedTokenizerFast( 107 | tokenizer_file=f"{str(pathlib.Path(path))}/word-level-vocab.json", 108 | bos_token="[CLS]", 109 | eos_token="[SEP]", 110 | unk_token="[UNK]", 111 | sep_token="[SEP]", 112 | pad_token="[PAD]", 113 | cls_token="[CLS]", 114 | mask_token="[MASK]", 115 | padding_side="right", 116 | ) 117 | 118 | # add length property to tokenizer object 119 | tokenizer.__len__ = property(lambda self: self.vocab_size) 120 | 121 | return tokenizer 122 | 123 | 124 | def train_word_level_tokenizer( 125 | path: str, 126 | vocab_size: int = 10000, 127 | special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"], 128 | ): 129 | 130 | from tokenizers import Tokenizer, normalizers, pre_tokenizers 131 | from tokenizers.models import WordLevel 132 | from tokenizers.normalizers import NFD, Lowercase, StripAccents 133 | from tokenizers.pre_tokenizers import Digits, Whitespace 134 | from tokenizers.processors import TemplateProcessing 135 | from tokenizers.trainers import WordLevelTrainer 136 | 137 | tokenizer = Tokenizer(WordLevel(unk_token="[UNK]")) 138 | tokenizer.normalizer = normalizers.Sequence([NFD(), Lowercase(), StripAccents()]) 139 | tokenizer.pre_tokenizer = pre_tokenizers.Sequence( 140 | [Digits(individual_digits=True), Whitespace()] 141 | ) 142 | tokenizer.post_processor = TemplateProcessing( 143 | single="[CLS] $A [SEP]", special_tokens=[("[CLS]", 1), ("[SEP]", 2)] 144 | ) 145 | 146 | trainer = WordLevelTrainer(vocab_size=vocab_size, special_tokens=special_tokens) 147 | tokenizer.train(files=[path], trainer=trainer) 148 | 149 | tokenizer.__len__ = property(lambda self: self.vocab_size) 150 | 151 | tokenizer.enable_truncation(max_length=512) 152 | 153 | print(tokenizer.encode("the red.").ids) 154 | 155 | print(tokenizer.encode("the red.")) 156 | 157 | tokenizer.save(f"{str(pathlib.Path(path).parent)}/word-level-vocab.json") 158 | 159 | 160 | if __name__ == "__main__": 161 | import sys 162 | 163 | if sys.argv[1] == "train-word-level": 164 | train_word_level_tokenizer(path=sys.argv[2]) 165 | elif sys.argv[1] == "train-byte-level": 166 | train_bytelevel(path=sys.argv[2]) 167 | elif sys.argv[1] == "create": 168 | create_tokenizer(path=sys.argv[2]) 169 | -------------------------------------------------------------------------------- /src/utils/data_utils_sentencepiece.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | import pandas as pd 4 | from torch.utils.data import DataLoader, Dataset 5 | import torch 6 | from functools import partial 7 | 8 | logging.basicConfig(level=logging.INFO) 9 | 10 | # BAD: this should not be global 11 | # tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") 12 | 13 | 14 | 15 | 16 | def get_dataloader(tokenizer, data_path, batch_size, max_seq_len): 17 | dataset = TextDataset(tokenizer=tokenizer, data_path=data_path) 18 | 19 | dataloader = DataLoader( 20 | dataset, 21 | batch_size=batch_size, # 20, 22 | drop_last=True, 23 | shuffle=True, 24 | num_workers=1, 25 | collate_fn=partial(TextDataset.collate_pad, cutoff=max_seq_len), 26 | ) 27 | 28 | while True: 29 | for batch in dataloader: 30 | yield batch 31 | 32 | 33 | class TextDataset(Dataset): 34 | def __init__( 35 | self, 36 | tokenizer, 37 | data_path: str, 38 | has_labels: bool = False 39 | ) -> None: 40 | super().__init__() 41 | self.data_path = data_path 42 | self.tokenizer = tokenizer 43 | self.read_data() 44 | if has_labels: 45 | self.read_labels() 46 | 47 | def read_data(self): 48 | logging.info("Reading data from {}".format(self.data_path)) 49 | data = pd.read_csv(self.data_path, sep="\t", header=None) # read text file 50 | logging.info(f"Tokenizing {len(data)} sentences") 51 | 52 | self.text = data[0].apply(lambda x: x.strip()).tolist() 53 | # encoded_input = self.tokenizer(self.questions, self.paragraphs) 54 | 55 | # check if tokenizer has a method 'encode_batch' 56 | if hasattr(self.tokenizer, 'encode_batch'): 57 | 58 | encoded_input = self.tokenizer.encode_batch(self.text) 59 | self.input_ids = [x.ids for x in encoded_input] 60 | 61 | else: 62 | encoded_input = self.tokenizer(self.text) 63 | self.input_ids = encoded_input["input_ids"] 64 | 65 | def read_labels(self): 66 | self.labels = pd.read_csv(self.data_path, sep="\t", header=None)[1].tolist() 67 | # check if labels are already numerical 68 | self.labels = [str(x) for x in self.labels] 69 | if isinstance(self.labels[0], int): 70 | return 71 | # if not, convert to numerical 72 | all_labels = sorted(list(set(self.labels))) 73 | self.label_to_idx = {label: i for i, label in enumerate(all_labels)} 74 | self.idx_to_label = {i: label for i, label in self.label_to_idx.items()} 75 | self.labels = [self.label_to_idx[label] for label in self.labels] 76 | 77 | 78 | 79 | def __len__(self) -> int: 80 | return len(self.text) 81 | 82 | def __getitem__(self, i): 83 | out_dict = { 84 | "input_ids": self.input_ids[i], 85 | # "attention_mask": [1] * len(self.input_ids[i]), 86 | } 87 | if hasattr(self, "labels"): 88 | out_dict["label"] = self.labels[i] 89 | return out_dict 90 | 91 | @staticmethod 92 | def collate_pad(batch, cutoff: int): 93 | max_token_len = 0 94 | num_elems = len(batch) 95 | # batch[0] -> __getitem__[0] --> returns a tuple (embeddings, out_dict) 96 | 97 | for i in range(num_elems): 98 | max_token_len = max(max_token_len, len(batch[i]["input_ids"])) 99 | 100 | max_token_len = min(cutoff, max_token_len) 101 | 102 | tokens = torch.zeros(num_elems, max_token_len).long() 103 | tokens_mask = torch.zeros(num_elems, max_token_len).long() 104 | 105 | has_labels = False 106 | if "label" in batch[0]: 107 | labels = torch.zeros(num_elems).long() 108 | has_labels = True 109 | 110 | for i in range(num_elems): 111 | toks = batch[i]["input_ids"] 112 | length = len(toks) 113 | tokens[i, :length] = torch.LongTensor(toks) 114 | tokens_mask[i, :length] = 1 115 | if has_labels: 116 | labels[i] = batch[i]["label"] 117 | 118 | # TODO: the first return None is just for backward compatibility -- can be removed 119 | if has_labels: 120 | return None, {"input_ids": tokens, "attention_mask": tokens_mask, "labels": labels} 121 | else: 122 | return None, {"input_ids": tokens, "attention_mask": tokens_mask} 123 | -------------------------------------------------------------------------------- /src/utils/dist_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for distributed training. 3 | """ 4 | 5 | import io 6 | import os 7 | import socket 8 | 9 | import blobfile as bf 10 | from mpi4py import MPI 11 | import torch as th 12 | import torch.distributed as dist 13 | 14 | # Change this to reflect your cluster layout. 15 | # The GPU for a given rank is (rank % GPUS_PER_NODE). 16 | GPUS_PER_NODE = 10 #8 17 | 18 | SETUP_RETRY_COUNT = 3 19 | 20 | 21 | def setup_dist(): 22 | """ 23 | Setup a distributed process group. 24 | """ 25 | if dist.is_initialized(): 26 | return 27 | 28 | comm = MPI.COMM_WORLD 29 | backend = "gloo" if not th.cuda.is_available() else "nccl" 30 | 31 | if backend == "gloo": 32 | hostname = "localhost" 33 | else: 34 | hostname = socket.gethostbyname(socket.getfqdn()) 35 | os.environ["MASTER_ADDR"] = comm.bcast(hostname, root=0) 36 | os.environ["RANK"] = str(comm.rank) 37 | os.environ["WORLD_SIZE"] = str(comm.size) 38 | 39 | 40 | port = comm.bcast(_find_free_port(), root=0) 41 | os.environ["MASTER_PORT"] = str(port) 42 | dist.init_process_group(backend=backend, init_method="env://") 43 | 44 | 45 | def dev(): 46 | """ 47 | Get the device to use for torch.distributed. 48 | """ 49 | if th.cuda.is_available(): 50 | return th.device(f"cuda:{MPI.COMM_WORLD.Get_rank() % GPUS_PER_NODE}") 51 | return th.device("cpu") 52 | 53 | 54 | def load_state_dict(path, **kwargs): 55 | """ 56 | Load a PyTorch file without redundant fetches across MPI ranks. 57 | """ 58 | if MPI.COMM_WORLD.Get_rank() == 0: 59 | with bf.BlobFile(path, "rb") as f: 60 | data = f.read() 61 | else: 62 | data = None 63 | data = MPI.COMM_WORLD.bcast(data) 64 | return th.load(io.BytesIO(data), **kwargs) 65 | 66 | 67 | def sync_params(params): 68 | """ 69 | Synchronize a sequence of Tensors across ranks from rank 0. 70 | """ 71 | for p in params: 72 | with th.no_grad(): 73 | dist.broadcast(p, 0) 74 | 75 | 76 | def _find_free_port(): 77 | try: 78 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 79 | s.bind(("", 0)) 80 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 81 | return s.getsockname()[1] 82 | finally: 83 | s.close() 84 | -------------------------------------------------------------------------------- /src/utils/eval_ppl.py: -------------------------------------------------------------------------------- 1 | """ 2 | Evaluates perplexity of a model on a dataset. 3 | Directly taken from https://huggingface.co/spaces/evaluate-measurement/perplexity/blob/main/perplexity.py 4 | """ 5 | 6 | from itertools import chain 7 | from typing import List 8 | import torch 9 | 10 | import numpy as np 11 | import torch 12 | from torch.nn import CrossEntropyLoss 13 | from transformers import AutoModelForCausalLM, AutoTokenizer 14 | 15 | device = "cuda" if torch.cuda.is_available() else "cpu" 16 | 17 | model_id = "distilgpt2" 18 | model = AutoModelForCausalLM.from_pretrained(model_id) 19 | model = model.to(device) 20 | 21 | tokenizer = AutoTokenizer.from_pretrained(model_id) 22 | 23 | 24 | def compute_perplexity(data, batch_size: int = 16, add_start_token: bool = True): 25 | 26 | 27 | 28 | # if batch_size > 1 (which generally leads to padding being required), and 29 | # if there is not an already assigned pad_token, assign an existing 30 | # special token to also be the padding token 31 | if tokenizer.pad_token is None and batch_size > 1: 32 | existing_special_tokens = list(tokenizer.special_tokens_map_extended.values()) 33 | # check that the model already has at least one special token defined 34 | assert ( 35 | len(existing_special_tokens) > 0 36 | ), "If batch_size > 1, model must have at least one special token to use for padding. Please use a different model or set batch_size=1." 37 | # assign one of the special tokens to also be the pad token 38 | tokenizer.add_special_tokens({"pad_token": existing_special_tokens[0]}) 39 | 40 | if add_start_token: 41 | # leave room for token to be added: 42 | assert ( 43 | tokenizer.bos_token is not None 44 | ), "Input model must already have a BOS token if using add_start_token=True. Please use a different model, or set add_start_token=False" 45 | max_tokenized_len = model.config.max_length - 1 46 | else: 47 | max_tokenized_len = model.config.max_length 48 | 49 | encodings = tokenizer( 50 | data, 51 | add_special_tokens=False, 52 | padding=True, 53 | truncation=True, 54 | max_length=max_tokenized_len, 55 | return_tensors="pt", 56 | return_attention_mask=True, 57 | ).to(device) 58 | 59 | encoded_texts = encodings["input_ids"] 60 | attn_masks = encodings["attention_mask"] 61 | 62 | # check that each input is long enough: 63 | if add_start_token: 64 | assert torch.all(torch.ge(attn_masks.sum(1), 1)), "Each input text must be at least one token long." 65 | else: 66 | assert torch.all( 67 | torch.ge(attn_masks.sum(1), 2) 68 | ), "When add_start_token=False, each input text must be at least two tokens long. Run with add_start_token=True if inputting strings of only one token, and remove all empty input strings." 69 | 70 | ppls = [] 71 | loss_fct = CrossEntropyLoss(reduction="none") 72 | 73 | for start_index in tqdm(range(0, len(encoded_texts), batch_size)): 74 | end_index = min(start_index + batch_size, len(encoded_texts)) 75 | encoded_batch = encoded_texts[start_index:end_index] 76 | attn_mask = attn_masks[start_index:end_index] 77 | 78 | if add_start_token: 79 | bos_tokens_tensor = torch.tensor([[tokenizer.bos_token_id]] * encoded_batch.size(dim=0)).to(device) 80 | encoded_batch = torch.cat([bos_tokens_tensor, encoded_batch], dim=1) 81 | attn_mask = torch.cat( 82 | [torch.ones(bos_tokens_tensor.size(), dtype=torch.int64).to(device), attn_mask], dim=1 83 | ) 84 | 85 | labels = encoded_batch 86 | 87 | with torch.no_grad(): 88 | out_logits = model(encoded_batch, attention_mask=attn_mask).logits 89 | 90 | shift_logits = out_logits[..., :-1, :].contiguous() 91 | shift_labels = labels[..., 1:].contiguous() 92 | shift_attention_mask_batch = attn_mask[..., 1:].contiguous() 93 | 94 | perplexity_batch = torch.exp( 95 | (loss_fct(shift_logits.transpose(1, 2), shift_labels) * shift_attention_mask_batch).sum(1) 96 | / shift_attention_mask_batch.sum(1) 97 | ) 98 | 99 | ppls += perplexity_batch.tolist() 100 | 101 | return {"perplexities": ppls, "mean_perplexity": np.mean(ppls)} 102 | 103 | 104 | def calculate_perplexity_for_file(path: str): 105 | # read lines 106 | special_tokens = ["[CLS]", "[SEP]", "[PAD]", "[MASK]", "", ""] 107 | with open(path, "r") as f: 108 | lines = f.readlines() 109 | 110 | lines = [remove_all(line.strip(), special_tokens) for line in lines if len(line.strip()) > 0] 111 | try: 112 | num_unique_lines = len(set(lines)) 113 | perc_unique_lines = round(num_unique_lines * 100 / len(lines), 2) 114 | all_tokens = list(chain(*[line.split() for line in lines])) 115 | perc_unique_tokens = round(len(set(all_tokens)) * 100 / len(all_tokens), 2) 116 | 117 | return {"data": lines, "ppl": compute_perplexity(lines)['mean_perplexity'], "perc_unique_lines": perc_unique_lines, "perc_unique_tokens": perc_unique_tokens} 118 | except Exception as e: 119 | return {"data": [], "ppl": 1e6, "perc_unique_lines": 0, "perc_unique_tokens": 0} 120 | 121 | 122 | def remove_all(line: str, special_toks: List[str]) -> str: 123 | for tok in special_toks: 124 | line = line.replace(tok, "").strip() 125 | return line 126 | 127 | 128 | if __name__ == '__main__': 129 | import sys 130 | import glob 131 | from tqdm import tqdm 132 | from pprint import pprint 133 | import json 134 | files = glob.glob(sys.argv[1]) 135 | res = dict() 136 | for file in tqdm(files): 137 | res[file] = calculate_perplexity_for_file(file) 138 | 139 | 140 | # sort by perplexity 141 | res = {k: v for k, v in sorted(res.items(), key=lambda item: item[1]['ppl'])} 142 | 143 | for file in res: 144 | # show a few lines 145 | print(f"File: {file}") 146 | pprint(res[file]['data'][:5]) 147 | 148 | # show the perplexity 149 | print(f"Perplexity: {res[file]['ppl']}") 150 | print(f"Percentage of unique lines: {res[file]['perc_unique_lines']}") 151 | print(f"Percentage of unique tokens: {res[file]['perc_unique_tokens']}") 152 | print("-" * 100) 153 | 154 | # Create a nice MARKDOWN report with: i) sample sentences, ii) perplexity, iii) percentage of unique lines, iv) percentage of unique tokens 155 | 156 | import random 157 | print("| File | Sample Sentences | Perplexity | % Unique Lines | % Unique Tokens |") 158 | for file in res: 159 | sentences = set(res[file]['data']) 160 | # pick 5 random sentences 161 | sentences = random.sample(sentences, 5) if len(sentences) > 5 else sentences 162 | 163 | filename = "#".join(file.split("/")[:-1]) 164 | # print row 165 | print('-' * 80) 166 | if res[file]['perc_unique_tokens'] > 0: 167 | print(f"| {filename} | {', '.join(sentences)} | {res[file]['ppl']} | {res[file]['perc_unique_lines']} | {res[file]['perc_unique_tokens']} |") 168 | 169 | 170 | with open("perplexity.json", "w") as f: 171 | json.dump(res, f) 172 | 173 | -------------------------------------------------------------------------------- /src/utils/fp16_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers to train with 16-bit precision. 3 | """ 4 | 5 | import torch.nn as nn 6 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors 7 | 8 | 9 | def convert_module_to_f16(l): 10 | """ 11 | Convert primitive modules to float16. 12 | """ 13 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 14 | l.weight.data = l.weight.data.half() 15 | l.bias.data = l.bias.data.half() 16 | 17 | 18 | def convert_module_to_f32(l): 19 | """ 20 | Convert primitive modules to float32, undoing convert_module_to_f16(). 21 | """ 22 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 23 | l.weight.data = l.weight.data.float() 24 | l.bias.data = l.bias.data.float() 25 | 26 | 27 | def make_master_params(model_params): 28 | """ 29 | Copy model parameters into a (differently-shaped) list of full-precision 30 | parameters. 31 | """ 32 | master_params = _flatten_dense_tensors( 33 | [param.detach().float() for param in model_params] 34 | ) 35 | master_params = nn.Parameter(master_params) 36 | master_params.requires_grad = True 37 | return [master_params] 38 | 39 | 40 | def model_grads_to_master_grads(model_params, master_params): 41 | """ 42 | Copy the gradients from the model parameters into the master parameters 43 | from make_master_params(). 44 | """ 45 | master_params[0].grad = _flatten_dense_tensors( 46 | [param.grad.data.detach().float() for param in model_params] 47 | ) 48 | 49 | 50 | def master_params_to_model_params(model_params, master_params): 51 | """ 52 | Copy the master parameter data back into the model parameters. 53 | """ 54 | # Without copying to a list, if a generator is passed, this will 55 | # silently not copy any parameters. 56 | model_params = list(model_params) 57 | 58 | for param, master_param in zip( 59 | model_params, unflatten_master_params(model_params, master_params) 60 | ): 61 | param.detach().copy_(master_param) 62 | 63 | 64 | def unflatten_master_params(model_params, master_params): 65 | """ 66 | Unflatten the master parameters to look like model_params. 67 | """ 68 | return _unflatten_dense_tensors(master_params[0].detach(), model_params) 69 | 70 | 71 | def zero_grad(model_params): 72 | for param in model_params: 73 | # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group 74 | if param.grad is not None: 75 | param.grad.detach_() 76 | param.grad.zero_() 77 | -------------------------------------------------------------------------------- /src/utils/logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | Logger copied from OpenAI baselines to avoid extra RL-based dependencies: 3 | https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/logger.py 4 | """ 5 | 6 | import os 7 | import sys 8 | import shutil 9 | import os.path as osp 10 | import json 11 | import time 12 | import datetime 13 | import tempfile 14 | import warnings 15 | from collections import defaultdict 16 | from contextlib import contextmanager 17 | # import wandb 18 | 19 | DEBUG = 10 20 | INFO = 20 21 | WARN = 30 22 | ERROR = 40 23 | 24 | DISABLED = 50 25 | 26 | 27 | class KVWriter(object): 28 | def writekvs(self, kvs): 29 | raise NotImplementedError 30 | 31 | 32 | class SeqWriter(object): 33 | def writeseq(self, seq): 34 | raise NotImplementedError 35 | 36 | 37 | class HumanOutputFormat(KVWriter, SeqWriter): 38 | def __init__(self, filename_or_file): 39 | if isinstance(filename_or_file, str): 40 | self.file = open(filename_or_file, "wt") 41 | self.own_file = True 42 | else: 43 | assert hasattr(filename_or_file, "read"), ( 44 | "expected file or str, got %s" % filename_or_file 45 | ) 46 | self.file = filename_or_file 47 | self.own_file = False 48 | 49 | def writekvs(self, kvs): 50 | # Create strings for printing 51 | key2str = {} 52 | for (key, val) in sorted(kvs.items()): 53 | if hasattr(val, "__float__"): 54 | valstr = "%-8.3g" % val 55 | else: 56 | valstr = str(val) 57 | key2str[self._truncate(key)] = self._truncate(valstr) 58 | 59 | # Find max widths 60 | if len(key2str) == 0: 61 | print("WARNING: tried to write empty key-value dict") 62 | return 63 | else: 64 | keywidth = max(map(len, key2str.keys())) 65 | valwidth = max(map(len, key2str.values())) 66 | 67 | # Write out the data 68 | dashes = "-" * (keywidth + valwidth + 7) 69 | lines = [dashes] 70 | for (key, val) in sorted(key2str.items(), key=lambda kv: kv[0].lower()): 71 | lines.append( 72 | "| %s%s | %s%s |" 73 | % (key, " " * (keywidth - len(key)), val, " " * (valwidth - len(val))) 74 | ) 75 | lines.append(dashes) 76 | self.file.write("\n".join(lines) + "\n") 77 | 78 | # Flush the output to the file 79 | self.file.flush() 80 | 81 | def _truncate(self, s): 82 | maxlen = 30 83 | return s[: maxlen - 3] + "..." if len(s) > maxlen else s 84 | 85 | def writeseq(self, seq): 86 | seq = list(seq) 87 | for (i, elem) in enumerate(seq): 88 | self.file.write(elem) 89 | if i < len(seq) - 1: # add space unless this is the last one 90 | self.file.write(" ") 91 | self.file.write("\n") 92 | self.file.flush() 93 | 94 | def close(self): 95 | if self.own_file: 96 | self.file.close() 97 | 98 | 99 | class JSONOutputFormat(KVWriter): 100 | def __init__(self, filename): 101 | self.file = open(filename, "wt") 102 | 103 | def writekvs(self, kvs): 104 | for k, v in sorted(kvs.items()): 105 | if hasattr(v, "dtype"): 106 | kvs[k] = float(v) 107 | self.file.write(json.dumps(kvs) + "\n") 108 | self.file.flush() 109 | 110 | def close(self): 111 | self.file.close() 112 | 113 | 114 | class CSVOutputFormat(KVWriter): 115 | def __init__(self, filename): 116 | self.file = open(filename, "w+t") 117 | self.keys = [] 118 | self.sep = "," 119 | 120 | def writekvs(self, kvs): 121 | # Add our current row to the history 122 | extra_keys = list(kvs.keys() - self.keys) 123 | extra_keys.sort() 124 | if extra_keys: 125 | self.keys.extend(extra_keys) 126 | self.file.seek(0) 127 | lines = self.file.readlines() 128 | self.file.seek(0) 129 | for (i, k) in enumerate(self.keys): 130 | if i > 0: 131 | self.file.write(",") 132 | self.file.write(k) 133 | self.file.write("\n") 134 | for line in lines[1:]: 135 | self.file.write(line[:-1]) 136 | self.file.write(self.sep * len(extra_keys)) 137 | self.file.write("\n") 138 | for (i, k) in enumerate(self.keys): 139 | if i > 0: 140 | self.file.write(",") 141 | v = kvs.get(k) 142 | if v is not None: 143 | self.file.write(str(v)) 144 | self.file.write("\n") 145 | self.file.flush() 146 | 147 | def close(self): 148 | self.file.close() 149 | 150 | 151 | class TensorBoardOutputFormat(KVWriter): 152 | """ 153 | Dumps key/value pairs into TensorBoard's numeric format. 154 | """ 155 | 156 | def __init__(self, dir): 157 | os.makedirs(dir, exist_ok=True) 158 | self.dir = dir 159 | self.step = 1 160 | prefix = "events" 161 | path = osp.join(osp.abspath(dir), prefix) 162 | import tensorflow as tf 163 | from tensorflow.python import pywrap_tensorflow 164 | from tensorflow.core.util import event_pb2 165 | from tensorflow.python.util import compat 166 | 167 | self.tf = tf 168 | self.event_pb2 = event_pb2 169 | self.pywrap_tensorflow = pywrap_tensorflow 170 | self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path)) 171 | 172 | def writekvs(self, kvs): 173 | def summary_val(k, v): 174 | kwargs = {"tag": k, "simple_value": float(v)} 175 | return self.tf.Summary.Value(**kwargs) 176 | 177 | summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()]) 178 | event = self.event_pb2.Event(wall_time=time.time(), summary=summary) 179 | event.step = ( 180 | self.step 181 | ) # is there any reason why you'd want to specify the step? 182 | self.writer.WriteEvent(event) 183 | self.writer.Flush() 184 | self.step += 1 185 | 186 | def close(self): 187 | if self.writer: 188 | self.writer.Close() 189 | self.writer = None 190 | 191 | 192 | def make_output_format(format, ev_dir, log_suffix=""): 193 | os.makedirs(ev_dir, exist_ok=True) 194 | if format == "stdout": 195 | return HumanOutputFormat(sys.stdout) 196 | elif format == "log": 197 | return HumanOutputFormat(osp.join(ev_dir, "log%s.txt" % log_suffix)) 198 | elif format == "json": 199 | return JSONOutputFormat(osp.join(ev_dir, "progress%s.json" % log_suffix)) 200 | elif format == "csv": 201 | return CSVOutputFormat(osp.join(ev_dir, "progress%s.csv" % log_suffix)) 202 | elif format == "tensorboard": 203 | return TensorBoardOutputFormat(osp.join(ev_dir, "tb%s" % log_suffix)) 204 | else: 205 | raise ValueError("Unknown format specified: %s" % (format,)) 206 | 207 | 208 | # ================================================================ 209 | # API 210 | # ================================================================ 211 | 212 | 213 | def logkv(key, val): 214 | """ 215 | Log a value of some diagnostic 216 | Call this once for each diagnostic quantity, each iteration 217 | If called many times, last value will be used. 218 | """ 219 | get_current().logkv(key, val) 220 | 221 | 222 | def logkv_mean(key, val): 223 | """ 224 | The same as logkv(), but if called many times, values averaged. 225 | """ 226 | get_current().logkv_mean(key, val) 227 | 228 | 229 | def logkvs(d): 230 | """ 231 | Log a dictionary of key-value pairs 232 | """ 233 | for (k, v) in d.items(): 234 | logkv(k, v) 235 | 236 | 237 | def dumpkvs(): 238 | """ 239 | Write all of the diagnostics from the current iteration 240 | """ 241 | return get_current().dumpkvs() 242 | 243 | 244 | def getkvs(): 245 | return get_current().name2val 246 | 247 | 248 | def log(*args, level=INFO): 249 | """ 250 | Write the sequence of args, with no separators, to the console and output files (if you've configured an output file). 251 | """ 252 | get_current().log(*args, level=level) 253 | 254 | 255 | def debug(*args): 256 | log(*args, level=DEBUG) 257 | 258 | 259 | def info(*args): 260 | log(*args, level=INFO) 261 | 262 | 263 | def warn(*args): 264 | log(*args, level=WARN) 265 | 266 | 267 | def error(*args): 268 | log(*args, level=ERROR) 269 | 270 | 271 | def set_level(level): 272 | """ 273 | Set logging threshold on current logger. 274 | """ 275 | get_current().set_level(level) 276 | 277 | 278 | def set_comm(comm): 279 | get_current().set_comm(comm) 280 | 281 | 282 | def get_dir(): 283 | """ 284 | Get directory that log files are being written to. 285 | will be None if there is no output directory (i.e., if you didn't call start) 286 | """ 287 | return get_current().get_dir() 288 | 289 | 290 | record_tabular = logkv 291 | dump_tabular = dumpkvs 292 | 293 | 294 | @contextmanager 295 | def profile_kv(scopename): 296 | logkey = "wait_" + scopename 297 | tstart = time.time() 298 | try: 299 | yield 300 | finally: 301 | get_current().name2val[logkey] += time.time() - tstart 302 | 303 | 304 | def profile(n): 305 | """ 306 | Usage: 307 | @profile("my_func") 308 | def my_func(): code 309 | """ 310 | 311 | def decorator_with_name(func): 312 | def func_wrapper(*args, **kwargs): 313 | with profile_kv(n): 314 | return func(*args, **kwargs) 315 | 316 | return func_wrapper 317 | 318 | return decorator_with_name 319 | 320 | 321 | # ================================================================ 322 | # Backend 323 | # ================================================================ 324 | 325 | 326 | def get_current(): 327 | if Logger.CURRENT is None: 328 | _configure_default_logger() 329 | 330 | return Logger.CURRENT 331 | 332 | 333 | class Logger(object): 334 | DEFAULT = None # A logger with no output files. (See right below class definition) 335 | # So that you can still log to the terminal without setting up any output files 336 | CURRENT = None # Current logger being used by the free functions above 337 | 338 | def __init__(self, dir, output_formats, comm=None): 339 | self.name2val = defaultdict(float) # values this iteration 340 | self.name2cnt = defaultdict(int) 341 | self.level = INFO 342 | self.dir = dir 343 | self.output_formats = output_formats 344 | self.comm = comm 345 | 346 | # Logging API, forwarded 347 | # ---------------------------------------- 348 | def logkv(self, key, val): 349 | self.name2val[key] = val 350 | 351 | def logkv_mean(self, key, val): 352 | oldval, cnt = self.name2val[key], self.name2cnt[key] 353 | self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1) 354 | self.name2cnt[key] = cnt + 1 355 | 356 | def dumpkvs(self, prefix=None): 357 | if self.comm is None: 358 | d = self.name2val 359 | else: 360 | d = mpi_weighted_mean( 361 | self.comm, 362 | { 363 | name: (val, self.name2cnt.get(name, 1)) 364 | for (name, val) in self.name2val.items() 365 | }, 366 | ) 367 | if self.comm.rank != 0: 368 | d["dummy"] = 1 # so we don't get a warning about empty dict 369 | # LISA 370 | # wandb.log({**d}) 371 | out = d.copy() # Return the dict for unit testing purposes 372 | for fmt in self.output_formats: 373 | if isinstance(fmt, KVWriter): 374 | fmt.writekvs(d) 375 | self.name2val.clear() 376 | self.name2cnt.clear() 377 | return out 378 | 379 | def log(self, *args, level=INFO): 380 | if self.level <= level: 381 | self._do_log(args) 382 | 383 | # Configuration 384 | # ---------------------------------------- 385 | def set_level(self, level): 386 | self.level = level 387 | 388 | def set_comm(self, comm): 389 | self.comm = comm 390 | 391 | def get_dir(self): 392 | return self.dir 393 | 394 | def close(self): 395 | for fmt in self.output_formats: 396 | fmt.close() 397 | 398 | # Misc 399 | # ---------------------------------------- 400 | def _do_log(self, args): 401 | for fmt in self.output_formats: 402 | if isinstance(fmt, SeqWriter): 403 | fmt.writeseq(map(str, args)) 404 | 405 | 406 | def get_rank_without_mpi_import(): 407 | # check environment variables here instead of importing mpi4py 408 | # to avoid calling MPI_Init() when this module is imported 409 | for varname in ["PMI_RANK", "OMPI_COMM_WORLD_RANK"]: 410 | if varname in os.environ: 411 | return int(os.environ[varname]) 412 | return 0 413 | 414 | 415 | def mpi_weighted_mean(comm, local_name2valcount): 416 | """ 417 | Copied from: https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/mpi_util.py#L110 418 | Perform a weighted average over dicts that are each on a different node 419 | Input: local_name2valcount: dict mapping key -> (value, count) 420 | Returns: key -> mean 421 | """ 422 | all_name2valcount = comm.gather(local_name2valcount) 423 | if comm.rank == 0: 424 | name2sum = defaultdict(float) 425 | name2count = defaultdict(float) 426 | for n2vc in all_name2valcount: 427 | for (name, (val, count)) in n2vc.items(): 428 | try: 429 | val = float(val) 430 | except ValueError: 431 | if comm.rank == 0: 432 | warnings.warn( 433 | "WARNING: tried to compute mean on non-float {}={}".format( 434 | name, val 435 | ) 436 | ) 437 | else: 438 | name2sum[name] += val * count 439 | name2count[name] += count 440 | return {name: name2sum[name] / name2count[name] for name in name2sum} 441 | else: 442 | return {} 443 | 444 | 445 | def configure(dir=None, format_strs=None, comm=None, log_suffix=""): 446 | """ 447 | If comm is provided, average all numerical stats across that comm 448 | """ 449 | if dir is None: 450 | dir = os.getenv("OPENAI_LOGDIR") 451 | if dir is None: 452 | dir = osp.join( 453 | tempfile.gettempdir(), 454 | datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f"), 455 | ) 456 | assert isinstance(dir, str) 457 | dir = os.path.expanduser(dir) 458 | os.makedirs(os.path.expanduser(dir), exist_ok=True) 459 | 460 | rank = get_rank_without_mpi_import() 461 | if rank > 0: 462 | log_suffix = log_suffix + "-rank%03i" % rank 463 | 464 | if format_strs is None: 465 | if rank == 0: 466 | format_strs = os.getenv("OPENAI_LOG_FORMAT", "stdout,log,csv").split(",") 467 | else: 468 | format_strs = os.getenv("OPENAI_LOG_FORMAT_MPI", "log").split(",") 469 | format_strs = filter(None, format_strs) 470 | output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs] 471 | 472 | Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm) 473 | if output_formats: 474 | log("Logging to %s" % dir) 475 | 476 | 477 | def _configure_default_logger(): 478 | configure() 479 | Logger.DEFAULT = Logger.CURRENT 480 | 481 | 482 | def reset(): 483 | if Logger.CURRENT is not Logger.DEFAULT: 484 | Logger.CURRENT.close() 485 | Logger.CURRENT = Logger.DEFAULT 486 | log("Reset logger") 487 | 488 | 489 | @contextmanager 490 | def scoped_configure(dir=None, format_strs=None, comm=None): 491 | prevlogger = Logger.CURRENT 492 | configure(dir=dir, format_strs=format_strs, comm=comm) 493 | try: 494 | yield 495 | finally: 496 | Logger.CURRENT.close() 497 | Logger.CURRENT = prevlogger 498 | 499 | -------------------------------------------------------------------------------- /src/utils/show_sampling_progress.py: -------------------------------------------------------------------------------- 1 | 2 | from typing import List 3 | list_of_colors_from_red_to_blue = [f"\033[38;2;{r};0;{b}m" for r, b in zip(range(255, 0, -10), range(0, 255, 10))] 4 | 5 | def pprint_sentences(sentences: List[str], banner: str = "", sep: str = ""): 6 | """ 7 | Given a list of sentences, prints them with a gradient of colors from red to blue 8 | """ 9 | print() 10 | print(f"\033[1m{'=' * 20} {banner} {'=' * 20}\033[0m") 11 | for i, sentence in enumerate(sentences): 12 | sentence_color = list_of_colors_from_red_to_blue[i] 13 | if i == len(sentences) - 1: 14 | print(f"\033[38;5;{sentence_color}{sentence}\033[0m") 15 | else: 16 | print(f"\033[38;5;{sentence_color}{sentence}\033[0m", end=sep) 17 | print() 18 | 19 | 20 | if __name__ == '__main__': 21 | sentences = [ 22 | "This is a sentence", 23 | "This is another sentence", 24 | "This is a third sentence", 25 | "This is a fourth sentence", 26 | "This is a fifth sentence", 27 | "This is a sixth sentence", 28 | "This is a seventh sentence", 29 | "This is an eighth sentence", 30 | "This is a ninth sentence", 31 | "This is a tenth sentence", 32 | "This is an eleventh sentence", 33 | "This is a twelfth sentence", 34 | "This is a thirteenth sentence", 35 | "This is a fourteenth sentence", 36 | "This is a fifteenth sentence", 37 | "This is a sixteenth sentence", 38 | "This is a seventeenth sentence", 39 | "This is an eighteenth sentence", 40 | "This is a nineteenth sentence", 41 | "This is a twentieth sentence", 42 | ] 43 | for i in range(1, len(sentences) + 1): 44 | pprint_sentences(sentences[:i], sep= " -> ") 45 | print("---") 46 | 47 | -------------------------------------------------------------------------------- /src/utils/test_util.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import numpy as np 3 | 4 | def compute_logp(args, model, x, input_ids): 5 | word_emb = model.weight 6 | sigma = 0.1 7 | if args.model_arch == '1d-unet': 8 | x = x.permute(0, 2, 1) 9 | 10 | bsz, seqlen, dim = x.shape 11 | 12 | x_flat = x.reshape(-1, x.size(-1)).unsqueeze(0) # 1, bsz*sample*seqlen, dim 13 | word_emb_flat = word_emb.unsqueeze(1) # vocab, 1, dim 14 | diff = (x_flat - word_emb_flat) ** 2 # vocab, seqlen, dim 15 | 16 | logp_expanded = -diff.sum(dim=-1) / (2 * sigma ** 2) # vocab, seqlen 17 | logp_expanded = logp_expanded.permute((1, 0)) 18 | # print(th.topk(logp_expanded.view(bsz, seqlen, -1), k=5, dim=-1)[0]) 19 | # print(input_ids[0]) 20 | ce = th.nn.CrossEntropyLoss(reduction='none') 21 | loss = ce(logp_expanded, input_ids.view(-1)).view(bsz, seqlen) 22 | # print(loss[0]) 23 | 24 | # print(loss.shape) 25 | return loss 26 | 27 | def get_weights(model, args): 28 | if hasattr(model, 'transformer'): 29 | input_embs = model.transformer.wte # input_embs 30 | down_proj = model.down_proj 31 | down_proj_emb = down_proj(input_embs.weight) 32 | print(down_proj_emb.shape) 33 | # model = th.nn.Embedding(down_proj_emb.shape[1], down_proj_emb.shape[0]) 34 | model = th.nn.Embedding(down_proj_emb.size(0), down_proj_emb.size(1)) 35 | print(args.emb_scale_factor) 36 | model.weight.data = down_proj_emb * args.emb_scale_factor 37 | 38 | elif hasattr(model, 'weight'): 39 | pass 40 | else: 41 | assert NotImplementedError 42 | 43 | model.weight.requires_grad = False 44 | return model 45 | 46 | def denoised_fn_round(args, model, text_emb, t): 47 | 48 | down_proj_emb = model.weight # input_embs 49 | # print(t) 50 | old_shape = text_emb.shape 51 | old_device = text_emb.device 52 | 53 | def get_efficient_knn(down_proj_emb, text_emb, dist='l2'): 54 | if dist == 'l2': 55 | emb_norm = (down_proj_emb**2).sum(-1).view(-1, 1) #vocab 56 | text_emb_t = th.transpose(text_emb.view(-1, text_emb.size(-1)), 0, 1) #d, bsz*seqlen 57 | arr_norm = (text_emb ** 2).sum(-1).view(-1, 1) #bsz*seqlen, 1 58 | # print(emb_norm.shape, arr_norm.shape) 59 | dist = emb_norm + arr_norm.transpose(0, 1) - 2.0 * th.mm(down_proj_emb, text_emb_t) #(vocab, d) x (d, bsz*seqlen) 60 | dist = th.clamp(dist, 0.0, np.inf) 61 | # print(dist.shape) 62 | topk_out = th.topk(-dist, k=1, dim=0) 63 | # adjacency = down_proj_emb.unsqueeze(1).expand(-1, text_emb.size(0), -1) - text_emb.unsqueeze(0).expand( 64 | # down_proj_emb.size(0), -1, -1) 65 | # adjacency = -th.norm(adjacency, dim=-1) 66 | # topk_out = th.topk(adjacency, k=1, dim=0) 67 | # print(topk_out1.indices == topk_out.indices) 68 | # assert th.all(topk_out1.indices == topk_out.indices) 69 | return topk_out.values, topk_out.indices 70 | 71 | def get_knn(down_proj_emb, text_emb, dist='l2'): 72 | if dist == 'l2': 73 | adjacency = down_proj_emb.unsqueeze(1).expand(-1, text_emb.size(0), -1) - text_emb.unsqueeze(0).expand( 74 | down_proj_emb.size(0), -1, -1) 75 | adjacency = -th.norm(adjacency, dim=-1) 76 | topk_out = th.topk(adjacency, k=1, dim=0) 77 | return topk_out.values, topk_out.indices 78 | 79 | dist = 'l2' 80 | if len(text_emb.shape) > 2: 81 | text_emb = text_emb.reshape(-1, text_emb.size(-1)) 82 | else: 83 | text_emb = text_emb 84 | # val, indices = get_knn(down_proj_emb, 85 | # text_emb.to(down_proj_emb.device), dist=dist) 86 | val, indices = get_efficient_knn(down_proj_emb, 87 | text_emb.to(down_proj_emb.device), dist=dist) 88 | rounded_tokens = indices[0] 89 | # print(rounded_tokens.shape) 90 | new_embeds = model(rounded_tokens).view(old_shape).to(old_device) 91 | if args.model_arch == '1d-unet': 92 | new_embeds = new_embeds.permute(0, 2, 1) 93 | return new_embeds 94 | 95 | def load_results(json_path, load_dict): 96 | import json 97 | with open(json_path, 'w') as f: 98 | json.dump(load_dict, f, indent=2) 99 | -------------------------------------------------------------------------------- /temp.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | import pandas as pd 5 | 6 | 7 | a=pd.read_csv('./data/pubchem_MolWt200-1000_iupac.csv',sep='|',header=0) 8 | 9 | 10 | #mask = (a['PUBCHEM_IUPAC_NAME'].str.len() <=512) 11 | 12 | #a = a.loc[mask] 13 | 14 | #a=a[a.PUBCHEM_IUPAC_NAME.astype(str).str.len()<=512] 15 | from iupac_tokenization import get_iupac_tokenizer 16 | from smile_tokenization import get_smiles_tokenizer 17 | 18 | iupac_tokenizer = get_iupac_tokenizer(is_train=1,full_path ='./data') 19 | smiles_tokenizer = get_smiles_tokenizer(is_train=1,checkpoint = "./data/smile_tocken") 20 | 21 | def seq_valid_iupac(iupac_tokenizer,myline): 22 | iupac_encoded = iupac_tokenizer(myline) 23 | if iupac_encoded["input_ids"].count(2)==1: 24 | return 1 25 | else: 26 | return 0 27 | 28 | def seq_valid_smiles(smiles_tokenizer,myline): 29 | iupac_encoded = smiles_tokenizer(myline) 30 | if iupac_encoded["input_ids"].count(1)==1: 31 | return 1 32 | else: 33 | return 0 34 | 35 | a['PUBCHEM_IUPAC_NAME_if'] = a['PUBCHEM_IUPAC_NAME'].apply(lambda x :seq_valid_iupac(iupac_tokenizer,x)) 36 | a['canon_smiles_if'] = a['canon_smiles'].apply(lambda x :seq_valid_smiles(smiles_tokenizer,x)) 37 | 38 | a = a[(a['PUBCHEM_IUPAC_NAME_if']==1)&(a['canon_smiles_if']==1)] 39 | 40 | #a[['PUBCHEM_IUPAC_NAME']].to_csv('./data/pubchem_iupac_valid.csv',header=None,index=None,sep='|') 41 | #a[['canon_smiles']].to_csv('./data/pubchem_smiles_valid.csv',header=None,index=None,sep='|') 42 | 43 | 44 | b=a.iloc[0:30000000] 45 | 46 | b[['PUBCHEM_IUPAC_NAME']].to_csv('./data/pubchem_iupac_train_3qw.csv',header=None,index=None,sep='|') 47 | b[['canon_smiles']].to_csv('./data/pubchem_smiles_train_3qw.csv',header=None,index=None,sep='|') 48 | 49 | c=a.iloc[30000000:] 50 | 51 | c[['PUBCHEM_IUPAC_NAME']].to_csv('./data/pubchem_iupac_valid_3qw.csv',header=None,index=None,sep='|') 52 | c[['canon_smiles']].to_csv('./data/pubchem_smiles_valid_3qw.csv',header=None,index=None,sep='|') 53 | -------------------------------------------------------------------------------- /tokenizer_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import pathlib 4 | import torch 5 | from transformers import AutoTokenizer 6 | import os.path as pt 7 | from tokenizers.processors import BertProcessing 8 | from tokenizers import ByteLevelBPETokenizer, decoders 9 | from iupac_tokenization import get_iupac_tokenizer 10 | from smile_tokenization import get_smiles_tokenizer 11 | 12 | logging.basicConfig(level=logging.INFO) 13 | 14 | 15 | def create_iupac_smiles_tokenizer(return_pretokenized, path, tokenizer_ckpt: str = './data'): 16 | 17 | if return_pretokenized: 18 | print(f'*******use pretrained iupac_tokenizer*****{return_pretokenized}*******') 19 | print(pt.join(tokenizer_ckpt,"real_iupac_tokenizer.pt")) 20 | iupac_tokenizer = get_iupac_tokenizer(is_train=1,full_path =tokenizer_ckpt) 21 | smiles_tokenizer = get_smiles_tokenizer(is_train=1,checkpoint = "./data/smile_tocken") 22 | return iupac_tokenizer,smiles_tokenizer 23 | else: 24 | return None,None 25 | 26 | def create_tokenizer(return_pretokenized, path, tokenizer_type: str = "word-level", tokenizer_ckpt: str = None): 27 | 28 | if return_pretokenized: 29 | print(f'*******use pretrained tokenizer*****{return_pretokenized}*******') 30 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_ckpt) 31 | return tokenizer 32 | 33 | if tokenizer_type == "byte-level": 34 | return read_byte_level(path) 35 | elif tokenizer_type == "word-level": 36 | return read_word_level(path) 37 | else: 38 | raise ValueError(f"Invalid tokenizer type: {tokenizer_type}") 39 | 40 | def train_bytelevel( 41 | path, #list 42 | save_path, 43 | vocab_size=10000, 44 | min_frequency=1, 45 | special_tokens=["", "", "", "", ""], 46 | ): 47 | 48 | tokenizer = ByteLevelBPETokenizer() 49 | 50 | # Customize training 51 | tokenizer.train( 52 | files=path, 53 | vocab_size=vocab_size, 54 | min_frequency=min_frequency, 55 | special_tokens=special_tokens, 56 | ) 57 | 58 | tokenizer.save_model(str(pathlib.Path(save_path))) 59 | 60 | def read_byte_level(path: str): 61 | tokenizer = ByteLevelBPETokenizer( 62 | f"{path}/vocab.json", 63 | f"{path}/merges.txt", 64 | ) 65 | 66 | tokenizer._tokenizer.post_processor = BertProcessing( 67 | ("", tokenizer.token_to_id("")), 68 | ("", tokenizer.token_to_id("")), 69 | ) 70 | 71 | tokenizer.enable_truncation(max_length=512) 72 | 73 | with open(f"{path}/vocab.json", "r") as fin: 74 | vocab = json.load(fin) 75 | 76 | # add length method to tokenizer object 77 | tokenizer.vocab_size = len(vocab) 78 | 79 | # add length property to tokenizer object 80 | tokenizer.__len__ = property(lambda self: self.vocab_size) 81 | 82 | tokenizer.decoder = decoders.ByteLevel() 83 | print(tokenizer.vocab_size) 84 | 85 | print( 86 | tokenizer.encode( 87 | "Bores can be divided into two classes; those who have their own particular subject, and those who do not need a subject." 88 | ).ids 89 | ) 90 | 91 | print( 92 | tokenizer.decode( 93 | tokenizer.encode( 94 | "Bores can be divided into two classes; those who have their own particular subject, and those who do not need a subject." 95 | ).ids, 96 | skip_special_tokens=True, 97 | ) 98 | ) 99 | 100 | ids = tokenizer.encode( 101 | "Bores can be divided into two classes; those who have their own particular subject, and those who do not need a subject." 102 | ).ids 103 | tensor = torch.tensor(ids) 104 | print(tokenizer.decode(tensor.tolist(), skip_special_tokens=True)) 105 | print(f"Vocab size: {tokenizer.vocab_size}") 106 | 107 | return tokenizer 108 | 109 | 110 | def read_word_level(path: str): 111 | 112 | from transformers import PreTrainedTokenizerFast 113 | 114 | logging.info(f"Loading tokenizer from {path}/word-level-vocab.json") 115 | tokenizer = PreTrainedTokenizerFast( 116 | tokenizer_file=f"{str(pathlib.Path(path))}/word-level-vocab.json", 117 | bos_token="[CLS]", 118 | eos_token="[SEP]", 119 | unk_token="[UNK]", 120 | sep_token="[SEP]", 121 | pad_token="[PAD]", 122 | cls_token="[CLS]", 123 | mask_token="[MASK]", 124 | padding_side="right", 125 | ) 126 | 127 | # add length property to tokenizer object 128 | tokenizer.__len__ = property(lambda self: self.vocab_size) 129 | 130 | return tokenizer 131 | 132 | 133 | def train_word_level_tokenizer( 134 | path: str, 135 | vocab_size: int = 10000, 136 | special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"], 137 | ): 138 | 139 | from tokenizers import Tokenizer, normalizers, pre_tokenizers 140 | from tokenizers.models import WordLevel 141 | from tokenizers.normalizers import NFD, Lowercase, StripAccents 142 | from tokenizers.pre_tokenizers import Digits, Whitespace 143 | from tokenizers.processors import TemplateProcessing 144 | from tokenizers.trainers import WordLevelTrainer 145 | 146 | tokenizer = Tokenizer(WordLevel(unk_token="[UNK]")) 147 | tokenizer.normalizer = normalizers.Sequence([NFD(), Lowercase(), StripAccents()]) 148 | tokenizer.pre_tokenizer = pre_tokenizers.Sequence( 149 | [Digits(individual_digits=True), Whitespace()] 150 | ) 151 | tokenizer.post_processor = TemplateProcessing( 152 | single="[CLS] $A [SEP]", special_tokens=[("[CLS]", 1), ("[SEP]", 2)] 153 | ) 154 | 155 | trainer = WordLevelTrainer(vocab_size=vocab_size, special_tokens=special_tokens) 156 | tokenizer.train(files=[path], trainer=trainer) 157 | 158 | tokenizer.__len__ = property(lambda self: self.vocab_size) 159 | 160 | tokenizer.enable_truncation(max_length=512) 161 | 162 | print(tokenizer.encode("the red.").ids) 163 | 164 | print(tokenizer.encode("the red.")) 165 | 166 | tokenizer.save(f"{str(pathlib.Path(path).parent)}/word-level-vocab.json") 167 | 168 | 169 | if __name__ == "__main__": 170 | import sys 171 | import os 172 | 173 | if sys.argv[1] == "train-word-level": 174 | train_word_level_tokenizer(path=sys.argv[2]) 175 | elif sys.argv[1] == "train-byte-level": 176 | path = f"./data/{sys.argv[2]}/" 177 | data_path = [path + item for item in os.listdir(path) if 'train' in item] 178 | train_bytelevel(path=data_path, vocab_size=int(sys.argv[3])+5, save_path=path) 179 | elif sys.argv[1] == "create": 180 | create_tokenizer(path=sys.argv[2]) 181 | -------------------------------------------------------------------------------- /train_scripts/gen_opt.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #!/bin/bash 4 | 5 | MODEL_DIR=./ckpts/wjm_ckpts/wjm14_128_0.0001_2000_1000000_10000_schegran20000_srciupac_tgtsmiles 6 | MODEL_NAME=${MODEL_DIR}/$1 7 | OUT_DIR=${2} 8 | SCHEDULE_PATH=${MODEL_DIR}/${3} 9 | VAL_TXT=./data/${4}/seed 10 | SEED=${5:-10708} 11 | 12 | if [ -z "$OUT_DIR" ]; then 13 | OUT_DIR=${MODEL_NAME} 14 | fi 15 | 16 | GEN_BY_Q=${6:-"False"} 17 | GEN_BY_MIX=${7:-"True"} 18 | MIX_PROB=${8:-1} 19 | MIX_PART=${9:-0} 20 | TOP_P=-1 21 | CLAMP="no_clamp" 22 | BATCH_SIZE=50 23 | SEQ_LEN=128 24 | DIFFUSION_STEPS=2000 25 | NUM_SAMPLES=-1 26 | 27 | 28 | 29 | python -u inference_main_opt_input.py --model_name_or_path ${MODEL_NAME} --sequence_len_src 1024 \ 30 | --batch_size ${BATCH_SIZE} --num_samples ${NUM_SAMPLES} --top_p ${TOP_P} --time_schedule_path ${SCHEDULE_PATH} \ 31 | --seed ${SEED} --val_txt_path ${VAL_TXT} --generate_by_q ${GEN_BY_Q} --generate_by_mix ${GEN_BY_MIX} \ 32 | --out_dir ${OUT_DIR} --diffusion_steps ${DIFFUSION_STEPS} --clamp ${CLAMP} --sequence_len ${SEQ_LEN} \ 33 | --generate_by_mix_prob ${MIX_PROB} --generate_by_mix_part ${MIX_PART} 34 | 35 | 36 | 37 | 38 | #bash ./inference_scrpts/non_translation_inf_opt.sh ema_0.9999_160000.pt ./output alpha_cumprod_step_160000.npy wjm14 39 | #https://mp.weixin.qq.com/s/v8wHFwRm3IbGgnGD0syTgw -------------------------------------------------------------------------------- /train_scripts/wjm_iupac_smiles.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -u 3 | 4 | 5 | GPU=${1} 6 | NUM_GPUS=1 7 | LOSS_FUNC="uniform" 8 | SRC=${2:-'iupac'} 9 | TGT=${3:-'smiles'} 10 | LR=0.0001 11 | SEQ_LEN=128 12 | WARMUP=10000 13 | SCHEDULE_UPDATE_STRIDE=20000 14 | DSET="wjm_ckpts" 15 | UPDATE_GRANU=20 16 | INIT_PRETRAINED_MODEL="False" 17 | USE_PRETRAINED_EMBEDDINGS="False" 18 | FREEZE_EMBEDDINGS="False" 19 | LR_ANNEAL_STEPS=1000000 20 | DIFFUSION_STEPS=2000 21 | NOISE_SCHEDULE=sqrt 22 | BATCH_SIZE=64 23 | 24 | 25 | CHECKPOINT_PATH="ckpts/${DSET}/wjm14_${SEQ_LEN}_${LR}_${DIFFUSION_STEPS}_${LR_ANNEAL_STEPS}_${WARMUP}_schegran${SCHEDULE_UPDATE_STRIDE}_src${SRC}_tgt${TGT}" 26 | TRAIN_TXT_PATH="./data/wjm14/train" 27 | VAL_TXT_PATH="./data/wjm14/valid" 28 | IN_CHANNELS=512 29 | WEIGHT_DECAY=0.0 30 | SEED=10708 31 | DROPOUT=0.3 32 | NUM_HEADS=8 33 | CONFIG_NAME="facebook/bart-base" 34 | NOTES="wjm14 training with noise schedule and self condition" 35 | 36 | mkdir -p ${CHECKPOINT_PATH} 37 | mkdir -p ${CHECKPOINT_PATH}/log/ 38 | export DIFFUSION_BLOB_LOGDIR=${CHECKPOINT_PATH}/log/ 39 | 40 | 41 | ARGS=(--checkpoint_path ${CHECKPOINT_PATH} 42 | --save_interval ${WARMUP} --lr ${LR} 43 | --batch_size ${BATCH_SIZE} 44 | --src ${SRC} 45 | --tgt ${TGT} 46 | --diffusion_steps ${DIFFUSION_STEPS} 47 | --noise_schedule ${NOISE_SCHEDULE} 48 | --sequence_len ${SEQ_LEN} --seed ${SEED} 49 | --weight_decay ${WEIGHT_DECAY} 50 | --predict_xstart True 51 | --train_txt_path ${TRAIN_TXT_PATH} 52 | --dataset "wjm14" 53 | --val_txt_path ${VAL_TXT_PATH} 54 | --config_name ${CONFIG_NAME} 55 | --init_pretrained ${INIT_PRETRAINED_MODEL} 56 | --freeze_embeddings ${FREEZE_EMBEDDINGS} 57 | --use_pretrained_embeddings ${USE_PRETRAINED_EMBEDDINGS} 58 | --notes \""${NOTES}"\") 59 | 60 | if [ ${LR_ANNEAL_STEPS} -eq 0 ]; then 61 | LR_ANNEAL_STEPS=100 62 | DEBUG=true 63 | else 64 | DEBUG=false 65 | fi 66 | 67 | ARGS+=(--lr_anneal_steps $LR_ANNEAL_STEPS) 68 | 69 | if [ $DEBUG = true ]; then 70 | ARGS+=(--debug) 71 | fi 72 | 73 | ARGS+=(--encoder_layers 6 74 | --decoder_layers 6 75 | --num_heads 8 76 | --num_heads 8 77 | --in_channel 512 78 | --out_channel 512 79 | --num_channels 2048 80 | --sequence_len_src 1024 81 | --warmup $WARMUP 82 | --schedule_sampler $LOSS_FUNC 83 | --loss_update_granu $UPDATE_GRANU 84 | --schedule_update_stride $SCHEDULE_UPDATE_STRIDE) 85 | 86 | export CUDA_VISIBLE_DEVICES=$GPU && mpiexec -n $NUM_GPUS python -u main.py "${ARGS[@]}" 87 | 88 | 89 | -------------------------------------------------------------------------------- /train_scripts/wjm_iupac_smiles_retrain.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -u 3 | 4 | 5 | GPU=${1} 6 | NUM_GPUS=1 7 | LOSS_FUNC="uniform" 8 | SRC=${2:-'iupac'} 9 | TGT=${3:-'smiles'} 10 | LR=0.0001 11 | SEQ_LEN=128 12 | WARMUP=10000 13 | SCHEDULE_UPDATE_STRIDE=20000 14 | DSET="wjm_ckpts" 15 | UPDATE_GRANU=20 16 | INIT_PRETRAINED_MODEL="False" 17 | USE_PRETRAINED_EMBEDDINGS="False" 18 | FREEZE_EMBEDDINGS="False" 19 | LR_ANNEAL_STEPS=1000000 20 | DIFFUSION_STEPS=2000 21 | NOISE_SCHEDULE=sqrt 22 | BATCH_SIZE=64 23 | 24 | 25 | CHECKPOINT_PATH="ckpts/${DSET}/wjm14_${SEQ_LEN}_${LR}_${DIFFUSION_STEPS}_${LR_ANNEAL_STEPS}_${WARMUP}_schegran${SCHEDULE_UPDATE_STRIDE}_src${SRC}_tgt${TGT}" 26 | TRAIN_TXT_PATH="./data/wjm14/train" 27 | VAL_TXT_PATH="./data/wjm14/valid" 28 | IN_CHANNELS=512 29 | WEIGHT_DECAY=0.0 30 | SEED=10708 31 | DROPOUT=0.3 32 | NUM_HEADS=8 33 | CONFIG_NAME="facebook/bart-base" 34 | NOTES="wjm14 training with noise schedule and self condition" 35 | 36 | mkdir -p ${CHECKPOINT_PATH} 37 | mkdir -p ${CHECKPOINT_PATH}/log/ 38 | export DIFFUSION_BLOB_LOGDIR=${CHECKPOINT_PATH}/log/ 39 | 40 | 41 | ARGS=(--checkpoint_path ${CHECKPOINT_PATH} 42 | --save_interval ${WARMUP} --lr ${LR} 43 | --batch_size ${BATCH_SIZE} 44 | --src ${SRC} 45 | --tgt ${TGT} 46 | --diffusion_steps ${DIFFUSION_STEPS} 47 | --noise_schedule ${NOISE_SCHEDULE} 48 | --sequence_len ${SEQ_LEN} --seed ${SEED} 49 | --weight_decay ${WEIGHT_DECAY} 50 | --predict_xstart True 51 | --train_txt_path ${TRAIN_TXT_PATH} 52 | --dataset "wjm14" 53 | --val_txt_path ${VAL_TXT_PATH} 54 | --config_name ${CONFIG_NAME} 55 | --init_pretrained ${INIT_PRETRAINED_MODEL} 56 | --freeze_embeddings ${FREEZE_EMBEDDINGS} 57 | --use_pretrained_embeddings ${USE_PRETRAINED_EMBEDDINGS} 58 | --notes \""${NOTES}"\") 59 | 60 | if [ ${LR_ANNEAL_STEPS} -eq 0 ]; then 61 | LR_ANNEAL_STEPS=100 62 | DEBUG=true 63 | else 64 | DEBUG=false 65 | fi 66 | 67 | ARGS+=(--lr_anneal_steps $LR_ANNEAL_STEPS) 68 | 69 | if [ $DEBUG = true ]; then 70 | ARGS+=(--debug) 71 | fi 72 | 73 | ARGS+=(--encoder_layers 6 74 | --decoder_layers 6 75 | --num_heads 8 76 | --num_heads 8 77 | --in_channel 512 78 | --out_channel 512 79 | --num_channels 2048 80 | --sequence_len_src 1024 81 | --warmup $WARMUP 82 | --schedule_sampler $LOSS_FUNC 83 | --loss_update_granu $UPDATE_GRANU 84 | --schedule_update_stride $SCHEDULE_UPDATE_STRIDE) 85 | 86 | 87 | MODEL_NAME=/root/autodl-tmp/wjm/SeqDiffuSeq-main/ckpts/wjm_ckpts/wjm14_128_0.0001_2000_1000000_10000_schegran20000_srciupac_tgtsmiles/ema_0.9999_310000.pt 88 | SCHEDULE_PATH=/root/autodl-tmp/wjm/SeqDiffuSeq-main/ckpts/wjm_ckpts/wjm14_128_0.0001_2000_1000000_10000_schegran20000_srciupac_tgtsmiles/alpha_cumprod_step_320000.npy 89 | 90 | ARGS+=(--model_name_or_path ${MODEL_NAME} 91 | --time_schedule_path ${SCHEDULE_PATH}) 92 | 93 | 94 | export CUDA_VISIBLE_DEVICES=$GPU && mpiexec -n $NUM_GPUS python -u main_retrain.py "${ARGS[@]}" 95 | 96 | 97 | #bash ./train_scripts/wjm_iupac_smiles_retrain.sh 0 iupac smiles -------------------------------------------------------------------------------- /train_spm.py: -------------------------------------------------------------------------------- 1 | import sentencepiece as spm 2 | import sys 3 | from collections import Counter 4 | 5 | # file with a list of IUPAC names (can be just 1 line if you want) 6 | #iupacs_fn = int(sys.argv[1]) 7 | 8 | 9 | with open("opsin_vocab_reduced.txt", "r") as f: 10 | words = f.read().split("\n") 11 | words = list(map(str, range(100))) + words 12 | 13 | smile_atom =[ 14 | 'Ac', 'Ag', 'Al', 'Am', 'Ar', 'As', 'At', 'Au', 'B', 'Ba', 'Be', 'Bh', 15 | 'Bi', 'Bk', 'Br', 'C', 'Ca', 'Cd', 'Ce', 'Cf', 'Cl', 'Cm', 'Co', 'Cr', 16 | 'Cs', 'Cu', 'Db', 'Dy', 'Er', 'Es', 'Eu', 'F', 'Fe', 'Fm', 'Fr', 'Ga', 17 | 'Gd', 'Ge', 'H', 'He', 'Hf', 'Hg', 'Ho', 'Hs', 'I', 'In', 'Ir', 'K', 18 | 'Kr', 'La', 'Li', 'Lr', 'Lu', 'Md', 'Mg', 'Mn', 'Mo', 'Mt', 'N', 'Na', 19 | 'Nb', 'Nd', 'Ne', 'Ni', 'No', 'Np', 'O', 'Os', 'P', 'Pa', 'Pb', 'Pd', 20 | 'Pm', 'Po', 'Pr', 'Pt', 'Pu', 'Ra', 'Rb', 'Re', 'Rf', 'Rh', 'Rn', 21 | 'Ru', 'S', 'Sb', 'Sc', 'Se', 'Sg', 'Si', 'Sm', 'Sn', 'Sr', 'Ta', 'Tb', 22 | 'Tc', 'Te', 'Th', 'Ti', 'Tl', 'Tm', 'U', 'V', 'W', 'Xe', 'Y', 'Yb', 23 | 'Zn', 'Zr' 24 | ] 25 | smile_non_atom = [ 26 | '-', '=', '#', ':', '(', ')', '.', '[', ']', '+', '-', '\\', '/', '*', 27 | #'1', '2', '3', '4', '5', '6', '7', '8', '9', '0', 28 | '@', 'AL', 'TH', 'SP', 'TB', 'OH', 29 | ] 30 | 31 | #words = smile_atom+smile_non_atom+words 32 | 33 | words = list(set(words)) 34 | 35 | vocab_size = len(words) + 1+100 36 | 37 | user_defined_symbols = words 38 | 39 | print("num user defined:", len(user_defined_symbols)) 40 | 41 | args = {"input": sys.argv[1], 42 | "model_type": "unigram", 43 | "model_prefix": "iupac_spm".format(vocab_size), 44 | "vocab_size": vocab_size, 45 | "input_sentence_size": 50000, 46 | "shuffle_input_sentence": True, 47 | "user_defined_symbols": user_defined_symbols, 48 | "split_by_number": False, 49 | "split_by_whitespace": False, 50 | "hard_vocab_limit": False, 51 | "max_sentencepiece_length": 320, 52 | "character_coverage": 0.99, 53 | "pad_id": 0, 54 | "eos_id": 1, 55 | "unk_id": 2, 56 | "bos_id": -1 57 | } 58 | #"train_extremely_large_corpus": True 59 | 60 | spm.SentencePieceTrainer.train(**args) 61 | --------------------------------------------------------------------------------