├── .github
└── workflows
│ └── python-publish.yml
├── .gitignore
├── LICENSE
├── README.md
├── mask.png
├── reasoning-tokens.png
├── self_reasoning_tokens_pytorch
├── __init__.py
├── attention_with_stop_graddable_qkv.py
└── self_reasoning_tokens.py
└── setup.py
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 |
2 |
3 | # This workflow will upload a Python Package using Twine when a release is created
4 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
5 |
6 | # This workflow uses actions that are not certified by GitHub.
7 | # They are provided by a third-party and are governed by
8 | # separate terms of service, privacy policy, and support
9 | # documentation.
10 |
11 | name: Upload Python Package
12 |
13 | on:
14 | release:
15 | types: [published]
16 |
17 | jobs:
18 | deploy:
19 |
20 | runs-on: ubuntu-latest
21 |
22 | steps:
23 | - uses: actions/checkout@v2
24 | - name: Set up Python
25 | uses: actions/setup-python@v2
26 | with:
27 | python-version: '3.x'
28 | - name: Install dependencies
29 | run: |
30 | python -m pip install --upgrade pip
31 | pip install build
32 | - name: Build package
33 | run: python -m build
34 | - name: Publish package
35 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
36 | with:
37 | user: __token__
38 | password: ${{ secrets.PYPI_API_TOKEN }}
39 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Phil Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Self Reasoning Tokens - Pytorch (wip)
4 |
5 | Exploration into the proposed Self Reasoning Tokens by Felipe Bonetto. The blog post seems a bit unfleshed out, but the idea of stop gradients from next token(s) is an interesting one.
6 |
7 | My initial thought was to apply a stop gradient mask on the attention matrix, but then realized that the values of the "reasoning" tokens could not be stop gradiented correctly without memory issues.
8 |
9 | While walking the dog and meditating on this, I came to the realization that one can create independent stop gradient masks for queries, keys, values in either flash attention or a custom attention backwards, and there may be a whole array of possibilities there. If any experiments come back positive from this exploration, will build out a concrete implementation of this.
10 |
11 | ## Install
12 |
13 | ```bash
14 | $ pip install self-reasoning-tokens-pytorch
15 | ```
16 |
17 | ## Usage
18 |
19 | ```python
20 | import torch
21 | from self_reasoning_tokens_pytorch import Transformer
22 |
23 | model = Transformer(
24 | dim = 512,
25 | depth = 4,
26 | num_tokens = 256,
27 | stop_grad_next_tokens_to_reason = True
28 | )
29 |
30 | x = torch.randint(0, 256, (1, 4))
31 |
32 | loss = model(
33 | x,
34 | num_reason_tokens = 4, # number of reasoning tokens per time step
35 | num_steps_future_can_use_reason = 16, # say you wish for reason tokens to be only attended to by tokens 16 time steps into the future
36 | return_loss = True
37 | )
38 |
39 | loss.backward()
40 |
41 | logits = model(x, num_reason_tokens = 4)
42 | ```
43 |
44 | Or use the novel attention with ability to pass specific stop gradient masks for queries, keys, values
45 |
46 | ```python
47 | import torch
48 | from self_reasoning_tokens_pytorch import stop_graddable_attn
49 |
50 | q = torch.randn(2, 8, 1024, 64)
51 | k = torch.randn(2, 8, 1024, 64)
52 | v = torch.randn(2, 8, 1024, 64)
53 |
54 | stop_grad_mask = torch.randint(0, 2, (8, 1024, 1024)).bool()
55 |
56 | out = stop_graddable_attn(
57 | q, k, v, causal = True,
58 | q_stop_grad_mask = stop_grad_mask,
59 | k_stop_grad_mask = stop_grad_mask,
60 | v_stop_grad_mask = stop_grad_mask
61 | )
62 |
63 | out.shape # (2, 8, 1024, 64)
64 | ```
65 |
66 | The mask should look something like
67 |
68 |
69 |
70 | ## Todo
71 |
72 | - [ ] deviating from blog post, also try optimizing only a subset of attention heads by tokens far into the future
73 |
74 | ## Citations
75 |
76 | ```bibtex
77 | @misc{Bonetto2024,
78 | author = {Felipe Bonetto},
79 | url = {https://reasoning-tokens.ghost.io/reasoning-tokens/}
80 | }
81 | ```
82 |
--------------------------------------------------------------------------------
/mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/self-reasoning-tokens-pytorch/cfc93ec82d05ba462ab0022cda152d8f5b0abc29/mask.png
--------------------------------------------------------------------------------
/reasoning-tokens.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/self-reasoning-tokens-pytorch/cfc93ec82d05ba462ab0022cda152d8f5b0abc29/reasoning-tokens.png
--------------------------------------------------------------------------------
/self_reasoning_tokens_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | from self_reasoning_tokens_pytorch.self_reasoning_tokens import (
2 | Transformer,
3 | CausalAttention
4 | )
5 |
6 | from self_reasoning_tokens_pytorch.attention_with_stop_graddable_qkv import (
7 | stop_graddable_attn_,
8 | stop_graddable_attn
9 | )
10 |
--------------------------------------------------------------------------------
/self_reasoning_tokens_pytorch/attention_with_stop_graddable_qkv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd.function import Function
3 |
4 | from einops import einsum, rearrange
5 |
6 | def exists(val):
7 | return val is not None
8 |
9 | # custom function
10 |
11 | class StopGraddableAttentionFunction(Function):
12 |
13 | @staticmethod
14 | @torch.no_grad()
15 | def forward(
16 | ctx,
17 | q,
18 | k,
19 | v,
20 | mask,
21 | attn_mask,
22 | causal: bool,
23 | q_stop_grad_mask,
24 | k_stop_grad_mask,
25 | v_stop_grad_mask,
26 | ):
27 | scale = q.shape[-1] ** -0.5
28 |
29 | sim = einsum(q, k, 'b h i d, b h j d -> b h i j') * scale
30 |
31 | max_neg_value = -torch.finfo(sim.dtype).max
32 |
33 | if exists(mask):
34 | mask = rearrange(col_mask, 'b j -> b 1 1 j')
35 | sim.masked_fill_(~mask, max_neg_value)
36 |
37 | if exists(attn_mask):
38 | sim.masked_fill_(~attn_mask, max_neg_value)
39 |
40 | if causal:
41 | i, j = sim.shape[-2:]
42 | causal_mask = torch.ones((i, j), dtype = torch.bool, device = sim.device).triu(j - i + 1)
43 | sim = sim.masked_fill(causal_mask, max_neg_value)
44 |
45 | attn = sim.softmax(dim = -1)
46 |
47 | out = einsum(attn, v, 'b h i j, b h j d -> b h i d')
48 |
49 | ctx.args = (
50 | causal,
51 | scale,
52 | mask,
53 | q_stop_grad_mask,
54 | k_stop_grad_mask,
55 | v_stop_grad_mask
56 | )
57 |
58 | ctx.save_for_backward(
59 | q, k, v,
60 | attn,
61 | out
62 | )
63 |
64 | return out
65 |
66 | @staticmethod
67 | @torch.no_grad()
68 | def backward(ctx, do):
69 |
70 | (
71 | causal,
72 | scale,
73 | mask,
74 | q_stop_grad_mask,
75 | k_stop_grad_mask,
76 | v_stop_grad_mask
77 | ) = ctx.args
78 |
79 | q, k, v, p, o = ctx.saved_tensors
80 |
81 | # stop grad masks are either type bool, with True indicating stop grad, or can be type float, in which case it will scale the gradients
82 |
83 | if q_stop_grad_mask.dtype == torch.bool:
84 | q_stop_grad_mask = (~q_stop_grad_mask).float()
85 |
86 | if k_stop_grad_mask.dtype == torch.bool:
87 | k_stop_grad_mask = (~k_stop_grad_mask).float()
88 |
89 | if v_stop_grad_mask.dtype == torch.bool:
90 | v_stop_grad_mask = (~v_stop_grad_mask).float()
91 |
92 | # softmax D
93 |
94 | D = (do * o).sum(dim = -1, keepdims = True)
95 |
96 | # stop grad for values
97 |
98 | p_v = p
99 |
100 | if exists(v_stop_grad_mask):
101 | p_v.mul_(v_stop_grad_mask)
102 |
103 | # dv
104 |
105 | dv = einsum(p_v, do, 'b h i j, b h i d -> b h j d')
106 |
107 | # prep for dq and dk
108 |
109 | dp = einsum(do, v, 'b h i d, b h j d -> b h i j')
110 | ds = p * scale * (dp - D)
111 |
112 | # handle stop grad masking for queries and keys
113 |
114 | ds_q = ds_k = ds
115 |
116 | if exists(q_stop_grad_mask):
117 | ds_q.mul_(q_stop_grad_mask)
118 |
119 | if exists(k_stop_grad_mask):
120 | ds_k.mul_(k_stop_grad_mask)
121 |
122 | # dq and dk
123 |
124 | dq = einsum(ds_q, k, 'b h i j, b h j d -> b h i d')
125 | dk = einsum(ds_k, q, 'b h i j, b h i d -> b h j d')
126 |
127 | return dq, dk, dv, None, None, None, None, None, None
128 |
129 | # convenience method with defaults
130 |
131 | stop_graddable_attn_ = StopGraddableAttentionFunction.apply
132 |
133 | def stop_graddable_attn(
134 | q, k, v,
135 | mask = None,
136 | attn_mask = None,
137 | causal = False,
138 | q_stop_grad_mask = None,
139 | k_stop_grad_mask = None,
140 | v_stop_grad_mask = None
141 | ):
142 | return stop_graddable_attn_(q, k, v, mask, attn_mask, causal, q_stop_grad_mask, k_stop_grad_mask, v_stop_grad_mask)
143 |
--------------------------------------------------------------------------------
/self_reasoning_tokens_pytorch/self_reasoning_tokens.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | from torch.nn import Module, ModuleList
5 |
6 | from einops import einsum, rearrange, repeat, reduce
7 | from einops.layers.torch import Rearrange
8 |
9 | from x_transformers import (
10 | RMSNorm,
11 | FeedForward
12 | )
13 |
14 | from self_reasoning_tokens_pytorch.attention_with_stop_graddable_qkv import (
15 | stop_graddable_attn
16 | )
17 |
18 | # helper functions
19 |
20 | def exists(v):
21 | return v is not None
22 |
23 | def default(v, d):
24 | return v if exists(v) else d
25 |
26 | # attention
27 |
28 | class CausalAttention(Module):
29 | def __init__(
30 | self,
31 | dim,
32 | dim_head = 64,
33 | heads = 8
34 | ):
35 | super().__init__()
36 | self.scale = dim_head ** -0.5
37 | dim_inner = dim_head * heads
38 |
39 | self.to_qkv = nn.Sequential(
40 | RMSNorm(dim),
41 | nn.Linear(dim, dim_inner * 3, bias = False),
42 | Rearrange('b n (qkv h d) -> qkv b h n d', qkv = 3, h = heads)
43 | )
44 |
45 | self.to_out = nn.Sequential(
46 | Rearrange('b h n d -> b n (h d)'),
47 | nn.Linear(dim_inner, dim, bias = False)
48 | )
49 |
50 | def forward(
51 | self,
52 | x,
53 | attn_mask = None,
54 | stop_grad_attn_mask = None
55 | ):
56 | seq, device = x.shape[-2], x.device
57 |
58 | q, k, v = self.to_qkv(x)
59 |
60 | if exists(stop_grad_attn_mask):
61 | if not isinstance(stop_grad_attn_mask, tuple):
62 | stop_grad_attn_mask = (None, stop_grad_attn_mask, stop_grad_attn_mask)
63 |
64 | assert len(stop_grad_attn_mask) == 3, 'stop_grad_attn_mask must be either a stop grad mask (implicit for key / values) or a tuple of 3 Tensor for individual stop grads of queries, keys, values'
65 |
66 | q_stop_grad, k_stop_grad, v_stop_grad = stop_grad_attn_mask
67 |
68 | out = stop_graddable_attn(
69 | q, k, v,
70 | attn_mask = attn_mask,
71 | q_stop_grad_mask = q_stop_grad,
72 | k_stop_grad_mask = k_stop_grad,
73 | v_stop_grad_mask = v_stop_grad
74 | )
75 |
76 | else:
77 | q = q * self.scale
78 | sim = einsum(q, k, 'b h i d, b h j d -> b h i j')
79 |
80 | causal_mask = torch.ones((seq, seq), device = device, dtype = torch.bool).triu(1)
81 |
82 | mask_value = -torch.finfo(sim.dtype).max
83 | sim = sim.masked_fill(causal_mask, mask_value)
84 |
85 | if exists(attn_mask):
86 | sim = sim.masked_fill(~attn_mask, mask_value)
87 |
88 | attn = sim.softmax(dim = -1)
89 |
90 | out = einsum(attn, v, 'b h i j, b h j d -> b h i d')
91 |
92 | # combine heads
93 |
94 | return self.to_out(out)
95 |
96 | # transformer
97 |
98 | class Transformer(Module):
99 | def __init__(
100 | self,
101 | *,
102 | dim,
103 | num_tokens,
104 | depth,
105 | max_seq_len = 2048,
106 | max_reason_seq_len = 4,
107 | dim_head = 64,
108 | heads = 8,
109 | ignore_index = -1,
110 | stop_grad_next_tokens_to_reason = False
111 | ):
112 | super().__init__()
113 | self.max_seq_len = max_seq_len
114 |
115 | # embed
116 |
117 | self.token_emb = nn.Embedding(num_tokens, dim)
118 | self.pos_emb = nn.Embedding(max_seq_len, dim)
119 |
120 | # reasoning tokens
121 |
122 | self.max_reason_seq_len = max_reason_seq_len
123 | self.reason_tokens = nn.Parameter(torch.randn(max_reason_seq_len, dim))
124 | nn.init.normal_(self.reason_tokens, std = 0.02)
125 |
126 | # transformer layers
127 |
128 | self.layers = ModuleList([])
129 | for _ in range(depth):
130 |
131 | attn = CausalAttention(
132 | dim = dim,
133 | dim_head = dim_head,
134 | heads = heads
135 | )
136 |
137 | ff = nn.Sequential(
138 | RMSNorm(dim),
139 | FeedForward(dim = dim)
140 | )
141 |
142 | self.layers.append(ModuleList([attn, ff]))
143 |
144 | self.norm = RMSNorm(dim)
145 | self.to_logits = nn.Linear(dim, num_tokens, bias = False)
146 |
147 | # loss related
148 |
149 | self.ignore_index = ignore_index
150 |
151 | # stop gradient settings
152 |
153 | self.stop_grad_next_tokens_to_reason = stop_grad_next_tokens_to_reason
154 |
155 | def forward(
156 | self,
157 | x,
158 | num_reason_tokens = 0,
159 | num_steps_future_can_use_reason = 2, # how many positions into the future until a reason token can be attended to
160 | remove_reason_tokens_at_end = False,
161 | return_loss = False
162 | ):
163 |
164 | if return_loss:
165 | x, labels = x[:, :-1], x[:, 1:]
166 |
167 | batch, seq, device = *x.shape, x.device
168 |
169 | assert seq <= self.max_seq_len
170 |
171 | x = self.token_emb(x)
172 |
173 | seq_arange = torch.arange(seq, device = device)
174 | pos = self.pos_emb(seq_arange)
175 |
176 | attn_kwargs = dict()
177 |
178 | # intersperse reasoning tokens if needed
179 |
180 | has_reason_tokens = num_reason_tokens > 0
181 |
182 | if has_reason_tokens:
183 | assert num_reason_tokens <= self.max_reason_seq_len
184 |
185 | x = rearrange(x, 'b n d -> b n 1 d')
186 |
187 | reason_tokens = self.reason_tokens[:num_reason_tokens]
188 | reason_tokens = repeat(reason_tokens, 'r d -> b n r d', b = batch, n = seq)
189 |
190 | x = torch.cat((x, reason_tokens), dim = -2)
191 | x = rearrange(x, 'b n r d -> b (n r) d')
192 |
193 | # handle absolute positions
194 | # applied axially to reasoning tokens and main token
195 |
196 | num_tokens_per_timestep = num_reason_tokens + 1
197 | pos = repeat(pos, 'n d -> (n r) d', r = num_tokens_per_timestep)
198 |
199 | # handle masking for reasoning tokens
200 | # each reason token can only be attended to by tokens (+ future reasoning tokens) that are {num_steps_future_can_use_reason}
201 |
202 | seq_timesteps = repeat(seq_arange, 'n -> (n r)', r = num_tokens_per_timestep)
203 |
204 | seq_with_reason_range = torch.arange(seq_timesteps.shape[-1], device = device)
205 | is_reason_token_mask = ~(seq_with_reason_range % num_tokens_per_timestep == 0)
206 |
207 | q_range = rearrange(seq_timesteps, 'n -> n 1')
208 | k_range = rearrange(seq_timesteps, 'n -> 1 n')
209 |
210 | attn_mask = ~(
211 | is_reason_token_mask &
212 | (q_range > k_range) &
213 | ((q_range - num_steps_future_can_use_reason) <= k_range)
214 | )
215 |
216 | # whether to fully mask out or stop gradient on attention matrix
217 |
218 | if self.stop_grad_next_tokens_to_reason:
219 | attn_kwargs = dict(stop_grad_attn_mask = ~attn_mask)
220 | else:
221 | attn_kwargs = dict(attn_mask = attn_mask)
222 |
223 | # attention and feedforward, passing in reason tokens mask from above
224 |
225 | x = x + pos
226 |
227 | for attn, ff in self.layers:
228 | x = attn(x, **attn_kwargs) + x
229 | x = ff(x) + x
230 |
231 | embed = self.norm(x)
232 |
233 | logits = self.to_logits(embed)
234 |
235 | # whether to remove reason tokens at the very end
236 |
237 | if has_reason_tokens and remove_reason_tokens_at_end:
238 | logits = rearrange(logits, 'b (n r) c -> b n r c', r = num_tokens_per_timestep)
239 | logits = logits[..., 0, :]
240 |
241 | if not return_loss:
242 | return logits
243 |
244 | if has_reason_tokens and not remove_reason_tokens_at_end:
245 | labels = repeat(labels, 'b n -> b (n r)', r = num_tokens_per_timestep)
246 |
247 | loss = F.cross_entropy(
248 | rearrange(logits, 'b n c -> b c n'),
249 | labels,
250 | ignore_index = self.ignore_index
251 | )
252 |
253 | return loss
254 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name = 'self-reasoning-tokens-pytorch',
5 | packages = find_packages(exclude = []),
6 | version = '0.0.4',
7 | license='MIT',
8 | description = 'Self Reasoning Tokens',
9 | author = 'Phil Wang',
10 | author_email = 'lucidrains@gmail.com',
11 | long_description_content_type = 'text/markdown',
12 | url = 'https://github.com/lucidrains/self-reasoning-tokens-pytorch',
13 | keywords = [
14 | 'artificial intelligence',
15 | 'deep learning',
16 | 'transformers',
17 | 'attention mechanism',
18 | 'adaptive computation'
19 | ],
20 | install_requires=[
21 | 'einops>=0.8.0',
22 | 'x-transformers>=1.28.4',
23 | 'torch>=2.0',
24 | ],
25 | classifiers=[
26 | 'Development Status :: 4 - Beta',
27 | 'Intended Audience :: Developers',
28 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
29 | 'License :: OSI Approved :: MIT License',
30 | 'Programming Language :: Python :: 3.7',
31 | ],
32 | )
33 |
--------------------------------------------------------------------------------