├── .github └── workflows │ └── python-publish.yml ├── .gitignore ├── LICENSE ├── README.md ├── requirements.txt ├── setup.py └── src ├── scripts └── train_gsm8k.py └── toolformer ├── __init__.py ├── config.py ├── sequence_scoring.py ├── tool.py ├── tool_sampling.py ├── toolformer.py └── tools ├── __init__.py ├── calculator.py └── date.py /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | permissions: 16 | contents: read 17 | 18 | jobs: 19 | deploy: 20 | 21 | runs-on: ubuntu-latest 22 | 23 | steps: 24 | - uses: actions/checkout@v3 25 | - name: Set up Python 26 | uses: actions/setup-python@v3 27 | with: 28 | python-version: '3.x' 29 | - name: Install dependencies 30 | run: | 31 | python -m pip install --upgrade pip 32 | pip install build 33 | - name: Build package 34 | run: python -m build 35 | - name: Publish package 36 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 37 | with: 38 | user: __token__ 39 | password: ${{ secrets.PYPI_API_TOKEN }} 40 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 mrcabbage972 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # simple-toolformer 2 | # Introduction 3 | A Python implementation of [Toolformer](https://arxiv.org/abs/2302.04761) using Pytorch and Huggingface Transformers 4 | 5 | This implementation is under active development and wasn't yet verified to work end-to-end. 6 | Therefore, it's currently intended to be used for educational purposes only. 7 | 8 | The immediate goal is to fine-tune a model on a downstream task and verify that this generates a lift vs fine-tuning 9 | just the backbone on the same task. 10 | 11 | # Usage 12 | First, please install the requirements file. 13 | 14 | The example training script is at `src/scripts/train_gsm8k.py`. This would train the model on the [GSM8k](https://huggingface.co/datasets/gsm8k) dataset of Math Word Problems. 15 | 16 | # Contributing 17 | If you wish to contribute to this project, please check out the existing issues or open a new one. 18 | 19 | # Citations 20 | 21 | ```bibtex 22 | @inproceedings{Schick2023ToolformerLM, 23 | title = {Toolformer: Language Models Can Teach Themselves to Use Tools}, 24 | author = {Timo Schick and Jane Dwivedi-Yu and Roberto Dessi and Roberta Raileanu and Maria Lomeli and Luke Zettlemoyer and Nicola Cancedda and Thomas Scialom}, 25 | year = {2023} 26 | } 27 | ``` -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm~=4.64.1 2 | torch~=1.13.1 3 | numpy~=1.24.2 4 | datasets~=2.9.0 5 | transformers~=4.26.1 6 | sentencepiece~=0.1.97 -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name = 'simple-toolformer', 5 | packages = find_packages('src'), 6 | package_dir={'':'src'}, 7 | version = '0.0.1', 8 | license='MIT', 9 | description = 'Toolformer', 10 | long_description_content_type = 'text/markdown', 11 | author = 'mrcabbage972', 12 | url = 'https://github.com/mrcabbage972/simple-toolformer', 13 | keywords = [ 14 | 'deep learning', 15 | 'transformers', 16 | 'natural language processing' 17 | ], 18 | install_requires=[ 19 | ], 20 | setup_requires=[ 21 | 'pytest-runner', 22 | ], 23 | tests_require=[ 24 | 'pytest' 25 | ], 26 | classifiers=[ 27 | 'Development Status :: 4 - Beta', 28 | 'Intended Audience :: Developers', 29 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 30 | 'License :: OSI Approved :: MIT License', 31 | 'Programming Language :: Python :: 3.6', 32 | ], 33 | ) 34 | -------------------------------------------------------------------------------- /src/scripts/train_gsm8k.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import re 3 | 4 | from datasets import load_dataset 5 | 6 | from toolformer.toolformer import Toolformer 7 | from toolformer.tools.calculator import CalculatorTool 8 | from toolformer.tools.date import DateTool 9 | 10 | logging.basicConfig(level=logging.INFO) 11 | 12 | def main(): 13 | tf = Toolformer() 14 | 15 | dataset = load_dataset("gsm8k", 'main', split="train").select(range(5)) 16 | dataset = dataset.rename_column('question', 'input') 17 | dataset = dataset.rename_column('answer', 'label') 18 | dataset = dataset.map(lambda x: {'input': x['input'], 19 | 'label': re.sub("(<<).*?(>>)", "", x['label']).split('####')[0] 20 | .rstrip().replace('\n', ' ')}) 21 | apis = [CalculatorTool()] 22 | 23 | tf.fit(dataset, apis) 24 | 25 | 26 | if __name__ == '__main__': 27 | main() -------------------------------------------------------------------------------- /src/toolformer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrcabbage972/simple-toolformer/9986d1784daddc533536b724dce6938700d6c313/src/toolformer/__init__.py -------------------------------------------------------------------------------- /src/toolformer/config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import torch 4 | 5 | 6 | @dataclass 7 | class ToolformerConfig: 8 | # General 9 | model_name = "EleutherAI/gpt-neo-125M" 10 | causal_model = True 11 | target_device = 'cpu' if not torch.cuda.is_available() else 'cuda' 12 | 13 | # Sampling 14 | sampler = 'basic' 15 | 16 | # Inference 17 | max_new_tokens = 128 18 | 19 | # Training 20 | mlm_prob = 0.15 21 | max_length = 256 22 | output_path = '..' 23 | output_name = 'model' 24 | learning_rate = 1e-4 25 | train_batch_size = 16 26 | eval_batch_size = 32 27 | epochs = 1 28 | weight_decay = 0.01 29 | warmup_ratio = 0.1 30 | fp16 = False 31 | early_stopping_patience = 1 32 | test_size = 0.2 33 | 34 | # Filtering 35 | tool_call_thresh = 0 -------------------------------------------------------------------------------- /src/toolformer/sequence_scoring.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | def get_scores_for_labels(input: List[str], labels: List[str], model, tokenizer): 7 | """ 8 | Calculates the conditional log-likelihood of labels given the input, for an encoder-decoder model 9 | :param input: 10 | :param labels: 11 | :param model: 12 | :param tokenizer: 13 | :return: 14 | """ 15 | # Taken from: https://colab.research.google.com/drive/1Q8VAwCPB12ZzYH79nAuiiSSnVDkZ2-u7?usp=sharing#scrollTo=WJshYeFQ_IeB 16 | 17 | batch_size, num_labels = len(input), len(labels) 18 | # Get encodings 19 | input_enc = tokenizer.batch_encode_plus(input, return_tensors="pt", add_special_tokens=True, truncation=True, padding="longest") 20 | target_enc = tokenizer.batch_encode_plus(labels, return_tensors="pt", padding="longest").input_ids 21 | 22 | # Get encoder's last hidden state 23 | encoder_hidden_states = model.encoder(**input_enc)[0] 24 | 25 | # Repeat the inputs `num_label` times 26 | encoder_hidden_states = encoder_hidden_states.unsqueeze(dim=1).repeat(1, num_labels, 1, 1).flatten(0, 1) 27 | attention_mask = input_enc.attention_mask.unsqueeze(dim=1).repeat(1, num_labels, 1).flatten(0, 1) 28 | 29 | # Create the decoding mask (that is commonly generated by the T5 model at predict time) -- makes it more efficient 30 | decoder_input_ids = torch.cat([torch.zeros((num_labels * batch_size, 1), dtype=torch.int), target_enc[:, :-1].repeat(num_labels, 1)], dim=1) 31 | decoder_attention_mask = (decoder_input_ids == decoder_input_ids).float() 32 | lm_target = target_enc - 100 * (target_enc == tokenizer.pad_token_id).long() 33 | 34 | model_output = model( 35 | attention_mask=attention_mask, 36 | encoder_outputs=[encoder_hidden_states], 37 | decoder_input_ids=decoder_input_ids, 38 | decoder_attention_mask=decoder_attention_mask, 39 | ) 40 | 41 | # Compute the log probabilities associated with each of the labels 42 | labels_log_probs = F.cross_entropy( 43 | model_output.logits.flatten(0, 1), 44 | lm_target.repeat(num_labels, 1).flatten(0, 1), 45 | reduction="none", 46 | ) 47 | 48 | # Sum log probs for each of the (input, label) pair 49 | labels_scores = labels_log_probs.view(batch_size, num_labels, -1) 50 | labels_scores = labels_scores.sum(dim=-1) 51 | 52 | # Note: Label log probabilities are positive (due to the internals of pytorch's 53 | # cross entropy). To obtain the "logits", we need to multiply by -1. 54 | return labels_scores * -1 -------------------------------------------------------------------------------- /src/toolformer/tool.py: -------------------------------------------------------------------------------- 1 | import re 2 | from abc import ABCMeta, abstractmethod 3 | 4 | 5 | class Tool(metaclass=ABCMeta): 6 | """ 7 | This is the base class for tools. It provides some convenience methods around extracting tool annotations. 8 | """ 9 | API_CALL_PREFIX = '[' 10 | API_CALL_SUFFIX = ']' 11 | RESULT_PREFIX = '->' 12 | 13 | def get_tool_regex(self, match_before=False) -> str: 14 | result = r'\[{}\(.*\)\]'.format(self.get_tool_name().upper()) 15 | if match_before: 16 | result = r'^.*' + result 17 | return result 18 | 19 | def text_has_call(self, text) -> bool: 20 | return re.match(self.get_tool_regex(), text) is not None 21 | 22 | def get_call_from_text(self, text) -> str: 23 | result = re.search('^.*(?P{})'.format(self.get_tool_regex()), text) 24 | return result.groupdict()['api_call'] 25 | 26 | def get_text_before_call(self, text) -> str: 27 | # TODO: refactor 28 | result = re.search('^.*(?P{})'.format(self.get_tool_regex()), text) 29 | return text[:result.span('api_call')[0]] 30 | 31 | def get_text_after_call(self, text) -> str: 32 | result = re.search('^.*(?P{})'.format(self.get_tool_regex()), text) 33 | return text[result.span('api_call')[1]:] 34 | 35 | @ abstractmethod 36 | def get_tool_name(self): 37 | pass 38 | 39 | @abstractmethod 40 | def get_prompt_template(self) -> str: 41 | pass 42 | 43 | @abstractmethod 44 | def run(self, input: str) -> str: 45 | pass 46 | -------------------------------------------------------------------------------- /src/toolformer/tool_sampling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from datasets import Dataset, concatenate_datasets 3 | from torch.utils.data import DataLoader 4 | from transformers import DataCollatorWithPadding 5 | import torch.nn.functional as F 6 | from toolformer.tool import Tool 7 | 8 | def prepare_dataset_for_sampling(dataset, tokenizer, tool): 9 | empty_prompt = tool.get_prompt_template().format('') 10 | tok_empty_prompt = tokenizer(empty_prompt) 11 | 12 | prompts_dataset = dataset.map(lambda x: {'prompt': tool.get_prompt_template().format(x['label'])}) 13 | encoded_dataset = prompts_dataset.map(lambda x: tokenizer(x['prompt'], 14 | truncation=True, padding=True), batched=True) 15 | encoded_dataset = encoded_dataset.map(lambda x: {'pos_idx_mask': 16 | torch.cat([torch.zeros(len(tok_empty_prompt['input_ids'])), 17 | torch.ones(len(x['input_ids']) - len(tok_empty_prompt['input_ids']))], 18 | -1)}) 19 | encoded_dataset.set_format(columns=['input_ids', 'attention_mask'], type='torch') 20 | return encoded_dataset 21 | 22 | def postprocess_samples(all_preds, dataset, tool, is_causal_model): 23 | pred_ds = Dataset.from_dict({'text': all_preds, 24 | 'prompt': [tool.get_prompt_template().format(z['label']) for z in dataset]}) 25 | # prompt_end_idx = len(tool.get_prompt_template().replace('{}', '').rstrip()) 26 | if is_causal_model: 27 | return pred_ds.map(lambda x: {'text': x['text'][len(x['prompt']):]}) 28 | else: 29 | return pred_ds 30 | 31 | 32 | class BasicToolSampler: 33 | """ 34 | A basic tool sampler that just calls the generate method of them model 35 | 36 | """ 37 | def __init__(self, tokenizer, model, cfg): 38 | self.cfg = cfg 39 | self.model = model 40 | self.tokenizer = tokenizer 41 | 42 | def sample(self, dataset: Dataset, tool: Tool) -> Dataset: 43 | encoded_dataset = prepare_dataset_for_sampling(dataset, self.tokenizer, tool) 44 | data_loader = DataLoader(encoded_dataset, batch_size=32, 45 | collate_fn=DataCollatorWithPadding(self.tokenizer)) 46 | data_iter = iter(data_loader) 47 | 48 | all_preds = [] 49 | for inputs in data_iter: 50 | inputs = {k: v.to(self.cfg.target_device) for k, v in inputs.items()} 51 | with torch.no_grad(): 52 | batch_preds = self.model.generate(**inputs, 53 | max_new_tokens=self.cfg.max_new_tokens, 54 | return_dict_in_generate=True, 55 | output_scores=True) 56 | 57 | all_preds += [self.tokenizer.decode(x, skip_special_tokens=True) for x in batch_preds['sequences']] 58 | 59 | # This is a bit ugly due to iterating over the dataset manually 60 | return postprocess_samples(all_preds, dataset, tool, self.cfg.causal_model) 61 | 62 | 63 | class TwoStepToolSampler: 64 | """ 65 | WORK IN PROGRESS 66 | 67 | Implements the sampling procedure as detailed in the paper: 68 | First, sample K positions for the [ token. 69 | Then, sample M sequences out of each of the K. 70 | """ 71 | 72 | def __init__(self, tokenizer, model, cfg, top_k, num_seq_per_pos): 73 | self.cfg = cfg 74 | self.model = model 75 | self.tokenizer = tokenizer 76 | self.num_seq_per_pos = num_seq_per_pos 77 | self.top_k = top_k 78 | self.tool_call_token_id = tokenizer.convert_tokens_to_ids(Tool.API_CALL_PREFIX) 79 | self.tool_call_end_token_id = tokenizer.convert_tokens_to_ids(Tool.API_CALL_SUFFIX) 80 | 81 | def sample(self, dataset: Dataset, tool: Tool) -> Dataset: 82 | encoded_dataset = prepare_dataset_for_sampling(dataset, self.tokenizer, tool) 83 | 84 | topk_pos_idx = self.get_topk_pos_idx(encoded_dataset, tool) 85 | anns_at_pos = [self.get_anns_at_pos(dataset, encoded_dataset, topk_pos_idx[:, idx], tool) for idx in range(self.top_k)] 86 | return concatenate_datasets(anns_at_pos) 87 | 88 | def get_topk_pos_idx(self, encoded_dataset, tool): 89 | encoded_dataset.set_format(columns=['input_ids', 'attention_mask', 'pos_idx_mask'], type='torch') 90 | data_loader = DataLoader(encoded_dataset, batch_size=32, 91 | collate_fn=DataCollatorWithPadding(self.tokenizer)) 92 | data_iter = iter(data_loader) 93 | 94 | all_preds = [] 95 | for inputs in data_iter: 96 | inputs = {k: v.to(self.cfg.target_device) for k, v in inputs.items()} 97 | out = self.model(**{k:v for k,v in inputs.items() if k != 'pos_idx_mask'}) 98 | api_prob_at_idx = out.logits[:, :, self.tool_call_token_id] 99 | api_prob_at_idx[~inputs['attention_mask']] = -100 100 | api_prob_at_idx[~inputs['pos_idx_mask'].long()] = -100 101 | api_prob_topk_idx = api_prob_at_idx.topk(self.top_k).indices 102 | all_preds.append(api_prob_topk_idx.detach()) 103 | return torch.concat(all_preds, 0) 104 | 105 | def get_anns_at_pos(self, dataset, encoded_dataset, pos_idx, tool): 106 | # TODO: refactor to avoid having to pass two dataset objects 107 | # Get the text before the desired position and add the tool call token 108 | dataset_at_idx = encoded_dataset.add_column('pos_idx', pos_idx.numpy())\ 109 | .map(lambda x: {'input_ids': torch.cat([x['input_ids'][:x['pos_idx']], torch.tensor(x['pos_idx']).unsqueeze(-1)], -1), 110 | #'input_ids_suffix': x['input_ids'][x['pos_idx']:], 111 | 'attention_mask': torch.cat([x['attention_mask'][:x['pos_idx']], torch.tensor(1).unsqueeze(-1)], -1)}) 112 | suffixes_ds = dataset_at_idx.map(lambda x: {'suffix': self.tokenizer.decode(x['input_ids'][x['pos_idx']:], skip_special_tokens=True)}) 113 | suffixes_ds = suffixes_ds.remove_columns(list(dataset_at_idx.features.keys())) 114 | dataset_at_idx.set_format(columns=['input_ids', 'attention_mask'], type='torch') 115 | data_loader = DataLoader(dataset_at_idx, batch_size=32, 116 | collate_fn=DataCollatorWithPadding(self.tokenizer)) 117 | data_iter = iter(data_loader) 118 | 119 | all_preds = [] 120 | for inputs in data_iter: 121 | inputs = {k: v.to(self.cfg.target_device) for k, v in inputs.items()} 122 | batch_preds = self.model.generate(**inputs, 123 | max_new_tokens=self.cfg.max_new_tokens, 124 | return_dict_in_generate=True, 125 | output_scores=True, 126 | eos_token_id=self.tool_call_end_token_id) 127 | all_preds += [self.tokenizer.decode(x, skip_special_tokens=True) for x 128 | in batch_preds['sequences']] 129 | 130 | samples_ds = postprocess_samples(all_preds, dataset, tool, self.cfg.causal_model) 131 | samples_ds = concatenate_datasets([samples_ds, suffixes_ds], axis=1) 132 | return samples_ds.map(lambda x: {'text': x['text'] + x['suffix']}) -------------------------------------------------------------------------------- /src/toolformer/toolformer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from typing import List 4 | 5 | import torch 6 | from datasets import Dataset, concatenate_datasets 7 | from torch.utils.data import DataLoader 8 | from transformers import DataCollatorWithPadding, EarlyStoppingCallback, T5ForConditionalGeneration, AutoTokenizer, \ 9 | AutoModelForCausalLM, TrainingArguments, Trainer, \ 10 | DataCollatorForLanguageModeling 11 | 12 | from toolformer.config import ToolformerConfig 13 | from toolformer.sequence_scoring import get_scores_for_labels 14 | from toolformer.tool import Tool 15 | from toolformer.tool_sampling import BasicToolSampler, TwoStepToolSampler 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | class Toolformer: 20 | def __init__(self): 21 | self.cfg = ToolformerConfig() 22 | self.tokenizer = AutoTokenizer.from_pretrained(self.cfg.model_name, padding_side='left') 23 | self.tokenizer.pad_token = self.tokenizer.eos_token 24 | if self.cfg.causal_model: 25 | self.model = AutoModelForCausalLM.from_pretrained(self.cfg.model_name) 26 | else: # Currently assuming that only non-causal model is T5 family 27 | self.model = T5ForConditionalGeneration.from_pretrained(self.cfg.model_name) 28 | 29 | if self.cfg.sampler == 'basic': 30 | self.tool_sampler = BasicToolSampler(self.tokenizer, self.model, self.cfg) 31 | elif self.cfg.sampler == 'two_step': 32 | # TODO: remove hard-coded args 33 | self.tool_sampler = TwoStepToolSampler(self.tokenizer, self.model, self.cfg, 4, 2) 34 | else: 35 | raise ValueError 36 | 37 | def fit(self, dataset: Dataset, tools: List[Tool]): 38 | """ 39 | This is the main method for implementing the training process described in the Toolformer paper. 40 | It contains three stages: 41 | 1. Use the model to generate samples of tool usage, using few-shot prompting 42 | 2. Filter the samples by a measure of how much the improve the likelihood of the text after the tool call. 43 | 3. Fit the model on the filtered samples. 44 | :param dataset: 45 | The dataset to fit on. 46 | :param tools: 47 | A list of tools to be considered. 48 | """ 49 | logger.info('Fitting with a dataset of size {} and {} tools'.format(len(dataset), len(tools))) 50 | samples_for_tuning = [] 51 | for tool in tools: 52 | maybe_tool_samples = self.sample_dataset(dataset, tool) 53 | logger.info('Examples of {} tool generation results: {}'.format(tool.get_tool_name(), ','.join(maybe_tool_samples[:2]['text']))) 54 | tool_samples = maybe_tool_samples.filter(lambda x: tool.text_has_call(x['text'])) 55 | logger.info('{} samples left after filtering for tool name'.format(len(tool_samples))) 56 | if len(tool_samples) > 0: 57 | logger.info('Examples of {} tool filtered annotations: {}'.format(tool.get_tool_name(), ','.join(maybe_tool_samples[:2]['text']))) 58 | executed_tool_samples = tool_samples.map(lambda x: self.execute_tool_call(x, tool)) 59 | likely_samples = self.filter_likelihood(executed_tool_samples, tool) 60 | logger.info('{} samples left after filtering by likelihood'.format(len(likely_samples))) 61 | samples_for_tuning.append(likely_samples) 62 | if len(samples_for_tuning) > 0: 63 | dataset_for_tuning = concatenate_datasets(samples_for_tuning) 64 | else: 65 | dataset_for_tuning = [] 66 | if len(dataset_for_tuning) == 0: 67 | raise ValueError("Can't proceed: There is no data to fine-tune on!") 68 | self.fine_tune(dataset_for_tuning) 69 | 70 | def sample_dataset(self, dataset: Dataset, tool: Tool) -> Dataset: 71 | """ 72 | This methods samples a dataset to produce example API calls. 73 | The sampling procedure is implemented as just straightforward generation. 74 | The method in the paper is to first find the top K positions for the next token being [ and then 75 | to try generating M calls start from each of these K. 76 | 77 | :param dataset: 78 | The input texts 79 | :param tool: 80 | The tool to annotate the input texts with 81 | :return: 82 | A Dataset containing a text field and a score field. 83 | """ 84 | logger.info('Sampling dataset') 85 | return self.tool_sampler.sample(dataset, tool) 86 | 87 | 88 | def filter_likelihood(self, inputs: Dataset, tool: Tool) -> Dataset: 89 | """ 90 | Filters the sampled tool uses by a criterion that quantifies how much they improve the likelihood 91 | of the text after the tool call. The paper uses a weighting scheme which is currently not implemented here. 92 | Another thing to note is that in this stage the tool annotation is prepended to the input text rather than 93 | inserted at its correct place 94 | The criterion can be roughly described as: 95 | # loss (with prefix of tool call and result ) < min(loss (with prefix of tool call), loss(no tool call) 96 | 97 | :param inputs: 98 | A dataset with tool call annotations. 99 | :param tool: 100 | THe tool. 101 | :return: 102 | Same as inputs but filtered by the criterion. 103 | """ 104 | logger.info('Filtering generated samples by their likelihood') 105 | 106 | if self.cfg.causal_model: 107 | raise NotImplementedError 108 | 109 | inputs = inputs.map(lambda x: {**x, 110 | 'text_before': tool.get_text_before_call(x['text']), 111 | 'tool_call': tool.get_call_from_text(x['text']), 112 | 'text_after': tool.get_text_after_call(x['text'])}) 113 | 114 | inputs = inputs.map(lambda x: {**x, 115 | 'tool_call_text_before': x['tool_call'] + x['text_before'], 116 | 'tool_call_result_text_before': x['tool_call'] + x['tool_result'] + x[ 117 | 'text_before'], 118 | }) 119 | 120 | inputs = inputs.map(lambda x: {**x, 121 | 'loss_no_tool': 122 | get_scores_for_labels(x['text_before'], x['text_after'], self.model, 123 | self.tokenizer)[0], 124 | 'loss_tool': 125 | get_scores_for_labels(inputs['tool_call_text_before'], inputs['text_after'], 126 | self.model, self.tokenizer)[0], 127 | 'loss_tool_no_result': 128 | get_scores_for_labels(inputs['tool_call_text_before'], inputs['text_after'], 129 | self.model, self.tokenizer)[0] 130 | }, batched=True) 131 | 132 | 133 | return inputs.filter( 134 | lambda x: min(x['loss_no_tool'], x['loss_tool_no_result']) - x['loss_tool'] >= self.cfg.tool_call_thresh) 135 | 136 | def fine_tune(self, api_call_samples: Dataset): 137 | """ 138 | This is just standard HF fine-tuning with the language modeling objective. 139 | See e.g. https://huggingface.co/docs/transformers/tasks/language_modeling 140 | :param api_call_samples: 141 | """ 142 | logger.info('Fine-tuning the model on {} API call samples'.format(len(api_call_samples))) 143 | datasets = api_call_samples.train_test_split(test_size=self.cfg.test_size) 144 | 145 | train_args = TrainingArguments( 146 | output_dir=os.path.join(self.cfg.output_path, self.cfg.output_name), 147 | evaluation_strategy="epoch", 148 | save_strategy="epoch", 149 | logging_strategy="epoch", 150 | learning_rate=self.cfg.learning_rate, 151 | per_device_train_batch_size=self.cfg.train_batch_size, 152 | per_device_eval_batch_size=self.cfg.eval_batch_size, 153 | num_train_epochs=self.cfg.epochs, 154 | weight_decay=self.cfg.weight_decay, 155 | warmup_ratio=self.cfg.warmup_ratio, 156 | load_best_model_at_end=True, 157 | save_total_limit=1, 158 | fp16=self.cfg.fp16, 159 | ) 160 | 161 | if self.cfg.causal_model: 162 | data_collator = DataCollatorForLanguageModeling(tokenizer=self.tokenizer, mlm=False) 163 | else: 164 | data_collator = DataCollatorForLanguageModeling(tokenizer=self.tokenizer, mlm_probability=self.cfg.mlm_prob) 165 | 166 | trainer = Trainer( 167 | self.model, 168 | train_args, 169 | train_dataset=datasets["train"], 170 | eval_dataset=datasets["test"], 171 | tokenizer=self.tokenizer, 172 | compute_metrics=None, 173 | data_collator=data_collator, 174 | callbacks=[EarlyStoppingCallback(early_stopping_patience=self.cfg.early_stopping_patience)] 175 | ) 176 | 177 | trainer.train() 178 | 179 | def execute_tool_call(self, sample, tool: Tool) -> dict: 180 | sample['tool_result'] = tool.run(sample['text']) 181 | return sample 182 | -------------------------------------------------------------------------------- /src/toolformer/tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrcabbage972/simple-toolformer/9986d1784daddc533536b724dce6938700d6c313/src/toolformer/tools/__init__.py -------------------------------------------------------------------------------- /src/toolformer/tools/calculator.py: -------------------------------------------------------------------------------- 1 | from toolformer.tool import Tool 2 | 3 | PROMPT_TEMPLATE = \ 4 | """ 5 | Your task is to add calls to a Calculator API to a piece of text. 6 | The calls should help you get information required to complete the text. 7 | You can call the API by writing "[Calculator(expression)]" where "expression" is the expression to be computed. 8 | Here are some examples of API calls: 9 | 10 | Input: The number in the next term is 18 + 12 x 3 = 54. 11 | Output: The number in the next term is 18 + 12 x 3 = [Calculator(18 + 12 * 3)] 54. 12 | 13 | Input: The population is 658,893 people. This is 11.4% of the national average of 5,763,868 people. 14 | Output: The population is 658,893 people. This is 11.4% of the national average of [Calculator(658,893 / 11.4%)] 5,763,868 people. 15 | 16 | Input: A total of 252 qualifying matches were played, and 723 goals were scored (an average of 2.87 per match). This is three times less than the 2169 goals last year. 17 | Output: A total of 252 qualifying matches were played, and 723 goals were scored (an average of [Calculator(723 / 252)] 2.87 per match). This is twenty goals more than the [Calculator(723 - 20)] 703 goals last year. 18 | 19 | Input: I went to Paris in 1994 and stayed there until 2011, so in total, it was 17 years. 20 | Output: I went to Paris in 1994 and stayed there until 2011, so in total, it was [Calculator(2011 - 1994)] 17 years. 21 | 22 | Input: From this, we have 4 * 30 minutes = 120 minutes. 23 | Output: From this, we have 4 * 30 minutes = [Calculator(4 * 30)] 120 minutes. 24 | 25 | Input: {} 26 | Output: """ 27 | 28 | 29 | class CalculatorTool(Tool): 30 | def get_tool_name(self): 31 | return 'Calculator' 32 | 33 | def get_prompt_template(self) -> str: 34 | return PROMPT_TEMPLATE 35 | 36 | def run(self, input: str) -> str: 37 | # TODO: the following code should be a method in the Tool class. 38 | call = self.get_call_from_text(input)[1:-1].replace(self.get_tool_name().upper(), '')[1:-1] 39 | for operator in ['+', '-', '*', '/']: 40 | if operator in call: 41 | operands = [float(x.strip()) for x in call.split(operator)] 42 | result = None 43 | if operator == '+': 44 | result = operands[0] + operands[1] 45 | elif operator == '-': 46 | result = operands[0] - operands[1] 47 | elif operator == '*': 48 | result = operands[0] * operands[1] 49 | elif operator == '/': 50 | result = operands[0] / operands[1] 51 | return "{:.4g}".format(result) 52 | raise ValueError('Tool call not found!') 53 | 54 | 55 | if __name__ == '__main__': 56 | print(CalculatorTool().run('asdadsad [Calculator(723 / 252000)] asd')) -------------------------------------------------------------------------------- /src/toolformer/tools/date.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | 3 | from toolformer.tool import Tool 4 | 5 | 6 | class DateTool(Tool): 7 | def get_tool_name(self): 8 | return "DATE" 9 | 10 | def get_prompt_template(self) -> str: 11 | return """ 12 | Your task is to add calls to a Date API to a piece of text. 13 | The calls should help you get information required to complete the text. 14 | Here is an example of an API call: 15 | 16 | Input: Joe Biden was born 80 years ago 17 | Output: Joe Biden was born [DATE()] 80 years ago 18 | 19 | Input: {} 20 | Output: 21 | """ 22 | 23 | def run(self, input: str) -> str: 24 | return datetime.today().strftime('%Y-%m-%d') 25 | 26 | 27 | if __name__ == '__main__': 28 | print(DateTool().text_has_call('aaa [DATE()] bbb')) 29 | print(DateTool().get_text_before_call('aaa [DATE()] bbb')) 30 | print(DateTool().get_call_from_text('aaa [DATE()] bbb')) 31 | print(DateTool().get_text_after_call('aaa [DATE()] bbb')) --------------------------------------------------------------------------------