├── requirements.txt ├── .gitignore ├── GPT2 ├── config.py ├── sample.py ├── utils.py ├── encoder.py └── model.py ├── LICENSE ├── main.py ├── README.md └── GPT2_Pytorch.ipynb /requirements.txt: -------------------------------------------------------------------------------- 1 | regex==2017.4.5 -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | venv 3 | models 4 | gpt2-pytorch_model.bin 5 | __pycache__ -------------------------------------------------------------------------------- /GPT2/config.py: -------------------------------------------------------------------------------- 1 | ''' 2 | code by TaeHwan Jung(@graykode) 3 | Original Paper and repository here : https://github.com/openai/gpt-2 4 | GPT2 Pytorch Model : https://github.com/huggingface/pytorch-pretrained-BERT 5 | ''' 6 | class GPT2Config(object): 7 | def __init__( 8 | self, 9 | vocab_size_or_config_json_file=50257, 10 | n_positions=1024, 11 | n_ctx=1024, 12 | n_embd=768, 13 | n_layer=12, 14 | n_head=12, 15 | layer_norm_epsilon=1e-5, 16 | initializer_range=0.02, 17 | ): 18 | self.vocab_size = vocab_size_or_config_json_file 19 | self.n_ctx = n_ctx 20 | self.n_positions = n_positions 21 | self.n_embd = n_embd 22 | self.n_layer = n_layer 23 | self.n_head = n_head 24 | self.layer_norm_epsilon = layer_norm_epsilon 25 | self.initializer_range = initializer_range -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 OpenAI, HugginFace Inc. team. and TaeHwan Jung 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /GPT2/sample.py: -------------------------------------------------------------------------------- 1 | ''' 2 | code by TaeHwan Jung(@graykode) 3 | Original Paper and repository here : https://github.com/openai/gpt-2 4 | GPT2 Pytorch Model : https://github.com/huggingface/pytorch-pretrained-BERT 5 | ''' 6 | import torch 7 | import torch.nn.functional as F 8 | from tqdm import trange 9 | 10 | def top_k_logits(logits, k): 11 | if k == 0: 12 | return logits 13 | values, _ = torch.topk(logits, k) 14 | min_values = values[:, -1] 15 | return torch.where(logits < min_values, torch.ones_like(logits, dtype=logits.dtype) * -1e10, logits) 16 | 17 | def sample_sequence(model, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0, device='cuda', sample=True): 18 | if start_token is None: 19 | assert context is not None, 'Specify exactly one of start_token and context!' 20 | context = torch.tensor(context, device=device, dtype=torch.long).unsqueeze(0).repeat(batch_size, 1) 21 | else: 22 | assert context is None, 'Specify exactly one of start_token and context!' 23 | context = torch.full((batch_size, 1), start_token, device=device, dtype=torch.long) 24 | prev = context 25 | output = context 26 | past = None 27 | with torch.no_grad(): 28 | for i in trange(length): 29 | logits, past = model(prev, past=past) 30 | logits = logits[:, -1, :] / temperature 31 | logits = top_k_logits(logits, k=top_k) 32 | log_probs = F.softmax(logits, dim=-1) 33 | if sample: 34 | prev = torch.multinomial(log_probs, num_samples=1) 35 | else: 36 | _, prev = torch.topk(log_probs, k=1, dim=-1) 37 | output = torch.cat((output, prev), dim=1) 38 | return output -------------------------------------------------------------------------------- /GPT2/utils.py: -------------------------------------------------------------------------------- 1 | ''' 2 | code by TaeHwan Jung(@graykode) 3 | Original Paper and repository here : https://github.com/openai/gpt-2 4 | GPT2 Pytorch Model : https://github.com/huggingface/pytorch-pretrained-BERT 5 | ''' 6 | import logging 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | def load_weight(model, state_dict): 11 | old_keys = [] 12 | new_keys = [] 13 | for key in state_dict.keys(): 14 | new_key = None 15 | if key.endswith(".g"): 16 | new_key = key[:-2] + ".weight" 17 | elif key.endswith(".b"): 18 | new_key = key[:-2] + ".bias" 19 | elif key.endswith(".w"): 20 | new_key = key[:-2] + ".weight" 21 | if new_key: 22 | old_keys.append(key) 23 | new_keys.append(new_key) 24 | for old_key, new_key in zip(old_keys, new_keys): 25 | state_dict[new_key] = state_dict.pop(old_key) 26 | 27 | missing_keys = [] 28 | unexpected_keys = [] 29 | error_msgs = [] 30 | # copy state_dict so _load_from_state_dict can modify it 31 | metadata = getattr(state_dict, "_metadata", None) 32 | state_dict = state_dict.copy() 33 | if metadata is not None: 34 | state_dict._metadata = metadata 35 | 36 | def load(module, prefix=""): 37 | local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) 38 | module._load_from_state_dict( 39 | state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs 40 | ) 41 | for name, child in module._modules.items(): 42 | if child is not None: 43 | load(child, prefix + name + ".") 44 | 45 | start_model = model 46 | if hasattr(model, "transformer") and all(not s.startswith('transformer.') for s in state_dict.keys()): 47 | start_model = model.transformer 48 | load(start_model, prefix="") 49 | 50 | # Make sure we are still sharing the output and input embeddings after loading weights 51 | model.set_tied() 52 | return model -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | ''' 2 | code by TaeHwan Jung(@graykode) 3 | Original Paper and repository here : https://github.com/openai/gpt-2 4 | GPT2 Pytorch Model : https://github.com/huggingface/pytorch-pretrained-BERT 5 | ''' 6 | import os 7 | import sys 8 | import torch 9 | import random 10 | import argparse 11 | import numpy as np 12 | from GPT2.model import (GPT2LMHeadModel) 13 | from GPT2.utils import load_weight 14 | from GPT2.config import GPT2Config 15 | from GPT2.sample import sample_sequence 16 | from GPT2.encoder import get_encoder 17 | 18 | def text_generator(state_dict): 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument("--text", type=str, required=True) 21 | parser.add_argument("--quiet", type=bool, default=False) 22 | parser.add_argument("--nsamples", type=int, default=1) 23 | parser.add_argument('--unconditional', action='store_true', help='If true, unconditional generation.') 24 | parser.add_argument("--batch_size", type=int, default=-1) 25 | parser.add_argument("--length", type=int, default=-1) 26 | parser.add_argument("--temperature", type=float, default=0.7) 27 | parser.add_argument("--top_k", type=int, default=40) 28 | args = parser.parse_args() 29 | 30 | if args.quiet is False: 31 | print(args) 32 | 33 | if args.batch_size == -1: 34 | args.batch_size = 1 35 | assert args.nsamples % args.batch_size == 0 36 | 37 | seed = random.randint(0, 2147483647) 38 | np.random.seed(seed) 39 | torch.random.manual_seed(seed) 40 | torch.cuda.manual_seed(seed) 41 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 42 | 43 | # Load Model 44 | enc = get_encoder() 45 | config = GPT2Config() 46 | model = GPT2LMHeadModel(config) 47 | model = load_weight(model, state_dict) 48 | model.to(device) 49 | model.eval() 50 | 51 | if args.length == -1: 52 | args.length = config.n_ctx // 2 53 | elif args.length > config.n_ctx: 54 | raise ValueError("Can't get samples longer than window size: %s" % config.n_ctx) 55 | 56 | print(args.text) 57 | context_tokens = enc.encode(args.text) 58 | 59 | generated = 0 60 | for _ in range(args.nsamples // args.batch_size): 61 | out = sample_sequence( 62 | model=model, length=args.length, 63 | context=context_tokens if not args.unconditional else None, 64 | start_token=enc.encoder['<|endoftext|>'] if args.unconditional else None, 65 | batch_size=args.batch_size, 66 | temperature=args.temperature, top_k=args.top_k, device=device 67 | ) 68 | out = out[:, len(context_tokens):].tolist() 69 | for i in range(args.batch_size): 70 | generated += 1 71 | text = enc.decode(out[i]) 72 | if args.quiet is False: 73 | print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40) 74 | print(text) 75 | 76 | if __name__ == '__main__': 77 | if os.path.exists('gpt2-pytorch_model.bin'): 78 | state_dict = torch.load('gpt2-pytorch_model.bin', map_location='cpu' if not torch.cuda.is_available() else None) 79 | text_generator(state_dict) 80 | else: 81 | print('Please download gpt2-pytorch_model.bin') 82 | sys.exit() 83 | -------------------------------------------------------------------------------- /GPT2/encoder.py: -------------------------------------------------------------------------------- 1 | """Byte pair encoding utilities""" 2 | 3 | import os 4 | import json 5 | import regex as re 6 | from functools import lru_cache 7 | 8 | @lru_cache() 9 | def bytes_to_unicode(): 10 | """ 11 | Returns list of utf-8 byte and a corresponding list of unicode strings. 12 | The reversible bpe codes work on unicode strings. 13 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 14 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 15 | This is a signficant percentage of your normal, say, 32K bpe vocab. 16 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 17 | And avoids mapping to whitespace/control characters the bpe code barfs on. 18 | """ 19 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 20 | cs = bs[:] 21 | n = 0 22 | for b in range(2**8): 23 | if b not in bs: 24 | bs.append(b) 25 | cs.append(2**8+n) 26 | n += 1 27 | cs = [chr(n) for n in cs] 28 | return dict(zip(bs, cs)) 29 | 30 | def get_pairs(word): 31 | """Return set of symbol pairs in a word. 32 | Word is represented as tuple of symbols (symbols being variable-length strings). 33 | """ 34 | pairs = set() 35 | prev_char = word[0] 36 | for char in word[1:]: 37 | pairs.add((prev_char, char)) 38 | prev_char = char 39 | return pairs 40 | 41 | class Encoder: 42 | def __init__(self, encoder, bpe_merges, errors='replace'): 43 | self.encoder = encoder 44 | self.decoder = {v:k for k,v in self.encoder.items()} 45 | self.errors = errors # how to handle errors in decoding 46 | self.byte_encoder = bytes_to_unicode() 47 | self.byte_decoder = {v:k for k, v in self.byte_encoder.items()} 48 | self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) 49 | self.cache = {} 50 | 51 | # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions 52 | self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") 53 | 54 | def bpe(self, token): 55 | if token in self.cache: 56 | return self.cache[token] 57 | word = tuple(token) 58 | pairs = get_pairs(word) 59 | 60 | if not pairs: 61 | return token 62 | 63 | while True: 64 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 65 | if bigram not in self.bpe_ranks: 66 | break 67 | first, second = bigram 68 | new_word = [] 69 | i = 0 70 | while i < len(word): 71 | try: 72 | j = word.index(first, i) 73 | new_word.extend(word[i:j]) 74 | i = j 75 | except: 76 | new_word.extend(word[i:]) 77 | break 78 | 79 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 80 | new_word.append(first+second) 81 | i += 2 82 | else: 83 | new_word.append(word[i]) 84 | i += 1 85 | new_word = tuple(new_word) 86 | word = new_word 87 | if len(word) == 1: 88 | break 89 | else: 90 | pairs = get_pairs(word) 91 | word = ' '.join(word) 92 | self.cache[token] = word 93 | return word 94 | 95 | def encode(self, text): 96 | bpe_tokens = [] 97 | for token in re.findall(self.pat, text): 98 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 99 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 100 | return bpe_tokens 101 | 102 | def decode(self, tokens): 103 | text = ''.join([self.decoder[token] for token in tokens]) 104 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors) 105 | return text 106 | 107 | def get_encoder(): 108 | with open('./GPT2/encoder.json', 'r') as f: 109 | encoder = json.load(f) 110 | with open('./GPT2/vocab.bpe', 'r', encoding="utf-8") as f: 111 | bpe_data = f.read() 112 | bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]] 113 | return Encoder( 114 | encoder=encoder, 115 | bpe_merges=bpe_merges, 116 | ) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## **GPT2-Pytorch with Text-Generator** 2 | 3 |

4 | 5 | **Better Language Models and Their Implications** 6 | 7 | > Our model, called GPT-2 (a successor to [GPT](https://blog.openai.com/language-unsupervised/)), was trained simply to predict the next word in 40GB of Internet text. Due to our concerns about malicious applications of the technology, we are not releasing the trained model. As an experiment in responsible disclosure, we are instead releasing a much [smaller model](https://github.com/openai/gpt-2) for researchers to experiment with, as well as a [technical paper](https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf). from [openAI Blog](https://blog.openai.com/better-language-models/) 8 | 9 | This repository is simple implementation GPT-2 about **text-generator** in **Pytorch** with **compress code** 10 | 11 | - The original repertoire is [openai/gpt-2](https://github.com/openai/gpt-2). Also You can Read Paper about gpt-2, ["Language Models are Unsupervised Multitask Learners"](https://d4mucfpksywv.cloudfront.net/better-language-models/language-models.pdf). To Understand more detail concept, I recommend papers about Transformer Model. 12 | - Good implementation GPT-2 in Pytorch which I referred to, [huggingface/pytorch-pretrained-BERT](https://github.com/huggingface/pytorch-pretrained-BERT), You can see more detail implementation in huggingface repository. 13 | 14 | - Transformer(Self-Attention) Paper : [Attention Is All You Need(2017)](https://arxiv.org/abs/1706.03762) 15 | - First OpenAi-GPT Paper : [Improving Language Understanding by Generative Pre-Training(2018)](https://s3-us-west-2.amazonaws.com/openai-assets/research-covers/language-unsupervised/language_understanding_paper.pdf) 16 | - See [OpenAI Blog](https://blog.openai.com/better-language-models/) about GPT-2 and Paper 17 | 18 | 19 | 20 | ## Quick Start 21 | 22 | 1. download GPT2 pre-trained model in Pytorch which huggingface/pytorch-pretrained-BERT already made! (Thanks for sharing! it's help my problem transferring tensorflow(ckpt) file to Pytorch Model!) 23 | ```shell 24 | $ git clone https://github.com/graykode/gpt-2-Pytorch && cd gpt-2-Pytorch 25 | # download huggingface's pytorch model 26 | $ curl --output gpt2-pytorch_model.bin https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin 27 | # setup requirements, if using mac os, then run additional setup as descibed below 28 | $ pip install -r requirements.txt 29 | ``` 30 | 31 | 32 | 2. Now, You can run like this. 33 | 34 | - Text from Book 1984, George Orwell 35 | 36 | ```shell 37 | $ python main.py --text "It was a bright cold day in April, and the clocks were striking thirteen. Winston Smith, his chin nuzzled into his breast in an effort to escape the vile wind, slipped quickly through the glass doors of Victory Mansions, though not quickly enough to prevent a swirl of gritty dust from entering along with him." 38 | ``` 39 | 40 | 3. Also You can Quick Starting in [Google Colab](https://colab.research.google.com/github/graykode/gpt-2-Pytorch/blob/master/GPT2_Pytorch.ipynb) 41 | 42 | 43 | 44 | ## Option 45 | 46 | - `--text` : sentence to begin with. 47 | - `--quiet` : not print all of the extraneous stuff like the "================" 48 | - `--nsamples` : number of sample sampled in batch when multinomial function use 49 | - `--unconditional` : If true, unconditional generation. 50 | - `--batch_size` : number of batch size 51 | - `--length` : sentence length (< number of context) 52 | - `--temperature`: the thermodynamic temperature in distribution `(default 0.7)` 53 | - `--top_k` : Returns the top k largest elements of the given input tensor along a given dimension. `(default 40)` 54 | 55 | See more detail option about `temperature` and `top_k` in [here](https://github.com/openai/gpt-2#gpt-2-samples) 56 | 57 | 58 | 59 | ## Dependencies 60 | 61 | - Pytorch 0.41+ 62 | - regex 2017.4.5 63 | 64 | ### Mac OS Setup 65 | ```shell 66 | $ python3 -m venv venv 67 | $ source venv/bin/activate 68 | $ pip install torch tqdm 69 | $ brew install libomp 70 | $ export LC_ALL=en_US.UTF-8 71 | $ export LANG=en_US.UTF-8 72 | $ pip install -r requirements.txt 73 | ``` 74 | 75 | ## Author 76 | 77 | - Tae Hwan Jung(Jeff Jung) @graykode 78 | - Author Email : [nlkey2022@gmail.com](mailto:nlkey2022@gmail.com) 79 | 80 | 81 | 82 | ## License 83 | 84 | - OpenAi/GPT2 follow MIT license, huggingface/pytorch-pretrained-BERT is Apache license. 85 | - I follow MIT license with original GPT2 repository 86 | 87 | 88 | 89 | ## Acknowledgement 90 | 91 | [Jeff Wu(@WuTheFWasThat)](https://github.com/WuTheFWasThat), [Thomas Wolf(@thomwolf)](https://github.com/thomwolf) for allowing referring code. -------------------------------------------------------------------------------- /GPT2_Pytorch.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "GPT2-Pytorch.ipynb", 7 | "version": "0.3.2", 8 | "provenance": [], 9 | "collapsed_sections": [] 10 | }, 11 | "kernelspec": { 12 | "name": "python3", 13 | "display_name": "Python 3" 14 | }, 15 | "accelerator": "GPU" 16 | }, 17 | "cells": [ 18 | { 19 | "metadata": { 20 | "id": "4L7HpNaFHO5D", 21 | "colab_type": "code", 22 | "outputId": "a97fd9ea-0390-473c-f96a-2b0e32e70cca", 23 | "colab": { 24 | "base_uri": "https://localhost:8080/", 25 | "height": 373 26 | } 27 | }, 28 | "cell_type": "code", 29 | "source": [ 30 | "!git clone https://github.com/graykode/gpt-2-Pytorch\n", 31 | "%cd gpt-2-Pytorch\n", 32 | "!curl --output gpt2-pytorch_model.bin https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin\n", 33 | "!pip install -r requirements.txt" 34 | ], 35 | "execution_count": 1, 36 | "outputs": [ 37 | { 38 | "output_type": "stream", 39 | "text": [ 40 | "Cloning into 'gpt-2-Pytorch'...\n", 41 | "remote: Enumerating objects: 51, done.\u001b[K\n", 42 | "remote: Counting objects: 100% (51/51), done.\u001b[K\n", 43 | "remote: Compressing objects: 100% (40/40), done.\u001b[K\n", 44 | "remote: Total 51 (delta 15), reused 44 (delta 9), pack-reused 0\u001b[K\n", 45 | "Unpacking objects: 100% (51/51), done.\n", 46 | "/content/gpt-2-Pytorch\n", 47 | " % Total % Received % Xferd Average Speed Time Time Time Current\n", 48 | " Dload Upload Total Spent Left Speed\n", 49 | "100 522M 100 522M 0 0 77.7M 0 0:00:06 0:00:06 --:--:-- 81.3M\n", 50 | "Collecting fire>=0.1.3 (from -r requirements.txt (line 1))\n", 51 | " Downloading https://files.pythonhosted.org/packages/5a/b7/205702f348aab198baecd1d8344a90748cb68f53bdcd1cc30cbc08e47d3e/fire-0.1.3.tar.gz\n", 52 | "Collecting regex==2017.4.5 (from -r requirements.txt (line 2))\n", 53 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/36/62/c0c0d762ffd4ffaf39f372eb8561b8d491a11ace5a7884610424a8b40f95/regex-2017.04.05.tar.gz (601kB)\n", 54 | "\u001b[K 100% |████████████████████████████████| 604kB 26.2MB/s \n", 55 | "\u001b[?25hCollecting gpt2-pytorch_model.bin (from -r requirements.txt (line 3))\n", 56 | "\u001b[31m Could not find a version that satisfies the requirement gpt2-pytorch_model.bin (from -r requirements.txt (line 3)) (from versions: )\u001b[0m\n", 57 | "\u001b[31mNo matching distribution found for gpt2-pytorch_model.bin (from -r requirements.txt (line 3))\u001b[0m\n" 58 | ], 59 | "name": "stdout" 60 | } 61 | ] 62 | }, 63 | { 64 | "metadata": { 65 | "id": "he_YiEC9T6-D", 66 | "colab_type": "code", 67 | "colab": { 68 | "base_uri": "https://localhost:8080/", 69 | "height": 205 70 | }, 71 | "outputId": "58e3f04e-905a-4adf-fd0c-98c682abe9d2" 72 | }, 73 | "cell_type": "code", 74 | "source": [ 75 | "!python main.py --text \"Once when I was six years old I saw a magnificent picture in a book, called True Stories from Nature, about the primeval forest.\"" 76 | ], 77 | "execution_count": 2, 78 | "outputs": [ 79 | { 80 | "output_type": "stream", 81 | "text": [ 82 | "Namespace(batch_size=-1, length=-1, nsamples=1, seed=0, temperature=1, text='Once when I was six years old I saw a magnificent picture in a book, called True Stories from Nature, about the primeval forest.', top_k=0, unconditional=False)\n", 83 | "Once when I was six years old I saw a magnificent picture in a book, called True Stories from Nature, about the primeval forest.\n", 84 | "[7454, 618, 314, 373, 2237, 812, 1468, 314, 2497, 257, 25023, 4286, 287, 257, 1492, 11, 1444, 6407, 18152, 422, 10362, 11, 546, 262, 6994, 2100, 8222, 13]\n", 85 | "100% 512/512 [00:12<00:00, 40.58it/s]\n", 86 | "======================================== SAMPLE 1 ========================================\n", 87 | " I didn't know about it until a few months later when comparing the figure of ichthyosis emblazoned on a top of mushroom scales with misinformation spread by reputable population labs that discovered breeders of various breeds and classes of plants that are more closely related to these plants than others, at least according to their interpretation of what was being observed, very closely adapted to replace the need for the invasive species, hence the perception that only some had even arrived where the original frontier dogs would not. A few good photos were taken of lumps hooked around short arelets, showing far more frequently than others. The large puff-thraps (chloroxy pickshares) that have grown up very near Mare Crosloquifroides: a species using the word to export its products frepeace, far more than it had acquired at Mare Crosloquifroides. The coffee fennel (pelleaphrotechus) has got several, distinct markings in particular. The pinks were some from bloom and others from tame stages. Again, almost, all the people of these flocks experienced this species on an uneventful day.\n", 88 | "\n", 89 | "It strays to the story of Brammas Nostalgia made popular by that particular years-old-islet. It told a incurable and unfathomable story about strong women abused without violence or the death call of the grand poet who so loved an East Asian wife in spite of violent boyfriends who'd filiated her, destroyed wife, and threatened her on the street (and still \"Rammas Sadasta\" period) with a triangle wrapped around her finger, which the common thief would wrap around an animal's hand, regardless of true love or it to shaving half of her bitten pink cat's back, and would wrap herself in a stuffed dog-sweat sack that Kimi visited when she went to America with a crush on her black-muscled shoes or on her Manchester level Stade de Pau (Date Of Birth). The Lions Superinformations Black Pockets grew on pilgrimage to Calais somewhere to see help those soldiers and stay quiet about raging racism in our cities really makper up such stories as did Brammas Nostalgia. That all comes off as too romantic to be true in Molly Szierabich's novel Rhizopus. Remember the liberal megalopolis in BBC World-Herald Affair where the infamous Milo Yiannopoulos extravaganza featured such scenes? That do it now as a museum in Museum de Citat? It was set to\n", 90 | "================================================================================\n" 91 | ], 92 | "name": "stdout" 93 | } 94 | ] 95 | } 96 | ] 97 | } -------------------------------------------------------------------------------- /GPT2/model.py: -------------------------------------------------------------------------------- 1 | ''' 2 | code by TaeHwan Jung(@graykode) 3 | Original Paper and repository here : https://github.com/openai/gpt-2 4 | GPT2 Pytorch Model : https://github.com/huggingface/pytorch-pretrained-BERT 5 | ''' 6 | import copy 7 | import torch 8 | import math 9 | import torch.nn as nn 10 | from torch.nn.parameter import Parameter 11 | 12 | def gelu(x): 13 | return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 14 | 15 | class LayerNorm(nn.Module): 16 | def __init__(self, hidden_size, eps=1e-12): 17 | """Construct a layernorm module in the TF style (epsilon inside the square root). 18 | """ 19 | super(LayerNorm, self).__init__() 20 | self.weight = nn.Parameter(torch.ones(hidden_size)) 21 | self.bias = nn.Parameter(torch.zeros(hidden_size)) 22 | self.variance_epsilon = eps 23 | 24 | def forward(self, x): 25 | u = x.mean(-1, keepdim=True) 26 | s = (x - u).pow(2).mean(-1, keepdim=True) 27 | x = (x - u) / torch.sqrt(s + self.variance_epsilon) 28 | return self.weight * x + self.bias 29 | 30 | class Conv1D(nn.Module): 31 | def __init__(self, nf, nx): 32 | super(Conv1D, self).__init__() 33 | self.nf = nf 34 | w = torch.empty(nx, nf) 35 | nn.init.normal_(w, std=0.02) 36 | self.weight = Parameter(w) 37 | self.bias = Parameter(torch.zeros(nf)) 38 | 39 | def forward(self, x): 40 | size_out = x.size()[:-1] + (self.nf,) 41 | x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight) 42 | x = x.view(*size_out) 43 | return x 44 | 45 | class Attention(nn.Module): 46 | def __init__(self, nx, n_ctx, config, scale=False): 47 | super(Attention, self).__init__() 48 | n_state = nx # in Attention: n_state=768 (nx=n_embd) 49 | # [switch nx => n_state from Block to Attention to keep identical to TF implem] 50 | assert n_state % config.n_head == 0 51 | self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx)) 52 | self.n_head = config.n_head 53 | self.split_size = n_state 54 | self.scale = scale 55 | self.c_attn = Conv1D(n_state * 3, nx) 56 | self.c_proj = Conv1D(n_state, nx) 57 | 58 | def _attn(self, q, k, v): 59 | w = torch.matmul(q, k) 60 | if self.scale: 61 | w = w / math.sqrt(v.size(-1)) 62 | nd, ns = w.size(-2), w.size(-1) 63 | b = self.bias[:, :, ns-nd:ns, :ns] 64 | w = w * b - 1e10 * (1 - b) 65 | w = nn.Softmax(dim=-1)(w) 66 | return torch.matmul(w, v) 67 | 68 | def merge_heads(self, x): 69 | x = x.permute(0, 2, 1, 3).contiguous() 70 | new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),) 71 | return x.view(*new_x_shape) # in Tensorflow implem: fct merge_states 72 | 73 | def split_heads(self, x, k=False): 74 | new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head) 75 | x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states 76 | if k: 77 | return x.permute(0, 2, 3, 1) # (batch, head, head_features, seq_length) 78 | else: 79 | return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) 80 | 81 | def forward(self, x, layer_past=None): 82 | x = self.c_attn(x) 83 | query, key, value = x.split(self.split_size, dim=2) 84 | query = self.split_heads(query) 85 | key = self.split_heads(key, k=True) 86 | value = self.split_heads(value) 87 | if layer_past is not None: 88 | past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1] # transpose back cf below 89 | key = torch.cat((past_key, key), dim=-1) 90 | value = torch.cat((past_value, value), dim=-2) 91 | present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking 92 | a = self._attn(query, key, value) 93 | a = self.merge_heads(a) 94 | a = self.c_proj(a) 95 | return a, present 96 | 97 | class MLP(nn.Module): 98 | def __init__(self, n_state, config): # in MLP: n_state=3072 (4 * n_embd) 99 | super(MLP, self).__init__() 100 | nx = config.n_embd 101 | self.c_fc = Conv1D(n_state, nx) 102 | self.c_proj = Conv1D(nx, n_state) 103 | self.act = gelu 104 | 105 | def forward(self, x): 106 | h = self.act(self.c_fc(x)) 107 | h2 = self.c_proj(h) 108 | return h2 109 | 110 | class Block(nn.Module): 111 | def __init__(self, n_ctx, config, scale=False): 112 | super(Block, self).__init__() 113 | nx = config.n_embd 114 | self.ln_1 = LayerNorm(nx, eps=config.layer_norm_epsilon) 115 | self.attn = Attention(nx, n_ctx, config, scale) 116 | self.ln_2 = LayerNorm(nx, eps=config.layer_norm_epsilon) 117 | self.mlp = MLP(4 * nx, config) 118 | 119 | def forward(self, x, layer_past=None): 120 | a, present = self.attn(self.ln_1(x), layer_past=layer_past) 121 | x = x + a 122 | m = self.mlp(self.ln_2(x)) 123 | x = x + m 124 | return x, present 125 | 126 | class GPT2Model(nn.Module): 127 | def __init__(self, config): 128 | super(GPT2Model, self).__init__() 129 | self.n_layer = config.n_layer 130 | self.n_embd = config.n_embd 131 | self.n_vocab = config.vocab_size 132 | 133 | self.wte = nn.Embedding(config.vocab_size, config.n_embd) 134 | self.wpe = nn.Embedding(config.n_positions, config.n_embd) 135 | block = Block(config.n_ctx, config, scale=True) 136 | self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)]) 137 | self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) 138 | 139 | def set_embeddings_weights(self, model_embeddings_weights): 140 | embed_shape = model_embeddings_weights.shape 141 | self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False) 142 | self.decoder.weight = model_embeddings_weights # Tied weights 143 | 144 | def forward(self, input_ids, position_ids=None, token_type_ids=None, past=None): 145 | if past is None: 146 | past_length = 0 147 | past = [None] * len(self.h) 148 | else: 149 | past_length = past[0][0].size(-2) 150 | if position_ids is None: 151 | position_ids = torch.arange(past_length, input_ids.size(-1) + past_length, dtype=torch.long, 152 | device=input_ids.device) 153 | position_ids = position_ids.unsqueeze(0).expand_as(input_ids) 154 | 155 | input_shape = input_ids.size() 156 | input_ids = input_ids.view(-1, input_ids.size(-1)) 157 | position_ids = position_ids.view(-1, position_ids.size(-1)) 158 | 159 | inputs_embeds = self.wte(input_ids) 160 | position_embeds = self.wpe(position_ids) 161 | if token_type_ids is not None: 162 | token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) 163 | token_type_embeds = self.wte(token_type_ids) 164 | else: 165 | token_type_embeds = 0 166 | hidden_states = inputs_embeds + position_embeds + token_type_embeds 167 | presents = [] 168 | for block, layer_past in zip(self.h, past): 169 | hidden_states, present = block(hidden_states, layer_past) 170 | presents.append(present) 171 | hidden_states = self.ln_f(hidden_states) 172 | output_shape = input_shape + (hidden_states.size(-1),) 173 | return hidden_states.view(*output_shape), presents 174 | 175 | class GPT2LMHead(nn.Module): 176 | def __init__(self, model_embeddings_weights, config): 177 | super(GPT2LMHead, self).__init__() 178 | self.n_embd = config.n_embd 179 | self.set_embeddings_weights(model_embeddings_weights) 180 | 181 | def set_embeddings_weights(self, model_embeddings_weights): 182 | embed_shape = model_embeddings_weights.shape 183 | self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False) 184 | self.decoder.weight = model_embeddings_weights # Tied weights 185 | 186 | def forward(self, hidden_state): 187 | # Truncated Language modeling logits (we remove the last token) 188 | # h_trunc = h[:, :-1].contiguous().view(-1, self.n_embd) 189 | lm_logits = self.decoder(hidden_state) 190 | return lm_logits 191 | 192 | class GPT2LMHeadModel(nn.Module): 193 | def __init__(self, config): 194 | super(GPT2LMHeadModel, self).__init__() 195 | self.transformer = GPT2Model(config) 196 | self.lm_head = GPT2LMHead(self.transformer.wte.weight, config) 197 | 198 | def set_tied(self): 199 | """ Make sure we are sharing the embeddings 200 | """ 201 | self.lm_head.set_embeddings_weights(self.transformer.wte.weight) 202 | 203 | def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, past=None): 204 | hidden_states, presents = self.transformer(input_ids, position_ids, token_type_ids, past) 205 | lm_logits = self.lm_head(hidden_states) 206 | if lm_labels is not None: 207 | loss_fct = nn.CrossEntropyLoss(ignore_index=-1) 208 | loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), lm_labels.view(-1)) 209 | return loss 210 | return lm_logits, presents --------------------------------------------------------------------------------