├── .github └── workflows │ └── python-publish.yml ├── .gitignore ├── LICENSE ├── PEER_pytorch ├── ChunkedPEER.py ├── PEER.py ├── PEERLora.py ├── PK.py ├── PKAttention.py └── __init__.py ├── README.md ├── peer.png ├── peer2.png └── pyproject.toml /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | jobs: 16 | deploy: 17 | 18 | runs-on: ubuntu-latest 19 | 20 | steps: 21 | - uses: actions/checkout@v2 22 | - name: Set up Python 23 | uses: actions/setup-python@v2 24 | with: 25 | python-version: '3.x' 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | pip install build 30 | - name: Build package 31 | run: python -m build 32 | - name: Publish package 33 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 34 | with: 35 | user: __token__ 36 | password: ${{ secrets.PYPI_API_TOKEN }} 37 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /PEER_pytorch/ChunkedPEER.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Module 3 | 4 | from functools import partial 5 | from torch.utils.checkpoint import checkpoint 6 | 7 | from PEER_pytorch.PEER import PEER 8 | from PEER_pytorch.PEERLora import PEERLora 9 | 10 | class ChunkedPEER(Module): 11 | def __init__( 12 | self, 13 | peer: PEER | PEERLora, 14 | seq_chunk_size: int = 128 15 | ): 16 | super().__init__() 17 | self.peer = peer 18 | self.seq_chunk_size = seq_chunk_size 19 | 20 | def forward( 21 | self, 22 | x 23 | ): 24 | peer = self.peer 25 | 26 | if self.training and x.requires_grad: 27 | peer = partial(checkpoint, peer) 28 | 29 | out = [] 30 | for chunk in x.split(self.seq_chunk_size, dim = 1): 31 | chunk_out = peer(chunk) 32 | out.append(chunk_out) 33 | 34 | return torch.cat(out, dim = 1) 35 | 36 | # quick test 37 | 38 | if __name__ == '__main__': 39 | peer = PEER(dim = 512, heads = 8).cuda() 40 | 41 | peer = ChunkedPEER(peer) 42 | 43 | x = torch.randn(1, 1024, 512).cuda().requires_grad_() 44 | 45 | out = peer(x) + x 46 | 47 | out.sum().backward() 48 | -------------------------------------------------------------------------------- /PEER_pytorch/PEER.py: -------------------------------------------------------------------------------- 1 | from math import sqrt 2 | 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | from torch.nn import Module, ModuleList 7 | 8 | import einx 9 | from einops import einsum 10 | from einops.layers.torch import Rearrange 11 | 12 | # helper functions 13 | 14 | def exists(v): 15 | return v is not None 16 | 17 | def default(v, d): 18 | return v if exists(v) else d 19 | 20 | # rmsnorm 21 | 22 | class RMSNorm(Module): 23 | def __init__(self, dim): 24 | super().__init__() 25 | self.scale = dim ** 0.5 26 | self.gamma = nn.Parameter(torch.zeros(dim)) 27 | 28 | def forward(self, x): 29 | return F.normalize(x, dim = -1) * self.scale * (self.gamma + 1) 30 | 31 | # main class 32 | 33 | class PEER(Module): 34 | """ 35 | following Algorithm 1 in the paper 36 | """ 37 | 38 | def __init__( 39 | self, 40 | dim, 41 | *, 42 | heads = 8, # tested up to 32 - (hk = heads * num_experts_per_head (16)) - for non-competing scores, increase number of heads to desired value for the inner dimension of the hypernetwork mlp 43 | num_experts = 1_000_000, # he chose 1 million 44 | num_experts_per_head = 16, # he settled on 16, but was 32 in PKM paper 45 | activation = nn.GELU, 46 | dim_key = None, 47 | product_key_topk = None, 48 | separate_embed_per_head = False, # @smerky notes that heads may retrieve same redundant neurons. this setting would allow for separate embeds per head and prevent that 49 | pre_rmsnorm = False, 50 | non_competing_scores = True, 51 | dropout = 0. 52 | ): 53 | """ 54 | einops notation 55 | b - batch 56 | n - sequence 57 | d - dimension 58 | h - heads 59 | p - 2 for product key 60 | k - number of keys 61 | """ 62 | 63 | super().__init__() 64 | 65 | self.norm = RMSNorm(dim) if pre_rmsnorm else nn.Identity() 66 | 67 | # whether to do separate embedding per head 68 | 69 | num_expert_sets = 1 if not separate_embed_per_head else heads 70 | 71 | self.heads = heads 72 | self.separate_embed_per_head = separate_embed_per_head 73 | self.num_experts = num_experts 74 | 75 | # experts that will form the mlp project in / out weights 76 | 77 | self.weight_down_embed = nn.Embedding(num_experts * num_expert_sets, dim) 78 | self.weight_up_embed = nn.Embedding(num_experts * num_expert_sets, dim) 79 | 80 | # activation function, defaults to gelu 81 | 82 | self.activation = activation() 83 | 84 | # queries and keys for product-key 85 | 86 | assert sqrt(num_experts).is_integer(), '`num_experts` needs to be a square' 87 | assert (dim % 2) == 0, 'feature dimension should be divisible by 2' 88 | 89 | dim_key = default(dim_key, dim // 2) 90 | self.num_keys = int(sqrt(num_experts)) 91 | 92 | self.to_queries = nn.Sequential( 93 | nn.Linear(dim, dim_key * heads * 2, bias = False), 94 | Rearrange('b n (p h d) -> p b n h d', p = 2, h = heads) 95 | ) 96 | 97 | self.product_key_topk = default(product_key_topk, num_experts_per_head) 98 | self.num_experts_per_head_topk = num_experts_per_head if not non_competing_scores else 1 99 | 100 | self.keys = nn.Parameter(torch.zeros(heads, self.num_keys, 2, dim_key)) 101 | nn.init.normal_(self.keys, std = 0.02) 102 | 103 | # dropout 104 | 105 | self.dropout = nn.Dropout(dropout) 106 | 107 | # whether to use softmax on scores 108 | 109 | # Csordas et al claims non-competing activation helps in PKM setting 110 | # https://arxiv.org/pdf/2310.10837 - Table 2 in Section 6.2 111 | 112 | self.score_activation = nn.Softmax(dim = -1) if not non_competing_scores else nn.ReLU() 113 | 114 | def forward( 115 | self, 116 | x 117 | ): 118 | 119 | x = self.norm(x) 120 | 121 | # queries 122 | 123 | queries = self.to_queries(x) 124 | 125 | # first get similarity with keys 126 | 127 | sim = einsum(queries, self.keys, 'p b n h d, h k p d -> p b n h k') 128 | 129 | # product key logic 130 | 131 | (scores_x, scores_y), (indices_x, indices_y) = sim.topk(self.product_key_topk, dim = -1) 132 | 133 | all_scores = einx.add('... i, ... j -> ... (i j)', scores_x, scores_y) 134 | all_indices = einx.add('... i, ... j -> ... (i j)', indices_x * self.num_keys, indices_y) 135 | 136 | scores, pk_indices = all_scores.topk(self.num_experts_per_head_topk, dim = -1) 137 | 138 | indices = all_indices.gather(-1, pk_indices) 139 | 140 | # if separate embeds per head, add appropriate offsets per head 141 | 142 | if self.separate_embed_per_head: 143 | head_expert_offsets = torch.arange(self.heads, device = x.device) * self.num_experts 144 | indices = einx.add('b n h k, h -> b n h k', indices, head_expert_offsets) 145 | 146 | # build the weight matrices for projecting in and out 147 | # basically the experts are the gathered parameters for an MLP 148 | 149 | weights_down = self.weight_down_embed(indices) 150 | weights_up = self.weight_up_embed(indices) 151 | 152 | # below is basically Algorithm 1 in paper 153 | 154 | x = einsum(x, weights_down, 'b n d, b n h k d -> b n h k') 155 | 156 | x = self.activation(x) 157 | x = self.dropout(x) 158 | 159 | x = x * self.score_activation(scores) 160 | 161 | x = einsum(x, weights_up, 'b n h k, b n h k d -> b n d') 162 | 163 | return x 164 | -------------------------------------------------------------------------------- /PEER_pytorch/PEERLora.py: -------------------------------------------------------------------------------- 1 | from math import sqrt 2 | 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | from torch.nn import Module, ModuleList 7 | 8 | import einx 9 | from einops import einsum 10 | from einops.layers.torch import Rearrange 11 | 12 | # helper functions 13 | 14 | def exists(v): 15 | return v is not None 16 | 17 | def default(v, d): 18 | return v if exists(v) else d 19 | 20 | # rmsnorm 21 | 22 | class RMSNorm(Module): 23 | def __init__(self, dim): 24 | super().__init__() 25 | self.scale = dim ** 0.5 26 | self.gamma = nn.Parameter(torch.zeros(dim)) 27 | 28 | def forward(self, x): 29 | return F.normalize(x, dim = -1) * self.scale * (self.gamma + 1) 30 | 31 | # main class 32 | 33 | class PEERLora(Module): 34 | """ 35 | Same as PEER, except it retrieves LORA weights and adds them to a usual feedforward weight1 and weight2 matrices 36 | """ 37 | 38 | def __init__( 39 | self, 40 | dim, 41 | *, 42 | expansion_factor = 2., 43 | num_experts = 1_000_000, # 1 million experts 44 | heads = 4, # the lora k dimension is kept at 16 (heads [4] * num_experts_per_head [4]) 45 | num_experts_per_head = 4, 46 | activation = nn.GELU, 47 | dim_key = None, 48 | product_key_topk = None, 49 | pre_rmsnorm = False, 50 | non_competing_scores = True, 51 | dropout = 0. 52 | ): 53 | """ 54 | einops notation 55 | b - batch 56 | n - sequence 57 | d - dimension 58 | h - heads 59 | p - 2 for product key 60 | k - number of keys 61 | """ 62 | 63 | super().__init__() 64 | dim_inner = int(dim * expansion_factor) 65 | 66 | self.norm = RMSNorm(dim) if pre_rmsnorm else nn.Identity() 67 | 68 | # heads and num experts 69 | 70 | self.heads = heads 71 | self.num_experts_per_head = num_experts_per_head 72 | 73 | self.num_experts = num_experts 74 | 75 | # usual feedforward weights without bias 76 | 77 | self.proj_in = nn.Linear(dim, dim_inner, bias = False) 78 | self.proj_out = nn.Linear(dim_inner, dim, bias = False) 79 | 80 | # experts that will form the mlp project in / out weights 81 | 82 | self.proj_in_lora_a = nn.Embedding(num_experts, dim) 83 | self.proj_in_lora_b = nn.Embedding(num_experts, dim_inner) 84 | 85 | self.proj_out_lora_a = nn.Embedding(num_experts, dim_inner) 86 | self.proj_out_lora_b = nn.Embedding(num_experts, dim) 87 | 88 | # activation function, defaults to gelu 89 | 90 | self.activation = activation() 91 | 92 | # queries and keys for product-key 93 | 94 | assert sqrt(num_experts).is_integer(), '`num_experts` needs to be a square' 95 | assert (dim % 2) == 0, 'feature dimension should be divisible by 2' 96 | 97 | dim_key = default(dim_key, dim // 2) 98 | self.num_keys = int(sqrt(num_experts)) 99 | 100 | self.to_queries = nn.Sequential( 101 | nn.Linear(dim, dim_key * heads * 2, bias = False), 102 | Rearrange('b n (p h d) -> p b n h d', p = 2, h = heads) 103 | ) 104 | 105 | self.product_key_topk = default(product_key_topk, num_experts_per_head) 106 | self.num_experts_per_head_topk = num_experts_per_head if not non_competing_scores else 1 107 | 108 | self.keys = nn.Parameter(torch.zeros(heads, self.num_keys, 2, dim_key)) 109 | nn.init.normal_(self.keys, std = 0.02) 110 | 111 | # dropout 112 | 113 | self.dropout = nn.Dropout(dropout) 114 | 115 | # whether to use softmax on scores 116 | 117 | # Csordas et al claims non-competing activation helps in PKM setting 118 | # https://arxiv.org/pdf/2310.10837 - Table 2 in Section 6.2 119 | 120 | self.score_activation = nn.Softmax(dim = -1) if not non_competing_scores else nn.ReLU() 121 | 122 | # init 123 | 124 | nn.init.normal_(self.proj_in_lora_a.weight, std = 0.02) 125 | nn.init.normal_(self.proj_out_lora_b.weight, std = 0.02) 126 | nn.init.normal_(self.proj_in_lora_b.weight, std = 0.02) 127 | nn.init.normal_(self.proj_out_lora_a.weight, std = 0.02) 128 | 129 | @property 130 | def lora_k(self): 131 | return self.heads * self.num_experts_per_head 132 | 133 | def forward( 134 | self, 135 | x 136 | ): 137 | 138 | x = self.norm(x) 139 | 140 | # queries 141 | 142 | queries = self.to_queries(x) 143 | 144 | # first get similarity with keys 145 | 146 | sim = einsum(queries, self.keys, 'p b n h d, h k p d -> p b n h k') 147 | 148 | # product key logic 149 | 150 | (scores_x, scores_y), (indices_x, indices_y) = sim.topk(self.product_key_topk, dim = -1) 151 | 152 | all_scores = einx.add('... i, ... j -> ... (i j)', scores_x, scores_y) 153 | all_indices = einx.add('... i, ... j -> ... (i j)', indices_x * self.num_keys, indices_y) 154 | 155 | scores, pk_indices = all_scores.topk(self.num_experts_per_head_topk, dim = -1) 156 | 157 | indices = all_indices.gather(-1, pk_indices) 158 | 159 | # build the loras for projecting in and out weights of a feedforward 160 | 161 | proj_in_lora_a = self.proj_in_lora_a(indices) 162 | proj_in_lora_b = self.proj_in_lora_b(indices) 163 | 164 | proj_out_lora_a = self.proj_out_lora_a(indices) 165 | proj_out_lora_b = self.proj_out_lora_b(indices) 166 | 167 | # feedforward, but with expert loras chosen by pk 168 | 169 | # project in 170 | 171 | hidden = self.proj_in(x) 172 | 173 | lora_in_hidden = einsum(x, proj_in_lora_a, 'b n d, b n h k d -> b n h k') 174 | lora_in_hidden = lora_in_hidden * self.score_activation(scores) 175 | lora_in_hidden = einsum(lora_in_hidden, proj_in_lora_b, 'b n h k, b n h k d -> b n d') 176 | 177 | hidden = hidden + lora_in_hidden 178 | 179 | # gelu and dropout 180 | 181 | hidden = self.activation(hidden) 182 | hidden = self.dropout(hidden) 183 | 184 | # project out 185 | 186 | out = self.proj_out(hidden) 187 | 188 | lora_out_hidden = einsum(hidden, proj_out_lora_a, 'b n d, b n h k d -> b n h k') 189 | lora_out_hidden = lora_out_hidden * self.score_activation(scores) 190 | lora_out_hidden = einsum(lora_out_hidden, proj_out_lora_b, 'b n h k, b n h k d -> b n d') 191 | 192 | out = out + lora_out_hidden 193 | 194 | return out 195 | -------------------------------------------------------------------------------- /PEER_pytorch/PK.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import Module 4 | 5 | import einx 6 | from einops.layers.torch import Rearrange 7 | from einops import einsum 8 | 9 | # helper functions 10 | 11 | def exists(v): 12 | return v is not None 13 | 14 | def default(v, d): 15 | return v if exists(v) else d 16 | 17 | # main class 18 | 19 | class PK(Module): 20 | def __init__( 21 | self, 22 | dim, 23 | *, 24 | heads = 8, 25 | dim_key = None, 26 | num_keys = 1_000, 27 | product_keys = 2, 28 | product_key_topk = None, 29 | final_topk = 16, 30 | num_experts_per_head = 16 31 | ): 32 | """ 33 | einops notation 34 | b - batch 35 | n - sequence 36 | d - dimension 37 | h - heads 38 | p - product keys 39 | k - number of keys 40 | """ 41 | 42 | super().__init__() 43 | assert (dim % 2) == 0 44 | dim_key = default(dim_key, dim // 2) 45 | 46 | self.to_queries = nn.Sequential( 47 | nn.Linear(dim, dim_key * product_keys * heads, bias = False), 48 | Rearrange('b n (p h d) -> p b n h d', h = heads, p = product_keys) 49 | ) 50 | 51 | self.num_keys = num_keys 52 | self.product_keys = product_keys 53 | 54 | self.keys = nn.Parameter(torch.zeros(product_keys, num_keys, heads, dim_key)) 55 | nn.init.normal_(self.keys, std = 0.02) 56 | 57 | product_key_topk = default(product_key_topk, final_topk) 58 | assert final_topk <= (product_key_topk ** product_keys) 59 | 60 | self.topk = product_key_topk 61 | self.final_topk = final_topk 62 | 63 | # the maximum index, or the total space being indexed into 64 | 65 | self.max_index = int(num_keys ** product_keys) 66 | 67 | def forward( 68 | self, 69 | x, 70 | softmax_scores = False 71 | ): 72 | 73 | queries = self.to_queries(x) 74 | 75 | sim = einsum(queries, self.keys, 'p b n h d, p k h d -> p b n h k') 76 | 77 | scores, indices = sim.topk(self.topk, dim = -1) 78 | 79 | # cartesian product indices 80 | 81 | strides = self.num_keys ** torch.arange(self.product_keys, device = x.device) 82 | indices = einx.multiply('p ..., p -> p ...', indices, strides) 83 | 84 | index, *rest_indices = indices 85 | 86 | for rest_index in rest_indices: 87 | index = einx.add('... i, ... j -> ... (i j)', index, rest_index) 88 | 89 | # cartesian product score 90 | 91 | score, *rest_scores = scores 92 | 93 | for rest_score in rest_scores: 94 | score = einx.add('... i, ... j -> ... (i j)', score, rest_score) 95 | 96 | final_scores, final_indices = score, index 97 | 98 | # final topk 99 | 100 | final_scores, pk_indices = final_scores.topk(self.final_topk, dim = -1) 101 | 102 | final_indices = final_indices.gather(-1, pk_indices) 103 | 104 | if softmax_scores: 105 | final_scores = final_scores.softmax(dim = -1) 106 | 107 | return final_scores, final_indices 108 | 109 | 110 | if __name__ == '__main__': 111 | 112 | pk = PK( 113 | dim = 512, 114 | num_keys = 100, 115 | final_topk = 10, 116 | product_keys = 3 117 | ) 118 | 119 | x = torch.randn(2, 1024, 512) 120 | score, indices = pk(x) 121 | 122 | assert score.shape == (2, 1024, 8, 10) 123 | assert indices.shape == (2, 1024, 8, 10) 124 | -------------------------------------------------------------------------------- /PEER_pytorch/PKAttention.py: -------------------------------------------------------------------------------- 1 | from math import sqrt 2 | 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | from torch.nn import Module, ModuleList 7 | 8 | import einx 9 | from einops import einsum, pack, unpack 10 | from einops.layers.torch import Rearrange 11 | 12 | from PEER_pytorch.PK import PK 13 | 14 | # helper functions 15 | 16 | def exists(v): 17 | return v is not None 18 | 19 | def default(v, d): 20 | return v if exists(v) else d 21 | 22 | def pack_one(t, pattern): 23 | return pack([t], pattern) 24 | 25 | def unpack_one(t, ps, pattern): 26 | return unpack(t, ps, pattern)[0] 27 | 28 | # rmsnorm 29 | 30 | class RMSNorm(Module): 31 | def __init__(self, dim): 32 | super().__init__() 33 | self.scale = dim ** 0.5 34 | self.gamma = nn.Parameter(torch.zeros(dim)) 35 | 36 | def forward(self, x): 37 | return F.normalize(x, dim = -1) * self.scale * (self.gamma + 1) 38 | 39 | # main class 40 | 41 | class PKAttention(Module): 42 | def __init__( 43 | self, 44 | dim, 45 | *, 46 | causal = True, 47 | heads = 8, 48 | num_key_values = 1_000_000, 49 | key_value_pk_topk = 16, 50 | dim_key = None, 51 | product_keys = 2, 52 | pre_rmsnorm = False, 53 | dropout = 0. 54 | ): 55 | """ 56 | einops notation 57 | b - batch 58 | n - sequence 59 | d - dimension 60 | h - heads 61 | p - 2 for product key 62 | k - number of keys 63 | """ 64 | 65 | super().__init__() 66 | self.causal = causal 67 | self.heads = heads 68 | self.num_key_values = num_key_values 69 | 70 | self.norm = RMSNorm(dim) if pre_rmsnorm else nn.Identity() 71 | 72 | # experts that will form the mlp project in / out weights 73 | 74 | self.to_queries = nn.Sequential( 75 | nn.Linear(dim, dim * heads, bias = False), 76 | Rearrange('b n (h d) -> b n h d', h = heads) 77 | ) 78 | 79 | # keys and values selected using product-key 80 | 81 | self.keys = nn.EmbeddingBag(num_key_values * heads, dim, mode = 'sum') 82 | self.values = nn.EmbeddingBag(num_key_values * heads, dim, mode = 'sum') 83 | 84 | assert sqrt(num_key_values).is_integer(), '`num_key_values` needs to be a square' 85 | assert (dim % 2) == 0, 'feature dimension should be divisible by 2' 86 | 87 | self.to_kv_pk_indices = PK( 88 | dim = dim, 89 | num_keys = int(sqrt(num_key_values)), 90 | final_topk = key_value_pk_topk, 91 | product_keys = product_keys 92 | ) 93 | 94 | # dropout 95 | 96 | self.dropout = nn.Dropout(dropout) 97 | 98 | # output 99 | 100 | self.to_out = nn.Sequential( 101 | Rearrange('b h n d -> b n (h d)'), 102 | nn.Linear(dim * heads, dim, bias = False) 103 | ) 104 | 105 | def forward( 106 | self, 107 | x, 108 | mask = None 109 | ): 110 | device = x.device 111 | 112 | x = self.norm(x) 113 | 114 | # queries 115 | 116 | q = self.to_queries(x) 117 | 118 | q = q * (q.shape[-1] ** -0.5) 119 | 120 | # keys and values 121 | 122 | kv_scores, indices = self.to_kv_pk_indices(x, softmax_scores = True) 123 | 124 | offsets = torch.arange(self.heads, device = device) * self.num_key_values 125 | indices = einx.add('b n h k, h -> b n h k', indices, offsets) 126 | 127 | indices, packed_shape = pack_one(indices, '* k') 128 | kv_scores, _ = pack_one(kv_scores, '* k') 129 | 130 | k, v = self.keys(indices, per_sample_weights = kv_scores), self.values(indices, per_sample_weights = kv_scores) 131 | 132 | k = unpack_one(k, packed_shape, '* d') 133 | v = unpack_one(v, packed_shape, '* d') 134 | 135 | # usual multihead self attention 136 | 137 | sim = einsum(q, k, 'b i h d, b j h d -> b h i j') 138 | 139 | # whether causal or not 140 | 141 | if self.causal: 142 | assert not exists(mask) 143 | i, j, device = *sim.shape[-2:], x.device 144 | causal_mask = torch.ones((i, j), device = device, dtype = torch.bool).triu(j - i + 1) 145 | sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max) 146 | 147 | elif exists(mask): 148 | sim = einx.where('b j, b h i j, -> b h i j', mask, sim, -torch.finfo(sim.dtype).max) 149 | 150 | # attention 151 | 152 | attn = sim.softmax(dim = -1) 153 | attn = self.dropout(attn) 154 | 155 | # aggregate 156 | 157 | out = einsum(attn, v, 'b h i j, b j h d -> b h i d') 158 | 159 | # combine heads 160 | 161 | return self.to_out(out) 162 | 163 | # main 164 | 165 | if __name__ == '__main__': 166 | peer_attn = PKAttention( 167 | dim = 256, 168 | causal = True, 169 | heads = 8, 170 | num_key_values = int(1e4), 171 | pre_rmsnorm = True 172 | ) 173 | 174 | x = torch.randn(2, 512, 256) 175 | 176 | out = peer_attn(x) + x 177 | 178 | assert x.shape == out.shape 179 | -------------------------------------------------------------------------------- /PEER_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from PEER_pytorch.PEER import PEER 2 | from PEER_pytorch.PEERLora import PEERLora 3 | 4 | from PEER_pytorch.ChunkedPEER import ChunkedPEER 5 | 6 | from PEER_pytorch.PK import PK 7 | from PEER_pytorch.PKAttention import PKAttention 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | ## PEER - Pytorch 6 | 7 | Pytorch implementation of the PEER block from the Deepmind paper, Mixture of A Million Experts, by Xu Owen He. 8 | 9 | ## Install 10 | 11 | ```bash 12 | $ pip install PEER-pytorch 13 | ``` 14 | 15 | ## Usage 16 | 17 | ```python 18 | import torch 19 | from PEER_pytorch import PEER 20 | 21 | peer = PEER( 22 | dim = 512, 23 | heads = 8, # tested up to 32 - (hk = heads * num_experts_per_head (16)) 24 | num_experts = 1_000_000, # he chose 1 million 25 | num_experts_per_head = 16, # he settled on 16, but was 32 in PKM paper 26 | dim_key = 128, 27 | pre_rmsnorm = True 28 | ).cuda() 29 | 30 | x = torch.randn(2, 1024, 512).cuda() 31 | 32 | out = peer(x) + x 33 | 34 | assert x.shape == out.shape 35 | ``` 36 | 37 | ## Citations 38 | 39 | ```bibtex 40 | @inproceedings{He2024MixtureOA, 41 | title = {Mixture of A Million Experts}, 42 | author = {Xu Owen He}, 43 | year = {2024}, 44 | url = {https://api.semanticscholar.org/CorpusID:271038610} 45 | } 46 | ``` 47 | 48 | ```bibtex 49 | @article{Csordas2023ApproximatingTF, 50 | title = {Approximating Two-Layer Feedforward Networks for Efficient Transformers}, 51 | author = {R'obert Csord'as and Kazuki Irie and J{\"u}rgen Schmidhuber}, 52 | journal = {ArXiv}, 53 | year = {2023}, 54 | volume = {abs/2310.10837}, 55 | url = {https://api.semanticscholar.org/CorpusID:264172384} 56 | } 57 | ``` 58 | -------------------------------------------------------------------------------- /peer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/PEER-pytorch/6cb671f997c38ad19b27838f3f1a6aacb0a373f4/peer.png -------------------------------------------------------------------------------- /peer2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/PEER-pytorch/6cb671f997c38ad19b27838f3f1a6aacb0a373f4/peer2.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "PEER-pytorch" 3 | version = "0.2.1" 4 | description = "PEER - Pytorch" 5 | authors = [ 6 | { name = "Phil Wang", email = "lucidrains@gmail.com" } 7 | ] 8 | readme = "README.md" 9 | requires-python = ">= 3.8" 10 | license = { file = "LICENSE" } 11 | keywords = [ 12 | 'artificial intelligence', 13 | 'deep learning', 14 | 'product key', 15 | 'mixture of experts', 16 | ] 17 | classifiers=[ 18 | 'Development Status :: 4 - Beta', 19 | 'Intended Audience :: Developers', 20 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 21 | 'License :: OSI Approved :: MIT License', 22 | 'Programming Language :: Python :: 3.8', 23 | ] 24 | 25 | dependencies = [ 26 | 'einops>=0.8.0', 27 | 'einx>=0.3.0', 28 | 'torch>=2.0', 29 | ] 30 | 31 | [project.urls] 32 | Homepage = "https://pypi.org/project/PEER-pytorch/" 33 | Repository = "https://github.com/lucidrains/PEER-pytorch" 34 | 35 | [build-system] 36 | requires = ["hatchling"] 37 | build-backend = "hatchling.build" 38 | 39 | 40 | [tool.hatch.metadata] 41 | allow-direct-references = true 42 | 43 | [tool.hatch.build.targets.wheel] 44 | packages = ["PEER_pytorch"] 45 | --------------------------------------------------------------------------------