├── .editorconfig ├── .gitignore ├── LICENSE.txt ├── README.md ├── demo └── fastseg-semantic-segmentation.ipynb ├── fastseg ├── __init__.py ├── image │ ├── __init__.py │ ├── colorize.py │ └── palette.py └── model │ ├── base.py │ ├── efficientnet.py │ ├── lraspp.py │ ├── mobilenetv3.py │ └── utils.py ├── infer.py ├── onnx_export.py ├── onnx_infer.py ├── onnx_optimize.py ├── requirements.txt └── setup.py /.editorconfig: -------------------------------------------------------------------------------- 1 | [*] 2 | end_of_line = lf 3 | insert_final_newline = true 4 | trim_trailing_whitespace = true 5 | 6 | [*.py] 7 | indent_style = space 8 | indent_size = 4 9 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .vscode/ 3 | *.onnx 4 | 5 | /build 6 | /dist 7 | /*.egg-info 8 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2020 Eric Zhang 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. 20 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Fast Semantic Segmentation 2 | 3 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ekzhang/fastseg/blob/master/demo/fastseg-semantic-segmentation.ipynb) 4 | 5 | This respository aims to provide accurate _real-time semantic segmentation_ code for mobile devices in PyTorch, with pretrained weights on Cityscapes. This can be used for efficient segmentation on a variety of real-world street images, including datasets like Mapillary Vistas, KITTI, and CamVid. 6 | 7 | ```python 8 | from fastseg import MobileV3Large 9 | model = MobileV3Large.from_pretrained().cuda().eval() 10 | model.predict(images) 11 | ``` 12 | 13 | ![Example image segmentation video](https://i.imgur.com/vOApT8N.gif) 14 | 15 | The models are implementations of **MobileNetV3** (both large and small variants) with a modified segmentation head based on **LR-ASPP**. The top model was able to achieve **72.3%** mIoU accuracy on Cityscapes _val_, while running at up to **37.3 FPS** on a GPU. Please see below for detailed benchmarks. 16 | 17 | Currently, you can do the following: 18 | 19 | - Load pretrained MobileNetV3 semantic segmentation models. 20 | - Easily generate hard segmentation labels or soft probabilities for street image scenes. 21 | - Evaluate MobileNetV3 models on Cityscapes, or your own dataset. 22 | - Export models for production with ONNX. 23 | 24 | If you have any feature requests or questions, feel free to leave them as GitHub issues! 25 | 26 | ## Table of Contents 27 | 28 | * [What's New?](#whats-new) 29 | + [September 29th, 2020](#september-29th-2020) 30 | + [August 12th, 2020](#august-12th-2020) 31 | + [August 11th, 2020](#august-11th-2020) 32 | * [Overview](#overview) 33 | * [Requirements](#requirements) 34 | * [Pretrained Models and Metrics](#pretrained-models-and-metrics) 35 | * [Usage](#usage) 36 | + [Running Inference](#running-inference) 37 | + [Exporting to ONNX](#exporting-to-onnx) 38 | * [Training from Scratch](#training-from-scratch) 39 | * [Contributions](#contributions) 40 | 41 | ## What's New? 42 | 43 | ### September 29th, 2020 44 | 45 | - Released [training code](https://github.com/ekzhang/semantic-segmentation) for semantic segmentation models 46 | 47 | ### August 12th, 2020 48 | 49 | - Added pretrained weights for `MobileV3Small` with 256 filters 50 | 51 | ### August 11th, 2020 52 | 53 | - Initial release 54 | - Implementations of `MobileV3Large` and `MobileV3Small` with LR-ASPP 55 | - Pretrained weights for `MobileV3Large` with 128/256 filters, and `MobileV3Small` with 64/128 filters 56 | - Inference, ONNX export, and optimization scripts 57 | 58 | ## Overview 59 | 60 | Here's an excerpt from the [original paper](https://arxiv.org/abs/1905.02244) introducing MobileNetV3: 61 | 62 | > This paper starts the exploration of how automated search algorithms and network design can work together to harness complementary approaches improving the overall state of the art. Through this process we create two new MobileNet models for release: MobileNetV3-Large and MobileNetV3-Small, which are targeted for high and low resource use cases. These models are then adapted and applied to the tasks of object detection and semantic segmentation. 63 | > 64 | > For the task of semantic segmentation (or any dense pixel prediction), we propose a new efficient segmentation decoder Lite Reduced Atrous Spatial Pyramid Pooling (LR-ASPP). **We achieve new state of the art results for mobile classification, detection and segmentation.** 65 | > 66 | > **MobileNetV3-Large LRASPP is 34% faster than MobileNetV2 R-ASPP at similar accuracy for Cityscapes segmentation.** 67 | > 68 | > ![MobileNetV3 Comparison](https://i.imgur.com/E9IYp0c.png?1) 69 | 70 | This project tries to faithfully implement MobileNetV3 for real-time semantic segmentation, with the aims of being efficient, easy to use, and extensible. 71 | 72 | ## Requirements 73 | 74 | This code requires Python 3.7 or later. It has been tested to work with PyTorch versions 1.5 and 1.6. To install the package, simply run `pip install fastseg`. Then you can get started with a pretrained model: 75 | 76 | ```python 77 | # Load a pretrained MobileNetV3 segmentation model in inference mode 78 | from fastseg import MobileV3Large 79 | model = MobileV3Large.from_pretrained().cuda() 80 | model.eval() 81 | 82 | # Open a local image as input 83 | from PIL import Image 84 | image = Image.open('street_image.png') 85 | 86 | # Predict numeric labels [0-18] for each pixel of the image 87 | labels = model.predict_one(image) 88 | ``` 89 | 90 | ![Example image segmentation](https://i.imgur.com/WspmlwN.jpg) 91 | 92 | More detailed examples are given below. As an alternative, instead of installing `fastseg` from pip, you can clone this repository and install the [`geffnet` package](https://github.com/rwightman/gen-efficientnet-pytorch) (along with other dependencies) by running `pip install -r requirements.txt` in the project root. 93 | 94 | ## Pretrained Models and Metrics 95 | 96 | I was able to train a few models close to or exceeding the accuracy described in the original [Searching for MobileNetV3](https://arxiv.org/abs/1905.02244) paper. Each was trained only on the `gtFine` labels from Cityscapes for around 12 hours on an Nvidia DGX-1 node, with 8 V100 GPUs. 97 | 98 | | Model | Segmentation Head | Parameters | mIoU | Inference | TensorRT | Weights? | 99 | | :-------------: | :---------------: | :--------: | :---: | :-------: | :------: | :------: | 100 | | `MobileV3Large` | LR-ASPP, F=256 | 3.6M | 72.3% | 21.1 FPS | 30.7 FPS | ✔ | 101 | | `MobileV3Large` | LR-ASPP, F=128 | 3.2M | 72.3% | 25.7 FPS | 37.3 FPS | ✔ | 102 | | `MobileV3Small` | LR-ASPP, F=256 | 1.4M | 67.8% | 30.3 FPS | 39.4 FPS | ✔ | 103 | | `MobileV3Small` | LR-ASPP, F=128 | 1.1M | 67.4% | 38.2 FPS | 52.4 FPS | ✔ | 104 | | `MobileV3Small` | LR-ASPP, F=64 | 1.0M | 66.9% | 46.5 FPS | 61.9 FPS | ✔ | 105 | 106 | The accuracy is within **0.3%** of the original paper, which reported 72.6% mIoU and 3.6M parameters on the Cityscapes _val_ set. Inference was tested on a single V100 GPU with full-resolution 2MP images (1024 x 2048) as input. It runs roughly 4x faster on half-resolution (512 x 1024) images. 107 | 108 | The "TensorRT" column shows benchmarks I ran after exporting optimized ONNX models to [Nvidia TensorRT](https://developer.nvidia.com/tensorrt) with fp16 precision. Performance is measured by taking average GPU latency over 100 iterations. 109 | 110 | ## Usage 111 | 112 | ### Running Inference 113 | 114 | The easiest way to get started with inference is to clone this repository and use the `infer.py` script. For example, if you have street images named `city_1.png` and `city_2.png`, then you can generate segmentation labels for them with the following command. 115 | 116 | ```shell 117 | $ python infer.py city_1.png city_2.png 118 | ``` 119 | 120 | Output: 121 | ``` 122 | ==> Creating PyTorch MobileV3Large model 123 | ==> Loading images and running inference 124 | Loading city_1.png 125 | Generated colorized_city_1.png 126 | Generated composited_city_1.png 127 | Loading city_2.png 128 | Generated colorized_city_2.png 129 | Generated composited_city_2.png 130 | ``` 131 | 132 | | Original | Colorized | Composited | 133 | | :----------------------------------: | :----------------------------------: | :----------------------------------: | 134 | | ![](https://i.imgur.com/74vqz0q.png) | ![](https://i.imgur.com/HRr16YC.png) | ![](https://i.imgur.com/WVd5a6Z.png) | 135 | | ![](https://i.imgur.com/MJA7VMN.png) | ![](https://i.imgur.com/FqoxHzR.png) | ![](https://i.imgur.com/fVMvbRv.png) | 136 | 137 | To interact with the models programmatically, first install the `fastseg` package with pip, as described above. Then, you can import and construct models in your own Python code, which are instances of PyTorch `nn.Module`. 138 | 139 | ```python 140 | from fastseg import MobileV3Large, MobileV3Small 141 | 142 | # Load a pretrained segmentation model 143 | model = MobileV3Large.from_pretrained() 144 | 145 | # Load a segmentation model from a local checkpoint 146 | model = MobileV3Small.from_pretrained('path/to/weights.pt') 147 | 148 | # Create a custom model with random initialization 149 | model = MobileV3Large(num_classes=19, use_aspp=False, num_filters=256) 150 | ``` 151 | 152 | To run inference on an image or batch of images, you can use the methods `model.predict_one()` and `model.predict()`, respectively. These methods take care of the preprocessing and output interpretation for you; they take PIL Images or NumPy arrays as input and return a NumPy array. 153 | 154 | (You can also run inference directly with `model.forward()`, which will return a tensor containing logits, but be sure to normalize the inputs to have mean 0 and variance 1.) 155 | 156 | ```python 157 | import torch 158 | from PIL import Image 159 | from fastseg import MobileV3Large, MobileV3Small 160 | 161 | # Construct a new model with pretrained weights, in evaluation mode 162 | model = MobileV3Large.from_pretrained().cuda() 163 | model.eval() 164 | 165 | # Run inference on an image 166 | img = Image.open('city_1.png') 167 | labels = model.predict_one(img) # returns a NumPy array containing integer labels 168 | assert labels.shape == (1024, 2048) 169 | 170 | # Run inference on a batch of images 171 | img2 = Image.open('city_2.png') 172 | batch_labels = model.predict([img, img2]) # returns a NumPy array containing integer labels 173 | assert batch_labels.shape == (2, 1024, 2048) 174 | 175 | # Run forward pass directly 176 | dummy_input = torch.randn(1, 3, 1024, 2048, device='cuda') 177 | with torch.no_grad(): 178 | dummy_output = model(dummy_input) 179 | assert dummy_output.shape == (1, 19, 1024, 2048) 180 | ``` 181 | 182 | The output labels can be visualized with colorized and composited images. 183 | 184 | ```python 185 | from fastseg.image import colorize, blend 186 | 187 | colorized = colorize(labels) # returns a PIL Image 188 | colorized.show() 189 | 190 | composited = blend(img, colorized) # returns a PIL Image 191 | composited.show() 192 | ``` 193 | 194 | ### Exporting to ONNX 195 | 196 | The `onnx_export.py` script can be used to convert a pretrained segmentation model to ONNX. You should specify the image input dimensions when exporting. See the usage instructions below: 197 | 198 | ``` 199 | $ python onnx_export.py --help 200 | usage: onnx_export.py [-h] [--model MODEL] [--num_filters NUM_FILTERS] 201 | [--size SIZE] [--checkpoint CHECKPOINT] 202 | OUTPUT_FILENAME 203 | 204 | Command line script to export a pretrained segmentation model to ONNX. 205 | 206 | positional arguments: 207 | OUTPUT_FILENAME filename of output model (e.g., 208 | mobilenetv3_large.onnx) 209 | 210 | optional arguments: 211 | -h, --help show this help message and exit 212 | --model MODEL, -m MODEL 213 | the model to export (default MobileV3Large) 214 | --num_filters NUM_FILTERS, -F NUM_FILTERS 215 | the number of filters in the segmentation head 216 | (default 128) 217 | --size SIZE, -s SIZE the image dimensions to set as input (default 218 | 1024,2048) 219 | --checkpoint CHECKPOINT, -c CHECKPOINT 220 | filename of the weights checkpoint .pth file (uses 221 | pretrained by default) 222 | ``` 223 | 224 | The `onnx_optimize.py` script optimizes exported models. If you're looking to deploy a model to TensorRT or a mobile device, you might also want to run it through [onnx-simplifier](https://github.com/daquexian/onnx-simplifier). 225 | 226 | ## Training from Scratch 227 | 228 | Please see the [ekzhang/semantic-segmentation](https://github.com/ekzhang/semantic-segmentation) repository for the training code used in this project, as well as documentation about how to train your own custom models. 229 | 230 | ## Contributions 231 | 232 | Pull requests are always welcome! A big thanks to Andrew Tao and Karan Sapra from [NVIDIA ADLR](https://nv-adlr.github.io/) for helpful discussions and for lending me their training code, as well as Branislav Kisacanin, without whom this wouldn't be possible. 233 | 234 | I'm grateful for advice from: Ching Hung, Eric Viscito, Franklyn Wang, Jagadeesh Sankaran, and Zoran Nikolic. 235 | 236 | Licensed under the MIT License. 237 | -------------------------------------------------------------------------------- /fastseg/__init__.py: -------------------------------------------------------------------------------- 1 | from .model.lraspp import MobileV3Large, MobileV3Small 2 | -------------------------------------------------------------------------------- /fastseg/image/__init__.py: -------------------------------------------------------------------------------- 1 | from .colorize import * 2 | -------------------------------------------------------------------------------- /fastseg/image/colorize.py: -------------------------------------------------------------------------------- 1 | """Utilities for generating a colorized segmentation image. 2 | 3 | Parts of this code were modified from https://github.com/NVIDIA/semantic-segmentation. 4 | """ 5 | 6 | import numpy as np 7 | from PIL import Image 8 | 9 | from .palette import all_palettes 10 | 11 | def colorize(mask_array, palette='cityscapes'): 12 | """Colorize a segmentation mask. 13 | 14 | Keyword arguments: 15 | mask_array -- the segmentation as a 2D numpy array of integers [0..classes - 1] 16 | palette -- the palette to use (default 'cityscapes') 17 | """ 18 | mask_img = Image.fromarray(mask_array.astype(np.uint8)).convert('P') 19 | mask_img.putpalette(all_palettes[palette]) 20 | return mask_img.convert('RGB') 21 | 22 | def blend(input_img, seg_img): 23 | """Blend an input image with its colorized segmentation labels.""" 24 | return Image.blend(input_img, seg_img, 0.4) 25 | -------------------------------------------------------------------------------- /fastseg/image/palette.py: -------------------------------------------------------------------------------- 1 | """Various RGB palettes for coloring segmentation labels.""" 2 | 3 | cityscapes = ( 4 | 128, 64, 128, 5 | 244, 35, 232, 6 | 70, 70, 70, 7 | 102, 102, 156, 8 | 190, 153, 153, 9 | 153, 153, 153, 10 | 250, 170, 30, 11 | 220, 220, 0, 12 | 107, 142, 35, 13 | 152, 251, 152, 14 | 70, 130, 180, 15 | 220, 20, 60, 16 | 255, 0, 0, 17 | 0, 0, 142, 18 | 0, 0, 70, 19 | 0, 60, 100, 20 | 0, 80, 100, 21 | 0, 0, 230, 22 | 119, 11, 32, 23 | ) 24 | 25 | all_palettes = { 26 | 'cityscapes': cityscapes, 27 | } 28 | -------------------------------------------------------------------------------- /fastseg/model/base.py: -------------------------------------------------------------------------------- 1 | """The `BaseSegmentation` class provides useful convenience functions for inference.""" 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torchvision import transforms 6 | 7 | MODEL_WEIGHTS_URL = { 8 | ('mobilev3large-lraspp', 256): 'https://github.com/ekzhang/fastseg/releases/download/v0.1-weights/mobilev3large-lraspp-f256-9b613ffd.pt', 9 | ('mobilev3large-lraspp', 128): 'https://github.com/ekzhang/fastseg/releases/download/v0.1-weights/mobilev3large-lraspp-f128-9cbabfde.pt', 10 | ('mobilev3small-lraspp', 256): 'https://github.com/ekzhang/fastseg/releases/download/v0.1-weights/mobilev3small-lraspp-f256-d853f901.pt', 11 | ('mobilev3small-lraspp', 128): 'https://github.com/ekzhang/fastseg/releases/download/v0.1-weights/mobilev3small-lraspp-f128-a39a1e4b.pt', 12 | ('mobilev3small-lraspp', 64): 'https://github.com/ekzhang/fastseg/releases/download/v0.1-weights/mobilev3small-lraspp-f64-114fc23b.pt', 13 | } 14 | 15 | class BaseSegmentation(nn.Module): 16 | """Module subclass providing useful convenience functions for inference.""" 17 | 18 | @classmethod 19 | def from_pretrained(cls, filename=None, num_filters=128, **kwargs): 20 | """Load a pretrained model from a .pth checkpoint given by `filename`.""" 21 | if filename is None: 22 | # Pull default pretrained model from internet 23 | name = (cls.model_name, num_filters) 24 | if name in MODEL_WEIGHTS_URL: 25 | weights_url = MODEL_WEIGHTS_URL[name] 26 | print(f'Loading pretrained model {name[0]} with F={name[1]}...') 27 | checkpoint = torch.hub.load_state_dict_from_url(weights_url, map_location='cpu') 28 | else: 29 | raise ValueError(f'pretrained weights not found for model {name}, please specify a checkpoint') 30 | else: 31 | checkpoint = torch.load(filename, map_location='cpu') 32 | net = cls(checkpoint['num_classes'], num_filters=num_filters, **kwargs) 33 | net.load_checkpoint(checkpoint) 34 | return net 35 | 36 | def load_checkpoint(self, checkpoint): 37 | """Load weights given a checkpoint object from training.""" 38 | state_dict = {} 39 | for k, v in checkpoint['state_dict'].items(): 40 | if k.startswith('module.'): 41 | state_dict[k[len('module.'):]] = v 42 | self.load_state_dict(state_dict) 43 | 44 | def predict_one(self, image, return_prob=False, device=None): 45 | """Generate and return segmentation for a single image. 46 | 47 | See the documentation of the `predict()` function for more details. This function 48 | is a convenience wrapper that only returns predictions for a single image, rather 49 | than an entire batch. 50 | """ 51 | return self.predict([image], return_prob, device)[0] 52 | 53 | def predict(self, images, return_prob=False, device=None): 54 | """Generate and return segmentations for a batch of images. 55 | 56 | Keyword arguments: 57 | images -- a list of PIL images or NumPy arrays to run segmentation on 58 | return_prob -- whether to return the output probabilities (default False) 59 | device -- the device to use when running evaluation, defaults to 'cuda' or 'cpu' 60 | (this must match the device that the model is currently on) 61 | 62 | Returns: 63 | if `return_prob == False`, a NumPy array of shape (len(images), height, width) 64 | containing the predicted classes 65 | if `return_prob == True`, a NumPy array of shape (len(images), num_classes, height, width) 66 | containing the log-probabilities of each class 67 | """ 68 | # Determine the device 69 | if device is None: 70 | if torch.cuda.is_available(): 71 | device = torch.device('cuda') 72 | else: 73 | device = torch.device('cpu') 74 | 75 | # Preprocess images by normalizing and turning into `torch.tensor`s 76 | tfms = transforms.Compose([ 77 | transforms.ToTensor(), 78 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 79 | ]) 80 | ipt = torch.stack([tfms(im) for im in images]).to(device) 81 | 82 | # Run inference 83 | with torch.no_grad(): 84 | out = self.forward(ipt) 85 | 86 | # Return the output as a `np.ndarray` on the CPU 87 | if not return_prob: 88 | out = out.argmax(dim=1) 89 | return out.cpu().numpy() 90 | -------------------------------------------------------------------------------- /fastseg/model/efficientnet.py: -------------------------------------------------------------------------------- 1 | """Modified EfficientNets for use as semantic segmentation feature extractors.""" 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from geffnet import tf_efficientnet_b4, tf_efficientnet_b0 7 | 8 | class EfficientNet_B4(nn.Module): 9 | def __init__(self, output_stride=8, BatchNorm=nn.BatchNorm2d, 10 | pretrained=False): 11 | super(EfficientNet_B4, self).__init__() 12 | net = tf_efficientnet_b4(pretrained=pretrained, 13 | drop_rate=0.25, 14 | drop_connect_rate=0.2, 15 | norm_layer=BatchNorm) 16 | 17 | self.output_stride = output_stride 18 | self.early = nn.Sequential(net.conv_stem, 19 | net.bn1, 20 | net.act1) 21 | if self.output_stride == 8: 22 | block3_stride = 1 23 | block5_stride = 1 24 | dilation_blocks34 = 2 25 | dilation_blocks56 = 4 26 | elif self.output_stride == 16: 27 | block3_stride = 1 28 | block5_stride = 2 29 | dilation_blocks34 = 1 30 | dilation_blocks56 = 1 31 | else: 32 | raise 33 | 34 | net.blocks[3][0].conv_dw.stride = (block3_stride, block3_stride) 35 | net.blocks[5][0].conv_dw.stride = (block5_stride, block5_stride) 36 | 37 | for block_num in (3, 4, 5, 6): 38 | for sub_block in range(len(net.blocks[block_num])): 39 | m = net.blocks[block_num][sub_block].conv_dw 40 | if block_num < 5: 41 | m.dilation = (dilation_blocks34, dilation_blocks34) 42 | pad = dilation_blocks34 43 | else: 44 | m.dilation = (dilation_blocks56, dilation_blocks56) 45 | pad = dilation_blocks56 46 | if m.kernel_size[0] == 3: 47 | pad *= 1 48 | elif m.kernel_size[0] == 5: 49 | pad *= 2 50 | else: 51 | raise 52 | m.padding = (pad, pad) 53 | 54 | self.block0 = net.blocks[0] 55 | self.block1 = net.blocks[1] 56 | self.block2 = net.blocks[2] 57 | self.block3 = net.blocks[3] 58 | self.block4 = net.blocks[4] 59 | self.block5 = net.blocks[5] 60 | self.block6 = net.blocks[6] 61 | self.late = nn.Sequential(net.conv_head, 62 | net.bn2, 63 | net.act2) 64 | del net 65 | 66 | def forward(self, x): 67 | x = self.early(x) 68 | x = self.block0(x) 69 | s2 = x 70 | x = self.block1(x) 71 | s4 = x 72 | x = self.block2(x) 73 | s8 = x 74 | x = self.block3(x) 75 | x = self.block4(x) 76 | x = self.block5(x) 77 | x = self.block6(x) 78 | x = self.late(x) 79 | if self.output_stride == 8: 80 | return s2, s4, x 81 | else: 82 | return s4, s8, x 83 | 84 | 85 | class EfficientNet_B0(nn.Module): 86 | def __init__(self, output_stride=8, BatchNorm=nn.BatchNorm2d, 87 | pretrained=False): 88 | super(EfficientNet_B0, self).__init__() 89 | net = tf_efficientnet_b0(pretrained=pretrained, 90 | drop_rate=0.25, 91 | drop_connect_rate=0.2, 92 | norm_layer=BatchNorm) 93 | 94 | self.output_stride = output_stride 95 | self.early = nn.Sequential(net.conv_stem, 96 | net.bn1, 97 | net.act1) 98 | if self.output_stride == 8: 99 | block3_stride = 1 100 | block5_stride = 1 101 | dilation_blocks34 = 2 102 | dilation_blocks56 = 4 103 | elif self.output_stride == 16: 104 | block3_stride = 1 105 | block5_stride = 2 106 | dilation_blocks34 = 1 107 | dilation_blocks56 = 1 108 | else: 109 | raise 110 | 111 | net.blocks[3][0].conv_dw.stride = (block3_stride, block3_stride) 112 | net.blocks[5][0].conv_dw.stride = (block5_stride, block5_stride) 113 | 114 | for block_num in (3, 4, 5, 6): 115 | for sub_block in range(len(net.blocks[block_num])): 116 | m = net.blocks[block_num][sub_block].conv_dw 117 | if block_num < 5: 118 | m.dilation = (dilation_blocks34, dilation_blocks34) 119 | pad = dilation_blocks34 120 | else: 121 | m.dilation = (dilation_blocks56, dilation_blocks56) 122 | pad = dilation_blocks56 123 | if m.kernel_size[0] == 3: 124 | pad *= 1 125 | elif m.kernel_size[0] == 5: 126 | pad *= 2 127 | else: 128 | raise 129 | m.padding = (pad, pad) 130 | 131 | self.block0 = net.blocks[0] 132 | self.block1 = net.blocks[1] 133 | self.block2 = net.blocks[2] 134 | self.block3 = net.blocks[3] 135 | self.block4 = net.blocks[4] 136 | self.block5 = net.blocks[5] 137 | self.block6 = net.blocks[6] 138 | self.late = nn.Sequential(net.conv_head, 139 | net.bn2, 140 | net.act2) 141 | del net 142 | 143 | def forward(self, x): 144 | x = self.early(x) 145 | x = self.block0(x) 146 | s2 = x 147 | x = self.block1(x) 148 | s4 = x 149 | x = self.block2(x) 150 | s8 = x 151 | x = self.block3(x) 152 | x = self.block4(x) 153 | x = self.block5(x) 154 | x = self.block6(x) 155 | x = self.late(x) 156 | if self.output_stride == 8: 157 | return s2, s4, x 158 | else: 159 | return s4, s8, x 160 | 161 | 162 | if __name__ == "__main__": 163 | model = EfficientNet_B0(BatchNorm=nn.BatchNorm2d, pretrained=True, 164 | output_stride=8) 165 | input = torch.rand(1, 3, 512, 512) 166 | low, mid, x = model(input) 167 | print(model) 168 | print(sum(p.numel() for p in model.parameters()), ' parameters') 169 | print(x.size()) 170 | print(low.size()) 171 | print(mid.size()) 172 | -------------------------------------------------------------------------------- /fastseg/model/lraspp.py: -------------------------------------------------------------------------------- 1 | """Lite Reduced Atrous Spatial Pyramid Pooling 2 | 3 | Architecture introduced in the MobileNetV3 (2019) paper, as an 4 | efficient semantic segmentation head. 5 | """ 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from .utils import get_trunk, ConvBnRelu 12 | from .base import BaseSegmentation 13 | 14 | class LRASPP(BaseSegmentation): 15 | """Lite R-ASPP style segmentation network.""" 16 | def __init__(self, num_classes, trunk, use_aspp=False, num_filters=128): 17 | """Initialize a new segmentation model. 18 | 19 | Keyword arguments: 20 | num_classes -- number of output classes (e.g., 19 for Cityscapes) 21 | trunk -- the name of the trunk to use ('mobilenetv3_large', 'mobilenetv3_small') 22 | use_aspp -- whether to use DeepLabV3+ style ASPP (True) or Lite R-ASPP (False) 23 | (setting this to True may yield better results, at the cost of latency) 24 | num_filters -- the number of filters in the segmentation head 25 | """ 26 | super(LRASPP, self).__init__() 27 | 28 | self.trunk, s2_ch, s4_ch, high_level_ch = get_trunk(trunk_name=trunk) 29 | self.use_aspp = use_aspp 30 | 31 | # Reduced atrous spatial pyramid pooling 32 | if self.use_aspp: 33 | self.aspp_conv1 = nn.Sequential( 34 | nn.Conv2d(high_level_ch, num_filters, 1, bias=False), 35 | nn.BatchNorm2d(num_filters), 36 | nn.ReLU(inplace=True), 37 | ) 38 | self.aspp_conv2 = nn.Sequential( 39 | nn.Conv2d(high_level_ch, num_filters, 1, bias=False), 40 | nn.Conv2d(num_filters, num_filters, 3, dilation=12, padding=12), 41 | nn.BatchNorm2d(num_filters), 42 | nn.ReLU(inplace=True), 43 | ) 44 | self.aspp_conv3 = nn.Sequential( 45 | nn.Conv2d(high_level_ch, num_filters, 1, bias=False), 46 | nn.Conv2d(num_filters, num_filters, 3, dilation=36, padding=36), 47 | nn.BatchNorm2d(num_filters), 48 | nn.ReLU(inplace=True), 49 | ) 50 | self.aspp_pool = nn.Sequential( 51 | nn.AdaptiveAvgPool2d(1), 52 | nn.Conv2d(high_level_ch, num_filters, 1, bias=False), 53 | nn.BatchNorm2d(num_filters), 54 | nn.ReLU(inplace=True), 55 | ) 56 | aspp_out_ch = num_filters * 4 57 | else: 58 | self.aspp_conv1 = nn.Sequential( 59 | nn.Conv2d(high_level_ch, num_filters, 1, bias=False), 60 | nn.BatchNorm2d(num_filters), 61 | nn.ReLU(inplace=True), 62 | ) 63 | self.aspp_conv2 = nn.Sequential( 64 | nn.AvgPool2d(kernel_size=(49, 49), stride=(16, 20)), 65 | nn.Conv2d(high_level_ch, num_filters, 1, bias=False), 66 | nn.Sigmoid(), 67 | ) 68 | aspp_out_ch = num_filters 69 | 70 | self.convs2 = nn.Conv2d(s2_ch, 32, kernel_size=1, bias=False) 71 | self.convs4 = nn.Conv2d(s4_ch, 64, kernel_size=1, bias=False) 72 | self.conv_up1 = nn.Conv2d(aspp_out_ch, num_filters, kernel_size=1) 73 | self.conv_up2 = ConvBnRelu(num_filters + 64, num_filters, kernel_size=1) 74 | self.conv_up3 = ConvBnRelu(num_filters + 32, num_filters, kernel_size=1) 75 | self.last = nn.Conv2d(num_filters, num_classes, kernel_size=1) 76 | 77 | def forward(self, x): 78 | s2, s4, final = self.trunk(x) 79 | if self.use_aspp: 80 | aspp = torch.cat([ 81 | self.aspp_conv1(final), 82 | self.aspp_conv2(final), 83 | self.aspp_conv3(final), 84 | F.interpolate(self.aspp_pool(final), size=final.shape[2:]), 85 | ], 1) 86 | else: 87 | aspp = self.aspp_conv1(final) * F.interpolate( 88 | self.aspp_conv2(final), 89 | final.shape[2:], 90 | mode='bilinear', 91 | align_corners=True 92 | ) 93 | y = self.conv_up1(aspp) 94 | y = F.interpolate(y, size=s4.shape[2:], mode='bilinear', align_corners=False) 95 | 96 | y = torch.cat([y, self.convs4(s4)], 1) 97 | y = self.conv_up2(y) 98 | y = F.interpolate(y, size=s2.shape[2:], mode='bilinear', align_corners=False) 99 | 100 | y = torch.cat([y, self.convs2(s2)], 1) 101 | y = self.conv_up3(y) 102 | y = self.last(y) 103 | y = F.interpolate(y, size=x.shape[2:], mode='bilinear', align_corners=False) 104 | return y 105 | 106 | 107 | class MobileV3Large(LRASPP): 108 | """MobileNetV3-Large segmentation network.""" 109 | model_name = 'mobilev3large-lraspp' 110 | 111 | def __init__(self, num_classes, **kwargs): 112 | super(MobileV3Large, self).__init__( 113 | num_classes, 114 | trunk='mobilenetv3_large', 115 | **kwargs 116 | ) 117 | 118 | 119 | class MobileV3Small(LRASPP): 120 | """MobileNetV3-Small segmentation network.""" 121 | model_name = 'mobilev3small-lraspp' 122 | 123 | def __init__(self, num_classes, **kwargs): 124 | super(MobileV3Small, self).__init__( 125 | num_classes, 126 | trunk='mobilenetv3_small', 127 | **kwargs 128 | ) 129 | -------------------------------------------------------------------------------- /fastseg/model/mobilenetv3.py: -------------------------------------------------------------------------------- 1 | """Modified MobileNetV3 for use as semantic segmentation feature extractors.""" 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from geffnet import tf_mobilenetv3_large_100, tf_mobilenetv3_small_100 7 | from geffnet.efficientnet_builder import InvertedResidual, Conv2dSame, Conv2dSameExport 8 | 9 | class MobileNetV3_Large(nn.Module): 10 | def __init__(self, trunk=tf_mobilenetv3_large_100, pretrained=False): 11 | super(MobileNetV3_Large, self).__init__() 12 | net = trunk(pretrained=pretrained, 13 | norm_layer=nn.BatchNorm2d) 14 | 15 | self.early = nn.Sequential(net.conv_stem, net.bn1, net.act1) 16 | 17 | net.blocks[3][0].conv_dw.stride = (1, 1) 18 | net.blocks[5][0].conv_dw.stride = (1, 1) 19 | 20 | for block_num in (3, 4, 5, 6): 21 | for sub_block in range(len(net.blocks[block_num])): 22 | sb = net.blocks[block_num][sub_block] 23 | if isinstance(sb, InvertedResidual): 24 | m = sb.conv_dw 25 | else: 26 | m = sb.conv 27 | if block_num < 5: 28 | m.dilation = (2, 2) 29 | pad = 2 30 | else: 31 | m.dilation = (4, 4) 32 | pad = 4 33 | # Adjust padding if necessary, but NOT for "same" layers 34 | assert m.kernel_size[0] == m.kernel_size[1] 35 | if not isinstance(m, Conv2dSame) and not isinstance(m, Conv2dSameExport): 36 | pad *= (m.kernel_size[0] - 1) // 2 37 | m.padding = (pad, pad) 38 | 39 | self.block0 = net.blocks[0] 40 | self.block1 = net.blocks[1] 41 | self.block2 = net.blocks[2] 42 | self.block3 = net.blocks[3] 43 | self.block4 = net.blocks[4] 44 | self.block5 = net.blocks[5] 45 | self.block6 = net.blocks[6] 46 | 47 | def forward(self, x): 48 | x = self.early(x) # 2x 49 | x = self.block0(x) 50 | s2 = x 51 | x = self.block1(x) # 4x 52 | s4 = x 53 | x = self.block2(x) # 8x 54 | x = self.block3(x) 55 | x = self.block4(x) 56 | x = self.block5(x) 57 | x = self.block6(x) 58 | return s2, s4, x 59 | 60 | 61 | class MobileNetV3_Small(nn.Module): 62 | def __init__(self, trunk=tf_mobilenetv3_small_100, pretrained=False): 63 | super(MobileNetV3_Small, self).__init__() 64 | net = trunk(pretrained=pretrained, 65 | norm_layer=nn.BatchNorm2d) 66 | 67 | self.early = nn.Sequential(net.conv_stem, net.bn1, net.act1) 68 | 69 | net.blocks[2][0].conv_dw.stride = (1, 1) 70 | net.blocks[4][0].conv_dw.stride = (1, 1) 71 | 72 | for block_num in (2, 3, 4, 5): 73 | for sub_block in range(len(net.blocks[block_num])): 74 | sb = net.blocks[block_num][sub_block] 75 | if isinstance(sb, InvertedResidual): 76 | m = sb.conv_dw 77 | else: 78 | m = sb.conv 79 | if block_num < 4: 80 | m.dilation = (2, 2) 81 | pad = 2 82 | else: 83 | m.dilation = (4, 4) 84 | pad = 4 85 | # Adjust padding if necessary, but NOT for "same" layers 86 | assert m.kernel_size[0] == m.kernel_size[1] 87 | if not isinstance(m, Conv2dSame) and not isinstance(m, Conv2dSameExport): 88 | pad *= (m.kernel_size[0] - 1) // 2 89 | m.padding = (pad, pad) 90 | 91 | self.block0 = net.blocks[0] 92 | self.block1 = net.blocks[1] 93 | self.block2 = net.blocks[2] 94 | self.block3 = net.blocks[3] 95 | self.block4 = net.blocks[4] 96 | self.block5 = net.blocks[5] 97 | 98 | def forward(self, x): 99 | x = self.early(x) # 2x 100 | s2 = x 101 | x = self.block0(x) # 4x 102 | s4 = x 103 | x = self.block1(x) # 8x 104 | x = self.block2(x) 105 | x = self.block3(x) 106 | x = self.block4(x) 107 | x = self.block5(x) 108 | return s2, s4, x 109 | 110 | 111 | if __name__ == '__main__': 112 | model = MobileNetV3_Large(pretrained=True) 113 | input = torch.rand(1, 3, 512, 512) 114 | low, mid, x = model(input) 115 | print(model) 116 | print(sum(p.numel() for p in model.parameters()), ' parameters') 117 | print(x.size()) 118 | print(low.size()) 119 | print(mid.size()) 120 | -------------------------------------------------------------------------------- /fastseg/model/utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from .efficientnet import EfficientNet_B4, EfficientNet_B0 4 | from .mobilenetv3 import MobileNetV3_Large, MobileNetV3_Small 5 | 6 | def get_trunk(trunk_name): 7 | """Retrieve the pretrained network trunk and channel counts""" 8 | if trunk_name == 'efficientnet_b4': 9 | backbone = EfficientNet_B4(pretrained=True) 10 | s2_ch = 24 11 | s4_ch = 32 12 | high_level_ch = 1792 13 | elif trunk_name == 'efficientnet_b0': 14 | backbone = EfficientNet_B0(pretrained=True) 15 | s2_ch = 16 16 | s4_ch = 24 17 | high_level_ch = 1280 18 | elif trunk_name == 'mobilenetv3_large': 19 | backbone = MobileNetV3_Large(pretrained=True) 20 | s2_ch = 16 21 | s4_ch = 24 22 | high_level_ch = 960 23 | elif trunk_name == 'mobilenetv3_small': 24 | backbone = MobileNetV3_Small(pretrained=True) 25 | s2_ch = 16 26 | s4_ch = 16 27 | high_level_ch = 576 28 | else: 29 | raise ValueError('unknown backbone {}'.format(trunk_name)) 30 | return backbone, s2_ch, s4_ch, high_level_ch 31 | 32 | class ConvBnRelu(nn.Module): 33 | """Convenience layer combining a Conv2d, BatchNorm2d, and a ReLU activation. 34 | 35 | Original source of this code comes from 36 | https://github.com/lingtengqiu/Deeperlab-pytorch/blob/master/seg_opr/seg_oprs.py 37 | """ 38 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, 39 | norm_layer=nn.BatchNorm2d): 40 | super(ConvBnRelu, self).__init__() 41 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, 42 | stride=stride, padding=padding, bias=False) 43 | self.bn = norm_layer(out_planes, eps=1e-5) 44 | self.relu = nn.ReLU(inplace=True) 45 | 46 | def forward(self, x): 47 | x = self.conv(x) 48 | x = self.bn(x) 49 | x = self.relu(x) 50 | return x 51 | -------------------------------------------------------------------------------- /infer.py: -------------------------------------------------------------------------------- 1 | """Command line script to test inference on one or more images.""" 2 | 3 | import argparse 4 | import os.path as path 5 | import sys 6 | 7 | import torch 8 | import torch.nn as nn 9 | from PIL import Image 10 | 11 | from fastseg import MobileV3Large, MobileV3Small 12 | from fastseg.image import colorize, blend 13 | 14 | torch.backends.cudnn.benchmark = True 15 | 16 | parser = argparse.ArgumentParser(description=__doc__) 17 | 18 | parser.add_argument('images', nargs='*', metavar='IMAGES', 19 | help='one or more filenames of images to run inference on') 20 | parser.add_argument('--model', '-m', default='MobileV3Large', 21 | help='the model to use for inference (default MobileV3Large)') 22 | parser.add_argument('--num_filters', '-F', type=int, default=128, 23 | help='the number of filters in the segmentation head (default 128)') 24 | parser.add_argument('--checkpoint', '-c', default=None, 25 | help='filename of the weights checkpoint .pth file (uses pretrained by default)') 26 | parser.add_argument('--show', '-s', action='store_true', 27 | help='display the output segmentation results in the default image viewer') 28 | 29 | args = parser.parse_args() 30 | 31 | if not args.images: 32 | print('Please supply at least one image to run inference on.', file=sys.stderr) 33 | sys.exit(1) 34 | 35 | print(f'==> Creating PyTorch {args.model} model') 36 | if args.model == 'MobileV3Large': 37 | model_cls = MobileV3Large 38 | elif args.model == 'MobileV3Small': 39 | model_cls = MobileV3Small 40 | else: 41 | print(f'Unknown model name: {args.model}', file=sys.stderr) 42 | sys.exit(1) 43 | 44 | model = model_cls.from_pretrained(args.checkpoint, num_filters=args.num_filters).cuda().eval() 45 | 46 | print('==> Loading images and running inference') 47 | 48 | for im_path in args.images: 49 | print('Loading', im_path) 50 | img = Image.open(im_path) 51 | 52 | seg = model.predict_one(img) 53 | 54 | colorized = colorize(seg) 55 | composited = blend(img, colorized) 56 | 57 | basename, filename = path.split(im_path) 58 | colorized_filename = 'colorized_' + filename 59 | composited_filename = 'composited_' + filename 60 | colorized.save(colorized_filename) 61 | composited.save(composited_filename) 62 | print(f'Generated {colorized_filename}') 63 | print(f'Generated {composited_filename}') 64 | 65 | if args.show: 66 | colorized.show() 67 | composited.show() 68 | -------------------------------------------------------------------------------- /onnx_export.py: -------------------------------------------------------------------------------- 1 | """Command line script to export a pretrained segmentation model to ONNX.""" 2 | 3 | import argparse 4 | import sys 5 | 6 | import torch 7 | import geffnet 8 | import fastseg 9 | 10 | parser = argparse.ArgumentParser(description=__doc__) 11 | 12 | parser.add_argument('output', metavar='OUTPUT_FILENAME', 13 | help='filename of output model (e.g., mobilenetv3_large.onnx)') 14 | parser.add_argument('--model', '-m', default='MobileV3Large', 15 | help='the model to export (default MobileV3Large)') 16 | parser.add_argument('--num_filters', '-F', type=int, default=128, 17 | help='the number of filters in the segmentation head (default 128)') 18 | parser.add_argument('--size', '-s', default='1024,2048', 19 | help='the image dimensions to set as input (default 1024,2048)') 20 | parser.add_argument('--checkpoint', '-c', default=None, 21 | help='filename of the weights checkpoint .pth file (uses pretrained by default)') 22 | 23 | args = parser.parse_args() 24 | 25 | print(f'==> Creating PyTorch {args.model} model') 26 | if args.model == 'MobileV3Large': 27 | model_cls = fastseg.MobileV3Large 28 | elif args.model == 'MobileV3Small': 29 | model_cls = fastseg.MobileV3Small 30 | else: 31 | print(f'Unknown model name: {args.model}', file=sys.stderr) 32 | sys.exit(1) 33 | 34 | geffnet.config.set_exportable(True) 35 | model = model_cls.from_pretrained(args.checkpoint, num_filters=args.num_filters) 36 | model.eval() 37 | 38 | print('==> Exporting to ONNX') 39 | height, width = [int(x) for x in args.size.split(',')] 40 | print(f'Image dimensions: {height} x {width}') 41 | print(f'Output file: {args.output}') 42 | 43 | dummy_input = torch.randn(1, 3, height, width) 44 | input_names = ['input0'] 45 | output_names = ['output0'] 46 | 47 | # Run model once, this is required by geffnet 48 | model(dummy_input) 49 | 50 | torch.onnx.export(model, dummy_input, args.output, verbose=True, 51 | input_names=input_names, output_names=output_names, 52 | opset_version=11, keep_initializers_as_inputs=True) 53 | 54 | # Check the model 55 | print(f'==> Finished export, loading and checking model: {args.output}') 56 | import onnx 57 | onnx_model = onnx.load(args.output) 58 | onnx.checker.check_model(onnx_model) 59 | print('==> Passed check') 60 | -------------------------------------------------------------------------------- /onnx_infer.py: -------------------------------------------------------------------------------- 1 | """Script to test inference of an exported ONNX model.""" 2 | 3 | import argparse 4 | 5 | import numpy as np 6 | import onnxruntime 7 | import torch 8 | 9 | from torchvision import transforms 10 | from PIL import Image 11 | 12 | from fastseg.image import colorize, blend 13 | 14 | parser = argparse.ArgumentParser(description=__doc__) 15 | 16 | parser.add_argument('model', metavar='MODEL', 17 | help='filename of onnx model (e.g., mobilenetv3_large.onnx)') 18 | parser.add_argument('image', metavar='IMAGE', 19 | help='filename of image to run inference on') 20 | 21 | args = parser.parse_args() 22 | 23 | im_path = args.image 24 | img = Image.open(im_path) 25 | 26 | tfms = transforms.Compose([ 27 | transforms.ToTensor(), 28 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 29 | ]) 30 | ipt = torch.stack([tfms(img)]).numpy() 31 | 32 | sess = onnxruntime.InferenceSession(args.model) 33 | 34 | out_ort = sess.run(None, { 35 | 'input0': ipt, 36 | }) 37 | 38 | labels = np.argmax(out_ort[0], axis=1)[0] 39 | print(labels) 40 | 41 | colorized = colorize(labels) 42 | colorized.show() 43 | 44 | composited = blend(img, colorized) 45 | composited.show() 46 | -------------------------------------------------------------------------------- /onnx_optimize.py: -------------------------------------------------------------------------------- 1 | """Command line script to optimize an ONNX model.""" 2 | 3 | import argparse 4 | 5 | import onnx 6 | from onnx import optimizer 7 | 8 | parser = argparse.ArgumentParser(description=__doc__) 9 | 10 | parser.add_argument('input', metavar='INPUT_FILENAME', 11 | help='filename of input model (e.g., mobilenetv3_large.onnx)') 12 | parser.add_argument('output', metavar='OUTPUT_FILENAME', 13 | help='filename of output model (e.g., mobilenetv3_large.opt.onnx)') 14 | 15 | args = parser.parse_args() 16 | 17 | print(f'==> Loading model {args.input}') 18 | original_model = onnx.load(args.input) 19 | 20 | print('Number of nodes before optimization:', len(original_model.graph.node)) 21 | 22 | print('==> Optimizing') 23 | passes = [ 24 | 'eliminate_identity', 25 | 'eliminate_nop_dropout', 26 | 'eliminate_nop_monotone_argmax', 27 | 'eliminate_nop_pad', 28 | 'eliminate_nop_transpose', 29 | 'eliminate_unused_initializer', 30 | 'extract_constant_to_initializer', 31 | 'fuse_add_bias_into_conv', 32 | 'fuse_bn_into_conv', 33 | 'fuse_consecutive_concats', 34 | 'fuse_consecutive_log_softmax', 35 | 'fuse_consecutive_reduce_unsqueeze', 36 | 'fuse_consecutive_squeezes', 37 | 'fuse_consecutive_transposes', 38 | 'fuse_pad_into_conv', 39 | 'nop', 40 | ] 41 | 42 | optimized_model = optimizer.optimize(original_model, passes) 43 | 44 | print('Number of nodes after optimization:', len(optimized_model.graph.node)) 45 | 46 | print(f'Output file: {args.output}') 47 | onnx.save(optimized_model, args.output) 48 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | geffnet==0.9.8 2 | onnx==1.7.0 3 | onnxruntime==1.4.0 4 | Pillow==7.2.0 5 | numpy>=1.18.0 6 | torch>=1.5.0 7 | torchvision>=0.6.0 8 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """Setup""" 2 | 3 | import pathlib 4 | from setuptools import setup, find_namespace_packages 5 | 6 | HERE = pathlib.Path(__file__).parent 7 | README = (HERE / "README.md").read_text() 8 | 9 | setup( 10 | name="fastseg", 11 | version="0.1.2", 12 | description="Fast Semantic Segmentation for PyTorch", 13 | long_description=README, 14 | long_description_content_type="text/markdown", 15 | url="https://github.com/ekzhang/fast-semantic-seg", 16 | author="Eric Zhang", 17 | author_email="ekzhang1@gmail.com", 18 | license="MIT", 19 | classifiers=[ 20 | "Development Status :: 4 - Beta", 21 | "Intended Audience :: Science/Research", 22 | "License :: OSI Approved :: MIT License", 23 | "Programming Language :: Python :: 3", 24 | "Programming Language :: Python :: 3.7", 25 | "Topic :: Scientific/Engineering", 26 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 27 | "Topic :: Scientific/Engineering :: Image Recognition", 28 | "Topic :: Multimedia :: Graphics", 29 | ], 30 | packages=find_namespace_packages(include=["fastseg", "fastseg.*"]), 31 | include_package_data=True, 32 | install_requires=[ 33 | "geffnet >= 0.9.8", 34 | "Pillow >= 7.0.0", 35 | "numpy >= 1.18.0", 36 | "torch >= 1.5.0", 37 | "torchvision >= 0.6.0", 38 | ], 39 | ) 40 | --------------------------------------------------------------------------------