├── .github
└── workflows
│ └── python-publish.yml
├── .gitignore
├── LICENSE
├── PEER_pytorch
├── ChunkedPEER.py
├── PEER.py
├── PEERLora.py
├── PK.py
├── PKAttention.py
└── __init__.py
├── README.md
├── peer.png
├── peer2.png
└── pyproject.toml
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | # This workflow will upload a Python Package using Twine when a release is created
2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3 |
4 | # This workflow uses actions that are not certified by GitHub.
5 | # They are provided by a third-party and are governed by
6 | # separate terms of service, privacy policy, and support
7 | # documentation.
8 |
9 | name: Upload Python Package
10 |
11 | on:
12 | release:
13 | types: [published]
14 |
15 | jobs:
16 | deploy:
17 |
18 | runs-on: ubuntu-latest
19 |
20 | steps:
21 | - uses: actions/checkout@v2
22 | - name: Set up Python
23 | uses: actions/setup-python@v2
24 | with:
25 | python-version: '3.x'
26 | - name: Install dependencies
27 | run: |
28 | python -m pip install --upgrade pip
29 | pip install build
30 | - name: Build package
31 | run: python -m build
32 | - name: Publish package
33 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
34 | with:
35 | user: __token__
36 | password: ${{ secrets.PYPI_API_TOKEN }}
37 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110 | .pdm.toml
111 | .pdm-python
112 | .pdm-build/
113 |
114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115 | __pypackages__/
116 |
117 | # Celery stuff
118 | celerybeat-schedule
119 | celerybeat.pid
120 |
121 | # SageMath parsed files
122 | *.sage.py
123 |
124 | # Environments
125 | .env
126 | .venv
127 | env/
128 | venv/
129 | ENV/
130 | env.bak/
131 | venv.bak/
132 |
133 | # Spyder project settings
134 | .spyderproject
135 | .spyproject
136 |
137 | # Rope project settings
138 | .ropeproject
139 |
140 | # mkdocs documentation
141 | /site
142 |
143 | # mypy
144 | .mypy_cache/
145 | .dmypy.json
146 | dmypy.json
147 |
148 | # Pyre type checker
149 | .pyre/
150 |
151 | # pytype static type analyzer
152 | .pytype/
153 |
154 | # Cython debug symbols
155 | cython_debug/
156 |
157 | # PyCharm
158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160 | # and can be added to the global gitignore or merged into this file. For a more nuclear
161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162 | #.idea/
163 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Phil Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/PEER_pytorch/ChunkedPEER.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import Module
3 |
4 | from functools import partial
5 | from torch.utils.checkpoint import checkpoint
6 |
7 | from PEER_pytorch.PEER import PEER
8 | from PEER_pytorch.PEERLora import PEERLora
9 |
10 | class ChunkedPEER(Module):
11 | def __init__(
12 | self,
13 | peer: PEER | PEERLora,
14 | seq_chunk_size: int = 128
15 | ):
16 | super().__init__()
17 | self.peer = peer
18 | self.seq_chunk_size = seq_chunk_size
19 |
20 | def forward(
21 | self,
22 | x
23 | ):
24 | peer = self.peer
25 |
26 | if self.training and x.requires_grad:
27 | peer = partial(checkpoint, peer)
28 |
29 | out = []
30 | for chunk in x.split(self.seq_chunk_size, dim = 1):
31 | chunk_out = peer(chunk)
32 | out.append(chunk_out)
33 |
34 | return torch.cat(out, dim = 1)
35 |
36 | # quick test
37 |
38 | if __name__ == '__main__':
39 | peer = PEER(dim = 512, heads = 8).cuda()
40 |
41 | peer = ChunkedPEER(peer)
42 |
43 | x = torch.randn(1, 1024, 512).cuda().requires_grad_()
44 |
45 | out = peer(x) + x
46 |
47 | out.sum().backward()
48 |
--------------------------------------------------------------------------------
/PEER_pytorch/PEER.py:
--------------------------------------------------------------------------------
1 | from math import sqrt
2 |
3 | import torch
4 | from torch import nn
5 | import torch.nn.functional as F
6 | from torch.nn import Module, ModuleList
7 |
8 | import einx
9 | from einops import einsum
10 | from einops.layers.torch import Rearrange
11 |
12 | # helper functions
13 |
14 | def exists(v):
15 | return v is not None
16 |
17 | def default(v, d):
18 | return v if exists(v) else d
19 |
20 | # rmsnorm
21 |
22 | class RMSNorm(Module):
23 | def __init__(self, dim):
24 | super().__init__()
25 | self.scale = dim ** 0.5
26 | self.gamma = nn.Parameter(torch.zeros(dim))
27 |
28 | def forward(self, x):
29 | return F.normalize(x, dim = -1) * self.scale * (self.gamma + 1)
30 |
31 | # main class
32 |
33 | class PEER(Module):
34 | """
35 | following Algorithm 1 in the paper
36 | """
37 |
38 | def __init__(
39 | self,
40 | dim,
41 | *,
42 | heads = 8, # tested up to 32 - (hk = heads * num_experts_per_head (16)) - for non-competing scores, increase number of heads to desired value for the inner dimension of the hypernetwork mlp
43 | num_experts = 1_000_000, # he chose 1 million
44 | num_experts_per_head = 16, # he settled on 16, but was 32 in PKM paper
45 | activation = nn.GELU,
46 | dim_key = None,
47 | product_key_topk = None,
48 | separate_embed_per_head = False, # @smerky notes that heads may retrieve same redundant neurons. this setting would allow for separate embeds per head and prevent that
49 | pre_rmsnorm = False,
50 | non_competing_scores = True,
51 | dropout = 0.
52 | ):
53 | """
54 | einops notation
55 | b - batch
56 | n - sequence
57 | d - dimension
58 | h - heads
59 | p - 2 for product key
60 | k - number of keys
61 | """
62 |
63 | super().__init__()
64 |
65 | self.norm = RMSNorm(dim) if pre_rmsnorm else nn.Identity()
66 |
67 | # whether to do separate embedding per head
68 |
69 | num_expert_sets = 1 if not separate_embed_per_head else heads
70 |
71 | self.heads = heads
72 | self.separate_embed_per_head = separate_embed_per_head
73 | self.num_experts = num_experts
74 |
75 | # experts that will form the mlp project in / out weights
76 |
77 | self.weight_down_embed = nn.Embedding(num_experts * num_expert_sets, dim)
78 | self.weight_up_embed = nn.Embedding(num_experts * num_expert_sets, dim)
79 |
80 | # activation function, defaults to gelu
81 |
82 | self.activation = activation()
83 |
84 | # queries and keys for product-key
85 |
86 | assert sqrt(num_experts).is_integer(), '`num_experts` needs to be a square'
87 | assert (dim % 2) == 0, 'feature dimension should be divisible by 2'
88 |
89 | dim_key = default(dim_key, dim // 2)
90 | self.num_keys = int(sqrt(num_experts))
91 |
92 | self.to_queries = nn.Sequential(
93 | nn.Linear(dim, dim_key * heads * 2, bias = False),
94 | Rearrange('b n (p h d) -> p b n h d', p = 2, h = heads)
95 | )
96 |
97 | self.product_key_topk = default(product_key_topk, num_experts_per_head)
98 | self.num_experts_per_head_topk = num_experts_per_head if not non_competing_scores else 1
99 |
100 | self.keys = nn.Parameter(torch.zeros(heads, self.num_keys, 2, dim_key))
101 | nn.init.normal_(self.keys, std = 0.02)
102 |
103 | # dropout
104 |
105 | self.dropout = nn.Dropout(dropout)
106 |
107 | # whether to use softmax on scores
108 |
109 | # Csordas et al claims non-competing activation helps in PKM setting
110 | # https://arxiv.org/pdf/2310.10837 - Table 2 in Section 6.2
111 |
112 | self.score_activation = nn.Softmax(dim = -1) if not non_competing_scores else nn.ReLU()
113 |
114 | def forward(
115 | self,
116 | x
117 | ):
118 |
119 | x = self.norm(x)
120 |
121 | # queries
122 |
123 | queries = self.to_queries(x)
124 |
125 | # first get similarity with keys
126 |
127 | sim = einsum(queries, self.keys, 'p b n h d, h k p d -> p b n h k')
128 |
129 | # product key logic
130 |
131 | (scores_x, scores_y), (indices_x, indices_y) = sim.topk(self.product_key_topk, dim = -1)
132 |
133 | all_scores = einx.add('... i, ... j -> ... (i j)', scores_x, scores_y)
134 | all_indices = einx.add('... i, ... j -> ... (i j)', indices_x * self.num_keys, indices_y)
135 |
136 | scores, pk_indices = all_scores.topk(self.num_experts_per_head_topk, dim = -1)
137 |
138 | indices = all_indices.gather(-1, pk_indices)
139 |
140 | # if separate embeds per head, add appropriate offsets per head
141 |
142 | if self.separate_embed_per_head:
143 | head_expert_offsets = torch.arange(self.heads, device = x.device) * self.num_experts
144 | indices = einx.add('b n h k, h -> b n h k', indices, head_expert_offsets)
145 |
146 | # build the weight matrices for projecting in and out
147 | # basically the experts are the gathered parameters for an MLP
148 |
149 | weights_down = self.weight_down_embed(indices)
150 | weights_up = self.weight_up_embed(indices)
151 |
152 | # below is basically Algorithm 1 in paper
153 |
154 | x = einsum(x, weights_down, 'b n d, b n h k d -> b n h k')
155 |
156 | x = self.activation(x)
157 | x = self.dropout(x)
158 |
159 | x = x * self.score_activation(scores)
160 |
161 | x = einsum(x, weights_up, 'b n h k, b n h k d -> b n d')
162 |
163 | return x
164 |
--------------------------------------------------------------------------------
/PEER_pytorch/PEERLora.py:
--------------------------------------------------------------------------------
1 | from math import sqrt
2 |
3 | import torch
4 | from torch import nn
5 | import torch.nn.functional as F
6 | from torch.nn import Module, ModuleList
7 |
8 | import einx
9 | from einops import einsum
10 | from einops.layers.torch import Rearrange
11 |
12 | # helper functions
13 |
14 | def exists(v):
15 | return v is not None
16 |
17 | def default(v, d):
18 | return v if exists(v) else d
19 |
20 | # rmsnorm
21 |
22 | class RMSNorm(Module):
23 | def __init__(self, dim):
24 | super().__init__()
25 | self.scale = dim ** 0.5
26 | self.gamma = nn.Parameter(torch.zeros(dim))
27 |
28 | def forward(self, x):
29 | return F.normalize(x, dim = -1) * self.scale * (self.gamma + 1)
30 |
31 | # main class
32 |
33 | class PEERLora(Module):
34 | """
35 | Same as PEER, except it retrieves LORA weights and adds them to a usual feedforward weight1 and weight2 matrices
36 | """
37 |
38 | def __init__(
39 | self,
40 | dim,
41 | *,
42 | expansion_factor = 2.,
43 | num_experts = 1_000_000, # 1 million experts
44 | heads = 4, # the lora k dimension is kept at 16 (heads [4] * num_experts_per_head [4])
45 | num_experts_per_head = 4,
46 | activation = nn.GELU,
47 | dim_key = None,
48 | product_key_topk = None,
49 | pre_rmsnorm = False,
50 | non_competing_scores = True,
51 | dropout = 0.
52 | ):
53 | """
54 | einops notation
55 | b - batch
56 | n - sequence
57 | d - dimension
58 | h - heads
59 | p - 2 for product key
60 | k - number of keys
61 | """
62 |
63 | super().__init__()
64 | dim_inner = int(dim * expansion_factor)
65 |
66 | self.norm = RMSNorm(dim) if pre_rmsnorm else nn.Identity()
67 |
68 | # heads and num experts
69 |
70 | self.heads = heads
71 | self.num_experts_per_head = num_experts_per_head
72 |
73 | self.num_experts = num_experts
74 |
75 | # usual feedforward weights without bias
76 |
77 | self.proj_in = nn.Linear(dim, dim_inner, bias = False)
78 | self.proj_out = nn.Linear(dim_inner, dim, bias = False)
79 |
80 | # experts that will form the mlp project in / out weights
81 |
82 | self.proj_in_lora_a = nn.Embedding(num_experts, dim)
83 | self.proj_in_lora_b = nn.Embedding(num_experts, dim_inner)
84 |
85 | self.proj_out_lora_a = nn.Embedding(num_experts, dim_inner)
86 | self.proj_out_lora_b = nn.Embedding(num_experts, dim)
87 |
88 | # activation function, defaults to gelu
89 |
90 | self.activation = activation()
91 |
92 | # queries and keys for product-key
93 |
94 | assert sqrt(num_experts).is_integer(), '`num_experts` needs to be a square'
95 | assert (dim % 2) == 0, 'feature dimension should be divisible by 2'
96 |
97 | dim_key = default(dim_key, dim // 2)
98 | self.num_keys = int(sqrt(num_experts))
99 |
100 | self.to_queries = nn.Sequential(
101 | nn.Linear(dim, dim_key * heads * 2, bias = False),
102 | Rearrange('b n (p h d) -> p b n h d', p = 2, h = heads)
103 | )
104 |
105 | self.product_key_topk = default(product_key_topk, num_experts_per_head)
106 | self.num_experts_per_head_topk = num_experts_per_head if not non_competing_scores else 1
107 |
108 | self.keys = nn.Parameter(torch.zeros(heads, self.num_keys, 2, dim_key))
109 | nn.init.normal_(self.keys, std = 0.02)
110 |
111 | # dropout
112 |
113 | self.dropout = nn.Dropout(dropout)
114 |
115 | # whether to use softmax on scores
116 |
117 | # Csordas et al claims non-competing activation helps in PKM setting
118 | # https://arxiv.org/pdf/2310.10837 - Table 2 in Section 6.2
119 |
120 | self.score_activation = nn.Softmax(dim = -1) if not non_competing_scores else nn.ReLU()
121 |
122 | # init
123 |
124 | nn.init.normal_(self.proj_in_lora_a.weight, std = 0.02)
125 | nn.init.normal_(self.proj_out_lora_b.weight, std = 0.02)
126 | nn.init.normal_(self.proj_in_lora_b.weight, std = 0.02)
127 | nn.init.normal_(self.proj_out_lora_a.weight, std = 0.02)
128 |
129 | @property
130 | def lora_k(self):
131 | return self.heads * self.num_experts_per_head
132 |
133 | def forward(
134 | self,
135 | x
136 | ):
137 |
138 | x = self.norm(x)
139 |
140 | # queries
141 |
142 | queries = self.to_queries(x)
143 |
144 | # first get similarity with keys
145 |
146 | sim = einsum(queries, self.keys, 'p b n h d, h k p d -> p b n h k')
147 |
148 | # product key logic
149 |
150 | (scores_x, scores_y), (indices_x, indices_y) = sim.topk(self.product_key_topk, dim = -1)
151 |
152 | all_scores = einx.add('... i, ... j -> ... (i j)', scores_x, scores_y)
153 | all_indices = einx.add('... i, ... j -> ... (i j)', indices_x * self.num_keys, indices_y)
154 |
155 | scores, pk_indices = all_scores.topk(self.num_experts_per_head_topk, dim = -1)
156 |
157 | indices = all_indices.gather(-1, pk_indices)
158 |
159 | # build the loras for projecting in and out weights of a feedforward
160 |
161 | proj_in_lora_a = self.proj_in_lora_a(indices)
162 | proj_in_lora_b = self.proj_in_lora_b(indices)
163 |
164 | proj_out_lora_a = self.proj_out_lora_a(indices)
165 | proj_out_lora_b = self.proj_out_lora_b(indices)
166 |
167 | # feedforward, but with expert loras chosen by pk
168 |
169 | # project in
170 |
171 | hidden = self.proj_in(x)
172 |
173 | lora_in_hidden = einsum(x, proj_in_lora_a, 'b n d, b n h k d -> b n h k')
174 | lora_in_hidden = lora_in_hidden * self.score_activation(scores)
175 | lora_in_hidden = einsum(lora_in_hidden, proj_in_lora_b, 'b n h k, b n h k d -> b n d')
176 |
177 | hidden = hidden + lora_in_hidden
178 |
179 | # gelu and dropout
180 |
181 | hidden = self.activation(hidden)
182 | hidden = self.dropout(hidden)
183 |
184 | # project out
185 |
186 | out = self.proj_out(hidden)
187 |
188 | lora_out_hidden = einsum(hidden, proj_out_lora_a, 'b n d, b n h k d -> b n h k')
189 | lora_out_hidden = lora_out_hidden * self.score_activation(scores)
190 | lora_out_hidden = einsum(lora_out_hidden, proj_out_lora_b, 'b n h k, b n h k d -> b n d')
191 |
192 | out = out + lora_out_hidden
193 |
194 | return out
195 |
--------------------------------------------------------------------------------
/PEER_pytorch/PK.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import Module
4 |
5 | import einx
6 | from einops.layers.torch import Rearrange
7 | from einops import einsum
8 |
9 | # helper functions
10 |
11 | def exists(v):
12 | return v is not None
13 |
14 | def default(v, d):
15 | return v if exists(v) else d
16 |
17 | # main class
18 |
19 | class PK(Module):
20 | def __init__(
21 | self,
22 | dim,
23 | *,
24 | heads = 8,
25 | dim_key = None,
26 | num_keys = 1_000,
27 | product_keys = 2,
28 | product_key_topk = None,
29 | final_topk = 16,
30 | num_experts_per_head = 16
31 | ):
32 | """
33 | einops notation
34 | b - batch
35 | n - sequence
36 | d - dimension
37 | h - heads
38 | p - product keys
39 | k - number of keys
40 | """
41 |
42 | super().__init__()
43 | assert (dim % 2) == 0
44 | dim_key = default(dim_key, dim // 2)
45 |
46 | self.to_queries = nn.Sequential(
47 | nn.Linear(dim, dim_key * product_keys * heads, bias = False),
48 | Rearrange('b n (p h d) -> p b n h d', h = heads, p = product_keys)
49 | )
50 |
51 | self.num_keys = num_keys
52 | self.product_keys = product_keys
53 |
54 | self.keys = nn.Parameter(torch.zeros(product_keys, num_keys, heads, dim_key))
55 | nn.init.normal_(self.keys, std = 0.02)
56 |
57 | product_key_topk = default(product_key_topk, final_topk)
58 | assert final_topk <= (product_key_topk ** product_keys)
59 |
60 | self.topk = product_key_topk
61 | self.final_topk = final_topk
62 |
63 | # the maximum index, or the total space being indexed into
64 |
65 | self.max_index = int(num_keys ** product_keys)
66 |
67 | def forward(
68 | self,
69 | x,
70 | softmax_scores = False
71 | ):
72 |
73 | queries = self.to_queries(x)
74 |
75 | sim = einsum(queries, self.keys, 'p b n h d, p k h d -> p b n h k')
76 |
77 | scores, indices = sim.topk(self.topk, dim = -1)
78 |
79 | # cartesian product indices
80 |
81 | strides = self.num_keys ** torch.arange(self.product_keys, device = x.device)
82 | indices = einx.multiply('p ..., p -> p ...', indices, strides)
83 |
84 | index, *rest_indices = indices
85 |
86 | for rest_index in rest_indices:
87 | index = einx.add('... i, ... j -> ... (i j)', index, rest_index)
88 |
89 | # cartesian product score
90 |
91 | score, *rest_scores = scores
92 |
93 | for rest_score in rest_scores:
94 | score = einx.add('... i, ... j -> ... (i j)', score, rest_score)
95 |
96 | final_scores, final_indices = score, index
97 |
98 | # final topk
99 |
100 | final_scores, pk_indices = final_scores.topk(self.final_topk, dim = -1)
101 |
102 | final_indices = final_indices.gather(-1, pk_indices)
103 |
104 | if softmax_scores:
105 | final_scores = final_scores.softmax(dim = -1)
106 |
107 | return final_scores, final_indices
108 |
109 |
110 | if __name__ == '__main__':
111 |
112 | pk = PK(
113 | dim = 512,
114 | num_keys = 100,
115 | final_topk = 10,
116 | product_keys = 3
117 | )
118 |
119 | x = torch.randn(2, 1024, 512)
120 | score, indices = pk(x)
121 |
122 | assert score.shape == (2, 1024, 8, 10)
123 | assert indices.shape == (2, 1024, 8, 10)
124 |
--------------------------------------------------------------------------------
/PEER_pytorch/PKAttention.py:
--------------------------------------------------------------------------------
1 | from math import sqrt
2 |
3 | import torch
4 | from torch import nn
5 | import torch.nn.functional as F
6 | from torch.nn import Module, ModuleList
7 |
8 | import einx
9 | from einops import einsum, pack, unpack
10 | from einops.layers.torch import Rearrange
11 |
12 | from PEER_pytorch.PK import PK
13 |
14 | # helper functions
15 |
16 | def exists(v):
17 | return v is not None
18 |
19 | def default(v, d):
20 | return v if exists(v) else d
21 |
22 | def pack_one(t, pattern):
23 | return pack([t], pattern)
24 |
25 | def unpack_one(t, ps, pattern):
26 | return unpack(t, ps, pattern)[0]
27 |
28 | # rmsnorm
29 |
30 | class RMSNorm(Module):
31 | def __init__(self, dim):
32 | super().__init__()
33 | self.scale = dim ** 0.5
34 | self.gamma = nn.Parameter(torch.zeros(dim))
35 |
36 | def forward(self, x):
37 | return F.normalize(x, dim = -1) * self.scale * (self.gamma + 1)
38 |
39 | # main class
40 |
41 | class PKAttention(Module):
42 | def __init__(
43 | self,
44 | dim,
45 | *,
46 | causal = True,
47 | heads = 8,
48 | num_key_values = 1_000_000,
49 | key_value_pk_topk = 16,
50 | dim_key = None,
51 | product_keys = 2,
52 | pre_rmsnorm = False,
53 | dropout = 0.
54 | ):
55 | """
56 | einops notation
57 | b - batch
58 | n - sequence
59 | d - dimension
60 | h - heads
61 | p - 2 for product key
62 | k - number of keys
63 | """
64 |
65 | super().__init__()
66 | self.causal = causal
67 | self.heads = heads
68 | self.num_key_values = num_key_values
69 |
70 | self.norm = RMSNorm(dim) if pre_rmsnorm else nn.Identity()
71 |
72 | # experts that will form the mlp project in / out weights
73 |
74 | self.to_queries = nn.Sequential(
75 | nn.Linear(dim, dim * heads, bias = False),
76 | Rearrange('b n (h d) -> b n h d', h = heads)
77 | )
78 |
79 | # keys and values selected using product-key
80 |
81 | self.keys = nn.EmbeddingBag(num_key_values * heads, dim, mode = 'sum')
82 | self.values = nn.EmbeddingBag(num_key_values * heads, dim, mode = 'sum')
83 |
84 | assert sqrt(num_key_values).is_integer(), '`num_key_values` needs to be a square'
85 | assert (dim % 2) == 0, 'feature dimension should be divisible by 2'
86 |
87 | self.to_kv_pk_indices = PK(
88 | dim = dim,
89 | num_keys = int(sqrt(num_key_values)),
90 | final_topk = key_value_pk_topk,
91 | product_keys = product_keys
92 | )
93 |
94 | # dropout
95 |
96 | self.dropout = nn.Dropout(dropout)
97 |
98 | # output
99 |
100 | self.to_out = nn.Sequential(
101 | Rearrange('b h n d -> b n (h d)'),
102 | nn.Linear(dim * heads, dim, bias = False)
103 | )
104 |
105 | def forward(
106 | self,
107 | x,
108 | mask = None
109 | ):
110 | device = x.device
111 |
112 | x = self.norm(x)
113 |
114 | # queries
115 |
116 | q = self.to_queries(x)
117 |
118 | q = q * (q.shape[-1] ** -0.5)
119 |
120 | # keys and values
121 |
122 | kv_scores, indices = self.to_kv_pk_indices(x, softmax_scores = True)
123 |
124 | offsets = torch.arange(self.heads, device = device) * self.num_key_values
125 | indices = einx.add('b n h k, h -> b n h k', indices, offsets)
126 |
127 | indices, packed_shape = pack_one(indices, '* k')
128 | kv_scores, _ = pack_one(kv_scores, '* k')
129 |
130 | k, v = self.keys(indices, per_sample_weights = kv_scores), self.values(indices, per_sample_weights = kv_scores)
131 |
132 | k = unpack_one(k, packed_shape, '* d')
133 | v = unpack_one(v, packed_shape, '* d')
134 |
135 | # usual multihead self attention
136 |
137 | sim = einsum(q, k, 'b i h d, b j h d -> b h i j')
138 |
139 | # whether causal or not
140 |
141 | if self.causal:
142 | assert not exists(mask)
143 | i, j, device = *sim.shape[-2:], x.device
144 | causal_mask = torch.ones((i, j), device = device, dtype = torch.bool).triu(j - i + 1)
145 | sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
146 |
147 | elif exists(mask):
148 | sim = einx.where('b j, b h i j, -> b h i j', mask, sim, -torch.finfo(sim.dtype).max)
149 |
150 | # attention
151 |
152 | attn = sim.softmax(dim = -1)
153 | attn = self.dropout(attn)
154 |
155 | # aggregate
156 |
157 | out = einsum(attn, v, 'b h i j, b j h d -> b h i d')
158 |
159 | # combine heads
160 |
161 | return self.to_out(out)
162 |
163 | # main
164 |
165 | if __name__ == '__main__':
166 | peer_attn = PKAttention(
167 | dim = 256,
168 | causal = True,
169 | heads = 8,
170 | num_key_values = int(1e4),
171 | pre_rmsnorm = True
172 | )
173 |
174 | x = torch.randn(2, 512, 256)
175 |
176 | out = peer_attn(x) + x
177 |
178 | assert x.shape == out.shape
179 |
--------------------------------------------------------------------------------
/PEER_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | from PEER_pytorch.PEER import PEER
2 | from PEER_pytorch.PEERLora import PEERLora
3 |
4 | from PEER_pytorch.ChunkedPEER import ChunkedPEER
5 |
6 | from PEER_pytorch.PK import PK
7 | from PEER_pytorch.PKAttention import PKAttention
8 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | ## PEER - Pytorch
6 |
7 | Pytorch implementation of the PEER block from the Deepmind paper, Mixture of A Million Experts, by Xu Owen He.
8 |
9 | ## Install
10 |
11 | ```bash
12 | $ pip install PEER-pytorch
13 | ```
14 |
15 | ## Usage
16 |
17 | ```python
18 | import torch
19 | from PEER_pytorch import PEER
20 |
21 | peer = PEER(
22 | dim = 512,
23 | heads = 8, # tested up to 32 - (hk = heads * num_experts_per_head (16))
24 | num_experts = 1_000_000, # he chose 1 million
25 | num_experts_per_head = 16, # he settled on 16, but was 32 in PKM paper
26 | dim_key = 128,
27 | pre_rmsnorm = True
28 | ).cuda()
29 |
30 | x = torch.randn(2, 1024, 512).cuda()
31 |
32 | out = peer(x) + x
33 |
34 | assert x.shape == out.shape
35 | ```
36 |
37 | ## Citations
38 |
39 | ```bibtex
40 | @inproceedings{He2024MixtureOA,
41 | title = {Mixture of A Million Experts},
42 | author = {Xu Owen He},
43 | year = {2024},
44 | url = {https://api.semanticscholar.org/CorpusID:271038610}
45 | }
46 | ```
47 |
48 | ```bibtex
49 | @article{Csordas2023ApproximatingTF,
50 | title = {Approximating Two-Layer Feedforward Networks for Efficient Transformers},
51 | author = {R'obert Csord'as and Kazuki Irie and J{\"u}rgen Schmidhuber},
52 | journal = {ArXiv},
53 | year = {2023},
54 | volume = {abs/2310.10837},
55 | url = {https://api.semanticscholar.org/CorpusID:264172384}
56 | }
57 | ```
58 |
--------------------------------------------------------------------------------
/peer.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/PEER-pytorch/6cb671f997c38ad19b27838f3f1a6aacb0a373f4/peer.png
--------------------------------------------------------------------------------
/peer2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/PEER-pytorch/6cb671f997c38ad19b27838f3f1a6aacb0a373f4/peer2.png
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "PEER-pytorch"
3 | version = "0.2.1"
4 | description = "PEER - Pytorch"
5 | authors = [
6 | { name = "Phil Wang", email = "lucidrains@gmail.com" }
7 | ]
8 | readme = "README.md"
9 | requires-python = ">= 3.8"
10 | license = { file = "LICENSE" }
11 | keywords = [
12 | 'artificial intelligence',
13 | 'deep learning',
14 | 'product key',
15 | 'mixture of experts',
16 | ]
17 | classifiers=[
18 | 'Development Status :: 4 - Beta',
19 | 'Intended Audience :: Developers',
20 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
21 | 'License :: OSI Approved :: MIT License',
22 | 'Programming Language :: Python :: 3.8',
23 | ]
24 |
25 | dependencies = [
26 | 'einops>=0.8.0',
27 | 'einx>=0.3.0',
28 | 'torch>=2.0',
29 | ]
30 |
31 | [project.urls]
32 | Homepage = "https://pypi.org/project/PEER-pytorch/"
33 | Repository = "https://github.com/lucidrains/PEER-pytorch"
34 |
35 | [build-system]
36 | requires = ["hatchling"]
37 | build-backend = "hatchling.build"
38 |
39 |
40 | [tool.hatch.metadata]
41 | allow-direct-references = true
42 |
43 | [tool.hatch.build.targets.wheel]
44 | packages = ["PEER_pytorch"]
45 |
--------------------------------------------------------------------------------