├── .github └── workflows │ └── python-publish.yml ├── .gitignore ├── LICENSE ├── README.md ├── agent-attention.png ├── agent_attention_pytorch ├── __init__.py ├── agent_attention_pytorch.py └── agent_transformer.py └── setup.py /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | 2 | 3 | # This workflow will upload a Python Package using Twine when a release is created 4 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 5 | 6 | # This workflow uses actions that are not certified by GitHub. 7 | # They are provided by a third-party and are governed by 8 | # separate terms of service, privacy policy, and support 9 | # documentation. 10 | 11 | name: Upload Python Package 12 | 13 | on: 14 | release: 15 | types: [published] 16 | 17 | jobs: 18 | deploy: 19 | 20 | runs-on: ubuntu-latest 21 | 22 | steps: 23 | - uses: actions/checkout@v2 24 | - name: Set up Python 25 | uses: actions/setup-python@v2 26 | with: 27 | python-version: '3.x' 28 | - name: Install dependencies 29 | run: | 30 | python -m pip install --upgrade pip 31 | pip install build 32 | - name: Build package 33 | run: python -m build 34 | - name: Publish package 35 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 36 | with: 37 | user: __token__ 38 | password: ${{ secrets.PYPI_API_TOKEN }} 39 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Agent Attention - Pytorch 4 | 5 | Implementation of Agent Attention in Pytorch. 6 | 7 | This work seems to be an elegant simplification of `ISAB` architecture from the Set Transformers paper (requires only one attention block rather than two). While ISAB works, I have found it to be a bit unstable, thus wondering if the simplification in this work resolves that issue. 8 | 9 | This repository will add support for variable sequence lengths (masking) and post-softmax talking heads. 10 | 11 | ## Appreciation 12 | 13 | - A16Z Open Source AI Grant Program and 🤗 Huggingface for the generous sponsorships, as well as my other sponsors, for affording me the independence to open source current artificial intelligence research 14 | 15 | ## Install 16 | 17 | ```bash 18 | $ pip install agent-attention-pytorch 19 | ``` 20 | 21 | ## Usage 22 | 23 | ```python 24 | import torch 25 | from agent_attention_pytorch import AgentSelfAttention 26 | 27 | attn = AgentSelfAttention( 28 | dim = 512, 29 | num_agent_tokens = 256, # number of "agent" tokens 30 | dim_head = 64, # attention head dimension 31 | heads = 8 # number of heads 32 | ) 33 | 34 | x = torch.randn(3, 65536, 512) 35 | mask = torch.ones(3, 65536).bool() 36 | 37 | out = attn(x, mask = mask) 38 | 39 | assert out.shape == x.shape 40 | ``` 41 | 42 | For a full fledged linear transformer based on agent tokens, just import `AgentTransformer` 43 | 44 | ```python 45 | import torch 46 | from agent_attention_pytorch import AgentTransformer 47 | 48 | transformer = AgentTransformer( 49 | dim = 512, 50 | depth = 6, 51 | num_agent_tokens = 128, 52 | dim_head = 64, 53 | heads = 8 54 | ) 55 | 56 | x = torch.randn(3, 65536, 512) 57 | mask = torch.ones(3, 65536).bool() 58 | 59 | out, agent_tokens = transformer(x, mask = mask, return_agent_tokens = True) 60 | 61 | # (3, 65536, 512), (3, 128, 512) 62 | assert out.shape == x.shape 63 | ``` 64 | 65 | ## Citations 66 | 67 | ```bibtex 68 | @inproceedings{Han2023AgentAO, 69 | title = {Agent Attention: On the Integration of Softmax and Linear Attention}, 70 | author = {Dongchen Han and Tianzhu Ye and Yizeng Han and Zhuofan Xia and Shiji Song and Gao Huang}, 71 | year = {2023}, 72 | url = {https://api.semanticscholar.org/CorpusID:266210414} 73 | } 74 | ``` 75 | 76 | ```bibtex 77 | @misc{shazeer2020talkingheads, 78 | title = {Talking-Heads Attention}, 79 | author = {Noam Shazeer and Zhenzhong Lan and Youlong Cheng and Nan Ding and Le Hou}, 80 | year = {2020}, 81 | eprint = {2003.02436}, 82 | archivePrefix = {arXiv}, 83 | primaryClass = {cs.LG} 84 | } 85 | ``` 86 | 87 | ```bibtex 88 | @article{Bondarenko2023QuantizableTR, 89 | title = {Quantizable Transformers: Removing Outliers by Helping Attention Heads Do Nothing}, 90 | author = {Yelysei Bondarenko and Markus Nagel and Tijmen Blankevoort}, 91 | journal = {ArXiv}, 92 | year = {2023}, 93 | volume = {abs/2306.12929}, 94 | url = {https://api.semanticscholar.org/CorpusID:259224568} 95 | } 96 | ``` 97 | 98 | ```bibtex 99 | @article{Wang2022FoundationT, 100 | title = {Foundation Transformers}, 101 | author = {Hongyu Wang and Shuming Ma and Shaohan Huang and Li Dong and Wenhui Wang and Zhiliang Peng and Yu Wu and Payal Bajaj and Saksham Singhal and Alon Benhaim and Barun Patra and Zhun Liu and Vishrav Chaudhary and Xia Song and Furu Wei}, 102 | journal = {ArXiv}, 103 | year = {2022}, 104 | volume = {abs/2210.06423}, 105 | url = {https://api.semanticscholar.org/CorpusID:252846241} 106 | } 107 | ``` 108 | -------------------------------------------------------------------------------- /agent-attention.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/agent-attention-pytorch/091f6d447d5005cd8c8e16843688b5cc2a9b8cd2/agent-attention.png -------------------------------------------------------------------------------- /agent_attention_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from agent_attention_pytorch.agent_attention_pytorch import ( 2 | AgentSelfAttention 3 | ) 4 | 5 | from agent_attention_pytorch.agent_transformer import ( 6 | AgentTransformer 7 | ) 8 | -------------------------------------------------------------------------------- /agent_attention_pytorch/agent_attention_pytorch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Module 3 | from torch import nn, einsum, Tensor 4 | 5 | from einops import rearrange, repeat 6 | from einops.layers.torch import Rearrange 7 | 8 | # functions 9 | 10 | def exists(v): 11 | return v is not None 12 | 13 | # main class 14 | 15 | class AgentSelfAttention(Module): 16 | def __init__( 17 | self, 18 | dim, 19 | *, 20 | num_agent_tokens, 21 | dim_head = 64, 22 | heads = 8, 23 | dropout = 0., 24 | talking_heads = True, 25 | gate = True, 26 | combine_agent_tokens = False 27 | ): 28 | super().__init__() 29 | self.scale = dim_head ** -0.5 30 | dim_inner = dim_head * heads 31 | 32 | self.to_qkv = nn.Sequential( 33 | nn.Linear(dim, dim_inner * 3, bias = False), 34 | Rearrange('b n (qkv h d) -> qkv b h n d', h = heads, qkv = 3) 35 | ) 36 | 37 | self.to_gates = nn.Sequential( 38 | nn.Linear(dim, heads), 39 | Rearrange('b n h -> b h n 1'), 40 | nn.Sigmoid() 41 | ) if gate else None 42 | 43 | self.agent_tokens = nn.Parameter(torch.zeros(heads, num_agent_tokens, dim_head)) 44 | nn.init.normal_(self.agent_tokens, std = 0.02) 45 | 46 | self.qa_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if talking_heads else nn.Identity() 47 | self.ak_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if talking_heads else nn.Identity() 48 | 49 | self.qa_dropout = nn.Dropout(dropout) 50 | self.ak_dropout = nn.Dropout(dropout) 51 | 52 | self.to_out = nn.Sequential( 53 | Rearrange('b h n d -> b n (h d)'), 54 | nn.Linear(dim_inner, dim, bias = False) 55 | ) 56 | 57 | def forward( 58 | self, 59 | x, 60 | mask = None, 61 | agent_tokens = None, 62 | return_agent_tokens = False 63 | ): 64 | batch = x.shape[0] 65 | 66 | q, k, v = self.to_qkv(x) 67 | 68 | if exists(agent_tokens): 69 | a = agent_tokens 70 | else: 71 | a = repeat(self.agent_tokens, 'h m d -> b h m d', b = batch) 72 | 73 | a = a * self.scale 74 | 75 | qa_sim = einsum('b h i d, b h j d -> b h i j', q, a) 76 | ak_sim = einsum('b h i d, b h j d -> b h i j', a, k) 77 | 78 | if exists(mask): 79 | max_neg_value = -torch.finfo(qa_sim.dtype).max 80 | ak_sim = ak_sim.masked_fill(~rearrange(mask, 'b j -> b 1 1 j'), max_neg_value) 81 | 82 | qa_attn = qa_sim.softmax(dim = -1) 83 | ak_attn = ak_sim.softmax(dim = -1) 84 | 85 | qa_attn = self.qa_dropout(qa_attn) 86 | ak_attn = self.ak_dropout(ak_attn) 87 | 88 | qa_attn = self.qa_talking_heads(qa_attn) 89 | ak_attn = self.ak_talking_heads(ak_attn) 90 | 91 | agent_gathered_tokens = einsum('b h i j, b h j d -> b h i d', ak_attn, v) 92 | 93 | out = einsum('b h i j, b h j d -> b h i d', qa_attn, agent_gathered_tokens) 94 | 95 | if exists(mask): 96 | out = out.masked_fill(~rearrange(mask, 'b n -> b 1 n 1'), 0.) 97 | 98 | if exists(self.to_gates): 99 | out = out * self.to_gates(x) 100 | 101 | out = self.to_out(out) 102 | 103 | if not return_agent_tokens: 104 | return out 105 | 106 | return out, agent_gathered_tokens 107 | -------------------------------------------------------------------------------- /agent_attention_pytorch/agent_transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.nn import Module, ModuleList 4 | from torch import nn, einsum, Tensor 5 | 6 | from einops import rearrange, repeat, pack, unpack 7 | from einops.layers.torch import Rearrange 8 | 9 | # functions 10 | 11 | def exists(v): 12 | return v is not None 13 | 14 | # norm 15 | 16 | class RMSNorm(Module): 17 | def __init__(self, dim): 18 | super().__init__() 19 | self.scale = dim ** 0.5 20 | self.gamma = nn.Parameter(torch.ones(dim)) 21 | 22 | def forward(self, x): 23 | return F.normalize(x, dim = -1) * self.scale * self.gamma 24 | 25 | # feedforward 26 | 27 | def FeedForward(dim, mult = 4): 28 | dim_inner = int(dim * mult) 29 | return nn.Sequential( 30 | RMSNorm(dim), 31 | nn.Linear(dim, dim_inner), 32 | nn.GELU(), 33 | nn.Linear(dim_inner, dim) 34 | ) 35 | 36 | # main class 37 | 38 | class AgentSelfAttention(Module): 39 | def __init__( 40 | self, 41 | dim, 42 | *, 43 | num_agent_tokens, 44 | dim_head = 64, 45 | heads = 8, 46 | dropout = 0., 47 | talking_heads = True, 48 | gate = True, 49 | sub_layernorm = False 50 | ): 51 | super().__init__() 52 | self.scale = dim_head ** -0.5 53 | dim_inner = dim_head * heads 54 | 55 | self.norm = RMSNorm(dim) 56 | 57 | self.to_qkv = nn.Sequential( 58 | nn.Linear(dim, dim_inner * 3, bias = False), 59 | Rearrange('b n (qkv h d) -> qkv b h n d', h = heads, qkv = 3) 60 | ) 61 | 62 | self.to_gates = nn.Sequential( 63 | nn.Linear(dim, heads), 64 | Rearrange('b n h -> b h n 1'), 65 | nn.Sigmoid() 66 | ) if gate else None 67 | 68 | self.qa_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if talking_heads else nn.Identity() 69 | self.ak_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if talking_heads else nn.Identity() 70 | 71 | self.qa_dropout = nn.Dropout(dropout) 72 | self.ak_dropout = nn.Dropout(dropout) 73 | 74 | self.to_agent_out = nn.Sequential( 75 | nn.LayerNorm(dim_head) if sub_layernorm else nn.Identity(), 76 | Rearrange('b h n d -> b n (h d)'), 77 | nn.Linear(dim_inner, dim, bias = False) 78 | ) 79 | 80 | self.to_out = nn.Sequential( 81 | nn.LayerNorm(dim_head) if sub_layernorm else nn.Identity(), 82 | Rearrange('b h n d -> b n (h d)'), 83 | nn.Linear(dim_inner, dim, bias = False) 84 | ) 85 | 86 | def forward( 87 | self, 88 | x, 89 | *, 90 | agent_tokens, 91 | mask = None, 92 | return_agent_tokens = False 93 | ): 94 | x = self.norm(x) 95 | a = self.norm(agent_tokens) 96 | 97 | x_and_agents, xa_ps = pack([a, x], 'b * d') 98 | qkv = self.to_qkv(x_and_agents) 99 | 100 | qkv_agent, qkv_input = unpack(qkv, xa_ps, 'qkv b h * d') 101 | 102 | q, k, v = qkv_input 103 | agent_queries, agent_keys, _ = qkv_agent 104 | 105 | q = q * self.scale 106 | agent_queries = agent_queries * self.scale 107 | 108 | qa_sim = einsum('b h i d, b h j d -> b h i j', q, agent_keys) 109 | ak_sim = einsum('b h i d, b h j d -> b h i j', agent_queries, k) 110 | 111 | if exists(mask): 112 | max_neg_value = -torch.finfo(qa_sim.dtype).max 113 | ak_sim = ak_sim.masked_fill(~rearrange(mask, 'b j -> b 1 1 j'), max_neg_value) 114 | 115 | qa_attn = qa_sim.softmax(dim = -1) 116 | ak_attn = ak_sim.softmax(dim = -1) 117 | 118 | qa_attn = self.qa_dropout(qa_attn) 119 | ak_attn = self.ak_dropout(ak_attn) 120 | 121 | qa_attn = self.qa_talking_heads(qa_attn) 122 | ak_attn = self.ak_talking_heads(ak_attn) 123 | 124 | agent_out = einsum('b h i j, b h j d -> b h i d', ak_attn, v) 125 | 126 | out = einsum('b h i j, b h j d -> b h i d', qa_attn, agent_out) 127 | 128 | if exists(mask): 129 | out = out.masked_fill(~rearrange(mask, 'b n -> b 1 n 1'), 0.) 130 | 131 | if exists(self.to_gates): 132 | out = out * self.to_gates(x) 133 | agent_out = agent_out * self.to_gates(a) 134 | 135 | out = self.to_out(out) 136 | agent_out = self.to_agent_out(agent_out) 137 | 138 | if not return_agent_tokens: 139 | return out 140 | 141 | return out, agent_out 142 | 143 | # transformer 144 | 145 | class AgentTransformer(Module): 146 | def __init__( 147 | self, 148 | dim, 149 | *, 150 | num_agent_tokens, 151 | depth, 152 | heads = 8, 153 | dim_head = 64, 154 | ff_mult = 4, 155 | final_norm = True, 156 | **attn_kwargs: dict 157 | ): 158 | super().__init__() 159 | 160 | self.agent_tokens = nn.Parameter(torch.zeros(num_agent_tokens, dim)) 161 | nn.init.normal_(self.agent_tokens, std = 0.02) 162 | 163 | self.layers = ModuleList([]) 164 | 165 | for _ in range(depth): 166 | self.layers.append(ModuleList([ 167 | AgentSelfAttention( 168 | dim = dim, 169 | heads = heads, 170 | dim_head = dim_head, 171 | num_agent_tokens = num_agent_tokens, 172 | **attn_kwargs 173 | ), 174 | FeedForward(dim = dim, mult = ff_mult) 175 | ])) 176 | 177 | self.final_norm = RMSNorm(dim) if final_norm else None 178 | 179 | def forward( 180 | self, 181 | x, 182 | mask = None, 183 | return_agent_tokens = False 184 | ): 185 | batch = x.shape[0] 186 | a = repeat(self.agent_tokens, 'm d -> b m d', b = batch) 187 | 188 | for attn, ff in self.layers: 189 | attn_out, agent_out = attn( 190 | x, 191 | agent_tokens = a, 192 | mask = mask, 193 | return_agent_tokens = True 194 | ) 195 | 196 | a = a + agent_out 197 | x = x + attn_out 198 | 199 | x, ps = pack([a, x], 'b * d') 200 | 201 | x = ff(x) + x 202 | 203 | a, x = unpack(x, ps, 'b * d') 204 | 205 | if exists(self.final_norm): 206 | x = self.final_norm(x) 207 | a = self.final_norm(a) 208 | 209 | if not return_agent_tokens: 210 | return x 211 | 212 | return x, a 213 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name = 'agent-attention-pytorch', 5 | packages = find_packages(exclude=[]), 6 | version = '0.1.7', 7 | license='MIT', 8 | description = 'Agent Attention - Pytorch', 9 | author = 'Phil Wang', 10 | author_email = 'lucidrains@gmail.com', 11 | long_description_content_type = 'text/markdown', 12 | url = 'https://github.com/lucidrains/agent-attention-pytorch', 13 | keywords = [ 14 | 'artificial intelligence', 15 | 'deep learning', 16 | 'attention', 17 | 'linear attention' 18 | ], 19 | install_requires=[ 20 | 'einops>=0.7.0', 21 | 'torch>=2.0' 22 | ], 23 | classifiers=[ 24 | 'Development Status :: 4 - Beta', 25 | 'Intended Audience :: Developers', 26 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 27 | 'License :: OSI Approved :: MIT License', 28 | 'Programming Language :: Python :: 3.6', 29 | ], 30 | ) 31 | --------------------------------------------------------------------------------