├── .github └── workflows │ └── python-publish.yml ├── .gitignore ├── LICENSE ├── README.md ├── glom1.png ├── glom2.png ├── glom_pytorch ├── __init__.py └── glom_pytorch.py └── setup.py /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | name: Upload Python Package 5 | 6 | on: 7 | release: 8 | types: [created] 9 | 10 | jobs: 11 | deploy: 12 | 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - uses: actions/checkout@v2 17 | - name: Set up Python 18 | uses: actions/setup-python@v2 19 | with: 20 | python-version: '3.x' 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install setuptools wheel twine 25 | - name: Build and publish 26 | env: 27 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 28 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 29 | run: | 30 | python setup.py sdist bdist_wheel 31 | twine upload dist/* 32 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | ## GLOM - Pytorch 6 | 7 | An implementation of Glom, Geoffrey Hinton's new idea that integrates concepts from neural fields, top-down-bottom-up processing, and attention (consensus between columns) for learning emergent part-whole heirarchies from data. 8 | 9 | Yannic Kilcher's video was instrumental in helping me to understand this paper 10 | 11 | ## Install 12 | 13 | ```bash 14 | $ pip install glom-pytorch 15 | ``` 16 | 17 | ## Usage 18 | 19 | ```python 20 | import torch 21 | from glom_pytorch import Glom 22 | 23 | model = Glom( 24 | dim = 512, # dimension 25 | levels = 6, # number of levels 26 | image_size = 224, # image size 27 | patch_size = 14 # patch size 28 | ) 29 | 30 | img = torch.randn(1, 3, 224, 224) 31 | levels = model(img, iters = 12) # (1, 256, 6, 512) - (batch - patches - levels - dimension) 32 | ``` 33 | 34 | Pass the `return_all = True` keyword argument on forward, and you will be returned all the column and level states per iteration, (including the initial state, number of iterations + 1). You can then use this to attach any losses to any level outputs at any time step. 35 | 36 | It also gives you access to all the level data across iterations for clustering, from which one can inspect for the theorized islands in the paper. 37 | 38 | ```python 39 | import torch 40 | from glom_pytorch import Glom 41 | 42 | model = Glom( 43 | dim = 512, # dimension 44 | levels = 6, # number of levels 45 | image_size = 224, # image size 46 | patch_size = 14 # patch size 47 | ) 48 | 49 | img = torch.randn(1, 3, 224, 224) 50 | all_levels = model(img, iters = 12, return_all = True) # (13, 1, 256, 6, 512) - (time, batch, patches, levels, dimension) 51 | 52 | # get the top level outputs after iteration 6 53 | top_level_output = all_levels[7, :, :, -1] # (1, 256, 512) - (batch, patches, dimension) 54 | ``` 55 | 56 | Denoising self-supervised learning for encouraging emergence, as described by Hinton 57 | 58 | ```python 59 | import torch 60 | import torch.nn.functional as F 61 | from torch import nn 62 | from einops.layers.torch import Rearrange 63 | 64 | from glom_pytorch import Glom 65 | 66 | model = Glom( 67 | dim = 512, # dimension 68 | levels = 6, # number of levels 69 | image_size = 224, # image size 70 | patch_size = 14 # patch size 71 | ) 72 | 73 | img = torch.randn(1, 3, 224, 224) 74 | noised_img = img + torch.randn_like(img) 75 | 76 | all_levels = model(noised_img, return_all = True) 77 | 78 | patches_to_images = nn.Sequential( 79 | nn.Linear(512, 14 * 14 * 3), 80 | Rearrange('b (h w) (p1 p2 c) -> b c (h p1) (w p2)', p1 = 14, p2 = 14, h = (224 // 14)) 81 | ) 82 | 83 | top_level = all_levels[7, :, :, -1] # get the top level embeddings after iteration 6 84 | recon_img = patches_to_images(top_level) 85 | 86 | # do self-supervised learning by denoising 87 | 88 | loss = F.mse_loss(img, recon_img) 89 | loss.backward() 90 | ``` 91 | 92 | You can pass in the state of the column and levels back into the model to continue where you left off (perhaps if you are processing consecutive frames of a slow video, as mentioned in the paper) 93 | 94 | ```python 95 | import torch 96 | from glom_pytorch import Glom 97 | 98 | model = Glom( 99 | dim = 512, 100 | levels = 6, 101 | image_size = 224, 102 | patch_size = 14 103 | ) 104 | 105 | img1 = torch.randn(1, 3, 224, 224) 106 | img2 = torch.randn(1, 3, 224, 224) 107 | img3 = torch.randn(1, 3, 224, 224) 108 | 109 | levels1 = model(img1, iters = 12) # image 1 for 12 iterations 110 | levels2 = model(img2, levels = levels1, iters = 10) # image 2 for 10 iteratoins 111 | levels3 = model(img3, levels = levels2, iters = 6) # image 3 for 6 iterations 112 | ``` 113 | 114 | ### Appreciation 115 | 116 | Thanks goes out to Cfoster0 for reviewing the code 117 | 118 | ### Todo 119 | 120 | - [ ] contrastive / consistency regularization of top-ish levels 121 | 122 | ## Citations 123 | 124 | ```bibtex 125 | @misc{hinton2021represent, 126 | title = {How to represent part-whole hierarchies in a neural network}, 127 | author = {Geoffrey Hinton}, 128 | year = {2021}, 129 | eprint = {2102.12627}, 130 | archivePrefix = {arXiv}, 131 | primaryClass = {cs.CV} 132 | } 133 | ``` 134 | -------------------------------------------------------------------------------- /glom1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/glom-pytorch/f30f62165d0c9f9ccdc0330b0005c35ffaaa1635/glom1.png -------------------------------------------------------------------------------- /glom2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/glom-pytorch/f30f62165d0c9f9ccdc0330b0005c35ffaaa1635/glom2.png -------------------------------------------------------------------------------- /glom_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from glom_pytorch.glom_pytorch import Glom 2 | -------------------------------------------------------------------------------- /glom_pytorch/glom_pytorch.py: -------------------------------------------------------------------------------- 1 | from math import sqrt 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn, einsum 5 | 6 | from einops import rearrange, repeat 7 | from einops.layers.torch import Rearrange 8 | 9 | # constants 10 | 11 | TOKEN_ATTEND_SELF_VALUE = -5e-4 12 | 13 | # helpers 14 | 15 | def exists(val): 16 | return val is not None 17 | 18 | def default(val, d): 19 | return val if exists(val) else d 20 | 21 | # class 22 | 23 | class GroupedFeedForward(nn.Module): 24 | def __init__(self, *, dim, groups, mult = 4): 25 | super().__init__() 26 | total_dim = dim * groups # levels * dim 27 | self.net = nn.Sequential( 28 | Rearrange('b n l d -> b (l d) n'), 29 | nn.Conv1d(total_dim, total_dim * mult, 1, groups = groups), 30 | nn.GELU(), 31 | nn.Conv1d(total_dim * mult, total_dim, 1, groups = groups), 32 | Rearrange('b (l d) n -> b n l d', l = groups) 33 | ) 34 | 35 | def forward(self, levels): 36 | return self.net(levels) 37 | 38 | class ConsensusAttention(nn.Module): 39 | def __init__(self, num_patches_side, attend_self = True, local_consensus_radius = 0): 40 | super().__init__() 41 | self.attend_self = attend_self 42 | self.local_consensus_radius = local_consensus_radius 43 | 44 | if self.local_consensus_radius > 0: 45 | coors = torch.stack(torch.meshgrid( 46 | torch.arange(num_patches_side), 47 | torch.arange(num_patches_side) 48 | )).float() 49 | 50 | coors = rearrange(coors, 'c h w -> (h w) c') 51 | dist = torch.cdist(coors, coors) 52 | mask_non_local = dist > self.local_consensus_radius 53 | mask_non_local = rearrange(mask_non_local, 'i j -> () i j') 54 | self.register_buffer('non_local_mask', mask_non_local) 55 | 56 | def forward(self, levels): 57 | _, n, _, d, device = *levels.shape, levels.device 58 | q, k, v = levels, F.normalize(levels, dim = -1), levels 59 | 60 | sim = einsum('b i l d, b j l d -> b l i j', q, k) * (d ** -0.5) 61 | 62 | if not self.attend_self: 63 | self_mask = torch.eye(n, device = device, dtype = torch.bool) 64 | self_mask = rearrange(self_mask, 'i j -> () () i j') 65 | sim.masked_fill_(self_mask, TOKEN_ATTEND_SELF_VALUE) 66 | 67 | if self.local_consensus_radius > 0: 68 | max_neg_value = -torch.finfo(sim.dtype).max 69 | sim.masked_fill_(self.non_local_mask, max_neg_value) 70 | 71 | attn = sim.softmax(dim = -1) 72 | out = einsum('b l i j, b j l d -> b i l d', attn, levels) 73 | return out 74 | 75 | # main class 76 | 77 | class Glom(nn.Module): 78 | def __init__( 79 | self, 80 | *, 81 | dim = 512, 82 | levels = 6, 83 | image_size = 224, 84 | patch_size = 14, 85 | consensus_self = False, 86 | local_consensus_radius = 0 87 | ): 88 | super().__init__() 89 | # bottom level - incoming image, tokenize and add position 90 | num_patches_side = (image_size // patch_size) 91 | num_patches = num_patches_side ** 2 92 | self.levels = levels 93 | 94 | self.image_to_tokens = nn.Sequential( 95 | Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size), 96 | nn.Linear(patch_size ** 2 * 3, dim) 97 | ) 98 | self.pos_emb = nn.Embedding(num_patches, dim) 99 | 100 | # initial embeddings for all levels of a column 101 | self.init_levels = nn.Parameter(torch.randn(levels, dim)) 102 | 103 | # bottom-up and top-down 104 | self.bottom_up = GroupedFeedForward(dim = dim, groups = levels) 105 | self.top_down = GroupedFeedForward(dim = dim, groups = levels - 1) 106 | 107 | # consensus attention 108 | self.attention = ConsensusAttention(num_patches_side, attend_self = consensus_self, local_consensus_radius = local_consensus_radius) 109 | 110 | def forward(self, img, iters = None, levels = None, return_all = False): 111 | b, device = img.shape[0], img.device 112 | iters = default(iters, self.levels * 2) # need to have twice the number of levels of iterations in order for information to propagate up and back down. can be overridden 113 | 114 | tokens = self.image_to_tokens(img) 115 | n = tokens.shape[1] 116 | 117 | pos_embs = self.pos_emb(torch.arange(n, device = device)) 118 | pos_embs = rearrange(pos_embs, 'n d -> () n () d') 119 | 120 | bottom_level = tokens 121 | bottom_level = rearrange(bottom_level, 'b n d -> b n () d') 122 | 123 | if not exists(levels): 124 | levels = repeat(self.init_levels, 'l d -> b n l d', b = b, n = n) 125 | 126 | hiddens = [levels] 127 | 128 | num_contributions = torch.empty(self.levels, device = device).fill_(4) 129 | num_contributions[-1] = 3 # top level does not get a top-down contribution, so have to account for this when doing the weighted mean 130 | 131 | for _ in range(iters): 132 | levels_with_input = torch.cat((bottom_level, levels), dim = -2) # each iteration, attach original input at the most bottom level, to be bottomed-up 133 | 134 | bottom_up_out = self.bottom_up(levels_with_input[..., :-1, :]) 135 | 136 | top_down_out = self.top_down(levels_with_input[..., 2:, :] + pos_embs) # positional embeddings given to top-down networks 137 | top_down_out = F.pad(top_down_out, (0, 0, 0, 1), value = 0.) 138 | 139 | consensus = self.attention(levels) 140 | 141 | levels_sum = torch.stack((levels, bottom_up_out, top_down_out, consensus)).sum(dim = 0) # hinton said to use the weighted mean of (1) bottom up (2) top down (3) previous level value {t - 1} (4) consensus value 142 | levels_mean = levels_sum / rearrange(num_contributions, 'l -> () () l ()') 143 | 144 | levels = levels_mean # set for next iteration 145 | hiddens.append(levels) 146 | 147 | if return_all: 148 | return torch.stack(hiddens) # return (time step, batch, num columns, levels, dimension) 149 | 150 | return levels 151 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name = 'glom-pytorch', 5 | packages = find_packages(), 6 | version = '0.0.14', 7 | license='MIT', 8 | description = 'Glom - Pytorch', 9 | author = 'Phil Wang', 10 | author_email = 'lucidrains@gmail.com', 11 | url = 'https://github.com/lucidrains/glom-pytorch', 12 | keywords = [ 13 | 'artificial intelligence', 14 | 'deep learning' 15 | ], 16 | install_requires=[ 17 | 'einops>=0.3', 18 | 'torch>=1.6' 19 | ], 20 | classifiers=[ 21 | 'Development Status :: 4 - Beta', 22 | 'Intended Audience :: Developers', 23 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 24 | 'License :: OSI Approved :: MIT License', 25 | 'Programming Language :: Python :: 3.6', 26 | ], 27 | ) 28 | --------------------------------------------------------------------------------