├── .gitignore ├── CODEOWNERS ├── CODE_OF_CONDUCT.md ├── LICENSE.txt ├── README.md ├── SECURITY.md └── progen2 ├── README.md ├── likelihood.py ├── models └── progen │ ├── configuration_progen.py │ └── modeling_progen.py ├── requirements.txt ├── sample.py └── tokenizer.json /.gitignore: -------------------------------------------------------------------------------- 1 | **/__pycache__ 2 | **/.venv 3 | **/.vscode 4 | **/checkpoints -------------------------------------------------------------------------------- /CODEOWNERS: -------------------------------------------------------------------------------- 1 | # Comment line immediately above ownership line is reserved for related other information. Please be careful while editing. 2 | #ECCN:Open Source 3 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Salesforce Open Source Community Code of Conduct 2 | 3 | ## About the Code of Conduct 4 | 5 | Equality is a core value at Salesforce. We believe a diverse and inclusive 6 | community fosters innovation and creativity, and are committed to building a 7 | culture where everyone feels included. 8 | 9 | Salesforce open-source projects are committed to providing a friendly, safe, and 10 | welcoming environment for all, regardless of gender identity and expression, 11 | sexual orientation, disability, physical appearance, body size, ethnicity, nationality, 12 | race, age, religion, level of experience, education, socioeconomic status, or 13 | other similar personal characteristics. 14 | 15 | The goal of this code of conduct is to specify a baseline standard of behavior so 16 | that people with different social values and communication styles can work 17 | together effectively, productively, and respectfully in our open source community. 18 | It also establishes a mechanism for reporting issues and resolving conflicts. 19 | 20 | All questions and reports of abusive, harassing, or otherwise unacceptable behavior 21 | in a Salesforce open-source project may be reported by contacting the Salesforce 22 | Open Source Conduct Committee at ossconduct@salesforce.com. 23 | 24 | ## Our Pledge 25 | 26 | In the interest of fostering an open and welcoming environment, we as 27 | contributors and maintainers pledge to making participation in our project and 28 | our community a harassment-free experience for everyone, regardless of gender 29 | identity and expression, sexual orientation, disability, physical appearance, 30 | body size, ethnicity, nationality, race, age, religion, level of experience, education, 31 | socioeconomic status, or other similar personal characteristics. 32 | 33 | ## Our Standards 34 | 35 | Examples of behavior that contributes to creating a positive environment 36 | include: 37 | 38 | * Using welcoming and inclusive language 39 | * Being respectful of differing viewpoints and experiences 40 | * Gracefully accepting constructive criticism 41 | * Focusing on what is best for the community 42 | * Showing empathy toward other community members 43 | 44 | Examples of unacceptable behavior by participants include: 45 | 46 | * The use of sexualized language or imagery and unwelcome sexual attention or 47 | advances 48 | * Personal attacks, insulting/derogatory comments, or trolling 49 | * Public or private harassment 50 | * Publishing, or threatening to publish, others' private information—such as 51 | a physical or electronic address—without explicit permission 52 | * Other conduct which could reasonably be considered inappropriate in a 53 | professional setting 54 | * Advocating for or encouraging any of the above behaviors 55 | 56 | ## Our Responsibilities 57 | 58 | Project maintainers are responsible for clarifying the standards of acceptable 59 | behavior and are expected to take appropriate and fair corrective action in 60 | response to any instances of unacceptable behavior. 61 | 62 | Project maintainers have the right and responsibility to remove, edit, or 63 | reject comments, commits, code, wiki edits, issues, and other contributions 64 | that are not aligned with this Code of Conduct, or to ban temporarily or 65 | permanently any contributor for other behaviors that they deem inappropriate, 66 | threatening, offensive, or harmful. 67 | 68 | ## Scope 69 | 70 | This Code of Conduct applies both within project spaces and in public spaces 71 | when an individual is representing the project or its community. Examples of 72 | representing a project or community include using an official project email 73 | address, posting via an official social media account, or acting as an appointed 74 | representative at an online or offline event. Representation of a project may be 75 | further defined and clarified by project maintainers. 76 | 77 | ## Enforcement 78 | 79 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 80 | reported by contacting the Salesforce Open Source Conduct Committee 81 | at ossconduct@salesforce.com. All complaints will be reviewed and investigated 82 | and will result in a response that is deemed necessary and appropriate to the 83 | circumstances. The committee is obligated to maintain confidentiality with 84 | regard to the reporter of an incident. Further details of specific enforcement 85 | policies may be posted separately. 86 | 87 | Project maintainers who do not follow or enforce the Code of Conduct in good 88 | faith may face temporary or permanent repercussions as determined by other 89 | members of the project's leadership and the Salesforce Open Source Conduct 90 | Committee. 91 | 92 | ## Attribution 93 | 94 | This Code of Conduct is adapted from the [Contributor Covenant][contributor-covenant-home], 95 | version 1.4, available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html. 96 | It includes adaptions and additions from [Go Community Code of Conduct][golang-coc], 97 | [CNCF Code of Conduct][cncf-coc], and [Microsoft Open Source Code of Conduct][microsoft-coc]. 98 | 99 | This Code of Conduct is licensed under the [Creative Commons Attribution 3.0 License][cc-by-3-us]. 100 | 101 | [contributor-covenant-home]: https://www.contributor-covenant.org (https://www.contributor-covenant.org/) 102 | [golang-coc]: https://golang.org/conduct 103 | [cncf-coc]: https://github.com/cncf/foundation/blob/master/code-of-conduct.md 104 | [microsoft-coc]: https://opensource.microsoft.com/codeofconduct/ 105 | [cc-by-3-us]: https://creativecommons.org/licenses/by/3.0/us/ 106 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2022, Salesforce.com, Inc. 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 9 | 10 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 11 | 12 | 3. Neither the name of Salesforce.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 13 | 14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 15 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## ProGen: Language Modeling for Protein Engineering 2 | 3 | Suite of open-sourced projects and models for protein engineering and design. 4 | 5 | ### License 6 | Our code and models are BSD-3 licensed. See LICENSE.txt for details. 7 | 8 | ### Ethics 9 | Predicting the fitness of a protein sequence and capturing the distribution of natural proteins for generative purposes could be a powerful tool for protein design. If our technique or a future iteration thereof is adopted broadly, care should be taken in terms of the end use-cases of these designed samples and downstream effects to ensure safe, non-nefarious, and ethical applications. For projects in any domain, active oversight during project initiation, experimental optimization, and deployment phases should be put in place to ensure safe usage and limitation of unintended harmful effects. 10 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | ## Security 2 | 3 | Please report any security issue to [security@salesforce.com](mailto:security@salesforce.com) 4 | as soon as it is discovered. This library limits its runtime dependencies in 5 | order to reduce the total cost of ownership as much as can be, but all consumers 6 | should remain vigilant and have their security stakeholders review all third-party 7 | products (3PP) like this one and their dependencies. 8 | -------------------------------------------------------------------------------- /progen2/README.md: -------------------------------------------------------------------------------- 1 | # ProGen2 2 | Official release of the **ProGen2** models (`151M`, `764M`, `2.7B`, `6.4B`) for **Protein Engineering** (see paper). 3 | 4 | ## Models 5 | 6 | | Model | Size | Checkpoint | 7 | | ------ | ------ | ---------- | 8 | | progen2-small | `151M` | https://storage.googleapis.com/sfr-progen-research/checkpoints/progen2-small.tar.gz | 9 | | progen2-medium | `764M` | https://storage.googleapis.com/sfr-progen-research/checkpoints/progen2-medium.tar.gz | 10 | | progen2-oas | `764M` | https://storage.googleapis.com/sfr-progen-research/checkpoints/progen2-oas.tar.gz | 11 | | progen2-base | `764M` | https://storage.googleapis.com/sfr-progen-research/checkpoints/progen2-base.tar.gz | 12 | | progen2-large | `2.7B` | https://storage.googleapis.com/sfr-progen-research/checkpoints/progen2-large.tar.gz | 13 | | progen2-BFD90 | `2.7B` | https://storage.googleapis.com/sfr-progen-research/checkpoints/progen2-BFD90.tar.gz | 14 | | progen2-xlarge | `6.4B` | https://storage.googleapis.com/sfr-progen-research/checkpoints/progen2-xlarge.tar.gz | 15 | 16 | ## Setup 17 | ```sh 18 | # code 19 | git clone https://github.com/salesforce/progen 20 | cd progen/progen2 21 | 22 | # checkpoint 23 | model=progen2-large 24 | wget -P checkpoints/${model} https://storage.googleapis.com/sfr-progen-research/checkpoints/${model}.tar.gz 25 | tar -xvf checkpoints/${model}/${model}.tar.gz -C checkpoints/${model}/ 26 | 27 | # venv 28 | python3.8 -m venv .venv 29 | source .venv/bin/activate 30 | pip3 install --upgrade pip setuptools 31 | pip3 install -r requirements.txt 32 | 33 | # sample 34 | python3 sample.py --model ${model} --t 0.8 --p 0.9 --max-length 1024 --num-samples 2 --context "1" 35 | 36 | # log-likelihood (GenBank: TMF32756.1) 37 | python3 likelihood.py --model ${model} --context "1MGHGVSRPPVVTLRPAVLDDCPVLWRWRNDPETRQASVDEREIPVDTHTRWFEETLKRFDRKLFIVSADGVDAGMVRLDIQDRDAAVSVNIAPEWRGRGVGPRALGCLSREAFGPLALLRMSAVVKRENAASRIAFERAGFTVVDTGGPLLHSSKARLHVVAAIQARMGSTRLPGKVLVSIAGRPTIQRIAERLAVCQELDAVAVSTSVENRDDAIADLAAHLGLVCVRGSETDLIERLGRTAARTGADALVRITADCPLVDPALVDRVVGVWRRSAGRLEYVSNVFPPTFPDGLDVEVLSRTVLERLDREVSDPFFRESLTAYVREHPAAFEIANVEHPEDLSRLRWTMDYPEDLAFVEAVYRRLGNQGEIFGMDDLLRLLEWSPELRDLNRCREDVTVERGIRGTGYHAALRARGQAP2" 38 | ``` 39 | 40 | ## Citation 41 | If you find our code or paper useful, please cite: 42 | ```bibtex 43 | @article{ProGen2, 44 | title={ProGen2: Exploring the Boundaries of Protein Language Models}, 45 | author={Nijkamp, Erik and Ruffolo, Jeffrey and Weinstein, Eli N. and Naik, Nikhil and Madani, Ali}, 46 | journal={arXiv}, 47 | year={2022} 48 | } 49 | ``` 50 | 51 | ## License 52 | Our code and models are BSD-3 licensed. See LICENSE.txt for details. 53 | 54 | ## Ethics 55 | Predicting the fitness of a protein sequence and capturing the distribution of natural proteins for generative purposes could be a powerful tool for protein design. If our technique or a future iteration thereof is adopted broadly, care should be taken in terms of the end use-cases of these designed samples and downstream effects to ensure safe, non-nefarious, and ethical applications. For projects in any domain, active oversight during project initiation, experimental optimization, and deployment phases should be put in place to ensure safe usage and limitation of unintended harmful effects. 56 | -------------------------------------------------------------------------------- /progen2/likelihood.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, salesforce.com, inc. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 5 | 6 | import os 7 | import time 8 | import random 9 | import argparse 10 | 11 | import torch 12 | 13 | from tokenizers import Tokenizer 14 | from models.progen.modeling_progen import ProGenForCausalLM 15 | 16 | 17 | 18 | ######################################################################## 19 | # util 20 | 21 | 22 | class print_time: 23 | def __init__(self, desc): 24 | self.desc = desc 25 | 26 | def __enter__(self): 27 | print(self.desc) 28 | self.t = time.time() 29 | 30 | def __exit__(self, type, value, traceback): 31 | print(f'{self.desc} took {time.time()-self.t:.02f}s') 32 | 33 | 34 | def set_env(): 35 | os.environ['TOKENIZERS_PARALLELISM'] = 'false' 36 | 37 | 38 | def set_seed(seed, deterministic=True): 39 | random.seed(seed) 40 | os.environ['PYTHONHASHSEED'] = str(seed) 41 | torch.manual_seed(seed) 42 | if torch.cuda.is_available(): 43 | torch.cuda.manual_seed(seed) 44 | torch.backends.cudnn.deterministic = deterministic 45 | torch.backends.cudnn.benchmark = not deterministic 46 | 47 | 48 | 49 | ######################################################################## 50 | # model 51 | 52 | 53 | def create_model(ckpt, fp16=True): 54 | if fp16: 55 | return ProGenForCausalLM.from_pretrained(ckpt, revision='float16', torch_dtype=torch.float16, low_cpu_mem_usage=True) 56 | else: 57 | return ProGenForCausalLM.from_pretrained(ckpt) 58 | 59 | 60 | def create_tokenizer_custom(file): 61 | with open(file, 'r') as f: 62 | return Tokenizer.from_str(f.read()) 63 | 64 | 65 | ######################################################################## 66 | # sample 67 | 68 | 69 | def sample(device, model, tokenizer, context, max_length, num_return_sequences, top_p, temp, pad_token_id): 70 | 71 | with torch.no_grad(): 72 | input_ids = torch.tensor(tokenizer.encode(context).ids).view([1, -1]).to(device) 73 | tokens_batch = model.generate(input_ids, do_sample=True, temperature=temp, max_length=max_length, top_p=top_p, num_return_sequences=num_return_sequences, pad_token_id=pad_token_id) 74 | as_lists = lambda batch: [batch[i, ...].detach().cpu().numpy().tolist() for i in range(batch.shape[0])] 75 | return tokenizer.decode_batch(as_lists(tokens_batch)) 76 | 77 | 78 | ######################################################################## 79 | # likelihood 80 | 81 | def cross_entropy(logits, target, reduction='mean'): 82 | return torch.nn.functional.cross_entropy(input=logits, target=target, weight=None, size_average=None, reduce=None, reduction=reduction) 83 | 84 | 85 | def log_likelihood(logits, target, reduction='mean'): 86 | return -cross_entropy(logits.view(-1, logits.size(-1)), target.view(-1), reduction=reduction) 87 | 88 | 89 | def log_likelihood_custom_1(logits, target, reduction='mean'): 90 | return -torch.nn.functional.nll_loss(input=torch.log_softmax(logits, dim=1), target=target, reduction=reduction) 91 | 92 | 93 | def log_likelihood_custom_2(logits, target, reduction='mean'): 94 | assert len(target.shape) == 1 95 | assert logits.shape[0] == target.shape[0] 96 | 97 | log_likelihood = 0.0 98 | n = logits.shape[0] 99 | for i in range(n): 100 | log_likelihood += torch.log_softmax(logits, dim=1)[i, target[i]] / (1. if reduction == 'sum' else n) 101 | return log_likelihood 102 | 103 | 104 | ######################################################################## 105 | # main 106 | 107 | 108 | def main(): 109 | 110 | # (0) constants 111 | 112 | models_151M = [ 'progen2-small' ] 113 | models_754M = [ 'progen2-medium', 'progen2-oas', 'progen2-base' ] 114 | models_2B = [ 'progen2-large', 'progen2-BFD90' ] 115 | models_6B = [ 'progen2-xlarge' ] 116 | models = models_151M + models_754M + models_2B + models_6B 117 | 118 | 119 | # (1) params 120 | 121 | parser = argparse.ArgumentParser() 122 | parser.add_argument('--model', type=str, choices=models, default='progen2-base') 123 | parser.add_argument('--device', type=str, default='cuda:0') 124 | parser.add_argument('--rng-seed', type=int, default=42) 125 | parser.add_argument('--rng-deterministic', default=True, type=lambda x: (str(x).lower() == 'true')) 126 | parser.add_argument('--fp16', default=True, type=lambda x: (str(x).lower() == 'true')) 127 | parser.add_argument('--context', type=str, default='1MGHGVSRPPVVTLRPAVLDDCPVLWRWRNDPETRQASVDEREIPVDTHTRWFEETLKRFDRKLFIVSADGVDAGMVRLDIQDRDAAVSVNIAPEWRGRGVGPRALGCLSREAFGPLALLRMSAVVKRENAASRIAFERAGFTVVDTGGPLLHSSKARLHVVAAIQARMGSTRLPGKVLVSIAGRPTIQRIAERLAVCQELDAVAVSTSVENRDDAIADLAAHLGLVCVRGSETDLIERLGRTAARTGADALVRITADCPLVDPALVDRVVGVWRRSAGRLEYVSNVFPPTFPDGLDVEVLSRTVLERLDREVSDPFFRESLTAYVREHPAAFEIANVEHPEDLSRLRWTMDYPEDLAFVEAVYRRLGNQGEIFGMDDLLRLLEWSPELRDLNRCREDVTVERGIRGTGYHAALRARGQAP2') 128 | parser.add_argument('--sanity', default=False, type=lambda x: (str(x).lower() == 'true')) 129 | args = parser.parse_args() 130 | 131 | 132 | # (2) preamble 133 | 134 | set_env() 135 | set_seed(args.rng_seed, deterministic=args.rng_deterministic) 136 | 137 | if not torch.cuda.is_available(): 138 | print('falling back to cpu') 139 | args.device = 'cpu' 140 | 141 | device = torch.device(args.device) 142 | ckpt = f'./checkpoints/{args.model}' 143 | 144 | if device.type == 'cpu': 145 | print('falling back to fp32') 146 | args.fp16 = False 147 | 148 | 149 | # (3) load 150 | 151 | with print_time('loading parameters'): 152 | model = create_model(ckpt=ckpt, fp16=args.fp16).to(device) 153 | 154 | 155 | with print_time('loading tokenizer'): 156 | tokenizer = create_tokenizer_custom(file='tokenizer.json') 157 | 158 | 159 | # (4) log likelihood 160 | 161 | def ce(tokens): 162 | with torch.no_grad(): 163 | with torch.cuda.amp.autocast(enabled=args.fp16): 164 | target = torch.tensor(tokenizer.encode(tokens).ids).to(device) 165 | logits = model(target, labels=target).logits 166 | 167 | # shift 168 | logits = logits[:-1, ...] 169 | target = target[1:] 170 | 171 | return cross_entropy(logits=logits, target=target).item() 172 | 173 | 174 | def ll(tokens, f=log_likelihood, reduction='mean'): 175 | with torch.no_grad(): 176 | with torch.cuda.amp.autocast(enabled=args.fp16): 177 | target = torch.tensor(tokenizer.encode(tokens).ids).to(device) 178 | logits = model(target, labels=target).logits 179 | 180 | # shift 181 | logits = logits[:-1, ...] 182 | target = target[1:] 183 | 184 | # remove terminals 185 | bos_token, eos_token = 3, 4 186 | if target[-1] in [bos_token, eos_token]: 187 | logits = logits[:-1, ...] 188 | target = target[:-1] 189 | 190 | assert (target == bos_token).sum() == 0 191 | assert (target == eos_token).sum() == 0 192 | 193 | # remove unused logits 194 | first_token, last_token = 5, 29 195 | logits = logits[:, first_token:(last_token+1)] 196 | target = target - first_token 197 | 198 | assert logits.shape[1] == (last_token - first_token + 1) 199 | 200 | return f(logits=logits, target=target, reduction=reduction).item() 201 | 202 | 203 | # (5) sanity 204 | 205 | if args.sanity: 206 | 207 | with print_time('sanity cross-entropy'): 208 | 209 | x_uniref90bfd30 = '2GFLPFRGADEGLAAREAATLAARGTAARAYREDSWAVPVPRGLLGDLTARVAALGAASPPPADPLAVTLDLHHVTAEVALTTVLDAATLVHGQTRVLSAEDAAEAATAAAAATEAYLERLQDFVLFMSASVRVWRRGNAAGATGPEWDQWYTVADRDALGSAPTHLAVLGRQADALCHFVLDRVAWGTCGTPLWSGDEDLGNVVATFAGYADRLATAPRDLIM1' 210 | x_oas = '1EVQLVESGGGLVQPGGSLRLSCAASGFTFSSYAMHWVRQAPWKGLEYVSAISSNGGSTYYANSVKGRFTISRDNSKNTLYLQMGSLRAEDMAVYYCARDESGYSYGWGYYFDYWGQGTLVTVSS2' 211 | x_bfd90 = '1TAPRSTRASGSEGSRPPGIPAKGRRCLPSRAGSVTPRFRHARQGTATVAKEQGRKLIASNRKARHDYHIEDTFEAGLVLTGTEVKSLRMGRASLIDGYAVFYGEELWLEGVHIPEYLNGNWTNHTPRRRRKLLLNRSELTKLAHKTSESGHTIVPLALYFKDGRAKVEIAVAKGKKAYDKRHALRERQDQREV2' 212 | 213 | checkpoint_x_ce = { 214 | 'progen2-small': (x_uniref90bfd30, 2.4), 215 | 'progen2-medium': (x_uniref90bfd30, 1.9), 216 | 'progen2-base': (x_uniref90bfd30, 1.9), 217 | 'progen2-large': (x_uniref90bfd30, 1.8), 218 | 'progen2-xlarge': (x_uniref90bfd30, 1.0), 219 | 'progen2-oas': (x_oas, 0.3), 220 | 'progen2-BFD90': (x_bfd90, 1.3), 221 | } 222 | 223 | ce_eval = ce(checkpoint_x_ce[args.model][0]) 224 | ce_target = checkpoint_x_ce[args.model][1] 225 | 226 | print(ce_target, ce_eval, abs(ce_eval - ce_target)) 227 | 228 | assert abs(ce_eval - ce_target) < 0.1 229 | 230 | 231 | with print_time('sanity log-likelihood'): 232 | 233 | x_data = '2PAQGRARLAAHYGTGRIGREVTVDERCRNLDRLEPSWELLRLLDDMGFIEGQNGLRRYVAEVFALDEPYDMTWRLRSLDEPHEVNAIEFAAPHERVYATLSERFFPDSVERDLRELVTRSLVEVDLGDPFTPPFVNSVYELRGASRRWVGVVRDVLAPDVLPCDATIRVLADAGTRAATRGLREILDTESGRVCVLGLHAALDAIADDRNEVSTSVAVADLEQCVALREAIRQITPRGAISVLVKGPLRTSGMRAQIAAVVHLRAKSSHLLPGGTDVVTFGAREFAIRSAANERKVVASMRLLALPGFAERSLCGLARPGVGRGRWEPAINVSVAADRDQIDLRVMGADVGDASVIFLKRDFRKLTEEFWRTHTDVPIEREDVSAQRTEPDNRWRWLVPCDDLVAPRLTVVPPRSVGHGM1' 234 | 235 | ll_0 = ll(x_data, f=log_likelihood, reduction='mean') 236 | ll_1 = ll(x_data, f=log_likelihood_custom_1, reduction='mean') 237 | ll_2 = ll(x_data, f=log_likelihood_custom_2, reduction='mean') 238 | 239 | print(f'll_0={ll_0}') 240 | print(f'll_1={ll_1}') 241 | print(f'll_2={ll_2}') 242 | 243 | assert abs(ll_0 - ll_1) < 1e-2 244 | assert abs(ll_0 - ll_2) < 1e-2 245 | 246 | 247 | with print_time('sanity model'): 248 | 249 | alphabet = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z'] 250 | x_data = '2PAQGRARLAAHYGTGRIGREVTVDERCRNLDRLEPSWELLRLLDDMGFIEGQNGLRRYVAEVFALDEPYDMTWRLRSLDEPHEVNAIEFAAPHERVYATLSERFFPDSVERDLRELVTRSLVEVDLGDPFTPPFVNSVYELRGASRRWVGVVRDVLAPDVLPCDATIRVLADAGTRAATRGLREILDTESGRVCVLGLHAALDAIADDRNEVSTSVAVADLEQCVALREAIRQITPRGAISVLVKGPLRTSGMRAQIAAVVHLRAKSSHLLPGGTDVVTFGAREFAIRSAANERKVVASMRLLALPGFAERSLCGLARPGVGRGRWEPAINVSVAADRDQIDLRVMGADVGDASVIFLKRDFRKLTEEFWRTHTDVPIEREDVSAQRTEPDNRWRWLVPCDDLVAPRLTVVPPRSVGHGM1' 251 | x_random = '2' + ''.join([random.choice(alphabet) for _ in range(len(x_data)-2)]) + '1' 252 | x_perturb = x_random[:64] + x_data[len(x_random[:64]):] 253 | 254 | print(x_data) 255 | print(x_perturb) 256 | print(x_random) 257 | 258 | assert x_data != x_perturb 259 | 260 | ll_x_data = ll(x_data) 261 | ll_x_random = ll(x_random) 262 | ll_x_perturb = ll(x_perturb) 263 | 264 | print(f'll_x_data={ll_x_data}') 265 | print(f'll_x_random={ll_x_random}') 266 | print(f'll_x_perturb={ll_x_perturb}') 267 | 268 | assert ll_x_data > ll_x_random 269 | assert ll_x_data > ll_x_perturb 270 | 271 | ''' 272 | # (6) likelihood 273 | 274 | with print_time('log-likelihood (left-to-right)'): 275 | 276 | ll_sum = ll(tokens=args.context, reduction='sum') 277 | ll_mean = ll(tokens=args.context, reduction='mean') 278 | 279 | print(f'll_sum={ll_sum}') 280 | print(f'll_mean={ll_mean}') 281 | ''' 282 | 283 | # (7) likelihood 284 | 285 | with print_time('log-likelihood (left-to-right, right-to-left)'): 286 | 287 | reverse = lambda s: s[::-1] 288 | 289 | ll_lr_sum = ll(tokens=args.context, reduction='sum') 290 | ll_rl_sum = ll(tokens=reverse(args.context), reduction='sum') 291 | 292 | ll_lr_mean = ll(tokens=args.context, reduction='mean') 293 | ll_rl_mean = ll(tokens=reverse(args.context), reduction='mean') 294 | 295 | ll_sum = .5 * (ll_lr_sum + ll_rl_sum) 296 | ll_mean = .5 * (ll_lr_mean + ll_rl_mean) 297 | 298 | print(f'll_sum={(ll_sum)}') 299 | print(f'll_mean={ll_mean}') 300 | 301 | 302 | 303 | if __name__ == '__main__': 304 | main() 305 | print('done.') 306 | -------------------------------------------------------------------------------- /progen2/models/progen/configuration_progen.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The EleutherAI and HuggingFace Teams. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Modified configuration implementation based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/gptj/configuration_gptj.py 17 | 18 | from transformers.configuration_utils import PretrainedConfig 19 | from transformers.utils import logging 20 | 21 | logger = logging.get_logger(__name__) 22 | 23 | 24 | class ProGenConfig(PretrainedConfig): 25 | model_type = "progen" 26 | 27 | def __init__( 28 | self, 29 | vocab_size=50400, 30 | n_positions=2048, 31 | n_ctx=2048, 32 | n_embd=4096, 33 | n_layer=28, 34 | n_head=16, 35 | rotary_dim=64, 36 | n_inner=None, 37 | activation_function="gelu_new", 38 | resid_pdrop=0.0, 39 | embd_pdrop=0.0, 40 | attn_pdrop=0.0, 41 | layer_norm_epsilon=1e-5, 42 | initializer_range=0.02, 43 | scale_attn_weights=True, 44 | gradient_checkpointing=False, 45 | use_cache=True, 46 | bos_token_id=50256, 47 | eos_token_id=50256, 48 | **kwargs 49 | ): 50 | super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) 51 | 52 | self.vocab_size = vocab_size 53 | self.n_ctx = n_ctx 54 | self.n_positions = n_positions 55 | self.n_embd = n_embd 56 | self.n_layer = n_layer 57 | self.n_head = n_head 58 | self.n_inner = n_inner 59 | self.rotary_dim = rotary_dim 60 | self.activation_function = activation_function 61 | self.resid_pdrop = resid_pdrop 62 | self.embd_pdrop = embd_pdrop 63 | self.attn_pdrop = attn_pdrop 64 | self.layer_norm_epsilon = layer_norm_epsilon 65 | self.initializer_range = initializer_range 66 | self.gradient_checkpointing = gradient_checkpointing 67 | self.scale_attn_weights = scale_attn_weights 68 | self.use_cache = use_cache 69 | 70 | self.bos_token_id = bos_token_id 71 | self.eos_token_id = eos_token_id 72 | 73 | @property 74 | def max_position_embeddings(self): 75 | return self.n_positions 76 | 77 | @property 78 | def hidden_size(self): 79 | return self.n_embd 80 | 81 | @property 82 | def num_attention_heads(self): 83 | return self.n_head 84 | 85 | @property 86 | def num_hidden_layers(self): 87 | return self.n_layer 88 | -------------------------------------------------------------------------------- /progen2/models/progen/modeling_progen.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The EleutherAI and HuggingFace Teams. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Modified forward-pass implementation based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/gptj/modeling_gptj.py 17 | 18 | from typing import Tuple 19 | 20 | import numpy as np 21 | 22 | import torch 23 | import torch.utils.checkpoint 24 | from torch import nn 25 | from torch.nn import CrossEntropyLoss 26 | 27 | from transformers.activations import ACT2FN 28 | from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast 29 | from transformers.modeling_utils import PreTrainedModel 30 | from transformers.utils import logging 31 | from transformers.utils.model_parallel_utils import assert_device_map, get_device_map 32 | from .configuration_progen import ProGenConfig 33 | 34 | 35 | logger = logging.get_logger(__name__) 36 | 37 | 38 | def fixed_pos_embedding(x, seq_dim=1, seq_len=None): 39 | dim = x.shape[-1] 40 | if seq_len is None: 41 | seq_len = x.shape[seq_dim] 42 | inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2) / dim)) 43 | sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(seq_len), inv_freq).to(x.device).float() 44 | return torch.sin(sinusoid_inp), torch.cos(sinusoid_inp) 45 | 46 | 47 | def rotate_every_two(x): 48 | x1 = x[:, :, :, ::2] 49 | x2 = x[:, :, :, 1::2] 50 | x = torch.stack((-x2, x1), axis=-1) 51 | return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)') 52 | 53 | 54 | def apply_rotary_pos_emb(x, sincos, offset=0): 55 | sin, cos = map(lambda t: t[None, offset : x.shape[1] + offset, None, :].repeat_interleave(2, 3), sincos) 56 | # einsum notation for lambda t: repeat(t[offset:x.shape[1]+offset,:], "n d -> () n () (d j)", j=2) 57 | return (x * cos) + (rotate_every_two(x) * sin) 58 | 59 | 60 | class ProGenAttention(nn.Module): 61 | def __init__(self, config): 62 | super().__init__() 63 | 64 | max_positions = config.max_position_embeddings 65 | self.register_buffer( 66 | "bias", 67 | torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view( 68 | 1, 1, max_positions, max_positions 69 | ), 70 | ) 71 | self.register_buffer("masked_bias", torch.tensor(-1e9)) 72 | 73 | self.attn_dropout = nn.Dropout(config.attn_pdrop) 74 | self.resid_dropout = nn.Dropout(config.resid_pdrop) 75 | 76 | self.embed_dim = config.hidden_size 77 | self.num_attention_heads = config.num_attention_heads 78 | self.head_dim = self.embed_dim // self.num_attention_heads 79 | if self.head_dim * self.num_attention_heads != self.embed_dim: 80 | raise ValueError( 81 | f"embed_dim must be divisible by num_attention_heads (got `embed_dim`: {self.embed_dim} and `num_attention_heads`: {self.num_attention_heads})." 82 | ) 83 | self.scale_attn = torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype()) 84 | self.qkv_proj = nn.Linear(self.embed_dim, self.embed_dim * 3, bias=False) 85 | 86 | self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) 87 | self.rotary_dim = None 88 | if config.rotary_dim is not None: 89 | self.rotary_dim = config.rotary_dim 90 | 91 | def _split_heads(self, x, n_head, dim_head, mp_num): 92 | reshaped = x.reshape(x.shape[:-1] + (n_head//mp_num, dim_head)) 93 | reshaped = reshaped.reshape(x.shape[:-2] + (-1, ) + reshaped.shape[-1:]) 94 | return reshaped 95 | 96 | def _merge_heads(self, tensor, num_attention_heads, attn_head_size): 97 | """ 98 | Merges attn_head_size dim and num_attn_heads dim into n_ctx 99 | """ 100 | if len(tensor.shape) == 5: 101 | tensor = tensor.permute(0, 1, 3, 2, 4).contiguous() 102 | elif len(tensor.shape) == 4: 103 | tensor = tensor.permute(0, 2, 1, 3).contiguous() 104 | else: 105 | raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}") 106 | new_shape = tensor.size()[:-2] + (num_attention_heads * attn_head_size,) 107 | return tensor.view(new_shape) 108 | 109 | def _attn( 110 | self, 111 | query, 112 | key, 113 | value, 114 | attention_mask=None, 115 | head_mask=None, 116 | ): 117 | 118 | # compute causal mask from causal mask buffer 119 | query_length, key_length = query.size(-2), key.size(-2) 120 | causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] 121 | 122 | # Keep the attention weights computation in fp32 to avoid overflow issues 123 | query = query.to(torch.float32) 124 | key = key.to(torch.float32) 125 | 126 | attn_weights = torch.matmul(query, key.transpose(-1, -2)) 127 | 128 | attn_weights = attn_weights / self.scale_attn 129 | attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype)) 130 | 131 | if attention_mask is not None: 132 | # Apply the attention mask 133 | attn_weights = attn_weights + attention_mask 134 | 135 | attn_weights = nn.Softmax(dim=-1)(attn_weights) 136 | attn_weights = attn_weights.to(value.dtype) 137 | attn_weights = self.attn_dropout(attn_weights) 138 | 139 | # Mask heads if we want to 140 | if head_mask is not None: 141 | attn_weights = attn_weights * head_mask 142 | 143 | attn_output = torch.matmul(attn_weights, value) 144 | 145 | return attn_output, attn_weights 146 | 147 | def forward( 148 | self, 149 | hidden_states, 150 | attention_mask=None, 151 | layer_past=None, 152 | head_mask=None, 153 | use_cache=False, 154 | output_attentions=False, 155 | ): 156 | 157 | qkv = self.qkv_proj(hidden_states) 158 | # TODO(enijkamp): factor out number of logical TPU-v3/v4 cores or make forward pass agnostic 159 | # mp_num = 4 160 | mp_num = 8 161 | qkv_split = qkv.reshape(qkv.shape[:-1] + (mp_num, -1)) 162 | 163 | local_dim = self.head_dim * self.num_attention_heads // mp_num 164 | query, value, key = torch.split(qkv_split, local_dim, dim=-1) 165 | query = self._split_heads(query, self.num_attention_heads, self.head_dim, mp_num=mp_num) 166 | key = self._split_heads(key, self.num_attention_heads, self.head_dim, mp_num=mp_num) 167 | 168 | value = self._split_heads(value, self.num_attention_heads, self.head_dim, mp_num=mp_num) 169 | value = value.permute(0, 2, 1, 3) 170 | 171 | seq_len = key.shape[1] 172 | offset = 0 173 | 174 | if layer_past is not None: 175 | offset = layer_past[0].shape[-2] 176 | seq_len += offset 177 | 178 | if self.rotary_dim is not None: 179 | k_rot = key[:, :, :, : self.rotary_dim] 180 | k_pass = key[:, :, :, self.rotary_dim :] 181 | 182 | q_rot = query[:, :, :, : self.rotary_dim] 183 | q_pass = query[:, :, :, self.rotary_dim :] 184 | 185 | sincos = fixed_pos_embedding(k_rot, 1, seq_len=seq_len) 186 | k_rot = apply_rotary_pos_emb(k_rot, sincos, offset=offset) 187 | q_rot = apply_rotary_pos_emb(q_rot, sincos, offset=offset) 188 | 189 | key = torch.cat([k_rot, k_pass], dim=-1) 190 | query = torch.cat([q_rot, q_pass], dim=-1) 191 | else: 192 | sincos = fixed_pos_embedding(key, 1, seq_len=seq_len) 193 | key = apply_rotary_pos_emb(key, sincos, offset=offset) 194 | query = apply_rotary_pos_emb(query, sincos, offset=offset) 195 | 196 | key = key.permute(0, 2, 1, 3) 197 | query = query.permute(0, 2, 1, 3) 198 | 199 | if layer_past is not None: 200 | past_key = layer_past[0] 201 | past_value = layer_past[1] 202 | key = torch.cat((past_key, key), dim=-2) 203 | value = torch.cat((past_value, value), dim=-2) 204 | 205 | if use_cache is True: 206 | present = (key, value) 207 | else: 208 | present = None 209 | 210 | # compute self-attention: V x Softmax(QK^T) 211 | attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) 212 | 213 | attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_dim) 214 | 215 | attn_output = self.out_proj(attn_output) 216 | attn_output = self.resid_dropout(attn_output) 217 | 218 | outputs = (attn_output, present) 219 | if output_attentions: 220 | outputs += (attn_weights,) 221 | 222 | return outputs # a, present, (attentions) 223 | 224 | 225 | class ProGenMLP(nn.Module): 226 | def __init__(self, intermediate_size, config): # in MLP: intermediate_size= 4 * embed_dim 227 | super().__init__() 228 | embed_dim = config.n_embd 229 | 230 | self.fc_in = nn.Linear(embed_dim, intermediate_size) 231 | self.fc_out = nn.Linear(intermediate_size, embed_dim) 232 | 233 | self.act = ACT2FN[config.activation_function] 234 | self.dropout = nn.Dropout(config.resid_pdrop) 235 | 236 | def forward(self, hidden_states): 237 | hidden_states = self.fc_in(hidden_states) 238 | hidden_states = self.act(hidden_states) 239 | hidden_states = self.fc_out(hidden_states) 240 | hidden_states = self.dropout(hidden_states) 241 | return hidden_states 242 | 243 | 244 | class ProGenBlock(nn.Module): 245 | def __init__(self, config): 246 | super().__init__() 247 | inner_dim = config.n_inner if config.n_inner is not None else 4 * config.n_embd 248 | self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) 249 | self.attn = ProGenAttention(config) 250 | self.mlp = ProGenMLP(inner_dim, config) 251 | 252 | def forward( 253 | self, 254 | hidden_states, 255 | layer_past=None, 256 | attention_mask=None, 257 | head_mask=None, 258 | use_cache=False, 259 | output_attentions=False, 260 | ): 261 | residual = hidden_states 262 | hidden_states = self.ln_1(hidden_states) 263 | attn_outputs = self.attn( 264 | hidden_states, 265 | layer_past=layer_past, 266 | attention_mask=attention_mask, 267 | head_mask=head_mask, 268 | use_cache=use_cache, 269 | output_attentions=output_attentions, 270 | ) 271 | attn_output = attn_outputs[0] # output_attn: a, present, (attentions) 272 | outputs = attn_outputs[1:] 273 | 274 | feed_forward_hidden_states = self.mlp(hidden_states) 275 | hidden_states = attn_output + feed_forward_hidden_states + residual 276 | 277 | if use_cache: 278 | outputs = (hidden_states,) + outputs 279 | else: 280 | outputs = (hidden_states,) + outputs[1:] 281 | 282 | return outputs # hidden_states, present, (attentions) 283 | 284 | 285 | class ProGenPreTrainedModel(PreTrainedModel): 286 | """ 287 | An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained 288 | models. 289 | """ 290 | 291 | config_class = ProGenConfig 292 | base_model_prefix = "transformer" 293 | is_parallelizable = True 294 | 295 | def __init__(self, *inputs, **kwargs): 296 | super().__init__(*inputs, **kwargs) 297 | 298 | def _init_weights(self, module): 299 | """Initialize the weights.""" 300 | if isinstance(module, (nn.Linear,)): 301 | # Slightly different from Mesh Transformer JAX which uses truncated_normal for initialization 302 | # cf https://github.com/pytorch/pytorch/pull/5617 303 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 304 | if module.bias is not None: 305 | module.bias.data.zero_() 306 | elif isinstance(module, nn.Embedding): 307 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 308 | if module.padding_idx is not None: 309 | module.weight.data[module.padding_idx].zero_() 310 | elif isinstance(module, nn.LayerNorm): 311 | module.bias.data.zero_() 312 | module.weight.data.fill_(1.0) 313 | 314 | 315 | class ProGenModel(ProGenPreTrainedModel): 316 | def __init__(self, config): 317 | super().__init__(config) 318 | 319 | self.embed_dim = config.n_embd 320 | self.vocab_size = config.vocab_size 321 | self.wte = nn.Embedding(config.vocab_size, self.embed_dim) 322 | self.drop = nn.Dropout(config.embd_pdrop) 323 | self.h = nn.ModuleList([ProGenBlock(config) for _ in range(config.n_layer)]) 324 | self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) 325 | self.rotary_dim = min(config.rotary_dim, config.n_ctx // config.num_attention_heads) 326 | self.init_weights() 327 | 328 | # Model parallel 329 | self.model_parallel = False 330 | self.device_map = None 331 | 332 | 333 | def parallelize(self, device_map=None): 334 | # Check validity of device_map 335 | self.device_map = ( 336 | get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map 337 | ) 338 | assert_device_map(self.device_map, len(self.h)) 339 | self.model_parallel = True 340 | self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys())) 341 | self.last_device = "cuda:" + str(max(self.device_map.keys())) 342 | self.wte = self.wte.to(self.first_device) 343 | # Load onto devices 344 | for k, v in self.device_map.items(): 345 | for block in v: 346 | cuda_device = "cuda:" + str(k) 347 | self.h[block] = self.h[block].to(cuda_device) 348 | # ln_f to last 349 | self.ln_f = self.ln_f.to(self.last_device) 350 | 351 | 352 | def deparallelize(self): 353 | self.model_parallel = False 354 | self.device_map = None 355 | self.first_device = "cpu" 356 | self.last_device = "cpu" 357 | self.wte = self.wte.to("cpu") 358 | for index in range(len(self.h)): 359 | self.h[index] = self.h[index].to("cpu") 360 | self.ln_f = self.ln_f.to("cpu") 361 | torch.cuda.empty_cache() 362 | 363 | def get_input_embeddings(self): 364 | return self.wte 365 | 366 | def set_input_embeddings(self, new_embeddings): 367 | self.wte = new_embeddings 368 | 369 | def forward( 370 | self, 371 | input_ids=None, 372 | past_key_values=None, 373 | attention_mask=None, 374 | token_type_ids=None, 375 | position_ids=None, 376 | head_mask=None, 377 | inputs_embeds=None, 378 | use_cache=None, 379 | output_attentions=None, 380 | output_hidden_states=None, 381 | return_dict=None, 382 | ): 383 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 384 | output_hidden_states = ( 385 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 386 | ) 387 | use_cache = use_cache if use_cache is not None else self.config.use_cache 388 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 389 | 390 | if input_ids is not None and inputs_embeds is not None: 391 | raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") 392 | elif input_ids is not None: 393 | input_shape = input_ids.size() 394 | input_ids = input_ids.view(-1, input_shape[-1]) 395 | batch_size = input_ids.shape[0] 396 | elif inputs_embeds is not None: 397 | input_shape = inputs_embeds.size()[:-1] 398 | batch_size = inputs_embeds.shape[0] 399 | else: 400 | raise ValueError("You have to specify either input_ids or inputs_embeds") 401 | 402 | device = input_ids.device if input_ids is not None else inputs_embeds.device 403 | 404 | if token_type_ids is not None: 405 | token_type_ids = token_type_ids.view(-1, input_shape[-1]) 406 | 407 | if position_ids is not None: 408 | position_ids = position_ids.view(-1, input_shape[-1]) 409 | 410 | if past_key_values is None: 411 | past_length = 0 412 | past_key_values = tuple([None] * len(self.h)) 413 | else: 414 | past_length = past_key_values[0][0].size(-2) 415 | 416 | if position_ids is None: 417 | position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) 418 | position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) 419 | 420 | # Attention mask. 421 | if attention_mask is not None: 422 | assert batch_size > 0, "batch_size has to be defined and > 0" 423 | attention_mask = attention_mask.view(batch_size, -1) 424 | # We create a 3D attention mask from a 2D tensor mask. 425 | # Sizes are [batch_size, 1, 1, to_seq_length] 426 | # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] 427 | # this attention mask is more simple than the triangular masking of causal attention 428 | # used in OpenAI GPT, we just need to prepare the broadcast dimension here. 429 | attention_mask = attention_mask[:, None, None, :] 430 | 431 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for 432 | # masked positions, this operation will create a tensor which is 0.0 for 433 | # positions we want to attend and -10000.0 for masked positions. 434 | # Since we are adding it to the raw scores before the softmax, this is 435 | # effectively the same as removing these entirely. 436 | attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility 437 | attention_mask = (1.0 - attention_mask) * -10000.0 438 | 439 | # Prepare head mask if needed 440 | # 1.0 in head_mask indicate we keep the head 441 | # attention_probs has shape bsz x num_attention_heads x N x N 442 | # head_mask has shape n_layer x batch x num_attention_heads x N x N 443 | head_mask = self.get_head_mask(head_mask, self.config.n_layer) 444 | 445 | if inputs_embeds is None: 446 | inputs_embeds = self.wte(input_ids) 447 | 448 | hidden_states = inputs_embeds 449 | 450 | if token_type_ids is not None: 451 | token_type_embeds = self.wte(token_type_ids) 452 | hidden_states = hidden_states + token_type_embeds 453 | 454 | hidden_states = self.drop(hidden_states) 455 | 456 | output_shape = input_shape + (hidden_states.size(-1),) 457 | 458 | presents = () if use_cache else None 459 | all_self_attentions = () if output_attentions else None 460 | all_hidden_states = () if output_hidden_states else None 461 | for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): 462 | 463 | # Model parallel 464 | if self.model_parallel: 465 | torch.cuda.set_device(hidden_states.device) 466 | # Ensure layer_past is on same device as hidden_states (might not be correct) 467 | if layer_past is not None: 468 | layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) 469 | # Ensure that attention_mask is always on the same device as hidden_states 470 | if attention_mask is not None: 471 | attention_mask = attention_mask.to(hidden_states.device) 472 | if isinstance(head_mask, torch.Tensor): 473 | head_mask = head_mask.to(hidden_states.device) 474 | if output_hidden_states: 475 | all_hidden_states = all_hidden_states + (hidden_states,) 476 | 477 | if getattr(self.config, "gradient_checkpointing", False) and self.training: 478 | 479 | if use_cache: 480 | logger.warning( 481 | "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " 482 | "`use_cache=False`..." 483 | ) 484 | use_cache = False 485 | 486 | def create_custom_forward(module): 487 | def custom_forward(*inputs): 488 | # None for past_key_value 489 | return module(*inputs, use_cache, output_attentions) 490 | 491 | return custom_forward 492 | 493 | outputs = torch.utils.checkpoint.checkpoint( 494 | create_custom_forward(block), 495 | hidden_states, 496 | None, 497 | attention_mask, 498 | head_mask[i], 499 | ) 500 | else: 501 | outputs = block( 502 | hidden_states, 503 | layer_past=layer_past, 504 | attention_mask=attention_mask, 505 | head_mask=head_mask[i], 506 | use_cache=use_cache, 507 | output_attentions=output_attentions, 508 | ) 509 | 510 | hidden_states = outputs[0] 511 | if use_cache is True: 512 | presents = presents + (outputs[1],) 513 | 514 | if output_attentions: 515 | all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) 516 | 517 | # Model Parallel: If it's the last layer for that device, put things on the next device 518 | if self.model_parallel: 519 | for k, v in self.device_map.items(): 520 | if i == v[-1] and "cuda:" + str(k) != self.last_device: 521 | hidden_states = hidden_states.to("cuda:" + str(k + 1)) 522 | 523 | hidden_states = self.ln_f(hidden_states) 524 | 525 | hidden_states = hidden_states.view(*output_shape) 526 | # Add last hidden state 527 | if output_hidden_states: 528 | all_hidden_states = all_hidden_states + (hidden_states,) 529 | 530 | if not return_dict: 531 | return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) 532 | 533 | return BaseModelOutputWithPast( 534 | last_hidden_state=hidden_states, 535 | past_key_values=presents, 536 | hidden_states=all_hidden_states, 537 | attentions=all_self_attentions, 538 | ) 539 | 540 | 541 | class ProGenForCausalLM(ProGenPreTrainedModel): 542 | _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias", r"lm_head\.weight"] 543 | 544 | def __init__(self, config): 545 | super().__init__(config) 546 | self.transformer = ProGenModel(config) 547 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size) 548 | self.init_weights() 549 | 550 | # Model parallel 551 | self.model_parallel = False 552 | self.device_map = None 553 | 554 | def parallelize(self, device_map=None): 555 | self.device_map = ( 556 | get_device_map(len(self.transformer.h), range(torch.cuda.device_count())) 557 | if device_map is None 558 | else device_map 559 | ) 560 | assert_device_map(self.device_map, len(self.transformer.h)) 561 | self.transformer.parallelize(self.device_map) 562 | self.lm_head = self.lm_head.to(self.transformer.first_device) 563 | self.model_parallel = True 564 | 565 | def deparallelize(self): 566 | self.transformer.deparallelize() 567 | self.transformer = self.transformer.to("cpu") 568 | self.lm_head = self.lm_head.to("cpu") 569 | self.model_parallel = False 570 | torch.cuda.empty_cache() 571 | 572 | def get_output_embeddings(self): 573 | return None 574 | 575 | def set_output_embeddings(self, new_embeddings): 576 | return 577 | 578 | def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): 579 | token_type_ids = kwargs.get("token_type_ids", None) 580 | # only last token for inputs_ids if past is defined in kwargs 581 | if past: 582 | input_ids = input_ids[:, -1].unsqueeze(-1) 583 | if token_type_ids is not None: 584 | token_type_ids = token_type_ids[:, -1].unsqueeze(-1) 585 | 586 | attention_mask = kwargs.get("attention_mask", None) 587 | position_ids = kwargs.get("position_ids", None) 588 | 589 | if attention_mask is not None and position_ids is None: 590 | # create position_ids on the fly for batch generation 591 | position_ids = attention_mask.long().cumsum(-1) - 1 592 | position_ids.masked_fill_(attention_mask == 0, 1) 593 | if past: 594 | position_ids = position_ids[:, -1].unsqueeze(-1) 595 | else: 596 | position_ids = None 597 | return { 598 | "input_ids": input_ids, 599 | "past_key_values": past, 600 | "use_cache": kwargs.get("use_cache"), 601 | "position_ids": position_ids, 602 | "attention_mask": attention_mask, 603 | "token_type_ids": token_type_ids, 604 | } 605 | 606 | def forward( 607 | self, 608 | input_ids=None, 609 | past_key_values=None, 610 | attention_mask=None, 611 | token_type_ids=None, 612 | position_ids=None, 613 | head_mask=None, 614 | inputs_embeds=None, 615 | labels=None, 616 | use_cache=None, 617 | output_attentions=None, 618 | output_hidden_states=None, 619 | return_dict=None, 620 | ): 621 | r""" 622 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): 623 | Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set 624 | ``labels = input_ids`` Indices are selected in ``[-100, 0, ..., config.vocab_size]`` All labels set to 625 | ``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ..., config.vocab_size]`` 626 | """ 627 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 628 | 629 | transformer_outputs = self.transformer( 630 | input_ids, 631 | past_key_values=past_key_values, 632 | attention_mask=attention_mask, 633 | token_type_ids=token_type_ids, 634 | position_ids=position_ids, 635 | head_mask=head_mask, 636 | inputs_embeds=inputs_embeds, 637 | use_cache=use_cache, 638 | output_attentions=output_attentions, 639 | output_hidden_states=output_hidden_states, 640 | return_dict=return_dict, 641 | ) 642 | hidden_states = transformer_outputs[0] 643 | 644 | # Set device for model parallelism 645 | if self.model_parallel: 646 | torch.cuda.set_device(self.transformer.first_device) 647 | hidden_states = hidden_states.to(self.lm_head.weight.device) 648 | 649 | # make sure sampling in fp16 works correctly and 650 | # compute loss in fp32 to match with mesh-tf version 651 | # https://github.com/EleutherAI/gpt-neo/blob/89ce74164da2fb16179106f54e2269b5da8db333/models/gpt2/gpt2.py#L179 652 | lm_logits = self.lm_head(hidden_states).to(torch.float32) 653 | 654 | loss = None 655 | if labels is not None: 656 | # Shift so that tokens < n predict n 657 | shift_logits = lm_logits[..., :-1, :].contiguous() 658 | shift_labels = labels[..., 1:].contiguous() 659 | # Flatten the tokens 660 | loss_fct = CrossEntropyLoss() 661 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 662 | 663 | loss = loss.to(hidden_states.dtype) 664 | 665 | if not return_dict: 666 | output = (lm_logits,) + transformer_outputs[1:] 667 | return ((loss,) + output) if loss is not None else output 668 | 669 | return CausalLMOutputWithPast( 670 | loss=loss, 671 | logits=lm_logits, 672 | past_key_values=transformer_outputs.past_key_values, 673 | hidden_states=transformer_outputs.hidden_states, 674 | attentions=transformer_outputs.attentions, 675 | ) 676 | 677 | @staticmethod 678 | def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]: 679 | """ 680 | This function is used to re-order the :obj:`past_key_values` cache if 681 | :meth:`~transformers.PretrainedModel.beam_search` or :meth:`~transformers.PretrainedModel.beam_sample` is 682 | called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step. 683 | """ 684 | return tuple( 685 | tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) 686 | for layer_past in past 687 | ) 688 | -------------------------------------------------------------------------------- /progen2/requirements.txt: -------------------------------------------------------------------------------- 1 | --find-links https://download.pytorch.org/whl/torch_stable.html 2 | torch==1.9.0+cu111 3 | transformers==4.16.2 4 | tokenizers==0.10.3 -------------------------------------------------------------------------------- /progen2/sample.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, salesforce.com, inc. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 5 | 6 | import os 7 | import time 8 | import random 9 | import argparse 10 | 11 | import torch 12 | 13 | from tokenizers import Tokenizer 14 | from models.progen.modeling_progen import ProGenForCausalLM 15 | 16 | 17 | 18 | ######################################################################## 19 | # util 20 | 21 | 22 | class print_time: 23 | def __init__(self, desc): 24 | self.desc = desc 25 | 26 | def __enter__(self): 27 | print(self.desc) 28 | self.t = time.time() 29 | 30 | def __exit__(self, type, value, traceback): 31 | print(f'{self.desc} took {time.time()-self.t:.02f}s') 32 | 33 | 34 | def set_env(): 35 | os.environ['TOKENIZERS_PARALLELISM'] = 'false' 36 | 37 | 38 | def set_seed(seed, deterministic=True): 39 | random.seed(seed) 40 | os.environ['PYTHONHASHSEED'] = str(seed) 41 | torch.manual_seed(seed) 42 | if torch.cuda.is_available(): 43 | torch.cuda.manual_seed(seed) 44 | torch.backends.cudnn.deterministic = deterministic 45 | torch.backends.cudnn.benchmark = not deterministic 46 | 47 | 48 | 49 | ######################################################################## 50 | # model 51 | 52 | 53 | def create_model(ckpt, fp16=True): 54 | if fp16: 55 | return ProGenForCausalLM.from_pretrained(ckpt, revision='float16', torch_dtype=torch.float16, low_cpu_mem_usage=True) 56 | else: 57 | return ProGenForCausalLM.from_pretrained(ckpt) 58 | 59 | 60 | def create_tokenizer_custom(file): 61 | with open(file, 'r') as f: 62 | return Tokenizer.from_str(f.read()) 63 | 64 | 65 | ######################################################################## 66 | # sample 67 | 68 | 69 | def sample(device, model, tokenizer, context, max_length, num_return_sequences, top_p, temp, pad_token_id): 70 | 71 | with torch.no_grad(): 72 | input_ids = torch.tensor(tokenizer.encode(context).ids).view([1, -1]).to(device) 73 | tokens_batch = model.generate(input_ids, do_sample=True, temperature=temp, max_length=max_length, top_p=top_p, num_return_sequences=num_return_sequences, pad_token_id=pad_token_id) 74 | as_lists = lambda batch: [batch[i, ...].detach().cpu().numpy().tolist() for i in range(batch.shape[0])] 75 | return tokenizer.decode_batch(as_lists(tokens_batch)) 76 | 77 | 78 | def truncate(sample, terminals): 79 | pos = [] 80 | for terminal in terminals: 81 | find_pos = sample.find(terminal, 1) 82 | if find_pos != -1: 83 | pos.append(find_pos) 84 | if len(pos) > 0: 85 | return sample[:(min(pos)+1)] 86 | else: 87 | return sample 88 | 89 | 90 | def cross_entropy(logits, target, reduction='mean'): 91 | return torch.nn.functional.cross_entropy(input=logits, target=target, weight=None, size_average=None, reduce=None, reduction=reduction) 92 | 93 | 94 | 95 | ######################################################################## 96 | # main 97 | 98 | 99 | def main(): 100 | 101 | # (0) constants 102 | 103 | models_151M = [ 'progen2-small' ] 104 | models_754M = [ 'progen2-medium', 'progen2-oas', 'progen2-base' ] 105 | models_2B = [ 'progen2-large', 'progen2-BFD90' ] 106 | models_6B = [ 'progen2-xlarge' ] 107 | models = models_151M + models_754M + models_2B + models_6B 108 | 109 | # (1) params 110 | 111 | parser = argparse.ArgumentParser() 112 | parser.add_argument('--model', type=str, choices=models, default='progen2-large') 113 | parser.add_argument('--device', type=str, default='cuda:0') 114 | parser.add_argument('--rng-seed', type=int, default=42) 115 | parser.add_argument('--rng-deterministic', default=True, type=lambda x: (str(x).lower() == 'true')) 116 | parser.add_argument('--p', type=float, default=0.95) 117 | parser.add_argument('--t', type=float, default=0.2) 118 | parser.add_argument('--max-length', type=int, default=256) 119 | parser.add_argument('--num-samples', type=int, default=1) 120 | parser.add_argument('--fp16', default=True, type=lambda x: (str(x).lower() == 'true')) 121 | parser.add_argument('--context', type=str, default='1') 122 | parser.add_argument('--sanity', default=True, type=lambda x: (str(x).lower() == 'true')) 123 | args = parser.parse_args() 124 | 125 | 126 | # (2) preamble 127 | 128 | set_env() 129 | set_seed(args.rng_seed, deterministic=args.rng_deterministic) 130 | 131 | if not torch.cuda.is_available(): 132 | print('falling back to cpu') 133 | args.device = 'cpu' 134 | 135 | device = torch.device(args.device) 136 | ckpt = f'./checkpoints/{args.model}' 137 | 138 | if device.type == 'cpu': 139 | print('falling back to fp32') 140 | args.fp16 = False 141 | 142 | # (3) load 143 | 144 | with print_time('loading parameters'): 145 | model = create_model(ckpt=ckpt, fp16=args.fp16).to(device) 146 | 147 | 148 | with print_time('loading tokenizer'): 149 | tokenizer = create_tokenizer_custom(file='tokenizer.json') 150 | 151 | # (4) sanity 152 | 153 | if args.sanity: 154 | 155 | with print_time('sanity cross-entropy'): 156 | 157 | def ce(tokens): 158 | with torch.no_grad(): 159 | with torch.cuda.amp.autocast(enabled=args.fp16): 160 | target = torch.tensor(tokenizer.encode(tokens).ids).to(device) 161 | logits = model(target, labels=target).logits 162 | 163 | # shift 164 | logits = logits[:-1, ...] 165 | target = target[1:] 166 | 167 | return cross_entropy(logits=logits, target=target).item() 168 | 169 | x_uniref90bfd30 = '2GFLPFRGADEGLAAREAATLAARGTAARAYREDSWAVPVPRGLLGDLTARVAALGAASPPPADPLAVTLDLHHVTAEVALTTVLDAATLVHGQTRVLSAEDAAEAATAAAAATEAYLERLQDFVLFMSASVRVWRRGNAAGATGPEWDQWYTVADRDALGSAPTHLAVLGRQADALCHFVLDRVAWGTCGTPLWSGDEDLGNVVATFAGYADRLATAPRDLIM1' 170 | x_oas = '1EVQLVESGGGLVQPGGSLRLSCAASGFTFSSYAMHWVRQAPWKGLEYVSAISSNGGSTYYANSVKGRFTISRDNSKNTLYLQMGSLRAEDMAVYYCARDESGYSYGWGYYFDYWGQGTLVTVSS2' 171 | x_bfd90 = '1TAPRSTRASGSEGSRPPGIPAKGRRCLPSRAGSVTPRFRHARQGTATVAKEQGRKLIASNRKARHDYHIEDTFEAGLVLTGTEVKSLRMGRASLIDGYAVFYGEELWLEGVHIPEYLNGNWTNHTPRRRRKLLLNRSELTKLAHKTSESGHTIVPLALYFKDGRAKVEIAVAKGKKAYDKRHALRERQDQREV2' 172 | 173 | checkpoint_x_ce = { 174 | 'progen2-small': (x_uniref90bfd30, 2.4), 175 | 'progen2-medium': (x_uniref90bfd30, 1.9), 176 | 'progen2-base': (x_uniref90bfd30, 1.9), 177 | 'progen2-large': (x_uniref90bfd30, 1.8), 178 | 'progen2-xlarge': (x_uniref90bfd30, 1.0), 179 | 'progen2-oas': (x_oas, 0.3), 180 | 'progen2-BFD90': (x_bfd90, 1.3), 181 | } 182 | 183 | ce_eval = ce(checkpoint_x_ce[args.model][0]) 184 | ce_target = checkpoint_x_ce[args.model][1] 185 | 186 | print(ce_target, ce_eval, abs(ce_eval - ce_target)) 187 | 188 | assert abs(ce_eval - ce_target) < 0.1 189 | 190 | # (5) sample 191 | 192 | with print_time('sampling'): 193 | completions = sample(device=device, model=model, tokenizer=tokenizer, context=args.context, pad_token_id=tokenizer.encode('<|pad|>').ids[0], num_return_sequences=args.num_samples, temp=args.t, top_p=args.p, max_length=args.max_length) 194 | truncations = [truncate(completion, terminals=['1', '2']) for completion in completions] 195 | 196 | print(args.context) 197 | 198 | for (i, truncation) in enumerate(truncations): 199 | 200 | print() 201 | print(i) 202 | print(truncation) 203 | 204 | 205 | 206 | if __name__ == '__main__': 207 | main() 208 | print('done.') 209 | -------------------------------------------------------------------------------- /progen2/tokenizer.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": "1.0", 3 | "truncation": null, 4 | "padding": null, 5 | "added_tokens": [ 6 | { 7 | "id": 0, 8 | "special": true, 9 | "content": "<|pad|>", 10 | "single_word": false, 11 | "lstrip": false, 12 | "rstrip": false, 13 | "normalized": false 14 | }, 15 | { 16 | "id": 1, 17 | "special": true, 18 | "content": "<|bos|>", 19 | "single_word": false, 20 | "lstrip": false, 21 | "rstrip": false, 22 | "normalized": false 23 | }, 24 | { 25 | "id": 2, 26 | "special": true, 27 | "content": "<|eos|>", 28 | "single_word": false, 29 | "lstrip": false, 30 | "rstrip": false, 31 | "normalized": false 32 | } 33 | ], 34 | "normalizer": null, 35 | "pre_tokenizer": { 36 | "type": "ByteLevel", 37 | "add_prefix_space": false, 38 | "trim_offsets": true 39 | }, 40 | "post_processor": { 41 | "type": "ByteLevel", 42 | "add_prefix_space": true, 43 | "trim_offsets": true 44 | }, 45 | "decoder": { 46 | "type": "ByteLevel", 47 | "add_prefix_space": true, 48 | "trim_offsets": true 49 | }, 50 | "model": { 51 | "type": "BPE", 52 | "dropout": null, 53 | "unk_token": null, 54 | "continuing_subword_prefix": null, 55 | "end_of_word_suffix": null, 56 | "fuse_unk": false, 57 | "vocab": { 58 | "<|pad|>": 0, 59 | "<|bos|>": 1, 60 | "<|eos|>": 2, 61 | "1": 3, 62 | "2": 4, 63 | "A": 5, 64 | "B": 6, 65 | "C": 7, 66 | "D": 8, 67 | "E": 9, 68 | "F": 10, 69 | "G": 11, 70 | "H": 12, 71 | "I": 13, 72 | "K": 14, 73 | "L": 15, 74 | "M": 16, 75 | "N": 17, 76 | "O": 18, 77 | "P": 19, 78 | "Q": 20, 79 | "R": 21, 80 | "S": 22, 81 | "T": 23, 82 | "U": 24, 83 | "V": 25, 84 | "W": 26, 85 | "X": 27, 86 | "Y": 28, 87 | "Z": 29 88 | }, 89 | "merges": [] 90 | } 91 | } --------------------------------------------------------------------------------