├── .github └── workflows │ └── release.yml ├── .gitignore ├── LICENCE.md ├── README.md ├── alpha.py ├── data_loader.py ├── images └── .gitignore ├── masks └── .gitignore ├── model ├── __init__.py ├── u2net.py └── u2net_refactor.py ├── requirements.txt ├── saved_models └── .gitignore └── u2net_train.py /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: 4 | push: 5 | tags: 6 | - "*" 7 | 8 | env: 9 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 10 | 11 | jobs: 12 | build: 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - name: Checkout code 17 | uses: actions/checkout@v2 18 | 19 | - name: Set up Python 20 | uses: actions/setup-python@v2 21 | with: 22 | python-version: "3.12" 23 | 24 | - name: Create zip file 25 | run: | 26 | zip -r rembg-trainer.zip . -x .github/\* \*.gitignore 27 | 28 | - name: Create Release 29 | id: create_release 30 | uses: actions/create-release@v1 31 | with: 32 | tag_name: ${{ github.ref }} 33 | release_name: ${{ github.ref }} 34 | draft: false 35 | prerelease: false 36 | 37 | - name: Upload Release Asset 38 | uses: actions/upload-release-asset@v1 39 | with: 40 | upload_url: ${{ steps.create_release.outputs.upload_url }} # This is the URL for uploading assets to the release 41 | asset_path: ./rembg-trainer.zip 42 | asset_name: rembg-trainer.zip 43 | asset_content_type: application/zip 44 | env: 45 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 46 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | .vscode 3 | .DS_Store 4 | __pycache__ 5 | model/__pycache__ 6 | venv -------------------------------------------------------------------------------- /LICENCE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Evgenii Ostrovskii 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # rembg trainer 2 | 3 | This code allows you to easily train U2-Net model in [ONNX](https://github.com/onnx/onnx) format to use with [rembg](https://github.com/danielgatis/rembg]) tool. 4 | 5 | This work is based off [U2Net](https://github.com/xuebinqin/U-2-Net) repo, which is under Apache licence. The derivative work is loicensed under MIT; do as you please with it. 6 | 7 | A couple of notes on performance: 8 | 9 | - Default parameters are fine-tuned for maximum performance on systems with 32gb of processing memory, like the Apple M1 Pro. Adjust accordingly. 10 | - Computations are performed in float32, because float16 support on Metal is a bit undercooked at the moment. 11 | - If this is your first time using CUDA on Windows, you'd have to install [CUDA Toolkit](https://developer.nvidia.com/cuda-downloads). 12 | - For CUDA, you can easily rewrite this code with half precision calculations for increased performance. Apex library can help you with that; I don't have such plans at the moment. 13 | - For acceleration on AMD GPUs, please refer to installation guide of [AMD ROCm platform](https://rocm.docs.amd.com/en/latest/how_to/pytorch_install/pytorch_install.html). No code changes will be required. 14 | 15 | If the training is interrupted for any reason, don't worry — the program saves its state regularly, allowing you to resume from where you left off. Frequency of saving can be adjusted. 16 | 17 | ## Fancy a go? 18 | 19 | - Download the latest release 20 | - Install `requirements.txt` 21 | - Put your images into `images` folder 22 | - Put their masks into `masks` folder; or see [below](#mask-extraction) 23 | - Launch `python3 u2net_train.py --help` for more details on supported command line flags 24 | - Launch script with your desired configuration 25 | - Go grab yourself a [nice latte](https://www.youtube.com/shorts/h75W1uhL-iQ) and wait........... and wait..... 26 | - Once you've had your fill of waiting, here's how you use resulting model with rembg: 27 | 28 | ```bash 29 | rembg p -w input output -m u2net_custom -x '{"model_path": "/saved_models/u2net/27.onnx"}' 30 | # input — folder with images to have their backgrounds removed 31 | # output — folder for resulting images processed with custom model 32 | # adjust path(s) as necessary! 33 | ``` 34 | 35 | ## Mask extraction 36 | 37 | If you already have a bunch of images with removed background, then you can create masks off them using the provided `alpha.py` script. Create a directory called `clean`, put your pngs there, and launch the script. 38 | 39 | But fair warning mate: the script is very CPU-heavy. Oh, and you'll need the `ImageMagick` tool installed and present in your PATH. 40 | 41 | So, at the end of the day, you will end up with the following folder structure: 42 | 43 | - `images` — source images, will be needed for training 44 | - `masks` — required for training, to teach model where the background was 45 | - `clean` — images with removed background, to extract masks (they're not used for actual training) 46 | 47 | ## Leave your mark 👉👈🥺 48 | 49 | Buy me ~~a coffee~~ an alcohol-free cider [here](http://buymeacoffee.com/jonathunky) 50 | -------------------------------------------------------------------------------- /alpha.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import threading 4 | import time 5 | 6 | 7 | def extract_alpha(input_file_path, output_file_path): 8 | if not os.path.exists(output_file_path): 9 | subprocess.run( 10 | ["magick", input_file_path, "-strip", "-alpha", "extract", output_file_path] 11 | ) 12 | 13 | 14 | current_dir = os.getcwd() 15 | input_dir = os.path.join(current_dir, "clean") 16 | output_dir = os.path.join(current_dir, "masks") 17 | 18 | if not os.path.exists(output_dir): 19 | os.makedirs(output_dir) 20 | 21 | threads = [] 22 | 23 | print("This is Pequod, arriving shortly at LZ to extract team Alpha!") 24 | 25 | # Extracting alpha using ImageMagick 26 | files = os.listdir(input_dir) 27 | total_files = len(files) 28 | 29 | for idx, file in enumerate(files, start=1): 30 | if idx % 20 == 0: 31 | print(f"Processing file {idx} out of {total_files}") 32 | if file.endswith(".png"): 33 | input_file_path = os.path.join(input_dir, file) 34 | output_file_path = os.path.join(output_dir, file) 35 | 36 | # Start a new thread to process the image 37 | t = threading.Thread( 38 | target=extract_alpha, args=(input_file_path, output_file_path) 39 | ) 40 | threads.append(t) 41 | t.start() 42 | 43 | # Control the number of threads to prevent overwhelming the system. 44 | while threading.active_count() > 8: # reduce as necessary! 45 | time.sleep(0.5) 46 | 47 | # Wait for all threads to finish 48 | for t in threads: 49 | t.join() 50 | 51 | print("Masks extracted!") 52 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helper code for loading and altering images in dataset 3 | """ 4 | import gc 5 | import random 6 | 7 | import imageio.v3 as iio 8 | import numpy as np 9 | import torchvision.transforms.functional as tf 10 | from PIL import Image 11 | from torch.utils.data import Dataset 12 | 13 | 14 | class RandomCrop: 15 | """ 16 | A class for performing random cropping of given image. 17 | Treats white pixels as empty space, and strives to return as few of them as possible. 18 | """ 19 | 20 | THRESHOLDS = [ 21 | 0.5, 22 | 0.8, 23 | 0.9, 24 | 0.95, 25 | 0.98, 26 | 0.99, 27 | ] # allowed percentages of white pixels 28 | start_threshold_index = 0 # chosen percentage that we aim for 29 | 30 | def __init__(self, output_size, index=0): 31 | """ 32 | Initialize the RandomCrop transformer. 33 | 34 | Parameters: 35 | - output_size (int or tuple): The desired size of the cropped image. 36 | - index (int): The starting threshold index. 37 | """ 38 | assert isinstance(output_size, (int, tuple)) 39 | if isinstance(output_size, int): 40 | self.output_size = (output_size, output_size) 41 | else: 42 | self.output_size = output_size 43 | self.start_threshold_index = index 44 | 45 | @staticmethod 46 | def _calculate_white_percentage(img): 47 | """Calculate the percentage of white pixels in the image.""" 48 | white_pixels = np.sum(img == 255) 49 | total_pixels = img.size 50 | return white_pixels / total_pixels 51 | 52 | def __call__(self, sample): 53 | """ 54 | Apply the random crop to the input image + mask. 55 | 56 | Parameters: 57 | - sample (dict): Dictionary containing an image and its mask. 58 | 59 | Returns: 60 | - Dictionary containing the random crop of image and mask. 61 | """ 62 | image, label = sample["image"], sample["label"] 63 | 64 | w, h = image.size 65 | grid_size = self.output_size[0] 66 | 67 | # splitting image into grid for faster search 68 | cells = [(i, j) for i in range(0, w, grid_size) for j in range(0, h, grid_size)] 69 | random.shuffle(cells) 70 | 71 | # looking for as non-empty cell as possible 72 | # lowering threshold of whiteness if none are found 73 | threshold_sequence = RandomCrop.THRESHOLDS[RandomCrop.start_threshold_index :] 74 | for threshold in threshold_sequence: 75 | for i, j in cells: 76 | if i + self.output_size[0] <= w and j + self.output_size[1] <= h: 77 | cropped_image = tf.crop(image, i, j, *self.output_size) 78 | cropped_label = tf.crop(label, i, j, *self.output_size) 79 | 80 | if ( 81 | self._calculate_white_percentage(np.array(cropped_image)) 82 | <= threshold 83 | ): 84 | return {"image": cropped_image, "label": cropped_label} 85 | 86 | raise ValueError("Fully white image is given :(") 87 | 88 | 89 | class HorizontalFlip: 90 | """ 91 | A class to perform horizontal flipping of images. 92 | """ 93 | 94 | def __call__(self, sample): 95 | """ 96 | Flip the image and mask horizontally. 97 | 98 | Parameters: 99 | - sample (dict): Dictionary containing an image and its mask. 100 | 101 | Returns: 102 | - Dictionary containing the horizontally flipped image and mask. 103 | """ 104 | image, label = sample["image"], sample["label"] 105 | 106 | # Apply horizontal flip 107 | image = tf.hflip(image) 108 | label = tf.hflip(label) 109 | 110 | return {"image": image, "label": label} 111 | 112 | 113 | class VerticalFlip: 114 | """ 115 | A class to perform vertical flipping of images. 116 | """ 117 | 118 | def __call__(self, sample): 119 | """ 120 | Flip the image and mask vertically. 121 | 122 | Parameters: 123 | - sample (dict): Dictionary containing an image and its mask. 124 | 125 | Returns: 126 | - Dictionary containing the vertically flipped image and mask. 127 | """ 128 | image, label = sample["image"], sample["label"] 129 | 130 | # Apply vertical flip 131 | image = tf.vflip(image) 132 | label = tf.vflip(label) 133 | 134 | return {"image": image, "label": label} 135 | 136 | 137 | class Rotation: 138 | """ 139 | A class to rotate images by a given angle. 140 | """ 141 | 142 | def __init__(self, degrees): 143 | """ 144 | Initialize the Rotation transformer. 145 | 146 | Parameters: 147 | - degrees (float): The angle by which the image should be rotated. 148 | """ 149 | self.degrees = degrees 150 | 151 | def __call__(self, sample): 152 | """ 153 | Rotate the image and mask by a specified angle. 154 | 155 | Parameters: 156 | - sample (dict): Dictionary containing an image and its mask. 157 | 158 | Returns: 159 | - Dictionary containing the rotated image and mask. 160 | """ 161 | image, label = sample["image"], sample["label"] 162 | 163 | # Apply rotation 164 | image = tf.rotate(image, self.degrees) 165 | label = tf.rotate(label, self.degrees) 166 | 167 | return {"image": image, "label": label} 168 | 169 | 170 | class Resize: 171 | """ 172 | A class to resize images to a specified size. 173 | """ 174 | 175 | def __init__(self, size=1024): 176 | """ 177 | Initialize the Resize transformer. 178 | 179 | Parameters: 180 | - size (int): The desired size of the image after resizing. 181 | """ 182 | self.size = size 183 | 184 | def __call__(self, sample): 185 | """ 186 | Resize the image and mask to the specified size. 187 | 188 | Parameters: 189 | - sample (dict): Dictionary containing an image and its mask. 190 | 191 | Returns: 192 | - Dictionary containing the resized image and mask. 193 | """ 194 | image, label = sample["image"], sample["label"] 195 | 196 | # Resize both the image and label 197 | image = tf.resize(image, [self.size, self.size]) 198 | label = tf.resize(label, [self.size, self.size]) 199 | 200 | return {"image": image, "label": label} 201 | 202 | 203 | class ToTensorLab: 204 | """ 205 | A class to convert images from PIL format to PyTorch tensors. 206 | """ 207 | 208 | def __call__(self, sample): 209 | """ 210 | Convert the image and label from PIL format to PyTorch tensors. 211 | 212 | Parameters: 213 | - sample (dict): Dictionary containing an image and its mask. 214 | 215 | Returns: 216 | - Dictionary containing the image and mask as tensors. 217 | """ 218 | from u2net_train import HALF_PRECISION 219 | 220 | image, label = sample["image"], sample["label"] 221 | 222 | # Convert to tensor 223 | image = tf.to_tensor(image) 224 | label = tf.to_tensor(label) 225 | 226 | if HALF_PRECISION: 227 | image, label = image.half(), label.half() 228 | 229 | return {"image": image, "label": label} 230 | 231 | 232 | class SalObjDataset(Dataset): 233 | """ 234 | Custom dataset class for salient object detection. This class helps in 235 | loading images, their corresponding masks, and applying the desired 236 | transformations before feeding them to the network. 237 | """ 238 | 239 | def __init__(self, img_name_list, lbl_name_list, transform=None): 240 | """ 241 | Initialize the SalObjDataset. 242 | 243 | Parameters: 244 | - img_name_list (list): List of paths to the images. 245 | - lbl_name_list (list): List of paths to the corresponding masks. 246 | - transform (callable, optional): Optional transform to be applied to both image & mask. 247 | """ 248 | self.img_name_list = img_name_list 249 | self.lbl_name_list = lbl_name_list 250 | self.transform = transform 251 | 252 | def __len__(self): 253 | """Return the total number of images in the dataset.""" 254 | return len(self.img_name_list) 255 | 256 | def __getitem__(self, idx): 257 | """ 258 | Fetch an image and its corresponding label, apply any transformations if needed, 259 | and return them as a dictionary. 260 | 261 | Parameters: 262 | - idx (int): Index of the desired sample. 263 | 264 | Returns: 265 | - Dictionary containing an image and its label. 266 | """ 267 | # Read the images 268 | image_array = iio.imread(self.img_name_list[idx]) 269 | label_array = iio.imread(self.lbl_name_list[idx]) 270 | 271 | # Convert arrays to PIL images for compatibility with existing transforms 272 | image = Image.fromarray(image_array) 273 | label = Image.fromarray(label_array).convert("L") # Convert RGB to grayscale 274 | # TODO add a check here if it even needs to be converted, to increase perf 275 | 276 | # Clean up memory 277 | del image_array, label_array 278 | gc.collect() 279 | 280 | sample = {"image": image, "label": label} 281 | 282 | # Apply the transformations 283 | if self.transform: 284 | sample = self.transform(sample) 285 | 286 | return sample 287 | -------------------------------------------------------------------------------- /images/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore 3 | -------------------------------------------------------------------------------- /masks/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore 3 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .u2net import U2NET 2 | from .u2net import U2NETP 3 | -------------------------------------------------------------------------------- /model/u2net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class REBNCONV(nn.Module): 7 | def __init__(self, in_ch=3, out_ch=3, dirate=1): 8 | super(REBNCONV, self).__init__() 9 | 10 | self.conv_s1 = nn.Conv2d( 11 | in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate 12 | ) 13 | self.bn_s1 = nn.BatchNorm2d(out_ch) 14 | self.relu_s1 = nn.ReLU(inplace=True) 15 | 16 | def forward(self, x): 17 | hx = x 18 | xout = self.relu_s1(self.bn_s1(self.conv_s1(hx))) 19 | 20 | return xout 21 | 22 | 23 | ## upsample tensor 'src' to have the same spatial size with tensor 'tar' 24 | def _upsample_like(src, tar): 25 | src = F.interpolate(src, size=tar.shape[2:], mode="bilinear", align_corners=True) 26 | return src 27 | 28 | 29 | ### RSU-7 ### 30 | class RSU7(nn.Module): # UNet07DRES(nn.Module): 31 | def __init__(self, in_ch=3, mid_ch=12, out_ch=3): 32 | super(RSU7, self).__init__() 33 | 34 | self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) 35 | 36 | self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) 37 | self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 38 | 39 | self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1) 40 | self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 41 | 42 | self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1) 43 | self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 44 | 45 | self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1) 46 | self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 47 | 48 | self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1) 49 | self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 50 | 51 | self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1) 52 | 53 | self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2) 54 | 55 | self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 56 | self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 57 | self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 58 | self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 59 | self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 60 | self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) 61 | 62 | def forward(self, x): 63 | hx = x 64 | hxin = self.rebnconvin(hx) 65 | 66 | hx1 = self.rebnconv1(hxin) 67 | hx = self.pool1(hx1) 68 | 69 | hx2 = self.rebnconv2(hx) 70 | hx = self.pool2(hx2) 71 | 72 | hx3 = self.rebnconv3(hx) 73 | hx = self.pool3(hx3) 74 | 75 | hx4 = self.rebnconv4(hx) 76 | hx = self.pool4(hx4) 77 | 78 | hx5 = self.rebnconv5(hx) 79 | hx = self.pool5(hx5) 80 | 81 | hx6 = self.rebnconv6(hx) 82 | 83 | hx7 = self.rebnconv7(hx6) 84 | 85 | hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1)) 86 | hx6dup = _upsample_like(hx6d, hx5) 87 | 88 | hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1)) 89 | hx5dup = _upsample_like(hx5d, hx4) 90 | 91 | hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1)) 92 | hx4dup = _upsample_like(hx4d, hx3) 93 | 94 | hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1)) 95 | hx3dup = _upsample_like(hx3d, hx2) 96 | 97 | hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1)) 98 | hx2dup = _upsample_like(hx2d, hx1) 99 | 100 | hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1)) 101 | 102 | return hx1d + hxin 103 | 104 | 105 | ### RSU-6 ### 106 | class RSU6(nn.Module): # UNet06DRES(nn.Module): 107 | def __init__(self, in_ch=3, mid_ch=12, out_ch=3): 108 | super(RSU6, self).__init__() 109 | 110 | self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) 111 | 112 | self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) 113 | self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 114 | 115 | self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1) 116 | self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 117 | 118 | self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1) 119 | self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 120 | 121 | self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1) 122 | self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 123 | 124 | self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1) 125 | 126 | self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2) 127 | 128 | self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 129 | self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 130 | self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 131 | self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 132 | self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) 133 | 134 | def forward(self, x): 135 | hx = x 136 | 137 | hxin = self.rebnconvin(hx) 138 | 139 | hx1 = self.rebnconv1(hxin) 140 | hx = self.pool1(hx1) 141 | 142 | hx2 = self.rebnconv2(hx) 143 | hx = self.pool2(hx2) 144 | 145 | hx3 = self.rebnconv3(hx) 146 | hx = self.pool3(hx3) 147 | 148 | hx4 = self.rebnconv4(hx) 149 | hx = self.pool4(hx4) 150 | 151 | hx5 = self.rebnconv5(hx) 152 | 153 | hx6 = self.rebnconv6(hx5) 154 | 155 | hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1)) 156 | hx5dup = _upsample_like(hx5d, hx4) 157 | 158 | hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1)) 159 | hx4dup = _upsample_like(hx4d, hx3) 160 | 161 | hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1)) 162 | hx3dup = _upsample_like(hx3d, hx2) 163 | 164 | hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1)) 165 | hx2dup = _upsample_like(hx2d, hx1) 166 | 167 | hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1)) 168 | 169 | return hx1d + hxin 170 | 171 | 172 | ### RSU-5 ### 173 | class RSU5(nn.Module): # UNet05DRES(nn.Module): 174 | def __init__(self, in_ch=3, mid_ch=12, out_ch=3): 175 | super(RSU5, self).__init__() 176 | 177 | self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) 178 | 179 | self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) 180 | self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 181 | 182 | self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1) 183 | self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 184 | 185 | self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1) 186 | self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 187 | 188 | self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1) 189 | 190 | self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2) 191 | 192 | self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 193 | self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 194 | self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 195 | self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) 196 | 197 | def forward(self, x): 198 | hx = x 199 | 200 | hxin = self.rebnconvin(hx) 201 | 202 | hx1 = self.rebnconv1(hxin) 203 | hx = self.pool1(hx1) 204 | 205 | hx2 = self.rebnconv2(hx) 206 | hx = self.pool2(hx2) 207 | 208 | hx3 = self.rebnconv3(hx) 209 | hx = self.pool3(hx3) 210 | 211 | hx4 = self.rebnconv4(hx) 212 | 213 | hx5 = self.rebnconv5(hx4) 214 | 215 | hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1)) 216 | hx4dup = _upsample_like(hx4d, hx3) 217 | 218 | hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1)) 219 | hx3dup = _upsample_like(hx3d, hx2) 220 | 221 | hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1)) 222 | hx2dup = _upsample_like(hx2d, hx1) 223 | 224 | hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1)) 225 | 226 | return hx1d + hxin 227 | 228 | 229 | ### RSU-4 ### 230 | class RSU4(nn.Module): # UNet04DRES(nn.Module): 231 | def __init__(self, in_ch=3, mid_ch=12, out_ch=3): 232 | super(RSU4, self).__init__() 233 | 234 | self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) 235 | 236 | self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) 237 | self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 238 | 239 | self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1) 240 | self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 241 | 242 | self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1) 243 | 244 | self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2) 245 | 246 | self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 247 | self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 248 | self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) 249 | 250 | def forward(self, x): 251 | hx = x 252 | 253 | hxin = self.rebnconvin(hx) 254 | 255 | hx1 = self.rebnconv1(hxin) 256 | hx = self.pool1(hx1) 257 | 258 | hx2 = self.rebnconv2(hx) 259 | hx = self.pool2(hx2) 260 | 261 | hx3 = self.rebnconv3(hx) 262 | 263 | hx4 = self.rebnconv4(hx3) 264 | 265 | hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1)) 266 | hx3dup = _upsample_like(hx3d, hx2) 267 | 268 | hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1)) 269 | hx2dup = _upsample_like(hx2d, hx1) 270 | 271 | hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1)) 272 | 273 | return hx1d + hxin 274 | 275 | 276 | ### RSU-4F ### 277 | class RSU4F(nn.Module): # UNet04FRES(nn.Module): 278 | def __init__(self, in_ch=3, mid_ch=12, out_ch=3): 279 | super(RSU4F, self).__init__() 280 | 281 | self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) 282 | 283 | self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) 284 | self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2) 285 | self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4) 286 | 287 | self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8) 288 | 289 | self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4) 290 | self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2) 291 | self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) 292 | 293 | def forward(self, x): 294 | hx = x 295 | 296 | hxin = self.rebnconvin(hx) 297 | 298 | hx1 = self.rebnconv1(hxin) 299 | hx2 = self.rebnconv2(hx1) 300 | hx3 = self.rebnconv3(hx2) 301 | 302 | hx4 = self.rebnconv4(hx3) 303 | 304 | hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1)) 305 | hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1)) 306 | hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1)) 307 | 308 | return hx1d + hxin 309 | 310 | 311 | ##### U^2-Net #### 312 | class U2NET(nn.Module): 313 | def __init__(self, in_ch=3, out_ch=1): 314 | super(U2NET, self).__init__() 315 | 316 | self.stage1 = RSU7(in_ch, 32, 64) 317 | self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 318 | 319 | self.stage2 = RSU6(64, 32, 128) 320 | self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 321 | 322 | self.stage3 = RSU5(128, 64, 256) 323 | self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 324 | 325 | self.stage4 = RSU4(256, 128, 512) 326 | self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 327 | 328 | self.stage5 = RSU4F(512, 256, 512) 329 | self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 330 | 331 | self.stage6 = RSU4F(512, 256, 512) 332 | 333 | # decoder 334 | self.stage5d = RSU4F(1024, 256, 512) 335 | self.stage4d = RSU4(1024, 128, 256) 336 | self.stage3d = RSU5(512, 64, 128) 337 | self.stage2d = RSU6(256, 32, 64) 338 | self.stage1d = RSU7(128, 16, 64) 339 | 340 | self.side1 = nn.Conv2d(64, out_ch, 3, padding=1) 341 | self.side2 = nn.Conv2d(64, out_ch, 3, padding=1) 342 | self.side3 = nn.Conv2d(128, out_ch, 3, padding=1) 343 | self.side4 = nn.Conv2d(256, out_ch, 3, padding=1) 344 | self.side5 = nn.Conv2d(512, out_ch, 3, padding=1) 345 | self.side6 = nn.Conv2d(512, out_ch, 3, padding=1) 346 | 347 | self.outconv = nn.Conv2d(6 * out_ch, out_ch, 1) 348 | 349 | def forward(self, x): 350 | hx = x 351 | 352 | # stage 1 353 | hx1 = self.stage1(hx) 354 | hx = self.pool12(hx1) 355 | 356 | # stage 2 357 | hx2 = self.stage2(hx) 358 | hx = self.pool23(hx2) 359 | 360 | # stage 3 361 | hx3 = self.stage3(hx) 362 | hx = self.pool34(hx3) 363 | 364 | # stage 4 365 | hx4 = self.stage4(hx) 366 | hx = self.pool45(hx4) 367 | 368 | # stage 5 369 | hx5 = self.stage5(hx) 370 | hx = self.pool56(hx5) 371 | 372 | # stage 6 373 | hx6 = self.stage6(hx) 374 | hx6up = _upsample_like(hx6, hx5) 375 | 376 | # -------------------- decoder -------------------- 377 | hx5d = self.stage5d(torch.cat((hx6up, hx5), 1)) 378 | hx5dup = _upsample_like(hx5d, hx4) 379 | 380 | hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1)) 381 | hx4dup = _upsample_like(hx4d, hx3) 382 | 383 | hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1)) 384 | hx3dup = _upsample_like(hx3d, hx2) 385 | 386 | hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1)) 387 | hx2dup = _upsample_like(hx2d, hx1) 388 | 389 | hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1)) 390 | 391 | # side output 392 | d1 = self.side1(hx1d) 393 | 394 | d2 = self.side2(hx2d) 395 | d2 = _upsample_like(d2, d1) 396 | 397 | d3 = self.side3(hx3d) 398 | d3 = _upsample_like(d3, d1) 399 | 400 | d4 = self.side4(hx4d) 401 | d4 = _upsample_like(d4, d1) 402 | 403 | d5 = self.side5(hx5d) 404 | d5 = _upsample_like(d5, d1) 405 | 406 | d6 = self.side6(hx6) 407 | d6 = _upsample_like(d6, d1) 408 | 409 | d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1)) 410 | 411 | return ( 412 | F.sigmoid(d0), 413 | F.sigmoid(d1), 414 | F.sigmoid(d2), 415 | F.sigmoid(d3), 416 | F.sigmoid(d4), 417 | F.sigmoid(d5), 418 | F.sigmoid(d6), 419 | ) 420 | 421 | 422 | ### U^2-Net small ### 423 | class U2NETP(nn.Module): 424 | def __init__(self, in_ch=3, out_ch=1): 425 | super(U2NETP, self).__init__() 426 | 427 | self.stage1 = RSU7(in_ch, 16, 64) 428 | self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 429 | 430 | self.stage2 = RSU6(64, 16, 64) 431 | self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 432 | 433 | self.stage3 = RSU5(64, 16, 64) 434 | self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 435 | 436 | self.stage4 = RSU4(64, 16, 64) 437 | self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 438 | 439 | self.stage5 = RSU4F(64, 16, 64) 440 | self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 441 | 442 | self.stage6 = RSU4F(64, 16, 64) 443 | 444 | # decoder 445 | self.stage5d = RSU4F(128, 16, 64) 446 | self.stage4d = RSU4(128, 16, 64) 447 | self.stage3d = RSU5(128, 16, 64) 448 | self.stage2d = RSU6(128, 16, 64) 449 | self.stage1d = RSU7(128, 16, 64) 450 | 451 | self.side1 = nn.Conv2d(64, out_ch, 3, padding=1) 452 | self.side2 = nn.Conv2d(64, out_ch, 3, padding=1) 453 | self.side3 = nn.Conv2d(64, out_ch, 3, padding=1) 454 | self.side4 = nn.Conv2d(64, out_ch, 3, padding=1) 455 | self.side5 = nn.Conv2d(64, out_ch, 3, padding=1) 456 | self.side6 = nn.Conv2d(64, out_ch, 3, padding=1) 457 | 458 | self.outconv = nn.Conv2d(6 * out_ch, out_ch, 1) 459 | 460 | def forward(self, x): 461 | hx = x 462 | 463 | # stage 1 464 | hx1 = self.stage1(hx) 465 | hx = self.pool12(hx1) 466 | 467 | # stage 2 468 | hx2 = self.stage2(hx) 469 | hx = self.pool23(hx2) 470 | 471 | # stage 3 472 | hx3 = self.stage3(hx) 473 | hx = self.pool34(hx3) 474 | 475 | # stage 4 476 | hx4 = self.stage4(hx) 477 | hx = self.pool45(hx4) 478 | 479 | # stage 5 480 | hx5 = self.stage5(hx) 481 | hx = self.pool56(hx5) 482 | 483 | # stage 6 484 | hx6 = self.stage6(hx) 485 | hx6up = _upsample_like(hx6, hx5) 486 | 487 | # decoder 488 | hx5d = self.stage5d(torch.cat((hx6up, hx5), 1)) 489 | hx5dup = _upsample_like(hx5d, hx4) 490 | 491 | hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1)) 492 | hx4dup = _upsample_like(hx4d, hx3) 493 | 494 | hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1)) 495 | hx3dup = _upsample_like(hx3d, hx2) 496 | 497 | hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1)) 498 | hx2dup = _upsample_like(hx2d, hx1) 499 | 500 | hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1)) 501 | 502 | # side output 503 | d1 = self.side1(hx1d) 504 | 505 | d2 = self.side2(hx2d) 506 | d2 = _upsample_like(d2, d1) 507 | 508 | d3 = self.side3(hx3d) 509 | d3 = _upsample_like(d3, d1) 510 | 511 | d4 = self.side4(hx4d) 512 | d4 = _upsample_like(d4, d1) 513 | 514 | d5 = self.side5(hx5d) 515 | d5 = _upsample_like(d5, d1) 516 | 517 | d6 = self.side6(hx6) 518 | d6 = _upsample_like(d6, d1) 519 | 520 | d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1)) 521 | 522 | return ( 523 | F.sigmoid(d0), 524 | F.sigmoid(d1), 525 | F.sigmoid(d2), 526 | F.sigmoid(d3), 527 | F.sigmoid(d4), 528 | F.sigmoid(d5), 529 | F.sigmoid(d6), 530 | ) 531 | -------------------------------------------------------------------------------- /model/u2net_refactor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import math 5 | 6 | __all__ = ['U2NET_full', 'U2NET_lite'] 7 | 8 | 9 | def _upsample_like(x, size): 10 | return nn.Upsample(size=size, mode='bilinear', align_corners=False)(x) 11 | 12 | 13 | def _size_map(x, height): 14 | # {height: size} for Upsample 15 | size = list(x.shape[-2:]) 16 | sizes = {} 17 | for h in range(1, height): 18 | sizes[h] = size 19 | size = [math.ceil(w / 2) for w in size] 20 | return sizes 21 | 22 | 23 | class REBNCONV(nn.Module): 24 | def __init__(self, in_ch=3, out_ch=3, dilate=1): 25 | super(REBNCONV, self).__init__() 26 | 27 | self.conv_s1 = nn.Conv2d(in_ch, out_ch, 3, padding=1 * dilate, dilation=1 * dilate) 28 | self.bn_s1 = nn.BatchNorm2d(out_ch) 29 | self.relu_s1 = nn.ReLU(inplace=True) 30 | 31 | def forward(self, x): 32 | return self.relu_s1(self.bn_s1(self.conv_s1(x))) 33 | 34 | 35 | class RSU(nn.Module): 36 | def __init__(self, name, height, in_ch, mid_ch, out_ch, dilated=False): 37 | super(RSU, self).__init__() 38 | self.name = name 39 | self.height = height 40 | self.dilated = dilated 41 | self._make_layers(height, in_ch, mid_ch, out_ch, dilated) 42 | 43 | def forward(self, x): 44 | sizes = _size_map(x, self.height) 45 | x = self.rebnconvin(x) 46 | 47 | # U-Net like symmetric encoder-decoder structure 48 | def unet(x, height=1): 49 | if height < self.height: 50 | x1 = getattr(self, f'rebnconv{height}')(x) 51 | if not self.dilated and height < self.height - 1: 52 | x2 = unet(getattr(self, 'downsample')(x1), height + 1) 53 | else: 54 | x2 = unet(x1, height + 1) 55 | 56 | x = getattr(self, f'rebnconv{height}d')(torch.cat((x2, x1), 1)) 57 | return _upsample_like(x, sizes[height - 1]) if not self.dilated and height > 1 else x 58 | else: 59 | return getattr(self, f'rebnconv{height}')(x) 60 | 61 | return x + unet(x) 62 | 63 | def _make_layers(self, height, in_ch, mid_ch, out_ch, dilated=False): 64 | self.add_module('rebnconvin', REBNCONV(in_ch, out_ch)) 65 | self.add_module('downsample', nn.MaxPool2d(2, stride=2, ceil_mode=True)) 66 | 67 | self.add_module(f'rebnconv1', REBNCONV(out_ch, mid_ch)) 68 | self.add_module(f'rebnconv1d', REBNCONV(mid_ch * 2, out_ch)) 69 | 70 | for i in range(2, height): 71 | dilate = 1 if not dilated else 2 ** (i - 1) 72 | self.add_module(f'rebnconv{i}', REBNCONV(mid_ch, mid_ch, dilate=dilate)) 73 | self.add_module(f'rebnconv{i}d', REBNCONV(mid_ch * 2, mid_ch, dilate=dilate)) 74 | 75 | dilate = 2 if not dilated else 2 ** (height - 1) 76 | self.add_module(f'rebnconv{height}', REBNCONV(mid_ch, mid_ch, dilate=dilate)) 77 | 78 | 79 | class U2NET(nn.Module): 80 | def __init__(self, cfgs, out_ch): 81 | super(U2NET, self).__init__() 82 | self.out_ch = out_ch 83 | self._make_layers(cfgs) 84 | 85 | def forward(self, x): 86 | sizes = _size_map(x, self.height) 87 | maps = [] # storage for maps 88 | 89 | # side saliency map 90 | def unet(x, height=1): 91 | if height < 6: 92 | x1 = getattr(self, f'stage{height}')(x) 93 | x2 = unet(getattr(self, 'downsample')(x1), height + 1) 94 | x = getattr(self, f'stage{height}d')(torch.cat((x2, x1), 1)) 95 | side(x, height) 96 | return _upsample_like(x, sizes[height - 1]) if height > 1 else x 97 | else: 98 | x = getattr(self, f'stage{height}')(x) 99 | side(x, height) 100 | return _upsample_like(x, sizes[height - 1]) 101 | 102 | def side(x, h): 103 | # side output saliency map (before sigmoid) 104 | x = getattr(self, f'side{h}')(x) 105 | x = _upsample_like(x, sizes[1]) 106 | maps.append(x) 107 | 108 | def fuse(): 109 | # fuse saliency probability maps 110 | maps.reverse() 111 | x = torch.cat(maps, 1) 112 | x = getattr(self, 'outconv')(x) 113 | maps.insert(0, x) 114 | return [torch.sigmoid(x) for x in maps] 115 | 116 | unet(x) 117 | maps = fuse() 118 | return maps 119 | 120 | def _make_layers(self, cfgs): 121 | self.height = int((len(cfgs) + 1) / 2) 122 | self.add_module('downsample', nn.MaxPool2d(2, stride=2, ceil_mode=True)) 123 | for k, v in cfgs.items(): 124 | # build rsu block 125 | self.add_module(k, RSU(v[0], *v[1])) 126 | if v[2] > 0: 127 | # build side layer 128 | self.add_module(f'side{v[0][-1]}', nn.Conv2d(v[2], self.out_ch, 3, padding=1)) 129 | # build fuse layer 130 | self.add_module('outconv', nn.Conv2d(int(self.height * self.out_ch), self.out_ch, 1)) 131 | 132 | 133 | def U2NET_full(): 134 | full = { 135 | # cfgs for building RSUs and sides 136 | # {stage : [name, (height(L), in_ch, mid_ch, out_ch, dilated), side]} 137 | 'stage1': ['En_1', (7, 3, 32, 64), -1], 138 | 'stage2': ['En_2', (6, 64, 32, 128), -1], 139 | 'stage3': ['En_3', (5, 128, 64, 256), -1], 140 | 'stage4': ['En_4', (4, 256, 128, 512), -1], 141 | 'stage5': ['En_5', (4, 512, 256, 512, True), -1], 142 | 'stage6': ['En_6', (4, 512, 256, 512, True), 512], 143 | 'stage5d': ['De_5', (4, 1024, 256, 512, True), 512], 144 | 'stage4d': ['De_4', (4, 1024, 128, 256), 256], 145 | 'stage3d': ['De_3', (5, 512, 64, 128), 128], 146 | 'stage2d': ['De_2', (6, 256, 32, 64), 64], 147 | 'stage1d': ['De_1', (7, 128, 16, 64), 64], 148 | } 149 | return U2NET(cfgs=full, out_ch=1) 150 | 151 | 152 | def U2NET_lite(): 153 | lite = { 154 | # cfgs for building RSUs and sides 155 | # {stage : [name, (height(L), in_ch, mid_ch, out_ch, dilated), side]} 156 | 'stage1': ['En_1', (7, 3, 16, 64), -1], 157 | 'stage2': ['En_2', (6, 64, 16, 64), -1], 158 | 'stage3': ['En_3', (5, 64, 16, 64), -1], 159 | 'stage4': ['En_4', (4, 64, 16, 64), -1], 160 | 'stage5': ['En_5', (4, 64, 16, 64, True), -1], 161 | 'stage6': ['En_6', (4, 64, 16, 64, True), 64], 162 | 'stage5d': ['De_5', (4, 128, 16, 64, True), 64], 163 | 'stage4d': ['De_4', (4, 128, 16, 64), 64], 164 | 'stage3d': ['De_3', (5, 128, 16, 64), 64], 165 | 'stage2d': ['De_2', (6, 128, 16, 64), 64], 166 | 'stage1d': ['De_1', (7, 128, 16, 64), 64], 167 | } 168 | return U2NET(cfgs=lite, out_ch=1) 169 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | numpy 4 | Pillow 5 | imageio 6 | onnx -------------------------------------------------------------------------------- /saved_models/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore 3 | -------------------------------------------------------------------------------- /u2net_train.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script trains a deep learning model on an image dataset using various augmentations like flips, rotations, and crops. 3 | The model is intended to use with rembg for background removal. 4 | """ 5 | import os 6 | import argparse 7 | import glob 8 | import time 9 | 10 | import torch 11 | from torch import nn 12 | from torch import optim 13 | from torch.optim.lr_scheduler import CosineAnnealingLR 14 | from torch.utils.data import DataLoader 15 | from torchvision.transforms import transforms 16 | 17 | from data_loader import ( 18 | SalObjDataset, 19 | RandomCrop, 20 | Resize, 21 | ToTensorLab, 22 | VerticalFlip, 23 | HorizontalFlip, 24 | Rotation, 25 | ) 26 | from model import U2NET 27 | 28 | SAVE_FRQ = 0 29 | CHECK_FRQ = 0 30 | 31 | #: float16 if true, float32 if false 32 | HALF_PRECISION = False # not tested!! 33 | 34 | # Defining BCE Loss for Binary Cross Entropy 35 | bce_loss = nn.BCELoss(reduction="mean") 36 | 37 | # Definition of different augmentations 38 | train_configs = { 39 | "plain_resized": { 40 | "name": "Plain Images", 41 | "message": "Learning the dataset itself...\n", 42 | "transform": [Resize(1024), ToTensorLab()], 43 | "batch_factor": 1, 44 | }, 45 | "flipped_v": { 46 | "name": "Vertical Flips", 47 | "message": "Learning the vertical flips of dataset images...\n", 48 | "transform": [Resize(512), VerticalFlip(), ToTensorLab()], 49 | "batch_factor": 4, 50 | }, 51 | "flipped_h": { 52 | "name": "Horizontal Flips", 53 | "message": "Learning the horizontal flips of dataset images...\n", 54 | "transform": [Resize(512), HorizontalFlip(), ToTensorLab()], 55 | "batch_factor": 4, 56 | }, 57 | "rotated_l": { 58 | "name": "Left Rotations", 59 | "message": "Learning the left rotations of dataset images...\n", 60 | "transform": [Resize(512), Rotation(90), ToTensorLab()], 61 | "batch_factor": 4, 62 | }, 63 | "rotated_r": { 64 | "name": "Right Rotations", 65 | "message": "Learning the right rotation of dataset images...\n", 66 | "transform": [Resize(512), Rotation(270), ToTensorLab()], 67 | "batch_factor": 4, 68 | }, 69 | "crops": { 70 | "name": "256px Crops", 71 | "message": "Augmenting dataset with random crops...\n", 72 | "transform": [Resize(2304), RandomCrop(256, 0), ToTensorLab()], 73 | "batch_factor": 16, # because they are smaller => we can fit more in memory 74 | }, 75 | "crops_loyal": { 76 | "name": "Different crops", 77 | "message": "Augmenting dataset with different crops...\n", 78 | "transform": [Resize(2304), RandomCrop(256, 3), ToTensorLab()], 79 | "batch_factor": 16, # same here 80 | }, 81 | } 82 | 83 | 84 | def dice_loss(predict, target, smooth=1.0): 85 | """ 86 | Calculates the Dice Loss. 87 | 88 | Parameters: 89 | predict (Tensor): Predicted output. 90 | target (Tensor): Ground truth/target output. 91 | smooth (float, optional): A smoothing factor to prevent division by zero. Defaults to 1.0. 92 | 93 | Returns: 94 | float: Dice Loss value. 95 | """ 96 | predict = predict.contiguous() 97 | target = target.contiguous() 98 | 99 | intersection = (predict * target).sum(dim=2).sum(dim=2) 100 | 101 | loss = 1 - ( 102 | (2.0 * intersection + smooth) 103 | / (predict.sum(dim=2).sum(dim=2) + target.sum(dim=2).sum(dim=2) + smooth) 104 | ) 105 | 106 | return loss.mean() 107 | 108 | 109 | def get_args(): 110 | """ 111 | Parses command-line arguments. 112 | 113 | Returns: 114 | argparse.Namespace: Parsed arguments. 115 | """ 116 | parser = argparse.ArgumentParser( 117 | description="A program that trains ONNX model for use with rembg" 118 | ) 119 | 120 | parser.add_argument( 121 | "-i", 122 | "--tra_image_dir", 123 | type=str, 124 | default="images", 125 | help="Directory with images.", 126 | ) 127 | parser.add_argument( 128 | "-m", 129 | "--tra_masks_dir", 130 | type=str, 131 | default="masks", 132 | help="Directory with masks.", 133 | ) 134 | parser.add_argument( 135 | "-s", 136 | "--save_frq", 137 | type=int, 138 | default=5, 139 | help="Frequency of saving onnx model (every X epochs).", 140 | ) 141 | parser.add_argument( 142 | "-c", 143 | "--check_frq", 144 | type=int, 145 | default=5, 146 | help="Frequency of saving checkpoints (every X epochs).", 147 | ) 148 | parser.add_argument( 149 | "-b", 150 | "--batch", 151 | type=int, 152 | default=3, 153 | help="Size of a single batch loaded into memory. 1 is lowest possible; it may run on 8gb GPUs but also may not. 3 works well on 32gb of shared memory.", 154 | ) 155 | parser.add_argument( 156 | "-p", 157 | "--plain_resized", 158 | type=int, 159 | default=5, 160 | help="Number of training epochs for plain_resized.", 161 | ) 162 | parser.add_argument( 163 | "-vf", 164 | "--vflipped", 165 | type=int, 166 | default=2, 167 | help="Number of training epochs for flipped_v.", 168 | ) 169 | parser.add_argument( 170 | "-hf", 171 | "--hflipped", 172 | type=int, 173 | default=2, 174 | help="Number of training epochs for flipped_h.", 175 | ) 176 | parser.add_argument( 177 | "-left", 178 | "--rotated_l", 179 | type=int, 180 | default=2, 181 | help="Number of training epochs for rotated_l.", 182 | ) 183 | parser.add_argument( 184 | "-right", 185 | "--rotated_r", 186 | type=int, 187 | default=2, 188 | help="Number of training epochs for rotated_r.", 189 | ) 190 | parser.add_argument( 191 | "-r", 192 | "--rand", 193 | type=int, 194 | default=20, 195 | help="Number of training epochs for 256px crops.", 196 | ) 197 | parser.add_argument( 198 | "-l", 199 | "--loyal", 200 | type=int, 201 | default=7, 202 | help="Number of training epochs for different 256px crops.", 203 | ) 204 | 205 | return parser.parse_args() 206 | 207 | 208 | def get_device(): 209 | """ 210 | Determines the device to run the model on (GPU/CPU). 211 | 212 | Returns: 213 | torch.device: Device type ('cuda:0', 'mps', or 'cpu'). 214 | """ 215 | if torch.cuda.is_available(): 216 | print("NVIDIA CUDA acceleration enabled") 217 | torch.multiprocessing.set_start_method("spawn") 218 | return torch.device("cuda:0") 219 | elif torch.backends.mps.is_available(): 220 | print("Apple Metal Performance Shaders acceleration enabled") 221 | torch.multiprocessing.set_start_method("fork") 222 | return torch.device("mps") 223 | else: 224 | print("No GPU acceleration :/") 225 | return torch.device("cpu") 226 | 227 | 228 | def save_model_as_onnx(model, device, ite_num, input_tensor_size=(1, 3, 320, 320)): 229 | """ 230 | Saves the model in ONNX format. 231 | 232 | Parameters: 233 | model (nn.Module): The trained model. 234 | device (torch.device): The device where the model is located. 235 | ite_num (int): Amount of epochs already done. 236 | input_tensor_size (tuple, optional): The size of the input tensor. Defaults to (1, 3, 320, 320). 237 | """ 238 | x = torch.randn(*input_tensor_size, requires_grad=True) 239 | x = x.to(device) 240 | 241 | onnx_file_name = f"saved_models/{ite_num}.onnx" 242 | torch.onnx.export( 243 | model, 244 | x, 245 | onnx_file_name, 246 | export_params=True, 247 | opset_version=16, 248 | do_constant_folding=True, 249 | input_names=["input"], 250 | output_names=["output"], 251 | dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}}, 252 | ) 253 | print("Model saved to:", onnx_file_name, "\n") 254 | del x 255 | 256 | 257 | def save_checkpoint(state, filename="saved_models/checkpoint.pth.tar"): 258 | """ 259 | Saves the model's state as a checkpoint. 260 | 261 | Parameters: 262 | state (dict): State of the model to save. 263 | filename (str, optional): Path to save the checkpoint. Defaults to "saved_models/checkpoint.pth.tar". 264 | """ 265 | torch.save({"state": state}, filename) 266 | 267 | 268 | def load_checkpoint(net, optimizer, filename="saved_models/checkpoint.pth.tar"): 269 | """ 270 | Loads model state from a checkpoint. 271 | 272 | Parameters: 273 | net (nn.Module): Model architecture. 274 | optimizer (Optimizer): Optimizer used during training. 275 | filename (str, optional): Path to the checkpoint. Defaults to "saved_models/checkpoint.pth.tar". 276 | 277 | Returns: 278 | dict: Counts of training epochs for various augmentations. 279 | """ 280 | training_counts = { 281 | "plain_resized": 0, 282 | "flipped_v": 0, 283 | "flipped_h": 0, 284 | "rotated_l": 0, 285 | "rotated_r": 0, 286 | "crops": 0, 287 | "crops_loyal": 0, 288 | } 289 | 290 | if os.path.isfile(filename): 291 | checkpoint = torch.load(filename) 292 | net.load_state_dict(checkpoint["state"]["state_dict"]) 293 | optimizer.load_state_dict(checkpoint["state"]["optimizer"]) 294 | 295 | # Update the dictionary with values from the checkpoint 296 | # Only updates keys that exist in both dictionaries 297 | # This is done for expandability in future 298 | for key in training_counts: 299 | if key in checkpoint["state"]["training_counts"]: 300 | training_counts[key] = checkpoint["state"]["training_counts"][key] 301 | 302 | print(f"Loading checkpoint '{filename}'...") 303 | else: 304 | print(f"No checkpoint file found at '{filename}'. Starting from scratch...") 305 | print("\n———") 306 | 307 | return training_counts 308 | 309 | 310 | def load_dataset(img_dir, lbl_dir, ext): 311 | """ 312 | Loads image and mask filenames from given directories. 313 | 314 | Parameters: 315 | img_dir (str): Directory with images. 316 | lbl_dir (str): Directory with masks. 317 | ext (str): Extension of the image files (e.g., '.png'). 318 | 319 | Returns: 320 | list, list: Lists of image and mask filenames. 321 | """ 322 | img_list = glob.glob(os.path.join(img_dir, "*" + ext)) 323 | lbl_list = [os.path.join(lbl_dir, os.path.basename(img)) for img in img_list] 324 | 325 | return img_list, lbl_list 326 | 327 | 328 | def multi_loss_fusion(d_list, labels_v): 329 | """ 330 | Combines BCE and Dice losses. Gives more weight to dice loss. 331 | 332 | Parameters: 333 | d_list (list): List of predicted outputs. 334 | labels_v (Tensor): Ground truth/target outputs. 335 | 336 | Returns: 337 | float: Combined loss value. 338 | """ 339 | bce_losses = [bce_loss(d, labels_v) for d in d_list] 340 | dice_losses = [dice_loss(d, labels_v) for d in d_list] 341 | w_bce, w_dice = 1 / 3, 2 / 3 342 | combined_losses = [ 343 | w_bce * bce + w_dice * dice for bce, dice in zip(bce_losses, dice_losses) 344 | ] 345 | total_loss = sum(combined_losses) 346 | # return combined_losses[0], total_loss 347 | return total_loss 348 | 349 | 350 | def get_dataloader(tra_img_name_list, tra_lbl_name_list, transform, batch_size): 351 | """ 352 | Creates a DataLoader for the dataset. 353 | 354 | Parameters: 355 | tra_img_name_list (list): List of image filenames. 356 | tra_lbl_name_list (list): List of mask filenames. 357 | transform (transforms.Compose): Transformations to apply. 358 | batch_size (int): Amount of tensors to load into memory at once. 359 | 360 | Returns: 361 | DataLoader: DataLoader object for the dataset. 362 | """ 363 | # Dataset with given transform 364 | dataset = SalObjDataset( 365 | img_name_list=tra_img_name_list, 366 | lbl_name_list=tra_lbl_name_list, 367 | transform=transform, 368 | ) 369 | 370 | cores = 2 # freeing up memory a bit 371 | 372 | # DataLoader for the dataset 373 | dataloader = DataLoader( 374 | dataset, batch_size=max(1, int(batch_size)), shuffle=True, num_workers=cores 375 | ) 376 | 377 | return dataloader 378 | 379 | 380 | def train_model(net, optimizer, scheduler, dataloader, device): 381 | """ 382 | Trains the model for a single epoch. 383 | 384 | Parameters: 385 | net (nn.Module): Model architecture. 386 | optimizer (Optimizer): Optimizer used during training. 387 | scheduler (lr_scheduler): Learning rate scheduler. 388 | dataloader (DataLoader): DataLoader for the dataset. 389 | device (torch.device): Device to train on (e.g., GPU/CPU). 390 | """ 391 | epoch_loss = 0.0 392 | 393 | for i, data in enumerate(dataloader): 394 | print(f" Iteration: {i + 1:4}/{len(dataloader)}, ", end="") 395 | inputs = data["image"].to(device) 396 | labels = data["label"].to(device) 397 | optimizer.zero_grad() 398 | 399 | outputs = net(inputs) 400 | 401 | combined_loss = multi_loss_fusion(outputs, labels) 402 | combined_loss.backward() 403 | torch.nn.utils.clip_grad_norm_( 404 | net.parameters(), max_norm=1.0 405 | ) # Clip gradients if their norm exceeds 1 406 | optimizer.step() 407 | optimizer.zero_grad() 408 | scheduler.step() 409 | 410 | epoch_loss += combined_loss.item() 411 | 412 | print(f"loss: {epoch_loss / (i + 1):.5f}") 413 | 414 | return epoch_loss 415 | 416 | 417 | def train_epochs( 418 | net, optimizer, scheduler, dataloader, device, epochs, training_counts, key 419 | ): 420 | """ 421 | Train the model for given amount of epochs. Updates training counts. 422 | 423 | Parameters: 424 | net (nn.Module): The model architecture to be trained. 425 | optimizer (Optimizer): The optimizer used during training. 426 | scheduler (lr_scheduler): Scheduler to adjust the learning rate during training. 427 | dataloader (DataLoader): DataLoader object supplying the training data. 428 | device (torch.device): The device on which the training will take place (e.g., GPU/CPU). 429 | epochs (range): Number of epochs for which the model will be trained. 430 | training_counts (dict): Dictionary tracking the number of epochs trained for different configurations. 431 | key (str): Key for the specific training configuration. 432 | 433 | Returns: 434 | nn.Module: Trained model. 435 | """ 436 | for index, epoch in enumerate(epochs): 437 | start_time = time.time() 438 | 439 | # this is where the training occurs! 440 | print(f" Epoch: {epoch + 1}/{epochs[-1] + 1}") 441 | epoch_loss = train_model(net, optimizer, scheduler, dataloader, device) 442 | print(f" Loss per epoch: {epoch_loss}\n") 443 | 444 | if sum(training_counts.values()) == 3: 445 | elapsed_time = time.time() - start_time 446 | minutes, seconds = divmod(elapsed_time, 60) 447 | perf = minutes + (seconds / 60) 448 | print(f" Expected performance is {perf:.1f} minutes per epoch.\n") 449 | # Increment the corresponding training count 450 | training_counts[key] += 1 451 | 452 | # Saves model every save_frq iterations or during the last one 453 | if sum(training_counts.values()) % SAVE_FRQ == 0 or index + 1 == len(epochs): 454 | # in ONNX format! ^_^ UwU 455 | save_model_as_onnx(net, device, sum(training_counts.values())) 456 | 457 | # Saves checkpoint every check_frq epochs or during the last one 458 | if sum(training_counts.values()) % CHECK_FRQ == 0 or index + 1 == len(epochs): 459 | save_checkpoint( 460 | { 461 | "epoch_count": epoch + 1, 462 | "state_dict": net.state_dict(), 463 | "optimizer": optimizer.state_dict(), 464 | "training_counts": training_counts, 465 | } 466 | ) 467 | print("Checkpoint made\n") 468 | 469 | return net 470 | 471 | 472 | def main(): 473 | """ 474 | Main function for initiating training of the model on the dataset. 475 | """ 476 | device = get_device() 477 | 478 | args = get_args() 479 | global SAVE_FRQ, CHECK_FRQ 480 | SAVE_FRQ = args.save_frq 481 | CHECK_FRQ = args.check_frq 482 | tra_image_dir = args.tra_image_dir 483 | tra_label_dir = args.tra_masks_dir 484 | batch = args.batch 485 | 486 | targets = { 487 | "plain_resized": args.plain_resized, 488 | "flipped_h": args.hflipped, 489 | "flipped_v": args.vflipped, 490 | "rotated_l": args.rotated_l, 491 | "rotated_r": args.rotated_r, 492 | "crops": args.rand, 493 | "crops_loyal": args.loyal, 494 | } 495 | 496 | if not os.path.exists("saved_models"): 497 | os.makedirs("saved_models") 498 | 499 | tra_img_name_list, tra_lbl_name_list = load_dataset( 500 | tra_image_dir, tra_label_dir, ".png" 501 | ) 502 | 503 | print(f"Images: {format(len(tra_img_name_list))}, masks: {len(tra_lbl_name_list)}") 504 | 505 | if len(tra_img_name_list) != len(tra_lbl_name_list): 506 | print("Different amounts of images and masks, can't proceed mate") 507 | return 508 | 509 | if HALF_PRECISION: 510 | net = U2NET(3, 1).half() 511 | else: 512 | net = U2NET(3, 1) 513 | net.to(device) 514 | net.train() 515 | 516 | optimizer = optim.Adam( 517 | net.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0 518 | ) 519 | 520 | training_counts = load_checkpoint(net, optimizer) 521 | # dealing with negative values, if model was trained for more epochs than in target: 522 | for key, count in training_counts.items(): 523 | if targets[key] < count: 524 | targets[key] = count 525 | print( 526 | f"Task: {train_configs[key]['name']:<17} Epochs done: {count}/{targets[key]}" 527 | ) 528 | 529 | print("———\n") 530 | 531 | scheduler = CosineAnnealingLR(optimizer, T_max=sum(targets.values()), eta_min=1e-6) 532 | 533 | def create_and_train(transform, batch_size, epochs, train_type): 534 | """Creates a dataloader and trains the network using the given parameters.""" 535 | dataloader = get_dataloader( 536 | tra_img_name_list, tra_lbl_name_list, transform, batch_size 537 | ) 538 | train_epochs( 539 | net, 540 | optimizer, 541 | scheduler, 542 | dataloader, 543 | device, 544 | epochs, 545 | training_counts, 546 | train_type, 547 | ) 548 | 549 | # Training loop 550 | for train_type, config in train_configs.items(): 551 | if training_counts[train_type] < targets[train_type]: 552 | print(config["message"]) 553 | epochs = range(training_counts[train_type], targets[train_type]) 554 | transform = transforms.Compose(config["transform"]) 555 | 556 | create_and_train( 557 | transform, batch * config["batch_factor"], epochs, train_type 558 | ) 559 | 560 | training_counts[train_type] = targets[train_type] 561 | 562 | print("Nothing left to do!") 563 | 564 | 565 | if __name__ == "__main__": 566 | main() 567 | --------------------------------------------------------------------------------