├── docs
├── pywarm-logo.png
├── pywarm-logo-small-dark.gif
├── pywarm-logo-small-light.gif
├── github-icon.svg
├── text.mako
├── tutorial.md
└── example.md
├── warm
├── __init__.py
├── util.py
├── module.py
├── engine.py
└── functional.py
├── tests
├── test_warm.py
├── test_util.py
├── test_module.py
├── test_functional.py
└── test_engine.py
├── CONTRIBUTING.md
├── LICENSE.md
├── pyproject.toml
├── .gitignore
├── examples
├── transformer.py
├── resnet.py
├── mobilenet.py
├── lstm.py
├── efficientnet.py
└── mnist.py
└── README.md
/docs/pywarm-logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/blue-season/pywarm/HEAD/docs/pywarm-logo.png
--------------------------------------------------------------------------------
/docs/pywarm-logo-small-dark.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/blue-season/pywarm/HEAD/docs/pywarm-logo-small-dark.gif
--------------------------------------------------------------------------------
/docs/pywarm-logo-small-light.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/blue-season/pywarm/HEAD/docs/pywarm-logo-small-light.gif
--------------------------------------------------------------------------------
/warm/__init__.py:
--------------------------------------------------------------------------------
1 | # 09-10-2019;
2 |
3 | """ `warm.up` is an alias of
4 | [`warm.engine.prepare_model_`](https://blue-season.github.io/pywarm/reference/warm/engine/#prepare_model_). """
5 | from warm.engine import prepare_model_ as up
6 |
--------------------------------------------------------------------------------
/tests/test_warm.py:
--------------------------------------------------------------------------------
1 | # 09-10-2019;
2 | """
3 | Test cases for the warm module.
4 | """
5 | import torch.nn as nn
6 | from pathlib import Path
7 | import sys
8 | sys.path.append(str(Path(__file__).parent.parent))
9 | import warm
10 |
11 |
12 | def test_warm_up():
13 | m = nn.Identity()
14 | assert not warm.engine.is_ready(m), 'is_ready did not work correctly.'
15 | warm.up(m, [1, 2, 3])
16 | assert warm.engine.is_ready(m), 'warm.up did not work correctly.'
17 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing to PyWarm
2 |
3 | PyWarm is developed on [GitHub](https://github.com/blue-season/pywarm).
4 |
5 | Please use GitHub to file Bug reports and submit pull requests.
6 |
7 | Please document and test before submissions.
8 |
9 | PyWarm is developed with Python 3.7, but has been tested to work with Python 3.6+.
10 |
11 | # Coding Style
12 |
13 | For the rational behind the distinct coding style use in PyWarm, please check
14 |
15 | [A Coding Style for Python](https://blue-season.github.io/a-coding-style-for-python/).
16 |
--------------------------------------------------------------------------------
/docs/github-icon.svg:
--------------------------------------------------------------------------------
1 |
6 |
--------------------------------------------------------------------------------
/tests/test_util.py:
--------------------------------------------------------------------------------
1 | # 08-31-2019;
2 | """
3 | Test cases for warm.util.
4 | """
5 | import torch
6 | import torch.nn as nn
7 | import numpy as np
8 | from pathlib import Path
9 | import sys
10 | sys.path.append(str(Path(__file__).parent.parent))
11 | from warm import util
12 |
13 |
14 | def test_camel_to_snake():
15 | assert util.camel_to_snake('CamelAndSnake') == 'camel_and_snake'
16 | assert util.camel_to_snake('camelAndSnake') == 'camel_and_snake'
17 | assert util.camel_to_snake('camelANDSnake') == 'camel_and_snake'
18 | assert util.camel_to_snake('CAMELAndSnake') == 'camel_and_snake'
19 | assert util.camel_to_snake('CAMELAndSNAKE') == 'camel_and_snake'
20 | assert util.camel_to_snake('CamelAndSnake_') == 'camel_and_snake_'
21 | assert util.camel_to_snake('_CamelAndSnake') == '__camel_and_snake'
22 |
23 |
24 | def test_summary_str():
25 | from examples.resnet import WarmResNet
26 | m = WarmResNet()
27 | s = util.summary_str(m)
28 | assert len(s) > 0
29 |
30 |
31 | def test_summary():
32 | from examples.resnet import WarmResNet
33 | m = WarmResNet()
34 | util.summary(m)
35 |
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 blue-season
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.poetry]
2 | name = 'PyWarm'
3 | version = '0.4.1'
4 | description = 'A cleaner way to build neural networks for PyTorch.'
5 | license = 'MIT'
6 | authors = ['blue-season ']
7 | readme = 'README.md'
8 | repository = 'https://github.com/blue-season/pywarm'
9 | homepage = 'https://github.com/blue-season/pywarm'
10 | keywords = ['pywarm', 'pytorch', 'neural network', 'deep learning']
11 | packages = [ { include='warm' }, ]
12 |
13 |
14 | [tool.poetry.dependencies]
15 | python = '>=3.6'
16 |
17 |
18 | [tool.poetry.dev-dependencies]
19 | toml = '>=0.9'
20 | pytest = '>=3.0'
21 | torch = '>=1.0'
22 | torchvision = '>=0.4'
23 |
24 |
25 | [tool.portray]
26 | modules = ['warm']
27 |
28 |
29 | [tool.portray.mkdocs]
30 | markdown_extensions = ['pymdownx.superfences']
31 |
32 |
33 | [tool.portray.mkdocs.theme]
34 | logo = 'docs/pywarm-logo-small-light.gif'
35 | favicon = 'docs/pywarm-logo-small-dark.gif'
36 | name = 'material'
37 | palette = {primary='deep orange', accent='pink'}
38 |
39 |
40 | [tool.portray.pdoc3]
41 | config = ['show_source_code=False',
42 | 'show_type_annotations=False',
43 | 'sort_identifiers=True',
44 | 'show_inherited_members=False']
45 | template_dir = 'docs'
46 |
--------------------------------------------------------------------------------
/tests/test_module.py:
--------------------------------------------------------------------------------
1 | # 08-31-2019;
2 | """
3 | Test cases for warm.module.
4 | """
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | from pathlib import Path
9 | import sys
10 | sys.path.append(str(Path(__file__).parent.parent))
11 | import warm.module as mm
12 | import warm.functional as W
13 |
14 |
15 | def test_lambda():
16 | f = lambda x: x*2
17 | m = mm.Lambda(f)
18 | x = torch.randn(1, 2)
19 | assert torch.equal(f(x), m(x)), 'lambda did not work correctly.'
20 | def f(x, w, b=5):
21 | return x*w+b
22 | m = mm.Lambda(f, 2, b=1)
23 | assert torch.equal(f(x, 2, 1), m(x)), 'function with args and kwargs did not work correctly.'
24 | x = torch.randn(3, 2, 4)
25 | m = mm.Lambda(W.permute, 'BDC', 'BCD')
26 | assert list(m(x).shape) == [3, 4, 2], 'lambda permute did not work correctly.'
27 |
28 |
29 | def test_sequential():
30 | s = mm.Sequential(
31 | nn.Linear(1, 2),
32 | nn.LSTM(2, 3, batch_first=True), # lstm and gru return multiple outputs
33 | nn.GRU(3, 4, batch_first=True),
34 | mm.Lambda(W.permute, 'BDC', 'BCD'),
35 | nn.Conv1d(4, 5, 1), )
36 | x = torch.randn(3, 2, 1)
37 | assert list(s(x).shape) == [3, 5, 2]
38 |
39 |
40 | def test_shortcut():
41 | l = nn.Linear(1, 1, bias=False)
42 | nn.init.constant_(l.weight, 2.0)
43 | s = mm.Shortcut(l)
44 | x = torch.ones(1, 1)
45 | assert torch.allclose(s(x), torch.Tensor([3.0]))
46 |
--------------------------------------------------------------------------------
/warm/util.py:
--------------------------------------------------------------------------------
1 | # 08-28-2019;
2 | """
3 | Short utilities.
4 | """
5 | import torch
6 | import torch.nn as nn
7 | import numpy as np
8 | import re
9 |
10 |
11 | """ Create a property for class torch.Tensor called ndim, for pytorch earlier than 1.2. """
12 | if not hasattr(torch.Tensor, 'ndim'):
13 | torch.Tensor.ndim = property(lambda x: x.dim())
14 |
15 |
16 | def camel_to_snake(name):
17 | """ Convert a camelCaseString to its snake_case_equivalent. """
18 | s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name)
19 | return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower()
20 |
21 |
22 | def summary_str(model):
23 | """ Get a string representation of model building blocks and parameter counts. """
24 | indent_list, name_list, count_list = [], [], []
25 | def module_info(m, name, indent_level):
26 | count_list.append(sum([np.prod(list(p.size())) for p in m.parameters()]))
27 | indent_list.append(indent_level)
28 | name_list.append(name)
29 | for name, child in m.named_children():
30 | if name.isdigit():
31 | name = child._get_name()
32 | module_info(child, name, indent_level+1)
33 | module_info(model, model._get_name(), 0)
34 | max_indent = max(indent_list)*4
35 | max_name = max(len(x) for x in name_list)+max_indent+2
36 | max_param = len(str(count_list[0]))+max_name+2
37 | out = ['Blocks{:>{w}}'.format('Params', w=max_param-6)]
38 | out += ['-'*max_param]
39 | for indent, name, param in zip(indent_list, name_list, count_list):
40 | s0 = ' '*indent
41 | s1 = '{:{w}}'.format(name, w=max_name-len(s0))
42 | s2 = '{:>{w}}'.format(param, w=max_param-len(s1)-len(s0))
43 | out += [s0+s1+s2]
44 | return '\n'.join(out)
45 |
46 |
47 | def summary(model):
48 | """ Print a summary about model building blocks and parameter counts. """
49 | print(summary_str(model))
50 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
106 | # Auto-generated content above this. Manually added content below.
107 | .vscode/
108 | .cache/
109 | *cache*
110 | /.project
111 | tmp/
112 | data/
113 | site/
114 |
--------------------------------------------------------------------------------
/examples/transformer.py:
--------------------------------------------------------------------------------
1 | # 09-05-2019;
2 | """
3 | The Transformer model from paper *Attention is all you need*.
4 | """
5 | from pathlib import Path
6 | import sys
7 | sys.path.append(str(Path(__file__).parent.parent))
8 | sys.path.append('..')
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | import warm
13 | import warm.util
14 | import warm.functional as W
15 |
16 |
17 | def multi_head_attention(x, y=None, num_head=8, dropout=0.1, mask=None, **kw):
18 | def split_heads(t): # (B, C, L) -> (B, N, H, L) where N*H == C
19 | return t.reshape(batch, num_head, size//num_head, t.shape[-1])
20 | def merge_heads(t): # (B, N, H, L) -> (B, C, L)
21 | return t.reshape(batch, -1, t.shape[-1]) # (B, C, L)
22 | if y is None:
23 | y = x # self attention
24 | batch, size = x.shape[:2] # B, C, Lx
25 | assert size%num_head == 0, 'num_head must be a divisor of size.'
26 | assert y.shape[:2] == x.shape[:2], 'The first 2 dims of x, y must match.'
27 | q = W.linear(x, size) # query
28 | k = W.linear(y, size) # key
29 | v = W.linear(y, size) # value
30 | q = split_heads(q) # (B, N, H, Lx)
31 | k = split_heads(k) # (B, N, H, Ly)
32 | v = split_heads(v) # (B, N, H, Ly)
33 | q *= (size//num_head)**(-0.5)
34 | a = q.transpose(2, 3).contiguous().matmul(k) # attention weights, (B, N, Lx, Ly)
35 | if mask is not None:
36 | a += mask
37 | a = F.softmax(a, dim=-1)
38 | a = W.dropout(a, dropout)
39 | x = v.matmul(a.transpose(2, 3).contiguous()) # (B, N, H, Lx)
40 | x = merge_heads(x) # (B, C, Lx)
41 | return W.linear(x, size)
42 |
43 |
44 | def feed_forward(x, size_ff=2048, dropout=0.1, **kw):
45 | y = W.linear(x, size_ff, activation='relu')
46 | y = W.dropout(y, dropout)
47 | return W.linear(y, x.shape[1])
48 |
49 |
50 | def residual_add(x, layer, dropout=0.1, **kw):
51 | y = W.layer_norm(x)
52 | y = layer(y, **kw)
53 | y = W.dropout(y, dropout)
54 | return x+y
55 |
56 |
57 | def encoder(x, num_encoder=6, **kw):
58 | for i in range(num_encoder):
59 | x = residual_add(x, multi_head_attention, **kw)
60 | x = residual_add(x, feed_forward, **kw)
61 | return W.layer_norm(x)
62 |
63 |
64 | def decoder(x, y, num_decoder=6, mask_x=None, mask_y=None, **kw):
65 | for i in range(num_decoder):
66 | y = residual_add(y, multi_head_attention, mask=mask_y, **kw)
67 | y = residual_add(x, multi_head_attention, y=y, mask=mask_x, **kw)
68 | y = residual_add(y, feed_forward, **kw)
69 | return W.layer_norm(y)
70 |
71 |
72 | def transformer(x, y, **kw):
73 | x = encoder(x, **kw)
74 | x = decoder(x, y, **kw)
75 | return x
76 |
77 |
78 | class Transformer(nn.Module):
79 | def __init__(self, *shape, **kw):
80 | super().__init__()
81 | self.kw = kw
82 | warm.up(self, *shape)
83 | def forward(self, x, y):
84 | return transformer(x, y, **self.kw)
85 |
--------------------------------------------------------------------------------
/warm/module.py:
--------------------------------------------------------------------------------
1 | # 08-27-2019;
2 | """
3 | Custom modules to enhance the nn Sequential experience.
4 |
5 | PyWarm's core concept is to use a functional interface to simplify network building.
6 | However, if you still prefer the classical way of defining child modules in `__init__()`,
7 | PyWarm provides some utilities to help organize child modules better.
8 |
9 | - `Lambda` can be used to wrap one line data transformations, like `x.view()`, `x.permute()` etc, into modules.
10 |
11 | - `Sequential` is an extension to `nn.Sequential` that better accomodates PyTorch RNNs.
12 |
13 | - `Shortcut` is another extension to `nn.Sequential` that will also perform a shortcut addition (AKA residual connection)
14 | for the input with output, so that residual blocks can be written in an entire sequential way.
15 |
16 | For example, to define the basic block type for resnet:
17 |
18 |
19 | ```Python
20 | import torch.nn as nn
21 | import warm.module as wm
22 |
23 |
24 | def basic_block(size_in, size_out, stride=1):
25 | block = wm.Shortcut(
26 | nn.Conv2d(size_in, size_out, 3, stride, 1, bias=False),
27 | nn.BatchNorm2d(size_out),
28 | nn.ReLU(),
29 | nn.Conv2d(size_out, size_out, 3, 1, 1, bias=False),
30 | nn.BatchNorm2d(size_out),
31 | projection=wm.Lambda(
32 | lambda x: x if x.shape[1] == size_out else nn.Sequential(
33 | nn.Conv2d(size_in, size_out, 1, stride, bias=False),
34 | nn.BatchNorm2d(size_out), )(x), ), )
35 | return block
36 | ```
37 | """
38 |
39 |
40 | import torch.nn as nn
41 |
42 |
43 | class Lambda(nn.Module):
44 | """ Wraps a callable and all its call arguments.\n
45 | - `fn: callable`; The callable being wrapped.
46 | - `*arg: list`; Arguments to be passed to `fn`.
47 | - `**kw: dict`; KWargs to be passed to `fn`. """
48 | def __init__(self, fn, *arg, **kw):
49 | super().__init__()
50 | self.fn = fn
51 | self.arg = arg
52 | self.kw = kw
53 | def forward(self, x):
54 | """ forward. """
55 | return self.fn(x, *self.arg, **self.kw)
56 |
57 |
58 | class Sequential(nn.Sequential):
59 | """ Similar to `nn.Sequential`, except that child modules can have multiple outputs (e.g. `nn.RNN`).\n
60 | - `*arg: list of Modules`; Same as `nn.Sequential`. """
61 | def forward(self, x):
62 | """ forward. """
63 | for module in self._modules.values():
64 | if isinstance(x, tuple):
65 | try:
66 | x = module(x)
67 | except Exception:
68 | x = module(x[0])
69 | else:
70 | x = module(x)
71 | return x
72 |
73 |
74 | class Shortcut(Sequential):
75 | """ Similar to `nn.Sequential`, except that it performs a shortcut addition for the input and output.\n
76 | - `*arg: list of Modules`; Same as `nn.Sequential`.
77 | - `projection: None or callable`; If `None`, input with be added directly to the output.
78 | otherwise input will be passed to the `projection` first, usually to make the shapes match. """
79 | def __init__(self, *arg, projection=None):
80 | super().__init__(*arg)
81 | self.projection = projection or nn.Identity()
82 | def forward(self, x):
83 | """ forward. """
84 | return super().forward(x)+self.projection(x)
85 |
--------------------------------------------------------------------------------
/examples/resnet.py:
--------------------------------------------------------------------------------
1 | # 08-29-2019;
2 | """
3 | Construct a WarmResNet() using PyWarm, then copy state dicts
4 | from torchvision.models.resnet18() into WarmResNet(),
5 | compare if it produce identical results as the official one.
6 | """
7 | from pathlib import Path
8 | import sys
9 | sys.path.append(str(Path(__file__).parent.parent))
10 | sys.path.append('..')
11 | import time
12 | import torch
13 | import torch.nn as nn
14 | import torch.nn.functional as F
15 | import warm
16 | import warm.util
17 | import warm.functional as W
18 |
19 |
20 | def basic(x, size, stride, stack_index, block_index):
21 | """ The basic block. """
22 | prefix = f'layer{stack_index+1}-{block_index}-'
23 | y = W.conv(x, size, 3, stride=stride, padding=1, bias=False, name=prefix+'conv1')
24 | y = W.batch_norm(y, activation='relu', name=prefix+'bn1')
25 | y = W.conv(y, size, 3, stride=1, padding=1, bias=False, name=prefix+'conv2')
26 | y = W.batch_norm(y, name=prefix+'bn2')
27 | if y.shape[1] != x.shape[1]:
28 | x = W.conv(x, y.shape[1], 1, stride=stride, bias=False, name=prefix+'downsample-0')
29 | x = W.batch_norm(x, name=prefix+'downsample-1')
30 | return F.relu(y+x)
31 |
32 |
33 | def stack(x, num_block, size, stride, stack_index, block=basic):
34 | """ A stack of num_block blocks. """
35 | for block_index, s in enumerate([stride]+[1]*(num_block-1)):
36 | x = block(x, size, s, stack_index, block_index)
37 | return x
38 |
39 |
40 | class WarmResNet(nn.Module):
41 | def __init__(self, block=basic, stack_spec=((2, 64, 1), (2, 128, 2), (2, 256, 2), (2, 512, 2))):
42 | super().__init__()
43 | self.block = block
44 | self.stack_spec = stack_spec
45 | warm.up(self, [2, 3, 32, 32])
46 | def forward(self, x):
47 | y = W.conv(x, 64, 7, stride=2, padding=3, bias=False, name='conv1')
48 | y = W.batch_norm(y, activation='relu', name='bn1')
49 | y = F.max_pool2d(y, 3, stride=2, padding=1)
50 | for i, spec in enumerate(self.stack_spec):
51 | y = stack(y, *spec, i, block=self.block)
52 | y = F.adaptive_avg_pool2d(y, 1)
53 | y = torch.flatten(y, 1)
54 | y = W.linear(y, 1000, name='fc')
55 | return y
56 |
57 |
58 | def test_time(fn, *arg, repeat=10, **kw):
59 | dur = 0.0
60 | for i in range(repeat):
61 | start = time.time()
62 | y = fn(*arg, **kw)
63 | dur += time.time()-start
64 | return dur
65 |
66 |
67 | def test():
68 | """ Compare the classification result of WarmResNet versus torchvision resnet18. """
69 | new = WarmResNet()
70 | from torchvision.models import resnet18
71 | old = resnet18()
72 | state = old.state_dict()
73 | for k in list(state.keys()): # Map parameters of old, e.g. layer2.0.conv1.weight
74 | s = k.split('.') # to parameters of new, e.g. layer2-0-conv1.weight
75 | s = '-'.join(s[:-1])+'.'+s[-1]
76 | state[s] = state.pop(k)
77 | new.load_state_dict(state)
78 | warm.util.summary(old)
79 | warm.util.summary(new)
80 | x = torch.randn(2, 3, 224, 224)
81 | with torch.no_grad():
82 | old.eval()
83 | y_old = old(x)
84 | new.eval()
85 | y_new = new(x)
86 | if torch.equal(y_old, y_new):
87 | print('Success! Same results from old and new.')
88 | else:
89 | print('Warning! New and old produce different results.')
90 | t_old = test_time(old, x)
91 | t_new = test_time(new, x)
92 | print('Total forward time for old:', t_old, 'seconds.')
93 | print('Total forward time for new:', t_new, 'seconds.')
94 |
95 |
96 | if __name__ == '__main__':
97 | test()
98 |
--------------------------------------------------------------------------------
/examples/mobilenet.py:
--------------------------------------------------------------------------------
1 | # 09-03-2019;
2 | """
3 | Construct a WarmMobileNetV2() using PyWarm, then copy state dicts
4 | from torchvision.models.mobilenet_v2() into WarmMobileNetV2(),
5 | compare if it produce identical results as the official one.
6 | """
7 | from pathlib import Path
8 | import sys
9 | sys.path.append(str(Path(__file__).parent.parent))
10 | sys.path.append('..')
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.functional as F
14 | import warm
15 | import warm.util
16 | import warm.functional as W
17 |
18 |
19 | def conv_bn_relu(x, size, stride=1, expand=1, kernel=3, groups=1, name=''):
20 | x = W.conv(x, size, kernel, padding=(kernel-1)//2, stride=stride, groups=groups, bias=False,
21 | name=f'{name}-0', )
22 | return W.batch_norm(x, activation='relu6', name=f'{name}-1')
23 |
24 |
25 | def bottleneck(x, size_out, stride, expand, name=''):
26 | size_in = x.shape[1]
27 | size_mid = size_in*expand
28 | y = conv_bn_relu(x, size_mid, kernel=1, name=f'{name}-conv-0') if expand > 1 else x
29 | y = conv_bn_relu(y, size_mid, stride, kernel=3, groups=size_mid, name=f'{name}-conv-{1 if expand > 1 else 0}')
30 | y = W.conv(y, size_out, kernel=1, bias=False, name=f'{name}-conv-{2 if expand > 1 else 1}')
31 | y = W.batch_norm(y, name=f'{name}-conv-{3 if expand > 1 else 2}')
32 | if stride == 1 and size_in == size_out:
33 | y += x # residual shortcut
34 | return y
35 |
36 |
37 | def conv1x1(x, *arg, **kw):
38 | return conv_bn_relu(x, *arg, kernel=1, **kw)
39 |
40 |
41 | def pool(x, *arg, **kw):
42 | return x.mean([2, 3])
43 |
44 |
45 | def classify(x, size, *arg, **kw):
46 | x = W.dropout(x, rate=0.2, name='classifier-0')
47 | return W.linear(x, size, name='classifier-1')
48 |
49 |
50 | default_spec = (
51 | (None, 32, 1, 2, conv_bn_relu), # t, c, n, s, operator
52 | (1, 16, 1, 1, bottleneck),
53 | (6, 24, 2, 2, bottleneck),
54 | (6, 32, 3, 2, bottleneck),
55 | (6, 64, 4, 2, bottleneck),
56 | (6, 96, 3, 1, bottleneck),
57 | (6, 160, 3, 2, bottleneck),
58 | (6, 320, 1, 1, bottleneck),
59 | (None, 1280, 1, 1, conv1x1),
60 | (None, None, 1, None, pool),
61 | (None, 1000, 1, None, classify), )
62 |
63 |
64 | class WarmMobileNetV2(nn.Module):
65 | def __init__(self):
66 | super().__init__()
67 | warm.up(self, [2, 3, 224, 224])
68 | def forward(self, x):
69 | count = 0
70 | for t, c, n, s, op in default_spec:
71 | for i in range(n):
72 | stride = s if i == 0 else 1
73 | x = op(x, c, stride, t, name=f'features-{count}')
74 | count += 1
75 | return x
76 |
77 |
78 | def test():
79 | """ Compare the classification result of WarmMobileNetV2 versus torchvision mobilenet_v2. """
80 | new = WarmMobileNetV2()
81 | from torchvision.models import mobilenet_v2
82 | old = mobilenet_v2()
83 | state = old.state_dict()
84 | for k in list(state.keys()): # Map parameters of old, e.g. layer2.0.conv1.weight
85 | s = k.split('.') # to parameters of new, e.g. layer2-0-conv1.weight
86 | s = '-'.join(s[:-1])+'.'+s[-1]
87 | state[s] = state.pop(k)
88 | new.load_state_dict(state)
89 | warm.util.summary(old)
90 | warm.util.summary(new)
91 | x = torch.randn(1, 3, 224, 224)
92 | with torch.no_grad():
93 | old.eval()
94 | y_old = old(x)
95 | new.eval()
96 | y_new = new(x)
97 | if torch.equal(y_old, y_new):
98 | print('Success! Same results from old and new.')
99 | else:
100 | print('Warning! New and old produce different results.')
101 |
102 |
103 | if __name__ == '__main__':
104 | test()
105 |
--------------------------------------------------------------------------------
/docs/text.mako:
--------------------------------------------------------------------------------
1 | ## Define mini-templates for each portion of the doco.
2 |
3 | <%!
4 | def indent(s, spaces=4):
5 | new = s.replace('\n', '\n' + ' ' * spaces)
6 | return ' ' * spaces + new.strip()
7 | %>
8 |
9 | <%def name="deflist(s)">:${indent(s)[1:]}%def>
10 |
11 | <%def name="h3(s)">### ${s}
12 | %def>
13 |
14 | <%def name="function(func)" buffered="True">
15 | <%
16 | returns = show_type_annotations and func.return_annotation() or ''
17 | if returns:
18 | returns = ' -> ' + returns
19 | %>
20 | ${"---"}
21 | ${"### " + func.name}
22 |
23 |
24 | ```python3
25 | def :
26 | ${",\n ".join(func.params(annotate=show_type_annotations))} ${returns}
27 | ```
28 | ${func.docstring}
29 |
30 | % if show_source_code and func.source and func.obj is not getattr(func.inherits, 'obj', None):
31 |
32 | ??? example "View Source"
33 | ${"\n ".join(func.source.split("\n"))}
34 |
35 | % endif
36 | %def>
37 |
38 | <%def name="variable(var)" buffered="True">
39 | ```python3
40 | ${var.name}
41 | ```
42 | ${var.docstring | deflist}
43 | %def>
44 |
45 | <%def name="class_(cls)" buffered="True">
46 | ${"---"}
47 | ${"### " + cls.name}
48 |
49 | ```python3
50 | def :
51 | ${",\n ".join(cls.params(annotate=show_type_annotations))}
52 | ```
53 |
54 | ${cls.docstring}
55 |
56 | % if show_source_code and cls.source:
57 |
58 | ??? example "View Source"
59 | ${"\n ".join(cls.source.split("\n"))}
60 |
61 | ------
62 |
63 | % endif
64 |
65 | <%
66 | class_vars = cls.class_variables(show_inherited_members, sort=sort_identifiers)
67 | static_methods = cls.functions(show_inherited_members, sort=sort_identifiers)
68 | inst_vars = cls.instance_variables(show_inherited_members, sort=sort_identifiers)
69 | methods = cls.methods(show_inherited_members, sort=sort_identifiers)
70 | mro = cls.mro()
71 | subclasses = cls.subclasses()
72 | %>
73 | % if mro:
74 | ${h3('Ancestors (in MRO)')}
75 | % for c in mro:
76 | * ${c.refname}
77 | % endfor
78 | % endif
79 |
80 | % if subclasses:
81 | ${h3('Descendants')}
82 | % for c in subclasses:
83 | * ${c.refname}
84 | % endfor
85 | % endif
86 |
87 | % if class_vars:
88 | ${h3('Class variables')}
89 | % for v in class_vars:
90 | ${variable(v)}
91 |
92 | % endfor
93 | % endif
94 |
95 | % if static_methods:
96 | ${h3('Static methods')}
97 | % for f in static_methods:
98 | ${function(f)}
99 |
100 | % endfor
101 | % endif
102 |
103 | % if inst_vars:
104 | ${h3('Instance variables')}
105 | % for v in inst_vars:
106 | ${variable(v)}
107 |
108 | % endfor
109 | % endif
110 | % if methods:
111 | ${h3('Methods')}
112 | % for m in methods:
113 | ${function(m)}
114 |
115 | % endfor
116 | % endif
117 |
118 | %def>
119 |
120 | ## Start the output logic for an entire module.
121 |
122 | <%
123 | variables = module.variables()
124 | classes = module.classes()
125 | functions = module.functions()
126 | submodules = module.submodules()
127 | heading = 'Namespace' if module.is_namespace else 'Module'
128 | %>
129 |
130 | ${"# " + heading} ${module.name}
131 |
132 | ${module.docstring}
133 |
134 | % if show_source_code:
135 |
136 | ??? example "View Source"
137 | ${"\n ".join(module.source.split("\n"))}
138 |
139 | % endif
140 |
141 |
142 | % if submodules:
143 | Sub-modules
144 | -----------
145 | % for m in submodules:
146 | * [${m.name}](${m.name.split(".")[-1]}/)
147 | % endfor
148 | % endif
149 |
150 | % if variables:
151 | Variables
152 | ---------
153 | % for v in variables:
154 | ${variable(v)}
155 |
156 | % endfor
157 | % endif
158 |
159 | % if functions:
160 | Functions
161 | ---------
162 | % for f in functions:
163 | ${function(f)}
164 |
165 | % endfor
166 | % endif
167 |
168 | % if classes:
169 | Classes
170 | -------
171 | % for c in classes:
172 | ${class_(c)}
173 |
174 | % endfor
175 | % endif
176 |
--------------------------------------------------------------------------------
/examples/lstm.py:
--------------------------------------------------------------------------------
1 | # 09-07-2019;
2 | """
3 | LSTM sequence model example, based on
4 | https://pytorch.org/tutorials/beginner/nlp/sequence_models_tutorial.html
5 | """
6 | import argparse
7 | from pathlib import Path
8 | import sys
9 | sys.path.append(str(Path(__file__).parent.parent))
10 | sys.path.append('..')
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.functional as F
14 | import torch.optim as optim
15 | import warm
16 | import warm.functional as W
17 |
18 |
19 | training_data = [
20 | ('The dog ate the apple'.split(), ['DET', 'NN', 'V', 'DET', 'NN']),
21 | ('Everybody read that book'.split(), ['NN', 'V', 'DET', 'NN']), ]
22 | testing_data = [('The dog ate the book'.split(), ['DET', 'NN', 'V', 'DET', 'NN'])]
23 | word_to_ix = {}
24 | for sent, tags in training_data:
25 | for word in sent:
26 | if word not in word_to_ix:
27 | word_to_ix[word] = len(word_to_ix)
28 | tag_to_ix = {'DET': 0, 'NN': 1, 'V': 2}
29 | ix_to_tag = {v:k for k, v in tag_to_ix.items()}
30 |
31 |
32 | class WarmTagger(nn.Module):
33 | def __init__(self, embedding_dim, hidden_dim, vocab_size, tagset_size):
34 | super().__init__()
35 | self.arg = (embedding_dim, hidden_dim, vocab_size, tagset_size)
36 | warm.up(self, torch.tensor([0, 1], dtype=torch.long))
37 | def forward(self, x): # D
38 | embedding_dim, hidden_dim, vocab_size, tagset_size = self.arg
39 | y = W.embedding(x, embedding_dim, vocab_size) # D->DC
40 | y = W.lstm(y.T[None, ...], hidden_dim) # DC->BCD
41 | y = W.linear(y, tagset_size) # BCD
42 | y = F.log_softmax(y, dim=1) # BCD
43 | return y[0].T # DC
44 |
45 |
46 | class TorchTagger(nn.Module):
47 | def __init__(self, embedding_dim, hidden_dim, vocab_size, tagset_size):
48 | super().__init__()
49 | self.hidden_dim = hidden_dim
50 | self.word_embeddings = nn.Embedding(vocab_size, embedding_dim)
51 | self.lstm = nn.LSTM(embedding_dim, hidden_dim)
52 | self.hidden2tag = nn.Linear(hidden_dim, tagset_size)
53 | def forward(self, sentence):
54 | embeds = self.word_embeddings(sentence)
55 | lstm_out, _ = self.lstm(embeds.view(len(sentence), 1, -1))
56 | tag_space = self.hidden2tag(lstm_out.view(len(sentence), -1))
57 | tag_scores = F.log_softmax(tag_space, dim=1)
58 | return tag_scores
59 |
60 |
61 | def prepare_sequence(seq, to_ix):
62 | idxs = [to_ix[w] for w in seq]
63 | return torch.tensor(idxs, dtype=torch.long)
64 |
65 |
66 | def main():
67 | parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
68 | parser.add_argument(
69 | '--warm', action='store_true', help='use warm instead of vanilla pytorch.')
70 | p = parser.parse_args()
71 | torch.manual_seed(1)
72 | #
73 | arg = (6, 6, len(word_to_ix), len(tag_to_ix))
74 | model = WarmTagger(*arg) if p.warm else TorchTagger(*arg)
75 | print(f'Using {model._get_name()}.')
76 | loss_function = nn.NLLLoss()
77 | optimizer = optim.SGD(model.parameters(), lr=0.1)
78 | #
79 | for epoch in range(300):
80 | for sentence, tags in training_data:
81 | model.zero_grad()
82 | sentence_in = prepare_sequence(sentence, word_to_ix)
83 | targets = prepare_sequence(tags, tag_to_ix)
84 | tag_scores = model(sentence_in)
85 | loss = loss_function(tag_scores, targets)
86 | loss.backward()
87 | optimizer.step()
88 | #
89 | with torch.no_grad():
90 | inputs = prepare_sequence(testing_data[0][0], word_to_ix)
91 | tag_scores = model(inputs)
92 | ix = torch.argmax(tag_scores, -1).numpy()
93 | print(testing_data[0][0])
94 | print('Network tags:\n', [ix_to_tag[i] for i in ix])
95 | print('True tags:\n', testing_data[0][1])
96 |
97 |
98 | if __name__ == '__main__':
99 | main()
100 |
--------------------------------------------------------------------------------
/examples/efficientnet.py:
--------------------------------------------------------------------------------
1 |
2 | # 09-20-2019;
3 | """
4 | EfficientNet
5 | """
6 | from pathlib import Path
7 | import sys
8 | sys.path.append(str(Path(__file__).parent.parent))
9 | sys.path.append('..')
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 | import warm
14 | import warm.util
15 | import warm.functional as W
16 | from warm.engine import namespace
17 |
18 |
19 | def swish(x):
20 | return x*torch.sigmoid(x)
21 |
22 |
23 | def conv_pad_same(x, size, kernel=1, stride=1, **kw):
24 | pad = 0
25 | if kernel != 1 or stride != 1:
26 | in_size, s, k = [torch.as_tensor(v) for v in (x.shape[2:], stride, kernel)]
27 | pad = torch.max(((in_size+s-1)//s-1)*s+k-in_size, torch.tensor(0))
28 | left, right = pad//2, pad-pad//2
29 | if torch.all(left == right):
30 | pad = tuple(left.tolist())
31 | else:
32 | left, right = left.tolist(), right.tolist()
33 | pad = sum(zip(left[::-1], right[::-1]), ())
34 | x = F.pad(x, pad)
35 | pad = 0
36 | return W.conv(x, size, kernel, stride=stride, padding=pad, **kw)
37 |
38 |
39 | @namespace
40 | def conv_bn_act(x, size, kernel=1, stride=1, groups=1, bias=False, eps=1e-3, momentum=1e-2, act=swish, name='', **kw):
41 | x = conv_pad_same(x, size, kernel, stride=stride, groups=groups, bias=bias, name=name+'-conv')
42 | return W.batch_norm(x, eps=eps, momentum=momentum, activation=act, name=name+'-bn')
43 |
44 |
45 | @namespace
46 | def mb_block(x, size_out, expand=1, kernel=1, stride=1, se_ratio=0.25, dc_ratio=0.2, **kw):
47 | """ MobileNet Bottleneck Block. """
48 | size_in = x.shape[1]
49 | size_mid = size_in*expand
50 | y = conv_bn_act(x, size_mid, 1, **kw) if expand > 1 else x
51 | y = conv_bn_act(y, size_mid, kernel, stride=stride, groups=size_mid, **kw)
52 | y = squeeze_excitation(y, int(size_in*se_ratio), **kw)
53 | y = conv_bn_act(y, size_out, 1, act=None, **kw)
54 | if stride == 1 and size_in == size_out:
55 | y = drop_connect(y, dc_ratio)
56 | y += x
57 | return y
58 |
59 |
60 | @namespace
61 | def squeeze_excitation(x, size_se, name='', **kw):
62 | if size_se == 0:
63 | return x
64 | size_in = x.shape[1]
65 | x = F.adaptive_avg_pool2d(x, 1)
66 | x = W.conv(x, size_se, 1, activation=swish, name=name+'-conv1')
67 | return W.conv(x, size_in, 1, activation=swish, name=name+'-conv2')
68 |
69 |
70 | def drop_connect(x, rate):
71 | """ Randomly set entire batch to 0. """
72 | if rate == 0:
73 | return x
74 | rate = 1.0-rate
75 | drop_mask = torch.rand([x.shape[0], 1, 1, 1], device=x.device, requires_grad=False)+rate
76 | return x/rate*drop_mask.floor()
77 |
78 |
79 | spec_b0 = (
80 | (16, 1, 3, 1, 1, 0.25, 0.2), # size, expand, kernel, stride, repeat, se_ratio, dc_ratio
81 | (24, 6, 3, 2, 2, 0.25, 0.2),
82 | (40, 6, 5, 2, 2, 0.25, 0.2),
83 | (80, 6, 3, 2, 3, 0.25, 0.2),
84 | (112, 6, 5, 1, 3, 0.25, 0.2),
85 | (192, 6, 5, 2, 4, 0.25, 0.2),
86 | (320, 6, 3, 1, 1, 0.25, 0.2), )
87 |
88 |
89 | class WarmEfficientNet(nn.Module):
90 | def __init__(self):
91 | super().__init__()
92 | warm.up(self, [2, 3, 32, 32])
93 | def forward(self, x):
94 | x = conv_bn_act(x, 32, kernel=3, stride=2, name='head')
95 | for size, expand, kernel, stride, repeat, se_ratio, dc_ratio in spec_b0:
96 | for i in range(repeat):
97 | stride = stride if i == 0 else 1
98 | x = mb_block(x, size, expand, kernel, stride, se_ratio, dc_ratio)
99 | x = conv_bn_act(x, 1280, name='tail')
100 | x = F.adaptive_avg_pool2d(x, 1)
101 | x = W.dropout(x, 0.2)
102 | x = x.view(x.shape[0], -1)
103 | x = W.linear(x, 1000)
104 | return x
105 |
106 |
107 | if __name__ == '__main__':
108 | m = WarmEfficientNet()
109 | warm.util.summary(m)
110 |
--------------------------------------------------------------------------------
/tests/test_functional.py:
--------------------------------------------------------------------------------
1 | # 08-31-2019;
2 | """
3 | Test cases for warm.functional.
4 | """
5 | import torch
6 | import torch.nn as nn
7 | from pathlib import Path
8 | import sys
9 | sys.path.append(str(Path(__file__).parent.parent))
10 | import warm.module as mm
11 | import warm.functional as W
12 |
13 |
14 | def test_conv():
15 | m = nn.Module()
16 | x = torch.randn(1, 2, 8) # BCD
17 | torch.manual_seed(100)
18 | y0 = nn.Conv1d(2, 3, 3)(x)
19 | torch.manual_seed(100)
20 | y1 = W.conv(x, 3, 3, parent=m)
21 | assert torch.equal(y0, y1), 'conv incorrect output on 1d signal.'
22 | m = nn.Module()
23 | x = torch.randn(1, 2, 3, 4) # BCD
24 | torch.manual_seed(100)
25 | y0 = nn.Conv2d(2, 3, 3)(x)
26 | torch.manual_seed(100)
27 | y1 = W.conv(x, 3, 3, parent=m)
28 | assert torch.equal(y0, y1), 'conv incorrect output on 2d signal.'
29 |
30 |
31 | def test_linear():
32 | m = nn.Module()
33 | x = torch.randn(1, 2, 3) # BDC
34 | torch.manual_seed(100)
35 | y0 = nn.Linear(3, 4)(x)
36 | torch.manual_seed(100)
37 | y1 = W.linear(x, 4, parent=m, in_shape='BDC', out_shape='BDC')
38 | assert torch.equal(y0, y1), 'linear incorrect output on 1d signal.'
39 | m = nn.Module()
40 | x = torch.randn(1, 2, 3, 4) # BDC
41 | torch.manual_seed(100)
42 | y0 = nn.Linear(4, 3)(x)
43 | torch.manual_seed(100)
44 | y1 = W.linear(x, 3, parent=m, in_shape='BDC', out_shape='BDC')
45 | assert torch.equal(y0, y1), 'batch_norm incorrect output on 2d signal.'
46 |
47 |
48 | def test_batch_norm():
49 | m = nn.Module()
50 | x = torch.randn(1, 2, 3) # BCD
51 | torch.manual_seed(100)
52 | y0 = nn.BatchNorm1d(2)(x)
53 | torch.manual_seed(100)
54 | y1 = W.batch_norm(x, parent=m)
55 | m = nn.Module()
56 | assert torch.equal(y0, y1), 'batch_norm incorrect output on 1d signal.'
57 | x = torch.randn(1, 2, 3, 4) # BCD
58 | torch.manual_seed(100)
59 | y0 = nn.BatchNorm2d(2)(x)
60 | torch.manual_seed(100)
61 | y1 = W.batch_norm(x, parent=m)
62 | assert torch.equal(y0, y1), 'batch_norm incorrect output on 2d signal.'
63 |
64 |
65 | def test_lstm():
66 | m = nn.Module()
67 | x = torch.randn(3, 2, 1) # DBC
68 | torch.manual_seed(100)
69 | y0, *_ = nn.LSTM(1, 2, num_layers=2)(x)
70 | torch.manual_seed(100)
71 | y1 = W.lstm(x, 2, num_layers=2, parent=m, init_weight_hh=None, in_shape='DBC', out_shape='DBC')
72 | assert torch.equal(y0, y1)
73 | y1, s1 = W.lstm(x, 2, parent=m, tuple_out=True) # test tuple out
74 | assert len(s1) == 2
75 | y2 = W.lstm((y1, s1), 2, parent=m) # test tuple in
76 | assert torch.is_tensor(y2)
77 |
78 |
79 | def test_gru():
80 | m = nn.Module()
81 | x = torch.randn(3, 2, 1) # DBC
82 | torch.manual_seed(100)
83 | y0, *_ = nn.GRU(1, 2, num_layers=2)(x)
84 | torch.manual_seed(100)
85 | y1 = W.gru(x, 2, num_layers=2, parent=m, init_weight_hh=None, in_shape='DBC', out_shape='DBC')
86 | assert torch.equal(y0, y1)
87 |
88 |
89 | def test_identity():
90 | x = torch.randn(1, 2, 3)
91 | assert torch.equal(W.identity(x, 7, 8, a='b'), x)
92 |
93 |
94 | def test_dropout():
95 | m = nn.Module()
96 | x = torch.ones(2, 6, 6, 6)
97 | torch.manual_seed(100)
98 | y0 = nn.Dropout(0.3)(x)
99 | torch.manual_seed(100)
100 | y1 = W.dropout(x, 0.3, parent=m)
101 | assert torch.equal(y0, y1)
102 | torch.manual_seed(100)
103 | y0 = nn.Dropout2d(0.3)(x)
104 | torch.manual_seed(100)
105 | y1 = W.dropout(x, 0.3, by_channel=True, parent=m)
106 | assert torch.equal(y0, y1)
107 |
108 |
109 | def test_transformer():
110 | m = nn.Module()
111 | x = torch.randn(10, 2, 4)
112 | y = torch.randn(6, 2, 4)
113 | torch.manual_seed(100)
114 | z0 = nn.Transformer(4, 2, 1, 1, dim_feedforward=8)(x, y)
115 | torch.manual_seed(100)
116 | z1 = W.transformer(x, y, 1, 1, 2, dim_feedforward=8, in_shape='DBC', out_shape='DBC', parent=m)
117 | assert torch.equal(z0, z1)
118 | torch.manual_seed(100)
119 | z1 = W.transformer(x, y, 1, 1, 2, dim_feedforward=8, in_shape='DBC', out_shape='DBC', parent=m, causal=True)
120 | assert not torch.equal(z0, z1)
121 | z1 = W.transformer(x, None, 2, 0, 2, dim_feedforward=8, in_shape='DBC', out_shape='DBC', parent=m)
122 | assert z1.shape == x.shape
123 |
124 |
125 | def test_layer_norm():
126 | m = nn.Module()
127 | x = torch.randn(1, 2, 3, 4, 5)
128 | y0 = nn.LayerNorm([3, 4, 5])(x)
129 | y1 = W.layer_norm(x, [2, -2, -1], parent=m)
130 | assert torch.equal(y0, y1)
131 | y0 = nn.LayerNorm(5)(x)
132 | y1 = W.layer_norm(x, dim=-1, parent=m)
133 | assert torch.equal(y0, y1)
134 | x0 = x.permute(0, 4, 2, 1, 3)
135 | y0 = nn.LayerNorm([2, 4])(x0)
136 | y0 = y0.permute(0, 3, 2, 4, 1)
137 | y1 = W.layer_norm(x, dim=[1, -2], parent=m)
138 | assert torch.equal(y0, y1)
139 |
140 |
141 | def test_embedding():
142 | m = nn.Module()
143 | x = torch.randint(0, 20, (1, 2, 3, 4, 5))
144 | torch.manual_seed(10)
145 | y0 = nn.Embedding(20, 8)(x)
146 | torch.manual_seed(10)
147 | y1 = W.embedding(x, 8, 20, parent=m)
148 | assert torch.equal(y0, y1)
149 | torch.manual_seed(10)
150 | y1 = W.embedding(x, 8, 20, in_shape='DCB', parent=m) # shapes should have no effect
151 | assert torch.equal(y0, y1)
152 | torch.manual_seed(10)
153 | y1 = W.embedding(x, 8, 20, out_shape='CBD', parent=m) # shapes should have no effect
154 | assert torch.equal(y0, y1)
155 | y1 = W.embedding(x, 8, parent=m) # should work without a explicit vocabulary size
156 | torch.manual_seed(10)
157 | y1 = W.embedding(x.double(), 8, parent=m) # should work with non integer tensors.
158 | assert torch.equal(y0, y1)
159 |
--------------------------------------------------------------------------------
/examples/mnist.py:
--------------------------------------------------------------------------------
1 | # 08-27-2019;
2 | """
3 | MNIST training example.
4 | Use `python mnist.py` to run with PyTorch NN.
5 | Use `python mnist.py --warm` to run with PyWarm NN.
6 | Use `python mnist.py --help` to see a list of cli argument options.
7 | """
8 | from pathlib import Path
9 | import sys
10 | sys.path.append(str(Path(__file__).parent.parent))
11 | sys.path.append('..')
12 | import argparse
13 | import torch
14 | import torch.nn as nn
15 | import torch.nn.functional as F
16 | import torch.optim as optim
17 | from torchvision import datasets, transforms
18 | import warm
19 | import warm.functional as W
20 |
21 |
22 | class WarmNet(nn.Module):
23 | def __init__(self):
24 | super().__init__()
25 | warm.up(self, [1, 1, 28, 28])
26 | def forward(self, x):
27 | x = W.conv(x, 20, 5, activation='relu')
28 | x = F.max_pool2d(x, 2)
29 | x = W.conv(x, 50, 5, activation='relu')
30 | x = F.max_pool2d(x, 2)
31 | x = x.view(-1, 800)
32 | x = W.linear(x, 500, activation='relu')
33 | x = W.linear(x, 10)
34 | return F.log_softmax(x, dim=1)
35 |
36 |
37 | class TorchNet(nn.Module):
38 | def __init__(self):
39 | super().__init__()
40 | self.conv1 = nn.Conv2d(1, 20, 5, 1)
41 | self.conv2 = nn.Conv2d(20, 50, 5, 1)
42 | self.fc1 = nn.Linear(4*4*50, 500)
43 | self.fc2 = nn.Linear(500, 10)
44 | def forward(self, x):
45 | x = F.relu(self.conv1(x))
46 | x = F.max_pool2d(x, 2, 2)
47 | x = F.relu(self.conv2(x))
48 | x = F.max_pool2d(x, 2, 2)
49 | x = x.view(-1, 4*4*50)
50 | x = F.relu(self.fc1(x))
51 | x = self.fc2(x)
52 | return F.log_softmax(x, dim=1)
53 |
54 |
55 | def train(p, model, device, train_loader, optimizer, epoch):
56 | model.train()
57 | for batch_idx, (data, target) in enumerate(train_loader):
58 | data, target = data.to(device), target.to(device)
59 | optimizer.zero_grad()
60 | output = model(data)
61 | loss = F.nll_loss(output, target)
62 | loss.backward()
63 | optimizer.step()
64 | if batch_idx%p.log_interval == 0:
65 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
66 | epoch, batch_idx*len(data), len(train_loader.dataset),
67 | 100.*batch_idx/len(train_loader), loss.item()))
68 |
69 |
70 | def test(p, model, device, test_loader):
71 | model.eval()
72 | test_loss = 0
73 | correct = 0
74 | size = len(test_loader.dataset)
75 | with torch.no_grad():
76 | for data, target in test_loader:
77 | data, target = data.to(device), target.to(device)
78 | output = model(data)
79 | test_loss += F.nll_loss(output, target, reduction='sum').item()
80 | pred = output.argmax(dim=1, keepdim=True)
81 | correct += pred.eq(target.view_as(pred)).sum().item()
82 | test_loss /= size
83 | print(f'\nTest loss: {test_loss:.4f}, Accuracy: {correct}/{size} ({100*correct/size:.2f}%)\n')
84 |
85 |
86 | def main():
87 | parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
88 | parser.add_argument(
89 | '--warm', action='store_true', help='use warm instead of vanilla pytorch.')
90 | parser.add_argument(
91 | '--batch-size', type=int, default=128, metavar='N', help='input batch size for training (default: 128)')
92 | parser.add_argument(
93 | '--test-batch-size', type=int, default=1000, metavar='N', help='input batch size for testing (default: 1000)')
94 | parser.add_argument(
95 | '--epochs', type=int, default=3, metavar='N', help='number of epochs to train (default: 3)')
96 | parser.add_argument(
97 | '--lr', type=float, default=0.02, metavar='LR', help='learning rate (default: 0.02)')
98 | parser.add_argument(
99 | '--momentum', type=float, default=0.5, metavar='M', help='SGD momentum (default: 0.5)')
100 | parser.add_argument(
101 | '--no-cuda', action='store_true', default=False, help='disables CUDA training')
102 | parser.add_argument(
103 | '--seed', type=int, default=1, metavar='S', help='random seed (default: 1)')
104 | parser.add_argument(
105 | '--log-interval', type=int, default=10, metavar='N', help='number of batchs between logging training status')
106 | parser.add_argument(
107 | '--save-model', action='store_true', default=False, help='For Saving the current Model')
108 | p = parser.parse_args()
109 | #
110 | torch.manual_seed(p.seed)
111 | use_cuda = not p.no_cuda and torch.cuda.is_available()
112 | device = 'cuda' if use_cuda else 'cpu'
113 | kw = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
114 | data_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)), ])
115 | train_data = datasets.MNIST('../data', train=True, download=True, transform=data_transform)
116 | test_data = datasets.MNIST('../data', train=False, download=True, transform=data_transform)
117 | train_loader = torch.utils.data.DataLoader(train_data, batch_size=p.batch_size, shuffle=True, **kw)
118 | test_loader = torch.utils.data.DataLoader(test_data, batch_size=p.test_batch_size, shuffle=True, **kw)
119 | model = WarmNet() if p.warm else TorchNet()
120 | print(f'Using {model._get_name()}.')
121 | model = model.to(device)
122 | optimizer = optim.SGD(model.parameters(), lr=p.lr, momentum=p.momentum)
123 | print(f'Training with {p.epochs} epochs on {device} device.')
124 | #
125 | for i in range(p.epochs):
126 | train(p, model, device, train_loader, optimizer, i)
127 | test(p, model, device, test_loader)
128 | #
129 | if p.save_model:
130 | torch.save(model.state_dict(), 'mnist_cnn.pt')
131 |
132 |
133 | if __name__ == '__main__':
134 | main()
135 |
--------------------------------------------------------------------------------
/tests/test_engine.py:
--------------------------------------------------------------------------------
1 | # 08-31-2019;
2 | """
3 | Test cases for warm.engine.
4 | """
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | import copy
9 | from pathlib import Path
10 | import sys
11 | sys.path.append(str(Path(__file__).parent.parent))
12 | from warm import engine
13 |
14 |
15 | def test_set_get_default_parent():
16 | a = nn.Identity()
17 | b = nn.Identity()
18 | engine.set_default_parent(a)
19 | assert engine.get_default_parent() is a, 'get_default_parent result mismatchs set_default_parent.'
20 | engine.set_default_parent(b)
21 | assert engine.get_default_parent() is b, 'get_default_parent result mismatchs set_default_parent.'
22 |
23 |
24 | def test_auto_name():
25 | a = nn.Identity()
26 | for i in range(10):
27 | assert engine._auto_name('test', a) == f'test_{i+1}', 'new calls to _auto_name failed to increment name count.'
28 | a(None) # test if forward pre hook is triggered to reset names
29 | assert engine._auto_name('test', a) == 'test_1', 'forward_pre_hook did not work.'
30 |
31 |
32 | def test_initialize():
33 | a = nn.Parameter(torch.zeros(3, 4))
34 | b = nn.Parameter(torch.zeros(3, 4))
35 | c = nn.Parameter(torch.zeros(3, 4))
36 | torch.manual_seed(1)
37 | engine.initialize_(a, 'normal_')
38 | torch.manual_seed(1)
39 | nn.init.normal_(b)
40 | assert torch.equal(a, b), 'initialize_ with str spec did not work correctly.'
41 | assert not torch.equal(a, c), 'initialize_ with str spec did not work.'
42 | torch.manual_seed(1)
43 | engine.initialize_(c, nn.init.normal_)
44 | assert torch.equal(a, c), 'initialize_ with function spec did not work correctly.'
45 |
46 |
47 | def test_activate():
48 | a = torch.randn(3, 4)
49 | b = copy.deepcopy(a)
50 | a = engine.activate(a, 'hardshrink')
51 | b = F.hardshrink(b)
52 | assert torch.equal(a, b), 'activate with str spec did not work correctly.'
53 | a = engine.activate(a, 'relu')
54 | b = F.relu(b)
55 | assert torch.equal(a, b), 'activate with str spec did not work correctly.'
56 |
57 |
58 | def test_permute():
59 | x = torch.randn(1, 2, 3)
60 | y = engine.permute(x, 'BCD', 'DCB')
61 | assert list(y.shape) == [3, 2, 1], 'permute 3d tensor with str in_shape and str out_shape did not work correctly.'
62 | y = engine.permute(x, 'BCD', None)
63 | assert list(y.shape) == [1, 2, 3], 'permute tensor with None out_shape did not work corretly.'
64 | y = engine.permute(x, 'BCD', [1, 0, 2])
65 | assert list(y.shape) == [2, 1, 3], 'permute tensor with list out_shape did not work corretly.'
66 | x = torch.randn(1, 2, 3, 4)
67 | y = engine.permute(x, 'BCD', 'DCB')
68 | assert list(y.shape) == [3, 4, 2, 1], 'permute 4d tensor with str in_shape and str out_shape did not work correctly.'
69 | y = engine.permute(x, 'DBC', 'CDB')
70 | assert list(y.shape) == [4, 1, 2, 3], 'permute 4d tensor with str in_shape and str out_shape did not work correctly.'
71 | x = torch.randn(1, 2, 3, 4, 5)
72 | y = engine.permute(x, 'BDC', 'BCD')
73 | assert list(y.shape) == [1, 5, 2, 3, 4], 'permute 5d tensor with str in_shape and str out_shape did not work correctly.'
74 | x = torch.randn(1, 2)
75 | y = engine.permute(x, 'BDC', 'BCD')
76 | assert list(y.shape) == [1, 2], 'permute 2d tensor with str in_shape and str out_shape did not work correctly.'
77 | y = engine.permute(x, 'CBD', 'DBC')
78 | assert list(y.shape) == [2, 1], 'permute 2d tensor with str in_shape and str out_shape did not work correctly.'
79 |
80 |
81 | def test_unused_kwargs():
82 | kw = {'unused1':0, 'unused2':0, 'base_class':0}
83 | unused = engine.unused_kwargs(kw)
84 | assert 'base_class' not in unused, 'unused_kwargs leaks used.'
85 | assert set(unused.keys()) == {'unused1', 'unused2'}, 'unused_kwargs did not filter kw correctly.'
86 |
87 |
88 | def test_prepare_model_is_ready():
89 | class TestModel(nn.Module):
90 | def forward(self, x):
91 | x = engine.forward(x, nn.Linear, 'linear',
92 | base_arg=(x.shape[-1], 4, False), # in_features, out_features, bias
93 | in_shape=None, out_shape=None, base_shape=None,
94 | initialization={'weight':'ones_'}, activation=(F.dropout, {'p':1.0}), )
95 | return x
96 | x = torch.randn(1, 2, 3)
97 | m = TestModel()
98 | assert not engine.is_ready(m), 'is_ready did not work correctly.'
99 | engine.prepare_model_(m, x)
100 | assert engine.is_ready(m), 'prepare_model_ did not work correctly.'
101 | assert m.linear_1.bias is None, 'linear_1 should not have bias.'
102 | assert torch.allclose(m.linear_1.weight, torch.Tensor([1.0])), 'linear_1.weight should be initialized to all 1s.'
103 | y = m(x)
104 | assert torch.allclose(y, torch.Tensor([0.0])), 'y should be all 0s because we dropout everything.'
105 | assert list(y.shape) == [1, 2, 4], 'y should have shape [1, 2, 4] after linear projection.'
106 |
107 |
108 | def test_forward():
109 | x = torch.randn(1, 2, 3)
110 | m = nn.Module()
111 | engine.set_default_parent(m)
112 | class TripleOut(nn.Module): # to test tuple_out
113 | def forward(self, x, b=1, c='2'):
114 | return x+b, x, c
115 | y = engine.forward(x, base_class=TripleOut, base_name='tri', tuple_out=False)
116 | assert isinstance(y, torch.Tensor), 'tuple_out did not work correctly.'
117 | y = engine.forward(x, base_class=TripleOut, base_name='tri', tuple_out=True)
118 | assert isinstance(y, tuple) and len(y) == 3 and y[-1] == '2', 'tuple_out did not work correctly.'
119 | y = engine.forward(x, base_class=TripleOut, base_name='tri', forward_kw={'c':3}, tuple_out=True)
120 | assert y[-1] == 3, 'forward_kw did not work correctly.'
121 | y = engine.forward(x, base_class=TripleOut, base_name='tri', forward_arg=(2.0,))
122 | assert torch.allclose(y-x, torch.Tensor([2.0])), 'forward_arg did not work correctly.'
123 | y = engine.forward(x, base_class=TripleOut, activation=(F.dropout, {'p':1.0}))
124 | assert torch.allclose(y, torch.Tensor([0.0])), 'activation did not work correctly.'
125 | y = engine.forward(
126 | x, base_class=nn.Linear, base_kw={'out_features':4}, infer_kw={'in_features':'C'}, base_shape='BDC')
127 | assert y.shape[1] == 4, 'base_kw, infer_kw did not work correctly.'
128 |
129 |
130 | def test_namespace():
131 | m = nn.Module()
132 | engine.set_default_parent(m)
133 | @engine.namespace
134 | def f1(name=''):
135 | return ';'.join([f2(name=name) for i in range(2)])
136 | @engine.namespace
137 | def f2(name=''):
138 | return name
139 | s0, s1, s2 = [f1() for i in range(3)]
140 | assert s0 == 'f1_1-f2_1;f1_1-f2_2'
141 | assert s1 == 'f1_2-f2_1;f1_2-f2_2'
142 | assert s2 == 'f1_3-f2_1;f1_3-f2_2'
143 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | [](https://blue-season.github.io/pywarm/)
3 |
4 | # PyWarm
5 |
6 | A cleaner way to build neural networks for PyTorch.
7 |
8 | [](https://github.com/blue-season/pywarm)
9 | [](https://pypi.org/project/pywarm/)
10 | [](https://github.com/blue-season/pywarm/blob/master/LICENSE)
11 |
12 | [Examples](https://blue-season.github.io/pywarm/docs/example/) | [Tutorial](https://blue-season.github.io/pywarm/docs/tutorial/) | [API reference](https://blue-season.github.io/pywarm/reference/warm/functional/)
13 |
14 | ----
15 |
16 | ## Introduction
17 |
18 | PyWarm is a lightweight, high-level neural network construction API for PyTorch.
19 | It enables defining all parts of NNs in the functional way.
20 |
21 | With PyWarm, you can put *all* network data flow logic in the `forward()` method of
22 | your model, without the need to define children modules in the `__init__()` method
23 | and then call it again in the `forward()`.
24 | This result in a much more readable model definition in fewer lines of code.
25 |
26 | PyWarm only aims to simplify the network definition, and does not attempt to cover
27 | model training, validation or data handling.
28 |
29 | ----
30 |
31 | For example, a convnet for MNIST:
32 | (If needed, click the tabs to switch between Warm and Torch versions)
33 |
34 |
35 | ``` Python tab="Warm" linenums="1"
36 | # powered by PyWarm
37 | import torch.nn as nn
38 | import torch.nn.functional as F
39 | import warm
40 | import warm.functional as W
41 |
42 |
43 | class ConvNet(nn.Module):
44 |
45 | def __init__(self):
46 | super().__init__()
47 | warm.up(self, [2, 1, 28, 28])
48 |
49 | def forward(self, x):
50 | x = W.conv(x, 20, 5, activation='relu')
51 | x = F.max_pool2d(x, 2)
52 | x = W.conv(x, 50, 5, activation='relu')
53 | x = F.max_pool2d(x, 2)
54 | x = x.view(-1, 800)
55 | x = W.linear(x, 500, activation='relu')
56 | x = W.linear(x, 10)
57 | return F.log_softmax(x, dim=1)
58 | ```
59 |
60 | ``` Python tab="Torch" linenums="1"
61 | # vanilla PyTorch version, taken from
62 | # pytorch tutorials/beginner_source/blitz/neural_networks_tutorial.py
63 | import torch.nn as nn
64 | import torch.nn.functional as F
65 |
66 |
67 | class ConvNet(nn.Module):
68 |
69 | def __init__(self):
70 | super().__init__()
71 | self.conv1 = nn.Conv2d(1, 20, 5, 1)
72 | self.conv2 = nn.Conv2d(20, 50, 5, 1)
73 | self.fc1 = nn.Linear(4*4*50, 500)
74 | self.fc2 = nn.Linear(500, 10)
75 |
76 | def forward(self, x):
77 | x = F.relu(self.conv1(x))
78 | x = F.max_pool2d(x, 2, 2)
79 | x = F.relu(self.conv2(x))
80 | x = F.max_pool2d(x, 2, 2)
81 | x = x.view(-1, 4*4*50)
82 | x = F.relu(self.fc1(x))
83 | x = self.fc2(x)
84 | return F.log_softmax(x, dim=1)
85 | ```
86 |
87 | ----
88 |
89 | A couple of things you may have noticed:
90 |
91 | - First of all, in the PyWarm version, the entire network definition and
92 | data flow logic resides in the `forward()` method. You don't have to look
93 | up and down repeatedly to understand what `self.conv1`, `self.fc1` etc.
94 | is doing.
95 |
96 | - You do not need to track and specify `in_channels` (or `in_features`, etc.)
97 | for network layers. PyWarm can infer the information for you. e.g.
98 |
99 | ```Python
100 | # Warm
101 | x = W.conv(x, 20, 5, activation='relu')
102 | x = W.conv(x, 50, 5, activation='relu')
103 |
104 |
105 | # Torch
106 | self.conv1 = nn.Conv2d(1, 20, 5, 1)
107 | self.conv2 = nn.Conv2d(20, 50, 5, 1)
108 | ```
109 |
110 | - One unified `W.conv` for all 1D, 2D, and 3D cases. Fewer things to keep track of!
111 |
112 | - `activation='relu'`. All `warm.functional` APIs accept an optional `activation` keyword,
113 | which is basically equivalent to `F.relu(W.conv(...))`. The keyword `activation` can also
114 | take in a callable, for example `activation=torch.nn.ReLU(inplace=True)` or `activation=swish`.
115 |
116 | For deeper neural networks, see additional [examples](https://blue-season.github.io/pywarm/docs/example/).
117 |
118 | ----
119 | ## Installation
120 |
121 | pip3 install pywarm
122 |
123 | ----
124 | ## Quick start: 30 seconds to PyWarm
125 |
126 | If you already have experinces with PyTorch, using PyWarm is very straightforward:
127 |
128 | - First, import PyWarm in you model file:
129 | ```Python
130 | import warm
131 | import warm.functional as W
132 | ```
133 |
134 | - Second, remove child module definitions in the model's `__init__()` method.
135 | In stead, use `W.conv`, `W.linear` ... etc. in the model's `forward()` method,
136 | just like how you would use torch nn functional `F.max_pool2d`, `F.relu` ... etc.
137 |
138 | For example, instead of writing:
139 |
140 | ```Python
141 | # Torch
142 | class MyModule(nn.Module):
143 | def __init__(self):
144 | super().__init__()
145 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size)
146 | # other child module definitions
147 | def forward(self, x):
148 | x = self.conv1(x)
149 | # more forward steps
150 | ```
151 |
152 | - You can now write in the warm way:
153 |
154 | ```Python
155 | # Warm
156 | class MyWarmModule(nn.Module):
157 | def __init__(self):
158 | super().__init__()
159 | warm.up(self, input_shape_or_data)
160 | def forward(self, x):
161 | x = W.conv(x, out_channels, kernel_size) # no in_channels needed
162 | # more forward steps
163 | ```
164 |
165 | - Finally, don't forget to warmify the model by adding
166 |
167 | `warm.up(self, input_shape_or_data)`
168 |
169 | at the end of the model's `__init__()` method. You need to supply
170 | `input_shape_or_data`, which is either a tensor of input data,
171 | or just its shape, e.g. `[2, 1, 28, 28]` for MNIST inputs.
172 |
173 | The model is now ready to use, just like any other PyTorch models.
174 |
175 | Check out the [tutorial](https://blue-season.github.io/pywarm/docs/tutorial/)
176 | and [examples](https://blue-season.github.io/pywarm/docs/example/) if you want to learn more!
177 |
178 | ----
179 | ## Testing
180 |
181 | Clone the repository first, then
182 |
183 | cd pywarm
184 | pytest -v
185 |
186 | ----
187 | ## Documentation
188 |
189 | Documentations are generated using the excellent [Portray](https://timothycrosley.github.io/portray/) package.
190 |
191 | - [Examples](https://blue-season.github.io/pywarm/docs/example/)
192 |
193 | - [Tutorial](https://blue-season.github.io/pywarm/docs/tutorial/)
194 |
195 | - [API reference](https://blue-season.github.io/pywarm/reference/warm/functional/)
196 |
--------------------------------------------------------------------------------
/docs/tutorial.md:
--------------------------------------------------------------------------------
1 |
2 | # PyWarm Basic Tutorial
3 |
4 | ## Import
5 |
6 | To get started, first import PyWarm in your project:
7 |
8 | ```Python
9 | import warm
10 | import warm.functional as W
11 | ```
12 |
13 | ## Rewrite
14 |
15 | Now you can replace child module definitions with function calls.
16 | For example, instead of:
17 |
18 | ```Python
19 | # Torch
20 | class MyModule(nn.Module):
21 |
22 | def __init__(self):
23 | super().__init__()
24 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size)
25 | # other child module definitions
26 |
27 | def forward(self, x):
28 | x = self.conv1(x)
29 | # more forward steps
30 | ```
31 |
32 | You now use the warm functions:
33 |
34 | ```Python
35 | # Warm
36 | class MyWarmModule(nn.Module):
37 |
38 | def __init__(self):
39 | super().__init__()
40 | warm.up(self, input_shape_or_data)
41 |
42 | def forward(self, x):
43 | x = W.conv(x, out_channels, kernel_size) # no in_channels needed
44 | # more forward steps
45 | ```
46 |
47 | Notice the `warm.up(self, input_shape_or_data)` at the end of the `__init__()` method.
48 | It is required so that PyWarm can infer all shapes of itermediate steps and set up trainable parameters.
49 | The only argument `input_shape_or_data` can either be a tensor, e.g. `torch.randn(2, 1, 28, 28)`,
50 | or just the shape, e.g. `[2, 1, 28, 28]` for the model inputs. If the model has multiple inputs,
51 | you may supple them in a list or a dictionary.
52 |
53 | Although it is recommended that you attach `warm.up()` to the end of the `__init__()` of your model, you can actually
54 | use it on the class instances outside of the definition, like a normal function call:
55 |
56 | ```Python
57 | class MyWarmModule(nn.Module):
58 |
59 | def __init__(self):
60 | super().__init__() # no warm.up here
61 |
62 | def forward(self, x):
63 | x = W.conv(x, 10, 3)
64 | # forward step, powered by PyWarm
65 |
66 |
67 | model = MyWarmModule() # call warm.up outside of the module definition
68 |
69 | warm.up(model, [2, 1, 28, 28])
70 | ```
71 |
72 | **Note**: If the model contains `batch_norm` layers, you need to specify the `Batch` dimension to at least 2.
73 |
74 | # Advanced Topics
75 |
76 | ## Default shapes
77 |
78 | PyWarm has a unified functional interface, that by default all functions accept and return tensors with shape
79 | `(Batch, Channel, *)`, where `*` is any number of additional dimensions. For example, for 2d images,
80 | the `*` usually stands for `(Height, Width)`, and for 1d time series, the `*` means `(Time,)`.
81 |
82 | This convention is optimized for the performance of Convolutional networks. It may become less efficient if your
83 | model relies heavily on dense (Linear) or recurrent (LSTM, GRU) layers. You can use different input and
84 | output shapes by specifying `in_shape`, `out_shape` keyword arguments in the function calls. These keywords
85 | accept only letters `'B'`, `'C'` and `'D'`, which stand for `Batch`, `Channel`, and `*` (extra Dimensions)
86 | respectively. So for example if for a 1d time series you want to have `(Time, Batch, Channel)` as the output shape,
87 | you can specify `out_shape='DBC'`.
88 |
89 | ## Dimensional awareness
90 |
91 | PyWarm functions can automatically identify 1d, 2d and 3d input data, so the same function can be used on different
92 | dimensional cases. For example, the single `W.conv` is enough to replace `nn.Conv1d, nn.Conv2d, nn.Conv3d`.
93 | Similarly, you don't need `nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d` for differnt inputs, a single `W.batch_norm`
94 | can replace them all.
95 |
96 | ## Shape inference
97 |
98 | Many neural network layers will perform a transformation of shapes. For example, after a convolution operation,
99 | the shape is changed from `(Batch, ChannelIn, *)` to `(Batch, ChannelOut, *)`. PyTorch nn Modules require the user to
100 | keep track of both `in_channels` and `out_channels`. PyWarm relieves this pain by inferring the `in_channels` for you,
101 | so you can focus more on the nature of your tasks, rather than chores.
102 |
103 | ## Argument passdown
104 |
105 | If the signature of a PyWarm function does not specify all possible argument of its torch nn Module couterpart, it will pass down
106 | additional keyword arguments to the underlying nn Module. For example, if you want to specify strides of 2 for a conv layer,
107 | just use `W.conv(..., stride=2)`. The only thing to remember is that you have to specify the full keyword, instead of
108 | relying on the position of arguments.
109 |
110 | ## Parameter initialization per layer
111 |
112 | Unlike PyTorch's approach, paramter initialization can be specified directly in PyWarm's functional interface.
113 | For example:
114 |
115 | ```Python
116 | x = W.conv(x, 20, 1, init_weight='kaiming_uniform_')
117 | ```
118 | This makes it easier to create layer specific initialization in PyWarm. You no long need to go through
119 | `self.modules()` and `self.parameters()` to create custom initializations.
120 |
121 | By default, PyWarm will look into `torch.nn.init` for initialization function names.
122 | Alternatively, you may just specify a callable, or a tuple `(fn, kwargs)` if the callable accepts more than 1 input.
123 |
124 | If the initialization is not specified or `None` is used, the corresponding layer will get default initializations as used
125 | in torch nn modules.
126 |
127 | ## Apply activation nonlinearity to the output
128 |
129 | PyWarm's functional interface supports adding an optional keyword argument `activation=name`, where
130 | name is a callable or just its name, which represents an activation (nonlinearity) function
131 | in `torch.nn.functional` or just `torch`. By default no activation is used.
132 |
133 | ## Mix and Match
134 |
135 | You are not limited to only use PyWarm's functional interface. It is completely ok to mix and match the old
136 | PyTorch way of child module definitions with PyWarm's function API. For example:
137 |
138 | ```Python
139 | class MyModel(nn.Module):
140 |
141 | def __init__(self):
142 | super().__init__()
143 | # other stuff
144 | self.conv1 = nn.Conv2d(2, 30, 7, padding=3)
145 | # other stuff
146 |
147 | def forward(self, x):
148 | y = F.relu(self.conv1(x))
149 | y = W.conv(y, 40, 3, activation='relu')
150 | ```
151 |
152 | ## Custom layer names
153 |
154 | Normally you do not have to specify layer names when using the functional API.
155 | PyWarm will track and count usage for the layer type and automatically assign names for you. For example,
156 | subsequent convolutional layer calls via `W.conv` will create `conv_1`, `conv_2`, ... etc. in the parent module.
157 |
158 | Nevertheless, if you want to ensure certain layer have particular names, you can specify `name='my_name'`
159 | keyword arguments in the call.
160 |
161 | Alternatively, if you still want PyWarm to count usage and increment ordinal for you, but only want to customize
162 | the base type name, you can use `base_name='my_prefix'` keyword instead. The PyWarm modules will then have
163 | names like `my_prefix_1`, `my_prefix_2` in the parent module.
164 |
165 | See the PyWarm [resnet example in the examples folder](https://github.com/blue-season/pywarm/blob/master/examples/resnet.py)
166 | on how to use these features to load pre-trained model parameters into PyWarm models.
167 |
--------------------------------------------------------------------------------
/warm/engine.py:
--------------------------------------------------------------------------------
1 | # 08-26-2019;
2 | """
3 | PyWarm engine to the functional interface.
4 | """
5 | import torch
6 | import torch.nn as nn
7 | import numpy as np
8 | from warm import util
9 |
10 |
11 | _DEFAULT_PARENT_MODULE = None
12 |
13 |
14 | def set_default_parent(parent):
15 | """ Set the default `parent` module. """
16 | global _DEFAULT_PARENT_MODULE
17 | _DEFAULT_PARENT_MODULE = parent
18 |
19 |
20 | def get_default_parent():
21 | """ Get the default `parent` module. """
22 | global _DEFAULT_PARENT_MODULE
23 | return _DEFAULT_PARENT_MODULE
24 |
25 |
26 | def _auto_name(name, parent):
27 | """ Track the count of reference to `name` from `parent`. """
28 | if not is_ready(parent):
29 | parent._pywarm_auto_name_dict = {}
30 | def _hook(model, x):
31 | model._pywarm_auto_name_dict = {}
32 | parent._pywarm_forward_pre_hook = parent.register_forward_pre_hook(_hook)
33 | track = parent._pywarm_auto_name_dict
34 | if name not in track:
35 | track[name] = 0
36 | track[name] += 1
37 | return f'{name}_{track[name]}'
38 |
39 |
40 | def prepare_model_(model, *data, device='cpu'):
41 | """ Initialize all childen modules defined by `warm` in a parent `model`.\n
42 | - `model: Module`; The parent model to be prepared.
43 | - `data: Tensor, or list of int`; A batch of data with the correct shape and type to be forwarded by model.
44 | `data` can also be a list of `int`, in which case it is interpreted as the shape of the input data.
45 | - `device: str, or torch.device`; Should be the same for `model` and `data`. Default: `'cpu'`.
46 | - `return: Module`; The prepared model, with all children modules defined by `warm` initialized. """
47 | _auto_name('', model)
48 | set_default_parent(model)
49 | def _prep_data(d):
50 | if isinstance(d, (np.ndarray, torch.Tensor)):
51 | return torch.as_tensor(d).to(device)
52 | elif isinstance(d, (list, tuple)):
53 | if all(isinstance(x, int) for x in d):
54 | return torch.randn(*d, device=device)
55 | return [_prep_data(x) for x in d]
56 | elif isinstance(d, dict):
57 | return {k:_prep_data(v) for k, v in d.items()}
58 | with torch.no_grad():
59 | is_training = model.training
60 | data = [_prep_data(d) for d in data]
61 | model.eval()
62 | model.to(device)
63 | model(*data)
64 | model.train(is_training)
65 | return model
66 |
67 |
68 | def is_ready(model):
69 | """ Check if a `model` is prepared. """
70 | return hasattr(model, '_pywarm_forward_pre_hook')
71 |
72 |
73 | def activate(x, spec, lookup=None):
74 | """ Activate tensors with given nonlinearity `spec`ification.\n
75 | - `x: Tensor or list of Tensor`; The tensors to be initialized.
76 | - `spec: str or callable or 2-tuple`; If a `str`, should be one of the nonlinearity functions contained in
77 | `torch.nn.functional` or `torch`. If a `callable`, it will be applied to `x` directly, i.e. `spec(x)`.
78 | If a 2-`tuple`, it must be of format `(callable, kwargs)`, i.e. `callable(x, **kwargs)`.
79 | - `lookup: None or list of module`; Parent modules to look for `spec`. If `None`, `[nn.functional, torch]` is used.
80 | - `return: Tensor or list of Tensor`; Activation results. """
81 | if spec is None:
82 | return x
83 | lookup = lookup or [nn.functional, torch]
84 | if isinstance(spec, str):
85 | for look in lookup:
86 | try:
87 | spec = getattr(look, spec)
88 | break
89 | except:
90 | pass
91 | if isinstance(spec, str):
92 | raise ValueError(f'Unknown spec {spec}.')
93 | if callable(spec):
94 | spec = (spec, {})
95 | fn, kw = spec
96 | if isinstance(x, (list, tuple)):
97 | return [fn(y, **kw) for y in x]
98 | return fn(x, **kw)
99 |
100 |
101 | def initialize_(x, spec):
102 | """ Initialize parameters with given nonlinearity `spec`ification.\n
103 | - `x: Tensor or list of Tensor`; The tensors to be initialized.
104 | - `spec: str or callable or 2-tuple`; If a `str`, should be one of the nonlinearity functions contained in
105 | `torch.nn.init`. If a `callable`, it will be applied to `x` directly, i.e. `spec(x)`. If a 2-`tuple`,
106 | it must be of format `(callable, kwargs)`, i.e. `callable(x, **kwargs)`. """
107 | activate(x, spec, lookup=[nn.init])
108 |
109 |
110 | def permute(x, in_shape='BCD', out_shape='BCD', **kw):
111 | """ Permute the dimensions of a tensor.\n
112 | - `x: Tensor`; The nd-tensor to be permuted.
113 | - `in_shape: str`; The dimension shape of `x`. Can only have characters `'B'` or `'C'` or `'D'`,
114 | which stand for Batch, Channel, or extra Dimensions. The default value `'BCD'` means
115 | the input tensor `x` should be at lest 2-d with shape `(Batch, Channel, Dim0, Dim1, Dim2, ...)`,
116 | where `Dim0, Dim1, Dim2 ...` stand for any number of extra dimensions.
117 | - `out_shape: str or tuple or None`; The dimension shape of returned tensor. Default: `'BCD'`.
118 | If a `str`, it is restricted to the same three characters `'B'`, `'C'` or `'D'` as the `in_shape`.
119 | If a `tuple`, `in_shape` is ignored, and simply `x.permute(out_shape)` is returned.
120 | If `None`, no permution will be performed.
121 | - `return: Tensor`; Permuted nd-tensor. """
122 | if (in_shape == out_shape) or (out_shape is None):
123 | return x
124 | if isinstance(out_shape, (list, tuple, torch.Size)):
125 | return x.permute(*out_shape)
126 | if isinstance(in_shape, str) and isinstance(out_shape, str) :
127 | assert set(in_shape) == set(out_shape) <= {'B', 'C', 'D'}, 'In and out shapes must have save set of chars among B, C, and D.'
128 | in_shape = in_shape.lower().replace('d', '...')
129 | out_shape = out_shape.lower().replace('d', '...')
130 | return torch.einsum(f'{in_shape}->{out_shape}', x)
131 | return x
132 |
133 |
134 | def unused_kwargs(kw):
135 | """ Filter out entries used by `forward` and return the rest. """
136 | fn_kw = dict(base_class=None,
137 | base_name=None, name=None, base_arg=None, base_kw=None, parent=None,
138 | infer_kw=None, in_shape='BCD', base_shape=None, out_shape='BCD', tuple_out=False,
139 | forward_arg=None, forward_kw=None, initialization=None, activation=None, )
140 | return {k:v for k, v in kw.items() if k not in fn_kw}
141 |
142 |
143 | def forward(x, base_class,
144 | base_name=None, name=None, base_arg=None, base_kw=None, parent=None,
145 | infer_kw=None, in_shape='BCD', base_shape='BCD', out_shape='BCD', tuple_out=False,
146 | forward_arg=None, forward_kw=None, initialization=None, activation=None, **kw):
147 | """ A forward template that creates child instances at the first time it is called.\n
148 | - `x: Tensor`; The nd-tensor to be forwarded.
149 | - `base_class: Module`; A child `torch.nn.Module` that will be created at the first time this function is called.
150 | - `base_name: str`; Name for the `base_class`. Default: base_class name.
151 | - `name: str`; Name for the child module instance. Default: class name plus ordinal.
152 | - `base_arg: tuple`; Positional args to be passed to create the child module instance. Default: None.
153 | - `base_kw: dict`; KWargs to be passed to create the child module instance. Default: None.
154 | - `parent: Module`; The parent of the child instance. Default: None. If `None`, will use `get_default_parent`.
155 | - `infer_kw: dict`; Key should be valid for the child instance. Value shoud be a character,
156 | one of `'B'`, `'C'`, or `'D'` (see `permute`), to substitute for a dimension of `x`. Default: None.
157 | - `in_shape: str`; The dimension shape of `x`. See also `permute`. Default: `'BCD'`.
158 | - `base_shape: str`; The dimension shape required by the child module. See also `permute`. Default: `'BCD'`.
159 | - `out_shape: str or tuple or None`; The dimension shape of returned tensor. See also `permute`. Default: `'BCD'`.
160 | - `tuple_out: bool`; Whether the child module will return more than 1 outputs (e.g. `nn.RNN`).
161 | If `True`, the returned value of the function will be a tuple containing all outputs. Default: False.
162 | - `forward_arg: tuple`; positional args to be passed when calling the child module instance. Default: None.
163 | - `forward_kw: dict`; KWargs to be passed when calling the child module instance. Default: None.
164 | - `initialization: dict`; Keys are name of parameters to initialize. Values are init specs, which can be
165 | a, `str`, a `callable`, or `2-tuple`; See the `spec` argument of `initialize_` for details. Default: None.
166 | - `activation: str or callable or 2-tuple`; See the `spec` argument of `activate`. Default: None.
167 | - `return: Tensor or tuple`; If `tuple_out` is `True`, the returned value will be a `tuple`. """
168 | parent = parent or get_default_parent()
169 | if name is None:
170 | base_name = base_name or util.camel_to_snake(base_class.__name__)
171 | name = _auto_name(base_name, parent)
172 | if name not in parent._modules:
173 | if infer_kw is not None:
174 | shape = in_shape
175 | if 'D' in shape:
176 | shape = list(shape)
177 | shape[shape.index('D')] = 'D'*(x.ndim-len(shape)+1)
178 | shape = ''.join(shape)
179 | infer_kw = {
180 | k:x.shape[shape.find(v) if isinstance(v, str) else v]
181 | for k, v in infer_kw.items()}
182 | base = base_class(*(base_arg or []), **(infer_kw or {}), **(base_kw or {}), )
183 | parent.add_module(name, base)
184 | if initialization is not None:
185 | s = parent.state_dict()
186 | for k, v in initialization.items():
187 | initialize_(s[name+'.'+k], v)
188 | x = permute(x, in_shape, base_shape)
189 | y = parent._modules[name](x, *(forward_arg or []), **(forward_kw or {}))
190 | r = []
191 | if isinstance(y, tuple):
192 | y, *r = y
193 | y = permute(y, base_shape, out_shape)
194 | y = activate(y, activation)
195 | if tuple_out:
196 | return (y, *r)
197 | return y
198 |
199 |
200 | import functools
201 | def namespace(f):
202 | """ After decoration, the function name and call count will be appended to the `name` kw. """
203 | @functools.wraps(f)
204 | def _wrapped(*arg, **kw):
205 | parent = kw.get('parent', get_default_parent())
206 | name = kw.get('name', '')
207 | name = '_warmns_' + name + ('-' if name else '') + f.__name__
208 | name = _auto_name(name, parent)
209 | kw['name'] = name.replace('_warmns_', '')
210 | return f(*arg, **kw)
211 | return _wrapped
212 |
--------------------------------------------------------------------------------
/docs/example.md:
--------------------------------------------------------------------------------
1 |
2 | # PyWarm Examples
3 |
4 | ## ResNet
5 |
6 | A more detailed example, the ResNet18 network defined in PyWarm and vanilla PyTorch:
7 |
8 | ``` Python tab="Warm" linenums="1"
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | import warm
13 | import warm.functional as W
14 |
15 |
16 | def basic(x, size, stride):
17 | y = W.conv(x, size, 3, stride=stride, padding=1, bias=False)
18 | y = W.batch_norm(y, activation='relu')
19 | y = W.conv(y, size, 3, stride=1, padding=1, bias=False)
20 | y = W.batch_norm(y)
21 | if y.shape[1] != x.shape[1]: # channel size mismatch, needs projection
22 | x = W.conv(x, y.shape[1], 1, stride=stride, bias=False)
23 | x = W.batch_norm(x)
24 | y = y+x # residual shortcut connection
25 | return F.relu(y)
26 |
27 |
28 | def stack(x, num_block, size, stride, block=basic):
29 | for s in [stride]+[1]*(num_block-1):
30 | x = block(x, size, s)
31 | return x
32 |
33 |
34 | class ResNet(nn.Module):
35 |
36 | def __init__(self, block=basic,
37 | stack_spec=((2, 64, 1), (2, 128, 2), (2, 256, 2), (2, 512, 2))):
38 | super().__init__()
39 | self.block = block
40 | self.stack_spec = stack_spec
41 | warm.up(self, [2, 3, 32, 32])
42 |
43 | def forward(self, x):
44 | y = W.conv(x, 64, 7, stride=2, padding=3, bias=False)
45 | y = W.batch_norm(y, activation='relu')
46 | y = F.max_pool2d(y, 3, stride=2, padding=1)
47 | for spec in self.stack_spec:
48 | y = stack(y, *spec, block=self.block)
49 | y = F.adaptive_avg_pool2d(y, 1)
50 | y = torch.flatten(y, 1)
51 | y = W.linear(y, 1000)
52 | return y
53 |
54 |
55 | resnet18 = ResNet()
56 | ```
57 |
58 | ``` Python tab="Torch" linenums="1"
59 | # code based on torchvision/models/resnet.py
60 | import torch
61 | import torch.nn as nn
62 | import torch.nn.functional as F
63 |
64 |
65 | def conv3x3(size_in, size_out, stride=1):
66 | return nn.Conv2d(size_in, size_out, kernel_size=3, stride=stride,
67 | padding=1, groups=1, bias=False, dilation=1, )
68 |
69 |
70 | def conv1x1(size_in, size_out, stride=1):
71 | return nn.Conv2d(size_in, size_out, kernel_size=1, stride=stride,
72 | padding=0, groups=1, bias=False, dilation=1, )
73 |
74 |
75 | class BasicBlock(nn.Module):
76 |
77 | expansion = 1
78 |
79 | def __init__(self, size_in, size_out, stride=1, downsample=None):
80 | super().__init__()
81 | self.conv1 = conv3x3(size_in, size_out, stride)
82 | self.bn1 = nn.BatchNorm2d(size_out)
83 | self.relu = nn.ReLU(inplace=True)
84 | self.conv2 = conv3x3(size_out, size_out)
85 | self.bn2 = nn.BatchNorm2d(size_out)
86 | self.downsample = downsample
87 |
88 | def forward(self, x):
89 | identity = x
90 | y = self.conv1(x)
91 | y = self.bn1(y)
92 | y = self.relu(y)
93 | y = self.conv2(y)
94 | y = self.bn2(y)
95 | if self.downsample is not None:
96 | identity = self.downsample(x)
97 | y += identity
98 | y = self.relu(y)
99 | return y
100 |
101 |
102 | class ResNet(nn.Module):
103 |
104 | def __init__(self,
105 | block=BasicBlock, num_block=[2, 2, 2, 2]):
106 | super().__init__()
107 | self.size_in = 64
108 | self.conv1 = nn.Conv2d(3, self.size_in, kernel_size=7, stride=2,
109 | padding=3, bias=False)
110 | self.bn1 = nn.BatchNorm2d(self.size_in)
111 | self.relu = nn.ReLU(inplace=True)
112 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
113 | self.stack1 = self._make_stack(block, 64, num_block[0], 1)
114 | self.stack2 = self._make_stack(block, 128, num_block[1], 2)
115 | self.stack3 = self._make_stack(block, 256, num_block[2], 2)
116 | self.stack4 = self._make_stack(block, 512, num_block[3], 2)
117 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
118 | self.fc = nn.Linear(512, 1000)
119 |
120 | def _make_stack(self, block, size_out, num_blocks, stride):
121 | downsample = None
122 | if stride != 1 or self.size_in != size_out:
123 | downsample = nn.Sequential(
124 | conv1x1(self.size_in, size_out, stride),
125 | nn.BatchNorm2d(size_out), )
126 | stacks = []
127 | for stride in strides:
128 | stacks.append(
129 | block(self.size_in, size_out, stride, downsample))
130 | self.size_in = size_out
131 | return nn.Sequential(*stacks)
132 |
133 | def forward(self, x):
134 | y = self.conv1(x)
135 | y = self.bn1(y)
136 | y = self.relu(y)
137 | y = self.maxpool(y)
138 | y = self.stack1(y)
139 | y = self.stack2(y)
140 | y = self.stack3(y)
141 | y = self.stack4(y)
142 | y = self.avg_pool(y)
143 | y = torch.flatten(y, 1)
144 | y = self.fc(y)
145 | return y
146 |
147 |
148 | resnet18 = ResNet()
149 | ```
150 |
151 | - The PyWarm version significantly reduces self-repititions of code as in the vanilla PyTorch version.
152 |
153 | - Note that when warming the model via `warm.up(self, [2, 3, 32, 32])`
154 | We set the first `Batch` dimension to 2 because the model uses `batch_norm`,
155 | which will not work when `Batch` is 1.
156 |
157 | ----
158 |
159 | ## MobileNet
160 |
161 | ``` Python tab="Warm" linenums="1"
162 | import torch
163 | import torch.nn as nn
164 | import torch.nn.functional as F
165 | import warm
166 | import warm.functional as W
167 |
168 |
169 | def conv_bn_relu(x, size, stride=1, expand=1, kernel=3, groups=1):
170 | x = W.conv(x, size, kernel, padding=(kernel-1)//2,
171 | stride=stride, groups=groups, bias=False, )
172 | return W.batch_norm(x, activation='relu6')
173 |
174 |
175 | def bottleneck(x, size_out, stride, expand):
176 | size_in = x.shape[1]
177 | size_mid = size_in*expand
178 | y = conv_bn_relu(x, size_mid, kernel=1) if expand > 1 else x
179 | y = conv_bn_relu(y, size_mid, stride, kernel=3, groups=size_mid)
180 | y = W.conv(y, size_out, kernel=1, bias=False)
181 | y = W.batch_norm(y)
182 | if stride == 1 and size_in == size_out:
183 | y += x # residual shortcut
184 | return y
185 |
186 |
187 | def conv1x1(x, *arg):
188 | return conv_bn_relu(x, *arg, kernel=1)
189 |
190 |
191 | def pool(x, *arg):
192 | return x.mean([2, 3])
193 |
194 |
195 | def classify(x, size, *arg):
196 | x = W.dropout(x, rate=0.2)
197 | return W.linear(x, size)
198 |
199 |
200 | default_spec = (
201 | (None, 32, 1, 2, conv_bn_relu), # t, c, n, s, operator
202 | (1, 16, 1, 1, bottleneck),
203 | (6, 24, 2, 2, bottleneck),
204 | (6, 32, 3, 2, bottleneck),
205 | (6, 64, 4, 2, bottleneck),
206 | (6, 96, 3, 1, bottleneck),
207 | (6, 160, 3, 2, bottleneck),
208 | (6, 320, 1, 1, bottleneck),
209 | (None, 1280, 1, 1, conv1x1),
210 | (None, None, 1, None, pool),
211 | (None, 1000, 1, None, classify), )
212 |
213 |
214 | class MobileNetV2(nn.Module):
215 |
216 | def __init__(self):
217 | super().__init__()
218 | warm.up(self, [2, 3, 224, 224])
219 |
220 | def forward(self, x):
221 | for t, c, n, s, op in default_spec:
222 | for i in range(n):
223 | stride = s if i == 0 else 1
224 | x = op(x, c, stride, t)
225 | return x
226 |
227 |
228 | net = MobileNetV2()
229 | ```
230 |
231 | ``` Python tab="Torch" linenums="1"
232 | # code based on torchvision/models/mobilenet.py
233 | import torch
234 | import torch.nn as nn
235 | import torch.nn.functional as F
236 |
237 |
238 | class ConvBNReLU(nn.Sequential):
239 |
240 | def __init__(self, in_planes, out_planes,
241 | kernel_size=3, stride=1, groups=1):
242 | padding = (kernel_size-1)//2
243 | super(ConvBNReLU, self).__init__(
244 | nn.Conv2d(in_planes, out_planes, kernel_size,
245 | stride, padding, groups=groups, bias=False),
246 | nn.BatchNorm2d(out_planes),
247 | nn.ReLU6(inplace=True), )
248 |
249 |
250 | class BottleNeck(nn.Module):
251 |
252 | def __init__(self, inp, oup, stride, expand_ratio):
253 | super().__init__()
254 | self.stride = stride
255 | assert stride in [1, 2]
256 | hidden_dim = int(round(inp * expand_ratio))
257 | self.use_res_connect = self.stride == 1 and inp == oup
258 | layers = []
259 | if expand_ratio != 1:
260 | layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
261 | layers.extend([
262 | ConvBNReLU(hidden_dim, hidden_dim,
263 | stride=stride, groups=hidden_dim),
264 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
265 | nn.BatchNorm2d(oup), ])
266 | self.conv = nn.Sequential(*layers)
267 |
268 | def forward(self, x):
269 | if self.use_res_connect:
270 | return x + self.conv(x)
271 | else:
272 | return self.conv(x)
273 |
274 |
275 | default_spec = [
276 | [1, 16, 1, 1], # t, c, n, s
277 | [6, 24, 2, 2],
278 | [6, 32, 3, 2],
279 | [6, 64, 4, 2],
280 | [6, 96, 3, 1],
281 | [6, 160, 3, 2],
282 | [6, 320, 1, 1], ]
283 |
284 |
285 | class MobileNetV2(nn.Module):
286 |
287 | def __init__(self):
288 | super().__init__()
289 | input_channel = 32
290 | last_channel = 1280
291 | features = [ConvBNReLU(3, input_channel, stride=2)]
292 | for t, c, n, s in default_spec:
293 | output_channel = c
294 | for i in range(n):
295 | stride = s if i == 0 else 1
296 | features.append(
297 | BottleNeck(
298 | input_channel, output_channel,
299 | stride, expand_ratio=t))
300 | input_channel = output_channel
301 | features.append(ConvBNReLU(input_channel,
302 | last_channel, kernel_size=1))
303 | self.features = nn.Sequential(*features)
304 | self.classifier = nn.Sequential(
305 | nn.Dropout(0.2),
306 | nn.Linear(last_channel, 1000), )
307 |
308 | def forward(self, x):
309 | x = self.features(x)
310 | x = x.mean([2, 3])
311 | x = self.classifier(x)
312 | return x
313 |
314 |
315 | net = MobileNetV2()
316 | ```
317 |
318 | ## Transformer
319 |
320 | ```Python
321 | """
322 | The Transformer model from paper Attention is all you need.
323 | The Transformer instance accepts two inputs:
324 | x is Tensor with shape (Batch, Channel, LengthX).
325 | usually a source sequence from embedding (in such cases,
326 | Channel equals the embedding size).
327 | y is Tensor with shape (Batch, Channel, lengthY).
328 | usually a target sequence, also from embedding.
329 | **kw is passed down to inner components.
330 | """
331 | import torch
332 | import torch.nn as nn
333 | import torch.nn.functional as F
334 | import warm
335 | import warm.functional as W
336 |
337 |
338 | def multi_head_attention(x, y=None, num_head=8, dropout=0.1, mask=None, **kw):
339 | def split_heads(t):
340 | return t.reshape(batch, num_head, size//num_head, t.shape[-1])
341 | def merge_heads(t):
342 | return t.reshape(batch, -1, t.shape[-1])
343 | if y is None:
344 | y = x # self attention
345 | batch, size = x.shape[:2]
346 | assert size%num_head == 0, 'num_head must be a divisor of size.'
347 | assert y.shape[:2] == x.shape[:2], 'The first 2 dims of x, y must match.'
348 | q = W.linear(x, size) # query
349 | k = W.linear(y, size) # key
350 | v = W.linear(y, size) # value
351 | q = split_heads(q)
352 | k = split_heads(k)
353 | v = split_heads(v)
354 | q *= (size//num_head)**(-0.5)
355 | a = q.transpose(2, 3).contiguous().matmul(k) # attention weights
356 | if mask is not None:
357 | a += mask
358 | a = F.softmax(a, dim=-1)
359 | a = W.dropout(a, dropout)
360 | x = v.matmul(a.transpose(2, 3).contiguous())
361 | x = merge_heads(x)
362 | return W.linear(x, size)
363 |
364 |
365 | def feed_forward(x, size_ff=2048, dropout=0.1, **kw):
366 | y = W.linear(x, size_ff, activation='relu')
367 | y = W.dropout(y, dropout)
368 | return W.linear(y, x.shape[1])
369 |
370 |
371 | def residual_add(x, layer, dropout=0.1, **kw):
372 | y = W.layer_norm(x)
373 | y = layer(y, **kw)
374 | y = W.dropout(y, dropout)
375 | return x+y
376 |
377 |
378 | def encoder(x, num_encoder=6, **kw):
379 | for i in range(num_encoder):
380 | x = residual_add(x, multi_head_attention, **kw)
381 | x = residual_add(x, feed_forward, **kw)
382 | return W.layer_norm(x)
383 |
384 |
385 | def decoder(x, y, num_decoder=6, mask_x=None, mask_y=None, **kw):
386 | for i in range(num_decoder):
387 | y = residual_add(y, multi_head_attention, mask=mask_y, **kw)
388 | y = residual_add(x, multi_head_attention, y=y, mask=mask_x, **kw)
389 | y = residual_add(y, feed_forward, **kw)
390 | return W.layer_norm(y)
391 |
392 |
393 | def transformer(x, y, **kw):
394 | x = encoder(x, **kw)
395 | x = decoder(x, y, **kw)
396 | return x
397 |
398 |
399 | class Transformer(nn.Module):
400 |
401 | def __init__(self, *shape, **kw):
402 | super().__init__()
403 | self.kw = kw
404 | warm.up(self, *shape)
405 |
406 | def forward(self, x, y):
407 | return transformer(x, y, **self.kw)
408 |
409 | ```
410 |
411 |
412 | ## EfficientNet
413 |
414 | For a brief overview, check the [blog post](https://blue-season.github.io/efficientnet-in-5-minutes/).
415 |
416 | ```python
417 | """
418 | EfficientNet model from https://arxiv.org/abs/1905.11946
419 | """
420 | import torch
421 | import torch.nn as nn
422 | import torch.nn.functional as F
423 | import warm
424 | import warm.functional as W
425 |
426 |
427 | def swish(x):
428 | return x*torch.sigmoid(x)
429 |
430 |
431 | def squeeze_excitation(x, size_se):
432 | if size_se == 0:
433 | return x
434 | size_in = x.shape[1]
435 | x = F.adaptive_avg_pool2d(x, 1)
436 | x = W.conv(x, size_se, 1, activation=swish)
437 | return W.conv(x, size_in, 1, activation=swish)
438 |
439 |
440 | def drop_connect(x, rate):
441 | if rate == 0:
442 | return x
443 | rate = 1.0-rate
444 | drop_mask = rate + torch.rand([x.shape[0], 1, 1, 1],
445 | device=x.device, requires_grad=False)
446 | return x/rate*drop_mask.floor()
447 |
448 |
449 | def conv_pad_same(x, size, kernel=1, stride=1, **kw):
450 | """ Same padding so that out_size*stride == in_size. """
451 | pad = 0
452 | if kernel != 1 or stride != 1:
453 | in_size, s, k = [torch.as_tensor(v)
454 | for v in (x.shape[2:], stride, kernel)]
455 | pad = torch.max(((in_size+s-1)//s-1)*s+k-in_size, torch.tensor(0))
456 | left, right = pad//2, pad-pad//2
457 | if torch.all(left == right):
458 | pad = tuple(left.tolist())
459 | else:
460 | left, right = left.tolist(), right.tolist()
461 | pad = sum(zip(left[::-1], right[::-1]), ())
462 | x = F.pad(x, pad)
463 | pad = 0
464 | return W.conv(x, size, kernel, stride=stride, padding=pad, **kw)
465 |
466 |
467 | def conv_bn_act(x, size, kernel=1, stride=1, groups=1,
468 | bias=False, eps=1e-3, momentum=1e-2, act=swish):
469 | x = conv_pad_same(x, size, kernel, stride=stride, groups=groups, bias=bias)
470 | return W.batch_norm(x, eps=eps, momentum=momentum, activation=act)
471 |
472 |
473 | def mb_block(x, size_out, expand=1, kernel=1, stride=1,
474 | se_ratio=0.25, dc_ratio=0.2):
475 | """ Mobilenet Bottleneck Block. """
476 | size_in = x.shape[1]
477 | size_mid = size_in*expand
478 | y = conv_bn_act(x, size_mid, 1) if expand > 1 else x
479 | y = conv_bn_act(y, size_mid, kernel, stride=stride, groups=size_mid)
480 | y = squeeze_excitation(y, int(size_in*se_ratio))
481 | y = conv_bn_act(y, size_out, 1, act=None)
482 | if stride == 1 and size_in == size_out:
483 | y = drop_connect(y, dc_ratio)
484 | y += x
485 | return y
486 |
487 |
488 | spec_b0 = (
489 | # size, expand, kernel, stride, repeat, squeeze_excitation, drop_connect
490 | (16, 1, 3, 1, 1, 0.25, 0.2),
491 | (24, 6, 3, 2, 2, 0.25, 0.2),
492 | (40, 6, 5, 2, 2, 0.25, 0.2),
493 | (80, 6, 3, 2, 3, 0.25, 0.2),
494 | (112, 6, 5, 1, 3, 0.25, 0.2),
495 | (192, 6, 5, 2, 4, 0.25, 0.2),
496 | (320, 6, 3, 1, 1, 0.25, 0.2), )
497 |
498 |
499 | class WarmEfficientNet(nn.Module):
500 | def __init__(self):
501 | super().__init__()
502 | warm.up(self, [2, 3, 32, 32])
503 | def forward(self, x):
504 | x = conv_bn_act(x, 32, kernel=3, stride=2)
505 | for size, expand, kernel, stride, repeat, se, dc in spec_b0:
506 | for i in range(repeat):
507 | stride = stride if i == 0 else 1
508 | x = mb_block(x, size, expand, kernel, stride, se, dc)
509 | x = conv_bn_act(x, 1280)
510 | x = F.adaptive_avg_pool2d(x, 1)
511 | x = W.dropout(x, 0.2)
512 | x = x.view(x.shape[0], -1)
513 | x = W.linear(x, 1000)
514 | return x
515 | ```
516 |
--------------------------------------------------------------------------------
/warm/functional.py:
--------------------------------------------------------------------------------
1 | # 08-27-2019;
2 | """
3 | Wraps around various torch.nn Modules to fit into a functional interface.
4 | """
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | import numpy as np
9 | from warm import engine
10 | from warm import util
11 |
12 |
13 | permute = engine.permute
14 |
15 |
16 | def conv(x, size, kernel, init_weight=None, init_bias=None, bias=True, **kw):
17 | """ Convolution layer.\n
18 | - `x: Tensor`; With shape `(Batch, Channel, *)` where `*` Can be 1d or 2d or 3d.
19 | If 3d, shapes are `(Batch, Channel, Length)`.
20 | If 4d, shapes are `(Batch, Channel, Height, Width)`.
21 | If 5d, shapes are `(Batch, Channel, Depth, Height, Width)`.
22 | - `size: int`; Size of hidden filters, and size of the output channel.
23 | - `kernel: int or tuple`; Size of the convolution kernel.
24 | - `init_weight: None or str or callable`; Initialization specification for the weight tensor.
25 | If a `str`, should be one of the nonlinearity functions contained in `torch.nn.init`.
26 | If a `callable`, it will be applied to `x` directly, i.e. `spec(x)`. If a 2-`tuple`,
27 | it must be of format `(callable, kwargs)`, i.e. `callable(x, **kwargs)`.
28 | Default: `None`, and the weight tensor is initialized using `torch.nn.ConvNd`s default scheme.
29 | - `init_bias: None or str or callable`; Same as `init_weight`, but for the bias tensor.
30 | - `bias: bool`; If `True`, adds a learnable bias to the output. Default: `True`.
31 | - `**kw:dict`; Any additional KWargs are passed down to `torch.nn.ConvNd`, where N can be 1, 2 or 3.
32 | as well as `warm.engine.forward`. Refer to their docs for details. Some of the additional ConvNd arguments:
33 | `stride, padding, dilation, groups`.
34 | - `return: Tensor`; With shape `(Batch, Size, *)` where `*` can be 1d, 2d, 3d that depends on `x`. """
35 | d = x.ndim-3
36 | assert d in [0, 1, 2], 'Incompatible number of dims for input x.'
37 | inferred_kw = dict(
38 | base_name='conv',
39 | base_class=[nn.Conv1d, nn.Conv2d, nn.Conv3d][d],
40 | base_kw={
41 | 'out_channels':size,
42 | 'kernel_size':kernel,
43 | 'bias':bias,
44 | **engine.unused_kwargs(kw), },
45 | infer_kw={'in_channels':'C'},
46 | initialization={'weight':init_weight, **({'bias':init_bias} if bias else {})}, )
47 | return engine.forward(x, **{**inferred_kw, **kw})
48 |
49 |
50 | def linear(x, size, init_weight=None, init_bias=None, bias=True, **kw):
51 | """ Linear transformation layer.\n
52 | - `x: Tensor`; 2d or more, with shapes `(Batch, Channel, *)` where `*` means any number of additional dimensions.
53 | - `size: int`; Size of hidden features, and size of the output channel.
54 | - `init_weight: None or str or callable`; Initialization specification for the weight tensor.
55 | If a `str`, should be one of the nonlinearity functions contained in `torch.nn.init`.
56 | If a `callable`, it will be applied to `x` directly, i.e. `spec(x)`. If a 2-`tuple`,
57 | it must be of format `(callable, kwargs)`, i.e. `callable(x, **kwargs)`.
58 | Default: `None`, and the weight tensor is initialized using `torch.nn.Linear`s default scheme.
59 | - `init_bias: None or str or callable`; Same as `init_weight`, but for the bias tensor.
60 | - `bias: bool`; If `True`, adds a learnable bias to the output. Default: `True`.
61 | - `**kw:dict`; Any additional KWargs are passed down to `warm.engine.forward`. Refer to its docs for details.
62 | - `return: Tensor`; With shape `(Batch, Size, *)` where `*` can be 1d, 2d, 3d that depends on `x`. """
63 | inferred_kw = dict(
64 | base_name='linear',
65 | base_class=nn.Linear,
66 | base_kw={'out_features':size, 'bias':bias},
67 | base_shape='BDC',
68 | infer_kw={'in_features':'C'},
69 | initialization={'weight':init_weight, **({'bias':init_bias} if bias else {})}, )
70 | return engine.forward(x, **{**inferred_kw, **kw})
71 |
72 |
73 | def batch_norm(x, **kw):
74 | """ Batch Normalization layer.\n
75 | - `x: Tensor`; 2d or more, with shapes `(Batch, Channel, *)` where `*` means any number of additional dimensions.
76 | - `**kw: dict`; Any additional KWargs are passed down to `torch.nn.BatchNormNd`, where N can be 1, 2 or 3.
77 | as well as `warm.engine.forward`. Refer to their docs for details. Some of the additional BatchNorm arguments:
78 | `eps, momentum, affine, track_running_stats`.
79 | - `return: Tensor`; Same shape as input `x`. """
80 | d = x.ndim-3
81 | assert d in [0, 1, 2], 'Incompatible number of dims for input x.'
82 | inferred_kw = dict(
83 | base_name='batch_norm',
84 | base_class=[nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d][d],
85 | base_kw={'num_features':x.shape[1]}, )
86 | return engine.forward(x, **{**inferred_kw, **kw})
87 |
88 |
89 | def lstm(x, size,
90 | init_weight_hh='orthogonal_', init_weight_ih=None, init_bias_hh=None, init_bias_ih=None,
91 | bias=True, num_layers=1, **kw):
92 | """ Long Short Term Memory layer.\n
93 | - `x: Tensor or tuple`; If tuple, must be of format `(x, (h_0, c_0))`, where `x` is a 3d tensor,
94 | with shapes `(Batch, Channel, Length)`.
95 | - `size: int`; Size of hidden features, and size of the output channel.
96 | - `init_weight_hh: None or str or callable`; Initialization specification for the hidden-hidden weight tensor.
97 | If a `str`, should be one of the nonlinearity functions contained in `torch.nn.init`.
98 | If a `callable`, it will be applied to `x` directly, i.e. `spec(x)`. If a 2-`tuple`,
99 | it must be of format `(callable, kwargs)`, i.e. `callable(x, **kwargs)`.
100 | Default: `'orthogonal_'`.
101 | - `init_weight_ih: None or str or callable`; Initialization specification for the input-hidden weight tensor.
102 | Default: `None`, and the weight tensor is initialized using `torch.nn.LSTM`s default scheme.
103 | - `init_bias_hh: None or str or callable`; Initialization specification for the hidden-hidden bias tensor.
104 | Default: `None`, and the weight tensor is initialized using `torch.nn.LSTM`s default scheme.
105 | - `init_bias_ih: None or str or callable`; Initialization specification for the input-hidden bias tensor.
106 | Default: `None`, and the weight tensor is initialized using `torch.nn.LSTM`s default scheme.
107 | - `bias: bool`; If `False`, then the layer does not use `bias_ih` and `bias_hh`. Default: `True`.
108 | - `num_layers: int`; Number of the recurrent layers. Default: 1.
109 | - `tuple_out: bool`; If `True`, the returned value will be a tuple `(out, (h_n, c_n))`. Default: False.
110 | - `**kw: dict`; Any additional KWargs are passed down to `torch.nn.LSTM`, as well as `warm.engine.forward`.
111 | Refer to their docs for details. Some of the additional LSTM arguments: `dropout, bidirectional, batch_first`.
112 | - `return: Tensor or tuple`; If `tuple_out` set to true, will return `(out, (h_n, c_n)`, otherwise just `out`.
113 | `out` has shape `(Batch, Size, Length*Directions)`,
114 | where Directions = 2 if `bidirectional` else 1.
115 | `h_n` is the hidden states with shape `(num_layers*Directions, Batch, Size)`.
116 | `c_n` is the cell states with shape `(num_layers*Directions, Batch, Size)`. """
117 | states = None
118 | if isinstance(x, tuple):
119 | x, *states = x
120 | init = dict(
121 | weight_hh=init_weight_hh,
122 | weight_ih=init_weight_ih,
123 | bias_hh=init_bias_hh,
124 | bias_ih=init_bias_ih, )
125 | inferred_kw = dict(
126 | base_name='lstm',
127 | base_class=nn.LSTM,
128 | base_kw={
129 | 'hidden_size':size,
130 | 'num_layers':num_layers,
131 | **engine.unused_kwargs(kw), },
132 | base_shape='DBC',
133 | infer_kw={'input_size':'C'},
134 | forward_arg=states,
135 | initialization={
136 | f'{k}_l{l}':init[k] for k in ['weight_hh', 'weight_ih']+(['bias_hh', 'bias_ih'] if bias else [])
137 | for l in range(num_layers)}, )
138 | return engine.forward(x, **{**inferred_kw, **kw})
139 |
140 |
141 | def gru(*arg, **kw):
142 | """ Gated Recurrent Unit layer.\n
143 | - `x: Tensor or tuple`; If tuple, must be of format `(x, (h_0, c_0))`, where `x` is a 3d tensor,
144 | with shapes `(Batch, Channel, Length)`.
145 | - `size: int`; Size of hidden features, and size of the output channel.
146 | - `init_weight_hh: None or str or callable`; Initialization specification for the hidden-hidden weight tensor.
147 | If a `str`, should be one of the nonlinearity functions contained in `torch.nn.init`.
148 | If a `callable`, it will be applied to `x` directly, i.e. `spec(x)`. If a 2-`tuple`,
149 | it must be of format `(callable, kwargs)`, i.e. `callable(x, **kwargs)`.
150 | Default: `'orthogonal_'`.
151 | - `init_weight_ih: None or str or callable`; Initialization specification for the input-hidden weight tensor.
152 | Default: `None`, and the weight tensor is initialized using `torch.nn.GRU`s default scheme.
153 | - `init_bias_hh: None or str or callable`; Initialization specification for the hidden-hidden bias tensor.
154 | Default: `None`, and the weight tensor is initialized using `torch.nn.GRU`s default scheme.
155 | - `init_bias_ih: None or str or callable`; Initialization specification for the input-hidden bias tensor.
156 | Default: `None`, and the weight tensor is initialized using `torch.nn.GRU`s default scheme.
157 | - `bias: bool`; If `False`, then the layer does not use `bias_ih` and `bias_hh`. Default: `True`.
158 | - `num_layers: int`; Number of the recurrent layers. Default: 1.
159 | - `tuple_out: bool`; If `True`, the returned value will be a tuple `(out, (h_n, c_n))`. Default: False.
160 | - `**kw: dict`; Any additional KWargs are passed down to `torch.nn.GRU`, as well as `warm.engine.forward`.
161 | Refer to their docs for details. Some of the additional GRU arguments: `dropout, bidirectional, batch_first`.
162 | - `return: Tensor or tuple`; If `tuple_out` set to true, will return `(out, (h_n, c_n)`, otherwise just `out`.
163 | `out` has shape `(Batch, Size, Length*Directions)`,
164 | where Directions = 2 if `bidirectional` else 1.
165 | `h_n` is the hidden states with shape `(num_layers*Directions, Batch, Size)`.
166 | `c_n` is the cell states with shape `(num_layers*Directions, Batch, Size)`. """
167 | return lstm(*arg, base_name='gru', base_class=nn.GRU, **kw)
168 |
169 |
170 | def identity(x, *arg, **kw):
171 | """ Identity layer that returns the first input, ignores the rest arguments. """
172 | return x
173 |
174 |
175 | def dropout(x, rate=0.5, by_channel=False, **kw):
176 | """ Dropout layer.\n
177 | During training, randomly zeros part of input tensor `x`, at probability `rate`.\n
178 | - `x: Tensor`; Can be of any shape if `by_channel` is false, or 2d and up if `by_channel` is true.
179 | - `rate: float`; The probability of dropout. Default 0.5.
180 | - `by_channel: bool`; If true, will dropout entire channels (all `'D'` dimensions will be 0 if x is `'BCD'`).
181 | `by_channel` true requires `x` to be 2d or more.
182 | - `inplace: bool`; If true, the operation will be in-place and the input `x` will be altered.
183 | - `return: Tensor`; Same shape as `x`. """
184 | inferred_kw = dict(
185 | base_name='dropout',
186 | base_class=[nn.Dropout, nn.Dropout2d][by_channel],
187 | base_kw={'p':rate},
188 | base_shape=[None, 'BCD'][by_channel], )
189 | return engine.forward(x, **{**inferred_kw, **kw})
190 |
191 |
192 | def transformer(x, y=None, num_encoder=6, num_decoder=6, num_head=8,
193 | mask=None, causal=False, in_shape='BCD', **kw):
194 | """ Transformer layer.\n
195 | This layer covers functionality of `Transformer`, `TransformerEncoder`, and `TransformerDecoder`.
196 | See [`torch.nn.Transformer`](https://pytorch.org/docs/stable/nn.html#transformer) for more details.\n
197 | - `x: Tensor`; The source sequence, with shape `(Batch, Channel, LengthX)`.
198 | `Channel` is usually from embedding.
199 | - `y: None or Tensor`; The target sequence. Also with shape `(Batch, Channel, LengthY)`.
200 | If not present, default to equal `x`.
201 | - `num_encoder: int`; Number of encoder layers. Set to 0 to disable encoder and use only decoder. Default 6.
202 | - `num_decoder: int`; Number of decoder layers. Set to 0 to disable decoder and use only encoder. Default 6.
203 | - `num_head: int`; Number of heads for multi-headed attention. Default 8.
204 | - `mask: None or dict`; Keys are among: `src_mask`, `tgt_mask`, `memory_mask`,
205 | `src_key_padding_mask`, `tgt_key_padding_mask`, `memory_key_padding_mask`.
206 | See the `forward` method of `torch.nn.Transformer` for details.
207 | - `causal: bool`; Default false. if true, will add causal masks to source and target, so that
208 | current value only depends on the past, not the future, in the sequences.
209 | - `**kw: dict`; Any additional KWargs are passed down to `torch.nn.Transformer`, as well as `warm.engine.forward`.
210 | - `return: Tensor`; Same shape as `y`, if `num_decoder` > 0. Otherwise same shape as `x`. """
211 | def _causal_mask(n):
212 | mask = (torch.triu(torch.ones(n, n)) == 1).transpose(0, 1)
213 | return mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
214 | if y is None:
215 | y = x
216 | y = permute(y, in_shape, 'DBC')
217 | mask = mask or {}
218 | if causal:
219 | i = in_shape.find('D')
220 | mx = _causal_mask(x.shape[i])
221 | mask['src_mask'] = mask.pop('src_mask', 0.0)+mx
222 | my = _causal_mask(y.shape[0])
223 | mask['tgt_mask'] = mask.pop('tgt_mask', 0.0)+my
224 | encoder = identity if num_encoder == 0 else None
225 | decoder = identity if num_decoder == 0 else None
226 | inferred_kw = dict(
227 | base_name='transformer',
228 | base_class=nn.Transformer,
229 | base_shape='DBC',
230 | base_kw=dict(
231 | d_model=x.shape[in_shape.find('C')],
232 | custom_encoder=encoder,
233 | custom_decoder=decoder,
234 | nhead=num_head,
235 | num_encoder_layers=num_encoder,
236 | num_decoder_layers=num_decoder,
237 | **engine.unused_kwargs(kw), ),
238 | in_shape=in_shape,
239 | forward_kw=mask,
240 | forward_arg=(y, ), )
241 | return engine.forward(x, **{**inferred_kw, **kw})
242 |
243 |
244 | def layer_norm(x, dim=1, **kw):
245 | """ Layer Normalization.\n
246 | - `x: Tensor`; Can be of any shape.
247 | - `dim: int or list of int`; Dimensions to be normalized. Default: 1.
248 | - `**kw: dict`; Any additional KWargs are passed down to `torch.nn.LayerNorm`, as well as `warm.engine.forward`.
249 | - `return: Tensor`; Same shape as `x`. """
250 | if dim != -1:
251 | if isinstance(dim, int):
252 | dim = [dim]
253 | dim_norm = [x.ndim+i if i < 0 else i for i in dim]
254 | order = [i for i in range(x.ndim) if i not in dim_norm]+dim_norm
255 | x = x.permute(order)
256 | norm_shape = x.shape[-len(dim_norm):]
257 | else:
258 | norm_shape = [x.shape[-1]]
259 | inferred_kw = dict(
260 | base_name='layer_norm',
261 | base_class=nn.LayerNorm,
262 | base_kw={'normalized_shape':norm_shape}, )
263 | x = engine.forward(x, **{**inferred_kw, **kw})
264 | if dim != -1:
265 | x = x.permute(np.argsort(order).tolist())
266 | return x
267 |
268 |
269 | def embedding(x, size, vocabulary=None, **kw):
270 | """ Embedding layer.\n
271 | The input is usually a list of indices (integers), and the output is a dense matrix which
272 | maps indices to dense vectors. Thus the output will have 1 more dimension than the input.\n
273 | **Note**: The output of this function is always one more dimension than the input. For input with shape `(*)`,
274 | The output will be `(*, size)`. Any shape specifications in the KWargs are ignored. \n
275 | - `x: Tensor`; Contains indices into the vocabulary. Will be converted to `LongTensor` of integers.
276 | Can be of any shape.
277 | - `size: int`; The size of embedding vector.
278 | - `vocabulary: int or None`; The size of vocabulary of embedding, or max number of unique indices in `x`.
279 | By default it is set to `max(x)-min(x)+1`.
280 | - `**kw: dict`; Any additional KWargs are passed down to `torch.nn.LayerNorm`, as well as `warm.engine.forward`.
281 | - `return: Tensor`; With the embedded dim appended to the shape of x.
282 | Thus with shape `(*, Size)`, where `*` is the shape of `x`. """
283 | x = x.type(torch.LongTensor)
284 | if vocabulary is None:
285 | vocabulary = x.max()-x.min()+1
286 | kw.pop('in_shape', None)
287 | kw.pop('out_shape', None)
288 | kw.pop('base_shape', None)
289 | inferred_kw = dict(
290 | base_name='embedding',
291 | base_class=nn.Embedding,
292 | base_kw=dict(
293 | num_embeddings=vocabulary,
294 | embedding_dim=size,
295 | **engine.unused_kwargs(kw), ),
296 | base_shape=None,
297 | in_shape=None,
298 | out_shape=None, )
299 | return engine.forward(x, **{**inferred_kw, **kw})
300 |
--------------------------------------------------------------------------------