├── .github └── workflows │ └── python-publish.yml ├── .gitignore ├── LICENSE ├── README.md ├── mask.png ├── reasoning-tokens.png ├── self_reasoning_tokens_pytorch ├── __init__.py ├── attention_with_stop_graddable_qkv.py └── self_reasoning_tokens.py └── setup.py /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | 2 | 3 | # This workflow will upload a Python Package using Twine when a release is created 4 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 5 | 6 | # This workflow uses actions that are not certified by GitHub. 7 | # They are provided by a third-party and are governed by 8 | # separate terms of service, privacy policy, and support 9 | # documentation. 10 | 11 | name: Upload Python Package 12 | 13 | on: 14 | release: 15 | types: [published] 16 | 17 | jobs: 18 | deploy: 19 | 20 | runs-on: ubuntu-latest 21 | 22 | steps: 23 | - uses: actions/checkout@v2 24 | - name: Set up Python 25 | uses: actions/setup-python@v2 26 | with: 27 | python-version: '3.x' 28 | - name: Install dependencies 29 | run: | 30 | python -m pip install --upgrade pip 31 | pip install build 32 | - name: Build package 33 | run: python -m build 34 | - name: Publish package 35 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 36 | with: 37 | user: __token__ 38 | password: ${{ secrets.PYPI_API_TOKEN }} 39 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Self Reasoning Tokens - Pytorch (wip) 4 | 5 | Exploration into the proposed Self Reasoning Tokens by Felipe Bonetto. The blog post seems a bit unfleshed out, but the idea of stop gradients from next token(s) is an interesting one. 6 | 7 | My initial thought was to apply a stop gradient mask on the attention matrix, but then realized that the values of the "reasoning" tokens could not be stop gradiented correctly without memory issues. 8 | 9 | While walking the dog and meditating on this, I came to the realization that one can create independent stop gradient masks for queries, keys, values in either flash attention or a custom attention backwards, and there may be a whole array of possibilities there. If any experiments come back positive from this exploration, will build out a concrete implementation of this. 10 | 11 | ## Install 12 | 13 | ```bash 14 | $ pip install self-reasoning-tokens-pytorch 15 | ``` 16 | 17 | ## Usage 18 | 19 | ```python 20 | import torch 21 | from self_reasoning_tokens_pytorch import Transformer 22 | 23 | model = Transformer( 24 | dim = 512, 25 | depth = 4, 26 | num_tokens = 256, 27 | stop_grad_next_tokens_to_reason = True 28 | ) 29 | 30 | x = torch.randint(0, 256, (1, 4)) 31 | 32 | loss = model( 33 | x, 34 | num_reason_tokens = 4, # number of reasoning tokens per time step 35 | num_steps_future_can_use_reason = 16, # say you wish for reason tokens to be only attended to by tokens 16 time steps into the future 36 | return_loss = True 37 | ) 38 | 39 | loss.backward() 40 | 41 | logits = model(x, num_reason_tokens = 4) 42 | ``` 43 | 44 | Or use the novel attention with ability to pass specific stop gradient masks for queries, keys, values 45 | 46 | ```python 47 | import torch 48 | from self_reasoning_tokens_pytorch import stop_graddable_attn 49 | 50 | q = torch.randn(2, 8, 1024, 64) 51 | k = torch.randn(2, 8, 1024, 64) 52 | v = torch.randn(2, 8, 1024, 64) 53 | 54 | stop_grad_mask = torch.randint(0, 2, (8, 1024, 1024)).bool() 55 | 56 | out = stop_graddable_attn( 57 | q, k, v, causal = True, 58 | q_stop_grad_mask = stop_grad_mask, 59 | k_stop_grad_mask = stop_grad_mask, 60 | v_stop_grad_mask = stop_grad_mask 61 | ) 62 | 63 | out.shape # (2, 8, 1024, 64) 64 | ``` 65 | 66 | The mask should look something like 67 | 68 | 69 | 70 | ## Todo 71 | 72 | - [ ] deviating from blog post, also try optimizing only a subset of attention heads by tokens far into the future 73 | 74 | ## Citations 75 | 76 | ```bibtex 77 | @misc{Bonetto2024, 78 | author = {Felipe Bonetto}, 79 | url = {https://reasoning-tokens.ghost.io/reasoning-tokens/} 80 | } 81 | ``` 82 | -------------------------------------------------------------------------------- /mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/self-reasoning-tokens-pytorch/cfc93ec82d05ba462ab0022cda152d8f5b0abc29/mask.png -------------------------------------------------------------------------------- /reasoning-tokens.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/self-reasoning-tokens-pytorch/cfc93ec82d05ba462ab0022cda152d8f5b0abc29/reasoning-tokens.png -------------------------------------------------------------------------------- /self_reasoning_tokens_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from self_reasoning_tokens_pytorch.self_reasoning_tokens import ( 2 | Transformer, 3 | CausalAttention 4 | ) 5 | 6 | from self_reasoning_tokens_pytorch.attention_with_stop_graddable_qkv import ( 7 | stop_graddable_attn_, 8 | stop_graddable_attn 9 | ) 10 | -------------------------------------------------------------------------------- /self_reasoning_tokens_pytorch/attention_with_stop_graddable_qkv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd.function import Function 3 | 4 | from einops import einsum, rearrange 5 | 6 | def exists(val): 7 | return val is not None 8 | 9 | # custom function 10 | 11 | class StopGraddableAttentionFunction(Function): 12 | 13 | @staticmethod 14 | @torch.no_grad() 15 | def forward( 16 | ctx, 17 | q, 18 | k, 19 | v, 20 | mask, 21 | attn_mask, 22 | causal: bool, 23 | q_stop_grad_mask, 24 | k_stop_grad_mask, 25 | v_stop_grad_mask, 26 | ): 27 | scale = q.shape[-1] ** -0.5 28 | 29 | sim = einsum(q, k, 'b h i d, b h j d -> b h i j') * scale 30 | 31 | max_neg_value = -torch.finfo(sim.dtype).max 32 | 33 | if exists(mask): 34 | mask = rearrange(col_mask, 'b j -> b 1 1 j') 35 | sim.masked_fill_(~mask, max_neg_value) 36 | 37 | if exists(attn_mask): 38 | sim.masked_fill_(~attn_mask, max_neg_value) 39 | 40 | if causal: 41 | i, j = sim.shape[-2:] 42 | causal_mask = torch.ones((i, j), dtype = torch.bool, device = sim.device).triu(j - i + 1) 43 | sim = sim.masked_fill(causal_mask, max_neg_value) 44 | 45 | attn = sim.softmax(dim = -1) 46 | 47 | out = einsum(attn, v, 'b h i j, b h j d -> b h i d') 48 | 49 | ctx.args = ( 50 | causal, 51 | scale, 52 | mask, 53 | q_stop_grad_mask, 54 | k_stop_grad_mask, 55 | v_stop_grad_mask 56 | ) 57 | 58 | ctx.save_for_backward( 59 | q, k, v, 60 | attn, 61 | out 62 | ) 63 | 64 | return out 65 | 66 | @staticmethod 67 | @torch.no_grad() 68 | def backward(ctx, do): 69 | 70 | ( 71 | causal, 72 | scale, 73 | mask, 74 | q_stop_grad_mask, 75 | k_stop_grad_mask, 76 | v_stop_grad_mask 77 | ) = ctx.args 78 | 79 | q, k, v, p, o = ctx.saved_tensors 80 | 81 | # stop grad masks are either type bool, with True indicating stop grad, or can be type float, in which case it will scale the gradients 82 | 83 | if q_stop_grad_mask.dtype == torch.bool: 84 | q_stop_grad_mask = (~q_stop_grad_mask).float() 85 | 86 | if k_stop_grad_mask.dtype == torch.bool: 87 | k_stop_grad_mask = (~k_stop_grad_mask).float() 88 | 89 | if v_stop_grad_mask.dtype == torch.bool: 90 | v_stop_grad_mask = (~v_stop_grad_mask).float() 91 | 92 | # softmax D 93 | 94 | D = (do * o).sum(dim = -1, keepdims = True) 95 | 96 | # stop grad for values 97 | 98 | p_v = p 99 | 100 | if exists(v_stop_grad_mask): 101 | p_v.mul_(v_stop_grad_mask) 102 | 103 | # dv 104 | 105 | dv = einsum(p_v, do, 'b h i j, b h i d -> b h j d') 106 | 107 | # prep for dq and dk 108 | 109 | dp = einsum(do, v, 'b h i d, b h j d -> b h i j') 110 | ds = p * scale * (dp - D) 111 | 112 | # handle stop grad masking for queries and keys 113 | 114 | ds_q = ds_k = ds 115 | 116 | if exists(q_stop_grad_mask): 117 | ds_q.mul_(q_stop_grad_mask) 118 | 119 | if exists(k_stop_grad_mask): 120 | ds_k.mul_(k_stop_grad_mask) 121 | 122 | # dq and dk 123 | 124 | dq = einsum(ds_q, k, 'b h i j, b h j d -> b h i d') 125 | dk = einsum(ds_k, q, 'b h i j, b h i d -> b h j d') 126 | 127 | return dq, dk, dv, None, None, None, None, None, None 128 | 129 | # convenience method with defaults 130 | 131 | stop_graddable_attn_ = StopGraddableAttentionFunction.apply 132 | 133 | def stop_graddable_attn( 134 | q, k, v, 135 | mask = None, 136 | attn_mask = None, 137 | causal = False, 138 | q_stop_grad_mask = None, 139 | k_stop_grad_mask = None, 140 | v_stop_grad_mask = None 141 | ): 142 | return stop_graddable_attn_(q, k, v, mask, attn_mask, causal, q_stop_grad_mask, k_stop_grad_mask, v_stop_grad_mask) 143 | -------------------------------------------------------------------------------- /self_reasoning_tokens_pytorch/self_reasoning_tokens.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from torch.nn import Module, ModuleList 5 | 6 | from einops import einsum, rearrange, repeat, reduce 7 | from einops.layers.torch import Rearrange 8 | 9 | from x_transformers import ( 10 | RMSNorm, 11 | FeedForward 12 | ) 13 | 14 | from self_reasoning_tokens_pytorch.attention_with_stop_graddable_qkv import ( 15 | stop_graddable_attn 16 | ) 17 | 18 | # helper functions 19 | 20 | def exists(v): 21 | return v is not None 22 | 23 | def default(v, d): 24 | return v if exists(v) else d 25 | 26 | # attention 27 | 28 | class CausalAttention(Module): 29 | def __init__( 30 | self, 31 | dim, 32 | dim_head = 64, 33 | heads = 8 34 | ): 35 | super().__init__() 36 | self.scale = dim_head ** -0.5 37 | dim_inner = dim_head * heads 38 | 39 | self.to_qkv = nn.Sequential( 40 | RMSNorm(dim), 41 | nn.Linear(dim, dim_inner * 3, bias = False), 42 | Rearrange('b n (qkv h d) -> qkv b h n d', qkv = 3, h = heads) 43 | ) 44 | 45 | self.to_out = nn.Sequential( 46 | Rearrange('b h n d -> b n (h d)'), 47 | nn.Linear(dim_inner, dim, bias = False) 48 | ) 49 | 50 | def forward( 51 | self, 52 | x, 53 | attn_mask = None, 54 | stop_grad_attn_mask = None 55 | ): 56 | seq, device = x.shape[-2], x.device 57 | 58 | q, k, v = self.to_qkv(x) 59 | 60 | if exists(stop_grad_attn_mask): 61 | if not isinstance(stop_grad_attn_mask, tuple): 62 | stop_grad_attn_mask = (None, stop_grad_attn_mask, stop_grad_attn_mask) 63 | 64 | assert len(stop_grad_attn_mask) == 3, 'stop_grad_attn_mask must be either a stop grad mask (implicit for key / values) or a tuple of 3 Tensor for individual stop grads of queries, keys, values' 65 | 66 | q_stop_grad, k_stop_grad, v_stop_grad = stop_grad_attn_mask 67 | 68 | out = stop_graddable_attn( 69 | q, k, v, 70 | attn_mask = attn_mask, 71 | q_stop_grad_mask = q_stop_grad, 72 | k_stop_grad_mask = k_stop_grad, 73 | v_stop_grad_mask = v_stop_grad 74 | ) 75 | 76 | else: 77 | q = q * self.scale 78 | sim = einsum(q, k, 'b h i d, b h j d -> b h i j') 79 | 80 | causal_mask = torch.ones((seq, seq), device = device, dtype = torch.bool).triu(1) 81 | 82 | mask_value = -torch.finfo(sim.dtype).max 83 | sim = sim.masked_fill(causal_mask, mask_value) 84 | 85 | if exists(attn_mask): 86 | sim = sim.masked_fill(~attn_mask, mask_value) 87 | 88 | attn = sim.softmax(dim = -1) 89 | 90 | out = einsum(attn, v, 'b h i j, b h j d -> b h i d') 91 | 92 | # combine heads 93 | 94 | return self.to_out(out) 95 | 96 | # transformer 97 | 98 | class Transformer(Module): 99 | def __init__( 100 | self, 101 | *, 102 | dim, 103 | num_tokens, 104 | depth, 105 | max_seq_len = 2048, 106 | max_reason_seq_len = 4, 107 | dim_head = 64, 108 | heads = 8, 109 | ignore_index = -1, 110 | stop_grad_next_tokens_to_reason = False 111 | ): 112 | super().__init__() 113 | self.max_seq_len = max_seq_len 114 | 115 | # embed 116 | 117 | self.token_emb = nn.Embedding(num_tokens, dim) 118 | self.pos_emb = nn.Embedding(max_seq_len, dim) 119 | 120 | # reasoning tokens 121 | 122 | self.max_reason_seq_len = max_reason_seq_len 123 | self.reason_tokens = nn.Parameter(torch.randn(max_reason_seq_len, dim)) 124 | nn.init.normal_(self.reason_tokens, std = 0.02) 125 | 126 | # transformer layers 127 | 128 | self.layers = ModuleList([]) 129 | for _ in range(depth): 130 | 131 | attn = CausalAttention( 132 | dim = dim, 133 | dim_head = dim_head, 134 | heads = heads 135 | ) 136 | 137 | ff = nn.Sequential( 138 | RMSNorm(dim), 139 | FeedForward(dim = dim) 140 | ) 141 | 142 | self.layers.append(ModuleList([attn, ff])) 143 | 144 | self.norm = RMSNorm(dim) 145 | self.to_logits = nn.Linear(dim, num_tokens, bias = False) 146 | 147 | # loss related 148 | 149 | self.ignore_index = ignore_index 150 | 151 | # stop gradient settings 152 | 153 | self.stop_grad_next_tokens_to_reason = stop_grad_next_tokens_to_reason 154 | 155 | def forward( 156 | self, 157 | x, 158 | num_reason_tokens = 0, 159 | num_steps_future_can_use_reason = 2, # how many positions into the future until a reason token can be attended to 160 | remove_reason_tokens_at_end = False, 161 | return_loss = False 162 | ): 163 | 164 | if return_loss: 165 | x, labels = x[:, :-1], x[:, 1:] 166 | 167 | batch, seq, device = *x.shape, x.device 168 | 169 | assert seq <= self.max_seq_len 170 | 171 | x = self.token_emb(x) 172 | 173 | seq_arange = torch.arange(seq, device = device) 174 | pos = self.pos_emb(seq_arange) 175 | 176 | attn_kwargs = dict() 177 | 178 | # intersperse reasoning tokens if needed 179 | 180 | has_reason_tokens = num_reason_tokens > 0 181 | 182 | if has_reason_tokens: 183 | assert num_reason_tokens <= self.max_reason_seq_len 184 | 185 | x = rearrange(x, 'b n d -> b n 1 d') 186 | 187 | reason_tokens = self.reason_tokens[:num_reason_tokens] 188 | reason_tokens = repeat(reason_tokens, 'r d -> b n r d', b = batch, n = seq) 189 | 190 | x = torch.cat((x, reason_tokens), dim = -2) 191 | x = rearrange(x, 'b n r d -> b (n r) d') 192 | 193 | # handle absolute positions 194 | # applied axially to reasoning tokens and main token 195 | 196 | num_tokens_per_timestep = num_reason_tokens + 1 197 | pos = repeat(pos, 'n d -> (n r) d', r = num_tokens_per_timestep) 198 | 199 | # handle masking for reasoning tokens 200 | # each reason token can only be attended to by tokens (+ future reasoning tokens) that are {num_steps_future_can_use_reason} 201 | 202 | seq_timesteps = repeat(seq_arange, 'n -> (n r)', r = num_tokens_per_timestep) 203 | 204 | seq_with_reason_range = torch.arange(seq_timesteps.shape[-1], device = device) 205 | is_reason_token_mask = ~(seq_with_reason_range % num_tokens_per_timestep == 0) 206 | 207 | q_range = rearrange(seq_timesteps, 'n -> n 1') 208 | k_range = rearrange(seq_timesteps, 'n -> 1 n') 209 | 210 | attn_mask = ~( 211 | is_reason_token_mask & 212 | (q_range > k_range) & 213 | ((q_range - num_steps_future_can_use_reason) <= k_range) 214 | ) 215 | 216 | # whether to fully mask out or stop gradient on attention matrix 217 | 218 | if self.stop_grad_next_tokens_to_reason: 219 | attn_kwargs = dict(stop_grad_attn_mask = ~attn_mask) 220 | else: 221 | attn_kwargs = dict(attn_mask = attn_mask) 222 | 223 | # attention and feedforward, passing in reason tokens mask from above 224 | 225 | x = x + pos 226 | 227 | for attn, ff in self.layers: 228 | x = attn(x, **attn_kwargs) + x 229 | x = ff(x) + x 230 | 231 | embed = self.norm(x) 232 | 233 | logits = self.to_logits(embed) 234 | 235 | # whether to remove reason tokens at the very end 236 | 237 | if has_reason_tokens and remove_reason_tokens_at_end: 238 | logits = rearrange(logits, 'b (n r) c -> b n r c', r = num_tokens_per_timestep) 239 | logits = logits[..., 0, :] 240 | 241 | if not return_loss: 242 | return logits 243 | 244 | if has_reason_tokens and not remove_reason_tokens_at_end: 245 | labels = repeat(labels, 'b n -> b (n r)', r = num_tokens_per_timestep) 246 | 247 | loss = F.cross_entropy( 248 | rearrange(logits, 'b n c -> b c n'), 249 | labels, 250 | ignore_index = self.ignore_index 251 | ) 252 | 253 | return loss 254 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name = 'self-reasoning-tokens-pytorch', 5 | packages = find_packages(exclude = []), 6 | version = '0.0.4', 7 | license='MIT', 8 | description = 'Self Reasoning Tokens', 9 | author = 'Phil Wang', 10 | author_email = 'lucidrains@gmail.com', 11 | long_description_content_type = 'text/markdown', 12 | url = 'https://github.com/lucidrains/self-reasoning-tokens-pytorch', 13 | keywords = [ 14 | 'artificial intelligence', 15 | 'deep learning', 16 | 'transformers', 17 | 'attention mechanism', 18 | 'adaptive computation' 19 | ], 20 | install_requires=[ 21 | 'einops>=0.8.0', 22 | 'x-transformers>=1.28.4', 23 | 'torch>=2.0', 24 | ], 25 | classifiers=[ 26 | 'Development Status :: 4 - Beta', 27 | 'Intended Audience :: Developers', 28 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 29 | 'License :: OSI Approved :: MIT License', 30 | 'Programming Language :: Python :: 3.7', 31 | ], 32 | ) 33 | --------------------------------------------------------------------------------