├── diagram.png
├── diagram-2.png
├── lie_transformer_pytorch
├── __init__.py
├── reversible.py
├── se3.py
└── lie_transformer_pytorch.py
├── setup.cfg
├── tests.py
├── .github
└── workflows
│ ├── python-publish.yml
│ └── python-package.yml
├── setup.py
├── LICENSE
├── .gitignore
└── README.md
/diagram.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/lie-transformer-pytorch/HEAD/diagram.png
--------------------------------------------------------------------------------
/diagram-2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/lie-transformer-pytorch/HEAD/diagram-2.png
--------------------------------------------------------------------------------
/lie_transformer_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | from lie_transformer_pytorch.lie_transformer_pytorch import LieTransformer
2 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [aliases]
2 | test=pytest
3 |
4 | [tool:pytest]
5 | addopts = --verbose
6 | python_files = tests.py
7 |
--------------------------------------------------------------------------------
/tests.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from lie_transformer_pytorch import LieTransformer
3 |
4 | def test_transformer():
5 | model = LieTransformer(
6 | dim = 512,
7 | depth = 1
8 | )
9 |
10 | feats = torch.randn(1, 64, 512)
11 | coors = torch.randn(1, 64, 3)
12 | mask = torch.ones(1, 64).bool()
13 |
14 | out = model(feats, coors, mask = mask)
15 | assert out.shape == (1, 256, 512), 'transformer runs'
16 |
--------------------------------------------------------------------------------
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | # This workflow will upload a Python Package using Twine when a release is created
2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3 |
4 | name: Upload Python Package
5 |
6 | on:
7 | release:
8 | types: [created]
9 |
10 | jobs:
11 | deploy:
12 |
13 | runs-on: ubuntu-latest
14 |
15 | steps:
16 | - uses: actions/checkout@v2
17 | - name: Set up Python
18 | uses: actions/setup-python@v2
19 | with:
20 | python-version: '3.x'
21 | - name: Install dependencies
22 | run: |
23 | python -m pip install --upgrade pip
24 | pip install setuptools wheel twine
25 | - name: Build and publish
26 | env:
27 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
28 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
29 | run: |
30 | python setup.py sdist bdist_wheel
31 | twine upload dist/*
32 |
--------------------------------------------------------------------------------
/.github/workflows/python-package.yml:
--------------------------------------------------------------------------------
1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions
2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
3 |
4 | name: Python package
5 |
6 | on:
7 | push:
8 | branches: [ main ]
9 | pull_request:
10 | branches: [ main ]
11 |
12 | jobs:
13 | build:
14 |
15 | runs-on: ubuntu-latest
16 | strategy:
17 | matrix:
18 | python-version: [3.7, 3.8, 3.9]
19 |
20 | steps:
21 | - uses: actions/checkout@v2
22 | - name: Set up Python ${{ matrix.python-version }}
23 | uses: actions/setup-python@v2
24 | with:
25 | python-version: ${{ matrix.python-version }}
26 | - name: Install dependencies
27 | run: |
28 | python -m pip install --upgrade pip
29 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
30 | - name: Test with pytest
31 | run: |
32 | python setup.py test
33 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name = 'lie-transformer-pytorch',
5 | packages = find_packages(),
6 | version = '0.0.17',
7 | license='MIT',
8 | description = 'Lie Transformer - Pytorch',
9 | author = 'Phil Wang',
10 | author_email = 'lucidrains@gmail.com',
11 | url = 'https://github.com/lucidrains/lie-transformer-pytorch',
12 | keywords = [
13 | 'artificial intelligence',
14 | 'attention mechanism',
15 | 'transformers',
16 | 'equivariance',
17 | 'lifting',
18 | 'lie groups'
19 | ],
20 | install_requires=[
21 | 'torch>=1.6',
22 | 'einops>=0.3'
23 | ],
24 | setup_requires=[
25 | 'pytest-runner',
26 | ],
27 | tests_require=[
28 | 'pytest'
29 | ],
30 | classifiers=[
31 | 'Development Status :: 4 - Beta',
32 | 'Intended Audience :: Developers',
33 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
34 | 'License :: OSI Approved :: MIT License',
35 | 'Programming Language :: Python :: 3.6',
36 | ],
37 | )
38 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Phil Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | ## Lie Transformer - Pytorch
6 |
7 | Implementation of Lie Transformer, Equivariant Self-Attention, in Pytorch. Only the SE3 version will be present in this repository, as it may be needed for Alphafold2 replication.
8 |
9 | ## Install
10 |
11 | ```bash
12 | $ pip install lie-transformer-pytorch
13 | ```
14 |
15 | ## Usage
16 |
17 | ```python
18 | import torch
19 | from lie_transformer_pytorch import LieTransformer
20 |
21 | model = LieTransformer(
22 | dim = 512,
23 | depth = 2,
24 | heads = 8,
25 | dim_head = 64,
26 | liftsamples = 4
27 | )
28 |
29 | coors = torch.randn(1, 64, 3)
30 | features = torch.randn(1, 64, 512)
31 | mask = torch.ones(1, 64).bool()
32 |
33 | out = model(features, coors, mask = mask) # (1, 256, 512) <- 256 = (seq len * liftsamples)
34 | ```
35 |
36 | Allowing Lie Transformer take care of embedding the features, just specify the number of unique tokens (node types).
37 |
38 | ```python
39 | import torch
40 | from lie_transformer_pytorch import LieTransformer
41 |
42 | model = LieTransformer(
43 | num_tokens = 28, # say 28 different types of atoms
44 | dim = 512,
45 | depth = 2,
46 | heads = 8,
47 | dim_head = 64,
48 | liftsamples = 4
49 | )
50 |
51 | atoms = torch.randint(0, 28, (1, 64))
52 | coors = torch.randn(1, 64, 3)
53 | mask = torch.ones(1, 64).bool()
54 |
55 | out = model(atoms, coors, mask = mask) # (1, 256, 512) <- 256 = (seq len * liftsamples)
56 | ```
57 |
58 | Although it was not in the paper, I decided to allow for passing in edge information as well (bond types). The edge information will be embedded by the dimension specified, concatted with the location, and passed through the MLP before summed with the attention matrix.
59 |
60 | Simply set two more keyword arguments on initialization of the transformer, and then pass in the specific bond types as shape `b x seq x seq`.
61 |
62 | ```python
63 | import torch
64 | from lie_transformer_pytorch import LieTransformer
65 |
66 | model = LieTransformer(
67 | num_tokens = 28, # say 28 different types of atoms
68 | num_edge_types = 4, # number of different edge types
69 | edge_dim = 16, # dimension of edges
70 | dim = 512,
71 | depth = 2,
72 | heads = 8,
73 | dim_head = 64,
74 | liftsamples = 4
75 | )
76 |
77 | atoms = torch.randint(0, 28, (1, 64))
78 | bonds = torch.randint(0, 4, (1, 64, 64))
79 | coors = torch.randn(1, 64, 3)
80 | mask = torch.ones(1, 64).bool()
81 |
82 | out = model(atoms, coors, edges = bonds, mask = mask) # (1, 256, 512) <- 256 = (seq len * liftsamples)
83 | ```
84 |
85 | ## Credit
86 |
87 | This repository is largely adapted from LieConv, cited below
88 |
89 | ## Citations
90 |
91 | ```bibtex
92 | @misc{hutchinson2020lietransformer,
93 | title = {LieTransformer: Equivariant self-attention for Lie Groups},
94 | author = {Michael Hutchinson and Charline Le Lan and Sheheryar Zaidi and Emilien Dupont and Yee Whye Teh and Hyunjik Kim},
95 | year = {2020},
96 | eprint = {2012.10885},
97 | archivePrefix = {arXiv},
98 | primaryClass = {cs.LG}
99 | }
100 | ```
101 |
102 | ```bibtex
103 | @misc{finzi2020generalizing,
104 | title = {Generalizing Convolutional Neural Networks for Equivariance to Lie Groups on Arbitrary Continuous Data},
105 | author = {Marc Finzi and Samuel Stanton and Pavel Izmailov and Andrew Gordon Wilson},
106 | year = {2020},
107 | eprint = {2002.12880},
108 | archivePrefix = {arXiv},
109 | primaryClass = {stat.ML}
110 | }
111 | ```
112 |
--------------------------------------------------------------------------------
/lie_transformer_pytorch/reversible.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.autograd.function import Function
4 | from torch.utils.checkpoint import get_device_states, set_device_states
5 |
6 | # helpers
7 |
8 | def sum_tuple(x, y, dim = 1):
9 | x = list(x)
10 | x[dim] += y[dim]
11 | return tuple(x)
12 |
13 | def subtract_tuple(x, y, dim = 1):
14 | x = list(x)
15 | x[dim] -= y[dim]
16 | return tuple(x)
17 |
18 | def set_tuple(x, dim, value):
19 | x = list(x).copy()
20 | x[dim] = value
21 | return tuple(x)
22 |
23 | def map_tuple(fn, x, dim = 1):
24 | x = list(x)
25 | x[dim] = fn(x[dim])
26 | return tuple(x)
27 |
28 | def chunk_tuple(fn, x, dim = 1):
29 | x = list(x)
30 | value = x[dim]
31 | chunks = fn(value)
32 | return tuple(map(lambda t: set_tuple(x, 1, t), chunks))
33 |
34 | def cat_tuple(x, y, dim = 1, cat_dim = -1):
35 | x = list(x)
36 | y = list(y)
37 | x[dim] = torch.cat((x[dim], y[dim]), dim = cat_dim)
38 | return tuple(x)
39 |
40 | def del_tuple(x):
41 | for el in x:
42 | if el is not None:
43 | del el
44 |
45 | # following example for saving and setting rng here https://pytorch.org/docs/stable/_modules/torch/utils/checkpoint.html
46 | class Deterministic(nn.Module):
47 | def __init__(self, net):
48 | super().__init__()
49 | self.net = net
50 | self.cpu_state = None
51 | self.cuda_in_fwd = None
52 | self.gpu_devices = None
53 | self.gpu_states = None
54 |
55 | def record_rng(self, *args):
56 | self.cpu_state = torch.get_rng_state()
57 | if torch.cuda._initialized:
58 | self.cuda_in_fwd = True
59 | self.gpu_devices, self.gpu_states = get_device_states(*args)
60 |
61 | def forward(self, *args, record_rng = False, set_rng = False, **kwargs):
62 | if record_rng:
63 | self.record_rng(*args)
64 |
65 | if not set_rng:
66 | return self.net(*args, **kwargs)
67 |
68 | rng_devices = []
69 | if self.cuda_in_fwd:
70 | rng_devices = self.gpu_devices
71 |
72 | with torch.random.fork_rng(devices=rng_devices, enabled=True):
73 | torch.set_rng_state(self.cpu_state)
74 | if self.cuda_in_fwd:
75 | set_device_states(self.gpu_devices, self.gpu_states)
76 | return self.net(*args, **kwargs)
77 |
78 | # heavily inspired by https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py
79 | # once multi-GPU is confirmed working, refactor and send PR back to source
80 | class ReversibleBlock(nn.Module):
81 | def __init__(self, f, g):
82 | super().__init__()
83 | self.f = Deterministic(f)
84 | self.g = Deterministic(g)
85 |
86 | def forward(self, x, f_args = {}, g_args = {}):
87 | training = self.training
88 | x1, x2 = chunk_tuple(lambda t: torch.chunk(t, 2, dim=2), x)
89 | y1, y2 = None, None
90 |
91 | with torch.no_grad():
92 | y1 = sum_tuple(self.f(x2, record_rng = training, **f_args), x1)
93 | y2 = sum_tuple(self.g(y1, record_rng = training, **g_args), x2)
94 |
95 | return cat_tuple(y1, y2, cat_dim = 2)
96 |
97 | def backward_pass(self, y, dy, f_args = {}, g_args = {}):
98 | y1, y2 = chunk_tuple(lambda t: torch.chunk(t, 2, dim=2), y)
99 | del_tuple(y)
100 |
101 | dy1, dy2 = torch.chunk(dy, 2, dim=2)
102 | del dy
103 |
104 | with torch.enable_grad():
105 | y1[1].requires_grad = True
106 | gy1 = self.g(y1, set_rng=True, **g_args)
107 | torch.autograd.backward(gy1[1], dy2)
108 |
109 | with torch.no_grad():
110 | x2 = subtract_tuple(y2, gy1)
111 | del_tuple(y2)
112 | del gy1
113 |
114 | dx1 = dy1 + y1[1].grad
115 | del dy1
116 | y1[1].grad = None
117 |
118 | with torch.enable_grad():
119 | x2[1].requires_grad = True
120 | fx2 = self.f(x2, set_rng = True, **f_args)
121 | torch.autograd.backward(fx2[1], dx1)
122 |
123 | with torch.no_grad():
124 | x1 = subtract_tuple(y1, fx2)
125 | del fx2
126 | del_tuple(y1)
127 |
128 | dx2 = dy2 + x2[1].grad
129 | del dy2
130 | x2[1].grad = None
131 |
132 | x2 = map_tuple(lambda t: t.detach(), x2)
133 | x = cat_tuple(x1, x2, cat_dim = -1)
134 | dx = torch.cat((dx1, dx2), dim=2)
135 |
136 | return x, dx
137 |
138 | class _ReversibleFunction(Function):
139 | @staticmethod
140 | def forward(ctx, x, blocks, kwargs):
141 | ctx.kwargs = kwargs
142 | x = (kwargs.pop('coords'), x, kwargs.pop('mask'), kwargs.pop('edges'))
143 |
144 | for block in blocks:
145 | x = block(x, **kwargs)
146 |
147 | ctx.y = map_tuple(lambda t: t.detach(), x, dim = 1)
148 | ctx.blocks = blocks
149 | return x[1]
150 |
151 | @staticmethod
152 | def backward(ctx, dy):
153 | y = ctx.y
154 | kwargs = ctx.kwargs
155 |
156 | for block in ctx.blocks[::-1]:
157 | y, dy = block.backward_pass(y, dy, **kwargs)
158 | return dy, None, None
159 |
160 | class SequentialSequence(nn.Module):
161 | def __init__(self, blocks):
162 | super().__init__()
163 | self.blocks = blocks
164 |
165 | def forward(self, x):
166 | for (f, g) in self.blocks:
167 | x = sum_tuple(f(x), x, dim = 1)
168 | x = sum_tuple(g(x), x, dim = 1)
169 | return x
170 |
171 | class ReversibleSequence(nn.Module):
172 | def __init__(self, blocks):
173 | super().__init__()
174 | self.blocks = nn.ModuleList([ReversibleBlock(f, g) for (f, g) in blocks])
175 |
176 | def forward(self, x, **kwargs):
177 | x = map_tuple(lambda t: torch.cat((t, t), dim = -1), x)
178 |
179 | blocks = self.blocks
180 |
181 | coords, values, mask, edges = x
182 | kwargs = {'coords': coords, 'mask': mask, 'edges': edges, **kwargs}
183 | x = _ReversibleFunction.apply(values, blocks, kwargs)
184 |
185 | x = (coords, x, mask, edges)
186 | return map_tuple(lambda t: sum(t.chunk(2, dim = -1)) * 0.5, x)
187 |
--------------------------------------------------------------------------------
/lie_transformer_pytorch/se3.py:
--------------------------------------------------------------------------------
1 | from math import pi
2 | import torch
3 | from functools import wraps
4 | from torch import acos, atan2, cos, sin
5 | from einops import rearrange, repeat
6 |
7 | # constants
8 |
9 | THRES = 7e-2
10 |
11 | # helper functions
12 |
13 | def exists(val):
14 | return val is not None
15 |
16 | def to(t):
17 | return {'device': t.device, 'dtype': t.dtype}
18 |
19 | def taylor(thres):
20 | def outer(fn):
21 | @wraps(fn)
22 | def inner(x):
23 | usetaylor = x.abs() < THRES
24 | taylor_expanded, full = fn(x, x * x)
25 | return torch.where(usetaylor, taylor_expanded, full)
26 | return inner
27 | return outer
28 |
29 | # Helper functions for analytic exponential maps. Uses taylor expansions near x=0
30 | # See http://ethaneade.com/lie_groups.pdf for derivations.
31 |
32 | @taylor(THRES)
33 | def sinc(x, x2):
34 | """ sin(x)/x """
35 | texpand = 1-x2/6*(1-x2/20*(1-x2/42))
36 | full = sin(x) / x
37 | return texpand, full
38 |
39 | @taylor(THRES)
40 | def sincc(x, x2):
41 | """ (1-sinc(x))/x^2"""
42 | texpand = 1/6*(1-x2/20*(1-x2/42*(1-x2/72)))
43 | full = (x-sin(x)) / x**3
44 | return texpand, full
45 |
46 | @taylor(THRES)
47 | def cosc(x, x2):
48 | """ (1-cos(x))/x^2"""
49 | texpand = 1/2*(1-x2/12*(1-x2/30*(1-x2/56)))
50 | full = (1-cos(x)) / x2
51 | return texpand, full
52 |
53 | @taylor(THRES)
54 | def coscc(x, x2):
55 | #assert not torch.any(torch.isinf(x2)), f"infs in x2 log"
56 | texpand = 1/12*(1+x2/60*(1+x2/42*(1+x2/40)))
57 | costerm = (2*(1-cos(x))).clamp(min=1e-6)
58 | full = (1-x*sin(x)/costerm) / x2 #Nans can come up here when cos = 1
59 | return texpand, full
60 |
61 | @taylor(THRES)
62 | def sinc_inv(x, _):
63 | texpand = 1+(1/6)*x**2 +(7/360)*x**4
64 | full = x / sin(x)
65 | assert not torch.any(torch.isinf(texpand)|torch.isnan(texpand)),'sincinv texpand inf'+torch.any(torch.isinf(texpand))
66 | return texpand, full
67 |
68 | ## Lie Groups acting on R3
69 |
70 | # Hodge star on R3
71 | def cross_matrix(k):
72 | """Application of hodge star on R3, mapping Λ^1 R3 -> Λ^2 R3"""
73 | K = torch.zeros(*k.shape[:-1], 3, 3, **to(k))
74 | K[...,0,1] = -k[...,2]
75 | K[...,0,2] = k[...,1]
76 | K[...,1,0] = k[...,2]
77 | K[...,1,2] = -k[...,0]
78 | K[...,2,0] = -k[...,1]
79 | K[...,2,1] = k[...,0]
80 | return K
81 |
82 | def uncross_matrix(K):
83 | """Application of hodge star on R3, mapping Λ^2 R3 -> Λ^1 R3"""
84 | k = torch.zeros(*K.shape[:-1], **to(K))
85 | k[...,0] = (K[...,2,1] - K[...,1,2])/2
86 | k[...,1] = (K[...,0,2] - K[...,2,0])/2
87 | k[...,2] = (K[...,1,0] - K[...,0,1])/2
88 | return k
89 |
90 | class SO3:
91 | lie_dim = 3
92 | rep_dim = 3
93 | q_dim = 1
94 |
95 | def __init__(self, alpha = .2):
96 | super().__init__()
97 | self.alpha = alpha
98 |
99 | def exp(self,w):
100 | """ Computes (matrix) exponential Lie algebra elements (in a given basis).
101 | ie out = exp(\sum_i a_i A_i) where A_i are the exponential generators of G.
102 | Input: [a (*,lie_dim)] where * is arbitrarily shaped
103 | Output: [exp(a) (*,rep_dim,rep_dim)] returns the matrix for each."""
104 |
105 | """ Rodriguez's formula, assuming shape (*,3)
106 | where components 1,2,3 are the generators for xrot,yrot,zrot"""
107 | theta = w.norm(dim=-1)[..., None, None]
108 | K = cross_matrix(w)
109 | I = torch.eye(3, **to(K))
110 | Rs = I + K * sinc(theta) + (K @ K) * cosc(theta)
111 | return Rs
112 |
113 | def log(self,R):
114 | """ Computes components in terms of generators rx,ry,rz. Shape (*,3,3)"""
115 |
116 | """ Computes (matrix) logarithm for collection of matrices and converts to Lie algebra basis.
117 | Input [u (*,rep_dim,rep_dim)]
118 | Output [coeffs of log(u) in basis (*,d)] """
119 | trR = R[..., 0, 0] + R[..., 1, 1] + R[..., 2, 2]
120 | costheta = ((trR-1) / 2).clamp(max=1, min=-1).unsqueeze(-1)
121 | theta = acos(costheta)
122 | logR = uncross_matrix(R) * sinc_inv(theta)
123 | return logR
124 |
125 | def inv(self,g):
126 | """ We can compute the inverse of elements g (*,rep_dim,rep_dim) as exp(-log(g))"""
127 | return self.exp(-self.log(g))
128 |
129 | def elems2pairs(self,a):
130 | """ computes log(e^-b e^a) for all a b pairs along n dimension of input.
131 | inputs: [a (bs,n,d)] outputs: [pairs_ab (bs,n,n,d)] """
132 | vinv = self.exp(-a.unsqueeze(-3))
133 | u = self.exp(a.unsqueeze(-2))
134 | return self.log(vinv@u) # ((bs,1,n,d) -> (bs,1,n,r,r))@((bs,n,1,d) -> (bs,n,1,r,r))
135 |
136 | def lift(self, x, nsamples, **kwargs):
137 | """assumes p has shape (*,n,2), vals has shape (*,n,c), mask has shape (*,n)
138 | returns (a,v) with shapes [(*,n*nsamples,lie_dim),(*,n*nsamples,c)"""
139 | p, v, m, e = x
140 | expanded_a = self.lifted_elems(p,nsamples,**kwargs) # (bs,n*ns,d), (bs,n*ns,qd)
141 | nsamples = expanded_a.shape[-2]//m.shape[-1]
142 | # expand v and mask like q
143 | expanded_v = repeat(v, 'b n c -> b (n m) c', m = nsamples) # (bs,n,c) -> (bs,n,1,c) -> (bs,n,ns,c) -> (bs,n*ns,c)
144 | expanded_mask = repeat(m, 'b n -> b (n m)', m = nsamples) # (bs,n) -> (bs,n,ns) -> (bs,n*ns)
145 | expanded_e = repeat(e, 'b n1 n2 c -> b (n1 m1) (n2 m2) c', m1 = nsamples, m2 = nsamples) if exists(e) else None
146 |
147 | # convert from elems to pairs
148 | paired_a = self.elems2pairs(expanded_a) #(bs,n*ns,d) -> (bs,n*ns,n*ns,d)
149 | embedded_locations = paired_a
150 | return (embedded_locations,expanded_v,expanded_mask, expanded_e)
151 |
152 | class SE3(SO3):
153 | lie_dim = 6
154 | rep_dim = 4
155 | q_dim = 0
156 |
157 | def __init__(self, alpha=.2, per_point=True):
158 | super().__init__()
159 | self.alpha = alpha
160 | self.per_point = per_point
161 |
162 | def exp(self,w):
163 | dd_kwargs = to(w)
164 | theta = w[...,:3].norm(dim=-1)[...,None,None]
165 | K = cross_matrix(w[...,:3])
166 | R = super().exp(w[...,:3])
167 | I = torch.eye(3, **dd_kwargs)
168 | V = I + cosc(theta)*K + sincc(theta)*(K@K)
169 | U = torch.zeros(*w.shape[:-1],4,4, **dd_kwargs)
170 | U[...,:3,:3] = R
171 | U[...,:3,3] = (V@w[...,3:].unsqueeze(-1)).squeeze(-1)
172 | U[...,3,3] = 1
173 | return U
174 |
175 | def log(self,U):
176 | w = super().log(U[..., :3, :3])
177 | I = torch.eye(3, **to(w))
178 | K = cross_matrix(w[..., :3])
179 | theta = w.norm(dim=-1)[..., None, None]#%(2*pi)
180 | #theta[theta>pi] -= 2*pi
181 | cosccc = coscc(theta)
182 | Vinv = I - K/2 + cosccc*(K@K)
183 | u = (Vinv @ U[..., :3, 3].unsqueeze(-1)).squeeze(-1)
184 | #assert not torch.any(torch.isnan(u)), f"nans in u log {torch.isnan(u).sum()}, {torch.where(torch.isnan(u))}"
185 | return torch.cat([w, u], dim=-1)
186 |
187 | def lifted_elems(self,pt,nsamples):
188 | """ pt (bs,n,D) mask (bs,n), per_point specifies whether to
189 | use a different group element per atom in the molecule"""
190 | #return farthest_lift(self,pt,mask,nsamples,alpha)
191 | # same lifts for each point right now
192 | bs,n = pt.shape[:2]
193 | dd_kwargs = to(pt)
194 |
195 | q = torch.randn(bs, (n if self.per_point else 1), nsamples, 4, **dd_kwargs)
196 | q /= q.norm(dim=-1).unsqueeze(-1)
197 |
198 | theta_2 = atan2(q[..., 1:].norm(dim=-1),q[..., 0])[..., None]
199 | so3_elem = 2 * sinc_inv(theta_2) * q[...,1:] # (sin(x/2)u -> xu) for x angle and u direction
200 | se3_elem = torch.cat([so3_elem, torch.zeros_like(so3_elem)], dim=-1)
201 | R = self.exp(se3_elem)
202 |
203 | T = torch.zeros(bs, n, nsamples, 4, 4, **dd_kwargs) # (bs,n,nsamples,4,4)
204 | T[..., :, :] = torch.eye(4, **dd_kwargs)
205 | T[..., :3, 3] = pt[..., None, :] # (bs,n,1,3)
206 |
207 | a = self.log(T @ R) # bs, n, nsamples, 6
208 | return a.reshape(bs, n * nsamples, 6)
209 |
210 | def distance(self,abq_pairs):
211 | dist_rot = abq_pairs[...,:3].norm(dim=-1)
212 | dist_trans = abq_pairs[...,3:].norm(dim=-1)
213 | return dist_rot * self.alpha + (1-self.alpha) * dist_trans
214 |
--------------------------------------------------------------------------------
/lie_transformer_pytorch/lie_transformer_pytorch.py:
--------------------------------------------------------------------------------
1 | import math
2 | from functools import partial
3 |
4 | import torch
5 | import torch.nn.functional as F
6 | from torch import nn, einsum
7 |
8 | from lie_transformer_pytorch.se3 import SE3
9 | from einops import rearrange, repeat
10 |
11 | from lie_transformer_pytorch.reversible import SequentialSequence, ReversibleSequence
12 |
13 | # helpers
14 |
15 | def exists(val):
16 | return val is not None
17 |
18 | def cast_tuple(val, depth):
19 | return val if isinstance(val, tuple) else ((val,) * depth)
20 |
21 | def default(val, d):
22 | return val if exists(val) else d
23 |
24 | def batched_index_select(values, indices, dim = 1):
25 | value_dims = values.shape[(dim + 1):]
26 | values_shape, indices_shape = map(lambda t: list(t.shape), (values, indices))
27 | indices = indices[(..., *((None,) * len(value_dims)))]
28 | indices = indices.expand(*((-1,) * len(indices_shape)), *value_dims)
29 | value_expand_len = len(indices_shape) - (dim + 1)
30 | values = values[(*((slice(None),) * dim), *((None,) * value_expand_len), ...)]
31 |
32 | value_expand_shape = [-1] * len(values.shape)
33 | expand_slice = slice(dim, (dim + value_expand_len))
34 | value_expand_shape[expand_slice] = indices.shape[expand_slice]
35 | values = values.expand(*value_expand_shape)
36 |
37 | dim += value_expand_len
38 | return values.gather(dim, indices)
39 |
40 | # helper classes
41 |
42 | class Pass(nn.Module):
43 | def __init__(self, fn, dim = 1):
44 | super().__init__()
45 | self.fn = fn
46 | self.dim = dim
47 |
48 | def forward(self,x):
49 | dim = self.dim
50 | xs = list(x)
51 | xs[dim] = self.fn(xs[dim])
52 | return xs
53 |
54 | class Lambda(nn.Module):
55 | def __init__(self, fn):
56 | super().__init__()
57 | self.fn = fn
58 |
59 | def forward(self, x):
60 | return self.fn(x)
61 |
62 |
63 | class GlobalPool(nn.Module):
64 | """computes values reduced over all spatial locations (& group elements) in the mask"""
65 | def __init__(self, mean = False):
66 | super().__init__()
67 | self.mean = mean
68 |
69 | def forward(self, x):
70 | coords, vals, mask = x
71 |
72 | if not exists(mask):
73 | return val.mean(dim = 1)
74 |
75 | masked_vals = vals.masked_fill_(~mask[..., None], 0.)
76 | summed = masked_vals.sum(dim = 1)
77 |
78 | if not self.mean:
79 | return summed
80 |
81 | count = mask.sum(-1).unsqueeze(-1)
82 | return summed / count
83 |
84 | # subsampling code
85 |
86 | def FPSindices(dists, frac, mask):
87 | """ inputs: pairwise distances DISTS (bs,n,n), downsample_frac (float), valid atom mask (bs,n)
88 | outputs: chosen_indices (bs,m) """
89 | m = int(round(frac * dists.shape[1]))
90 | bs, n, device = *dists.shape[:2], dists.device
91 | dd_kwargs = {'device': device, 'dtype': torch.long}
92 | B = torch.arange(bs, **dd_kwargs)
93 |
94 | chosen_indices = torch.zeros(bs, m, **dd_kwargs)
95 | distances = torch.ones(bs, n, device=device) * 1e8
96 | a = torch.randint(0, n, (bs,), **dd_kwargs) # choose random start
97 | idx = a % mask.sum(-1) + torch.cat([torch.zeros(1, **dd_kwargs), torch.cumsum(mask.sum(-1), dim=0)[:-1]], dim=0)
98 | farthest = torch.where(mask)[1][idx]
99 |
100 | for i in range(m):
101 | chosen_indices[:, i] = farthest # add point that is farthest to chosen
102 | dist = dists[B, farthest].masked_fill(~mask, -100) # (bs,n) compute distance from new point to all others
103 | closer = dist < distances # if dist from new point is smaller than chosen points so far
104 | distances[closer] = dist[closer] # update the chosen set's distance to all other points
105 | farthest = torch.max(distances, -1)[1] # select the point that is farthest from the set
106 |
107 | return chosen_indices
108 |
109 |
110 | class FPSsubsample(nn.Module):
111 | def __init__(self, ds_frac, cache = False, group = None):
112 | super().__init__()
113 | self.ds_frac = ds_frac
114 | self.cache = cache
115 | self.cached_indices = None
116 | self.group = default(group, SE3())
117 |
118 | def get_query_indices(self, abq_pairs, mask):
119 | if self.cache and exists(self.cached_indices):
120 | return self.cached_indices
121 |
122 | dist = self.group.distance if self.group else lambda ab: ab.norm(dim=-1)
123 | value = FPSindices(dist(abq_pairs), self.ds_frac, mask).detach()
124 |
125 | if self.cache:
126 | self.cached_indices = value
127 |
128 | return value
129 |
130 | def forward(self, inp, withquery=False):
131 | abq_pairs, vals, mask, edges = inp
132 | device = vals.device
133 |
134 | if self.ds_frac != 1:
135 | query_idx = self.get_query_indices(abq_pairs, mask)
136 |
137 | B = torch.arange(query_idx.shape[0], device = device).long()[:,None]
138 | subsampled_abq_pairs = abq_pairs[B, query_idx][B, :, query_idx]
139 | subsampled_values = batched_index_select(vals, query_idx, dim = 1)
140 | subsampled_mask = batched_index_select(mask, query_idx, dim = 1)
141 | subsampled_edges = edges[B, query_idx][B, :, query_idx] if exists(edges) else None
142 | else:
143 | subsampled_abq_pairs = abq_pairs
144 | subsampled_values = vals
145 | subsampled_mask = mask
146 | subsampled_edges = edges
147 | query_idx = None
148 |
149 | ret = (
150 | subsampled_abq_pairs,
151 | subsampled_values,
152 | subsampled_mask,
153 | subsampled_edges
154 | )
155 |
156 | if withquery:
157 | ret = (*ret, query_idx)
158 |
159 | return ret
160 |
161 | # lie attention
162 |
163 | class LieSelfAttention(nn.Module):
164 | def __init__(
165 | self,
166 | dim,
167 | edge_dim = None,
168 | group = None,
169 | mc_samples = 32,
170 | ds_frac = 1,
171 | fill = 1 / 3,
172 | dim_head = 64,
173 | heads = 8,
174 | cache = False
175 | ):
176 | super().__init__()
177 | self.dim = dim
178 |
179 | self.mc_samples = mc_samples # number of samples to use to estimate convolution
180 | self.group = default(group, SE3()) # Equivariance group for LieConv
181 | self.register_buffer('r',torch.tensor(2.)) # Internal variable for local_neighborhood radius, set by fill
182 | self.fill_frac = min(fill, 1.) # Average Fraction of the input which enters into local_neighborhood, determines r
183 |
184 | self.subsample = FPSsubsample(ds_frac, cache = cache, group = self.group)
185 | self.coeff = .5 # Internal coefficient used for updating r
186 |
187 | self.fill_frac_ema = fill # Keeps track of average fill frac, used for logging only
188 |
189 | # attention related parameters
190 |
191 | inner_dim = dim_head * heads
192 | self.heads = heads
193 |
194 | self.to_q = nn.Linear(dim, inner_dim, bias = False)
195 | self.to_k = nn.Linear(dim, inner_dim, bias = False)
196 | self.to_v = nn.Linear(dim, inner_dim, bias = False)
197 | self.to_out = nn.Linear(inner_dim, dim)
198 |
199 | edge_dim = default(edge_dim, 0)
200 | edge_dim_in = self.group.lie_dim + edge_dim
201 |
202 | self.loc_attn_mlp = nn.Sequential(
203 | nn.Linear(edge_dim_in, edge_dim_in * 4),
204 | nn.ReLU(),
205 | nn.Linear(edge_dim_in * 4, 1),
206 | )
207 |
208 | def extract_neighborhood(self, inp, query_indices):
209 | """ inputs: [pairs_abq (bs,n,n,d), inp_vals (bs,n,c), mask (bs,n), query_indices (bs,m)]
210 | outputs: [neighbor_abq (bs,m,mc_samples,d), neighbor_vals (bs,m,mc_samples,c)]"""
211 |
212 | # Subsample pairs_ab, inp_vals, mask to the query_indices
213 | pairs_abq, inp_vals, mask, edges = inp
214 | device = inp_vals.device
215 |
216 | if exists(query_indices):
217 | abq_at_query = batched_index_select(pairs_abq, query_indices, dim = 1)
218 | mask_at_query = batched_index_select(mask, query_indices, dim = 1)
219 | edges_at_query = batched_index_select(edges, query_indices, dim = 1) if exists(edges) else None
220 | else:
221 | abq_at_query = pairs_abq
222 | mask_at_query = mask
223 | edges_at_query = edges
224 |
225 | mask_at_query = mask_at_query[..., None]
226 |
227 | vals_at_query = inp_vals
228 | dists = self.group.distance(abq_at_query) #(bs,m,n,d) -> (bs,m,n)
229 | mask_value = torch.finfo(dists.dtype).max
230 | dists = dists.masked_fill(mask[:,None,:], mask_value)
231 |
232 | k = min(self.mc_samples, inp_vals.shape[1])
233 |
234 | # NBHD: Sampled Distance Ball
235 | bs, m, n = dists.shape
236 | within_ball = (dists < self.r) & mask[:,None,:] & mask_at_query # (bs,m,n)
237 | noise = torch.zeros((bs, m, n), device = device).uniform_(0, 1)
238 | valid_within_ball, nbhd_idx = torch.topk(within_ball + noise, k, dim=-1, sorted=False)
239 | valid_within_ball = (valid_within_ball > 1)
240 |
241 | # Retrieve abq_pairs, values, and mask at the nbhd locations
242 |
243 | nbhd_abq = batched_index_select(abq_at_query, nbhd_idx, dim = 2)
244 | nbhd_vals = batched_index_select(vals_at_query, nbhd_idx, dim = 1)
245 | nbhd_mask = batched_index_select(mask, nbhd_idx, dim = 1)
246 | nbhd_edges = batched_index_select(edges_at_query, nbhd_idx, dim = 2) if exists(edges) else None
247 |
248 | if self.training: # update ball radius to match fraction fill_frac inside
249 | navg = (within_ball.float()).sum(-1).sum() / mask_at_query.sum()
250 | avg_fill = (navg / mask.sum(-1).float().mean()).cpu().item()
251 | self.r += self.coeff * (self.fill_frac - avg_fill)
252 | self.fill_frac_ema += .1 * (avg_fill-self.fill_frac_ema)
253 |
254 | nbhd_mask &= valid_within_ball.bool()
255 |
256 | return nbhd_abq, nbhd_vals, nbhd_mask, nbhd_edges, nbhd_idx
257 |
258 | def forward(self, inp):
259 | """inputs: [pairs_abq (bs,n,n,d)], [inp_vals (bs,n,ci)]), [query_indices (bs,m)]
260 | outputs [subsampled_abq (bs,m,m,d)], [convolved_vals (bs,m,co)]"""
261 | sub_abq, sub_vals, sub_mask, sub_edges, query_indices = self.subsample(inp, withquery = True)
262 | nbhd_abq, nbhd_vals, nbhd_mask, nbhd_edges, nbhd_indices = self.extract_neighborhood(inp, query_indices)
263 |
264 | h, b, n, d, device = self.heads, *sub_vals.shape, sub_vals.device
265 |
266 | q, k, v = self.to_q(sub_vals), self.to_k(nbhd_vals), self.to_v(nbhd_vals)
267 |
268 | q = rearrange(q, 'b n (h d) -> b h n d', h = h)
269 | k, v = map(lambda t: rearrange(t, 'b n m (h d) -> b h n m d', h = h), (k, v))
270 |
271 | sim = einsum('b h i d, b h i j d -> b h i j', q, k) * (q.shape[-1] ** -0.5)
272 |
273 | edges = nbhd_abq
274 | if exists(nbhd_edges):
275 | edges = torch.cat((nbhd_abq, nbhd_edges), dim = -1)
276 |
277 | loc_attn = self.loc_attn_mlp(edges)
278 | loc_attn = rearrange(loc_attn, 'b i j () -> b () i j')
279 | sim = sim + loc_attn
280 |
281 | mask_value = -torch.finfo(sim.dtype).max
282 |
283 | sim.masked_fill_(~rearrange(nbhd_mask, 'b n m -> b () n m'), mask_value)
284 |
285 | attn = sim.softmax(dim = -1)
286 | out = einsum('b h i j, b h i j d -> b h i d', attn, v)
287 | out = rearrange(out, 'b h n d -> b n (h d)', h = h)
288 | combined = self.to_out(out)
289 |
290 | return sub_abq, combined, sub_mask, sub_edges
291 |
292 | class LieSelfAttentionWrapper(nn.Module):
293 | def __init__(self, dim, attn):
294 | super().__init__()
295 | self.dim = dim
296 | self.attn = attn
297 |
298 | self.net = nn.Sequential(
299 | Pass(nn.LayerNorm(dim)),
300 | self.attn
301 | )
302 |
303 | def forward(self, inp):
304 | sub_coords, sub_values, mask, edges = self.attn.subsample(inp)
305 | new_coords, new_values, mask, edges = self.net(inp)
306 | new_values[..., :self.dim] += sub_values
307 | return new_coords, new_values, mask, edges
308 |
309 | class FeedForward(nn.Module):
310 | def __init__(self, dim, mult = 4):
311 | super().__init__()
312 | self.dim = dim
313 |
314 | self.net = nn.Sequential(
315 | Pass(nn.LayerNorm(dim)),
316 | Pass(nn.Linear(dim, mult * dim)),
317 | Pass(nn.GELU()),
318 | Pass(nn.Linear(mult * dim, dim)),
319 | )
320 |
321 | def forward(self,inp):
322 | sub_coords, sub_values, mask, edges = inp
323 | new_coords, new_values, mask, edges = self.net(inp)
324 | new_values = new_values + sub_values
325 | return new_coords, new_values, mask, edges
326 |
327 | # transformer class
328 |
329 | class LieTransformer(nn.Module):
330 | """
331 | [Fill] specifies the fraction of the input which is included in local neighborhood.
332 | (can be array to specify a different value for each layer)
333 | [nbhd] number of samples to use for Monte Carlo estimation (p)
334 | [dim] number of input channels: 1 for MNIST, 3 for RGB images, other for non images
335 | [ds_frac] total downsampling to perform throughout the layers of the net. In (0,1)
336 | [num_layers] number of BottleNeck Block layers in the network
337 | [k] channel width for the network. Can be int (same for all) or array to specify individually.
338 | [liftsamples] number of samples to use in lifting. 1 for all groups with trivial stabilizer. Otherwise 2+
339 | [Group] Chosen group to be equivariant to.
340 | """
341 | def __init__(
342 | self,
343 | dim,
344 | num_tokens = None,
345 | num_edge_types = None,
346 | edge_dim = None,
347 | heads = 8,
348 | dim_head = 64,
349 | depth = 2,
350 | ds_frac = 1.,
351 | dim_out = None,
352 | k = 1536,
353 | nbhd = 128,
354 | mean = True,
355 | per_point = True,
356 | liftsamples = 4,
357 | fill = 1 / 4,
358 | cache = False,
359 | reversible = False,
360 | **kwargs
361 | ):
362 | super().__init__()
363 | assert not (ds_frac < 1 and reversible), 'must not downsample if network is reversible'
364 |
365 | dim_out = default(dim_out, dim)
366 | self.token_emb = nn.Embedding(num_tokens, dim) if exists(num_tokens) else None
367 | self.edge_emb = nn.Embedding(num_edge_types, edge_dim) if exists(num_edge_types) else None
368 |
369 | group = SE3()
370 | self.group = group
371 | self.liftsamples = liftsamples
372 |
373 | layers_fill = cast_tuple(fill, depth)
374 | layers = nn.ModuleList([])
375 |
376 | for _, layer_fill in zip(range(depth), layers_fill):
377 | layers.append(nn.ModuleList([
378 | LieSelfAttentionWrapper(dim, LieSelfAttention(dim, heads = heads, dim_head = dim_head, edge_dim = edge_dim, mc_samples = nbhd, ds_frac = ds_frac, group = group, fill = fill, cache = cache,**kwargs)),
379 | FeedForward(dim)
380 | ]))
381 |
382 | execute_class = ReversibleSequence if reversible else SequentialSequence
383 | self.net = execute_class(layers)
384 |
385 | self.to_logits = nn.Sequential(
386 | Pass(nn.LayerNorm(dim)),
387 | Pass(nn.Linear(dim, dim_out))
388 | )
389 |
390 | self.pool = GlobalPool(mean = mean)
391 |
392 | def forward(self, feats, coors, edges = None, mask = None, return_pooled = False):
393 | b, n, *_ = feats.shape
394 |
395 | if exists(self.token_emb):
396 | feats = self.token_emb(feats)
397 |
398 | if exists(self.edge_emb):
399 | assert exists(edges), 'edges must be passed in on forward'
400 | assert edges.shape[1] == edges.shape[2] and edges.shape[1] == n, f'edges must be of the shape ({b}, {n}, {n})'
401 | edges = self.edge_emb(edges)
402 |
403 | inps = (coors, feats, mask, edges)
404 |
405 | lifted_x = self.group.lift(inps, self.liftsamples)
406 | out = self.net(lifted_x)
407 |
408 | out = self.to_logits(out)
409 |
410 | if not return_pooled:
411 | features = out[1]
412 | return features
413 |
414 | return self.pool(out)
415 |
--------------------------------------------------------------------------------