├── docs ├── pywarm-logo.png ├── pywarm-logo-small-dark.gif ├── pywarm-logo-small-light.gif ├── github-icon.svg ├── text.mako ├── tutorial.md └── example.md ├── warm ├── __init__.py ├── util.py ├── module.py ├── engine.py └── functional.py ├── tests ├── test_warm.py ├── test_util.py ├── test_module.py ├── test_functional.py └── test_engine.py ├── CONTRIBUTING.md ├── LICENSE.md ├── pyproject.toml ├── .gitignore ├── examples ├── transformer.py ├── resnet.py ├── mobilenet.py ├── lstm.py ├── efficientnet.py └── mnist.py └── README.md /docs/pywarm-logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blue-season/pywarm/HEAD/docs/pywarm-logo.png -------------------------------------------------------------------------------- /docs/pywarm-logo-small-dark.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blue-season/pywarm/HEAD/docs/pywarm-logo-small-dark.gif -------------------------------------------------------------------------------- /docs/pywarm-logo-small-light.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blue-season/pywarm/HEAD/docs/pywarm-logo-small-light.gif -------------------------------------------------------------------------------- /warm/__init__.py: -------------------------------------------------------------------------------- 1 | # 09-10-2019; 2 | 3 | """ `warm.up` is an alias of 4 | [`warm.engine.prepare_model_`](https://blue-season.github.io/pywarm/reference/warm/engine/#prepare_model_). """ 5 | from warm.engine import prepare_model_ as up 6 | -------------------------------------------------------------------------------- /tests/test_warm.py: -------------------------------------------------------------------------------- 1 | # 09-10-2019; 2 | """ 3 | Test cases for the warm module. 4 | """ 5 | import torch.nn as nn 6 | from pathlib import Path 7 | import sys 8 | sys.path.append(str(Path(__file__).parent.parent)) 9 | import warm 10 | 11 | 12 | def test_warm_up(): 13 | m = nn.Identity() 14 | assert not warm.engine.is_ready(m), 'is_ready did not work correctly.' 15 | warm.up(m, [1, 2, 3]) 16 | assert warm.engine.is_ready(m), 'warm.up did not work correctly.' 17 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to PyWarm 2 | 3 | PyWarm is developed on [GitHub](https://github.com/blue-season/pywarm). 4 | 5 | Please use GitHub to file Bug reports and submit pull requests. 6 | 7 | Please document and test before submissions. 8 | 9 | PyWarm is developed with Python 3.7, but has been tested to work with Python 3.6+. 10 | 11 | # Coding Style 12 | 13 | For the rational behind the distinct coding style use in PyWarm, please check 14 | 15 | [A Coding Style for Python](https://blue-season.github.io/a-coding-style-for-python/). 16 | -------------------------------------------------------------------------------- /docs/github-icon.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /tests/test_util.py: -------------------------------------------------------------------------------- 1 | # 08-31-2019; 2 | """ 3 | Test cases for warm.util. 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | import numpy as np 8 | from pathlib import Path 9 | import sys 10 | sys.path.append(str(Path(__file__).parent.parent)) 11 | from warm import util 12 | 13 | 14 | def test_camel_to_snake(): 15 | assert util.camel_to_snake('CamelAndSnake') == 'camel_and_snake' 16 | assert util.camel_to_snake('camelAndSnake') == 'camel_and_snake' 17 | assert util.camel_to_snake('camelANDSnake') == 'camel_and_snake' 18 | assert util.camel_to_snake('CAMELAndSnake') == 'camel_and_snake' 19 | assert util.camel_to_snake('CAMELAndSNAKE') == 'camel_and_snake' 20 | assert util.camel_to_snake('CamelAndSnake_') == 'camel_and_snake_' 21 | assert util.camel_to_snake('_CamelAndSnake') == '__camel_and_snake' 22 | 23 | 24 | def test_summary_str(): 25 | from examples.resnet import WarmResNet 26 | m = WarmResNet() 27 | s = util.summary_str(m) 28 | assert len(s) > 0 29 | 30 | 31 | def test_summary(): 32 | from examples.resnet import WarmResNet 33 | m = WarmResNet() 34 | util.summary(m) 35 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 blue-season 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = 'PyWarm' 3 | version = '0.4.1' 4 | description = 'A cleaner way to build neural networks for PyTorch.' 5 | license = 'MIT' 6 | authors = ['blue-season '] 7 | readme = 'README.md' 8 | repository = 'https://github.com/blue-season/pywarm' 9 | homepage = 'https://github.com/blue-season/pywarm' 10 | keywords = ['pywarm', 'pytorch', 'neural network', 'deep learning'] 11 | packages = [ { include='warm' }, ] 12 | 13 | 14 | [tool.poetry.dependencies] 15 | python = '>=3.6' 16 | 17 | 18 | [tool.poetry.dev-dependencies] 19 | toml = '>=0.9' 20 | pytest = '>=3.0' 21 | torch = '>=1.0' 22 | torchvision = '>=0.4' 23 | 24 | 25 | [tool.portray] 26 | modules = ['warm'] 27 | 28 | 29 | [tool.portray.mkdocs] 30 | markdown_extensions = ['pymdownx.superfences'] 31 | 32 | 33 | [tool.portray.mkdocs.theme] 34 | logo = 'docs/pywarm-logo-small-light.gif' 35 | favicon = 'docs/pywarm-logo-small-dark.gif' 36 | name = 'material' 37 | palette = {primary='deep orange', accent='pink'} 38 | 39 | 40 | [tool.portray.pdoc3] 41 | config = ['show_source_code=False', 42 | 'show_type_annotations=False', 43 | 'sort_identifiers=True', 44 | 'show_inherited_members=False'] 45 | template_dir = 'docs' 46 | -------------------------------------------------------------------------------- /tests/test_module.py: -------------------------------------------------------------------------------- 1 | # 08-31-2019; 2 | """ 3 | Test cases for warm.module. 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from pathlib import Path 9 | import sys 10 | sys.path.append(str(Path(__file__).parent.parent)) 11 | import warm.module as mm 12 | import warm.functional as W 13 | 14 | 15 | def test_lambda(): 16 | f = lambda x: x*2 17 | m = mm.Lambda(f) 18 | x = torch.randn(1, 2) 19 | assert torch.equal(f(x), m(x)), 'lambda did not work correctly.' 20 | def f(x, w, b=5): 21 | return x*w+b 22 | m = mm.Lambda(f, 2, b=1) 23 | assert torch.equal(f(x, 2, 1), m(x)), 'function with args and kwargs did not work correctly.' 24 | x = torch.randn(3, 2, 4) 25 | m = mm.Lambda(W.permute, 'BDC', 'BCD') 26 | assert list(m(x).shape) == [3, 4, 2], 'lambda permute did not work correctly.' 27 | 28 | 29 | def test_sequential(): 30 | s = mm.Sequential( 31 | nn.Linear(1, 2), 32 | nn.LSTM(2, 3, batch_first=True), # lstm and gru return multiple outputs 33 | nn.GRU(3, 4, batch_first=True), 34 | mm.Lambda(W.permute, 'BDC', 'BCD'), 35 | nn.Conv1d(4, 5, 1), ) 36 | x = torch.randn(3, 2, 1) 37 | assert list(s(x).shape) == [3, 5, 2] 38 | 39 | 40 | def test_shortcut(): 41 | l = nn.Linear(1, 1, bias=False) 42 | nn.init.constant_(l.weight, 2.0) 43 | s = mm.Shortcut(l) 44 | x = torch.ones(1, 1) 45 | assert torch.allclose(s(x), torch.Tensor([3.0])) 46 | -------------------------------------------------------------------------------- /warm/util.py: -------------------------------------------------------------------------------- 1 | # 08-28-2019; 2 | """ 3 | Short utilities. 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | import numpy as np 8 | import re 9 | 10 | 11 | """ Create a property for class torch.Tensor called ndim, for pytorch earlier than 1.2. """ 12 | if not hasattr(torch.Tensor, 'ndim'): 13 | torch.Tensor.ndim = property(lambda x: x.dim()) 14 | 15 | 16 | def camel_to_snake(name): 17 | """ Convert a camelCaseString to its snake_case_equivalent. """ 18 | s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name) 19 | return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower() 20 | 21 | 22 | def summary_str(model): 23 | """ Get a string representation of model building blocks and parameter counts. """ 24 | indent_list, name_list, count_list = [], [], [] 25 | def module_info(m, name, indent_level): 26 | count_list.append(sum([np.prod(list(p.size())) for p in m.parameters()])) 27 | indent_list.append(indent_level) 28 | name_list.append(name) 29 | for name, child in m.named_children(): 30 | if name.isdigit(): 31 | name = child._get_name() 32 | module_info(child, name, indent_level+1) 33 | module_info(model, model._get_name(), 0) 34 | max_indent = max(indent_list)*4 35 | max_name = max(len(x) for x in name_list)+max_indent+2 36 | max_param = len(str(count_list[0]))+max_name+2 37 | out = ['Blocks{:>{w}}'.format('Params', w=max_param-6)] 38 | out += ['-'*max_param] 39 | for indent, name, param in zip(indent_list, name_list, count_list): 40 | s0 = ' '*indent 41 | s1 = '{:{w}}'.format(name, w=max_name-len(s0)) 42 | s2 = '{:>{w}}'.format(param, w=max_param-len(s1)-len(s0)) 43 | out += [s0+s1+s2] 44 | return '\n'.join(out) 45 | 46 | 47 | def summary(model): 48 | """ Print a summary about model building blocks and parameter counts. """ 49 | print(summary_str(model)) 50 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # Auto-generated content above this. Manually added content below. 107 | .vscode/ 108 | .cache/ 109 | *cache* 110 | /.project 111 | tmp/ 112 | data/ 113 | site/ 114 | -------------------------------------------------------------------------------- /examples/transformer.py: -------------------------------------------------------------------------------- 1 | # 09-05-2019; 2 | """ 3 | The Transformer model from paper *Attention is all you need*. 4 | """ 5 | from pathlib import Path 6 | import sys 7 | sys.path.append(str(Path(__file__).parent.parent)) 8 | sys.path.append('..') 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import warm 13 | import warm.util 14 | import warm.functional as W 15 | 16 | 17 | def multi_head_attention(x, y=None, num_head=8, dropout=0.1, mask=None, **kw): 18 | def split_heads(t): # (B, C, L) -> (B, N, H, L) where N*H == C 19 | return t.reshape(batch, num_head, size//num_head, t.shape[-1]) 20 | def merge_heads(t): # (B, N, H, L) -> (B, C, L) 21 | return t.reshape(batch, -1, t.shape[-1]) # (B, C, L) 22 | if y is None: 23 | y = x # self attention 24 | batch, size = x.shape[:2] # B, C, Lx 25 | assert size%num_head == 0, 'num_head must be a divisor of size.' 26 | assert y.shape[:2] == x.shape[:2], 'The first 2 dims of x, y must match.' 27 | q = W.linear(x, size) # query 28 | k = W.linear(y, size) # key 29 | v = W.linear(y, size) # value 30 | q = split_heads(q) # (B, N, H, Lx) 31 | k = split_heads(k) # (B, N, H, Ly) 32 | v = split_heads(v) # (B, N, H, Ly) 33 | q *= (size//num_head)**(-0.5) 34 | a = q.transpose(2, 3).contiguous().matmul(k) # attention weights, (B, N, Lx, Ly) 35 | if mask is not None: 36 | a += mask 37 | a = F.softmax(a, dim=-1) 38 | a = W.dropout(a, dropout) 39 | x = v.matmul(a.transpose(2, 3).contiguous()) # (B, N, H, Lx) 40 | x = merge_heads(x) # (B, C, Lx) 41 | return W.linear(x, size) 42 | 43 | 44 | def feed_forward(x, size_ff=2048, dropout=0.1, **kw): 45 | y = W.linear(x, size_ff, activation='relu') 46 | y = W.dropout(y, dropout) 47 | return W.linear(y, x.shape[1]) 48 | 49 | 50 | def residual_add(x, layer, dropout=0.1, **kw): 51 | y = W.layer_norm(x) 52 | y = layer(y, **kw) 53 | y = W.dropout(y, dropout) 54 | return x+y 55 | 56 | 57 | def encoder(x, num_encoder=6, **kw): 58 | for i in range(num_encoder): 59 | x = residual_add(x, multi_head_attention, **kw) 60 | x = residual_add(x, feed_forward, **kw) 61 | return W.layer_norm(x) 62 | 63 | 64 | def decoder(x, y, num_decoder=6, mask_x=None, mask_y=None, **kw): 65 | for i in range(num_decoder): 66 | y = residual_add(y, multi_head_attention, mask=mask_y, **kw) 67 | y = residual_add(x, multi_head_attention, y=y, mask=mask_x, **kw) 68 | y = residual_add(y, feed_forward, **kw) 69 | return W.layer_norm(y) 70 | 71 | 72 | def transformer(x, y, **kw): 73 | x = encoder(x, **kw) 74 | x = decoder(x, y, **kw) 75 | return x 76 | 77 | 78 | class Transformer(nn.Module): 79 | def __init__(self, *shape, **kw): 80 | super().__init__() 81 | self.kw = kw 82 | warm.up(self, *shape) 83 | def forward(self, x, y): 84 | return transformer(x, y, **self.kw) 85 | -------------------------------------------------------------------------------- /warm/module.py: -------------------------------------------------------------------------------- 1 | # 08-27-2019; 2 | """ 3 | Custom modules to enhance the nn Sequential experience. 4 | 5 | PyWarm's core concept is to use a functional interface to simplify network building. 6 | However, if you still prefer the classical way of defining child modules in `__init__()`, 7 | PyWarm provides some utilities to help organize child modules better. 8 | 9 | - `Lambda` can be used to wrap one line data transformations, like `x.view()`, `x.permute()` etc, into modules. 10 | 11 | - `Sequential` is an extension to `nn.Sequential` that better accomodates PyTorch RNNs. 12 | 13 | - `Shortcut` is another extension to `nn.Sequential` that will also perform a shortcut addition (AKA residual connection) 14 | for the input with output, so that residual blocks can be written in an entire sequential way. 15 | 16 | For example, to define the basic block type for resnet: 17 | 18 | 19 | ```Python 20 | import torch.nn as nn 21 | import warm.module as wm 22 | 23 | 24 | def basic_block(size_in, size_out, stride=1): 25 | block = wm.Shortcut( 26 | nn.Conv2d(size_in, size_out, 3, stride, 1, bias=False), 27 | nn.BatchNorm2d(size_out), 28 | nn.ReLU(), 29 | nn.Conv2d(size_out, size_out, 3, 1, 1, bias=False), 30 | nn.BatchNorm2d(size_out), 31 | projection=wm.Lambda( 32 | lambda x: x if x.shape[1] == size_out else nn.Sequential( 33 | nn.Conv2d(size_in, size_out, 1, stride, bias=False), 34 | nn.BatchNorm2d(size_out), )(x), ), ) 35 | return block 36 | ``` 37 | """ 38 | 39 | 40 | import torch.nn as nn 41 | 42 | 43 | class Lambda(nn.Module): 44 | """ Wraps a callable and all its call arguments.\n 45 | - `fn: callable`; The callable being wrapped. 46 | - `*arg: list`; Arguments to be passed to `fn`. 47 | - `**kw: dict`; KWargs to be passed to `fn`. """ 48 | def __init__(self, fn, *arg, **kw): 49 | super().__init__() 50 | self.fn = fn 51 | self.arg = arg 52 | self.kw = kw 53 | def forward(self, x): 54 | """ forward. """ 55 | return self.fn(x, *self.arg, **self.kw) 56 | 57 | 58 | class Sequential(nn.Sequential): 59 | """ Similar to `nn.Sequential`, except that child modules can have multiple outputs (e.g. `nn.RNN`).\n 60 | - `*arg: list of Modules`; Same as `nn.Sequential`. """ 61 | def forward(self, x): 62 | """ forward. """ 63 | for module in self._modules.values(): 64 | if isinstance(x, tuple): 65 | try: 66 | x = module(x) 67 | except Exception: 68 | x = module(x[0]) 69 | else: 70 | x = module(x) 71 | return x 72 | 73 | 74 | class Shortcut(Sequential): 75 | """ Similar to `nn.Sequential`, except that it performs a shortcut addition for the input and output.\n 76 | - `*arg: list of Modules`; Same as `nn.Sequential`. 77 | - `projection: None or callable`; If `None`, input with be added directly to the output. 78 | otherwise input will be passed to the `projection` first, usually to make the shapes match. """ 79 | def __init__(self, *arg, projection=None): 80 | super().__init__(*arg) 81 | self.projection = projection or nn.Identity() 82 | def forward(self, x): 83 | """ forward. """ 84 | return super().forward(x)+self.projection(x) 85 | -------------------------------------------------------------------------------- /examples/resnet.py: -------------------------------------------------------------------------------- 1 | # 08-29-2019; 2 | """ 3 | Construct a WarmResNet() using PyWarm, then copy state dicts 4 | from torchvision.models.resnet18() into WarmResNet(), 5 | compare if it produce identical results as the official one. 6 | """ 7 | from pathlib import Path 8 | import sys 9 | sys.path.append(str(Path(__file__).parent.parent)) 10 | sys.path.append('..') 11 | import time 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | import warm 16 | import warm.util 17 | import warm.functional as W 18 | 19 | 20 | def basic(x, size, stride, stack_index, block_index): 21 | """ The basic block. """ 22 | prefix = f'layer{stack_index+1}-{block_index}-' 23 | y = W.conv(x, size, 3, stride=stride, padding=1, bias=False, name=prefix+'conv1') 24 | y = W.batch_norm(y, activation='relu', name=prefix+'bn1') 25 | y = W.conv(y, size, 3, stride=1, padding=1, bias=False, name=prefix+'conv2') 26 | y = W.batch_norm(y, name=prefix+'bn2') 27 | if y.shape[1] != x.shape[1]: 28 | x = W.conv(x, y.shape[1], 1, stride=stride, bias=False, name=prefix+'downsample-0') 29 | x = W.batch_norm(x, name=prefix+'downsample-1') 30 | return F.relu(y+x) 31 | 32 | 33 | def stack(x, num_block, size, stride, stack_index, block=basic): 34 | """ A stack of num_block blocks. """ 35 | for block_index, s in enumerate([stride]+[1]*(num_block-1)): 36 | x = block(x, size, s, stack_index, block_index) 37 | return x 38 | 39 | 40 | class WarmResNet(nn.Module): 41 | def __init__(self, block=basic, stack_spec=((2, 64, 1), (2, 128, 2), (2, 256, 2), (2, 512, 2))): 42 | super().__init__() 43 | self.block = block 44 | self.stack_spec = stack_spec 45 | warm.up(self, [2, 3, 32, 32]) 46 | def forward(self, x): 47 | y = W.conv(x, 64, 7, stride=2, padding=3, bias=False, name='conv1') 48 | y = W.batch_norm(y, activation='relu', name='bn1') 49 | y = F.max_pool2d(y, 3, stride=2, padding=1) 50 | for i, spec in enumerate(self.stack_spec): 51 | y = stack(y, *spec, i, block=self.block) 52 | y = F.adaptive_avg_pool2d(y, 1) 53 | y = torch.flatten(y, 1) 54 | y = W.linear(y, 1000, name='fc') 55 | return y 56 | 57 | 58 | def test_time(fn, *arg, repeat=10, **kw): 59 | dur = 0.0 60 | for i in range(repeat): 61 | start = time.time() 62 | y = fn(*arg, **kw) 63 | dur += time.time()-start 64 | return dur 65 | 66 | 67 | def test(): 68 | """ Compare the classification result of WarmResNet versus torchvision resnet18. """ 69 | new = WarmResNet() 70 | from torchvision.models import resnet18 71 | old = resnet18() 72 | state = old.state_dict() 73 | for k in list(state.keys()): # Map parameters of old, e.g. layer2.0.conv1.weight 74 | s = k.split('.') # to parameters of new, e.g. layer2-0-conv1.weight 75 | s = '-'.join(s[:-1])+'.'+s[-1] 76 | state[s] = state.pop(k) 77 | new.load_state_dict(state) 78 | warm.util.summary(old) 79 | warm.util.summary(new) 80 | x = torch.randn(2, 3, 224, 224) 81 | with torch.no_grad(): 82 | old.eval() 83 | y_old = old(x) 84 | new.eval() 85 | y_new = new(x) 86 | if torch.equal(y_old, y_new): 87 | print('Success! Same results from old and new.') 88 | else: 89 | print('Warning! New and old produce different results.') 90 | t_old = test_time(old, x) 91 | t_new = test_time(new, x) 92 | print('Total forward time for old:', t_old, 'seconds.') 93 | print('Total forward time for new:', t_new, 'seconds.') 94 | 95 | 96 | if __name__ == '__main__': 97 | test() 98 | -------------------------------------------------------------------------------- /examples/mobilenet.py: -------------------------------------------------------------------------------- 1 | # 09-03-2019; 2 | """ 3 | Construct a WarmMobileNetV2() using PyWarm, then copy state dicts 4 | from torchvision.models.mobilenet_v2() into WarmMobileNetV2(), 5 | compare if it produce identical results as the official one. 6 | """ 7 | from pathlib import Path 8 | import sys 9 | sys.path.append(str(Path(__file__).parent.parent)) 10 | sys.path.append('..') 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | import warm 15 | import warm.util 16 | import warm.functional as W 17 | 18 | 19 | def conv_bn_relu(x, size, stride=1, expand=1, kernel=3, groups=1, name=''): 20 | x = W.conv(x, size, kernel, padding=(kernel-1)//2, stride=stride, groups=groups, bias=False, 21 | name=f'{name}-0', ) 22 | return W.batch_norm(x, activation='relu6', name=f'{name}-1') 23 | 24 | 25 | def bottleneck(x, size_out, stride, expand, name=''): 26 | size_in = x.shape[1] 27 | size_mid = size_in*expand 28 | y = conv_bn_relu(x, size_mid, kernel=1, name=f'{name}-conv-0') if expand > 1 else x 29 | y = conv_bn_relu(y, size_mid, stride, kernel=3, groups=size_mid, name=f'{name}-conv-{1 if expand > 1 else 0}') 30 | y = W.conv(y, size_out, kernel=1, bias=False, name=f'{name}-conv-{2 if expand > 1 else 1}') 31 | y = W.batch_norm(y, name=f'{name}-conv-{3 if expand > 1 else 2}') 32 | if stride == 1 and size_in == size_out: 33 | y += x # residual shortcut 34 | return y 35 | 36 | 37 | def conv1x1(x, *arg, **kw): 38 | return conv_bn_relu(x, *arg, kernel=1, **kw) 39 | 40 | 41 | def pool(x, *arg, **kw): 42 | return x.mean([2, 3]) 43 | 44 | 45 | def classify(x, size, *arg, **kw): 46 | x = W.dropout(x, rate=0.2, name='classifier-0') 47 | return W.linear(x, size, name='classifier-1') 48 | 49 | 50 | default_spec = ( 51 | (None, 32, 1, 2, conv_bn_relu), # t, c, n, s, operator 52 | (1, 16, 1, 1, bottleneck), 53 | (6, 24, 2, 2, bottleneck), 54 | (6, 32, 3, 2, bottleneck), 55 | (6, 64, 4, 2, bottleneck), 56 | (6, 96, 3, 1, bottleneck), 57 | (6, 160, 3, 2, bottleneck), 58 | (6, 320, 1, 1, bottleneck), 59 | (None, 1280, 1, 1, conv1x1), 60 | (None, None, 1, None, pool), 61 | (None, 1000, 1, None, classify), ) 62 | 63 | 64 | class WarmMobileNetV2(nn.Module): 65 | def __init__(self): 66 | super().__init__() 67 | warm.up(self, [2, 3, 224, 224]) 68 | def forward(self, x): 69 | count = 0 70 | for t, c, n, s, op in default_spec: 71 | for i in range(n): 72 | stride = s if i == 0 else 1 73 | x = op(x, c, stride, t, name=f'features-{count}') 74 | count += 1 75 | return x 76 | 77 | 78 | def test(): 79 | """ Compare the classification result of WarmMobileNetV2 versus torchvision mobilenet_v2. """ 80 | new = WarmMobileNetV2() 81 | from torchvision.models import mobilenet_v2 82 | old = mobilenet_v2() 83 | state = old.state_dict() 84 | for k in list(state.keys()): # Map parameters of old, e.g. layer2.0.conv1.weight 85 | s = k.split('.') # to parameters of new, e.g. layer2-0-conv1.weight 86 | s = '-'.join(s[:-1])+'.'+s[-1] 87 | state[s] = state.pop(k) 88 | new.load_state_dict(state) 89 | warm.util.summary(old) 90 | warm.util.summary(new) 91 | x = torch.randn(1, 3, 224, 224) 92 | with torch.no_grad(): 93 | old.eval() 94 | y_old = old(x) 95 | new.eval() 96 | y_new = new(x) 97 | if torch.equal(y_old, y_new): 98 | print('Success! Same results from old and new.') 99 | else: 100 | print('Warning! New and old produce different results.') 101 | 102 | 103 | if __name__ == '__main__': 104 | test() 105 | -------------------------------------------------------------------------------- /docs/text.mako: -------------------------------------------------------------------------------- 1 | ## Define mini-templates for each portion of the doco. 2 | 3 | <%! 4 | def indent(s, spaces=4): 5 | new = s.replace('\n', '\n' + ' ' * spaces) 6 | return ' ' * spaces + new.strip() 7 | %> 8 | 9 | <%def name="deflist(s)">:${indent(s)[1:]} 10 | 11 | <%def name="h3(s)">### ${s} 12 | 13 | 14 | <%def name="function(func)" buffered="True"> 15 | <% 16 | returns = show_type_annotations and func.return_annotation() or '' 17 | if returns: 18 | returns = ' -> ' + returns 19 | %> 20 | ${"---"} 21 | ${"### " + func.name} 22 | 23 | 24 | ```python3 25 | def : 26 | ${",\n ".join(func.params(annotate=show_type_annotations))} ${returns} 27 | ``` 28 | ${func.docstring} 29 | 30 | % if show_source_code and func.source and func.obj is not getattr(func.inherits, 'obj', None): 31 | 32 | ??? example "View Source" 33 | ${"\n ".join(func.source.split("\n"))} 34 | 35 | % endif 36 | 37 | 38 | <%def name="variable(var)" buffered="True"> 39 | ```python3 40 | ${var.name} 41 | ``` 42 | ${var.docstring | deflist} 43 | 44 | 45 | <%def name="class_(cls)" buffered="True"> 46 | ${"---"} 47 | ${"### " + cls.name} 48 | 49 | ```python3 50 | def : 51 | ${",\n ".join(cls.params(annotate=show_type_annotations))} 52 | ``` 53 | 54 | ${cls.docstring} 55 | 56 | % if show_source_code and cls.source: 57 | 58 | ??? example "View Source" 59 | ${"\n ".join(cls.source.split("\n"))} 60 | 61 | ------ 62 | 63 | % endif 64 | 65 | <% 66 | class_vars = cls.class_variables(show_inherited_members, sort=sort_identifiers) 67 | static_methods = cls.functions(show_inherited_members, sort=sort_identifiers) 68 | inst_vars = cls.instance_variables(show_inherited_members, sort=sort_identifiers) 69 | methods = cls.methods(show_inherited_members, sort=sort_identifiers) 70 | mro = cls.mro() 71 | subclasses = cls.subclasses() 72 | %> 73 | % if mro: 74 | ${h3('Ancestors (in MRO)')} 75 | % for c in mro: 76 | * ${c.refname} 77 | % endfor 78 | % endif 79 | 80 | % if subclasses: 81 | ${h3('Descendants')} 82 | % for c in subclasses: 83 | * ${c.refname} 84 | % endfor 85 | % endif 86 | 87 | % if class_vars: 88 | ${h3('Class variables')} 89 | % for v in class_vars: 90 | ${variable(v)} 91 | 92 | % endfor 93 | % endif 94 | 95 | % if static_methods: 96 | ${h3('Static methods')} 97 | % for f in static_methods: 98 | ${function(f)} 99 | 100 | % endfor 101 | % endif 102 | 103 | % if inst_vars: 104 | ${h3('Instance variables')} 105 | % for v in inst_vars: 106 | ${variable(v)} 107 | 108 | % endfor 109 | % endif 110 | % if methods: 111 | ${h3('Methods')} 112 | % for m in methods: 113 | ${function(m)} 114 | 115 | % endfor 116 | % endif 117 | 118 | 119 | 120 | ## Start the output logic for an entire module. 121 | 122 | <% 123 | variables = module.variables() 124 | classes = module.classes() 125 | functions = module.functions() 126 | submodules = module.submodules() 127 | heading = 'Namespace' if module.is_namespace else 'Module' 128 | %> 129 | 130 | ${"# " + heading} ${module.name} 131 | 132 | ${module.docstring} 133 | 134 | % if show_source_code: 135 | 136 | ??? example "View Source" 137 | ${"\n ".join(module.source.split("\n"))} 138 | 139 | % endif 140 | 141 | 142 | % if submodules: 143 | Sub-modules 144 | ----------- 145 | % for m in submodules: 146 | * [${m.name}](${m.name.split(".")[-1]}/) 147 | % endfor 148 | % endif 149 | 150 | % if variables: 151 | Variables 152 | --------- 153 | % for v in variables: 154 | ${variable(v)} 155 | 156 | % endfor 157 | % endif 158 | 159 | % if functions: 160 | Functions 161 | --------- 162 | % for f in functions: 163 | ${function(f)} 164 | 165 | % endfor 166 | % endif 167 | 168 | % if classes: 169 | Classes 170 | ------- 171 | % for c in classes: 172 | ${class_(c)} 173 | 174 | % endfor 175 | % endif 176 | -------------------------------------------------------------------------------- /examples/lstm.py: -------------------------------------------------------------------------------- 1 | # 09-07-2019; 2 | """ 3 | LSTM sequence model example, based on 4 | https://pytorch.org/tutorials/beginner/nlp/sequence_models_tutorial.html 5 | """ 6 | import argparse 7 | from pathlib import Path 8 | import sys 9 | sys.path.append(str(Path(__file__).parent.parent)) 10 | sys.path.append('..') 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | import torch.optim as optim 15 | import warm 16 | import warm.functional as W 17 | 18 | 19 | training_data = [ 20 | ('The dog ate the apple'.split(), ['DET', 'NN', 'V', 'DET', 'NN']), 21 | ('Everybody read that book'.split(), ['NN', 'V', 'DET', 'NN']), ] 22 | testing_data = [('The dog ate the book'.split(), ['DET', 'NN', 'V', 'DET', 'NN'])] 23 | word_to_ix = {} 24 | for sent, tags in training_data: 25 | for word in sent: 26 | if word not in word_to_ix: 27 | word_to_ix[word] = len(word_to_ix) 28 | tag_to_ix = {'DET': 0, 'NN': 1, 'V': 2} 29 | ix_to_tag = {v:k for k, v in tag_to_ix.items()} 30 | 31 | 32 | class WarmTagger(nn.Module): 33 | def __init__(self, embedding_dim, hidden_dim, vocab_size, tagset_size): 34 | super().__init__() 35 | self.arg = (embedding_dim, hidden_dim, vocab_size, tagset_size) 36 | warm.up(self, torch.tensor([0, 1], dtype=torch.long)) 37 | def forward(self, x): # D 38 | embedding_dim, hidden_dim, vocab_size, tagset_size = self.arg 39 | y = W.embedding(x, embedding_dim, vocab_size) # D->DC 40 | y = W.lstm(y.T[None, ...], hidden_dim) # DC->BCD 41 | y = W.linear(y, tagset_size) # BCD 42 | y = F.log_softmax(y, dim=1) # BCD 43 | return y[0].T # DC 44 | 45 | 46 | class TorchTagger(nn.Module): 47 | def __init__(self, embedding_dim, hidden_dim, vocab_size, tagset_size): 48 | super().__init__() 49 | self.hidden_dim = hidden_dim 50 | self.word_embeddings = nn.Embedding(vocab_size, embedding_dim) 51 | self.lstm = nn.LSTM(embedding_dim, hidden_dim) 52 | self.hidden2tag = nn.Linear(hidden_dim, tagset_size) 53 | def forward(self, sentence): 54 | embeds = self.word_embeddings(sentence) 55 | lstm_out, _ = self.lstm(embeds.view(len(sentence), 1, -1)) 56 | tag_space = self.hidden2tag(lstm_out.view(len(sentence), -1)) 57 | tag_scores = F.log_softmax(tag_space, dim=1) 58 | return tag_scores 59 | 60 | 61 | def prepare_sequence(seq, to_ix): 62 | idxs = [to_ix[w] for w in seq] 63 | return torch.tensor(idxs, dtype=torch.long) 64 | 65 | 66 | def main(): 67 | parser = argparse.ArgumentParser(description='PyTorch MNIST Example') 68 | parser.add_argument( 69 | '--warm', action='store_true', help='use warm instead of vanilla pytorch.') 70 | p = parser.parse_args() 71 | torch.manual_seed(1) 72 | # 73 | arg = (6, 6, len(word_to_ix), len(tag_to_ix)) 74 | model = WarmTagger(*arg) if p.warm else TorchTagger(*arg) 75 | print(f'Using {model._get_name()}.') 76 | loss_function = nn.NLLLoss() 77 | optimizer = optim.SGD(model.parameters(), lr=0.1) 78 | # 79 | for epoch in range(300): 80 | for sentence, tags in training_data: 81 | model.zero_grad() 82 | sentence_in = prepare_sequence(sentence, word_to_ix) 83 | targets = prepare_sequence(tags, tag_to_ix) 84 | tag_scores = model(sentence_in) 85 | loss = loss_function(tag_scores, targets) 86 | loss.backward() 87 | optimizer.step() 88 | # 89 | with torch.no_grad(): 90 | inputs = prepare_sequence(testing_data[0][0], word_to_ix) 91 | tag_scores = model(inputs) 92 | ix = torch.argmax(tag_scores, -1).numpy() 93 | print(testing_data[0][0]) 94 | print('Network tags:\n', [ix_to_tag[i] for i in ix]) 95 | print('True tags:\n', testing_data[0][1]) 96 | 97 | 98 | if __name__ == '__main__': 99 | main() 100 | -------------------------------------------------------------------------------- /examples/efficientnet.py: -------------------------------------------------------------------------------- 1 | 2 | # 09-20-2019; 3 | """ 4 | EfficientNet 5 | """ 6 | from pathlib import Path 7 | import sys 8 | sys.path.append(str(Path(__file__).parent.parent)) 9 | sys.path.append('..') 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import warm 14 | import warm.util 15 | import warm.functional as W 16 | from warm.engine import namespace 17 | 18 | 19 | def swish(x): 20 | return x*torch.sigmoid(x) 21 | 22 | 23 | def conv_pad_same(x, size, kernel=1, stride=1, **kw): 24 | pad = 0 25 | if kernel != 1 or stride != 1: 26 | in_size, s, k = [torch.as_tensor(v) for v in (x.shape[2:], stride, kernel)] 27 | pad = torch.max(((in_size+s-1)//s-1)*s+k-in_size, torch.tensor(0)) 28 | left, right = pad//2, pad-pad//2 29 | if torch.all(left == right): 30 | pad = tuple(left.tolist()) 31 | else: 32 | left, right = left.tolist(), right.tolist() 33 | pad = sum(zip(left[::-1], right[::-1]), ()) 34 | x = F.pad(x, pad) 35 | pad = 0 36 | return W.conv(x, size, kernel, stride=stride, padding=pad, **kw) 37 | 38 | 39 | @namespace 40 | def conv_bn_act(x, size, kernel=1, stride=1, groups=1, bias=False, eps=1e-3, momentum=1e-2, act=swish, name='', **kw): 41 | x = conv_pad_same(x, size, kernel, stride=stride, groups=groups, bias=bias, name=name+'-conv') 42 | return W.batch_norm(x, eps=eps, momentum=momentum, activation=act, name=name+'-bn') 43 | 44 | 45 | @namespace 46 | def mb_block(x, size_out, expand=1, kernel=1, stride=1, se_ratio=0.25, dc_ratio=0.2, **kw): 47 | """ MobileNet Bottleneck Block. """ 48 | size_in = x.shape[1] 49 | size_mid = size_in*expand 50 | y = conv_bn_act(x, size_mid, 1, **kw) if expand > 1 else x 51 | y = conv_bn_act(y, size_mid, kernel, stride=stride, groups=size_mid, **kw) 52 | y = squeeze_excitation(y, int(size_in*se_ratio), **kw) 53 | y = conv_bn_act(y, size_out, 1, act=None, **kw) 54 | if stride == 1 and size_in == size_out: 55 | y = drop_connect(y, dc_ratio) 56 | y += x 57 | return y 58 | 59 | 60 | @namespace 61 | def squeeze_excitation(x, size_se, name='', **kw): 62 | if size_se == 0: 63 | return x 64 | size_in = x.shape[1] 65 | x = F.adaptive_avg_pool2d(x, 1) 66 | x = W.conv(x, size_se, 1, activation=swish, name=name+'-conv1') 67 | return W.conv(x, size_in, 1, activation=swish, name=name+'-conv2') 68 | 69 | 70 | def drop_connect(x, rate): 71 | """ Randomly set entire batch to 0. """ 72 | if rate == 0: 73 | return x 74 | rate = 1.0-rate 75 | drop_mask = torch.rand([x.shape[0], 1, 1, 1], device=x.device, requires_grad=False)+rate 76 | return x/rate*drop_mask.floor() 77 | 78 | 79 | spec_b0 = ( 80 | (16, 1, 3, 1, 1, 0.25, 0.2), # size, expand, kernel, stride, repeat, se_ratio, dc_ratio 81 | (24, 6, 3, 2, 2, 0.25, 0.2), 82 | (40, 6, 5, 2, 2, 0.25, 0.2), 83 | (80, 6, 3, 2, 3, 0.25, 0.2), 84 | (112, 6, 5, 1, 3, 0.25, 0.2), 85 | (192, 6, 5, 2, 4, 0.25, 0.2), 86 | (320, 6, 3, 1, 1, 0.25, 0.2), ) 87 | 88 | 89 | class WarmEfficientNet(nn.Module): 90 | def __init__(self): 91 | super().__init__() 92 | warm.up(self, [2, 3, 32, 32]) 93 | def forward(self, x): 94 | x = conv_bn_act(x, 32, kernel=3, stride=2, name='head') 95 | for size, expand, kernel, stride, repeat, se_ratio, dc_ratio in spec_b0: 96 | for i in range(repeat): 97 | stride = stride if i == 0 else 1 98 | x = mb_block(x, size, expand, kernel, stride, se_ratio, dc_ratio) 99 | x = conv_bn_act(x, 1280, name='tail') 100 | x = F.adaptive_avg_pool2d(x, 1) 101 | x = W.dropout(x, 0.2) 102 | x = x.view(x.shape[0], -1) 103 | x = W.linear(x, 1000) 104 | return x 105 | 106 | 107 | if __name__ == '__main__': 108 | m = WarmEfficientNet() 109 | warm.util.summary(m) 110 | -------------------------------------------------------------------------------- /tests/test_functional.py: -------------------------------------------------------------------------------- 1 | # 08-31-2019; 2 | """ 3 | Test cases for warm.functional. 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | from pathlib import Path 8 | import sys 9 | sys.path.append(str(Path(__file__).parent.parent)) 10 | import warm.module as mm 11 | import warm.functional as W 12 | 13 | 14 | def test_conv(): 15 | m = nn.Module() 16 | x = torch.randn(1, 2, 8) # BCD 17 | torch.manual_seed(100) 18 | y0 = nn.Conv1d(2, 3, 3)(x) 19 | torch.manual_seed(100) 20 | y1 = W.conv(x, 3, 3, parent=m) 21 | assert torch.equal(y0, y1), 'conv incorrect output on 1d signal.' 22 | m = nn.Module() 23 | x = torch.randn(1, 2, 3, 4) # BCD 24 | torch.manual_seed(100) 25 | y0 = nn.Conv2d(2, 3, 3)(x) 26 | torch.manual_seed(100) 27 | y1 = W.conv(x, 3, 3, parent=m) 28 | assert torch.equal(y0, y1), 'conv incorrect output on 2d signal.' 29 | 30 | 31 | def test_linear(): 32 | m = nn.Module() 33 | x = torch.randn(1, 2, 3) # BDC 34 | torch.manual_seed(100) 35 | y0 = nn.Linear(3, 4)(x) 36 | torch.manual_seed(100) 37 | y1 = W.linear(x, 4, parent=m, in_shape='BDC', out_shape='BDC') 38 | assert torch.equal(y0, y1), 'linear incorrect output on 1d signal.' 39 | m = nn.Module() 40 | x = torch.randn(1, 2, 3, 4) # BDC 41 | torch.manual_seed(100) 42 | y0 = nn.Linear(4, 3)(x) 43 | torch.manual_seed(100) 44 | y1 = W.linear(x, 3, parent=m, in_shape='BDC', out_shape='BDC') 45 | assert torch.equal(y0, y1), 'batch_norm incorrect output on 2d signal.' 46 | 47 | 48 | def test_batch_norm(): 49 | m = nn.Module() 50 | x = torch.randn(1, 2, 3) # BCD 51 | torch.manual_seed(100) 52 | y0 = nn.BatchNorm1d(2)(x) 53 | torch.manual_seed(100) 54 | y1 = W.batch_norm(x, parent=m) 55 | m = nn.Module() 56 | assert torch.equal(y0, y1), 'batch_norm incorrect output on 1d signal.' 57 | x = torch.randn(1, 2, 3, 4) # BCD 58 | torch.manual_seed(100) 59 | y0 = nn.BatchNorm2d(2)(x) 60 | torch.manual_seed(100) 61 | y1 = W.batch_norm(x, parent=m) 62 | assert torch.equal(y0, y1), 'batch_norm incorrect output on 2d signal.' 63 | 64 | 65 | def test_lstm(): 66 | m = nn.Module() 67 | x = torch.randn(3, 2, 1) # DBC 68 | torch.manual_seed(100) 69 | y0, *_ = nn.LSTM(1, 2, num_layers=2)(x) 70 | torch.manual_seed(100) 71 | y1 = W.lstm(x, 2, num_layers=2, parent=m, init_weight_hh=None, in_shape='DBC', out_shape='DBC') 72 | assert torch.equal(y0, y1) 73 | y1, s1 = W.lstm(x, 2, parent=m, tuple_out=True) # test tuple out 74 | assert len(s1) == 2 75 | y2 = W.lstm((y1, s1), 2, parent=m) # test tuple in 76 | assert torch.is_tensor(y2) 77 | 78 | 79 | def test_gru(): 80 | m = nn.Module() 81 | x = torch.randn(3, 2, 1) # DBC 82 | torch.manual_seed(100) 83 | y0, *_ = nn.GRU(1, 2, num_layers=2)(x) 84 | torch.manual_seed(100) 85 | y1 = W.gru(x, 2, num_layers=2, parent=m, init_weight_hh=None, in_shape='DBC', out_shape='DBC') 86 | assert torch.equal(y0, y1) 87 | 88 | 89 | def test_identity(): 90 | x = torch.randn(1, 2, 3) 91 | assert torch.equal(W.identity(x, 7, 8, a='b'), x) 92 | 93 | 94 | def test_dropout(): 95 | m = nn.Module() 96 | x = torch.ones(2, 6, 6, 6) 97 | torch.manual_seed(100) 98 | y0 = nn.Dropout(0.3)(x) 99 | torch.manual_seed(100) 100 | y1 = W.dropout(x, 0.3, parent=m) 101 | assert torch.equal(y0, y1) 102 | torch.manual_seed(100) 103 | y0 = nn.Dropout2d(0.3)(x) 104 | torch.manual_seed(100) 105 | y1 = W.dropout(x, 0.3, by_channel=True, parent=m) 106 | assert torch.equal(y0, y1) 107 | 108 | 109 | def test_transformer(): 110 | m = nn.Module() 111 | x = torch.randn(10, 2, 4) 112 | y = torch.randn(6, 2, 4) 113 | torch.manual_seed(100) 114 | z0 = nn.Transformer(4, 2, 1, 1, dim_feedforward=8)(x, y) 115 | torch.manual_seed(100) 116 | z1 = W.transformer(x, y, 1, 1, 2, dim_feedforward=8, in_shape='DBC', out_shape='DBC', parent=m) 117 | assert torch.equal(z0, z1) 118 | torch.manual_seed(100) 119 | z1 = W.transformer(x, y, 1, 1, 2, dim_feedforward=8, in_shape='DBC', out_shape='DBC', parent=m, causal=True) 120 | assert not torch.equal(z0, z1) 121 | z1 = W.transformer(x, None, 2, 0, 2, dim_feedforward=8, in_shape='DBC', out_shape='DBC', parent=m) 122 | assert z1.shape == x.shape 123 | 124 | 125 | def test_layer_norm(): 126 | m = nn.Module() 127 | x = torch.randn(1, 2, 3, 4, 5) 128 | y0 = nn.LayerNorm([3, 4, 5])(x) 129 | y1 = W.layer_norm(x, [2, -2, -1], parent=m) 130 | assert torch.equal(y0, y1) 131 | y0 = nn.LayerNorm(5)(x) 132 | y1 = W.layer_norm(x, dim=-1, parent=m) 133 | assert torch.equal(y0, y1) 134 | x0 = x.permute(0, 4, 2, 1, 3) 135 | y0 = nn.LayerNorm([2, 4])(x0) 136 | y0 = y0.permute(0, 3, 2, 4, 1) 137 | y1 = W.layer_norm(x, dim=[1, -2], parent=m) 138 | assert torch.equal(y0, y1) 139 | 140 | 141 | def test_embedding(): 142 | m = nn.Module() 143 | x = torch.randint(0, 20, (1, 2, 3, 4, 5)) 144 | torch.manual_seed(10) 145 | y0 = nn.Embedding(20, 8)(x) 146 | torch.manual_seed(10) 147 | y1 = W.embedding(x, 8, 20, parent=m) 148 | assert torch.equal(y0, y1) 149 | torch.manual_seed(10) 150 | y1 = W.embedding(x, 8, 20, in_shape='DCB', parent=m) # shapes should have no effect 151 | assert torch.equal(y0, y1) 152 | torch.manual_seed(10) 153 | y1 = W.embedding(x, 8, 20, out_shape='CBD', parent=m) # shapes should have no effect 154 | assert torch.equal(y0, y1) 155 | y1 = W.embedding(x, 8, parent=m) # should work without a explicit vocabulary size 156 | torch.manual_seed(10) 157 | y1 = W.embedding(x.double(), 8, parent=m) # should work with non integer tensors. 158 | assert torch.equal(y0, y1) 159 | -------------------------------------------------------------------------------- /examples/mnist.py: -------------------------------------------------------------------------------- 1 | # 08-27-2019; 2 | """ 3 | MNIST training example. 4 | Use `python mnist.py` to run with PyTorch NN. 5 | Use `python mnist.py --warm` to run with PyWarm NN. 6 | Use `python mnist.py --help` to see a list of cli argument options. 7 | """ 8 | from pathlib import Path 9 | import sys 10 | sys.path.append(str(Path(__file__).parent.parent)) 11 | sys.path.append('..') 12 | import argparse 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | import torch.optim as optim 17 | from torchvision import datasets, transforms 18 | import warm 19 | import warm.functional as W 20 | 21 | 22 | class WarmNet(nn.Module): 23 | def __init__(self): 24 | super().__init__() 25 | warm.up(self, [1, 1, 28, 28]) 26 | def forward(self, x): 27 | x = W.conv(x, 20, 5, activation='relu') 28 | x = F.max_pool2d(x, 2) 29 | x = W.conv(x, 50, 5, activation='relu') 30 | x = F.max_pool2d(x, 2) 31 | x = x.view(-1, 800) 32 | x = W.linear(x, 500, activation='relu') 33 | x = W.linear(x, 10) 34 | return F.log_softmax(x, dim=1) 35 | 36 | 37 | class TorchNet(nn.Module): 38 | def __init__(self): 39 | super().__init__() 40 | self.conv1 = nn.Conv2d(1, 20, 5, 1) 41 | self.conv2 = nn.Conv2d(20, 50, 5, 1) 42 | self.fc1 = nn.Linear(4*4*50, 500) 43 | self.fc2 = nn.Linear(500, 10) 44 | def forward(self, x): 45 | x = F.relu(self.conv1(x)) 46 | x = F.max_pool2d(x, 2, 2) 47 | x = F.relu(self.conv2(x)) 48 | x = F.max_pool2d(x, 2, 2) 49 | x = x.view(-1, 4*4*50) 50 | x = F.relu(self.fc1(x)) 51 | x = self.fc2(x) 52 | return F.log_softmax(x, dim=1) 53 | 54 | 55 | def train(p, model, device, train_loader, optimizer, epoch): 56 | model.train() 57 | for batch_idx, (data, target) in enumerate(train_loader): 58 | data, target = data.to(device), target.to(device) 59 | optimizer.zero_grad() 60 | output = model(data) 61 | loss = F.nll_loss(output, target) 62 | loss.backward() 63 | optimizer.step() 64 | if batch_idx%p.log_interval == 0: 65 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 66 | epoch, batch_idx*len(data), len(train_loader.dataset), 67 | 100.*batch_idx/len(train_loader), loss.item())) 68 | 69 | 70 | def test(p, model, device, test_loader): 71 | model.eval() 72 | test_loss = 0 73 | correct = 0 74 | size = len(test_loader.dataset) 75 | with torch.no_grad(): 76 | for data, target in test_loader: 77 | data, target = data.to(device), target.to(device) 78 | output = model(data) 79 | test_loss += F.nll_loss(output, target, reduction='sum').item() 80 | pred = output.argmax(dim=1, keepdim=True) 81 | correct += pred.eq(target.view_as(pred)).sum().item() 82 | test_loss /= size 83 | print(f'\nTest loss: {test_loss:.4f}, Accuracy: {correct}/{size} ({100*correct/size:.2f}%)\n') 84 | 85 | 86 | def main(): 87 | parser = argparse.ArgumentParser(description='PyTorch MNIST Example') 88 | parser.add_argument( 89 | '--warm', action='store_true', help='use warm instead of vanilla pytorch.') 90 | parser.add_argument( 91 | '--batch-size', type=int, default=128, metavar='N', help='input batch size for training (default: 128)') 92 | parser.add_argument( 93 | '--test-batch-size', type=int, default=1000, metavar='N', help='input batch size for testing (default: 1000)') 94 | parser.add_argument( 95 | '--epochs', type=int, default=3, metavar='N', help='number of epochs to train (default: 3)') 96 | parser.add_argument( 97 | '--lr', type=float, default=0.02, metavar='LR', help='learning rate (default: 0.02)') 98 | parser.add_argument( 99 | '--momentum', type=float, default=0.5, metavar='M', help='SGD momentum (default: 0.5)') 100 | parser.add_argument( 101 | '--no-cuda', action='store_true', default=False, help='disables CUDA training') 102 | parser.add_argument( 103 | '--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') 104 | parser.add_argument( 105 | '--log-interval', type=int, default=10, metavar='N', help='number of batchs between logging training status') 106 | parser.add_argument( 107 | '--save-model', action='store_true', default=False, help='For Saving the current Model') 108 | p = parser.parse_args() 109 | # 110 | torch.manual_seed(p.seed) 111 | use_cuda = not p.no_cuda and torch.cuda.is_available() 112 | device = 'cuda' if use_cuda else 'cpu' 113 | kw = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} 114 | data_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)), ]) 115 | train_data = datasets.MNIST('../data', train=True, download=True, transform=data_transform) 116 | test_data = datasets.MNIST('../data', train=False, download=True, transform=data_transform) 117 | train_loader = torch.utils.data.DataLoader(train_data, batch_size=p.batch_size, shuffle=True, **kw) 118 | test_loader = torch.utils.data.DataLoader(test_data, batch_size=p.test_batch_size, shuffle=True, **kw) 119 | model = WarmNet() if p.warm else TorchNet() 120 | print(f'Using {model._get_name()}.') 121 | model = model.to(device) 122 | optimizer = optim.SGD(model.parameters(), lr=p.lr, momentum=p.momentum) 123 | print(f'Training with {p.epochs} epochs on {device} device.') 124 | # 125 | for i in range(p.epochs): 126 | train(p, model, device, train_loader, optimizer, i) 127 | test(p, model, device, test_loader) 128 | # 129 | if p.save_model: 130 | torch.save(model.state_dict(), 'mnist_cnn.pt') 131 | 132 | 133 | if __name__ == '__main__': 134 | main() 135 | -------------------------------------------------------------------------------- /tests/test_engine.py: -------------------------------------------------------------------------------- 1 | # 08-31-2019; 2 | """ 3 | Test cases for warm.engine. 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import copy 9 | from pathlib import Path 10 | import sys 11 | sys.path.append(str(Path(__file__).parent.parent)) 12 | from warm import engine 13 | 14 | 15 | def test_set_get_default_parent(): 16 | a = nn.Identity() 17 | b = nn.Identity() 18 | engine.set_default_parent(a) 19 | assert engine.get_default_parent() is a, 'get_default_parent result mismatchs set_default_parent.' 20 | engine.set_default_parent(b) 21 | assert engine.get_default_parent() is b, 'get_default_parent result mismatchs set_default_parent.' 22 | 23 | 24 | def test_auto_name(): 25 | a = nn.Identity() 26 | for i in range(10): 27 | assert engine._auto_name('test', a) == f'test_{i+1}', 'new calls to _auto_name failed to increment name count.' 28 | a(None) # test if forward pre hook is triggered to reset names 29 | assert engine._auto_name('test', a) == 'test_1', 'forward_pre_hook did not work.' 30 | 31 | 32 | def test_initialize(): 33 | a = nn.Parameter(torch.zeros(3, 4)) 34 | b = nn.Parameter(torch.zeros(3, 4)) 35 | c = nn.Parameter(torch.zeros(3, 4)) 36 | torch.manual_seed(1) 37 | engine.initialize_(a, 'normal_') 38 | torch.manual_seed(1) 39 | nn.init.normal_(b) 40 | assert torch.equal(a, b), 'initialize_ with str spec did not work correctly.' 41 | assert not torch.equal(a, c), 'initialize_ with str spec did not work.' 42 | torch.manual_seed(1) 43 | engine.initialize_(c, nn.init.normal_) 44 | assert torch.equal(a, c), 'initialize_ with function spec did not work correctly.' 45 | 46 | 47 | def test_activate(): 48 | a = torch.randn(3, 4) 49 | b = copy.deepcopy(a) 50 | a = engine.activate(a, 'hardshrink') 51 | b = F.hardshrink(b) 52 | assert torch.equal(a, b), 'activate with str spec did not work correctly.' 53 | a = engine.activate(a, 'relu') 54 | b = F.relu(b) 55 | assert torch.equal(a, b), 'activate with str spec did not work correctly.' 56 | 57 | 58 | def test_permute(): 59 | x = torch.randn(1, 2, 3) 60 | y = engine.permute(x, 'BCD', 'DCB') 61 | assert list(y.shape) == [3, 2, 1], 'permute 3d tensor with str in_shape and str out_shape did not work correctly.' 62 | y = engine.permute(x, 'BCD', None) 63 | assert list(y.shape) == [1, 2, 3], 'permute tensor with None out_shape did not work corretly.' 64 | y = engine.permute(x, 'BCD', [1, 0, 2]) 65 | assert list(y.shape) == [2, 1, 3], 'permute tensor with list out_shape did not work corretly.' 66 | x = torch.randn(1, 2, 3, 4) 67 | y = engine.permute(x, 'BCD', 'DCB') 68 | assert list(y.shape) == [3, 4, 2, 1], 'permute 4d tensor with str in_shape and str out_shape did not work correctly.' 69 | y = engine.permute(x, 'DBC', 'CDB') 70 | assert list(y.shape) == [4, 1, 2, 3], 'permute 4d tensor with str in_shape and str out_shape did not work correctly.' 71 | x = torch.randn(1, 2, 3, 4, 5) 72 | y = engine.permute(x, 'BDC', 'BCD') 73 | assert list(y.shape) == [1, 5, 2, 3, 4], 'permute 5d tensor with str in_shape and str out_shape did not work correctly.' 74 | x = torch.randn(1, 2) 75 | y = engine.permute(x, 'BDC', 'BCD') 76 | assert list(y.shape) == [1, 2], 'permute 2d tensor with str in_shape and str out_shape did not work correctly.' 77 | y = engine.permute(x, 'CBD', 'DBC') 78 | assert list(y.shape) == [2, 1], 'permute 2d tensor with str in_shape and str out_shape did not work correctly.' 79 | 80 | 81 | def test_unused_kwargs(): 82 | kw = {'unused1':0, 'unused2':0, 'base_class':0} 83 | unused = engine.unused_kwargs(kw) 84 | assert 'base_class' not in unused, 'unused_kwargs leaks used.' 85 | assert set(unused.keys()) == {'unused1', 'unused2'}, 'unused_kwargs did not filter kw correctly.' 86 | 87 | 88 | def test_prepare_model_is_ready(): 89 | class TestModel(nn.Module): 90 | def forward(self, x): 91 | x = engine.forward(x, nn.Linear, 'linear', 92 | base_arg=(x.shape[-1], 4, False), # in_features, out_features, bias 93 | in_shape=None, out_shape=None, base_shape=None, 94 | initialization={'weight':'ones_'}, activation=(F.dropout, {'p':1.0}), ) 95 | return x 96 | x = torch.randn(1, 2, 3) 97 | m = TestModel() 98 | assert not engine.is_ready(m), 'is_ready did not work correctly.' 99 | engine.prepare_model_(m, x) 100 | assert engine.is_ready(m), 'prepare_model_ did not work correctly.' 101 | assert m.linear_1.bias is None, 'linear_1 should not have bias.' 102 | assert torch.allclose(m.linear_1.weight, torch.Tensor([1.0])), 'linear_1.weight should be initialized to all 1s.' 103 | y = m(x) 104 | assert torch.allclose(y, torch.Tensor([0.0])), 'y should be all 0s because we dropout everything.' 105 | assert list(y.shape) == [1, 2, 4], 'y should have shape [1, 2, 4] after linear projection.' 106 | 107 | 108 | def test_forward(): 109 | x = torch.randn(1, 2, 3) 110 | m = nn.Module() 111 | engine.set_default_parent(m) 112 | class TripleOut(nn.Module): # to test tuple_out 113 | def forward(self, x, b=1, c='2'): 114 | return x+b, x, c 115 | y = engine.forward(x, base_class=TripleOut, base_name='tri', tuple_out=False) 116 | assert isinstance(y, torch.Tensor), 'tuple_out did not work correctly.' 117 | y = engine.forward(x, base_class=TripleOut, base_name='tri', tuple_out=True) 118 | assert isinstance(y, tuple) and len(y) == 3 and y[-1] == '2', 'tuple_out did not work correctly.' 119 | y = engine.forward(x, base_class=TripleOut, base_name='tri', forward_kw={'c':3}, tuple_out=True) 120 | assert y[-1] == 3, 'forward_kw did not work correctly.' 121 | y = engine.forward(x, base_class=TripleOut, base_name='tri', forward_arg=(2.0,)) 122 | assert torch.allclose(y-x, torch.Tensor([2.0])), 'forward_arg did not work correctly.' 123 | y = engine.forward(x, base_class=TripleOut, activation=(F.dropout, {'p':1.0})) 124 | assert torch.allclose(y, torch.Tensor([0.0])), 'activation did not work correctly.' 125 | y = engine.forward( 126 | x, base_class=nn.Linear, base_kw={'out_features':4}, infer_kw={'in_features':'C'}, base_shape='BDC') 127 | assert y.shape[1] == 4, 'base_kw, infer_kw did not work correctly.' 128 | 129 | 130 | def test_namespace(): 131 | m = nn.Module() 132 | engine.set_default_parent(m) 133 | @engine.namespace 134 | def f1(name=''): 135 | return ';'.join([f2(name=name) for i in range(2)]) 136 | @engine.namespace 137 | def f2(name=''): 138 | return name 139 | s0, s1, s2 = [f1() for i in range(3)] 140 | assert s0 == 'f1_1-f2_1;f1_1-f2_2' 141 | assert s1 == 'f1_2-f2_1;f1_2-f2_2' 142 | assert s2 == 'f1_3-f2_1;f1_3-f2_2' 143 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | [![PyWarm - A cleaner way to build neural networks for PyTorch](https://github.com/blue-season/pywarm/raw/gh-pages/docs/pywarm-logo.png)](https://blue-season.github.io/pywarm/) 3 | 4 | # PyWarm 5 | 6 | A cleaner way to build neural networks for PyTorch. 7 | 8 | [![PyPI Python Version](https://img.shields.io/pypi/pyversions/pywarm)](https://github.com/blue-season/pywarm) 9 | [![PyPI Version](https://img.shields.io/pypi/v/pywarm)](https://pypi.org/project/pywarm/) 10 | [![License](https://img.shields.io/github/license/blue-season/pywarm)](https://github.com/blue-season/pywarm/blob/master/LICENSE) 11 | 12 | [Examples](https://blue-season.github.io/pywarm/docs/example/) | [Tutorial](https://blue-season.github.io/pywarm/docs/tutorial/) | [API reference](https://blue-season.github.io/pywarm/reference/warm/functional/) 13 | 14 | ---- 15 | 16 | ## Introduction 17 | 18 | PyWarm is a lightweight, high-level neural network construction API for PyTorch. 19 | It enables defining all parts of NNs in the functional way. 20 | 21 | With PyWarm, you can put *all* network data flow logic in the `forward()` method of 22 | your model, without the need to define children modules in the `__init__()` method 23 | and then call it again in the `forward()`. 24 | This result in a much more readable model definition in fewer lines of code. 25 | 26 | PyWarm only aims to simplify the network definition, and does not attempt to cover 27 | model training, validation or data handling. 28 | 29 | ---- 30 | 31 | For example, a convnet for MNIST: 32 | (If needed, click the tabs to switch between Warm and Torch versions) 33 | 34 | 35 | ``` Python tab="Warm" linenums="1" 36 | # powered by PyWarm 37 | import torch.nn as nn 38 | import torch.nn.functional as F 39 | import warm 40 | import warm.functional as W 41 | 42 | 43 | class ConvNet(nn.Module): 44 | 45 | def __init__(self): 46 | super().__init__() 47 | warm.up(self, [2, 1, 28, 28]) 48 | 49 | def forward(self, x): 50 | x = W.conv(x, 20, 5, activation='relu') 51 | x = F.max_pool2d(x, 2) 52 | x = W.conv(x, 50, 5, activation='relu') 53 | x = F.max_pool2d(x, 2) 54 | x = x.view(-1, 800) 55 | x = W.linear(x, 500, activation='relu') 56 | x = W.linear(x, 10) 57 | return F.log_softmax(x, dim=1) 58 | ``` 59 | 60 | ``` Python tab="Torch" linenums="1" 61 | # vanilla PyTorch version, taken from 62 | # pytorch tutorials/beginner_source/blitz/neural_networks_tutorial.py 63 | import torch.nn as nn 64 | import torch.nn.functional as F 65 | 66 | 67 | class ConvNet(nn.Module): 68 | 69 | def __init__(self): 70 | super().__init__() 71 | self.conv1 = nn.Conv2d(1, 20, 5, 1) 72 | self.conv2 = nn.Conv2d(20, 50, 5, 1) 73 | self.fc1 = nn.Linear(4*4*50, 500) 74 | self.fc2 = nn.Linear(500, 10) 75 | 76 | def forward(self, x): 77 | x = F.relu(self.conv1(x)) 78 | x = F.max_pool2d(x, 2, 2) 79 | x = F.relu(self.conv2(x)) 80 | x = F.max_pool2d(x, 2, 2) 81 | x = x.view(-1, 4*4*50) 82 | x = F.relu(self.fc1(x)) 83 | x = self.fc2(x) 84 | return F.log_softmax(x, dim=1) 85 | ``` 86 | 87 | ---- 88 | 89 | A couple of things you may have noticed: 90 | 91 | - First of all, in the PyWarm version, the entire network definition and 92 | data flow logic resides in the `forward()` method. You don't have to look 93 | up and down repeatedly to understand what `self.conv1`, `self.fc1` etc. 94 | is doing. 95 | 96 | - You do not need to track and specify `in_channels` (or `in_features`, etc.) 97 | for network layers. PyWarm can infer the information for you. e.g. 98 | 99 | ```Python 100 | # Warm 101 | x = W.conv(x, 20, 5, activation='relu') 102 | x = W.conv(x, 50, 5, activation='relu') 103 | 104 | 105 | # Torch 106 | self.conv1 = nn.Conv2d(1, 20, 5, 1) 107 | self.conv2 = nn.Conv2d(20, 50, 5, 1) 108 | ``` 109 | 110 | - One unified `W.conv` for all 1D, 2D, and 3D cases. Fewer things to keep track of! 111 | 112 | - `activation='relu'`. All `warm.functional` APIs accept an optional `activation` keyword, 113 | which is basically equivalent to `F.relu(W.conv(...))`. The keyword `activation` can also 114 | take in a callable, for example `activation=torch.nn.ReLU(inplace=True)` or `activation=swish`. 115 | 116 | For deeper neural networks, see additional [examples](https://blue-season.github.io/pywarm/docs/example/). 117 | 118 | ---- 119 | ## Installation 120 | 121 | pip3 install pywarm 122 | 123 | ---- 124 | ## Quick start: 30 seconds to PyWarm 125 | 126 | If you already have experinces with PyTorch, using PyWarm is very straightforward: 127 | 128 | - First, import PyWarm in you model file: 129 | ```Python 130 | import warm 131 | import warm.functional as W 132 | ``` 133 | 134 | - Second, remove child module definitions in the model's `__init__()` method. 135 | In stead, use `W.conv`, `W.linear` ... etc. in the model's `forward()` method, 136 | just like how you would use torch nn functional `F.max_pool2d`, `F.relu` ... etc. 137 | 138 | For example, instead of writing: 139 | 140 | ```Python 141 | # Torch 142 | class MyModule(nn.Module): 143 | def __init__(self): 144 | super().__init__() 145 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size) 146 | # other child module definitions 147 | def forward(self, x): 148 | x = self.conv1(x) 149 | # more forward steps 150 | ``` 151 | 152 | - You can now write in the warm way: 153 | 154 | ```Python 155 | # Warm 156 | class MyWarmModule(nn.Module): 157 | def __init__(self): 158 | super().__init__() 159 | warm.up(self, input_shape_or_data) 160 | def forward(self, x): 161 | x = W.conv(x, out_channels, kernel_size) # no in_channels needed 162 | # more forward steps 163 | ``` 164 | 165 | - Finally, don't forget to warmify the model by adding 166 | 167 | `warm.up(self, input_shape_or_data)` 168 | 169 | at the end of the model's `__init__()` method. You need to supply 170 | `input_shape_or_data`, which is either a tensor of input data, 171 | or just its shape, e.g. `[2, 1, 28, 28]` for MNIST inputs. 172 | 173 | The model is now ready to use, just like any other PyTorch models. 174 | 175 | Check out the [tutorial](https://blue-season.github.io/pywarm/docs/tutorial/) 176 | and [examples](https://blue-season.github.io/pywarm/docs/example/) if you want to learn more! 177 | 178 | ---- 179 | ## Testing 180 | 181 | Clone the repository first, then 182 | 183 | cd pywarm 184 | pytest -v 185 | 186 | ---- 187 | ## Documentation 188 | 189 | Documentations are generated using the excellent [Portray](https://timothycrosley.github.io/portray/) package. 190 | 191 | - [Examples](https://blue-season.github.io/pywarm/docs/example/) 192 | 193 | - [Tutorial](https://blue-season.github.io/pywarm/docs/tutorial/) 194 | 195 | - [API reference](https://blue-season.github.io/pywarm/reference/warm/functional/) 196 | -------------------------------------------------------------------------------- /docs/tutorial.md: -------------------------------------------------------------------------------- 1 | 2 | # PyWarm Basic Tutorial 3 | 4 | ## Import 5 | 6 | To get started, first import PyWarm in your project: 7 | 8 | ```Python 9 | import warm 10 | import warm.functional as W 11 | ``` 12 | 13 | ## Rewrite 14 | 15 | Now you can replace child module definitions with function calls. 16 | For example, instead of: 17 | 18 | ```Python 19 | # Torch 20 | class MyModule(nn.Module): 21 | 22 | def __init__(self): 23 | super().__init__() 24 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size) 25 | # other child module definitions 26 | 27 | def forward(self, x): 28 | x = self.conv1(x) 29 | # more forward steps 30 | ``` 31 | 32 | You now use the warm functions: 33 | 34 | ```Python 35 | # Warm 36 | class MyWarmModule(nn.Module): 37 | 38 | def __init__(self): 39 | super().__init__() 40 | warm.up(self, input_shape_or_data) 41 | 42 | def forward(self, x): 43 | x = W.conv(x, out_channels, kernel_size) # no in_channels needed 44 | # more forward steps 45 | ``` 46 | 47 | Notice the `warm.up(self, input_shape_or_data)` at the end of the `__init__()` method. 48 | It is required so that PyWarm can infer all shapes of itermediate steps and set up trainable parameters. 49 | The only argument `input_shape_or_data` can either be a tensor, e.g. `torch.randn(2, 1, 28, 28)`, 50 | or just the shape, e.g. `[2, 1, 28, 28]` for the model inputs. If the model has multiple inputs, 51 | you may supple them in a list or a dictionary. 52 | 53 | Although it is recommended that you attach `warm.up()` to the end of the `__init__()` of your model, you can actually 54 | use it on the class instances outside of the definition, like a normal function call: 55 | 56 | ```Python 57 | class MyWarmModule(nn.Module): 58 | 59 | def __init__(self): 60 | super().__init__() # no warm.up here 61 | 62 | def forward(self, x): 63 | x = W.conv(x, 10, 3) 64 | # forward step, powered by PyWarm 65 | 66 | 67 | model = MyWarmModule() # call warm.up outside of the module definition 68 | 69 | warm.up(model, [2, 1, 28, 28]) 70 | ``` 71 | 72 | **Note**: If the model contains `batch_norm` layers, you need to specify the `Batch` dimension to at least 2. 73 | 74 | # Advanced Topics 75 | 76 | ## Default shapes 77 | 78 | PyWarm has a unified functional interface, that by default all functions accept and return tensors with shape 79 | `(Batch, Channel, *)`, where `*` is any number of additional dimensions. For example, for 2d images, 80 | the `*` usually stands for `(Height, Width)`, and for 1d time series, the `*` means `(Time,)`. 81 | 82 | This convention is optimized for the performance of Convolutional networks. It may become less efficient if your 83 | model relies heavily on dense (Linear) or recurrent (LSTM, GRU) layers. You can use different input and 84 | output shapes by specifying `in_shape`, `out_shape` keyword arguments in the function calls. These keywords 85 | accept only letters `'B'`, `'C'` and `'D'`, which stand for `Batch`, `Channel`, and `*` (extra Dimensions) 86 | respectively. So for example if for a 1d time series you want to have `(Time, Batch, Channel)` as the output shape, 87 | you can specify `out_shape='DBC'`. 88 | 89 | ## Dimensional awareness 90 | 91 | PyWarm functions can automatically identify 1d, 2d and 3d input data, so the same function can be used on different 92 | dimensional cases. For example, the single `W.conv` is enough to replace `nn.Conv1d, nn.Conv2d, nn.Conv3d`. 93 | Similarly, you don't need `nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d` for differnt inputs, a single `W.batch_norm` 94 | can replace them all. 95 | 96 | ## Shape inference 97 | 98 | Many neural network layers will perform a transformation of shapes. For example, after a convolution operation, 99 | the shape is changed from `(Batch, ChannelIn, *)` to `(Batch, ChannelOut, *)`. PyTorch nn Modules require the user to 100 | keep track of both `in_channels` and `out_channels`. PyWarm relieves this pain by inferring the `in_channels` for you, 101 | so you can focus more on the nature of your tasks, rather than chores. 102 | 103 | ## Argument passdown 104 | 105 | If the signature of a PyWarm function does not specify all possible argument of its torch nn Module couterpart, it will pass down 106 | additional keyword arguments to the underlying nn Module. For example, if you want to specify strides of 2 for a conv layer, 107 | just use `W.conv(..., stride=2)`. The only thing to remember is that you have to specify the full keyword, instead of 108 | relying on the position of arguments. 109 | 110 | ## Parameter initialization per layer 111 | 112 | Unlike PyTorch's approach, paramter initialization can be specified directly in PyWarm's functional interface. 113 | For example: 114 | 115 | ```Python 116 | x = W.conv(x, 20, 1, init_weight='kaiming_uniform_') 117 | ``` 118 | This makes it easier to create layer specific initialization in PyWarm. You no long need to go through 119 | `self.modules()` and `self.parameters()` to create custom initializations. 120 | 121 | By default, PyWarm will look into `torch.nn.init` for initialization function names. 122 | Alternatively, you may just specify a callable, or a tuple `(fn, kwargs)` if the callable accepts more than 1 input. 123 | 124 | If the initialization is not specified or `None` is used, the corresponding layer will get default initializations as used 125 | in torch nn modules. 126 | 127 | ## Apply activation nonlinearity to the output 128 | 129 | PyWarm's functional interface supports adding an optional keyword argument `activation=name`, where 130 | name is a callable or just its name, which represents an activation (nonlinearity) function 131 | in `torch.nn.functional` or just `torch`. By default no activation is used. 132 | 133 | ## Mix and Match 134 | 135 | You are not limited to only use PyWarm's functional interface. It is completely ok to mix and match the old 136 | PyTorch way of child module definitions with PyWarm's function API. For example: 137 | 138 | ```Python 139 | class MyModel(nn.Module): 140 | 141 | def __init__(self): 142 | super().__init__() 143 | # other stuff 144 | self.conv1 = nn.Conv2d(2, 30, 7, padding=3) 145 | # other stuff 146 | 147 | def forward(self, x): 148 | y = F.relu(self.conv1(x)) 149 | y = W.conv(y, 40, 3, activation='relu') 150 | ``` 151 | 152 | ## Custom layer names 153 | 154 | Normally you do not have to specify layer names when using the functional API. 155 | PyWarm will track and count usage for the layer type and automatically assign names for you. For example, 156 | subsequent convolutional layer calls via `W.conv` will create `conv_1`, `conv_2`, ... etc. in the parent module. 157 | 158 | Nevertheless, if you want to ensure certain layer have particular names, you can specify `name='my_name'` 159 | keyword arguments in the call. 160 | 161 | Alternatively, if you still want PyWarm to count usage and increment ordinal for you, but only want to customize 162 | the base type name, you can use `base_name='my_prefix'` keyword instead. The PyWarm modules will then have 163 | names like `my_prefix_1`, `my_prefix_2` in the parent module. 164 | 165 | See the PyWarm [resnet example in the examples folder](https://github.com/blue-season/pywarm/blob/master/examples/resnet.py) 166 | on how to use these features to load pre-trained model parameters into PyWarm models. 167 | -------------------------------------------------------------------------------- /warm/engine.py: -------------------------------------------------------------------------------- 1 | # 08-26-2019; 2 | """ 3 | PyWarm engine to the functional interface. 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | import numpy as np 8 | from warm import util 9 | 10 | 11 | _DEFAULT_PARENT_MODULE = None 12 | 13 | 14 | def set_default_parent(parent): 15 | """ Set the default `parent` module. """ 16 | global _DEFAULT_PARENT_MODULE 17 | _DEFAULT_PARENT_MODULE = parent 18 | 19 | 20 | def get_default_parent(): 21 | """ Get the default `parent` module. """ 22 | global _DEFAULT_PARENT_MODULE 23 | return _DEFAULT_PARENT_MODULE 24 | 25 | 26 | def _auto_name(name, parent): 27 | """ Track the count of reference to `name` from `parent`. """ 28 | if not is_ready(parent): 29 | parent._pywarm_auto_name_dict = {} 30 | def _hook(model, x): 31 | model._pywarm_auto_name_dict = {} 32 | parent._pywarm_forward_pre_hook = parent.register_forward_pre_hook(_hook) 33 | track = parent._pywarm_auto_name_dict 34 | if name not in track: 35 | track[name] = 0 36 | track[name] += 1 37 | return f'{name}_{track[name]}' 38 | 39 | 40 | def prepare_model_(model, *data, device='cpu'): 41 | """ Initialize all childen modules defined by `warm` in a parent `model`.\n 42 | - `model: Module`; The parent model to be prepared. 43 | - `data: Tensor, or list of int`; A batch of data with the correct shape and type to be forwarded by model. 44 | `data` can also be a list of `int`, in which case it is interpreted as the shape of the input data. 45 | - `device: str, or torch.device`; Should be the same for `model` and `data`. Default: `'cpu'`. 46 | - `return: Module`; The prepared model, with all children modules defined by `warm` initialized. """ 47 | _auto_name('', model) 48 | set_default_parent(model) 49 | def _prep_data(d): 50 | if isinstance(d, (np.ndarray, torch.Tensor)): 51 | return torch.as_tensor(d).to(device) 52 | elif isinstance(d, (list, tuple)): 53 | if all(isinstance(x, int) for x in d): 54 | return torch.randn(*d, device=device) 55 | return [_prep_data(x) for x in d] 56 | elif isinstance(d, dict): 57 | return {k:_prep_data(v) for k, v in d.items()} 58 | with torch.no_grad(): 59 | is_training = model.training 60 | data = [_prep_data(d) for d in data] 61 | model.eval() 62 | model.to(device) 63 | model(*data) 64 | model.train(is_training) 65 | return model 66 | 67 | 68 | def is_ready(model): 69 | """ Check if a `model` is prepared. """ 70 | return hasattr(model, '_pywarm_forward_pre_hook') 71 | 72 | 73 | def activate(x, spec, lookup=None): 74 | """ Activate tensors with given nonlinearity `spec`ification.\n 75 | - `x: Tensor or list of Tensor`; The tensors to be initialized. 76 | - `spec: str or callable or 2-tuple`; If a `str`, should be one of the nonlinearity functions contained in 77 | `torch.nn.functional` or `torch`. If a `callable`, it will be applied to `x` directly, i.e. `spec(x)`. 78 | If a 2-`tuple`, it must be of format `(callable, kwargs)`, i.e. `callable(x, **kwargs)`. 79 | - `lookup: None or list of module`; Parent modules to look for `spec`. If `None`, `[nn.functional, torch]` is used. 80 | - `return: Tensor or list of Tensor`; Activation results. """ 81 | if spec is None: 82 | return x 83 | lookup = lookup or [nn.functional, torch] 84 | if isinstance(spec, str): 85 | for look in lookup: 86 | try: 87 | spec = getattr(look, spec) 88 | break 89 | except: 90 | pass 91 | if isinstance(spec, str): 92 | raise ValueError(f'Unknown spec {spec}.') 93 | if callable(spec): 94 | spec = (spec, {}) 95 | fn, kw = spec 96 | if isinstance(x, (list, tuple)): 97 | return [fn(y, **kw) for y in x] 98 | return fn(x, **kw) 99 | 100 | 101 | def initialize_(x, spec): 102 | """ Initialize parameters with given nonlinearity `spec`ification.\n 103 | - `x: Tensor or list of Tensor`; The tensors to be initialized. 104 | - `spec: str or callable or 2-tuple`; If a `str`, should be one of the nonlinearity functions contained in 105 | `torch.nn.init`. If a `callable`, it will be applied to `x` directly, i.e. `spec(x)`. If a 2-`tuple`, 106 | it must be of format `(callable, kwargs)`, i.e. `callable(x, **kwargs)`. """ 107 | activate(x, spec, lookup=[nn.init]) 108 | 109 | 110 | def permute(x, in_shape='BCD', out_shape='BCD', **kw): 111 | """ Permute the dimensions of a tensor.\n 112 | - `x: Tensor`; The nd-tensor to be permuted. 113 | - `in_shape: str`; The dimension shape of `x`. Can only have characters `'B'` or `'C'` or `'D'`, 114 | which stand for Batch, Channel, or extra Dimensions. The default value `'BCD'` means 115 | the input tensor `x` should be at lest 2-d with shape `(Batch, Channel, Dim0, Dim1, Dim2, ...)`, 116 | where `Dim0, Dim1, Dim2 ...` stand for any number of extra dimensions. 117 | - `out_shape: str or tuple or None`; The dimension shape of returned tensor. Default: `'BCD'`. 118 | If a `str`, it is restricted to the same three characters `'B'`, `'C'` or `'D'` as the `in_shape`. 119 | If a `tuple`, `in_shape` is ignored, and simply `x.permute(out_shape)` is returned. 120 | If `None`, no permution will be performed. 121 | - `return: Tensor`; Permuted nd-tensor. """ 122 | if (in_shape == out_shape) or (out_shape is None): 123 | return x 124 | if isinstance(out_shape, (list, tuple, torch.Size)): 125 | return x.permute(*out_shape) 126 | if isinstance(in_shape, str) and isinstance(out_shape, str) : 127 | assert set(in_shape) == set(out_shape) <= {'B', 'C', 'D'}, 'In and out shapes must have save set of chars among B, C, and D.' 128 | in_shape = in_shape.lower().replace('d', '...') 129 | out_shape = out_shape.lower().replace('d', '...') 130 | return torch.einsum(f'{in_shape}->{out_shape}', x) 131 | return x 132 | 133 | 134 | def unused_kwargs(kw): 135 | """ Filter out entries used by `forward` and return the rest. """ 136 | fn_kw = dict(base_class=None, 137 | base_name=None, name=None, base_arg=None, base_kw=None, parent=None, 138 | infer_kw=None, in_shape='BCD', base_shape=None, out_shape='BCD', tuple_out=False, 139 | forward_arg=None, forward_kw=None, initialization=None, activation=None, ) 140 | return {k:v for k, v in kw.items() if k not in fn_kw} 141 | 142 | 143 | def forward(x, base_class, 144 | base_name=None, name=None, base_arg=None, base_kw=None, parent=None, 145 | infer_kw=None, in_shape='BCD', base_shape='BCD', out_shape='BCD', tuple_out=False, 146 | forward_arg=None, forward_kw=None, initialization=None, activation=None, **kw): 147 | """ A forward template that creates child instances at the first time it is called.\n 148 | - `x: Tensor`; The nd-tensor to be forwarded. 149 | - `base_class: Module`; A child `torch.nn.Module` that will be created at the first time this function is called. 150 | - `base_name: str`; Name for the `base_class`. Default: base_class name. 151 | - `name: str`; Name for the child module instance. Default: class name plus ordinal. 152 | - `base_arg: tuple`; Positional args to be passed to create the child module instance. Default: None. 153 | - `base_kw: dict`; KWargs to be passed to create the child module instance. Default: None. 154 | - `parent: Module`; The parent of the child instance. Default: None. If `None`, will use `get_default_parent`. 155 | - `infer_kw: dict`; Key should be valid for the child instance. Value shoud be a character, 156 | one of `'B'`, `'C'`, or `'D'` (see `permute`), to substitute for a dimension of `x`. Default: None. 157 | - `in_shape: str`; The dimension shape of `x`. See also `permute`. Default: `'BCD'`. 158 | - `base_shape: str`; The dimension shape required by the child module. See also `permute`. Default: `'BCD'`. 159 | - `out_shape: str or tuple or None`; The dimension shape of returned tensor. See also `permute`. Default: `'BCD'`. 160 | - `tuple_out: bool`; Whether the child module will return more than 1 outputs (e.g. `nn.RNN`). 161 | If `True`, the returned value of the function will be a tuple containing all outputs. Default: False. 162 | - `forward_arg: tuple`; positional args to be passed when calling the child module instance. Default: None. 163 | - `forward_kw: dict`; KWargs to be passed when calling the child module instance. Default: None. 164 | - `initialization: dict`; Keys are name of parameters to initialize. Values are init specs, which can be 165 | a, `str`, a `callable`, or `2-tuple`; See the `spec` argument of `initialize_` for details. Default: None. 166 | - `activation: str or callable or 2-tuple`; See the `spec` argument of `activate`. Default: None. 167 | - `return: Tensor or tuple`; If `tuple_out` is `True`, the returned value will be a `tuple`. """ 168 | parent = parent or get_default_parent() 169 | if name is None: 170 | base_name = base_name or util.camel_to_snake(base_class.__name__) 171 | name = _auto_name(base_name, parent) 172 | if name not in parent._modules: 173 | if infer_kw is not None: 174 | shape = in_shape 175 | if 'D' in shape: 176 | shape = list(shape) 177 | shape[shape.index('D')] = 'D'*(x.ndim-len(shape)+1) 178 | shape = ''.join(shape) 179 | infer_kw = { 180 | k:x.shape[shape.find(v) if isinstance(v, str) else v] 181 | for k, v in infer_kw.items()} 182 | base = base_class(*(base_arg or []), **(infer_kw or {}), **(base_kw or {}), ) 183 | parent.add_module(name, base) 184 | if initialization is not None: 185 | s = parent.state_dict() 186 | for k, v in initialization.items(): 187 | initialize_(s[name+'.'+k], v) 188 | x = permute(x, in_shape, base_shape) 189 | y = parent._modules[name](x, *(forward_arg or []), **(forward_kw or {})) 190 | r = [] 191 | if isinstance(y, tuple): 192 | y, *r = y 193 | y = permute(y, base_shape, out_shape) 194 | y = activate(y, activation) 195 | if tuple_out: 196 | return (y, *r) 197 | return y 198 | 199 | 200 | import functools 201 | def namespace(f): 202 | """ After decoration, the function name and call count will be appended to the `name` kw. """ 203 | @functools.wraps(f) 204 | def _wrapped(*arg, **kw): 205 | parent = kw.get('parent', get_default_parent()) 206 | name = kw.get('name', '') 207 | name = '_warmns_' + name + ('-' if name else '') + f.__name__ 208 | name = _auto_name(name, parent) 209 | kw['name'] = name.replace('_warmns_', '') 210 | return f(*arg, **kw) 211 | return _wrapped 212 | -------------------------------------------------------------------------------- /docs/example.md: -------------------------------------------------------------------------------- 1 | 2 | # PyWarm Examples 3 | 4 | ## ResNet 5 | 6 | A more detailed example, the ResNet18 network defined in PyWarm and vanilla PyTorch: 7 | 8 | ``` Python tab="Warm" linenums="1" 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import warm 13 | import warm.functional as W 14 | 15 | 16 | def basic(x, size, stride): 17 | y = W.conv(x, size, 3, stride=stride, padding=1, bias=False) 18 | y = W.batch_norm(y, activation='relu') 19 | y = W.conv(y, size, 3, stride=1, padding=1, bias=False) 20 | y = W.batch_norm(y) 21 | if y.shape[1] != x.shape[1]: # channel size mismatch, needs projection 22 | x = W.conv(x, y.shape[1], 1, stride=stride, bias=False) 23 | x = W.batch_norm(x) 24 | y = y+x # residual shortcut connection 25 | return F.relu(y) 26 | 27 | 28 | def stack(x, num_block, size, stride, block=basic): 29 | for s in [stride]+[1]*(num_block-1): 30 | x = block(x, size, s) 31 | return x 32 | 33 | 34 | class ResNet(nn.Module): 35 | 36 | def __init__(self, block=basic, 37 | stack_spec=((2, 64, 1), (2, 128, 2), (2, 256, 2), (2, 512, 2))): 38 | super().__init__() 39 | self.block = block 40 | self.stack_spec = stack_spec 41 | warm.up(self, [2, 3, 32, 32]) 42 | 43 | def forward(self, x): 44 | y = W.conv(x, 64, 7, stride=2, padding=3, bias=False) 45 | y = W.batch_norm(y, activation='relu') 46 | y = F.max_pool2d(y, 3, stride=2, padding=1) 47 | for spec in self.stack_spec: 48 | y = stack(y, *spec, block=self.block) 49 | y = F.adaptive_avg_pool2d(y, 1) 50 | y = torch.flatten(y, 1) 51 | y = W.linear(y, 1000) 52 | return y 53 | 54 | 55 | resnet18 = ResNet() 56 | ``` 57 | 58 | ``` Python tab="Torch" linenums="1" 59 | # code based on torchvision/models/resnet.py 60 | import torch 61 | import torch.nn as nn 62 | import torch.nn.functional as F 63 | 64 | 65 | def conv3x3(size_in, size_out, stride=1): 66 | return nn.Conv2d(size_in, size_out, kernel_size=3, stride=stride, 67 | padding=1, groups=1, bias=False, dilation=1, ) 68 | 69 | 70 | def conv1x1(size_in, size_out, stride=1): 71 | return nn.Conv2d(size_in, size_out, kernel_size=1, stride=stride, 72 | padding=0, groups=1, bias=False, dilation=1, ) 73 | 74 | 75 | class BasicBlock(nn.Module): 76 | 77 | expansion = 1 78 | 79 | def __init__(self, size_in, size_out, stride=1, downsample=None): 80 | super().__init__() 81 | self.conv1 = conv3x3(size_in, size_out, stride) 82 | self.bn1 = nn.BatchNorm2d(size_out) 83 | self.relu = nn.ReLU(inplace=True) 84 | self.conv2 = conv3x3(size_out, size_out) 85 | self.bn2 = nn.BatchNorm2d(size_out) 86 | self.downsample = downsample 87 | 88 | def forward(self, x): 89 | identity = x 90 | y = self.conv1(x) 91 | y = self.bn1(y) 92 | y = self.relu(y) 93 | y = self.conv2(y) 94 | y = self.bn2(y) 95 | if self.downsample is not None: 96 | identity = self.downsample(x) 97 | y += identity 98 | y = self.relu(y) 99 | return y 100 | 101 | 102 | class ResNet(nn.Module): 103 | 104 | def __init__(self, 105 | block=BasicBlock, num_block=[2, 2, 2, 2]): 106 | super().__init__() 107 | self.size_in = 64 108 | self.conv1 = nn.Conv2d(3, self.size_in, kernel_size=7, stride=2, 109 | padding=3, bias=False) 110 | self.bn1 = nn.BatchNorm2d(self.size_in) 111 | self.relu = nn.ReLU(inplace=True) 112 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 113 | self.stack1 = self._make_stack(block, 64, num_block[0], 1) 114 | self.stack2 = self._make_stack(block, 128, num_block[1], 2) 115 | self.stack3 = self._make_stack(block, 256, num_block[2], 2) 116 | self.stack4 = self._make_stack(block, 512, num_block[3], 2) 117 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 118 | self.fc = nn.Linear(512, 1000) 119 | 120 | def _make_stack(self, block, size_out, num_blocks, stride): 121 | downsample = None 122 | if stride != 1 or self.size_in != size_out: 123 | downsample = nn.Sequential( 124 | conv1x1(self.size_in, size_out, stride), 125 | nn.BatchNorm2d(size_out), ) 126 | stacks = [] 127 | for stride in strides: 128 | stacks.append( 129 | block(self.size_in, size_out, stride, downsample)) 130 | self.size_in = size_out 131 | return nn.Sequential(*stacks) 132 | 133 | def forward(self, x): 134 | y = self.conv1(x) 135 | y = self.bn1(y) 136 | y = self.relu(y) 137 | y = self.maxpool(y) 138 | y = self.stack1(y) 139 | y = self.stack2(y) 140 | y = self.stack3(y) 141 | y = self.stack4(y) 142 | y = self.avg_pool(y) 143 | y = torch.flatten(y, 1) 144 | y = self.fc(y) 145 | return y 146 | 147 | 148 | resnet18 = ResNet() 149 | ``` 150 | 151 | - The PyWarm version significantly reduces self-repititions of code as in the vanilla PyTorch version. 152 | 153 | - Note that when warming the model via `warm.up(self, [2, 3, 32, 32])` 154 | We set the first `Batch` dimension to 2 because the model uses `batch_norm`, 155 | which will not work when `Batch` is 1. 156 | 157 | ---- 158 | 159 | ## MobileNet 160 | 161 | ``` Python tab="Warm" linenums="1" 162 | import torch 163 | import torch.nn as nn 164 | import torch.nn.functional as F 165 | import warm 166 | import warm.functional as W 167 | 168 | 169 | def conv_bn_relu(x, size, stride=1, expand=1, kernel=3, groups=1): 170 | x = W.conv(x, size, kernel, padding=(kernel-1)//2, 171 | stride=stride, groups=groups, bias=False, ) 172 | return W.batch_norm(x, activation='relu6') 173 | 174 | 175 | def bottleneck(x, size_out, stride, expand): 176 | size_in = x.shape[1] 177 | size_mid = size_in*expand 178 | y = conv_bn_relu(x, size_mid, kernel=1) if expand > 1 else x 179 | y = conv_bn_relu(y, size_mid, stride, kernel=3, groups=size_mid) 180 | y = W.conv(y, size_out, kernel=1, bias=False) 181 | y = W.batch_norm(y) 182 | if stride == 1 and size_in == size_out: 183 | y += x # residual shortcut 184 | return y 185 | 186 | 187 | def conv1x1(x, *arg): 188 | return conv_bn_relu(x, *arg, kernel=1) 189 | 190 | 191 | def pool(x, *arg): 192 | return x.mean([2, 3]) 193 | 194 | 195 | def classify(x, size, *arg): 196 | x = W.dropout(x, rate=0.2) 197 | return W.linear(x, size) 198 | 199 | 200 | default_spec = ( 201 | (None, 32, 1, 2, conv_bn_relu), # t, c, n, s, operator 202 | (1, 16, 1, 1, bottleneck), 203 | (6, 24, 2, 2, bottleneck), 204 | (6, 32, 3, 2, bottleneck), 205 | (6, 64, 4, 2, bottleneck), 206 | (6, 96, 3, 1, bottleneck), 207 | (6, 160, 3, 2, bottleneck), 208 | (6, 320, 1, 1, bottleneck), 209 | (None, 1280, 1, 1, conv1x1), 210 | (None, None, 1, None, pool), 211 | (None, 1000, 1, None, classify), ) 212 | 213 | 214 | class MobileNetV2(nn.Module): 215 | 216 | def __init__(self): 217 | super().__init__() 218 | warm.up(self, [2, 3, 224, 224]) 219 | 220 | def forward(self, x): 221 | for t, c, n, s, op in default_spec: 222 | for i in range(n): 223 | stride = s if i == 0 else 1 224 | x = op(x, c, stride, t) 225 | return x 226 | 227 | 228 | net = MobileNetV2() 229 | ``` 230 | 231 | ``` Python tab="Torch" linenums="1" 232 | # code based on torchvision/models/mobilenet.py 233 | import torch 234 | import torch.nn as nn 235 | import torch.nn.functional as F 236 | 237 | 238 | class ConvBNReLU(nn.Sequential): 239 | 240 | def __init__(self, in_planes, out_planes, 241 | kernel_size=3, stride=1, groups=1): 242 | padding = (kernel_size-1)//2 243 | super(ConvBNReLU, self).__init__( 244 | nn.Conv2d(in_planes, out_planes, kernel_size, 245 | stride, padding, groups=groups, bias=False), 246 | nn.BatchNorm2d(out_planes), 247 | nn.ReLU6(inplace=True), ) 248 | 249 | 250 | class BottleNeck(nn.Module): 251 | 252 | def __init__(self, inp, oup, stride, expand_ratio): 253 | super().__init__() 254 | self.stride = stride 255 | assert stride in [1, 2] 256 | hidden_dim = int(round(inp * expand_ratio)) 257 | self.use_res_connect = self.stride == 1 and inp == oup 258 | layers = [] 259 | if expand_ratio != 1: 260 | layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) 261 | layers.extend([ 262 | ConvBNReLU(hidden_dim, hidden_dim, 263 | stride=stride, groups=hidden_dim), 264 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 265 | nn.BatchNorm2d(oup), ]) 266 | self.conv = nn.Sequential(*layers) 267 | 268 | def forward(self, x): 269 | if self.use_res_connect: 270 | return x + self.conv(x) 271 | else: 272 | return self.conv(x) 273 | 274 | 275 | default_spec = [ 276 | [1, 16, 1, 1], # t, c, n, s 277 | [6, 24, 2, 2], 278 | [6, 32, 3, 2], 279 | [6, 64, 4, 2], 280 | [6, 96, 3, 1], 281 | [6, 160, 3, 2], 282 | [6, 320, 1, 1], ] 283 | 284 | 285 | class MobileNetV2(nn.Module): 286 | 287 | def __init__(self): 288 | super().__init__() 289 | input_channel = 32 290 | last_channel = 1280 291 | features = [ConvBNReLU(3, input_channel, stride=2)] 292 | for t, c, n, s in default_spec: 293 | output_channel = c 294 | for i in range(n): 295 | stride = s if i == 0 else 1 296 | features.append( 297 | BottleNeck( 298 | input_channel, output_channel, 299 | stride, expand_ratio=t)) 300 | input_channel = output_channel 301 | features.append(ConvBNReLU(input_channel, 302 | last_channel, kernel_size=1)) 303 | self.features = nn.Sequential(*features) 304 | self.classifier = nn.Sequential( 305 | nn.Dropout(0.2), 306 | nn.Linear(last_channel, 1000), ) 307 | 308 | def forward(self, x): 309 | x = self.features(x) 310 | x = x.mean([2, 3]) 311 | x = self.classifier(x) 312 | return x 313 | 314 | 315 | net = MobileNetV2() 316 | ``` 317 | 318 | ## Transformer 319 | 320 | ```Python 321 | """ 322 | The Transformer model from paper Attention is all you need. 323 | The Transformer instance accepts two inputs: 324 | x is Tensor with shape (Batch, Channel, LengthX). 325 | usually a source sequence from embedding (in such cases, 326 | Channel equals the embedding size). 327 | y is Tensor with shape (Batch, Channel, lengthY). 328 | usually a target sequence, also from embedding. 329 | **kw is passed down to inner components. 330 | """ 331 | import torch 332 | import torch.nn as nn 333 | import torch.nn.functional as F 334 | import warm 335 | import warm.functional as W 336 | 337 | 338 | def multi_head_attention(x, y=None, num_head=8, dropout=0.1, mask=None, **kw): 339 | def split_heads(t): 340 | return t.reshape(batch, num_head, size//num_head, t.shape[-1]) 341 | def merge_heads(t): 342 | return t.reshape(batch, -1, t.shape[-1]) 343 | if y is None: 344 | y = x # self attention 345 | batch, size = x.shape[:2] 346 | assert size%num_head == 0, 'num_head must be a divisor of size.' 347 | assert y.shape[:2] == x.shape[:2], 'The first 2 dims of x, y must match.' 348 | q = W.linear(x, size) # query 349 | k = W.linear(y, size) # key 350 | v = W.linear(y, size) # value 351 | q = split_heads(q) 352 | k = split_heads(k) 353 | v = split_heads(v) 354 | q *= (size//num_head)**(-0.5) 355 | a = q.transpose(2, 3).contiguous().matmul(k) # attention weights 356 | if mask is not None: 357 | a += mask 358 | a = F.softmax(a, dim=-1) 359 | a = W.dropout(a, dropout) 360 | x = v.matmul(a.transpose(2, 3).contiguous()) 361 | x = merge_heads(x) 362 | return W.linear(x, size) 363 | 364 | 365 | def feed_forward(x, size_ff=2048, dropout=0.1, **kw): 366 | y = W.linear(x, size_ff, activation='relu') 367 | y = W.dropout(y, dropout) 368 | return W.linear(y, x.shape[1]) 369 | 370 | 371 | def residual_add(x, layer, dropout=0.1, **kw): 372 | y = W.layer_norm(x) 373 | y = layer(y, **kw) 374 | y = W.dropout(y, dropout) 375 | return x+y 376 | 377 | 378 | def encoder(x, num_encoder=6, **kw): 379 | for i in range(num_encoder): 380 | x = residual_add(x, multi_head_attention, **kw) 381 | x = residual_add(x, feed_forward, **kw) 382 | return W.layer_norm(x) 383 | 384 | 385 | def decoder(x, y, num_decoder=6, mask_x=None, mask_y=None, **kw): 386 | for i in range(num_decoder): 387 | y = residual_add(y, multi_head_attention, mask=mask_y, **kw) 388 | y = residual_add(x, multi_head_attention, y=y, mask=mask_x, **kw) 389 | y = residual_add(y, feed_forward, **kw) 390 | return W.layer_norm(y) 391 | 392 | 393 | def transformer(x, y, **kw): 394 | x = encoder(x, **kw) 395 | x = decoder(x, y, **kw) 396 | return x 397 | 398 | 399 | class Transformer(nn.Module): 400 | 401 | def __init__(self, *shape, **kw): 402 | super().__init__() 403 | self.kw = kw 404 | warm.up(self, *shape) 405 | 406 | def forward(self, x, y): 407 | return transformer(x, y, **self.kw) 408 | 409 | ``` 410 | 411 | 412 | ## EfficientNet 413 | 414 | For a brief overview, check the [blog post](https://blue-season.github.io/efficientnet-in-5-minutes/). 415 | 416 | ```python 417 | """ 418 | EfficientNet model from https://arxiv.org/abs/1905.11946 419 | """ 420 | import torch 421 | import torch.nn as nn 422 | import torch.nn.functional as F 423 | import warm 424 | import warm.functional as W 425 | 426 | 427 | def swish(x): 428 | return x*torch.sigmoid(x) 429 | 430 | 431 | def squeeze_excitation(x, size_se): 432 | if size_se == 0: 433 | return x 434 | size_in = x.shape[1] 435 | x = F.adaptive_avg_pool2d(x, 1) 436 | x = W.conv(x, size_se, 1, activation=swish) 437 | return W.conv(x, size_in, 1, activation=swish) 438 | 439 | 440 | def drop_connect(x, rate): 441 | if rate == 0: 442 | return x 443 | rate = 1.0-rate 444 | drop_mask = rate + torch.rand([x.shape[0], 1, 1, 1], 445 | device=x.device, requires_grad=False) 446 | return x/rate*drop_mask.floor() 447 | 448 | 449 | def conv_pad_same(x, size, kernel=1, stride=1, **kw): 450 | """ Same padding so that out_size*stride == in_size. """ 451 | pad = 0 452 | if kernel != 1 or stride != 1: 453 | in_size, s, k = [torch.as_tensor(v) 454 | for v in (x.shape[2:], stride, kernel)] 455 | pad = torch.max(((in_size+s-1)//s-1)*s+k-in_size, torch.tensor(0)) 456 | left, right = pad//2, pad-pad//2 457 | if torch.all(left == right): 458 | pad = tuple(left.tolist()) 459 | else: 460 | left, right = left.tolist(), right.tolist() 461 | pad = sum(zip(left[::-1], right[::-1]), ()) 462 | x = F.pad(x, pad) 463 | pad = 0 464 | return W.conv(x, size, kernel, stride=stride, padding=pad, **kw) 465 | 466 | 467 | def conv_bn_act(x, size, kernel=1, stride=1, groups=1, 468 | bias=False, eps=1e-3, momentum=1e-2, act=swish): 469 | x = conv_pad_same(x, size, kernel, stride=stride, groups=groups, bias=bias) 470 | return W.batch_norm(x, eps=eps, momentum=momentum, activation=act) 471 | 472 | 473 | def mb_block(x, size_out, expand=1, kernel=1, stride=1, 474 | se_ratio=0.25, dc_ratio=0.2): 475 | """ Mobilenet Bottleneck Block. """ 476 | size_in = x.shape[1] 477 | size_mid = size_in*expand 478 | y = conv_bn_act(x, size_mid, 1) if expand > 1 else x 479 | y = conv_bn_act(y, size_mid, kernel, stride=stride, groups=size_mid) 480 | y = squeeze_excitation(y, int(size_in*se_ratio)) 481 | y = conv_bn_act(y, size_out, 1, act=None) 482 | if stride == 1 and size_in == size_out: 483 | y = drop_connect(y, dc_ratio) 484 | y += x 485 | return y 486 | 487 | 488 | spec_b0 = ( 489 | # size, expand, kernel, stride, repeat, squeeze_excitation, drop_connect 490 | (16, 1, 3, 1, 1, 0.25, 0.2), 491 | (24, 6, 3, 2, 2, 0.25, 0.2), 492 | (40, 6, 5, 2, 2, 0.25, 0.2), 493 | (80, 6, 3, 2, 3, 0.25, 0.2), 494 | (112, 6, 5, 1, 3, 0.25, 0.2), 495 | (192, 6, 5, 2, 4, 0.25, 0.2), 496 | (320, 6, 3, 1, 1, 0.25, 0.2), ) 497 | 498 | 499 | class WarmEfficientNet(nn.Module): 500 | def __init__(self): 501 | super().__init__() 502 | warm.up(self, [2, 3, 32, 32]) 503 | def forward(self, x): 504 | x = conv_bn_act(x, 32, kernel=3, stride=2) 505 | for size, expand, kernel, stride, repeat, se, dc in spec_b0: 506 | for i in range(repeat): 507 | stride = stride if i == 0 else 1 508 | x = mb_block(x, size, expand, kernel, stride, se, dc) 509 | x = conv_bn_act(x, 1280) 510 | x = F.adaptive_avg_pool2d(x, 1) 511 | x = W.dropout(x, 0.2) 512 | x = x.view(x.shape[0], -1) 513 | x = W.linear(x, 1000) 514 | return x 515 | ``` 516 | -------------------------------------------------------------------------------- /warm/functional.py: -------------------------------------------------------------------------------- 1 | # 08-27-2019; 2 | """ 3 | Wraps around various torch.nn Modules to fit into a functional interface. 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import numpy as np 9 | from warm import engine 10 | from warm import util 11 | 12 | 13 | permute = engine.permute 14 | 15 | 16 | def conv(x, size, kernel, init_weight=None, init_bias=None, bias=True, **kw): 17 | """ Convolution layer.\n 18 | - `x: Tensor`; With shape `(Batch, Channel, *)` where `*` Can be 1d or 2d or 3d. 19 | If 3d, shapes are `(Batch, Channel, Length)`. 20 | If 4d, shapes are `(Batch, Channel, Height, Width)`. 21 | If 5d, shapes are `(Batch, Channel, Depth, Height, Width)`. 22 | - `size: int`; Size of hidden filters, and size of the output channel. 23 | - `kernel: int or tuple`; Size of the convolution kernel. 24 | - `init_weight: None or str or callable`; Initialization specification for the weight tensor. 25 | If a `str`, should be one of the nonlinearity functions contained in `torch.nn.init`. 26 | If a `callable`, it will be applied to `x` directly, i.e. `spec(x)`. If a 2-`tuple`, 27 | it must be of format `(callable, kwargs)`, i.e. `callable(x, **kwargs)`. 28 | Default: `None`, and the weight tensor is initialized using `torch.nn.ConvNd`s default scheme. 29 | - `init_bias: None or str or callable`; Same as `init_weight`, but for the bias tensor. 30 | - `bias: bool`; If `True`, adds a learnable bias to the output. Default: `True`. 31 | - `**kw:dict`; Any additional KWargs are passed down to `torch.nn.ConvNd`, where N can be 1, 2 or 3. 32 | as well as `warm.engine.forward`. Refer to their docs for details. Some of the additional ConvNd arguments: 33 | `stride, padding, dilation, groups`. 34 | - `return: Tensor`; With shape `(Batch, Size, *)` where `*` can be 1d, 2d, 3d that depends on `x`. """ 35 | d = x.ndim-3 36 | assert d in [0, 1, 2], 'Incompatible number of dims for input x.' 37 | inferred_kw = dict( 38 | base_name='conv', 39 | base_class=[nn.Conv1d, nn.Conv2d, nn.Conv3d][d], 40 | base_kw={ 41 | 'out_channels':size, 42 | 'kernel_size':kernel, 43 | 'bias':bias, 44 | **engine.unused_kwargs(kw), }, 45 | infer_kw={'in_channels':'C'}, 46 | initialization={'weight':init_weight, **({'bias':init_bias} if bias else {})}, ) 47 | return engine.forward(x, **{**inferred_kw, **kw}) 48 | 49 | 50 | def linear(x, size, init_weight=None, init_bias=None, bias=True, **kw): 51 | """ Linear transformation layer.\n 52 | - `x: Tensor`; 2d or more, with shapes `(Batch, Channel, *)` where `*` means any number of additional dimensions. 53 | - `size: int`; Size of hidden features, and size of the output channel. 54 | - `init_weight: None or str or callable`; Initialization specification for the weight tensor. 55 | If a `str`, should be one of the nonlinearity functions contained in `torch.nn.init`. 56 | If a `callable`, it will be applied to `x` directly, i.e. `spec(x)`. If a 2-`tuple`, 57 | it must be of format `(callable, kwargs)`, i.e. `callable(x, **kwargs)`. 58 | Default: `None`, and the weight tensor is initialized using `torch.nn.Linear`s default scheme. 59 | - `init_bias: None or str or callable`; Same as `init_weight`, but for the bias tensor. 60 | - `bias: bool`; If `True`, adds a learnable bias to the output. Default: `True`. 61 | - `**kw:dict`; Any additional KWargs are passed down to `warm.engine.forward`. Refer to its docs for details. 62 | - `return: Tensor`; With shape `(Batch, Size, *)` where `*` can be 1d, 2d, 3d that depends on `x`. """ 63 | inferred_kw = dict( 64 | base_name='linear', 65 | base_class=nn.Linear, 66 | base_kw={'out_features':size, 'bias':bias}, 67 | base_shape='BDC', 68 | infer_kw={'in_features':'C'}, 69 | initialization={'weight':init_weight, **({'bias':init_bias} if bias else {})}, ) 70 | return engine.forward(x, **{**inferred_kw, **kw}) 71 | 72 | 73 | def batch_norm(x, **kw): 74 | """ Batch Normalization layer.\n 75 | - `x: Tensor`; 2d or more, with shapes `(Batch, Channel, *)` where `*` means any number of additional dimensions. 76 | - `**kw: dict`; Any additional KWargs are passed down to `torch.nn.BatchNormNd`, where N can be 1, 2 or 3. 77 | as well as `warm.engine.forward`. Refer to their docs for details. Some of the additional BatchNorm arguments: 78 | `eps, momentum, affine, track_running_stats`. 79 | - `return: Tensor`; Same shape as input `x`. """ 80 | d = x.ndim-3 81 | assert d in [0, 1, 2], 'Incompatible number of dims for input x.' 82 | inferred_kw = dict( 83 | base_name='batch_norm', 84 | base_class=[nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d][d], 85 | base_kw={'num_features':x.shape[1]}, ) 86 | return engine.forward(x, **{**inferred_kw, **kw}) 87 | 88 | 89 | def lstm(x, size, 90 | init_weight_hh='orthogonal_', init_weight_ih=None, init_bias_hh=None, init_bias_ih=None, 91 | bias=True, num_layers=1, **kw): 92 | """ Long Short Term Memory layer.\n 93 | - `x: Tensor or tuple`; If tuple, must be of format `(x, (h_0, c_0))`, where `x` is a 3d tensor, 94 | with shapes `(Batch, Channel, Length)`. 95 | - `size: int`; Size of hidden features, and size of the output channel. 96 | - `init_weight_hh: None or str or callable`; Initialization specification for the hidden-hidden weight tensor. 97 | If a `str`, should be one of the nonlinearity functions contained in `torch.nn.init`. 98 | If a `callable`, it will be applied to `x` directly, i.e. `spec(x)`. If a 2-`tuple`, 99 | it must be of format `(callable, kwargs)`, i.e. `callable(x, **kwargs)`. 100 | Default: `'orthogonal_'`. 101 | - `init_weight_ih: None or str or callable`; Initialization specification for the input-hidden weight tensor. 102 | Default: `None`, and the weight tensor is initialized using `torch.nn.LSTM`s default scheme. 103 | - `init_bias_hh: None or str or callable`; Initialization specification for the hidden-hidden bias tensor. 104 | Default: `None`, and the weight tensor is initialized using `torch.nn.LSTM`s default scheme. 105 | - `init_bias_ih: None or str or callable`; Initialization specification for the input-hidden bias tensor. 106 | Default: `None`, and the weight tensor is initialized using `torch.nn.LSTM`s default scheme. 107 | - `bias: bool`; If `False`, then the layer does not use `bias_ih` and `bias_hh`. Default: `True`. 108 | - `num_layers: int`; Number of the recurrent layers. Default: 1. 109 | - `tuple_out: bool`; If `True`, the returned value will be a tuple `(out, (h_n, c_n))`. Default: False. 110 | - `**kw: dict`; Any additional KWargs are passed down to `torch.nn.LSTM`, as well as `warm.engine.forward`. 111 | Refer to their docs for details. Some of the additional LSTM arguments: `dropout, bidirectional, batch_first`. 112 | - `return: Tensor or tuple`; If `tuple_out` set to true, will return `(out, (h_n, c_n)`, otherwise just `out`. 113 | `out` has shape `(Batch, Size, Length*Directions)`, 114 | where Directions = 2 if `bidirectional` else 1. 115 | `h_n` is the hidden states with shape `(num_layers*Directions, Batch, Size)`. 116 | `c_n` is the cell states with shape `(num_layers*Directions, Batch, Size)`. """ 117 | states = None 118 | if isinstance(x, tuple): 119 | x, *states = x 120 | init = dict( 121 | weight_hh=init_weight_hh, 122 | weight_ih=init_weight_ih, 123 | bias_hh=init_bias_hh, 124 | bias_ih=init_bias_ih, ) 125 | inferred_kw = dict( 126 | base_name='lstm', 127 | base_class=nn.LSTM, 128 | base_kw={ 129 | 'hidden_size':size, 130 | 'num_layers':num_layers, 131 | **engine.unused_kwargs(kw), }, 132 | base_shape='DBC', 133 | infer_kw={'input_size':'C'}, 134 | forward_arg=states, 135 | initialization={ 136 | f'{k}_l{l}':init[k] for k in ['weight_hh', 'weight_ih']+(['bias_hh', 'bias_ih'] if bias else []) 137 | for l in range(num_layers)}, ) 138 | return engine.forward(x, **{**inferred_kw, **kw}) 139 | 140 | 141 | def gru(*arg, **kw): 142 | """ Gated Recurrent Unit layer.\n 143 | - `x: Tensor or tuple`; If tuple, must be of format `(x, (h_0, c_0))`, where `x` is a 3d tensor, 144 | with shapes `(Batch, Channel, Length)`. 145 | - `size: int`; Size of hidden features, and size of the output channel. 146 | - `init_weight_hh: None or str or callable`; Initialization specification for the hidden-hidden weight tensor. 147 | If a `str`, should be one of the nonlinearity functions contained in `torch.nn.init`. 148 | If a `callable`, it will be applied to `x` directly, i.e. `spec(x)`. If a 2-`tuple`, 149 | it must be of format `(callable, kwargs)`, i.e. `callable(x, **kwargs)`. 150 | Default: `'orthogonal_'`. 151 | - `init_weight_ih: None or str or callable`; Initialization specification for the input-hidden weight tensor. 152 | Default: `None`, and the weight tensor is initialized using `torch.nn.GRU`s default scheme. 153 | - `init_bias_hh: None or str or callable`; Initialization specification for the hidden-hidden bias tensor. 154 | Default: `None`, and the weight tensor is initialized using `torch.nn.GRU`s default scheme. 155 | - `init_bias_ih: None or str or callable`; Initialization specification for the input-hidden bias tensor. 156 | Default: `None`, and the weight tensor is initialized using `torch.nn.GRU`s default scheme. 157 | - `bias: bool`; If `False`, then the layer does not use `bias_ih` and `bias_hh`. Default: `True`. 158 | - `num_layers: int`; Number of the recurrent layers. Default: 1. 159 | - `tuple_out: bool`; If `True`, the returned value will be a tuple `(out, (h_n, c_n))`. Default: False. 160 | - `**kw: dict`; Any additional KWargs are passed down to `torch.nn.GRU`, as well as `warm.engine.forward`. 161 | Refer to their docs for details. Some of the additional GRU arguments: `dropout, bidirectional, batch_first`. 162 | - `return: Tensor or tuple`; If `tuple_out` set to true, will return `(out, (h_n, c_n)`, otherwise just `out`. 163 | `out` has shape `(Batch, Size, Length*Directions)`, 164 | where Directions = 2 if `bidirectional` else 1. 165 | `h_n` is the hidden states with shape `(num_layers*Directions, Batch, Size)`. 166 | `c_n` is the cell states with shape `(num_layers*Directions, Batch, Size)`. """ 167 | return lstm(*arg, base_name='gru', base_class=nn.GRU, **kw) 168 | 169 | 170 | def identity(x, *arg, **kw): 171 | """ Identity layer that returns the first input, ignores the rest arguments. """ 172 | return x 173 | 174 | 175 | def dropout(x, rate=0.5, by_channel=False, **kw): 176 | """ Dropout layer.\n 177 | During training, randomly zeros part of input tensor `x`, at probability `rate`.\n 178 | - `x: Tensor`; Can be of any shape if `by_channel` is false, or 2d and up if `by_channel` is true. 179 | - `rate: float`; The probability of dropout. Default 0.5. 180 | - `by_channel: bool`; If true, will dropout entire channels (all `'D'` dimensions will be 0 if x is `'BCD'`). 181 | `by_channel` true requires `x` to be 2d or more. 182 | - `inplace: bool`; If true, the operation will be in-place and the input `x` will be altered. 183 | - `return: Tensor`; Same shape as `x`. """ 184 | inferred_kw = dict( 185 | base_name='dropout', 186 | base_class=[nn.Dropout, nn.Dropout2d][by_channel], 187 | base_kw={'p':rate}, 188 | base_shape=[None, 'BCD'][by_channel], ) 189 | return engine.forward(x, **{**inferred_kw, **kw}) 190 | 191 | 192 | def transformer(x, y=None, num_encoder=6, num_decoder=6, num_head=8, 193 | mask=None, causal=False, in_shape='BCD', **kw): 194 | """ Transformer layer.\n 195 | This layer covers functionality of `Transformer`, `TransformerEncoder`, and `TransformerDecoder`. 196 | See [`torch.nn.Transformer`](https://pytorch.org/docs/stable/nn.html#transformer) for more details.\n 197 | - `x: Tensor`; The source sequence, with shape `(Batch, Channel, LengthX)`. 198 | `Channel` is usually from embedding. 199 | - `y: None or Tensor`; The target sequence. Also with shape `(Batch, Channel, LengthY)`. 200 | If not present, default to equal `x`. 201 | - `num_encoder: int`; Number of encoder layers. Set to 0 to disable encoder and use only decoder. Default 6. 202 | - `num_decoder: int`; Number of decoder layers. Set to 0 to disable decoder and use only encoder. Default 6. 203 | - `num_head: int`; Number of heads for multi-headed attention. Default 8. 204 | - `mask: None or dict`; Keys are among: `src_mask`, `tgt_mask`, `memory_mask`, 205 | `src_key_padding_mask`, `tgt_key_padding_mask`, `memory_key_padding_mask`. 206 | See the `forward` method of `torch.nn.Transformer` for details. 207 | - `causal: bool`; Default false. if true, will add causal masks to source and target, so that 208 | current value only depends on the past, not the future, in the sequences. 209 | - `**kw: dict`; Any additional KWargs are passed down to `torch.nn.Transformer`, as well as `warm.engine.forward`. 210 | - `return: Tensor`; Same shape as `y`, if `num_decoder` > 0. Otherwise same shape as `x`. """ 211 | def _causal_mask(n): 212 | mask = (torch.triu(torch.ones(n, n)) == 1).transpose(0, 1) 213 | return mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) 214 | if y is None: 215 | y = x 216 | y = permute(y, in_shape, 'DBC') 217 | mask = mask or {} 218 | if causal: 219 | i = in_shape.find('D') 220 | mx = _causal_mask(x.shape[i]) 221 | mask['src_mask'] = mask.pop('src_mask', 0.0)+mx 222 | my = _causal_mask(y.shape[0]) 223 | mask['tgt_mask'] = mask.pop('tgt_mask', 0.0)+my 224 | encoder = identity if num_encoder == 0 else None 225 | decoder = identity if num_decoder == 0 else None 226 | inferred_kw = dict( 227 | base_name='transformer', 228 | base_class=nn.Transformer, 229 | base_shape='DBC', 230 | base_kw=dict( 231 | d_model=x.shape[in_shape.find('C')], 232 | custom_encoder=encoder, 233 | custom_decoder=decoder, 234 | nhead=num_head, 235 | num_encoder_layers=num_encoder, 236 | num_decoder_layers=num_decoder, 237 | **engine.unused_kwargs(kw), ), 238 | in_shape=in_shape, 239 | forward_kw=mask, 240 | forward_arg=(y, ), ) 241 | return engine.forward(x, **{**inferred_kw, **kw}) 242 | 243 | 244 | def layer_norm(x, dim=1, **kw): 245 | """ Layer Normalization.\n 246 | - `x: Tensor`; Can be of any shape. 247 | - `dim: int or list of int`; Dimensions to be normalized. Default: 1. 248 | - `**kw: dict`; Any additional KWargs are passed down to `torch.nn.LayerNorm`, as well as `warm.engine.forward`. 249 | - `return: Tensor`; Same shape as `x`. """ 250 | if dim != -1: 251 | if isinstance(dim, int): 252 | dim = [dim] 253 | dim_norm = [x.ndim+i if i < 0 else i for i in dim] 254 | order = [i for i in range(x.ndim) if i not in dim_norm]+dim_norm 255 | x = x.permute(order) 256 | norm_shape = x.shape[-len(dim_norm):] 257 | else: 258 | norm_shape = [x.shape[-1]] 259 | inferred_kw = dict( 260 | base_name='layer_norm', 261 | base_class=nn.LayerNorm, 262 | base_kw={'normalized_shape':norm_shape}, ) 263 | x = engine.forward(x, **{**inferred_kw, **kw}) 264 | if dim != -1: 265 | x = x.permute(np.argsort(order).tolist()) 266 | return x 267 | 268 | 269 | def embedding(x, size, vocabulary=None, **kw): 270 | """ Embedding layer.\n 271 | The input is usually a list of indices (integers), and the output is a dense matrix which 272 | maps indices to dense vectors. Thus the output will have 1 more dimension than the input.\n 273 | **Note**: The output of this function is always one more dimension than the input. For input with shape `(*)`, 274 | The output will be `(*, size)`. Any shape specifications in the KWargs are ignored. \n 275 | - `x: Tensor`; Contains indices into the vocabulary. Will be converted to `LongTensor` of integers. 276 | Can be of any shape. 277 | - `size: int`; The size of embedding vector. 278 | - `vocabulary: int or None`; The size of vocabulary of embedding, or max number of unique indices in `x`. 279 | By default it is set to `max(x)-min(x)+1`. 280 | - `**kw: dict`; Any additional KWargs are passed down to `torch.nn.LayerNorm`, as well as `warm.engine.forward`. 281 | - `return: Tensor`; With the embedded dim appended to the shape of x. 282 | Thus with shape `(*, Size)`, where `*` is the shape of `x`. """ 283 | x = x.type(torch.LongTensor) 284 | if vocabulary is None: 285 | vocabulary = x.max()-x.min()+1 286 | kw.pop('in_shape', None) 287 | kw.pop('out_shape', None) 288 | kw.pop('base_shape', None) 289 | inferred_kw = dict( 290 | base_name='embedding', 291 | base_class=nn.Embedding, 292 | base_kw=dict( 293 | num_embeddings=vocabulary, 294 | embedding_dim=size, 295 | **engine.unused_kwargs(kw), ), 296 | base_shape=None, 297 | in_shape=None, 298 | out_shape=None, ) 299 | return engine.forward(x, **{**inferred_kw, **kw}) 300 | --------------------------------------------------------------------------------