├── .gitignore ├── LICENSE ├── README.md ├── assets └── architecture.png ├── requirements.txt └── segmenter ├── __init__.py ├── segmenter.py └── vision_transformer.py /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/ 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | pip-wheel-metadata/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 97 | __pypackages__/ 98 | 99 | # Celery stuff 100 | celerybeat-schedule 101 | celerybeat.pid 102 | 103 | # SageMath parsed files 104 | *.sage.py 105 | 106 | # Environments 107 | .env 108 | .venv 109 | env/ 110 | venv/ 111 | ENV/ 112 | env.bak/ 113 | venv.bak/ 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # mkdocs documentation 123 | /site 124 | 125 | # mypy 126 | .mypy_cache/ 127 | .dmypy.json 128 | dmypy.json 129 | 130 | # Pyre type checker 131 | .pyre/ 132 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 isaac 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # segmenter-pytorch 2 | PyTorch implementation of ["Segmenter: Transformer for Semantic Segmentation" by Strudel et al. (2021)](https://arxiv.org/abs/2105.05633) 3 | 4 | 5 | 6 | Currently a work in progress -------------------------------------------------------------------------------- /assets/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isaaccorley/segmenter-pytorch/9b4d0b0f1a74eaec8b509ff25566fd420f3bb930/assets/architecture.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | timm 4 | einops -------------------------------------------------------------------------------- /segmenter/__init__.py: -------------------------------------------------------------------------------- 1 | from . import vision_transformer 2 | from .segmenter import Segmenter 3 | -------------------------------------------------------------------------------- /segmenter/segmenter.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import timm 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from einops.layers.torch import Rearrange 8 | 9 | import segmenter.vision_transformer 10 | 11 | 12 | class MaskTransformer(nn.Module): 13 | 14 | def __init__( 15 | self, 16 | num_classes: int, 17 | emb_dim: int, 18 | hidden_dim: int, 19 | num_layers: int, 20 | num_heads: int, 21 | ): 22 | self.num_classes = num_classes 23 | self.cls_tokens = nn.Parameter(torch.randn(1, num_classes, emb_dim)) 24 | layer = nn.TransformerEncoderLayer( 25 | d_model=emb_dim, 26 | nhead=num_heads, 27 | dim_feedforward=hidden_dim, 28 | activation="gelu" 29 | ) 30 | self.transformer = nn.TransformerEncoder( 31 | layer, 32 | num_layers=num_layers 33 | ) 34 | 35 | def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 36 | b = x.shape[0] 37 | cls_tokens = self.cls_tokens.repeat(b, 1, 1) 38 | x = torch.cat([cls_tokens, x], dim=1) 39 | x = self.transformer(x) 40 | c = x[:, :self.num_classes] 41 | z = x[:, self.num_classes:] 42 | return z, c 43 | 44 | 45 | class Upsample(nn.Module): 46 | 47 | def __init__(self, image_size: int, patch_size: Tuple[int, int]): 48 | self.model = nn.Sequential( 49 | Rearrange("b (p1 p2) c -> b c p1 p2", p1=image_size//patch_size[0], p2=image_size//patch_size[1]), 50 | nn.Upsample(scale_factor=patch_size, mode="bilinear") 51 | ) 52 | 53 | def forward(self, x: torch.Tensor) -> torch.Tensor: 54 | return self.model(x) 55 | 56 | 57 | class Segmenter(nn.Module): 58 | 59 | def __init__( 60 | self, 61 | backbone: str, 62 | num_classes: int, 63 | image_size: int, 64 | emb_dim: int, 65 | hidden_dim: int, 66 | num_layers: int, 67 | num_heads: int, 68 | ): 69 | self.encoder = timm.create_model(backbone, img_size=image_size, pretrained=True) 70 | patch_size = self.encoder.patch_embed.patch_size 71 | self.mask_transformer = MaskTransformer( 72 | num_classes, 73 | emb_dim, 74 | hidden_dim, 75 | num_layers, 76 | num_heads, 77 | ) 78 | self.upsample = Upsample(image_size, patch_size) 79 | self.scale = emb_dim ** -0.5 80 | 81 | def forward(self, x: torch.Tensor): 82 | x = self.encoder(x) 83 | z, c = self.mask_transformer(x) 84 | masks = z @ c.transpose(1, 2) 85 | masks = torch.softmax(masks / self.scale, dim=-1) 86 | return self.upsample(masks) 87 | -------------------------------------------------------------------------------- /segmenter/vision_transformer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Override timm.models.vision_transformer.VisionTransformer to 3 | output all output tokens (excluding class or distill tokens) 4 | 5 | This works for the vision_transformer_hybrid models as well 6 | """ 7 | import torch 8 | import timm.models.vision_transformer 9 | 10 | 11 | class VisionTransformer(timm.models.vision_transformer.VisionTransformer): 12 | 13 | def forward_features(self, x): 14 | x = self.patch_embed(x) 15 | cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks 16 | if self.dist_token is None: 17 | x = torch.cat((cls_token, x), dim=1) 18 | else: 19 | x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1) 20 | x = self.pos_drop(x + self.pos_embed) 21 | x = self.blocks(x) 22 | """ 23 | x = self.norm(x) 24 | if self.dist_token is None: 25 | return self.pre_logits(x[:, 0]) 26 | else: 27 | return x[:, 0], x[:, 1] 28 | """ 29 | if self.dist_token is None: 30 | return x[:, 1:] 31 | else: 32 | return x[:, 2:] 33 | 34 | def forward(self, x): 35 | """ 36 | x = self.forward_features(x) 37 | if self.head_dist is not None: 38 | x, x_dist = self.head(x[0]), self.head_dist(x[1]) # x must be a tuple 39 | if self.training and not torch.jit.is_scripting(): 40 | # during inference, return the average of both classifier predictions 41 | return x, x_dist 42 | else: 43 | return (x + x_dist) / 2 44 | else: 45 | x = self.head(x) 46 | return x 47 | """ 48 | x = self.forward_features(x) 49 | return x 50 | 51 | 52 | timm.models.vision_transformer.VisionTransformer = VisionTransformer 53 | --------------------------------------------------------------------------------