├── .github
└── workflows
│ └── python-publish.yml
├── .gitignore
├── LICENSE
├── README.md
├── glom1.png
├── glom2.png
├── glom_pytorch
├── __init__.py
└── glom_pytorch.py
└── setup.py
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | # This workflow will upload a Python Package using Twine when a release is created
2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3 |
4 | name: Upload Python Package
5 |
6 | on:
7 | release:
8 | types: [created]
9 |
10 | jobs:
11 | deploy:
12 |
13 | runs-on: ubuntu-latest
14 |
15 | steps:
16 | - uses: actions/checkout@v2
17 | - name: Set up Python
18 | uses: actions/setup-python@v2
19 | with:
20 | python-version: '3.x'
21 | - name: Install dependencies
22 | run: |
23 | python -m pip install --upgrade pip
24 | pip install setuptools wheel twine
25 | - name: Build and publish
26 | env:
27 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
28 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
29 | run: |
30 | python setup.py sdist bdist_wheel
31 | twine upload dist/*
32 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Phil Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | ## GLOM - Pytorch
6 |
7 | An implementation of Glom, Geoffrey Hinton's new idea that integrates concepts from neural fields, top-down-bottom-up processing, and attention (consensus between columns) for learning emergent part-whole heirarchies from data.
8 |
9 | Yannic Kilcher's video was instrumental in helping me to understand this paper
10 |
11 | ## Install
12 |
13 | ```bash
14 | $ pip install glom-pytorch
15 | ```
16 |
17 | ## Usage
18 |
19 | ```python
20 | import torch
21 | from glom_pytorch import Glom
22 |
23 | model = Glom(
24 | dim = 512, # dimension
25 | levels = 6, # number of levels
26 | image_size = 224, # image size
27 | patch_size = 14 # patch size
28 | )
29 |
30 | img = torch.randn(1, 3, 224, 224)
31 | levels = model(img, iters = 12) # (1, 256, 6, 512) - (batch - patches - levels - dimension)
32 | ```
33 |
34 | Pass the `return_all = True` keyword argument on forward, and you will be returned all the column and level states per iteration, (including the initial state, number of iterations + 1). You can then use this to attach any losses to any level outputs at any time step.
35 |
36 | It also gives you access to all the level data across iterations for clustering, from which one can inspect for the theorized islands in the paper.
37 |
38 | ```python
39 | import torch
40 | from glom_pytorch import Glom
41 |
42 | model = Glom(
43 | dim = 512, # dimension
44 | levels = 6, # number of levels
45 | image_size = 224, # image size
46 | patch_size = 14 # patch size
47 | )
48 |
49 | img = torch.randn(1, 3, 224, 224)
50 | all_levels = model(img, iters = 12, return_all = True) # (13, 1, 256, 6, 512) - (time, batch, patches, levels, dimension)
51 |
52 | # get the top level outputs after iteration 6
53 | top_level_output = all_levels[7, :, :, -1] # (1, 256, 512) - (batch, patches, dimension)
54 | ```
55 |
56 | Denoising self-supervised learning for encouraging emergence, as described by Hinton
57 |
58 | ```python
59 | import torch
60 | import torch.nn.functional as F
61 | from torch import nn
62 | from einops.layers.torch import Rearrange
63 |
64 | from glom_pytorch import Glom
65 |
66 | model = Glom(
67 | dim = 512, # dimension
68 | levels = 6, # number of levels
69 | image_size = 224, # image size
70 | patch_size = 14 # patch size
71 | )
72 |
73 | img = torch.randn(1, 3, 224, 224)
74 | noised_img = img + torch.randn_like(img)
75 |
76 | all_levels = model(noised_img, return_all = True)
77 |
78 | patches_to_images = nn.Sequential(
79 | nn.Linear(512, 14 * 14 * 3),
80 | Rearrange('b (h w) (p1 p2 c) -> b c (h p1) (w p2)', p1 = 14, p2 = 14, h = (224 // 14))
81 | )
82 |
83 | top_level = all_levels[7, :, :, -1] # get the top level embeddings after iteration 6
84 | recon_img = patches_to_images(top_level)
85 |
86 | # do self-supervised learning by denoising
87 |
88 | loss = F.mse_loss(img, recon_img)
89 | loss.backward()
90 | ```
91 |
92 | You can pass in the state of the column and levels back into the model to continue where you left off (perhaps if you are processing consecutive frames of a slow video, as mentioned in the paper)
93 |
94 | ```python
95 | import torch
96 | from glom_pytorch import Glom
97 |
98 | model = Glom(
99 | dim = 512,
100 | levels = 6,
101 | image_size = 224,
102 | patch_size = 14
103 | )
104 |
105 | img1 = torch.randn(1, 3, 224, 224)
106 | img2 = torch.randn(1, 3, 224, 224)
107 | img3 = torch.randn(1, 3, 224, 224)
108 |
109 | levels1 = model(img1, iters = 12) # image 1 for 12 iterations
110 | levels2 = model(img2, levels = levels1, iters = 10) # image 2 for 10 iteratoins
111 | levels3 = model(img3, levels = levels2, iters = 6) # image 3 for 6 iterations
112 | ```
113 |
114 | ### Appreciation
115 |
116 | Thanks goes out to Cfoster0 for reviewing the code
117 |
118 | ### Todo
119 |
120 | - [ ] contrastive / consistency regularization of top-ish levels
121 |
122 | ## Citations
123 |
124 | ```bibtex
125 | @misc{hinton2021represent,
126 | title = {How to represent part-whole hierarchies in a neural network},
127 | author = {Geoffrey Hinton},
128 | year = {2021},
129 | eprint = {2102.12627},
130 | archivePrefix = {arXiv},
131 | primaryClass = {cs.CV}
132 | }
133 | ```
134 |
--------------------------------------------------------------------------------
/glom1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/glom-pytorch/f30f62165d0c9f9ccdc0330b0005c35ffaaa1635/glom1.png
--------------------------------------------------------------------------------
/glom2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/glom-pytorch/f30f62165d0c9f9ccdc0330b0005c35ffaaa1635/glom2.png
--------------------------------------------------------------------------------
/glom_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | from glom_pytorch.glom_pytorch import Glom
2 |
--------------------------------------------------------------------------------
/glom_pytorch/glom_pytorch.py:
--------------------------------------------------------------------------------
1 | from math import sqrt
2 | import torch
3 | import torch.nn.functional as F
4 | from torch import nn, einsum
5 |
6 | from einops import rearrange, repeat
7 | from einops.layers.torch import Rearrange
8 |
9 | # constants
10 |
11 | TOKEN_ATTEND_SELF_VALUE = -5e-4
12 |
13 | # helpers
14 |
15 | def exists(val):
16 | return val is not None
17 |
18 | def default(val, d):
19 | return val if exists(val) else d
20 |
21 | # class
22 |
23 | class GroupedFeedForward(nn.Module):
24 | def __init__(self, *, dim, groups, mult = 4):
25 | super().__init__()
26 | total_dim = dim * groups # levels * dim
27 | self.net = nn.Sequential(
28 | Rearrange('b n l d -> b (l d) n'),
29 | nn.Conv1d(total_dim, total_dim * mult, 1, groups = groups),
30 | nn.GELU(),
31 | nn.Conv1d(total_dim * mult, total_dim, 1, groups = groups),
32 | Rearrange('b (l d) n -> b n l d', l = groups)
33 | )
34 |
35 | def forward(self, levels):
36 | return self.net(levels)
37 |
38 | class ConsensusAttention(nn.Module):
39 | def __init__(self, num_patches_side, attend_self = True, local_consensus_radius = 0):
40 | super().__init__()
41 | self.attend_self = attend_self
42 | self.local_consensus_radius = local_consensus_radius
43 |
44 | if self.local_consensus_radius > 0:
45 | coors = torch.stack(torch.meshgrid(
46 | torch.arange(num_patches_side),
47 | torch.arange(num_patches_side)
48 | )).float()
49 |
50 | coors = rearrange(coors, 'c h w -> (h w) c')
51 | dist = torch.cdist(coors, coors)
52 | mask_non_local = dist > self.local_consensus_radius
53 | mask_non_local = rearrange(mask_non_local, 'i j -> () i j')
54 | self.register_buffer('non_local_mask', mask_non_local)
55 |
56 | def forward(self, levels):
57 | _, n, _, d, device = *levels.shape, levels.device
58 | q, k, v = levels, F.normalize(levels, dim = -1), levels
59 |
60 | sim = einsum('b i l d, b j l d -> b l i j', q, k) * (d ** -0.5)
61 |
62 | if not self.attend_self:
63 | self_mask = torch.eye(n, device = device, dtype = torch.bool)
64 | self_mask = rearrange(self_mask, 'i j -> () () i j')
65 | sim.masked_fill_(self_mask, TOKEN_ATTEND_SELF_VALUE)
66 |
67 | if self.local_consensus_radius > 0:
68 | max_neg_value = -torch.finfo(sim.dtype).max
69 | sim.masked_fill_(self.non_local_mask, max_neg_value)
70 |
71 | attn = sim.softmax(dim = -1)
72 | out = einsum('b l i j, b j l d -> b i l d', attn, levels)
73 | return out
74 |
75 | # main class
76 |
77 | class Glom(nn.Module):
78 | def __init__(
79 | self,
80 | *,
81 | dim = 512,
82 | levels = 6,
83 | image_size = 224,
84 | patch_size = 14,
85 | consensus_self = False,
86 | local_consensus_radius = 0
87 | ):
88 | super().__init__()
89 | # bottom level - incoming image, tokenize and add position
90 | num_patches_side = (image_size // patch_size)
91 | num_patches = num_patches_side ** 2
92 | self.levels = levels
93 |
94 | self.image_to_tokens = nn.Sequential(
95 | Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
96 | nn.Linear(patch_size ** 2 * 3, dim)
97 | )
98 | self.pos_emb = nn.Embedding(num_patches, dim)
99 |
100 | # initial embeddings for all levels of a column
101 | self.init_levels = nn.Parameter(torch.randn(levels, dim))
102 |
103 | # bottom-up and top-down
104 | self.bottom_up = GroupedFeedForward(dim = dim, groups = levels)
105 | self.top_down = GroupedFeedForward(dim = dim, groups = levels - 1)
106 |
107 | # consensus attention
108 | self.attention = ConsensusAttention(num_patches_side, attend_self = consensus_self, local_consensus_radius = local_consensus_radius)
109 |
110 | def forward(self, img, iters = None, levels = None, return_all = False):
111 | b, device = img.shape[0], img.device
112 | iters = default(iters, self.levels * 2) # need to have twice the number of levels of iterations in order for information to propagate up and back down. can be overridden
113 |
114 | tokens = self.image_to_tokens(img)
115 | n = tokens.shape[1]
116 |
117 | pos_embs = self.pos_emb(torch.arange(n, device = device))
118 | pos_embs = rearrange(pos_embs, 'n d -> () n () d')
119 |
120 | bottom_level = tokens
121 | bottom_level = rearrange(bottom_level, 'b n d -> b n () d')
122 |
123 | if not exists(levels):
124 | levels = repeat(self.init_levels, 'l d -> b n l d', b = b, n = n)
125 |
126 | hiddens = [levels]
127 |
128 | num_contributions = torch.empty(self.levels, device = device).fill_(4)
129 | num_contributions[-1] = 3 # top level does not get a top-down contribution, so have to account for this when doing the weighted mean
130 |
131 | for _ in range(iters):
132 | levels_with_input = torch.cat((bottom_level, levels), dim = -2) # each iteration, attach original input at the most bottom level, to be bottomed-up
133 |
134 | bottom_up_out = self.bottom_up(levels_with_input[..., :-1, :])
135 |
136 | top_down_out = self.top_down(levels_with_input[..., 2:, :] + pos_embs) # positional embeddings given to top-down networks
137 | top_down_out = F.pad(top_down_out, (0, 0, 0, 1), value = 0.)
138 |
139 | consensus = self.attention(levels)
140 |
141 | levels_sum = torch.stack((levels, bottom_up_out, top_down_out, consensus)).sum(dim = 0) # hinton said to use the weighted mean of (1) bottom up (2) top down (3) previous level value {t - 1} (4) consensus value
142 | levels_mean = levels_sum / rearrange(num_contributions, 'l -> () () l ()')
143 |
144 | levels = levels_mean # set for next iteration
145 | hiddens.append(levels)
146 |
147 | if return_all:
148 | return torch.stack(hiddens) # return (time step, batch, num columns, levels, dimension)
149 |
150 | return levels
151 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name = 'glom-pytorch',
5 | packages = find_packages(),
6 | version = '0.0.14',
7 | license='MIT',
8 | description = 'Glom - Pytorch',
9 | author = 'Phil Wang',
10 | author_email = 'lucidrains@gmail.com',
11 | url = 'https://github.com/lucidrains/glom-pytorch',
12 | keywords = [
13 | 'artificial intelligence',
14 | 'deep learning'
15 | ],
16 | install_requires=[
17 | 'einops>=0.3',
18 | 'torch>=1.6'
19 | ],
20 | classifiers=[
21 | 'Development Status :: 4 - Beta',
22 | 'Intended Audience :: Developers',
23 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
24 | 'License :: OSI Approved :: MIT License',
25 | 'Programming Language :: Python :: 3.6',
26 | ],
27 | )
28 |
--------------------------------------------------------------------------------