├── models └── .gitignore ├── test_images └── irani_043.png ├── README.md ├── requirements.txt ├── src ├── main.py └── models.py └── .gitignore /models/.gitignore: -------------------------------------------------------------------------------- 1 | *.pt 2 | -------------------------------------------------------------------------------- /test_images/irani_043.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yisaienkov/target_class_transformation_for_segmentation_task_tool/master/test_images/irani_043.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TARGET CLASS TRANSFORMATION FOR SEGMENTATION TASK 2 | 3 | ## Run 4 | 5 | ```bash 6 | > python src/main.py 7 | ``` 8 | 9 | ## App View 10 | 11 | ![Image](https://i.imgur.com/qotqoV0.jpg) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | aiofiles==23.2.1 2 | albumentations==1.3.1 3 | altair==5.2.0 4 | annotated-types==0.6.0 5 | anyio==4.2.0 6 | attrs==23.2.0 7 | certifi==2024.2.2 8 | charset-normalizer==3.3.2 9 | click==8.1.7 10 | colorama==0.4.6 11 | contourpy==1.2.0 12 | cycler==0.12.1 13 | exceptiongroup==1.2.0 14 | fastapi==0.109.2 15 | ffmpy==0.3.1 16 | filelock==3.13.1 17 | fonttools==4.48.1 18 | fsspec==2024.2.0 19 | gradio==4.18.0 20 | gradio_client==0.10.0 21 | h11==0.14.0 22 | httpcore==1.0.2 23 | httpx==0.26.0 24 | huggingface-hub==0.20.3 25 | idna==3.6 26 | imageio==2.34.0 27 | importlib-resources==6.1.1 28 | Jinja2==3.1.3 29 | joblib==1.3.2 30 | jsonschema==4.21.1 31 | jsonschema-specifications==2023.12.1 32 | kiwisolver==1.4.5 33 | lazy_loader==0.3 34 | markdown-it-py==3.0.0 35 | MarkupSafe==2.1.5 36 | matplotlib==3.8.2 37 | mdurl==0.1.2 38 | mpmath==1.3.0 39 | networkx==3.2.1 40 | numpy==1.26.4 41 | opencv-python==4.9.0.80 42 | opencv-python-headless==4.9.0.80 43 | orjson==3.9.13 44 | packaging==23.2 45 | pandas==2.2.0 46 | pillow==10.2.0 47 | pydantic==2.6.1 48 | pydantic_core==2.16.2 49 | pydub==0.25.1 50 | Pygments==2.17.2 51 | pyparsing==3.1.1 52 | python-dateutil==2.8.2 53 | python-multipart==0.0.9 54 | pytz==2024.1 55 | PyYAML==6.0.1 56 | qudida==0.0.4 57 | referencing==0.33.0 58 | requests==2.31.0 59 | rich==13.7.0 60 | rpds-py==0.17.1 61 | ruff==0.2.1 62 | scikit-image==0.22.0 63 | scikit-learn==1.4.0 64 | scipy==1.12.0 65 | semantic-version==2.10.0 66 | shellingham==1.5.4 67 | six==1.16.0 68 | sniffio==1.3.0 69 | starlette==0.36.3 70 | sympy==1.12 71 | threadpoolctl==3.2.0 72 | tifffile==2024.1.30 73 | tomlkit==0.12.0 74 | toolz==0.12.1 75 | torch==2.2.0 76 | torchvision==0.17.0 77 | tqdm==4.66.2 78 | typer==0.9.0 79 | typing_extensions==4.9.0 80 | tzdata==2024.1 81 | urllib3==2.2.0 82 | uvicorn==0.27.1 83 | websockets==11.0.3 84 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | import cv2 3 | 4 | from models import Model 5 | 6 | 7 | gen = Model() 8 | 9 | def draw_rectangle(image, *arr): 10 | output_image = image.copy() 11 | gen_output_image = image.copy() 12 | for i in range(0, len(arr), 4): 13 | x1, y1, x2, y2 = int(arr[i + 0]), int(arr[i + 1]), int(arr[i + 2]), int(arr[i + 3]) 14 | if x1 == 0 and x2 == 0: 15 | continue 16 | 17 | output_image = cv2.rectangle(output_image, (x1, y1), (x2, y2), (0, 255, 0), 2) 18 | 19 | tooth = image[y1:y2, x1:x2] 20 | res = gen(tooth) 21 | gen_output_image[y1:y2, x1:x2] = (cv2.cvtColor(cv2.resize(res, (x2 - x1, y2 - y1)), cv2.COLOR_GRAY2RGB) * 255).astype(int) 22 | 23 | return output_image, gen_output_image 24 | 25 | 26 | def variable_outputs(k): 27 | k = int(k) * 4 28 | return [gr.Textbox(visible=True)] * k + [gr.Textbox(visible=False)] * (20 - k) 29 | 30 | 31 | with gr.Blocks() as demo: 32 | with gr.Row(): 33 | with gr.Column(): 34 | image_input = gr.Image() 35 | coords = [] 36 | n_boxes = gr.Number(1, label="Boxes", min_width=80, maximum=5, minimum=0) 37 | for i in range(5): 38 | with gr.Row(): 39 | visible = False 40 | if i == 0: 41 | visible = True 42 | coords.append(gr.Number(0, label="xs", min_width=80, visible=visible)) 43 | coords.append(gr.Number(0, label="ys", min_width=80, visible=visible)) 44 | coords.append(gr.Number(0, label="xe", min_width=80, visible=visible)) 45 | coords.append(gr.Number(0, label="ye", min_width=80, visible=visible)) 46 | 47 | with gr.Column(): 48 | image_rect_output = gr.Image() 49 | image_gen_output = gr.Image() 50 | 51 | n_boxes.change(variable_outputs, n_boxes, coords) 52 | image_button = gr.Button("Generate") 53 | 54 | image_button.click(draw_rectangle, inputs=[image_input, *coords], outputs=[image_rect_output, image_gen_output]) 55 | 56 | if __name__ == "__main__": 57 | demo.launch() -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by https://www.toptal.com/developers/gitignore/api/python,visualstudiocode 2 | # Edit at https://www.toptal.com/developers/gitignore?templates=python,visualstudiocode 3 | 4 | ### Python ### 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | cover/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | .pybuilder/ 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | # For a library or package, you might want to ignore these files since the code is 91 | # intended to run in multiple environments; otherwise, check them in: 92 | # .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # poetry 102 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 103 | # This is especially recommended for binary packages to ensure reproducibility, and is more 104 | # commonly ignored for libraries. 105 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 106 | #poetry.lock 107 | 108 | # pdm 109 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 110 | #pdm.lock 111 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 112 | # in version control. 113 | # https://pdm.fming.dev/#use-with-ide 114 | .pdm.toml 115 | 116 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 117 | __pypackages__/ 118 | 119 | # Celery stuff 120 | celerybeat-schedule 121 | celerybeat.pid 122 | 123 | # SageMath parsed files 124 | *.sage.py 125 | 126 | # Environments 127 | .env 128 | .venv 129 | env/ 130 | venv/ 131 | ENV/ 132 | env.bak/ 133 | venv.bak/ 134 | 135 | # Spyder project settings 136 | .spyderproject 137 | .spyproject 138 | 139 | # Rope project settings 140 | .ropeproject 141 | 142 | # mkdocs documentation 143 | /site 144 | 145 | # mypy 146 | .mypy_cache/ 147 | .dmypy.json 148 | dmypy.json 149 | 150 | # Pyre type checker 151 | .pyre/ 152 | 153 | # pytype static type analyzer 154 | .pytype/ 155 | 156 | # Cython debug symbols 157 | cython_debug/ 158 | 159 | # PyCharm 160 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 161 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 162 | # and can be added to the global gitignore or merged into this file. For a more nuclear 163 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 164 | #.idea/ 165 | 166 | ### Python Patch ### 167 | # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration 168 | poetry.toml 169 | 170 | # ruff 171 | .ruff_cache/ 172 | 173 | # LSP config files 174 | pyrightconfig.json 175 | 176 | ### VisualStudioCode ### 177 | .vscode/* 178 | !.vscode/settings.json 179 | !.vscode/tasks.json 180 | !.vscode/launch.json 181 | !.vscode/extensions.json 182 | !.vscode/*.code-snippets 183 | 184 | # Local History for Visual Studio Code 185 | .history/ 186 | 187 | # Built Visual Studio Code Extensions 188 | *.vsix 189 | 190 | ### VisualStudioCode Patch ### 191 | # Ignore all local history of files 192 | .history 193 | .ionide 194 | 195 | # End of https://www.toptal.com/developers/gitignore/api/python,visualstudiocode -------------------------------------------------------------------------------- /src/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as torch_f 4 | import albumentations as A 5 | from albumentations import pytorch as ATorch 6 | import numpy as np 7 | import cv2 8 | 9 | 10 | IMAGE_SIZE = 128 11 | 12 | 13 | class DoubleConv(nn.Module): 14 | """(convolution => [BN] => ReLU) * 2""" 15 | 16 | def __init__(self, in_channels, out_channels, mid_channels=None): 17 | super().__init__() 18 | if not mid_channels: 19 | mid_channels = out_channels 20 | self.double_conv = nn.Sequential( 21 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False), 22 | nn.BatchNorm2d(mid_channels), 23 | nn.ReLU(inplace=True), 24 | nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False), 25 | nn.BatchNorm2d(out_channels), 26 | nn.ReLU(inplace=True) 27 | ) 28 | 29 | def forward(self, x): 30 | return self.double_conv(x) 31 | 32 | 33 | class Down(nn.Module): 34 | """Downscaling with maxpool then double conv""" 35 | 36 | def __init__(self, in_channels, out_channels): 37 | super().__init__() 38 | self.maxpool_conv = nn.Sequential( 39 | nn.MaxPool2d(2), 40 | DoubleConv(in_channels, out_channels) 41 | ) 42 | 43 | def forward(self, x): 44 | return self.maxpool_conv(x) 45 | 46 | 47 | class Up(nn.Module): 48 | """Upscaling then double conv""" 49 | 50 | def __init__(self, in_channels, out_channels, bilinear=False): 51 | super().__init__() 52 | 53 | # if bilinear, use the normal convolutions to reduce the number of channels 54 | if bilinear: 55 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 56 | self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) 57 | else: 58 | self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) 59 | self.conv = DoubleConv(in_channels, out_channels) 60 | 61 | def forward(self, x1, x2): 62 | x1 = self.up(x1) 63 | # input is CHW 64 | diffY = x2.size()[2] - x1.size()[2] 65 | diffX = x2.size()[3] - x1.size()[3] 66 | 67 | x1 = torch_f.pad(x1, [diffX // 2, diffX - diffX // 2, 68 | diffY // 2, diffY - diffY // 2]) 69 | # if you have padding issues, see 70 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a 71 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd 72 | x = torch.cat([x2, x1], dim=1) 73 | return self.conv(x) 74 | 75 | 76 | class OutConv(nn.Module): 77 | def __init__(self, in_channels, out_channels): 78 | super(OutConv, self).__init__() 79 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) 80 | self.a = nn.Tanh() 81 | 82 | def forward(self, x): 83 | return self.a(self.conv(x)) 84 | 85 | 86 | class GeneratorWGAN(nn.Module): 87 | def __init__(self, n_channels, n_classes, bilinear=True): 88 | super(GeneratorWGAN, self).__init__() 89 | self.n_channels = n_channels 90 | self.n_classes = n_classes 91 | self.bilinear = bilinear 92 | 93 | self.inc = (DoubleConv(n_channels, 64)) 94 | self.down1 = (Down(64, 128)) 95 | self.down2 = (Down(128, 256)) 96 | self.down3 = (Down(256, 512)) 97 | factor = 2 if bilinear else 1 98 | self.down4 = (Down(512, 1024 // factor)) 99 | self.up1 = (Up(1024, 512 // factor, bilinear)) 100 | self.up2 = (Up(512, 256 // factor, bilinear)) 101 | self.up3 = (Up(256, 128 // factor, bilinear)) 102 | self.up4 = (Up(128, 64, bilinear)) 103 | self.outc = (OutConv(64, n_classes)) 104 | 105 | def forward(self, x): 106 | x1 = self.inc(x) 107 | x2 = self.down1(x1) 108 | x3 = self.down2(x2) 109 | x4 = self.down3(x3) 110 | x5 = self.down4(x4) 111 | x = self.up1(x5, x4) 112 | x = self.up2(x, x3) 113 | x = self.up3(x, x2) 114 | x = self.up4(x, x1) 115 | logits = self.outc(x) 116 | return logits 117 | 118 | 119 | def get_transforms() -> A.Compose: 120 | return A.Compose( 121 | [ 122 | A.Resize(p=1.0, height=IMAGE_SIZE, width=IMAGE_SIZE), 123 | A.Normalize((0.5,), (0.5,), p=1.0), 124 | ATorch.transforms.ToTensorV2(p=1.0), 125 | ], 126 | p=1.0, 127 | ) 128 | 129 | 130 | def get_noise(*, n_samples: int, z_dim: int, device: torch.device): 131 | return torch.randn(n_samples, z_dim, device=device) 132 | 133 | 134 | def combine_vectors(x, y): 135 | combined = torch.cat([x.float(), y.float()], dim=1) 136 | return combined 137 | 138 | 139 | class Model: 140 | def __init__(self) -> None: 141 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 142 | self.gen = GeneratorWGAN( 143 | n_channels=2, 144 | n_classes=2, 145 | ).to(self.device) 146 | state_dict = torch.load("models/210_generator.pt", map_location=self.device)["model_state_dict"] 147 | self.gen.load_state_dict(state_dict) 148 | 149 | self.transforms = get_transforms() 150 | 151 | def __call__(self, image): 152 | with torch.no_grad(): 153 | img = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY).astype(float) 154 | img = self.transforms(image=img)["image"] 155 | 156 | sample = {"image": img} 157 | 158 | gen_real = sample["image"].unsqueeze(0).to(self.device) 159 | 160 | fake_noise = get_noise( 161 | n_samples=1, 162 | z_dim=128, 163 | device=self.device, 164 | ).unsqueeze(1).unsqueeze(2).expand(-1, -1, IMAGE_SIZE, IMAGE_SIZE) 165 | 166 | noise_and_labels = combine_vectors(gen_real, fake_noise) 167 | fake = self.gen(noise_and_labels) 168 | 169 | real = (gen_real.cpu().numpy()[0][0] + 1) / 2 170 | image_tensor = (fake[0, 0, :, :].detach().cpu().numpy() + 1) / 2 171 | mask_unflat = fake[0, 1, :, :].cpu().numpy() 172 | 173 | merged = np.where(mask_unflat > 0.5, image_tensor, real) 174 | 175 | # cv2.imwrite(f"tmp.png", merged * 255) 176 | # cv2.imwrite(f"tmp_mask.png", (mask_unflat > 0.5) * 255) 177 | 178 | return merged --------------------------------------------------------------------------------