├── .github
└── workflows
│ └── python-publish.yml
├── .gitignore
├── LICENSE
├── README.md
├── setup.py
├── toolformer.png
├── toolformer_pytorch
├── __init__.py
├── optimizer.py
├── palm.py
├── prompts.py
├── toolformer_pytorch.py
└── tools.py
└── tools-requirements.txt
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 |
2 |
3 | # This workflow will upload a Python Package using Twine when a release is created
4 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
5 |
6 | # This workflow uses actions that are not certified by GitHub.
7 | # They are provided by a third-party and are governed by
8 | # separate terms of service, privacy policy, and support
9 | # documentation.
10 |
11 | name: Upload Python Package
12 |
13 | on:
14 | release:
15 | types: [published]
16 |
17 | jobs:
18 | deploy:
19 |
20 | runs-on: ubuntu-latest
21 |
22 | steps:
23 | - uses: actions/checkout@v2
24 | - name: Set up Python
25 | uses: actions/setup-python@v2
26 | with:
27 | python-version: '3.x'
28 | - name: Install dependencies
29 | run: |
30 | python -m pip install --upgrade pip
31 | pip install build
32 | - name: Build package
33 | run: python -m build
34 | - name: Publish package
35 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
36 | with:
37 | user: __token__
38 | password: ${{ secrets.PYPI_API_TOKEN }}
39 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Phil Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Toolformer - Pytorch (wip)
4 |
5 | Implementation of Toolformer, Language Models That Can Use Tools, by MetaAI
6 |
7 | ## Appreciation
8 |
9 | - Stability.ai for the generous sponsorship to work and open source cutting edge artificial intelligence research
10 |
11 | - Enrico for getting the ball rolling with the initial commit of different tools!
12 |
13 | - Thanks goes out to ChatGPT for doing all the regular expressions in this repository for parsing the functions and parameters for the API calls. I am terrible at regular expressions, so this was enormous help from the AI (with no hitches, it was perfect).
14 |
15 | ## Install
16 |
17 | ```bash
18 | $ pip install toolformer-pytorch
19 | ```
20 |
21 | ## Usage
22 |
23 | Example usage with giving language models awareness of current date and time.
24 |
25 | ```python
26 | import torch
27 | from toolformer_pytorch import Toolformer, PaLM
28 |
29 | # simple calendar api call - function that returns a string
30 |
31 | def Calendar():
32 | import datetime
33 | from calendar import day_name, month_name
34 | now = datetime.datetime.now()
35 | return f'Today is {day_name[now.weekday()]}, {month_name[now.month]} {now.day}, {now.year}.'
36 |
37 | # prompt for teaching it to use the Calendar function from above
38 |
39 | prompt = f"""
40 | Your task is to add calls to a Calendar API to a piece of text.
41 | The API calls should help you get information required to complete the text.
42 | You can call the API by writing "[Calendar()]"
43 | Here are some examples of API calls:
44 | Input: Today is the first Friday of the year.
45 | Output: Today is the first [Calendar()] Friday of the year.
46 | Input: The president of the United States is Joe Biden.
47 | Output: The president of the United States is [Calendar()] Joe Biden.
48 | Input: [input]
49 | Output:
50 | """
51 |
52 | data = [
53 | "The store is never open on the weekend, so today it is closed.",
54 | "The number of days from now until Christmas is 30",
55 | "The current day of the week is Wednesday."
56 | ]
57 |
58 | # model - here using PaLM, but any nn.Module that returns logits in the shape (batch, seq, num_tokens) is fine
59 |
60 | model = PaLM(
61 | dim = 512,
62 | depth = 2,
63 | heads = 8,
64 | dim_head = 64
65 | ).cuda()
66 |
67 | # toolformer
68 |
69 | toolformer = Toolformer(
70 | model = model,
71 | model_seq_len = 256,
72 | teach_tool_prompt = prompt,
73 | tool_id = 'Calendar',
74 | tool = Calendar,
75 | finetune = True
76 | )
77 |
78 | # invoking this will
79 | # (1) prompt the model with your inputs (data), inserted into [input] tag
80 | # (2) with the sampled outputs, filter out the ones that made proper API calls
81 | # (3) execute the API calls with the `tool` given
82 | # (4) filter with the specialized filter function (which can be used independently as shown in the next section)
83 | # (5) fine-tune on the filtered results
84 |
85 | filtered_stats = toolformer(data)
86 |
87 | # then, once you see the 'finetune complete' message
88 |
89 | response = toolformer.sample_model_with_api_calls("How many days until the next new years?")
90 |
91 | # hopefully you see it invoke the calendar and utilize the response of the api call...
92 |
93 | ```
94 |
95 | The main novelty of the paper is defining a fitness score for the outputs from a transformer instructed to insert API calls. The score is used to filter the sampled outputs for finetuning the transformer to make API calls that decreases perplexity of the text that follows it.
96 |
97 | ```python
98 | import torch
99 |
100 | from toolformer_pytorch import (
101 | Toolformer,
102 | PaLM,
103 | filter_tokens_with_api_response
104 | )
105 |
106 | # model
107 |
108 | palm = PaLM(
109 | dim = 512,
110 | num_tokens = 20000,
111 | depth = 2,
112 | heads = 8,
113 | dim_head = 64
114 | ).cuda()
115 |
116 | # mock some tokens
117 |
118 | mock_start_pos = 512
119 | mock_api_call_length = 10
120 | mock_api_start_id = 19998
121 | mock_api_stop_id = 19999
122 |
123 | tokens = torch.randint(0, 20000, (10, 1024)).cuda()
124 | tokens_with_api_response = torch.randint(0, 20000, (10, 1024)).cuda()
125 | tokens_without_api_response = torch.randint(0, 20000, (10, 1024)).cuda()
126 |
127 | tokens_with_api_response[:, mock_start_pos] = mock_api_start_id
128 | tokens_with_api_response[:, mock_start_pos + mock_api_call_length] = mock_api_stop_id
129 |
130 | tokens_without_api_response[:, mock_start_pos] = mock_api_start_id
131 | tokens_without_api_response[:, mock_start_pos + mock_api_call_length] = mock_api_stop_id
132 |
133 | # filter
134 |
135 | filtered_results = filter_tokens_with_api_response(
136 | model = palm,
137 | tokens = tokens,
138 | tokens_with_api_response = tokens_with_api_response,
139 | tokens_without_api_response = tokens_without_api_response,
140 | filter_threshold = 1.,
141 | api_start_token_id = mock_api_start_id,
142 | api_end_token_id = mock_api_stop_id
143 | )
144 | ```
145 |
146 | To invoke the tools on a string generated by the language model, use `invoke_tools`
147 |
148 | ```python
149 | from toolformer_pytorch import invoke_tools
150 |
151 | def inc(i):
152 | return i + 1
153 |
154 | def dec(i):
155 | return i - 1
156 |
157 | function_registry = dict(
158 | inc = inc,
159 | dec = dec
160 | )
161 |
162 | text = 'make the following api calls: [inc(1)] and [dec(2)] and [ignored(3)]'
163 |
164 | invoke_tools(function_registry, text)
165 |
166 | # make the following api calls: [inc(1) → 2] and [dec(2) → 1] and [ignored(3)]
167 | ```
168 |
169 | ## Todo
170 |
171 | - [x] create custom generate function for palm that can do external API calls
172 | - [x] allow for generating tokens at different cursor indices
173 | - [x] api token (which was left and right brackets in paper) needs to be customizable
174 | - [ ] allow for customizing how to fine handling errors in function name, parameters, or execution and output
175 | - [ ] Toolformer should eventually calculate all statistics (how many properly sampled, filtered out by different criterias, the distribution of scores as well as how many were rejected) before the final fine-tuning
176 | - [ ] do end-to-end training in `Toolformer`
177 | - [x] doing the prompting and bootstrapping the data
178 | - [x] prefiltering of bootstrapped data followed by api calls and then another round of filtering
179 | - [ ] keep track of all stats
180 | - [x] take care of fine-tuning
181 | - [ ] interleaving of datasets + optimizer hyperparams
182 | - [ ] hook up gpt-j
183 | - [ ] test for a simple calculator eval dataset
184 | - [ ] add a default callback within the Toolformer that automatically aligns the text and checks for validity before the filtering step - if the text was not copied correctly, the filtering step is not valid.
185 | - [ ] make sure final model, trained on many `Toolformer` instances, can be invoked with multiple tools - start with batch size of 1 and work way up
186 |
187 | ## Citations
188 |
189 | ```bibtex
190 | @inproceedings{Schick2023ToolformerLM,
191 | title = {Toolformer: Language Models Can Teach Themselves to Use Tools},
192 | author = {Timo Schick and Jane Dwivedi-Yu and Roberto Dessi and Roberta Raileanu and Maria Lomeli and Luke Zettlemoyer and Nicola Cancedda and Thomas Scialom},
193 | year = {2023}
194 | }
195 | ```
196 |
197 | ```bibtex
198 | @article{Gao2022PALPL,
199 | title = {PAL: Program-aided Language Models},
200 | author = {Luyu Gao and Aman Madaan and Shuyan Zhou and Uri Alon and Pengfei Liu and Yiming Yang and Jamie Callan and Graham Neubig},
201 | journal = {ArXiv},
202 | year = {2022},
203 | volume = {abs/2211.10435}
204 | }
205 | ```
206 |
207 | *Reality is that which, when you stop believing it, doesn't go away.* – Philip K. Dick.
208 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name = 'toolformer-pytorch',
5 | packages = find_packages(exclude=[]),
6 | version = '0.0.30',
7 | license='MIT',
8 | description = 'Toolformer - Pytorch',
9 | author = 'Phil Wang',
10 | author_email = 'lucidrains@gmail.com',
11 | long_description_content_type = 'text/markdown',
12 | url = 'https://github.com/lucidrains/toolformer-pytorch',
13 | keywords = [
14 | 'artificial intelligence',
15 | 'deep learning',
16 | 'transformers',
17 | 'attention mechanism',
18 | 'automated-tool-use'
19 | ],
20 | install_requires=[
21 | 'beartype',
22 | 'einops>=0.4',
23 | 'torch>=1.6',
24 | 'tqdm',
25 | 'x-clip>=0.14.3'
26 | ],
27 | classifiers=[
28 | 'Development Status :: 4 - Beta',
29 | 'Intended Audience :: Developers',
30 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
31 | 'License :: OSI Approved :: MIT License',
32 | 'Programming Language :: Python :: 3.6',
33 | ],
34 | )
35 |
--------------------------------------------------------------------------------
/toolformer.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/toolformer-pytorch/27e633217f11bb56a277436b584f2347442869c9/toolformer.png
--------------------------------------------------------------------------------
/toolformer_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | from toolformer_pytorch.palm import PaLM
2 |
3 | from toolformer_pytorch.toolformer_pytorch import (
4 | Toolformer,
5 | filter_tokens_with_api_response,
6 | sample,
7 | sample_with_api_call,
8 | has_api_calls,
9 | invoke_tools,
10 | replace_all_but_first
11 | )
12 |
--------------------------------------------------------------------------------
/toolformer_pytorch/optimizer.py:
--------------------------------------------------------------------------------
1 | from torch.optim import AdamW, Adam
2 |
3 | def separate_weight_decayable_params(params):
4 | wd_params, no_wd_params = [], []
5 | for param in params:
6 | param_list = no_wd_params if param.ndim < 2 else wd_params
7 | param_list.append(param)
8 | return wd_params, no_wd_params
9 |
10 | def get_optimizer(
11 | params,
12 | lr = 1e-4,
13 | wd = 1e-2,
14 | betas = (0.9, 0.99),
15 | eps = 1e-8,
16 | filter_by_requires_grad = False,
17 | group_wd_params = True,
18 | **kwargs
19 | ):
20 | has_weight_decay = wd > 0
21 |
22 | if filter_by_requires_grad:
23 | params = list(filter(lambda t: t.requires_grad, params))
24 |
25 | if group_wd_params and has_weight_decay:
26 | wd_params, no_wd_params = separate_weight_decayable_params(params)
27 |
28 | params = [
29 | {'params': wd_params},
30 | {'params': no_wd_params, 'weight_decay': 0},
31 | ]
32 |
33 | adam_kwargs = dict(lr = lr, betas = betas, eps = eps)
34 |
35 | if not has_weight_decay:
36 | return Adam(params, **adam_kwargs)
37 |
38 | return AdamW(params, weight_decay = wd, **adam_kwargs)
39 |
--------------------------------------------------------------------------------
/toolformer_pytorch/palm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, einsum
3 | from einops import rearrange
4 |
5 | from x_clip.tokenizer import tokenizer
6 |
7 | # helpers
8 |
9 | def exists(val):
10 | return val is not None
11 |
12 |
13 | # normalization
14 |
15 |
16 | class RMSNorm(nn.Module):
17 | def __init__(self, dim, eps = 1e-8):
18 | super().__init__()
19 | self.scale = dim ** -0.5
20 | self.eps = eps
21 | self.g = nn.Parameter(torch.ones(dim))
22 |
23 | def forward(self, x):
24 | norm = torch.norm(x, dim = -1, keepdim = True) * self.scale
25 | return x / norm.clamp(min = self.eps) * self.g
26 |
27 |
28 | # rotary positional embedding
29 | # https://arxiv.org/abs/2104.09864
30 |
31 |
32 | class RotaryEmbedding(nn.Module):
33 | def __init__(self, dim):
34 | super().__init__()
35 | inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
36 | self.register_buffer("inv_freq", inv_freq)
37 |
38 | def forward(self, max_seq_len, *, device):
39 | seq = torch.arange(max_seq_len, device=device, dtype=self.inv_freq.dtype)
40 | freqs = einsum("i , j -> i j", seq, self.inv_freq)
41 | return torch.cat((freqs, freqs), dim=-1)
42 |
43 |
44 | def rotate_half(x):
45 | x = rearrange(x, "... (j d) -> ... j d", j=2)
46 | x1, x2 = x.unbind(dim=-2)
47 | return torch.cat((-x2, x1), dim=-1)
48 |
49 |
50 | def apply_rotary_pos_emb(pos, t):
51 | return (t * pos.cos()) + (rotate_half(t) * pos.sin())
52 |
53 |
54 | # all we need
55 |
56 |
57 | class ParallelTransformerBlock(nn.Module):
58 | def __init__(self, dim, dim_head=64, heads=8, ff_mult=4):
59 | super().__init__()
60 | self.norm = RMSNorm(dim)
61 |
62 | attn_inner_dim = dim_head * heads
63 | ff_inner_dim = dim * ff_mult
64 | self.fused_dims = (attn_inner_dim, dim_head, dim_head, (ff_inner_dim))
65 |
66 | self.heads = heads
67 | self.scale = dim_head**-0.5
68 | self.rotary_emb = RotaryEmbedding(dim_head)
69 |
70 | self.fused_attn_ff_proj = nn.Linear(dim, sum(self.fused_dims), bias=False)
71 | self.attn_out = nn.Linear(attn_inner_dim, dim, bias=False)
72 |
73 | self.ff_out = nn.Sequential(
74 | nn.GELU(),
75 | nn.Linear(ff_inner_dim, dim, bias=False)
76 | )
77 |
78 | # for caching causal mask and rotary embeddings
79 |
80 | self.register_buffer("mask", None, persistent=False)
81 | self.register_buffer("pos_emb", None, persistent=False)
82 |
83 | def get_mask(self, n, device):
84 | if self.mask is not None and self.mask.shape[-1] >= n:
85 | return self.mask[:n, :n]
86 |
87 | mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1)
88 | self.register_buffer("mask", mask, persistent=False)
89 | return mask
90 |
91 | def get_rotary_embedding(self, n, device):
92 | if self.pos_emb is not None and self.pos_emb.shape[-2] >= n:
93 | return self.pos_emb[:n]
94 |
95 | pos_emb = self.rotary_emb(n, device=device)
96 | self.register_buffer("pos_emb", pos_emb, persistent=False)
97 | return pos_emb
98 |
99 | def forward(self, x):
100 | """
101 | einstein notation
102 | b - batch
103 | h - heads
104 | n, i, j - sequence length (base sequence length, source, target)
105 | d - feature dimension
106 | """
107 |
108 | n, device, h = x.shape[1], x.device, self.heads
109 |
110 | # pre layernorm
111 |
112 | x = self.norm(x)
113 |
114 | # attention queries, keys, values, and feedforward inner
115 |
116 | q, k, v, ff = self.fused_attn_ff_proj(x).split(self.fused_dims, dim=-1)
117 |
118 | # split heads
119 | # they use multi-query single-key-value attention, yet another Noam Shazeer paper
120 | # they found no performance loss past a certain scale, and more efficient decoding obviously
121 | # https://arxiv.org/abs/1911.02150
122 |
123 | q = rearrange(q, "b n (h d) -> b h n d", h=h)
124 |
125 | # rotary embeddings
126 |
127 | positions = self.get_rotary_embedding(n, device)
128 | q, k = map(lambda t: apply_rotary_pos_emb(positions, t), (q, k))
129 |
130 | # scale
131 |
132 | q = q * self.scale
133 |
134 | # similarity
135 |
136 | sim = einsum("b h i d, b j d -> b h i j", q, k)
137 |
138 | # causal mask
139 |
140 | causal_mask = self.get_mask(n, device)
141 | sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
142 |
143 | # attention
144 |
145 | attn = sim.softmax(dim=-1)
146 |
147 | # aggregate values
148 |
149 | out = einsum("b h i j, b j d -> b h i d", attn, v)
150 |
151 | # merge heads
152 |
153 | out = rearrange(out, "b h n d -> b n (h d)")
154 | return self.attn_out(out) + self.ff_out(ff)
155 |
156 |
157 | # Transformer
158 |
159 |
160 | class Transformer(nn.Module):
161 | def __init__(
162 | self,
163 | dim,
164 | depth,
165 | heads,
166 | dim_head,
167 | ff_mult = 4,
168 | ):
169 | super().__init__()
170 | self.layers = nn.ModuleList([])
171 |
172 | for _ in range(depth):
173 | self.layers.append(
174 | ParallelTransformerBlock(dim, dim_head, heads, ff_mult),
175 | )
176 |
177 | def forward(self, x):
178 | for block in self.layers:
179 | x = block(x) + x
180 | return x
181 |
182 |
183 | # classes
184 |
185 | class PaLM(nn.Module):
186 | def __init__(
187 | self,
188 | dim,
189 | depth,
190 | num_tokens=tokenizer.vocab_size,
191 | dim_head=64,
192 | heads=8,
193 | ff_mult=4,
194 | ):
195 | super().__init__()
196 | self.emb = nn.Embedding(num_tokens, dim)
197 |
198 | self.transformer = Transformer(dim, depth, heads, dim_head, ff_mult)
199 |
200 | self.to_logits = nn.Sequential(
201 | RMSNorm(dim),
202 | nn.Linear(dim, num_tokens)
203 | )
204 |
205 | def forward(self, x):
206 | x = self.emb(x)
207 | x = self.transformer(x)
208 | return self.to_logits(x)
209 |
210 | if __name__ == "__main__":
211 | palm = PaLM(
212 | num_tokens = 20000,
213 | dim = 512,
214 | depth = 6,
215 | dim_head = 64,
216 | heads = 8,
217 | ff_mult = 4,
218 | )
219 |
220 | tokens = torch.randint(0, 20000, (1, 512))
221 | logits = palm(tokens)
222 | print(logits.shape)
223 |
--------------------------------------------------------------------------------
/toolformer_pytorch/prompts.py:
--------------------------------------------------------------------------------
1 | DEFAULT_PROMPT_INPUT_TAG = '[input]'
2 |
3 | calculator_prompt = f"""
4 | Your task is to add calls to a Calculator API to a piece of text.
5 | The calls should help you get information required to complete the text.
6 | You can call the API by writing "[Calculator(expression)]" where "expression" is the expression to be computed.
7 | Here are some examples of API calls:
8 | Input: The number in the next term is 18 + 12 x 3 = 54.
9 | Output: The number in the next term is 18 + 12 x 3 = [Calculator(18 + 12 * 3)] 54.
10 | Input: The population is 658,893 people. This is 11.4% of the national average of 5,763,868 people.
11 | Output: The population is 658,893 people. This is 11.4% of the national average of [Calculator(658,893 / 11.4%)] 5,763,868 people.
12 | Input: A total of 252 qualifying matches were played, and 723 goals were scored (an average of 2.87 per match). This is three times less than the 2169 goals last year.
13 | Output: A total of 252 qualifying matches were played, and 723 goals were scored (an average of [Calculator(723 / 252)] 2.87 per match). This is twenty goals more than the [Calculator(723 - 20)] 703 goals last year.
14 | Input: I went to Paris in 1994 and stayed there until 2011, so in total, it was 17 years.
15 | Output: I went to Paris in 1994 and stayed there until 2011, so in total, it was [Calculator(2011 - 1994)] 17 years.
16 | Input: From this, we have 4 * 30 minutes = 120 minutes.
17 | Output: From this, we have 4 * 30 minutes = [Calculator(4 * 30)] 120 minutes.
18 | Input: {DEFAULT_PROMPT_INPUT_TAG}
19 | Output:
20 | """
21 |
22 | wikipedia_search_prompt = f"""
23 | Your task is to complete a given piece of text.
24 | You can use a Wikipedia Search API to look up information.
25 | You can do so by writing "[WikiSearch(term)]" where "term" is the search term you want to look up.
26 | Here are some examples of API calls:
27 | Input: The colors on the flag of Ghana have the following meanings: red is for the blood of martyrs, green for forests, and gold for mineral wealth.
28 | Output: The colors on the flag of Ghana have the following meanings: red is for [WikiSearch("Ghana flag red meaning")] the blood of martyrs, green for forests, and gold for mineral wealth.
29 | Input: But what are the risks during production of nanomaterials? Some nanomaterials may give rise to various kinds of lung damage.
30 | Output: But what are the risks during production of nanomaterials? [WikiSearch("nanomaterial production risks")] Some nanomaterials may give rise to various kinds of lung damage.
31 | Input: Metformin is the first-line drug for patients with type 2 diabetes and obesity.
32 | Output: Metformin is the first-line drug for [WikiSearch("Metformin first-line drug")] patients with type 2 diabetes and obesity.
33 | Input: {DEFAULT_PROMPT_INPUT_TAG}
34 | Output:
35 | """
36 |
37 | machine_translation_prompt = f"""
38 | Your task is to complete a given piece of text by using a Machine Translation API.
39 | You can do so by writing "[MT(text)]" where text is the text to be translated into English.
40 | Here are some examples:
41 | Input: He has published one book: O homem suprimido (“The Supressed Man”)
42 | Output: He has published one book: O homem suprimido [MT(O homem suprimido)] (“The Supressed Man”)
43 | Input: In Morris de Jonge’s Jeschuah, der klassische jüdische Mann, there is a description of a Jewish writer
44 | Output: In Morris de Jonge’s Jeschuah, der klassische jüdische Mann [MT(der klassische jüdische Mann)], there is a description of a Jewish writer
45 | Input: 南 京 高 淳 县 住 房 和 城 乡 建 设 局 城 市 新 区 设 计 a plane of reference Gaochun is one of seven districts of the provincial capital Nanjing
46 | Output: [MT(南京高淳县住房和城乡建设局 城市新 区 设 计)] a plane of reference Gaochun is one of seven districts of the provincial capital Nanjing
47 | Input: {DEFAULT_PROMPT_INPUT_TAG}
48 | Output:
49 | """
50 |
51 | calendar_prompt = f"""
52 | Your task is to add calls to a Calendar API to a piece of text.
53 | The API calls should help you get information required to complete the text.
54 | You can call the API by writing "[Calendar()]"
55 | Here are some examples of API calls:
56 | Input: Today is the first Friday of the year.
57 | Output: Today is the first [Calendar()] Friday of the year.
58 | Input: The president of the United States is Joe Biden.
59 | Output: The president of the United States is [Calendar()] Joe Biden.
60 | Input: The current day of the week is Wednesday.
61 | Output: The current day of the week is [Calendar()] Wednesday.
62 | Input: The number of days from now until Christmas is 30.
63 | Output: The number of days from now until Christmas is [Calendar()] 30.
64 | Input: The store is never open on the weekend, so today it is closed.
65 | Output: The store is never open on the weekend, so today [Calendar()] it is closed.
66 | Input: {DEFAULT_PROMPT_INPUT_TAG}
67 | Output:
68 | """
69 |
--------------------------------------------------------------------------------
/toolformer_pytorch/toolformer_pytorch.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 | from functools import partial, wraps
4 | from collections import namedtuple
5 |
6 | import torch
7 | from torch import nn
8 | import torch.nn.functional as F
9 |
10 | from torch.utils.data import Dataset, DataLoader
11 | from torch.nn.utils.rnn import pad_sequence
12 |
13 | from einops import rearrange, reduce
14 |
15 | from toolformer_pytorch.palm import PaLM
16 | from toolformer_pytorch.optimizer import get_optimizer
17 | from toolformer_pytorch.prompts import DEFAULT_PROMPT_INPUT_TAG
18 |
19 | from beartype import beartype
20 | from beartype.typing import Callable, Optional, Union, List, Tuple
21 |
22 | from tqdm import tqdm
23 | from x_clip.tokenizer import tokenizer
24 |
25 | pad_sequence = partial(pad_sequence, batch_first = True)
26 |
27 | # helpers
28 |
29 | def exists(val):
30 | return val is not None
31 |
32 | def default(val, d):
33 | return val if exists(val) else d
34 |
35 | def identity(t):
36 | return t
37 |
38 | def always(val):
39 | def inner(*args, **kwargs):
40 | return val
41 | return inner
42 |
43 | def try_except(fn, callback = identity):
44 | @wraps(fn)
45 | def inner(*args):
46 | try:
47 | return fn(*args)
48 | except Exception as e:
49 | return callback(e)
50 | return inner
51 |
52 | # tensor helpers
53 |
54 | def log(t, eps = 1e-20):
55 | return t.clamp(min = eps).log()
56 |
57 | def gumbel_noise(t):
58 | noise = torch.zeros_like(t).uniform_(0, 1)
59 | return -log(-log(noise))
60 |
61 | def gumbel_sample(t, temperature = 1., dim = -1, eps = 1e-10):
62 | if temperature == 0:
63 | return t.argmax(dim = dim)
64 |
65 | return ((t / max(temperature, eps)) + gumbel_noise(t)).argmax(dim = dim)
66 |
67 | def top_k(logits, thres = 0.9):
68 | k = math.ceil((1 - thres) * logits.shape[-1])
69 | val, indices = torch.topk(logits, k)
70 | probs = torch.full_like(logits, -torch.finfo(logits.dtype).max)
71 | probs.scatter_(1, indices, val)
72 | return probs
73 |
74 | def all_contains_id(t: torch.Tensor, token_id: int):
75 | mask = t == token_id
76 | return mask.any(dim = -1).all()
77 |
78 | def find_indices_of(t: torch.Tensor, token_id: int, occurrence = 1):
79 | assert occurrence > 0
80 | mask = (t == token_id)
81 |
82 | has_occurred = mask.cumsum(dim = -1)
83 | has_occurred = F.pad(has_occurred, (1, 0), value = 0.)
84 |
85 | return (has_occurred < occurrence).sum(dim = -1).long()
86 |
87 | # invoking api call functions
88 |
89 | def is_valid_string(s):
90 | return exists(re.fullmatch(r"'[^']*'|\"[^\"]*\"", s))
91 |
92 | def is_valid_integer(s):
93 | return exists(re.fullmatch(r"[+-]?\d+", s))
94 |
95 | def is_valid_float(s):
96 | return exists(re.fullmatch(r"[+-]?\d+(\.\d+)?", s))
97 |
98 | def parse_param(s: str) -> Optional[Union[int, float, str]]:
99 | if is_valid_string(s):
100 | return str(s)
101 | elif is_valid_integer(s):
102 | return int(s)
103 | elif is_valid_float(s):
104 | return float(s)
105 |
106 | return None
107 |
108 | @beartype
109 | def replace_fn(
110 | registry: dict[str, Callable],
111 | matches,
112 | delimiter = '→'
113 | ):
114 | orig_text = matches.group(0)
115 |
116 | text_without_end_api_token = matches.group(1)
117 | end_api_token = matches.group(4)
118 | function_name = matches.group(2)
119 |
120 | # unable to find function in registry
121 |
122 | if function_name not in registry:
123 | return orig_text
124 |
125 | fn = registry[function_name]
126 |
127 | params = matches.group(3).split(',')
128 | params = list(map(lambda s: s.strip(), params))
129 | params = list(filter(len, params))
130 | params = list(map(parse_param, params))
131 |
132 | # if any of the parameters are not parseable, return
133 |
134 | if any([(not exists(p)) for p in params]):
135 | return orig_text
136 |
137 | # just return original text if there is some error with the function
138 |
139 | out = try_except(fn, always(None))(*params)
140 |
141 | # the api calling function can also arrest the process, by returning None
142 |
143 | if not exists(out):
144 | return orig_text
145 |
146 | # return original text with the output delimiter and the stringified output
147 |
148 | return f'{text_without_end_api_token} {delimiter} {str(out)} {end_api_token}'
149 |
150 | # main function, which takes a registry of functions, the text in question, and makes all the appropriate api calls and append the output
151 |
152 | def create_function_regex(
153 | api_start = ' [',
154 | api_stop = ']'
155 | ):
156 | api_start_regex, api_stop_regex = map(re.escape, (api_start, api_stop))
157 | return rf'({api_start_regex}(\w+)\(([^)]*)\))({api_stop_regex})'
158 |
159 | def num_matches(substr: str, text: str):
160 | return len(re.findall(re.escape(substr), text))
161 |
162 | def has_api_calls(
163 | text,
164 | api_start = ' [',
165 | api_stop = ']'
166 | ):
167 | regex = create_function_regex(api_start, api_stop)
168 | matches = re.findall(regex, text)
169 | return len(matches) > 0
170 |
171 | def replace_all_but_first(
172 | text: str,
173 | api_start = ' [',
174 | api_stop = ']'
175 | ) -> str:
176 | regex = create_function_regex(api_start, api_stop)
177 |
178 | count = 0
179 |
180 | def replace_(matches):
181 | orig_text = matches.group(0)
182 | nonlocal count
183 | count += 1
184 | if count > 1:
185 | return ''
186 | return orig_text
187 |
188 | return re.sub(regex, replace_, text)
189 |
190 | def invoke_tools(
191 | registry: dict[str, Callable],
192 | text: str,
193 | delimiter: str = '→',
194 | api_start = ' [',
195 | api_stop = ' ]'
196 | ) -> str:
197 | regex = create_function_regex(api_start, api_stop)
198 | replace_ = partial(replace_fn, registry, delimiter = delimiter)
199 | return re.sub(regex, replace_, text)
200 |
201 | def invoke_tools_on_batch_sequences(
202 | registry: dict[str, Callable],
203 | token_ids: torch.Tensor,
204 | *,
205 | encode: Callable,
206 | decode: Callable,
207 | delimiter: str = '→',
208 | api_start = ' [',
209 | api_stop = ']'
210 | ) -> torch.Tensor:
211 | regex = create_function_regex(api_start_regex, api_stop_regex)
212 | all_texts = [decode(one_seq_token_ids) for one_seq_token_ids in token_ids]
213 |
214 | invoke_tools_ = partial(invoke_tools, api_start = api_start, api_stop = api_stop)
215 | all_texts_with_api_calls = [invoke_tools_(registry, text, delimiter) for text in all_texts]
216 |
217 | return encode(all_texts_with_api_calls)
218 |
219 | # sampling api related functions
220 | # they do greedy sampling, but encourage sampling api calls by auto-selecting when that token is in the top k = 10
221 |
222 | @beartype
223 | @torch.no_grad()
224 | def sample(
225 | model: nn.Module,
226 | *,
227 | seq_len,
228 | prime: Optional[torch.Tensor] = None,
229 | positions: Optional[torch.Tensor] = None,
230 | batch_size = 1,
231 | eos_token_id = None,
232 | sos_token_id = 1,
233 | temperature = 0.,
234 | pad_id = 0,
235 | call_api_only_once = False,
236 | api_start_token_id = None,
237 | auto_select_api_start_token_when_topk = False,
238 | select_api_start_id_top_k = 10,
239 | ):
240 | device = next(model.parameters()).device
241 | max_seq_len = seq_len + 1
242 |
243 | # validate
244 |
245 | if call_api_only_once:
246 | assert exists(api_start_token_id)
247 |
248 | # prime
249 |
250 | if exists(prime):
251 | batch_size, prime_length = prime.shape
252 | else:
253 | prime_length = 1
254 | prime = torch.full((batch_size, 1), sos_token_id, device = device, dtype = torch.long)
255 |
256 | prime = prime.to(device)
257 |
258 | # sampling positions - different sequences have different cursors
259 |
260 | if exists(positions):
261 | positions = positions.clone()
262 | else:
263 | positions = torch.zeros((batch_size,), device = device, dtype = torch.long)
264 |
265 | assert (positions <= (prime_length + 1)).all() and (positions <= max_seq_len).all(), 'all positions must be less then initial prime length as well as the total sequence length + 1 (plus one for noop if one sequence finished sampling before the other)'
266 |
267 | # eval model
268 |
269 | model.eval()
270 |
271 | # lengthen the prime to the entire sequence length
272 |
273 | remain_iterations = seq_len - prime_length
274 | output = F.pad(prime, (0, max_seq_len - prime_length), value = 0.)
275 |
276 | batch_indices = torch.arange(batch_size, device = device)
277 | batch_indices = rearrange(batch_indices, 'b -> b 1')
278 | position_indices = rearrange(positions, 'b -> b 1')
279 |
280 | # determine the token mask, for making sure api is called only once, masking out logit to prevent it from being selected for those rows which already contains an token
281 |
282 | api_token_mask = None # lazily created, since do not know logit dimensions
283 |
284 | def create_api_token_mask(num_tokens, api_start_token_id):
285 | mask = torch.zeros((1, 1, num_tokens), dtype = torch.bool)
286 | assert api_start_token_id < num_tokens
287 | mask[..., api_start_token_id] = True
288 | return mask
289 |
290 | # start iterating
291 |
292 | for iteration in tqdm(range(remain_iterations)):
293 | logits = model(output)
294 | last_logits = logits[batch_indices, position_indices]
295 |
296 | # this will ensure that each batch token sequence will have at most one token
297 |
298 | if call_api_only_once:
299 | if not exists(api_token_mask):
300 | num_tokens = last_logits.shape[-1]
301 | api_token_mask = create_api_token_mask(num_tokens, api_start_token_id)
302 | api_token_mask = api_token_mask.to(device)
303 |
304 | api_called = (output == api_start_token_id).any(dim = -1)
305 |
306 | logit_mask = api_token_mask & rearrange(api_called, 'b -> b 1 1')
307 | last_logits = last_logits.masked_fill(logit_mask, -torch.finfo(last_logits.dtype).max)
308 |
309 | # greedy sample (but could be made non-greedy)
310 |
311 | sampled = gumbel_sample(last_logits, temperature = temperature)
312 |
313 | # for those sequences without an api call, if the api_start_token_id is within top k (set to 10 in paper) of logits, just auto-select
314 |
315 | # seems to be an important hack in the paper
316 | # it seems like this paper will take a lot more follow up research to be viable
317 |
318 | if auto_select_api_start_token_when_topk:
319 | top_token_ids = last_logits.topk(select_api_start_id_top_k, dim = -1).indices
320 | has_api_token_in_topk = (top_token_ids == api_start_token_id).any(dim = -1)
321 | should_auto_select_api_token = has_api_token_in_topk & ~rearrange(api_called, 'b -> b 1')
322 |
323 | sampled = sampled.masked_fill(should_auto_select_api_token, api_start_token_id)
324 |
325 | # set the sampled tokens at the right curosr positions
326 |
327 | output[batch_indices, position_indices] = sampled
328 |
329 | # increment positions
330 |
331 | position_indices += 1
332 | position_indices.clamp_(max = seq_len) # noop if one sequence is further along and near the end
333 |
334 | # if using tokens, look for all sequences having it and terminate, also anything after will be padded
335 |
336 | if exists(eos_token_id):
337 | eos_mask = (output == eos_token_id)
338 | all_rows_have_eos = eos_mask.any(dim = -1).all()
339 |
340 | if all_rows_have_eos:
341 | keep_mask = eos_mask.cumsum(dim = -1) == 0
342 | keep_mask = F.pad(keep_mask, (1, 0), value = True)
343 | output = output.masked_fill(~keep_mask, pad_id)
344 | break
345 |
346 | # remove the last token in output (use as noop placeholder)
347 |
348 | output = output[:, :-1]
349 | return output
350 |
351 | @beartype
352 | @torch.no_grad()
353 | def sample_with_api_call(
354 | model: nn.Module,
355 | *,
356 | seq_len,
357 | call_apis: Callable,
358 | prime: torch.Tensor,
359 | api_end_token_id: int,
360 | occurrence = 1,
361 | **kwargs
362 | ):
363 | sampled = sample(
364 | model = model,
365 | prime = prime,
366 | seq_len = seq_len,
367 | **kwargs
368 | )
369 |
370 | sampled = call_apis(sampled)
371 |
372 | sampled_seq_len = sampled.shape[-1]
373 | null_positions = sampled_seq_len # handle sequences that do not have api calls
374 |
375 | pos_starting_at_end_of_api = find_indices_of(
376 | sampled,
377 | api_end_token_id,
378 | occurrence = occurrence
379 | )
380 |
381 | resample_after_api_calls = sample(
382 | model = model,
383 | prime = sampled,
384 | seq_len = sampled_seq_len,
385 | positions = (pos_starting_at_end_of_api + 1).clamp(max = null_positions), # start at the position right after the
386 | **kwargs
387 | )
388 |
389 | return resample_after_api_calls
390 |
391 | # the main contribution of the paper is simply the filtering equations presented in section 2
392 |
393 | def default_weight_fn(t):
394 | # following the formula in section 4.1 - however, not sure what w_s is in the denominator
395 | # if t stands for each timestep, this would also mean within 5 tokens it would diminish to 0?
396 | return (1. - t * 0.2).clamp(min = 0.)
397 |
398 | def get_pred_prob(token_ids, logits):
399 | logits = logits[:, :-1] # logits of each token... (omit last logit)
400 | token_ids = token_ids[:, 1:] # predicts the next token id (omit first token id)
401 |
402 | token_ids = rearrange(token_ids, 'b n -> b n 1')
403 | probs = logits.softmax(dim = -1)
404 | correct_token_id_pred_prob = probs.gather(-1, token_ids)
405 | return rearrange(correct_token_id_pred_prob, 'b n 1 -> b n')
406 |
407 | def get_arange_start_at_token_id(
408 | token_ids: torch.Tensor,
409 | token_id: int,
410 | pad_id = -1
411 | ):
412 | is_token_id_mask = token_ids == token_id
413 | arange = (is_token_id_mask.cumsum(dim = -1) > 0).cumsum(dim = -1)
414 | before_token_mask = arange == 0
415 | arange = arange - 1
416 | arange = arange.masked_fill(before_token_mask, pad_id)
417 | return arange
418 |
419 | def weight_and_mask(
420 | token_ids: torch.Tensor,
421 | token_id: int,
422 | pad_id = -1,
423 | weighting_fn: Callable = default_weight_fn
424 | ):
425 | t = get_arange_start_at_token_id(token_ids, token_id, pad_id)
426 | weights = weighting_fn(t)
427 | return weights.masked_fill(t == pad_id, 0.)
428 |
429 | FilteredResults = namedtuple('FilteredResults', [
430 | 'num_passed',
431 | 'num_failed',
432 | 'selected_indices',
433 | 'selected_mask',
434 | 'filtered_tokens',
435 | 'filtered_tokens_without_api_response',
436 | 'filtered_tokens_with_api_response'
437 | ])
438 |
439 | @beartype
440 | def filter_tokens_with_api_response(
441 | model: nn.Module, # the language model should accept the token ids below and return the logits in shape (batch, seq, num tokens)
442 | *,
443 | tokens: torch.Tensor, # token ids (batch, seq) of the original passage, without api calls
444 | tokens_without_api_response: torch.Tensor, # token ids (batch, seq) of the passage, but with the api call (but without a response filled in) - tool1(x, y)
445 | tokens_with_api_response: torch.Tensor, # token ids (batch, seq) of the passage with api call and the response - tool1(x, y) → {response}
446 | api_start_token_id: int, # token id of the tag
447 | api_end_token_id: int, # token id of the tag
448 | filter_threshold: float = 1., # the threshold at which to accept the sampled api call (tokens_with_api_response) for fine-tuning
449 | weighting_fn: Callable = default_weight_fn # weighting function
450 | ) -> FilteredResults:
451 |
452 | # validations
453 |
454 | assert all([*map(lambda t: t.dtype == torch.long, (tokens, tokens_with_api_response, tokens_without_api_response))])
455 |
456 | assert all_contains_id(tokens_without_api_response, api_start_token_id)
457 | assert all_contains_id(tokens_without_api_response, api_end_token_id)
458 |
459 | assert all_contains_id(tokens_with_api_response, api_start_token_id)
460 | assert all_contains_id(tokens_with_api_response, api_end_token_id)
461 |
462 | # auto set devices
463 |
464 | device = next(model.parameters()).device
465 | tokens, tokens_without_api_response, tokens_with_api_response = map(lambda t: t.to(device), (tokens, tokens_without_api_response, tokens_with_api_response))
466 |
467 | # get all the logits
468 |
469 | with torch.no_grad():
470 | model.eval()
471 | logits, logits_without_api_response, logits_with_api_response = map(model, (tokens, tokens_without_api_response, tokens_with_api_response))
472 |
473 | # derive all predicted prob of the actual next token id in sequence
474 |
475 | probs = get_pred_prob(tokens, logits)
476 | probs_without_api_response = get_pred_prob(tokens_without_api_response, logits_without_api_response)
477 | probs_with_api_response = get_pred_prob(tokens_with_api_response, logits_with_api_response)
478 |
479 | weight_and_mask_fn = partial(weight_and_mask, weighting_fn = weighting_fn)
480 |
481 | # derive the weighting
482 |
483 | weight_without_api_response = weight_and_mask_fn(tokens_without_api_response[:, :-1], api_end_token_id)
484 | weight_with_api_response = weight_and_mask_fn(tokens_with_api_response[:, :-1], api_end_token_id)
485 |
486 | # deriving the weighting for the original passage is more tricky
487 | # would need to start counting up from start token location
488 | # this would also assume that the language model perfectly copied the passage over and that both token ids are aligned except for the inserted API call - but this can be done with the custom filtering functions eventually
489 |
490 | weight = weight_and_mask_fn(tokens_without_api_response[:, 1:], api_start_token_id) # shift to the left by one since does not exist in the original sequence
491 | weight = weight[:, :probs.shape[-1]]
492 |
493 | # get the loss L for all three types of sequences
494 |
495 | def loss_fn(weight, probs):
496 | return (weight * -log(probs)).sum(dim = -1)
497 |
498 | loss = loss_fn(weight, probs)
499 | loss_without_api_response = loss_fn(weight_without_api_response, probs_without_api_response)
500 | loss_with_api_response = loss_fn(weight_with_api_response, probs_with_api_response)
501 |
502 | # calculate the main formula in the paper
503 |
504 | # loss+ = loss with api response
505 | # loss- = min(loss without api response, loss without api at all)
506 |
507 | loss_plus = loss_with_api_response
508 | loss_minus = torch.minimum(loss_without_api_response, loss)
509 |
510 | selected_mask = (loss_minus - loss_plus) >= filter_threshold
511 |
512 | # now we can select and return the entries that survived the filtering stage
513 | # also returning the selected indices of the batch being processed
514 | # for finetuning the model into toolformer
515 |
516 | batch = tokens.shape[0]
517 | indices = torch.arange(batch, device = tokens.device)
518 |
519 | selected_indices = indices[selected_mask]
520 |
521 | ret = FilteredResults(
522 | selected_mask.sum().item(),
523 | (~selected_mask).sum().item(),
524 | selected_indices,
525 | selected_mask,
526 | tokens[selected_mask],
527 | tokens_without_api_response[selected_mask],
528 | tokens_with_api_response[selected_mask]
529 | )
530 |
531 | return ret
532 |
533 | # datasets and dataloaders
534 |
535 | # for bootstrapping the initial datasets with api calls
536 | # as well as for the final finetuning
537 |
538 | @beartype
539 | class PromptDataset(Dataset):
540 | def __init__(
541 | self,
542 | prompt: str,
543 | prompt_input_tag: str,
544 | data: List[str],
545 | tokenizer_encode: Callable
546 | ):
547 | self.data = data
548 | self.prompt = prompt
549 | self.prompt_input_tag_regex = re.escape(prompt_input_tag)
550 | self.tokenizer_encode = tokenizer_encode
551 |
552 | def __len__(self):
553 | return len(self.data)
554 |
555 | def __getitem__(self, idx):
556 | data_string = self.data[idx]
557 | data_with_prompt = re.sub(self.prompt_input_tag_regex, data_string, self.prompt)
558 | token_ids = self.tokenizer_encode(data_with_prompt)
559 | return torch.tensor(token_ids).long(), torch.tensor(len(token_ids)).long()
560 |
561 | def prompt_collate_fn(data, padding_value = 0):
562 | prompts, prompt_lengths = zip(*data)
563 | prompts = pad_sequence(prompts, padding_value = padding_value)
564 | return prompts, torch.stack(prompt_lengths)
565 |
566 | def PromptDataloader(ds: Dataset, *args, padding_value = 0, **kwargs):
567 | collate_fn = partial(prompt_collate_fn, padding_value = padding_value)
568 | return DataLoader(ds, *args, collate_fn = collate_fn, **kwargs)
569 |
570 | class FinetuneDataset(Dataset):
571 | def __init__(
572 | self,
573 | tokens: torch.Tensor
574 | ):
575 | self.tokens = tokens
576 |
577 | def __len__(self):
578 | return len(self.tokens)
579 |
580 | def __getitem__(self, idx):
581 | return self.tokens[idx]
582 |
583 | def FinetuneDataloader(ds: Dataset, *args, padding_value = 0, **kwargs):
584 | return DataLoader(ds, *args, collate_fn = partial(pad_sequence, padding_value = padding_value), **kwargs)
585 |
586 | # classes
587 |
588 | @beartype
589 | class Toolformer(nn.Module):
590 | def __init__(
591 | self,
592 | model: nn.Module,
593 | *,
594 | tool_id: str,
595 | tool: Callable,
596 | api_start_str = ' [',
597 | api_stop_str = ']',
598 | api_response_delimiter = '→',
599 | api_start_id = None,
600 | api_stop_id = None,
601 | teach_tool_prompt: str,
602 | filter_threshold = 1.,
603 | pad_id = 0,
604 | prompt_batch_size = 4,
605 | model_seq_len = 2048,
606 | tokenizer_encode: Callable = tokenizer.encode,
607 | tokenizer_decode: Callable = tokenizer.decode,
608 | post_prompt_callback: Callable = identity,
609 | prompt_input_tag: str = DEFAULT_PROMPT_INPUT_TAG,
610 | exclude_filters: dict[str, Callable[[str], bool]] = dict(),
611 | finetune = False,
612 | finetune_lr = 1e-4,
613 | finetune_wd = 1e-2,
614 | finetune_betas = (0.9, 0.99),
615 | finetune_eps = 1e-8,
616 | finetune_epochs = 3,
617 | finetune_batch_size = 16
618 | ):
619 | super().__init__()
620 | self.model = model
621 | self.model_seq_len = model_seq_len
622 |
623 | self.teach_tool_prompt = teach_tool_prompt
624 | self.prompt_batch_size = prompt_batch_size
625 | self.prompt_input_tag = prompt_input_tag
626 |
627 | self.post_prompt_callback = post_prompt_callback # for easy mocking
628 |
629 | self.tokenizer_encode = tokenizer_encode
630 | self.tokenizer_decode = tokenizer_decode
631 | self.tokenizer_encode_to_tensor = lambda s: torch.tensor(tokenizer_encode(s)).long()
632 |
633 | self.filter_threshold = filter_threshold
634 |
635 | self.api_start_str = api_start_str
636 | self.api_stop_str = api_stop_str
637 | self.api_response_delimiter = api_response_delimiter
638 |
639 | if not exists(api_start_id):
640 | api_start_id = tokenizer_encode(api_start_str)
641 | assert len(api_start_id) == 1
642 | api_start_id = api_start_id[0]
643 |
644 | self.api_start_id = api_start_id
645 |
646 | if not exists(api_stop_id):
647 | api_stop_id = tokenizer_encode(api_stop_str)
648 | assert len(api_stop_id) == 1
649 | api_stop_id = api_stop_id[0]
650 |
651 | self.api_stop_id = api_stop_id
652 |
653 | self.pad_id = pad_id
654 |
655 | self.tool_id = tool_id
656 | self.tool = tool
657 | self.registry = {tool_id: tool}
658 |
659 | assert num_matches(prompt_input_tag, teach_tool_prompt) == 1, f'there must be exactly one prompt input tag `{prompt_input_tag}` in your prompt to encourage the language model to use the designated tool'
660 |
661 | self.teach_tool_prompt = teach_tool_prompt
662 | self.exclude_filters = exclude_filters
663 |
664 | self.should_finetune = finetune
665 |
666 | if not finetune:
667 | return
668 |
669 | self.finetune_batch_size = finetune_batch_size
670 | self.finetune_epochs = finetune_epochs
671 |
672 | self.optimizer = get_optimizer(
673 | model.parameters(),
674 | lr = finetune_lr,
675 | wd = finetune_wd,
676 | betas = finetune_betas,
677 | eps = finetune_eps
678 | )
679 |
680 | def generate_data_with_api_calls(
681 | self,
682 | data: List[str],
683 | temperature: float = 0.9
684 | ) -> List[str]:
685 |
686 | dataset = PromptDataset(
687 | data = data,
688 | prompt_input_tag = self.prompt_input_tag,
689 | prompt = self.teach_tool_prompt,
690 | tokenizer_encode = self.tokenizer_encode
691 | )
692 |
693 | dl = PromptDataloader(
694 | dataset,
695 | batch_size = self.prompt_batch_size
696 | )
697 |
698 | prompted_outputs = []
699 |
700 | for prime, positions in dl:
701 |
702 | sampled_outputs = sample(
703 | model = self.model,
704 | prime = prime,
705 | positions = positions,
706 | seq_len = self.model_seq_len,
707 | pad_id = self.pad_id,
708 | temperature = temperature
709 | )
710 |
711 | for sample_output, position in zip(sampled_outputs, positions):
712 | start_position = position.item()
713 |
714 | prompted_output = self.tokenizer_decode(sample_output[start_position:])
715 | prompted_outputs.append(prompted_output)
716 |
717 | return self.post_prompt_callback(prompted_outputs)
718 |
719 | def filter_and_keep_only_first_api_call(
720 | self,
721 | data,
722 | data_with_api_calls: List[str],
723 | return_excluded = False
724 | ):
725 | included_data = []
726 | included_data_with_api_calls = []
727 |
728 | included = (included_data, included_data_with_api_calls)
729 |
730 | excluded_data = []
731 | excluded_data_with_api_calls = []
732 |
733 | excluded = (excluded_data, excluded_data_with_api_calls)
734 |
735 | api_start_stop_kwargs = dict(api_start = self.api_start_str, api_stop = self.api_stop_str)
736 |
737 | has_api_calls_ = partial(has_api_calls, **api_start_stop_kwargs)
738 | replace_all_but_first_ = partial(replace_all_but_first, **api_start_stop_kwargs)
739 |
740 | for datum, data_with_api_call in zip(data, data_with_api_calls):
741 | if has_api_calls_(data_with_api_call):
742 | data_with_api_call = replace_all_but_first_(data_with_api_call)
743 |
744 | included_data.append(datum)
745 | included_data_with_api_calls.append(data_with_api_call)
746 | else:
747 | excluded_data.append(datum)
748 | excluded_data_with_api_calls.append(data_with_api_call)
749 |
750 | if not return_excluded:
751 | return included
752 |
753 | return included, excluded
754 |
755 | @torch.no_grad()
756 | def sample_model_with_api_calls(
757 | self,
758 | prime: Union[torch.Tensor, str],
759 | occurrence = 1,
760 | **kwargs
761 | ):
762 | self.model.eval()
763 |
764 | prime_is_str = isinstance(prime, str)
765 |
766 | if prime_is_str:
767 | prime = self.tokenizer_encode(prime)
768 | prime = torch.tensor(prime).long()
769 | prime = rearrange(prime, 'n -> 1 n')
770 |
771 | assert prime.shape[0] == 1, 'only one at a time for now'
772 |
773 | invoke_tools_ = partial(invoke_tools, self.registry)
774 |
775 | def call_apis(t: torch.Tensor):
776 | t = self.tokenizer_decode(t[0])
777 | t = invoke_tools_(t)
778 | t = self.tokenizer_encode_to_tensor(t)
779 | return rearrange(t, 'n -> 1 n')
780 |
781 | output = sample_with_api_call(
782 | model = self.model,
783 | prime = prime,
784 | seq_len = self.model_seq_len,
785 | call_apis = call_apis,
786 | api_end_token_id = self.api_stop_id,
787 | occurrence = occurrence,
788 | **kwargs
789 | )
790 |
791 | if not prime_is_str:
792 | return output
793 |
794 | return self.tokenizer_decode(output[0])
795 |
796 | def make_api_calls(
797 | self,
798 | filtered_data_with_api_calls: List[str]
799 | ):
800 | invoke_tools_ = partial(
801 | invoke_tools,
802 | self.registry,
803 | api_start = self.api_start_str,
804 | api_stop = self.api_stop_str, delimiter = self.api_response_delimiter
805 | )
806 |
807 | data_with_api_responses = []
808 | for data in filtered_data_with_api_calls:
809 | output = invoke_tools_(data)
810 | data_with_api_responses.append(output)
811 |
812 | return data_with_api_responses
813 |
814 | def filter_by_api_responses(
815 | self,
816 | data: List[str],
817 | data_with_api_calls: List[str],
818 | data_with_api_responses: List[str]
819 | ) -> FilteredResults:
820 |
821 | to_token_ids = lambda l: pad_sequence([*map(self.tokenizer_encode_to_tensor, l)], padding_value = self.pad_id)
822 |
823 | tokens, tokens_without_api_response, tokens_with_api_response = map(to_token_ids, (data, data_with_api_calls, data_with_api_responses))
824 |
825 | filtered_results = filter_tokens_with_api_response(
826 | model = self.model,
827 | tokens = tokens,
828 | tokens_with_api_response = tokens_with_api_response,
829 | tokens_without_api_response = tokens_without_api_response,
830 | filter_threshold = self.filter_threshold,
831 | api_start_token_id = self.api_start_id,
832 | api_end_token_id = self.api_stop_id
833 | )
834 |
835 | return filtered_results
836 |
837 | def finetune(
838 | self,
839 | filtered_results: Union[FilteredResults, torch.Tensor]
840 | ):
841 | self.model.train()
842 |
843 | if isinstance(filtered_results, FilteredResults):
844 | filtered_results = filtered_results.filtered_tokens_without_api_response
845 |
846 | dataset = FinetuneDataset(tokens = filtered_results)
847 | dl = FinetuneDataloader(dataset, batch_size = self.finetune_batch_size, shuffle = True)
848 |
849 | for epoch in tqdm(range(self.finetune_epochs), desc = 'finetune epochs'):
850 | for batch in dl:
851 | inp, labels = batch[:, :-1], batch[:, 1:]
852 |
853 | logits = self.model(inp)
854 | logits = rearrange(logits, 'b n c -> b c n')
855 |
856 | loss = F.cross_entropy(logits, labels, ignore_index = self.pad_id)
857 | loss.backward()
858 |
859 | print(f'loss: {loss.item()}')
860 | self.optimizer.step()
861 | self.optimizer.zero_grad()
862 |
863 | print(f'finished finetuning on {len(dataset)} filtered samples')
864 |
865 | def forward(
866 | self,
867 | data: List[str],
868 | return_after_generating_api_calls = False,
869 | return_after_making_api_calls = False,
870 | return_after_filtering_api_calls = False,
871 | return_after_filtering_by_api_response = False
872 | ):
873 | data_with_api_calls = self.generate_data_with_api_calls(data)
874 |
875 | if return_after_generating_api_calls:
876 | return data_with_api_calls
877 |
878 | filtered_data, filtered_data_with_api_calls = self.filter_and_keep_only_first_api_call(data, data_with_api_calls)
879 |
880 | if return_after_filtering_api_calls:
881 | return filtered_data, filtered_data_with_api_calls
882 |
883 | assert len(filtered_data_with_api_calls) > 0, 'your model failed to follow instructions and make API calls. please try a better model or do some better prompt engineering'
884 |
885 | data_with_responses = self.make_api_calls(filtered_data_with_api_calls)
886 |
887 | if return_after_making_api_calls:
888 | return filtered_data, filtered_data_with_api_calls, data_with_responses
889 |
890 | filtered_results = self.filter_by_api_responses(filtered_data, filtered_data_with_api_calls, data_with_responses)
891 |
892 | if return_after_filtering_by_api_response:
893 | return filtered_results
894 |
895 | if self.should_finetune:
896 | assert filtered_results.num_passed > 0, f'none of the sequences with API calls passed the filtering criteria with threshold {self.filter_threshold}'
897 |
898 | self.finetune(filtered_results)
899 |
900 | return filtered_results
901 |
--------------------------------------------------------------------------------
/toolformer_pytorch/tools.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | try:
4 | from dotenv import load_dotenv
5 | load_dotenv()
6 |
7 | import requests
8 | import calendar
9 | import wolframalpha
10 | import datetime
11 | from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
12 | from operator import pow, truediv, mul, add, sub
13 |
14 | # Optional imports
15 | from googleapiclient.discovery import build
16 |
17 | except ImportError:
18 | print('please run `pip install tools-requirements.txt` first at project directory')
19 | exit()
20 |
21 | '''
22 | Calendar
23 |
24 | Uses Python's datetime and calendar libraries to retrieve the current date.
25 |
26 | input - None
27 |
28 | output - A string, the current date.
29 | '''
30 | def Calendar():
31 | now = datetime.datetime.now()
32 | return f'Today is {calendar.day_name[now.weekday()]}, {calendar.month_name[now.month]} {now.day}, {now.year}.'
33 |
34 |
35 | '''
36 | Wikipedia Search
37 |
38 | Uses ColBERTv2 to retrieve Wikipedia documents.
39 |
40 | input_query - A string, the input query (e.g. "what is a dog?")
41 | k - The number of documents to retrieve
42 |
43 | output - A list of strings, each string is a Wikipedia document
44 |
45 | Adapted from Stanford's DSP: https://github.com/stanfordnlp/dsp/
46 | Also see: https://github.com/lucabeetz/dsp
47 | '''
48 | class ColBERTv2:
49 | def __init__(self, url: str):
50 | self.url = url
51 |
52 | def __call__(self, query, k=10):
53 | topk = colbertv2_get_request(self.url, query, k)
54 |
55 | topk = [doc['text'] for doc in topk]
56 | return topk
57 |
58 | def colbertv2_get_request(url: str, query: str, k: int):
59 | payload = {'query': query, 'k': k}
60 | res = requests.get(url, params=payload)
61 |
62 | topk = res.json()['topk'][:k]
63 | return topk
64 |
65 | def WikiSearch(
66 | input_query: str,
67 | url: str = 'http://ec2-44-228-128-229.us-west-2.compute.amazonaws.com:8893/api/search',
68 | k: int = 10
69 | ):
70 | retrieval_model = ColBERTv2(url)
71 | output = retrieval_model(input_query, k)
72 | return output
73 |
74 | '''
75 | Machine Translation - NLLB-600M
76 |
77 | Uses HuggingFace's transformers library to translate input query to English.
78 |
79 | input_query - A string, the input query (e.g. "what is a dog?")
80 |
81 | output - A string, the translated input query.
82 | '''
83 | def MT(input_query: str, model_name: str = "facebook/nllb-200-distilled-600M"):
84 | tokenizer = AutoTokenizer.from_pretrained(model_name)
85 | model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
86 | input_ids = tokenizer(input_query, return_tensors='pt')
87 | outputs = model.generate(
88 | **input_ids,
89 | forced_bos_token_id=tokenizer.lang_code_to_id["eng_Latn"],
90 | )
91 | output = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
92 | return output
93 |
94 |
95 | '''
96 | Calculator
97 |
98 | Calculates the result of a mathematical expression.
99 |
100 | input_query - A string, the input query (e.g. "400/1400")
101 |
102 | output - A float, the result of the calculation
103 |
104 | Adapted from: https://levelup.gitconnected.com/3-ways-to-write-a-calculator-in-python-61642f2e4a9a
105 | '''
106 | def Calculator(input_query: str):
107 | operators = {
108 | '+': add,
109 | '-': sub,
110 | '*': mul,
111 | '/': truediv
112 | }
113 | if input_query.isdigit():
114 | return float(input_query)
115 | for c in operators.keys():
116 | left, operator, right = input_query.partition(c)
117 | if operator in operators:
118 | return round(operators[operator](Calculator(left), Calculator(right)), 2)
119 |
120 |
121 | # Other Optional Tools
122 |
123 |
124 | '''
125 | Wolfram Alpha Calculator
126 |
127 | pip install wolframalpha
128 |
129 | Uses Wolfram Alpha API to calculate input query.
130 |
131 | input_query - A string, the input query (e.g. "what is 2 + 2?")
132 |
133 | output - A string, the answer to the input query
134 |
135 | wolfarm_alpha_appid - your Wolfram Alpha API key
136 | '''
137 | def WolframAlphaCalculator(input_query: str):
138 | wolfram_alpha_appid = os.environ.get('WOLFRAM_ALPHA_APPID')
139 | wolfram_client = wolframalpha.Client(wolfram_alpha_appid)
140 | res = wolfram_client.query(input_query)
141 | assumption = next(res.pods).text
142 | answer = next(res.results).text
143 | return f'Assumption: {assumption} \nAnswer: {answer}'
144 |
145 |
146 | '''
147 | Google Search
148 |
149 | Uses Google's Custom Search API to retrieve Google Search results.
150 |
151 | input_query - The query to search for.
152 | num_results - The number of results to return.
153 | api_key - Your Google API key.
154 | cse_id - Your Google Custom Search Engine ID.
155 |
156 | output - A list of dictionaries, each dictionary is a Google Search result
157 | '''
158 | def custom_search(query, api_key, cse_id, **kwargs):
159 | service = build("customsearch", "v1", developerKey=api_key)
160 | res = service.cse().list(q=query, cx=cse_id, **kwargs).execute()
161 | return res['items']
162 |
163 | def google_search(input_query: str, num_results: int = 10):
164 | api_key = os.environ.get('GOOGLE_API_KEY')
165 | cse_id = os.environ.get('GOOGLE_CSE_ID')
166 |
167 | metadata_results = []
168 | results = custom_search(input_query, num=num_results, api_key=api_key, cse_id=cse_id)
169 | for result in results:
170 | metadata_result = {
171 | "snippet": result["snippet"],
172 | "title": result["title"],
173 | "link": result["link"],
174 | }
175 | metadata_results.append(metadata_result)
176 | return metadata_results
177 |
178 |
179 | '''
180 | Bing Search
181 |
182 | Uses Bing's Custom Search API to retrieve Bing Search results.
183 |
184 | input_query: The query to search for.
185 | bing_subscription_key: Your Bing API key.
186 | num_results: The number of results to return.
187 |
188 | output: A list of dictionaries, each dictionary is a Bing Search result
189 | '''
190 | def _bing_search_results(
191 | search_term: str,
192 | bing_subscription_key: str,
193 | count: int,
194 | url: str = "https://api.bing.microsoft.com/v7.0/search"
195 | ):
196 | headers = {"Ocp-Apim-Subscription-Key": bing_subscription_key}
197 | params = {
198 | "q": search_term,
199 | "count": count,
200 | "textDecorations": True,
201 | "textFormat": "HTML",
202 | }
203 | response = requests.get(
204 | url, headers=headers, params=params
205 | )
206 | response.raise_for_status()
207 | search_results = response.json()
208 | return search_results["webPages"]["value"]
209 |
210 | def bing_search(
211 | input_query: str,
212 | num_results: int = 10
213 | ):
214 | bing_subscription_key = os.environ.get("BING_API_KEY")
215 | metadata_results = []
216 | results = _bing_search_results(input_query, bing_subscription_key, count=num_results)
217 | for result in results:
218 | metadata_result = {
219 | "snippet": result["snippet"],
220 | "title": result["name"],
221 | "link": result["url"],
222 | }
223 | metadata_results.append(metadata_result)
224 | return metadata_results
225 |
226 |
227 | if __name__ == '__main__':
228 |
229 | print(Calendar()) # Outputs a string, the current date
230 |
231 | print(Calculator('400/1400')) # For Optional Basic Calculator
232 |
233 | print(WikiSearch('What is a dog?')) # Outputs a list of strings, each string is a Wikipedia document
234 |
235 | print(MT("Un chien c'est quoi?")) # What is a dog?
236 |
237 | # Optional Tools
238 |
239 | print(WolframAlphaCalculator('What is 2 + 2?')) # 4
240 |
241 | print(google_search('What is a dog?'))
242 | # Outputs a list of dictionaries, each dictionary is a Google Search result
243 |
244 | print(bing_search('What is a dog?'))
245 | # Outputs a list of dictionaries, each dictionary is a Bing Search result
246 |
--------------------------------------------------------------------------------
/tools-requirements.txt:
--------------------------------------------------------------------------------
1 | google-api-python-client
2 | python-dotenv
3 | requests
4 | transformers
5 | wolframalpha
6 |
--------------------------------------------------------------------------------