├── docs ├── CNAME ├── loggers.md ├── gpt-2-simple.md ├── upload.md ├── ethics.md ├── tutorials │ ├── colab.md │ ├── hello-world.md │ ├── model-from-scratch.md │ └── generate_1_5b.md ├── generate-performance.md ├── cli.md ├── save-model.md ├── index.md ├── helpful-notes.md ├── load-model.md ├── generate.md └── dataset.md ├── MANIFEST.in ├── aitextgen ├── __init__.py ├── cli.py ├── colab.py ├── tokenizers.py ├── utils.py ├── train.py ├── TokenDataset.py └── aitextgen.py ├── requirements.txt ├── .gitignore ├── Dockerfile ├── ROADMAP.md ├── .github └── FUNDING.yml ├── setup.py ├── LICENSE ├── UPCOMING.md ├── mkdocs.yml ├── CHANGELOG.md ├── DESIGN.md ├── README.md └── notebooks ├── reddit_demo.ipynb ├── generation_hello_world.ipynb ├── hacker_news_demo.ipynb └── training_hello_world.ipynb /docs/CNAME: -------------------------------------------------------------------------------- 1 | docs.aitextgen.io -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | graft */static 2 | global-exclude .DS_Store 3 | -------------------------------------------------------------------------------- /aitextgen/__init__.py: -------------------------------------------------------------------------------- 1 | from .aitextgen import aitextgen # noqa 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers>=4.5.1 2 | fire>=0.3.0 3 | pytorch-lightning>=1.8.0 4 | torch>=1.6.0 5 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | test_notebooks/ 3 | /build 4 | /dist 5 | *.egg-info 6 | .vscode/settings.json 7 | /site 8 | .DS_Store -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.8.6-slim 2 | 3 | RUN apt-get -y update && apt-get -y install gcc 4 | 5 | WORKDIR / 6 | 7 | COPY requirements.txt . 8 | 9 | # install dependencies 10 | RUN pip --no-cache-dir install -r requirements.txt 11 | COPY * / 12 | 13 | # Clean up APT when done. 14 | RUN apt-get clean && rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* -------------------------------------------------------------------------------- /docs/loggers.md: -------------------------------------------------------------------------------- 1 | # Loggers 2 | 3 | You can create loggers with popular tools such as [TensorBoard](https://www.tensorflow.org/tensorboard) and [Weights and Biases](https://www.wandb.com) by leveraging pytorch-lightning's logger functionality. 4 | 5 | [See their documentation](https://pytorch-lightning.readthedocs.io/en/stable/loggers.html) on all the available options for loggers. 6 | 7 | For example, if you want to create a TensorBoard logger, you can create it: 8 | 9 | ```py3 10 | from pytorch_lightning import loggers 11 | 12 | tb_logger = loggers.TensorBoardLogger('logs/') 13 | ``` 14 | 15 | Then pass it to the `loggers` parameter for `ai.train()`. 16 | 17 | ```py3 18 | ai.train(train_data=data, loggers=tb_logger) 19 | ``` 20 | -------------------------------------------------------------------------------- /ROADMAP.md: -------------------------------------------------------------------------------- 1 | # Roadmap 2 | 3 | A rough roadmap for implementing new features. **All is subject to change at a moment's notice.** 4 | 5 | ## Launch 6 | 7 | - Training using pytorch-lightning, with suppport for fp16 and Colab TPUs. 8 | - Training a GPT-2 model from scratch w/ parametricized context window sizes and parameters 9 | - PyTorch support for training/generating 10 | - Generation from Transformer's native generate() function 11 | - Actual documentation 12 | - Examples 13 | - Training on a CPU 14 | - Training on a GPU 15 | - Training on multiple GPU (4x T4) 16 | - Training on a TPU 17 | - Cross-Training on Multiple Datasets 18 | - Generate on a CPU 19 | - Generate on a GPU 20 | - API docs for all classes 21 | -------------------------------------------------------------------------------- /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: minimaxir # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] 4 | patreon: minimaxir # Replace with a single Patreon username 5 | open_collective: # Replace with a single Open Collective username 6 | ko_fi: # Replace with a single Ko-fi username 7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel 8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry 9 | liberapay: # Replace with a single Liberapay username 10 | issuehunt: # Replace with a single IssueHunt username 11 | otechie: # Replace with a single Otechie username 12 | custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] 13 | -------------------------------------------------------------------------------- /docs/gpt-2-simple.md: -------------------------------------------------------------------------------- 1 | # Importing from gpt-2-simple 2 | 3 | Want to import a model trained using [gpt-2-simple](https://github.com/minimaxir/gpt-2-simple), or another GPT-2 based finetuning approach? You can do that [using the transformers-cli](https://huggingface.co/transformers/converting_tensorflow_models.html). 4 | 5 | In the case of gpt-2-simple (where the output is structured `checkpoint/run1`), you'd `cd` into the directory containing the `checkpoint` folder and run: 6 | 7 | ```sh 8 | transformers-cli convert --model_type gpt2 --tf_checkpoint checkpoint/run1 --pytorch_dump_output pytorch --config checkpoint/run1/hparams.json 9 | ``` 10 | 11 | This will put a `pytorch_model.bin` and `config.json` in the `pytorch` folder, which is what you'll need to pass to `aitextgen()` to load the model. 12 | 13 | That's it! 14 | -------------------------------------------------------------------------------- /docs/upload.md: -------------------------------------------------------------------------------- 1 | # Upload Model to Huggingface 2 | 3 | You can [upload your trained models](https://huggingface.co/transformers/model_sharing.html) to Huggingface, where it can be downloaded by others! 4 | 5 | To upload your model, you'll have to create a folder which has **6** files: 6 | 7 | - pytorch_model.bin 8 | - config.json 9 | - vocab.json 10 | - merges.txt 11 | - special_tokens_map.json 12 | - tokenizer_config.json 13 | 14 | You can generate all of these files at the same time into a given folder by running `ai.save_for_upload(model_name)`. 15 | 16 | Then, follow the `transformers-cli` instructions to upload the model. 17 | 18 | ```sh 19 | transformers-cli login 20 | ``` 21 | 22 | ```sh 23 | transformers-cli upload model_name 24 | ``` 25 | 26 | You (or another user) can download cache, and generate from that model via: 27 | 28 | ``` 29 | ai = aitextgen(model="username/model_name") 30 | ``` 31 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name="aitextgen", 5 | packages=["aitextgen"], # this must be the same as the name above 6 | version="0.6.0", 7 | description="A robust Python tool for text-based AI training and generation using GPT-2.", 8 | long_description=open("README.md", "r", encoding="utf-8").read(), 9 | long_description_content_type="text/markdown", 10 | author="Max Woolf", 11 | author_email="max@minimaxir.com", 12 | url="https://github.com/minimaxir/aitextgen", 13 | keywords=["gpt-2", "gpt2", "text generation", "ai"], 14 | classifiers=[], 15 | license="MIT", 16 | entry_points={"console_scripts": ["aitextgen=aitextgen.cli:aitextgen_cli"]}, 17 | python_requires=">=3.6", 18 | include_package_data=True, 19 | install_requires=[ 20 | "transformers>=4.5.1", 21 | "fire>=0.3.0", 22 | "pytorch-lightning>=1.7.0", 23 | "torch>=1.6.0", 24 | ], 25 | ) 26 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019-2020 Max Woolf 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /docs/ethics.md: -------------------------------------------------------------------------------- 1 | # Ethics 2 | 3 | aitextgen is a tool primarily intended to help facilitate creative content. It is not a tool intended to deceive. Although parody accounts are an obvious use case for this package, make sure you are _as upfront as possible_ with the methodology of the text you create. This includes: 4 | 5 | - State that the text was generated using aitextgen and/or a GPT-2 model architecture. (a link to this repo would be a bonus!) 6 | - If parodying a person, explicitly state that it is a parody, and reference who it is parodying. 7 | - If the generated human-curated, or if it's unsupervised random output 8 | - Indicating who is maintaining/curating the AI-generated text. 9 | - Make a good-faith effort to remove overfit output from the generated text that matches the input text verbatim. 10 | 11 | It's fun to anthropomorphise the nameless "AI" as an abstract genius, but part of the reason I made aitextgen (and all my previous text-generation projects) is to make the technology more accessible and accurately demonstrate both its promise, and its limitations. **Any AI text generation projects that are deliberately deceptive may be disavowed.** -------------------------------------------------------------------------------- /docs/tutorials/colab.md: -------------------------------------------------------------------------------- 1 | # Colaboratory Notebooks 2 | 3 | You cannot finetune OpenAI's GPT-2 models on CPU (and not even on some consumer GPUs). Therefore, there are a couple Google Colaboratory notebooks, which provide a GPU suitable for finetuning a model. 4 | 5 | The Colab Notebooks also contain utilities to make it easier to export the model to Google Drive during and after training. 6 | 7 | ## Finetuning OpenAI's Model 8 | 9 | [Colab Notebook](https://colab.research.google.com/drive/15qBZx5y9rdaQSyWpsreMDnTiZ5IlN0zD?usp=sharing) 10 | 11 | A Notebook for finetuning OpenAI's model on a GPU. This is the most common use case. 12 | 13 | 14 | !!! note 15 | Currently you can only finetune the 124M/355M/774M OpenAI GPT-2 models, with the latter two forcing `gradient_checkpointing=True` to ensure it does not cause the Colab GPU to go OOM. 16 | 17 | ## Training Your Own GPT-2 Model 18 | 19 | [Colab Notebook](https://colab.research.google.com/drive/144MdX5aLqrQ3-YW-po81CQMrD6kpgpYh?usp=sharing) 20 | 21 | A Notebook for creating your own GPT-2 model with your own tokenizer. See the Model From Scratch section on the advantages and disadvantages of this approach. 22 | -------------------------------------------------------------------------------- /docs/generate-performance.md: -------------------------------------------------------------------------------- 1 | # Improving Generation Performance 2 | 3 | A few tips and tricks for improving generation performance for both on CPUs and GPUs. (note that with these tricks, you cannot train the model afterwards!) 4 | 5 | ## CPU 6 | 7 | ### Quantization 8 | 9 | PyTorch has the ability to quantize models on the CPU. Currently, it will only quantize the Linear layer of GPT-2, but the generation performances increases **15% — 25%**; far from trivial! 10 | 11 | To quantize a model after it's loaded, just run: 12 | 13 | ```py3 14 | ai.quantize() 15 | ``` 16 | 17 | ## GPU 18 | 19 | ### FP16 20 | 21 | Certain GPUs, notably the cheap T4 and the expensive V100, support the ability to process models using FP16, giving massive speed and memory improvements, 22 | 23 | Assuming you are using a compatable GPU and already have [apex](https://github.com/NVIDIA/apex) installed, you can convert a model to the "half" FP16 mode with this: 24 | 25 | ```py3 26 | ai.to_fp16() 27 | ``` 28 | 29 | If you want to convert the model _before_ loading it into GPU memory (which may help avoid memory leaks), you can instantiate the model like this: 30 | 31 | ```py3 32 | ai.to_fp16(to_gpu=True, to_fp16=True) 33 | ``` 34 | 35 | With this, you can generate massive amounts of text from even the GPT-2 1.5B model! 36 | -------------------------------------------------------------------------------- /aitextgen/cli.py: -------------------------------------------------------------------------------- 1 | from .aitextgen import aitextgen 2 | from .TokenDataset import TokenDataset 3 | from .tokenizers import train_tokenizer 4 | import fire 5 | 6 | 7 | def aitextgen_cli(**kwargs): 8 | """Entrypoint for the CLI""" 9 | fire.Fire( 10 | { 11 | "encode": encode_cli, 12 | "train": train_cli, 13 | "generate": generate_cli, 14 | "train_tokenizer": train_tokenizer_cli, 15 | } 16 | ) 17 | 18 | 19 | def encode_cli(file_path: str, **kwargs): 20 | """Encode + compress a dataset""" 21 | TokenDataset(file_path, save_cache=True, **kwargs) 22 | 23 | 24 | def train_cli(file_path: str, **kwargs): 25 | """Train on a dataset.""" 26 | ai = aitextgen(**kwargs) 27 | 28 | from_cache = file_path.endswith(".tar.gz") 29 | dataset = TokenDataset(file_path, from_cache=from_cache, **kwargs) 30 | 31 | ai.train(dataset, **kwargs) 32 | 33 | 34 | def generate_cli(to_file: bool = True, **kwargs): 35 | """Generate from a trained model, or download one if not present.""" 36 | 37 | ai = aitextgen(**kwargs) 38 | if to_file: 39 | ai.generate_to_file(**kwargs) 40 | else: 41 | ai.generate(**kwargs) 42 | 43 | 44 | def train_tokenizer_cli(files: str, **kwargs): 45 | """Trains a tokenizer on the specified file.""" 46 | train_tokenizer(files, **kwargs) 47 | -------------------------------------------------------------------------------- /docs/cli.md: -------------------------------------------------------------------------------- 1 | # Command-Line Interface 2 | 3 | aitextgen has a command-line interface to quickly automate common tasks, and make it less necessary to use a script; helpful if running on a remote server. 4 | 5 | ## Encode 6 | 7 | Encodes given text `text.txt` into a cache and compressed `TokenDataset`, good for prepping a dataset for transit to a remote server. 8 | 9 | ```sh 10 | aitextgen encode text.txt 11 | ``` 12 | 13 | If you are encoding a CSV, you should pass in the `line_by_line` parameter as well. 14 | 15 | ```sh 16 | aitextgen encode reddit.csv --line_by_line True 17 | ``` 18 | 19 | ## Train 20 | 21 | To train/finetune on the default 124M GPT-2, given text `text.txt` and all default parameters: 22 | 23 | ```sh 24 | aitextgen train text.txt 25 | ``` 26 | 27 | If you are using a cached/compressed dataset that ends with `tar.gz` (e.g one created by the Encoding CLI command above), you can pass that to this function as well. 28 | 29 | ```sh 30 | aitextgen train dataset_cache.tar.gz 31 | ``` 32 | 33 | Other parameters to the TokenDataset constructor can be used. 34 | 35 | ## Generate 36 | 37 | Loads a model and generates to a file. 38 | 39 | By default, it will generate 20 texts to the file, 1 at a time at temperature of 0.7. 40 | 41 | ```sh 42 | aitextgen generate 43 | ``` 44 | 45 | You can print to console instead by passing `--to_file False` 46 | 47 | ```sh 48 | aitextgen generate --prompt "I believe in unicorns because" --to_file False 49 | ``` 50 | -------------------------------------------------------------------------------- /aitextgen/colab.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import shutil 3 | import os 4 | 5 | try: 6 | from google.colab import drive 7 | except ImportError: 8 | pass 9 | 10 | 11 | def mount_gdrive(): 12 | """Mounts the user's Google Drive in Colaboratory.""" 13 | assert ( 14 | "google.colab" in sys.modules 15 | ), "You must be in Colaboratory to mount your Google Drive" 16 | 17 | drive.mount("/content/drive") 18 | 19 | 20 | def is_mounted(): 21 | """Checks if the Google Drive is mounted.""" 22 | assert os.path.isdir("/content/drive"), "You must mount first using mount_gdrive()" 23 | 24 | 25 | def copy_file_to_gdrive(file_path, to_folder=None): 26 | """Copies a file to a mounted Google Drive.""" 27 | is_mounted() 28 | 29 | if to_folder: 30 | dest_path = os.path.join("/content/drive/MyDrive/", to_folder, file_path) 31 | else: 32 | dest_path = os.path.join("/content/drive/MyDrive/", file_path) 33 | 34 | shutil.copyfile(file_path, dest_path) 35 | 36 | 37 | def copy_file_from_gdrive(file_path, from_folder=None): 38 | """Copies a file from a mounted Google Drive.""" 39 | is_mounted() 40 | 41 | if from_folder: 42 | source_path = os.path.join("/content/drive/MyDrive/", from_folder, file_path) 43 | else: 44 | source_path = os.path.join("/content/drive/MyDrive/", file_path) 45 | 46 | shutil.copyfile(source_path, file_path) 47 | 48 | 49 | def create_gdrive_folder(folder_name): 50 | """Creates a folder in a mounted Google Drive.""" 51 | is_mounted() 52 | 53 | folder_path = os.path.join("/content/drive/MyDrive/", folder_name) 54 | 55 | if not os.path.exists(folder_path): 56 | os.makedirs(folder_path) 57 | -------------------------------------------------------------------------------- /UPCOMING.md: -------------------------------------------------------------------------------- 1 | # Upcoming Features 2 | 3 | Here is a list of features, in no particular order, of what I _hope_ to add to aitextgen, given that it's feasible. 4 | 5 | ## Training Features 6 | 7 | - Sparse Transformers? 8 | - Include keyword conditioning as well 9 | - - For vocab prefix, use heuristics on tokens, since very speedy and can scale high with strong tokenizers 10 | - Context support, provided as a named dict to a generation function 11 | e sequence stronger than later text (e.g. recent tweets than older tweets!) 12 | 13 | ## Generation Features 14 | 15 | - Allow a user-curated mode 16 | - Allow returning probabilities of next token 17 | - - Console BG coloring of token prediction confidence? 18 | - Allow excluding of tokens by index (e.g. allow model generation without the letter e) 19 | - For postfix prediction, allow returning best guess or probabilities of all classes. 20 | - Calculate text Dimensionality by using activation weights of all tokens prior to EOS token. 21 | - Use dimensionality to calculate a similarity score from generated texts to real texts; scores below a threshold may be considered incoherent and can be discarded. 22 | - Dedupe existing texts when generating using token id matching. 23 | - Allow cycling context tokens when generating 24 | - Unlimited text generation via sliding context window (need to include a warning if doing so) 25 | 26 | ## Deployment Features 27 | 28 | - Use [ray](https://github.com/ray-project/ray) async actors for async generation. May need a custom async generation function. 29 | - Use websockets in [starlette](https://www.starlette.io) so output can be returned token by token. 30 | - Export function: PyTorch trace, TensorFlow Serving, TensorFlow.js, CoreML 31 | 32 | ## Quality-of-Life Features 33 | 34 | - Have a super minimal version that can be distributed with the PyPi package. (< 3MB) 35 | - Include a small model trainable on a CPU, trained from a distilled larger model 36 | - Support any CLM model, not just GPT2. 37 | - Support CPU-based XLA for faster PyTorch CPU Train/Predict performance 38 | -------------------------------------------------------------------------------- /docs/save-model.md: -------------------------------------------------------------------------------- 1 | # Model Saving 2 | 3 | There are are multiple ways to save models. 4 | 5 | Whenever a model is saved, two files are generated: `pytorch_model.bin` which contains the model weights, and `config.json` which is needed to load the model. 6 | 7 | Assuming we have an aitextgen model `ai`: 8 | 9 | ## Ad Hoc saving 10 | 11 | The aitextgen model can be saved at any time using `save`. 12 | 13 | ```py3 14 | ai.save() 15 | ``` 16 | 17 | ## Save to Google Drive 18 | 19 | If you are using Google Colaboratory, you can mount your personal Google Drive to the notebook and save your models there. 20 | 21 | 22 | !!! note "Downloading models from Colab Notebooks" 23 | It's strongly recommended to move models to Google Drive before downloading them from Colaboratory. 24 | 25 | First mount your Google Drive using `mount_gdrive()`: 26 | 27 | ```py3 28 | from aitextgen.colab import mount_gdrive, copy_file_to_gdrive 29 | mount_gdrive() 30 | ``` 31 | 32 | You'll be asked for an auth code; input it and press enter, and a `MyDrive` folder will appear in Colab Files view. 33 | 34 | You can drag and drop the model files into the Google Drive, or use `copy_file_to_gdrive` to copy them programmatically. 35 | 36 | ```py3 37 | copy_file_to_gdrive("pytorch_model.bin") 38 | copy_file_to_gdrive("config.json") 39 | ``` 40 | 41 | ## Saving During Training 42 | 43 | By default, the `train()` function has `save_every = 1000`, which means the model will save every 1000 steps to the specified `output_dir` (`trained_model` by default). You can adjust as necessary. 44 | 45 | ## Saving During Training in Google Colab 46 | 47 | Concerned about timeouts in Google Colab? aitextgen has a feature that will copy models to your Google Drive periodically in case the instance gets killed! 48 | 49 | As long as your drive is mounted as above, pass `save_gdrive = True` to the `train()` function: 50 | 51 | ```py3 52 | ai.train(save_gdrive=True) 53 | ``` 54 | 55 | This will save the model to the folder corresponding to the training `run_id` parameter (the datetime training was called, to prevent accidently overwriting). 56 | -------------------------------------------------------------------------------- /aitextgen/tokenizers.py: -------------------------------------------------------------------------------- 1 | from tokenizers import ByteLevelBPETokenizer 2 | from typing import Union, List 3 | 4 | 5 | def train_tokenizer( 6 | files: Union[str, List[str]], 7 | dropout: float = None, 8 | vocab_size: int = 1000, 9 | min_frequency: int = 2, 10 | prefix: str = "aitextgen", 11 | save_path: str = "", 12 | added_tokens: List[str] = [], 13 | bos_token: str = "<|endoftext|>", 14 | eos_token: str = "<|endoftext|>", 15 | unk_token: str = "<|endoftext|>", 16 | serialize: bool = True, 17 | trim_offsets: bool = True, 18 | ) -> None: 19 | """ 20 | Tokenizes the text(s) as a tokenizer, wrapping the tokenizer package. 21 | See: https://huggingface.co/blog/how-to-train 22 | 23 | For consistency, this function makes opinionated assuptions. 24 | 25 | :param files: path to file(s) to train tokenizer on 26 | :param dropout: Training dropout 27 | :param vocab_size: Final vocabulary size 28 | :param min_frequency: Minimum number of occurences to add to vocab 29 | :param prefix: File name prefix of the final tokenizer 30 | :param save_path: Where to save the final tokenizer 31 | :param added_tokens: List of tokens to add to the tokenizer (currently not working) 32 | :param bos_token: Beginning-of-string special token 33 | :param eos_token: End-of-string special token 34 | :param unk_token: Unknown special token 35 | """ 36 | 37 | assert isinstance(files, str) or isinstance( 38 | files, list 39 | ), "files must be a string or a list." 40 | 41 | assert isinstance(added_tokens, list), "added_tokens must be a list." 42 | 43 | if isinstance(files, str): 44 | files = [files] 45 | 46 | tokenizer = ByteLevelBPETokenizer(dropout=dropout, trim_offsets=trim_offsets) 47 | 48 | tokenizer.train( 49 | files=files, 50 | vocab_size=vocab_size, 51 | min_frequency=min_frequency, 52 | special_tokens=[bos_token, eos_token, unk_token] + added_tokens, 53 | ) 54 | 55 | if serialize: 56 | tokenizer.save(f"{prefix}.tokenizer.json") 57 | else: 58 | tokenizer.save_model(save_path, prefix) 59 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # aitextgen 2 | 3 | _Last Updated: April 18th, 2021 (aitextgen v0.5.0)_ 4 | 5 | A robust Python tool for text-based AI training and generation using [OpenAI's](https://openai.com) [GPT-2](https://openai.com/blog/better-language-models/) and [EleutherAI's](https://www.eleuther.ai) [GPT Neo/GPT-3](https://github.com/EleutherAI/gpt-neo) architecture. 6 | 7 | aitextgen is a Python package that leverages [PyTorch](https://pytorch.org), [Hugging Face Transformers](https://github.com/huggingface/transformers) and [pytorch-lightning](https://github.com/PyTorchLightning/pytorch-lightning) with specific optimizations for text generation using GPT-2, plus _many_ added features. It is the successor to [textgenrnn](https://github.com/minimaxir/textgenrnn) and [gpt-2-simple](https://github.com/minimaxir/gpt-2-simple), taking the best of both packages: 8 | 9 | - Finetunes on a pretrained 124M/355M/774M GPT-2 model from OpenAI or a 125M/355M GPT Neo model from EleutherAI...or create your own GPT-2/GPT Neo model + tokenizer and train from scratch! 10 | - Generates text faster than gpt-2-simple and with better memory efficiency! 11 | - With Transformers, aitextgen preserves compatibility with the base package, allowing you to use the model for other NLP tasks, download custom GPT-2 models from the HuggingFace model repository, and upload your own models! Also, it uses the included `generate()` function to allow a massive amount of control over the generated text. 12 | - With pytorch-lightning, aitextgen trains models not just on CPUs and GPUs, but also _multiple_ GPUs and (eventually) TPUs! It also includes a pretty training progress bar, with the ability to add optional loggers. 13 | - The input dataset is its own object, allowing you to not only easily encode megabytes of data in seconds, cache, and compress it on a local computer before transporting to a remote server, but you are able to _merge_ datasets without biasing the resulting dataset, or _cross-train_ on multiple datasets to create blended output. 14 | 15 | ## Installation 16 | 17 | aitextgen can be installed from PyPI: 18 | 19 | ```sh 20 | pip3 install aitextgen 21 | ``` 22 | 23 | ## More 24 | 25 | For more info on how to use aitextgen, check out the nav sidebar, or you can do a [Hello World tutorial](tutorials/hello-world.md). 26 | -------------------------------------------------------------------------------- /docs/helpful-notes.md: -------------------------------------------------------------------------------- 1 | # Helpful Notes 2 | 3 | A few helpful tips and tricks for using aitextgen. 4 | 5 | 6 | - **Not all AI generated text will be good**, hence why human curation is currently a necessary strategy for many finetuned models. In testing, only **5% — 10% of generated text** is viable. One of the design goals of aitextgen is to help provide tools to improve that signal-to-noise ratio. 7 | - **You may not necessarily get better results with larger models**. Larger models perform better on academic benchmarks, yes, but the _quality_ of text can vary strongly depending on the size of the model used, _especially_ if you do not have a lot of input data. 8 | - To convert a GPT-2 model trained using earlier TensorFlow-based finetuning tools such as gpt-2-simple to the PyTorch format, use the transformers-cli command and the [instructions here](https://huggingface.co/transformers/converting_tensorflow_models.html) to convert the checkpoint (where `OPENAI_GPT2_CHECKPOINT_PATH` is the _folder_ containing the model) 9 | - When running on Google Cloud Platform (including Google Colab), it's recommended to download the TF-based GPT-2 from the Google API vs. downloading the PyTorch GPT-2 from Huggingface as the download will be _much_ faster and also saves Huggingface some bandwidth. 10 | - If you want to generate text from a GPU, you must manually move the model to the GPU (it will not be done automatically to save GPU VRAM for training). Either call `to_gpu=True` when loading the model or call `to_gpu()` from the aitextgen object. 11 | - Encoding your text dataset before moving it to a cloud/remote server is _strongly_ recommended. You can do that quickly from the CLI (`aitextgen encode text.txt`) Thanks to a few tricks, the file size is reduced by about 1/2 to 2/3, and the encoded text will instantly load on the remote server! 12 | - If you're making a micro-GPT-2 model, using a GPU with a large batch size is recommended, as with the AdamW optimizer, it effectively has built-in batch normalization. 13 | - If you make frequent use of the Colab Notebooks, I recommend purchasing [Colab Pro](https://colab.research.google.com/signup): the timeouts are infrequent, and being able to reliably train on a P100 (normally [$1.46/hr](https://cloud.google.com/compute/gpus-pricing)) for hours on end very quickly pays for itself! 14 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: aitextgen 2 | site_description: A robust Python tool for text-based AI training and generation using GPT-2 and GPT Neo. 3 | site_author: Max Woolf (@minimaxir) 4 | 5 | nav: 6 | - Home: index.md 7 | - Hello World Tutorial: tutorials/hello-world.md 8 | - Loading a Model: load-model.md 9 | - Saving a Model: save-model.md 10 | - TokenDataset: dataset.md 11 | - Training a Model: 12 | - Colaboratory Notebooks: tutorials/colab.md 13 | # - Improving Training Performance: train-performance.md 14 | - Training a GPT-2 Model From Scratch: tutorials/model-from-scratch.md 15 | - Loggers: loggers.md 16 | - Generating from a Model: 17 | - Generating Text: generate.md 18 | # - Improving Generation Performance: generate-performance.md 19 | # - Generating From GPT-2 1.5B: tutorials/generate_1_5b.md 20 | - Importing from gpt-2-simple: gpt-2-simple.md 21 | - Helpful Notes: helpful-notes.md 22 | - Ethics: ethics.md 23 | - Command-Line Interface: cli.md 24 | - Upload Model to Huggingface: upload.md 25 | 26 | theme: 27 | name: "material" 28 | palette: 29 | - media: "(prefers-color-scheme: light)" 30 | scheme: default 31 | toggle: 32 | icon: material/toggle-switch-off-outline 33 | name: Switch to dark mode 34 | - media: "(prefers-color-scheme: dark)" 35 | scheme: slate 36 | primary: black 37 | toggle: 38 | icon: material/toggle-switch 39 | name: Switch to light mode 40 | font: 41 | text: "Source Sans Pro" 42 | code: "Fira Code" 43 | features: 44 | - tabs 45 | - instant 46 | icon: 47 | logo: fontawesome/solid/robot 48 | repo: fontawesome/brands/github 49 | 50 | repo_name: minimaxir/aitextgen 51 | repo_url: https://github.com/minimaxir/aitextgen 52 | edit_uri: "" 53 | 54 | copyright: "Copyright © 2019 - 2021 Max Woolf" 55 | 56 | extra: 57 | social: 58 | - icon: "fontawesome/brands/github-alt" 59 | link: "https://github.com/minimaxir/aitextgen" 60 | - icon: "fontawesome/brands/twitter" 61 | link: "https://twitter.com/minimaxir" 62 | 63 | markdown_extensions: 64 | - pymdownx.highlight 65 | - pymdownx.superfences 66 | - admonition 67 | - toc: 68 | permalink: true 69 | 70 | plugins: 71 | - search # necessary for search to work 72 | - minify: 73 | minify_html: true 74 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | All notable changes to this project will be documented in this file. 4 | 5 | The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). 6 | 7 | ## [0.4.0] - 2021-02-21 8 | 9 | - Increased minimum versions of dependencies (`transformers` to 4.3.0, `pytorch-lightning` to 1.2.0) 10 | - Remove dependency on `tokenizers` as `transformers` pins it. 11 | - Made Fast tokenizers the default (as it is the default in `transformers` 4.0.0) 12 | - Made serialized tokenizers the default for custom tokenizers, and added support for loading them for both `aitextgen` and `TokenDataset`s 13 | - Added gradient checkpointing for GPT-2, and set it to the default for training 355M and 774M. 14 | - Added layer freezing to freeze the first `n` layers of GPT-2 while training. This allows 1.5B GPT-2 to be trained with a high `n`. 15 | - Added schema-based generation for specificed schema_tokens (which can be encoded in the Transformers config). This can be used with an appropriate dataset for schema-based generation. 16 | - Switched TensorFlow weight download URL from GCP (as OpenAI removed it from there) to Azure 17 | - Fixed issue where prompt character length was used to check for a too-long assert instead of prompt token length (#90) 18 | - Workaround breaking issue in Transformers 4.3.0 by moving special token stripping into aitextgen instead of the tokenizer (#90) 19 | - Added an `lstrip` param to generation, which strips all whitespace at the beginning of generated text (related to point above) 20 | 21 | ## [0.3.0] - 2020-11-30 22 | 23 | - Increased minimum versions of dependencies (`transformers` to 4.0.0, `pytorch-lightning` to 1.0.8, Pytorch to 1.6) 24 | - Fixed imports to account for new Transfomers file architecture 25 | - Fixed training to account for new transformer/pytorch-lightning minimums 26 | - Fully removed TorchScript code (ONNX implementation will supercede it) 27 | - Made prompt specification for generation more canonical with Transformers 28 | - Set default `vocab` size for new tokenizers to `1000` 29 | - Began work on serializing tokenizers in accordance to the new `tokenizers` approach 30 | 31 | ## [0.2.1] - 2020-06-28 32 | 33 | ### Added 34 | 35 | - CHANGELOG.md 36 | 37 | ## [0.2.0] - 2020-06-01 38 | 39 | ### Added 40 | 41 | - Progress bar for loading a dataset. 42 | - `progress_bar_refresh_rate` parameter for `train()` and `TokenDataset()`. 43 | 44 | ### Changed 45 | 46 | - Set numpy data store for `TokenDataset`. 47 | - Set single-text files to be loaded delimited as newlines. 48 | 49 | ### Removed 50 | 51 | - `shuffle` and `seed` parameters for TokenDataset. 52 | 53 | ## [0.1.1] - 2020-05-17 54 | 55 | ### Changed 56 | 57 | - Set `generate()` defaults to `max_length=256` and `temperature=0.7`. 58 | - Added to docs notes about GPT-2 maximum length of 1024. 59 | 60 | ## [0.1] - 2020-05-17 61 | 62 | ### Added 63 | 64 | - Everything! 65 | -------------------------------------------------------------------------------- /docs/tutorials/hello-world.md: -------------------------------------------------------------------------------- 1 | # Hello World 2 | 3 | Here's how you can quickly test out aitextgen on your own computer, even if you don't have a GPU! 4 | 5 | For generating text from a pretrained GPT-2 model: 6 | 7 | ```py3 8 | from aitextgen import aitextgen 9 | 10 | # Without any parameters, aitextgen() will download, cache, and load the 124M GPT-2 "small" model 11 | ai = aitextgen() 12 | 13 | ai.generate() 14 | ai.generate(n=3, max_length=100) 15 | ai.generate(n=3, prompt="I believe in unicorns because", max_length=100) 16 | ai.generate_to_file(n=10, prompt="I believe in unicorns because", max_length=100, temperature=1.2) 17 | ``` 18 | 19 | You can also generate from the command line: 20 | 21 | ```sh 22 | aitextgen generate 23 | aitextgen generate --prompt "I believe in unicorns because" --to_file False 24 | ``` 25 | 26 | Want to train your own mini GPT-2 model on your own computer? Download this [text file of Shakespeare plays](https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt), cd to that directory in a Teriminal, open up a `python3` console and go: 27 | 28 | ```py3 29 | from aitextgen.TokenDataset import TokenDataset 30 | from aitextgen.tokenizers import train_tokenizer 31 | from aitextgen.utils import GPT2ConfigCPU 32 | from aitextgen import aitextgen 33 | 34 | # The name of the downloaded Shakespeare text for training 35 | file_name = "input.txt" 36 | 37 | # Train a custom BPE Tokenizer on the downloaded text 38 | # This will save one file: `aitextgen.tokenizer.json`, which contains the 39 | # information needed to rebuild the tokenizer. 40 | train_tokenizer(file_name) 41 | tokenizer_file = "aitextgen.tokenizer.json" 42 | 43 | # GPT2ConfigCPU is a mini variant of GPT-2 optimized for CPU-training 44 | # e.g. the # of input tokens here is 64 vs. 1024 for base GPT-2. 45 | config = GPT2ConfigCPU() 46 | 47 | # Instantiate aitextgen using the created tokenizer and config 48 | ai = aitextgen(tokenizer_file=tokenizer_file, config=config) 49 | 50 | # You can build datasets for training by creating TokenDatasets, 51 | # which automatically processes the dataset with the appropriate size. 52 | data = TokenDataset(file_name, tokenizer_file=tokenizer_file, block_size=64) 53 | 54 | # Train the model! It will save pytorch_model.bin periodically and after completion to the `trained_model` folder. 55 | # On a 2020 8-core iMac, this took ~25 minutes to run. 56 | ai.train(data, batch_size=8, num_steps=50000, generate_every=5000, save_every=5000) 57 | 58 | # Generate text from it! 59 | ai.generate(10, prompt="ROMEO:") 60 | 61 | # With your trained model, you can reload the model at any time by 62 | # providing the folder containing the pytorch_model.bin model weights + the config, and providing the tokenizer. 63 | ai2 = aitextgen(model_folder="trained_model", 64 | tokenizer_file="aitextgen.tokenizer.json") 65 | 66 | ai2.generate(10, prompt="ROMEO:") 67 | ``` 68 | -------------------------------------------------------------------------------- /docs/load-model.md: -------------------------------------------------------------------------------- 1 | # Model Loading 2 | 3 | There are several ways to load models. 4 | 5 | 6 | !!! note "Continue Training/Finetuning" 7 | You can further train a model by reloading a model that has already been trained, using the methods outlined below. 8 | 9 | ## Loading an aitextgen model 10 | 11 | For the base case, loading the default 124M GPT-2 model via Huggingface: 12 | 13 | ```py3 14 | ai = aitextgen() 15 | ``` 16 | 17 | The downloaded model will be downloaded to `cache_dir`: `/aitextgen` by default. 18 | 19 | If you're loading a custom model for a different GPT-2/GPT-Neo architecture _from scratch_ but with the normal GPT-2 tokenizer, you can pass only a config. 20 | 21 | ```py3 22 | from aitextgen.utils import GPT2ConfigCPU 23 | config = GPT2ConfigCPU() 24 | ai = aitextgen(config=config) 25 | ``` 26 | 27 | While training/finetuning a model, two files will be created: the `pytorch_model.bin` which contains the weights for the model, and a `config.json` illustrating the architecture for the model. Both of these files are needed to reload the model. 28 | 29 | If you've finetuned a model using aitextgen (the default model), you can pass the **folder name** containing the generated `pytorch_model.bin` and `config.json` to aitextgen (e.g. `trained_model`, which is where trained models will be saved by default). 30 | 31 | 32 | !!! note "Same Directory" 33 | If both files are in the current directory, you can pass `model_folder="."`. 34 | 35 | ```py3 36 | ai = aitextgen(model_folder="trained_model") 37 | ``` 38 | 39 | These examples assume you are using the default GPT-2 tokenizer. If you have a _custom tokenizer_, you'll need to pass that along with loading the model. 40 | 41 | ```py3 42 | ai3 = aitextgen(model_folder="trained_model", 43 | tokenizer_file="aitextgen.tokenizer.json") 44 | ``` 45 | 46 | If you want to download an alternative GPT-2 model from Hugging Face's repository of models, pass that model name to `model`. 47 | 48 | ```py3 49 | ai = aitextgen(model="minimaxir/hacker-news") 50 | ``` 51 | 52 | The model and associated config + tokenizer will be downloaded into `cache_dir`. 53 | 54 | This can also be used to download the [pretrained GPT Neo models](https://huggingface.co/EleutherAI) from EleutherAI. 55 | 56 | ```py3 57 | ai = aitextgen(model="EleutherAI/gpt-neo-125M") 58 | ``` 59 | 60 | ## Loading TensorFlow-based GPT-2 models 61 | 62 | aitextgen lets you download the models from Microsoft's servers that OpenAI had uploaded back when GPT-2 was first released in 2019. These models are then converted to a PyTorch format. 63 | 64 | To use this workflow, pass the corresponding model number to `tf_gpt2`: 65 | 66 | ```py3 67 | ai = aitextgen(tf_gpt2="124M") 68 | ``` 69 | 70 | This will cache the converted model locally in `cache_dir`, and using the same parameters will load the converted model. 71 | 72 | The valid TF model names are `["124M","355M","774M","1558M"]`. 73 | -------------------------------------------------------------------------------- /docs/tutorials/model-from-scratch.md: -------------------------------------------------------------------------------- 1 | # Training a GPT-2 Model From Scratch 2 | 3 | The original GPT-2 model released by OpenAI was trained on English webpages linked to from Reddit, with a strong bias toward longform content (multiple paragraphs). 4 | 5 | If that is _not_ your use case, you may get a better generation quality _and_ speed by training your own model and Tokenizer. Examples of good use cases: 6 | 7 | - Short-form content (e.g. Tweets, Reddit post titles) 8 | - Non-English Text 9 | - Heavily Encoded Text 10 | 11 | It still will require a _massive_ amount of training time (several hours) but will be more flexible. 12 | 13 | ## Building a Custom Tokenizer. 14 | 15 | The `train_tokenizer()` function from `aitextgen.tokenizers` trains the model on the specified text(s) on disk. 16 | 17 | 18 | !!! note "Vocabulary Size" 19 | The default vocabulary size for `train_tokenizer()` is 1,000 tokens. Although this is much lower than GPT-2's 50k vocab size, the smaller the vocab size, the easier it is to train the model (since it's more likely for the model to make a correct "guess"), and the model file size will be _much_ smaller. 20 | 21 | ```py3 22 | from aitextgen.tokenizers import train_tokenizer 23 | train_tokenizer(file_name) 24 | ``` 25 | 26 | This creates one file, `aitextgen.tokenizer.json`, which is needed to rebuild the tokenizer. 27 | 28 | # Building a Custom Dataset 29 | 30 | You can build a TokenDataset based off your custom Tokenizer, to be fed into the model. 31 | 32 | ```py3 33 | data = TokenDataset(file_name, vocab_file=vocab_file, merges_file=merges_file, block_size=32) 34 | ``` 35 | 36 | ## Building a Custom Config 37 | 38 | Whenever you load a default 124M GPT-2 model, it uses a `GPT2Config()` under the hood. But you can create your own, with whatever parameters you want. 39 | 40 | The `build_gpt2_config()` function from `aitextgen.utils` gives you more control. 41 | 42 | ```py3 43 | config = build_gpt2_config(vocab_size=5000, max_length=32, dropout=0.0, n_embd=256, n_layer=8, n_head=8) 44 | ``` 45 | 46 | A few notes on the inputs: 47 | 48 | - `vocab_size`: Vocabulary size: this _must_ match what you used to build the tokenizer! 49 | - `max_length`: Context window for the GPT-2 model: this _must_ match the `block_size` used in the TokenDataset! 50 | - `dropout`: Dropout on various areas of the model to limit overfitting (you should likely keep at 0) 51 | - `n_embd`: The embedding size for each vocab token. 52 | - `n_layers`: Transformer layers 53 | - `n_head`: Transformer heads 54 | 55 | 56 | !!! note "Model Size" 57 | GPT-2 Model size is directly proportional to `vocab_size` \* `embeddings`. 58 | 59 | ## Training the Custom Model 60 | 61 | You can instantiate an empty GPT-2 according to your custom config, and construct a custom tokenizer according to your vocab and merges file: 62 | 63 | ```py3 64 | ai = aitextgen(tokenizer_file=tokenizer_file, config=config) 65 | ``` 66 | 67 | Training is done as normal. 68 | 69 | ```py3 70 | ai.train(data, batch_size=16, num_steps=5000) 71 | ``` 72 | 73 | ## Reloading the Custom Model 74 | 75 | You'll always need to provide the tokenizer_file and the folder containing the `pytorch_model.bin` and `config.json`. 76 | 77 | ```py3 78 | ai = aitextgen(model_folder="trained_model", tokenizer_file="aitextgen.tokenizer.json") 79 | ``` 80 | -------------------------------------------------------------------------------- /docs/generate.md: -------------------------------------------------------------------------------- 1 | # Text Generation 2 | 3 | Thanks to the base Transformers package, aitextgen has more options for generating text than other text-generating apps before. 4 | 5 | ## Generation Parameters 6 | 7 | See [this article](https://huggingface.co/blog/how-to-generate) by Huggingface engineer Patrick von Platen for how sampling and these parameters are used in practice. 8 | 9 | - `n`: Number of texts generated. 10 | - `max_length`: Maximum length of the generated text (default: 200; for GPT-2, the maximum is 1024; for GPT Neo, the maximum is 2048) 11 | - `prompt`: Prompt that starts the generated text and is included in the generated text. 12 | - `temperature`: Controls the "craziness" of the text (default: 0.7) 13 | - `top_k`: If nonzero, limits the sampled tokens to the top _k_ values. (default: 0) 14 | - `top_p`: If nonzero, limits the sampled tokens to the cumulative probability 15 | 16 | Some lesser-known-but-still-useful-parameters that are unique to Transformers: 17 | 18 | 19 | !!! warning "Performance" 20 | Enabling these parameters may slow down generation. 21 | 22 | - `num_beams`: If greater than 1, executes beam search for cleaner text. 23 | - `repetition_penalty`: If greater than 1.0, penalizes repetition in a text to avoid infinite loops. 24 | - `length_penalty`: If greater than 1.0, penalizes text proportional to the length 25 | - `no_repeat_ngram_size`: Token length to avoid repeating given phrases. 26 | 27 | ## Generation Functions 28 | 29 | Given a `aitextgen` object with a loaded model + tokenizer named `ai`: 30 | 31 | 32 | !!! note "About devices" 33 | aitextgen does not automatically set the device used to generate text. If you 34 | want to generate on the GPU, make sure you call `ai.to_gpu()` beforehand, or 35 | load the model into the GPU using `ai = aitextgen(to_gpu=True)` 36 | 37 | - `ai.generate()`: Generates and prints text to console. If `prompt` is used, the `prompt` is **bolded**. 38 | - `ai.generate_one()`: A helper function which generates a single text and returns as a string (good for APIs) 39 | - `ai.generate_samples()`: Generates multiple samples at specified temperatures: great for debugging. 40 | - `ai.generate_to_file()`: Generates a bulk amount of texts to file. (this accepts a `batch_size` parameter which is useful if using on a GPU, as it can generate texts in parallel with no performance loss) 41 | 42 | 43 | !!! note "lstrip and nonempty_output" 44 | By default, the `lstrip` and `nonempty_output` parameters to `generate` are set to `True`, which alters the behavior of the generated text in a way that is most likely preferable. `lstrip`: Removes all whitespace at the beginning of the generated space. `nonempty_output`: If the output is empty (possible on shortform content), skip it if generating multiple texts, or try again if it's a single text. If `min_length` is specified, the same behavior occurs for texts below the minimum length after processing. 45 | 46 | ## Seed 47 | 48 | aitextgen has a new `seed` parameter for generation. Using any generate function with a `seed` parameter (must be an integer) and all other models/parameters the same, and the generated text will be identical. This allows for reproducible generations in case someone accuses you of faking the AI output. 49 | 50 | For `generate_to_file()`, the 8-digit number at the end of the file name will be the seed used to generate the file, making reprodicibility easy. 51 | -------------------------------------------------------------------------------- /docs/tutorials/generate_1_5b.md: -------------------------------------------------------------------------------- 1 | # Generating From GPT-2 1.5B 2 | 3 | 4 | Want to generate a ton of text with the largest GPT-2 model, with the generation control provided by aitextgen? Now you can, at a surprisingly low cost! ($0.382/hr, prorated to the nearest second) 5 | 6 | Here's how to set it up on Google Cloud Platform. 7 | 8 | ## Setting Up an AI Platform Notebook 9 | 10 | An [AI Platform Notebook](https://cloud.google.com/ai-platform-notebooks) is a hosted Jupyter Lab instance on Google Cloud Platform oriented for AI training and inference. Since it requires zero setup _and has no additional costs outside of CPU/GPUs_, it's the best tool to play with aitextgen. 11 | 12 | First, go to [AI Platform Notebooks in the GCP console](https://console.cloud.google.com/ai-platform/notebooks/) (if you haven't made a project + billing, it should prompt you to do so). Go to `New Instance`, select `PyTorch 1.4` and `With 1 NVIDIA Tesla T4`. 13 | 14 | 15 | !!! note "Quotas" 16 | You may need T4 quota to create a VM with a T4; accounts should have enough by default, but you may want to confirm. 17 | 18 | The rest of the VM config settings are fine to leave as/is, however make sure you check `Install NVIDIA GPU driver automatically for me`! 19 | 20 | Once the instance is created, wait a bit (it takes awhile to install the driver), and a `OPEN JUPYTERLAB` button will appear. Click it to open the hosted Jupyter Lab 21 | 22 | ## Installing aitextgen on the VM 23 | 24 | Now we have to install the dependencies, which only have to be done once. 25 | 26 | In the Jupyter Lab instance, open a Terminal tab, and install both aitextgen and tensorflow (we'll need tensorflow later) 27 | 28 | ```sh 29 | pip3 install aitextgen tensorflow 30 | ``` 31 | 32 | Now the harder part: we need to install and compile [apex](https://github.com/NVIDIA/apex) for FP16 support with the T4 GPU. To do that, run: 33 | 34 | ```sh 35 | git clone https://github.com/NVIDIA/apex 36 | cd apex && pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./ 37 | ``` 38 | 39 | That will take a few minutes, but once that is done, you are good to go and do not need to rerun these steps again! 40 | 41 | ## Loading GPT-2 1.5B 42 | 43 | Now go back to the Launcher and create a Python 3 Notebook (or upload the one here). 44 | 45 | 46 | !!! warning "CUDA" 47 | You may want to ensure the Notebook sees the CUDA installation, which appears to be somewhat random. This can be verified by running `import torch` in a cell, then `torch.cuda.is_available()`. 48 | 49 | In a cell, load aitextgen: 50 | 51 | ```py3 52 | from aitextgen import aitextgen 53 | ``` 54 | 55 | In another cell, input and run: 56 | 57 | ```py3 58 | ai = aitextgen(tf_gpt2="1558M", to_gpu=True, to_fp16=True) 59 | ``` 60 | 61 | A few things going on here: 62 | 63 | - The TensorFlow-based GPT-2 1.5B is downloaded from Google's servers. (download rate is _very_ fast). This download will only occur once. 64 | - It is converted to a corresponding PyTorch model, and then loaded. 65 | - After it is loaded, it is converted to a FP16 representation. 66 | - Then it is moved to the T4 GPU. 67 | 68 | ## Generating from GPT-2 1.5B 69 | 70 | Now we can generate texts! The T4, for GPT-2 1.5B in FP16 mode, can generate about 30 texts in a batch without going OOM. (you can verify GPU memory usage at any time by opening up a Terminal and running `nvidia-smi`) 71 | 72 | Create a cell and add: 73 | 74 | ```py3 75 | ai.generate_to_file(n=300, batch_size=30) 76 | ``` 77 | 78 | 79 | !!! warning "Batch Size" 80 | The batch size of 30 above assumes the default `max_length` of 256. If you want to use the full 1024 token max length, lower the batch size to 15, as the GPU will go OOM otherwise. 81 | 82 | And it will generate the texts to a file! When completed, you can double-click to view it in Jupyter Lab, and you can download the file by right-clicking it from the file viewer. 83 | 84 | More importantly, all parameters to `generate` are valid, allowing massive flexibility! 85 | 86 | ```py3 87 | ai.generate_to_file(n=150, batch_size=15, max_length=1024, top_p=0.9, temperature=1.2, prompt="President Donald Trump has magically transformed into a unicorn.") 88 | ``` 89 | 90 | ## Cleanup 91 | 92 | **Make sure you Stop the instance when you are done to avoid being charged**. To do that, go back to the AI Platform Notebook console, select the instance, and press Stop. 93 | 94 | Since 100GB of storage may be pricy, you may want to consider deleting the VM fully if you are done with it as well. 95 | -------------------------------------------------------------------------------- /docs/dataset.md: -------------------------------------------------------------------------------- 1 | # TokenDataset 2 | 3 | aitextgen has a special class, `TokenDataset`, used for managing tokenized datasets to be fed into model training. (this is in contrast with other GPT-2 finetuning approaches, which tokenizes at training time although you can still do that by passing a `file_path` and other relevant parameters to `ai.train()`.) 4 | 5 | This has a few nice bonuses, including: 6 | 7 | - Tokenize a dataset on a local machine ahead of time and compress it, saving time/bandwidth transporting data to a remote machine 8 | - Supports both reading a dataset line-by-line (including single-column CSVs), or bulk texts. 9 | - Debug and log the loaded texts. 10 | - Merge datasets together without using external libraries 11 | - Cross-train on multiple datasets to "blend" them together. 12 | 13 | ## Creating a TokenDataset For GPT-2 Finetuning 14 | 15 | The easiest way to create a TokenDataset is to provide a target file. If no `tokenizer_file` is provided, it will use the default GPT-2 tokenizer. 16 | 17 | ```py3 18 | from aitextgen.TokenDataset import TokenDataset 19 | 20 | data = TokenDataset("shakespeare.txt") 21 | ``` 22 | 23 | If you pass a single-column CSV and specify `line_by_line=True`, the TokenDataset will parse it row-by-row, and is the recommended way to handle multiline texts. 24 | 25 | ```py3 26 | data = TokenDataset("politics.csv", line_by_line=True) 27 | ``` 28 | 29 | You can also manually pass a list of texts to `texts` instead if you've processed them elsewhere. 30 | 31 | ```py3 32 | data = TokenDataset(texts = ["Lorem", "Ipsum", "Dolor"]) 33 | ``` 34 | 35 | ## Block Size 36 | 37 | `block_size` is another parameter that can be passed when creating a TokenDataset, more useful for custom models. This should match the context window (e.g. the `n_positions` or `max_position_embeddings` config parameters). By default, it will choose `1024`: the GPT-2 context window. 38 | 39 | When implicitly loading a dataset via `ai.train()`, the `block_size` will be set to what is supported by the corresponding model `config`. 40 | 41 | ## Debugging a TokenDataset 42 | 43 | When loading a dataset, a progress bar will appear showing how many texts are loaded and 44 | 45 | If you want to see what exactly is input to the model during training, you can access a slice via `data[0]`. 46 | 47 | ## Saving/Loading a TokenDataset 48 | 49 | When creating a TokenDataset, you can automatically save it as a compressed gzipped numpy array when completed. 50 | 51 | ```py3 52 | data = TokenDataset("shakespeare.txt", save_cache=True) 53 | ``` 54 | 55 | Or save it after you've loaded it with the `save()` function. 56 | 57 | ```py3 58 | data = TokenDataset("shakespeare.txt") 59 | data.save() 60 | ``` 61 | 62 | By default, it will save to `dataset_cache.tar.gz`. You can then reload that into another Python session by specifying the cache. 63 | 64 | ```py3 65 | data = TokenDataset("dataset_cache.tar.gz", from_cache=True) 66 | ``` 67 | 68 | 69 | !!! note "CLI" 70 | You can quickly create a Tokenized dataset using the command line, e.g. `aitextgen encode text.txt`. This will drastically reduce the file size, and is recommended before moving the file to cloud services (where it can be loaded using the `from_cache` parameter noted above) 71 | 72 | ## Using TokenDatasets with a Custom GPT-2 Model 73 | 74 | The default TokenDataset has a `block_size` of `1024`, which corresponds to the _context window of the default GPT-2 model_. If you're using a custom model w/ a different maximum. Additionally, you must explicitly provide the tokenizer file to rebuild the tokenizer, as the tokenizer will be different than the normal GPT-2 one. 75 | 76 | See the [Model From Scratch](tutorials/model-from-scratch.md) docs for more info. 77 | 78 | ## Merging TokenDatasets 79 | 80 | Merging processed TokenDatasets can be done with the `merge_datasets()` function. By default, it will take samples equal to the smallest dataset from each TokenDataset, randomly sampling the appropriate number of texts from the larger datasets. This will ensure that model training does not bias toward one particular dataset. (it can be disabled by setting `equalize=False`) 81 | 82 | (You _can_ merge bulk datasets and line-by-line datasets, but the training output may be bizarre!) 83 | 84 | 85 | !!! note "About Merging" 86 | The current implementation merges by subset count, so equalization may not be perfect, but it will not significantly impact training. 87 | 88 | ```py3 89 | from aitextgen.TokenDataset import TokenDataset, merge_datasets 90 | 91 | data1 = TokenDataset("politics1000.csv", line_by_line=True) # 1000 samples 92 | data2 = TokenDataset("politics4000.csv", line_by_line=True) # 4000 samples 93 | 94 | data_merged = merge_datasets([data1, data2]) # ~2000 samples 95 | ``` 96 | -------------------------------------------------------------------------------- /DESIGN.md: -------------------------------------------------------------------------------- 1 | # Design 2 | 3 | A few notes on some opinionated design decisions present in aitextgen. 4 | 5 | ## Features 6 | 7 | - TokenDataset caching is done via [MessagePack](https://msgpack.org/index.html). Tokens are encoded as integers, which is the case MessagePack is especially designed for, compressing the final file by about 50%. MessagePack will also save/reload the file much faster than `pickle` or `json`. 8 | - By default, this is additionally compressed via `gzip`, further reducing file size. This may have memory constraints for superlarge files (1GB+), hence a parameter to disable compression. 9 | - Although GPT-2 is the default Transformers model architecture used, there is limited GPT-2 specific code. This allows this tool to easily adapt to new CLM architectures that may be released in the future, such as SparseTransformers or Reformers. 10 | - `generate_to_file()` automatically assigns the generated process a seed if one is not specified. This allows other users to reproduce a generation deterministically (e.g. in a Jupyter Notebook), in order to provide proof that at text was generated by AI and not altered. 11 | - For training checkpoints, aitextgen deliberately disables pytorch-Lightning’s checkpoint feature since it incurs a lot of overhead, and using the native model saving within the model itself is easier. 12 | - Testing/Validation is deliberately not implemented since it serves as more of a crutch and doesn't provide that much help in identifying overfitting. 13 | 14 | ## Philosophies 15 | 16 | - The development intent of aitextgen is as a _tool_ for AI text generation, and not a philosophical experiment behind AI consciousness or whatnot. (alternatively, one could argue that _humans_ are the ones who perform actions based on prior knowledge and [free will is a myth](https://www.youtube.com/watch?v=kQjb-EP2JEE), but that's a discussion for another time) 17 | - AI generated text should _never_ be called "deepfake text." Stop trying to make deepfake text happen. 18 | 19 | ## Deviations from Huggingface Transformers 20 | 21 | - While aitextgen was in development, Huggingface added a finetuning `Trainer` to Transformers, which would make this package somewhat redundant. However, pytorch-lightning has a number of specific features and optimizations that make separating out the logic more efficient and easier to modify/tweak. 22 | - For generation, random sampling is the default while in the base Transformers, [greedy sampling is used](https://github.com/huggingface/transformers/pull/3298) which makes the output deterministic. 23 | - The TokenDataset uses a implementation different from Transformer’s TextDataset, as that implementation is flawed for GPT-2. The Transformers TextDataset splits the corpus into blocks of `block_size` which means a) many token subsets may not be possible to be fed into the model and b) Merges may be split accross the block size boundaries, causing incorrect tokenization. TokenData fixes both issues by creating a giant single text, and provides the models random subsets of that model of `block_size`. 24 | 25 | ## Why aitextgen was made vs. improving gpt-2-simple 26 | 27 | gpt-2-simple was an quick project intended to be able to use GPT-2 in an app when no user-friendly alternative existed at the time, and for others to be able to use GPT-2 without fiddiling around with the command line. However, since development on GPT-2 from OpenAI abruptly and unexpectedly stopped, it would take too much time to manually implement necessary improvements on my end. 28 | 29 | It is faster to build a new project on a stronger, more futureproof base than it is to iterate on gpt-2-simple. Both transformers and pytorch-lightning are in rapid development so I am confident they will be around for awhile. 30 | 31 | I am also trying to avoid feature creep, as that happened with textgenrnn and gpt-2-simple and made the packages more difficult to maintain. Not all Issues/PRs will be added. 32 | 33 | ## Deviations from gpt-2-simple 34 | 35 | - PyTorch is the primary ML framework since it's the first-class citizen for Transformers. Additionally, as of 2020, the differences (in both performance and usability) between TensorFlow and PyTorch have converged, so I figured I'd give PyTorch a try. There were also a few random issues when developing gpt-2-simple (memory leaks, having to start a new session) which I believe were attributable to the TF base; maybe PyTorch will work out better. 36 | - The base `generate()` function in Transformers which aitextgen uses is "more correct" than gpt-2-simple: it stops and starts at special tokens negating the need for encoding hacks, it uses caching to speed up decoding, and it only generates as much text as needed. 37 | - The OOP approach (which I had used long ago in textgenrnn) should make working with aitextgen much easier for integration into other apps. 38 | - For generation, you need to explicitly load/move the model to the GPU via `to_gpu()` whereas TensorFlow did it automatically (it's not necessary to do so for training) 39 | 40 | ## Why create a separate package instead of commiting directly to Transformers 41 | 42 | aitextgen has specific optimizations that may be out of scope for the Transformers package primary use cases. If an optimization benefits the default use cases of Transformers without negatively impacting it, I may do a pull request. 43 | 44 | As the license of this repo is permissible (MIT), the Huggingface team is able to apply any modifications/ideas to their own repo if necessary without explicit permission. 45 | -------------------------------------------------------------------------------- /aitextgen/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import requests 3 | from tqdm.auto import tqdm 4 | import torch 5 | import numpy as np 6 | import random 7 | from transformers import GPT2Config, GPTNeoConfig 8 | 9 | 10 | def download_gpt2(model_dir: str = "tf_model", model_name: str = "124M") -> None: 11 | """ 12 | Downloads the GPT-2 model (weights only) into the specified directory 13 | from Google Cloud Storage. 14 | 15 | If running in Colaboratory or Google Compute Engine, 16 | this is substantially faster (and cheaper for HuggingFace) than using the 17 | default model downloading. However, the model is in TensorFlow, 18 | so the weights must be converted. 19 | 20 | Adapted from gpt-2-simple. 21 | """ 22 | 23 | # create the / subdirectory if not present 24 | sub_dir = os.path.join(model_dir, model_name) 25 | if not os.path.exists(sub_dir): 26 | os.makedirs(sub_dir) 27 | sub_dir = sub_dir.replace("\\", "/") # needed for Windows 28 | 29 | for file_name in [ 30 | "checkpoint", 31 | "hparams.json", 32 | "model.ckpt.data-00000-of-00001", 33 | "model.ckpt.index", 34 | "model.ckpt.meta", 35 | ]: 36 | if not os.path.isfile(os.path.join(sub_dir, file_name)): 37 | download_file_with_progress( 38 | url_base="https://openaipublic.blob.core.windows.net/gpt-2", 39 | sub_dir=sub_dir, 40 | model_name=model_name, 41 | file_name=file_name, 42 | ) 43 | 44 | 45 | def download_file_with_progress( 46 | url_base: str, sub_dir: str, model_name: str, file_name: str 47 | ): 48 | """ 49 | General utility for incrementally downloading files from the internet 50 | with progress bar. 51 | 52 | Adapted from gpt-2-simple. 53 | """ 54 | 55 | # set to download 1MB at a time. This could be much larger with no issue 56 | DOWNLOAD_CHUNK_SIZE = 1024 * 1024 57 | r = requests.get( 58 | os.path.join(url_base, "models", model_name, file_name), stream=True 59 | ) 60 | with open(os.path.join(sub_dir, file_name), "wb") as f: 61 | file_size = int(r.headers["content-length"]) 62 | with tqdm( 63 | desc="Fetching " + file_name, 64 | total=file_size, 65 | unit_scale=True, 66 | ) as pbar: 67 | for chunk in r.iter_content(chunk_size=DOWNLOAD_CHUNK_SIZE): 68 | f.write(chunk) 69 | pbar.update(DOWNLOAD_CHUNK_SIZE) 70 | 71 | 72 | def set_seed(seed: int): 73 | """ 74 | Sets the seed for all potential generation libraries. 75 | """ 76 | 77 | assert isinstance(seed, int), "seed must be an integer." 78 | random.seed(seed) 79 | np.random.seed(seed) 80 | torch.manual_seed(seed) 81 | torch.cuda.manual_seed_all(seed) 82 | 83 | 84 | def reset_seed(): 85 | """ 86 | Resets the seed for all potential generation libraries. 87 | """ 88 | random.seed() 89 | np.random.seed() 90 | # torch.seed() 91 | # torch.cuda.seed_all() 92 | 93 | 94 | def build_gpt2_config( 95 | vocab_size: int = 10000, 96 | bos_token_id: int = 0, 97 | eos_token_id: int = 0, 98 | max_length: int = 1024, 99 | dropout: float = 0.0, 100 | **kwargs 101 | ): 102 | """ 103 | Builds a custom GPT-2 config based on a given Transformers config, 104 | with a few more user-friendly aliases. 105 | """ 106 | 107 | return GPT2Config( 108 | vocab_size=vocab_size, 109 | n_positions=max_length, 110 | n_ctx=max_length, 111 | resid_pdrop=dropout, 112 | embd_pdrop=dropout, 113 | attn_pdrop=dropout, 114 | summary_first_dropout=dropout, 115 | bos_token_id=bos_token_id, 116 | eos_token_id=eos_token_id, 117 | **kwargs, 118 | ) 119 | 120 | 121 | def GPT2ConfigCPU( 122 | vocab_size: int = 1000, bos_token_id: int = 0, eos_token_id: int = 0, **kwargs 123 | ): 124 | """ 125 | Returns a GPT-2 config more suitable for training on a regular consumer CPU. 126 | """ 127 | 128 | return GPT2Config( 129 | vocab_size=vocab_size, 130 | n_positions=64, 131 | n_ctx=64, 132 | n_embd=128, 133 | n_layer=4, 134 | n_head=4, 135 | bos_token_id=bos_token_id, 136 | eos_token_id=eos_token_id, 137 | **kwargs, 138 | ) 139 | 140 | 141 | def GPTNeoConfigCPU( 142 | vocab_size: int = 1000, bos_token_id: int = 0, eos_token_id: int = 0, **kwargs 143 | ): 144 | """ 145 | Returns a GPT Neo config more suitable for training on a regular consumer CPU. 146 | """ 147 | 148 | return GPTNeoConfig( 149 | vocab_size=vocab_size, 150 | max_position_embeddings=64, 151 | hidden_size=256, 152 | window_size=32, 153 | intermediate_size=256, 154 | attention_types=[[["global", "local"], 2]], 155 | num_layers=4, 156 | num_heads=4, 157 | bos_token_id=bos_token_id, 158 | eos_token_id=eos_token_id, 159 | **kwargs, 160 | ) 161 | 162 | 163 | def find_index_of_subset(large_list, small_list): 164 | """ 165 | Returns the index after the small_list within the large list, 166 | Returns -1 if not present. 167 | 168 | Adapted from https://stackoverflow.com/a/45819222 which shows that it is 169 | performant for input lengths < 1000, which is the common case for this function. 170 | """ 171 | length_small_list = len(small_list) 172 | firstneedle = small_list[0] 173 | for idx, item in enumerate(large_list): 174 | if item == firstneedle: 175 | if large_list[idx : idx + length_small_list] == small_list: 176 | return idx + length_small_list 177 | return -1 178 | 179 | 180 | def skip_special_tokens(tensor, device, special_token_ids): 181 | """Filters out special tokens by ids in the given 1D tensor. 182 | 183 | Adapted from https://stackoverflow.com/a/62588955 184 | 185 | Args: 186 | tensor (tensor): PyTorch Tensor 187 | device (str): Device, usually "cpu" or "cuda:0" 188 | token_ids (set): List of Token IDs 189 | """ 190 | special_token_id_tensor = torch.unique(torch.as_tensor(special_token_ids)).to( 191 | device 192 | ) 193 | return tensor[ 194 | ~tensor.unsqueeze(1).eq(special_token_id_tensor.unsqueeze(1)).any(1) 195 | ].tolist() 196 | 197 | 198 | def model_max_length(config): 199 | """Returns the maximum generation length for the given model.""" 200 | return getattr(config, "n_positions", None) or getattr( 201 | config, "max_position_embeddings", None 202 | ) 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # aitextgen 2 | 3 | A robust Python tool for text-based AI training and generation using [OpenAI's](https://openai.com) [GPT-2](https://openai.com/blog/better-language-models/) and [EleutherAI's](https://www.eleuther.ai) [GPT Neo/GPT-3](https://github.com/EleutherAI/gpt-neo) architecture. 4 | 5 | aitextgen is a Python package that leverages [PyTorch](https://pytorch.org), [Hugging Face Transformers](https://github.com/huggingface/transformers) and [pytorch-lightning](https://github.com/PyTorchLightning/pytorch-lightning) with specific optimizations for text generation using GPT-2, plus _many_ added features. It is the successor to [textgenrnn](https://github.com/minimaxir/textgenrnn) and [gpt-2-simple](https://github.com/minimaxir/gpt-2-simple), taking the best of both packages: 6 | 7 | - Finetunes on a pretrained 124M/355M/774M GPT-2 model from OpenAI or a 125M/350M GPT Neo model from EleutherAI...or create your own GPT-2/GPT Neo model + tokenizer and train from scratch! 8 | - Generates text faster than gpt-2-simple and with better memory efficiency! 9 | - With Transformers, aitextgen preserves compatibility with the base package, allowing you to use the model for other NLP tasks, download custom GPT-2 models from the HuggingFace model repository, and upload your own models! Also, it uses the included `generate()` function to allow a massive amount of control over the generated text. 10 | - With pytorch-lightning, aitextgen trains models not just on CPUs and GPUs, but also _multiple_ GPUs and (eventually) TPUs! It also includes a pretty training progress bar, with the ability to add optional loggers. 11 | - The input dataset is its own object, allowing you to not only easily encode megabytes of data in seconds, cache, and compress it on a local computer before transporting to a remote server, but you are able to _merge_ datasets without biasing the resulting dataset, or _cross-train_ on multiple datasets to create blended output. 12 | 13 | You can read more about aitextgen [in the documentation](https://aitextgen.minimaxir.com/)! 14 | 15 | ## Demo 16 | 17 | You can play with aitextgen _for free_ with powerful GPUs using these Colaboratory Notebooks! 18 | 19 | - [Finetune OpenAI's 124M GPT-2 model (or GPT Neo) on your own dataset (GPU)](https://colab.research.google.com/drive/15qBZx5y9rdaQSyWpsreMDnTiZ5IlN0zD?usp=sharing) 20 | - [Train a GPT-2 model + tokenizer from scratch (GPU)](https://colab.research.google.com/drive/144MdX5aLqrQ3-YW-po81CQMrD6kpgpYh?usp=sharing) 21 | 22 | You can also play with custom [Reddit](notebooks/reddit_demo.ipynb) and [Hacker News](notebooks/hacker_news_demo.ipynb) demo models on your own PC. 23 | 24 | ## Installation 25 | 26 | aitextgen can be installed [from PyPI](https://pypi.org/project/aitextgen/): 27 | 28 | ```sh 29 | pip3 install aitextgen 30 | ``` 31 | 32 | ## Quick Examples 33 | 34 | Here's how you can quickly test out aitextgen on your own computer, even if you don't have a GPU! 35 | 36 | For generating text from a pretrained GPT-2 model: 37 | 38 | ```py3 39 | from aitextgen import aitextgen 40 | 41 | # Without any parameters, aitextgen() will download, cache, and load the 124M GPT-2 "small" model 42 | ai = aitextgen() 43 | 44 | ai.generate() 45 | ai.generate(n=3, max_length=100) 46 | ai.generate(n=3, prompt="I believe in unicorns because", max_length=100) 47 | ai.generate_to_file(n=10, prompt="I believe in unicorns because", max_length=100, temperature=1.2) 48 | ``` 49 | 50 | You can also generate from the command line: 51 | 52 | ```sh 53 | aitextgen generate 54 | aitextgen generate --prompt "I believe in unicorns because" --to_file False 55 | ``` 56 | 57 | Want to train your own mini GPT-2 model on your own computer? You can follow along [in this Jupyter Notebook](/notebooks/training_hello_world.ipynb) or, download this [text file of Shakespeare's plays](https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt), cd to that directory in a Terminal, open up a `python3` console and go: 58 | 59 | ```py3 60 | from aitextgen.TokenDataset import TokenDataset 61 | from aitextgen.tokenizers import train_tokenizer 62 | from aitextgen.utils import GPT2ConfigCPU 63 | from aitextgen import aitextgen 64 | 65 | # The name of the downloaded Shakespeare text for training 66 | file_name = "input.txt" 67 | 68 | # Train a custom BPE Tokenizer on the downloaded text 69 | # This will save one file: `aitextgen.tokenizer.json`, which contains the 70 | # information needed to rebuild the tokenizer. 71 | train_tokenizer(file_name) 72 | tokenizer_file = "aitextgen.tokenizer.json" 73 | 74 | # GPT2ConfigCPU is a mini variant of GPT-2 optimized for CPU-training 75 | # e.g. the # of input tokens here is 64 vs. 1024 for base GPT-2. 76 | config = GPT2ConfigCPU() 77 | 78 | # Instantiate aitextgen using the created tokenizer and config 79 | ai = aitextgen(tokenizer_file=tokenizer_file, config=config) 80 | 81 | # You can build datasets for training by creating TokenDatasets, 82 | # which automatically processes the dataset with the appropriate size. 83 | data = TokenDataset(file_name, tokenizer_file=tokenizer_file, block_size=64) 84 | 85 | # Train the model! It will save pytorch_model.bin periodically and after completion to the `trained_model` folder. 86 | # On a 2020 8-core iMac, this took ~25 minutes to run. 87 | ai.train(data, batch_size=8, num_steps=50000, generate_every=5000, save_every=5000) 88 | 89 | # Generate text from it! 90 | ai.generate(10, prompt="ROMEO:") 91 | 92 | # With your trained model, you can reload the model at any time by 93 | # providing the folder containing the pytorch_model.bin model weights + the config, and providing the tokenizer. 94 | ai2 = aitextgen(model_folder="trained_model", 95 | tokenizer_file="aitextgen.tokenizer.json") 96 | 97 | ai2.generate(10, prompt="ROMEO:") 98 | ``` 99 | 100 | Want to run aitextgen and finetune GPT-2? Use the Colab notebooks in the Demos section, or [follow the documentation](https://aitextgen.minimaxir.com/) to get more information and learn some helpful tips! 101 | 102 | ## Known Issues 103 | 104 | - TPUs cannot be used to train a model: although you _can_ train an aitextgen model on TPUs by setting `n_tpu_cores=8` in an appropriate runtime, and the training loss indeed does decrease, there are a number of miscellaneous blocking problems. [[Tracking GitHub Issue](https://github.com/minimaxir/aitextgen/issues/3)] 105 | 106 | ## Upcoming Features 107 | 108 | The current release (v0.5.X) of aitextgen **is considered to be a beta**, targeting the most common use cases. The Notebooks and examples written so far are tested to work, but more fleshing out of the docs/use cases will be done over the next few months in addition to fixing the known issues noted above. 109 | 110 | The next versions of aitextgen (and one of the reasons I made this package in the first place) will have native support for _schema-based generation_. (See [this repo](https://github.com/minimaxir/gpt-2-keyword-generation) for a rough proof-of-concept.) 111 | 112 | Additionally, I plan to develop an aitextgen [SaaS](https://en.wikipedia.org/wiki/Software_as_a_service) to allow anyone to run aitextgen in the cloud and build APIs/Twitter+Slack+Discord bots with just a few clicks. (The primary constraint is compute cost; if any venture capitalists are interested in funding the development of such a service, let me know.) 113 | 114 | I've listed more tentative features in the [UPCOMING](UPCOMING.md) document. 115 | 116 | ## Ethics 117 | 118 | aitextgen is a tool primarily intended to help facilitate creative content. It is not a tool intended to deceive. Although parody accounts are an obvious use case for this package, make sure you are _as upfront as possible_ with the methodology of the text you create. This includes: 119 | 120 | - State that the text was generated using aitextgen and/or a GPT-2 model architecture. (A link to this repo would be a bonus!) 121 | - If parodying a person, explicitly state that it is a parody, and reference who it is parodying. 122 | - If the generated text is human-curated, or if it's unsupervised random output. 123 | - Indicating who is maintaining/curating the AI-generated text. 124 | - Make a good-faith effort to remove overfit output from the generated text that matches the input text verbatim. 125 | 126 | It's fun to anthropomorphise the nameless "AI" as an abstract genius, but part of the reason I made aitextgen (and all my previous text-generation projects) is to make the technology more accessible and accurately demonstrate both its promise, and its limitations. **Any AI text generation projects that are deliberately deceptive may be disavowed.** 127 | 128 | ## Maintainer/Creator 129 | 130 | Max Woolf ([@minimaxir](https://minimaxir.com)) 131 | 132 | _Max's open-source projects are supported by his [Patreon](https://www.patreon.com/minimaxir) and [GitHub Sponsors](https://github.com/sponsors/minimaxir). If you found this project helpful, any monetary contributions to the Patreon are appreciated and will be put to good creative use._ 133 | 134 | ## License 135 | 136 | MIT 137 | -------------------------------------------------------------------------------- /aitextgen/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import subprocess 4 | import sys 5 | 6 | import torch 7 | from torch.optim import AdamW 8 | from torch.utils.data import DataLoader 9 | from tqdm.auto import tqdm 10 | from transformers import get_linear_schedule_with_warmup 11 | 12 | import pytorch_lightning as pl 13 | from pytorch_lightning.callbacks.progress import ProgressBarBase 14 | from pytorch_lightning.accelerators import TPUAccelerator 15 | 16 | 17 | class ATGTransformer(pl.LightningModule): 18 | """ 19 | A training module for aitextgen. 20 | """ 21 | 22 | def __init__(self, model, dataset, hparams, tokenizer): 23 | super(ATGTransformer, self).__init__() 24 | self.model, self.dataset, self.tokenizer = ( 25 | model, 26 | dataset, 27 | tokenizer, 28 | ) 29 | self.save_hyperparameters(hparams) 30 | 31 | def forward(self, inputs): 32 | return self.model(**inputs, return_dict=False) 33 | 34 | def training_step(self, batch, batch_num): 35 | outputs = self({"input_ids": batch, "labels": batch}) 36 | loss = outputs[0] 37 | 38 | return {"loss": loss} 39 | 40 | def train_dataloader(self): 41 | return DataLoader( 42 | self.dataset, 43 | batch_size=self.hparams["batch_size"], 44 | shuffle=True, 45 | pin_memory=self.hparams["pin_memory"], 46 | num_workers=self.hparams["num_workers"], 47 | ) 48 | 49 | def configure_optimizers(self): 50 | "Prepare optimizer" 51 | 52 | no_decay = ["bias", "LayerNorm.weight"] 53 | optimizer_grouped_parameters = [ 54 | { 55 | "params": [ 56 | p 57 | for n, p in self.model.named_parameters() 58 | if not any(nd in n for nd in no_decay) 59 | ], 60 | "weight_decay": self.hparams["weight_decay"], 61 | }, 62 | { 63 | "params": [ 64 | p 65 | for n, p in self.model.named_parameters() 66 | if any(nd in n for nd in no_decay) 67 | ], 68 | "weight_decay": 0.0, 69 | }, 70 | ] 71 | optimizer = AdamW( 72 | optimizer_grouped_parameters, 73 | lr=self.hparams["learning_rate"], 74 | eps=self.hparams["adam_epsilon"], 75 | ) 76 | 77 | scheduler = get_linear_schedule_with_warmup( 78 | optimizer, 79 | num_warmup_steps=self.hparams["warmup_steps"], 80 | num_training_steps=self.hparams["num_steps"], 81 | ) 82 | 83 | return [optimizer], [scheduler] 84 | 85 | 86 | class ATGProgressBar(ProgressBarBase): 87 | """A variant progress bar that works off of steps and prints periodically.""" 88 | 89 | def __init__( 90 | self, 91 | save_every, 92 | generate_every, 93 | output_dir, 94 | n_generate, 95 | gpu, 96 | smoothing, 97 | run_id, 98 | save_gdrive, 99 | progress_bar_refresh_rate, 100 | train_transformers_only, 101 | num_layers_freeze, 102 | ): 103 | super().__init__() 104 | self.enabled = True 105 | self.save_every = save_every 106 | self.generate_every = generate_every 107 | self.output_dir = output_dir 108 | self.n_generate = n_generate 109 | self.gpu = gpu 110 | self.steps = 0 111 | self.prev_avg_loss = None 112 | self.smoothing = smoothing 113 | self.run_id = run_id 114 | self.save_gdrive = save_gdrive 115 | self.progress_bar_refresh_rate = progress_bar_refresh_rate 116 | self.train_transformers_only = train_transformers_only 117 | self.num_layers_freeze = num_layers_freeze 118 | 119 | @property 120 | def save_every_check(self): 121 | return self.save_every > 0 and self.steps % self.save_every == 0 122 | 123 | def enabled(self): 124 | self.enabled = True 125 | 126 | def disable(self): 127 | self.enabled = False 128 | 129 | def on_train_start(self, trainer, pl_module): 130 | super().on_train_start(trainer, pl_module) 131 | self.main_progress_bar = tqdm( 132 | total=trainer.max_steps, 133 | disable=not self.enabled, 134 | smoothing=0, 135 | leave=True, 136 | dynamic_ncols=True, 137 | file=sys.stdout, 138 | ) 139 | self.freeze_layers(pl_module) 140 | 141 | def on_train_end(self, trainer, pl_module): 142 | self.main_progress_bar.close() 143 | self.unfreeze_layers(pl_module) 144 | 145 | def get_metrics(self, trainer, pl_module): 146 | # don't show the version number 147 | items = super().get_metrics(trainer, pl_module) 148 | items.pop("v_num", None) 149 | return items 150 | 151 | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): 152 | super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx) 153 | 154 | # clean up the GPU cache used for the benchmark 155 | # https://discuss.pytorch.org/t/about-torch-cuda-empty-cache/34232/4 156 | if self.steps == 0 and self.gpu: 157 | torch.cuda.empty_cache() 158 | 159 | metrics = self.get_metrics(trainer, pl_module) 160 | current_loss = float(metrics["loss"]) 161 | self.steps += 1 162 | avg_loss = 0 163 | if current_loss == current_loss: # don't add if current_loss is NaN 164 | avg_loss = self.average_loss( 165 | current_loss, self.prev_avg_loss, self.smoothing 166 | ) 167 | self.prev_avg_loss = avg_loss 168 | 169 | desc = f"Loss: {current_loss:.3f} — Avg: {avg_loss:.3f}" 170 | 171 | if self.steps % self.progress_bar_refresh_rate == 0: 172 | if self.gpu: 173 | # via pytorch-lightning's get_gpu_memory_map() 174 | result = subprocess.run( 175 | [ 176 | shutil.which("nvidia-smi"), 177 | "--query-gpu=memory.used", 178 | "--format=csv,nounits,noheader", 179 | ], 180 | encoding="utf-8", 181 | stdout=subprocess.PIPE, 182 | stderr=subprocess.PIPE, 183 | check=True, 184 | ) 185 | gpu_memory = result.stdout.strip().split(os.linesep)[0] 186 | desc += f" — GPU Mem: {gpu_memory} MB" 187 | self.main_progress_bar.update(self.progress_bar_refresh_rate) 188 | self.main_progress_bar.set_description(desc) 189 | 190 | if TPUAccelerator.is_available() and self.save_every_check: 191 | did_unfreeze = False 192 | if self.enabled: 193 | self.unfreeze_layers(pl_module) 194 | did_unfreeze = True 195 | self.save_pytorch_model(trainer, pl_module, tpu=True) 196 | if did_unfreeze: 197 | self.freeze_layers(pl_module) 198 | 199 | if self.enabled: 200 | did_unfreeze = False 201 | if not TPUAccelerator.is_available() and self.save_every_check: 202 | self.unfreeze_layers(pl_module) 203 | self.save_pytorch_model(trainer, pl_module) 204 | did_unfreeze = True 205 | 206 | if self.generate_every > 0 and self.steps % self.generate_every == 0: 207 | self.unfreeze_layers(pl_module) 208 | self.generate_sample_text(trainer, pl_module) 209 | did_unfreeze = True 210 | 211 | if did_unfreeze: 212 | self.freeze_layers(pl_module) 213 | 214 | def generate_sample_text(self, trainer, pl_module): 215 | self.main_progress_bar.write( 216 | f"\033[1m{self.steps:,} steps reached: generating sample texts.\033[0m" 217 | ) 218 | 219 | gen_length_max = getattr( 220 | pl_module.model.config, "n_positions", None 221 | ) or getattr(pl_module.model.config, "max_position_embeddings", None) 222 | gen_length = min(gen_length_max, 256) 223 | 224 | pad_token_id = getattr(pl_module.tokenizer, "pad_token_id", None) or getattr( 225 | pl_module.tokenizer, "eos_token_id", None 226 | ) 227 | 228 | outputs = pl_module.model.generate( 229 | input_ids=None, 230 | max_length=gen_length, 231 | do_sample=True, 232 | num_return_sequences=self.n_generate, 233 | temperature=0.7, 234 | pad_token_id=pad_token_id, 235 | ) 236 | 237 | gen_texts = pl_module.tokenizer.batch_decode(outputs, skip_special_tokens=True) 238 | 239 | for text in gen_texts: 240 | self.main_progress_bar.write("=" * 10) 241 | self.main_progress_bar.write(text) 242 | 243 | self.main_progress_bar.write("=" * 10) 244 | 245 | def save_pytorch_model(self, trainer, pl_module, tpu=False): 246 | 247 | if self.enabled: 248 | self.main_progress_bar.write( 249 | f"\033[1m{self.steps:,} steps reached: saving model to /{self.output_dir}\033[0m" 250 | ) 251 | if tpu: 252 | import torch_xla.core.xla_model as xm 253 | 254 | pl_module.model.save_pretrained(self.output_dir, save_function=xm.save) 255 | else: 256 | pl_module.model.save_pretrained(self.output_dir) 257 | 258 | if self.enabled and self.save_gdrive: 259 | for pt_file in ["pytorch_model.bin", "config.json"]: 260 | shutil.copyfile( 261 | os.path.join(self.output_dir, pt_file), 262 | os.path.join("/content/drive/MyDrive/", self.run_id, pt_file), 263 | ) 264 | 265 | def average_loss(self, current_loss, prev_avg_loss, smoothing): 266 | if prev_avg_loss is None: 267 | return current_loss 268 | else: 269 | return (smoothing * current_loss) + (1 - smoothing) * prev_avg_loss 270 | 271 | def modify_layers(self, pl_module, unfreeze): 272 | if self.train_transformers_only: 273 | for name, param in pl_module.model.named_parameters(): 274 | if self.num_layers_freeze: 275 | layer_num = int(name.split(".")[2]) if ".h." in name else None 276 | to_freeze = layer_num and layer_num < self.num_layers_freeze 277 | else: 278 | to_freeze = False 279 | if name == "transformer.wte.weight" or to_freeze: 280 | param.requires_grad = unfreeze 281 | 282 | def freeze_layers(self, pl_module): 283 | self.modify_layers(pl_module, False) 284 | 285 | def unfreeze_layers(self, pl_module): 286 | self.modify_layers(pl_module, True) 287 | -------------------------------------------------------------------------------- /notebooks/reddit_demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Reddit aitextgen\n", 8 | "\n", 9 | "A demo on how aitextgen can be used to create bespoke Reddit submission titles.\n", 10 | "\n", 11 | "**WARNING**: The content of the output may be NSFW!\n", 12 | "\n", 13 | "**NOTE**: This is released as a proof of concept for mini-GPT-2 models; quality of titles may vary." 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 1, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "from aitextgen import aitextgen" 23 | ] 24 | }, 25 | { 26 | "cell_type": "markdown", 27 | "metadata": {}, 28 | "source": [ 29 | "## Loading the Reddit Model\n", 30 | "\n", 31 | "The `minimaxir/reddit` model was finetuned on Reddit submissions up until August 2019, from the top 1,000 posts for each of the [top 1,000 subreddits](https://docs.google.com/spreadsheets/d/1zInLaR3daOC3N2ZBkudG5ypgJSaEwXqHRo9VazPFq-w/edit?usp=sharing) by unique submitters (an equal number of posts from each subreddit is important to prevent sampling bias).\n", 32 | "\n", 33 | "It uses a custom GPT-2 architecture that is only 30 MB on disk (compared to 124M GPT-2's 500MB on disk.)\n", 34 | "\n", 35 | "Running the cell will download the model and cache it into `/aitextgen`." 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": 2, 41 | "metadata": {}, 42 | "outputs": [ 43 | { 44 | "name": "stderr", 45 | "output_type": "stream", 46 | "text": [ 47 | "INFO:aitextgen:Loading minimaxir/reddit model from /aitextgen.\n", 48 | "INFO:aitextgen:Using the tokenizer for minimaxir/reddit.\n" 49 | ] 50 | } 51 | ], 52 | "source": [ 53 | "ai = aitextgen(model=\"minimaxir/reddit\")" 54 | ] 55 | }, 56 | { 57 | "cell_type": "markdown", 58 | "metadata": {}, 59 | "source": [ 60 | "## Generation\n", 61 | "\n", 62 | "Since the model is so small, generation happens almost immediately, even in bulk." 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": 9, 68 | "metadata": {}, 69 | "outputs": [ 70 | { 71 | "name": "stdout", 72 | "output_type": "stream", 73 | "text": [ 74 | "halo Damn...\n" 75 | ] 76 | } 77 | ], 78 | "source": [ 79 | "ai.generate()" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": 4, 85 | "metadata": {}, 86 | "outputs": [ 87 | { 88 | "name": "stdout", 89 | "output_type": "stream", 90 | "text": [ 91 | "tifu TIFU by having a boner at my girlfriend's house\n", 92 | "==========\n", 93 | "television The Amazon Prime Minister's New Year, and a General Library Is Family With 'Realistic'\n", 94 | "==========\n", 95 | "summonerswar My friends and I are getting a new new tier list!\n", 96 | "==========\n", 97 | "congratslikeimfive I've been born and it's been an hour since I was 6.\n", 98 | "==========\n", 99 | "cakeday I'm still a bit inexperienced with it.\n" 100 | ] 101 | } 102 | ], 103 | "source": [ 104 | "ai.generate(5)" 105 | ] 106 | }, 107 | { 108 | "cell_type": "markdown", 109 | "metadata": {}, 110 | "source": [ 111 | "## Prompted Input\n", 112 | "\n", 113 | "You can seed input with a `prompt` to get specific types of Reddit posts. The prompt will be **bolded** in the output.\n", 114 | "\n", 115 | "Since the lowercase name of the subreddit is always first, you can use that in a prompt as a control code to (roughly) ensure output is from that subreddit." 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": 5, 121 | "metadata": {}, 122 | "outputs": [ 123 | { 124 | "name": "stdout", 125 | "output_type": "stream", 126 | "text": [ 127 | "\u001b[1maskreddit What\u001b[0m’s your favourite most unique book you’ve ever read?\n", 128 | "==========\n", 129 | "\u001b[1maskreddit What\u001b[0m's the biggest difference between a pedophile and a child?\n", 130 | "==========\n", 131 | "\u001b[1maskreddit What\u001b[0m's a way to avoid a risk of your life?\n", 132 | "==========\n", 133 | "\u001b[1maskreddit What\u001b[0m's the most horrible thing you've ever done for your SO?\n", 134 | "==========\n", 135 | "\u001b[1maskreddit What\u001b[0m’s your most illegal/vocal/miscarriage/sort of Reddit?\n" 136 | ] 137 | } 138 | ], 139 | "source": [ 140 | "ai.generate(5, prompt=\"askreddit What\")" 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": 6, 146 | "metadata": {}, 147 | "outputs": [ 148 | { 149 | "name": "stdout", 150 | "output_type": "stream", 151 | "text": [ 152 | "\u001b[1mgaming The Witcher 3\u001b[0m will be the first game in a row when it's in a game\n", 153 | "==========\n", 154 | "\u001b[1mgaming The Witcher 3\u001b[0m is looking for an end game\n", 155 | "==========\n", 156 | "\u001b[1mgaming The Witcher 3\u001b[0m is coming.\n", 157 | "==========\n", 158 | "\u001b[1mgaming The Witcher 3\u001b[0m\n", 159 | "==========\n", 160 | "\u001b[1mgaming The Witcher 3\u001b[0m is now in the new DS3\n" 161 | ] 162 | } 163 | ], 164 | "source": [ 165 | "ai.generate(5, prompt=\"gaming The Witcher 3\")" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": 13, 171 | "metadata": {}, 172 | "outputs": [ 173 | { 174 | "name": "stdout", 175 | "output_type": "stream", 176 | "text": [ 177 | "\u001b[1mrelationships My parents\u001b[0m [25/M] made me [27/F] with a \"mother\". The hospital has finally done it. I'm starting to believe it's been a long day today.\n", 178 | "==========\n", 179 | "\u001b[1mrelationships My parents\u001b[0m (26F) told me that my boyfriend (27M) of 4 years is pregnant, and I have a long long long partner and have no idea how to afford it.\n", 180 | "==========\n", 181 | "\u001b[1mrelationships My parents\u001b[0m [28M] wants me [27F] that I'm not sure how to respond to my neighbor's [21F] relationship with my girlfriend [31F] of 3 years. Should I be worried about her?\n", 182 | "==========\n", 183 | "\u001b[1mrelationships My parents\u001b[0m [33F] broke my boyfriend [28F] with my best friend [26F] because I don't want to be a gentleman. What do I do?\n", 184 | "==========\n", 185 | "\u001b[1mrelationships My parents\u001b[0m [27F] cheated on me (27F) with my girlfriend (26M) and my husband (22F) don't know how to deal with her boyfriend (27M)\n" 186 | ] 187 | } 188 | ], 189 | "source": [ 190 | "ai.generate(5, prompt=\"relationships My parents\")" 191 | ] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "execution_count": 17, 196 | "metadata": {}, 197 | "outputs": [ 198 | { 199 | "name": "stdout", 200 | "output_type": "stream", 201 | "text": [ 202 | "\u001b[1mamitheasshole AITA\u001b[0m for telling my son to be a parent because she was a kid in middle school?\n", 203 | "==========\n", 204 | "\u001b[1mamitheasshole AITA\u001b[0m for telling my son she was being an adult?\n", 205 | "==========\n", 206 | "\u001b[1mamitheasshole AITA\u001b[0m for not telling my wife that she is pregnant?\n", 207 | "==========\n", 208 | "\u001b[1mamitheasshole AITA\u001b[0m for telling my wife that I’m not pregnant if it’s my best friend?\n", 209 | "==========\n", 210 | "\u001b[1mamitheasshole AITA\u001b[0m for not giving up my child over my brother's wedding?\n" 211 | ] 212 | } 213 | ], 214 | "source": [ 215 | "ai.generate(5, prompt=\"amitheasshole AITA\")" 216 | ] 217 | }, 218 | { 219 | "cell_type": "markdown", 220 | "metadata": {}, 221 | "source": [ 222 | "## Bulk Generation to File\n", 223 | "\n", 224 | "You can use `generate_to_file()` to create many Reddit titles." 225 | ] 226 | }, 227 | { 228 | "cell_type": "code", 229 | "execution_count": 18, 230 | "metadata": {}, 231 | "outputs": [ 232 | { 233 | "name": "stderr", 234 | "output_type": "stream", 235 | "text": [ 236 | "INFO:aitextgen:Generating 1,000 texts to ATG_20200518_001105_30355161.txt\n" 237 | ] 238 | }, 239 | { 240 | "data": { 241 | "application/vnd.jupyter.widget-view+json": { 242 | "model_id": "63e3b99f783c4457a9342714ff009963", 243 | "version_major": 2, 244 | "version_minor": 0 245 | }, 246 | "text/plain": [ 247 | "HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))" 248 | ] 249 | }, 250 | "metadata": {}, 251 | "output_type": "display_data" 252 | }, 253 | { 254 | "name": "stdout", 255 | "output_type": "stream", 256 | "text": [ 257 | "\n" 258 | ] 259 | } 260 | ], 261 | "source": [ 262 | "ai.generate_to_file(1000, batch_size=20)" 263 | ] 264 | }, 265 | { 266 | "cell_type": "markdown", 267 | "metadata": {}, 268 | "source": [ 269 | "# MIT License\n", 270 | "\n", 271 | "Copyright (c) 2020 Max Woolf\n", 272 | "\n", 273 | "Permission is hereby granted, free of charge, to any person obtaining a copy\n", 274 | "of this software and associated documentation files (the \"Software\"), to deal\n", 275 | "in the Software without restriction, including without limitation the rights\n", 276 | "to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n", 277 | "copies of the Software, and to permit persons to whom the Software is\n", 278 | "furnished to do so, subject to the following conditions:\n", 279 | "\n", 280 | "The above copyright notice and this permission notice shall be included in all\n", 281 | "copies or substantial portions of the Software.\n", 282 | "\n", 283 | "THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n", 284 | "IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n", 285 | "FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n", 286 | "AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n", 287 | "LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n", 288 | "OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n", 289 | "SOFTWARE." 290 | ] 291 | } 292 | ], 293 | "metadata": { 294 | "kernelspec": { 295 | "display_name": "Python 3.7.5 64-bit", 296 | "language": "python", 297 | "name": "python37564bitb9ff4e3157b244a896f88d1e5f3eb324" 298 | }, 299 | "language_info": { 300 | "codemirror_mode": { 301 | "name": "ipython", 302 | "version": 3 303 | }, 304 | "file_extension": ".py", 305 | "mimetype": "text/x-python", 306 | "name": "python", 307 | "nbconvert_exporter": "python", 308 | "pygments_lexer": "ipython3", 309 | "version": "3.7.7" 310 | } 311 | }, 312 | "nbformat": 4, 313 | "nbformat_minor": 4 314 | } -------------------------------------------------------------------------------- /notebooks/generation_hello_world.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# aitextgen Generation Hello World\n", 8 | "\n", 9 | "by Max Woolf\n", 10 | "\n", 11 | "A \"Hello World\" Tutorial to show how generation works with aitextgen!" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 1, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "from aitextgen import aitextgen" 21 | ] 22 | }, 23 | { 24 | "cell_type": "markdown", 25 | "metadata": {}, 26 | "source": [ 27 | "Without any parameters, `aitextgen()` will download, cache, and load the 124M \"small\" form of GPT-2 (~500MB on disk) to `/aitextgen`; good for prototyping generation.\n", 28 | "\n", 29 | "You can change the directory the model is saved/loaded from by setting the `cache_dir` parameter." 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 2, 35 | "metadata": {}, 36 | "outputs": [ 37 | { 38 | "name": "stderr", 39 | "output_type": "stream", 40 | "text": [ 41 | "INFO:aitextgen.aitextgen:Downloading gpt2 model.\n" 42 | ] 43 | }, 44 | { 45 | "data": { 46 | "application/vnd.jupyter.widget-view+json": { 47 | "model_id": "2a4e7d7f9a5a4dcbbe4f95fec9cde7c0", 48 | "version_major": 2, 49 | "version_minor": 0 50 | }, 51 | "text/plain": [ 52 | "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=554.0, style=ProgressStyle(description_…" 53 | ] 54 | }, 55 | "metadata": {}, 56 | "output_type": "display_data" 57 | }, 58 | { 59 | "name": "stdout", 60 | "output_type": "stream", 61 | "text": [ 62 | "\n" 63 | ] 64 | }, 65 | { 66 | "data": { 67 | "application/vnd.jupyter.widget-view+json": { 68 | "model_id": "6ba5dd224b7c41d995dc7d73169adf94", 69 | "version_major": 2, 70 | "version_minor": 0 71 | }, 72 | "text/plain": [ 73 | "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=548118077.0, style=ProgressStyle(descri…" 74 | ] 75 | }, 76 | "metadata": {}, 77 | "output_type": "display_data" 78 | }, 79 | { 80 | "name": "stdout", 81 | "output_type": "stream", 82 | "text": [ 83 | "\n" 84 | ] 85 | }, 86 | { 87 | "name": "stderr", 88 | "output_type": "stream", 89 | "text": [ 90 | "INFO:aitextgen.aitextgen:Using the default GPT-2 Tokenizer.\n" 91 | ] 92 | } 93 | ], 94 | "source": [ 95 | "ai = aitextgen()" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": 3, 101 | "metadata": {}, 102 | "outputs": [ 103 | { 104 | "name": "stdout", 105 | "output_type": "stream", 106 | "text": [ 107 | "\n", 108 | "SARUNANA SINGAPORE\n", 109 | "\n", 110 | "The Sri Lankan Navy conducted its fourth round of combat maneuvers for the first time in its 20 year training history on Thursday.\n", 111 | "\n", 112 | "The Sea Southerners' four-speed patrol was part of the naval exercise \"Operation Sustainability,\" which is a collaborative effort with the U.S. Naval Reserve, the Navy and the South Korea Maritime Self Defense Force.\n", 113 | "\n", 114 | "The patrol, which was the result of a joint joint training exercise that resulted in over 7,000 American sailors and military aircraft, took place on the South Korean coast of North and South Korea, said Vice Admiral James M. \"Skip\" Pincap for the Navy, which has been involved in exercises in the North and South Korean seas for about 70 years.\n", 115 | "\n", 116 | "The U.S. Navy's current naval exercise, which spans four months, commenced last year. That's in response to a joint military exercise of the U.S. Navy\n" 117 | ] 118 | } 119 | ], 120 | "source": [ 121 | "ai.generate()" 122 | ] 123 | }, 124 | { 125 | "cell_type": "markdown", 126 | "metadata": {}, 127 | "source": [ 128 | "You can generate multiple texts at a time with the `n` parameter. You can also control the length of the generated text with the `max_length` parameter (default 200 tokens)" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": 4, 134 | "metadata": {}, 135 | "outputs": [ 136 | { 137 | "name": "stdout", 138 | "output_type": "stream", 139 | "text": [ 140 | "AUSTIN -- In March, Travis County prosecutors said that their investigation into a fatal car-crash accident in a neighborhood near Austin was being compromised because they were using an incomplete and unverified criminal history.\n", 141 | "\n", 142 | "Prosecutors said they learned that Tilden County DA Greg Saylor and district attorney Chris D'Amico are two of four Texas officials on the case that were implicated in the case. According to a search warrant obtained by The Austin American-Statesman, the warrant reads:\n", 143 | "==========\n", 144 | "In 2008, the city had a new water tower. It was located on the corner of North and State avenues in the city center. This tower was about an inch tall, and had about five stories in the middle. The walls were dark brick with a concrete base, white flooring, and thick slabs of brick that filled the surrounding building.\n", 145 | "\n", 146 | "This building was designed by the City.\n", 147 | "\n", 148 | "The new housing did not last long…\n", 149 | "\n", 150 | "On April 18, 2009, a\n", 151 | "==========\n", 152 | "In recent years there's been a significant escalation in the fight against Ebola that can result in a serious illness or death, and several countries have begun to seek out and treat such people.\n", 153 | "\n", 154 | "The United States and others have begun to see and track how they've been treated by WHO's two \"Ebola Surveillance teams,\" where they monitor, evaluate and respond to the cases and determine who are the most likely to be infected in their local communities.\n", 155 | "\n", 156 | "But there's also been much\n" 157 | ] 158 | } 159 | ], 160 | "source": [ 161 | "ai.generate(n=3, max_length=100)" 162 | ] 163 | }, 164 | { 165 | "cell_type": "markdown", 166 | "metadata": {}, 167 | "source": [ 168 | "You can optionally seed the input text with a `prompt`. If you do, the prompt will be bolded in the console/notebook output." 169 | ] 170 | }, 171 | { 172 | "cell_type": "code", 173 | "execution_count": 5, 174 | "metadata": {}, 175 | "outputs": [ 176 | { 177 | "name": "stdout", 178 | "output_type": "stream", 179 | "text": [ 180 | "\u001b[1mI believe in unicorns because\u001b[0m of their ability to read, write and recognize syllables. They have also become the source of modern computer science, thanks to the power of words like, say, \"This is a monkey\"; or \"This is a tree.\"\n", 181 | "\n", 182 | "Now some will point fingers at Apple, which says that \"C-word\" is more like a dictionary than a computer program. They are correct, that Apple's product language, known as Apple II, is for words. They\n", 183 | "==========\n", 184 | "\u001b[1mI believe in unicorns because\u001b[0m of their good, bad, and ugly qualities. I believe that unicorns are the kind that have good, bad, and ugly qualities. I believe that unicorns have high moral and ethical standards because they're kindhearted.\"\n", 185 | "\n", 186 | "The group's first act will be to show that a group's character is worth celebrating because it's kind, gentle, and not cruel.\n", 187 | "\n", 188 | "\"By encouraging good behavior, to show we can get things done, to tell us\n", 189 | "==========\n", 190 | "\u001b[1mI believe in unicorns because\u001b[0m there are thousands of them.\n", 191 | "\n", 192 | "A few weeks ago I saw people get frustrated when I told them about how long it takes to get from point A to point B, as my first priority was getting those same people where they actually belong. Why isn't there something more to accomplish before those first six days, or even half a day after these first six days of making such a conscious effort? I want my family and friends to know that I are not here today\n" 193 | ] 194 | } 195 | ], 196 | "source": [ 197 | "ai.generate(n=3, max_length=100, prompt=\"I believe in unicorns because\")" 198 | ] 199 | }, 200 | { 201 | "cell_type": "markdown", 202 | "metadata": {}, 203 | "source": [ 204 | "Lastly, you can generate texts in bulk using `generate_to_file()` and save them to file, allowing you search through them and curate them!\n", 205 | "\n", 206 | "You can also change the `temperature` of the text to make it less crazy...or more crazy." 207 | ] 208 | }, 209 | { 210 | "cell_type": "code", 211 | "execution_count": 6, 212 | "metadata": {}, 213 | "outputs": [ 214 | { 215 | "name": "stderr", 216 | "output_type": "stream", 217 | "text": [ 218 | "INFO:root:Generating 10 texts to ATG_20200506_035322_13582062.txt\n" 219 | ] 220 | }, 221 | { 222 | "data": { 223 | "application/vnd.jupyter.widget-view+json": { 224 | "model_id": "08055273e46f4983a8042683880171ff", 225 | "version_major": 2, 226 | "version_minor": 0 227 | }, 228 | "text/plain": [ 229 | "HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))" 230 | ] 231 | }, 232 | "metadata": {}, 233 | "output_type": "display_data" 234 | }, 235 | { 236 | "name": "stdout", 237 | "output_type": "stream", 238 | "text": [ 239 | "\n" 240 | ] 241 | } 242 | ], 243 | "source": [ 244 | "ai.generate_to_file(n=10,\n", 245 | " prompt=\"I believe in unicorns because\",\n", 246 | " max_length=100, temperature=1.2)" 247 | ] 248 | }, 249 | { 250 | "cell_type": "markdown", 251 | "metadata": {}, 252 | "source": [ 253 | "# MIT License\n", 254 | "\n", 255 | "Copyright (c) 2020 Max Woolf\n", 256 | "\n", 257 | "Permission is hereby granted, free of charge, to any person obtaining a copy\n", 258 | "of this software and associated documentation files (the \"Software\"), to deal\n", 259 | "in the Software without restriction, including without limitation the rights\n", 260 | "to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n", 261 | "copies of the Software, and to permit persons to whom the Software is\n", 262 | "furnished to do so, subject to the following conditions:\n", 263 | "\n", 264 | "The above copyright notice and this permission notice shall be included in all\n", 265 | "copies or substantial portions of the Software.\n", 266 | "\n", 267 | "THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n", 268 | "IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n", 269 | "FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n", 270 | "AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n", 271 | "LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n", 272 | "OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n", 273 | "SOFTWARE." 274 | ] 275 | } 276 | ], 277 | "metadata": { 278 | "kernelspec": { 279 | "display_name": "Python 3.7.5 64-bit", 280 | "language": "python", 281 | "name": "python37564bitb9ff4e3157b244a896f88d1e5f3eb324" 282 | }, 283 | "language_info": { 284 | "codemirror_mode": { 285 | "name": "ipython", 286 | "version": 3 287 | }, 288 | "file_extension": ".py", 289 | "mimetype": "text/x-python", 290 | "name": "python", 291 | "nbconvert_exporter": "python", 292 | "pygments_lexer": "ipython3", 293 | "version": "3.7.7" 294 | } 295 | }, 296 | "nbformat": 4, 297 | "nbformat_minor": 2 298 | } 299 | -------------------------------------------------------------------------------- /notebooks/hacker_news_demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Hacker News aitextgen\n", 8 | "\n", 9 | "A demo on how aitextgen can be used to create bespoke Hacker News submission titles.\n", 10 | "\n", 11 | "**NOTE**: This is released as a proof of concept for mini-GPT-2 models; quality of titles may vary." 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 1, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "from aitextgen import aitextgen" 21 | ] 22 | }, 23 | { 24 | "cell_type": "markdown", 25 | "metadata": {}, 26 | "source": [ 27 | "## Loading the Hacker News Model\n", 28 | "\n", 29 | "The `minimaxir/hacker-news` model was finetuned on HN submissions up until May 12th with atleast 5 points.\n", 30 | "\n", 31 | "It uses a custom GPT-2 architecture that is only 30 MB on disk (compared to 124M GPT-2's 500MB on disk.)\n", 32 | "\n", 33 | "Running the cell will download the model and cache it into `/aitextgen`." 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 2, 39 | "metadata": {}, 40 | "outputs": [ 41 | { 42 | "name": "stderr", 43 | "output_type": "stream", 44 | "text": [ 45 | "INFO:aitextgen:Loading minimaxir/hacker-news model from /aitextgen.\n" 46 | ] 47 | }, 48 | { 49 | "data": { 50 | "application/vnd.jupyter.widget-view+json": { 51 | "model_id": "219f11651d264ea89c53d26e51fc1a2e", 52 | "version_major": 2, 53 | "version_minor": 0 54 | }, 55 | "text/plain": [ 56 | "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=539.0, style=ProgressStyle(description_…" 57 | ] 58 | }, 59 | "metadata": {}, 60 | "output_type": "display_data" 61 | }, 62 | { 63 | "name": "stdout", 64 | "output_type": "stream", 65 | "text": [ 66 | "\n" 67 | ] 68 | }, 69 | { 70 | "data": { 71 | "application/vnd.jupyter.widget-view+json": { 72 | "model_id": "9692a8e7fa934ee9b5b0cbe12c95b4a4", 73 | "version_major": 2, 74 | "version_minor": 0 75 | }, 76 | "text/plain": [ 77 | "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=30458628.0, style=ProgressStyle(descrip…" 78 | ] 79 | }, 80 | "metadata": {}, 81 | "output_type": "display_data" 82 | }, 83 | { 84 | "name": "stdout", 85 | "output_type": "stream", 86 | "text": [ 87 | "\n" 88 | ] 89 | }, 90 | { 91 | "name": "stderr", 92 | "output_type": "stream", 93 | "text": [ 94 | "INFO:aitextgen:Using the tokenizer for minimaxir/hacker-news.\n" 95 | ] 96 | }, 97 | { 98 | "data": { 99 | "application/vnd.jupyter.widget-view+json": { 100 | "model_id": "bc616f81fd264615a7a32e2776ea29b5", 101 | "version_major": 2, 102 | "version_minor": 0 103 | }, 104 | "text/plain": [ 105 | "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=74927.0, style=ProgressStyle(descriptio…" 106 | ] 107 | }, 108 | "metadata": {}, 109 | "output_type": "display_data" 110 | }, 111 | { 112 | "name": "stdout", 113 | "output_type": "stream", 114 | "text": [ 115 | "\n" 116 | ] 117 | }, 118 | { 119 | "data": { 120 | "application/vnd.jupyter.widget-view+json": { 121 | "model_id": "41387b5113ea4e45bee4fc5a790bbf61", 122 | "version_major": 2, 123 | "version_minor": 0 124 | }, 125 | "text/plain": [ 126 | "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=35091.0, style=ProgressStyle(descriptio…" 127 | ] 128 | }, 129 | "metadata": {}, 130 | "output_type": "display_data" 131 | }, 132 | { 133 | "name": "stdout", 134 | "output_type": "stream", 135 | "text": [ 136 | "\n" 137 | ] 138 | }, 139 | { 140 | "data": { 141 | "application/vnd.jupyter.widget-view+json": { 142 | "model_id": "2427ce1184754b8088601d2a0296411a", 143 | "version_major": 2, 144 | "version_minor": 0 145 | }, 146 | "text/plain": [ 147 | "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=120.0, style=ProgressStyle(description_…" 148 | ] 149 | }, 150 | "metadata": {}, 151 | "output_type": "display_data" 152 | }, 153 | { 154 | "name": "stdout", 155 | "output_type": "stream", 156 | "text": [ 157 | "\n" 158 | ] 159 | }, 160 | { 161 | "data": { 162 | "application/vnd.jupyter.widget-view+json": { 163 | "model_id": "06052496d23e4c7b8e2c0722ff65b9de", 164 | "version_major": 2, 165 | "version_minor": 0 166 | }, 167 | "text/plain": [ 168 | "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=2.0, style=ProgressStyle(description_wi…" 169 | ] 170 | }, 171 | "metadata": {}, 172 | "output_type": "display_data" 173 | }, 174 | { 175 | "name": "stdout", 176 | "output_type": "stream", 177 | "text": [ 178 | "\n" 179 | ] 180 | } 181 | ], 182 | "source": [ 183 | "ai = aitextgen(model=\"minimaxir/hacker-news\")" 184 | ] 185 | }, 186 | { 187 | "cell_type": "markdown", 188 | "metadata": {}, 189 | "source": [ 190 | "## Generation\n", 191 | "\n", 192 | "Since the model is so small, generation happens almost immediately, even in bulk." 193 | ] 194 | }, 195 | { 196 | "cell_type": "code", 197 | "execution_count": 3, 198 | "metadata": {}, 199 | "outputs": [ 200 | { 201 | "name": "stdout", 202 | "output_type": "stream", 203 | "text": [ 204 | "Kinect can now centralize cellphone locations, not their pictures\n" 205 | ] 206 | } 207 | ], 208 | "source": [ 209 | "ai.generate()" 210 | ] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "execution_count": 4, 215 | "metadata": {}, 216 | "outputs": [ 217 | { 218 | "name": "stdout", 219 | "output_type": "stream", 220 | "text": [ 221 | "Ask HN: Should I start writing a blog post in Python?\n", 222 | "==========\n", 223 | "The Psychology of Human Misjudgment (2012)\n", 224 | "==========\n", 225 | "New York' New Year: $99 Linux PC\n", 226 | "==========\n", 227 | "C++11/12 Released\n", 228 | "==========\n", 229 | "Dynamic types in Go\n" 230 | ] 231 | } 232 | ], 233 | "source": [ 234 | "ai.generate(5)" 235 | ] 236 | }, 237 | { 238 | "cell_type": "markdown", 239 | "metadata": {}, 240 | "source": [ 241 | "## Prompted Input\n", 242 | "\n", 243 | "You can seed input with a `prompt` to get specific types of HN posts. The prompt will be **bolded** in the output." 244 | ] 245 | }, 246 | { 247 | "cell_type": "code", 248 | "execution_count": 12, 249 | "metadata": {}, 250 | "outputs": [ 251 | { 252 | "name": "stdout", 253 | "output_type": "stream", 254 | "text": [ 255 | "\u001b[1mAsk HN\u001b[0m: What are some good (O'Reilly eval) books for a new web-based project?\n", 256 | "==========\n", 257 | "\u001b[1mAsk HN\u001b[0m: How to avoid the Huawei of 20k job candidates?\n", 258 | "==========\n", 259 | "\u001b[1mAsk HN\u001b[0m: How to grow your startup\n", 260 | "==========\n", 261 | "\u001b[1mAsk HN\u001b[0m: What's the best way to learn a new languages on your website?\n", 262 | "==========\n", 263 | "\u001b[1mAsk HN\u001b[0m: How to get started in Machine Learning?\n" 264 | ] 265 | } 266 | ], 267 | "source": [ 268 | "ai.generate(5, prompt=\"Ask HN\")" 269 | ] 270 | }, 271 | { 272 | "cell_type": "code", 273 | "execution_count": 6, 274 | "metadata": {}, 275 | "outputs": [ 276 | { 277 | "name": "stdout", 278 | "output_type": "stream", 279 | "text": [ 280 | "\u001b[1mShow HN\u001b[0m: The Penetration Tester\n", 281 | "==========\n", 282 | "\u001b[1mShow HN\u001b[0m: qVD.S.Next Windows Awesomeness\n", 283 | "==========\n", 284 | "\u001b[1mShow HN\u001b[0m: My Startup – a crowdfunded satellite news aggregator\n", 285 | "==========\n", 286 | "\u001b[1mShow HN\u001b[0m: The JavaScript Way to Learn JavaScript Within the Web\n", 287 | "==========\n", 288 | "\u001b[1mShow HN\u001b[0m: Hacker News like / you read the message\n" 289 | ] 290 | } 291 | ], 292 | "source": [ 293 | "ai.generate(5, prompt=\"Show HN\")" 294 | ] 295 | }, 296 | { 297 | "cell_type": "code", 298 | "execution_count": 7, 299 | "metadata": {}, 300 | "outputs": [ 301 | { 302 | "name": "stdout", 303 | "output_type": "stream", 304 | "text": [ 305 | "\u001b[1mElon Musk\u001b[0m Says Tesla Is a Wireless Carrier Has Been Laying Off\n", 306 | "==========\n", 307 | "\u001b[1mElon Musk\u001b[0m’s Family Secretary of Munich Is the New Model 3\n", 308 | "==========\n", 309 | "\u001b[1mElon Musk\u001b[0m is a suitable person to learn the originally good\n", 310 | "==========\n", 311 | "\u001b[1mElon Musk\u001b[0m's Hyperloop Is a Success\n", 312 | "==========\n", 313 | "\u001b[1mElon Musk\u001b[0m’s New Nexus Program\n" 314 | ] 315 | } 316 | ], 317 | "source": [ 318 | "ai.generate(5, prompt=\"Elon Musk\")" 319 | ] 320 | }, 321 | { 322 | "cell_type": "code", 323 | "execution_count": 10, 324 | "metadata": {}, 325 | "outputs": [ 326 | { 327 | "name": "stdout", 328 | "output_type": "stream", 329 | "text": [ 330 | "\u001b[1mGoogle says\u001b[0m its employees are working with Amazon and Apple\n", 331 | "==========\n", 332 | "\u001b[1mGoogle says\u001b[0m it’s peaked\n", 333 | "==========\n", 334 | "\u001b[1mGoogle says\u001b[0m it is flea banning visible to people who worked in U.S.\n", 335 | "==========\n", 336 | "\u001b[1mGoogle says\u001b[0m it will not allow enemy mine to secure sensitive information\n", 337 | "==========\n", 338 | "\u001b[1mGoogle says\u001b[0m no to Google for Java\n" 339 | ] 340 | } 341 | ], 342 | "source": [ 343 | "ai.generate(5, prompt=\"Google says\")" 344 | ] 345 | }, 346 | { 347 | "cell_type": "markdown", 348 | "metadata": {}, 349 | "source": [ 350 | "## Bulk Generation to File\n", 351 | "\n", 352 | "You can use `generate_to_file()` to create many HN titles." 353 | ] 354 | }, 355 | { 356 | "cell_type": "code", 357 | "execution_count": 11, 358 | "metadata": {}, 359 | "outputs": [ 360 | { 361 | "name": "stderr", 362 | "output_type": "stream", 363 | "text": [ 364 | "INFO:aitextgen:Generating 1,000 texts to ATG_20200517_235441_14821584.txt\n" 365 | ] 366 | }, 367 | { 368 | "data": { 369 | "application/vnd.jupyter.widget-view+json": { 370 | "model_id": "48e73d0062ea487d8402af3acc06bddc", 371 | "version_major": 2, 372 | "version_minor": 0 373 | }, 374 | "text/plain": [ 375 | "HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))" 376 | ] 377 | }, 378 | "metadata": {}, 379 | "output_type": "display_data" 380 | }, 381 | { 382 | "name": "stdout", 383 | "output_type": "stream", 384 | "text": [ 385 | "\n" 386 | ] 387 | } 388 | ], 389 | "source": [ 390 | "ai.generate_to_file(1000, batch_size=20)" 391 | ] 392 | }, 393 | { 394 | "cell_type": "markdown", 395 | "metadata": {}, 396 | "source": [ 397 | "# MIT License\n", 398 | "\n", 399 | "Copyright (c) 2020 Max Woolf\n", 400 | "\n", 401 | "Permission is hereby granted, free of charge, to any person obtaining a copy\n", 402 | "of this software and associated documentation files (the \"Software\"), to deal\n", 403 | "in the Software without restriction, including without limitation the rights\n", 404 | "to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n", 405 | "copies of the Software, and to permit persons to whom the Software is\n", 406 | "furnished to do so, subject to the following conditions:\n", 407 | "\n", 408 | "The above copyright notice and this permission notice shall be included in all\n", 409 | "copies or substantial portions of the Software.\n", 410 | "\n", 411 | "THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n", 412 | "IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n", 413 | "FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n", 414 | "AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n", 415 | "LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n", 416 | "OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n", 417 | "SOFTWARE." 418 | ] 419 | } 420 | ], 421 | "metadata": { 422 | "kernelspec": { 423 | "display_name": "Python 3.7.5 64-bit", 424 | "language": "python", 425 | "name": "python37564bitb9ff4e3157b244a896f88d1e5f3eb324" 426 | }, 427 | "language_info": { 428 | "codemirror_mode": { 429 | "name": "ipython", 430 | "version": 3 431 | }, 432 | "file_extension": ".py", 433 | "mimetype": "text/x-python", 434 | "name": "python", 435 | "nbconvert_exporter": "python", 436 | "pygments_lexer": "ipython3", 437 | "version": "3.7.7" 438 | } 439 | }, 440 | "nbformat": 4, 441 | "nbformat_minor": 4 442 | } -------------------------------------------------------------------------------- /notebooks/training_hello_world.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# aitextgen Training Hello World\n", 8 | "\n", 9 | "_Last Updated: Feb 21, 2021 (v.0.4.0)_\n", 10 | "\n", 11 | "by Max Woolf\n", 12 | "\n", 13 | "A \"Hello World\" Tutorial to show how training works with aitextgen, even on a CPU!" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 1, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "from aitextgen.TokenDataset import TokenDataset\n", 23 | "from aitextgen.tokenizers import train_tokenizer\n", 24 | "from aitextgen.utils import GPT2ConfigCPU\n", 25 | "from aitextgen import aitextgen" 26 | ] 27 | }, 28 | { 29 | "source": [ 30 | "First, download this [text file of Shakespeare's plays](https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt), to the folder with this notebook, then put the name of the downloaded Shakespeare text for training into the cell below." 31 | ], 32 | "cell_type": "markdown", 33 | "metadata": {} 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 2, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "file_name = \"input.txt\"" 42 | ] 43 | }, 44 | { 45 | "source": [ 46 | "You can now train a custom Byte Pair Encoding Tokenizer on the downloaded text!\n", 47 | "\n", 48 | "This will save one file: `aitextgen.tokenizer.json`, which contains the information needed to rebuild the tokenizer." 49 | ], 50 | "cell_type": "markdown", 51 | "metadata": {} 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": 3, 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "train_tokenizer(file_name)\n", 60 | "tokenizer_file = \"aitextgen.tokenizer.json\"" 61 | ] 62 | }, 63 | { 64 | "source": [ 65 | "`GPT2ConfigCPU()` is a mini variant of GPT-2 optimized for CPU-training.\n", 66 | "\n", 67 | "e.g. the # of input tokens here is 64 vs. 1024 for base GPT-2. This dramatically speeds training up." 68 | ], 69 | "cell_type": "markdown", 70 | "metadata": {} 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": 4, 75 | "metadata": {}, 76 | "outputs": [], 77 | "source": [ 78 | "config = GPT2ConfigCPU()" 79 | ] 80 | }, 81 | { 82 | "source": [ 83 | "Instantiate aitextgen using the created tokenizer and config" 84 | ], 85 | "cell_type": "markdown", 86 | "metadata": {} 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": 5, 91 | "metadata": {}, 92 | "outputs": [], 93 | "source": [ 94 | "ai = aitextgen(tokenizer_file=tokenizer_file, config=config)" 95 | ] 96 | }, 97 | { 98 | "source": [ 99 | "You can build datasets for training by creating TokenDatasets, which automatically processes the dataset with the appropriate size." 100 | ], 101 | "cell_type": "markdown", 102 | "metadata": {} 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": 6, 107 | "metadata": {}, 108 | "outputs": [ 109 | { 110 | "output_type": "stream", 111 | "name": "stderr", 112 | "text": [ 113 | "100%|██████████| 40000/40000 [00:00<00:00, 86712.61it/s]\n" 114 | ] 115 | }, 116 | { 117 | "output_type": "execute_result", 118 | "data": { 119 | "text/plain": [ 120 | "TokenDataset containing 462,820 subsets loaded from file at input.txt." 121 | ] 122 | }, 123 | "metadata": {}, 124 | "execution_count": 6 125 | } 126 | ], 127 | "source": [ 128 | "data = TokenDataset(file_name, tokenizer_file=tokenizer_file, block_size=64)\n", 129 | "data" 130 | ] 131 | }, 132 | { 133 | "source": [ 134 | "Train the model! It will save pytorch_model.bin periodically and after completion to the `trained_model` folder. On a 2020 8-core iMac, this took ~25 minutes to run.\n", 135 | "\n", 136 | "The configuration below processes 400,000 subsets of tokens (8 * 50000), which is about just one pass through all the data (1 epoch). Ideally you'll want multiple passes through the data and a training loss less than `2.0` for coherent output; when training a model from scratch, that's more difficult, but with long enough training you can get there!" 137 | ], 138 | "cell_type": "markdown", 139 | "metadata": {} 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": 7, 144 | "metadata": {}, 145 | "outputs": [ 146 | { 147 | "output_type": "stream", 148 | "name": "stderr", 149 | "text": [ 150 | "pytorch_model.bin already exists in /trained_model and will be overwritten!\n", 151 | "GPU available: False, used: False\n", 152 | "TPU available: None, using: 0 TPU cores\n", 153 | "\u001b[1m5,000 steps reached: saving model to /trained_model\u001b[0m\n", 154 | "\u001b[1m5,000 steps reached: generating sample texts.\u001b[0m\n", 155 | "==========\n", 156 | "'s dead;\n", 157 | "But is no winted in his northeritiff\n", 158 | "Tave passage, and eleve your hours.\n", 159 | "\n", 160 | "PETRUCHIO:\n", 161 | "What is this I does, I will, sir;\n", 162 | "That, you have, nor tolding we\n", 163 | "==========\n", 164 | "\u001b[1m10,000 steps reached: saving model to /trained_model\u001b[0m\n", 165 | "\u001b[1m10,000 steps reached: generating sample texts.\u001b[0m\n", 166 | "==========\n", 167 | ".\n", 168 | "\n", 169 | "QUEEN ELIZABETH:\n", 170 | "I know, to, fair beat, to my soul is wonder'd intend.\n", 171 | "\n", 172 | "KING RICHARD III:\n", 173 | "Hold, and threaten, my lord, and my shame!\n", 174 | "\n", 175 | "QUEEN ELIZAB\n", 176 | "==========\n", 177 | "\u001b[1m15,000 steps reached: saving model to /trained_model\u001b[0m\n", 178 | "\u001b[1m15,000 steps reached: generating sample texts.\u001b[0m\n", 179 | "==========\n", 180 | "s of capitcts!\n", 181 | "\n", 182 | "EDWARD:\n", 183 | "Gardener, what is this hour will not say.\n", 184 | "What, shall the joint, I pray, if they\n", 185 | "Harry, let bid me as he would readness so.\n", 186 | "\n", 187 | "B\n", 188 | "==========\n", 189 | "\u001b[1m20,000 steps reached: saving model to /trained_model\u001b[0m\n", 190 | "\u001b[1m20,000 steps reached: generating sample texts.\u001b[0m\n", 191 | "==========\n", 192 | " for.\n", 193 | "\n", 194 | "ROMEO:\n", 195 | "Fair to the iercing wide's fretch,\n", 196 | "And happy talk of the master,\n", 197 | "And waste their justice with the feet and punning,\n", 198 | "And therefore be ben\n", 199 | "==========\n", 200 | "\u001b[1m25,000 steps reached: saving model to /trained_model\u001b[0m\n", 201 | "\u001b[1m25,000 steps reached: generating sample texts.\u001b[0m\n", 202 | "==========\n", 203 | ",\n", 204 | "That we we will have not lose such.\n", 205 | "\n", 206 | "See, to the kingdom of our virtue,\n", 207 | "You banish'd our purpose, for our own ignorse,\n", 208 | "Dispon I remain, and seem'd in\n", 209 | "==========\n", 210 | "\u001b[1m30,000 steps reached: saving model to /trained_model\u001b[0m\n", 211 | "\u001b[1m30,000 steps reached: generating sample texts.\u001b[0m\n", 212 | "==========\n", 213 | ".\n", 214 | "\n", 215 | "BENVOLIO:\n", 216 | "O, she's dead!\n", 217 | "\n", 218 | "CAMILLO:\n", 219 | "No, my lord;\n", 220 | "These accession will be hous.\n", 221 | "\n", 222 | "DERBY:\n", 223 | "No, my lord.\n", 224 | "\n", 225 | "GLOUCESTER:\n", 226 | "What is the\n", 227 | "==========\n", 228 | "\u001b[1m35,000 steps reached: saving model to /trained_model\u001b[0m\n", 229 | "\u001b[1m35,000 steps reached: generating sample texts.\u001b[0m\n", 230 | "==========\n", 231 | ",\n", 232 | "And whiles it is but the castle,\n", 233 | "That stavin'd in the gods of men.\n", 234 | "\n", 235 | "COMFEY:\n", 236 | "What, then?\n", 237 | "\n", 238 | "ELBOW:\n", 239 | "Peace, my lord,\n", 240 | "And weat your greats\n", 241 | "==========\n", 242 | "\u001b[1m40,000 steps reached: saving model to /trained_model\u001b[0m\n", 243 | "\u001b[1m40,000 steps reached: generating sample texts.\u001b[0m\n", 244 | "==========\n", 245 | "\n", 246 | "The white mercy of the sun upon my past,\n", 247 | "Of my father's son be first, thy sake,\n", 248 | "His son's chief son, and my includy;\n", 249 | "And if thy brother's loss, thy thrief,\n", 250 | "\n", 251 | "==========\n", 252 | "\u001b[1m45,000 steps reached: saving model to /trained_model\u001b[0m\n", 253 | "\u001b[1m45,000 steps reached: generating sample texts.\u001b[0m\n", 254 | "==========\n", 255 | " to the crown,\n", 256 | "Or I'll privy I have.\n", 257 | "\n", 258 | "POLIXENES:\n", 259 | "I have been a stir.\n", 260 | "\n", 261 | "LEONTES:\n", 262 | "The worshiped, the benefition of the crown.\n", 263 | "\n", 264 | "His somet\n", 265 | "==========\n", 266 | "\u001b[1m50,000 steps reached: saving model to /trained_model\u001b[0m\n", 267 | "\u001b[1m50,000 steps reached: generating sample texts.\u001b[0m\n", 268 | "==========\n", 269 | ":\n", 270 | "Catesby, girls, and make avoides;\n", 271 | "But, welcome a far\n", 272 | "That ever home, like a villain, and behold\n", 273 | "Canusy not passing nonquial at the g\n", 274 | "==========\n", 275 | "Loss: 2.940 — Avg: 2.884: 100%|██████████| 50000/50000 [31:39<00:00, 26.32it/s]\n" 276 | ] 277 | } 278 | ], 279 | "source": [ 280 | "ai.train(data, batch_size=8, num_steps=50000, generate_every=5000, save_every=5000)" 281 | ] 282 | }, 283 | { 284 | "source": [ 285 | "Generate text from your trained model!" 286 | ], 287 | "cell_type": "markdown", 288 | "metadata": {} 289 | }, 290 | { 291 | "cell_type": "code", 292 | "execution_count": 8, 293 | "metadata": {}, 294 | "outputs": [ 295 | { 296 | "output_type": "stream", 297 | "name": "stdout", 298 | "text": [ 299 | "\u001b[1mROMEO:\u001b[0m\nAbook, ho! forthing me, gentle Earl's royal king,\nAnd this, I, with that I do not beseech you\nTo visit the battle, that I should believe you,\nWhich I would never\n==========\n\u001b[1mROMEO:\u001b[0m\nConfound is gone, thou art a maid into the widow;\nPut up my life and make me no harmony\nAnd make thee I know uncle,\nUnconted and curses: therefore in my\n==========\n\u001b[1mROMEO:\u001b[0m\nGod push! but what days to see\nThe giving bleedom's heart I do? Therefore,\nAnd most unless I had rather. He saddle\nTake your cold shack down; and so far I\n==========\n\u001b[1mROMEO:\u001b[0m\nPersetain'd up the earth of mercy,\nAnd never yet, the sun to make him all the\nMore than my battle.\n\nROMEO:\nI warrant him, to know, we'll not do't, but hate me\n==========\n\u001b[1mROMEO:\u001b[0m\nMethinks I am a mile, and trench one\nThy winded makes, in faults and cast\nWith one to meether, of twenty days,\nThat in my waters, that f\n==========\n\u001b[1mROMEO:\u001b[0m\nO, here is such a woman guilty.\n\nROMEO:\nI do not think it; I should be renowned\nThat I am in that which can controy\nA bawd I take it to the purpose.\n\nJU\n==========\n\u001b[1mROMEO:\u001b[0m\nI know not what I am.\n\nFLORIZEL:\nAy, as I did,\nI would be adverpite of the homely treason\nFrom the doubled in the farm of his bed.\nTa\n==========\n\u001b[1mROMEO:\u001b[0m\nI pray you, he would have taken to him but,\nAnd freely mark his into a fine of it,\nSpeak to the second to our cheek;\nAnd every day, and sanctious cover\n==========\n\u001b[1mROMEO:\u001b[0m\nI had left me--born to be drawn.\n\nJULIET:\nMy husbour, I will have thee here:\nAnd, I have found to seek thyself.\n\nJULIET:\nI will be not b\n==========\n\u001b[1mROMEO:\u001b[0m\nThat is a hour,\nThe castard is, I'll not buy, or indeeding.\n\nNurse:\nLADY CAPULET:\nThe matter, that ta'en as I may find thee.\n\n" 300 | ] 301 | } 302 | ], 303 | "source": [ 304 | "ai.generate(10, prompt=\"ROMEO:\")" 305 | ] 306 | }, 307 | { 308 | "source": [ 309 | "With your trained model, you can reload the model at any time by providing the `pytorch_model.bin` model weights, the `config`, and the `tokenizer`." 310 | ], 311 | "cell_type": "markdown", 312 | "metadata": {} 313 | }, 314 | { 315 | "cell_type": "code", 316 | "execution_count": 9, 317 | "metadata": {}, 318 | "outputs": [], 319 | "source": [ 320 | "ai2 = aitextgen(model_folder=\"trained_model\",\n", 321 | " tokenizer_file=\"aitextgen.tokenizer.json\")" 322 | ] 323 | }, 324 | { 325 | "cell_type": "code", 326 | "execution_count": 10, 327 | "metadata": {}, 328 | "outputs": [ 329 | { 330 | "output_type": "stream", 331 | "name": "stdout", 332 | "text": [ 333 | "\u001b[1mROMEO:\u001b[0m\nBoy, unreacher, unhallupony, in Padua,\nUntimely fall till I be learn'd.\n\nROMEO:\nFie, good friar, be quick, for I am,\nI'll\n==========\n\u001b[1mROMEO:\u001b[0m\nI'll be plain, I am a tail of blessed wounds;\nFor I am dead, I have not borne to make\nA couple of her fortune, but that I'll bear,\nAnd say 'Ay, chur\n==========\n\u001b[1mROMEO:\u001b[0m\nAnd yet I am a resolution of my dear dear:\nIf I have not reason to do me say\nI'll deny the sea of my body to answer,\nAnd all thy tale, or I have my m\n==========\n\u001b[1mROMEO:\u001b[0m\nIntenty to a bawd of my bait,--\n\nJULIET:\nNo, I hope to know the title,\nFor that I wish her place.\n\nJULIET:\nDo I assure her?\n==========\n\u001b[1mROMEO:\u001b[0m\nO, what's the parle that I chide thee,\nThat honourable may be, that I have still'd thee:\nI pray thee, my lord.\n\nMERCUTIO:\nI', my lord.\n\nROMEO:\nHere is a\n==========\n\u001b[1mROMEO:\u001b[0m\nAnd, for I am, and not talk of that?\n\nROMEO:\nWhere's my child, I would guess thee here.\n\nJULIET:\nNay, boy, I'll not be bowling why I;\nO thou\n==========\n\u001b[1mROMEO:\u001b[0m\nO, but thou hast seen thee of mine own.\n\nROMEO:\nI would assist thee--\n\nJULIET:\nAy, it is, and not so.\n\nROMEO:\nNo, but that I must told me with it.\n\nROMEO\n==========\n\u001b[1mROMEO:\u001b[0m\nNo, no, nor I am. I am content.\n\nBENVOLIO:\nI will not, sir: but I have required\nAs I am grown in the lawful virtue\nThat it hath bid you think, and I\n==========\n\u001b[1mROMEO:\u001b[0m\nThat I should pardon, I would be gone.\n\nESCALUS:\nI should believe you, sir, sir, ay, I would not\nnot know more, but that I can, but I would have savour me.\n\nP\n==========\n\u001b[1mROMEO:\u001b[0m\nAnd thou, I will find out thy life the wind of love.\n\nROMEO:\nIt is the morning groom of it.\n\nJULIET:\nFie, good sweet boy, I will take my leave to a happy day,\n" 334 | ] 335 | } 336 | ], 337 | "source": [ 338 | "ai2.generate(10, prompt=\"ROMEO:\")" 339 | ] 340 | }, 341 | { 342 | "cell_type": "markdown", 343 | "metadata": {}, 344 | "source": [ 345 | "# MIT License\n", 346 | "\n", 347 | "Copyright (c) 2021 Max Woolf\n", 348 | "\n", 349 | "Permission is hereby granted, free of charge, to any person obtaining a copy\n", 350 | "of this software and associated documentation files (the \"Software\"), to deal\n", 351 | "in the Software without restriction, including without limitation the rights\n", 352 | "to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n", 353 | "copies of the Software, and to permit persons to whom the Software is\n", 354 | "furnished to do so, subject to the following conditions:\n", 355 | "\n", 356 | "The above copyright notice and this permission notice shall be included in all\n", 357 | "copies or substantial portions of the Software.\n", 358 | "\n", 359 | "THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n", 360 | "IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n", 361 | "FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n", 362 | "AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n", 363 | "LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n", 364 | "OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n", 365 | "SOFTWARE." 366 | ] 367 | } 368 | ], 369 | "metadata": { 370 | "kernelspec": { 371 | "name": "python3", 372 | "display_name": "Python 3.9.1 64-bit", 373 | "metadata": { 374 | "interpreter": { 375 | "hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49" 376 | } 377 | } 378 | }, 379 | "language_info": { 380 | "codemirror_mode": { 381 | "name": "ipython", 382 | "version": 3 383 | }, 384 | "file_extension": ".py", 385 | "mimetype": "text/x-python", 386 | "name": "python", 387 | "nbconvert_exporter": "python", 388 | "pygments_lexer": "ipython3", 389 | "version": "3.9.1-final" 390 | } 391 | }, 392 | "nbformat": 4, 393 | "nbformat_minor": 2 394 | } -------------------------------------------------------------------------------- /aitextgen/TokenDataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import logging 3 | import csv 4 | import os 5 | import gzip 6 | from torch.utils.data import Dataset 7 | from typing import List 8 | from transformers import GPT2TokenizerFast, PreTrainedTokenizerFast 9 | from pkg_resources import resource_filename 10 | import itertools 11 | from tqdm.auto import tqdm 12 | import numpy as np 13 | 14 | csv.field_size_limit(2**31 - 1) 15 | 16 | logger = logging.getLogger(__name__) 17 | logger.setLevel(logging.INFO) 18 | 19 | STATIC_PATH = resource_filename(__name__, "static") 20 | 21 | 22 | class TokenDataset(Dataset): 23 | """ 24 | Class that merges TextDataset and LineByLineTextDataset from 25 | run_language_modeling.py in transformers, plus 26 | adds more ways to ingest text such as with CSVs. 27 | 28 | :param file_path: A string indicating the relative file path of the text 29 | to be tokenized, or the cached dataset. 30 | :param vocab_file: Path to a vocab file (generated by train_tokenizer()) 31 | :param merges_file: Path to a merges file (generated by train_tokenizer()) 32 | :param texts: A list of input texts (if providing texts manually) 33 | :param line_by_line: A boolean to indicate if the input file should be read 34 | line by line (True) or as a full text (False). 35 | :param from_cache: A string indicating if loading from a pregenerated MsgPack 36 | dump. 37 | :param header: A boolean indicating if loading from a CSV, if it has a header. 38 | :param save_cache: A boolean indicating whether to save the tokenized 39 | dataset as a MsgPack dump to load later. 40 | :param cache_destination: A string indicating where to save the cache. 41 | :param block_size: An integer indicating maximum length of the text document 42 | (usually set by the model architecture) 43 | :param tokenized_texts: Texts that are already tokenized; only should 44 | be used by merge_datasets(). 45 | :param text_delim: delimiter to use to split bulk texts (default paragraph breaks) 46 | :param bos_token: String to override the beginning-of-string token 47 | :param eos_token: String to override the end-of-string token 48 | :param unk_token: String to override the unknown token 49 | :param pad_token: String to override the padding token 50 | :param progress_bar_refresh_rate: How often to update progress bar when loading 51 | """ 52 | 53 | def __init__( 54 | self, 55 | file_path: str = None, 56 | vocab_file: str = os.path.join(STATIC_PATH, "gpt2_vocab.json"), 57 | merges_file: str = os.path.join(STATIC_PATH, "gpt2_merges.txt"), 58 | tokenizer: GPT2TokenizerFast = None, 59 | tokenizer_file: str = None, 60 | texts: List[str] = None, 61 | line_by_line: bool = False, 62 | from_cache: bool = False, 63 | header: bool = True, 64 | save_cache: bool = False, 65 | cache_destination: str = "dataset_cache.tar.gz", 66 | compress: bool = True, 67 | block_size: int = 1024, 68 | tokenized_texts: bool = False, 69 | text_delim: str = "\n", 70 | bos_token: str = "<|endoftext|>", 71 | eos_token: str = "<|endoftext|>", 72 | unk_token: str = "<|endoftext|>", 73 | pad_token: str = "<|endoftext|>", 74 | progress_bar_refresh_rate: int = 20, 75 | **kwargs, 76 | ) -> None: 77 | 78 | self.line_by_line = False 79 | 80 | # Special case; load tokenized texts immediately 81 | if tokenized_texts: 82 | self.tokens = np.asarray(tokenized_texts) 83 | self.num_subsets = self.tokens.shape[0] - block_size 84 | self.block_size = block_size 85 | self.file_path = "merged TokenDataset" 86 | self.str_suffix = "by merging TokenDatasets." 87 | return 88 | 89 | assert any([texts, file_path]), "texts or file_path must be specified." 90 | 91 | if not tokenizer: 92 | if tokenizer_file: 93 | # load the custom tokenizer from a serialized tokenizer 94 | tokenizer = PreTrainedTokenizerFast( 95 | tokenizer_file=tokenizer_file, 96 | bos_token=bos_token, 97 | eos_token=eos_token, 98 | unk_token=unk_token, 99 | pad_token=pad_token, 100 | ) 101 | else: 102 | tokenizer = GPT2TokenizerFast( 103 | vocab_file=vocab_file, 104 | merges_file=merges_file, 105 | bos_token=bos_token, 106 | eos_token=eos_token, 107 | unk_token=unk_token, 108 | pad_token=pad_token, 109 | verbose=False, 110 | ) 111 | # https://github.com/huggingface/transformers/issues/10202 112 | tokenizer.add_special_tokens( 113 | {"additional_special_tokens": ["<|endoftext|>"]} 114 | ) 115 | 116 | # If a cache path is provided, load it. 117 | if from_cache: 118 | open_func = gzip.open if file_path.endswith(".gz") else open 119 | 120 | with open_func(file_path, "rb") as f: 121 | self.tokens = np.load(f) 122 | self.num_subsets = self.tokens.shape[0] - block_size 123 | self.block_size = block_size 124 | self.line_by_line = line_by_line 125 | self.str_suffix = "via cache." 126 | 127 | logger.info( 128 | f"TokenDataset containing {self.num_subsets:,} subsets loaded {self.str_suffix}" 129 | ) 130 | return 131 | 132 | # if texts are present, just tokenize them. 133 | elif texts: 134 | self.str_suffix = "via application." 135 | 136 | # if a file is specified, and it's line-delimited, 137 | # the text must be processed line-by-line into a a single bulk file 138 | elif line_by_line: 139 | assert os.path.isfile( 140 | file_path 141 | ), f"{file_path} is not present in the current directory." 142 | 143 | text_delim = None 144 | self.line_by_line = True 145 | self.file_path = file_path 146 | self.str_suffix = f"from line-by-line file at {file_path}." 147 | 148 | # if a file is specified, and it's not line-delimited, 149 | # the texts must be parsed as a single bulk file. 150 | else: 151 | assert os.path.isfile( 152 | file_path 153 | ), f"{file_path} is not present in the current directory." 154 | if file_path.endswith(".csv"): 155 | logger.warning( 156 | "You are tokenizing a CSV file, but you did not " 157 | + "set line_by_line=True. Please change if unintended." 158 | ) 159 | 160 | eos_token = "" 161 | header = False 162 | self.file_path = file_path 163 | self.str_suffix = f"from file at {file_path}." 164 | 165 | # Encode tokens in a batched manner to ensure constant memory usage 166 | if texts: 167 | self.tokens = encode_tokens_from_list( 168 | texts, eos_token, tokenizer, progress_bar_refresh_rate 169 | ) 170 | else: 171 | self.tokens = encode_tokens_from_file( 172 | file_path, 173 | eos_token, 174 | tokenizer, 175 | text_delim, 176 | header, 177 | progress_bar_refresh_rate, 178 | ) 179 | 180 | assert ( 181 | self.tokens.shape[0] >= block_size 182 | ), f"There are fewer than {block_size} encoded tokens." 183 | self.num_subsets = self.tokens.shape[0] - block_size 184 | self.block_size = block_size 185 | 186 | if save_cache: 187 | self.save(cache_destination, compress=compress) 188 | 189 | def save( 190 | self, cache_destination: str = "dataset_cache.tar.gz", compress: bool = True 191 | ) -> None: 192 | assert self.tokens.shape[0] > 0, "No data loaded to save." 193 | 194 | if compress: 195 | open_func = gzip.open 196 | compress_str = "and compressing " 197 | else: 198 | open_func = open 199 | cache_destination = ( 200 | "dataset_cache.npy" 201 | if cache_destination == "dataset_cache.tar.gz" 202 | else cache_destination 203 | ) 204 | compress_str = "" 205 | 206 | logger.info(f"Caching {compress_str}dataset to {cache_destination}") 207 | 208 | with open_func(cache_destination, "wb") as f: 209 | np.save(f, self.tokens) 210 | 211 | def __len__(self) -> int: 212 | return self.num_subsets 213 | 214 | def __getitem__(self, item: int) -> torch.Tensor: 215 | return torch.as_tensor( 216 | self.tokens[item : (item + self.block_size)].astype(np.int64, copy=False), 217 | dtype=torch.long, 218 | ) 219 | 220 | def __str__(self) -> str: 221 | return self.file_path if self.file_path is not None else "loaded dataset" 222 | 223 | def __repr__(self) -> str: 224 | return f"TokenDataset containing {self.num_subsets:,} subsets loaded {self.str_suffix}" 225 | 226 | 227 | def get_lines_in_file(file_path: str, newline: str = None) -> int: 228 | """ 229 | Returns the number of lines in a file to build progress bar. 230 | c.f. https://stackoverflow.com/a/16108605/9314418 231 | """ 232 | 233 | with open(file_path, "r", encoding="utf-8", newline=newline) as f: 234 | return sum(1 for row in f) 235 | 236 | 237 | def get_lines_in_file_csv(file_path: str, header: bool = True) -> int: 238 | """ 239 | Returns the number of lines in a CSV to build progress bar. 240 | c.f. https://stackoverflow.com/a/16108605/9314418 241 | """ 242 | 243 | with open(file_path, "r", encoding="utf-8") as f: 244 | if header: 245 | f.readline() 246 | reader = csv.reader(f) 247 | return sum(1 for row in reader) 248 | 249 | 250 | def get_dtype(vocab_size: int): 251 | """ 252 | Finds the appropriate numpy dtype depending on vocab size. 253 | 254 | The highest value for the dtype serves as a placeholder. 255 | """ 256 | if vocab_size < 2**8 - 1: 257 | return np.uint8 258 | elif vocab_size < 2**16 - 1: 259 | return np.uint16 260 | elif vocab_size < 2**32 - 1: 261 | return np.uint32 262 | 263 | return np.uint64 264 | 265 | 266 | def encode_tokens_from_file( 267 | file_path: str, 268 | eos_token: str, 269 | tokenizer: GPT2TokenizerFast, 270 | newline: str, 271 | header: bool = True, 272 | progress_bar_refresh_rate: int = 20, 273 | batch_size: int = 1024, 274 | ) -> List[int]: 275 | """ 276 | Retrieves texts from a newline-delimited file/CSV and returns texts. 277 | """ 278 | 279 | is_csv = file_path.endswith(".csv") 280 | a_dtype = get_dtype(tokenizer.vocab_size) 281 | 282 | if is_csv: 283 | num_texts = get_lines_in_file_csv(file_path, header) 284 | else: 285 | num_texts = get_lines_in_file(file_path, newline) 286 | 287 | pbar = tqdm( 288 | total=num_texts, 289 | smoothing=0, 290 | leave=True, 291 | dynamic_ncols=True, 292 | ) 293 | tokens = np.full((num_texts, 1), -1, dtype=a_dtype) 294 | num_batches = 0 295 | 296 | with open(file_path, "r", encoding="utf-8", newline=newline) as f_load: 297 | 298 | if header: 299 | f_load.readline() 300 | if is_csv: 301 | f_read = csv.reader(f_load) 302 | logger.info(f"Encoding {num_texts:,} rows from {file_path}.") 303 | else: 304 | f_read = f_load 305 | logger.info(f"Encoding {num_texts:,} sets of tokens from {file_path}.") 306 | 307 | # https://stackoverflow.com/a/6335876/9314418 308 | while True: 309 | if is_csv: 310 | batch = [ 311 | text[0] + eos_token 312 | for text in list(itertools.islice(f_read, batch_size)) 313 | ] 314 | else: 315 | batch = [ 316 | text + eos_token 317 | for text in list(itertools.islice(f_read, batch_size)) 318 | ] 319 | 320 | if not batch: 321 | break 322 | 323 | encoded_texts = tokenizer( 324 | batch, 325 | add_special_tokens=False, 326 | return_token_type_ids=False, 327 | return_attention_mask=False, 328 | )["input_ids"] 329 | 330 | for i, encoded_text in enumerate(encoded_texts): 331 | if len(encoded_text) > tokens.shape[1]: 332 | cols_to_add = len(encoded_text) - tokens.shape[1] 333 | tokens = np.concatenate( 334 | ( 335 | tokens, 336 | np.full( 337 | (num_texts, cols_to_add), 338 | -1, 339 | dtype=a_dtype, 340 | ), 341 | ), 342 | axis=1, 343 | ) 344 | tokens[ 345 | (num_batches * batch_size) + i, : len(encoded_text) 346 | ] = encoded_text 347 | 348 | num_batches += 1 349 | 350 | if num_batches % progress_bar_refresh_rate == 0: 351 | pbar.update(batch_size * progress_bar_refresh_rate) 352 | 353 | pbar.n = num_texts 354 | pbar.refresh() 355 | pbar.close() 356 | tokens = tokens.flatten() 357 | return tokens[tokens < np.array(-1, dtype=a_dtype)] 358 | 359 | 360 | def encode_tokens_from_list( 361 | texts: List[str], 362 | eos_token: str, 363 | tokenizer: GPT2TokenizerFast, 364 | progress_bar_refresh_rate: int = 20, 365 | batch_size: int = 1024, 366 | ) -> List[int]: 367 | """ 368 | Retrieves texts from a newline-delimited file/CSV and returns texts. 369 | """ 370 | 371 | num_texts = len(texts) 372 | a_dtype = get_dtype(tokenizer.vocab_size) 373 | logger.info(f"Encoding {num_texts:,} texts.") 374 | 375 | pbar = tqdm( 376 | total=num_texts, 377 | smoothing=0, 378 | leave=True, 379 | dynamic_ncols=True, 380 | ) 381 | tokens = np.full((len(texts), 1), -1, dtype=a_dtype) 382 | 383 | for i_start in range(num_texts // batch_size + 1): 384 | batch = [ 385 | text + eos_token 386 | for text in texts[ 387 | (i_start * batch_size) : ((i_start * batch_size) + batch_size) 388 | ] 389 | ] 390 | 391 | encoded_texts = tokenizer( 392 | batch, 393 | add_special_tokens=False, 394 | return_token_type_ids=False, 395 | return_attention_mask=False, 396 | )["input_ids"] 397 | 398 | for i, encoded_text in enumerate(encoded_texts): 399 | if len(encoded_text) > tokens.shape[1]: 400 | cols_to_add = len(encoded_text) - tokens.shape[1] 401 | tokens = np.concatenate( 402 | ( 403 | tokens, 404 | np.full( 405 | (num_texts, cols_to_add), 406 | -1, 407 | dtype=a_dtype, 408 | ), 409 | ), 410 | axis=1, 411 | ) 412 | tokens[(i_start * batch_size) + i, : len(encoded_text)] = encoded_text 413 | 414 | if i_start % progress_bar_refresh_rate == 0: 415 | pbar.update(batch_size * progress_bar_refresh_rate) 416 | 417 | pbar.n = num_texts 418 | pbar.refresh() 419 | pbar.close() 420 | tokens = tokens.flatten() 421 | return tokens[tokens < np.array(-1, dtype=a_dtype)] 422 | 423 | 424 | def merge_datasets(datasets: List[TokenDataset], equalize: bool = True) -> TokenDataset: 425 | """ 426 | Merges multiple TokenDatasets into a single TokenDataset. 427 | This assumes that you are using the same tokenizer for all TokenDatasets. 428 | 429 | :param datasets: A list of TokenDatasets. 430 | :param equalize: Whether to take an equal amount of samples from all 431 | input datasets (by taking random samples from 432 | each dataset equal to the smallest dataset) 433 | in order to balance out the result dataset. 434 | """ 435 | 436 | assert ( 437 | isinstance(datasets, list) and len(datasets) > 1 438 | ), "datasets must be a list of multiple TokenDatasets." 439 | 440 | len_smallest = min([len(dataset) for dataset in datasets]) 441 | block_size = datasets[0].block_size 442 | 443 | tokenized_texts = [] 444 | 445 | for dataset in datasets: 446 | assert ( 447 | dataset.block_size == block_size 448 | ), "The input datasets have different block sizes." 449 | if equalize: 450 | tokenized_texts.extend(dataset.tokens[0:len_smallest]) 451 | else: 452 | tokenized_texts.extend(dataset.tokens) 453 | 454 | return TokenDataset(tokenized_texts=tokenized_texts, block_size=block_size) 455 | -------------------------------------------------------------------------------- /aitextgen/aitextgen.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import platform 4 | import re 5 | import shutil 6 | import sys 7 | from datetime import datetime 8 | from random import randint 9 | from typing import List, Optional, Union 10 | 11 | import pytorch_lightning as pl 12 | import torch 13 | from pkg_resources import resource_filename 14 | # from pytorch_lightning.plugins import DeepSpeedPlugin 15 | from tqdm.auto import trange 16 | from transformers import ( 17 | AutoConfig, 18 | AutoModelForCausalLM, 19 | AutoTokenizer, 20 | GPT2Config, 21 | GPT2LMHeadModel, 22 | GPT2TokenizerFast, 23 | PreTrainedTokenizerFast, 24 | ) 25 | from transformers.models.gpt2.convert_gpt2_original_tf_checkpoint_to_pytorch import ( 26 | convert_gpt2_checkpoint_to_pytorch, 27 | ) 28 | 29 | from .colab import create_gdrive_folder 30 | from .TokenDataset import TokenDataset 31 | from .train import ATGProgressBar, ATGTransformer 32 | from .utils import ( 33 | download_gpt2, 34 | find_index_of_subset, 35 | model_max_length, 36 | reset_seed, 37 | set_seed, 38 | ) 39 | 40 | logger = logging.getLogger("aitextgen") 41 | logger.setLevel(logging.INFO) 42 | 43 | STATIC_PATH = resource_filename(__name__, "static") 44 | 45 | 46 | class aitextgen: 47 | """ 48 | Class that serves as the main aitextgen object for training and generation. 49 | 50 | :param model: Either the file path of a PyTorch GPT-2 model, or a string 51 | representing the Huggingface model to download. 52 | :param config: Either a file path of a config.json representing the model, 53 | or a GPT2Config with the model architecture. 54 | :param vocab_file: Path to a vocab file (generated by train_tokenizer()) 55 | :param merges_file: Path to a merges file (generated by train_tokenizer()) 56 | :param cache_dir: folder path which downloaded models will be stored and loaded 57 | :param tf_gpt2: model indicator of OpenAI-distributed version of GPT-2. 58 | This will convert the model to PyTorch if not present. 59 | :param to_gpu: Whether to load the model into the GPU after loading 60 | (good for generation) 61 | :param to_fp16: Whether to convert the model to FP16 before loading 62 | to GPU (for supported GPUs only) 63 | :param verbose: Whether to enable logging from base Huggingface packages 64 | :param bos_token: String to override the beginning-of-string token 65 | :param eos_token: String to override the end-of-string token 66 | :param unk_token: String to override the unknown token 67 | """ 68 | 69 | openai_tf_gpt2 = None 70 | 71 | # default values for GPT2Tokenizer 72 | tokenizer = None 73 | vocab_file = os.path.join(STATIC_PATH, "gpt2_vocab.json") 74 | merges_file = os.path.join(STATIC_PATH, "gpt2_merges.txt") 75 | bos_token = "<|endoftext|>" 76 | eos_token = "<|endoftext|>" 77 | unk_token = "<|endoftext|>" 78 | pad_token = "<|endoftext|>" 79 | 80 | def __init__( 81 | self, 82 | model: str = None, 83 | model_folder: str = None, 84 | config: Union[str, GPT2Config] = None, 85 | vocab_file: str = None, 86 | merges_file: str = None, 87 | tokenizer_file: str = None, 88 | schema_tokens: List[str] = None, 89 | schema_return: List[str] = None, 90 | cache_dir: str = "aitextgen", 91 | tf_gpt2: str = None, 92 | to_gpu: bool = False, 93 | to_fp16: bool = False, 94 | verbose: bool = False, 95 | gradient_checkpointing: bool = False, 96 | bos_token: str = None, 97 | eos_token: str = None, 98 | unk_token: str = None, 99 | **kwargs, 100 | ) -> None: 101 | 102 | if model: 103 | assert not os.path.isfile(model), ( 104 | "As of aitextgen 0.5.0, you must " 105 | + "use `model_folder` to load an existing model." 106 | ) 107 | 108 | if not verbose: 109 | for module in [ 110 | "transformers.file_utils", 111 | "transformers.configuration_utils", 112 | "transformers.tokenization_utils", 113 | "filelock", 114 | "transformers.modeling_gpt2", 115 | ]: 116 | logging.getLogger(module).setLevel(logging.WARN) 117 | logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR) 118 | 119 | if tf_gpt2: 120 | self.openai_tf_gpt2 = tf_gpt2 121 | 122 | # Download + convert the TF weights if a PyTorch model has not been created 123 | if not os.path.isfile( 124 | os.path.join(cache_dir, f"pytorch_model_{tf_gpt2}.bin") 125 | ): 126 | assert tf_gpt2 in [ 127 | "124M", 128 | "355M", 129 | "774M", 130 | "1558M", 131 | ], "Invalid TensorFlow GPT-2 model size." 132 | 133 | logger.info( 134 | f"Downloading the {tf_gpt2} GPT-2 TensorFlow weights/config " 135 | + "from Google's servers" 136 | ) 137 | 138 | download_gpt2(cache_dir, tf_gpt2) 139 | 140 | logger.info( 141 | f"Converting the {tf_gpt2} GPT-2 TensorFlow weights to PyTorch." 142 | ) 143 | 144 | config_path = os.path.join(cache_dir, tf_gpt2, "hparams.json") 145 | 146 | convert_gpt2_checkpoint_to_pytorch( 147 | os.path.join(cache_dir, tf_gpt2), 148 | config_path, 149 | cache_dir, 150 | ) 151 | 152 | os.rename( 153 | os.path.join(cache_dir, "pytorch_model.bin"), 154 | os.path.join(cache_dir, f"pytorch_model_{tf_gpt2}.bin"), 155 | ) 156 | 157 | os.rename( 158 | os.path.join(cache_dir, "config.json"), 159 | os.path.join(cache_dir, f"config_{tf_gpt2}.json"), 160 | ) 161 | 162 | logger.info(f"Loading {tf_gpt2} GPT-2 model from /{cache_dir}.") 163 | model = os.path.join(cache_dir, f"pytorch_model_{tf_gpt2}.bin") 164 | config = os.path.join(cache_dir, f"config_{tf_gpt2}.json") 165 | 166 | self.model = GPT2LMHeadModel.from_pretrained(model, config=config) 167 | 168 | elif model_folder: 169 | # A folder is provided containing pytorch_model.bin and config.json 170 | assert os.path.exists( 171 | os.path.join(model_folder, "pytorch_model.bin") 172 | ), f"There is no pytorch_model.bin in /{model_folder}." 173 | assert os.path.exists( 174 | os.path.join(model_folder, "config.json") 175 | ), f"There is no config.json in /{model_folder}." 176 | 177 | logger.info( 178 | f"Loading model from provided weights and config in /{model_folder}." 179 | ) 180 | self.model = AutoModelForCausalLM.from_pretrained( 181 | model_folder, local_files_only=True 182 | ) 183 | elif config: 184 | # Manually construct a model from scratch 185 | logger.info("Constructing model from provided config.") 186 | if isinstance(config, str): 187 | config = AutoConfig.from_pretrained(config) 188 | self.model = AutoModelForCausalLM.from_config(config=config) 189 | else: 190 | # Download and cache model from Huggingface 191 | if os.path.isdir(cache_dir) and len(os.listdir(cache_dir)) > 0: 192 | logger.info(f"Loading {model or 'gpt2'} model from /{cache_dir}.") 193 | else: 194 | logger.info(f"Downloading {model or 'gpt2'} model to /{cache_dir}.") 195 | self.model = AutoModelForCausalLM.from_pretrained( 196 | model or "gpt2", cache_dir=cache_dir 197 | ) 198 | if model and "gpt2" not in model: 199 | logger.info(f"Using the tokenizer for {model}.") 200 | self.tokenizer = AutoTokenizer.from_pretrained( 201 | model, 202 | cache_dir=cache_dir, 203 | ) 204 | 205 | logger.info(self) 206 | 207 | if gradient_checkpointing or tf_gpt2 in ["355M", "774M", "1558M"]: 208 | logger.info("Gradient checkpointing enabled for model training.") 209 | setattr(self.model.config, "gradient_checkpointing", True) 210 | setattr(self.model.config, "use_cache", False) 211 | 212 | if schema_tokens: 213 | setattr(self.model.config, "schema_tokens", schema_tokens) 214 | 215 | if schema_return: 216 | setattr(self.model.config, "schema_return", schema_return) 217 | 218 | if self.tokenizer is None: 219 | # Update tokenizer settings (if not set already) 220 | args = locals() 221 | custom_tokenizer = False 222 | for attr in [ 223 | "vocab_file", 224 | "merges_file", 225 | "tokenizer_file", 226 | "bos_token", 227 | "eos_token", 228 | "unk_token", 229 | ]: 230 | if args[attr] is not None: 231 | custom_tokenizer = True 232 | setattr(self, attr, args[attr]) 233 | 234 | if custom_tokenizer: 235 | logger.info("Using a custom tokenizer.") 236 | else: 237 | logger.info("Using the default GPT-2 Tokenizer.") 238 | 239 | if tokenizer_file: 240 | # load the custom GPT-2 tokenizer from a serialized tokenizer. 241 | # GPT-Neo uses the GPT-2 tokenizer. 242 | self.tokenizer = PreTrainedTokenizerFast( 243 | tokenizer_file=tokenizer_file, 244 | bos_token=self.bos_token, 245 | eos_token=self.eos_token, 246 | unk_token=self.unk_token, 247 | pad_token=self.pad_token, 248 | ) 249 | else: 250 | self.tokenizer = GPT2TokenizerFast( 251 | vocab_file=self.vocab_file, 252 | merges_file=self.merges_file, 253 | bos_token=self.bos_token, 254 | eos_token=self.eos_token, 255 | unk_token=self.unk_token, 256 | pad_token=self.pad_token, 257 | verbose=False, 258 | ) 259 | if not custom_tokenizer: 260 | # https://github.com/huggingface/transformers/issues/10202 261 | self.tokenizer.add_special_tokens( 262 | {"additional_special_tokens": ["<|endoftext|>"]} 263 | ) 264 | 265 | self.tokenizer.padding_side = "left" 266 | 267 | if to_gpu: 268 | if to_fp16: 269 | logger.warn( 270 | "Currently, FP16 text generation results in random output. " 271 | + "You may want to avoid using to_fp16 for the time being." 272 | ) 273 | self.to_fp16() 274 | self.to_gpu() 275 | 276 | def generate( 277 | self, 278 | n: int = 1, 279 | prompt: str = "", 280 | prepend_bos: bool = None, 281 | min_length: int = None, 282 | max_length: int = 256, 283 | temperature: float = 0.7, 284 | do_sample: bool = True, 285 | return_as_list: bool = False, 286 | seed: int = None, 287 | pad_token_id: str = None, 288 | schema: str = False, 289 | normalize_key: bool = True, 290 | use_cache: bool = True, 291 | lstrip: bool = True, 292 | nonempty_output: bool = True, 293 | skip_special_tokens: bool = True, 294 | **kwargs, 295 | ) -> Optional[str]: 296 | """ 297 | Generates texts using the stored Transformers model. 298 | Currently generates text using the model's generate() function. 299 | 300 | :param n: Numbers of texts to generate. 301 | :param prompt: Text to force the generated text to start with 302 | :param max_length: Maximum length for the generated text 303 | :param temperature: Determines the "creativity" of the generated text. 304 | The value range is different for each type of Transformer. 305 | :param do_sample: Samples the text, which is what we want. If False, 306 | the generated text will be the optimal prediction at each time, 307 | and therefore deterministic. 308 | :param return_as_list: Boolean which determine if text should be returned 309 | as a list. If False, the generated texts will be print to console. 310 | :param seed: A numeric seed which sets all randomness, allowing the 311 | generate text to be reproducible if rerunning with same parameters 312 | and model. 313 | """ 314 | 315 | prompt_text = prompt 316 | prompt_tensors = self.tokenizer(text=prompt, return_tensors="pt") 317 | 318 | if prompt: 319 | prompt_num_tokens = list(prompt_tensors["input_ids"].shape)[1] 320 | assert prompt_num_tokens < model_max_length( 321 | self.model.config 322 | ), f"The prompt is too large for the model. ({prompt_num_tokens} tokens)" 323 | 324 | input_ids = ( 325 | prompt_tensors["input_ids"].to(self.get_device()) if prompt else None 326 | ) 327 | 328 | if prepend_bos is None: 329 | prepend_bos = getattr(self.model.config, "line_by_line", None) 330 | 331 | if prepend_bos: 332 | bos = torch.tensor([[self.tokenizer.bos_token_id]]).to(self.get_device()) 333 | if prompt: 334 | input_ids = torch.cat((bos, input_ids), dim=1) 335 | else: 336 | input_ids = bos 337 | 338 | if seed: 339 | set_seed(seed) 340 | 341 | if pad_token_id is None: 342 | pad_token_id = getattr(self.tokenizer, "pad_token_id", None) or getattr( 343 | self.tokenizer, "eos_token_id", None 344 | ) 345 | 346 | # prevent an error from using a length greater than the model 347 | gen_max_length = model_max_length(self.model.config) 348 | max_length = min(gen_max_length, max_length) 349 | 350 | while True: 351 | outputs = self.model.generate( 352 | input_ids=input_ids, 353 | min_length=min_length, 354 | max_length=max_length, 355 | temperature=temperature, 356 | do_sample=do_sample, 357 | num_return_sequences=n, 358 | pad_token_id=pad_token_id, 359 | use_cache=use_cache, 360 | **kwargs, 361 | ) 362 | 363 | # Schema token handling 364 | if schema: 365 | schema_tokens = getattr(self.model.config, "schema_tokens") 366 | schema_return = getattr(self.model.config, "schema_return", None) 367 | schema_tokens_enc = self.tokenizer(text=schema_tokens)["input_ids"] 368 | 369 | nonalphanum_pattern = re.compile(r"[\W_]+", re.UNICODE) 370 | 371 | outputs = outputs.tolist() 372 | gen_texts = [] 373 | for output in outputs: 374 | gen_text_dict = {} 375 | 376 | # Get indices of each schema token within the text 377 | schema_token_indices = [ 378 | (schema_tokens[i], find_index_of_subset(output, token_enc)) 379 | for i, token_enc in enumerate(schema_tokens_enc) 380 | ] 381 | 382 | schema_token_indices.sort(key=lambda x: x[1]) 383 | 384 | for i, token_tuple in enumerate(schema_token_indices): 385 | start_index = token_tuple[1] 386 | key = ( 387 | nonalphanum_pattern.sub("", token_tuple[0]) 388 | if normalize_key 389 | else token_tuple[0] 390 | ) 391 | if start_index == -1: 392 | gen_text_dict[key] = "" 393 | else: 394 | end_index = ( 395 | schema_token_indices[i + 1][1] - 1 396 | if i + 1 < len(schema_token_indices) 397 | else None 398 | ) 399 | 400 | gen_text_dict[key] = self.tokenizer.decode( 401 | output[start_index:end_index], skip_special_tokens=True 402 | ) 403 | 404 | # remove fields not in schema_return 405 | if schema_return: 406 | keys = gen_text_dict.keys() 407 | if len(schema_return) == 1: 408 | gen_text_dict = gen_text_dict[schema_return[0]] 409 | for key in keys: 410 | if key not in schema_return: 411 | gen_text_dict.pop(key, None) 412 | 413 | gen_texts.append(gen_text_dict) 414 | 415 | # Reset seed if used 416 | if seed: 417 | reset_seed() 418 | 419 | if not return_as_list: 420 | print(*gen_texts, sep="\n" + "=" * 10 + "\n") 421 | break 422 | else: 423 | if n > 1: 424 | return gen_texts 425 | else: 426 | return gen_texts[0] 427 | 428 | # Typical use case 429 | else: 430 | gen_texts = self.tokenizer.batch_decode( 431 | outputs, skip_special_tokens=skip_special_tokens 432 | ) 433 | 434 | # Handle stripping tokenization spaces w/ regex 435 | if lstrip: 436 | gen_texts = [re.sub(r"^\s+", "", text) for text in gen_texts] 437 | 438 | if nonempty_output: 439 | if min_length: 440 | gen_texts = list( 441 | filter(lambda x: len(x) > min_length, gen_texts) 442 | ) 443 | else: 444 | gen_texts = list(filter(lambda x: len(x) > 0, gen_texts)) 445 | 446 | # if there is no generated text after cleanup, try again. 447 | if len(gen_texts) == 0: 448 | continue 449 | 450 | # Reset seed if used 451 | if seed: 452 | reset_seed() 453 | 454 | if not return_as_list: 455 | if prompt: 456 | # Bold the prompt if printing to console 457 | gen_texts = [ 458 | text.replace(prompt_text, f"\033[1m{prompt_text}\033[0m", 1) 459 | for text in gen_texts 460 | ] 461 | 462 | if n > 1: 463 | print(*gen_texts, sep="\n" + "=" * 10 + "\n") 464 | else: 465 | print(gen_texts[0]) 466 | break 467 | else: 468 | return gen_texts 469 | 470 | def generate_one(self, **kwargs) -> None: 471 | """ 472 | Generates a single text, and returns it as a string. Useful for 473 | returning a generated text within an API. 474 | 475 | See generate() for more parameters. 476 | """ 477 | 478 | return self.generate(n=1, return_as_list=True, **kwargs)[0] 479 | 480 | def generate_samples( 481 | self, n: int = 3, temperatures: List[float] = [0.7, 1.0, 1.2], **kwargs 482 | ) -> None: 483 | """ 484 | Prints multiple samples to console at specified temperatures. 485 | """ 486 | 487 | for temperature in temperatures: 488 | print("#" * 20 + f"\nTemperature: {temperature}\n" + "#" * 20) 489 | self.generate(n=n, temperature=temperature, return_as_list=False, **kwargs) 490 | 491 | def generate_to_file( 492 | self, 493 | n: int = 20, 494 | batch_size: int = 1, 495 | destination_path: str = None, 496 | sample_delim: str = "=" * 20 + "\n", 497 | seed: int = None, 498 | **kwargs, 499 | ) -> None: 500 | """ 501 | Generates a bulk amount of texts to a file, into a format 502 | good for manually inspecting and curating the texts. 503 | 504 | :param n: Number of texts to generate 505 | :param batch_size: Number of texts to generate simultaneously, taking 506 | advantage of CPU/GPU parallelization. 507 | :param destination_path: File name of the file. If None, a timestampped 508 | file name is automatically used. 509 | :param sample_delim: The text used to delimit each generated text. 510 | :param seed: Seed used for the generation. The last part of a file name 511 | will be the seed used to reproduce a generation. 512 | :param cleanup: Whether to polish the text before returning 513 | 514 | See generate() for more parameters. 515 | """ 516 | 517 | assert n % batch_size == 0, f"n must be divisible by batch_size ({batch_size})." 518 | 519 | self.model = self.model.eval() 520 | 521 | if destination_path is None: 522 | # Create a time-based file name to prevent overwriting. 523 | # Use a 8-digit number as the seed, which is the last 524 | # numeric part of the file name. 525 | if seed is None: 526 | seed = randint(10 ** 7, 10 ** 8 - 1) 527 | 528 | destination_path = f"ATG_{datetime.utcnow():%Y%m%d_%H%M%S}_{seed}.txt" 529 | 530 | if seed: 531 | set_seed(seed) 532 | 533 | logger.info(f"Generating {n:,} texts to {destination_path}") 534 | 535 | pbar = trange(n) 536 | f = open(destination_path, "w", encoding="utf-8") 537 | 538 | for _ in range(n // batch_size): 539 | gen_texts = self.generate(n=batch_size, return_as_list=True, **kwargs) 540 | 541 | for gen_text in gen_texts: 542 | f.write("{}\n{}".format(gen_text, sample_delim)) 543 | pbar.update(batch_size) 544 | 545 | pbar.close() 546 | f.close() 547 | 548 | if seed: 549 | reset_seed() 550 | 551 | def train( 552 | self, 553 | train_data: Union[str, TokenDataset], 554 | output_dir: str = "trained_model", 555 | fp16: bool = False, 556 | fp16_opt_level: str = "O1", 557 | n_gpu: int = -1, 558 | tpu_cores: int = 0, 559 | max_grad_norm: float = 0.5, 560 | gradient_accumulation_steps: int = 1, 561 | seed: int = None, 562 | learning_rate: float = 1e-3, 563 | weight_decay: float = 0.05, 564 | adam_epsilon: float = 1e-8, 565 | warmup_steps: int = 0, 566 | num_steps: int = 5000, 567 | save_every: int = 1000, 568 | generate_every: int = 1000, 569 | n_generate: int = 1, 570 | loggers: List = None, 571 | batch_size: int = 1, 572 | num_workers: int = None, 573 | benchmark: bool = True, 574 | avg_loss_smoothing: float = 0.01, 575 | save_gdrive: bool = False, 576 | run_id: str = f"ATG_{datetime.utcnow():%Y%m%d_%H%M%S}", 577 | progress_bar_refresh_rate: int = 20, 578 | freeze_layers: bool = False, 579 | num_layers_freeze: int = None, 580 | use_deepspeed: bool = False, 581 | **kwargs, 582 | ) -> None: 583 | """ 584 | Trains/finetunes the model on the provided file/dataset using pytorch-lightning. 585 | 586 | :param train_data: Either a TokenDataset containing the samples to be trained, or 587 | a string containing the text to be trained (shortcut instead of dataset) 588 | :param output_dir: A string indicating where to store the resulting 589 | model file folder. 590 | :param fp16: Boolean whether to use fp16, assuming using a compatible GPU/TPU. 591 | :param fp16_opt_level: Option level for FP16/APEX training. 592 | :param n_gpu: Number of GPU to use (-1 implies all available GPUs) 593 | :param tpu_cores: Number of TPU cores to use (should be a multiple of 8) 594 | :param max_grad_norm: Maximum gradient normalization 595 | :param gradient_accumulation_steps: Number of gradient acc steps 596 | :param seed: Interger representing the training seed. 597 | :param learning_rate: Training learnign rate for the default AdamW optimizer. 598 | :param weight_decay: Weight decay for the default AdamW optimizer. 599 | :param warmup_steps: Warmrup steps for the default AdamW optimizer. 600 | :param num_steps: Number of samples through the dataset. 601 | :param save_every: Number of steps for each time to save the model to disk 602 | :param generate_every: Number of steps for each time to generate sample text 603 | :param n_generate: Number of texts to generate when generate_every occurs. 604 | :param loggers: pytorch-lightning logger(s) to log results. 605 | :param batch_size: Number of input samples per batch 606 | :param num_workers: Number of DataLoader workers 607 | :param benchmark: If using GPU, whether to use cudnn.benchmarkl 608 | :param avg_loss_smoothing: Smoothing factor for Avg loss in progress bar 609 | :param save_gdrive: If using Colab, whether to save the notebook 610 | to Google Drive at each save_every 611 | :param run_id: Run identifier; used for save_gdrive 612 | :param progress_bar_refresh_rate: How often to update 613 | the progress bar while training. 614 | """ 615 | 616 | if not os.path.exists(output_dir): 617 | os.makedirs(output_dir) 618 | 619 | if save_gdrive: 620 | assert ( 621 | "google.colab" in sys.modules 622 | ), "You must be in Colaboratory to copy to your Google Drive" 623 | create_gdrive_folder(run_id) 624 | 625 | self.model = self.model.train() 626 | is_gpu_used = torch.cuda.is_available() and n_gpu != 0 627 | 628 | if isinstance(train_data, str): 629 | block_size = model_max_length(self.model.config) 630 | logger.info( 631 | f"Loading text from {train_data} with generation length of {block_size}." 632 | ) 633 | train_data = TokenDataset( 634 | tokenizer=self.tokenizer, 635 | bos_token=self.bos_token, 636 | eos_token=self.eos_token, 637 | unk_token=self.unk_token, 638 | file_path=train_data, 639 | block_size=block_size, 640 | **kwargs, 641 | ) 642 | 643 | setattr(self.model.config, "line_by_line", train_data.line_by_line) 644 | 645 | if freeze_layers or self.openai_tf_gpt2 == "1558M": 646 | logger.info("Layer freezing enabled for model training.") 647 | freeze_layers = True 648 | if num_layers_freeze: 649 | assert ( 650 | num_layers_freeze < self.model.config.n_layer 651 | ), "You are freezing more Transformer layers than in the model." 652 | 653 | if num_workers is None: 654 | # Use all CPU cores as workers if not training on CPU 655 | if is_gpu_used or tpu_cores > 0: 656 | num_workers = os.cpu_count() 657 | # If training on the CPU, use half the CPUs 658 | else: 659 | num_workers = int(os.cpu_count() / 2) 660 | 661 | hparams = dict( 662 | weight_decay=weight_decay, 663 | learning_rate=learning_rate, 664 | adam_epsilon=adam_epsilon, 665 | warmup_steps=warmup_steps, 666 | batch_size=batch_size, 667 | num_steps=num_steps, 668 | pin_memory=is_gpu_used, 669 | num_workers=num_workers, 670 | save_every=save_every, 671 | generate_every=generate_every, 672 | use_tpu=tpu_cores > 0, 673 | ) 674 | 675 | # Wrap the model in a pytorch-lightning module 676 | train_model = ATGTransformer(self.model, train_data, hparams, self.tokenizer) 677 | 678 | # Begin training 679 | if seed: 680 | set_seed(seed) 681 | 682 | if os.path.exists(output_dir) and "pytorch_model.bin" in os.listdir(output_dir): 683 | logger.warning( 684 | f"pytorch_model.bin already exists in /{output_dir} and will be overwritten!" 685 | ) 686 | 687 | # if try to use a GPU but no CUDA, use CPU 688 | if not is_gpu_used: 689 | n_gpu = 0 690 | 691 | # force single-GPU on Windows 692 | if platform.system() == "Windows" and is_gpu_used and n_gpu != 1: 693 | logger.warning( 694 | "Windows does not support multi-GPU training. Setting to 1 GPU." 695 | ) 696 | n_gpu = 1 697 | 698 | # use the DeepSpeed plugin if installed and specified 699 | deepspeed_plugin = None 700 | if is_gpu_used and use_deepspeed: 701 | deepspeed_plugin = DeepSpeedPlugin() 702 | logger.info("Using DeepSpeed training.") 703 | if not fp16: 704 | logger.info("Setting FP16 to True for DeepSpeed ZeRO Training.") 705 | fp16 = True 706 | 707 | train_params = dict( 708 | accumulate_grad_batches=gradient_accumulation_steps, 709 | gpus=n_gpu, 710 | max_steps=num_steps, 711 | gradient_clip_val=max_grad_norm, 712 | enable_checkpointing=False, #checkpoint_callback deprecated in pytorch_lighning v1.7 713 | logger=loggers if loggers else False, 714 | enable_model_summary=None, #weights_summary and progress_bar_refresh_rate are removed in pytorch_lighning v1.7 715 | callbacks=[ 716 | ATGProgressBar( 717 | save_every, 718 | generate_every, 719 | output_dir, 720 | n_generate, 721 | is_gpu_used, 722 | avg_loss_smoothing, 723 | run_id, 724 | save_gdrive, 725 | progress_bar_refresh_rate, 726 | freeze_layers, 727 | num_layers_freeze, 728 | ) 729 | ], 730 | plugins=deepspeed_plugin, 731 | ) 732 | 733 | if fp16: 734 | train_params["precision"] = 16 if fp16 else 32 735 | train_params["amp_level"] = fp16_opt_level 736 | train_params["amp_backend"] = "apex" 737 | 738 | if tpu_cores > 0: 739 | train_params["tpu_cores"] = tpu_cores 740 | train_params["gpus"] = 0 741 | n_gpu = 0 742 | 743 | # benchmark gives a boost for GPUs if input size is constant, 744 | # which will always be the case with aitextgen training 745 | if is_gpu_used and benchmark: 746 | train_params["benchmark"] = True 747 | 748 | if n_gpu > 1: 749 | train_params["distributed_backend"] = "ddp" 750 | 751 | trainer = pl.Trainer(**train_params) 752 | trainer.fit(train_model) 753 | 754 | logger.info(f"Saving trained model pytorch_model.bin to /{output_dir}") 755 | 756 | self.model.save_pretrained(output_dir) 757 | 758 | if save_gdrive: 759 | for pt_file in ["pytorch_model.bin", "config.json"]: 760 | shutil.copyfile( 761 | os.path.join(output_dir, pt_file), 762 | os.path.join("/content/drive/MyDrive/", run_id, pt_file), 763 | ) 764 | 765 | if seed: 766 | reset_seed() 767 | 768 | def cross_train( 769 | self, 770 | inputs: List[TokenDataset], 771 | learning_rate: Union[float, List[float]] = 1e-4, 772 | num_steps: Union[int, List[int]] = 4000, 773 | run_id: str = f"ATG_{datetime.utcnow():%Y%m%d_%H%M%S}", 774 | **kwargs, 775 | ) -> None: 776 | """Trains a model across multiple input datasets, with automatic 777 | decay after each run.""" 778 | 779 | datasets = [ 780 | TokenDataset( 781 | vocab_file=self.vocab_file, 782 | merges_file=self.merges_file, 783 | bos_token=self.bos_token, 784 | eos_token=self.eos_token, 785 | unk_token=self.unk_token, 786 | file_path=x, 787 | **kwargs, 788 | ) 789 | if isinstance(x, str) 790 | else x 791 | for x in inputs 792 | ] 793 | 794 | if not isinstance(learning_rate, list): 795 | learning_rate = [learning_rate / (2 ** x) for x in range(len(datasets))] 796 | 797 | if not isinstance(num_steps, list): 798 | num_steps = [int(num_steps / (2 ** x)) for x in range(len(datasets))] 799 | 800 | assert len(datasets) == len(learning_rate) == len(num_steps), ( 801 | "The provided learning_rates or num_steps" 802 | + " is not equal to the number of inputs." 803 | ) 804 | 805 | for i, dataset in enumerate(datasets): 806 | logger.info(f"Now training on {dataset} for {num_steps[i]:,} steps.") 807 | self.train( 808 | dataset, 809 | learning_rate=learning_rate[i], 810 | num_steps=num_steps[i], 811 | run_id=run_id, 812 | **kwargs, 813 | ) 814 | 815 | def save(self, target_folder: str = os.getcwd()): 816 | """Saves the model into the specified directory.""" 817 | self.model.save_pretrained(target_folder) 818 | 819 | def save_for_upload(self, target_folder: str = "my-model"): 820 | """ 821 | Saves the model + tokenizerinto the specified directory. 822 | 823 | This generates the 6 files needed to upload the model to 824 | Huggingface's S3 bucket. 825 | """ 826 | self.model.save_pretrained(target_folder) 827 | self.tokenizer.save_pretrained(target_folder) 828 | 829 | def export( 830 | self, 831 | quantize: bool = True, 832 | ) -> None: 833 | """ 834 | Exports the model, with optional quantization 835 | """ 836 | 837 | def to_gpu(self, index: int = 0) -> None: 838 | """Moves the model to the specified GPU.""" 839 | 840 | assert torch.cuda.is_available(), "CUDA is not installed." 841 | 842 | self.model.to(torch.device("cuda", index)) 843 | 844 | def to_cpu(self, index: int = 0) -> None: 845 | """Moves the model to the specified CPU.""" 846 | 847 | self.model.to(torch.device("cpu", index)) 848 | 849 | def to_fp16(self) -> None: 850 | """ 851 | Converts the model to a FP16 representation. 852 | Should only be used to generate on a supported GPU. 853 | """ 854 | 855 | self.model = self.model.half() 856 | 857 | def get_device(self) -> str: 858 | """Getter for the current device where the model is located.""" 859 | return self.model.device.type 860 | 861 | def __repr__(self) -> str: 862 | # https://discuss.pytorch.org/t/how-do-i-check-the-number-of-parameters-of-a-model/4325/24 863 | num_params_m = int(sum(p.numel() for p in self.model.parameters()) / 10 ** 6) 864 | model_name = type(self.model.config).__name__.replace("Config", "") 865 | return f"{model_name} loaded with {num_params_m}M parameters." 866 | --------------------------------------------------------------------------------