├── .github └── workflows │ └── python-publish.yml ├── .gitignore ├── LICENSE ├── README.md ├── setup.py ├── toolformer.png ├── toolformer_pytorch ├── __init__.py ├── optimizer.py ├── palm.py ├── prompts.py ├── toolformer_pytorch.py └── tools.py └── tools-requirements.txt /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | 2 | 3 | # This workflow will upload a Python Package using Twine when a release is created 4 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 5 | 6 | # This workflow uses actions that are not certified by GitHub. 7 | # They are provided by a third-party and are governed by 8 | # separate terms of service, privacy policy, and support 9 | # documentation. 10 | 11 | name: Upload Python Package 12 | 13 | on: 14 | release: 15 | types: [published] 16 | 17 | jobs: 18 | deploy: 19 | 20 | runs-on: ubuntu-latest 21 | 22 | steps: 23 | - uses: actions/checkout@v2 24 | - name: Set up Python 25 | uses: actions/setup-python@v2 26 | with: 27 | python-version: '3.x' 28 | - name: Install dependencies 29 | run: | 30 | python -m pip install --upgrade pip 31 | pip install build 32 | - name: Build package 33 | run: python -m build 34 | - name: Publish package 35 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 36 | with: 37 | user: __token__ 38 | password: ${{ secrets.PYPI_API_TOKEN }} 39 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Toolformer - Pytorch (wip) 4 | 5 | Implementation of Toolformer, Language Models That Can Use Tools, by MetaAI 6 | 7 | ## Appreciation 8 | 9 | - Stability.ai for the generous sponsorship to work and open source cutting edge artificial intelligence research 10 | 11 | - Enrico for getting the ball rolling with the initial commit of different tools! 12 | 13 | - Thanks goes out to ChatGPT for doing all the regular expressions in this repository for parsing the functions and parameters for the API calls. I am terrible at regular expressions, so this was enormous help from the AI (with no hitches, it was perfect). 14 | 15 | ## Install 16 | 17 | ```bash 18 | $ pip install toolformer-pytorch 19 | ``` 20 | 21 | ## Usage 22 | 23 | Example usage with giving language models awareness of current date and time. 24 | 25 | ```python 26 | import torch 27 | from toolformer_pytorch import Toolformer, PaLM 28 | 29 | # simple calendar api call - function that returns a string 30 | 31 | def Calendar(): 32 | import datetime 33 | from calendar import day_name, month_name 34 | now = datetime.datetime.now() 35 | return f'Today is {day_name[now.weekday()]}, {month_name[now.month]} {now.day}, {now.year}.' 36 | 37 | # prompt for teaching it to use the Calendar function from above 38 | 39 | prompt = f""" 40 | Your task is to add calls to a Calendar API to a piece of text. 41 | The API calls should help you get information required to complete the text. 42 | You can call the API by writing "[Calendar()]" 43 | Here are some examples of API calls: 44 | Input: Today is the first Friday of the year. 45 | Output: Today is the first [Calendar()] Friday of the year. 46 | Input: The president of the United States is Joe Biden. 47 | Output: The president of the United States is [Calendar()] Joe Biden. 48 | Input: [input] 49 | Output: 50 | """ 51 | 52 | data = [ 53 | "The store is never open on the weekend, so today it is closed.", 54 | "The number of days from now until Christmas is 30", 55 | "The current day of the week is Wednesday." 56 | ] 57 | 58 | # model - here using PaLM, but any nn.Module that returns logits in the shape (batch, seq, num_tokens) is fine 59 | 60 | model = PaLM( 61 | dim = 512, 62 | depth = 2, 63 | heads = 8, 64 | dim_head = 64 65 | ).cuda() 66 | 67 | # toolformer 68 | 69 | toolformer = Toolformer( 70 | model = model, 71 | model_seq_len = 256, 72 | teach_tool_prompt = prompt, 73 | tool_id = 'Calendar', 74 | tool = Calendar, 75 | finetune = True 76 | ) 77 | 78 | # invoking this will 79 | # (1) prompt the model with your inputs (data), inserted into [input] tag 80 | # (2) with the sampled outputs, filter out the ones that made proper API calls 81 | # (3) execute the API calls with the `tool` given 82 | # (4) filter with the specialized filter function (which can be used independently as shown in the next section) 83 | # (5) fine-tune on the filtered results 84 | 85 | filtered_stats = toolformer(data) 86 | 87 | # then, once you see the 'finetune complete' message 88 | 89 | response = toolformer.sample_model_with_api_calls("How many days until the next new years?") 90 | 91 | # hopefully you see it invoke the calendar and utilize the response of the api call... 92 | 93 | ``` 94 | 95 | The main novelty of the paper is defining a fitness score for the outputs from a transformer instructed to insert API calls. The score is used to filter the sampled outputs for finetuning the transformer to make API calls that decreases perplexity of the text that follows it. 96 | 97 | ```python 98 | import torch 99 | 100 | from toolformer_pytorch import ( 101 | Toolformer, 102 | PaLM, 103 | filter_tokens_with_api_response 104 | ) 105 | 106 | # model 107 | 108 | palm = PaLM( 109 | dim = 512, 110 | num_tokens = 20000, 111 | depth = 2, 112 | heads = 8, 113 | dim_head = 64 114 | ).cuda() 115 | 116 | # mock some tokens 117 | 118 | mock_start_pos = 512 119 | mock_api_call_length = 10 120 | mock_api_start_id = 19998 121 | mock_api_stop_id = 19999 122 | 123 | tokens = torch.randint(0, 20000, (10, 1024)).cuda() 124 | tokens_with_api_response = torch.randint(0, 20000, (10, 1024)).cuda() 125 | tokens_without_api_response = torch.randint(0, 20000, (10, 1024)).cuda() 126 | 127 | tokens_with_api_response[:, mock_start_pos] = mock_api_start_id 128 | tokens_with_api_response[:, mock_start_pos + mock_api_call_length] = mock_api_stop_id 129 | 130 | tokens_without_api_response[:, mock_start_pos] = mock_api_start_id 131 | tokens_without_api_response[:, mock_start_pos + mock_api_call_length] = mock_api_stop_id 132 | 133 | # filter 134 | 135 | filtered_results = filter_tokens_with_api_response( 136 | model = palm, 137 | tokens = tokens, 138 | tokens_with_api_response = tokens_with_api_response, 139 | tokens_without_api_response = tokens_without_api_response, 140 | filter_threshold = 1., 141 | api_start_token_id = mock_api_start_id, 142 | api_end_token_id = mock_api_stop_id 143 | ) 144 | ``` 145 | 146 | To invoke the tools on a string generated by the language model, use `invoke_tools` 147 | 148 | ```python 149 | from toolformer_pytorch import invoke_tools 150 | 151 | def inc(i): 152 | return i + 1 153 | 154 | def dec(i): 155 | return i - 1 156 | 157 | function_registry = dict( 158 | inc = inc, 159 | dec = dec 160 | ) 161 | 162 | text = 'make the following api calls: [inc(1)] and [dec(2)] and [ignored(3)]' 163 | 164 | invoke_tools(function_registry, text) 165 | 166 | # make the following api calls: [inc(1) → 2] and [dec(2) → 1] and [ignored(3)] 167 | ``` 168 | 169 | ## Todo 170 | 171 | - [x] create custom generate function for palm that can do external API calls 172 | - [x] allow for generating tokens at different cursor indices 173 | - [x] api token (which was left and right brackets in paper) needs to be customizable 174 | - [ ] allow for customizing how to fine handling errors in function name, parameters, or execution and output 175 | - [ ] Toolformer should eventually calculate all statistics (how many properly sampled, filtered out by different criterias, the distribution of scores as well as how many were rejected) before the final fine-tuning 176 | - [ ] do end-to-end training in `Toolformer` 177 | - [x] doing the prompting and bootstrapping the data 178 | - [x] prefiltering of bootstrapped data followed by api calls and then another round of filtering 179 | - [ ] keep track of all stats 180 | - [x] take care of fine-tuning 181 | - [ ] interleaving of datasets + optimizer hyperparams 182 | - [ ] hook up gpt-j 183 | - [ ] test for a simple calculator eval dataset 184 | - [ ] add a default callback within the Toolformer that automatically aligns the text and checks for validity before the filtering step - if the text was not copied correctly, the filtering step is not valid. 185 | - [ ] make sure final model, trained on many `Toolformer` instances, can be invoked with multiple tools - start with batch size of 1 and work way up 186 | 187 | ## Citations 188 | 189 | ```bibtex 190 | @inproceedings{Schick2023ToolformerLM, 191 | title = {Toolformer: Language Models Can Teach Themselves to Use Tools}, 192 | author = {Timo Schick and Jane Dwivedi-Yu and Roberto Dessi and Roberta Raileanu and Maria Lomeli and Luke Zettlemoyer and Nicola Cancedda and Thomas Scialom}, 193 | year = {2023} 194 | } 195 | ``` 196 | 197 | ```bibtex 198 | @article{Gao2022PALPL, 199 | title = {PAL: Program-aided Language Models}, 200 | author = {Luyu Gao and Aman Madaan and Shuyan Zhou and Uri Alon and Pengfei Liu and Yiming Yang and Jamie Callan and Graham Neubig}, 201 | journal = {ArXiv}, 202 | year = {2022}, 203 | volume = {abs/2211.10435} 204 | } 205 | ``` 206 | 207 | *Reality is that which, when you stop believing it, doesn't go away.* – Philip K. Dick. 208 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name = 'toolformer-pytorch', 5 | packages = find_packages(exclude=[]), 6 | version = '0.0.30', 7 | license='MIT', 8 | description = 'Toolformer - Pytorch', 9 | author = 'Phil Wang', 10 | author_email = 'lucidrains@gmail.com', 11 | long_description_content_type = 'text/markdown', 12 | url = 'https://github.com/lucidrains/toolformer-pytorch', 13 | keywords = [ 14 | 'artificial intelligence', 15 | 'deep learning', 16 | 'transformers', 17 | 'attention mechanism', 18 | 'automated-tool-use' 19 | ], 20 | install_requires=[ 21 | 'beartype', 22 | 'einops>=0.4', 23 | 'torch>=1.6', 24 | 'tqdm', 25 | 'x-clip>=0.14.3' 26 | ], 27 | classifiers=[ 28 | 'Development Status :: 4 - Beta', 29 | 'Intended Audience :: Developers', 30 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 31 | 'License :: OSI Approved :: MIT License', 32 | 'Programming Language :: Python :: 3.6', 33 | ], 34 | ) 35 | -------------------------------------------------------------------------------- /toolformer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/toolformer-pytorch/27e633217f11bb56a277436b584f2347442869c9/toolformer.png -------------------------------------------------------------------------------- /toolformer_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from toolformer_pytorch.palm import PaLM 2 | 3 | from toolformer_pytorch.toolformer_pytorch import ( 4 | Toolformer, 5 | filter_tokens_with_api_response, 6 | sample, 7 | sample_with_api_call, 8 | has_api_calls, 9 | invoke_tools, 10 | replace_all_but_first 11 | ) 12 | -------------------------------------------------------------------------------- /toolformer_pytorch/optimizer.py: -------------------------------------------------------------------------------- 1 | from torch.optim import AdamW, Adam 2 | 3 | def separate_weight_decayable_params(params): 4 | wd_params, no_wd_params = [], [] 5 | for param in params: 6 | param_list = no_wd_params if param.ndim < 2 else wd_params 7 | param_list.append(param) 8 | return wd_params, no_wd_params 9 | 10 | def get_optimizer( 11 | params, 12 | lr = 1e-4, 13 | wd = 1e-2, 14 | betas = (0.9, 0.99), 15 | eps = 1e-8, 16 | filter_by_requires_grad = False, 17 | group_wd_params = True, 18 | **kwargs 19 | ): 20 | has_weight_decay = wd > 0 21 | 22 | if filter_by_requires_grad: 23 | params = list(filter(lambda t: t.requires_grad, params)) 24 | 25 | if group_wd_params and has_weight_decay: 26 | wd_params, no_wd_params = separate_weight_decayable_params(params) 27 | 28 | params = [ 29 | {'params': wd_params}, 30 | {'params': no_wd_params, 'weight_decay': 0}, 31 | ] 32 | 33 | adam_kwargs = dict(lr = lr, betas = betas, eps = eps) 34 | 35 | if not has_weight_decay: 36 | return Adam(params, **adam_kwargs) 37 | 38 | return AdamW(params, weight_decay = wd, **adam_kwargs) 39 | -------------------------------------------------------------------------------- /toolformer_pytorch/palm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, einsum 3 | from einops import rearrange 4 | 5 | from x_clip.tokenizer import tokenizer 6 | 7 | # helpers 8 | 9 | def exists(val): 10 | return val is not None 11 | 12 | 13 | # normalization 14 | 15 | 16 | class RMSNorm(nn.Module): 17 | def __init__(self, dim, eps = 1e-8): 18 | super().__init__() 19 | self.scale = dim ** -0.5 20 | self.eps = eps 21 | self.g = nn.Parameter(torch.ones(dim)) 22 | 23 | def forward(self, x): 24 | norm = torch.norm(x, dim = -1, keepdim = True) * self.scale 25 | return x / norm.clamp(min = self.eps) * self.g 26 | 27 | 28 | # rotary positional embedding 29 | # https://arxiv.org/abs/2104.09864 30 | 31 | 32 | class RotaryEmbedding(nn.Module): 33 | def __init__(self, dim): 34 | super().__init__() 35 | inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) 36 | self.register_buffer("inv_freq", inv_freq) 37 | 38 | def forward(self, max_seq_len, *, device): 39 | seq = torch.arange(max_seq_len, device=device, dtype=self.inv_freq.dtype) 40 | freqs = einsum("i , j -> i j", seq, self.inv_freq) 41 | return torch.cat((freqs, freqs), dim=-1) 42 | 43 | 44 | def rotate_half(x): 45 | x = rearrange(x, "... (j d) -> ... j d", j=2) 46 | x1, x2 = x.unbind(dim=-2) 47 | return torch.cat((-x2, x1), dim=-1) 48 | 49 | 50 | def apply_rotary_pos_emb(pos, t): 51 | return (t * pos.cos()) + (rotate_half(t) * pos.sin()) 52 | 53 | 54 | # all we need 55 | 56 | 57 | class ParallelTransformerBlock(nn.Module): 58 | def __init__(self, dim, dim_head=64, heads=8, ff_mult=4): 59 | super().__init__() 60 | self.norm = RMSNorm(dim) 61 | 62 | attn_inner_dim = dim_head * heads 63 | ff_inner_dim = dim * ff_mult 64 | self.fused_dims = (attn_inner_dim, dim_head, dim_head, (ff_inner_dim)) 65 | 66 | self.heads = heads 67 | self.scale = dim_head**-0.5 68 | self.rotary_emb = RotaryEmbedding(dim_head) 69 | 70 | self.fused_attn_ff_proj = nn.Linear(dim, sum(self.fused_dims), bias=False) 71 | self.attn_out = nn.Linear(attn_inner_dim, dim, bias=False) 72 | 73 | self.ff_out = nn.Sequential( 74 | nn.GELU(), 75 | nn.Linear(ff_inner_dim, dim, bias=False) 76 | ) 77 | 78 | # for caching causal mask and rotary embeddings 79 | 80 | self.register_buffer("mask", None, persistent=False) 81 | self.register_buffer("pos_emb", None, persistent=False) 82 | 83 | def get_mask(self, n, device): 84 | if self.mask is not None and self.mask.shape[-1] >= n: 85 | return self.mask[:n, :n] 86 | 87 | mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1) 88 | self.register_buffer("mask", mask, persistent=False) 89 | return mask 90 | 91 | def get_rotary_embedding(self, n, device): 92 | if self.pos_emb is not None and self.pos_emb.shape[-2] >= n: 93 | return self.pos_emb[:n] 94 | 95 | pos_emb = self.rotary_emb(n, device=device) 96 | self.register_buffer("pos_emb", pos_emb, persistent=False) 97 | return pos_emb 98 | 99 | def forward(self, x): 100 | """ 101 | einstein notation 102 | b - batch 103 | h - heads 104 | n, i, j - sequence length (base sequence length, source, target) 105 | d - feature dimension 106 | """ 107 | 108 | n, device, h = x.shape[1], x.device, self.heads 109 | 110 | # pre layernorm 111 | 112 | x = self.norm(x) 113 | 114 | # attention queries, keys, values, and feedforward inner 115 | 116 | q, k, v, ff = self.fused_attn_ff_proj(x).split(self.fused_dims, dim=-1) 117 | 118 | # split heads 119 | # they use multi-query single-key-value attention, yet another Noam Shazeer paper 120 | # they found no performance loss past a certain scale, and more efficient decoding obviously 121 | # https://arxiv.org/abs/1911.02150 122 | 123 | q = rearrange(q, "b n (h d) -> b h n d", h=h) 124 | 125 | # rotary embeddings 126 | 127 | positions = self.get_rotary_embedding(n, device) 128 | q, k = map(lambda t: apply_rotary_pos_emb(positions, t), (q, k)) 129 | 130 | # scale 131 | 132 | q = q * self.scale 133 | 134 | # similarity 135 | 136 | sim = einsum("b h i d, b j d -> b h i j", q, k) 137 | 138 | # causal mask 139 | 140 | causal_mask = self.get_mask(n, device) 141 | sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max) 142 | 143 | # attention 144 | 145 | attn = sim.softmax(dim=-1) 146 | 147 | # aggregate values 148 | 149 | out = einsum("b h i j, b j d -> b h i d", attn, v) 150 | 151 | # merge heads 152 | 153 | out = rearrange(out, "b h n d -> b n (h d)") 154 | return self.attn_out(out) + self.ff_out(ff) 155 | 156 | 157 | # Transformer 158 | 159 | 160 | class Transformer(nn.Module): 161 | def __init__( 162 | self, 163 | dim, 164 | depth, 165 | heads, 166 | dim_head, 167 | ff_mult = 4, 168 | ): 169 | super().__init__() 170 | self.layers = nn.ModuleList([]) 171 | 172 | for _ in range(depth): 173 | self.layers.append( 174 | ParallelTransformerBlock(dim, dim_head, heads, ff_mult), 175 | ) 176 | 177 | def forward(self, x): 178 | for block in self.layers: 179 | x = block(x) + x 180 | return x 181 | 182 | 183 | # classes 184 | 185 | class PaLM(nn.Module): 186 | def __init__( 187 | self, 188 | dim, 189 | depth, 190 | num_tokens=tokenizer.vocab_size, 191 | dim_head=64, 192 | heads=8, 193 | ff_mult=4, 194 | ): 195 | super().__init__() 196 | self.emb = nn.Embedding(num_tokens, dim) 197 | 198 | self.transformer = Transformer(dim, depth, heads, dim_head, ff_mult) 199 | 200 | self.to_logits = nn.Sequential( 201 | RMSNorm(dim), 202 | nn.Linear(dim, num_tokens) 203 | ) 204 | 205 | def forward(self, x): 206 | x = self.emb(x) 207 | x = self.transformer(x) 208 | return self.to_logits(x) 209 | 210 | if __name__ == "__main__": 211 | palm = PaLM( 212 | num_tokens = 20000, 213 | dim = 512, 214 | depth = 6, 215 | dim_head = 64, 216 | heads = 8, 217 | ff_mult = 4, 218 | ) 219 | 220 | tokens = torch.randint(0, 20000, (1, 512)) 221 | logits = palm(tokens) 222 | print(logits.shape) 223 | -------------------------------------------------------------------------------- /toolformer_pytorch/prompts.py: -------------------------------------------------------------------------------- 1 | DEFAULT_PROMPT_INPUT_TAG = '[input]' 2 | 3 | calculator_prompt = f""" 4 | Your task is to add calls to a Calculator API to a piece of text. 5 | The calls should help you get information required to complete the text. 6 | You can call the API by writing "[Calculator(expression)]" where "expression" is the expression to be computed. 7 | Here are some examples of API calls: 8 | Input: The number in the next term is 18 + 12 x 3 = 54. 9 | Output: The number in the next term is 18 + 12 x 3 = [Calculator(18 + 12 * 3)] 54. 10 | Input: The population is 658,893 people. This is 11.4% of the national average of 5,763,868 people. 11 | Output: The population is 658,893 people. This is 11.4% of the national average of [Calculator(658,893 / 11.4%)] 5,763,868 people. 12 | Input: A total of 252 qualifying matches were played, and 723 goals were scored (an average of 2.87 per match). This is three times less than the 2169 goals last year. 13 | Output: A total of 252 qualifying matches were played, and 723 goals were scored (an average of [Calculator(723 / 252)] 2.87 per match). This is twenty goals more than the [Calculator(723 - 20)] 703 goals last year. 14 | Input: I went to Paris in 1994 and stayed there until 2011, so in total, it was 17 years. 15 | Output: I went to Paris in 1994 and stayed there until 2011, so in total, it was [Calculator(2011 - 1994)] 17 years. 16 | Input: From this, we have 4 * 30 minutes = 120 minutes. 17 | Output: From this, we have 4 * 30 minutes = [Calculator(4 * 30)] 120 minutes. 18 | Input: {DEFAULT_PROMPT_INPUT_TAG} 19 | Output: 20 | """ 21 | 22 | wikipedia_search_prompt = f""" 23 | Your task is to complete a given piece of text. 24 | You can use a Wikipedia Search API to look up information. 25 | You can do so by writing "[WikiSearch(term)]" where "term" is the search term you want to look up. 26 | Here are some examples of API calls: 27 | Input: The colors on the flag of Ghana have the following meanings: red is for the blood of martyrs, green for forests, and gold for mineral wealth. 28 | Output: The colors on the flag of Ghana have the following meanings: red is for [WikiSearch("Ghana flag red meaning")] the blood of martyrs, green for forests, and gold for mineral wealth. 29 | Input: But what are the risks during production of nanomaterials? Some nanomaterials may give rise to various kinds of lung damage. 30 | Output: But what are the risks during production of nanomaterials? [WikiSearch("nanomaterial production risks")] Some nanomaterials may give rise to various kinds of lung damage. 31 | Input: Metformin is the first-line drug for patients with type 2 diabetes and obesity. 32 | Output: Metformin is the first-line drug for [WikiSearch("Metformin first-line drug")] patients with type 2 diabetes and obesity. 33 | Input: {DEFAULT_PROMPT_INPUT_TAG} 34 | Output: 35 | """ 36 | 37 | machine_translation_prompt = f""" 38 | Your task is to complete a given piece of text by using a Machine Translation API. 39 | You can do so by writing "[MT(text)]" where text is the text to be translated into English. 40 | Here are some examples: 41 | Input: He has published one book: O homem suprimido (“The Supressed Man”) 42 | Output: He has published one book: O homem suprimido [MT(O homem suprimido)] (“The Supressed Man”) 43 | Input: In Morris de Jonge’s Jeschuah, der klassische jüdische Mann, there is a description of a Jewish writer 44 | Output: In Morris de Jonge’s Jeschuah, der klassische jüdische Mann [MT(der klassische jüdische Mann)], there is a description of a Jewish writer 45 | Input: 南 京 高 淳 县 住 房 和 城 乡 建 设 局 城 市 新 区 设 计 a plane of reference Gaochun is one of seven districts of the provincial capital Nanjing 46 | Output: [MT(南京高淳县住房和城乡建设局 城市新 区 设 计)] a plane of reference Gaochun is one of seven districts of the provincial capital Nanjing 47 | Input: {DEFAULT_PROMPT_INPUT_TAG} 48 | Output: 49 | """ 50 | 51 | calendar_prompt = f""" 52 | Your task is to add calls to a Calendar API to a piece of text. 53 | The API calls should help you get information required to complete the text. 54 | You can call the API by writing "[Calendar()]" 55 | Here are some examples of API calls: 56 | Input: Today is the first Friday of the year. 57 | Output: Today is the first [Calendar()] Friday of the year. 58 | Input: The president of the United States is Joe Biden. 59 | Output: The president of the United States is [Calendar()] Joe Biden. 60 | Input: The current day of the week is Wednesday. 61 | Output: The current day of the week is [Calendar()] Wednesday. 62 | Input: The number of days from now until Christmas is 30. 63 | Output: The number of days from now until Christmas is [Calendar()] 30. 64 | Input: The store is never open on the weekend, so today it is closed. 65 | Output: The store is never open on the weekend, so today [Calendar()] it is closed. 66 | Input: {DEFAULT_PROMPT_INPUT_TAG} 67 | Output: 68 | """ 69 | -------------------------------------------------------------------------------- /toolformer_pytorch/toolformer_pytorch.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | from functools import partial, wraps 4 | from collections import namedtuple 5 | 6 | import torch 7 | from torch import nn 8 | import torch.nn.functional as F 9 | 10 | from torch.utils.data import Dataset, DataLoader 11 | from torch.nn.utils.rnn import pad_sequence 12 | 13 | from einops import rearrange, reduce 14 | 15 | from toolformer_pytorch.palm import PaLM 16 | from toolformer_pytorch.optimizer import get_optimizer 17 | from toolformer_pytorch.prompts import DEFAULT_PROMPT_INPUT_TAG 18 | 19 | from beartype import beartype 20 | from beartype.typing import Callable, Optional, Union, List, Tuple 21 | 22 | from tqdm import tqdm 23 | from x_clip.tokenizer import tokenizer 24 | 25 | pad_sequence = partial(pad_sequence, batch_first = True) 26 | 27 | # helpers 28 | 29 | def exists(val): 30 | return val is not None 31 | 32 | def default(val, d): 33 | return val if exists(val) else d 34 | 35 | def identity(t): 36 | return t 37 | 38 | def always(val): 39 | def inner(*args, **kwargs): 40 | return val 41 | return inner 42 | 43 | def try_except(fn, callback = identity): 44 | @wraps(fn) 45 | def inner(*args): 46 | try: 47 | return fn(*args) 48 | except Exception as e: 49 | return callback(e) 50 | return inner 51 | 52 | # tensor helpers 53 | 54 | def log(t, eps = 1e-20): 55 | return t.clamp(min = eps).log() 56 | 57 | def gumbel_noise(t): 58 | noise = torch.zeros_like(t).uniform_(0, 1) 59 | return -log(-log(noise)) 60 | 61 | def gumbel_sample(t, temperature = 1., dim = -1, eps = 1e-10): 62 | if temperature == 0: 63 | return t.argmax(dim = dim) 64 | 65 | return ((t / max(temperature, eps)) + gumbel_noise(t)).argmax(dim = dim) 66 | 67 | def top_k(logits, thres = 0.9): 68 | k = math.ceil((1 - thres) * logits.shape[-1]) 69 | val, indices = torch.topk(logits, k) 70 | probs = torch.full_like(logits, -torch.finfo(logits.dtype).max) 71 | probs.scatter_(1, indices, val) 72 | return probs 73 | 74 | def all_contains_id(t: torch.Tensor, token_id: int): 75 | mask = t == token_id 76 | return mask.any(dim = -1).all() 77 | 78 | def find_indices_of(t: torch.Tensor, token_id: int, occurrence = 1): 79 | assert occurrence > 0 80 | mask = (t == token_id) 81 | 82 | has_occurred = mask.cumsum(dim = -1) 83 | has_occurred = F.pad(has_occurred, (1, 0), value = 0.) 84 | 85 | return (has_occurred < occurrence).sum(dim = -1).long() 86 | 87 | # invoking api call functions 88 | 89 | def is_valid_string(s): 90 | return exists(re.fullmatch(r"'[^']*'|\"[^\"]*\"", s)) 91 | 92 | def is_valid_integer(s): 93 | return exists(re.fullmatch(r"[+-]?\d+", s)) 94 | 95 | def is_valid_float(s): 96 | return exists(re.fullmatch(r"[+-]?\d+(\.\d+)?", s)) 97 | 98 | def parse_param(s: str) -> Optional[Union[int, float, str]]: 99 | if is_valid_string(s): 100 | return str(s) 101 | elif is_valid_integer(s): 102 | return int(s) 103 | elif is_valid_float(s): 104 | return float(s) 105 | 106 | return None 107 | 108 | @beartype 109 | def replace_fn( 110 | registry: dict[str, Callable], 111 | matches, 112 | delimiter = '→' 113 | ): 114 | orig_text = matches.group(0) 115 | 116 | text_without_end_api_token = matches.group(1) 117 | end_api_token = matches.group(4) 118 | function_name = matches.group(2) 119 | 120 | # unable to find function in registry 121 | 122 | if function_name not in registry: 123 | return orig_text 124 | 125 | fn = registry[function_name] 126 | 127 | params = matches.group(3).split(',') 128 | params = list(map(lambda s: s.strip(), params)) 129 | params = list(filter(len, params)) 130 | params = list(map(parse_param, params)) 131 | 132 | # if any of the parameters are not parseable, return 133 | 134 | if any([(not exists(p)) for p in params]): 135 | return orig_text 136 | 137 | # just return original text if there is some error with the function 138 | 139 | out = try_except(fn, always(None))(*params) 140 | 141 | # the api calling function can also arrest the process, by returning None 142 | 143 | if not exists(out): 144 | return orig_text 145 | 146 | # return original text with the output delimiter and the stringified output 147 | 148 | return f'{text_without_end_api_token} {delimiter} {str(out)} {end_api_token}' 149 | 150 | # main function, which takes a registry of functions, the text in question, and makes all the appropriate api calls and append the output 151 | 152 | def create_function_regex( 153 | api_start = ' [', 154 | api_stop = ']' 155 | ): 156 | api_start_regex, api_stop_regex = map(re.escape, (api_start, api_stop)) 157 | return rf'({api_start_regex}(\w+)\(([^)]*)\))({api_stop_regex})' 158 | 159 | def num_matches(substr: str, text: str): 160 | return len(re.findall(re.escape(substr), text)) 161 | 162 | def has_api_calls( 163 | text, 164 | api_start = ' [', 165 | api_stop = ']' 166 | ): 167 | regex = create_function_regex(api_start, api_stop) 168 | matches = re.findall(regex, text) 169 | return len(matches) > 0 170 | 171 | def replace_all_but_first( 172 | text: str, 173 | api_start = ' [', 174 | api_stop = ']' 175 | ) -> str: 176 | regex = create_function_regex(api_start, api_stop) 177 | 178 | count = 0 179 | 180 | def replace_(matches): 181 | orig_text = matches.group(0) 182 | nonlocal count 183 | count += 1 184 | if count > 1: 185 | return '' 186 | return orig_text 187 | 188 | return re.sub(regex, replace_, text) 189 | 190 | def invoke_tools( 191 | registry: dict[str, Callable], 192 | text: str, 193 | delimiter: str = '→', 194 | api_start = ' [', 195 | api_stop = ' ]' 196 | ) -> str: 197 | regex = create_function_regex(api_start, api_stop) 198 | replace_ = partial(replace_fn, registry, delimiter = delimiter) 199 | return re.sub(regex, replace_, text) 200 | 201 | def invoke_tools_on_batch_sequences( 202 | registry: dict[str, Callable], 203 | token_ids: torch.Tensor, 204 | *, 205 | encode: Callable, 206 | decode: Callable, 207 | delimiter: str = '→', 208 | api_start = ' [', 209 | api_stop = ']' 210 | ) -> torch.Tensor: 211 | regex = create_function_regex(api_start_regex, api_stop_regex) 212 | all_texts = [decode(one_seq_token_ids) for one_seq_token_ids in token_ids] 213 | 214 | invoke_tools_ = partial(invoke_tools, api_start = api_start, api_stop = api_stop) 215 | all_texts_with_api_calls = [invoke_tools_(registry, text, delimiter) for text in all_texts] 216 | 217 | return encode(all_texts_with_api_calls) 218 | 219 | # sampling api related functions 220 | # they do greedy sampling, but encourage sampling api calls by auto-selecting when that token is in the top k = 10 221 | 222 | @beartype 223 | @torch.no_grad() 224 | def sample( 225 | model: nn.Module, 226 | *, 227 | seq_len, 228 | prime: Optional[torch.Tensor] = None, 229 | positions: Optional[torch.Tensor] = None, 230 | batch_size = 1, 231 | eos_token_id = None, 232 | sos_token_id = 1, 233 | temperature = 0., 234 | pad_id = 0, 235 | call_api_only_once = False, 236 | api_start_token_id = None, 237 | auto_select_api_start_token_when_topk = False, 238 | select_api_start_id_top_k = 10, 239 | ): 240 | device = next(model.parameters()).device 241 | max_seq_len = seq_len + 1 242 | 243 | # validate 244 | 245 | if call_api_only_once: 246 | assert exists(api_start_token_id) 247 | 248 | # prime 249 | 250 | if exists(prime): 251 | batch_size, prime_length = prime.shape 252 | else: 253 | prime_length = 1 254 | prime = torch.full((batch_size, 1), sos_token_id, device = device, dtype = torch.long) 255 | 256 | prime = prime.to(device) 257 | 258 | # sampling positions - different sequences have different cursors 259 | 260 | if exists(positions): 261 | positions = positions.clone() 262 | else: 263 | positions = torch.zeros((batch_size,), device = device, dtype = torch.long) 264 | 265 | assert (positions <= (prime_length + 1)).all() and (positions <= max_seq_len).all(), 'all positions must be less then initial prime length as well as the total sequence length + 1 (plus one for noop if one sequence finished sampling before the other)' 266 | 267 | # eval model 268 | 269 | model.eval() 270 | 271 | # lengthen the prime to the entire sequence length 272 | 273 | remain_iterations = seq_len - prime_length 274 | output = F.pad(prime, (0, max_seq_len - prime_length), value = 0.) 275 | 276 | batch_indices = torch.arange(batch_size, device = device) 277 | batch_indices = rearrange(batch_indices, 'b -> b 1') 278 | position_indices = rearrange(positions, 'b -> b 1') 279 | 280 | # determine the token mask, for making sure api is called only once, masking out logit to prevent it from being selected for those rows which already contains an token 281 | 282 | api_token_mask = None # lazily created, since do not know logit dimensions 283 | 284 | def create_api_token_mask(num_tokens, api_start_token_id): 285 | mask = torch.zeros((1, 1, num_tokens), dtype = torch.bool) 286 | assert api_start_token_id < num_tokens 287 | mask[..., api_start_token_id] = True 288 | return mask 289 | 290 | # start iterating 291 | 292 | for iteration in tqdm(range(remain_iterations)): 293 | logits = model(output) 294 | last_logits = logits[batch_indices, position_indices] 295 | 296 | # this will ensure that each batch token sequence will have at most one token 297 | 298 | if call_api_only_once: 299 | if not exists(api_token_mask): 300 | num_tokens = last_logits.shape[-1] 301 | api_token_mask = create_api_token_mask(num_tokens, api_start_token_id) 302 | api_token_mask = api_token_mask.to(device) 303 | 304 | api_called = (output == api_start_token_id).any(dim = -1) 305 | 306 | logit_mask = api_token_mask & rearrange(api_called, 'b -> b 1 1') 307 | last_logits = last_logits.masked_fill(logit_mask, -torch.finfo(last_logits.dtype).max) 308 | 309 | # greedy sample (but could be made non-greedy) 310 | 311 | sampled = gumbel_sample(last_logits, temperature = temperature) 312 | 313 | # for those sequences without an api call, if the api_start_token_id is within top k (set to 10 in paper) of logits, just auto-select 314 | 315 | # seems to be an important hack in the paper 316 | # it seems like this paper will take a lot more follow up research to be viable 317 | 318 | if auto_select_api_start_token_when_topk: 319 | top_token_ids = last_logits.topk(select_api_start_id_top_k, dim = -1).indices 320 | has_api_token_in_topk = (top_token_ids == api_start_token_id).any(dim = -1) 321 | should_auto_select_api_token = has_api_token_in_topk & ~rearrange(api_called, 'b -> b 1') 322 | 323 | sampled = sampled.masked_fill(should_auto_select_api_token, api_start_token_id) 324 | 325 | # set the sampled tokens at the right curosr positions 326 | 327 | output[batch_indices, position_indices] = sampled 328 | 329 | # increment positions 330 | 331 | position_indices += 1 332 | position_indices.clamp_(max = seq_len) # noop if one sequence is further along and near the end 333 | 334 | # if using tokens, look for all sequences having it and terminate, also anything after will be padded 335 | 336 | if exists(eos_token_id): 337 | eos_mask = (output == eos_token_id) 338 | all_rows_have_eos = eos_mask.any(dim = -1).all() 339 | 340 | if all_rows_have_eos: 341 | keep_mask = eos_mask.cumsum(dim = -1) == 0 342 | keep_mask = F.pad(keep_mask, (1, 0), value = True) 343 | output = output.masked_fill(~keep_mask, pad_id) 344 | break 345 | 346 | # remove the last token in output (use as noop placeholder) 347 | 348 | output = output[:, :-1] 349 | return output 350 | 351 | @beartype 352 | @torch.no_grad() 353 | def sample_with_api_call( 354 | model: nn.Module, 355 | *, 356 | seq_len, 357 | call_apis: Callable, 358 | prime: torch.Tensor, 359 | api_end_token_id: int, 360 | occurrence = 1, 361 | **kwargs 362 | ): 363 | sampled = sample( 364 | model = model, 365 | prime = prime, 366 | seq_len = seq_len, 367 | **kwargs 368 | ) 369 | 370 | sampled = call_apis(sampled) 371 | 372 | sampled_seq_len = sampled.shape[-1] 373 | null_positions = sampled_seq_len # handle sequences that do not have api calls 374 | 375 | pos_starting_at_end_of_api = find_indices_of( 376 | sampled, 377 | api_end_token_id, 378 | occurrence = occurrence 379 | ) 380 | 381 | resample_after_api_calls = sample( 382 | model = model, 383 | prime = sampled, 384 | seq_len = sampled_seq_len, 385 | positions = (pos_starting_at_end_of_api + 1).clamp(max = null_positions), # start at the position right after the 386 | **kwargs 387 | ) 388 | 389 | return resample_after_api_calls 390 | 391 | # the main contribution of the paper is simply the filtering equations presented in section 2 392 | 393 | def default_weight_fn(t): 394 | # following the formula in section 4.1 - however, not sure what w_s is in the denominator 395 | # if t stands for each timestep, this would also mean within 5 tokens it would diminish to 0? 396 | return (1. - t * 0.2).clamp(min = 0.) 397 | 398 | def get_pred_prob(token_ids, logits): 399 | logits = logits[:, :-1] # logits of each token... (omit last logit) 400 | token_ids = token_ids[:, 1:] # predicts the next token id (omit first token id) 401 | 402 | token_ids = rearrange(token_ids, 'b n -> b n 1') 403 | probs = logits.softmax(dim = -1) 404 | correct_token_id_pred_prob = probs.gather(-1, token_ids) 405 | return rearrange(correct_token_id_pred_prob, 'b n 1 -> b n') 406 | 407 | def get_arange_start_at_token_id( 408 | token_ids: torch.Tensor, 409 | token_id: int, 410 | pad_id = -1 411 | ): 412 | is_token_id_mask = token_ids == token_id 413 | arange = (is_token_id_mask.cumsum(dim = -1) > 0).cumsum(dim = -1) 414 | before_token_mask = arange == 0 415 | arange = arange - 1 416 | arange = arange.masked_fill(before_token_mask, pad_id) 417 | return arange 418 | 419 | def weight_and_mask( 420 | token_ids: torch.Tensor, 421 | token_id: int, 422 | pad_id = -1, 423 | weighting_fn: Callable = default_weight_fn 424 | ): 425 | t = get_arange_start_at_token_id(token_ids, token_id, pad_id) 426 | weights = weighting_fn(t) 427 | return weights.masked_fill(t == pad_id, 0.) 428 | 429 | FilteredResults = namedtuple('FilteredResults', [ 430 | 'num_passed', 431 | 'num_failed', 432 | 'selected_indices', 433 | 'selected_mask', 434 | 'filtered_tokens', 435 | 'filtered_tokens_without_api_response', 436 | 'filtered_tokens_with_api_response' 437 | ]) 438 | 439 | @beartype 440 | def filter_tokens_with_api_response( 441 | model: nn.Module, # the language model should accept the token ids below and return the logits in shape (batch, seq, num tokens) 442 | *, 443 | tokens: torch.Tensor, # token ids (batch, seq) of the original passage, without api calls 444 | tokens_without_api_response: torch.Tensor, # token ids (batch, seq) of the passage, but with the api call (but without a response filled in) - tool1(x, y) 445 | tokens_with_api_response: torch.Tensor, # token ids (batch, seq) of the passage with api call and the response - tool1(x, y) → {response} 446 | api_start_token_id: int, # token id of the tag 447 | api_end_token_id: int, # token id of the tag 448 | filter_threshold: float = 1., # the threshold at which to accept the sampled api call (tokens_with_api_response) for fine-tuning 449 | weighting_fn: Callable = default_weight_fn # weighting function 450 | ) -> FilteredResults: 451 | 452 | # validations 453 | 454 | assert all([*map(lambda t: t.dtype == torch.long, (tokens, tokens_with_api_response, tokens_without_api_response))]) 455 | 456 | assert all_contains_id(tokens_without_api_response, api_start_token_id) 457 | assert all_contains_id(tokens_without_api_response, api_end_token_id) 458 | 459 | assert all_contains_id(tokens_with_api_response, api_start_token_id) 460 | assert all_contains_id(tokens_with_api_response, api_end_token_id) 461 | 462 | # auto set devices 463 | 464 | device = next(model.parameters()).device 465 | tokens, tokens_without_api_response, tokens_with_api_response = map(lambda t: t.to(device), (tokens, tokens_without_api_response, tokens_with_api_response)) 466 | 467 | # get all the logits 468 | 469 | with torch.no_grad(): 470 | model.eval() 471 | logits, logits_without_api_response, logits_with_api_response = map(model, (tokens, tokens_without_api_response, tokens_with_api_response)) 472 | 473 | # derive all predicted prob of the actual next token id in sequence 474 | 475 | probs = get_pred_prob(tokens, logits) 476 | probs_without_api_response = get_pred_prob(tokens_without_api_response, logits_without_api_response) 477 | probs_with_api_response = get_pred_prob(tokens_with_api_response, logits_with_api_response) 478 | 479 | weight_and_mask_fn = partial(weight_and_mask, weighting_fn = weighting_fn) 480 | 481 | # derive the weighting 482 | 483 | weight_without_api_response = weight_and_mask_fn(tokens_without_api_response[:, :-1], api_end_token_id) 484 | weight_with_api_response = weight_and_mask_fn(tokens_with_api_response[:, :-1], api_end_token_id) 485 | 486 | # deriving the weighting for the original passage is more tricky 487 | # would need to start counting up from start token location 488 | # this would also assume that the language model perfectly copied the passage over and that both token ids are aligned except for the inserted API call - but this can be done with the custom filtering functions eventually 489 | 490 | weight = weight_and_mask_fn(tokens_without_api_response[:, 1:], api_start_token_id) # shift to the left by one since does not exist in the original sequence 491 | weight = weight[:, :probs.shape[-1]] 492 | 493 | # get the loss L for all three types of sequences 494 | 495 | def loss_fn(weight, probs): 496 | return (weight * -log(probs)).sum(dim = -1) 497 | 498 | loss = loss_fn(weight, probs) 499 | loss_without_api_response = loss_fn(weight_without_api_response, probs_without_api_response) 500 | loss_with_api_response = loss_fn(weight_with_api_response, probs_with_api_response) 501 | 502 | # calculate the main formula in the paper 503 | 504 | # loss+ = loss with api response 505 | # loss- = min(loss without api response, loss without api at all) 506 | 507 | loss_plus = loss_with_api_response 508 | loss_minus = torch.minimum(loss_without_api_response, loss) 509 | 510 | selected_mask = (loss_minus - loss_plus) >= filter_threshold 511 | 512 | # now we can select and return the entries that survived the filtering stage 513 | # also returning the selected indices of the batch being processed 514 | # for finetuning the model into toolformer 515 | 516 | batch = tokens.shape[0] 517 | indices = torch.arange(batch, device = tokens.device) 518 | 519 | selected_indices = indices[selected_mask] 520 | 521 | ret = FilteredResults( 522 | selected_mask.sum().item(), 523 | (~selected_mask).sum().item(), 524 | selected_indices, 525 | selected_mask, 526 | tokens[selected_mask], 527 | tokens_without_api_response[selected_mask], 528 | tokens_with_api_response[selected_mask] 529 | ) 530 | 531 | return ret 532 | 533 | # datasets and dataloaders 534 | 535 | # for bootstrapping the initial datasets with api calls 536 | # as well as for the final finetuning 537 | 538 | @beartype 539 | class PromptDataset(Dataset): 540 | def __init__( 541 | self, 542 | prompt: str, 543 | prompt_input_tag: str, 544 | data: List[str], 545 | tokenizer_encode: Callable 546 | ): 547 | self.data = data 548 | self.prompt = prompt 549 | self.prompt_input_tag_regex = re.escape(prompt_input_tag) 550 | self.tokenizer_encode = tokenizer_encode 551 | 552 | def __len__(self): 553 | return len(self.data) 554 | 555 | def __getitem__(self, idx): 556 | data_string = self.data[idx] 557 | data_with_prompt = re.sub(self.prompt_input_tag_regex, data_string, self.prompt) 558 | token_ids = self.tokenizer_encode(data_with_prompt) 559 | return torch.tensor(token_ids).long(), torch.tensor(len(token_ids)).long() 560 | 561 | def prompt_collate_fn(data, padding_value = 0): 562 | prompts, prompt_lengths = zip(*data) 563 | prompts = pad_sequence(prompts, padding_value = padding_value) 564 | return prompts, torch.stack(prompt_lengths) 565 | 566 | def PromptDataloader(ds: Dataset, *args, padding_value = 0, **kwargs): 567 | collate_fn = partial(prompt_collate_fn, padding_value = padding_value) 568 | return DataLoader(ds, *args, collate_fn = collate_fn, **kwargs) 569 | 570 | class FinetuneDataset(Dataset): 571 | def __init__( 572 | self, 573 | tokens: torch.Tensor 574 | ): 575 | self.tokens = tokens 576 | 577 | def __len__(self): 578 | return len(self.tokens) 579 | 580 | def __getitem__(self, idx): 581 | return self.tokens[idx] 582 | 583 | def FinetuneDataloader(ds: Dataset, *args, padding_value = 0, **kwargs): 584 | return DataLoader(ds, *args, collate_fn = partial(pad_sequence, padding_value = padding_value), **kwargs) 585 | 586 | # classes 587 | 588 | @beartype 589 | class Toolformer(nn.Module): 590 | def __init__( 591 | self, 592 | model: nn.Module, 593 | *, 594 | tool_id: str, 595 | tool: Callable, 596 | api_start_str = ' [', 597 | api_stop_str = ']', 598 | api_response_delimiter = '→', 599 | api_start_id = None, 600 | api_stop_id = None, 601 | teach_tool_prompt: str, 602 | filter_threshold = 1., 603 | pad_id = 0, 604 | prompt_batch_size = 4, 605 | model_seq_len = 2048, 606 | tokenizer_encode: Callable = tokenizer.encode, 607 | tokenizer_decode: Callable = tokenizer.decode, 608 | post_prompt_callback: Callable = identity, 609 | prompt_input_tag: str = DEFAULT_PROMPT_INPUT_TAG, 610 | exclude_filters: dict[str, Callable[[str], bool]] = dict(), 611 | finetune = False, 612 | finetune_lr = 1e-4, 613 | finetune_wd = 1e-2, 614 | finetune_betas = (0.9, 0.99), 615 | finetune_eps = 1e-8, 616 | finetune_epochs = 3, 617 | finetune_batch_size = 16 618 | ): 619 | super().__init__() 620 | self.model = model 621 | self.model_seq_len = model_seq_len 622 | 623 | self.teach_tool_prompt = teach_tool_prompt 624 | self.prompt_batch_size = prompt_batch_size 625 | self.prompt_input_tag = prompt_input_tag 626 | 627 | self.post_prompt_callback = post_prompt_callback # for easy mocking 628 | 629 | self.tokenizer_encode = tokenizer_encode 630 | self.tokenizer_decode = tokenizer_decode 631 | self.tokenizer_encode_to_tensor = lambda s: torch.tensor(tokenizer_encode(s)).long() 632 | 633 | self.filter_threshold = filter_threshold 634 | 635 | self.api_start_str = api_start_str 636 | self.api_stop_str = api_stop_str 637 | self.api_response_delimiter = api_response_delimiter 638 | 639 | if not exists(api_start_id): 640 | api_start_id = tokenizer_encode(api_start_str) 641 | assert len(api_start_id) == 1 642 | api_start_id = api_start_id[0] 643 | 644 | self.api_start_id = api_start_id 645 | 646 | if not exists(api_stop_id): 647 | api_stop_id = tokenizer_encode(api_stop_str) 648 | assert len(api_stop_id) == 1 649 | api_stop_id = api_stop_id[0] 650 | 651 | self.api_stop_id = api_stop_id 652 | 653 | self.pad_id = pad_id 654 | 655 | self.tool_id = tool_id 656 | self.tool = tool 657 | self.registry = {tool_id: tool} 658 | 659 | assert num_matches(prompt_input_tag, teach_tool_prompt) == 1, f'there must be exactly one prompt input tag `{prompt_input_tag}` in your prompt to encourage the language model to use the designated tool' 660 | 661 | self.teach_tool_prompt = teach_tool_prompt 662 | self.exclude_filters = exclude_filters 663 | 664 | self.should_finetune = finetune 665 | 666 | if not finetune: 667 | return 668 | 669 | self.finetune_batch_size = finetune_batch_size 670 | self.finetune_epochs = finetune_epochs 671 | 672 | self.optimizer = get_optimizer( 673 | model.parameters(), 674 | lr = finetune_lr, 675 | wd = finetune_wd, 676 | betas = finetune_betas, 677 | eps = finetune_eps 678 | ) 679 | 680 | def generate_data_with_api_calls( 681 | self, 682 | data: List[str], 683 | temperature: float = 0.9 684 | ) -> List[str]: 685 | 686 | dataset = PromptDataset( 687 | data = data, 688 | prompt_input_tag = self.prompt_input_tag, 689 | prompt = self.teach_tool_prompt, 690 | tokenizer_encode = self.tokenizer_encode 691 | ) 692 | 693 | dl = PromptDataloader( 694 | dataset, 695 | batch_size = self.prompt_batch_size 696 | ) 697 | 698 | prompted_outputs = [] 699 | 700 | for prime, positions in dl: 701 | 702 | sampled_outputs = sample( 703 | model = self.model, 704 | prime = prime, 705 | positions = positions, 706 | seq_len = self.model_seq_len, 707 | pad_id = self.pad_id, 708 | temperature = temperature 709 | ) 710 | 711 | for sample_output, position in zip(sampled_outputs, positions): 712 | start_position = position.item() 713 | 714 | prompted_output = self.tokenizer_decode(sample_output[start_position:]) 715 | prompted_outputs.append(prompted_output) 716 | 717 | return self.post_prompt_callback(prompted_outputs) 718 | 719 | def filter_and_keep_only_first_api_call( 720 | self, 721 | data, 722 | data_with_api_calls: List[str], 723 | return_excluded = False 724 | ): 725 | included_data = [] 726 | included_data_with_api_calls = [] 727 | 728 | included = (included_data, included_data_with_api_calls) 729 | 730 | excluded_data = [] 731 | excluded_data_with_api_calls = [] 732 | 733 | excluded = (excluded_data, excluded_data_with_api_calls) 734 | 735 | api_start_stop_kwargs = dict(api_start = self.api_start_str, api_stop = self.api_stop_str) 736 | 737 | has_api_calls_ = partial(has_api_calls, **api_start_stop_kwargs) 738 | replace_all_but_first_ = partial(replace_all_but_first, **api_start_stop_kwargs) 739 | 740 | for datum, data_with_api_call in zip(data, data_with_api_calls): 741 | if has_api_calls_(data_with_api_call): 742 | data_with_api_call = replace_all_but_first_(data_with_api_call) 743 | 744 | included_data.append(datum) 745 | included_data_with_api_calls.append(data_with_api_call) 746 | else: 747 | excluded_data.append(datum) 748 | excluded_data_with_api_calls.append(data_with_api_call) 749 | 750 | if not return_excluded: 751 | return included 752 | 753 | return included, excluded 754 | 755 | @torch.no_grad() 756 | def sample_model_with_api_calls( 757 | self, 758 | prime: Union[torch.Tensor, str], 759 | occurrence = 1, 760 | **kwargs 761 | ): 762 | self.model.eval() 763 | 764 | prime_is_str = isinstance(prime, str) 765 | 766 | if prime_is_str: 767 | prime = self.tokenizer_encode(prime) 768 | prime = torch.tensor(prime).long() 769 | prime = rearrange(prime, 'n -> 1 n') 770 | 771 | assert prime.shape[0] == 1, 'only one at a time for now' 772 | 773 | invoke_tools_ = partial(invoke_tools, self.registry) 774 | 775 | def call_apis(t: torch.Tensor): 776 | t = self.tokenizer_decode(t[0]) 777 | t = invoke_tools_(t) 778 | t = self.tokenizer_encode_to_tensor(t) 779 | return rearrange(t, 'n -> 1 n') 780 | 781 | output = sample_with_api_call( 782 | model = self.model, 783 | prime = prime, 784 | seq_len = self.model_seq_len, 785 | call_apis = call_apis, 786 | api_end_token_id = self.api_stop_id, 787 | occurrence = occurrence, 788 | **kwargs 789 | ) 790 | 791 | if not prime_is_str: 792 | return output 793 | 794 | return self.tokenizer_decode(output[0]) 795 | 796 | def make_api_calls( 797 | self, 798 | filtered_data_with_api_calls: List[str] 799 | ): 800 | invoke_tools_ = partial( 801 | invoke_tools, 802 | self.registry, 803 | api_start = self.api_start_str, 804 | api_stop = self.api_stop_str, delimiter = self.api_response_delimiter 805 | ) 806 | 807 | data_with_api_responses = [] 808 | for data in filtered_data_with_api_calls: 809 | output = invoke_tools_(data) 810 | data_with_api_responses.append(output) 811 | 812 | return data_with_api_responses 813 | 814 | def filter_by_api_responses( 815 | self, 816 | data: List[str], 817 | data_with_api_calls: List[str], 818 | data_with_api_responses: List[str] 819 | ) -> FilteredResults: 820 | 821 | to_token_ids = lambda l: pad_sequence([*map(self.tokenizer_encode_to_tensor, l)], padding_value = self.pad_id) 822 | 823 | tokens, tokens_without_api_response, tokens_with_api_response = map(to_token_ids, (data, data_with_api_calls, data_with_api_responses)) 824 | 825 | filtered_results = filter_tokens_with_api_response( 826 | model = self.model, 827 | tokens = tokens, 828 | tokens_with_api_response = tokens_with_api_response, 829 | tokens_without_api_response = tokens_without_api_response, 830 | filter_threshold = self.filter_threshold, 831 | api_start_token_id = self.api_start_id, 832 | api_end_token_id = self.api_stop_id 833 | ) 834 | 835 | return filtered_results 836 | 837 | def finetune( 838 | self, 839 | filtered_results: Union[FilteredResults, torch.Tensor] 840 | ): 841 | self.model.train() 842 | 843 | if isinstance(filtered_results, FilteredResults): 844 | filtered_results = filtered_results.filtered_tokens_without_api_response 845 | 846 | dataset = FinetuneDataset(tokens = filtered_results) 847 | dl = FinetuneDataloader(dataset, batch_size = self.finetune_batch_size, shuffle = True) 848 | 849 | for epoch in tqdm(range(self.finetune_epochs), desc = 'finetune epochs'): 850 | for batch in dl: 851 | inp, labels = batch[:, :-1], batch[:, 1:] 852 | 853 | logits = self.model(inp) 854 | logits = rearrange(logits, 'b n c -> b c n') 855 | 856 | loss = F.cross_entropy(logits, labels, ignore_index = self.pad_id) 857 | loss.backward() 858 | 859 | print(f'loss: {loss.item()}') 860 | self.optimizer.step() 861 | self.optimizer.zero_grad() 862 | 863 | print(f'finished finetuning on {len(dataset)} filtered samples') 864 | 865 | def forward( 866 | self, 867 | data: List[str], 868 | return_after_generating_api_calls = False, 869 | return_after_making_api_calls = False, 870 | return_after_filtering_api_calls = False, 871 | return_after_filtering_by_api_response = False 872 | ): 873 | data_with_api_calls = self.generate_data_with_api_calls(data) 874 | 875 | if return_after_generating_api_calls: 876 | return data_with_api_calls 877 | 878 | filtered_data, filtered_data_with_api_calls = self.filter_and_keep_only_first_api_call(data, data_with_api_calls) 879 | 880 | if return_after_filtering_api_calls: 881 | return filtered_data, filtered_data_with_api_calls 882 | 883 | assert len(filtered_data_with_api_calls) > 0, 'your model failed to follow instructions and make API calls. please try a better model or do some better prompt engineering' 884 | 885 | data_with_responses = self.make_api_calls(filtered_data_with_api_calls) 886 | 887 | if return_after_making_api_calls: 888 | return filtered_data, filtered_data_with_api_calls, data_with_responses 889 | 890 | filtered_results = self.filter_by_api_responses(filtered_data, filtered_data_with_api_calls, data_with_responses) 891 | 892 | if return_after_filtering_by_api_response: 893 | return filtered_results 894 | 895 | if self.should_finetune: 896 | assert filtered_results.num_passed > 0, f'none of the sequences with API calls passed the filtering criteria with threshold {self.filter_threshold}' 897 | 898 | self.finetune(filtered_results) 899 | 900 | return filtered_results 901 | -------------------------------------------------------------------------------- /toolformer_pytorch/tools.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | try: 4 | from dotenv import load_dotenv 5 | load_dotenv() 6 | 7 | import requests 8 | import calendar 9 | import wolframalpha 10 | import datetime 11 | from transformers import AutoModelForSeq2SeqLM, AutoTokenizer 12 | from operator import pow, truediv, mul, add, sub 13 | 14 | # Optional imports 15 | from googleapiclient.discovery import build 16 | 17 | except ImportError: 18 | print('please run `pip install tools-requirements.txt` first at project directory') 19 | exit() 20 | 21 | ''' 22 | Calendar 23 | 24 | Uses Python's datetime and calendar libraries to retrieve the current date. 25 | 26 | input - None 27 | 28 | output - A string, the current date. 29 | ''' 30 | def Calendar(): 31 | now = datetime.datetime.now() 32 | return f'Today is {calendar.day_name[now.weekday()]}, {calendar.month_name[now.month]} {now.day}, {now.year}.' 33 | 34 | 35 | ''' 36 | Wikipedia Search 37 | 38 | Uses ColBERTv2 to retrieve Wikipedia documents. 39 | 40 | input_query - A string, the input query (e.g. "what is a dog?") 41 | k - The number of documents to retrieve 42 | 43 | output - A list of strings, each string is a Wikipedia document 44 | 45 | Adapted from Stanford's DSP: https://github.com/stanfordnlp/dsp/ 46 | Also see: https://github.com/lucabeetz/dsp 47 | ''' 48 | class ColBERTv2: 49 | def __init__(self, url: str): 50 | self.url = url 51 | 52 | def __call__(self, query, k=10): 53 | topk = colbertv2_get_request(self.url, query, k) 54 | 55 | topk = [doc['text'] for doc in topk] 56 | return topk 57 | 58 | def colbertv2_get_request(url: str, query: str, k: int): 59 | payload = {'query': query, 'k': k} 60 | res = requests.get(url, params=payload) 61 | 62 | topk = res.json()['topk'][:k] 63 | return topk 64 | 65 | def WikiSearch( 66 | input_query: str, 67 | url: str = 'http://ec2-44-228-128-229.us-west-2.compute.amazonaws.com:8893/api/search', 68 | k: int = 10 69 | ): 70 | retrieval_model = ColBERTv2(url) 71 | output = retrieval_model(input_query, k) 72 | return output 73 | 74 | ''' 75 | Machine Translation - NLLB-600M 76 | 77 | Uses HuggingFace's transformers library to translate input query to English. 78 | 79 | input_query - A string, the input query (e.g. "what is a dog?") 80 | 81 | output - A string, the translated input query. 82 | ''' 83 | def MT(input_query: str, model_name: str = "facebook/nllb-200-distilled-600M"): 84 | tokenizer = AutoTokenizer.from_pretrained(model_name) 85 | model = AutoModelForSeq2SeqLM.from_pretrained(model_name) 86 | input_ids = tokenizer(input_query, return_tensors='pt') 87 | outputs = model.generate( 88 | **input_ids, 89 | forced_bos_token_id=tokenizer.lang_code_to_id["eng_Latn"], 90 | ) 91 | output = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] 92 | return output 93 | 94 | 95 | ''' 96 | Calculator 97 | 98 | Calculates the result of a mathematical expression. 99 | 100 | input_query - A string, the input query (e.g. "400/1400") 101 | 102 | output - A float, the result of the calculation 103 | 104 | Adapted from: https://levelup.gitconnected.com/3-ways-to-write-a-calculator-in-python-61642f2e4a9a 105 | ''' 106 | def Calculator(input_query: str): 107 | operators = { 108 | '+': add, 109 | '-': sub, 110 | '*': mul, 111 | '/': truediv 112 | } 113 | if input_query.isdigit(): 114 | return float(input_query) 115 | for c in operators.keys(): 116 | left, operator, right = input_query.partition(c) 117 | if operator in operators: 118 | return round(operators[operator](Calculator(left), Calculator(right)), 2) 119 | 120 | 121 | # Other Optional Tools 122 | 123 | 124 | ''' 125 | Wolfram Alpha Calculator 126 | 127 | pip install wolframalpha 128 | 129 | Uses Wolfram Alpha API to calculate input query. 130 | 131 | input_query - A string, the input query (e.g. "what is 2 + 2?") 132 | 133 | output - A string, the answer to the input query 134 | 135 | wolfarm_alpha_appid - your Wolfram Alpha API key 136 | ''' 137 | def WolframAlphaCalculator(input_query: str): 138 | wolfram_alpha_appid = os.environ.get('WOLFRAM_ALPHA_APPID') 139 | wolfram_client = wolframalpha.Client(wolfram_alpha_appid) 140 | res = wolfram_client.query(input_query) 141 | assumption = next(res.pods).text 142 | answer = next(res.results).text 143 | return f'Assumption: {assumption} \nAnswer: {answer}' 144 | 145 | 146 | ''' 147 | Google Search 148 | 149 | Uses Google's Custom Search API to retrieve Google Search results. 150 | 151 | input_query - The query to search for. 152 | num_results - The number of results to return. 153 | api_key - Your Google API key. 154 | cse_id - Your Google Custom Search Engine ID. 155 | 156 | output - A list of dictionaries, each dictionary is a Google Search result 157 | ''' 158 | def custom_search(query, api_key, cse_id, **kwargs): 159 | service = build("customsearch", "v1", developerKey=api_key) 160 | res = service.cse().list(q=query, cx=cse_id, **kwargs).execute() 161 | return res['items'] 162 | 163 | def google_search(input_query: str, num_results: int = 10): 164 | api_key = os.environ.get('GOOGLE_API_KEY') 165 | cse_id = os.environ.get('GOOGLE_CSE_ID') 166 | 167 | metadata_results = [] 168 | results = custom_search(input_query, num=num_results, api_key=api_key, cse_id=cse_id) 169 | for result in results: 170 | metadata_result = { 171 | "snippet": result["snippet"], 172 | "title": result["title"], 173 | "link": result["link"], 174 | } 175 | metadata_results.append(metadata_result) 176 | return metadata_results 177 | 178 | 179 | ''' 180 | Bing Search 181 | 182 | Uses Bing's Custom Search API to retrieve Bing Search results. 183 | 184 | input_query: The query to search for. 185 | bing_subscription_key: Your Bing API key. 186 | num_results: The number of results to return. 187 | 188 | output: A list of dictionaries, each dictionary is a Bing Search result 189 | ''' 190 | def _bing_search_results( 191 | search_term: str, 192 | bing_subscription_key: str, 193 | count: int, 194 | url: str = "https://api.bing.microsoft.com/v7.0/search" 195 | ): 196 | headers = {"Ocp-Apim-Subscription-Key": bing_subscription_key} 197 | params = { 198 | "q": search_term, 199 | "count": count, 200 | "textDecorations": True, 201 | "textFormat": "HTML", 202 | } 203 | response = requests.get( 204 | url, headers=headers, params=params 205 | ) 206 | response.raise_for_status() 207 | search_results = response.json() 208 | return search_results["webPages"]["value"] 209 | 210 | def bing_search( 211 | input_query: str, 212 | num_results: int = 10 213 | ): 214 | bing_subscription_key = os.environ.get("BING_API_KEY") 215 | metadata_results = [] 216 | results = _bing_search_results(input_query, bing_subscription_key, count=num_results) 217 | for result in results: 218 | metadata_result = { 219 | "snippet": result["snippet"], 220 | "title": result["name"], 221 | "link": result["url"], 222 | } 223 | metadata_results.append(metadata_result) 224 | return metadata_results 225 | 226 | 227 | if __name__ == '__main__': 228 | 229 | print(Calendar()) # Outputs a string, the current date 230 | 231 | print(Calculator('400/1400')) # For Optional Basic Calculator 232 | 233 | print(WikiSearch('What is a dog?')) # Outputs a list of strings, each string is a Wikipedia document 234 | 235 | print(MT("Un chien c'est quoi?")) # What is a dog? 236 | 237 | # Optional Tools 238 | 239 | print(WolframAlphaCalculator('What is 2 + 2?')) # 4 240 | 241 | print(google_search('What is a dog?')) 242 | # Outputs a list of dictionaries, each dictionary is a Google Search result 243 | 244 | print(bing_search('What is a dog?')) 245 | # Outputs a list of dictionaries, each dictionary is a Bing Search result 246 | -------------------------------------------------------------------------------- /tools-requirements.txt: -------------------------------------------------------------------------------- 1 | google-api-python-client 2 | python-dotenv 3 | requests 4 | transformers 5 | wolframalpha 6 | --------------------------------------------------------------------------------